lnwire: implement zlib encode/decode for channel range queries

In this commit, we implement zlib encoding and decoding for the channel
range queries. Notably, we utilize an io.LimitedReader to ensure that we
can enforce a hard cap on the total number of bytes we'll ever allocate
in a decoding attempt.
This commit is contained in:
Olaoluwa Osuntokun 2018-06-15 18:31:23 -07:00
parent c1c4b84757
commit a0e2f8dbd1

@ -2,9 +2,11 @@ package lnwire
import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sort"
"sync"
"github.com/roasbeef/btcd/chaincfg/chainhash"
)
@ -20,8 +22,17 @@ const (
// encoded using the regular encoding, in a sorted order.
EncodingSortedPlain ShortChanIDEncoding = 0
// TODO(roasbeef): list max number of short chan id's that are able to
// use
// EncodingSortedZlib signals that the set of short channel ID's is
// encoded by first sorting the set of channel ID's, as then
// compressing them using zlib.
EncodingSortedZlib ShortChanIDEncoding = 1
)
const (
// maxZlibBufSize is the max number of bytes that we'll accept from a
// zlib decoding instance. We do this in order to limit the total
// amount of memory allocated during a decoding instance.
maxZlibBufSize = 67413630
)
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
@ -144,6 +155,50 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
return encodingType, shortChanIDs, nil
// In this encoding, we'll use zlib to decode the compressed payload.
// However, we'll pay attention to ensure that we don't open our selves
// up to a memory exhaustion attack.
case EncodingSortedZlib:
// Before we start to decode, we'll create a limit reader over
// the current reader. This will ensure that we can control how
// much memory we're allocating during the decoding process.
limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
R: bytes.NewReader(queryBody),
N: maxZlibBufSize,
})
if err != nil {
return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err)
}
var shortChanIDs []ShortChannelID
for {
// We'll now attempt to read the next short channel ID
// encoded in the payload.
var cid ShortChannelID
err := readElements(limitedDecompressor, &cid)
switch {
// If we get an EOF error, then that either means we've
// read all that's contained in the buffer, or have hit
// our limit on the number of bytes we'll read. In
// either case, we'll return what we have so far.
case err == io.ErrUnexpectedEOF || err == io.EOF:
return encodingType, shortChanIDs, nil
// Otherwise, we hit some other sort of error, possibly
// an invalid payload, so we'll exit early with the
// error.
case err != nil:
return 0, nil, fmt.Errorf("unable to "+
"deflate next short chan "+
"ID: %v", err)
}
// We successfully read the next ID, so well collect
// that in the set of final ID's to return.
shortChanIDs = append(shortChanIDs, cid)
}
default:
// If we've been sent an encoding type that we don't know of,
// then we'll return a parsing error as we can't continue if
@ -173,6 +228,13 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error {
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
shortChanIDs []ShortChannelID) error {
// For both of the current encoding types, the channel ID's are to be
// sorted in place, so we'll do that now.
sort.Slice(shortChanIDs, func(i, j int) bool {
return shortChanIDs[i].ToUint64() <
shortChanIDs[j].ToUint64()
})
switch encodingType {
// In this encoding, we'll simply write a sorted array of encoded short
@ -192,13 +254,6 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
return err
}
// Next, we'll ensure that the set of short channel ID's is
// properly sorted in place.
sort.Slice(shortChanIDs, func(i, j int) bool {
return shortChanIDs[i].ToUint64() <
shortChanIDs[j].ToUint64()
})
// Now that we know they're sorted, we can write out each short
// channel ID to the buffer.
for _, chanID := range shortChanIDs {
@ -210,6 +265,54 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
return nil
// For this encoding we'll first write out a serialized version of all
// the channel ID's into a buffer, then zlib encode that. The final
// payload is what we'll write out to the passed io.Writer.
//
// TODO(roasbeef): assumes the caller knows the proper chunk size to
// pass to avoid bin-packing here
case EncodingSortedZlib:
// We'll make a new buffer, then wrap that with a zlib writer
// so we can write directly to the buffer and encode in a
// streaming manner.
var buf bytes.Buffer
zlibWriter := zlib.NewWriter(&buf)
// Next, we'll write out all the channel ID's directly into the
// zlib writer, which will do compressing on the fly.
for _, chanID := range shortChanIDs {
err := writeElements(zlibWriter, chanID)
if err != nil {
return fmt.Errorf("unable to write short chan "+
"ID: %v", err)
}
}
// Now that we've written all the elements, we'll ensure the
// compressed stream is written to the underlying buffer.
if err := zlibWriter.Close(); err != nil {
return fmt.Errorf("unable to finalize "+
"compression: %v", err)
}
// Now that we have all the items compressed, we can compute
// what the total payload size will be. We add one to account
// for the byte to encode the type.
compressedPayload := buf.Bytes()
numBytesBody := len(compressedPayload) + 1
// Finally, we can write out the number of bytes, the
// compression type, and finally the buffer itself.
if err := writeElements(w, uint16(numBytesBody)); err != nil {
return err
}
if err := writeElements(w, encodingType); err != nil {
return err
}
_, err := w.Write(compressedPayload)
return err
default:
// If we're trying to encode with an encoding type that we
// don't know of, then we'll return a parsing error as we can't