lnwire: explicitly handle empty list when encoding short chan IDs
Before this commit, both writing and reading an encoded empty set of short channel IDs from the wire would fail. Prior to this commit, we treated decoding an empty set as a caller error, and failed to write out the zlib encoding of an empty set in a way that us and the other implementations were able to read. To fix this, rather than giving zlib an empty buffer to write out (which results in an encoding with the zlib header data and the rest), we just write a blank slice. When decoding, if we have an empty query body, then we'll return a `nil` slice. With the above changes, we'll now always write out an empty short channel ID set as: ``` 0001 (1 byte follows) || <encoding_type> ``` A new test has also been added to exercise this case for both known encoding types.
This commit is contained in:
parent
83ff6a59d4
commit
17200afc57
@ -132,7 +132,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
if numBytesResp == 0 {
|
if numBytesResp == 0 {
|
||||||
return 0, nil, fmt.Errorf("No encoding type specified")
|
return 0, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
queryBody := make([]byte, numBytesResp)
|
queryBody := make([]byte, numBytesResp)
|
||||||
@ -148,6 +148,13 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err
|
|||||||
// as that was just the encoding type.
|
// as that was just the encoding type.
|
||||||
queryBody = queryBody[1:]
|
queryBody = queryBody[1:]
|
||||||
|
|
||||||
|
// At this point, if there's no body remaining, then only the encoding
|
||||||
|
// type was specified, meaning that there're no further bytes to be
|
||||||
|
// parsed.
|
||||||
|
if len(queryBody) == 0 {
|
||||||
|
return encodingType, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Otherwise, depending on the encoding type, we'll decode the encode
|
// Otherwise, depending on the encoding type, we'll decode the encode
|
||||||
// short channel ID's in a different manner.
|
// short channel ID's in a different manner.
|
||||||
switch encodingType {
|
switch encodingType {
|
||||||
@ -338,27 +345,43 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
|
|||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
zlibWriter := zlib.NewWriter(&buf)
|
zlibWriter := zlib.NewWriter(&buf)
|
||||||
|
|
||||||
// Next, we'll write out all the channel ID's directly into the
|
// If we don't have anything at all to write, then we'll write
|
||||||
// zlib writer, which will do compressing on the fly.
|
// an empty payload so we don't include things like the zlib
|
||||||
for _, chanID := range shortChanIDs {
|
// header when the remote party is expecting no actual short
|
||||||
err := WriteElements(zlibWriter, chanID)
|
// channel IDs.
|
||||||
if err != nil {
|
var compressedPayload []byte
|
||||||
return fmt.Errorf("unable to write short chan "+
|
if len(shortChanIDs) > 0 {
|
||||||
"ID: %v", err)
|
// Next, we'll write out all the channel ID's directly
|
||||||
|
// into the zlib writer, which will do compressing on
|
||||||
|
// the fly.
|
||||||
|
for _, chanID := range shortChanIDs {
|
||||||
|
err := WriteElements(zlibWriter, chanID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to write short chan "+
|
||||||
|
"ID: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we've written all the elements, we'll ensure the
|
// Now that we've written all the elements, we'll
|
||||||
// compressed stream is written to the underlying buffer.
|
// ensure the compressed stream is written to the
|
||||||
if err := zlibWriter.Close(); err != nil {
|
// underlying buffer.
|
||||||
return fmt.Errorf("unable to finalize "+
|
if err := zlibWriter.Close(); err != nil {
|
||||||
"compression: %v", err)
|
return fmt.Errorf("unable to finalize "+
|
||||||
|
"compression: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
compressedPayload = buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now that we have all the items compressed, we can compute
|
// Now that we have all the items compressed, we can compute
|
||||||
// what the total payload size will be. We add one to account
|
// what the total payload size will be. We add one to account
|
||||||
// for the byte to encode the type.
|
// for the byte to encode the type.
|
||||||
compressedPayload := buf.Bytes()
|
//
|
||||||
|
// If we don't have any actual bytes to write, then we'll end
|
||||||
|
// up emitting one byte for the length, followed by the
|
||||||
|
// encoding type, and nothing more. The spec isn't 100% clear
|
||||||
|
// in this area, but we do this as this is what most of the
|
||||||
|
// other implementations do.
|
||||||
numBytesBody := len(compressedPayload) + 1
|
numBytesBody := len(compressedPayload) + 1
|
||||||
|
|
||||||
// Finally, we can write out the number of bytes, the
|
// Finally, we can write out the number of bytes, the
|
||||||
|
@ -2,7 +2,11 @@ package lnwire
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request
|
// TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request
|
||||||
@ -32,3 +36,72 @@ func TestReplyChannelRangeUnsorted(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestReplyChannelRangeEmpty tests encoding and decoding a ReplyChannelRange
|
||||||
|
// that doesn't contain any channel results.
|
||||||
|
func TestReplyChannelRangeEmpty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
emptyChannelsTests := []struct {
|
||||||
|
name string
|
||||||
|
encType ShortChanIDEncoding
|
||||||
|
encodedHex string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty plain encoding",
|
||||||
|
encType: EncodingSortedPlain,
|
||||||
|
encodedHex: "000000000000000000000000000000000000000" +
|
||||||
|
"00000000000000000000000000000000100000002" +
|
||||||
|
"01000100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty zlib encoding",
|
||||||
|
encType: EncodingSortedZlib,
|
||||||
|
encodedHex: "00000000000000000000000000000000000000" +
|
||||||
|
"0000000000000000000000000000000001000000" +
|
||||||
|
"0201000101",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range emptyChannelsTests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
req := ReplyChannelRange{
|
||||||
|
QueryChannelRange: QueryChannelRange{
|
||||||
|
FirstBlockHeight: 1,
|
||||||
|
NumBlocks: 2,
|
||||||
|
},
|
||||||
|
Complete: 1,
|
||||||
|
EncodingType: test.encType,
|
||||||
|
ShortChanIDs: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First decode the hex string in the test case into a
|
||||||
|
// new ReplyChannelRange message. It should be
|
||||||
|
// identical to the one created above.
|
||||||
|
var req2 ReplyChannelRange
|
||||||
|
b, _ := hex.DecodeString(test.encodedHex)
|
||||||
|
err := req2.Decode(bytes.NewReader(b), 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to decode req: %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(req, req2) {
|
||||||
|
t.Fatalf("requests don't match: expected %v got %v",
|
||||||
|
spew.Sdump(req), spew.Sdump(req2))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, we go in the reverse direction: encode the
|
||||||
|
// request created above, and assert that it matches
|
||||||
|
// the raw byte encoding.
|
||||||
|
var b2 bytes.Buffer
|
||||||
|
err = req.Encode(&b2, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to encode req: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b, b2.Bytes()) {
|
||||||
|
t.Fatalf("encoded requests don't match: expected %x got %x",
|
||||||
|
b, b2.Bytes())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user