Merge pull request #3743 from cfromknecht/in-order-sids
lnwire: assert sorted short channel ids
This commit is contained in:
commit
a6ef03c777
@ -35,6 +35,19 @@ const (
|
||||
maxZlibBufSize = 67413630
|
||||
)
|
||||
|
||||
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
|
||||
// items were not sorted.
|
||||
type ErrUnsortedSIDs struct {
|
||||
prevSID ShortChannelID
|
||||
curSID ShortChannelID
|
||||
}
|
||||
|
||||
// Error returns a human-readable description of the error.
|
||||
func (e ErrUnsortedSIDs) Error() string {
|
||||
return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
|
||||
e.curSID, e.prevSID)
|
||||
}
|
||||
|
||||
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
|
||||
// that we'll only attempt a single zlib decoding instance at a time. This
|
||||
// allows us to also further bound our memory usage.
|
||||
@ -67,6 +80,12 @@ type QueryShortChanIDs struct {
|
||||
|
||||
// ShortChanIDs is a slice of decoded short channel ID's.
|
||||
ShortChanIDs []ShortChannelID
|
||||
|
||||
// noSort indicates whether or not to sort the short channel ids before
|
||||
// writing them out.
|
||||
//
|
||||
// NOTE: This should only be used during testing.
|
||||
noSort bool
|
||||
}
|
||||
|
||||
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
|
||||
@ -158,11 +177,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
|
||||
// ID's to conclude our parsing.
|
||||
shortChanIDs := make([]ShortChannelID, numShortChanIDs)
|
||||
bodyReader := bytes.NewReader(queryBody)
|
||||
var lastChanID ShortChannelID
|
||||
for i := 0; i < numShortChanIDs; i++ {
|
||||
if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
|
||||
return 0, nil, fmt.Errorf("unable to parse "+
|
||||
"short chan ID: %v", err)
|
||||
}
|
||||
|
||||
cid := shortChanIDs[i]
|
||||
if cid.ToUint64() <= lastChanID.ToUint64() {
|
||||
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
|
||||
}
|
||||
lastChanID = cid
|
||||
}
|
||||
|
||||
return encodingType, shortChanIDs, nil
|
||||
@ -224,9 +250,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
|
||||
// within the encoding, and if violated can aide us in
|
||||
// detecting malicious payloads.
|
||||
if cid.ToUint64() <= lastChanID.ToUint64() {
|
||||
return 0, nil, fmt.Errorf("current sid of %v "+
|
||||
"isn't greater than last sid of %v", cid,
|
||||
lastChanID)
|
||||
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
|
||||
}
|
||||
|
||||
lastChanID = cid
|
||||
@ -253,20 +277,23 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error {
|
||||
|
||||
// Base on our encoding type, we'll write out the set of short channel
|
||||
// ID's.
|
||||
return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
|
||||
return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
|
||||
}
|
||||
|
||||
// encodeShortChanIDs encodes the passed short channel ID's into the passed
|
||||
// io.Writer, respecting the specified encoding type.
|
||||
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
|
||||
shortChanIDs []ShortChannelID) error {
|
||||
shortChanIDs []ShortChannelID, noSort bool) error {
|
||||
|
||||
// For both of the current encoding types, the channel ID's are to be
|
||||
// sorted in place, so we'll do that now.
|
||||
sort.Slice(shortChanIDs, func(i, j int) bool {
|
||||
return shortChanIDs[i].ToUint64() <
|
||||
shortChanIDs[j].ToUint64()
|
||||
})
|
||||
// sorted in place, so we'll do that now. The sorting is applied unless
|
||||
// we were specifically requested not to for testing purposes.
|
||||
if !noSort {
|
||||
sort.Slice(shortChanIDs, func(i, j int) bool {
|
||||
return shortChanIDs[i].ToUint64() <
|
||||
shortChanIDs[j].ToUint64()
|
||||
})
|
||||
}
|
||||
|
||||
switch encodingType {
|
||||
|
||||
|
75
lnwire/query_short_chan_ids_test.go
Normal file
75
lnwire/query_short_chan_ids_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package lnwire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type unsortedSidTest struct {
|
||||
name string
|
||||
encType ShortChanIDEncoding
|
||||
sids []ShortChannelID
|
||||
}
|
||||
|
||||
var (
|
||||
unsortedSids = []ShortChannelID{
|
||||
NewShortChanIDFromInt(4),
|
||||
NewShortChanIDFromInt(3),
|
||||
}
|
||||
|
||||
duplicateSids = []ShortChannelID{
|
||||
NewShortChanIDFromInt(3),
|
||||
NewShortChanIDFromInt(3),
|
||||
}
|
||||
|
||||
unsortedSidTests = []unsortedSidTest{
|
||||
{
|
||||
name: "plain unsorted",
|
||||
encType: EncodingSortedPlain,
|
||||
sids: unsortedSids,
|
||||
},
|
||||
{
|
||||
name: "plain duplicate",
|
||||
encType: EncodingSortedPlain,
|
||||
sids: duplicateSids,
|
||||
},
|
||||
{
|
||||
name: "zlib unsorted",
|
||||
encType: EncodingSortedZlib,
|
||||
sids: unsortedSids,
|
||||
},
|
||||
{
|
||||
name: "zlib duplicate",
|
||||
encType: EncodingSortedZlib,
|
||||
sids: duplicateSids,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// TestQueryShortChanIDsUnsorted tests that decoding a QueryShortChanID request
|
||||
// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure.
|
||||
func TestQueryShortChanIDsUnsorted(t *testing.T) {
|
||||
for _, test := range unsortedSidTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := &QueryShortChanIDs{
|
||||
EncodingType: test.encType,
|
||||
ShortChanIDs: test.sids,
|
||||
noSort: true,
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
err := req.Encode(&b, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode req: %v", err)
|
||||
}
|
||||
|
||||
var req2 QueryShortChanIDs
|
||||
err = req2.Decode(bytes.NewReader(b.Bytes()), 0)
|
||||
if _, ok := err.(ErrUnsortedSIDs); !ok {
|
||||
t.Fatalf("expected ErrUnsortedSIDs, got: %T",
|
||||
err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -21,6 +21,12 @@ type ReplyChannelRange struct {
|
||||
|
||||
// ShortChanIDs is a slice of decoded short channel ID's.
|
||||
ShortChanIDs []ShortChannelID
|
||||
|
||||
// noSort indicates whether or not to sort the short channel ids before
|
||||
// writing them out.
|
||||
//
|
||||
// NOTE: This should only be used for testing.
|
||||
noSort bool
|
||||
}
|
||||
|
||||
// NewReplyChannelRange creates a new empty ReplyChannelRange message.
|
||||
@ -64,7 +70,7 @@ func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs)
|
||||
return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
|
||||
}
|
||||
|
||||
// MsgType returns the integer uniquely identifying this message type on the
|
||||
|
34
lnwire/reply_channel_range_test.go
Normal file
34
lnwire/reply_channel_range_test.go
Normal file
@ -0,0 +1,34 @@
|
||||
package lnwire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request
|
||||
// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure.
|
||||
func TestReplyChannelRangeUnsorted(t *testing.T) {
|
||||
for _, test := range unsortedSidTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := &ReplyChannelRange{
|
||||
EncodingType: test.encType,
|
||||
ShortChanIDs: test.sids,
|
||||
noSort: true,
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
err := req.Encode(&b, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode req: %v", err)
|
||||
}
|
||||
|
||||
var req2 ReplyChannelRange
|
||||
err = req2.Decode(bytes.NewReader(b.Bytes()), 0)
|
||||
if _, ok := err.(ErrUnsortedSIDs); !ok {
|
||||
t.Fatalf("expected ErrUnsortedSIDs, got: %T",
|
||||
err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user