Merge pull request #2606 from cfromknecht/wtwire-init-connection-features
watchtower/multi: send connection features + chain hash in Init
This commit is contained in:
commit
f4dfcc35aa
@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/connmgr"
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/btcsuite/btcutil"
|
||||
@ -51,6 +52,9 @@ type Config struct {
|
||||
// NewAddress is used to generate reward addresses, where a cut of
|
||||
// successfully sent funds can be received.
|
||||
NewAddress func() (btcutil.Address, error)
|
||||
|
||||
// ChainHash identifies the network that the server is watching.
|
||||
ChainHash chainhash.Hash
|
||||
}
|
||||
|
||||
// Server houses the state required to handle watchtower peers. It's primary job
|
||||
@ -68,7 +72,7 @@ type Server struct {
|
||||
clients map[wtdb.SessionID]Peer
|
||||
|
||||
globalFeatures *lnwire.RawFeatureVector
|
||||
localFeatures *lnwire.RawFeatureVector
|
||||
connFeatures *lnwire.RawFeatureVector
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
@ -78,7 +82,7 @@ type Server struct {
|
||||
// clients connecting to the listener addresses, and allows them to open
|
||||
// sessions and send state updates.
|
||||
func New(cfg *Config) (*Server, error) {
|
||||
localFeatures := lnwire.NewRawFeatureVector(
|
||||
connFeatures := lnwire.NewRawFeatureVector(
|
||||
wtwire.WtSessionsOptional,
|
||||
)
|
||||
|
||||
@ -86,7 +90,7 @@ func New(cfg *Config) (*Server, error) {
|
||||
cfg: cfg,
|
||||
clients: make(map[wtdb.SessionID]Peer),
|
||||
globalFeatures: lnwire.NewRawFeatureVector(),
|
||||
localFeatures: localFeatures,
|
||||
connFeatures: connFeatures,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
@ -206,7 +210,7 @@ func (s *Server) handleClient(peer Peer) {
|
||||
}
|
||||
|
||||
localInit := wtwire.NewInitMessage(
|
||||
s.localFeatures, s.globalFeatures,
|
||||
s.connFeatures, s.cfg.ChainHash,
|
||||
)
|
||||
|
||||
err = s.sendMessage(peer, localInit)
|
||||
@ -296,25 +300,19 @@ func (s *Server) handleClient(peer Peer) {
|
||||
// handleInit accepts the local and remote Init messages, and verifies that the
|
||||
// client is not requesting any required features that are unknown to the tower.
|
||||
func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error {
|
||||
remoteLocalFeatures := lnwire.NewFeatureVector(
|
||||
remoteInit.LocalFeatures, wtwire.LocalFeatures,
|
||||
)
|
||||
remoteGlobalFeatures := lnwire.NewFeatureVector(
|
||||
remoteInit.GlobalFeatures, wtwire.GlobalFeatures,
|
||||
)
|
||||
|
||||
unknownLocalFeatures := remoteLocalFeatures.UnknownRequiredFeatures()
|
||||
if len(unknownLocalFeatures) > 0 {
|
||||
err := fmt.Errorf("Peer set unknown local feature bits: %v",
|
||||
unknownLocalFeatures)
|
||||
return err
|
||||
if localInit.ChainHash != remoteInit.ChainHash {
|
||||
return fmt.Errorf("Peer chain hash unknown: %x",
|
||||
remoteInit.ChainHash)
|
||||
}
|
||||
|
||||
unknownGlobalFeatures := remoteGlobalFeatures.UnknownRequiredFeatures()
|
||||
if len(unknownGlobalFeatures) > 0 {
|
||||
err := fmt.Errorf("Peer set unknown global feature bits: %v",
|
||||
unknownGlobalFeatures)
|
||||
return err
|
||||
remoteConnFeatures := lnwire.NewFeatureVector(
|
||||
remoteInit.ConnFeatures, wtwire.LocalFeatures,
|
||||
)
|
||||
|
||||
unknownLocalFeatures := remoteConnFeatures.UnknownRequiredFeatures()
|
||||
if len(unknownLocalFeatures) > 0 {
|
||||
return fmt.Errorf("Peer set unknown local feature bits: %v",
|
||||
unknownLocalFeatures)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -27,6 +27,8 @@ var (
|
||||
)
|
||||
|
||||
addrScript, _ = txscript.PayToAddrScript(addr)
|
||||
|
||||
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
|
||||
)
|
||||
|
||||
// randPubKey generates a new secp keypair, and returns the public key.
|
||||
@ -59,6 +61,7 @@ func initServer(t *testing.T, db wtserver.DB,
|
||||
NewAddress: func() (btcutil.Address, error) {
|
||||
return addr, nil
|
||||
},
|
||||
ChainHash: testnetChainHash,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create server: %v", err)
|
||||
@ -91,7 +94,7 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) {
|
||||
|
||||
// Serialize a Init message to be sent by both peers.
|
||||
init := wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
|
||||
lnwire.NewRawFeatureVector(), testnetChainHash,
|
||||
)
|
||||
|
||||
var b bytes.Buffer
|
||||
@ -159,7 +162,7 @@ var createSessionTests = []createSessionTestCase{
|
||||
name: "reject duplicate session create",
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
@ -181,7 +184,7 @@ var createSessionTests = []createSessionTestCase{
|
||||
name: "reject unsupported blob type",
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: 0,
|
||||
@ -279,10 +282,10 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
// Valid update sequence, send seqnum == lastapplied as last update.
|
||||
{
|
||||
name: "perm fail after sending seqnum equal lastapplied",
|
||||
initMsg: &wtwire.Init{&lnwire.Init{
|
||||
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||
}},
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 3,
|
||||
@ -309,10 +312,10 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
// Send update that skips next expected sequence number.
|
||||
{
|
||||
name: "skip sequence number",
|
||||
initMsg: &wtwire.Init{&lnwire.Init{
|
||||
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||
}},
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 4,
|
||||
@ -333,10 +336,10 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
// Send update that reverts to older sequence number.
|
||||
{
|
||||
name: "revert to older seqnum",
|
||||
initMsg: &wtwire.Init{&lnwire.Init{
|
||||
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||
}},
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 4,
|
||||
@ -361,10 +364,10 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
// Send update echoing a last applied that is lower than previous value.
|
||||
{
|
||||
name: "revert to older lastapplied",
|
||||
initMsg: &wtwire.Init{&lnwire.Init{
|
||||
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||
}},
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 4,
|
||||
@ -389,10 +392,10 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
// Client echos last applied as they are received.
|
||||
{
|
||||
name: "resume after disconnect",
|
||||
initMsg: &wtwire.Init{&lnwire.Init{
|
||||
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||
}},
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 4,
|
||||
@ -419,10 +422,10 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
// Client doesn't echo last applied until last message.
|
||||
{
|
||||
name: "resume after disconnect lagging lastapplied",
|
||||
initMsg: &wtwire.Init{&lnwire.Init{
|
||||
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||
}},
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 4,
|
||||
@ -448,10 +451,10 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
// Send update with sequence number that exceeds MaxUpdates.
|
||||
{
|
||||
name: "seqnum exceed maxupdates",
|
||||
initMsg: &wtwire.Init{&lnwire.Init{
|
||||
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||
}},
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 3,
|
||||
@ -478,10 +481,10 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
// Ensure sequence number 0 causes permanent failure.
|
||||
{
|
||||
name: "perm fail after seqnum 0",
|
||||
initMsg: &wtwire.Init{&lnwire.Init{
|
||||
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||
}},
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 3,
|
||||
|
@ -1,23 +1,59 @@
|
||||
package wtwire
|
||||
|
||||
import "github.com/lightningnetwork/lnd/lnwire"
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// Init is the first message sent over the watchtower wire protocol, and
|
||||
// specifies features and level of requiredness maintained by the sending node.
|
||||
// The watchtower Init message is identical to the LN Init message, except it
|
||||
// uses a different message type to ensure the two are not conflated.
|
||||
// specifies connection features bits and level of requiredness maintained by
|
||||
// the sending node. The Init message also sends the chain hash identifying the
|
||||
// network that the sender is on.
|
||||
type Init struct {
|
||||
*lnwire.Init
|
||||
// ConnFeatures are the feature bits being advertised for the duration
|
||||
// of a single connection with a peer.
|
||||
ConnFeatures *lnwire.RawFeatureVector
|
||||
|
||||
// ChainHash is the genesis hash of the chain that the advertiser claims
|
||||
// to be on.
|
||||
ChainHash chainhash.Hash
|
||||
}
|
||||
|
||||
// NewInitMessage generates a new Init message from raw global and local feature
|
||||
// vectors.
|
||||
func NewInitMessage(gf, lf *lnwire.RawFeatureVector) *Init {
|
||||
// NewInitMessage generates a new Init message from a raw connection feature
|
||||
// vector and chain hash.
|
||||
func NewInitMessage(connFeatures *lnwire.RawFeatureVector,
|
||||
chainHash chainhash.Hash) *Init {
|
||||
|
||||
return &Init{
|
||||
Init: lnwire.NewInitMessage(gf, lf),
|
||||
ConnFeatures: connFeatures,
|
||||
ChainHash: chainHash,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode serializes the target Init into the passed io.Writer observing the
|
||||
// protocol version specified.
|
||||
//
|
||||
// This is part of the wtwire.Message interface.
|
||||
func (msg *Init) Encode(w io.Writer, pver uint32) error {
|
||||
return WriteElements(w,
|
||||
msg.ConnFeatures,
|
||||
msg.ChainHash,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode deserializes a serialized Init message stored in the passed io.Reader
|
||||
// observing the specified protocol version.
|
||||
//
|
||||
// This is part of the wtwire.Message interface.
|
||||
func (msg *Init) Decode(r io.Reader, pver uint32) error {
|
||||
return ReadElements(r,
|
||||
&msg.ConnFeatures,
|
||||
&msg.ChainHash,
|
||||
)
|
||||
}
|
||||
|
||||
// MsgType returns the integer uniquely identifying this message type on the
|
||||
// wire.
|
||||
//
|
||||
@ -26,5 +62,13 @@ func (msg *Init) MsgType() MessageType {
|
||||
return MsgInit
|
||||
}
|
||||
|
||||
// MaxPayloadLength returns the maximum allowed payload size for an Init
|
||||
// complete message observing the specified protocol version.
|
||||
//
|
||||
// This is part of the wtwire.Message interface.
|
||||
func (msg *Init) MaxPayloadLength(uint32) uint32 {
|
||||
return MaxMessagePayload
|
||||
}
|
||||
|
||||
// A compile-time constraint to ensure Init implements the Message interface.
|
||||
var _ Message = (*Init)(nil)
|
||||
|
@ -88,7 +88,7 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
|
||||
|
||||
switch msgType {
|
||||
case MsgInit:
|
||||
msg = &Init{&lnwire.Init{}}
|
||||
msg = &Init{}
|
||||
case MsgCreateSession:
|
||||
msg = &CreateSession{}
|
||||
case MsgCreateSessionReply:
|
||||
|
@ -6,8 +6,10 @@ import (
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
)
|
||||
|
||||
@ -86,6 +88,20 @@ func WriteElement(w io.Writer, element interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
case chainhash.Hash:
|
||||
if _, err := w.Write(e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *lnwire.RawFeatureVector:
|
||||
if e == nil {
|
||||
return fmt.Errorf("cannot write nil feature vector")
|
||||
}
|
||||
|
||||
if err := e.Encode(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *btcec.PublicKey:
|
||||
if e == nil {
|
||||
return fmt.Errorf("cannot write nil pubkey")
|
||||
@ -192,6 +208,20 @@ func ReadElement(r io.Reader, element interface{}) error {
|
||||
}
|
||||
*e = ErrorCode(binary.BigEndian.Uint16(b[:]))
|
||||
|
||||
case *chainhash.Hash:
|
||||
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case **lnwire.RawFeatureVector:
|
||||
f := lnwire.NewRawFeatureVector()
|
||||
err := f.Decode(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*e = f
|
||||
|
||||
case **btcec.PublicKey:
|
||||
var b [btcec.PubKeyBytesLenCompressed]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||
@ -23,6 +24,12 @@ func randRawFeatureVector(r *rand.Rand) *lnwire.RawFeatureVector {
|
||||
return featureVec
|
||||
}
|
||||
|
||||
func randChainHash(r *rand.Rand) chainhash.Hash {
|
||||
var hash chainhash.Hash
|
||||
r.Read(hash[:])
|
||||
return hash
|
||||
}
|
||||
|
||||
// TestWatchtowerWireProtocol uses the testing/quick package to create a series
|
||||
// of fuzz tests to attempt to break a primary scenario which is implemented as
|
||||
// property based testing scenario.
|
||||
@ -73,7 +80,7 @@ func TestWatchtowerWireProtocol(t *testing.T) {
|
||||
wtwire.MsgInit: func(v []reflect.Value, r *rand.Rand) {
|
||||
req := wtwire.NewInitMessage(
|
||||
randRawFeatureVector(r),
|
||||
randRawFeatureVector(r),
|
||||
randChainHash(r),
|
||||
)
|
||||
|
||||
v[0] = reflect.ValueOf(*req)
|
||||
|
Loading…
Reference in New Issue
Block a user