diff --git a/discovery/syncer.go b/discovery/syncer.go index 8348abc5..a154f990 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -894,6 +894,23 @@ func (g *GossipSyncer) replyPeerQueries(msg lnwire.Message) error { // ensure that our final fragment carries the "complete" bit to indicate the // end of our streaming response. func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) error { + // Before responding, we'll check to ensure that the remote peer is + // querying for the same chain that we're on. If not, we'll send back a + // response with a complete value of zero to indicate we're on a + // different chain. + if g.cfg.chainHash != query.ChainHash { + log.Warnf("Remote peer requested QueryChannelRange for "+ + "chain=%v, we're on chain=%v", query.ChainHash, + g.cfg.chainHash) + + return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ + QueryChannelRange: *query, + Complete: 0, + EncodingType: g.cfg.encodingType, + ShortChanIDs: nil, + }) + } + log.Infof("GossipSyncer(%x): filtering chan range: start_height=%v, "+ "num_blocks=%v", g.cfg.peerPub[:], query.FirstBlockHeight, query.NumBlocks) diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 1b1f8890..606fc062 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -533,6 +533,61 @@ func TestGossipSyncerApplyGossipFilter(t *testing.T) { } } +// TestGossipSyncerQueryChannelRangeWrongChainHash tests that if we receive a +// channel range query for the wrong chain, then we send back a response with no +// channels and complete=0. +func TestGossipSyncerQueryChannelRangeWrongChainHash(t *testing.T) { + t.Parallel() + + // First, we'll create a GossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + msgChan, syncer, _ := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), defaultEncoding, + defaultChunkSize, + ) + + // We'll now ask the syncer to reply to a channel range query, but for a + // chain that it isn't aware of. + query := &lnwire.QueryChannelRange{ + ChainHash: *chaincfg.SimNetParams.GenesisHash, + FirstBlockHeight: 0, + NumBlocks: math.MaxUint32, + } + err := syncer.replyChanRangeQuery(query) + if err != nil { + t.Fatalf("unable to process short chan ID's: %v", err) + } + + select { + case <-time.After(time.Second * 15): + t.Fatalf("no msgs received") + + case msgs := <-msgChan: + // We should get back exactly one message, that's a + // ReplyChannelRange with a matching query, and a complete value + // of zero. + if len(msgs) != 1 { + t.Fatalf("wrong messages: expected %v, got %v", + 1, len(msgs)) + } + + msg, ok := msgs[0].(*lnwire.ReplyChannelRange) + if !ok { + t.Fatalf("expected lnwire.ReplyChannelRange, got %T", msg) + } + + if msg.QueryChannelRange != *query { + t.Fatalf("wrong query channel range in reply: "+ + "expected: %v\ngot: %v", spew.Sdump(*query), + spew.Sdump(msg.QueryChannelRange)) + } + if msg.Complete != 0 { + t.Fatalf("expected complete set to 0, got %v", + msg.Complete) + } + } +} + // TestGossipSyncerReplyShortChanIDsWrongChainHash tests that if we get a chan // ID query for the wrong chain, then we send back only a short ID end with // complete=0.