Merge pull request #4847 from carlaKC/4800-peerfeaturedowngrade

peer: do not require static remote for peers with legacy channels
This commit is contained in:
Johan T. Halseth 2020-12-10 12:14:07 +01:00 committed by GitHub
commit de66d35a5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 223 additions and 22 deletions

@ -446,9 +446,36 @@ func (p *Brontide) Start() error {
peerLog.Tracef("Peer %v starting", p) peerLog.Tracef("Peer %v starting", p)
// Fetch and then load all the active channels we have with this remote
// peer from the database.
activeChans, err := p.cfg.ChannelDB.FetchOpenChannels(
p.cfg.Addr.IdentityKey,
)
if err != nil {
peerLog.Errorf("Unable to fetch active chans "+
"for peer %v: %v", p, err)
return err
}
if len(activeChans) == 0 {
p.cfg.PrunePersistentPeerConnection(p.cfg.PubKeyBytes)
}
// Quickly check if we have any existing legacy channels with this
// peer.
haveLegacyChan := false
for _, c := range activeChans {
if c.ChanType.IsTweakless() {
continue
}
haveLegacyChan = true
break
}
// Exchange local and global features, the init message should be very // Exchange local and global features, the init message should be very
// first between two nodes. // first between two nodes.
if err := p.sendInitMsg(); err != nil { if err := p.sendInitMsg(haveLegacyChan); err != nil {
return fmt.Errorf("unable to send init msg: %v", err) return fmt.Errorf("unable to send init msg: %v", err)
} }
@ -496,19 +523,6 @@ func (p *Brontide) Start() error {
"must be init message") "must be init message")
} }
// Fetch and then load all the active channels we have with this remote
// peer from the database.
activeChans, err := p.cfg.ChannelDB.FetchOpenChannels(p.cfg.Addr.IdentityKey)
if err != nil {
peerLog.Errorf("unable to fetch active chans "+
"for peer %v: %v", p, err)
return err
}
if len(activeChans) == 0 {
p.cfg.PrunePersistentPeerConnection(p.cfg.PubKeyBytes)
}
// Next, load all the active channels we have with this peer, // Next, load all the active channels we have with this peer,
// registering them with the switch and launching the necessary // registering them with the switch and launching the necessary
// goroutines required to operate them. // goroutines required to operate them.
@ -2752,12 +2766,28 @@ func (p *Brontide) RemoteFeatures() *lnwire.FeatureVector {
return p.remoteFeatures return p.remoteFeatures
} }
// sendInitMsg sends the Init message to the remote peer. This message contains our // sendInitMsg sends the Init message to the remote peer. This message contains
// currently supported local and global features. // our currently supported local and global features.
func (p *Brontide) sendInitMsg() error { func (p *Brontide) sendInitMsg(legacyChan bool) error {
features := p.cfg.Features.Clone()
// If we have a legacy channel open with a peer, we downgrade static
// remote required to optional in case the peer does not understand the
// required feature bit. If we do not do this, the peer will reject our
// connection because it does not understand a required feature bit, and
// our channel will be unusable.
if legacyChan && features.RequiresFeature(lnwire.StaticRemoteKeyRequired) {
peerLog.Infof("Legacy channel open with peer: %x, "+
"downgrading static remote required feature bit to "+
"optional", p.PubKey())
features.Unset(lnwire.StaticRemoteKeyRequired)
features.Set(lnwire.StaticRemoteKeyOptional)
}
msg := lnwire.NewInitMessage( msg := lnwire.NewInitMessage(
p.cfg.LegacyFeatures.RawFeatureVector, p.cfg.LegacyFeatures.RawFeatureVector,
p.cfg.Features.RawFeatureVector, features.RawFeatureVector,
) )
return p.writeMessage(msg) return p.writeMessage(msg)

@ -15,6 +15,8 @@ import (
"github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lnwallet/chancloser" "github.com/lightningnetwork/lnd/lnwallet/chancloser"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/pool"
"github.com/stretchr/testify/require"
) )
var ( var (
@ -23,10 +25,6 @@ var (
// p2wshAddress is a valid pay to witness script hash address. // p2wshAddress is a valid pay to witness script hash address.
p2wshAddress = "bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3" p2wshAddress = "bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3"
// timeout is a timeout value to use for tests which need ot wait for
// a return value on a channel.
timeout = time.Second * 5
) )
// TestPeerChannelClosureAcceptFeeResponder tests the shutdown responder's // TestPeerChannelClosureAcceptFeeResponder tests the shutdown responder's
@ -862,6 +860,120 @@ func TestCustomShutdownScript(t *testing.T) {
} }
} }
// TestStaticRemoteDowngrade tests that we downgrade our static remote feature
// bit to optional if we have legacy channels with a peer. This ensures that
// we can stay connected to peers that don't support the feature bit that we
// have channels with.
func TestStaticRemoteDowngrade(t *testing.T) {
t.Parallel()
var (
// We set the same legacy feature bits for all tests, since
// these are not relevant to our test scenario
rawLegacy = lnwire.NewRawFeatureVector(
lnwire.UpfrontShutdownScriptOptional,
)
legacy = lnwire.NewFeatureVector(rawLegacy, nil)
rawFeatureOptional = lnwire.NewRawFeatureVector(
lnwire.StaticRemoteKeyOptional,
)
featureOptional = lnwire.NewFeatureVector(
rawFeatureOptional, nil,
)
rawFeatureRequired = lnwire.NewRawFeatureVector(
lnwire.StaticRemoteKeyRequired,
)
featureRequired = lnwire.NewFeatureVector(
rawFeatureRequired, nil,
)
)
tests := []struct {
name string
legacy bool
features *lnwire.FeatureVector
expectedInit *lnwire.Init
}{
{
name: "no legacy channel, static optional",
legacy: false,
features: featureOptional,
expectedInit: &lnwire.Init{
GlobalFeatures: rawLegacy,
Features: rawFeatureOptional,
},
},
{
name: "legacy channel, static optional",
legacy: true,
features: featureOptional,
expectedInit: &lnwire.Init{
GlobalFeatures: rawLegacy,
Features: rawFeatureOptional,
},
},
{
name: "no legacy channel, static required",
legacy: false,
features: featureRequired,
expectedInit: &lnwire.Init{
GlobalFeatures: rawLegacy,
Features: rawFeatureRequired,
},
},
{
name: "legacy channel, static required",
legacy: true,
features: featureRequired,
expectedInit: &lnwire.Init{
GlobalFeatures: rawLegacy,
Features: rawFeatureOptional,
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
writeBufferPool := pool.NewWriteBuffer(
pool.DefaultWriteBufferGCInterval,
pool.DefaultWriteBufferExpiryInterval,
)
writePool := pool.NewWrite(
writeBufferPool, 1, timeout,
)
require.NoError(t, writePool.Start())
mockConn := newMockConn(t, 1)
p := Brontide{
cfg: Config{
LegacyFeatures: legacy,
Features: test.features,
Conn: mockConn,
WritePool: writePool,
},
}
var b bytes.Buffer
_, err := lnwire.WriteMessage(&b, test.expectedInit, 0)
require.NoError(t, err)
// Send our init message, assert that we write our expected message
// and shutdown our write pool.
require.NoError(t, p.sendInitMsg(test.legacy))
mockConn.assertWrite(b.Bytes())
require.NoError(t, writePool.Stop())
})
}
}
// genScript creates a script paying out to the address provided, which must // genScript creates a script paying out to the address provided, which must
// be a valid address. // be a valid address.
func genScript(t *testing.T, address string) lnwire.DeliveryAddress { func genScript(t *testing.T, address string) lnwire.DeliveryAddress {

@ -9,6 +9,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"os" "os"
"testing"
"time" "time"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
@ -28,10 +29,15 @@ import (
"github.com/lightningnetwork/lnd/queue" "github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/ticker"
"github.com/stretchr/testify/require"
) )
const ( const (
broadcastHeight = 100 broadcastHeight = 100
// timeout is a timeout value to use for tests which need to wait for
// a return value on a channel.
timeout = time.Second * 5
) )
var ( var (
@ -443,3 +449,56 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
return alicePeer, channelBob, cleanUpFunc, nil return alicePeer, channelBob, cleanUpFunc, nil
} }
type mockMessageConn struct {
t *testing.T
// MessageConn embeds our interface so that the mock does not need to
// implement every function. The mock will panic if an unspecified function
// is called.
MessageConn
// writtenMessages is a channel that our mock pushes written messages into.
writtenMessages chan []byte
}
func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {
return &mockMessageConn{
t: t,
writtenMessages: make(chan []byte, expectedMessages),
}
}
// SetWriteDeadline mocks setting write deadline for our conn.
func (m *mockMessageConn) SetWriteDeadline(time.Time) error {
return nil
}
// Flush mocks a message conn flush.
func (m *mockMessageConn) Flush() (int, error) {
return 0, nil
}
// WriteMessage mocks sending of a message on our connection. It will push
// the bytes sent into the mock's writtenMessages channel.
func (m *mockMessageConn) WriteMessage(msg []byte) error {
select {
case m.writtenMessages <- msg:
case <-time.After(timeout):
m.t.Fatalf("timeout sending message: %v", msg)
}
return nil
}
// assertWrite asserts that our mock as had WriteMessage called with the byte
// slice we expect.
func (m *mockMessageConn) assertWrite(expected []byte) {
select {
case actual := <-m.writtenMessages:
require.Equal(m.t, expected, actual)
case <-time.After(timeout):
m.t.Fatalf("timeout waiting for write: %v", expected)
}
}