lnd.xprv/lnwire/message.go
Olaoluwa Osuntokun fcff17c336
multi: change all imports to roasbeef's forks
This commit will allow the general public to build lnd without jumping
through hoops setting up their local git branches nicely with all of
our forks.
2016-05-15 17:22:37 +03:00

266 lines
6.8 KiB
Go

// Code derived from https:// github.com/btcsuite/btcd/blob/master/wire/message.go
package lnwire
import (
"bytes"
"fmt"
"io"
"github.com/roasbeef/btcd/wire"
)
// 4-byte network + 4-byte message id + payload-length 4-byte
const MessageHeaderSize = 12
const MaxMessagePayload = 1024 * 1024 * 32 // 32MB
const (
// Funding channel open
CmdFundingRequest = uint32(200)
CmdFundingResponse = uint32(210)
CmdFundingSignAccept = uint32(220)
CmdFundingSignComplete = uint32(230)
// Close channel
CmdCloseRequest = uint32(300)
CmdCloseComplete = uint32(310)
// TODO Renumber to 1100
// HTLC payment
CmdHTLCAddRequest = uint32(1000)
CmdHTLCAddAccept = uint32(1010)
CmdHTLCAddReject = uint32(1020)
// TODO Renumber to 1200
// HTLC settlement
CmdHTLCSettleRequest = uint32(1100)
CmdHTLCSettleAccept = uint32(1110)
// HTLC timeout
CmdHTLCTimeoutRequest = uint32(1300)
CmdHTLCTimeoutAccept = uint32(1310)
// Commitments
CmdCommitSignature = uint32(2000)
CmdCommitRevocation = uint32(2010)
// Error
CmdErrorGeneric = uint32(4000)
)
// Every message has these functions:
type Message interface {
Decode(io.Reader, uint32) error // (io, protocol version)
Encode(io.Writer, uint32) error // (io, protocol version)
Command() uint32 // returns ID of the message
MaxPayloadLength(uint32) uint32 // (version) maxpayloadsize
Validate() error // Validates the data struct
String() string
}
func makeEmptyMessage(command uint32) (Message, error) {
var msg Message
switch command {
case CmdFundingRequest:
msg = &FundingRequest{}
case CmdFundingResponse:
msg = &FundingResponse{}
case CmdFundingSignAccept:
msg = &FundingSignAccept{}
case CmdFundingSignComplete:
msg = &FundingSignComplete{}
case CmdCloseRequest:
msg = &CloseRequest{}
case CmdCloseComplete:
msg = &CloseComplete{}
case CmdHTLCAddRequest:
msg = &HTLCAddRequest{}
case CmdHTLCAddAccept:
msg = &HTLCAddAccept{}
case CmdHTLCAddReject:
msg = &HTLCAddReject{}
case CmdHTLCSettleRequest:
msg = &HTLCSettleRequest{}
case CmdHTLCSettleAccept:
msg = &HTLCSettleAccept{}
case CmdHTLCTimeoutRequest:
msg = &HTLCTimeoutRequest{}
case CmdHTLCTimeoutAccept:
msg = &HTLCTimeoutAccept{}
case CmdCommitSignature:
msg = &CommitSignature{}
case CmdCommitRevocation:
msg = &CommitRevocation{}
case CmdErrorGeneric:
msg = &ErrorGeneric{}
default:
return nil, fmt.Errorf("unhandled command [%d]", command)
}
return msg, nil
}
type messageHeader struct {
// NOTE(j): We don't need to worry about the magic overlapping with
// bitcoin since this is inside encrypted comms anyway, but maybe we
// should use the XOR (^wire.TestNet3) just in case???
magic wire.BitcoinNet // which Blockchain Technology(TM) to use
command uint32
length uint32
}
func readMessageHeader(r io.Reader) (int, *messageHeader, error) {
var headerBytes [MessageHeaderSize]byte
n, err := io.ReadFull(r, headerBytes[:])
if err != nil {
return n, nil, err
}
hr := bytes.NewReader(headerBytes[:])
hdr := messageHeader{}
err = readElements(hr,
&hdr.magic,
&hdr.command,
&hdr.length)
if err != nil {
return n, nil, err
}
return n, &hdr, nil
}
// discardInput reads n bytes from reader r in chunks and discards the read
// bytes. This is used to skip payloads when various errors occur and helps
// prevent rogue nodes from causing massive memory allocation through forging
// header length.
func discardInput(r io.Reader, n uint32) {
maxSize := uint32(10 * 1024) // 10k at a time
numReads := n / maxSize
bytesRemaining := n % maxSize
if n > 0 {
buf := make([]byte, maxSize)
for i := uint32(0); i < numReads; i++ {
io.ReadFull(r, buf)
}
}
if bytesRemaining > 0 {
buf := make([]byte, bytesRemaining)
io.ReadFull(r, buf)
}
}
func WriteMessage(w io.Writer, msg Message, pver uint32, btcnet wire.BitcoinNet) (int, error) {
totalBytes := 0
cmd := msg.Command()
// Encode the message payload
var bw bytes.Buffer
err := msg.Encode(&bw, pver)
if err != nil {
return totalBytes, err
}
payload := bw.Bytes()
lenp := len(payload)
// Enforce maximum overall message payload
if lenp > MaxMessagePayload {
return totalBytes, fmt.Errorf("message payload is too large - encoded %d bytes, but maximum message payload is %d bytes", lenp, MaxMessagePayload)
}
// Enforce maximum message payload on the message type
mpl := msg.MaxPayloadLength(pver)
if uint32(lenp) > mpl {
return totalBytes, fmt.Errorf("message payload is too large - encoded %d bytes, but maximum message payload of type %x is %d bytes", lenp, cmd, mpl)
}
// Create header for the message
hdr := messageHeader{}
hdr.magic = btcnet
hdr.command = cmd
hdr.length = uint32(lenp)
// Encode the header for the message. This is done to a buffer
// rather than directly to the writer since writeElements doesn't
// return the number of bytes written.
hw := bytes.NewBuffer(make([]byte, 0, MessageHeaderSize))
writeElements(hw, hdr.magic, hdr.command, hdr.length)
// Write header
n, err := w.Write(hw.Bytes())
totalBytes += n
if err != nil {
return totalBytes, err
}
// Write payload
n, err = w.Write(payload)
totalBytes += n
if err != nil {
return totalBytes, err
}
return totalBytes, nil
}
func ReadMessage(r io.Reader, pver uint32, btcnet wire.BitcoinNet) (int, Message, []byte, error) {
totalBytes := 0
n, hdr, err := readMessageHeader(r)
totalBytes += n
if err != nil {
return totalBytes, nil, nil, err
}
// Enforce maximum message payload
if hdr.length > MaxMessagePayload {
return totalBytes, nil, nil, fmt.Errorf("message payload is too large - header indicates %d bytes, but max message payload is %d bytes.", hdr.length, MaxMessagePayload)
}
// Check for messages in the wrong bitcoin network
if hdr.magic != btcnet {
discardInput(r, hdr.length)
return totalBytes, nil, nil, fmt.Errorf("message from other network [%v]", hdr.magic)
}
// Create struct of appropriate message type based on the command
command := hdr.command
msg, err := makeEmptyMessage(command)
if err != nil {
discardInput(r, hdr.length)
return totalBytes, nil, nil, fmt.Errorf("ReadMessage %s", err.Error())
}
// Check for maximum length based on the message type
mpl := msg.MaxPayloadLength(pver)
if hdr.length > mpl {
discardInput(r, hdr.length)
return totalBytes, nil, nil, fmt.Errorf("payload exceeds max length. indicates %v bytes, but max of message type %v is %v.", hdr.length, command, mpl)
}
// Read payload
payload := make([]byte, hdr.length)
n, err = io.ReadFull(r, payload)
totalBytes += n
if err != nil {
return totalBytes, nil, nil, err
}
// Unmarshal message
pr := bytes.NewBuffer(payload)
err = msg.Decode(pr, pver)
if err != nil {
return totalBytes, nil, nil, err
}
// Validate the data
err = msg.Validate()
if err != nil {
return totalBytes, nil, nil, err
}
// We're good!
return totalBytes, msg, payload, nil
}