Merge pull request #2615 from cfromknecht/wtwire-check-remote-init

wtwire: add CheckRemoteInit helper
This commit is contained in:
Olaoluwa Osuntokun 2019-02-11 19:59:08 -08:00 committed by GitHub
commit e9fb6100f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 192 additions and 43 deletions

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/tor"
@ -35,6 +36,11 @@ var (
// All nil-able elements with the Config must be set in order for the Watchtower // All nil-able elements with the Config must be set in order for the Watchtower
// to function properly. // to function properly.
type Config struct { type Config struct {
// ChainHash identifies the chain that the watchtower will be monitoring
// for breaches and that will be advertised in the server's Init message
// to inbound clients.
ChainHash chainhash.Hash
// BlockFetcher supports the ability to fetch blocks from the network by // BlockFetcher supports the ability to fetch blocks from the network by
// hash. // hash.
BlockFetcher lookout.BlockFetcher BlockFetcher lookout.BlockFetcher

@ -78,6 +78,7 @@ func New(cfg *Config) (*Standalone, error) {
// Initialize the server with its required resources. // Initialize the server with its required resources.
server, err := wtserver.New(&wtserver.Config{ server, err := wtserver.New(&wtserver.Config{
ChainHash: cfg.ChainHash,
DB: cfg.DB, DB: cfg.DB,
NodePrivKey: cfg.NodePrivKey, NodePrivKey: cfg.NodePrivKey,
Listeners: listeners, Listeners: listeners,

@ -71,8 +71,7 @@ type Server struct {
clientMtx sync.RWMutex clientMtx sync.RWMutex
clients map[wtdb.SessionID]Peer clients map[wtdb.SessionID]Peer
globalFeatures *lnwire.RawFeatureVector localInit *wtwire.Init
connFeatures *lnwire.RawFeatureVector
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
@ -82,16 +81,16 @@ type Server struct {
// clients connecting to the listener addresses, and allows them to open // clients connecting to the listener addresses, and allows them to open
// sessions and send state updates. // sessions and send state updates.
func New(cfg *Config) (*Server, error) { func New(cfg *Config) (*Server, error) {
connFeatures := lnwire.NewRawFeatureVector( localInit := wtwire.NewInitMessage(
wtwire.WtSessionsOptional, lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
cfg.ChainHash,
) )
s := &Server{ s := &Server{
cfg: cfg, cfg: cfg,
clients: make(map[wtdb.SessionID]Peer), clients: make(map[wtdb.SessionID]Peer),
globalFeatures: lnwire.NewRawFeatureVector(), localInit: localInit,
connFeatures: connFeatures, quit: make(chan struct{}),
quit: make(chan struct{}),
} }
connMgr, err := connmgr.New(&connmgr.Config{ connMgr, err := connmgr.New(&connmgr.Config{
@ -209,17 +208,14 @@ func (s *Server) handleClient(peer Peer) {
return return
} }
localInit := wtwire.NewInitMessage( err = s.sendMessage(peer, s.localInit)
s.connFeatures, s.cfg.ChainHash,
)
err = s.sendMessage(peer, localInit)
if err != nil { if err != nil {
log.Errorf("Unable to send Init msg to %s: %v", id, err) log.Errorf("Unable to send Init msg to %s: %v", id, err)
return return
} }
if err = s.handleInit(localInit, remoteInit); err != nil { err = s.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
if err != nil {
log.Errorf("Cannot support client %s: %v", id, err) log.Errorf("Cannot support client %s: %v", id, err)
return return
} }
@ -297,27 +293,6 @@ func (s *Server) handleClient(peer Peer) {
} }
} }
// handleInit accepts the local and remote Init messages, and verifies that the
// client is not requesting any required features that are unknown to the tower.
func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error {
if localInit.ChainHash != remoteInit.ChainHash {
return fmt.Errorf("Peer chain hash unknown: %x",
remoteInit.ChainHash)
}
remoteConnFeatures := lnwire.NewFeatureVector(
remoteInit.ConnFeatures, wtwire.LocalFeatures,
)
unknownLocalFeatures := remoteConnFeatures.UnknownRequiredFeatures()
if len(unknownLocalFeatures) > 0 {
return fmt.Errorf("Peer set unknown local feature bits: %v",
unknownLocalFeatures)
}
return nil
}
// handleCreateSession processes a CreateSession message from the peer, and returns // handleCreateSession processes a CreateSession message from the peer, and returns
// a CreateSessionReply in response. This method will only succeed if no existing // a CreateSessionReply in response. This method will only succeed if no existing
// session info is known about the session id. If an existing session is found, // session info is known about the session id. If an existing session is found,

@ -2,13 +2,9 @@ package wtwire
import "github.com/lightningnetwork/lnd/lnwire" import "github.com/lightningnetwork/lnd/lnwire"
// GlobalFeatures holds the globally advertised feature bits understood by // FeatureNames holds a mapping from each feature bit understood by this
// watchtower implementations. // implementation to its common name.
var GlobalFeatures map[lnwire.FeatureBit]string var FeatureNames = map[lnwire.FeatureBit]string{
// LocalFeatures holds the locally advertised feature bits understood by
// watchtower implementations.
var LocalFeatures = map[lnwire.FeatureBit]string{
WtSessionsRequired: "wt-sessions", WtSessionsRequired: "wt-sessions",
WtSessionsOptional: "wt-sessions", WtSessionsOptional: "wt-sessions",
} }

@ -1,6 +1,7 @@
package wtwire package wtwire
import ( import (
"fmt"
"io" "io"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
@ -72,3 +73,67 @@ func (msg *Init) MaxPayloadLength(uint32) uint32 {
// A compile-time constraint to ensure Init implements the Message interface. // A compile-time constraint to ensure Init implements the Message interface.
var _ Message = (*Init)(nil) var _ Message = (*Init)(nil)
// CheckRemoteInit performs basic validation of the remote party's Init message.
// This method checks that the remote Init's chain hash matches our advertised
// chain hash and that the remote Init does not contain any required feature
// bits that we don't understand.
func (msg *Init) CheckRemoteInit(remoteInit *Init,
featureNames map[lnwire.FeatureBit]string) error {
// Check that the remote peer is on the same chain.
if msg.ChainHash != remoteInit.ChainHash {
return NewErrUnknownChainHash(remoteInit.ChainHash)
}
remoteConnFeatures := lnwire.NewFeatureVector(
remoteInit.ConnFeatures, featureNames,
)
// Check that the remote peer doesn't have any required connection
// feature bits that we ourselves are unaware of.
unknownConnFeatures := remoteConnFeatures.UnknownRequiredFeatures()
if len(unknownConnFeatures) > 0 {
return NewErrUnknownRequiredFeatures(unknownConnFeatures...)
}
return nil
}
// ErrUnknownChainHash signals that the remote Init has a different chain hash
// from the one we advertised.
type ErrUnknownChainHash struct {
hash chainhash.Hash
}
// NewErrUnknownChainHash creates an ErrUnknownChainHash using the remote Init's
// chain hash.
func NewErrUnknownChainHash(hash chainhash.Hash) *ErrUnknownChainHash {
return &ErrUnknownChainHash{hash}
}
// Error returns a human-readable error displaying the unknown chain hash.
func (e *ErrUnknownChainHash) Error() string {
return fmt.Sprintf("remote init has unknown chain hash: %s", e.hash)
}
// ErrUnknownRequiredFeatures signals that the remote Init has required feature
// bits that were unknown to us.
type ErrUnknownRequiredFeatures struct {
unknownFeatures []lnwire.FeatureBit
}
// NewErrUnknownRequiredFeatures creates an ErrUnknownRequiredFeatures using the
// remote Init's required features that were unknown to us.
func NewErrUnknownRequiredFeatures(
unknownFeatures ...lnwire.FeatureBit) *ErrUnknownRequiredFeatures {
return &ErrUnknownRequiredFeatures{unknownFeatures}
}
// Error returns a human-readable error displaying the unknown required feature
// bits.
func (e *ErrUnknownRequiredFeatures) Error() string {
return fmt.Sprintf("remote init has unknown required features: %v",
e.unknownFeatures)
}

@ -0,0 +1,106 @@
package wtwire_test
import (
"testing"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
var (
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
mainnetChainHash = *chaincfg.MainNetParams.GenesisHash
)
type checkRemoteInitTest struct {
name string
lFeatures *lnwire.RawFeatureVector
lHash chainhash.Hash
rFeatures *lnwire.RawFeatureVector
rHash chainhash.Hash
expErr error
}
var checkRemoteInitTests = []checkRemoteInitTest{
{
name: "same chain, local-optional remote-required",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
rHash: testnetChainHash,
},
{
name: "same chain, local-required remote-optional",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
rHash: testnetChainHash,
},
{
name: "different chain, local-optional remote-required",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
rHash: mainnetChainHash,
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
},
{
name: "different chain, local-required remote-optional",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
rHash: mainnetChainHash,
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
},
{
name: "same chain, remote-unknown-required",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired),
rHash: testnetChainHash,
expErr: wtwire.NewErrUnknownRequiredFeatures(
lnwire.GossipQueriesRequired,
),
},
}
// TestCheckRemoteInit asserts the behavior of CheckRemoteInit when called with
// the remote party's Init message and the default wtwire.Features. We assert
// the validity of advertised features from the perspective of both client and
// server, as well as failure cases such as differing chain hashes or unknown
// required features.
func TestCheckRemoteInit(t *testing.T) {
for _, test := range checkRemoteInitTests {
t.Run(test.name, func(t *testing.T) {
testCheckRemoteInit(t, test)
})
}
}
func testCheckRemoteInit(t *testing.T, test checkRemoteInitTest) {
localInit := wtwire.NewInitMessage(test.lFeatures, test.lHash)
remoteInit := wtwire.NewInitMessage(test.rFeatures, test.rHash)
err := localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
switch {
// Both non-nil, pass.
case err == nil && test.expErr == nil:
return
// One is nil and one is non-nil, fail.
default:
t.Fatalf("error mismatch, want: %v, got: %v", test.expErr, err)
// Both non-nil, assert same error type.
case err != nil && test.expErr != nil:
}
// Compare error strings to assert same type.
if err.Error() != test.expErr.Error() {
t.Fatalf("error mismatch, want: %v, got: %v",
test.expErr.Error(), err.Error())
}
}