lnwire/reply_channel_range: assert sorted encodings

This commit is contained in:
Conner Fromknecht 2019-11-20 01:57:59 -08:00
parent 3cc235a349
commit 2a6e41236c
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
2 changed files with 41 additions and 1 deletions

@ -21,6 +21,12 @@ type ReplyChannelRange struct {
// ShortChanIDs is a slice of decoded short channel ID's. // ShortChanIDs is a slice of decoded short channel ID's.
ShortChanIDs []ShortChannelID ShortChanIDs []ShortChannelID
// noSort indicates whether or not to sort the short channel ids before
// writing them out.
//
// NOTE: This should only be used for testing.
noSort bool
} }
// NewReplyChannelRange creates a new empty ReplyChannelRange message. // NewReplyChannelRange creates a new empty ReplyChannelRange message.
@ -64,7 +70,7 @@ func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error {
return err return err
} }
return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, false) return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the

@ -0,0 +1,34 @@
package lnwire
import (
"bytes"
"testing"
)
// TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request
// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure.
func TestReplyChannelRangeUnsorted(t *testing.T) {
for _, test := range unsortedSidTests {
test := test
t.Run(test.name, func(t *testing.T) {
req := &ReplyChannelRange{
EncodingType: test.encType,
ShortChanIDs: test.sids,
noSort: true,
}
var b bytes.Buffer
err := req.Encode(&b, 0)
if err != nil {
t.Fatalf("unable to encode req: %v", err)
}
var req2 ReplyChannelRange
err = req2.Decode(bytes.NewReader(b.Bytes()), 0)
if _, ok := err.(ErrUnsortedSIDs); !ok {
t.Fatalf("expected ErrUnsortedSIDs, got: %T",
err)
}
})
}
}