diff --git a/watchtower/wtserver/server.go b/watchtower/wtserver/server.go index 957b2202..81b8b5fe 100644 --- a/watchtower/wtserver/server.go +++ b/watchtower/wtserver/server.go @@ -10,6 +10,7 @@ import ( "time" "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcutil" @@ -51,6 +52,9 @@ type Config struct { // NewAddress is used to generate reward addresses, where a cut of // successfully sent funds can be received. NewAddress func() (btcutil.Address, error) + + // ChainHash identifies the network that the server is watching. + ChainHash chainhash.Hash } // Server houses the state required to handle watchtower peers. It's primary job @@ -68,7 +72,7 @@ type Server struct { clients map[wtdb.SessionID]Peer globalFeatures *lnwire.RawFeatureVector - localFeatures *lnwire.RawFeatureVector + connFeatures *lnwire.RawFeatureVector wg sync.WaitGroup quit chan struct{} @@ -78,7 +82,7 @@ type Server struct { // clients connecting to the listener addresses, and allows them to open // sessions and send state updates. func New(cfg *Config) (*Server, error) { - localFeatures := lnwire.NewRawFeatureVector( + connFeatures := lnwire.NewRawFeatureVector( wtwire.WtSessionsOptional, ) @@ -86,7 +90,7 @@ func New(cfg *Config) (*Server, error) { cfg: cfg, clients: make(map[wtdb.SessionID]Peer), globalFeatures: lnwire.NewRawFeatureVector(), - localFeatures: localFeatures, + connFeatures: connFeatures, quit: make(chan struct{}), } @@ -206,7 +210,7 @@ func (s *Server) handleClient(peer Peer) { } localInit := wtwire.NewInitMessage( - s.localFeatures, s.globalFeatures, + s.connFeatures, s.cfg.ChainHash, ) err = s.sendMessage(peer, localInit) @@ -296,25 +300,19 @@ func (s *Server) handleClient(peer Peer) { // handleInit accepts the local and remote Init messages, and verifies that the // client is not requesting any required features that are unknown to the tower. func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error { - remoteLocalFeatures := lnwire.NewFeatureVector( - remoteInit.LocalFeatures, wtwire.LocalFeatures, - ) - remoteGlobalFeatures := lnwire.NewFeatureVector( - remoteInit.GlobalFeatures, wtwire.GlobalFeatures, - ) - - unknownLocalFeatures := remoteLocalFeatures.UnknownRequiredFeatures() - if len(unknownLocalFeatures) > 0 { - err := fmt.Errorf("Peer set unknown local feature bits: %v", - unknownLocalFeatures) - return err + if localInit.ChainHash != remoteInit.ChainHash { + return fmt.Errorf("Peer chain hash unknown: %x", + remoteInit.ChainHash) } - unknownGlobalFeatures := remoteGlobalFeatures.UnknownRequiredFeatures() - if len(unknownGlobalFeatures) > 0 { - err := fmt.Errorf("Peer set unknown global feature bits: %v", - unknownGlobalFeatures) - return err + remoteConnFeatures := lnwire.NewFeatureVector( + remoteInit.ConnFeatures, wtwire.LocalFeatures, + ) + + unknownLocalFeatures := remoteConnFeatures.UnknownRequiredFeatures() + if len(unknownLocalFeatures) > 0 { + return fmt.Errorf("Peer set unknown local feature bits: %v", + unknownLocalFeatures) } return nil diff --git a/watchtower/wtserver/server_test.go b/watchtower/wtserver/server_test.go index 5b97985b..4a6ee27e 100644 --- a/watchtower/wtserver/server_test.go +++ b/watchtower/wtserver/server_test.go @@ -27,6 +27,8 @@ var ( ) addrScript, _ = txscript.PayToAddrScript(addr) + + testnetChainHash = *chaincfg.TestNet3Params.GenesisHash ) // randPubKey generates a new secp keypair, and returns the public key. @@ -59,6 +61,7 @@ func initServer(t *testing.T, db wtserver.DB, NewAddress: func() (btcutil.Address, error) { return addr, nil }, + ChainHash: testnetChainHash, }) if err != nil { t.Fatalf("unable to create server: %v", err) @@ -91,7 +94,7 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) { // Serialize a Init message to be sent by both peers. init := wtwire.NewInitMessage( - lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(), + lnwire.NewRawFeatureVector(), testnetChainHash, ) var b bytes.Buffer @@ -159,7 +162,7 @@ var createSessionTests = []createSessionTestCase{ name: "reject duplicate session create", initMsg: wtwire.NewInitMessage( lnwire.NewRawFeatureVector(), - lnwire.NewRawFeatureVector(), + testnetChainHash, ), createMsg: &wtwire.CreateSession{ BlobType: blob.TypeDefault, @@ -181,7 +184,7 @@ var createSessionTests = []createSessionTestCase{ name: "reject unsupported blob type", initMsg: wtwire.NewInitMessage( lnwire.NewRawFeatureVector(), - lnwire.NewRawFeatureVector(), + testnetChainHash, ), createMsg: &wtwire.CreateSession{ BlobType: 0, @@ -279,10 +282,10 @@ var stateUpdateTests = []stateUpdateTestCase{ // Valid update sequence, send seqnum == lastapplied as last update. { name: "perm fail after sending seqnum equal lastapplied", - initMsg: &wtwire.Init{&lnwire.Init{ - LocalFeatures: lnwire.NewRawFeatureVector(), - GlobalFeatures: lnwire.NewRawFeatureVector(), - }}, + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), createMsg: &wtwire.CreateSession{ BlobType: blob.TypeDefault, MaxUpdates: 3, @@ -309,10 +312,10 @@ var stateUpdateTests = []stateUpdateTestCase{ // Send update that skips next expected sequence number. { name: "skip sequence number", - initMsg: &wtwire.Init{&lnwire.Init{ - LocalFeatures: lnwire.NewRawFeatureVector(), - GlobalFeatures: lnwire.NewRawFeatureVector(), - }}, + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), createMsg: &wtwire.CreateSession{ BlobType: blob.TypeDefault, MaxUpdates: 4, @@ -333,10 +336,10 @@ var stateUpdateTests = []stateUpdateTestCase{ // Send update that reverts to older sequence number. { name: "revert to older seqnum", - initMsg: &wtwire.Init{&lnwire.Init{ - LocalFeatures: lnwire.NewRawFeatureVector(), - GlobalFeatures: lnwire.NewRawFeatureVector(), - }}, + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), createMsg: &wtwire.CreateSession{ BlobType: blob.TypeDefault, MaxUpdates: 4, @@ -361,10 +364,10 @@ var stateUpdateTests = []stateUpdateTestCase{ // Send update echoing a last applied that is lower than previous value. { name: "revert to older lastapplied", - initMsg: &wtwire.Init{&lnwire.Init{ - LocalFeatures: lnwire.NewRawFeatureVector(), - GlobalFeatures: lnwire.NewRawFeatureVector(), - }}, + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), createMsg: &wtwire.CreateSession{ BlobType: blob.TypeDefault, MaxUpdates: 4, @@ -389,10 +392,10 @@ var stateUpdateTests = []stateUpdateTestCase{ // Client echos last applied as they are received. { name: "resume after disconnect", - initMsg: &wtwire.Init{&lnwire.Init{ - LocalFeatures: lnwire.NewRawFeatureVector(), - GlobalFeatures: lnwire.NewRawFeatureVector(), - }}, + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), createMsg: &wtwire.CreateSession{ BlobType: blob.TypeDefault, MaxUpdates: 4, @@ -419,10 +422,10 @@ var stateUpdateTests = []stateUpdateTestCase{ // Client doesn't echo last applied until last message. { name: "resume after disconnect lagging lastapplied", - initMsg: &wtwire.Init{&lnwire.Init{ - LocalFeatures: lnwire.NewRawFeatureVector(), - GlobalFeatures: lnwire.NewRawFeatureVector(), - }}, + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), createMsg: &wtwire.CreateSession{ BlobType: blob.TypeDefault, MaxUpdates: 4, @@ -448,10 +451,10 @@ var stateUpdateTests = []stateUpdateTestCase{ // Send update with sequence number that exceeds MaxUpdates. { name: "seqnum exceed maxupdates", - initMsg: &wtwire.Init{&lnwire.Init{ - LocalFeatures: lnwire.NewRawFeatureVector(), - GlobalFeatures: lnwire.NewRawFeatureVector(), - }}, + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), createMsg: &wtwire.CreateSession{ BlobType: blob.TypeDefault, MaxUpdates: 3, @@ -478,10 +481,10 @@ var stateUpdateTests = []stateUpdateTestCase{ // Ensure sequence number 0 causes permanent failure. { name: "perm fail after seqnum 0", - initMsg: &wtwire.Init{&lnwire.Init{ - LocalFeatures: lnwire.NewRawFeatureVector(), - GlobalFeatures: lnwire.NewRawFeatureVector(), - }}, + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), createMsg: &wtwire.CreateSession{ BlobType: blob.TypeDefault, MaxUpdates: 3, diff --git a/watchtower/wtwire/init.go b/watchtower/wtwire/init.go index d9056e2b..ff901855 100644 --- a/watchtower/wtwire/init.go +++ b/watchtower/wtwire/init.go @@ -1,23 +1,59 @@ package wtwire -import "github.com/lightningnetwork/lnd/lnwire" +import ( + "io" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/lnwire" +) // Init is the first message sent over the watchtower wire protocol, and -// specifies features and level of requiredness maintained by the sending node. -// The watchtower Init message is identical to the LN Init message, except it -// uses a different message type to ensure the two are not conflated. +// specifies connection features bits and level of requiredness maintained by +// the sending node. The Init message also sends the chain hash identifying the +// network that the sender is on. type Init struct { - *lnwire.Init + // ConnFeatures are the feature bits being advertised for the duration + // of a single connection with a peer. + ConnFeatures *lnwire.RawFeatureVector + + // ChainHash is the genesis hash of the chain that the advertiser claims + // to be on. + ChainHash chainhash.Hash } -// NewInitMessage generates a new Init message from raw global and local feature -// vectors. -func NewInitMessage(gf, lf *lnwire.RawFeatureVector) *Init { +// NewInitMessage generates a new Init message from a raw connection feature +// vector and chain hash. +func NewInitMessage(connFeatures *lnwire.RawFeatureVector, + chainHash chainhash.Hash) *Init { + return &Init{ - Init: lnwire.NewInitMessage(gf, lf), + ConnFeatures: connFeatures, + ChainHash: chainHash, } } +// Encode serializes the target Init into the passed io.Writer observing the +// protocol version specified. +// +// This is part of the wtwire.Message interface. +func (msg *Init) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + msg.ConnFeatures, + msg.ChainHash, + ) +} + +// Decode deserializes a serialized Init message stored in the passed io.Reader +// observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (msg *Init) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &msg.ConnFeatures, + &msg.ChainHash, + ) +} + // MsgType returns the integer uniquely identifying this message type on the // wire. // @@ -26,5 +62,13 @@ func (msg *Init) MsgType() MessageType { return MsgInit } +// MaxPayloadLength returns the maximum allowed payload size for an Init +// complete message observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (msg *Init) MaxPayloadLength(uint32) uint32 { + return MaxMessagePayload +} + // A compile-time constraint to ensure Init implements the Message interface. var _ Message = (*Init)(nil) diff --git a/watchtower/wtwire/message.go b/watchtower/wtwire/message.go index 5daaa6de..364a2dab 100644 --- a/watchtower/wtwire/message.go +++ b/watchtower/wtwire/message.go @@ -88,7 +88,7 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { switch msgType { case MsgInit: - msg = &Init{&lnwire.Init{}} + msg = &Init{} case MsgCreateSession: msg = &CreateSession{} case MsgCreateSessionReply: diff --git a/watchtower/wtwire/wtwire.go b/watchtower/wtwire/wtwire.go index 2582c704..1af9433d 100644 --- a/watchtower/wtwire/wtwire.go +++ b/watchtower/wtwire/wtwire.go @@ -6,8 +6,10 @@ import ( "io" "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/blob" ) @@ -86,6 +88,20 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case chainhash.Hash: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case *lnwire.RawFeatureVector: + if e == nil { + return fmt.Errorf("cannot write nil feature vector") + } + + if err := e.Encode(w); err != nil { + return err + } + case *btcec.PublicKey: if e == nil { return fmt.Errorf("cannot write nil pubkey") @@ -192,6 +208,20 @@ func ReadElement(r io.Reader, element interface{}) error { } *e = ErrorCode(binary.BigEndian.Uint16(b[:])) + case *chainhash.Hash: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case **lnwire.RawFeatureVector: + f := lnwire.NewRawFeatureVector() + err := f.Decode(r) + if err != nil { + return err + } + + *e = f + case **btcec.PublicKey: var b [btcec.PubKeyBytesLenCompressed]byte if _, err := io.ReadFull(r, b[:]); err != nil { diff --git a/watchtower/wtwire/wtwire_test.go b/watchtower/wtwire/wtwire_test.go index 56adb7cd..1dfef1a2 100644 --- a/watchtower/wtwire/wtwire_test.go +++ b/watchtower/wtwire/wtwire_test.go @@ -8,6 +8,7 @@ import ( "testing/quick" "time" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/wtwire" @@ -23,6 +24,12 @@ func randRawFeatureVector(r *rand.Rand) *lnwire.RawFeatureVector { return featureVec } +func randChainHash(r *rand.Rand) chainhash.Hash { + var hash chainhash.Hash + r.Read(hash[:]) + return hash +} + // TestWatchtowerWireProtocol uses the testing/quick package to create a series // of fuzz tests to attempt to break a primary scenario which is implemented as // property based testing scenario. @@ -73,7 +80,7 @@ func TestWatchtowerWireProtocol(t *testing.T) { wtwire.MsgInit: func(v []reflect.Value, r *rand.Rand) { req := wtwire.NewInitMessage( randRawFeatureVector(r), - randRawFeatureVector(r), + randChainHash(r), ) v[0] = reflect.ValueOf(*req)