Merge pull request #1824 from cfromknecht/gossip-delay-reply
discovery/syncer: delay replies after initial sync to prevent DOS
This commit is contained in:
commit
1941353fb2
@ -3,6 +3,7 @@ package discovery
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
@ -14,7 +15,6 @@ import (
|
|||||||
"github.com/btcsuite/btcd/wire"
|
"github.com/btcsuite/btcd/wire"
|
||||||
"github.com/coreos/bbolt"
|
"github.com/coreos/bbolt"
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/go-errors/errors"
|
|
||||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
"github.com/lightningnetwork/lnd/lnpeer"
|
"github.com/lightningnetwork/lnd/lnpeer"
|
||||||
@ -1522,7 +1522,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge(
|
|||||||
}
|
}
|
||||||
err = ValidateChannelAnn(chanAnn)
|
err = ValidateChannelAnn(chanAnn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := errors.Errorf("assembled channel announcement proof "+
|
err := fmt.Errorf("assembled channel announcement proof "+
|
||||||
"for shortChanID=%v isn't valid: %v",
|
"for shortChanID=%v isn't valid: %v",
|
||||||
chanAnnMsg.ShortChannelID, err)
|
chanAnnMsg.ShortChannelID, err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -1533,7 +1533,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge(
|
|||||||
// to the database.
|
// to the database.
|
||||||
err = d.cfg.Router.AddProof(chanAnnMsg.ShortChannelID, proof)
|
err = d.cfg.Router.AddProof(chanAnnMsg.ShortChannelID, proof)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := errors.Errorf("unable add proof to shortChanID=%v: %v",
|
err := fmt.Errorf("unable add proof to shortChanID=%v: %v",
|
||||||
chanAnnMsg.ShortChannelID, err)
|
chanAnnMsg.ShortChannelID, err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -1599,7 +1599,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := ValidateNodeAnn(msg); err != nil {
|
if err := ValidateNodeAnn(msg); err != nil {
|
||||||
err := errors.Errorf("unable to validate "+
|
err := fmt.Errorf("unable to validate "+
|
||||||
"node announcement: %v", err)
|
"node announcement: %v", err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
nMsg.err <- err
|
nMsg.err <- err
|
||||||
@ -1702,7 +1702,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
var proof *channeldb.ChannelAuthProof
|
var proof *channeldb.ChannelAuthProof
|
||||||
if nMsg.isRemote {
|
if nMsg.isRemote {
|
||||||
if err := ValidateChannelAnn(msg); err != nil {
|
if err := ValidateChannelAnn(msg); err != nil {
|
||||||
err := errors.Errorf("unable to validate "+
|
err := fmt.Errorf("unable to validate "+
|
||||||
"announcement: %v", err)
|
"announcement: %v", err)
|
||||||
d.rejectMtx.Lock()
|
d.rejectMtx.Lock()
|
||||||
d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{}
|
d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{}
|
||||||
@ -1966,7 +1966,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
return nil
|
return nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
err := errors.Errorf("unable to validate "+
|
err := fmt.Errorf("unable to validate "+
|
||||||
"channel update short_chan_id=%v: %v",
|
"channel update short_chan_id=%v: %v",
|
||||||
shortChanID, err)
|
shortChanID, err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -1994,7 +1994,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
// key, In the case of an invalid channel , we'll return an
|
// key, In the case of an invalid channel , we'll return an
|
||||||
// error to the caller and exit early.
|
// error to the caller and exit early.
|
||||||
if err := ValidateChannelUpdateAnn(pubKey, msg); err != nil {
|
if err := ValidateChannelUpdateAnn(pubKey, msg); err != nil {
|
||||||
rErr := errors.Errorf("unable to validate channel "+
|
rErr := fmt.Errorf("unable to validate channel "+
|
||||||
"update announcement for short_chan_id=%v: %v",
|
"update announcement for short_chan_id=%v: %v",
|
||||||
spew.Sdump(msg.ShortChannelID), err)
|
spew.Sdump(msg.ShortChannelID), err)
|
||||||
|
|
||||||
@ -2130,7 +2130,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
// node might rewrite the waiting proof.
|
// node might rewrite the waiting proof.
|
||||||
proof := channeldb.NewWaitingProof(nMsg.isRemote, msg)
|
proof := channeldb.NewWaitingProof(nMsg.isRemote, msg)
|
||||||
if err := d.waitingProofs.Add(proof); err != nil {
|
if err := d.waitingProofs.Add(proof); err != nil {
|
||||||
err := errors.Errorf("unable to store "+
|
err := fmt.Errorf("unable to store "+
|
||||||
"the proof for short_chan_id=%v: %v",
|
"the proof for short_chan_id=%v: %v",
|
||||||
shortChanID, err)
|
shortChanID, err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -2152,7 +2152,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
// Ensure that channel that was retrieved belongs to the peer
|
// Ensure that channel that was retrieved belongs to the peer
|
||||||
// which sent the proof announcement.
|
// which sent the proof announcement.
|
||||||
if !(isFirstNode || isSecondNode) {
|
if !(isFirstNode || isSecondNode) {
|
||||||
err := errors.Errorf("channel that was received not "+
|
err := fmt.Errorf("channel that was received not "+
|
||||||
"belongs to the peer which sent the proof, "+
|
"belongs to the peer which sent the proof, "+
|
||||||
"short_chan_id=%v", shortChanID)
|
"short_chan_id=%v", shortChanID)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -2176,7 +2176,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
// deliver the proof when it comes online.
|
// deliver the proof when it comes online.
|
||||||
err := d.sendAnnSigReliably(msg, remotePeer)
|
err := d.sendAnnSigReliably(msg, remotePeer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := errors.Errorf("unable to send reliably "+
|
err := fmt.Errorf("unable to send reliably "+
|
||||||
"to remote for short_chan_id=%v: %v",
|
"to remote for short_chan_id=%v: %v",
|
||||||
shortChanID, err)
|
shortChanID, err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -2245,7 +2245,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
proof := channeldb.NewWaitingProof(nMsg.isRemote, msg)
|
proof := channeldb.NewWaitingProof(nMsg.isRemote, msg)
|
||||||
oppositeProof, err := d.waitingProofs.Get(proof.OppositeKey())
|
oppositeProof, err := d.waitingProofs.Get(proof.OppositeKey())
|
||||||
if err != nil && err != channeldb.ErrWaitingProofNotFound {
|
if err != nil && err != channeldb.ErrWaitingProofNotFound {
|
||||||
err := errors.Errorf("unable to get "+
|
err := fmt.Errorf("unable to get "+
|
||||||
"the opposite proof for short_chan_id=%v: %v",
|
"the opposite proof for short_chan_id=%v: %v",
|
||||||
shortChanID, err)
|
shortChanID, err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -2255,7 +2255,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
|
|
||||||
if err == channeldb.ErrWaitingProofNotFound {
|
if err == channeldb.ErrWaitingProofNotFound {
|
||||||
if err := d.waitingProofs.Add(proof); err != nil {
|
if err := d.waitingProofs.Add(proof); err != nil {
|
||||||
err := errors.Errorf("unable to store "+
|
err := fmt.Errorf("unable to store "+
|
||||||
"the proof for short_chan_id=%v: %v",
|
"the proof for short_chan_id=%v: %v",
|
||||||
shortChanID, err)
|
shortChanID, err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -2298,7 +2298,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
// With all the necessary components assembled validate the
|
// With all the necessary components assembled validate the
|
||||||
// full channel announcement proof.
|
// full channel announcement proof.
|
||||||
if err := ValidateChannelAnn(chanAnn); err != nil {
|
if err := ValidateChannelAnn(chanAnn); err != nil {
|
||||||
err := errors.Errorf("channel announcement proof "+
|
err := fmt.Errorf("channel announcement proof "+
|
||||||
"for short_chan_id=%v isn't valid: %v",
|
"for short_chan_id=%v isn't valid: %v",
|
||||||
shortChanID, err)
|
shortChanID, err)
|
||||||
|
|
||||||
@ -2316,7 +2316,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
// can announce it on peer connect.
|
// can announce it on peer connect.
|
||||||
err = d.cfg.Router.AddProof(msg.ShortChannelID, &dbProof)
|
err = d.cfg.Router.AddProof(msg.ShortChannelID, &dbProof)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := errors.Errorf("unable add proof to the "+
|
err := fmt.Errorf("unable add proof to the "+
|
||||||
"channel chanID=%v: %v", msg.ChannelID, err)
|
"channel chanID=%v: %v", msg.ChannelID, err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
nMsg.err <- err
|
nMsg.err <- err
|
||||||
@ -2325,7 +2325,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
|
|
||||||
err = d.waitingProofs.Remove(proof.OppositeKey())
|
err = d.waitingProofs.Remove(proof.OppositeKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := errors.Errorf("unable remove opposite proof "+
|
err := fmt.Errorf("unable remove opposite proof "+
|
||||||
"for the channel with chanID=%v: %v",
|
"for the channel with chanID=%v: %v",
|
||||||
msg.ChannelID, err)
|
msg.ChannelID, err)
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package discovery
|
package discovery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"sync"
|
"sync"
|
||||||
@ -52,6 +53,17 @@ const (
|
|||||||
chansSynced
|
chansSynced
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultMaxUndelayedQueryReplies specifies how many gossip queries we
|
||||||
|
// will respond to immediately before starting to delay responses.
|
||||||
|
DefaultMaxUndelayedQueryReplies = 5
|
||||||
|
|
||||||
|
// DefaultDelayedQueryReplyInterval is the length of time we will wait
|
||||||
|
// before responding to gossip queries after replying to
|
||||||
|
// maxUndelayedQueryReplies queries.
|
||||||
|
DefaultDelayedQueryReplyInterval = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
// String returns a human readable string describing the target syncerState.
|
// String returns a human readable string describing the target syncerState.
|
||||||
func (s syncerState) String() string {
|
func (s syncerState) String() string {
|
||||||
switch s {
|
switch s {
|
||||||
@ -82,6 +94,9 @@ var (
|
|||||||
encodingTypeToChunkSize = map[lnwire.ShortChanIDEncoding]int32{
|
encodingTypeToChunkSize = map[lnwire.ShortChanIDEncoding]int32{
|
||||||
lnwire.EncodingSortedPlain: 8000,
|
lnwire.EncodingSortedPlain: 8000,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrGossipSyncerExiting signals that the syncer has been killed.
|
||||||
|
ErrGossipSyncerExiting = errors.New("gossip syncer exiting")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -167,6 +182,15 @@ type gossipSyncerCfg struct {
|
|||||||
// targeted messages to the peer we've been assigned to sync the graph
|
// targeted messages to the peer we've been assigned to sync the graph
|
||||||
// state from.
|
// state from.
|
||||||
sendToPeer func(...lnwire.Message) error
|
sendToPeer func(...lnwire.Message) error
|
||||||
|
|
||||||
|
// maxUndelayedQueryReplies specifies how many gossip queries we will
|
||||||
|
// respond to immediately before starting to delay responses.
|
||||||
|
maxUndelayedQueryReplies int
|
||||||
|
|
||||||
|
// delayedQueryReplyInterval is the length of time we will wait before
|
||||||
|
// responding to gossip queries after replying to
|
||||||
|
// maxUndelayedQueryReplies queries.
|
||||||
|
delayedQueryReplyInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// gossipSyncer is a struct that handles synchronizing the channel graph state
|
// gossipSyncer is a struct that handles synchronizing the channel graph state
|
||||||
@ -214,6 +238,11 @@ type gossipSyncer struct {
|
|||||||
|
|
||||||
cfg gossipSyncerCfg
|
cfg gossipSyncerCfg
|
||||||
|
|
||||||
|
// replyCount records how many query replies we've responded to. This is
|
||||||
|
// used to determine when to start delaying responses to peers to
|
||||||
|
// prevent DOS vulnerabilities.
|
||||||
|
replyCount int
|
||||||
|
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
@ -223,6 +252,18 @@ type gossipSyncer struct {
|
|||||||
// newGossiperSyncer returns a new instance of the gossipSyncer populated using
|
// newGossiperSyncer returns a new instance of the gossipSyncer populated using
|
||||||
// the passed config.
|
// the passed config.
|
||||||
func newGossiperSyncer(cfg gossipSyncerCfg) *gossipSyncer {
|
func newGossiperSyncer(cfg gossipSyncerCfg) *gossipSyncer {
|
||||||
|
// If no parameter was specified for max undelayed query replies, set it
|
||||||
|
// to the default of 5 queries.
|
||||||
|
if cfg.maxUndelayedQueryReplies <= 0 {
|
||||||
|
cfg.maxUndelayedQueryReplies = DefaultMaxUndelayedQueryReplies
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no parameter was specified for delayed query reply interval, set
|
||||||
|
// to the default of 30 seconds.
|
||||||
|
if cfg.delayedQueryReplyInterval <= 0 {
|
||||||
|
cfg.delayedQueryReplyInterval = DefaultDelayedQueryReplyInterval
|
||||||
|
}
|
||||||
|
|
||||||
return &gossipSyncer{
|
return &gossipSyncer{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
gossipMsgs: make(chan lnwire.Message, 100),
|
gossipMsgs: make(chan lnwire.Message, 100),
|
||||||
@ -332,7 +373,7 @@ func (g *gossipSyncer) channelGraphSyncer() {
|
|||||||
// Otherwise, it's the remote peer performing a
|
// Otherwise, it's the remote peer performing a
|
||||||
// query, which we'll attempt to reply to.
|
// query, which we'll attempt to reply to.
|
||||||
err := g.replyPeerQueries(msg)
|
err := g.replyPeerQueries(msg)
|
||||||
if err != nil {
|
if err != nil && err != ErrGossipSyncerExiting {
|
||||||
log.Errorf("unable to reply to peer "+
|
log.Errorf("unable to reply to peer "+
|
||||||
"query: %v", err)
|
"query: %v", err)
|
||||||
}
|
}
|
||||||
@ -386,7 +427,7 @@ func (g *gossipSyncer) channelGraphSyncer() {
|
|||||||
// Otherwise, it's the remote peer performing a
|
// Otherwise, it's the remote peer performing a
|
||||||
// query, which we'll attempt to deploy to.
|
// query, which we'll attempt to deploy to.
|
||||||
err := g.replyPeerQueries(msg)
|
err := g.replyPeerQueries(msg)
|
||||||
if err != nil {
|
if err != nil && err != ErrGossipSyncerExiting {
|
||||||
log.Errorf("unable to reply to peer "+
|
log.Errorf("unable to reply to peer "+
|
||||||
"query: %v", err)
|
"query: %v", err)
|
||||||
}
|
}
|
||||||
@ -430,7 +471,7 @@ func (g *gossipSyncer) channelGraphSyncer() {
|
|||||||
select {
|
select {
|
||||||
case msg := <-g.gossipMsgs:
|
case msg := <-g.gossipMsgs:
|
||||||
err := g.replyPeerQueries(msg)
|
err := g.replyPeerQueries(msg)
|
||||||
if err != nil {
|
if err != nil && err != ErrGossipSyncerExiting {
|
||||||
log.Errorf("unable to reply to peer "+
|
log.Errorf("unable to reply to peer "+
|
||||||
"query: %v", err)
|
"query: %v", err)
|
||||||
}
|
}
|
||||||
@ -588,6 +629,24 @@ func (g *gossipSyncer) genChanRangeQuery() (*lnwire.QueryChannelRange, error) {
|
|||||||
// replyPeerQueries is called in response to any query by the remote peer.
|
// replyPeerQueries is called in response to any query by the remote peer.
|
||||||
// We'll examine our state and send back our best response.
|
// We'll examine our state and send back our best response.
|
||||||
func (g *gossipSyncer) replyPeerQueries(msg lnwire.Message) error {
|
func (g *gossipSyncer) replyPeerQueries(msg lnwire.Message) error {
|
||||||
|
// If we've already replied a handful of times, we will start to delay
|
||||||
|
// responses back to the remote peer. This can help prevent DOS attacks
|
||||||
|
// where the remote peer spams us endlessly.
|
||||||
|
switch {
|
||||||
|
case g.replyCount == g.cfg.maxUndelayedQueryReplies:
|
||||||
|
log.Infof("gossipSyncer(%x): entering delayed gossip replies",
|
||||||
|
g.peerPub[:])
|
||||||
|
fallthrough
|
||||||
|
|
||||||
|
case g.replyCount > g.cfg.maxUndelayedQueryReplies:
|
||||||
|
select {
|
||||||
|
case <-time.After(g.cfg.delayedQueryReplyInterval):
|
||||||
|
case <-g.quit:
|
||||||
|
return ErrGossipSyncerExiting
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.replyCount++
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
|
|
||||||
// In this state, we'll also handle any incoming channel range queries
|
// In this state, we'll also handle any incoming channel range queries
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package discovery
|
package discovery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
@ -49,7 +48,9 @@ type mockChannelGraphTimeSeries struct {
|
|||||||
updateResp chan []*lnwire.ChannelUpdate
|
updateResp chan []*lnwire.ChannelUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockChannelGraphTimeSeries(hID lnwire.ShortChannelID) *mockChannelGraphTimeSeries {
|
func newMockChannelGraphTimeSeries(
|
||||||
|
hID lnwire.ShortChannelID) *mockChannelGraphTimeSeries {
|
||||||
|
|
||||||
return &mockChannelGraphTimeSeries{
|
return &mockChannelGraphTimeSeries{
|
||||||
highestID: hID,
|
highestID: hID,
|
||||||
|
|
||||||
@ -127,6 +128,7 @@ func newTestSyncer(hID lnwire.ShortChannelID,
|
|||||||
msgChan <- msgs
|
msgChan <- msgs
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
delayedQueryReplyInterval: 2 * time.Second,
|
||||||
}
|
}
|
||||||
syncer := newGossiperSyncer(cfg)
|
syncer := newGossiperSyncer(cfg)
|
||||||
|
|
||||||
@ -810,9 +812,6 @@ func TestGossipSyncerProcessChanRangeReply(t *testing.T) {
|
|||||||
// We should get a request for the entire range of short
|
// We should get a request for the entire range of short
|
||||||
// chan ID's.
|
// chan ID's.
|
||||||
if !reflect.DeepEqual(expectedReq, req) {
|
if !reflect.DeepEqual(expectedReq, req) {
|
||||||
fmt.Printf("wrong request: expected %v, got %v\n",
|
|
||||||
expectedReq, req)
|
|
||||||
|
|
||||||
t.Fatalf("wrong request: expected %v, got %v",
|
t.Fatalf("wrong request: expected %v, got %v",
|
||||||
expectedReq, req)
|
expectedReq, req)
|
||||||
}
|
}
|
||||||
@ -983,6 +982,402 @@ func TestGossipSyncerSynchronizeChanIDs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGossipSyncerDelayDOS tests that the gossip syncer will begin delaying
|
||||||
|
// queries after its prescribed allotment of undelayed query responses. Once
|
||||||
|
// this happens, all query replies should be delayed by the configurated
|
||||||
|
// interval.
|
||||||
|
func TestGossipSyncerDelayDOS(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// We'll modify the chunk size to be a smaller value, since we'll be
|
||||||
|
// sending a modest number of queries. After exhausting our undelayed
|
||||||
|
// gossip queries, we'll send two extra queries and ensure that they are
|
||||||
|
// delayed properly.
|
||||||
|
const chunkSize = 2
|
||||||
|
const numDelayedQueries = 2
|
||||||
|
const delayTolerance = time.Millisecond * 200
|
||||||
|
|
||||||
|
// First, we'll create two gossipSyncer instances with a canned
|
||||||
|
// sendToPeer message to allow us to intercept their potential sends.
|
||||||
|
startHeight := lnwire.ShortChannelID{
|
||||||
|
BlockHeight: 1144,
|
||||||
|
}
|
||||||
|
msgChan1, syncer1, chanSeries1 := newTestSyncer(
|
||||||
|
startHeight, defaultEncoding, chunkSize,
|
||||||
|
)
|
||||||
|
syncer1.Start()
|
||||||
|
defer syncer1.Stop()
|
||||||
|
|
||||||
|
msgChan2, syncer2, chanSeries2 := newTestSyncer(
|
||||||
|
startHeight, defaultEncoding, chunkSize,
|
||||||
|
)
|
||||||
|
syncer2.Start()
|
||||||
|
defer syncer2.Stop()
|
||||||
|
|
||||||
|
// Record the delayed query reply interval used by each syncer.
|
||||||
|
delayedQueryInterval := syncer1.cfg.delayedQueryReplyInterval
|
||||||
|
|
||||||
|
// Record the number of undelayed queries allowed by the syncers.
|
||||||
|
numUndelayedQueries := syncer1.cfg.maxUndelayedQueryReplies
|
||||||
|
|
||||||
|
// We will send enough queries to exhaust the undelayed responses, and
|
||||||
|
// then send two more queries which should be delayed.
|
||||||
|
numQueryResponses := numUndelayedQueries + numDelayedQueries
|
||||||
|
|
||||||
|
// The total number of responses must include the initial reply each
|
||||||
|
// syner will make to QueryChannelRange.
|
||||||
|
numTotalQueries := 1 + numQueryResponses
|
||||||
|
|
||||||
|
// The total number of channels each syncer needs to request must be
|
||||||
|
// scaled by the chunk size being used.
|
||||||
|
numTotalChans := numQueryResponses * chunkSize
|
||||||
|
|
||||||
|
// Although both nodes are at the same height, they'll have a
|
||||||
|
// completely disjoint set of chan ID's that they know of.
|
||||||
|
var syncer1Chans []lnwire.ShortChannelID
|
||||||
|
for i := 0; i < numTotalChans; i++ {
|
||||||
|
syncer1Chans = append(
|
||||||
|
syncer1Chans, lnwire.NewShortChanIDFromInt(uint64(i)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
var syncer2Chans []lnwire.ShortChannelID
|
||||||
|
for i := numTotalChans; i < numTotalChans+numTotalChans; i++ {
|
||||||
|
syncer2Chans = append(
|
||||||
|
syncer2Chans, lnwire.NewShortChanIDFromInt(uint64(i)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We'll kick off the test by passing over the QueryChannelRange
|
||||||
|
// messages from one node to the other.
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("didn't get msg from syncer1")
|
||||||
|
|
||||||
|
case msgs := <-msgChan1:
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// The message MUST be a QueryChannelRange message.
|
||||||
|
_, ok := msg.(*lnwire.QueryChannelRange)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("wrong message: expected "+
|
||||||
|
"QueryChannelRange for %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("node 2 didn't read msg")
|
||||||
|
|
||||||
|
case syncer2.gossipMsgs <- msg:
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("didn't get msg from syncer2")
|
||||||
|
|
||||||
|
case msgs := <-msgChan2:
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// The message MUST be a QueryChannelRange message.
|
||||||
|
_, ok := msg.(*lnwire.QueryChannelRange)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("wrong message: expected "+
|
||||||
|
"QueryChannelRange for %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("node 2 didn't read msg")
|
||||||
|
|
||||||
|
case syncer1.gossipMsgs <- msg:
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, we'll need to send responses to both nodes from their
|
||||||
|
// respective channel series. Both nodes will simply request the entire
|
||||||
|
// set of channels from the other. This will count as the first
|
||||||
|
// undelayed response for each syncer.
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("no query recvd")
|
||||||
|
|
||||||
|
case <-chanSeries1.filterRangeReqs:
|
||||||
|
// We'll send all the channels that it should know of.
|
||||||
|
chanSeries1.filterRangeResp <- syncer1Chans
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("no query recvd")
|
||||||
|
|
||||||
|
case <-chanSeries2.filterRangeReqs:
|
||||||
|
// We'll send back all the channels that it should know of.
|
||||||
|
chanSeries2.filterRangeResp <- syncer2Chans
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, we'll forward the ReplyChannelRange messages to both
|
||||||
|
// parties. After receiving the set of channels known to the remote peer
|
||||||
|
for i := 0; i < numQueryResponses; i++ {
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("didn't get msg from syncer1")
|
||||||
|
|
||||||
|
case msgs := <-msgChan1:
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// The message MUST be a ReplyChannelRange message.
|
||||||
|
_, ok := msg.(*lnwire.ReplyChannelRange)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("wrong message: expected "+
|
||||||
|
"QueryChannelRange for %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("node 2 didn't read msg")
|
||||||
|
|
||||||
|
case syncer2.gossipMsgs <- msg:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("didn't get msg from syncer2")
|
||||||
|
|
||||||
|
case msgs := <-msgChan2:
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// The message MUST be a ReplyChannelRange message.
|
||||||
|
_, ok := msg.(*lnwire.ReplyChannelRange)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("wrong message: expected "+
|
||||||
|
"QueryChannelRange for %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("node 2 didn't read msg")
|
||||||
|
|
||||||
|
case syncer1.gossipMsgs <- msg:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We'll now send back a chunked response for both parties of the known
|
||||||
|
// short chan ID's.
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("no query recvd")
|
||||||
|
|
||||||
|
case <-chanSeries1.filterReq:
|
||||||
|
chanSeries1.filterResp <- syncer2Chans
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("no query recvd")
|
||||||
|
|
||||||
|
case <-chanSeries2.filterReq:
|
||||||
|
chanSeries2.filterResp <- syncer1Chans
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, both parties should start to send out initial
|
||||||
|
// requests to query the chan IDs of the remote party. We'll keep track
|
||||||
|
// of the number of queries made using the iterated value, which starts
|
||||||
|
// at one due the initial contribution of the QueryChannelRange msgs.
|
||||||
|
for i := 1; i < numTotalQueries; i++ {
|
||||||
|
// Both parties should now have sent out the initial requests
|
||||||
|
// to query the chan IDs of the other party.
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("didn't get msg from syncer1")
|
||||||
|
|
||||||
|
case msgs := <-msgChan1:
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// The message MUST be a QueryShortChanIDs message.
|
||||||
|
_, ok := msg.(*lnwire.QueryShortChanIDs)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("wrong message: expected "+
|
||||||
|
"QueryShortChanIDs for %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("node 2 didn't read msg")
|
||||||
|
|
||||||
|
case syncer2.gossipMsgs <- msg:
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("didn't get msg from syncer2")
|
||||||
|
|
||||||
|
case msgs := <-msgChan2:
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// The message MUST be a QueryShortChanIDs message.
|
||||||
|
_, ok := msg.(*lnwire.QueryShortChanIDs)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("wrong message: expected "+
|
||||||
|
"QueryShortChanIDs for %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("node 2 didn't read msg")
|
||||||
|
|
||||||
|
case syncer1.gossipMsgs <- msg:
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We'll then respond to both parties with an empty set of
|
||||||
|
// replies (as it doesn't affect the test).
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// If this query has surpassed the undelayed query threshold, we
|
||||||
|
// will impose stricter timing constraints on the response
|
||||||
|
// times. We'll first test that the peers don't immediately
|
||||||
|
// receive a query, and then check that both queries haven't
|
||||||
|
// gone unanswered entirely.
|
||||||
|
case i >= numUndelayedQueries:
|
||||||
|
// Create a before and after timeout to test, our test
|
||||||
|
// will ensure the messages are delivered to the peers
|
||||||
|
// in this timeframe.
|
||||||
|
before := time.After(
|
||||||
|
delayedQueryInterval - delayTolerance,
|
||||||
|
)
|
||||||
|
after := time.After(
|
||||||
|
delayedQueryInterval + delayTolerance,
|
||||||
|
)
|
||||||
|
|
||||||
|
// First, ensure neither peer tries to respond up until
|
||||||
|
// the before time fires.
|
||||||
|
select {
|
||||||
|
case <-before:
|
||||||
|
// Queries are delayed, proceed.
|
||||||
|
|
||||||
|
case <-chanSeries1.annReq:
|
||||||
|
t.Fatalf("DOSy query was not delayed")
|
||||||
|
|
||||||
|
case <-chanSeries2.annReq:
|
||||||
|
t.Fatalf("DOSy query was not delayed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, we'll need to test that both queries are
|
||||||
|
// received before the after timer expires. To account
|
||||||
|
// for ordering, we will try to pull a message from both
|
||||||
|
// peers, and then test that the opposite peer also
|
||||||
|
// receives the message promptly.
|
||||||
|
var (
|
||||||
|
firstChanSeries *mockChannelGraphTimeSeries
|
||||||
|
laterChanSeries *mockChannelGraphTimeSeries
|
||||||
|
)
|
||||||
|
|
||||||
|
// If neither peer attempts a response within the
|
||||||
|
// allowed interval, then the messages are probably
|
||||||
|
// lost. Otherwise, process the message and record the
|
||||||
|
// induced ordering.
|
||||||
|
select {
|
||||||
|
case <-after:
|
||||||
|
t.Fatalf("no delayed query received")
|
||||||
|
|
||||||
|
case <-chanSeries1.annReq:
|
||||||
|
chanSeries1.annResp <- []lnwire.Message{}
|
||||||
|
firstChanSeries = chanSeries1
|
||||||
|
laterChanSeries = chanSeries2
|
||||||
|
|
||||||
|
case <-chanSeries2.annReq:
|
||||||
|
chanSeries2.annResp <- []lnwire.Message{}
|
||||||
|
firstChanSeries = chanSeries2
|
||||||
|
laterChanSeries = chanSeries1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, using the same interval timeout as before,
|
||||||
|
// ensure the later peer also responds promptly. We also
|
||||||
|
// assert that the first peer doesn't attempt another
|
||||||
|
// response.
|
||||||
|
select {
|
||||||
|
case <-after:
|
||||||
|
t.Fatalf("no delayed query received")
|
||||||
|
|
||||||
|
case <-firstChanSeries.annReq:
|
||||||
|
t.Fatalf("spurious undelayed response")
|
||||||
|
|
||||||
|
case <-laterChanSeries.annReq:
|
||||||
|
laterChanSeries.annResp <- []lnwire.Message{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we still haven't exceeded our undelayed query
|
||||||
|
// limit. Assert that both peers promptly attempt a response to
|
||||||
|
// the queries.
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
t.Fatalf("no query recvd")
|
||||||
|
|
||||||
|
case <-chanSeries1.annReq:
|
||||||
|
chanSeries1.annResp <- []lnwire.Message{}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
t.Fatalf("no query recvd")
|
||||||
|
|
||||||
|
case <-chanSeries2.annReq:
|
||||||
|
chanSeries2.annResp <- []lnwire.Message{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, both sides should then receive a
|
||||||
|
// ReplyShortChanIDsEnd as the first chunk has been replied to.
|
||||||
|
select {
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
t.Fatalf("didn't get msg from syncer1")
|
||||||
|
|
||||||
|
case msgs := <-msgChan1:
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// The message MUST be a ReplyShortChanIDsEnd message.
|
||||||
|
_, ok := msg.(*lnwire.ReplyShortChanIDsEnd)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("wrong message: expected "+
|
||||||
|
"QueryChannelRange for %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("node 2 didn't read msg")
|
||||||
|
|
||||||
|
case syncer2.gossipMsgs <- msg:
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
t.Fatalf("didn't get msg from syncer2")
|
||||||
|
|
||||||
|
case msgs := <-msgChan2:
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// The message MUST be a ReplyShortChanIDsEnd message.
|
||||||
|
_, ok := msg.(*lnwire.ReplyShortChanIDsEnd)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("wrong message: expected "+
|
||||||
|
"ReplyShortChanIDsEnd for %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 2):
|
||||||
|
t.Fatalf("node 2 didn't read msg")
|
||||||
|
|
||||||
|
case syncer1.gossipMsgs <- msg:
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestGossipSyncerRoutineSync tests all state transitions of the main syncer
|
// TestGossipSyncerRoutineSync tests all state transitions of the main syncer
|
||||||
// goroutine. This ensures that given an encounter with a peer that has a set
|
// goroutine. This ensures that given an encounter with a peer that has a set
|
||||||
// of distinct channels, then we'll properly synchronize our channel state with
|
// of distinct channels, then we'll properly synchronize our channel state with
|
||||||
|
Loading…
Reference in New Issue
Block a user