From a0e2f8dbd145d9fb6e0ce5142eac980421811c2e Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 15 Jun 2018 18:31:23 -0700 Subject: [PATCH 1/4] lnwire: implement zlib encode/decode for channel range queries In this commit, we implement zlib encoding and decoding for the channel range queries. Notably, we utilize an io.LimitedReader to ensure that we can enforce a hard cap on the total number of bytes we'll ever allocate in a decoding attempt. --- lnwire/query_short_chan_ids.go | 121 ++++++++++++++++++++++++++++++--- 1 file changed, 112 insertions(+), 9 deletions(-) diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 1f4c1d35..f9ff58b5 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -2,9 +2,11 @@ package lnwire import ( "bytes" + "compress/zlib" "fmt" "io" "sort" + "sync" "github.com/roasbeef/btcd/chaincfg/chainhash" ) @@ -20,8 +22,17 @@ const ( // encoded using the regular encoding, in a sorted order. EncodingSortedPlain ShortChanIDEncoding = 0 - // TODO(roasbeef): list max number of short chan id's that are able to - // use + // EncodingSortedZlib signals that the set of short channel ID's is + // encoded by first sorting the set of channel ID's, as then + // compressing them using zlib. + EncodingSortedZlib ShortChanIDEncoding = 1 +) + +const ( + // maxZlibBufSize is the max number of bytes that we'll accept from a + // zlib decoding instance. We do this in order to limit the total + // amount of memory allocated during a decoding instance. + maxZlibBufSize = 67413630 ) // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we @@ -144,6 +155,50 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err return encodingType, shortChanIDs, nil + // In this encoding, we'll use zlib to decode the compressed payload. + // However, we'll pay attention to ensure that we don't open our selves + // up to a memory exhaustion attack. + case EncodingSortedZlib: + // Before we start to decode, we'll create a limit reader over + // the current reader. This will ensure that we can control how + // much memory we're allocating during the decoding process. + limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{ + R: bytes.NewReader(queryBody), + N: maxZlibBufSize, + }) + if err != nil { + return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err) + } + + var shortChanIDs []ShortChannelID + for { + // We'll now attempt to read the next short channel ID + // encoded in the payload. + var cid ShortChannelID + err := readElements(limitedDecompressor, &cid) + + switch { + // If we get an EOF error, then that either means we've + // read all that's contained in the buffer, or have hit + // our limit on the number of bytes we'll read. In + // either case, we'll return what we have so far. + case err == io.ErrUnexpectedEOF || err == io.EOF: + return encodingType, shortChanIDs, nil + + // Otherwise, we hit some other sort of error, possibly + // an invalid payload, so we'll exit early with the + // error. + case err != nil: + return 0, nil, fmt.Errorf("unable to "+ + "deflate next short chan "+ + "ID: %v", err) + } + + // We successfully read the next ID, so well collect + // that in the set of final ID's to return. + shortChanIDs = append(shortChanIDs, cid) + } + default: // If we've been sent an encoding type that we don't know of, // then we'll return a parsing error as we can't continue if @@ -173,6 +228,13 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, shortChanIDs []ShortChannelID) error { + // For both of the current encoding types, the channel ID's are to be + // sorted in place, so we'll do that now. + sort.Slice(shortChanIDs, func(i, j int) bool { + return shortChanIDs[i].ToUint64() < + shortChanIDs[j].ToUint64() + }) + switch encodingType { // In this encoding, we'll simply write a sorted array of encoded short @@ -192,13 +254,6 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, return err } - // Next, we'll ensure that the set of short channel ID's is - // properly sorted in place. - sort.Slice(shortChanIDs, func(i, j int) bool { - return shortChanIDs[i].ToUint64() < - shortChanIDs[j].ToUint64() - }) - // Now that we know they're sorted, we can write out each short // channel ID to the buffer. for _, chanID := range shortChanIDs { @@ -210,6 +265,54 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, return nil + // For this encoding we'll first write out a serialized version of all + // the channel ID's into a buffer, then zlib encode that. The final + // payload is what we'll write out to the passed io.Writer. + // + // TODO(roasbeef): assumes the caller knows the proper chunk size to + // pass to avoid bin-packing here + case EncodingSortedZlib: + // We'll make a new buffer, then wrap that with a zlib writer + // so we can write directly to the buffer and encode in a + // streaming manner. + var buf bytes.Buffer + zlibWriter := zlib.NewWriter(&buf) + + // Next, we'll write out all the channel ID's directly into the + // zlib writer, which will do compressing on the fly. + for _, chanID := range shortChanIDs { + err := writeElements(zlibWriter, chanID) + if err != nil { + return fmt.Errorf("unable to write short chan "+ + "ID: %v", err) + } + } + + // Now that we've written all the elements, we'll ensure the + // compressed stream is written to the underlying buffer. + if err := zlibWriter.Close(); err != nil { + return fmt.Errorf("unable to finalize "+ + "compression: %v", err) + } + + // Now that we have all the items compressed, we can compute + // what the total payload size will be. We add one to account + // for the byte to encode the type. + compressedPayload := buf.Bytes() + numBytesBody := len(compressedPayload) + 1 + + // Finally, we can write out the number of bytes, the + // compression type, and finally the buffer itself. + if err := writeElements(w, uint16(numBytesBody)); err != nil { + return err + } + if err := writeElements(w, encodingType); err != nil { + return err + } + + _, err := w.Write(compressedPayload) + return err + default: // If we're trying to encode with an encoding type that we // don't know of, then we'll return a parsing error as we can't From 5caf3d73105b9e068eecc38cfc689f1413075741 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 15 Jun 2018 18:33:04 -0700 Subject: [PATCH 2/4] lnwire: add new package level mutex to limit # of concurrent zlib decodings In this commit, we add a new package level mutex. Each time we decode a new set of chan IDs w/ zlib, we also grab this mutex. The purpose here is to ensure that we only EVER allocate the maxZlibBufSize globally across all peers. Otherwise, it may be possible for us to allocate up to 64 MB for _each_ peer, exposing an easy OOM attack vector. --- lnwire/query_short_chan_ids.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index f9ff58b5..a0e28f4d 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -35,6 +35,11 @@ const ( maxZlibBufSize = 67413630 ) +// zlibDecodeMtx is a package level mutex that we'll use in order to ensure +// that we'll only attempt a single zlib decoding instance at a time. This +// allows us to also further bound our memory usage. +var zlibDecodeMtx sync.Mutex + // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we // came across an unknown short channel ID encoding, and therefore were unable // to continue parsing. @@ -159,6 +164,12 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // However, we'll pay attention to ensure that we don't open our selves // up to a memory exhaustion attack. case EncodingSortedZlib: + // We'll obtain an ultimately release the zlib decode mutex. + // This guards us against allocating too much memory to decode + // each instance from concurrent peers. + zlibDecodeMtx.Lock() + defer zlibDecodeMtx.Unlock() + // Before we start to decode, we'll create a limit reader over // the current reader. This will ensure that we can control how // much memory we're allocating during the decoding process. From 940b95aad73f01fb76c319308ba7ea1a6a8a9e0d Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 15 Jun 2018 18:34:11 -0700 Subject: [PATCH 3/4] lnwire: update testing.Quick tests to alternate between encoding types --- lnwire/lnwire_test.go | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 0633e6b0..f1311c62 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -651,9 +651,14 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { - req := QueryShortChanIDs{ - // TODO(roasbeef): later alternate encoding types - EncodingType: EncodingSortedPlain, + req := QueryShortChanIDs{} + + // With a 50/50 change, we'll either use zlib encoding, + // or regular encoding. + if r.Int31()%2 == 0 { + req.EncodingType = EncodingSortedZlib + } else { + req.EncodingType = EncodingSortedPlain } if _, err := rand.Read(req.ChainHash[:]); err != nil { @@ -687,8 +692,13 @@ func TestLightningWireProtocol(t *testing.T) { req.Complete = uint8(r.Int31n(2)) - // TODO(roasbeef): later alternate encoding types - req.EncodingType = EncodingSortedPlain + // With a 50/50 change, we'll either use zlib encoding, + // or regular encoding. + if r.Int31()%2 == 0 { + req.EncodingType = EncodingSortedZlib + } else { + req.EncodingType = EncodingSortedPlain + } numChanIDs := rand.Int31n(5000) From 23b1678266fc026540f3c31f3f94c4c9b36e1e39 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 25 Jun 2018 16:15:30 -0700 Subject: [PATCH 4/4] lnwire: ensure zlib short chan id's are sorted --- lnwire/query_short_chan_ids.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index a0e28f4d..fd959c43 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -181,7 +181,10 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err) } - var shortChanIDs []ShortChannelID + var ( + shortChanIDs []ShortChannelID + lastChanID ShortChannelID + ) for { // We'll now attempt to read the next short channel ID // encoded in the payload. @@ -208,6 +211,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // We successfully read the next ID, so well collect // that in the set of final ID's to return. shortChanIDs = append(shortChanIDs, cid) + + // Finally, we'll ensure that this short chan ID is + // greater than the last one. This is a requirement + // within the encoding, and if violated can aide us in + // detecting malicious payloads. + if cid.ToUint64() <= lastChanID.ToUint64() { + return 0, nil, fmt.Errorf("current sid of %v "+ + "isn't greater than last sid of %v", cid, + lastChanID) + } + + lastChanID = cid } default: