diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 76729364..cb24178b 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -184,8 +184,13 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err "short chan ID: %v", err) } + // We'll ensure that this short chan ID is greater than + // the last one. This is a requirement within the + // encoding, and if violated can aide us in detecting + // malicious payloads. This can only be true starting + // at the second chanID. cid := shortChanIDs[i] - if cid.ToUint64() <= lastChanID.ToUint64() { + if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { return 0, nil, ErrUnsortedSIDs{lastChanID, cid} } lastChanID = cid @@ -224,6 +229,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err var ( shortChanIDs []ShortChannelID lastChanID ShortChannelID + i int ) for { // We'll now attempt to read the next short channel ID @@ -255,12 +261,14 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // Finally, we'll ensure that this short chan ID is // greater than the last one. This is a requirement // within the encoding, and if violated can aide us in - // detecting malicious payloads. - if cid.ToUint64() <= lastChanID.ToUint64() { + // detecting malicious payloads. This can only be true + // starting at the second chanID. + if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { return 0, nil, ErrUnsortedSIDs{lastChanID, cid} } lastChanID = cid + i++ } default: diff --git a/lnwire/query_short_chan_ids_test.go b/lnwire/query_short_chan_ids_test.go index 7d0538f5..6a2ecb62 100644 --- a/lnwire/query_short_chan_ids_test.go +++ b/lnwire/query_short_chan_ids_test.go @@ -49,6 +49,7 @@ var ( // TestQueryShortChanIDsUnsorted tests that decoding a QueryShortChanID request // that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure. func TestQueryShortChanIDsUnsorted(t *testing.T) { + for _, test := range unsortedSidTests { test := test t.Run(test.name, func(t *testing.T) { @@ -73,3 +74,48 @@ func TestQueryShortChanIDsUnsorted(t *testing.T) { }) } } + +// TestQueryShortChanIDsZero ensures that decoding of a list of short chan ids +// still works as expected when the first element of the list is zero. +func TestQueryShortChanIDsZero(t *testing.T) { + testCases := []struct { + name string + encoding ShortChanIDEncoding + }{ + { + name: "plain", + encoding: EncodingSortedPlain, + }, { + name: "zlib", + encoding: EncodingSortedZlib, + }, + } + + testSids := []ShortChannelID{ + NewShortChanIDFromInt(0), + NewShortChanIDFromInt(10), + } + + for _, test := range testCases { + test := test + t.Run(test.name, func(t *testing.T) { + req := &QueryShortChanIDs{ + EncodingType: test.encoding, + ShortChanIDs: testSids, + noSort: true, + } + + var b bytes.Buffer + err := req.Encode(&b, 0) + if err != nil { + t.Fatalf("unable to encode req: %v", err) + } + + var req2 QueryShortChanIDs + err = req2.Decode(bytes.NewReader(b.Bytes()), 0) + if err != nil { + t.Fatalf("unexpected decoding error: %v", err) + } + }) + } +}