Merge pull request #2615 from cfromknecht/wtwire-check-remote-init
wtwire: add CheckRemoteInit helper
This commit is contained in:
commit
e9fb6100f2
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/lightningnetwork/lnd/tor"
|
||||
@ -35,6 +36,11 @@ var (
|
||||
// All nil-able elements with the Config must be set in order for the Watchtower
|
||||
// to function properly.
|
||||
type Config struct {
|
||||
// ChainHash identifies the chain that the watchtower will be monitoring
|
||||
// for breaches and that will be advertised in the server's Init message
|
||||
// to inbound clients.
|
||||
ChainHash chainhash.Hash
|
||||
|
||||
// BlockFetcher supports the ability to fetch blocks from the network by
|
||||
// hash.
|
||||
BlockFetcher lookout.BlockFetcher
|
||||
|
@ -78,6 +78,7 @@ func New(cfg *Config) (*Standalone, error) {
|
||||
|
||||
// Initialize the server with its required resources.
|
||||
server, err := wtserver.New(&wtserver.Config{
|
||||
ChainHash: cfg.ChainHash,
|
||||
DB: cfg.DB,
|
||||
NodePrivKey: cfg.NodePrivKey,
|
||||
Listeners: listeners,
|
||||
|
@ -71,8 +71,7 @@ type Server struct {
|
||||
clientMtx sync.RWMutex
|
||||
clients map[wtdb.SessionID]Peer
|
||||
|
||||
globalFeatures *lnwire.RawFeatureVector
|
||||
connFeatures *lnwire.RawFeatureVector
|
||||
localInit *wtwire.Init
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
@ -82,16 +81,16 @@ type Server struct {
|
||||
// clients connecting to the listener addresses, and allows them to open
|
||||
// sessions and send state updates.
|
||||
func New(cfg *Config) (*Server, error) {
|
||||
connFeatures := lnwire.NewRawFeatureVector(
|
||||
wtwire.WtSessionsOptional,
|
||||
localInit := wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||
cfg.ChainHash,
|
||||
)
|
||||
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
clients: make(map[wtdb.SessionID]Peer),
|
||||
globalFeatures: lnwire.NewRawFeatureVector(),
|
||||
connFeatures: connFeatures,
|
||||
quit: make(chan struct{}),
|
||||
cfg: cfg,
|
||||
clients: make(map[wtdb.SessionID]Peer),
|
||||
localInit: localInit,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
connMgr, err := connmgr.New(&connmgr.Config{
|
||||
@ -209,17 +208,14 @@ func (s *Server) handleClient(peer Peer) {
|
||||
return
|
||||
}
|
||||
|
||||
localInit := wtwire.NewInitMessage(
|
||||
s.connFeatures, s.cfg.ChainHash,
|
||||
)
|
||||
|
||||
err = s.sendMessage(peer, localInit)
|
||||
err = s.sendMessage(peer, s.localInit)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to send Init msg to %s: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.handleInit(localInit, remoteInit); err != nil {
|
||||
err = s.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
|
||||
if err != nil {
|
||||
log.Errorf("Cannot support client %s: %v", id, err)
|
||||
return
|
||||
}
|
||||
@ -297,27 +293,6 @@ func (s *Server) handleClient(peer Peer) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleInit accepts the local and remote Init messages, and verifies that the
|
||||
// client is not requesting any required features that are unknown to the tower.
|
||||
func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error {
|
||||
if localInit.ChainHash != remoteInit.ChainHash {
|
||||
return fmt.Errorf("Peer chain hash unknown: %x",
|
||||
remoteInit.ChainHash)
|
||||
}
|
||||
|
||||
remoteConnFeatures := lnwire.NewFeatureVector(
|
||||
remoteInit.ConnFeatures, wtwire.LocalFeatures,
|
||||
)
|
||||
|
||||
unknownLocalFeatures := remoteConnFeatures.UnknownRequiredFeatures()
|
||||
if len(unknownLocalFeatures) > 0 {
|
||||
return fmt.Errorf("Peer set unknown local feature bits: %v",
|
||||
unknownLocalFeatures)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleCreateSession processes a CreateSession message from the peer, and returns
|
||||
// a CreateSessionReply in response. This method will only succeed if no existing
|
||||
// session info is known about the session id. If an existing session is found,
|
||||
|
@ -2,13 +2,9 @@ package wtwire
|
||||
|
||||
import "github.com/lightningnetwork/lnd/lnwire"
|
||||
|
||||
// GlobalFeatures holds the globally advertised feature bits understood by
|
||||
// watchtower implementations.
|
||||
var GlobalFeatures map[lnwire.FeatureBit]string
|
||||
|
||||
// LocalFeatures holds the locally advertised feature bits understood by
|
||||
// watchtower implementations.
|
||||
var LocalFeatures = map[lnwire.FeatureBit]string{
|
||||
// FeatureNames holds a mapping from each feature bit understood by this
|
||||
// implementation to its common name.
|
||||
var FeatureNames = map[lnwire.FeatureBit]string{
|
||||
WtSessionsRequired: "wt-sessions",
|
||||
WtSessionsOptional: "wt-sessions",
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package wtwire
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
@ -72,3 +73,67 @@ func (msg *Init) MaxPayloadLength(uint32) uint32 {
|
||||
|
||||
// A compile-time constraint to ensure Init implements the Message interface.
|
||||
var _ Message = (*Init)(nil)
|
||||
|
||||
// CheckRemoteInit performs basic validation of the remote party's Init message.
|
||||
// This method checks that the remote Init's chain hash matches our advertised
|
||||
// chain hash and that the remote Init does not contain any required feature
|
||||
// bits that we don't understand.
|
||||
func (msg *Init) CheckRemoteInit(remoteInit *Init,
|
||||
featureNames map[lnwire.FeatureBit]string) error {
|
||||
|
||||
// Check that the remote peer is on the same chain.
|
||||
if msg.ChainHash != remoteInit.ChainHash {
|
||||
return NewErrUnknownChainHash(remoteInit.ChainHash)
|
||||
}
|
||||
|
||||
remoteConnFeatures := lnwire.NewFeatureVector(
|
||||
remoteInit.ConnFeatures, featureNames,
|
||||
)
|
||||
|
||||
// Check that the remote peer doesn't have any required connection
|
||||
// feature bits that we ourselves are unaware of.
|
||||
unknownConnFeatures := remoteConnFeatures.UnknownRequiredFeatures()
|
||||
if len(unknownConnFeatures) > 0 {
|
||||
return NewErrUnknownRequiredFeatures(unknownConnFeatures...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrUnknownChainHash signals that the remote Init has a different chain hash
|
||||
// from the one we advertised.
|
||||
type ErrUnknownChainHash struct {
|
||||
hash chainhash.Hash
|
||||
}
|
||||
|
||||
// NewErrUnknownChainHash creates an ErrUnknownChainHash using the remote Init's
|
||||
// chain hash.
|
||||
func NewErrUnknownChainHash(hash chainhash.Hash) *ErrUnknownChainHash {
|
||||
return &ErrUnknownChainHash{hash}
|
||||
}
|
||||
|
||||
// Error returns a human-readable error displaying the unknown chain hash.
|
||||
func (e *ErrUnknownChainHash) Error() string {
|
||||
return fmt.Sprintf("remote init has unknown chain hash: %s", e.hash)
|
||||
}
|
||||
|
||||
// ErrUnknownRequiredFeatures signals that the remote Init has required feature
|
||||
// bits that were unknown to us.
|
||||
type ErrUnknownRequiredFeatures struct {
|
||||
unknownFeatures []lnwire.FeatureBit
|
||||
}
|
||||
|
||||
// NewErrUnknownRequiredFeatures creates an ErrUnknownRequiredFeatures using the
|
||||
// remote Init's required features that were unknown to us.
|
||||
func NewErrUnknownRequiredFeatures(
|
||||
unknownFeatures ...lnwire.FeatureBit) *ErrUnknownRequiredFeatures {
|
||||
|
||||
return &ErrUnknownRequiredFeatures{unknownFeatures}
|
||||
}
|
||||
|
||||
// Error returns a human-readable error displaying the unknown required feature
|
||||
// bits.
|
||||
func (e *ErrUnknownRequiredFeatures) Error() string {
|
||||
return fmt.Sprintf("remote init has unknown required features: %v",
|
||||
e.unknownFeatures)
|
||||
}
|
||||
|
106
watchtower/wtwire/init_test.go
Normal file
106
watchtower/wtwire/init_test.go
Normal file
@ -0,0 +1,106 @@
|
||||
package wtwire_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||
)
|
||||
|
||||
var (
|
||||
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
|
||||
mainnetChainHash = *chaincfg.MainNetParams.GenesisHash
|
||||
)
|
||||
|
||||
type checkRemoteInitTest struct {
|
||||
name string
|
||||
lFeatures *lnwire.RawFeatureVector
|
||||
lHash chainhash.Hash
|
||||
rFeatures *lnwire.RawFeatureVector
|
||||
rHash chainhash.Hash
|
||||
expErr error
|
||||
}
|
||||
|
||||
var checkRemoteInitTests = []checkRemoteInitTest{
|
||||
{
|
||||
name: "same chain, local-optional remote-required",
|
||||
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||
lHash: testnetChainHash,
|
||||
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||
rHash: testnetChainHash,
|
||||
},
|
||||
{
|
||||
name: "same chain, local-required remote-optional",
|
||||
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||
lHash: testnetChainHash,
|
||||
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||
rHash: testnetChainHash,
|
||||
},
|
||||
{
|
||||
name: "different chain, local-optional remote-required",
|
||||
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||
lHash: testnetChainHash,
|
||||
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||
rHash: mainnetChainHash,
|
||||
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
|
||||
},
|
||||
{
|
||||
name: "different chain, local-required remote-optional",
|
||||
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||
lHash: testnetChainHash,
|
||||
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||
rHash: mainnetChainHash,
|
||||
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
|
||||
},
|
||||
{
|
||||
name: "same chain, remote-unknown-required",
|
||||
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||
lHash: testnetChainHash,
|
||||
rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired),
|
||||
rHash: testnetChainHash,
|
||||
expErr: wtwire.NewErrUnknownRequiredFeatures(
|
||||
lnwire.GossipQueriesRequired,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
// TestCheckRemoteInit asserts the behavior of CheckRemoteInit when called with
|
||||
// the remote party's Init message and the default wtwire.Features. We assert
|
||||
// the validity of advertised features from the perspective of both client and
|
||||
// server, as well as failure cases such as differing chain hashes or unknown
|
||||
// required features.
|
||||
func TestCheckRemoteInit(t *testing.T) {
|
||||
for _, test := range checkRemoteInitTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testCheckRemoteInit(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testCheckRemoteInit(t *testing.T, test checkRemoteInitTest) {
|
||||
localInit := wtwire.NewInitMessage(test.lFeatures, test.lHash)
|
||||
remoteInit := wtwire.NewInitMessage(test.rFeatures, test.rHash)
|
||||
|
||||
err := localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
|
||||
switch {
|
||||
|
||||
// Both non-nil, pass.
|
||||
case err == nil && test.expErr == nil:
|
||||
return
|
||||
|
||||
// One is nil and one is non-nil, fail.
|
||||
default:
|
||||
t.Fatalf("error mismatch, want: %v, got: %v", test.expErr, err)
|
||||
|
||||
// Both non-nil, assert same error type.
|
||||
case err != nil && test.expErr != nil:
|
||||
}
|
||||
|
||||
// Compare error strings to assert same type.
|
||||
if err.Error() != test.expErr.Error() {
|
||||
t.Fatalf("error mismatch, want: %v, got: %v",
|
||||
test.expErr.Error(), err.Error())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user