discovery: adhere to proper channel chunk splitting for ReplyChannelRange

This commit is contained in:
Wilmer Paulino 2020-12-02 15:15:44 -08:00
parent c5fc7334a4
commit a4f33ae63c
No known key found for this signature in database
GPG Key ID: 6DF57B9F9514972F
7 changed files with 340 additions and 145 deletions

@ -10,6 +10,7 @@ import (
"io" "io"
"math" "math"
"net" "net"
"sort"
"sync" "sync"
"time" "time"
@ -1704,12 +1705,25 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) {
return newChanIDs, nil return newChanIDs, nil
} }
// BlockChannelRange represents a range of channels for a given block height.
type BlockChannelRange struct {
// Height is the height of the block all of the channels below were
// included in.
Height uint32
// Channels is the list of channels identified by their short ID
// representation known to us that were included in the block height
// above.
Channels []lnwire.ShortChannelID
}
// FilterChannelRange returns the channel ID's of all known channels which were // FilterChannelRange returns the channel ID's of all known channels which were
// mined in a block height within the passed range. This method can be used to // mined in a block height within the passed range. The channel IDs are grouped
// quickly share with a peer the set of channels we know of within a particular // by their common block height. This method can be used to quickly share with a
// range to catch them up after a period of time offline. // peer the set of channels we know of within a particular range to catch them
func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) { // up after a period of time offline.
var chanIDs []uint64 func (c *ChannelGraph) FilterChannelRange(startHeight,
endHeight uint32) ([]BlockChannelRange, error) {
startChanID := &lnwire.ShortChannelID{ startChanID := &lnwire.ShortChannelID{
BlockHeight: startHeight, BlockHeight: startHeight,
@ -1728,6 +1742,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint
byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64())
byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64())
var channelsPerBlock map[uint32][]lnwire.ShortChannelID
err := kvdb.View(c.db, func(tx kvdb.RTx) error { err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket) edges := tx.ReadBucket(edgeBucket)
if edges == nil { if edges == nil {
@ -1742,33 +1757,51 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint
// We'll now iterate through the database, and find each // We'll now iterate through the database, and find each
// channel ID that resides within the specified range. // channel ID that resides within the specified range.
var cid uint64
for k, _ := cursor.Seek(chanIDStart[:]); k != nil && for k, _ := cursor.Seek(chanIDStart[:]); k != nil &&
bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() { bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() {
// This channel ID rests within the target range, so // This channel ID rests within the target range, so
// we'll convert it into an integer and add it to our // we'll add it to our returned set.
// returned set. rawCid := byteOrder.Uint64(k)
cid = byteOrder.Uint64(k) cid := lnwire.NewShortChanIDFromInt(rawCid)
chanIDs = append(chanIDs, cid) channelsPerBlock[cid.BlockHeight] = append(
channelsPerBlock[cid.BlockHeight], cid,
)
} }
return nil return nil
}, func() { }, func() {
chanIDs = nil channelsPerBlock = make(map[uint32][]lnwire.ShortChannelID)
}) })
switch { switch {
// If we don't know of any channels yet, then there's nothing to // If we don't know of any channels yet, then there's nothing to
// filter, so we'll return an empty slice. // filter, so we'll return an empty slice.
case err == ErrGraphNoEdgesFound: case err == ErrGraphNoEdgesFound || len(channelsPerBlock) == 0:
return chanIDs, nil return nil, nil
case err != nil: case err != nil:
return nil, err return nil, err
} }
return chanIDs, nil // Return the channel ranges in ascending block height order.
blocks := make([]uint32, 0, len(channelsPerBlock))
for block := range channelsPerBlock {
blocks = append(blocks, block)
}
sort.Slice(blocks, func(i, j int) bool {
return blocks[i] < blocks[j]
})
channelRanges := make([]BlockChannelRange, 0, len(channelsPerBlock))
for _, block := range blocks {
channelRanges = append(channelRanges, BlockChannelRange{
Height: block,
Channels: channelsPerBlock[block],
})
}
return channelRanges, nil
} }
// FetchChanInfos returns the set of channel edges that correspond to the passed // FetchChanInfos returns the set of channel edges that correspond to the passed

@ -1848,24 +1848,32 @@ func TestFilterChannelRange(t *testing.T) {
t.Fatalf("expected zero chans, instead got %v", len(resp)) t.Fatalf("expected zero chans, instead got %v", len(resp))
} }
// To start, we'll create a set of channels, each mined in a block 10 // To start, we'll create a set of channels, two mined in a block 10
// blocks after the prior one. // blocks after the prior one.
startHeight := uint32(100) startHeight := uint32(100)
endHeight := startHeight endHeight := startHeight
const numChans = 10 const numChans = 10
chanIDs := make([]uint64, 0, numChans) channelRanges := make([]BlockChannelRange, 0, numChans/2)
for i := 0; i < numChans; i++ { for i := 0; i < numChans/2; i++ {
chanHeight := endHeight chanHeight := endHeight
channel, chanID := createEdge( channel1, chanID1 := createEdge(
uint32(chanHeight), uint32(i+1), 0, 0, node1, node2, chanHeight, uint32(i+1), 0, 0, node1, node2,
) )
if err := graph.AddChannelEdge(&channel1); err != nil {
if err := graph.AddChannelEdge(&channel); err != nil {
t.Fatalf("unable to create channel edge: %v", err) t.Fatalf("unable to create channel edge: %v", err)
} }
chanIDs = append(chanIDs, chanID.ToUint64()) channel2, chanID2 := createEdge(
chanHeight, uint32(i+2), 0, 0, node1, node2,
)
if err := graph.AddChannelEdge(&channel2); err != nil {
t.Fatalf("unable to create channel edge: %v", err)
}
channelRanges = append(channelRanges, BlockChannelRange{
Height: chanHeight,
Channels: []lnwire.ShortChannelID{chanID1, chanID2},
})
endHeight += 10 endHeight += 10
} }
@ -1876,7 +1884,7 @@ func TestFilterChannelRange(t *testing.T) {
startHeight uint32 startHeight uint32
endHeight uint32 endHeight uint32
resp []uint64 resp []BlockChannelRange
}{ }{
// If we query for the entire range, then we should get the same // If we query for the entire range, then we should get the same
// set of short channel IDs back. // set of short channel IDs back.
@ -1884,7 +1892,7 @@ func TestFilterChannelRange(t *testing.T) {
startHeight: startHeight, startHeight: startHeight,
endHeight: endHeight, endHeight: endHeight,
resp: chanIDs, resp: channelRanges,
}, },
// If we query for a range of channels right before our range, we // If we query for a range of channels right before our range, we
@ -1900,7 +1908,7 @@ func TestFilterChannelRange(t *testing.T) {
startHeight: endHeight - 10, startHeight: endHeight - 10,
endHeight: endHeight - 10, endHeight: endHeight - 10,
resp: chanIDs[9:], resp: channelRanges[4:],
}, },
// If we query for just the first height, we should only get a // If we query for just the first height, we should only get a
@ -1909,7 +1917,14 @@ func TestFilterChannelRange(t *testing.T) {
startHeight: startHeight, startHeight: startHeight,
endHeight: startHeight, endHeight: startHeight,
resp: chanIDs[:1], resp: channelRanges[:1],
},
{
startHeight: startHeight + 10,
endHeight: endHeight - 10,
resp: channelRanges[1:5],
}, },
} }
for i, queryCase := range queryCases { for i, queryCase := range queryCases {

@ -39,10 +39,11 @@ type ChannelGraphTimeSeries interface {
superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error)
// FilterChannelRange returns the set of channels that we created // FilterChannelRange returns the set of channels that we created
// between the start height and the end height. We'll use this to to a // between the start height and the end height. The channel IDs are
// remote peer's QueryChannelRange message. // grouped by their common block height. We'll use this to to a remote
// peer's QueryChannelRange message.
FilterChannelRange(chain chainhash.Hash, FilterChannelRange(chain chainhash.Hash,
startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error)
// FetchChanAnns returns a full set of channel announcements as well as // FetchChanAnns returns a full set of channel announcements as well as
// their updates that match the set of specified short channel ID's. // their updates that match the set of specified short channel ID's.
@ -203,26 +204,15 @@ func (c *ChanSeries) FilterKnownChanIDs(chain chainhash.Hash,
} }
// FilterChannelRange returns the set of channels that we created between the // FilterChannelRange returns the set of channels that we created between the
// start height and the end height. We'll use this respond to a remote peer's // start height and the end height. The channel IDs are grouped by their common
// QueryChannelRange message. // block height. We'll use this respond to a remote peer's QueryChannelRange
// message.
// //
// NOTE: This is part of the ChannelGraphTimeSeries interface. // NOTE: This is part of the ChannelGraphTimeSeries interface.
func (c *ChanSeries) FilterChannelRange(chain chainhash.Hash, func (c *ChanSeries) FilterChannelRange(chain chainhash.Hash,
startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) { startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) {
chansInRange, err := c.graph.FilterChannelRange(startHeight, endHeight) return c.graph.FilterChannelRange(startHeight, endHeight)
if err != nil {
return nil, err
}
chanResp := make([]lnwire.ShortChannelID, 0, len(chansInRange))
for _, chanID := range chansInRange {
chanResp = append(
chanResp, lnwire.NewShortChanIDFromInt(chanID),
)
}
return chanResp, nil
} }
// FetchChanAnns returns a full set of channel announcements as well as their // FetchChanAnns returns a full set of channel announcements as well as their

@ -426,6 +426,7 @@ func (m *SyncManager) createGossipSyncer(peer lnpeer.Peer) *GossipSyncer {
maxUndelayedQueryReplies: DefaultMaxUndelayedQueryReplies, maxUndelayedQueryReplies: DefaultMaxUndelayedQueryReplies,
delayedQueryReplyInterval: DefaultDelayedQueryReplyInterval, delayedQueryReplyInterval: DefaultDelayedQueryReplyInterval,
bestHeight: m.cfg.BestHeight, bestHeight: m.cfg.BestHeight,
maxQueryChanRangeReplies: maxQueryChanRangeReplies,
}) })
// Gossip syncers are initialized by default in a PassiveSync type // Gossip syncers are initialized by default in a PassiveSync type

@ -11,6 +11,7 @@ import (
"github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntest/wait"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/ticker"
"github.com/stretchr/testify/require"
) )
// randPeer creates a random peer. // randPeer creates a random peer.
@ -537,10 +538,14 @@ func assertTransitionToChansSynced(t *testing.T, s *GossipSyncer, peer *mockPeer
} }
assertMsgSent(t, peer, query) assertMsgSent(t, peer, query)
s.ProcessQueryMsg(&lnwire.ReplyChannelRange{ require.Eventually(t, func() bool {
return s.syncState() == waitingQueryRangeReply
}, time.Second, 500*time.Millisecond)
require.NoError(t, s.ProcessQueryMsg(&lnwire.ReplyChannelRange{
QueryChannelRange: *query, QueryChannelRange: *query,
Complete: 1, Complete: 1,
}, nil) }, nil))
chanSeries := s.cfg.channelSeries.(*mockChannelGraphTimeSeries) chanSeries := s.cfg.channelSeries.(*mockChannelGraphTimeSeries)

@ -4,6 +4,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"math" "math"
"math/rand"
"sort"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -128,6 +130,14 @@ const (
// maxUndelayedQueryReplies queries. // maxUndelayedQueryReplies queries.
DefaultDelayedQueryReplyInterval = 5 * time.Second DefaultDelayedQueryReplyInterval = 5 * time.Second
// maxQueryChanRangeReplies specifies the default limit of replies to
// process for a single QueryChannelRange request.
maxQueryChanRangeReplies = 500
// maxQueryChanRangeRepliesZlibFactor specifies the factor applied to
// the maximum number of replies allowed for zlib encoded replies.
maxQueryChanRangeRepliesZlibFactor = 4
// chanRangeQueryBuffer is the number of blocks back that we'll go when // chanRangeQueryBuffer is the number of blocks back that we'll go when
// asking the remote peer for their any channels they know of beyond // asking the remote peer for their any channels they know of beyond
// our highest known channel ID. // our highest known channel ID.
@ -240,6 +250,10 @@ type gossipSyncerCfg struct {
// bestHeight returns the latest height known of the chain. // bestHeight returns the latest height known of the chain.
bestHeight func() uint32 bestHeight func() uint32
// maxQueryChanRangeReplies is the maximum number of replies we'll allow
// for a single QueryChannelRange request.
maxQueryChanRangeReplies uint32
} }
// GossipSyncer is a struct that handles synchronizing the channel graph state // GossipSyncer is a struct that handles synchronizing the channel graph state
@ -316,6 +330,11 @@ type GossipSyncer struct {
// buffer all the chunked response to our query. // buffer all the chunked response to our query.
bufferedChanRangeReplies []lnwire.ShortChannelID bufferedChanRangeReplies []lnwire.ShortChannelID
// numChanRangeRepliesRcvd is used to track the number of replies
// received as part of a QueryChannelRange. This field is primarily used
// within the waitingQueryChanReply state.
numChanRangeRepliesRcvd uint32
// newChansToQuery is used to pass the set of channels we should query // newChansToQuery is used to pass the set of channels we should query
// for from the waitingQueryChanReply state to the queryNewChannels // for from the waitingQueryChanReply state to the queryNewChannels
// state. // state.
@ -741,17 +760,27 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro
g.bufferedChanRangeReplies = append( g.bufferedChanRangeReplies = append(
g.bufferedChanRangeReplies, msg.ShortChanIDs..., g.bufferedChanRangeReplies, msg.ShortChanIDs...,
) )
switch g.cfg.encodingType {
case lnwire.EncodingSortedPlain:
g.numChanRangeRepliesRcvd++
case lnwire.EncodingSortedZlib:
g.numChanRangeRepliesRcvd += maxQueryChanRangeRepliesZlibFactor
default:
return fmt.Errorf("unhandled encoding type %v", g.cfg.encodingType)
}
log.Infof("GossipSyncer(%x): buffering chan range reply of size=%v", log.Infof("GossipSyncer(%x): buffering chan range reply of size=%v",
g.cfg.peerPub[:], len(msg.ShortChanIDs)) g.cfg.peerPub[:], len(msg.ShortChanIDs))
// If this isn't the last response, then we can exit as we've already // If this isn't the last response and we can continue to receive more,
// buffered the latest portion of the streaming reply. // then we can exit as we've already buffered the latest portion of the
// streaming reply.
maxReplies := g.cfg.maxQueryChanRangeReplies
switch { switch {
// If we're communicating with a legacy node, we'll need to look at the // If we're communicating with a legacy node, we'll need to look at the
// complete field. // complete field.
case isLegacyReplyChannelRange(g.curQueryRangeMsg, msg): case isLegacyReplyChannelRange(g.curQueryRangeMsg, msg):
if msg.Complete == 0 { if msg.Complete == 0 && g.numChanRangeRepliesRcvd < maxReplies {
return nil return nil
} }
@ -763,7 +792,8 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro
// TODO(wilmer): This might require some padding if the remote // TODO(wilmer): This might require some padding if the remote
// node is not aware of the last height we sent them, i.e., is // node is not aware of the last height we sent them, i.e., is
// behind a few blocks from us. // behind a few blocks from us.
if replyLastHeight < queryLastHeight { if replyLastHeight < queryLastHeight &&
g.numChanRangeRepliesRcvd < maxReplies {
return nil return nil
} }
} }
@ -786,6 +816,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro
g.curQueryRangeMsg = nil g.curQueryRangeMsg = nil
g.prevReplyChannelRange = nil g.prevReplyChannelRange = nil
g.bufferedChanRangeReplies = nil g.bufferedChanRangeReplies = nil
g.numChanRangeRepliesRcvd = 0
// If there aren't any channels that we don't know of, then we can // If there aren't any channels that we don't know of, then we can
// switch straight to our terminal state. // switch straight to our terminal state.
@ -930,7 +961,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro
// channel ID's that match their query. // channel ID's that match their query.
startBlock := query.FirstBlockHeight startBlock := query.FirstBlockHeight
endBlock := query.LastBlockHeight() endBlock := query.LastBlockHeight()
channelRange, err := g.cfg.channelSeries.FilterChannelRange( channelRanges, err := g.cfg.channelSeries.FilterChannelRange(
query.ChainHash, startBlock, endBlock, query.ChainHash, startBlock, endBlock,
) )
if err != nil { if err != nil {
@ -940,102 +971,98 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro
// TODO(roasbeef): means can't send max uint above? // TODO(roasbeef): means can't send max uint above?
// * or make internal 64 // * or make internal 64
// In the base case (no actual response) the first block and last block // We'll send our response in a streaming manner, chunk-by-chunk. We do
// will match those of the query. In the loop below, we'll update these // this as there's a transport message size limit which we'll need to
// two variables incrementally with each chunk to properly compute the // adhere to. We also need to make sure all of our replies cover the
// starting block for each response and the number of blocks in a // expected range of the query.
// response. sendReplyForChunk := func(channelChunk []lnwire.ShortChannelID,
firstBlockHeight := startBlock firstHeight, lastHeight uint32, finalChunk bool) error {
lastBlockHeight := endBlock
numChannels := int32(len(channelRange)) // The number of blocks contained in the current chunk (the
numChansSent := int32(0) // total span) is the difference between the last channel ID and
for { // the first in the range. We add one as even if all channels
// We'll send our this response in a streaming manner,
// chunk-by-chunk. We do this as there's a transport message
// size limit which we'll need to adhere to.
var channelChunk []lnwire.ShortChannelID
// We know this is the final chunk, if the difference between
// the total number of channels, and the number of channels
// we've sent is less-than-or-equal to the chunk size.
isFinalChunk := (numChannels - numChansSent) <= g.cfg.chunkSize
// If this is indeed the last chunk, then we'll send the
// remainder of the channels.
if isFinalChunk {
channelChunk = channelRange[numChansSent:]
log.Infof("GossipSyncer(%x): sending final chan "+
"range chunk, size=%v", g.cfg.peerPub[:],
len(channelChunk))
} else {
// Otherwise, we'll only send off a fragment exactly
// sized to the proper chunk size.
channelChunk = channelRange[numChansSent : numChansSent+g.cfg.chunkSize]
log.Infof("GossipSyncer(%x): sending range chunk of "+
"size=%v", g.cfg.peerPub[:], len(channelChunk))
}
// If we have any channels at all to return, then we need to
// update our pointers to the first and last blocks for each
// response.
if len(channelChunk) > 0 {
// If this is the first response we'll send, we'll point
// the first block to the first block in the query.
// Otherwise, we'll continue from the block we left off
// at.
if numChansSent == 0 {
firstBlockHeight = startBlock
} else {
firstBlockHeight = lastBlockHeight
}
// If this is the last response we'll send, we'll point
// the last block to the last block of the query.
// Otherwise, we'll set it to the height of the last
// channel in the chunk.
if isFinalChunk {
lastBlockHeight = endBlock
} else {
lastBlockHeight = channelChunk[len(channelChunk)-1].BlockHeight
}
}
// The number of blocks contained in this response (the total
// span) is the difference between the last channel ID and the
// first in the range. We add one as even if all channels
// returned are in the same block, we need to count that. // returned are in the same block, we need to count that.
numBlocksInResp := lastBlockHeight - firstBlockHeight + 1 numBlocks := lastHeight - firstHeight + 1
complete := uint8(0)
if finalChunk {
complete = 1
}
// With our chunk assembled, we'll now send to the remote peer return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{
// the current chunk.
replyChunk := lnwire.ReplyChannelRange{
QueryChannelRange: lnwire.QueryChannelRange{ QueryChannelRange: lnwire.QueryChannelRange{
ChainHash: query.ChainHash, ChainHash: query.ChainHash,
NumBlocks: numBlocksInResp, NumBlocks: numBlocks,
FirstBlockHeight: firstBlockHeight, FirstBlockHeight: firstHeight,
}, },
Complete: 0, Complete: complete,
EncodingType: g.cfg.encodingType, EncodingType: g.cfg.encodingType,
ShortChanIDs: channelChunk, ShortChanIDs: channelChunk,
})
} }
if isFinalChunk {
replyChunk.Complete = 1 var (
firstHeight = query.FirstBlockHeight
lastHeight uint32
channelChunk []lnwire.ShortChannelID
)
for _, channelRange := range channelRanges {
channels := channelRange.Channels
numChannels := int32(len(channels))
numLeftToAdd := g.cfg.chunkSize - int32(len(channelChunk))
// Include the current block in the ongoing chunk if it can fit
// and move on to the next block.
if numChannels <= numLeftToAdd {
channelChunk = append(channelChunk, channels...)
continue
} }
if err := g.cfg.sendToPeerSync(&replyChunk); err != nil {
// Otherwise, we need to send our existing channel chunk as is
// as its own reply and start a new one for the current block.
// We'll mark the end of our current chunk as the height before
// the current block to ensure the whole query range is replied
// to.
log.Infof("GossipSyncer(%x): sending range chunk of size=%v",
g.cfg.peerPub[:], len(channelChunk))
lastHeight = channelRange.Height - 1
err := sendReplyForChunk(
channelChunk, firstHeight, lastHeight, false,
)
if err != nil {
return err return err
} }
// If this was the final chunk, then we'll exit now as our // With the reply constructed, we'll start tallying channels for
// response is now complete. // our next one keeping in mind our chunk size. This may result
if isFinalChunk { // in channels for this block being left out from the reply, but
return nil // this isn't an issue since we'll randomly shuffle them and we
// assume a historical gossip sync is performed at a later time.
firstHeight = channelRange.Height
chunkSize := numChannels
exceedsChunkSize := numChannels > g.cfg.chunkSize
if exceedsChunkSize {
rand.Shuffle(len(channels), func(i, j int) {
channels[i], channels[j] = channels[j], channels[i]
})
chunkSize = g.cfg.chunkSize
}
channelChunk = channels[:chunkSize]
// Sort the chunk once again if we had to shuffle it.
if exceedsChunkSize {
sort.Slice(channelChunk, func(i, j int) bool {
return channelChunk[i].ToUint64() <
channelChunk[j].ToUint64()
})
}
} }
numChansSent += int32(len(channelChunk)) // Send the remaining chunk as the final reply.
} log.Infof("GossipSyncer(%x): sending final chan range chunk, size=%v",
g.cfg.peerPub[:], len(channelChunk))
return sendReplyForChunk(
channelChunk, firstHeight, query.LastBlockHeight(), true,
)
} }
// replyShortChanIDs will be dispatched in response to a query by the remote // replyShortChanIDs will be dispatched in response to a query by the remote
@ -1285,11 +1312,23 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) {
// ProcessQueryMsg is used by outside callers to pass new channel time series // ProcessQueryMsg is used by outside callers to pass new channel time series
// queries to the internal processing goroutine. // queries to the internal processing goroutine.
func (g *GossipSyncer) ProcessQueryMsg(msg lnwire.Message, peerQuit <-chan struct{}) { func (g *GossipSyncer) ProcessQueryMsg(msg lnwire.Message, peerQuit <-chan struct{}) error {
var msgChan chan lnwire.Message var msgChan chan lnwire.Message
switch msg.(type) { switch msg.(type) {
case *lnwire.QueryChannelRange, *lnwire.QueryShortChanIDs: case *lnwire.QueryChannelRange, *lnwire.QueryShortChanIDs:
msgChan = g.queryMsgs msgChan = g.queryMsgs
// Reply messages should only be expected in states where we're waiting
// for a reply.
case *lnwire.ReplyChannelRange, *lnwire.ReplyShortChanIDsEnd:
syncState := g.syncState()
if syncState != waitingQueryRangeReply &&
syncState != waitingQueryChanReply {
return fmt.Errorf("received unexpected query reply "+
"message %T", msg)
}
msgChan = g.gossipMsgs
default: default:
msgChan = g.gossipMsgs msgChan = g.gossipMsgs
} }
@ -1299,6 +1338,8 @@ func (g *GossipSyncer) ProcessQueryMsg(msg lnwire.Message, peerQuit <-chan struc
case <-peerQuit: case <-peerQuit:
case <-g.quit: case <-g.quit:
} }
return nil
} }
// setSyncState sets the gossip syncer's state to the given state. // setSyncState sets the gossip syncer's state to the given state.

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"math" "math"
"reflect" "reflect"
"sort"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -12,7 +13,9 @@ import (
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
) )
const ( const (
@ -95,11 +98,36 @@ func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash,
return <-m.filterResp, nil return <-m.filterResp, nil
} }
func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash,
startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) { startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) {
m.filterRangeReqs <- filterRangeReq{startHeight, endHeight} m.filterRangeReqs <- filterRangeReq{startHeight, endHeight}
reply := <-m.filterRangeResp
return <-m.filterRangeResp, nil channelsPerBlock := make(map[uint32][]lnwire.ShortChannelID)
for _, cid := range reply {
channelsPerBlock[cid.BlockHeight] = append(
channelsPerBlock[cid.BlockHeight], cid,
)
}
// Return the channel ranges in ascending block height order.
blocks := make([]uint32, 0, len(channelsPerBlock))
for block := range channelsPerBlock {
blocks = append(blocks, block)
}
sort.Slice(blocks, func(i, j int) bool {
return blocks[i] < blocks[j]
})
channelRanges := make([]channeldb.BlockChannelRange, 0, len(channelsPerBlock))
for _, block := range blocks {
channelRanges = append(channelRanges, channeldb.BlockChannelRange{
Height: block,
Channels: channelsPerBlock[block],
})
}
return channelRanges, nil
} }
func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash,
shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) { shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) {
@ -161,6 +189,7 @@ func newTestSyncer(hID lnwire.ShortChannelID,
bestHeight: func() uint32 { bestHeight: func() uint32 {
return latestKnownHeight return latestKnownHeight
}, },
maxQueryChanRangeReplies: maxQueryChanRangeReplies,
} }
syncer := newGossipSyncer(cfg) syncer := newGossipSyncer(cfg)
@ -828,6 +857,7 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) {
// reply. We should get three sets of messages as two of them should be // reply. We should get three sets of messages as two of them should be
// full, while the other is the final fragment. // full, while the other is the final fragment.
const numExpectedChunks = 3 const numExpectedChunks = 3
var prevResp *lnwire.ReplyChannelRange
respMsgs := make([]lnwire.ShortChannelID, 0, 5) respMsgs := make([]lnwire.ShortChannelID, 0, 5)
for i := 0; i < numExpectedChunks; i++ { for i := 0; i < numExpectedChunks; i++ {
select { select {
@ -855,14 +885,14 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) {
// channels. // channels.
case i == 0: case i == 0:
expectedFirstBlockHeight = startingBlockHeight expectedFirstBlockHeight = startingBlockHeight
expectedNumBlocks = chunkSize + 1 expectedNumBlocks = 4
// The last reply should range starting from the next // The last reply should range starting from the next
// block of our previous reply up until the ending // block of our previous reply up until the ending
// height of the query. It should also have the Complete // height of the query. It should also have the Complete
// bit set. // bit set.
case i == numExpectedChunks-1: case i == numExpectedChunks-1:
expectedFirstBlockHeight = respMsgs[len(respMsgs)-1].BlockHeight expectedFirstBlockHeight = prevResp.LastBlockHeight() + 1
expectedNumBlocks = endingBlockHeight - expectedFirstBlockHeight + 1 expectedNumBlocks = endingBlockHeight - expectedFirstBlockHeight + 1
expectedComplete = 1 expectedComplete = 1
@ -870,8 +900,8 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) {
// the next block of our previous reply up until it // the next block of our previous reply up until it
// reaches its maximum capacity of channels. // reaches its maximum capacity of channels.
default: default:
expectedFirstBlockHeight = respMsgs[len(respMsgs)-1].BlockHeight expectedFirstBlockHeight = prevResp.LastBlockHeight() + 1
expectedNumBlocks = 5 expectedNumBlocks = 4
} }
switch { switch {
@ -889,9 +919,10 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) {
case rangeResp.Complete != expectedComplete: case rangeResp.Complete != expectedComplete:
t.Fatalf("Complete in resp #%d incorrect: "+ t.Fatalf("Complete in resp #%d incorrect: "+
"expected %v, got %v", i+1, "expected %v, got %v", i+1,
expectedNumBlocks, rangeResp.Complete) expectedComplete, rangeResp.Complete)
} }
prevResp = rangeResp
respMsgs = append(respMsgs, rangeResp.ShortChanIDs...) respMsgs = append(respMsgs, rangeResp.ShortChanIDs...)
} }
} }
@ -1498,10 +1529,12 @@ func TestGossipSyncerDelayDOS(t *testing.T) {
// inherently disjoint. // inherently disjoint.
var syncer2Chans []lnwire.ShortChannelID var syncer2Chans []lnwire.ShortChannelID
for i := 0; i < numTotalChans; i++ { for i := 0; i < numTotalChans; i++ {
syncer2Chans = append(syncer2Chans, lnwire.ShortChannelID{ syncer2Chans = append([]lnwire.ShortChannelID{
BlockHeight: highestID.BlockHeight - 1, {
BlockHeight: highestID.BlockHeight - uint32(i) - 1,
TxIndex: uint32(i), TxIndex: uint32(i),
}) },
}, syncer2Chans...)
} }
// We'll kick off the test by asserting syncer1 sends over the // We'll kick off the test by asserting syncer1 sends over the
@ -2305,3 +2338,80 @@ func TestGossipSyncerSyncedSignal(t *testing.T) {
t.Fatal("expected to receive chansSynced signal") t.Fatal("expected to receive chansSynced signal")
} }
} }
// TestGossipSyncerMaxChannelRangeReplies ensures that a gossip syncer
// transitions its state after receiving the maximum possible number of replies
// for a single QueryChannelRange message, and that any further replies after
// said limit are not processed.
func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) {
t.Parallel()
msgChan, syncer, chanSeries := newTestSyncer(
lnwire.ShortChannelID{BlockHeight: latestKnownHeight},
defaultEncoding, defaultChunkSize,
)
// We'll tune the maxQueryChanRangeReplies to a more sensible value for
// the sake of testing.
syncer.cfg.maxQueryChanRangeReplies = 100
syncer.Start()
defer syncer.Stop()
// Upon initialization, the syncer should submit a QueryChannelRange
// request.
var query *lnwire.QueryChannelRange
select {
case msgs := <-msgChan:
require.Len(t, msgs, 1)
require.IsType(t, &lnwire.QueryChannelRange{}, msgs[0])
query = msgs[0].(*lnwire.QueryChannelRange)
case <-time.After(time.Second):
t.Fatal("expected query channel range request msg")
}
// We'll send the maximum number of replies allowed to a
// QueryChannelRange request with each reply consuming only one block in
// order to transition the syncer's state.
for i := uint32(0); i < syncer.cfg.maxQueryChanRangeReplies; i++ {
reply := &lnwire.ReplyChannelRange{
QueryChannelRange: *query,
ShortChanIDs: []lnwire.ShortChannelID{
{
BlockHeight: query.FirstBlockHeight + i,
},
},
}
reply.FirstBlockHeight = query.FirstBlockHeight + i
reply.NumBlocks = 1
require.NoError(t, syncer.ProcessQueryMsg(reply, nil))
}
// We should receive a filter request for the syncer's local channels
// after processing all of the replies. We'll send back a nil response
// indicating that no new channels need to be synced, so it should
// transition to its final chansSynced state.
select {
case <-chanSeries.filterReq:
case <-time.After(time.Second):
t.Fatal("expected local filter request of known channels")
}
select {
case chanSeries.filterResp <- nil:
case <-time.After(time.Second):
t.Fatal("timed out sending filter response")
}
assertSyncerStatus(t, syncer, chansSynced, ActiveSync)
// Finally, attempting to process another reply for the same query
// should result in an error.
require.Error(t, syncer.ProcessQueryMsg(&lnwire.ReplyChannelRange{
QueryChannelRange: *query,
ShortChanIDs: []lnwire.ShortChannelID{
{
BlockHeight: query.LastBlockHeight() + 1,
},
},
}, nil))
}