diff --git a/watchtower/wtwire/create_session.go b/watchtower/wtwire/create_session.go new file mode 100644 index 00000000..8ee2c906 --- /dev/null +++ b/watchtower/wtwire/create_session.go @@ -0,0 +1,78 @@ +package wtwire + +import ( + "io" + + "github.com/lightningnetwork/lnd/lnwallet" +) + +// CreateSession is sent from a client to tower when to negotiate a session, which +// specifies the total number of updates that can be made, as well as fee rates. +// An update is consumed by uploading an encrypted blob that contains +// information required to sweep a revoked commitment transaction. +type CreateSession struct { + // BlobVersion specifies the blob format that must be used by all + // updates sent under the session key used to negotiate this session. + BlobVersion uint16 + + // MaxUpdates is the maximum number of updates the watchtower will honor + // for this session. + MaxUpdates uint16 + + // RewardRate is the fraction of the total balance of the revoked + // commitment that the watchtower is entitled to. This value is + // expressed in millionths of the total balance. + RewardRate uint32 + + // SweepFeeRate expresses the intended fee rate to be used when + // constructing the justice transaction. All sweep transactions created + // for this session must use this value during construction, and the + // signatures must implicitly commit to the resulting output values. + SweepFeeRate lnwallet.SatPerKWeight +} + +// A compile time check to ensure CreateSession implements the wtwire.Message +// interface. +var _ Message = (*CreateSession)(nil) + +// Decode deserializes a serialized CreateSession message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *CreateSession) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &m.BlobVersion, + &m.MaxUpdates, + &m.RewardRate, + &m.SweepFeeRate, + ) +} + +// Encode serializes the target CreateSession into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the wtwire.Message interface. +func (m *CreateSession) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + m.BlobVersion, + m.MaxUpdates, + m.RewardRate, + m.SweepFeeRate, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the wtwire.Message interface. +func (m *CreateSession) MsgType() MessageType { + return MsgCreateSession +} + +// MaxPayloadLength returns the maximum allowed payload size for a CreateSession +// complete message observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *CreateSession) MaxPayloadLength(uint32) uint32 { + return 16 +} diff --git a/watchtower/wtwire/create_session_reply.go b/watchtower/wtwire/create_session_reply.go new file mode 100644 index 00000000..b1224cb8 --- /dev/null +++ b/watchtower/wtwire/create_session_reply.go @@ -0,0 +1,91 @@ +package wtwire + +import "io" + +// CreateSessionCode is an error code returned by a watchtower in response to a +// CreateSession message. The code directs the client in interpreting the payload +// in the reply. +type CreateSessionCode = ErrorCode + +const ( + // CreateSessionCodeAlreadyExists is returned when a session is already + // active for the public key used to connect to the watchtower. The + // response includes the serialized reward address in case the original + // reply was never received and/or processed by the client. + CreateSessionCodeAlreadyExists CreateSessionCode = 60 + + // CreateSessionCodeRejectRejectMaxUpdates the tower rejected the maximum + // number of state updates proposed by the client. + CreateSessionCodeRejectRejectMaxUpdates CreateSessionCode = 61 + + // CreateSessionCodeRejectRewardRate the tower rejected the reward rate + // proposed by the client. + CreateSessionCodeRejectRewardRate CreateSessionCode = 62 + + // CreateSessionCodeRejectSweepFeeRate the tower rejected the sweep fee + // rate proposed by the client. + CreateSessionCodeRejectSweepFeeRate CreateSessionCode = 63 +) + +// MaxCreateSessionReplyDataLength is the maximum size of the Data payload +// returned in a CreateSessionReply message. This does not include the length of +// the Data field, which is a varint up to 3 bytes in size. +const MaxCreateSessionReplyDataLength = 1024 + +// CreateSessionReply is a message sent from watchtower to client in response to a +// CreateSession message, and signals either an acceptance or rejection of the +// proposed session parameters. +type CreateSessionReply struct { + // Code will be non-zero if the watchtower rejected the session init. + Code CreateSessionCode + + // Data is a byte slice returned the caller of the message, and is to be + // interpreted according to the error Code. When the response is + // CreateSessionCodeOK, data encodes the reward address to be included in + // any sweep transactions if the reward is not dusty. Otherwise, it may + // encode the watchtowers configured parameters for any policy + // rejections. + Data []byte +} + +// A compile time check to ensure CreateSessionReply implements the wtwire.Message +// interface. +var _ Message = (*CreateSessionReply)(nil) + +// Decode deserializes a serialized CreateSessionReply message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &m.Code, + &m.Data, + ) +} + +// Encode serializes the target CreateSessionReply into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the wtwire.Message interface. +func (m *CreateSessionReply) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + m.Code, + m.Data, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the wtwire.Message interface. +func (m *CreateSessionReply) MsgType() MessageType { + return MsgCreateSessionReply +} + +// MaxPayloadLength returns the maximum allowed payload size for a CreateSessionReply +// complete message observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *CreateSessionReply) MaxPayloadLength(uint32) uint32 { + return 2 + 3 + MaxCreateSessionReplyDataLength +} diff --git a/watchtower/wtwire/error.go b/watchtower/wtwire/error.go new file mode 100644 index 00000000..c8d61fa0 --- /dev/null +++ b/watchtower/wtwire/error.go @@ -0,0 +1,62 @@ +package wtwire + +import "io" + +// Error is a generic error message that can be sent to a client if a request +// fails outside of prescribed protocol errors. Typically this would be followed +// by the server disconnecting the client, and so can be useful to transfering +// the exact reason. +type Error struct { + // Code specifies the error code encountered by the server. + Code ErrorCode + + // Data encodes a payload whose contents can be interpreted by the + // client in response to the error code. + Data []byte +} + +// NewError returns an freshly-initialized Error message. +func NewError() *Error { + return &Error{} +} + +// A compile time check to ensure Error implements the wtwire.Message interface. +var _ Message = (*Error)(nil) + +// Decode deserializes a serialized Error message stored in the passed io.Reader +// observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (e *Error) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &e.Code, + &e.Data, + ) +} + +// Encode serializes the target Error into the passed io.Writer observing the +// protocol version specified. +// +// This is part of the wtwire.Message interface. +func (e *Error) Encode(w io.Writer, prver uint32) error { + return WriteElements(w, + e.Code, + e.Data, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the wtwire.Message interface. +func (e *Error) MsgType() MessageType { + return MsgError +} + +// MaxPayloadLength returns the maximum allowed payload size for a Error +// complete message observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (e *Error) MaxPayloadLength(uint32) uint32 { + return MaxMessagePayload +} diff --git a/watchtower/wtwire/error_code.go b/watchtower/wtwire/error_code.go new file mode 100644 index 00000000..6a441784 --- /dev/null +++ b/watchtower/wtwire/error_code.go @@ -0,0 +1,20 @@ +package wtwire + +// ErrorCode represents a generic error code used when replying to watchtower +// clients. Specific reply messages may extend the ErrorCode primitive and add +// custom codes, so long as they don't collide with the generic error codes.. +type ErrorCode uint16 + +const ( + // CodeOK signals that the request was successfully processed by the + // watchtower. + CodeOK ErrorCode = 0 + + // CodeTemporaryFailure alerts the client that the watchtower is + // temporarily unavailable, but that it may try again at a later time. + CodeTemporaryFailure ErrorCode = 40 + + // CodePermanentFailure alerts the client that the watchtower has + // permanently failed, and further communication should be avoided. + CodePermanentFailure ErrorCode = 50 +) diff --git a/watchtower/wtwire/features.go b/watchtower/wtwire/features.go new file mode 100644 index 00000000..327a8886 --- /dev/null +++ b/watchtower/wtwire/features.go @@ -0,0 +1,26 @@ +package wtwire + +import "github.com/lightningnetwork/lnd/lnwire" + +// GlobalFeatures holds the globally advertised feature bits understood by +// watchtower implementations. +var GlobalFeatures map[lnwire.FeatureBit]string + +// LocalFeatures holds the locally advertised feature bits understood by +// watchtower implementations. +var LocalFeatures = map[lnwire.FeatureBit]string{ + WtSessionsRequired: "wt-sessions-required", + WtSessionsOptional: "wt-sessions-optional", +} + +const ( + // WtSessionsRequired specifies that the advertising node requires the + // remote party to understand the protocol for creating and updating + // watchtower sessions. + WtSessionsRequired lnwire.FeatureBit = 8 + + // WtSessionsOptional specifies that the advertising node can support + // a remote party who understand the protocol for creating and updating + // watchtower sessions. + WtSessionsOptional lnwire.FeatureBit = 9 +) diff --git a/watchtower/wtwire/init.go b/watchtower/wtwire/init.go new file mode 100644 index 00000000..d9056e2b --- /dev/null +++ b/watchtower/wtwire/init.go @@ -0,0 +1,30 @@ +package wtwire + +import "github.com/lightningnetwork/lnd/lnwire" + +// Init is the first message sent over the watchtower wire protocol, and +// specifies features and level of requiredness maintained by the sending node. +// The watchtower Init message is identical to the LN Init message, except it +// uses a different message type to ensure the two are not conflated. +type Init struct { + *lnwire.Init +} + +// NewInitMessage generates a new Init message from raw global and local feature +// vectors. +func NewInitMessage(gf, lf *lnwire.RawFeatureVector) *Init { + return &Init{ + Init: lnwire.NewInitMessage(gf, lf), + } +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the wtwire.Message interface. +func (msg *Init) MsgType() MessageType { + return MsgInit +} + +// A compile-time constraint to ensure Init implements the Message interface. +var _ Message = (*Init)(nil) diff --git a/watchtower/wtwire/message.go b/watchtower/wtwire/message.go new file mode 100644 index 00000000..5daaa6de --- /dev/null +++ b/watchtower/wtwire/message.go @@ -0,0 +1,179 @@ +package wtwire + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/lnwire" +) + +// MaxMessagePayload is the maximum bytes a message can be regardless of other +// individual limits imposed by messages themselves. +const MaxMessagePayload = 65535 // 65KB + +// MessageType is the unique 2 byte big-endian integer that indicates the type +// of message on the wire. All messages have a very simple header which +// consists simply of 2-byte message type. We omit a length field, and checksum +// as the Watchtower Protocol is intended to be encapsulated within a +// confidential+authenticated cryptographic messaging protocol. +type MessageType uint16 + +// The currently defined message types within this current version of the +// Watchtower protocol. +const ( + // MsgInit identifies an encoded Init message. + MsgInit MessageType = 300 + + // MsgError identifies an encoded Error message. + MsgError = 301 + + // MsgCreateSession identifies an encoded CreateSession message. + MsgCreateSession MessageType = 302 + + // MsgCreateSessionReply identifies an encoded CreateSessionReply message. + MsgCreateSessionReply MessageType = 303 + + // MsgStateUpdate identifies an encoded StateUpdate message. + MsgStateUpdate MessageType = 304 + + // MsgStateUpdateReply identifies an encoded StateUpdateReply message. + MsgStateUpdateReply MessageType = 305 +) + +// String returns a human readable description of the message type. +func (m MessageType) String() string { + switch m { + case MsgInit: + return "Init" + case MsgCreateSession: + return "MsgCreateSession" + case MsgCreateSessionReply: + return "MsgCreateSessionReply" + case MsgStateUpdate: + return "MsgStateUpdate" + case MsgStateUpdateReply: + return "MsgStateUpdateReply" + case MsgError: + return "Error" + default: + return "" + } +} + +// Serializable is an interface which defines a lightning wire serializable +// object. +type Serializable = lnwire.Serializable + +// Message is an interface that defines a lightning wire protocol message. The +// interface is general in order to allow implementing types full control over +// the representation of its data. +type Message interface { + Serializable + + // MsgType returns a MessageType that uniquely identifies the message to + // be encoded. + MsgType() MessageType + + // MaxMessagePayload is the maximum serialized length that a particular + // message type can take. + MaxPayloadLength(uint32) uint32 +} + +// makeEmptyMessage creates a new empty message of the proper concrete type +// based on the passed message type. +func makeEmptyMessage(msgType MessageType) (Message, error) { + var msg Message + + switch msgType { + case MsgInit: + msg = &Init{&lnwire.Init{}} + case MsgCreateSession: + msg = &CreateSession{} + case MsgCreateSessionReply: + msg = &CreateSessionReply{} + case MsgStateUpdate: + msg = &StateUpdate{} + case MsgStateUpdateReply: + msg = &StateUpdateReply{} + case MsgError: + msg = &Error{} + default: + return nil, fmt.Errorf("unknown message type [%d]", msgType) + } + + return msg, nil +} + +// WriteMessage writes a lightning Message to w including the necessary header +// information and returns the number of bytes written. +func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) { + totalBytes := 0 + + // Encode the message payload itself into a temporary buffer. + // TODO(roasbeef): create buffer pool + var bw bytes.Buffer + if err := msg.Encode(&bw, pver); err != nil { + return totalBytes, err + } + payload := bw.Bytes() + lenp := len(payload) + + // Enforce maximum overall message payload. + if lenp > MaxMessagePayload { + return totalBytes, fmt.Errorf("message payload is too large - "+ + "encoded %d bytes, but maximum message payload is %d bytes", + lenp, MaxMessagePayload) + } + + // Enforce maximum message payload on the message type. + mpl := msg.MaxPayloadLength(pver) + if uint32(lenp) > mpl { + return totalBytes, fmt.Errorf("message payload is too large - "+ + "encoded %d bytes, but maximum message payload of "+ + "type %v is %d bytes", lenp, msg.MsgType(), mpl) + } + + // With the initial sanity checks complete, we'll now write out the + // message type itself. + var mType [2]byte + binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType())) + n, err := w.Write(mType[:]) + totalBytes += n + if err != nil { + return totalBytes, err + } + + // With the message type written, we'll now write out the raw payload + // itself. + n, err = w.Write(payload) + totalBytes += n + + return totalBytes, err +} + +// ReadMessage reads, validates, and parses the next Watchtower message from r +// for the provided protocol version. +func ReadMessage(r io.Reader, pver uint32) (Message, error) { + // First, we'll read out the first two bytes of the message so we can + // create the proper empty message. + var mType [2]byte + if _, err := io.ReadFull(r, mType[:]); err != nil { + return nil, err + } + + msgType := MessageType(binary.BigEndian.Uint16(mType[:])) + + // Now that we know the target message type, we can create the proper + // empty message type and decode the message into it. + msg, err := makeEmptyMessage(msgType) + if err != nil { + return nil, err + } + if err := msg.Decode(r, pver); err != nil { + return nil, err + } + + return msg, nil +} diff --git a/watchtower/wtwire/state_update.go b/watchtower/wtwire/state_update.go new file mode 100644 index 00000000..ab51ed38 --- /dev/null +++ b/watchtower/wtwire/state_update.go @@ -0,0 +1,88 @@ +package wtwire + +import "io" + +// StateUpdate transmits an encrypted state update from the client to the +// watchtower. Each state update is tied to particular session, identified by +// the client's brontine key used to make the request. +type StateUpdate struct { + // SeqNum is a 1-indexed, monotonically incrementing sequence number. + // This number represents to the client's expected sequence number when + // sending updates sent to the watchtower. This value must always be + // less or equal than the negotiated MaxUpdates for the session, and + // greater than the LastApplied sent in the same message. + SeqNum uint16 + + // LastApplied echos the LastApplied value returned from watchtower, + // allowing the tower to detect faulty clients. This allow provides a + // feedback mechanism for the tower if updates are allowed to stream in + // an async fashion. + LastApplied uint16 + + // IsComplete is 1 if the watchtower should close the connection after + // responding, and 0 otherwise. + IsComplete uint8 + + // Hint is the 16-byte prefix of the revoked commitment transaction ID + // for which the encrypted blob can exact justice. + Hint [16]byte + + // EncryptedBlob is the serialized ciphertext containing all necessary + // information to sweep the commitment transaction corresponding to the + // Hint. The ciphertext is to be encrypted using the full transaction ID + // of the revoked commitment transaction. + // + // The plaintext MUST be encoded using the negotiated Version for + // this session. In addition, the signatures must be computed over a + // sweep transaction honoring the decided SweepFeeRate, RewardRate, and + // (possibly) reward address returned in the SessionInitReply. + EncryptedBlob []byte +} + +// A compile time check to ensure StateUpdate implements the wtwire.Message +// interface. +var _ Message = (*StateUpdate)(nil) + +// Decode deserializes a serialized StateUpdate message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *StateUpdate) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &m.SeqNum, + &m.LastApplied, + &m.IsComplete, + &m.Hint, + &m.EncryptedBlob, + ) +} + +// Encode serializes the target StateUpdate into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the wtwire.Message interface. +func (m *StateUpdate) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + m.SeqNum, + m.LastApplied, + m.IsComplete, + m.Hint, + m.EncryptedBlob, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the wtwire.Message interface. +func (m *StateUpdate) MsgType() MessageType { + return MsgStateUpdate +} + +// MaxPayloadLength returns the maximum allowed payload size for a StateUpdate +// complete message observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *StateUpdate) MaxPayloadLength(uint32) uint32 { + return MaxMessagePayload +} diff --git a/watchtower/wtwire/state_update_reply.go b/watchtower/wtwire/state_update_reply.go new file mode 100644 index 00000000..b2258d0c --- /dev/null +++ b/watchtower/wtwire/state_update_reply.go @@ -0,0 +1,86 @@ +package wtwire + +import "io" + +// StateUpdateCode is an error code returned by a watchtower in response to a +// StateUpdate message. +type StateUpdateCode = ErrorCode + +const ( + // StateUpdateCodeClientBehind signals that the client's sequence number + // is behind what the watchtower expects based on its LastApplied. This + // error should cause the client to record the LastApplied field in the + // response, and initiate another attempt with the proper sequence + // number. + // + // NOTE: Repeated occurrences of this could be interpreted as an attempt + // to siphon state updates from the client. If the client believes it + // is not violating the protocol, this could be grounds to blacklist + // this tower from future session negotiation. + StateUpdateCodeClientBehind StateUpdateCode = 70 + + // StateUpdateCodeMaxUpdatesExceeded signals that the client tried to + // send a sequence number beyond the negotiated MaxUpdates of the + // session. + StateUpdateCodeMaxUpdatesExceeded StateUpdateCode = 71 + + // StateUpdateCodeSeqNumOutOfOrder signals the client sent an update + // that does not follow the required incremental monotonicity required + // by the tower. + StateUpdateCodeSeqNumOutOfOrder StateUpdateCode = 72 +) + +// StateUpdateReply is a message sent from watchtower to client in response to a +// StateUpdate message, and signals either an acceptance or rejection of the +// proposed state update. +type StateUpdateReply struct { + // Code will be non-zero if the watchtower rejected the state update. + Code StateUpdateCode + + // LastApplied returns the sequence number of the last accepted update + // known to the watchtower. If the update was successful, this value + // should be the sequence number of the last update sent. + LastApplied uint16 +} + +// A compile time check to ensure StateUpdateReply implements the wtwire.Message +// interface. +var _ Message = (*StateUpdateReply)(nil) + +// Decode deserializes a serialized StateUpdateReply message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (t *StateUpdateReply) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &t.Code, + &t.LastApplied, + ) +} + +// Encode serializes the target StateUpdateReply into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the wtwire.Message interface. +func (t *StateUpdateReply) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + t.Code, + t.LastApplied, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the wtwire.Message interface. +func (t *StateUpdateReply) MsgType() MessageType { + return MsgStateUpdateReply +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// StateUpdateReply complete message observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (t *StateUpdateReply) MaxPayloadLength(uint32) uint32 { + return 4 +} diff --git a/watchtower/wtwire/wtwire.go b/watchtower/wtwire/wtwire.go new file mode 100644 index 00000000..b77700ee --- /dev/null +++ b/watchtower/wtwire/wtwire.go @@ -0,0 +1,210 @@ +package wtwire + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lnwallet" +) + +// WriteElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for the wire protocol. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case uint8: + var b [1]byte + b[0] = e + if _, err := w.Write(b[:]); err != nil { + return err + } + + case uint16: + var b [2]byte + binary.BigEndian.PutUint16(b[:], e) + if _, err := w.Write(b[:]); err != nil { + return err + } + + case uint32: + var b [4]byte + binary.BigEndian.PutUint32(b[:], e) + if _, err := w.Write(b[:]); err != nil { + return err + } + + case uint64: + var b [8]byte + binary.BigEndian.PutUint64(b[:], e) + if _, err := w.Write(b[:]); err != nil { + return err + } + + case [16]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case [32]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case [33]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + case lnwallet.SatPerKWeight: + var b [8]byte + binary.BigEndian.PutUint64(b[:], uint64(e)) + if _, err := w.Write(b[:]); err != nil { + return err + } + + case ErrorCode: + var b [2]byte + binary.BigEndian.PutUint16(b[:], uint16(e)) + if _, err := w.Write(b[:]); err != nil { + return err + } + + case *btcec.PublicKey: + if e == nil { + return fmt.Errorf("cannot write nil pubkey") + } + + var b [33]byte + serializedPubkey := e.SerializeCompressed() + copy(b[:], serializedPubkey) + if _, err := w.Write(b[:]); err != nil { + return err + } + + default: + return fmt.Errorf("Unknown type in WriteElement: %T", e) + } + + return nil +} + +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := WriteElement(w, element) + if err != nil { + return err + } + } + return nil +} + +// ReadElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of lnwire. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *uint8: + var b [1]uint8 + if _, err := r.Read(b[:]); err != nil { + return err + } + *e = b[0] + + case *uint16: + var b [2]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = binary.BigEndian.Uint16(b[:]) + + case *uint32: + var b [4]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = binary.BigEndian.Uint32(b[:]) + + case *uint64: + var b [8]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = binary.BigEndian.Uint64(b[:]) + + case *[16]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[32]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + + } + + case *[33]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + *e = bytes + + case *lnwallet.SatPerKWeight: + var b [8]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = lnwallet.SatPerKWeight(binary.BigEndian.Uint64(b[:])) + + case *ErrorCode: + var b [2]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = ErrorCode(binary.BigEndian.Uint16(b[:])) + + case **btcec.PublicKey: + var b [btcec.PubKeyBytesLenCompressed]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + *e = pubKey + + default: + return fmt.Errorf("Unknown type in ReadElement: %T", e) + } + + return nil +} + +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement +// function. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := ReadElement(r, element) + if err != nil { + return err + } + } + return nil +} diff --git a/watchtower/wtwire/wtwire_test.go b/watchtower/wtwire/wtwire_test.go new file mode 100644 index 00000000..56adb7cd --- /dev/null +++ b/watchtower/wtwire/wtwire_test.go @@ -0,0 +1,152 @@ +package wtwire_test + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +func randRawFeatureVector(r *rand.Rand) *lnwire.RawFeatureVector { + featureVec := lnwire.NewRawFeatureVector() + for i := 0; i < 10000; i++ { + if r.Int31n(2) == 0 { + featureVec.Set(lnwire.FeatureBit(i)) + } + } + return featureVec +} + +// TestWatchtowerWireProtocol uses the testing/quick package to create a series +// of fuzz tests to attempt to break a primary scenario which is implemented as +// property based testing scenario. +func TestWatchtowerWireProtocol(t *testing.T) { + t.Parallel() + + // mainScenario is the primary test that will programmatically be + // executed for all registered wire messages. The quick-checker within + // testing/quick will attempt to find an input to this function, s.t + // the function returns false, if so then we've found an input that + // violates our model of the system. + mainScenario := func(msg wtwire.Message) bool { + // Give a new message, we'll serialize the message into a new + // bytes buffer. + var b bytes.Buffer + if _, err := wtwire.WriteMessage(&b, msg, 0); err != nil { + t.Fatalf("unable to write msg: %v", err) + return false + } + + // Next, we'll ensure that the serialized payload (subtracting + // the 2 bytes for the message type) is _below_ the specified + // max payload size for this message. + payloadLen := uint32(b.Len()) - 2 + if payloadLen > msg.MaxPayloadLength(0) { + t.Fatalf("msg payload constraint violated: %v > %v", + payloadLen, msg.MaxPayloadLength(0)) + return false + } + + // Finally, we'll deserialize the message from the written + // buffer, and finally assert that the messages are equal. + newMsg, err := wtwire.ReadMessage(&b, 0) + if err != nil { + t.Fatalf("unable to read msg: %v", err) + return false + } + if !reflect.DeepEqual(msg, newMsg) { + t.Fatalf("messages don't match after re-encoding: %v "+ + "vs %v", spew.Sdump(msg), spew.Sdump(newMsg)) + return false + } + + return true + } + + customTypeGen := map[wtwire.MessageType]func([]reflect.Value, *rand.Rand){ + wtwire.MsgInit: func(v []reflect.Value, r *rand.Rand) { + req := wtwire.NewInitMessage( + randRawFeatureVector(r), + randRawFeatureVector(r), + ) + + v[0] = reflect.ValueOf(*req) + }, + } + + // With the above types defined, we'll now generate a slice of + // scenarios to feed into quick.Check. The function scans in input + // space of the target function under test, so we'll need to create a + // series of wrapper functions to force it to iterate over the target + // types, but re-use the mainScenario defined above. + tests := []struct { + msgType wtwire.MessageType + scenario interface{} + }{ + { + msgType: wtwire.MsgInit, + scenario: func(m wtwire.Init) bool { + return mainScenario(&m) + }, + }, + { + msgType: wtwire.MsgCreateSession, + scenario: func(m wtwire.CreateSession) bool { + return mainScenario(&m) + }, + }, + { + msgType: wtwire.MsgCreateSessionReply, + scenario: func(m wtwire.CreateSessionReply) bool { + return mainScenario(&m) + }, + }, + { + msgType: wtwire.MsgStateUpdate, + scenario: func(m wtwire.StateUpdate) bool { + return mainScenario(&m) + }, + }, + { + msgType: wtwire.MsgStateUpdateReply, + scenario: func(m wtwire.StateUpdateReply) bool { + return mainScenario(&m) + }, + }, + { + msgType: wtwire.MsgError, + scenario: func(m wtwire.Error) bool { + return mainScenario(&m) + }, + }, + } + for _, test := range tests { + var config *quick.Config + + // If the type defined is within the custom type gen map above, + // the we'll modify the default config to use this Value + // function that knows how to generate the proper types. + if valueGen, ok := customTypeGen[test.msgType]; ok { + config = &quick.Config{ + Values: valueGen, + } + } + + t.Logf("Running fuzz tests for msgType=%v", test.msgType) + if err := quick.Check(test.scenario, config); err != nil { + t.Fatalf("fuzz checks for msg=%v failed: %v", + test.msgType, err) + } + } + +} + +func init() { + rand.Seed(time.Now().Unix()) +}