tlv/stream: parse entire stream to find all required types

This commit is contained in:
Conner Fromknecht 2019-10-30 21:20:29 -07:00
parent d08e8ddd61
commit e85aaa45f6
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
2 changed files with 104 additions and 24 deletions

@ -162,6 +162,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
var ( var (
typ Type typ Type
min Type min Type
firstFail *Type
recordIdx int recordIdx int
overflow bool overflow bool
) )
@ -176,7 +177,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
// We'll silence an EOF when zero bytes remain, meaning the // We'll silence an EOF when zero bytes remain, meaning the
// stream was cleanly encoded. // stream was cleanly encoded.
case err == io.EOF: case err == io.EOF:
return parsedTypes, nil if firstFail == nil {
return parsedTypes, nil
}
return parsedTypes, ErrUnknownRequiredType(*firstFail)
// Other unexpected errors. // Other unexpected errors.
case err != nil: case err != nil:
@ -243,7 +247,27 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
// This record type is unknown to the stream, fail if the type // This record type is unknown to the stream, fail if the type
// is even meaning that we are required to understand it. // is even meaning that we are required to understand it.
case typ%2 == 0: case typ%2 == 0:
return nil, ErrUnknownRequiredType(typ) // We'll fail immediately in the case that we aren't
// tracking the set of parsed types.
if parsedTypes == nil {
return nil, ErrUnknownRequiredType(typ)
}
// Otherwise, we'll track the first such failure and
// allow parsing to continue. If no other types of
// errors are encountered, the first failure will be
// returned as an ErrUnknownRequiredType so that the
// full set of included types can be returned.
if firstFail == nil {
failTyp := typ
firstFail = &failTyp
}
// With the failure type recorded, we'll simply discard
// the remainder of the record as if it were optional.
// The first failure will be returned after reaching the
// stopping condition.
fallthrough
// Otherwise, the record type is unknown and is odd, discard the // Otherwise, the record type is unknown and is odd, discard the
// number of bytes specified by length. // number of bytes specified by length.

@ -2,50 +2,106 @@ package tlv_test
import ( import (
"bytes" "bytes"
"reflect"
"testing" "testing"
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
) )
type parsedTypeTest struct {
name string
encode []tlv.Type
decode []tlv.Type
expErr error
}
// TestParsedTypes asserts that a Stream will properly return the set of types // TestParsedTypes asserts that a Stream will properly return the set of types
// that it encounters when the type is known-and-decoded or unknown-and-ignored. // that it encounters when the type is known-and-decoded or unknown-and-ignored.
func TestParsedTypes(t *testing.T) { func TestParsedTypes(t *testing.T) {
const ( const (
knownType = 1 firstReqType = 0
unknownType = 3 knownType = 1
unknownType = 3
secondReqType = 4
) )
// Construct a stream that will encode two types, one that will be known tests := []parsedTypeTest{
// to the decoder and another that will be unknown. {
encStream := tlv.MustNewStream( name: "known optional and unknown optional",
tlv.MakePrimitiveRecord(knownType, new(uint64)), encode: []tlv.Type{knownType, unknownType},
tlv.MakePrimitiveRecord(unknownType, new(uint64)), decode: []tlv.Type{knownType},
) },
{
name: "unknown required and known optional",
encode: []tlv.Type{firstReqType, knownType},
decode: []tlv.Type{knownType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
{
name: "unknown required and unknown optional",
encode: []tlv.Type{unknownType, secondReqType},
expErr: tlv.ErrUnknownRequiredType(secondReqType),
},
{
name: "unknown required and known required",
encode: []tlv.Type{firstReqType, secondReqType},
decode: []tlv.Type{secondReqType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
{
name: "two unknown required",
encode: []tlv.Type{firstReqType, secondReqType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testParsedTypes(t, test)
})
}
}
func testParsedTypes(t *testing.T, test parsedTypeTest) {
encRecords := make([]tlv.Record, 0, len(test.encode))
for _, typ := range test.encode {
encRecords = append(
encRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
)
}
decRecords := make([]tlv.Record, 0, len(test.decode))
for _, typ := range test.decode {
decRecords = append(
decRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
)
}
// Construct a stream that will encode the test's set of types.
encStream := tlv.MustNewStream(encRecords...)
var b bytes.Buffer var b bytes.Buffer
if err := encStream.Encode(&b); err != nil { if err := encStream.Encode(&b); err != nil {
t.Fatalf("unable to encode stream: %v", err) t.Fatalf("unable to encode stream: %v", err)
} }
// Create a stream that will parse only the known type. // Create a stream that will parse a subset of the test's types.
decStream := tlv.MustNewStream( decStream := tlv.MustNewStream(decRecords...)
tlv.MakePrimitiveRecord(knownType, new(uint64)),
)
parsedTypes, err := decStream.DecodeWithParsedTypes( parsedTypes, err := decStream.DecodeWithParsedTypes(
bytes.NewReader(b.Bytes()), bytes.NewReader(b.Bytes()),
) )
if err != nil { if !reflect.DeepEqual(err, test.expErr) {
t.Fatalf("unable to decode stream: %v", err) t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
} }
// Assert that both the known and unknown types are included in the set // Assert that all encoded types are included in the set of parsed
// of parsed types. // types.
if _, ok := parsedTypes[knownType]; !ok { for _, typ := range test.encode {
t.Fatalf("known type %d should be in parsed types", knownType) if _, ok := parsedTypes[typ]; !ok {
} t.Fatalf("encoded type %d should be in parsed types",
if _, ok := parsedTypes[unknownType]; !ok { typ)
t.Fatalf("unknown type %d should be in parsed types", }
unknownType)
} }
} }