watchtower/multi: send connection features + chain hash in Init

This commit is contained in:
Conner Fromknecht 2019-02-06 20:06:44 -08:00
parent eaea92e2cf
commit 4dbade64dd
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
6 changed files with 149 additions and 67 deletions

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/connmgr"
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
@ -51,6 +52,9 @@ type Config struct {
// NewAddress is used to generate reward addresses, where a cut of // NewAddress is used to generate reward addresses, where a cut of
// successfully sent funds can be received. // successfully sent funds can be received.
NewAddress func() (btcutil.Address, error) NewAddress func() (btcutil.Address, error)
// ChainHash identifies the network that the server is watching.
ChainHash chainhash.Hash
} }
// Server houses the state required to handle watchtower peers. It's primary job // Server houses the state required to handle watchtower peers. It's primary job
@ -68,7 +72,7 @@ type Server struct {
clients map[wtdb.SessionID]Peer clients map[wtdb.SessionID]Peer
globalFeatures *lnwire.RawFeatureVector globalFeatures *lnwire.RawFeatureVector
localFeatures *lnwire.RawFeatureVector connFeatures *lnwire.RawFeatureVector
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
@ -78,7 +82,7 @@ type Server struct {
// clients connecting to the listener addresses, and allows them to open // clients connecting to the listener addresses, and allows them to open
// sessions and send state updates. // sessions and send state updates.
func New(cfg *Config) (*Server, error) { func New(cfg *Config) (*Server, error) {
localFeatures := lnwire.NewRawFeatureVector( connFeatures := lnwire.NewRawFeatureVector(
wtwire.WtSessionsOptional, wtwire.WtSessionsOptional,
) )
@ -86,7 +90,7 @@ func New(cfg *Config) (*Server, error) {
cfg: cfg, cfg: cfg,
clients: make(map[wtdb.SessionID]Peer), clients: make(map[wtdb.SessionID]Peer),
globalFeatures: lnwire.NewRawFeatureVector(), globalFeatures: lnwire.NewRawFeatureVector(),
localFeatures: localFeatures, connFeatures: connFeatures,
quit: make(chan struct{}), quit: make(chan struct{}),
} }
@ -206,7 +210,7 @@ func (s *Server) handleClient(peer Peer) {
} }
localInit := wtwire.NewInitMessage( localInit := wtwire.NewInitMessage(
s.localFeatures, s.globalFeatures, s.connFeatures, s.cfg.ChainHash,
) )
err = s.sendMessage(peer, localInit) err = s.sendMessage(peer, localInit)
@ -296,25 +300,19 @@ func (s *Server) handleClient(peer Peer) {
// handleInit accepts the local and remote Init messages, and verifies that the // handleInit accepts the local and remote Init messages, and verifies that the
// client is not requesting any required features that are unknown to the tower. // client is not requesting any required features that are unknown to the tower.
func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error { func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error {
remoteLocalFeatures := lnwire.NewFeatureVector( if localInit.ChainHash != remoteInit.ChainHash {
remoteInit.LocalFeatures, wtwire.LocalFeatures, return fmt.Errorf("Peer chain hash unknown: %x",
) remoteInit.ChainHash)
remoteGlobalFeatures := lnwire.NewFeatureVector(
remoteInit.GlobalFeatures, wtwire.GlobalFeatures,
)
unknownLocalFeatures := remoteLocalFeatures.UnknownRequiredFeatures()
if len(unknownLocalFeatures) > 0 {
err := fmt.Errorf("Peer set unknown local feature bits: %v",
unknownLocalFeatures)
return err
} }
unknownGlobalFeatures := remoteGlobalFeatures.UnknownRequiredFeatures() remoteConnFeatures := lnwire.NewFeatureVector(
if len(unknownGlobalFeatures) > 0 { remoteInit.ConnFeatures, wtwire.LocalFeatures,
err := fmt.Errorf("Peer set unknown global feature bits: %v", )
unknownGlobalFeatures)
return err unknownLocalFeatures := remoteConnFeatures.UnknownRequiredFeatures()
if len(unknownLocalFeatures) > 0 {
return fmt.Errorf("Peer set unknown local feature bits: %v",
unknownLocalFeatures)
} }
return nil return nil

@ -27,6 +27,8 @@ var (
) )
addrScript, _ = txscript.PayToAddrScript(addr) addrScript, _ = txscript.PayToAddrScript(addr)
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
) )
// randPubKey generates a new secp keypair, and returns the public key. // randPubKey generates a new secp keypair, and returns the public key.
@ -59,6 +61,7 @@ func initServer(t *testing.T, db wtserver.DB,
NewAddress: func() (btcutil.Address, error) { NewAddress: func() (btcutil.Address, error) {
return addr, nil return addr, nil
}, },
ChainHash: testnetChainHash,
}) })
if err != nil { if err != nil {
t.Fatalf("unable to create server: %v", err) t.Fatalf("unable to create server: %v", err)
@ -91,7 +94,7 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) {
// Serialize a Init message to be sent by both peers. // Serialize a Init message to be sent by both peers.
init := wtwire.NewInitMessage( init := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(), testnetChainHash,
) )
var b bytes.Buffer var b bytes.Buffer
@ -159,7 +162,7 @@ var createSessionTests = []createSessionTestCase{
name: "reject duplicate session create", name: "reject duplicate session create",
initMsg: wtwire.NewInitMessage( initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
lnwire.NewRawFeatureVector(), testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeDefault,
@ -181,7 +184,7 @@ var createSessionTests = []createSessionTestCase{
name: "reject unsupported blob type", name: "reject unsupported blob type",
initMsg: wtwire.NewInitMessage( initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
lnwire.NewRawFeatureVector(), testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: 0, BlobType: 0,
@ -279,10 +282,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Valid update sequence, send seqnum == lastapplied as last update. // Valid update sequence, send seqnum == lastapplied as last update.
{ {
name: "perm fail after sending seqnum equal lastapplied", name: "perm fail after sending seqnum equal lastapplied",
initMsg: &wtwire.Init{&lnwire.Init{ initMsg: wtwire.NewInitMessage(
LocalFeatures: lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(), testnetChainHash,
}}, ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeDefault,
MaxUpdates: 3, MaxUpdates: 3,
@ -309,10 +312,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Send update that skips next expected sequence number. // Send update that skips next expected sequence number.
{ {
name: "skip sequence number", name: "skip sequence number",
initMsg: &wtwire.Init{&lnwire.Init{ initMsg: wtwire.NewInitMessage(
LocalFeatures: lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(), testnetChainHash,
}}, ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
@ -333,10 +336,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Send update that reverts to older sequence number. // Send update that reverts to older sequence number.
{ {
name: "revert to older seqnum", name: "revert to older seqnum",
initMsg: &wtwire.Init{&lnwire.Init{ initMsg: wtwire.NewInitMessage(
LocalFeatures: lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(), testnetChainHash,
}}, ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
@ -361,10 +364,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Send update echoing a last applied that is lower than previous value. // Send update echoing a last applied that is lower than previous value.
{ {
name: "revert to older lastapplied", name: "revert to older lastapplied",
initMsg: &wtwire.Init{&lnwire.Init{ initMsg: wtwire.NewInitMessage(
LocalFeatures: lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(), testnetChainHash,
}}, ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
@ -389,10 +392,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Client echos last applied as they are received. // Client echos last applied as they are received.
{ {
name: "resume after disconnect", name: "resume after disconnect",
initMsg: &wtwire.Init{&lnwire.Init{ initMsg: wtwire.NewInitMessage(
LocalFeatures: lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(), testnetChainHash,
}}, ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
@ -419,10 +422,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Client doesn't echo last applied until last message. // Client doesn't echo last applied until last message.
{ {
name: "resume after disconnect lagging lastapplied", name: "resume after disconnect lagging lastapplied",
initMsg: &wtwire.Init{&lnwire.Init{ initMsg: wtwire.NewInitMessage(
LocalFeatures: lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(), testnetChainHash,
}}, ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
@ -448,10 +451,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Send update with sequence number that exceeds MaxUpdates. // Send update with sequence number that exceeds MaxUpdates.
{ {
name: "seqnum exceed maxupdates", name: "seqnum exceed maxupdates",
initMsg: &wtwire.Init{&lnwire.Init{ initMsg: wtwire.NewInitMessage(
LocalFeatures: lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(), testnetChainHash,
}}, ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeDefault,
MaxUpdates: 3, MaxUpdates: 3,
@ -478,10 +481,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Ensure sequence number 0 causes permanent failure. // Ensure sequence number 0 causes permanent failure.
{ {
name: "perm fail after seqnum 0", name: "perm fail after seqnum 0",
initMsg: &wtwire.Init{&lnwire.Init{ initMsg: wtwire.NewInitMessage(
LocalFeatures: lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(), testnetChainHash,
}}, ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeDefault,
MaxUpdates: 3, MaxUpdates: 3,

@ -1,23 +1,59 @@
package wtwire package wtwire
import "github.com/lightningnetwork/lnd/lnwire" import (
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/lnwire"
)
// Init is the first message sent over the watchtower wire protocol, and // Init is the first message sent over the watchtower wire protocol, and
// specifies features and level of requiredness maintained by the sending node. // specifies connection features bits and level of requiredness maintained by
// The watchtower Init message is identical to the LN Init message, except it // the sending node. The Init message also sends the chain hash identifying the
// uses a different message type to ensure the two are not conflated. // network that the sender is on.
type Init struct { type Init struct {
*lnwire.Init // ConnFeatures are the feature bits being advertised for the duration
// of a single connection with a peer.
ConnFeatures *lnwire.RawFeatureVector
// ChainHash is the genesis hash of the chain that the advertiser claims
// to be on.
ChainHash chainhash.Hash
} }
// NewInitMessage generates a new Init message from raw global and local feature // NewInitMessage generates a new Init message from a raw connection feature
// vectors. // vector and chain hash.
func NewInitMessage(gf, lf *lnwire.RawFeatureVector) *Init { func NewInitMessage(connFeatures *lnwire.RawFeatureVector,
chainHash chainhash.Hash) *Init {
return &Init{ return &Init{
Init: lnwire.NewInitMessage(gf, lf), ConnFeatures: connFeatures,
ChainHash: chainHash,
} }
} }
// Encode serializes the target Init into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the wtwire.Message interface.
func (msg *Init) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
msg.ConnFeatures,
msg.ChainHash,
)
}
// Decode deserializes a serialized Init message stored in the passed io.Reader
// observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (msg *Init) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&msg.ConnFeatures,
&msg.ChainHash,
)
}
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the
// wire. // wire.
// //
@ -26,5 +62,13 @@ func (msg *Init) MsgType() MessageType {
return MsgInit return MsgInit
} }
// MaxPayloadLength returns the maximum allowed payload size for an Init
// complete message observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (msg *Init) MaxPayloadLength(uint32) uint32 {
return MaxMessagePayload
}
// A compile-time constraint to ensure Init implements the Message interface. // A compile-time constraint to ensure Init implements the Message interface.
var _ Message = (*Init)(nil) var _ Message = (*Init)(nil)

@ -88,7 +88,7 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
switch msgType { switch msgType {
case MsgInit: case MsgInit:
msg = &Init{&lnwire.Init{}} msg = &Init{}
case MsgCreateSession: case MsgCreateSession:
msg = &CreateSession{} msg = &CreateSession{}
case MsgCreateSessionReply: case MsgCreateSessionReply:

@ -6,8 +6,10 @@ import (
"io" "io"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/blob"
) )
@ -86,6 +88,20 @@ func WriteElement(w io.Writer, element interface{}) error {
return err return err
} }
case chainhash.Hash:
if _, err := w.Write(e[:]); err != nil {
return err
}
case *lnwire.RawFeatureVector:
if e == nil {
return fmt.Errorf("cannot write nil feature vector")
}
if err := e.Encode(w); err != nil {
return err
}
case *btcec.PublicKey: case *btcec.PublicKey:
if e == nil { if e == nil {
return fmt.Errorf("cannot write nil pubkey") return fmt.Errorf("cannot write nil pubkey")
@ -192,6 +208,20 @@ func ReadElement(r io.Reader, element interface{}) error {
} }
*e = ErrorCode(binary.BigEndian.Uint16(b[:])) *e = ErrorCode(binary.BigEndian.Uint16(b[:]))
case *chainhash.Hash:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case **lnwire.RawFeatureVector:
f := lnwire.NewRawFeatureVector()
err := f.Decode(r)
if err != nil {
return err
}
*e = f
case **btcec.PublicKey: case **btcec.PublicKey:
var b [btcec.PubKeyBytesLenCompressed]byte var b [btcec.PubKeyBytesLenCompressed]byte
if _, err := io.ReadFull(r, b[:]); err != nil { if _, err := io.ReadFull(r, b[:]); err != nil {

@ -8,6 +8,7 @@ import (
"testing/quick" "testing/quick"
"time" "time"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtwire" "github.com/lightningnetwork/lnd/watchtower/wtwire"
@ -23,6 +24,12 @@ func randRawFeatureVector(r *rand.Rand) *lnwire.RawFeatureVector {
return featureVec return featureVec
} }
func randChainHash(r *rand.Rand) chainhash.Hash {
var hash chainhash.Hash
r.Read(hash[:])
return hash
}
// TestWatchtowerWireProtocol uses the testing/quick package to create a series // TestWatchtowerWireProtocol uses the testing/quick package to create a series
// of fuzz tests to attempt to break a primary scenario which is implemented as // of fuzz tests to attempt to break a primary scenario which is implemented as
// property based testing scenario. // property based testing scenario.
@ -73,7 +80,7 @@ func TestWatchtowerWireProtocol(t *testing.T) {
wtwire.MsgInit: func(v []reflect.Value, r *rand.Rand) { wtwire.MsgInit: func(v []reflect.Value, r *rand.Rand) {
req := wtwire.NewInitMessage( req := wtwire.NewInitMessage(
randRawFeatureVector(r), randRawFeatureVector(r),
randRawFeatureVector(r), randChainHash(r),
) )
v[0] = reflect.ValueOf(*req) v[0] = reflect.ValueOf(*req)