diff --git a/channeldb/graph.go b/channeldb/graph.go index 2333bc94..2eaf08ed 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -10,6 +10,7 @@ import ( "io" "math" "net" + "sort" "sync" "time" @@ -1704,12 +1705,25 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { return newChanIDs, nil } +// BlockChannelRange represents a range of channels for a given block height. +type BlockChannelRange struct { + // Height is the height of the block all of the channels below were + // included in. + Height uint32 + + // Channels is the list of channels identified by their short ID + // representation known to us that were included in the block height + // above. + Channels []lnwire.ShortChannelID +} + // FilterChannelRange returns the channel ID's of all known channels which were -// mined in a block height within the passed range. This method can be used to -// quickly share with a peer the set of channels we know of within a particular -// range to catch them up after a period of time offline. -func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) { - var chanIDs []uint64 +// mined in a block height within the passed range. The channel IDs are grouped +// by their common block height. This method can be used to quickly share with a +// peer the set of channels we know of within a particular range to catch them +// up after a period of time offline. +func (c *ChannelGraph) FilterChannelRange(startHeight, + endHeight uint32) ([]BlockChannelRange, error) { startChanID := &lnwire.ShortChannelID{ BlockHeight: startHeight, @@ -1728,6 +1742,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) + var channelsPerBlock map[uint32][]lnwire.ShortChannelID err := kvdb.View(c.db, func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { @@ -1742,33 +1757,51 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint // We'll now iterate through the database, and find each // channel ID that resides within the specified range. - var cid uint64 for k, _ := cursor.Seek(chanIDStart[:]); k != nil && bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() { // This channel ID rests within the target range, so - // we'll convert it into an integer and add it to our - // returned set. - cid = byteOrder.Uint64(k) - chanIDs = append(chanIDs, cid) + // we'll add it to our returned set. + rawCid := byteOrder.Uint64(k) + cid := lnwire.NewShortChanIDFromInt(rawCid) + channelsPerBlock[cid.BlockHeight] = append( + channelsPerBlock[cid.BlockHeight], cid, + ) } return nil }, func() { - chanIDs = nil + channelsPerBlock = make(map[uint32][]lnwire.ShortChannelID) }) switch { // If we don't know of any channels yet, then there's nothing to // filter, so we'll return an empty slice. - case err == ErrGraphNoEdgesFound: - return chanIDs, nil + case err == ErrGraphNoEdgesFound || len(channelsPerBlock) == 0: + return nil, nil case err != nil: return nil, err } - return chanIDs, nil + // Return the channel ranges in ascending block height order. + blocks := make([]uint32, 0, len(channelsPerBlock)) + for block := range channelsPerBlock { + blocks = append(blocks, block) + } + sort.Slice(blocks, func(i, j int) bool { + return blocks[i] < blocks[j] + }) + + channelRanges := make([]BlockChannelRange, 0, len(channelsPerBlock)) + for _, block := range blocks { + channelRanges = append(channelRanges, BlockChannelRange{ + Height: block, + Channels: channelsPerBlock[block], + }) + } + + return channelRanges, nil } // FetchChanInfos returns the set of channel edges that correspond to the passed diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 2abdcc8e..331d1769 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -1848,24 +1848,32 @@ func TestFilterChannelRange(t *testing.T) { t.Fatalf("expected zero chans, instead got %v", len(resp)) } - // To start, we'll create a set of channels, each mined in a block 10 + // To start, we'll create a set of channels, two mined in a block 10 // blocks after the prior one. startHeight := uint32(100) endHeight := startHeight const numChans = 10 - chanIDs := make([]uint64, 0, numChans) - for i := 0; i < numChans; i++ { + channelRanges := make([]BlockChannelRange, 0, numChans/2) + for i := 0; i < numChans/2; i++ { chanHeight := endHeight - channel, chanID := createEdge( - uint32(chanHeight), uint32(i+1), 0, 0, node1, node2, + channel1, chanID1 := createEdge( + chanHeight, uint32(i+1), 0, 0, node1, node2, ) - - if err := graph.AddChannelEdge(&channel); err != nil { + if err := graph.AddChannelEdge(&channel1); err != nil { t.Fatalf("unable to create channel edge: %v", err) } - chanIDs = append(chanIDs, chanID.ToUint64()) + channel2, chanID2 := createEdge( + chanHeight, uint32(i+2), 0, 0, node1, node2, + ) + if err := graph.AddChannelEdge(&channel2); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + channelRanges = append(channelRanges, BlockChannelRange{ + Height: chanHeight, + Channels: []lnwire.ShortChannelID{chanID1, chanID2}, + }) endHeight += 10 } @@ -1876,7 +1884,7 @@ func TestFilterChannelRange(t *testing.T) { startHeight uint32 endHeight uint32 - resp []uint64 + resp []BlockChannelRange }{ // If we query for the entire range, then we should get the same // set of short channel IDs back. @@ -1884,7 +1892,7 @@ func TestFilterChannelRange(t *testing.T) { startHeight: startHeight, endHeight: endHeight, - resp: chanIDs, + resp: channelRanges, }, // If we query for a range of channels right before our range, we @@ -1900,7 +1908,7 @@ func TestFilterChannelRange(t *testing.T) { startHeight: endHeight - 10, endHeight: endHeight - 10, - resp: chanIDs[9:], + resp: channelRanges[4:], }, // If we query for just the first height, we should only get a @@ -1909,7 +1917,14 @@ func TestFilterChannelRange(t *testing.T) { startHeight: startHeight, endHeight: startHeight, - resp: chanIDs[:1], + resp: channelRanges[:1], + }, + + { + startHeight: startHeight + 10, + endHeight: endHeight - 10, + + resp: channelRanges[1:5], }, } for i, queryCase := range queryCases { diff --git a/discovery/chan_series.go b/discovery/chan_series.go index ffb59b4e..42ebe888 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -39,10 +39,11 @@ type ChannelGraphTimeSeries interface { superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) // FilterChannelRange returns the set of channels that we created - // between the start height and the end height. We'll use this to to a - // remote peer's QueryChannelRange message. + // between the start height and the end height. The channel IDs are + // grouped by their common block height. We'll use this to to a remote + // peer's QueryChannelRange message. FilterChannelRange(chain chainhash.Hash, - startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) + startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) // FetchChanAnns returns a full set of channel announcements as well as // their updates that match the set of specified short channel ID's. @@ -203,26 +204,15 @@ func (c *ChanSeries) FilterKnownChanIDs(chain chainhash.Hash, } // FilterChannelRange returns the set of channels that we created between the -// start height and the end height. We'll use this respond to a remote peer's -// QueryChannelRange message. +// start height and the end height. The channel IDs are grouped by their common +// block height. We'll use this respond to a remote peer's QueryChannelRange +// message. // // NOTE: This is part of the ChannelGraphTimeSeries interface. func (c *ChanSeries) FilterChannelRange(chain chainhash.Hash, - startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) { + startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) { - chansInRange, err := c.graph.FilterChannelRange(startHeight, endHeight) - if err != nil { - return nil, err - } - - chanResp := make([]lnwire.ShortChannelID, 0, len(chansInRange)) - for _, chanID := range chansInRange { - chanResp = append( - chanResp, lnwire.NewShortChanIDFromInt(chanID), - ) - } - - return chanResp, nil + return c.graph.FilterChannelRange(startHeight, endHeight) } // FetchChanAnns returns a full set of channel announcements as well as their diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index d81a905e..a0e73b06 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -426,6 +426,7 @@ func (m *SyncManager) createGossipSyncer(peer lnpeer.Peer) *GossipSyncer { maxUndelayedQueryReplies: DefaultMaxUndelayedQueryReplies, delayedQueryReplyInterval: DefaultDelayedQueryReplyInterval, bestHeight: m.cfg.BestHeight, + maxQueryChanRangeReplies: maxQueryChanRangeReplies, }) // Gossip syncers are initialized by default in a PassiveSync type diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index d3c0ba46..ac721868 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/ticker" + "github.com/stretchr/testify/require" ) // randPeer creates a random peer. @@ -537,10 +538,14 @@ func assertTransitionToChansSynced(t *testing.T, s *GossipSyncer, peer *mockPeer } assertMsgSent(t, peer, query) - s.ProcessQueryMsg(&lnwire.ReplyChannelRange{ + require.Eventually(t, func() bool { + return s.syncState() == waitingQueryRangeReply + }, time.Second, 500*time.Millisecond) + + require.NoError(t, s.ProcessQueryMsg(&lnwire.ReplyChannelRange{ QueryChannelRange: *query, Complete: 1, - }, nil) + }, nil)) chanSeries := s.cfg.channelSeries.(*mockChannelGraphTimeSeries) diff --git a/discovery/syncer.go b/discovery/syncer.go index 4f6c4256..04a722f2 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "math" + "math/rand" + "sort" "sync" "sync/atomic" "time" @@ -128,6 +130,14 @@ const ( // maxUndelayedQueryReplies queries. DefaultDelayedQueryReplyInterval = 5 * time.Second + // maxQueryChanRangeReplies specifies the default limit of replies to + // process for a single QueryChannelRange request. + maxQueryChanRangeReplies = 500 + + // maxQueryChanRangeRepliesZlibFactor specifies the factor applied to + // the maximum number of replies allowed for zlib encoded replies. + maxQueryChanRangeRepliesZlibFactor = 4 + // chanRangeQueryBuffer is the number of blocks back that we'll go when // asking the remote peer for their any channels they know of beyond // our highest known channel ID. @@ -240,6 +250,10 @@ type gossipSyncerCfg struct { // bestHeight returns the latest height known of the chain. bestHeight func() uint32 + + // maxQueryChanRangeReplies is the maximum number of replies we'll allow + // for a single QueryChannelRange request. + maxQueryChanRangeReplies uint32 } // GossipSyncer is a struct that handles synchronizing the channel graph state @@ -316,6 +330,11 @@ type GossipSyncer struct { // buffer all the chunked response to our query. bufferedChanRangeReplies []lnwire.ShortChannelID + // numChanRangeRepliesRcvd is used to track the number of replies + // received as part of a QueryChannelRange. This field is primarily used + // within the waitingQueryChanReply state. + numChanRangeRepliesRcvd uint32 + // newChansToQuery is used to pass the set of channels we should query // for from the waitingQueryChanReply state to the queryNewChannels // state. @@ -741,17 +760,27 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro g.bufferedChanRangeReplies = append( g.bufferedChanRangeReplies, msg.ShortChanIDs..., ) + switch g.cfg.encodingType { + case lnwire.EncodingSortedPlain: + g.numChanRangeRepliesRcvd++ + case lnwire.EncodingSortedZlib: + g.numChanRangeRepliesRcvd += maxQueryChanRangeRepliesZlibFactor + default: + return fmt.Errorf("unhandled encoding type %v", g.cfg.encodingType) + } log.Infof("GossipSyncer(%x): buffering chan range reply of size=%v", g.cfg.peerPub[:], len(msg.ShortChanIDs)) - // If this isn't the last response, then we can exit as we've already - // buffered the latest portion of the streaming reply. + // If this isn't the last response and we can continue to receive more, + // then we can exit as we've already buffered the latest portion of the + // streaming reply. + maxReplies := g.cfg.maxQueryChanRangeReplies switch { // If we're communicating with a legacy node, we'll need to look at the // complete field. case isLegacyReplyChannelRange(g.curQueryRangeMsg, msg): - if msg.Complete == 0 { + if msg.Complete == 0 && g.numChanRangeRepliesRcvd < maxReplies { return nil } @@ -763,7 +792,8 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // TODO(wilmer): This might require some padding if the remote // node is not aware of the last height we sent them, i.e., is // behind a few blocks from us. - if replyLastHeight < queryLastHeight { + if replyLastHeight < queryLastHeight && + g.numChanRangeRepliesRcvd < maxReplies { return nil } } @@ -786,6 +816,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro g.curQueryRangeMsg = nil g.prevReplyChannelRange = nil g.bufferedChanRangeReplies = nil + g.numChanRangeRepliesRcvd = 0 // If there aren't any channels that we don't know of, then we can // switch straight to our terminal state. @@ -930,7 +961,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // channel ID's that match their query. startBlock := query.FirstBlockHeight endBlock := query.LastBlockHeight() - channelRange, err := g.cfg.channelSeries.FilterChannelRange( + channelRanges, err := g.cfg.channelSeries.FilterChannelRange( query.ChainHash, startBlock, endBlock, ) if err != nil { @@ -940,102 +971,98 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // TODO(roasbeef): means can't send max uint above? // * or make internal 64 - // In the base case (no actual response) the first block and last block - // will match those of the query. In the loop below, we'll update these - // two variables incrementally with each chunk to properly compute the - // starting block for each response and the number of blocks in a - // response. - firstBlockHeight := startBlock - lastBlockHeight := endBlock + // We'll send our response in a streaming manner, chunk-by-chunk. We do + // this as there's a transport message size limit which we'll need to + // adhere to. We also need to make sure all of our replies cover the + // expected range of the query. + sendReplyForChunk := func(channelChunk []lnwire.ShortChannelID, + firstHeight, lastHeight uint32, finalChunk bool) error { - numChannels := int32(len(channelRange)) - numChansSent := int32(0) - for { - // We'll send our this response in a streaming manner, - // chunk-by-chunk. We do this as there's a transport message - // size limit which we'll need to adhere to. - var channelChunk []lnwire.ShortChannelID - - // We know this is the final chunk, if the difference between - // the total number of channels, and the number of channels - // we've sent is less-than-or-equal to the chunk size. - isFinalChunk := (numChannels - numChansSent) <= g.cfg.chunkSize - - // If this is indeed the last chunk, then we'll send the - // remainder of the channels. - if isFinalChunk { - channelChunk = channelRange[numChansSent:] - - log.Infof("GossipSyncer(%x): sending final chan "+ - "range chunk, size=%v", g.cfg.peerPub[:], - len(channelChunk)) - } else { - // Otherwise, we'll only send off a fragment exactly - // sized to the proper chunk size. - channelChunk = channelRange[numChansSent : numChansSent+g.cfg.chunkSize] - - log.Infof("GossipSyncer(%x): sending range chunk of "+ - "size=%v", g.cfg.peerPub[:], len(channelChunk)) - } - - // If we have any channels at all to return, then we need to - // update our pointers to the first and last blocks for each - // response. - if len(channelChunk) > 0 { - // If this is the first response we'll send, we'll point - // the first block to the first block in the query. - // Otherwise, we'll continue from the block we left off - // at. - if numChansSent == 0 { - firstBlockHeight = startBlock - } else { - firstBlockHeight = lastBlockHeight - } - - // If this is the last response we'll send, we'll point - // the last block to the last block of the query. - // Otherwise, we'll set it to the height of the last - // channel in the chunk. - if isFinalChunk { - lastBlockHeight = endBlock - } else { - lastBlockHeight = channelChunk[len(channelChunk)-1].BlockHeight - } - } - - // The number of blocks contained in this response (the total - // span) is the difference between the last channel ID and the - // first in the range. We add one as even if all channels + // The number of blocks contained in the current chunk (the + // total span) is the difference between the last channel ID and + // the first in the range. We add one as even if all channels // returned are in the same block, we need to count that. - numBlocksInResp := lastBlockHeight - firstBlockHeight + 1 + numBlocks := lastHeight - firstHeight + 1 + complete := uint8(0) + if finalChunk { + complete = 1 + } - // With our chunk assembled, we'll now send to the remote peer - // the current chunk. - replyChunk := lnwire.ReplyChannelRange{ + return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ QueryChannelRange: lnwire.QueryChannelRange{ ChainHash: query.ChainHash, - NumBlocks: numBlocksInResp, - FirstBlockHeight: firstBlockHeight, + NumBlocks: numBlocks, + FirstBlockHeight: firstHeight, }, - Complete: 0, + Complete: complete, EncodingType: g.cfg.encodingType, ShortChanIDs: channelChunk, + }) + } + + var ( + firstHeight = query.FirstBlockHeight + lastHeight uint32 + channelChunk []lnwire.ShortChannelID + ) + for _, channelRange := range channelRanges { + channels := channelRange.Channels + numChannels := int32(len(channels)) + numLeftToAdd := g.cfg.chunkSize - int32(len(channelChunk)) + + // Include the current block in the ongoing chunk if it can fit + // and move on to the next block. + if numChannels <= numLeftToAdd { + channelChunk = append(channelChunk, channels...) + continue } - if isFinalChunk { - replyChunk.Complete = 1 - } - if err := g.cfg.sendToPeerSync(&replyChunk); err != nil { + + // Otherwise, we need to send our existing channel chunk as is + // as its own reply and start a new one for the current block. + // We'll mark the end of our current chunk as the height before + // the current block to ensure the whole query range is replied + // to. + log.Infof("GossipSyncer(%x): sending range chunk of size=%v", + g.cfg.peerPub[:], len(channelChunk)) + lastHeight = channelRange.Height - 1 + err := sendReplyForChunk( + channelChunk, firstHeight, lastHeight, false, + ) + if err != nil { return err } - // If this was the final chunk, then we'll exit now as our - // response is now complete. - if isFinalChunk { - return nil + // With the reply constructed, we'll start tallying channels for + // our next one keeping in mind our chunk size. This may result + // in channels for this block being left out from the reply, but + // this isn't an issue since we'll randomly shuffle them and we + // assume a historical gossip sync is performed at a later time. + firstHeight = channelRange.Height + chunkSize := numChannels + exceedsChunkSize := numChannels > g.cfg.chunkSize + if exceedsChunkSize { + rand.Shuffle(len(channels), func(i, j int) { + channels[i], channels[j] = channels[j], channels[i] + }) + chunkSize = g.cfg.chunkSize } + channelChunk = channels[:chunkSize] - numChansSent += int32(len(channelChunk)) + // Sort the chunk once again if we had to shuffle it. + if exceedsChunkSize { + sort.Slice(channelChunk, func(i, j int) bool { + return channelChunk[i].ToUint64() < + channelChunk[j].ToUint64() + }) + } } + + // Send the remaining chunk as the final reply. + log.Infof("GossipSyncer(%x): sending final chan range chunk, size=%v", + g.cfg.peerPub[:], len(channelChunk)) + return sendReplyForChunk( + channelChunk, firstHeight, query.LastBlockHeight(), true, + ) } // replyShortChanIDs will be dispatched in response to a query by the remote @@ -1285,11 +1312,23 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // ProcessQueryMsg is used by outside callers to pass new channel time series // queries to the internal processing goroutine. -func (g *GossipSyncer) ProcessQueryMsg(msg lnwire.Message, peerQuit <-chan struct{}) { +func (g *GossipSyncer) ProcessQueryMsg(msg lnwire.Message, peerQuit <-chan struct{}) error { var msgChan chan lnwire.Message switch msg.(type) { case *lnwire.QueryChannelRange, *lnwire.QueryShortChanIDs: msgChan = g.queryMsgs + + // Reply messages should only be expected in states where we're waiting + // for a reply. + case *lnwire.ReplyChannelRange, *lnwire.ReplyShortChanIDsEnd: + syncState := g.syncState() + if syncState != waitingQueryRangeReply && + syncState != waitingQueryChanReply { + return fmt.Errorf("received unexpected query reply "+ + "message %T", msg) + } + msgChan = g.gossipMsgs + default: msgChan = g.gossipMsgs } @@ -1299,6 +1338,8 @@ func (g *GossipSyncer) ProcessQueryMsg(msg lnwire.Message, peerQuit <-chan struc case <-peerQuit: case <-g.quit: } + + return nil } // setSyncState sets the gossip syncer's state to the given state. diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 6d687bf9..d9da9382 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "reflect" + "sort" "sync" "testing" "time" @@ -12,7 +13,9 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) const ( @@ -95,11 +98,36 @@ func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash, return <-m.filterResp, nil } func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, - startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) { + startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) { m.filterRangeReqs <- filterRangeReq{startHeight, endHeight} + reply := <-m.filterRangeResp - return <-m.filterRangeResp, nil + channelsPerBlock := make(map[uint32][]lnwire.ShortChannelID) + for _, cid := range reply { + channelsPerBlock[cid.BlockHeight] = append( + channelsPerBlock[cid.BlockHeight], cid, + ) + } + + // Return the channel ranges in ascending block height order. + blocks := make([]uint32, 0, len(channelsPerBlock)) + for block := range channelsPerBlock { + blocks = append(blocks, block) + } + sort.Slice(blocks, func(i, j int) bool { + return blocks[i] < blocks[j] + }) + + channelRanges := make([]channeldb.BlockChannelRange, 0, len(channelsPerBlock)) + for _, block := range blocks { + channelRanges = append(channelRanges, channeldb.BlockChannelRange{ + Height: block, + Channels: channelsPerBlock[block], + }) + } + + return channelRanges, nil } func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) { @@ -161,6 +189,7 @@ func newTestSyncer(hID lnwire.ShortChannelID, bestHeight: func() uint32 { return latestKnownHeight }, + maxQueryChanRangeReplies: maxQueryChanRangeReplies, } syncer := newGossipSyncer(cfg) @@ -828,6 +857,7 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { // reply. We should get three sets of messages as two of them should be // full, while the other is the final fragment. const numExpectedChunks = 3 + var prevResp *lnwire.ReplyChannelRange respMsgs := make([]lnwire.ShortChannelID, 0, 5) for i := 0; i < numExpectedChunks; i++ { select { @@ -855,14 +885,14 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { // channels. case i == 0: expectedFirstBlockHeight = startingBlockHeight - expectedNumBlocks = chunkSize + 1 + expectedNumBlocks = 4 // The last reply should range starting from the next // block of our previous reply up until the ending // height of the query. It should also have the Complete // bit set. case i == numExpectedChunks-1: - expectedFirstBlockHeight = respMsgs[len(respMsgs)-1].BlockHeight + expectedFirstBlockHeight = prevResp.LastBlockHeight() + 1 expectedNumBlocks = endingBlockHeight - expectedFirstBlockHeight + 1 expectedComplete = 1 @@ -870,8 +900,8 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { // the next block of our previous reply up until it // reaches its maximum capacity of channels. default: - expectedFirstBlockHeight = respMsgs[len(respMsgs)-1].BlockHeight - expectedNumBlocks = 5 + expectedFirstBlockHeight = prevResp.LastBlockHeight() + 1 + expectedNumBlocks = 4 } switch { @@ -889,9 +919,10 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { case rangeResp.Complete != expectedComplete: t.Fatalf("Complete in resp #%d incorrect: "+ "expected %v, got %v", i+1, - expectedNumBlocks, rangeResp.Complete) + expectedComplete, rangeResp.Complete) } + prevResp = rangeResp respMsgs = append(respMsgs, rangeResp.ShortChanIDs...) } } @@ -1498,10 +1529,12 @@ func TestGossipSyncerDelayDOS(t *testing.T) { // inherently disjoint. var syncer2Chans []lnwire.ShortChannelID for i := 0; i < numTotalChans; i++ { - syncer2Chans = append(syncer2Chans, lnwire.ShortChannelID{ - BlockHeight: highestID.BlockHeight - 1, - TxIndex: uint32(i), - }) + syncer2Chans = append([]lnwire.ShortChannelID{ + { + BlockHeight: highestID.BlockHeight - uint32(i) - 1, + TxIndex: uint32(i), + }, + }, syncer2Chans...) } // We'll kick off the test by asserting syncer1 sends over the @@ -2305,3 +2338,80 @@ func TestGossipSyncerSyncedSignal(t *testing.T) { t.Fatal("expected to receive chansSynced signal") } } + +// TestGossipSyncerMaxChannelRangeReplies ensures that a gossip syncer +// transitions its state after receiving the maximum possible number of replies +// for a single QueryChannelRange message, and that any further replies after +// said limit are not processed. +func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { + t.Parallel() + + msgChan, syncer, chanSeries := newTestSyncer( + lnwire.ShortChannelID{BlockHeight: latestKnownHeight}, + defaultEncoding, defaultChunkSize, + ) + + // We'll tune the maxQueryChanRangeReplies to a more sensible value for + // the sake of testing. + syncer.cfg.maxQueryChanRangeReplies = 100 + + syncer.Start() + defer syncer.Stop() + + // Upon initialization, the syncer should submit a QueryChannelRange + // request. + var query *lnwire.QueryChannelRange + select { + case msgs := <-msgChan: + require.Len(t, msgs, 1) + require.IsType(t, &lnwire.QueryChannelRange{}, msgs[0]) + query = msgs[0].(*lnwire.QueryChannelRange) + + case <-time.After(time.Second): + t.Fatal("expected query channel range request msg") + } + + // We'll send the maximum number of replies allowed to a + // QueryChannelRange request with each reply consuming only one block in + // order to transition the syncer's state. + for i := uint32(0); i < syncer.cfg.maxQueryChanRangeReplies; i++ { + reply := &lnwire.ReplyChannelRange{ + QueryChannelRange: *query, + ShortChanIDs: []lnwire.ShortChannelID{ + { + BlockHeight: query.FirstBlockHeight + i, + }, + }, + } + reply.FirstBlockHeight = query.FirstBlockHeight + i + reply.NumBlocks = 1 + require.NoError(t, syncer.ProcessQueryMsg(reply, nil)) + } + + // We should receive a filter request for the syncer's local channels + // after processing all of the replies. We'll send back a nil response + // indicating that no new channels need to be synced, so it should + // transition to its final chansSynced state. + select { + case <-chanSeries.filterReq: + case <-time.After(time.Second): + t.Fatal("expected local filter request of known channels") + } + select { + case chanSeries.filterResp <- nil: + case <-time.After(time.Second): + t.Fatal("timed out sending filter response") + } + assertSyncerStatus(t, syncer, chansSynced, ActiveSync) + + // Finally, attempting to process another reply for the same query + // should result in an error. + require.Error(t, syncer.ProcessQueryMsg(&lnwire.ReplyChannelRange{ + QueryChannelRange: *query, + ShortChanIDs: []lnwire.ShortChannelID{ + { + BlockHeight: query.LastBlockHeight() + 1, + }, + }, + }, nil)) +}