108 lines
2.6 KiB
Go
108 lines
2.6 KiB
Go
package tlv_test
|
|
|
|
import (
|
|
"bytes"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/lightningnetwork/lnd/tlv"
|
|
)
|
|
|
|
type parsedTypeTest struct {
|
|
name string
|
|
encode []tlv.Type
|
|
decode []tlv.Type
|
|
expErr error
|
|
}
|
|
|
|
// TestParsedTypes asserts that a Stream will properly return the set of types
|
|
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
|
func TestParsedTypes(t *testing.T) {
|
|
const (
|
|
firstReqType = 0
|
|
knownType = 1
|
|
unknownType = 3
|
|
secondReqType = 4
|
|
)
|
|
|
|
tests := []parsedTypeTest{
|
|
{
|
|
name: "known optional and unknown optional",
|
|
encode: []tlv.Type{knownType, unknownType},
|
|
decode: []tlv.Type{knownType},
|
|
},
|
|
{
|
|
name: "unknown required and known optional",
|
|
encode: []tlv.Type{firstReqType, knownType},
|
|
decode: []tlv.Type{knownType},
|
|
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
|
},
|
|
{
|
|
name: "unknown required and unknown optional",
|
|
encode: []tlv.Type{unknownType, secondReqType},
|
|
expErr: tlv.ErrUnknownRequiredType(secondReqType),
|
|
},
|
|
{
|
|
name: "unknown required and known required",
|
|
encode: []tlv.Type{firstReqType, secondReqType},
|
|
decode: []tlv.Type{secondReqType},
|
|
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
|
},
|
|
{
|
|
name: "two unknown required",
|
|
encode: []tlv.Type{firstReqType, secondReqType},
|
|
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
test := test
|
|
t.Run(test.name, func(t *testing.T) {
|
|
testParsedTypes(t, test)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testParsedTypes(t *testing.T, test parsedTypeTest) {
|
|
encRecords := make([]tlv.Record, 0, len(test.encode))
|
|
for _, typ := range test.encode {
|
|
encRecords = append(
|
|
encRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
|
)
|
|
}
|
|
|
|
decRecords := make([]tlv.Record, 0, len(test.decode))
|
|
for _, typ := range test.decode {
|
|
decRecords = append(
|
|
decRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
|
)
|
|
}
|
|
|
|
// Construct a stream that will encode the test's set of types.
|
|
encStream := tlv.MustNewStream(encRecords...)
|
|
|
|
var b bytes.Buffer
|
|
if err := encStream.Encode(&b); err != nil {
|
|
t.Fatalf("unable to encode stream: %v", err)
|
|
}
|
|
|
|
// Create a stream that will parse a subset of the test's types.
|
|
decStream := tlv.MustNewStream(decRecords...)
|
|
|
|
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
|
bytes.NewReader(b.Bytes()),
|
|
)
|
|
if !reflect.DeepEqual(err, test.expErr) {
|
|
t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
|
|
}
|
|
|
|
// Assert that all encoded types are included in the set of parsed
|
|
// types.
|
|
for _, typ := range test.encode {
|
|
if _, ok := parsedTypes[typ]; !ok {
|
|
t.Fatalf("encoded type %d should be in parsed types",
|
|
typ)
|
|
}
|
|
}
|
|
}
|