Merge pull request #1399 from Roasbeef/zlib-decoding
lnwire: implement cautious zlib decoding for channel range queries
This commit is contained in:
commit
1219e14955
@ -651,9 +651,14 @@ func TestLightningWireProtocol(t *testing.T) {
|
||||
v[0] = reflect.ValueOf(req)
|
||||
},
|
||||
MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) {
|
||||
req := QueryShortChanIDs{
|
||||
// TODO(roasbeef): later alternate encoding types
|
||||
EncodingType: EncodingSortedPlain,
|
||||
req := QueryShortChanIDs{}
|
||||
|
||||
// With a 50/50 change, we'll either use zlib encoding,
|
||||
// or regular encoding.
|
||||
if r.Int31()%2 == 0 {
|
||||
req.EncodingType = EncodingSortedZlib
|
||||
} else {
|
||||
req.EncodingType = EncodingSortedPlain
|
||||
}
|
||||
|
||||
if _, err := rand.Read(req.ChainHash[:]); err != nil {
|
||||
@ -687,8 +692,13 @@ func TestLightningWireProtocol(t *testing.T) {
|
||||
|
||||
req.Complete = uint8(r.Int31n(2))
|
||||
|
||||
// TODO(roasbeef): later alternate encoding types
|
||||
req.EncodingType = EncodingSortedPlain
|
||||
// With a 50/50 change, we'll either use zlib encoding,
|
||||
// or regular encoding.
|
||||
if r.Int31()%2 == 0 {
|
||||
req.EncodingType = EncodingSortedZlib
|
||||
} else {
|
||||
req.EncodingType = EncodingSortedPlain
|
||||
}
|
||||
|
||||
numChanIDs := rand.Int31n(5000)
|
||||
|
||||
|
@ -2,9 +2,11 @@ package lnwire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/roasbeef/btcd/chaincfg/chainhash"
|
||||
)
|
||||
@ -20,10 +22,24 @@ const (
|
||||
// encoded using the regular encoding, in a sorted order.
|
||||
EncodingSortedPlain ShortChanIDEncoding = 0
|
||||
|
||||
// TODO(roasbeef): list max number of short chan id's that are able to
|
||||
// use
|
||||
// EncodingSortedZlib signals that the set of short channel ID's is
|
||||
// encoded by first sorting the set of channel ID's, as then
|
||||
// compressing them using zlib.
|
||||
EncodingSortedZlib ShortChanIDEncoding = 1
|
||||
)
|
||||
|
||||
const (
|
||||
// maxZlibBufSize is the max number of bytes that we'll accept from a
|
||||
// zlib decoding instance. We do this in order to limit the total
|
||||
// amount of memory allocated during a decoding instance.
|
||||
maxZlibBufSize = 67413630
|
||||
)
|
||||
|
||||
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
|
||||
// that we'll only attempt a single zlib decoding instance at a time. This
|
||||
// allows us to also further bound our memory usage.
|
||||
var zlibDecodeMtx sync.Mutex
|
||||
|
||||
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
|
||||
// came across an unknown short channel ID encoding, and therefore were unable
|
||||
// to continue parsing.
|
||||
@ -144,6 +160,71 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
|
||||
|
||||
return encodingType, shortChanIDs, nil
|
||||
|
||||
// In this encoding, we'll use zlib to decode the compressed payload.
|
||||
// However, we'll pay attention to ensure that we don't open our selves
|
||||
// up to a memory exhaustion attack.
|
||||
case EncodingSortedZlib:
|
||||
// We'll obtain an ultimately release the zlib decode mutex.
|
||||
// This guards us against allocating too much memory to decode
|
||||
// each instance from concurrent peers.
|
||||
zlibDecodeMtx.Lock()
|
||||
defer zlibDecodeMtx.Unlock()
|
||||
|
||||
// Before we start to decode, we'll create a limit reader over
|
||||
// the current reader. This will ensure that we can control how
|
||||
// much memory we're allocating during the decoding process.
|
||||
limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
|
||||
R: bytes.NewReader(queryBody),
|
||||
N: maxZlibBufSize,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
shortChanIDs []ShortChannelID
|
||||
lastChanID ShortChannelID
|
||||
)
|
||||
for {
|
||||
// We'll now attempt to read the next short channel ID
|
||||
// encoded in the payload.
|
||||
var cid ShortChannelID
|
||||
err := readElements(limitedDecompressor, &cid)
|
||||
|
||||
switch {
|
||||
// If we get an EOF error, then that either means we've
|
||||
// read all that's contained in the buffer, or have hit
|
||||
// our limit on the number of bytes we'll read. In
|
||||
// either case, we'll return what we have so far.
|
||||
case err == io.ErrUnexpectedEOF || err == io.EOF:
|
||||
return encodingType, shortChanIDs, nil
|
||||
|
||||
// Otherwise, we hit some other sort of error, possibly
|
||||
// an invalid payload, so we'll exit early with the
|
||||
// error.
|
||||
case err != nil:
|
||||
return 0, nil, fmt.Errorf("unable to "+
|
||||
"deflate next short chan "+
|
||||
"ID: %v", err)
|
||||
}
|
||||
|
||||
// We successfully read the next ID, so well collect
|
||||
// that in the set of final ID's to return.
|
||||
shortChanIDs = append(shortChanIDs, cid)
|
||||
|
||||
// Finally, we'll ensure that this short chan ID is
|
||||
// greater than the last one. This is a requirement
|
||||
// within the encoding, and if violated can aide us in
|
||||
// detecting malicious payloads.
|
||||
if cid.ToUint64() <= lastChanID.ToUint64() {
|
||||
return 0, nil, fmt.Errorf("current sid of %v "+
|
||||
"isn't greater than last sid of %v", cid,
|
||||
lastChanID)
|
||||
}
|
||||
|
||||
lastChanID = cid
|
||||
}
|
||||
|
||||
default:
|
||||
// If we've been sent an encoding type that we don't know of,
|
||||
// then we'll return a parsing error as we can't continue if
|
||||
@ -173,6 +254,13 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error {
|
||||
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
|
||||
shortChanIDs []ShortChannelID) error {
|
||||
|
||||
// For both of the current encoding types, the channel ID's are to be
|
||||
// sorted in place, so we'll do that now.
|
||||
sort.Slice(shortChanIDs, func(i, j int) bool {
|
||||
return shortChanIDs[i].ToUint64() <
|
||||
shortChanIDs[j].ToUint64()
|
||||
})
|
||||
|
||||
switch encodingType {
|
||||
|
||||
// In this encoding, we'll simply write a sorted array of encoded short
|
||||
@ -192,13 +280,6 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
|
||||
return err
|
||||
}
|
||||
|
||||
// Next, we'll ensure that the set of short channel ID's is
|
||||
// properly sorted in place.
|
||||
sort.Slice(shortChanIDs, func(i, j int) bool {
|
||||
return shortChanIDs[i].ToUint64() <
|
||||
shortChanIDs[j].ToUint64()
|
||||
})
|
||||
|
||||
// Now that we know they're sorted, we can write out each short
|
||||
// channel ID to the buffer.
|
||||
for _, chanID := range shortChanIDs {
|
||||
@ -210,6 +291,54 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
|
||||
|
||||
return nil
|
||||
|
||||
// For this encoding we'll first write out a serialized version of all
|
||||
// the channel ID's into a buffer, then zlib encode that. The final
|
||||
// payload is what we'll write out to the passed io.Writer.
|
||||
//
|
||||
// TODO(roasbeef): assumes the caller knows the proper chunk size to
|
||||
// pass to avoid bin-packing here
|
||||
case EncodingSortedZlib:
|
||||
// We'll make a new buffer, then wrap that with a zlib writer
|
||||
// so we can write directly to the buffer and encode in a
|
||||
// streaming manner.
|
||||
var buf bytes.Buffer
|
||||
zlibWriter := zlib.NewWriter(&buf)
|
||||
|
||||
// Next, we'll write out all the channel ID's directly into the
|
||||
// zlib writer, which will do compressing on the fly.
|
||||
for _, chanID := range shortChanIDs {
|
||||
err := writeElements(zlibWriter, chanID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to write short chan "+
|
||||
"ID: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we've written all the elements, we'll ensure the
|
||||
// compressed stream is written to the underlying buffer.
|
||||
if err := zlibWriter.Close(); err != nil {
|
||||
return fmt.Errorf("unable to finalize "+
|
||||
"compression: %v", err)
|
||||
}
|
||||
|
||||
// Now that we have all the items compressed, we can compute
|
||||
// what the total payload size will be. We add one to account
|
||||
// for the byte to encode the type.
|
||||
compressedPayload := buf.Bytes()
|
||||
numBytesBody := len(compressedPayload) + 1
|
||||
|
||||
// Finally, we can write out the number of bytes, the
|
||||
// compression type, and finally the buffer itself.
|
||||
if err := writeElements(w, uint16(numBytesBody)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeElements(w, encodingType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := w.Write(compressedPayload)
|
||||
return err
|
||||
|
||||
default:
|
||||
// If we're trying to encode with an encoding type that we
|
||||
// don't know of, then we'll return a parsing error as we can't
|
||||
|
Loading…
Reference in New Issue
Block a user