watchtower/multi: send connection features + chain hash in Init

This commit is contained in:
Conner Fromknecht 2019-02-06 20:06:44 -08:00
parent eaea92e2cf
commit 4dbade64dd
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
6 changed files with 149 additions and 67 deletions

@ -10,6 +10,7 @@ import (
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/connmgr"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcutil"
@ -51,6 +52,9 @@ type Config struct {
// NewAddress is used to generate reward addresses, where a cut of
// successfully sent funds can be received.
NewAddress func() (btcutil.Address, error)
// ChainHash identifies the network that the server is watching.
ChainHash chainhash.Hash
}
// Server houses the state required to handle watchtower peers. It's primary job
@ -68,7 +72,7 @@ type Server struct {
clients map[wtdb.SessionID]Peer
globalFeatures *lnwire.RawFeatureVector
localFeatures *lnwire.RawFeatureVector
connFeatures *lnwire.RawFeatureVector
wg sync.WaitGroup
quit chan struct{}
@ -78,7 +82,7 @@ type Server struct {
// clients connecting to the listener addresses, and allows them to open
// sessions and send state updates.
func New(cfg *Config) (*Server, error) {
localFeatures := lnwire.NewRawFeatureVector(
connFeatures := lnwire.NewRawFeatureVector(
wtwire.WtSessionsOptional,
)
@ -86,7 +90,7 @@ func New(cfg *Config) (*Server, error) {
cfg: cfg,
clients: make(map[wtdb.SessionID]Peer),
globalFeatures: lnwire.NewRawFeatureVector(),
localFeatures: localFeatures,
connFeatures: connFeatures,
quit: make(chan struct{}),
}
@ -206,7 +210,7 @@ func (s *Server) handleClient(peer Peer) {
}
localInit := wtwire.NewInitMessage(
s.localFeatures, s.globalFeatures,
s.connFeatures, s.cfg.ChainHash,
)
err = s.sendMessage(peer, localInit)
@ -296,25 +300,19 @@ func (s *Server) handleClient(peer Peer) {
// handleInit accepts the local and remote Init messages, and verifies that the
// client is not requesting any required features that are unknown to the tower.
func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error {
remoteLocalFeatures := lnwire.NewFeatureVector(
remoteInit.LocalFeatures, wtwire.LocalFeatures,
)
remoteGlobalFeatures := lnwire.NewFeatureVector(
remoteInit.GlobalFeatures, wtwire.GlobalFeatures,
)
unknownLocalFeatures := remoteLocalFeatures.UnknownRequiredFeatures()
if len(unknownLocalFeatures) > 0 {
err := fmt.Errorf("Peer set unknown local feature bits: %v",
unknownLocalFeatures)
return err
if localInit.ChainHash != remoteInit.ChainHash {
return fmt.Errorf("Peer chain hash unknown: %x",
remoteInit.ChainHash)
}
unknownGlobalFeatures := remoteGlobalFeatures.UnknownRequiredFeatures()
if len(unknownGlobalFeatures) > 0 {
err := fmt.Errorf("Peer set unknown global feature bits: %v",
unknownGlobalFeatures)
return err
remoteConnFeatures := lnwire.NewFeatureVector(
remoteInit.ConnFeatures, wtwire.LocalFeatures,
)
unknownLocalFeatures := remoteConnFeatures.UnknownRequiredFeatures()
if len(unknownLocalFeatures) > 0 {
return fmt.Errorf("Peer set unknown local feature bits: %v",
unknownLocalFeatures)
}
return nil

@ -27,6 +27,8 @@ var (
)
addrScript, _ = txscript.PayToAddrScript(addr)
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
)
// randPubKey generates a new secp keypair, and returns the public key.
@ -59,6 +61,7 @@ func initServer(t *testing.T, db wtserver.DB,
NewAddress: func() (btcutil.Address, error) {
return addr, nil
},
ChainHash: testnetChainHash,
})
if err != nil {
t.Fatalf("unable to create server: %v", err)
@ -91,7 +94,7 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) {
// Serialize a Init message to be sent by both peers.
init := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
lnwire.NewRawFeatureVector(), testnetChainHash,
)
var b bytes.Buffer
@ -159,7 +162,7 @@ var createSessionTests = []createSessionTestCase{
name: "reject duplicate session create",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
@ -181,7 +184,7 @@ var createSessionTests = []createSessionTestCase{
name: "reject unsupported blob type",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: 0,
@ -279,10 +282,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Valid update sequence, send seqnum == lastapplied as last update.
{
name: "perm fail after sending seqnum equal lastapplied",
initMsg: &wtwire.Init{&lnwire.Init{
LocalFeatures: lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(),
}},
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 3,
@ -309,10 +312,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Send update that skips next expected sequence number.
{
name: "skip sequence number",
initMsg: &wtwire.Init{&lnwire.Init{
LocalFeatures: lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(),
}},
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 4,
@ -333,10 +336,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Send update that reverts to older sequence number.
{
name: "revert to older seqnum",
initMsg: &wtwire.Init{&lnwire.Init{
LocalFeatures: lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(),
}},
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 4,
@ -361,10 +364,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Send update echoing a last applied that is lower than previous value.
{
name: "revert to older lastapplied",
initMsg: &wtwire.Init{&lnwire.Init{
LocalFeatures: lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(),
}},
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 4,
@ -389,10 +392,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Client echos last applied as they are received.
{
name: "resume after disconnect",
initMsg: &wtwire.Init{&lnwire.Init{
LocalFeatures: lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(),
}},
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 4,
@ -419,10 +422,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Client doesn't echo last applied until last message.
{
name: "resume after disconnect lagging lastapplied",
initMsg: &wtwire.Init{&lnwire.Init{
LocalFeatures: lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(),
}},
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 4,
@ -448,10 +451,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Send update with sequence number that exceeds MaxUpdates.
{
name: "seqnum exceed maxupdates",
initMsg: &wtwire.Init{&lnwire.Init{
LocalFeatures: lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(),
}},
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 3,
@ -478,10 +481,10 @@ var stateUpdateTests = []stateUpdateTestCase{
// Ensure sequence number 0 causes permanent failure.
{
name: "perm fail after seqnum 0",
initMsg: &wtwire.Init{&lnwire.Init{
LocalFeatures: lnwire.NewRawFeatureVector(),
GlobalFeatures: lnwire.NewRawFeatureVector(),
}},
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 3,

@ -1,23 +1,59 @@
package wtwire
import "github.com/lightningnetwork/lnd/lnwire"
import (
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/lnwire"
)
// Init is the first message sent over the watchtower wire protocol, and
// specifies features and level of requiredness maintained by the sending node.
// The watchtower Init message is identical to the LN Init message, except it
// uses a different message type to ensure the two are not conflated.
// specifies connection features bits and level of requiredness maintained by
// the sending node. The Init message also sends the chain hash identifying the
// network that the sender is on.
type Init struct {
*lnwire.Init
// ConnFeatures are the feature bits being advertised for the duration
// of a single connection with a peer.
ConnFeatures *lnwire.RawFeatureVector
// ChainHash is the genesis hash of the chain that the advertiser claims
// to be on.
ChainHash chainhash.Hash
}
// NewInitMessage generates a new Init message from raw global and local feature
// vectors.
func NewInitMessage(gf, lf *lnwire.RawFeatureVector) *Init {
// NewInitMessage generates a new Init message from a raw connection feature
// vector and chain hash.
func NewInitMessage(connFeatures *lnwire.RawFeatureVector,
chainHash chainhash.Hash) *Init {
return &Init{
Init: lnwire.NewInitMessage(gf, lf),
ConnFeatures: connFeatures,
ChainHash: chainHash,
}
}
// Encode serializes the target Init into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the wtwire.Message interface.
func (msg *Init) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
msg.ConnFeatures,
msg.ChainHash,
)
}
// Decode deserializes a serialized Init message stored in the passed io.Reader
// observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (msg *Init) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&msg.ConnFeatures,
&msg.ChainHash,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
@ -26,5 +62,13 @@ func (msg *Init) MsgType() MessageType {
return MsgInit
}
// MaxPayloadLength returns the maximum allowed payload size for an Init
// complete message observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (msg *Init) MaxPayloadLength(uint32) uint32 {
return MaxMessagePayload
}
// A compile-time constraint to ensure Init implements the Message interface.
var _ Message = (*Init)(nil)

@ -88,7 +88,7 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
switch msgType {
case MsgInit:
msg = &Init{&lnwire.Init{}}
msg = &Init{}
case MsgCreateSession:
msg = &CreateSession{}
case MsgCreateSessionReply:

@ -6,8 +6,10 @@ import (
"io"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
)
@ -86,6 +88,20 @@ func WriteElement(w io.Writer, element interface{}) error {
return err
}
case chainhash.Hash:
if _, err := w.Write(e[:]); err != nil {
return err
}
case *lnwire.RawFeatureVector:
if e == nil {
return fmt.Errorf("cannot write nil feature vector")
}
if err := e.Encode(w); err != nil {
return err
}
case *btcec.PublicKey:
if e == nil {
return fmt.Errorf("cannot write nil pubkey")
@ -192,6 +208,20 @@ func ReadElement(r io.Reader, element interface{}) error {
}
*e = ErrorCode(binary.BigEndian.Uint16(b[:]))
case *chainhash.Hash:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case **lnwire.RawFeatureVector:
f := lnwire.NewRawFeatureVector()
err := f.Decode(r)
if err != nil {
return err
}
*e = f
case **btcec.PublicKey:
var b [btcec.PubKeyBytesLenCompressed]byte
if _, err := io.ReadFull(r, b[:]); err != nil {

@ -8,6 +8,7 @@ import (
"testing/quick"
"time"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
@ -23,6 +24,12 @@ func randRawFeatureVector(r *rand.Rand) *lnwire.RawFeatureVector {
return featureVec
}
func randChainHash(r *rand.Rand) chainhash.Hash {
var hash chainhash.Hash
r.Read(hash[:])
return hash
}
// TestWatchtowerWireProtocol uses the testing/quick package to create a series
// of fuzz tests to attempt to break a primary scenario which is implemented as
// property based testing scenario.
@ -73,7 +80,7 @@ func TestWatchtowerWireProtocol(t *testing.T) {
wtwire.MsgInit: func(v []reflect.Value, r *rand.Rand) {
req := wtwire.NewInitMessage(
randRawFeatureVector(r),
randRawFeatureVector(r),
randChainHash(r),
)
v[0] = reflect.ValueOf(*req)