lnwire: convert message parsing to use the new minimal type header
This commit abandons our old bitcoin inspired message header and replaces it with the bare type-only message headers that’s currently used within the draft specification. As a result the message header now consists of only 2-bytes for the message type, then actual payload itself. With this change, the daemon will now need to switch to a purely message based wire protocol in order to be able to handle the extra data that can be extended to arbitrary messages.
This commit is contained in:
parent
6f2d3b3cc5
commit
febc8c399a
@ -9,13 +9,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessageHeaderSize is the number of bytes in a lightning message header.
|
|
||||||
// The bytes are allocated as follows: network magic 4 bytes + command 4
|
|
||||||
// bytes + payload length 4 bytes. Note that a checksum is omitted as lightning
|
|
||||||
// messages are assumed to be transmitted over an AEAD secured connection which
|
|
||||||
// provides integrity over the entire message.
|
|
||||||
const MessageHeaderSize = 12
|
|
||||||
|
|
||||||
// MaxMessagePayload is the maximum bytes a message can be regardless of other
|
// MaxMessagePayload is the maximum bytes a message can be regardless of other
|
||||||
// individual limits imposed by messages themselves.
|
// individual limits imposed by messages themselves.
|
||||||
const MaxMessagePayload = 65535 // 65KB
|
const MaxMessagePayload = 65535 // 65KB
|
||||||
@ -127,182 +120,82 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// messageHeader represents the header structure for all lightning protocol
|
|
||||||
// messages.
|
|
||||||
type messageHeader struct {
|
|
||||||
// magic represents Which Blockchain Technology(TM) to use.
|
|
||||||
// NOTE(j): We don't need to worry about the magic overlapping with
|
|
||||||
// bitcoin since this is inside encrypted comms anyway, but maybe we
|
|
||||||
// should use the XOR (^wire.TestNet3) just in case???
|
|
||||||
magic wire.BitcoinNet // 4 bytes
|
|
||||||
command uint32 // 4 bytes
|
|
||||||
length uint32 // 4 bytes
|
|
||||||
}
|
|
||||||
|
|
||||||
// readMessageHeader reads a lightning protocol message header from r.
|
|
||||||
func readMessageHeader(r io.Reader) (int, *messageHeader, error) {
|
|
||||||
// As the message header is a fixed size structure, read bytes for the
|
|
||||||
// entire header at once.
|
|
||||||
var headerBytes [MessageHeaderSize]byte
|
|
||||||
n, err := io.ReadFull(r, headerBytes[:])
|
|
||||||
if err != nil {
|
|
||||||
return n, nil, err
|
|
||||||
}
|
|
||||||
hr := bytes.NewReader(headerBytes[:])
|
|
||||||
|
|
||||||
// Create and populate the message header from the raw header bytes.
|
|
||||||
hdr := messageHeader{}
|
|
||||||
err = readElements(hr,
|
|
||||||
&hdr.magic,
|
|
||||||
&hdr.command,
|
|
||||||
&hdr.length)
|
|
||||||
if err != nil {
|
|
||||||
return n, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, &hdr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// discardInput reads n bytes from reader r in chunks and discards the read
|
|
||||||
// bytes. This is used to skip payloads when various errors occur and helps
|
|
||||||
// prevent rogue nodes from causing massive memory allocation through forging
|
|
||||||
// header length.
|
|
||||||
func discardInput(r io.Reader, n uint32) {
|
|
||||||
maxSize := uint32(10 * 1024) // 10k at a time
|
|
||||||
numReads := n / maxSize
|
|
||||||
bytesRemaining := n % maxSize
|
|
||||||
if n > 0 {
|
|
||||||
buf := make([]byte, maxSize)
|
|
||||||
for i := uint32(0); i < numReads; i++ {
|
|
||||||
io.ReadFull(r, buf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if bytesRemaining > 0 {
|
|
||||||
buf := make([]byte, bytesRemaining)
|
|
||||||
io.ReadFull(r, buf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteMessage writes a lightning Message to w including the necessary header
|
// WriteMessage writes a lightning Message to w including the necessary header
|
||||||
// information and returns the number of bytes written.
|
// information and returns the number of bytes written.
|
||||||
func WriteMessage(w io.Writer, msg Message, pver uint32, btcnet wire.BitcoinNet) (int, error) {
|
func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
|
||||||
totalBytes := 0
|
totalBytes := 0
|
||||||
|
|
||||||
cmd := msg.Command()
|
// Encode the message payload itself into a temporary buffer.
|
||||||
|
// TODO(roasbeef): create buffer pool
|
||||||
// Encode the message payload
|
|
||||||
var bw bytes.Buffer
|
var bw bytes.Buffer
|
||||||
err := msg.Encode(&bw, pver)
|
if err := msg.Encode(&bw, pver); err != nil {
|
||||||
if err != nil {
|
|
||||||
return totalBytes, err
|
return totalBytes, err
|
||||||
}
|
}
|
||||||
payload := bw.Bytes()
|
payload := bw.Bytes()
|
||||||
lenp := len(payload)
|
lenp := len(payload)
|
||||||
|
|
||||||
// Enforce maximum overall message payload
|
// Enforce maximum overall message payload.
|
||||||
if lenp > MaxMessagePayload {
|
if lenp > MaxMessagePayload {
|
||||||
return totalBytes, fmt.Errorf("message payload is too large - "+
|
return totalBytes, fmt.Errorf("message payload is too large - "+
|
||||||
"encoded %d bytes, but maximum message payload is %d bytes",
|
"encoded %d bytes, but maximum message payload is %d bytes",
|
||||||
lenp, MaxMessagePayload)
|
lenp, MaxMessagePayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce maximum message payload on the message type
|
// Enforce maximum message payload on the message type.
|
||||||
mpl := msg.MaxPayloadLength(pver)
|
mpl := msg.MaxPayloadLength(pver)
|
||||||
if uint32(lenp) > mpl {
|
if uint32(lenp) > mpl {
|
||||||
return totalBytes, fmt.Errorf("message payload is too large - "+
|
return totalBytes, fmt.Errorf("message payload is too large - "+
|
||||||
"encoded %d bytes, but maximum message payload of "+
|
"encoded %d bytes, but maximum message payload of "+
|
||||||
"type %x is %d bytes", lenp, cmd, mpl)
|
"type %x is %d bytes", lenp, msg.MsgType(), mpl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create header for the message.
|
// With the initial sanity checks complete, we'll now write out the
|
||||||
hdr := messageHeader{magic: btcnet, command: cmd, length: uint32(lenp)}
|
// message type itself.
|
||||||
|
var mType [2]byte
|
||||||
// Encode the header for the message. This is done to a buffer
|
binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType()))
|
||||||
// rather than directly to the writer since writeElements doesn't
|
n, err := w.Write(mType[:])
|
||||||
// return the number of bytes written.
|
|
||||||
hw := bytes.NewBuffer(make([]byte, 0, MessageHeaderSize))
|
|
||||||
if err := writeElements(hw, hdr.magic, hdr.command, hdr.length); err != nil {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write the header first.
|
|
||||||
n, err := w.Write(hw.Bytes())
|
|
||||||
totalBytes += n
|
totalBytes += n
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return totalBytes, err
|
return totalBytes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write payload the payload itself after the header.
|
// With the message type written, we'll now write out the raw payload
|
||||||
|
// itself.
|
||||||
n, err = w.Write(payload)
|
n, err = w.Write(payload)
|
||||||
totalBytes += n
|
totalBytes += n
|
||||||
|
|
||||||
return totalBytes, err
|
return totalBytes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadMessage reads, validates, and parses the next bitcoin Message from r for
|
// ReadMessage reads, validates, and parses the next bitcoin Message from r for
|
||||||
// the provided protocol version and bitcoin network. It returns the number of
|
// the provided protocol version. It returns the number of bytes read in
|
||||||
// bytes read in addition to the parsed Message and raw bytes which comprise the
|
// addition to the parsed Message and raw bytes which comprise the message.
|
||||||
// message. This function is the same as ReadMessage except it also returns the
|
func ReadMessage(r io.Reader, pver uint32) (int, Message, error) {
|
||||||
// number of bytes read.
|
// TODO(roasbeef): need to explicitly enforce max message payload, or
|
||||||
func ReadMessage(r io.Reader, pver uint32, btcnet wire.BitcoinNet) (int, Message, []byte, error) {
|
// just allow it to be done by the MaxPayloadLength?
|
||||||
totalBytes := 0
|
totalBytes := 0
|
||||||
n, hdr, err := readMessageHeader(r)
|
|
||||||
|
// First, we'll read out the first two bytes of the message so we can
|
||||||
|
// create the proper empty message.
|
||||||
|
var mType [2]byte
|
||||||
|
n, err := io.ReadFull(r, mType[:])
|
||||||
totalBytes += n
|
totalBytes += n
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return totalBytes, nil, nil, err
|
return totalBytes, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce maximum message payload
|
msgType := MessageType(binary.BigEndian.Uint16(mType[:]))
|
||||||
if hdr.length > MaxMessagePayload {
|
|
||||||
return totalBytes, nil, nil, fmt.Errorf("message payload is "+
|
|
||||||
"too large - header indicates %d bytes, but max "+
|
|
||||||
"message payload is %d bytes.", hdr.length,
|
|
||||||
MaxMessagePayload)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for messages in the wrong network.
|
// Now that we know the target message type, we can create the proper
|
||||||
if hdr.magic != btcnet {
|
// empty message type and decode the message into it.
|
||||||
discardInput(r, hdr.length)
|
msg, err := makeEmptyMessage(msgType)
|
||||||
return totalBytes, nil, nil, fmt.Errorf("message from other "+
|
|
||||||
"network [%v]", hdr.magic)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create struct of appropriate message type based on the command.
|
|
||||||
command := hdr.command
|
|
||||||
msg, err := makeEmptyMessage(command)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
discardInput(r, hdr.length)
|
return totalBytes, nil, err
|
||||||
return totalBytes, nil, nil, &UnknownMessage{
|
|
||||||
messageType: command,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if err := msg.Decode(r, pver); err != nil {
|
||||||
// Check for maximum length based on the message type.
|
return totalBytes, nil, err
|
||||||
mpl := msg.MaxPayloadLength(pver)
|
|
||||||
if hdr.length > mpl {
|
|
||||||
discardInput(r, hdr.length)
|
|
||||||
return totalBytes, nil, nil, fmt.Errorf("payload exceeds max "+
|
|
||||||
"length. indicates %v bytes, but max of message type %v is %v.",
|
|
||||||
hdr.length, command, mpl)
|
|
||||||
}
|
}
|
||||||
|
totalBytes += int(msg.MaxPayloadLength(pver))
|
||||||
|
|
||||||
// Read payload.
|
return totalBytes, msg, nil
|
||||||
payload := make([]byte, hdr.length)
|
|
||||||
n, err = io.ReadFull(r, payload)
|
|
||||||
totalBytes += n
|
|
||||||
if err != nil {
|
|
||||||
return totalBytes, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal message.
|
|
||||||
pr := bytes.NewBuffer(payload)
|
|
||||||
if err = msg.Decode(pr, pver); err != nil {
|
|
||||||
return totalBytes, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate the data.
|
|
||||||
if err = msg.Validate(); err != nil {
|
|
||||||
return totalBytes, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return totalBytes, msg, payload, nil
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user