diff --git a/peer/brontide_test.go b/peer/brontide_test.go index 6150fe27..77734d66 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -15,6 +15,8 @@ import ( "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet/chancloser" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/pool" + "github.com/stretchr/testify/require" ) var ( @@ -23,10 +25,6 @@ var ( // p2wshAddress is a valid pay to witness script hash address. p2wshAddress = "bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3" - - // timeout is a timeout value to use for tests which need ot wait for - // a return value on a channel. - timeout = time.Second * 5 ) // TestPeerChannelClosureAcceptFeeResponder tests the shutdown responder's @@ -862,6 +860,120 @@ func TestCustomShutdownScript(t *testing.T) { } } +// TestStaticRemoteDowngrade tests that we downgrade our static remote feature +// bit to optional if we have legacy channels with a peer. This ensures that +// we can stay connected to peers that don't support the feature bit that we +// have channels with. +func TestStaticRemoteDowngrade(t *testing.T) { + t.Parallel() + + var ( + // We set the same legacy feature bits for all tests, since + // these are not relevant to our test scenario + rawLegacy = lnwire.NewRawFeatureVector( + lnwire.UpfrontShutdownScriptOptional, + ) + legacy = lnwire.NewFeatureVector(rawLegacy, nil) + + rawFeatureOptional = lnwire.NewRawFeatureVector( + lnwire.StaticRemoteKeyOptional, + ) + + featureOptional = lnwire.NewFeatureVector( + rawFeatureOptional, nil, + ) + + rawFeatureRequired = lnwire.NewRawFeatureVector( + lnwire.StaticRemoteKeyRequired, + ) + + featureRequired = lnwire.NewFeatureVector( + rawFeatureRequired, nil, + ) + ) + + tests := []struct { + name string + legacy bool + features *lnwire.FeatureVector + expectedInit *lnwire.Init + }{ + { + name: "no legacy channel, static optional", + legacy: false, + features: featureOptional, + expectedInit: &lnwire.Init{ + GlobalFeatures: rawLegacy, + Features: rawFeatureOptional, + }, + }, + { + name: "legacy channel, static optional", + legacy: true, + features: featureOptional, + expectedInit: &lnwire.Init{ + GlobalFeatures: rawLegacy, + Features: rawFeatureOptional, + }, + }, + { + name: "no legacy channel, static required", + legacy: false, + features: featureRequired, + expectedInit: &lnwire.Init{ + GlobalFeatures: rawLegacy, + Features: rawFeatureRequired, + }, + }, + { + name: "legacy channel, static required", + legacy: true, + features: featureRequired, + expectedInit: &lnwire.Init{ + GlobalFeatures: rawLegacy, + Features: rawFeatureOptional, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + writeBufferPool := pool.NewWriteBuffer( + pool.DefaultWriteBufferGCInterval, + pool.DefaultWriteBufferExpiryInterval, + ) + + writePool := pool.NewWrite( + writeBufferPool, 1, timeout, + ) + require.NoError(t, writePool.Start()) + + mockConn := newMockConn(t, 1) + + p := Brontide{ + cfg: Config{ + LegacyFeatures: legacy, + Features: test.features, + Conn: mockConn, + WritePool: writePool, + }, + } + + var b bytes.Buffer + _, err := lnwire.WriteMessage(&b, test.expectedInit, 0) + require.NoError(t, err) + + // Send our init message, assert that we write our expected message + // and shutdown our write pool. + require.NoError(t, p.sendInitMsg(test.legacy)) + mockConn.assertWrite(b.Bytes()) + require.NoError(t, writePool.Stop()) + }) + } +} + // genScript creates a script paying out to the address provided, which must // be a valid address. func genScript(t *testing.T, address string) lnwire.DeliveryAddress { diff --git a/peer/test_utils.go b/peer/test_utils.go index 08c4557d..4039ae8a 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -9,6 +9,7 @@ import ( "math/rand" "net" "os" + "testing" "time" "github.com/btcsuite/btcd/btcec" @@ -28,10 +29,15 @@ import ( "github.com/lightningnetwork/lnd/queue" "github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/ticker" + "github.com/stretchr/testify/require" ) const ( broadcastHeight = 100 + + // timeout is a timeout value to use for tests which need to wait for + // a return value on a channel. + timeout = time.Second * 5 ) var ( @@ -443,3 +449,56 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, return alicePeer, channelBob, cleanUpFunc, nil } + +type mockMessageConn struct { + t *testing.T + + // MessageConn embeds our interface so that the mock does not need to + // implement every function. The mock will panic if an unspecified function + // is called. + MessageConn + + // writtenMessages is a channel that our mock pushes written messages into. + writtenMessages chan []byte +} + +func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn { + return &mockMessageConn{ + t: t, + writtenMessages: make(chan []byte, expectedMessages), + } +} + +// SetWriteDeadline mocks setting write deadline for our conn. +func (m *mockMessageConn) SetWriteDeadline(time.Time) error { + return nil +} + +// Flush mocks a message conn flush. +func (m *mockMessageConn) Flush() (int, error) { + return 0, nil +} + +// WriteMessage mocks sending of a message on our connection. It will push +// the bytes sent into the mock's writtenMessages channel. +func (m *mockMessageConn) WriteMessage(msg []byte) error { + select { + case m.writtenMessages <- msg: + case <-time.After(timeout): + m.t.Fatalf("timeout sending message: %v", msg) + } + + return nil +} + +// assertWrite asserts that our mock as had WriteMessage called with the byte +// slice we expect. +func (m *mockMessageConn) assertWrite(expected []byte) { + select { + case actual := <-m.writtenMessages: + require.Equal(m.t, expected, actual) + + case <-time.After(timeout): + m.t.Fatalf("timeout waiting for write: %v", expected) + } +}