diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 5c4cb2a9..87c98be2 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -3,6 +3,7 @@ package discovery import ( "bytes" "encoding/binary" + "errors" "fmt" "runtime" "sync" @@ -14,7 +15,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" - "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnpeer" @@ -1522,7 +1522,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge( } err = ValidateChannelAnn(chanAnn) if err != nil { - err := errors.Errorf("assembled channel announcement proof "+ + err := fmt.Errorf("assembled channel announcement proof "+ "for shortChanID=%v isn't valid: %v", chanAnnMsg.ShortChannelID, err) log.Error(err) @@ -1533,7 +1533,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge( // to the database. err = d.cfg.Router.AddProof(chanAnnMsg.ShortChannelID, proof) if err != nil { - err := errors.Errorf("unable add proof to shortChanID=%v: %v", + err := fmt.Errorf("unable add proof to shortChanID=%v: %v", chanAnnMsg.ShortChannelID, err) log.Error(err) return nil, err @@ -1599,7 +1599,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( } if err := ValidateNodeAnn(msg); err != nil { - err := errors.Errorf("unable to validate "+ + err := fmt.Errorf("unable to validate "+ "node announcement: %v", err) log.Error(err) nMsg.err <- err @@ -1702,7 +1702,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( var proof *channeldb.ChannelAuthProof if nMsg.isRemote { if err := ValidateChannelAnn(msg); err != nil { - err := errors.Errorf("unable to validate "+ + err := fmt.Errorf("unable to validate "+ "announcement: %v", err) d.rejectMtx.Lock() d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{} @@ -1966,7 +1966,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( return nil default: - err := errors.Errorf("unable to validate "+ + err := fmt.Errorf("unable to validate "+ "channel update short_chan_id=%v: %v", shortChanID, err) log.Error(err) @@ -1994,7 +1994,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // key, In the case of an invalid channel , we'll return an // error to the caller and exit early. if err := ValidateChannelUpdateAnn(pubKey, msg); err != nil { - rErr := errors.Errorf("unable to validate channel "+ + rErr := fmt.Errorf("unable to validate channel "+ "update announcement for short_chan_id=%v: %v", spew.Sdump(msg.ShortChannelID), err) @@ -2130,7 +2130,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // node might rewrite the waiting proof. proof := channeldb.NewWaitingProof(nMsg.isRemote, msg) if err := d.waitingProofs.Add(proof); err != nil { - err := errors.Errorf("unable to store "+ + err := fmt.Errorf("unable to store "+ "the proof for short_chan_id=%v: %v", shortChanID, err) log.Error(err) @@ -2152,7 +2152,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // Ensure that channel that was retrieved belongs to the peer // which sent the proof announcement. if !(isFirstNode || isSecondNode) { - err := errors.Errorf("channel that was received not "+ + err := fmt.Errorf("channel that was received not "+ "belongs to the peer which sent the proof, "+ "short_chan_id=%v", shortChanID) log.Error(err) @@ -2176,7 +2176,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // deliver the proof when it comes online. err := d.sendAnnSigReliably(msg, remotePeer) if err != nil { - err := errors.Errorf("unable to send reliably "+ + err := fmt.Errorf("unable to send reliably "+ "to remote for short_chan_id=%v: %v", shortChanID, err) log.Error(err) @@ -2245,7 +2245,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( proof := channeldb.NewWaitingProof(nMsg.isRemote, msg) oppositeProof, err := d.waitingProofs.Get(proof.OppositeKey()) if err != nil && err != channeldb.ErrWaitingProofNotFound { - err := errors.Errorf("unable to get "+ + err := fmt.Errorf("unable to get "+ "the opposite proof for short_chan_id=%v: %v", shortChanID, err) log.Error(err) @@ -2255,7 +2255,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( if err == channeldb.ErrWaitingProofNotFound { if err := d.waitingProofs.Add(proof); err != nil { - err := errors.Errorf("unable to store "+ + err := fmt.Errorf("unable to store "+ "the proof for short_chan_id=%v: %v", shortChanID, err) log.Error(err) @@ -2298,7 +2298,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // With all the necessary components assembled validate the // full channel announcement proof. if err := ValidateChannelAnn(chanAnn); err != nil { - err := errors.Errorf("channel announcement proof "+ + err := fmt.Errorf("channel announcement proof "+ "for short_chan_id=%v isn't valid: %v", shortChanID, err) @@ -2316,7 +2316,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // can announce it on peer connect. err = d.cfg.Router.AddProof(msg.ShortChannelID, &dbProof) if err != nil { - err := errors.Errorf("unable add proof to the "+ + err := fmt.Errorf("unable add proof to the "+ "channel chanID=%v: %v", msg.ChannelID, err) log.Error(err) nMsg.err <- err @@ -2325,7 +2325,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( err = d.waitingProofs.Remove(proof.OppositeKey()) if err != nil { - err := errors.Errorf("unable remove opposite proof "+ + err := fmt.Errorf("unable remove opposite proof "+ "for the channel with chanID=%v: %v", msg.ChannelID, err) log.Error(err) diff --git a/discovery/syncer.go b/discovery/syncer.go index 4041b1fd..dd353fde 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -1,6 +1,7 @@ package discovery import ( + "errors" "fmt" "math" "sync" @@ -52,6 +53,17 @@ const ( chansSynced ) +const ( + // DefaultMaxUndelayedQueryReplies specifies how many gossip queries we + // will respond to immediately before starting to delay responses. + DefaultMaxUndelayedQueryReplies = 5 + + // DefaultDelayedQueryReplyInterval is the length of time we will wait + // before responding to gossip queries after replying to + // maxUndelayedQueryReplies queries. + DefaultDelayedQueryReplyInterval = 30 * time.Second +) + // String returns a human readable string describing the target syncerState. func (s syncerState) String() string { switch s { @@ -82,6 +94,9 @@ var ( encodingTypeToChunkSize = map[lnwire.ShortChanIDEncoding]int32{ lnwire.EncodingSortedPlain: 8000, } + + // ErrGossipSyncerExiting signals that the syncer has been killed. + ErrGossipSyncerExiting = errors.New("gossip syncer exiting") ) const ( @@ -167,6 +182,15 @@ type gossipSyncerCfg struct { // targeted messages to the peer we've been assigned to sync the graph // state from. sendToPeer func(...lnwire.Message) error + + // maxUndelayedQueryReplies specifies how many gossip queries we will + // respond to immediately before starting to delay responses. + maxUndelayedQueryReplies int + + // delayedQueryReplyInterval is the length of time we will wait before + // responding to gossip queries after replying to + // maxUndelayedQueryReplies queries. + delayedQueryReplyInterval time.Duration } // gossipSyncer is a struct that handles synchronizing the channel graph state @@ -214,6 +238,11 @@ type gossipSyncer struct { cfg gossipSyncerCfg + // replyCount records how many query replies we've responded to. This is + // used to determine when to start delaying responses to peers to + // prevent DOS vulnerabilities. + replyCount int + sync.Mutex quit chan struct{} @@ -223,6 +252,18 @@ type gossipSyncer struct { // newGossiperSyncer returns a new instance of the gossipSyncer populated using // the passed config. func newGossiperSyncer(cfg gossipSyncerCfg) *gossipSyncer { + // If no parameter was specified for max undelayed query replies, set it + // to the default of 5 queries. + if cfg.maxUndelayedQueryReplies <= 0 { + cfg.maxUndelayedQueryReplies = DefaultMaxUndelayedQueryReplies + } + + // If no parameter was specified for delayed query reply interval, set + // to the default of 30 seconds. + if cfg.delayedQueryReplyInterval <= 0 { + cfg.delayedQueryReplyInterval = DefaultDelayedQueryReplyInterval + } + return &gossipSyncer{ cfg: cfg, gossipMsgs: make(chan lnwire.Message, 100), @@ -332,7 +373,7 @@ func (g *gossipSyncer) channelGraphSyncer() { // Otherwise, it's the remote peer performing a // query, which we'll attempt to reply to. err := g.replyPeerQueries(msg) - if err != nil { + if err != nil && err != ErrGossipSyncerExiting { log.Errorf("unable to reply to peer "+ "query: %v", err) } @@ -386,7 +427,7 @@ func (g *gossipSyncer) channelGraphSyncer() { // Otherwise, it's the remote peer performing a // query, which we'll attempt to deploy to. err := g.replyPeerQueries(msg) - if err != nil { + if err != nil && err != ErrGossipSyncerExiting { log.Errorf("unable to reply to peer "+ "query: %v", err) } @@ -430,7 +471,7 @@ func (g *gossipSyncer) channelGraphSyncer() { select { case msg := <-g.gossipMsgs: err := g.replyPeerQueries(msg) - if err != nil { + if err != nil && err != ErrGossipSyncerExiting { log.Errorf("unable to reply to peer "+ "query: %v", err) } @@ -588,6 +629,24 @@ func (g *gossipSyncer) genChanRangeQuery() (*lnwire.QueryChannelRange, error) { // replyPeerQueries is called in response to any query by the remote peer. // We'll examine our state and send back our best response. func (g *gossipSyncer) replyPeerQueries(msg lnwire.Message) error { + // If we've already replied a handful of times, we will start to delay + // responses back to the remote peer. This can help prevent DOS attacks + // where the remote peer spams us endlessly. + switch { + case g.replyCount == g.cfg.maxUndelayedQueryReplies: + log.Infof("gossipSyncer(%x): entering delayed gossip replies", + g.peerPub[:]) + fallthrough + + case g.replyCount > g.cfg.maxUndelayedQueryReplies: + select { + case <-time.After(g.cfg.delayedQueryReplyInterval): + case <-g.quit: + return ErrGossipSyncerExiting + } + } + g.replyCount++ + switch msg := msg.(type) { // In this state, we'll also handle any incoming channel range queries diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 2ae8435b..01afe7d9 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -1,7 +1,6 @@ package discovery import ( - "fmt" "math" "reflect" "testing" @@ -49,7 +48,9 @@ type mockChannelGraphTimeSeries struct { updateResp chan []*lnwire.ChannelUpdate } -func newMockChannelGraphTimeSeries(hID lnwire.ShortChannelID) *mockChannelGraphTimeSeries { +func newMockChannelGraphTimeSeries( + hID lnwire.ShortChannelID) *mockChannelGraphTimeSeries { + return &mockChannelGraphTimeSeries{ highestID: hID, @@ -127,6 +128,7 @@ func newTestSyncer(hID lnwire.ShortChannelID, msgChan <- msgs return nil }, + delayedQueryReplyInterval: 2 * time.Second, } syncer := newGossiperSyncer(cfg) @@ -810,9 +812,6 @@ func TestGossipSyncerProcessChanRangeReply(t *testing.T) { // We should get a request for the entire range of short // chan ID's. if !reflect.DeepEqual(expectedReq, req) { - fmt.Printf("wrong request: expected %v, got %v\n", - expectedReq, req) - t.Fatalf("wrong request: expected %v, got %v", expectedReq, req) } @@ -983,6 +982,402 @@ func TestGossipSyncerSynchronizeChanIDs(t *testing.T) { } } +// TestGossipSyncerDelayDOS tests that the gossip syncer will begin delaying +// queries after its prescribed allotment of undelayed query responses. Once +// this happens, all query replies should be delayed by the configurated +// interval. +func TestGossipSyncerDelayDOS(t *testing.T) { + t.Parallel() + + // We'll modify the chunk size to be a smaller value, since we'll be + // sending a modest number of queries. After exhausting our undelayed + // gossip queries, we'll send two extra queries and ensure that they are + // delayed properly. + const chunkSize = 2 + const numDelayedQueries = 2 + const delayTolerance = time.Millisecond * 200 + + // First, we'll create two gossipSyncer instances with a canned + // sendToPeer message to allow us to intercept their potential sends. + startHeight := lnwire.ShortChannelID{ + BlockHeight: 1144, + } + msgChan1, syncer1, chanSeries1 := newTestSyncer( + startHeight, defaultEncoding, chunkSize, + ) + syncer1.Start() + defer syncer1.Stop() + + msgChan2, syncer2, chanSeries2 := newTestSyncer( + startHeight, defaultEncoding, chunkSize, + ) + syncer2.Start() + defer syncer2.Stop() + + // Record the delayed query reply interval used by each syncer. + delayedQueryInterval := syncer1.cfg.delayedQueryReplyInterval + + // Record the number of undelayed queries allowed by the syncers. + numUndelayedQueries := syncer1.cfg.maxUndelayedQueryReplies + + // We will send enough queries to exhaust the undelayed responses, and + // then send two more queries which should be delayed. + numQueryResponses := numUndelayedQueries + numDelayedQueries + + // The total number of responses must include the initial reply each + // syner will make to QueryChannelRange. + numTotalQueries := 1 + numQueryResponses + + // The total number of channels each syncer needs to request must be + // scaled by the chunk size being used. + numTotalChans := numQueryResponses * chunkSize + + // Although both nodes are at the same height, they'll have a + // completely disjoint set of chan ID's that they know of. + var syncer1Chans []lnwire.ShortChannelID + for i := 0; i < numTotalChans; i++ { + syncer1Chans = append( + syncer1Chans, lnwire.NewShortChanIDFromInt(uint64(i)), + ) + } + var syncer2Chans []lnwire.ShortChannelID + for i := numTotalChans; i < numTotalChans+numTotalChans; i++ { + syncer2Chans = append( + syncer2Chans, lnwire.NewShortChanIDFromInt(uint64(i)), + ) + } + + // We'll kick off the test by passing over the QueryChannelRange + // messages from one node to the other. + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a QueryChannelRange message. + _, ok := msg.(*lnwire.QueryChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + + } + } + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer2") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a QueryChannelRange message. + _, ok := msg.(*lnwire.QueryChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + + } + } + } + + // At this point, we'll need to send responses to both nodes from their + // respective channel series. Both nodes will simply request the entire + // set of channels from the other. This will count as the first + // undelayed response for each syncer. + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries1.filterRangeReqs: + // We'll send all the channels that it should know of. + chanSeries1.filterRangeResp <- syncer1Chans + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries2.filterRangeReqs: + // We'll send back all the channels that it should know of. + chanSeries2.filterRangeResp <- syncer2Chans + } + + // At this point, we'll forward the ReplyChannelRange messages to both + // parties. After receiving the set of channels known to the remote peer + for i := 0; i < numQueryResponses; i++ { + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a ReplyChannelRange message. + _, ok := msg.(*lnwire.ReplyChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + } + } + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer2") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a ReplyChannelRange message. + _, ok := msg.(*lnwire.ReplyChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + } + } + } + } + + // We'll now send back a chunked response for both parties of the known + // short chan ID's. + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries1.filterReq: + chanSeries1.filterResp <- syncer2Chans + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries2.filterReq: + chanSeries2.filterResp <- syncer1Chans + } + + // At this point, both parties should start to send out initial + // requests to query the chan IDs of the remote party. We'll keep track + // of the number of queries made using the iterated value, which starts + // at one due the initial contribution of the QueryChannelRange msgs. + for i := 1; i < numTotalQueries; i++ { + // Both parties should now have sent out the initial requests + // to query the chan IDs of the other party. + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a QueryShortChanIDs message. + _, ok := msg.(*lnwire.QueryShortChanIDs) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryShortChanIDs for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + + } + } + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer2") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a QueryShortChanIDs message. + _, ok := msg.(*lnwire.QueryShortChanIDs) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryShortChanIDs for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + + } + } + } + + // We'll then respond to both parties with an empty set of + // replies (as it doesn't affect the test). + switch { + + // If this query has surpassed the undelayed query threshold, we + // will impose stricter timing constraints on the response + // times. We'll first test that the peers don't immediately + // receive a query, and then check that both queries haven't + // gone unanswered entirely. + case i >= numUndelayedQueries: + // Create a before and after timeout to test, our test + // will ensure the messages are delivered to the peers + // in this timeframe. + before := time.After( + delayedQueryInterval - delayTolerance, + ) + after := time.After( + delayedQueryInterval + delayTolerance, + ) + + // First, ensure neither peer tries to respond up until + // the before time fires. + select { + case <-before: + // Queries are delayed, proceed. + + case <-chanSeries1.annReq: + t.Fatalf("DOSy query was not delayed") + + case <-chanSeries2.annReq: + t.Fatalf("DOSy query was not delayed") + } + + // Next, we'll need to test that both queries are + // received before the after timer expires. To account + // for ordering, we will try to pull a message from both + // peers, and then test that the opposite peer also + // receives the message promptly. + var ( + firstChanSeries *mockChannelGraphTimeSeries + laterChanSeries *mockChannelGraphTimeSeries + ) + + // If neither peer attempts a response within the + // allowed interval, then the messages are probably + // lost. Otherwise, process the message and record the + // induced ordering. + select { + case <-after: + t.Fatalf("no delayed query received") + + case <-chanSeries1.annReq: + chanSeries1.annResp <- []lnwire.Message{} + firstChanSeries = chanSeries1 + laterChanSeries = chanSeries2 + + case <-chanSeries2.annReq: + chanSeries2.annResp <- []lnwire.Message{} + firstChanSeries = chanSeries2 + laterChanSeries = chanSeries1 + } + + // Finally, using the same interval timeout as before, + // ensure the later peer also responds promptly. We also + // assert that the first peer doesn't attempt another + // response. + select { + case <-after: + t.Fatalf("no delayed query received") + + case <-firstChanSeries.annReq: + t.Fatalf("spurious undelayed response") + + case <-laterChanSeries.annReq: + laterChanSeries.annResp <- []lnwire.Message{} + } + + // Otherwise, we still haven't exceeded our undelayed query + // limit. Assert that both peers promptly attempt a response to + // the queries. + default: + select { + case <-time.After(50 * time.Millisecond): + t.Fatalf("no query recvd") + + case <-chanSeries1.annReq: + chanSeries1.annResp <- []lnwire.Message{} + } + select { + case <-time.After(50 * time.Millisecond): + t.Fatalf("no query recvd") + + case <-chanSeries2.annReq: + chanSeries2.annResp <- []lnwire.Message{} + } + } + + // Finally, both sides should then receive a + // ReplyShortChanIDsEnd as the first chunk has been replied to. + select { + case <-time.After(50 * time.Millisecond): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a ReplyShortChanIDsEnd message. + _, ok := msg.(*lnwire.ReplyShortChanIDsEnd) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + + } + } + } + select { + case <-time.After(50 * time.Millisecond): + t.Fatalf("didn't get msg from syncer2") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a ReplyShortChanIDsEnd message. + _, ok := msg.(*lnwire.ReplyShortChanIDsEnd) + if !ok { + t.Fatalf("wrong message: expected "+ + "ReplyShortChanIDsEnd for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + + } + } + } + } +} + // TestGossipSyncerRoutineSync tests all state transitions of the main syncer // goroutine. This ensures that given an encounter with a peer that has a set // of distinct channels, then we'll properly synchronize our channel state with