watchtower/wtwire/message: define wtwire message interface
This commit is contained in:
parent
4325d9ec1e
commit
49b2a3bdb5
179
watchtower/wtwire/message.go
Normal file
179
watchtower/wtwire/message.go
Normal file
@ -0,0 +1,179 @@
|
||||
package wtwire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// MaxMessagePayload is the maximum bytes a message can be regardless of other
|
||||
// individual limits imposed by messages themselves.
|
||||
const MaxMessagePayload = 65535 // 65KB
|
||||
|
||||
// MessageType is the unique 2 byte big-endian integer that indicates the type
|
||||
// of message on the wire. All messages have a very simple header which
|
||||
// consists simply of 2-byte message type. We omit a length field, and checksum
|
||||
// as the Watchtower Protocol is intended to be encapsulated within a
|
||||
// confidential+authenticated cryptographic messaging protocol.
|
||||
type MessageType uint16
|
||||
|
||||
// The currently defined message types within this current version of the
|
||||
// Watchtower protocol.
|
||||
const (
|
||||
// MsgInit identifies an encoded Init message.
|
||||
MsgInit MessageType = 300
|
||||
|
||||
// MsgError identifies an encoded Error message.
|
||||
MsgError = 301
|
||||
|
||||
// MsgCreateSession identifies an encoded CreateSession message.
|
||||
MsgCreateSession MessageType = 302
|
||||
|
||||
// MsgCreateSessionReply identifies an encoded CreateSessionReply message.
|
||||
MsgCreateSessionReply MessageType = 303
|
||||
|
||||
// MsgStateUpdate identifies an encoded StateUpdate message.
|
||||
MsgStateUpdate MessageType = 304
|
||||
|
||||
// MsgStateUpdateReply identifies an encoded StateUpdateReply message.
|
||||
MsgStateUpdateReply MessageType = 305
|
||||
)
|
||||
|
||||
// String returns a human readable description of the message type.
|
||||
func (m MessageType) String() string {
|
||||
switch m {
|
||||
case MsgInit:
|
||||
return "Init"
|
||||
case MsgCreateSession:
|
||||
return "MsgCreateSession"
|
||||
case MsgCreateSessionReply:
|
||||
return "MsgCreateSessionReply"
|
||||
case MsgStateUpdate:
|
||||
return "MsgStateUpdate"
|
||||
case MsgStateUpdateReply:
|
||||
return "MsgStateUpdateReply"
|
||||
case MsgError:
|
||||
return "Error"
|
||||
default:
|
||||
return "<unknown>"
|
||||
}
|
||||
}
|
||||
|
||||
// Serializable is an interface which defines a lightning wire serializable
|
||||
// object.
|
||||
type Serializable = lnwire.Serializable
|
||||
|
||||
// Message is an interface that defines a lightning wire protocol message. The
|
||||
// interface is general in order to allow implementing types full control over
|
||||
// the representation of its data.
|
||||
type Message interface {
|
||||
Serializable
|
||||
|
||||
// MsgType returns a MessageType that uniquely identifies the message to
|
||||
// be encoded.
|
||||
MsgType() MessageType
|
||||
|
||||
// MaxMessagePayload is the maximum serialized length that a particular
|
||||
// message type can take.
|
||||
MaxPayloadLength(uint32) uint32
|
||||
}
|
||||
|
||||
// makeEmptyMessage creates a new empty message of the proper concrete type
|
||||
// based on the passed message type.
|
||||
func makeEmptyMessage(msgType MessageType) (Message, error) {
|
||||
var msg Message
|
||||
|
||||
switch msgType {
|
||||
case MsgInit:
|
||||
msg = &Init{&lnwire.Init{}}
|
||||
case MsgCreateSession:
|
||||
msg = &CreateSession{}
|
||||
case MsgCreateSessionReply:
|
||||
msg = &CreateSessionReply{}
|
||||
case MsgStateUpdate:
|
||||
msg = &StateUpdate{}
|
||||
case MsgStateUpdateReply:
|
||||
msg = &StateUpdateReply{}
|
||||
case MsgError:
|
||||
msg = &Error{}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type [%d]", msgType)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// WriteMessage writes a lightning Message to w including the necessary header
|
||||
// information and returns the number of bytes written.
|
||||
func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
|
||||
totalBytes := 0
|
||||
|
||||
// Encode the message payload itself into a temporary buffer.
|
||||
// TODO(roasbeef): create buffer pool
|
||||
var bw bytes.Buffer
|
||||
if err := msg.Encode(&bw, pver); err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
payload := bw.Bytes()
|
||||
lenp := len(payload)
|
||||
|
||||
// Enforce maximum overall message payload.
|
||||
if lenp > MaxMessagePayload {
|
||||
return totalBytes, fmt.Errorf("message payload is too large - "+
|
||||
"encoded %d bytes, but maximum message payload is %d bytes",
|
||||
lenp, MaxMessagePayload)
|
||||
}
|
||||
|
||||
// Enforce maximum message payload on the message type.
|
||||
mpl := msg.MaxPayloadLength(pver)
|
||||
if uint32(lenp) > mpl {
|
||||
return totalBytes, fmt.Errorf("message payload is too large - "+
|
||||
"encoded %d bytes, but maximum message payload of "+
|
||||
"type %v is %d bytes", lenp, msg.MsgType(), mpl)
|
||||
}
|
||||
|
||||
// With the initial sanity checks complete, we'll now write out the
|
||||
// message type itself.
|
||||
var mType [2]byte
|
||||
binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType()))
|
||||
n, err := w.Write(mType[:])
|
||||
totalBytes += n
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
|
||||
// With the message type written, we'll now write out the raw payload
|
||||
// itself.
|
||||
n, err = w.Write(payload)
|
||||
totalBytes += n
|
||||
|
||||
return totalBytes, err
|
||||
}
|
||||
|
||||
// ReadMessage reads, validates, and parses the next Watchtower message from r
|
||||
// for the provided protocol version.
|
||||
func ReadMessage(r io.Reader, pver uint32) (Message, error) {
|
||||
// First, we'll read out the first two bytes of the message so we can
|
||||
// create the proper empty message.
|
||||
var mType [2]byte
|
||||
if _, err := io.ReadFull(r, mType[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgType := MessageType(binary.BigEndian.Uint16(mType[:]))
|
||||
|
||||
// Now that we know the target message type, we can create the proper
|
||||
// empty message type and decode the message into it.
|
||||
msg, err := makeEmptyMessage(msgType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := msg.Decode(r, pver); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user