lnwire: fix decoding for initial zero sid

This fixes a decoding error when the list of short channel ids within a
QueryShortChanIDs message started with a zero sid.

BOLT-0007 specifies that lists of short channel ids should be sorted in
ascending order. Previously, this was checked within lnwire by comparing
two consecutive sids in the list, starting at the empty (zero) sid.

This meant that a list that started with a zero sid couldn't be decoded
since the first element would _not_ be greater than the last one
(namely: also zero).

Given that one can only check for ordering starting at the second
element, we add a check to ensure the proper behavior.

A unit test is also added to ensure no future regressions on this
behavior.
This commit is contained in:
Matheus Degiovani 2020-06-18 14:04:39 -03:00
parent 87880c0d56
commit 44555a70ed
2 changed files with 57 additions and 3 deletions

View File

@ -184,8 +184,13 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
"short chan ID: %v", err) "short chan ID: %v", err)
} }
// We'll ensure that this short chan ID is greater than
// the last one. This is a requirement within the
// encoding, and if violated can aide us in detecting
// malicious payloads. This can only be true starting
// at the second chanID.
cid := shortChanIDs[i] cid := shortChanIDs[i]
if cid.ToUint64() <= lastChanID.ToUint64() { if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, ErrUnsortedSIDs{lastChanID, cid} return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
} }
lastChanID = cid lastChanID = cid
@ -224,6 +229,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
var ( var (
shortChanIDs []ShortChannelID shortChanIDs []ShortChannelID
lastChanID ShortChannelID lastChanID ShortChannelID
i int
) )
for { for {
// We'll now attempt to read the next short channel ID // We'll now attempt to read the next short channel ID
@ -255,12 +261,14 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
// Finally, we'll ensure that this short chan ID is // Finally, we'll ensure that this short chan ID is
// greater than the last one. This is a requirement // greater than the last one. This is a requirement
// within the encoding, and if violated can aide us in // within the encoding, and if violated can aide us in
// detecting malicious payloads. // detecting malicious payloads. This can only be true
if cid.ToUint64() <= lastChanID.ToUint64() { // starting at the second chanID.
if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, ErrUnsortedSIDs{lastChanID, cid} return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
} }
lastChanID = cid lastChanID = cid
i++
} }
default: default:

View File

@ -49,6 +49,7 @@ var (
// TestQueryShortChanIDsUnsorted tests that decoding a QueryShortChanID request // TestQueryShortChanIDsUnsorted tests that decoding a QueryShortChanID request
// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure. // that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure.
func TestQueryShortChanIDsUnsorted(t *testing.T) { func TestQueryShortChanIDsUnsorted(t *testing.T) {
for _, test := range unsortedSidTests { for _, test := range unsortedSidTests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
@ -73,3 +74,48 @@ func TestQueryShortChanIDsUnsorted(t *testing.T) {
}) })
} }
} }
// TestQueryShortChanIDsZero ensures that decoding of a list of short chan ids
// still works as expected when the first element of the list is zero.
func TestQueryShortChanIDsZero(t *testing.T) {
testCases := []struct {
name string
encoding ShortChanIDEncoding
}{
{
name: "plain",
encoding: EncodingSortedPlain,
}, {
name: "zlib",
encoding: EncodingSortedZlib,
},
}
testSids := []ShortChannelID{
NewShortChanIDFromInt(0),
NewShortChanIDFromInt(10),
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
req := &QueryShortChanIDs{
EncodingType: test.encoding,
ShortChanIDs: testSids,
noSort: true,
}
var b bytes.Buffer
err := req.Encode(&b, 0)
if err != nil {
t.Fatalf("unable to encode req: %v", err)
}
var req2 QueryShortChanIDs
err = req2.Decode(bytes.NewReader(b.Bytes()), 0)
if err != nil {
t.Fatalf("unexpected decoding error: %v", err)
}
})
}
}