diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 64521ce3..afa01946 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -35,6 +35,19 @@ const ( maxZlibBufSize = 67413630 ) +// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose +// items were not sorted. +type ErrUnsortedSIDs struct { + prevSID ShortChannelID + curSID ShortChannelID +} + +// Error returns a human-readable description of the error. +func (e ErrUnsortedSIDs) Error() string { + return fmt.Sprintf("current sid: %v isn't greater than last sid: %v", + e.curSID, e.prevSID) +} + // zlibDecodeMtx is a package level mutex that we'll use in order to ensure // that we'll only attempt a single zlib decoding instance at a time. This // allows us to also further bound our memory usage. @@ -67,6 +80,12 @@ type QueryShortChanIDs struct { // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + + // noSort indicates whether or not to sort the short channel ids before + // writing them out. + // + // NOTE: This should only be used during testing. + noSort bool } // NewQueryShortChanIDs creates a new QueryShortChanIDs message. @@ -158,11 +177,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // ID's to conclude our parsing. shortChanIDs := make([]ShortChannelID, numShortChanIDs) bodyReader := bytes.NewReader(queryBody) + var lastChanID ShortChannelID for i := 0; i < numShortChanIDs; i++ { if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil { return 0, nil, fmt.Errorf("unable to parse "+ "short chan ID: %v", err) } + + cid := shortChanIDs[i] + if cid.ToUint64() <= lastChanID.ToUint64() { + return 0, nil, ErrUnsortedSIDs{lastChanID, cid} + } + lastChanID = cid } return encodingType, shortChanIDs, nil @@ -224,9 +250,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // within the encoding, and if violated can aide us in // detecting malicious payloads. if cid.ToUint64() <= lastChanID.ToUint64() { - return 0, nil, fmt.Errorf("current sid of %v "+ - "isn't greater than last sid of %v", cid, - lastChanID) + return 0, nil, ErrUnsortedSIDs{lastChanID, cid} } lastChanID = cid @@ -253,20 +277,23 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { // Base on our encoding type, we'll write out the set of short channel // ID's. - return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs) + return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) } // encodeShortChanIDs encodes the passed short channel ID's into the passed // io.Writer, respecting the specified encoding type. func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, - shortChanIDs []ShortChannelID) error { + shortChanIDs []ShortChannelID, noSort bool) error { // For both of the current encoding types, the channel ID's are to be - // sorted in place, so we'll do that now. - sort.Slice(shortChanIDs, func(i, j int) bool { - return shortChanIDs[i].ToUint64() < - shortChanIDs[j].ToUint64() - }) + // sorted in place, so we'll do that now. The sorting is applied unless + // we were specifically requested not to for testing purposes. + if !noSort { + sort.Slice(shortChanIDs, func(i, j int) bool { + return shortChanIDs[i].ToUint64() < + shortChanIDs[j].ToUint64() + }) + } switch encodingType { diff --git a/lnwire/query_short_chan_ids_test.go b/lnwire/query_short_chan_ids_test.go new file mode 100644 index 00000000..7d0538f5 --- /dev/null +++ b/lnwire/query_short_chan_ids_test.go @@ -0,0 +1,75 @@ +package lnwire + +import ( + "bytes" + "testing" +) + +type unsortedSidTest struct { + name string + encType ShortChanIDEncoding + sids []ShortChannelID +} + +var ( + unsortedSids = []ShortChannelID{ + NewShortChanIDFromInt(4), + NewShortChanIDFromInt(3), + } + + duplicateSids = []ShortChannelID{ + NewShortChanIDFromInt(3), + NewShortChanIDFromInt(3), + } + + unsortedSidTests = []unsortedSidTest{ + { + name: "plain unsorted", + encType: EncodingSortedPlain, + sids: unsortedSids, + }, + { + name: "plain duplicate", + encType: EncodingSortedPlain, + sids: duplicateSids, + }, + { + name: "zlib unsorted", + encType: EncodingSortedZlib, + sids: unsortedSids, + }, + { + name: "zlib duplicate", + encType: EncodingSortedZlib, + sids: duplicateSids, + }, + } +) + +// TestQueryShortChanIDsUnsorted tests that decoding a QueryShortChanID request +// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure. +func TestQueryShortChanIDsUnsorted(t *testing.T) { + for _, test := range unsortedSidTests { + test := test + t.Run(test.name, func(t *testing.T) { + req := &QueryShortChanIDs{ + EncodingType: test.encType, + ShortChanIDs: test.sids, + noSort: true, + } + + var b bytes.Buffer + err := req.Encode(&b, 0) + if err != nil { + t.Fatalf("unable to encode req: %v", err) + } + + var req2 QueryShortChanIDs + err = req2.Decode(bytes.NewReader(b.Bytes()), 0) + if _, ok := err.(ErrUnsortedSIDs); !ok { + t.Fatalf("expected ErrUnsortedSIDs, got: %T", + err) + } + }) + } +} diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 5765191d..18432392 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -64,7 +64,7 @@ func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { return err } - return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs) + return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, false) } // MsgType returns the integer uniquely identifying this message type on the