Merge pull request #2615 from cfromknecht/wtwire-check-remote-init
wtwire: add CheckRemoteInit helper
This commit is contained in:
commit
e9fb6100f2
@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
"github.com/btcsuite/btcd/wire"
|
"github.com/btcsuite/btcd/wire"
|
||||||
"github.com/btcsuite/btcutil"
|
"github.com/btcsuite/btcutil"
|
||||||
"github.com/lightningnetwork/lnd/tor"
|
"github.com/lightningnetwork/lnd/tor"
|
||||||
@ -35,6 +36,11 @@ var (
|
|||||||
// All nil-able elements with the Config must be set in order for the Watchtower
|
// All nil-able elements with the Config must be set in order for the Watchtower
|
||||||
// to function properly.
|
// to function properly.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
// ChainHash identifies the chain that the watchtower will be monitoring
|
||||||
|
// for breaches and that will be advertised in the server's Init message
|
||||||
|
// to inbound clients.
|
||||||
|
ChainHash chainhash.Hash
|
||||||
|
|
||||||
// BlockFetcher supports the ability to fetch blocks from the network by
|
// BlockFetcher supports the ability to fetch blocks from the network by
|
||||||
// hash.
|
// hash.
|
||||||
BlockFetcher lookout.BlockFetcher
|
BlockFetcher lookout.BlockFetcher
|
||||||
|
@ -78,6 +78,7 @@ func New(cfg *Config) (*Standalone, error) {
|
|||||||
|
|
||||||
// Initialize the server with its required resources.
|
// Initialize the server with its required resources.
|
||||||
server, err := wtserver.New(&wtserver.Config{
|
server, err := wtserver.New(&wtserver.Config{
|
||||||
|
ChainHash: cfg.ChainHash,
|
||||||
DB: cfg.DB,
|
DB: cfg.DB,
|
||||||
NodePrivKey: cfg.NodePrivKey,
|
NodePrivKey: cfg.NodePrivKey,
|
||||||
Listeners: listeners,
|
Listeners: listeners,
|
||||||
|
@ -71,8 +71,7 @@ type Server struct {
|
|||||||
clientMtx sync.RWMutex
|
clientMtx sync.RWMutex
|
||||||
clients map[wtdb.SessionID]Peer
|
clients map[wtdb.SessionID]Peer
|
||||||
|
|
||||||
globalFeatures *lnwire.RawFeatureVector
|
localInit *wtwire.Init
|
||||||
connFeatures *lnwire.RawFeatureVector
|
|
||||||
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
@ -82,16 +81,16 @@ type Server struct {
|
|||||||
// clients connecting to the listener addresses, and allows them to open
|
// clients connecting to the listener addresses, and allows them to open
|
||||||
// sessions and send state updates.
|
// sessions and send state updates.
|
||||||
func New(cfg *Config) (*Server, error) {
|
func New(cfg *Config) (*Server, error) {
|
||||||
connFeatures := lnwire.NewRawFeatureVector(
|
localInit := wtwire.NewInitMessage(
|
||||||
wtwire.WtSessionsOptional,
|
lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||||
|
cfg.ChainHash,
|
||||||
)
|
)
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
clients: make(map[wtdb.SessionID]Peer),
|
clients: make(map[wtdb.SessionID]Peer),
|
||||||
globalFeatures: lnwire.NewRawFeatureVector(),
|
localInit: localInit,
|
||||||
connFeatures: connFeatures,
|
quit: make(chan struct{}),
|
||||||
quit: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
connMgr, err := connmgr.New(&connmgr.Config{
|
connMgr, err := connmgr.New(&connmgr.Config{
|
||||||
@ -209,17 +208,14 @@ func (s *Server) handleClient(peer Peer) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
localInit := wtwire.NewInitMessage(
|
err = s.sendMessage(peer, s.localInit)
|
||||||
s.connFeatures, s.cfg.ChainHash,
|
|
||||||
)
|
|
||||||
|
|
||||||
err = s.sendMessage(peer, localInit)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to send Init msg to %s: %v", id, err)
|
log.Errorf("Unable to send Init msg to %s: %v", id, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = s.handleInit(localInit, remoteInit); err != nil {
|
err = s.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
|
||||||
|
if err != nil {
|
||||||
log.Errorf("Cannot support client %s: %v", id, err)
|
log.Errorf("Cannot support client %s: %v", id, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -297,27 +293,6 @@ func (s *Server) handleClient(peer Peer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleInit accepts the local and remote Init messages, and verifies that the
|
|
||||||
// client is not requesting any required features that are unknown to the tower.
|
|
||||||
func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error {
|
|
||||||
if localInit.ChainHash != remoteInit.ChainHash {
|
|
||||||
return fmt.Errorf("Peer chain hash unknown: %x",
|
|
||||||
remoteInit.ChainHash)
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteConnFeatures := lnwire.NewFeatureVector(
|
|
||||||
remoteInit.ConnFeatures, wtwire.LocalFeatures,
|
|
||||||
)
|
|
||||||
|
|
||||||
unknownLocalFeatures := remoteConnFeatures.UnknownRequiredFeatures()
|
|
||||||
if len(unknownLocalFeatures) > 0 {
|
|
||||||
return fmt.Errorf("Peer set unknown local feature bits: %v",
|
|
||||||
unknownLocalFeatures)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleCreateSession processes a CreateSession message from the peer, and returns
|
// handleCreateSession processes a CreateSession message from the peer, and returns
|
||||||
// a CreateSessionReply in response. This method will only succeed if no existing
|
// a CreateSessionReply in response. This method will only succeed if no existing
|
||||||
// session info is known about the session id. If an existing session is found,
|
// session info is known about the session id. If an existing session is found,
|
||||||
|
@ -2,13 +2,9 @@ package wtwire
|
|||||||
|
|
||||||
import "github.com/lightningnetwork/lnd/lnwire"
|
import "github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
|
||||||
// GlobalFeatures holds the globally advertised feature bits understood by
|
// FeatureNames holds a mapping from each feature bit understood by this
|
||||||
// watchtower implementations.
|
// implementation to its common name.
|
||||||
var GlobalFeatures map[lnwire.FeatureBit]string
|
var FeatureNames = map[lnwire.FeatureBit]string{
|
||||||
|
|
||||||
// LocalFeatures holds the locally advertised feature bits understood by
|
|
||||||
// watchtower implementations.
|
|
||||||
var LocalFeatures = map[lnwire.FeatureBit]string{
|
|
||||||
WtSessionsRequired: "wt-sessions",
|
WtSessionsRequired: "wt-sessions",
|
||||||
WtSessionsOptional: "wt-sessions",
|
WtSessionsOptional: "wt-sessions",
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package wtwire
|
package wtwire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
@ -72,3 +73,67 @@ func (msg *Init) MaxPayloadLength(uint32) uint32 {
|
|||||||
|
|
||||||
// A compile-time constraint to ensure Init implements the Message interface.
|
// A compile-time constraint to ensure Init implements the Message interface.
|
||||||
var _ Message = (*Init)(nil)
|
var _ Message = (*Init)(nil)
|
||||||
|
|
||||||
|
// CheckRemoteInit performs basic validation of the remote party's Init message.
|
||||||
|
// This method checks that the remote Init's chain hash matches our advertised
|
||||||
|
// chain hash and that the remote Init does not contain any required feature
|
||||||
|
// bits that we don't understand.
|
||||||
|
func (msg *Init) CheckRemoteInit(remoteInit *Init,
|
||||||
|
featureNames map[lnwire.FeatureBit]string) error {
|
||||||
|
|
||||||
|
// Check that the remote peer is on the same chain.
|
||||||
|
if msg.ChainHash != remoteInit.ChainHash {
|
||||||
|
return NewErrUnknownChainHash(remoteInit.ChainHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteConnFeatures := lnwire.NewFeatureVector(
|
||||||
|
remoteInit.ConnFeatures, featureNames,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check that the remote peer doesn't have any required connection
|
||||||
|
// feature bits that we ourselves are unaware of.
|
||||||
|
unknownConnFeatures := remoteConnFeatures.UnknownRequiredFeatures()
|
||||||
|
if len(unknownConnFeatures) > 0 {
|
||||||
|
return NewErrUnknownRequiredFeatures(unknownConnFeatures...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrUnknownChainHash signals that the remote Init has a different chain hash
|
||||||
|
// from the one we advertised.
|
||||||
|
type ErrUnknownChainHash struct {
|
||||||
|
hash chainhash.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrUnknownChainHash creates an ErrUnknownChainHash using the remote Init's
|
||||||
|
// chain hash.
|
||||||
|
func NewErrUnknownChainHash(hash chainhash.Hash) *ErrUnknownChainHash {
|
||||||
|
return &ErrUnknownChainHash{hash}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a human-readable error displaying the unknown chain hash.
|
||||||
|
func (e *ErrUnknownChainHash) Error() string {
|
||||||
|
return fmt.Sprintf("remote init has unknown chain hash: %s", e.hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrUnknownRequiredFeatures signals that the remote Init has required feature
|
||||||
|
// bits that were unknown to us.
|
||||||
|
type ErrUnknownRequiredFeatures struct {
|
||||||
|
unknownFeatures []lnwire.FeatureBit
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrUnknownRequiredFeatures creates an ErrUnknownRequiredFeatures using the
|
||||||
|
// remote Init's required features that were unknown to us.
|
||||||
|
func NewErrUnknownRequiredFeatures(
|
||||||
|
unknownFeatures ...lnwire.FeatureBit) *ErrUnknownRequiredFeatures {
|
||||||
|
|
||||||
|
return &ErrUnknownRequiredFeatures{unknownFeatures}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a human-readable error displaying the unknown required feature
|
||||||
|
// bits.
|
||||||
|
func (e *ErrUnknownRequiredFeatures) Error() string {
|
||||||
|
return fmt.Sprintf("remote init has unknown required features: %v",
|
||||||
|
e.unknownFeatures)
|
||||||
|
}
|
||||||
|
106
watchtower/wtwire/init_test.go
Normal file
106
watchtower/wtwire/init_test.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package wtwire_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/chaincfg"
|
||||||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
|
||||||
|
mainnetChainHash = *chaincfg.MainNetParams.GenesisHash
|
||||||
|
)
|
||||||
|
|
||||||
|
type checkRemoteInitTest struct {
|
||||||
|
name string
|
||||||
|
lFeatures *lnwire.RawFeatureVector
|
||||||
|
lHash chainhash.Hash
|
||||||
|
rFeatures *lnwire.RawFeatureVector
|
||||||
|
rHash chainhash.Hash
|
||||||
|
expErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
var checkRemoteInitTests = []checkRemoteInitTest{
|
||||||
|
{
|
||||||
|
name: "same chain, local-optional remote-required",
|
||||||
|
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||||
|
lHash: testnetChainHash,
|
||||||
|
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||||
|
rHash: testnetChainHash,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same chain, local-required remote-optional",
|
||||||
|
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||||
|
lHash: testnetChainHash,
|
||||||
|
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||||
|
rHash: testnetChainHash,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different chain, local-optional remote-required",
|
||||||
|
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||||
|
lHash: testnetChainHash,
|
||||||
|
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||||
|
rHash: mainnetChainHash,
|
||||||
|
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different chain, local-required remote-optional",
|
||||||
|
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||||
|
lHash: testnetChainHash,
|
||||||
|
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||||
|
rHash: mainnetChainHash,
|
||||||
|
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same chain, remote-unknown-required",
|
||||||
|
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional),
|
||||||
|
lHash: testnetChainHash,
|
||||||
|
rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired),
|
||||||
|
rHash: testnetChainHash,
|
||||||
|
expErr: wtwire.NewErrUnknownRequiredFeatures(
|
||||||
|
lnwire.GossipQueriesRequired,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCheckRemoteInit asserts the behavior of CheckRemoteInit when called with
|
||||||
|
// the remote party's Init message and the default wtwire.Features. We assert
|
||||||
|
// the validity of advertised features from the perspective of both client and
|
||||||
|
// server, as well as failure cases such as differing chain hashes or unknown
|
||||||
|
// required features.
|
||||||
|
func TestCheckRemoteInit(t *testing.T) {
|
||||||
|
for _, test := range checkRemoteInitTests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
testCheckRemoteInit(t, test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCheckRemoteInit(t *testing.T, test checkRemoteInitTest) {
|
||||||
|
localInit := wtwire.NewInitMessage(test.lFeatures, test.lHash)
|
||||||
|
remoteInit := wtwire.NewInitMessage(test.rFeatures, test.rHash)
|
||||||
|
|
||||||
|
err := localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// Both non-nil, pass.
|
||||||
|
case err == nil && test.expErr == nil:
|
||||||
|
return
|
||||||
|
|
||||||
|
// One is nil and one is non-nil, fail.
|
||||||
|
default:
|
||||||
|
t.Fatalf("error mismatch, want: %v, got: %v", test.expErr, err)
|
||||||
|
|
||||||
|
// Both non-nil, assert same error type.
|
||||||
|
case err != nil && test.expErr != nil:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare error strings to assert same type.
|
||||||
|
if err.Error() != test.expErr.Error() {
|
||||||
|
t.Fatalf("error mismatch, want: %v, got: %v",
|
||||||
|
test.expErr.Error(), err.Error())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user