diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index a6bcb707..f2113ff5 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -629,8 +629,10 @@ func assertTransitionToChansSynced(t *testing.T, s *GossipSyncer, peer *mockPeer }, time.Second, 500*time.Millisecond) require.NoError(t, s.ProcessQueryMsg(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, - Complete: 1, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, + Complete: 1, }, nil)) chanSeries := s.cfg.channelSeries.(*mockChannelGraphTimeSeries) diff --git a/discovery/syncer.go b/discovery/syncer.go index 36031f31..de6821e3 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -753,7 +753,9 @@ func (g *GossipSyncer) synchronizeChanIDs() (bool, error) { func isLegacyReplyChannelRange(query *lnwire.QueryChannelRange, reply *lnwire.ReplyChannelRange) bool { - return reply.QueryChannelRange == *query + return (reply.ChainHash == query.ChainHash && + reply.FirstBlockHeight == query.FirstBlockHeight && + reply.NumBlocks == query.NumBlocks) } // processChanRangeReply is called each time the GossipSyncer receives a new @@ -773,7 +775,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // The last block should also be. We don't need to check the // intermediate ones because they should already be in sorted // order. - replyLastHeight := msg.QueryChannelRange.LastBlockHeight() + replyLastHeight := msg.LastBlockHeight() queryLastHeight := g.curQueryRangeMsg.LastBlockHeight() if replyLastHeight > queryLastHeight { return fmt.Errorf("reply includes channels for height "+ @@ -832,7 +834,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // Otherwise, we'll look at the reply's height range. default: - replyLastHeight := msg.QueryChannelRange.LastBlockHeight() + replyLastHeight := msg.LastBlockHeight() queryLastHeight := g.curQueryRangeMsg.LastBlockHeight() // TODO(wilmer): This might require some padding if the remote @@ -997,10 +999,12 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro g.cfg.chainHash) return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, - Complete: 0, - EncodingType: g.cfg.encodingType, - ShortChanIDs: nil, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, + Complete: 0, + EncodingType: g.cfg.encodingType, + ShortChanIDs: nil, }) } @@ -1040,14 +1044,12 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro } return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ - QueryChannelRange: lnwire.QueryChannelRange{ - ChainHash: query.ChainHash, - NumBlocks: numBlocks, - FirstBlockHeight: firstHeight, - }, - Complete: complete, - EncodingType: g.cfg.encodingType, - ShortChanIDs: channelChunk, + ChainHash: query.ChainHash, + NumBlocks: numBlocks, + FirstBlockHeight: firstHeight, + Complete: complete, + EncodingType: g.cfg.encodingType, + ShortChanIDs: channelChunk, }) } diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index c3ad04f5..40e759f9 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -609,10 +609,9 @@ func TestGossipSyncerQueryChannelRangeWrongChainHash(t *testing.T) { t.Fatalf("expected lnwire.ReplyChannelRange, got %T", msg) } - if msg.QueryChannelRange != *query { - t.Fatalf("wrong query channel range in reply: "+ - "expected: %v\ngot: %v", spew.Sdump(*query), - spew.Sdump(msg.QueryChannelRange)) + if msg.ChainHash != query.ChainHash { + t.Fatalf("wrong chain hash: expected %v got %v", + query.ChainHash, msg.ChainHash) } if msg.Complete != 0 { t.Fatalf("expected complete set to 0, got %v", @@ -1227,34 +1226,13 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { t.Fatalf("unable to generate channel range query: %v", err) } - var replyQueries []*lnwire.QueryChannelRange - if legacy { - // Each reply query is the same as the original query in the - // legacy mode. - replyQueries = []*lnwire.QueryChannelRange{query, query, query} - } else { - // When interpreting block ranges, the first reply should start - // from our requested first block, and the last should end at - // our requested last block. - replyQueries = []*lnwire.QueryChannelRange{ - { - FirstBlockHeight: 0, - NumBlocks: 11, - }, - { - FirstBlockHeight: 11, - NumBlocks: 1, - }, - { - FirstBlockHeight: 12, - NumBlocks: query.NumBlocks - 12, - }, - } - } - + // When interpreting block ranges, the first reply should start from + // our requested first block, and the last should end at our requested + // last block. replies := []*lnwire.ReplyChannelRange{ { - QueryChannelRange: *replyQueries[0], + FirstBlockHeight: 0, + NumBlocks: 11, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 10, @@ -1262,7 +1240,8 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, }, { - QueryChannelRange: *replyQueries[1], + FirstBlockHeight: 11, + NumBlocks: 1, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 11, @@ -1270,8 +1249,9 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, }, { - QueryChannelRange: *replyQueries[2], - Complete: 1, + FirstBlockHeight: 12, + NumBlocks: query.NumBlocks - 12, + Complete: 1, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 12, @@ -1280,6 +1260,19 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, } + // Each reply query is the same as the original query in the legacy + // mode. + if legacy { + replies[0].FirstBlockHeight = query.FirstBlockHeight + replies[0].NumBlocks = query.NumBlocks + + replies[1].FirstBlockHeight = query.FirstBlockHeight + replies[1].NumBlocks = query.NumBlocks + + replies[2].FirstBlockHeight = query.FirstBlockHeight + replies[2].NumBlocks = query.NumBlocks + } + // We'll begin by sending the syncer a set of non-complete channel // range replies. if err := syncer.processChanRangeReply(replies[0]); err != nil { @@ -2377,7 +2370,9 @@ func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { // order to transition the syncer's state. for i := uint32(0); i < syncer.cfg.maxQueryChanRangeReplies; i++ { reply := &lnwire.ReplyChannelRange{ - QueryChannelRange: *query, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: query.FirstBlockHeight + i, @@ -2408,7 +2403,9 @@ func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { // Finally, attempting to process another reply for the same query // should result in an error. require.Error(t, syncer.ProcessQueryMsg(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: query.LastBlockHeight() + 1, diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 02023b02..ea90f3c0 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -810,10 +810,8 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) { req := ReplyChannelRange{ - QueryChannelRange: QueryChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - }, + FirstBlockHeight: uint32(r.Int31()), + NumBlocks: uint32(r.Int31()), } if _, err := rand.Read(req.ChainHash[:]); err != nil { diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 43060602..3ff5dd4b 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -1,14 +1,29 @@ package lnwire -import "io" +import ( + "io" + "math" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) // ReplyChannelRange is the response to the QueryChannelRange message. It // includes the original query, and the next streaming chunk of encoded short // channel ID's as the response. We'll also include a byte that indicates if // this is the last query in the message. type ReplyChannelRange struct { - // QueryChannelRange is the corresponding query to this response. - QueryChannelRange + // ChainHash denotes the target chain that we're trying to synchronize + // channel graph state for. + ChainHash chainhash.Hash + + // FirstBlockHeight is the first block in the query range. The + // responder should send all new short channel IDs from this block + // until this block plus the specified number of blocks. + FirstBlockHeight uint32 + + // NumBlocks is the number of blocks beyond the first block that short + // channel ID's should be sent for. + NumBlocks uint32 // Complete denotes if this is the conclusion of the set of streaming // responses to the original query. @@ -43,17 +58,21 @@ var _ Message = (*ReplyChannelRange)(nil) // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { - err := c.QueryChannelRange.Decode(r, pver) + err := ReadElements(r, + c.ChainHash[:], + &c.FirstBlockHeight, + &c.NumBlocks, + &c.Complete, + ) if err != nil { return err } - if err := ReadElements(r, &c.Complete); err != nil { + c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) + if err != nil { return err } - c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) - return err } @@ -62,15 +81,22 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { - if err := c.QueryChannelRange.Encode(w, pver); err != nil { + err := WriteElements(w, + c.ChainHash[:], + c.FirstBlockHeight, + c.NumBlocks, + c.Complete, + ) + if err != nil { return err } - if err := WriteElements(w, c.Complete); err != nil { + err = encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) + if err != nil { return err } - return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) + return nil } // MsgType returns the integer uniquely identifying this message type on the @@ -88,3 +114,14 @@ func (c *ReplyChannelRange) MsgType() MessageType { func (c *ReplyChannelRange) MaxPayloadLength(uint32) uint32 { return MaxMessagePayload } + +// LastBlockHeight returns the last block height covered by the range of a +// QueryChannelRange message. +func (c *ReplyChannelRange) LastBlockHeight() uint32 { + // Handle overflows by casting to uint64. + lastBlockHeight := uint64(c.FirstBlockHeight) + uint64(c.NumBlocks) - 1 + if lastBlockHeight > math.MaxUint32 { + return math.MaxUint32 + } + return uint32(lastBlockHeight) +} diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go index d2c8df68..d656db55 100644 --- a/lnwire/reply_channel_range_test.go +++ b/lnwire/reply_channel_range_test.go @@ -30,7 +30,7 @@ func TestReplyChannelRangeUnsorted(t *testing.T) { var req2 ReplyChannelRange err = req2.Decode(bytes.NewReader(b.Bytes()), 0) if _, ok := err.(ErrUnsortedSIDs); !ok { - t.Fatalf("expected ErrUnsortedSIDs, got: %T", + t.Fatalf("expected ErrUnsortedSIDs, got: %v", err) } }) @@ -67,13 +67,11 @@ func TestReplyChannelRangeEmpty(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { req := ReplyChannelRange{ - QueryChannelRange: QueryChannelRange{ - FirstBlockHeight: 1, - NumBlocks: 2, - }, - Complete: 1, - EncodingType: test.encType, - ShortChanIDs: nil, + FirstBlockHeight: 1, + NumBlocks: 2, + Complete: 1, + EncodingType: test.encType, + ShortChanIDs: nil, } // First decode the hex string in the test case into a