tlv/stream: parse entire stream to find all required types

This commit is contained in:
Conner Fromknecht 2019-10-30 21:20:29 -07:00
parent d08e8ddd61
commit e85aaa45f6
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
2 changed files with 104 additions and 24 deletions

@ -162,6 +162,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
var (
typ Type
min Type
firstFail *Type
recordIdx int
overflow bool
)
@ -176,7 +177,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
// We'll silence an EOF when zero bytes remain, meaning the
// stream was cleanly encoded.
case err == io.EOF:
if firstFail == nil {
return parsedTypes, nil
}
return parsedTypes, ErrUnknownRequiredType(*firstFail)
// Other unexpected errors.
case err != nil:
@ -243,7 +247,27 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
// This record type is unknown to the stream, fail if the type
// is even meaning that we are required to understand it.
case typ%2 == 0:
// We'll fail immediately in the case that we aren't
// tracking the set of parsed types.
if parsedTypes == nil {
return nil, ErrUnknownRequiredType(typ)
}
// Otherwise, we'll track the first such failure and
// allow parsing to continue. If no other types of
// errors are encountered, the first failure will be
// returned as an ErrUnknownRequiredType so that the
// full set of included types can be returned.
if firstFail == nil {
failTyp := typ
firstFail = &failTyp
}
// With the failure type recorded, we'll simply discard
// the remainder of the record as if it were optional.
// The first failure will be returned after reaching the
// stopping condition.
fallthrough
// Otherwise, the record type is unknown and is odd, discard the
// number of bytes specified by length.

@ -2,50 +2,106 @@ package tlv_test
import (
"bytes"
"reflect"
"testing"
"github.com/lightningnetwork/lnd/tlv"
)
type parsedTypeTest struct {
name string
encode []tlv.Type
decode []tlv.Type
expErr error
}
// TestParsedTypes asserts that a Stream will properly return the set of types
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
func TestParsedTypes(t *testing.T) {
const (
firstReqType = 0
knownType = 1
unknownType = 3
secondReqType = 4
)
// Construct a stream that will encode two types, one that will be known
// to the decoder and another that will be unknown.
encStream := tlv.MustNewStream(
tlv.MakePrimitiveRecord(knownType, new(uint64)),
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
tests := []parsedTypeTest{
{
name: "known optional and unknown optional",
encode: []tlv.Type{knownType, unknownType},
decode: []tlv.Type{knownType},
},
{
name: "unknown required and known optional",
encode: []tlv.Type{firstReqType, knownType},
decode: []tlv.Type{knownType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
{
name: "unknown required and unknown optional",
encode: []tlv.Type{unknownType, secondReqType},
expErr: tlv.ErrUnknownRequiredType(secondReqType),
},
{
name: "unknown required and known required",
encode: []tlv.Type{firstReqType, secondReqType},
decode: []tlv.Type{secondReqType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
{
name: "two unknown required",
encode: []tlv.Type{firstReqType, secondReqType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testParsedTypes(t, test)
})
}
}
func testParsedTypes(t *testing.T, test parsedTypeTest) {
encRecords := make([]tlv.Record, 0, len(test.encode))
for _, typ := range test.encode {
encRecords = append(
encRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
)
}
decRecords := make([]tlv.Record, 0, len(test.decode))
for _, typ := range test.decode {
decRecords = append(
decRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
)
}
// Construct a stream that will encode the test's set of types.
encStream := tlv.MustNewStream(encRecords...)
var b bytes.Buffer
if err := encStream.Encode(&b); err != nil {
t.Fatalf("unable to encode stream: %v", err)
}
// Create a stream that will parse only the known type.
decStream := tlv.MustNewStream(
tlv.MakePrimitiveRecord(knownType, new(uint64)),
)
// Create a stream that will parse a subset of the test's types.
decStream := tlv.MustNewStream(decRecords...)
parsedTypes, err := decStream.DecodeWithParsedTypes(
bytes.NewReader(b.Bytes()),
)
if err != nil {
t.Fatalf("unable to decode stream: %v", err)
if !reflect.DeepEqual(err, test.expErr) {
t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
}
// Assert that both the known and unknown types are included in the set
// of parsed types.
if _, ok := parsedTypes[knownType]; !ok {
t.Fatalf("known type %d should be in parsed types", knownType)
}
if _, ok := parsedTypes[unknownType]; !ok {
t.Fatalf("unknown type %d should be in parsed types",
unknownType)
// Assert that all encoded types are included in the set of parsed
// types.
for _, typ := range test.encode {
if _, ok := parsedTypes[typ]; !ok {
t.Fatalf("encoded type %d should be in parsed types",
typ)
}
}
}