Merge pull request #4847 from carlaKC/4800-peerfeaturedowngrade
peer: do not require static remote for peers with legacy channels
This commit is contained in:
commit
de66d35a5b
@ -446,9 +446,36 @@ func (p *Brontide) Start() error {
|
|||||||
|
|
||||||
peerLog.Tracef("Peer %v starting", p)
|
peerLog.Tracef("Peer %v starting", p)
|
||||||
|
|
||||||
|
// Fetch and then load all the active channels we have with this remote
|
||||||
|
// peer from the database.
|
||||||
|
activeChans, err := p.cfg.ChannelDB.FetchOpenChannels(
|
||||||
|
p.cfg.Addr.IdentityKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
peerLog.Errorf("Unable to fetch active chans "+
|
||||||
|
"for peer %v: %v", p, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(activeChans) == 0 {
|
||||||
|
p.cfg.PrunePersistentPeerConnection(p.cfg.PubKeyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quickly check if we have any existing legacy channels with this
|
||||||
|
// peer.
|
||||||
|
haveLegacyChan := false
|
||||||
|
for _, c := range activeChans {
|
||||||
|
if c.ChanType.IsTweakless() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
haveLegacyChan = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
// Exchange local and global features, the init message should be very
|
// Exchange local and global features, the init message should be very
|
||||||
// first between two nodes.
|
// first between two nodes.
|
||||||
if err := p.sendInitMsg(); err != nil {
|
if err := p.sendInitMsg(haveLegacyChan); err != nil {
|
||||||
return fmt.Errorf("unable to send init msg: %v", err)
|
return fmt.Errorf("unable to send init msg: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -496,19 +523,6 @@ func (p *Brontide) Start() error {
|
|||||||
"must be init message")
|
"must be init message")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch and then load all the active channels we have with this remote
|
|
||||||
// peer from the database.
|
|
||||||
activeChans, err := p.cfg.ChannelDB.FetchOpenChannels(p.cfg.Addr.IdentityKey)
|
|
||||||
if err != nil {
|
|
||||||
peerLog.Errorf("unable to fetch active chans "+
|
|
||||||
"for peer %v: %v", p, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(activeChans) == 0 {
|
|
||||||
p.cfg.PrunePersistentPeerConnection(p.cfg.PubKeyBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, load all the active channels we have with this peer,
|
// Next, load all the active channels we have with this peer,
|
||||||
// registering them with the switch and launching the necessary
|
// registering them with the switch and launching the necessary
|
||||||
// goroutines required to operate them.
|
// goroutines required to operate them.
|
||||||
@ -2752,12 +2766,28 @@ func (p *Brontide) RemoteFeatures() *lnwire.FeatureVector {
|
|||||||
return p.remoteFeatures
|
return p.remoteFeatures
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendInitMsg sends the Init message to the remote peer. This message contains our
|
// sendInitMsg sends the Init message to the remote peer. This message contains
|
||||||
// currently supported local and global features.
|
// our currently supported local and global features.
|
||||||
func (p *Brontide) sendInitMsg() error {
|
func (p *Brontide) sendInitMsg(legacyChan bool) error {
|
||||||
|
features := p.cfg.Features.Clone()
|
||||||
|
|
||||||
|
// If we have a legacy channel open with a peer, we downgrade static
|
||||||
|
// remote required to optional in case the peer does not understand the
|
||||||
|
// required feature bit. If we do not do this, the peer will reject our
|
||||||
|
// connection because it does not understand a required feature bit, and
|
||||||
|
// our channel will be unusable.
|
||||||
|
if legacyChan && features.RequiresFeature(lnwire.StaticRemoteKeyRequired) {
|
||||||
|
peerLog.Infof("Legacy channel open with peer: %x, "+
|
||||||
|
"downgrading static remote required feature bit to "+
|
||||||
|
"optional", p.PubKey())
|
||||||
|
|
||||||
|
features.Unset(lnwire.StaticRemoteKeyRequired)
|
||||||
|
features.Set(lnwire.StaticRemoteKeyOptional)
|
||||||
|
}
|
||||||
|
|
||||||
msg := lnwire.NewInitMessage(
|
msg := lnwire.NewInitMessage(
|
||||||
p.cfg.LegacyFeatures.RawFeatureVector,
|
p.cfg.LegacyFeatures.RawFeatureVector,
|
||||||
p.cfg.Features.RawFeatureVector,
|
features.RawFeatureVector,
|
||||||
)
|
)
|
||||||
|
|
||||||
return p.writeMessage(msg)
|
return p.writeMessage(msg)
|
||||||
|
@ -15,6 +15,8 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/lntest/mock"
|
"github.com/lightningnetwork/lnd/lntest/mock"
|
||||||
"github.com/lightningnetwork/lnd/lnwallet/chancloser"
|
"github.com/lightningnetwork/lnd/lnwallet/chancloser"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/pool"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -23,10 +25,6 @@ var (
|
|||||||
|
|
||||||
// p2wshAddress is a valid pay to witness script hash address.
|
// p2wshAddress is a valid pay to witness script hash address.
|
||||||
p2wshAddress = "bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3"
|
p2wshAddress = "bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3"
|
||||||
|
|
||||||
// timeout is a timeout value to use for tests which need ot wait for
|
|
||||||
// a return value on a channel.
|
|
||||||
timeout = time.Second * 5
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestPeerChannelClosureAcceptFeeResponder tests the shutdown responder's
|
// TestPeerChannelClosureAcceptFeeResponder tests the shutdown responder's
|
||||||
@ -862,6 +860,120 @@ func TestCustomShutdownScript(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestStaticRemoteDowngrade tests that we downgrade our static remote feature
|
||||||
|
// bit to optional if we have legacy channels with a peer. This ensures that
|
||||||
|
// we can stay connected to peers that don't support the feature bit that we
|
||||||
|
// have channels with.
|
||||||
|
func TestStaticRemoteDowngrade(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
// We set the same legacy feature bits for all tests, since
|
||||||
|
// these are not relevant to our test scenario
|
||||||
|
rawLegacy = lnwire.NewRawFeatureVector(
|
||||||
|
lnwire.UpfrontShutdownScriptOptional,
|
||||||
|
)
|
||||||
|
legacy = lnwire.NewFeatureVector(rawLegacy, nil)
|
||||||
|
|
||||||
|
rawFeatureOptional = lnwire.NewRawFeatureVector(
|
||||||
|
lnwire.StaticRemoteKeyOptional,
|
||||||
|
)
|
||||||
|
|
||||||
|
featureOptional = lnwire.NewFeatureVector(
|
||||||
|
rawFeatureOptional, nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
rawFeatureRequired = lnwire.NewRawFeatureVector(
|
||||||
|
lnwire.StaticRemoteKeyRequired,
|
||||||
|
)
|
||||||
|
|
||||||
|
featureRequired = lnwire.NewFeatureVector(
|
||||||
|
rawFeatureRequired, nil,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
legacy bool
|
||||||
|
features *lnwire.FeatureVector
|
||||||
|
expectedInit *lnwire.Init
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no legacy channel, static optional",
|
||||||
|
legacy: false,
|
||||||
|
features: featureOptional,
|
||||||
|
expectedInit: &lnwire.Init{
|
||||||
|
GlobalFeatures: rawLegacy,
|
||||||
|
Features: rawFeatureOptional,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy channel, static optional",
|
||||||
|
legacy: true,
|
||||||
|
features: featureOptional,
|
||||||
|
expectedInit: &lnwire.Init{
|
||||||
|
GlobalFeatures: rawLegacy,
|
||||||
|
Features: rawFeatureOptional,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no legacy channel, static required",
|
||||||
|
legacy: false,
|
||||||
|
features: featureRequired,
|
||||||
|
expectedInit: &lnwire.Init{
|
||||||
|
GlobalFeatures: rawLegacy,
|
||||||
|
Features: rawFeatureRequired,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy channel, static required",
|
||||||
|
legacy: true,
|
||||||
|
features: featureRequired,
|
||||||
|
expectedInit: &lnwire.Init{
|
||||||
|
GlobalFeatures: rawLegacy,
|
||||||
|
Features: rawFeatureOptional,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
writeBufferPool := pool.NewWriteBuffer(
|
||||||
|
pool.DefaultWriteBufferGCInterval,
|
||||||
|
pool.DefaultWriteBufferExpiryInterval,
|
||||||
|
)
|
||||||
|
|
||||||
|
writePool := pool.NewWrite(
|
||||||
|
writeBufferPool, 1, timeout,
|
||||||
|
)
|
||||||
|
require.NoError(t, writePool.Start())
|
||||||
|
|
||||||
|
mockConn := newMockConn(t, 1)
|
||||||
|
|
||||||
|
p := Brontide{
|
||||||
|
cfg: Config{
|
||||||
|
LegacyFeatures: legacy,
|
||||||
|
Features: test.features,
|
||||||
|
Conn: mockConn,
|
||||||
|
WritePool: writePool,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
_, err := lnwire.WriteMessage(&b, test.expectedInit, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Send our init message, assert that we write our expected message
|
||||||
|
// and shutdown our write pool.
|
||||||
|
require.NoError(t, p.sendInitMsg(test.legacy))
|
||||||
|
mockConn.assertWrite(b.Bytes())
|
||||||
|
require.NoError(t, writePool.Stop())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// genScript creates a script paying out to the address provided, which must
|
// genScript creates a script paying out to the address provided, which must
|
||||||
// be a valid address.
|
// be a valid address.
|
||||||
func genScript(t *testing.T, address string) lnwire.DeliveryAddress {
|
func genScript(t *testing.T, address string) lnwire.DeliveryAddress {
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
@ -28,10 +29,15 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/queue"
|
"github.com/lightningnetwork/lnd/queue"
|
||||||
"github.com/lightningnetwork/lnd/shachain"
|
"github.com/lightningnetwork/lnd/shachain"
|
||||||
"github.com/lightningnetwork/lnd/ticker"
|
"github.com/lightningnetwork/lnd/ticker"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
broadcastHeight = 100
|
broadcastHeight = 100
|
||||||
|
|
||||||
|
// timeout is a timeout value to use for tests which need to wait for
|
||||||
|
// a return value on a channel.
|
||||||
|
timeout = time.Second * 5
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -443,3 +449,56 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
|
|||||||
|
|
||||||
return alicePeer, channelBob, cleanUpFunc, nil
|
return alicePeer, channelBob, cleanUpFunc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockMessageConn struct {
|
||||||
|
t *testing.T
|
||||||
|
|
||||||
|
// MessageConn embeds our interface so that the mock does not need to
|
||||||
|
// implement every function. The mock will panic if an unspecified function
|
||||||
|
// is called.
|
||||||
|
MessageConn
|
||||||
|
|
||||||
|
// writtenMessages is a channel that our mock pushes written messages into.
|
||||||
|
writtenMessages chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {
|
||||||
|
return &mockMessageConn{
|
||||||
|
t: t,
|
||||||
|
writtenMessages: make(chan []byte, expectedMessages),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline mocks setting write deadline for our conn.
|
||||||
|
func (m *mockMessageConn) SetWriteDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush mocks a message conn flush.
|
||||||
|
func (m *mockMessageConn) Flush() (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteMessage mocks sending of a message on our connection. It will push
|
||||||
|
// the bytes sent into the mock's writtenMessages channel.
|
||||||
|
func (m *mockMessageConn) WriteMessage(msg []byte) error {
|
||||||
|
select {
|
||||||
|
case m.writtenMessages <- msg:
|
||||||
|
case <-time.After(timeout):
|
||||||
|
m.t.Fatalf("timeout sending message: %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertWrite asserts that our mock as had WriteMessage called with the byte
|
||||||
|
// slice we expect.
|
||||||
|
func (m *mockMessageConn) assertWrite(expected []byte) {
|
||||||
|
select {
|
||||||
|
case actual := <-m.writtenMessages:
|
||||||
|
require.Equal(m.t, expected, actual)
|
||||||
|
|
||||||
|
case <-time.After(timeout):
|
||||||
|
m.t.Fatalf("timeout waiting for write: %v", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user