2016-01-17 04:14:35 +03:00
// Code derived from https:// github.com/btcsuite/btcd/blob/master/wire/message.go
2015-12-28 14:24:16 +03:00
package lnwire
import (
"bytes"
"fmt"
"io"
2015-12-31 19:38:33 +03:00
2016-05-15 17:17:44 +03:00
"github.com/roasbeef/btcd/wire"
2015-12-31 19:38:33 +03:00
)
2016-01-17 04:14:35 +03:00
// 4-byte network + 4-byte message id + payload-length 4-byte
2015-12-28 14:24:16 +03:00
const MessageHeaderSize = 12
2016-01-17 04:14:35 +03:00
const MaxMessagePayload = 1024 * 1024 * 32 // 32MB
2015-12-28 14:24:16 +03:00
const (
2016-01-17 04:14:35 +03:00
// Funding channel open
2015-12-30 16:38:57 +03:00
CmdFundingRequest = uint32 ( 200 )
CmdFundingResponse = uint32 ( 210 )
CmdFundingSignAccept = uint32 ( 220 )
CmdFundingSignComplete = uint32 ( 230 )
2015-12-31 12:19:54 +03:00
2016-01-17 04:14:35 +03:00
// Close channel
2015-12-31 12:19:54 +03:00
CmdCloseRequest = uint32 ( 300 )
CmdCloseComplete = uint32 ( 310 )
2016-01-17 04:14:35 +03:00
// TODO Renumber to 1100
// HTLC payment
2016-01-05 19:19:22 +03:00
CmdHTLCAddRequest = uint32 ( 1000 )
CmdHTLCAddAccept = uint32 ( 1010 )
CmdHTLCAddReject = uint32 ( 1020 )
2016-01-17 04:14:35 +03:00
// TODO Renumber to 1200
// HTLC settlement
2016-01-05 19:19:22 +03:00
CmdHTLCSettleRequest = uint32 ( 1100 )
CmdHTLCSettleAccept = uint32 ( 1110 )
2016-01-17 04:14:35 +03:00
// HTLC timeout
2016-01-05 19:19:22 +03:00
CmdHTLCTimeoutRequest = uint32 ( 1300 )
CmdHTLCTimeoutAccept = uint32 ( 1310 )
2016-01-17 04:14:35 +03:00
// Commitments
2016-01-05 19:19:22 +03:00
CmdCommitSignature = uint32 ( 2000 )
CmdCommitRevocation = uint32 ( 2010 )
2016-01-05 19:53:42 +03:00
2016-01-17 04:14:35 +03:00
// Error
2016-01-05 19:53:42 +03:00
CmdErrorGeneric = uint32 ( 4000 )
2015-12-28 14:24:16 +03:00
)
2016-01-17 04:14:35 +03:00
// Every message has these functions:
2015-12-28 14:24:16 +03:00
type Message interface {
2016-01-17 04:14:35 +03:00
Decode ( io . Reader , uint32 ) error // (io, protocol version)
Encode ( io . Writer , uint32 ) error // (io, protocol version)
Command ( ) uint32 // returns ID of the message
MaxPayloadLength ( uint32 ) uint32 // (version) maxpayloadsize
Validate ( ) error // Validates the data struct
2015-12-28 14:24:16 +03:00
String ( ) string
}
func makeEmptyMessage ( command uint32 ) ( Message , error ) {
var msg Message
switch command {
case CmdFundingRequest :
msg = & FundingRequest { }
2015-12-30 16:38:57 +03:00
case CmdFundingResponse :
msg = & FundingResponse { }
case CmdFundingSignAccept :
msg = & FundingSignAccept { }
case CmdFundingSignComplete :
msg = & FundingSignComplete { }
2015-12-31 13:42:25 +03:00
case CmdCloseRequest :
msg = & CloseRequest { }
case CmdCloseComplete :
msg = & CloseComplete { }
2016-01-05 19:19:22 +03:00
case CmdHTLCAddRequest :
msg = & HTLCAddRequest { }
case CmdHTLCAddAccept :
msg = & HTLCAddAccept { }
case CmdHTLCAddReject :
msg = & HTLCAddReject { }
case CmdHTLCSettleRequest :
msg = & HTLCSettleRequest { }
case CmdHTLCSettleAccept :
msg = & HTLCSettleAccept { }
case CmdHTLCTimeoutRequest :
msg = & HTLCTimeoutRequest { }
case CmdHTLCTimeoutAccept :
msg = & HTLCTimeoutAccept { }
case CmdCommitSignature :
msg = & CommitSignature { }
case CmdCommitRevocation :
msg = & CommitRevocation { }
2016-01-05 19:53:42 +03:00
case CmdErrorGeneric :
msg = & ErrorGeneric { }
2015-12-28 14:24:16 +03:00
default :
2015-12-30 16:38:57 +03:00
return nil , fmt . Errorf ( "unhandled command [%d]" , command )
2015-12-28 14:24:16 +03:00
}
return msg , nil
}
type messageHeader struct {
2016-01-17 04:14:35 +03:00
// NOTE(j): We don't need to worry about the magic overlapping with
// bitcoin since this is inside encrypted comms anyway, but maybe we
// should use the XOR (^wire.TestNet3) just in case???
magic wire . BitcoinNet // which Blockchain Technology(TM) to use
2015-12-28 14:24:16 +03:00
command uint32
length uint32
}
func readMessageHeader ( r io . Reader ) ( int , * messageHeader , error ) {
var headerBytes [ MessageHeaderSize ] byte
n , err := io . ReadFull ( r , headerBytes [ : ] )
if err != nil {
return n , nil , err
}
hr := bytes . NewReader ( headerBytes [ : ] )
hdr := messageHeader { }
2015-12-31 10:34:40 +03:00
err = readElements ( hr ,
2015-12-28 14:24:16 +03:00
& hdr . magic ,
& hdr . command ,
& hdr . length )
if err != nil {
return n , nil , err
}
return n , & hdr , nil
}
2016-01-17 04:14:35 +03:00
// discardInput reads n bytes from reader r in chunks and discards the read
// bytes. This is used to skip payloads when various errors occur and helps
// prevent rogue nodes from causing massive memory allocation through forging
// header length.
2015-12-28 14:24:16 +03:00
func discardInput ( r io . Reader , n uint32 ) {
2016-01-17 04:14:35 +03:00
maxSize := uint32 ( 10 * 1024 ) // 10k at a time
2015-12-28 14:24:16 +03:00
numReads := n / maxSize
bytesRemaining := n % maxSize
if n > 0 {
buf := make ( [ ] byte , maxSize )
for i := uint32 ( 0 ) ; i < numReads ; i ++ {
io . ReadFull ( r , buf )
}
}
if bytesRemaining > 0 {
buf := make ( [ ] byte , bytesRemaining )
io . ReadFull ( r , buf )
}
}
func WriteMessage ( w io . Writer , msg Message , pver uint32 , btcnet wire . BitcoinNet ) ( int , error ) {
totalBytes := 0
cmd := msg . Command ( )
2016-01-17 04:14:35 +03:00
// Encode the message payload
2015-12-28 14:24:16 +03:00
var bw bytes . Buffer
err := msg . Encode ( & bw , pver )
if err != nil {
return totalBytes , err
}
payload := bw . Bytes ( )
lenp := len ( payload )
2016-01-17 04:14:35 +03:00
// Enforce maximum overall message payload
2015-12-28 14:24:16 +03:00
if lenp > MaxMessagePayload {
return totalBytes , fmt . Errorf ( "message payload is too large - encoded %d bytes, but maximum message payload is %d bytes" , lenp , MaxMessagePayload )
}
2016-01-17 04:14:35 +03:00
// Enforce maximum message payload on the message type
2015-12-28 14:24:16 +03:00
mpl := msg . MaxPayloadLength ( pver )
if uint32 ( lenp ) > mpl {
return totalBytes , fmt . Errorf ( "message payload is too large - encoded %d bytes, but maximum message payload of type %x is %d bytes" , lenp , cmd , mpl )
}
2016-01-17 04:14:35 +03:00
// Create header for the message
2015-12-28 14:24:16 +03:00
hdr := messageHeader { }
hdr . magic = btcnet
hdr . command = cmd
hdr . length = uint32 ( lenp )
2016-01-17 04:14:35 +03:00
// Encode the header for the message. This is done to a buffer
// rather than directly to the writer since writeElements doesn't
// return the number of bytes written.
2015-12-28 14:24:16 +03:00
hw := bytes . NewBuffer ( make ( [ ] byte , 0 , MessageHeaderSize ) )
2015-12-31 10:34:40 +03:00
writeElements ( hw , hdr . magic , hdr . command , hdr . length )
2015-12-28 14:24:16 +03:00
2016-01-17 04:14:35 +03:00
// Write header
2015-12-28 14:24:16 +03:00
n , err := w . Write ( hw . Bytes ( ) )
totalBytes += n
if err != nil {
return totalBytes , err
}
2016-01-17 04:14:35 +03:00
// Write payload
2015-12-28 14:24:16 +03:00
n , err = w . Write ( payload )
totalBytes += n
if err != nil {
return totalBytes , err
}
return totalBytes , nil
}
func ReadMessage ( r io . Reader , pver uint32 , btcnet wire . BitcoinNet ) ( int , Message , [ ] byte , error ) {
totalBytes := 0
n , hdr , err := readMessageHeader ( r )
totalBytes += n
if err != nil {
return totalBytes , nil , nil , err
}
2016-01-17 04:14:35 +03:00
// Enforce maximum message payload
2015-12-28 14:24:16 +03:00
if hdr . length > MaxMessagePayload {
return totalBytes , nil , nil , fmt . Errorf ( "message payload is too large - header indicates %d bytes, but max message payload is %d bytes." , hdr . length , MaxMessagePayload )
}
2016-01-17 04:14:35 +03:00
// Check for messages in the wrong bitcoin network
2015-12-28 14:24:16 +03:00
if hdr . magic != btcnet {
discardInput ( r , hdr . length )
return totalBytes , nil , nil , fmt . Errorf ( "message from other network [%v]" , hdr . magic )
}
2016-01-17 04:14:35 +03:00
// Create struct of appropriate message type based on the command
2015-12-28 14:24:16 +03:00
command := hdr . command
msg , err := makeEmptyMessage ( command )
if err != nil {
discardInput ( r , hdr . length )
return totalBytes , nil , nil , fmt . Errorf ( "ReadMessage %s" , err . Error ( ) )
}
2016-01-17 04:14:35 +03:00
// Check for maximum length based on the message type
2015-12-28 14:24:16 +03:00
mpl := msg . MaxPayloadLength ( pver )
if hdr . length > mpl {
discardInput ( r , hdr . length )
return totalBytes , nil , nil , fmt . Errorf ( "payload exceeds max length. indicates %v bytes, but max of message type %v is %v." , hdr . length , command , mpl )
}
2016-01-17 04:14:35 +03:00
// Read payload
2015-12-28 14:24:16 +03:00
payload := make ( [ ] byte , hdr . length )
n , err = io . ReadFull ( r , payload )
totalBytes += n
if err != nil {
return totalBytes , nil , nil , err
}
2016-01-17 04:14:35 +03:00
// Unmarshal message
2015-12-28 14:24:16 +03:00
pr := bytes . NewBuffer ( payload )
err = msg . Decode ( pr , pver )
if err != nil {
return totalBytes , nil , nil , err
}
2016-01-17 04:14:35 +03:00
// Validate the data
2015-12-30 16:38:57 +03:00
err = msg . Validate ( )
2015-12-28 14:24:16 +03:00
if err != nil {
return totalBytes , nil , nil , err
}
2016-01-17 04:14:35 +03:00
// We're good!
2015-12-28 14:24:16 +03:00
return totalBytes , msg , payload , nil
}