diff --git a/watchtower/config.go b/watchtower/config.go index 733ebb0b..b8136682 100644 --- a/watchtower/config.go +++ b/watchtower/config.go @@ -6,6 +6,7 @@ import ( "time" "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/tor" @@ -35,6 +36,11 @@ var ( // All nil-able elements with the Config must be set in order for the Watchtower // to function properly. type Config struct { + // ChainHash identifies the chain that the watchtower will be monitoring + // for breaches and that will be advertised in the server's Init message + // to inbound clients. + ChainHash chainhash.Hash + // BlockFetcher supports the ability to fetch blocks from the network by // hash. BlockFetcher lookout.BlockFetcher diff --git a/watchtower/standalone.go b/watchtower/standalone.go index ebb3646d..f55f44cb 100644 --- a/watchtower/standalone.go +++ b/watchtower/standalone.go @@ -78,6 +78,7 @@ func New(cfg *Config) (*Standalone, error) { // Initialize the server with its required resources. server, err := wtserver.New(&wtserver.Config{ + ChainHash: cfg.ChainHash, DB: cfg.DB, NodePrivKey: cfg.NodePrivKey, Listeners: listeners, diff --git a/watchtower/wtserver/server.go b/watchtower/wtserver/server.go index 81b8b5fe..98848fbc 100644 --- a/watchtower/wtserver/server.go +++ b/watchtower/wtserver/server.go @@ -71,8 +71,7 @@ type Server struct { clientMtx sync.RWMutex clients map[wtdb.SessionID]Peer - globalFeatures *lnwire.RawFeatureVector - connFeatures *lnwire.RawFeatureVector + localInit *wtwire.Init wg sync.WaitGroup quit chan struct{} @@ -82,16 +81,16 @@ type Server struct { // clients connecting to the listener addresses, and allows them to open // sessions and send state updates. func New(cfg *Config) (*Server, error) { - connFeatures := lnwire.NewRawFeatureVector( - wtwire.WtSessionsOptional, + localInit := wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + cfg.ChainHash, ) s := &Server{ - cfg: cfg, - clients: make(map[wtdb.SessionID]Peer), - globalFeatures: lnwire.NewRawFeatureVector(), - connFeatures: connFeatures, - quit: make(chan struct{}), + cfg: cfg, + clients: make(map[wtdb.SessionID]Peer), + localInit: localInit, + quit: make(chan struct{}), } connMgr, err := connmgr.New(&connmgr.Config{ @@ -209,17 +208,14 @@ func (s *Server) handleClient(peer Peer) { return } - localInit := wtwire.NewInitMessage( - s.connFeatures, s.cfg.ChainHash, - ) - - err = s.sendMessage(peer, localInit) + err = s.sendMessage(peer, s.localInit) if err != nil { log.Errorf("Unable to send Init msg to %s: %v", id, err) return } - if err = s.handleInit(localInit, remoteInit); err != nil { + err = s.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames) + if err != nil { log.Errorf("Cannot support client %s: %v", id, err) return } @@ -297,27 +293,6 @@ func (s *Server) handleClient(peer Peer) { } } -// handleInit accepts the local and remote Init messages, and verifies that the -// client is not requesting any required features that are unknown to the tower. -func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error { - if localInit.ChainHash != remoteInit.ChainHash { - return fmt.Errorf("Peer chain hash unknown: %x", - remoteInit.ChainHash) - } - - remoteConnFeatures := lnwire.NewFeatureVector( - remoteInit.ConnFeatures, wtwire.LocalFeatures, - ) - - unknownLocalFeatures := remoteConnFeatures.UnknownRequiredFeatures() - if len(unknownLocalFeatures) > 0 { - return fmt.Errorf("Peer set unknown local feature bits: %v", - unknownLocalFeatures) - } - - return nil -} - // handleCreateSession processes a CreateSession message from the peer, and returns // a CreateSessionReply in response. This method will only succeed if no existing // session info is known about the session id. If an existing session is found, diff --git a/watchtower/wtwire/features.go b/watchtower/wtwire/features.go index 07c3eb7f..e407c96e 100644 --- a/watchtower/wtwire/features.go +++ b/watchtower/wtwire/features.go @@ -2,13 +2,9 @@ package wtwire import "github.com/lightningnetwork/lnd/lnwire" -// GlobalFeatures holds the globally advertised feature bits understood by -// watchtower implementations. -var GlobalFeatures map[lnwire.FeatureBit]string - -// LocalFeatures holds the locally advertised feature bits understood by -// watchtower implementations. -var LocalFeatures = map[lnwire.FeatureBit]string{ +// FeatureNames holds a mapping from each feature bit understood by this +// implementation to its common name. +var FeatureNames = map[lnwire.FeatureBit]string{ WtSessionsRequired: "wt-sessions", WtSessionsOptional: "wt-sessions", } diff --git a/watchtower/wtwire/init.go b/watchtower/wtwire/init.go index ff901855..79a5fbf8 100644 --- a/watchtower/wtwire/init.go +++ b/watchtower/wtwire/init.go @@ -1,6 +1,7 @@ package wtwire import ( + "fmt" "io" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -72,3 +73,67 @@ func (msg *Init) MaxPayloadLength(uint32) uint32 { // A compile-time constraint to ensure Init implements the Message interface. var _ Message = (*Init)(nil) + +// CheckRemoteInit performs basic validation of the remote party's Init message. +// This method checks that the remote Init's chain hash matches our advertised +// chain hash and that the remote Init does not contain any required feature +// bits that we don't understand. +func (msg *Init) CheckRemoteInit(remoteInit *Init, + featureNames map[lnwire.FeatureBit]string) error { + + // Check that the remote peer is on the same chain. + if msg.ChainHash != remoteInit.ChainHash { + return NewErrUnknownChainHash(remoteInit.ChainHash) + } + + remoteConnFeatures := lnwire.NewFeatureVector( + remoteInit.ConnFeatures, featureNames, + ) + + // Check that the remote peer doesn't have any required connection + // feature bits that we ourselves are unaware of. + unknownConnFeatures := remoteConnFeatures.UnknownRequiredFeatures() + if len(unknownConnFeatures) > 0 { + return NewErrUnknownRequiredFeatures(unknownConnFeatures...) + } + + return nil +} + +// ErrUnknownChainHash signals that the remote Init has a different chain hash +// from the one we advertised. +type ErrUnknownChainHash struct { + hash chainhash.Hash +} + +// NewErrUnknownChainHash creates an ErrUnknownChainHash using the remote Init's +// chain hash. +func NewErrUnknownChainHash(hash chainhash.Hash) *ErrUnknownChainHash { + return &ErrUnknownChainHash{hash} +} + +// Error returns a human-readable error displaying the unknown chain hash. +func (e *ErrUnknownChainHash) Error() string { + return fmt.Sprintf("remote init has unknown chain hash: %s", e.hash) +} + +// ErrUnknownRequiredFeatures signals that the remote Init has required feature +// bits that were unknown to us. +type ErrUnknownRequiredFeatures struct { + unknownFeatures []lnwire.FeatureBit +} + +// NewErrUnknownRequiredFeatures creates an ErrUnknownRequiredFeatures using the +// remote Init's required features that were unknown to us. +func NewErrUnknownRequiredFeatures( + unknownFeatures ...lnwire.FeatureBit) *ErrUnknownRequiredFeatures { + + return &ErrUnknownRequiredFeatures{unknownFeatures} +} + +// Error returns a human-readable error displaying the unknown required feature +// bits. +func (e *ErrUnknownRequiredFeatures) Error() string { + return fmt.Sprintf("remote init has unknown required features: %v", + e.unknownFeatures) +} diff --git a/watchtower/wtwire/init_test.go b/watchtower/wtwire/init_test.go new file mode 100644 index 00000000..337c1de2 --- /dev/null +++ b/watchtower/wtwire/init_test.go @@ -0,0 +1,106 @@ +package wtwire_test + +import ( + "testing" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +var ( + testnetChainHash = *chaincfg.TestNet3Params.GenesisHash + mainnetChainHash = *chaincfg.MainNetParams.GenesisHash +) + +type checkRemoteInitTest struct { + name string + lFeatures *lnwire.RawFeatureVector + lHash chainhash.Hash + rFeatures *lnwire.RawFeatureVector + rHash chainhash.Hash + expErr error +} + +var checkRemoteInitTests = []checkRemoteInitTest{ + { + name: "same chain, local-optional remote-required", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + rHash: testnetChainHash, + }, + { + name: "same chain, local-required remote-optional", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + rHash: testnetChainHash, + }, + { + name: "different chain, local-optional remote-required", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + rHash: mainnetChainHash, + expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash), + }, + { + name: "different chain, local-required remote-optional", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + rHash: mainnetChainHash, + expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash), + }, + { + name: "same chain, remote-unknown-required", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), + rHash: testnetChainHash, + expErr: wtwire.NewErrUnknownRequiredFeatures( + lnwire.GossipQueriesRequired, + ), + }, +} + +// TestCheckRemoteInit asserts the behavior of CheckRemoteInit when called with +// the remote party's Init message and the default wtwire.Features. We assert +// the validity of advertised features from the perspective of both client and +// server, as well as failure cases such as differing chain hashes or unknown +// required features. +func TestCheckRemoteInit(t *testing.T) { + for _, test := range checkRemoteInitTests { + t.Run(test.name, func(t *testing.T) { + testCheckRemoteInit(t, test) + }) + } +} + +func testCheckRemoteInit(t *testing.T, test checkRemoteInitTest) { + localInit := wtwire.NewInitMessage(test.lFeatures, test.lHash) + remoteInit := wtwire.NewInitMessage(test.rFeatures, test.rHash) + + err := localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames) + switch { + + // Both non-nil, pass. + case err == nil && test.expErr == nil: + return + + // One is nil and one is non-nil, fail. + default: + t.Fatalf("error mismatch, want: %v, got: %v", test.expErr, err) + + // Both non-nil, assert same error type. + case err != nil && test.expErr != nil: + } + + // Compare error strings to assert same type. + if err.Error() != test.expErr.Error() { + t.Fatalf("error mismatch, want: %v, got: %v", + test.expErr.Error(), err.Error()) + } +}