diff --git a/peer/brontide.go b/peer/brontide.go index 578d7133..c0669cdf 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -446,9 +446,36 @@ func (p *Brontide) Start() error { peerLog.Tracef("Peer %v starting", p) + // Fetch and then load all the active channels we have with this remote + // peer from the database. + activeChans, err := p.cfg.ChannelDB.FetchOpenChannels( + p.cfg.Addr.IdentityKey, + ) + if err != nil { + peerLog.Errorf("Unable to fetch active chans "+ + "for peer %v: %v", p, err) + return err + } + + if len(activeChans) == 0 { + p.cfg.PrunePersistentPeerConnection(p.cfg.PubKeyBytes) + } + + // Quickly check if we have any existing legacy channels with this + // peer. + haveLegacyChan := false + for _, c := range activeChans { + if c.ChanType.IsTweakless() { + continue + } + + haveLegacyChan = true + break + } + // Exchange local and global features, the init message should be very // first between two nodes. - if err := p.sendInitMsg(); err != nil { + if err := p.sendInitMsg(haveLegacyChan); err != nil { return fmt.Errorf("unable to send init msg: %v", err) } @@ -496,19 +523,6 @@ func (p *Brontide) Start() error { "must be init message") } - // Fetch and then load all the active channels we have with this remote - // peer from the database. - activeChans, err := p.cfg.ChannelDB.FetchOpenChannels(p.cfg.Addr.IdentityKey) - if err != nil { - peerLog.Errorf("unable to fetch active chans "+ - "for peer %v: %v", p, err) - return err - } - - if len(activeChans) == 0 { - p.cfg.PrunePersistentPeerConnection(p.cfg.PubKeyBytes) - } - // Next, load all the active channels we have with this peer, // registering them with the switch and launching the necessary // goroutines required to operate them. @@ -2752,12 +2766,28 @@ func (p *Brontide) RemoteFeatures() *lnwire.FeatureVector { return p.remoteFeatures } -// sendInitMsg sends the Init message to the remote peer. This message contains our -// currently supported local and global features. -func (p *Brontide) sendInitMsg() error { +// sendInitMsg sends the Init message to the remote peer. This message contains +// our currently supported local and global features. +func (p *Brontide) sendInitMsg(legacyChan bool) error { + features := p.cfg.Features.Clone() + + // If we have a legacy channel open with a peer, we downgrade static + // remote required to optional in case the peer does not understand the + // required feature bit. If we do not do this, the peer will reject our + // connection because it does not understand a required feature bit, and + // our channel will be unusable. + if legacyChan && features.RequiresFeature(lnwire.StaticRemoteKeyRequired) { + peerLog.Infof("Legacy channel open with peer: %x, "+ + "downgrading static remote required feature bit to "+ + "optional", p.PubKey()) + + features.Unset(lnwire.StaticRemoteKeyRequired) + features.Set(lnwire.StaticRemoteKeyOptional) + } + msg := lnwire.NewInitMessage( p.cfg.LegacyFeatures.RawFeatureVector, - p.cfg.Features.RawFeatureVector, + features.RawFeatureVector, ) return p.writeMessage(msg) diff --git a/peer/brontide_test.go b/peer/brontide_test.go index 6150fe27..77734d66 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -15,6 +15,8 @@ import ( "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet/chancloser" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/pool" + "github.com/stretchr/testify/require" ) var ( @@ -23,10 +25,6 @@ var ( // p2wshAddress is a valid pay to witness script hash address. p2wshAddress = "bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3" - - // timeout is a timeout value to use for tests which need ot wait for - // a return value on a channel. - timeout = time.Second * 5 ) // TestPeerChannelClosureAcceptFeeResponder tests the shutdown responder's @@ -862,6 +860,120 @@ func TestCustomShutdownScript(t *testing.T) { } } +// TestStaticRemoteDowngrade tests that we downgrade our static remote feature +// bit to optional if we have legacy channels with a peer. This ensures that +// we can stay connected to peers that don't support the feature bit that we +// have channels with. +func TestStaticRemoteDowngrade(t *testing.T) { + t.Parallel() + + var ( + // We set the same legacy feature bits for all tests, since + // these are not relevant to our test scenario + rawLegacy = lnwire.NewRawFeatureVector( + lnwire.UpfrontShutdownScriptOptional, + ) + legacy = lnwire.NewFeatureVector(rawLegacy, nil) + + rawFeatureOptional = lnwire.NewRawFeatureVector( + lnwire.StaticRemoteKeyOptional, + ) + + featureOptional = lnwire.NewFeatureVector( + rawFeatureOptional, nil, + ) + + rawFeatureRequired = lnwire.NewRawFeatureVector( + lnwire.StaticRemoteKeyRequired, + ) + + featureRequired = lnwire.NewFeatureVector( + rawFeatureRequired, nil, + ) + ) + + tests := []struct { + name string + legacy bool + features *lnwire.FeatureVector + expectedInit *lnwire.Init + }{ + { + name: "no legacy channel, static optional", + legacy: false, + features: featureOptional, + expectedInit: &lnwire.Init{ + GlobalFeatures: rawLegacy, + Features: rawFeatureOptional, + }, + }, + { + name: "legacy channel, static optional", + legacy: true, + features: featureOptional, + expectedInit: &lnwire.Init{ + GlobalFeatures: rawLegacy, + Features: rawFeatureOptional, + }, + }, + { + name: "no legacy channel, static required", + legacy: false, + features: featureRequired, + expectedInit: &lnwire.Init{ + GlobalFeatures: rawLegacy, + Features: rawFeatureRequired, + }, + }, + { + name: "legacy channel, static required", + legacy: true, + features: featureRequired, + expectedInit: &lnwire.Init{ + GlobalFeatures: rawLegacy, + Features: rawFeatureOptional, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + writeBufferPool := pool.NewWriteBuffer( + pool.DefaultWriteBufferGCInterval, + pool.DefaultWriteBufferExpiryInterval, + ) + + writePool := pool.NewWrite( + writeBufferPool, 1, timeout, + ) + require.NoError(t, writePool.Start()) + + mockConn := newMockConn(t, 1) + + p := Brontide{ + cfg: Config{ + LegacyFeatures: legacy, + Features: test.features, + Conn: mockConn, + WritePool: writePool, + }, + } + + var b bytes.Buffer + _, err := lnwire.WriteMessage(&b, test.expectedInit, 0) + require.NoError(t, err) + + // Send our init message, assert that we write our expected message + // and shutdown our write pool. + require.NoError(t, p.sendInitMsg(test.legacy)) + mockConn.assertWrite(b.Bytes()) + require.NoError(t, writePool.Stop()) + }) + } +} + // genScript creates a script paying out to the address provided, which must // be a valid address. func genScript(t *testing.T, address string) lnwire.DeliveryAddress { diff --git a/peer/test_utils.go b/peer/test_utils.go index 08c4557d..4039ae8a 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -9,6 +9,7 @@ import ( "math/rand" "net" "os" + "testing" "time" "github.com/btcsuite/btcd/btcec" @@ -28,10 +29,15 @@ import ( "github.com/lightningnetwork/lnd/queue" "github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/ticker" + "github.com/stretchr/testify/require" ) const ( broadcastHeight = 100 + + // timeout is a timeout value to use for tests which need to wait for + // a return value on a channel. + timeout = time.Second * 5 ) var ( @@ -443,3 +449,56 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, return alicePeer, channelBob, cleanUpFunc, nil } + +type mockMessageConn struct { + t *testing.T + + // MessageConn embeds our interface so that the mock does not need to + // implement every function. The mock will panic if an unspecified function + // is called. + MessageConn + + // writtenMessages is a channel that our mock pushes written messages into. + writtenMessages chan []byte +} + +func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn { + return &mockMessageConn{ + t: t, + writtenMessages: make(chan []byte, expectedMessages), + } +} + +// SetWriteDeadline mocks setting write deadline for our conn. +func (m *mockMessageConn) SetWriteDeadline(time.Time) error { + return nil +} + +// Flush mocks a message conn flush. +func (m *mockMessageConn) Flush() (int, error) { + return 0, nil +} + +// WriteMessage mocks sending of a message on our connection. It will push +// the bytes sent into the mock's writtenMessages channel. +func (m *mockMessageConn) WriteMessage(msg []byte) error { + select { + case m.writtenMessages <- msg: + case <-time.After(timeout): + m.t.Fatalf("timeout sending message: %v", msg) + } + + return nil +} + +// assertWrite asserts that our mock as had WriteMessage called with the byte +// slice we expect. +func (m *mockMessageConn) assertWrite(expected []byte) { + select { + case actual := <-m.writtenMessages: + require.Equal(m.t, expected, actual) + + case <-time.After(timeout): + m.t.Fatalf("timeout waiting for write: %v", expected) + } +}