61c2493b7d
In this commit, we alter the behavior of the regular short channel id encoding, such that it returns a nil slice if the decoded number of elements is 0. This is done so that it matches the behavior of the zlib decompression, allowing us to test both in using the same corpus.
368 lines
12 KiB
Go
368 lines
12 KiB
Go
package lnwire
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/zlib"
|
|
"fmt"
|
|
"io"
|
|
"sort"
|
|
"sync"
|
|
|
|
"github.com/roasbeef/btcd/chaincfg/chainhash"
|
|
)
|
|
|
|
// ShortChanIDEncoding is an enum-like type that represents exactly how a set
|
|
// of short channel ID's is encoded on the wire. The set of encodings allows us
|
|
// to take advantage of the structure of a list of short channel ID's to
|
|
// achieving a high degree of compression.
|
|
type ShortChanIDEncoding uint8
|
|
|
|
const (
|
|
// EncodingSortedPlain signals that the set of short channel ID's is
|
|
// encoded using the regular encoding, in a sorted order.
|
|
EncodingSortedPlain ShortChanIDEncoding = 0
|
|
|
|
// EncodingSortedZlib signals that the set of short channel ID's is
|
|
// encoded by first sorting the set of channel ID's, as then
|
|
// compressing them using zlib.
|
|
EncodingSortedZlib ShortChanIDEncoding = 1
|
|
)
|
|
|
|
const (
|
|
// maxZlibBufSize is the max number of bytes that we'll accept from a
|
|
// zlib decoding instance. We do this in order to limit the total
|
|
// amount of memory allocated during a decoding instance.
|
|
maxZlibBufSize = 67413630
|
|
)
|
|
|
|
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
|
|
// that we'll only attempt a single zlib decoding instance at a time. This
|
|
// allows us to also further bound our memory usage.
|
|
var zlibDecodeMtx sync.Mutex
|
|
|
|
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
|
|
// came across an unknown short channel ID encoding, and therefore were unable
|
|
// to continue parsing.
|
|
func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error {
|
|
return fmt.Errorf("unknown short chan id encoding: %v", encoding)
|
|
}
|
|
|
|
// QueryShortChanIDs is a message that allows the sender to query a set of
|
|
// channel announcement and channel update messages that correspond to the set
|
|
// of encoded short channel ID's. The encoding of the short channel ID's is
|
|
// detailed in the query message ensuring that the receiver knows how to
|
|
// properly decode each encode short channel ID which may be encoded using a
|
|
// compression format. The receiver should respond with a series of channel
|
|
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd
|
|
// message.
|
|
type QueryShortChanIDs struct {
|
|
// ChainHash denotes the target chain that we're querying for the
|
|
// channel channel ID's of.
|
|
ChainHash chainhash.Hash
|
|
|
|
// EncodingType is a signal to the receiver of the message that
|
|
// indicates exactly how the set of short channel ID's that follow have
|
|
// been encoded.
|
|
EncodingType ShortChanIDEncoding
|
|
|
|
// ShortChanIDs is a slice of decoded short channel ID's.
|
|
ShortChanIDs []ShortChannelID
|
|
}
|
|
|
|
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
|
|
func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding,
|
|
s []ShortChannelID) *QueryShortChanIDs {
|
|
|
|
return &QueryShortChanIDs{
|
|
ChainHash: h,
|
|
EncodingType: e,
|
|
ShortChanIDs: s,
|
|
}
|
|
}
|
|
|
|
// A compile time check to ensure QueryShortChanIDs implements the
|
|
// lnwire.Message interface.
|
|
var _ Message = (*QueryShortChanIDs)(nil)
|
|
|
|
// Decode deserializes a serialized QueryShortChanIDs message stored in the
|
|
// passed io.Reader observing the specified protocol version.
|
|
//
|
|
// This is part of the lnwire.Message interface.
|
|
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
|
|
err := readElements(r, q.ChainHash[:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
|
|
|
|
return err
|
|
}
|
|
|
|
// decodeShortChanIDs decodes a set of short channel ID's that have been
|
|
// encoded. The first byte of the body details how the short chan ID's were
|
|
// encoded. We'll use this type to govern exactly how we go about encoding the
|
|
// set of short channel ID's.
|
|
func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) {
|
|
// First, we'll attempt to read the number of bytes in the body of the
|
|
// set of encoded short channel ID's.
|
|
var numBytesResp uint16
|
|
err := readElements(r, &numBytesResp)
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
|
|
queryBody := make([]byte, numBytesResp)
|
|
if _, err := io.ReadFull(r, queryBody); err != nil {
|
|
return 0, nil, err
|
|
}
|
|
|
|
// The first byte is the encoding type, so we'll extract that so we can
|
|
// continue our parsing.
|
|
encodingType := ShortChanIDEncoding(queryBody[0])
|
|
|
|
// Before continuing, we'll snip off the first byte of the query body
|
|
// as that was just the encoding type.
|
|
queryBody = queryBody[1:]
|
|
|
|
// Otherwise, depending on the encoding type, we'll decode the encode
|
|
// short channel ID's in a different manner.
|
|
switch encodingType {
|
|
|
|
// In this encoding, we'll simply read a sort array of encoded short
|
|
// channel ID's from the buffer.
|
|
case EncodingSortedPlain:
|
|
// If after extracting the encoding type, then number of
|
|
// remaining bytes instead a whole multiple of the size of an
|
|
// encoded short channel ID (8 bytes), then we'll return a
|
|
// parsing error.
|
|
if len(queryBody)%8 != 0 {
|
|
return 0, nil, fmt.Errorf("whole number of short "+
|
|
"chan ID's cannot be encoded in len=%v",
|
|
len(queryBody))
|
|
}
|
|
|
|
// As each short channel ID is encoded as 8 bytes, we can
|
|
// compute the number of bytes encoded based on the size of the
|
|
// query body.
|
|
numShortChanIDs := len(queryBody) / 8
|
|
if numShortChanIDs == 0 {
|
|
return encodingType, nil, nil
|
|
}
|
|
|
|
// Finally, we'll read out the exact number of short channel
|
|
// ID's to conclude our parsing.
|
|
shortChanIDs := make([]ShortChannelID, numShortChanIDs)
|
|
bodyReader := bytes.NewReader(queryBody)
|
|
for i := 0; i < numShortChanIDs; i++ {
|
|
if err := readElements(bodyReader, &shortChanIDs[i]); err != nil {
|
|
return 0, nil, fmt.Errorf("unable to parse "+
|
|
"short chan ID: %v", err)
|
|
}
|
|
}
|
|
|
|
return encodingType, shortChanIDs, nil
|
|
|
|
// In this encoding, we'll use zlib to decode the compressed payload.
|
|
// However, we'll pay attention to ensure that we don't open our selves
|
|
// up to a memory exhaustion attack.
|
|
case EncodingSortedZlib:
|
|
// We'll obtain an ultimately release the zlib decode mutex.
|
|
// This guards us against allocating too much memory to decode
|
|
// each instance from concurrent peers.
|
|
zlibDecodeMtx.Lock()
|
|
defer zlibDecodeMtx.Unlock()
|
|
|
|
// Before we start to decode, we'll create a limit reader over
|
|
// the current reader. This will ensure that we can control how
|
|
// much memory we're allocating during the decoding process.
|
|
limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
|
|
R: bytes.NewReader(queryBody),
|
|
N: maxZlibBufSize,
|
|
})
|
|
if err != nil {
|
|
return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err)
|
|
}
|
|
|
|
var (
|
|
shortChanIDs []ShortChannelID
|
|
lastChanID ShortChannelID
|
|
)
|
|
for {
|
|
// We'll now attempt to read the next short channel ID
|
|
// encoded in the payload.
|
|
var cid ShortChannelID
|
|
err := readElements(limitedDecompressor, &cid)
|
|
|
|
switch {
|
|
// If we get an EOF error, then that either means we've
|
|
// read all that's contained in the buffer, or have hit
|
|
// our limit on the number of bytes we'll read. In
|
|
// either case, we'll return what we have so far.
|
|
case err == io.ErrUnexpectedEOF || err == io.EOF:
|
|
return encodingType, shortChanIDs, nil
|
|
|
|
// Otherwise, we hit some other sort of error, possibly
|
|
// an invalid payload, so we'll exit early with the
|
|
// error.
|
|
case err != nil:
|
|
return 0, nil, fmt.Errorf("unable to "+
|
|
"deflate next short chan "+
|
|
"ID: %v", err)
|
|
}
|
|
|
|
// We successfully read the next ID, so well collect
|
|
// that in the set of final ID's to return.
|
|
shortChanIDs = append(shortChanIDs, cid)
|
|
|
|
// Finally, we'll ensure that this short chan ID is
|
|
// greater than the last one. This is a requirement
|
|
// within the encoding, and if violated can aide us in
|
|
// detecting malicious payloads.
|
|
if cid.ToUint64() <= lastChanID.ToUint64() {
|
|
return 0, nil, fmt.Errorf("current sid of %v "+
|
|
"isn't greater than last sid of %v", cid,
|
|
lastChanID)
|
|
}
|
|
|
|
lastChanID = cid
|
|
}
|
|
|
|
default:
|
|
// If we've been sent an encoding type that we don't know of,
|
|
// then we'll return a parsing error as we can't continue if
|
|
// we're unable to encode them.
|
|
return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
|
|
}
|
|
}
|
|
|
|
// Encode serializes the target QueryShortChanIDs into the passed io.Writer
|
|
// observing the protocol version specified.
|
|
//
|
|
// This is part of the lnwire.Message interface.
|
|
func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error {
|
|
// First, we'll write out the chain hash.
|
|
err := writeElements(w, q.ChainHash[:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Base on our encoding type, we'll write out the set of short channel
|
|
// ID's.
|
|
return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
|
|
}
|
|
|
|
// encodeShortChanIDs encodes the passed short channel ID's into the passed
|
|
// io.Writer, respecting the specified encoding type.
|
|
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
|
|
shortChanIDs []ShortChannelID) error {
|
|
|
|
// For both of the current encoding types, the channel ID's are to be
|
|
// sorted in place, so we'll do that now.
|
|
sort.Slice(shortChanIDs, func(i, j int) bool {
|
|
return shortChanIDs[i].ToUint64() <
|
|
shortChanIDs[j].ToUint64()
|
|
})
|
|
|
|
switch encodingType {
|
|
|
|
// In this encoding, we'll simply write a sorted array of encoded short
|
|
// channel ID's from the buffer.
|
|
case EncodingSortedPlain:
|
|
// First, we'll write out the number of bytes of the query
|
|
// body. We add 1 as the response will have the encoding type
|
|
// prepended to it.
|
|
numBytesBody := uint16(len(shortChanIDs)*8) + 1
|
|
if err := writeElements(w, numBytesBody); err != nil {
|
|
return err
|
|
}
|
|
|
|
// We'll then write out the encoding that that follows the
|
|
// actual encoded short channel ID's.
|
|
if err := writeElements(w, encodingType); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Now that we know they're sorted, we can write out each short
|
|
// channel ID to the buffer.
|
|
for _, chanID := range shortChanIDs {
|
|
if err := writeElements(w, chanID); err != nil {
|
|
return fmt.Errorf("unable to write short chan "+
|
|
"ID: %v", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
|
|
// For this encoding we'll first write out a serialized version of all
|
|
// the channel ID's into a buffer, then zlib encode that. The final
|
|
// payload is what we'll write out to the passed io.Writer.
|
|
//
|
|
// TODO(roasbeef): assumes the caller knows the proper chunk size to
|
|
// pass to avoid bin-packing here
|
|
case EncodingSortedZlib:
|
|
// We'll make a new buffer, then wrap that with a zlib writer
|
|
// so we can write directly to the buffer and encode in a
|
|
// streaming manner.
|
|
var buf bytes.Buffer
|
|
zlibWriter := zlib.NewWriter(&buf)
|
|
|
|
// Next, we'll write out all the channel ID's directly into the
|
|
// zlib writer, which will do compressing on the fly.
|
|
for _, chanID := range shortChanIDs {
|
|
err := writeElements(zlibWriter, chanID)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to write short chan "+
|
|
"ID: %v", err)
|
|
}
|
|
}
|
|
|
|
// Now that we've written all the elements, we'll ensure the
|
|
// compressed stream is written to the underlying buffer.
|
|
if err := zlibWriter.Close(); err != nil {
|
|
return fmt.Errorf("unable to finalize "+
|
|
"compression: %v", err)
|
|
}
|
|
|
|
// Now that we have all the items compressed, we can compute
|
|
// what the total payload size will be. We add one to account
|
|
// for the byte to encode the type.
|
|
compressedPayload := buf.Bytes()
|
|
numBytesBody := len(compressedPayload) + 1
|
|
|
|
// Finally, we can write out the number of bytes, the
|
|
// compression type, and finally the buffer itself.
|
|
if err := writeElements(w, uint16(numBytesBody)); err != nil {
|
|
return err
|
|
}
|
|
if err := writeElements(w, encodingType); err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err := w.Write(compressedPayload)
|
|
return err
|
|
|
|
default:
|
|
// If we're trying to encode with an encoding type that we
|
|
// don't know of, then we'll return a parsing error as we can't
|
|
// continue if we're unable to encode them.
|
|
return ErrUnknownShortChanIDEncoding(encodingType)
|
|
}
|
|
}
|
|
|
|
// MsgType returns the integer uniquely identifying this message type on the
|
|
// wire.
|
|
//
|
|
// This is part of the lnwire.Message interface.
|
|
func (q *QueryShortChanIDs) MsgType() MessageType {
|
|
return MsgQueryShortChanIDs
|
|
}
|
|
|
|
// MaxPayloadLength returns the maximum allowed payload size for a
|
|
// QueryShortChanIDs complete message observing the specified protocol version.
|
|
//
|
|
// This is part of the lnwire.Message interface.
|
|
func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 {
|
|
return MaxMessagePayload
|
|
}
|