Merge pull request #1824 from cfromknecht/gossip-delay-reply

discovery/syncer: delay replies after initial sync to prevent DOS
This commit is contained in:
Olaoluwa Osuntokun 2018-09-09 15:28:22 -07:00 committed by GitHub
commit 1941353fb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 477 additions and 23 deletions

@ -3,6 +3,7 @@ package discovery
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"runtime" "runtime"
"sync" "sync"
@ -14,7 +15,6 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnpeer"
@ -1522,7 +1522,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge(
} }
err = ValidateChannelAnn(chanAnn) err = ValidateChannelAnn(chanAnn)
if err != nil { if err != nil {
err := errors.Errorf("assembled channel announcement proof "+ err := fmt.Errorf("assembled channel announcement proof "+
"for shortChanID=%v isn't valid: %v", "for shortChanID=%v isn't valid: %v",
chanAnnMsg.ShortChannelID, err) chanAnnMsg.ShortChannelID, err)
log.Error(err) log.Error(err)
@ -1533,7 +1533,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge(
// to the database. // to the database.
err = d.cfg.Router.AddProof(chanAnnMsg.ShortChannelID, proof) err = d.cfg.Router.AddProof(chanAnnMsg.ShortChannelID, proof)
if err != nil { if err != nil {
err := errors.Errorf("unable add proof to shortChanID=%v: %v", err := fmt.Errorf("unable add proof to shortChanID=%v: %v",
chanAnnMsg.ShortChannelID, err) chanAnnMsg.ShortChannelID, err)
log.Error(err) log.Error(err)
return nil, err return nil, err
@ -1599,7 +1599,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
if err := ValidateNodeAnn(msg); err != nil { if err := ValidateNodeAnn(msg); err != nil {
err := errors.Errorf("unable to validate "+ err := fmt.Errorf("unable to validate "+
"node announcement: %v", err) "node announcement: %v", err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
@ -1702,7 +1702,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
var proof *channeldb.ChannelAuthProof var proof *channeldb.ChannelAuthProof
if nMsg.isRemote { if nMsg.isRemote {
if err := ValidateChannelAnn(msg); err != nil { if err := ValidateChannelAnn(msg); err != nil {
err := errors.Errorf("unable to validate "+ err := fmt.Errorf("unable to validate "+
"announcement: %v", err) "announcement: %v", err)
d.rejectMtx.Lock() d.rejectMtx.Lock()
d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{} d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{}
@ -1966,7 +1966,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
return nil return nil
default: default:
err := errors.Errorf("unable to validate "+ err := fmt.Errorf("unable to validate "+
"channel update short_chan_id=%v: %v", "channel update short_chan_id=%v: %v",
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
@ -1994,7 +1994,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// key, In the case of an invalid channel , we'll return an // key, In the case of an invalid channel , we'll return an
// error to the caller and exit early. // error to the caller and exit early.
if err := ValidateChannelUpdateAnn(pubKey, msg); err != nil { if err := ValidateChannelUpdateAnn(pubKey, msg); err != nil {
rErr := errors.Errorf("unable to validate channel "+ rErr := fmt.Errorf("unable to validate channel "+
"update announcement for short_chan_id=%v: %v", "update announcement for short_chan_id=%v: %v",
spew.Sdump(msg.ShortChannelID), err) spew.Sdump(msg.ShortChannelID), err)
@ -2130,7 +2130,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// node might rewrite the waiting proof. // node might rewrite the waiting proof.
proof := channeldb.NewWaitingProof(nMsg.isRemote, msg) proof := channeldb.NewWaitingProof(nMsg.isRemote, msg)
if err := d.waitingProofs.Add(proof); err != nil { if err := d.waitingProofs.Add(proof); err != nil {
err := errors.Errorf("unable to store "+ err := fmt.Errorf("unable to store "+
"the proof for short_chan_id=%v: %v", "the proof for short_chan_id=%v: %v",
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
@ -2152,7 +2152,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// Ensure that channel that was retrieved belongs to the peer // Ensure that channel that was retrieved belongs to the peer
// which sent the proof announcement. // which sent the proof announcement.
if !(isFirstNode || isSecondNode) { if !(isFirstNode || isSecondNode) {
err := errors.Errorf("channel that was received not "+ err := fmt.Errorf("channel that was received not "+
"belongs to the peer which sent the proof, "+ "belongs to the peer which sent the proof, "+
"short_chan_id=%v", shortChanID) "short_chan_id=%v", shortChanID)
log.Error(err) log.Error(err)
@ -2176,7 +2176,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// deliver the proof when it comes online. // deliver the proof when it comes online.
err := d.sendAnnSigReliably(msg, remotePeer) err := d.sendAnnSigReliably(msg, remotePeer)
if err != nil { if err != nil {
err := errors.Errorf("unable to send reliably "+ err := fmt.Errorf("unable to send reliably "+
"to remote for short_chan_id=%v: %v", "to remote for short_chan_id=%v: %v",
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
@ -2245,7 +2245,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
proof := channeldb.NewWaitingProof(nMsg.isRemote, msg) proof := channeldb.NewWaitingProof(nMsg.isRemote, msg)
oppositeProof, err := d.waitingProofs.Get(proof.OppositeKey()) oppositeProof, err := d.waitingProofs.Get(proof.OppositeKey())
if err != nil && err != channeldb.ErrWaitingProofNotFound { if err != nil && err != channeldb.ErrWaitingProofNotFound {
err := errors.Errorf("unable to get "+ err := fmt.Errorf("unable to get "+
"the opposite proof for short_chan_id=%v: %v", "the opposite proof for short_chan_id=%v: %v",
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
@ -2255,7 +2255,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
if err == channeldb.ErrWaitingProofNotFound { if err == channeldb.ErrWaitingProofNotFound {
if err := d.waitingProofs.Add(proof); err != nil { if err := d.waitingProofs.Add(proof); err != nil {
err := errors.Errorf("unable to store "+ err := fmt.Errorf("unable to store "+
"the proof for short_chan_id=%v: %v", "the proof for short_chan_id=%v: %v",
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
@ -2298,7 +2298,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// With all the necessary components assembled validate the // With all the necessary components assembled validate the
// full channel announcement proof. // full channel announcement proof.
if err := ValidateChannelAnn(chanAnn); err != nil { if err := ValidateChannelAnn(chanAnn); err != nil {
err := errors.Errorf("channel announcement proof "+ err := fmt.Errorf("channel announcement proof "+
"for short_chan_id=%v isn't valid: %v", "for short_chan_id=%v isn't valid: %v",
shortChanID, err) shortChanID, err)
@ -2316,7 +2316,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// can announce it on peer connect. // can announce it on peer connect.
err = d.cfg.Router.AddProof(msg.ShortChannelID, &dbProof) err = d.cfg.Router.AddProof(msg.ShortChannelID, &dbProof)
if err != nil { if err != nil {
err := errors.Errorf("unable add proof to the "+ err := fmt.Errorf("unable add proof to the "+
"channel chanID=%v: %v", msg.ChannelID, err) "channel chanID=%v: %v", msg.ChannelID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
@ -2325,7 +2325,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
err = d.waitingProofs.Remove(proof.OppositeKey()) err = d.waitingProofs.Remove(proof.OppositeKey())
if err != nil { if err != nil {
err := errors.Errorf("unable remove opposite proof "+ err := fmt.Errorf("unable remove opposite proof "+
"for the channel with chanID=%v: %v", "for the channel with chanID=%v: %v",
msg.ChannelID, err) msg.ChannelID, err)
log.Error(err) log.Error(err)

@ -1,6 +1,7 @@
package discovery package discovery
import ( import (
"errors"
"fmt" "fmt"
"math" "math"
"sync" "sync"
@ -52,6 +53,17 @@ const (
chansSynced chansSynced
) )
const (
// DefaultMaxUndelayedQueryReplies specifies how many gossip queries we
// will respond to immediately before starting to delay responses.
DefaultMaxUndelayedQueryReplies = 5
// DefaultDelayedQueryReplyInterval is the length of time we will wait
// before responding to gossip queries after replying to
// maxUndelayedQueryReplies queries.
DefaultDelayedQueryReplyInterval = 30 * time.Second
)
// String returns a human readable string describing the target syncerState. // String returns a human readable string describing the target syncerState.
func (s syncerState) String() string { func (s syncerState) String() string {
switch s { switch s {
@ -82,6 +94,9 @@ var (
encodingTypeToChunkSize = map[lnwire.ShortChanIDEncoding]int32{ encodingTypeToChunkSize = map[lnwire.ShortChanIDEncoding]int32{
lnwire.EncodingSortedPlain: 8000, lnwire.EncodingSortedPlain: 8000,
} }
// ErrGossipSyncerExiting signals that the syncer has been killed.
ErrGossipSyncerExiting = errors.New("gossip syncer exiting")
) )
const ( const (
@ -167,6 +182,15 @@ type gossipSyncerCfg struct {
// targeted messages to the peer we've been assigned to sync the graph // targeted messages to the peer we've been assigned to sync the graph
// state from. // state from.
sendToPeer func(...lnwire.Message) error sendToPeer func(...lnwire.Message) error
// maxUndelayedQueryReplies specifies how many gossip queries we will
// respond to immediately before starting to delay responses.
maxUndelayedQueryReplies int
// delayedQueryReplyInterval is the length of time we will wait before
// responding to gossip queries after replying to
// maxUndelayedQueryReplies queries.
delayedQueryReplyInterval time.Duration
} }
// gossipSyncer is a struct that handles synchronizing the channel graph state // gossipSyncer is a struct that handles synchronizing the channel graph state
@ -214,6 +238,11 @@ type gossipSyncer struct {
cfg gossipSyncerCfg cfg gossipSyncerCfg
// replyCount records how many query replies we've responded to. This is
// used to determine when to start delaying responses to peers to
// prevent DOS vulnerabilities.
replyCount int
sync.Mutex sync.Mutex
quit chan struct{} quit chan struct{}
@ -223,6 +252,18 @@ type gossipSyncer struct {
// newGossiperSyncer returns a new instance of the gossipSyncer populated using // newGossiperSyncer returns a new instance of the gossipSyncer populated using
// the passed config. // the passed config.
func newGossiperSyncer(cfg gossipSyncerCfg) *gossipSyncer { func newGossiperSyncer(cfg gossipSyncerCfg) *gossipSyncer {
// If no parameter was specified for max undelayed query replies, set it
// to the default of 5 queries.
if cfg.maxUndelayedQueryReplies <= 0 {
cfg.maxUndelayedQueryReplies = DefaultMaxUndelayedQueryReplies
}
// If no parameter was specified for delayed query reply interval, set
// to the default of 30 seconds.
if cfg.delayedQueryReplyInterval <= 0 {
cfg.delayedQueryReplyInterval = DefaultDelayedQueryReplyInterval
}
return &gossipSyncer{ return &gossipSyncer{
cfg: cfg, cfg: cfg,
gossipMsgs: make(chan lnwire.Message, 100), gossipMsgs: make(chan lnwire.Message, 100),
@ -332,7 +373,7 @@ func (g *gossipSyncer) channelGraphSyncer() {
// Otherwise, it's the remote peer performing a // Otherwise, it's the remote peer performing a
// query, which we'll attempt to reply to. // query, which we'll attempt to reply to.
err := g.replyPeerQueries(msg) err := g.replyPeerQueries(msg)
if err != nil { if err != nil && err != ErrGossipSyncerExiting {
log.Errorf("unable to reply to peer "+ log.Errorf("unable to reply to peer "+
"query: %v", err) "query: %v", err)
} }
@ -386,7 +427,7 @@ func (g *gossipSyncer) channelGraphSyncer() {
// Otherwise, it's the remote peer performing a // Otherwise, it's the remote peer performing a
// query, which we'll attempt to deploy to. // query, which we'll attempt to deploy to.
err := g.replyPeerQueries(msg) err := g.replyPeerQueries(msg)
if err != nil { if err != nil && err != ErrGossipSyncerExiting {
log.Errorf("unable to reply to peer "+ log.Errorf("unable to reply to peer "+
"query: %v", err) "query: %v", err)
} }
@ -430,7 +471,7 @@ func (g *gossipSyncer) channelGraphSyncer() {
select { select {
case msg := <-g.gossipMsgs: case msg := <-g.gossipMsgs:
err := g.replyPeerQueries(msg) err := g.replyPeerQueries(msg)
if err != nil { if err != nil && err != ErrGossipSyncerExiting {
log.Errorf("unable to reply to peer "+ log.Errorf("unable to reply to peer "+
"query: %v", err) "query: %v", err)
} }
@ -588,6 +629,24 @@ func (g *gossipSyncer) genChanRangeQuery() (*lnwire.QueryChannelRange, error) {
// replyPeerQueries is called in response to any query by the remote peer. // replyPeerQueries is called in response to any query by the remote peer.
// We'll examine our state and send back our best response. // We'll examine our state and send back our best response.
func (g *gossipSyncer) replyPeerQueries(msg lnwire.Message) error { func (g *gossipSyncer) replyPeerQueries(msg lnwire.Message) error {
// If we've already replied a handful of times, we will start to delay
// responses back to the remote peer. This can help prevent DOS attacks
// where the remote peer spams us endlessly.
switch {
case g.replyCount == g.cfg.maxUndelayedQueryReplies:
log.Infof("gossipSyncer(%x): entering delayed gossip replies",
g.peerPub[:])
fallthrough
case g.replyCount > g.cfg.maxUndelayedQueryReplies:
select {
case <-time.After(g.cfg.delayedQueryReplyInterval):
case <-g.quit:
return ErrGossipSyncerExiting
}
}
g.replyCount++
switch msg := msg.(type) { switch msg := msg.(type) {
// In this state, we'll also handle any incoming channel range queries // In this state, we'll also handle any incoming channel range queries

@ -1,7 +1,6 @@
package discovery package discovery
import ( import (
"fmt"
"math" "math"
"reflect" "reflect"
"testing" "testing"
@ -49,7 +48,9 @@ type mockChannelGraphTimeSeries struct {
updateResp chan []*lnwire.ChannelUpdate updateResp chan []*lnwire.ChannelUpdate
} }
func newMockChannelGraphTimeSeries(hID lnwire.ShortChannelID) *mockChannelGraphTimeSeries { func newMockChannelGraphTimeSeries(
hID lnwire.ShortChannelID) *mockChannelGraphTimeSeries {
return &mockChannelGraphTimeSeries{ return &mockChannelGraphTimeSeries{
highestID: hID, highestID: hID,
@ -127,6 +128,7 @@ func newTestSyncer(hID lnwire.ShortChannelID,
msgChan <- msgs msgChan <- msgs
return nil return nil
}, },
delayedQueryReplyInterval: 2 * time.Second,
} }
syncer := newGossiperSyncer(cfg) syncer := newGossiperSyncer(cfg)
@ -810,9 +812,6 @@ func TestGossipSyncerProcessChanRangeReply(t *testing.T) {
// We should get a request for the entire range of short // We should get a request for the entire range of short
// chan ID's. // chan ID's.
if !reflect.DeepEqual(expectedReq, req) { if !reflect.DeepEqual(expectedReq, req) {
fmt.Printf("wrong request: expected %v, got %v\n",
expectedReq, req)
t.Fatalf("wrong request: expected %v, got %v", t.Fatalf("wrong request: expected %v, got %v",
expectedReq, req) expectedReq, req)
} }
@ -983,6 +982,402 @@ func TestGossipSyncerSynchronizeChanIDs(t *testing.T) {
} }
} }
// TestGossipSyncerDelayDOS tests that the gossip syncer will begin delaying
// queries after its prescribed allotment of undelayed query responses. Once
// this happens, all query replies should be delayed by the configurated
// interval.
func TestGossipSyncerDelayDOS(t *testing.T) {
t.Parallel()
// We'll modify the chunk size to be a smaller value, since we'll be
// sending a modest number of queries. After exhausting our undelayed
// gossip queries, we'll send two extra queries and ensure that they are
// delayed properly.
const chunkSize = 2
const numDelayedQueries = 2
const delayTolerance = time.Millisecond * 200
// First, we'll create two gossipSyncer instances with a canned
// sendToPeer message to allow us to intercept their potential sends.
startHeight := lnwire.ShortChannelID{
BlockHeight: 1144,
}
msgChan1, syncer1, chanSeries1 := newTestSyncer(
startHeight, defaultEncoding, chunkSize,
)
syncer1.Start()
defer syncer1.Stop()
msgChan2, syncer2, chanSeries2 := newTestSyncer(
startHeight, defaultEncoding, chunkSize,
)
syncer2.Start()
defer syncer2.Stop()
// Record the delayed query reply interval used by each syncer.
delayedQueryInterval := syncer1.cfg.delayedQueryReplyInterval
// Record the number of undelayed queries allowed by the syncers.
numUndelayedQueries := syncer1.cfg.maxUndelayedQueryReplies
// We will send enough queries to exhaust the undelayed responses, and
// then send two more queries which should be delayed.
numQueryResponses := numUndelayedQueries + numDelayedQueries
// The total number of responses must include the initial reply each
// syner will make to QueryChannelRange.
numTotalQueries := 1 + numQueryResponses
// The total number of channels each syncer needs to request must be
// scaled by the chunk size being used.
numTotalChans := numQueryResponses * chunkSize
// Although both nodes are at the same height, they'll have a
// completely disjoint set of chan ID's that they know of.
var syncer1Chans []lnwire.ShortChannelID
for i := 0; i < numTotalChans; i++ {
syncer1Chans = append(
syncer1Chans, lnwire.NewShortChanIDFromInt(uint64(i)),
)
}
var syncer2Chans []lnwire.ShortChannelID
for i := numTotalChans; i < numTotalChans+numTotalChans; i++ {
syncer2Chans = append(
syncer2Chans, lnwire.NewShortChanIDFromInt(uint64(i)),
)
}
// We'll kick off the test by passing over the QueryChannelRange
// messages from one node to the other.
select {
case <-time.After(time.Second * 2):
t.Fatalf("didn't get msg from syncer1")
case msgs := <-msgChan1:
for _, msg := range msgs {
// The message MUST be a QueryChannelRange message.
_, ok := msg.(*lnwire.QueryChannelRange)
if !ok {
t.Fatalf("wrong message: expected "+
"QueryChannelRange for %T", msg)
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("node 2 didn't read msg")
case syncer2.gossipMsgs <- msg:
}
}
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("didn't get msg from syncer2")
case msgs := <-msgChan2:
for _, msg := range msgs {
// The message MUST be a QueryChannelRange message.
_, ok := msg.(*lnwire.QueryChannelRange)
if !ok {
t.Fatalf("wrong message: expected "+
"QueryChannelRange for %T", msg)
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("node 2 didn't read msg")
case syncer1.gossipMsgs <- msg:
}
}
}
// At this point, we'll need to send responses to both nodes from their
// respective channel series. Both nodes will simply request the entire
// set of channels from the other. This will count as the first
// undelayed response for each syncer.
select {
case <-time.After(time.Second * 2):
t.Fatalf("no query recvd")
case <-chanSeries1.filterRangeReqs:
// We'll send all the channels that it should know of.
chanSeries1.filterRangeResp <- syncer1Chans
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("no query recvd")
case <-chanSeries2.filterRangeReqs:
// We'll send back all the channels that it should know of.
chanSeries2.filterRangeResp <- syncer2Chans
}
// At this point, we'll forward the ReplyChannelRange messages to both
// parties. After receiving the set of channels known to the remote peer
for i := 0; i < numQueryResponses; i++ {
select {
case <-time.After(time.Second * 2):
t.Fatalf("didn't get msg from syncer1")
case msgs := <-msgChan1:
for _, msg := range msgs {
// The message MUST be a ReplyChannelRange message.
_, ok := msg.(*lnwire.ReplyChannelRange)
if !ok {
t.Fatalf("wrong message: expected "+
"QueryChannelRange for %T", msg)
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("node 2 didn't read msg")
case syncer2.gossipMsgs <- msg:
}
}
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("didn't get msg from syncer2")
case msgs := <-msgChan2:
for _, msg := range msgs {
// The message MUST be a ReplyChannelRange message.
_, ok := msg.(*lnwire.ReplyChannelRange)
if !ok {
t.Fatalf("wrong message: expected "+
"QueryChannelRange for %T", msg)
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("node 2 didn't read msg")
case syncer1.gossipMsgs <- msg:
}
}
}
}
// We'll now send back a chunked response for both parties of the known
// short chan ID's.
select {
case <-time.After(time.Second * 2):
t.Fatalf("no query recvd")
case <-chanSeries1.filterReq:
chanSeries1.filterResp <- syncer2Chans
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("no query recvd")
case <-chanSeries2.filterReq:
chanSeries2.filterResp <- syncer1Chans
}
// At this point, both parties should start to send out initial
// requests to query the chan IDs of the remote party. We'll keep track
// of the number of queries made using the iterated value, which starts
// at one due the initial contribution of the QueryChannelRange msgs.
for i := 1; i < numTotalQueries; i++ {
// Both parties should now have sent out the initial requests
// to query the chan IDs of the other party.
select {
case <-time.After(time.Second * 2):
t.Fatalf("didn't get msg from syncer1")
case msgs := <-msgChan1:
for _, msg := range msgs {
// The message MUST be a QueryShortChanIDs message.
_, ok := msg.(*lnwire.QueryShortChanIDs)
if !ok {
t.Fatalf("wrong message: expected "+
"QueryShortChanIDs for %T", msg)
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("node 2 didn't read msg")
case syncer2.gossipMsgs <- msg:
}
}
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("didn't get msg from syncer2")
case msgs := <-msgChan2:
for _, msg := range msgs {
// The message MUST be a QueryShortChanIDs message.
_, ok := msg.(*lnwire.QueryShortChanIDs)
if !ok {
t.Fatalf("wrong message: expected "+
"QueryShortChanIDs for %T", msg)
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("node 2 didn't read msg")
case syncer1.gossipMsgs <- msg:
}
}
}
// We'll then respond to both parties with an empty set of
// replies (as it doesn't affect the test).
switch {
// If this query has surpassed the undelayed query threshold, we
// will impose stricter timing constraints on the response
// times. We'll first test that the peers don't immediately
// receive a query, and then check that both queries haven't
// gone unanswered entirely.
case i >= numUndelayedQueries:
// Create a before and after timeout to test, our test
// will ensure the messages are delivered to the peers
// in this timeframe.
before := time.After(
delayedQueryInterval - delayTolerance,
)
after := time.After(
delayedQueryInterval + delayTolerance,
)
// First, ensure neither peer tries to respond up until
// the before time fires.
select {
case <-before:
// Queries are delayed, proceed.
case <-chanSeries1.annReq:
t.Fatalf("DOSy query was not delayed")
case <-chanSeries2.annReq:
t.Fatalf("DOSy query was not delayed")
}
// Next, we'll need to test that both queries are
// received before the after timer expires. To account
// for ordering, we will try to pull a message from both
// peers, and then test that the opposite peer also
// receives the message promptly.
var (
firstChanSeries *mockChannelGraphTimeSeries
laterChanSeries *mockChannelGraphTimeSeries
)
// If neither peer attempts a response within the
// allowed interval, then the messages are probably
// lost. Otherwise, process the message and record the
// induced ordering.
select {
case <-after:
t.Fatalf("no delayed query received")
case <-chanSeries1.annReq:
chanSeries1.annResp <- []lnwire.Message{}
firstChanSeries = chanSeries1
laterChanSeries = chanSeries2
case <-chanSeries2.annReq:
chanSeries2.annResp <- []lnwire.Message{}
firstChanSeries = chanSeries2
laterChanSeries = chanSeries1
}
// Finally, using the same interval timeout as before,
// ensure the later peer also responds promptly. We also
// assert that the first peer doesn't attempt another
// response.
select {
case <-after:
t.Fatalf("no delayed query received")
case <-firstChanSeries.annReq:
t.Fatalf("spurious undelayed response")
case <-laterChanSeries.annReq:
laterChanSeries.annResp <- []lnwire.Message{}
}
// Otherwise, we still haven't exceeded our undelayed query
// limit. Assert that both peers promptly attempt a response to
// the queries.
default:
select {
case <-time.After(50 * time.Millisecond):
t.Fatalf("no query recvd")
case <-chanSeries1.annReq:
chanSeries1.annResp <- []lnwire.Message{}
}
select {
case <-time.After(50 * time.Millisecond):
t.Fatalf("no query recvd")
case <-chanSeries2.annReq:
chanSeries2.annResp <- []lnwire.Message{}
}
}
// Finally, both sides should then receive a
// ReplyShortChanIDsEnd as the first chunk has been replied to.
select {
case <-time.After(50 * time.Millisecond):
t.Fatalf("didn't get msg from syncer1")
case msgs := <-msgChan1:
for _, msg := range msgs {
// The message MUST be a ReplyShortChanIDsEnd message.
_, ok := msg.(*lnwire.ReplyShortChanIDsEnd)
if !ok {
t.Fatalf("wrong message: expected "+
"QueryChannelRange for %T", msg)
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("node 2 didn't read msg")
case syncer2.gossipMsgs <- msg:
}
}
}
select {
case <-time.After(50 * time.Millisecond):
t.Fatalf("didn't get msg from syncer2")
case msgs := <-msgChan2:
for _, msg := range msgs {
// The message MUST be a ReplyShortChanIDsEnd message.
_, ok := msg.(*lnwire.ReplyShortChanIDsEnd)
if !ok {
t.Fatalf("wrong message: expected "+
"ReplyShortChanIDsEnd for %T", msg)
}
select {
case <-time.After(time.Second * 2):
t.Fatalf("node 2 didn't read msg")
case syncer1.gossipMsgs <- msg:
}
}
}
}
}
// TestGossipSyncerRoutineSync tests all state transitions of the main syncer // TestGossipSyncerRoutineSync tests all state transitions of the main syncer
// goroutine. This ensures that given an encounter with a peer that has a set // goroutine. This ensures that given an encounter with a peer that has a set
// of distinct channels, then we'll properly synchronize our channel state with // of distinct channels, then we'll properly synchronize our channel state with