Merge pull request #3743 from cfromknecht/in-order-sids
lnwire: assert sorted short channel ids
This commit is contained in:
commit
a6ef03c777
@ -35,6 +35,19 @@ const (
|
|||||||
maxZlibBufSize = 67413630
|
maxZlibBufSize = 67413630
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
|
||||||
|
// items were not sorted.
|
||||||
|
type ErrUnsortedSIDs struct {
|
||||||
|
prevSID ShortChannelID
|
||||||
|
curSID ShortChannelID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a human-readable description of the error.
|
||||||
|
func (e ErrUnsortedSIDs) Error() string {
|
||||||
|
return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
|
||||||
|
e.curSID, e.prevSID)
|
||||||
|
}
|
||||||
|
|
||||||
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
|
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
|
||||||
// that we'll only attempt a single zlib decoding instance at a time. This
|
// that we'll only attempt a single zlib decoding instance at a time. This
|
||||||
// allows us to also further bound our memory usage.
|
// allows us to also further bound our memory usage.
|
||||||
@ -67,6 +80,12 @@ type QueryShortChanIDs struct {
|
|||||||
|
|
||||||
// ShortChanIDs is a slice of decoded short channel ID's.
|
// ShortChanIDs is a slice of decoded short channel ID's.
|
||||||
ShortChanIDs []ShortChannelID
|
ShortChanIDs []ShortChannelID
|
||||||
|
|
||||||
|
// noSort indicates whether or not to sort the short channel ids before
|
||||||
|
// writing them out.
|
||||||
|
//
|
||||||
|
// NOTE: This should only be used during testing.
|
||||||
|
noSort bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
|
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
|
||||||
@ -158,11 +177,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
|
|||||||
// ID's to conclude our parsing.
|
// ID's to conclude our parsing.
|
||||||
shortChanIDs := make([]ShortChannelID, numShortChanIDs)
|
shortChanIDs := make([]ShortChannelID, numShortChanIDs)
|
||||||
bodyReader := bytes.NewReader(queryBody)
|
bodyReader := bytes.NewReader(queryBody)
|
||||||
|
var lastChanID ShortChannelID
|
||||||
for i := 0; i < numShortChanIDs; i++ {
|
for i := 0; i < numShortChanIDs; i++ {
|
||||||
if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
|
if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
|
||||||
return 0, nil, fmt.Errorf("unable to parse "+
|
return 0, nil, fmt.Errorf("unable to parse "+
|
||||||
"short chan ID: %v", err)
|
"short chan ID: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cid := shortChanIDs[i]
|
||||||
|
if cid.ToUint64() <= lastChanID.ToUint64() {
|
||||||
|
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
|
||||||
|
}
|
||||||
|
lastChanID = cid
|
||||||
}
|
}
|
||||||
|
|
||||||
return encodingType, shortChanIDs, nil
|
return encodingType, shortChanIDs, nil
|
||||||
@ -224,9 +250,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
|
|||||||
// within the encoding, and if violated can aide us in
|
// within the encoding, and if violated can aide us in
|
||||||
// detecting malicious payloads.
|
// detecting malicious payloads.
|
||||||
if cid.ToUint64() <= lastChanID.ToUint64() {
|
if cid.ToUint64() <= lastChanID.ToUint64() {
|
||||||
return 0, nil, fmt.Errorf("current sid of %v "+
|
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
|
||||||
"isn't greater than last sid of %v", cid,
|
|
||||||
lastChanID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lastChanID = cid
|
lastChanID = cid
|
||||||
@ -253,20 +277,23 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error {
|
|||||||
|
|
||||||
// Base on our encoding type, we'll write out the set of short channel
|
// Base on our encoding type, we'll write out the set of short channel
|
||||||
// ID's.
|
// ID's.
|
||||||
return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
|
return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// encodeShortChanIDs encodes the passed short channel ID's into the passed
|
// encodeShortChanIDs encodes the passed short channel ID's into the passed
|
||||||
// io.Writer, respecting the specified encoding type.
|
// io.Writer, respecting the specified encoding type.
|
||||||
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
|
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
|
||||||
shortChanIDs []ShortChannelID) error {
|
shortChanIDs []ShortChannelID, noSort bool) error {
|
||||||
|
|
||||||
// For both of the current encoding types, the channel ID's are to be
|
// For both of the current encoding types, the channel ID's are to be
|
||||||
// sorted in place, so we'll do that now.
|
// sorted in place, so we'll do that now. The sorting is applied unless
|
||||||
sort.Slice(shortChanIDs, func(i, j int) bool {
|
// we were specifically requested not to for testing purposes.
|
||||||
return shortChanIDs[i].ToUint64() <
|
if !noSort {
|
||||||
shortChanIDs[j].ToUint64()
|
sort.Slice(shortChanIDs, func(i, j int) bool {
|
||||||
})
|
return shortChanIDs[i].ToUint64() <
|
||||||
|
shortChanIDs[j].ToUint64()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
switch encodingType {
|
switch encodingType {
|
||||||
|
|
||||||
|
75
lnwire/query_short_chan_ids_test.go
Normal file
75
lnwire/query_short_chan_ids_test.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package lnwire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type unsortedSidTest struct {
|
||||||
|
name string
|
||||||
|
encType ShortChanIDEncoding
|
||||||
|
sids []ShortChannelID
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
unsortedSids = []ShortChannelID{
|
||||||
|
NewShortChanIDFromInt(4),
|
||||||
|
NewShortChanIDFromInt(3),
|
||||||
|
}
|
||||||
|
|
||||||
|
duplicateSids = []ShortChannelID{
|
||||||
|
NewShortChanIDFromInt(3),
|
||||||
|
NewShortChanIDFromInt(3),
|
||||||
|
}
|
||||||
|
|
||||||
|
unsortedSidTests = []unsortedSidTest{
|
||||||
|
{
|
||||||
|
name: "plain unsorted",
|
||||||
|
encType: EncodingSortedPlain,
|
||||||
|
sids: unsortedSids,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "plain duplicate",
|
||||||
|
encType: EncodingSortedPlain,
|
||||||
|
sids: duplicateSids,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zlib unsorted",
|
||||||
|
encType: EncodingSortedZlib,
|
||||||
|
sids: unsortedSids,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zlib duplicate",
|
||||||
|
encType: EncodingSortedZlib,
|
||||||
|
sids: duplicateSids,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestQueryShortChanIDsUnsorted tests that decoding a QueryShortChanID request
|
||||||
|
// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure.
|
||||||
|
func TestQueryShortChanIDsUnsorted(t *testing.T) {
|
||||||
|
for _, test := range unsortedSidTests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
req := &QueryShortChanIDs{
|
||||||
|
EncodingType: test.encType,
|
||||||
|
ShortChanIDs: test.sids,
|
||||||
|
noSort: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := req.Encode(&b, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to encode req: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var req2 QueryShortChanIDs
|
||||||
|
err = req2.Decode(bytes.NewReader(b.Bytes()), 0)
|
||||||
|
if _, ok := err.(ErrUnsortedSIDs); !ok {
|
||||||
|
t.Fatalf("expected ErrUnsortedSIDs, got: %T",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -21,6 +21,12 @@ type ReplyChannelRange struct {
|
|||||||
|
|
||||||
// ShortChanIDs is a slice of decoded short channel ID's.
|
// ShortChanIDs is a slice of decoded short channel ID's.
|
||||||
ShortChanIDs []ShortChannelID
|
ShortChanIDs []ShortChannelID
|
||||||
|
|
||||||
|
// noSort indicates whether or not to sort the short channel ids before
|
||||||
|
// writing them out.
|
||||||
|
//
|
||||||
|
// NOTE: This should only be used for testing.
|
||||||
|
noSort bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReplyChannelRange creates a new empty ReplyChannelRange message.
|
// NewReplyChannelRange creates a new empty ReplyChannelRange message.
|
||||||
@ -64,7 +70,7 @@ func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs)
|
return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MsgType returns the integer uniquely identifying this message type on the
|
// MsgType returns the integer uniquely identifying this message type on the
|
||||||
|
34
lnwire/reply_channel_range_test.go
Normal file
34
lnwire/reply_channel_range_test.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package lnwire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request
|
||||||
|
// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure.
|
||||||
|
func TestReplyChannelRangeUnsorted(t *testing.T) {
|
||||||
|
for _, test := range unsortedSidTests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
req := &ReplyChannelRange{
|
||||||
|
EncodingType: test.encType,
|
||||||
|
ShortChanIDs: test.sids,
|
||||||
|
noSort: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := req.Encode(&b, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to encode req: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var req2 ReplyChannelRange
|
||||||
|
err = req2.Decode(bytes.NewReader(b.Bytes()), 0)
|
||||||
|
if _, ok := err.(ErrUnsortedSIDs); !ok {
|
||||||
|
t.Fatalf("expected ErrUnsortedSIDs, got: %T",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user