Merge pull request #3743 from cfromknecht/in-order-sids

lnwire: assert sorted short channel ids
This commit is contained in:
Johan T. Halseth 2019-12-04 08:56:02 +01:00 committed by GitHub
commit a6ef03c777
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 153 additions and 11 deletions

View File

@ -35,6 +35,19 @@ const (
maxZlibBufSize = 67413630
)
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
// items were not sorted.
type ErrUnsortedSIDs struct {
prevSID ShortChannelID
curSID ShortChannelID
}
// Error returns a human-readable description of the error.
func (e ErrUnsortedSIDs) Error() string {
return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
e.curSID, e.prevSID)
}
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
// that we'll only attempt a single zlib decoding instance at a time. This
// allows us to also further bound our memory usage.
@ -67,6 +80,12 @@ type QueryShortChanIDs struct {
// ShortChanIDs is a slice of decoded short channel ID's.
ShortChanIDs []ShortChannelID
// noSort indicates whether or not to sort the short channel ids before
// writing them out.
//
// NOTE: This should only be used during testing.
noSort bool
}
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
@ -158,11 +177,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
// ID's to conclude our parsing.
shortChanIDs := make([]ShortChannelID, numShortChanIDs)
bodyReader := bytes.NewReader(queryBody)
var lastChanID ShortChannelID
for i := 0; i < numShortChanIDs; i++ {
if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
return 0, nil, fmt.Errorf("unable to parse "+
"short chan ID: %v", err)
}
cid := shortChanIDs[i]
if cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
}
lastChanID = cid
}
return encodingType, shortChanIDs, nil
@ -224,9 +250,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
// within the encoding, and if violated can aide us in
// detecting malicious payloads.
if cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, fmt.Errorf("current sid of %v "+
"isn't greater than last sid of %v", cid,
lastChanID)
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
}
lastChanID = cid
@ -253,20 +277,23 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error {
// Base on our encoding type, we'll write out the set of short channel
// ID's.
return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
}
// encodeShortChanIDs encodes the passed short channel ID's into the passed
// io.Writer, respecting the specified encoding type.
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
shortChanIDs []ShortChannelID) error {
shortChanIDs []ShortChannelID, noSort bool) error {
// For both of the current encoding types, the channel ID's are to be
// sorted in place, so we'll do that now.
sort.Slice(shortChanIDs, func(i, j int) bool {
return shortChanIDs[i].ToUint64() <
shortChanIDs[j].ToUint64()
})
// sorted in place, so we'll do that now. The sorting is applied unless
// we were specifically requested not to for testing purposes.
if !noSort {
sort.Slice(shortChanIDs, func(i, j int) bool {
return shortChanIDs[i].ToUint64() <
shortChanIDs[j].ToUint64()
})
}
switch encodingType {

View File

@ -0,0 +1,75 @@
package lnwire
import (
"bytes"
"testing"
)
type unsortedSidTest struct {
name string
encType ShortChanIDEncoding
sids []ShortChannelID
}
var (
unsortedSids = []ShortChannelID{
NewShortChanIDFromInt(4),
NewShortChanIDFromInt(3),
}
duplicateSids = []ShortChannelID{
NewShortChanIDFromInt(3),
NewShortChanIDFromInt(3),
}
unsortedSidTests = []unsortedSidTest{
{
name: "plain unsorted",
encType: EncodingSortedPlain,
sids: unsortedSids,
},
{
name: "plain duplicate",
encType: EncodingSortedPlain,
sids: duplicateSids,
},
{
name: "zlib unsorted",
encType: EncodingSortedZlib,
sids: unsortedSids,
},
{
name: "zlib duplicate",
encType: EncodingSortedZlib,
sids: duplicateSids,
},
}
)
// TestQueryShortChanIDsUnsorted tests that decoding a QueryShortChanID request
// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure.
func TestQueryShortChanIDsUnsorted(t *testing.T) {
for _, test := range unsortedSidTests {
test := test
t.Run(test.name, func(t *testing.T) {
req := &QueryShortChanIDs{
EncodingType: test.encType,
ShortChanIDs: test.sids,
noSort: true,
}
var b bytes.Buffer
err := req.Encode(&b, 0)
if err != nil {
t.Fatalf("unable to encode req: %v", err)
}
var req2 QueryShortChanIDs
err = req2.Decode(bytes.NewReader(b.Bytes()), 0)
if _, ok := err.(ErrUnsortedSIDs); !ok {
t.Fatalf("expected ErrUnsortedSIDs, got: %T",
err)
}
})
}
}

View File

@ -21,6 +21,12 @@ type ReplyChannelRange struct {
// ShortChanIDs is a slice of decoded short channel ID's.
ShortChanIDs []ShortChannelID
// noSort indicates whether or not to sort the short channel ids before
// writing them out.
//
// NOTE: This should only be used for testing.
noSort bool
}
// NewReplyChannelRange creates a new empty ReplyChannelRange message.
@ -64,7 +70,7 @@ func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error {
return err
}
return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs)
return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
}
// MsgType returns the integer uniquely identifying this message type on the

View File

@ -0,0 +1,34 @@
package lnwire
import (
"bytes"
"testing"
)
// TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request
// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure.
func TestReplyChannelRangeUnsorted(t *testing.T) {
for _, test := range unsortedSidTests {
test := test
t.Run(test.name, func(t *testing.T) {
req := &ReplyChannelRange{
EncodingType: test.encType,
ShortChanIDs: test.sids,
noSort: true,
}
var b bytes.Buffer
err := req.Encode(&b, 0)
if err != nil {
t.Fatalf("unable to encode req: %v", err)
}
var req2 ReplyChannelRange
err = req2.Decode(bytes.NewReader(b.Bytes()), 0)
if _, ok := err.(ErrUnsortedSIDs); !ok {
t.Fatalf("expected ErrUnsortedSIDs, got: %T",
err)
}
})
}
}