You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
434 lines
14 KiB
434 lines
14 KiB
package lnwire |
|
|
|
import ( |
|
"bytes" |
|
"compress/zlib" |
|
"fmt" |
|
"io" |
|
"sort" |
|
"sync" |
|
|
|
"github.com/btcsuite/btcd/chaincfg/chainhash" |
|
) |
|
|
|
// ShortChanIDEncoding is an enum-like type that represents exactly how a set |
|
// of short channel ID's is encoded on the wire. The set of encodings allows us |
|
// to take advantage of the structure of a list of short channel ID's to |
|
// achieving a high degree of compression. |
|
type ShortChanIDEncoding uint8 |
|
|
|
const ( |
|
// EncodingSortedPlain signals that the set of short channel ID's is |
|
// encoded using the regular encoding, in a sorted order. |
|
EncodingSortedPlain ShortChanIDEncoding = 0 |
|
|
|
// EncodingSortedZlib signals that the set of short channel ID's is |
|
// encoded by first sorting the set of channel ID's, as then |
|
// compressing them using zlib. |
|
EncodingSortedZlib ShortChanIDEncoding = 1 |
|
) |
|
|
|
const ( |
|
// maxZlibBufSize is the max number of bytes that we'll accept from a |
|
// zlib decoding instance. We do this in order to limit the total |
|
// amount of memory allocated during a decoding instance. |
|
maxZlibBufSize = 67413630 |
|
) |
|
|
|
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose |
|
// items were not sorted. |
|
type ErrUnsortedSIDs struct { |
|
prevSID ShortChannelID |
|
curSID ShortChannelID |
|
} |
|
|
|
// Error returns a human-readable description of the error. |
|
func (e ErrUnsortedSIDs) Error() string { |
|
return fmt.Sprintf("current sid: %v isn't greater than last sid: %v", |
|
e.curSID, e.prevSID) |
|
} |
|
|
|
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure |
|
// that we'll only attempt a single zlib decoding instance at a time. This |
|
// allows us to also further bound our memory usage. |
|
var zlibDecodeMtx sync.Mutex |
|
|
|
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we |
|
// came across an unknown short channel ID encoding, and therefore were unable |
|
// to continue parsing. |
|
func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error { |
|
return fmt.Errorf("unknown short chan id encoding: %v", encoding) |
|
} |
|
|
|
// QueryShortChanIDs is a message that allows the sender to query a set of |
|
// channel announcement and channel update messages that correspond to the set |
|
// of encoded short channel ID's. The encoding of the short channel ID's is |
|
// detailed in the query message ensuring that the receiver knows how to |
|
// properly decode each encode short channel ID which may be encoded using a |
|
// compression format. The receiver should respond with a series of channel |
|
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd |
|
// message. |
|
type QueryShortChanIDs struct { |
|
// ChainHash denotes the target chain that we're querying for the |
|
// channel ID's of. |
|
ChainHash chainhash.Hash |
|
|
|
// EncodingType is a signal to the receiver of the message that |
|
// indicates exactly how the set of short channel ID's that follow have |
|
// been encoded. |
|
EncodingType ShortChanIDEncoding |
|
|
|
// ShortChanIDs is a slice of decoded short channel ID's. |
|
ShortChanIDs []ShortChannelID |
|
|
|
// ExtraData is the set of data that was appended to this message to |
|
// fill out the full maximum transport message size. These fields can |
|
// be used to specify optional data such as custom TLV fields. |
|
ExtraData ExtraOpaqueData |
|
|
|
// noSort indicates whether or not to sort the short channel ids before |
|
// writing them out. |
|
// |
|
// NOTE: This should only be used during testing. |
|
noSort bool |
|
} |
|
|
|
// NewQueryShortChanIDs creates a new QueryShortChanIDs message. |
|
func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding, |
|
s []ShortChannelID) *QueryShortChanIDs { |
|
|
|
return &QueryShortChanIDs{ |
|
ChainHash: h, |
|
EncodingType: e, |
|
ShortChanIDs: s, |
|
} |
|
} |
|
|
|
// A compile time check to ensure QueryShortChanIDs implements the |
|
// lnwire.Message interface. |
|
var _ Message = (*QueryShortChanIDs)(nil) |
|
|
|
// Decode deserializes a serialized QueryShortChanIDs message stored in the |
|
// passed io.Reader observing the specified protocol version. |
|
// |
|
// This is part of the lnwire.Message interface. |
|
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { |
|
err := ReadElements(r, q.ChainHash[:]) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
return q.ExtraData.Decode(r) |
|
} |
|
|
|
// decodeShortChanIDs decodes a set of short channel ID's that have been |
|
// encoded. The first byte of the body details how the short chan ID's were |
|
// encoded. We'll use this type to govern exactly how we go about encoding the |
|
// set of short channel ID's. |
|
func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { |
|
// First, we'll attempt to read the number of bytes in the body of the |
|
// set of encoded short channel ID's. |
|
var numBytesResp uint16 |
|
err := ReadElements(r, &numBytesResp) |
|
if err != nil { |
|
return 0, nil, err |
|
} |
|
|
|
if numBytesResp == 0 { |
|
return 0, nil, nil |
|
} |
|
|
|
queryBody := make([]byte, numBytesResp) |
|
if _, err := io.ReadFull(r, queryBody); err != nil { |
|
return 0, nil, err |
|
} |
|
|
|
// The first byte is the encoding type, so we'll extract that so we can |
|
// continue our parsing. |
|
encodingType := ShortChanIDEncoding(queryBody[0]) |
|
|
|
// Before continuing, we'll snip off the first byte of the query body |
|
// as that was just the encoding type. |
|
queryBody = queryBody[1:] |
|
|
|
// Otherwise, depending on the encoding type, we'll decode the encode |
|
// short channel ID's in a different manner. |
|
switch encodingType { |
|
|
|
// In this encoding, we'll simply read a sort array of encoded short |
|
// channel ID's from the buffer. |
|
case EncodingSortedPlain: |
|
// If after extracting the encoding type, the number of |
|
// remaining bytes is not a whole multiple of the size of an |
|
// encoded short channel ID (8 bytes), then we'll return a |
|
// parsing error. |
|
if len(queryBody)%8 != 0 { |
|
return 0, nil, fmt.Errorf("whole number of short "+ |
|
"chan ID's cannot be encoded in len=%v", |
|
len(queryBody)) |
|
} |
|
|
|
// As each short channel ID is encoded as 8 bytes, we can |
|
// compute the number of bytes encoded based on the size of the |
|
// query body. |
|
numShortChanIDs := len(queryBody) / 8 |
|
if numShortChanIDs == 0 { |
|
return encodingType, nil, nil |
|
} |
|
|
|
// Finally, we'll read out the exact number of short channel |
|
// ID's to conclude our parsing. |
|
shortChanIDs := make([]ShortChannelID, numShortChanIDs) |
|
bodyReader := bytes.NewReader(queryBody) |
|
var lastChanID ShortChannelID |
|
for i := 0; i < numShortChanIDs; i++ { |
|
if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil { |
|
return 0, nil, fmt.Errorf("unable to parse "+ |
|
"short chan ID: %v", err) |
|
} |
|
|
|
// We'll ensure that this short chan ID is greater than |
|
// the last one. This is a requirement within the |
|
// encoding, and if violated can aide us in detecting |
|
// malicious payloads. This can only be true starting |
|
// at the second chanID. |
|
cid := shortChanIDs[i] |
|
if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { |
|
return 0, nil, ErrUnsortedSIDs{lastChanID, cid} |
|
} |
|
lastChanID = cid |
|
} |
|
|
|
return encodingType, shortChanIDs, nil |
|
|
|
// In this encoding, we'll use zlib to decode the compressed payload. |
|
// However, we'll pay attention to ensure that we don't open our selves |
|
// up to a memory exhaustion attack. |
|
case EncodingSortedZlib: |
|
// We'll obtain an ultimately release the zlib decode mutex. |
|
// This guards us against allocating too much memory to decode |
|
// each instance from concurrent peers. |
|
zlibDecodeMtx.Lock() |
|
defer zlibDecodeMtx.Unlock() |
|
|
|
// At this point, if there's no body remaining, then only the encoding |
|
// type was specified, meaning that there're no further bytes to be |
|
// parsed. |
|
if len(queryBody) == 0 { |
|
return encodingType, nil, nil |
|
} |
|
|
|
// Before we start to decode, we'll create a limit reader over |
|
// the current reader. This will ensure that we can control how |
|
// much memory we're allocating during the decoding process. |
|
limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{ |
|
R: bytes.NewReader(queryBody), |
|
N: maxZlibBufSize, |
|
}) |
|
if err != nil { |
|
return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err) |
|
} |
|
|
|
var ( |
|
shortChanIDs []ShortChannelID |
|
lastChanID ShortChannelID |
|
i int |
|
) |
|
for { |
|
// We'll now attempt to read the next short channel ID |
|
// encoded in the payload. |
|
var cid ShortChannelID |
|
err := ReadElements(limitedDecompressor, &cid) |
|
|
|
switch { |
|
// If we get an EOF error, then that either means we've |
|
// read all that's contained in the buffer, or have hit |
|
// our limit on the number of bytes we'll read. In |
|
// either case, we'll return what we have so far. |
|
case err == io.ErrUnexpectedEOF || err == io.EOF: |
|
return encodingType, shortChanIDs, nil |
|
|
|
// Otherwise, we hit some other sort of error, possibly |
|
// an invalid payload, so we'll exit early with the |
|
// error. |
|
case err != nil: |
|
return 0, nil, fmt.Errorf("unable to "+ |
|
"deflate next short chan "+ |
|
"ID: %v", err) |
|
} |
|
|
|
// We successfully read the next ID, so we'll collect |
|
// that in the set of final ID's to return. |
|
shortChanIDs = append(shortChanIDs, cid) |
|
|
|
// Finally, we'll ensure that this short chan ID is |
|
// greater than the last one. This is a requirement |
|
// within the encoding, and if violated can aide us in |
|
// detecting malicious payloads. This can only be true |
|
// starting at the second chanID. |
|
if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { |
|
return 0, nil, ErrUnsortedSIDs{lastChanID, cid} |
|
} |
|
|
|
lastChanID = cid |
|
i++ |
|
} |
|
|
|
default: |
|
// If we've been sent an encoding type that we don't know of, |
|
// then we'll return a parsing error as we can't continue if |
|
// we're unable to encode them. |
|
return 0, nil, ErrUnknownShortChanIDEncoding(encodingType) |
|
} |
|
} |
|
|
|
// Encode serializes the target QueryShortChanIDs into the passed io.Writer |
|
// observing the protocol version specified. |
|
// |
|
// This is part of the lnwire.Message interface. |
|
func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { |
|
// First, we'll write out the chain hash. |
|
err := WriteElements(w, q.ChainHash[:]) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
// Base on our encoding type, we'll write out the set of short channel |
|
// ID's. |
|
err = encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
return q.ExtraData.Encode(w) |
|
} |
|
|
|
// encodeShortChanIDs encodes the passed short channel ID's into the passed |
|
// io.Writer, respecting the specified encoding type. |
|
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, |
|
shortChanIDs []ShortChannelID, noSort bool) error { |
|
|
|
// For both of the current encoding types, the channel ID's are to be |
|
// sorted in place, so we'll do that now. The sorting is applied unless |
|
// we were specifically requested not to for testing purposes. |
|
if !noSort { |
|
sort.Slice(shortChanIDs, func(i, j int) bool { |
|
return shortChanIDs[i].ToUint64() < |
|
shortChanIDs[j].ToUint64() |
|
}) |
|
} |
|
|
|
switch encodingType { |
|
|
|
// In this encoding, we'll simply write a sorted array of encoded short |
|
// channel ID's from the buffer. |
|
case EncodingSortedPlain: |
|
// First, we'll write out the number of bytes of the query |
|
// body. We add 1 as the response will have the encoding type |
|
// prepended to it. |
|
numBytesBody := uint16(len(shortChanIDs)*8) + 1 |
|
if err := WriteElements(w, numBytesBody); err != nil { |
|
return err |
|
} |
|
|
|
// We'll then write out the encoding that that follows the |
|
// actual encoded short channel ID's. |
|
if err := WriteElements(w, encodingType); err != nil { |
|
return err |
|
} |
|
|
|
// Now that we know they're sorted, we can write out each short |
|
// channel ID to the buffer. |
|
for _, chanID := range shortChanIDs { |
|
if err := WriteElements(w, chanID); err != nil { |
|
return fmt.Errorf("unable to write short chan "+ |
|
"ID: %v", err) |
|
} |
|
} |
|
|
|
return nil |
|
|
|
// For this encoding we'll first write out a serialized version of all |
|
// the channel ID's into a buffer, then zlib encode that. The final |
|
// payload is what we'll write out to the passed io.Writer. |
|
// |
|
// TODO(roasbeef): assumes the caller knows the proper chunk size to |
|
// pass to avoid bin-packing here |
|
case EncodingSortedZlib: |
|
// We'll make a new buffer, then wrap that with a zlib writer |
|
// so we can write directly to the buffer and encode in a |
|
// streaming manner. |
|
var buf bytes.Buffer |
|
zlibWriter := zlib.NewWriter(&buf) |
|
|
|
// If we don't have anything at all to write, then we'll write |
|
// an empty payload so we don't include things like the zlib |
|
// header when the remote party is expecting no actual short |
|
// channel IDs. |
|
var compressedPayload []byte |
|
if len(shortChanIDs) > 0 { |
|
// Next, we'll write out all the channel ID's directly |
|
// into the zlib writer, which will do compressing on |
|
// the fly. |
|
for _, chanID := range shortChanIDs { |
|
err := WriteElements(zlibWriter, chanID) |
|
if err != nil { |
|
return fmt.Errorf("unable to write short chan "+ |
|
"ID: %v", err) |
|
} |
|
} |
|
|
|
// Now that we've written all the elements, we'll |
|
// ensure the compressed stream is written to the |
|
// underlying buffer. |
|
if err := zlibWriter.Close(); err != nil { |
|
return fmt.Errorf("unable to finalize "+ |
|
"compression: %v", err) |
|
} |
|
|
|
compressedPayload = buf.Bytes() |
|
} |
|
|
|
// Now that we have all the items compressed, we can compute |
|
// what the total payload size will be. We add one to account |
|
// for the byte to encode the type. |
|
// |
|
// If we don't have any actual bytes to write, then we'll end |
|
// up emitting one byte for the length, followed by the |
|
// encoding type, and nothing more. The spec isn't 100% clear |
|
// in this area, but we do this as this is what most of the |
|
// other implementations do. |
|
numBytesBody := len(compressedPayload) + 1 |
|
|
|
// Finally, we can write out the number of bytes, the |
|
// compression type, and finally the buffer itself. |
|
if err := WriteElements(w, uint16(numBytesBody)); err != nil { |
|
return err |
|
} |
|
if err := WriteElements(w, encodingType); err != nil { |
|
return err |
|
} |
|
|
|
_, err := w.Write(compressedPayload) |
|
return err |
|
|
|
default: |
|
// If we're trying to encode with an encoding type that we |
|
// don't know of, then we'll return a parsing error as we can't |
|
// continue if we're unable to encode them. |
|
return ErrUnknownShortChanIDEncoding(encodingType) |
|
} |
|
} |
|
|
|
// MsgType returns the integer uniquely identifying this message type on the |
|
// wire. |
|
// |
|
// This is part of the lnwire.Message interface. |
|
func (q *QueryShortChanIDs) MsgType() MessageType { |
|
return MsgQueryShortChanIDs |
|
}
|
|
|