Merge pull request #4391 from matheusdtech/lnwire-zero-sid

lnwire: fix decoding for initial zero sid
This commit is contained in:
Conner Fromknecht 2020-06-18 17:21:43 -07:00 committed by GitHub
commit 60a6f2ddd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 3 deletions

@ -184,8 +184,13 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
"short chan ID: %v", err)
}
// We'll ensure that this short chan ID is greater than
// the last one. This is a requirement within the
// encoding, and if violated can aide us in detecting
// malicious payloads. This can only be true starting
// at the second chanID.
cid := shortChanIDs[i]
if cid.ToUint64() <= lastChanID.ToUint64() {
if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
}
lastChanID = cid
@ -224,6 +229,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
var (
shortChanIDs []ShortChannelID
lastChanID ShortChannelID
i int
)
for {
// We'll now attempt to read the next short channel ID
@ -255,12 +261,14 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
// Finally, we'll ensure that this short chan ID is
// greater than the last one. This is a requirement
// within the encoding, and if violated can aide us in
// detecting malicious payloads.
if cid.ToUint64() <= lastChanID.ToUint64() {
// detecting malicious payloads. This can only be true
// starting at the second chanID.
if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
}
lastChanID = cid
i++
}
default:

@ -49,6 +49,7 @@ var (
// TestQueryShortChanIDsUnsorted tests that decoding a QueryShortChanID request
// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure.
func TestQueryShortChanIDsUnsorted(t *testing.T) {
for _, test := range unsortedSidTests {
test := test
t.Run(test.name, func(t *testing.T) {
@ -73,3 +74,48 @@ func TestQueryShortChanIDsUnsorted(t *testing.T) {
})
}
}
// TestQueryShortChanIDsZero ensures that decoding of a list of short chan ids
// still works as expected when the first element of the list is zero.
func TestQueryShortChanIDsZero(t *testing.T) {
testCases := []struct {
name string
encoding ShortChanIDEncoding
}{
{
name: "plain",
encoding: EncodingSortedPlain,
}, {
name: "zlib",
encoding: EncodingSortedZlib,
},
}
testSids := []ShortChannelID{
NewShortChanIDFromInt(0),
NewShortChanIDFromInt(10),
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
req := &QueryShortChanIDs{
EncodingType: test.encoding,
ShortChanIDs: testSids,
noSort: true,
}
var b bytes.Buffer
err := req.Encode(&b, 0)
if err != nil {
t.Fatalf("unable to encode req: %v", err)
}
var req2 QueryShortChanIDs
err = req2.Decode(bytes.NewReader(b.Bytes()), 0)
if err != nil {
t.Fatalf("unexpected decoding error: %v", err)
}
})
}
}