diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 1f4c1d35..f9ff58b5 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -2,9 +2,11 @@ package lnwire import ( "bytes" + "compress/zlib" "fmt" "io" "sort" + "sync" "github.com/roasbeef/btcd/chaincfg/chainhash" ) @@ -20,8 +22,17 @@ const ( // encoded using the regular encoding, in a sorted order. EncodingSortedPlain ShortChanIDEncoding = 0 - // TODO(roasbeef): list max number of short chan id's that are able to - // use + // EncodingSortedZlib signals that the set of short channel ID's is + // encoded by first sorting the set of channel ID's, as then + // compressing them using zlib. + EncodingSortedZlib ShortChanIDEncoding = 1 +) + +const ( + // maxZlibBufSize is the max number of bytes that we'll accept from a + // zlib decoding instance. We do this in order to limit the total + // amount of memory allocated during a decoding instance. + maxZlibBufSize = 67413630 ) // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we @@ -144,6 +155,50 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err return encodingType, shortChanIDs, nil + // In this encoding, we'll use zlib to decode the compressed payload. + // However, we'll pay attention to ensure that we don't open our selves + // up to a memory exhaustion attack. + case EncodingSortedZlib: + // Before we start to decode, we'll create a limit reader over + // the current reader. This will ensure that we can control how + // much memory we're allocating during the decoding process. + limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{ + R: bytes.NewReader(queryBody), + N: maxZlibBufSize, + }) + if err != nil { + return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err) + } + + var shortChanIDs []ShortChannelID + for { + // We'll now attempt to read the next short channel ID + // encoded in the payload. + var cid ShortChannelID + err := readElements(limitedDecompressor, &cid) + + switch { + // If we get an EOF error, then that either means we've + // read all that's contained in the buffer, or have hit + // our limit on the number of bytes we'll read. In + // either case, we'll return what we have so far. + case err == io.ErrUnexpectedEOF || err == io.EOF: + return encodingType, shortChanIDs, nil + + // Otherwise, we hit some other sort of error, possibly + // an invalid payload, so we'll exit early with the + // error. + case err != nil: + return 0, nil, fmt.Errorf("unable to "+ + "deflate next short chan "+ + "ID: %v", err) + } + + // We successfully read the next ID, so well collect + // that in the set of final ID's to return. + shortChanIDs = append(shortChanIDs, cid) + } + default: // If we've been sent an encoding type that we don't know of, // then we'll return a parsing error as we can't continue if @@ -173,6 +228,13 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, shortChanIDs []ShortChannelID) error { + // For both of the current encoding types, the channel ID's are to be + // sorted in place, so we'll do that now. + sort.Slice(shortChanIDs, func(i, j int) bool { + return shortChanIDs[i].ToUint64() < + shortChanIDs[j].ToUint64() + }) + switch encodingType { // In this encoding, we'll simply write a sorted array of encoded short @@ -192,13 +254,6 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, return err } - // Next, we'll ensure that the set of short channel ID's is - // properly sorted in place. - sort.Slice(shortChanIDs, func(i, j int) bool { - return shortChanIDs[i].ToUint64() < - shortChanIDs[j].ToUint64() - }) - // Now that we know they're sorted, we can write out each short // channel ID to the buffer. for _, chanID := range shortChanIDs { @@ -210,6 +265,54 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, return nil + // For this encoding we'll first write out a serialized version of all + // the channel ID's into a buffer, then zlib encode that. The final + // payload is what we'll write out to the passed io.Writer. + // + // TODO(roasbeef): assumes the caller knows the proper chunk size to + // pass to avoid bin-packing here + case EncodingSortedZlib: + // We'll make a new buffer, then wrap that with a zlib writer + // so we can write directly to the buffer and encode in a + // streaming manner. + var buf bytes.Buffer + zlibWriter := zlib.NewWriter(&buf) + + // Next, we'll write out all the channel ID's directly into the + // zlib writer, which will do compressing on the fly. + for _, chanID := range shortChanIDs { + err := writeElements(zlibWriter, chanID) + if err != nil { + return fmt.Errorf("unable to write short chan "+ + "ID: %v", err) + } + } + + // Now that we've written all the elements, we'll ensure the + // compressed stream is written to the underlying buffer. + if err := zlibWriter.Close(); err != nil { + return fmt.Errorf("unable to finalize "+ + "compression: %v", err) + } + + // Now that we have all the items compressed, we can compute + // what the total payload size will be. We add one to account + // for the byte to encode the type. + compressedPayload := buf.Bytes() + numBytesBody := len(compressedPayload) + 1 + + // Finally, we can write out the number of bytes, the + // compression type, and finally the buffer itself. + if err := writeElements(w, uint16(numBytesBody)); err != nil { + return err + } + if err := writeElements(w, encodingType); err != nil { + return err + } + + _, err := w.Write(compressedPayload) + return err + default: // If we're trying to encode with an encoding type that we // don't know of, then we'll return a parsing error as we can't