From e85aaa45f688f12b55f533eaa98d3c7a2cba438f Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 30 Oct 2019 21:20:29 -0700 Subject: [PATCH] tlv/stream: parse entire stream to find all required types --- tlv/stream.go | 28 ++++++++++++- tlv/stream_test.go | 100 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 104 insertions(+), 24 deletions(-) diff --git a/tlv/stream.go b/tlv/stream.go index 60b40919..d104a206 100644 --- a/tlv/stream.go +++ b/tlv/stream.go @@ -162,6 +162,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { var ( typ Type min Type + firstFail *Type recordIdx int overflow bool ) @@ -176,7 +177,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { // We'll silence an EOF when zero bytes remain, meaning the // stream was cleanly encoded. case err == io.EOF: - return parsedTypes, nil + if firstFail == nil { + return parsedTypes, nil + } + return parsedTypes, ErrUnknownRequiredType(*firstFail) // Other unexpected errors. case err != nil: @@ -243,7 +247,27 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { // This record type is unknown to the stream, fail if the type // is even meaning that we are required to understand it. case typ%2 == 0: - return nil, ErrUnknownRequiredType(typ) + // We'll fail immediately in the case that we aren't + // tracking the set of parsed types. + if parsedTypes == nil { + return nil, ErrUnknownRequiredType(typ) + } + + // Otherwise, we'll track the first such failure and + // allow parsing to continue. If no other types of + // errors are encountered, the first failure will be + // returned as an ErrUnknownRequiredType so that the + // full set of included types can be returned. + if firstFail == nil { + failTyp := typ + firstFail = &failTyp + } + + // With the failure type recorded, we'll simply discard + // the remainder of the record as if it were optional. + // The first failure will be returned after reaching the + // stopping condition. + fallthrough // Otherwise, the record type is unknown and is odd, discard the // number of bytes specified by length. diff --git a/tlv/stream_test.go b/tlv/stream_test.go index e9970e74..f7d54f73 100644 --- a/tlv/stream_test.go +++ b/tlv/stream_test.go @@ -2,50 +2,106 @@ package tlv_test import ( "bytes" + "reflect" "testing" "github.com/lightningnetwork/lnd/tlv" ) +type parsedTypeTest struct { + name string + encode []tlv.Type + decode []tlv.Type + expErr error +} + // TestParsedTypes asserts that a Stream will properly return the set of types // that it encounters when the type is known-and-decoded or unknown-and-ignored. func TestParsedTypes(t *testing.T) { const ( - knownType = 1 - unknownType = 3 + firstReqType = 0 + knownType = 1 + unknownType = 3 + secondReqType = 4 ) - // Construct a stream that will encode two types, one that will be known - // to the decoder and another that will be unknown. - encStream := tlv.MustNewStream( - tlv.MakePrimitiveRecord(knownType, new(uint64)), - tlv.MakePrimitiveRecord(unknownType, new(uint64)), - ) + tests := []parsedTypeTest{ + { + name: "known optional and unknown optional", + encode: []tlv.Type{knownType, unknownType}, + decode: []tlv.Type{knownType}, + }, + { + name: "unknown required and known optional", + encode: []tlv.Type{firstReqType, knownType}, + decode: []tlv.Type{knownType}, + expErr: tlv.ErrUnknownRequiredType(firstReqType), + }, + { + name: "unknown required and unknown optional", + encode: []tlv.Type{unknownType, secondReqType}, + expErr: tlv.ErrUnknownRequiredType(secondReqType), + }, + { + name: "unknown required and known required", + encode: []tlv.Type{firstReqType, secondReqType}, + decode: []tlv.Type{secondReqType}, + expErr: tlv.ErrUnknownRequiredType(firstReqType), + }, + { + name: "two unknown required", + encode: []tlv.Type{firstReqType, secondReqType}, + expErr: tlv.ErrUnknownRequiredType(firstReqType), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + testParsedTypes(t, test) + }) + } +} + +func testParsedTypes(t *testing.T, test parsedTypeTest) { + encRecords := make([]tlv.Record, 0, len(test.encode)) + for _, typ := range test.encode { + encRecords = append( + encRecords, tlv.MakePrimitiveRecord(typ, new(uint64)), + ) + } + + decRecords := make([]tlv.Record, 0, len(test.decode)) + for _, typ := range test.decode { + decRecords = append( + decRecords, tlv.MakePrimitiveRecord(typ, new(uint64)), + ) + } + + // Construct a stream that will encode the test's set of types. + encStream := tlv.MustNewStream(encRecords...) var b bytes.Buffer if err := encStream.Encode(&b); err != nil { t.Fatalf("unable to encode stream: %v", err) } - // Create a stream that will parse only the known type. - decStream := tlv.MustNewStream( - tlv.MakePrimitiveRecord(knownType, new(uint64)), - ) + // Create a stream that will parse a subset of the test's types. + decStream := tlv.MustNewStream(decRecords...) parsedTypes, err := decStream.DecodeWithParsedTypes( bytes.NewReader(b.Bytes()), ) - if err != nil { - t.Fatalf("unable to decode stream: %v", err) + if !reflect.DeepEqual(err, test.expErr) { + t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr) } - // Assert that both the known and unknown types are included in the set - // of parsed types. - if _, ok := parsedTypes[knownType]; !ok { - t.Fatalf("known type %d should be in parsed types", knownType) - } - if _, ok := parsedTypes[unknownType]; !ok { - t.Fatalf("unknown type %d should be in parsed types", - unknownType) + // Assert that all encoded types are included in the set of parsed + // types. + for _, typ := range test.encode { + if _, ok := parsedTypes[typ]; !ok { + t.Fatalf("encoded type %d should be in parsed types", + typ) + } } }