Merge pull request #1512 from cfromknecht/wtwire

[watchtower/wtwire]: Watchtower Wire Messages
This commit is contained in:
Olaoluwa Osuntokun 2018-10-24 20:17:02 -07:00 committed by GitHub
commit 463d352fa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1022 additions and 0 deletions

@ -0,0 +1,78 @@
package wtwire
import (
"io"
"github.com/lightningnetwork/lnd/lnwallet"
)
// CreateSession is sent from a client to tower when to negotiate a session, which
// specifies the total number of updates that can be made, as well as fee rates.
// An update is consumed by uploading an encrypted blob that contains
// information required to sweep a revoked commitment transaction.
type CreateSession struct {
// BlobVersion specifies the blob format that must be used by all
// updates sent under the session key used to negotiate this session.
BlobVersion uint16
// MaxUpdates is the maximum number of updates the watchtower will honor
// for this session.
MaxUpdates uint16
// RewardRate is the fraction of the total balance of the revoked
// commitment that the watchtower is entitled to. This value is
// expressed in millionths of the total balance.
RewardRate uint32
// SweepFeeRate expresses the intended fee rate to be used when
// constructing the justice transaction. All sweep transactions created
// for this session must use this value during construction, and the
// signatures must implicitly commit to the resulting output values.
SweepFeeRate lnwallet.SatPerKWeight
}
// A compile time check to ensure CreateSession implements the wtwire.Message
// interface.
var _ Message = (*CreateSession)(nil)
// Decode deserializes a serialized CreateSession message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *CreateSession) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&m.BlobVersion,
&m.MaxUpdates,
&m.RewardRate,
&m.SweepFeeRate,
)
}
// Encode serializes the target CreateSession into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the wtwire.Message interface.
func (m *CreateSession) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
m.BlobVersion,
m.MaxUpdates,
m.RewardRate,
m.SweepFeeRate,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the wtwire.Message interface.
func (m *CreateSession) MsgType() MessageType {
return MsgCreateSession
}
// MaxPayloadLength returns the maximum allowed payload size for a CreateSession
// complete message observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *CreateSession) MaxPayloadLength(uint32) uint32 {
return 16
}

@ -0,0 +1,91 @@
package wtwire
import "io"
// CreateSessionCode is an error code returned by a watchtower in response to a
// CreateSession message. The code directs the client in interpreting the payload
// in the reply.
type CreateSessionCode = ErrorCode
const (
// CreateSessionCodeAlreadyExists is returned when a session is already
// active for the public key used to connect to the watchtower. The
// response includes the serialized reward address in case the original
// reply was never received and/or processed by the client.
CreateSessionCodeAlreadyExists CreateSessionCode = 60
// CreateSessionCodeRejectRejectMaxUpdates the tower rejected the maximum
// number of state updates proposed by the client.
CreateSessionCodeRejectRejectMaxUpdates CreateSessionCode = 61
// CreateSessionCodeRejectRewardRate the tower rejected the reward rate
// proposed by the client.
CreateSessionCodeRejectRewardRate CreateSessionCode = 62
// CreateSessionCodeRejectSweepFeeRate the tower rejected the sweep fee
// rate proposed by the client.
CreateSessionCodeRejectSweepFeeRate CreateSessionCode = 63
)
// MaxCreateSessionReplyDataLength is the maximum size of the Data payload
// returned in a CreateSessionReply message. This does not include the length of
// the Data field, which is a varint up to 3 bytes in size.
const MaxCreateSessionReplyDataLength = 1024
// CreateSessionReply is a message sent from watchtower to client in response to a
// CreateSession message, and signals either an acceptance or rejection of the
// proposed session parameters.
type CreateSessionReply struct {
// Code will be non-zero if the watchtower rejected the session init.
Code CreateSessionCode
// Data is a byte slice returned the caller of the message, and is to be
// interpreted according to the error Code. When the response is
// CreateSessionCodeOK, data encodes the reward address to be included in
// any sweep transactions if the reward is not dusty. Otherwise, it may
// encode the watchtowers configured parameters for any policy
// rejections.
Data []byte
}
// A compile time check to ensure CreateSessionReply implements the wtwire.Message
// interface.
var _ Message = (*CreateSessionReply)(nil)
// Decode deserializes a serialized CreateSessionReply message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&m.Code,
&m.Data,
)
}
// Encode serializes the target CreateSessionReply into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the wtwire.Message interface.
func (m *CreateSessionReply) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
m.Code,
m.Data,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the wtwire.Message interface.
func (m *CreateSessionReply) MsgType() MessageType {
return MsgCreateSessionReply
}
// MaxPayloadLength returns the maximum allowed payload size for a CreateSessionReply
// complete message observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *CreateSessionReply) MaxPayloadLength(uint32) uint32 {
return 2 + 3 + MaxCreateSessionReplyDataLength
}

@ -0,0 +1,62 @@
package wtwire
import "io"
// Error is a generic error message that can be sent to a client if a request
// fails outside of prescribed protocol errors. Typically this would be followed
// by the server disconnecting the client, and so can be useful to transfering
// the exact reason.
type Error struct {
// Code specifies the error code encountered by the server.
Code ErrorCode
// Data encodes a payload whose contents can be interpreted by the
// client in response to the error code.
Data []byte
}
// NewError returns an freshly-initialized Error message.
func NewError() *Error {
return &Error{}
}
// A compile time check to ensure Error implements the wtwire.Message interface.
var _ Message = (*Error)(nil)
// Decode deserializes a serialized Error message stored in the passed io.Reader
// observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (e *Error) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&e.Code,
&e.Data,
)
}
// Encode serializes the target Error into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the wtwire.Message interface.
func (e *Error) Encode(w io.Writer, prver uint32) error {
return WriteElements(w,
e.Code,
e.Data,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the wtwire.Message interface.
func (e *Error) MsgType() MessageType {
return MsgError
}
// MaxPayloadLength returns the maximum allowed payload size for a Error
// complete message observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (e *Error) MaxPayloadLength(uint32) uint32 {
return MaxMessagePayload
}

@ -0,0 +1,20 @@
package wtwire
// ErrorCode represents a generic error code used when replying to watchtower
// clients. Specific reply messages may extend the ErrorCode primitive and add
// custom codes, so long as they don't collide with the generic error codes..
type ErrorCode uint16
const (
// CodeOK signals that the request was successfully processed by the
// watchtower.
CodeOK ErrorCode = 0
// CodeTemporaryFailure alerts the client that the watchtower is
// temporarily unavailable, but that it may try again at a later time.
CodeTemporaryFailure ErrorCode = 40
// CodePermanentFailure alerts the client that the watchtower has
// permanently failed, and further communication should be avoided.
CodePermanentFailure ErrorCode = 50
)

@ -0,0 +1,26 @@
package wtwire
import "github.com/lightningnetwork/lnd/lnwire"
// GlobalFeatures holds the globally advertised feature bits understood by
// watchtower implementations.
var GlobalFeatures map[lnwire.FeatureBit]string
// LocalFeatures holds the locally advertised feature bits understood by
// watchtower implementations.
var LocalFeatures = map[lnwire.FeatureBit]string{
WtSessionsRequired: "wt-sessions-required",
WtSessionsOptional: "wt-sessions-optional",
}
const (
// WtSessionsRequired specifies that the advertising node requires the
// remote party to understand the protocol for creating and updating
// watchtower sessions.
WtSessionsRequired lnwire.FeatureBit = 8
// WtSessionsOptional specifies that the advertising node can support
// a remote party who understand the protocol for creating and updating
// watchtower sessions.
WtSessionsOptional lnwire.FeatureBit = 9
)

30
watchtower/wtwire/init.go Normal file

@ -0,0 +1,30 @@
package wtwire
import "github.com/lightningnetwork/lnd/lnwire"
// Init is the first message sent over the watchtower wire protocol, and
// specifies features and level of requiredness maintained by the sending node.
// The watchtower Init message is identical to the LN Init message, except it
// uses a different message type to ensure the two are not conflated.
type Init struct {
*lnwire.Init
}
// NewInitMessage generates a new Init message from raw global and local feature
// vectors.
func NewInitMessage(gf, lf *lnwire.RawFeatureVector) *Init {
return &Init{
Init: lnwire.NewInitMessage(gf, lf),
}
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the wtwire.Message interface.
func (msg *Init) MsgType() MessageType {
return MsgInit
}
// A compile-time constraint to ensure Init implements the Message interface.
var _ Message = (*Init)(nil)

@ -0,0 +1,179 @@
package wtwire
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/lightningnetwork/lnd/lnwire"
)
// MaxMessagePayload is the maximum bytes a message can be regardless of other
// individual limits imposed by messages themselves.
const MaxMessagePayload = 65535 // 65KB
// MessageType is the unique 2 byte big-endian integer that indicates the type
// of message on the wire. All messages have a very simple header which
// consists simply of 2-byte message type. We omit a length field, and checksum
// as the Watchtower Protocol is intended to be encapsulated within a
// confidential+authenticated cryptographic messaging protocol.
type MessageType uint16
// The currently defined message types within this current version of the
// Watchtower protocol.
const (
// MsgInit identifies an encoded Init message.
MsgInit MessageType = 300
// MsgError identifies an encoded Error message.
MsgError = 301
// MsgCreateSession identifies an encoded CreateSession message.
MsgCreateSession MessageType = 302
// MsgCreateSessionReply identifies an encoded CreateSessionReply message.
MsgCreateSessionReply MessageType = 303
// MsgStateUpdate identifies an encoded StateUpdate message.
MsgStateUpdate MessageType = 304
// MsgStateUpdateReply identifies an encoded StateUpdateReply message.
MsgStateUpdateReply MessageType = 305
)
// String returns a human readable description of the message type.
func (m MessageType) String() string {
switch m {
case MsgInit:
return "Init"
case MsgCreateSession:
return "MsgCreateSession"
case MsgCreateSessionReply:
return "MsgCreateSessionReply"
case MsgStateUpdate:
return "MsgStateUpdate"
case MsgStateUpdateReply:
return "MsgStateUpdateReply"
case MsgError:
return "Error"
default:
return "<unknown>"
}
}
// Serializable is an interface which defines a lightning wire serializable
// object.
type Serializable = lnwire.Serializable
// Message is an interface that defines a lightning wire protocol message. The
// interface is general in order to allow implementing types full control over
// the representation of its data.
type Message interface {
Serializable
// MsgType returns a MessageType that uniquely identifies the message to
// be encoded.
MsgType() MessageType
// MaxMessagePayload is the maximum serialized length that a particular
// message type can take.
MaxPayloadLength(uint32) uint32
}
// makeEmptyMessage creates a new empty message of the proper concrete type
// based on the passed message type.
func makeEmptyMessage(msgType MessageType) (Message, error) {
var msg Message
switch msgType {
case MsgInit:
msg = &Init{&lnwire.Init{}}
case MsgCreateSession:
msg = &CreateSession{}
case MsgCreateSessionReply:
msg = &CreateSessionReply{}
case MsgStateUpdate:
msg = &StateUpdate{}
case MsgStateUpdateReply:
msg = &StateUpdateReply{}
case MsgError:
msg = &Error{}
default:
return nil, fmt.Errorf("unknown message type [%d]", msgType)
}
return msg, nil
}
// WriteMessage writes a lightning Message to w including the necessary header
// information and returns the number of bytes written.
func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
totalBytes := 0
// Encode the message payload itself into a temporary buffer.
// TODO(roasbeef): create buffer pool
var bw bytes.Buffer
if err := msg.Encode(&bw, pver); err != nil {
return totalBytes, err
}
payload := bw.Bytes()
lenp := len(payload)
// Enforce maximum overall message payload.
if lenp > MaxMessagePayload {
return totalBytes, fmt.Errorf("message payload is too large - "+
"encoded %d bytes, but maximum message payload is %d bytes",
lenp, MaxMessagePayload)
}
// Enforce maximum message payload on the message type.
mpl := msg.MaxPayloadLength(pver)
if uint32(lenp) > mpl {
return totalBytes, fmt.Errorf("message payload is too large - "+
"encoded %d bytes, but maximum message payload of "+
"type %v is %d bytes", lenp, msg.MsgType(), mpl)
}
// With the initial sanity checks complete, we'll now write out the
// message type itself.
var mType [2]byte
binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType()))
n, err := w.Write(mType[:])
totalBytes += n
if err != nil {
return totalBytes, err
}
// With the message type written, we'll now write out the raw payload
// itself.
n, err = w.Write(payload)
totalBytes += n
return totalBytes, err
}
// ReadMessage reads, validates, and parses the next Watchtower message from r
// for the provided protocol version.
func ReadMessage(r io.Reader, pver uint32) (Message, error) {
// First, we'll read out the first two bytes of the message so we can
// create the proper empty message.
var mType [2]byte
if _, err := io.ReadFull(r, mType[:]); err != nil {
return nil, err
}
msgType := MessageType(binary.BigEndian.Uint16(mType[:]))
// Now that we know the target message type, we can create the proper
// empty message type and decode the message into it.
msg, err := makeEmptyMessage(msgType)
if err != nil {
return nil, err
}
if err := msg.Decode(r, pver); err != nil {
return nil, err
}
return msg, nil
}

@ -0,0 +1,88 @@
package wtwire
import "io"
// StateUpdate transmits an encrypted state update from the client to the
// watchtower. Each state update is tied to particular session, identified by
// the client's brontine key used to make the request.
type StateUpdate struct {
// SeqNum is a 1-indexed, monotonically incrementing sequence number.
// This number represents to the client's expected sequence number when
// sending updates sent to the watchtower. This value must always be
// less or equal than the negotiated MaxUpdates for the session, and
// greater than the LastApplied sent in the same message.
SeqNum uint16
// LastApplied echos the LastApplied value returned from watchtower,
// allowing the tower to detect faulty clients. This allow provides a
// feedback mechanism for the tower if updates are allowed to stream in
// an async fashion.
LastApplied uint16
// IsComplete is 1 if the watchtower should close the connection after
// responding, and 0 otherwise.
IsComplete uint8
// Hint is the 16-byte prefix of the revoked commitment transaction ID
// for which the encrypted blob can exact justice.
Hint [16]byte
// EncryptedBlob is the serialized ciphertext containing all necessary
// information to sweep the commitment transaction corresponding to the
// Hint. The ciphertext is to be encrypted using the full transaction ID
// of the revoked commitment transaction.
//
// The plaintext MUST be encoded using the negotiated Version for
// this session. In addition, the signatures must be computed over a
// sweep transaction honoring the decided SweepFeeRate, RewardRate, and
// (possibly) reward address returned in the SessionInitReply.
EncryptedBlob []byte
}
// A compile time check to ensure StateUpdate implements the wtwire.Message
// interface.
var _ Message = (*StateUpdate)(nil)
// Decode deserializes a serialized StateUpdate message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *StateUpdate) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&m.SeqNum,
&m.LastApplied,
&m.IsComplete,
&m.Hint,
&m.EncryptedBlob,
)
}
// Encode serializes the target StateUpdate into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the wtwire.Message interface.
func (m *StateUpdate) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
m.SeqNum,
m.LastApplied,
m.IsComplete,
m.Hint,
m.EncryptedBlob,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the wtwire.Message interface.
func (m *StateUpdate) MsgType() MessageType {
return MsgStateUpdate
}
// MaxPayloadLength returns the maximum allowed payload size for a StateUpdate
// complete message observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *StateUpdate) MaxPayloadLength(uint32) uint32 {
return MaxMessagePayload
}

@ -0,0 +1,86 @@
package wtwire
import "io"
// StateUpdateCode is an error code returned by a watchtower in response to a
// StateUpdate message.
type StateUpdateCode = ErrorCode
const (
// StateUpdateCodeClientBehind signals that the client's sequence number
// is behind what the watchtower expects based on its LastApplied. This
// error should cause the client to record the LastApplied field in the
// response, and initiate another attempt with the proper sequence
// number.
//
// NOTE: Repeated occurrences of this could be interpreted as an attempt
// to siphon state updates from the client. If the client believes it
// is not violating the protocol, this could be grounds to blacklist
// this tower from future session negotiation.
StateUpdateCodeClientBehind StateUpdateCode = 70
// StateUpdateCodeMaxUpdatesExceeded signals that the client tried to
// send a sequence number beyond the negotiated MaxUpdates of the
// session.
StateUpdateCodeMaxUpdatesExceeded StateUpdateCode = 71
// StateUpdateCodeSeqNumOutOfOrder signals the client sent an update
// that does not follow the required incremental monotonicity required
// by the tower.
StateUpdateCodeSeqNumOutOfOrder StateUpdateCode = 72
)
// StateUpdateReply is a message sent from watchtower to client in response to a
// StateUpdate message, and signals either an acceptance or rejection of the
// proposed state update.
type StateUpdateReply struct {
// Code will be non-zero if the watchtower rejected the state update.
Code StateUpdateCode
// LastApplied returns the sequence number of the last accepted update
// known to the watchtower. If the update was successful, this value
// should be the sequence number of the last update sent.
LastApplied uint16
}
// A compile time check to ensure StateUpdateReply implements the wtwire.Message
// interface.
var _ Message = (*StateUpdateReply)(nil)
// Decode deserializes a serialized StateUpdateReply message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (t *StateUpdateReply) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&t.Code,
&t.LastApplied,
)
}
// Encode serializes the target StateUpdateReply into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the wtwire.Message interface.
func (t *StateUpdateReply) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
t.Code,
t.LastApplied,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the wtwire.Message interface.
func (t *StateUpdateReply) MsgType() MessageType {
return MsgStateUpdateReply
}
// MaxPayloadLength returns the maximum allowed payload size for a
// StateUpdateReply complete message observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (t *StateUpdateReply) MaxPayloadLength(uint32) uint32 {
return 4
}

210
watchtower/wtwire/wtwire.go Normal file

@ -0,0 +1,210 @@
package wtwire
import (
"encoding/binary"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnwallet"
)
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for the wire protocol. The passed
// io.Writer should be backed by an appropriately sized byte slice, or be able
// to dynamically expand to accommodate additional data.
func WriteElement(w io.Writer, element interface{}) error {
switch e := element.(type) {
case uint8:
var b [1]byte
b[0] = e
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint16:
var b [2]byte
binary.BigEndian.PutUint16(b[:], e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint32:
var b [4]byte
binary.BigEndian.PutUint32(b[:], e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint64:
var b [8]byte
binary.BigEndian.PutUint64(b[:], e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case [16]byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case [32]byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case [33]byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case []byte:
if err := wire.WriteVarBytes(w, 0, e); err != nil {
return err
}
case lnwallet.SatPerKWeight:
var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(e))
if _, err := w.Write(b[:]); err != nil {
return err
}
case ErrorCode:
var b [2]byte
binary.BigEndian.PutUint16(b[:], uint16(e))
if _, err := w.Write(b[:]); err != nil {
return err
}
case *btcec.PublicKey:
if e == nil {
return fmt.Errorf("cannot write nil pubkey")
}
var b [33]byte
serializedPubkey := e.SerializeCompressed()
copy(b[:], serializedPubkey)
if _, err := w.Write(b[:]); err != nil {
return err
}
default:
return fmt.Errorf("Unknown type in WriteElement: %T", e)
}
return nil
}
// WriteElements is writes each element in the elements slice to the passed
// io.Writer using WriteElement.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
err := WriteElement(w, element)
if err != nil {
return err
}
}
return nil
}
// ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of lnwire.
func ReadElement(r io.Reader, element interface{}) error {
switch e := element.(type) {
case *uint8:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = b[0]
case *uint16:
var b [2]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = binary.BigEndian.Uint16(b[:])
case *uint32:
var b [4]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = binary.BigEndian.Uint32(b[:])
case *uint64:
var b [8]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = binary.BigEndian.Uint64(b[:])
case *[16]byte:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *[32]byte:
if _, err := io.ReadFull(r, e[:]); err != nil {
}
case *[33]byte:
if _, err := io.ReadFull(r, e[:]); err != nil {
}
case *[]byte:
bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte")
if err != nil {
return err
}
*e = bytes
case *lnwallet.SatPerKWeight:
var b [8]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = lnwallet.SatPerKWeight(binary.BigEndian.Uint64(b[:]))
case *ErrorCode:
var b [2]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = ErrorCode(binary.BigEndian.Uint16(b[:]))
case **btcec.PublicKey:
var b [btcec.PubKeyBytesLenCompressed]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
pubKey, err := btcec.ParsePubKey(b[:], btcec.S256())
if err != nil {
return err
}
*e = pubKey
default:
return fmt.Errorf("Unknown type in ReadElement: %T", e)
}
return nil
}
// ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the ReadElement
// function.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
err := ReadElement(r, element)
if err != nil {
return err
}
}
return nil
}

@ -0,0 +1,152 @@
package wtwire_test
import (
"bytes"
"math/rand"
"reflect"
"testing"
"testing/quick"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
func randRawFeatureVector(r *rand.Rand) *lnwire.RawFeatureVector {
featureVec := lnwire.NewRawFeatureVector()
for i := 0; i < 10000; i++ {
if r.Int31n(2) == 0 {
featureVec.Set(lnwire.FeatureBit(i))
}
}
return featureVec
}
// TestWatchtowerWireProtocol uses the testing/quick package to create a series
// of fuzz tests to attempt to break a primary scenario which is implemented as
// property based testing scenario.
func TestWatchtowerWireProtocol(t *testing.T) {
t.Parallel()
// mainScenario is the primary test that will programmatically be
// executed for all registered wire messages. The quick-checker within
// testing/quick will attempt to find an input to this function, s.t
// the function returns false, if so then we've found an input that
// violates our model of the system.
mainScenario := func(msg wtwire.Message) bool {
// Give a new message, we'll serialize the message into a new
// bytes buffer.
var b bytes.Buffer
if _, err := wtwire.WriteMessage(&b, msg, 0); err != nil {
t.Fatalf("unable to write msg: %v", err)
return false
}
// Next, we'll ensure that the serialized payload (subtracting
// the 2 bytes for the message type) is _below_ the specified
// max payload size for this message.
payloadLen := uint32(b.Len()) - 2
if payloadLen > msg.MaxPayloadLength(0) {
t.Fatalf("msg payload constraint violated: %v > %v",
payloadLen, msg.MaxPayloadLength(0))
return false
}
// Finally, we'll deserialize the message from the written
// buffer, and finally assert that the messages are equal.
newMsg, err := wtwire.ReadMessage(&b, 0)
if err != nil {
t.Fatalf("unable to read msg: %v", err)
return false
}
if !reflect.DeepEqual(msg, newMsg) {
t.Fatalf("messages don't match after re-encoding: %v "+
"vs %v", spew.Sdump(msg), spew.Sdump(newMsg))
return false
}
return true
}
customTypeGen := map[wtwire.MessageType]func([]reflect.Value, *rand.Rand){
wtwire.MsgInit: func(v []reflect.Value, r *rand.Rand) {
req := wtwire.NewInitMessage(
randRawFeatureVector(r),
randRawFeatureVector(r),
)
v[0] = reflect.ValueOf(*req)
},
}
// With the above types defined, we'll now generate a slice of
// scenarios to feed into quick.Check. The function scans in input
// space of the target function under test, so we'll need to create a
// series of wrapper functions to force it to iterate over the target
// types, but re-use the mainScenario defined above.
tests := []struct {
msgType wtwire.MessageType
scenario interface{}
}{
{
msgType: wtwire.MsgInit,
scenario: func(m wtwire.Init) bool {
return mainScenario(&m)
},
},
{
msgType: wtwire.MsgCreateSession,
scenario: func(m wtwire.CreateSession) bool {
return mainScenario(&m)
},
},
{
msgType: wtwire.MsgCreateSessionReply,
scenario: func(m wtwire.CreateSessionReply) bool {
return mainScenario(&m)
},
},
{
msgType: wtwire.MsgStateUpdate,
scenario: func(m wtwire.StateUpdate) bool {
return mainScenario(&m)
},
},
{
msgType: wtwire.MsgStateUpdateReply,
scenario: func(m wtwire.StateUpdateReply) bool {
return mainScenario(&m)
},
},
{
msgType: wtwire.MsgError,
scenario: func(m wtwire.Error) bool {
return mainScenario(&m)
},
},
}
for _, test := range tests {
var config *quick.Config
// If the type defined is within the custom type gen map above,
// the we'll modify the default config to use this Value
// function that knows how to generate the proper types.
if valueGen, ok := customTypeGen[test.msgType]; ok {
config = &quick.Config{
Values: valueGen,
}
}
t.Logf("Running fuzz tests for msgType=%v", test.msgType)
if err := quick.Check(test.scenario, config); err != nil {
t.Fatalf("fuzz checks for msg=%v failed: %v",
test.msgType, err)
}
}
}
func init() {
rand.Seed(time.Now().Unix())
}