tlv/stream: parse entire stream to find all required types
This commit is contained in:
parent
d08e8ddd61
commit
e85aaa45f6
@ -162,6 +162,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
|||||||
var (
|
var (
|
||||||
typ Type
|
typ Type
|
||||||
min Type
|
min Type
|
||||||
|
firstFail *Type
|
||||||
recordIdx int
|
recordIdx int
|
||||||
overflow bool
|
overflow bool
|
||||||
)
|
)
|
||||||
@ -176,7 +177,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
|||||||
// We'll silence an EOF when zero bytes remain, meaning the
|
// We'll silence an EOF when zero bytes remain, meaning the
|
||||||
// stream was cleanly encoded.
|
// stream was cleanly encoded.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return parsedTypes, nil
|
if firstFail == nil {
|
||||||
|
return parsedTypes, nil
|
||||||
|
}
|
||||||
|
return parsedTypes, ErrUnknownRequiredType(*firstFail)
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
@ -243,7 +247,27 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
|||||||
// This record type is unknown to the stream, fail if the type
|
// This record type is unknown to the stream, fail if the type
|
||||||
// is even meaning that we are required to understand it.
|
// is even meaning that we are required to understand it.
|
||||||
case typ%2 == 0:
|
case typ%2 == 0:
|
||||||
return nil, ErrUnknownRequiredType(typ)
|
// We'll fail immediately in the case that we aren't
|
||||||
|
// tracking the set of parsed types.
|
||||||
|
if parsedTypes == nil {
|
||||||
|
return nil, ErrUnknownRequiredType(typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we'll track the first such failure and
|
||||||
|
// allow parsing to continue. If no other types of
|
||||||
|
// errors are encountered, the first failure will be
|
||||||
|
// returned as an ErrUnknownRequiredType so that the
|
||||||
|
// full set of included types can be returned.
|
||||||
|
if firstFail == nil {
|
||||||
|
failTyp := typ
|
||||||
|
firstFail = &failTyp
|
||||||
|
}
|
||||||
|
|
||||||
|
// With the failure type recorded, we'll simply discard
|
||||||
|
// the remainder of the record as if it were optional.
|
||||||
|
// The first failure will be returned after reaching the
|
||||||
|
// stopping condition.
|
||||||
|
fallthrough
|
||||||
|
|
||||||
// Otherwise, the record type is unknown and is odd, discard the
|
// Otherwise, the record type is unknown and is odd, discard the
|
||||||
// number of bytes specified by length.
|
// number of bytes specified by length.
|
||||||
|
@ -2,50 +2,106 @@ package tlv_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/tlv"
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type parsedTypeTest struct {
|
||||||
|
name string
|
||||||
|
encode []tlv.Type
|
||||||
|
decode []tlv.Type
|
||||||
|
expErr error
|
||||||
|
}
|
||||||
|
|
||||||
// TestParsedTypes asserts that a Stream will properly return the set of types
|
// TestParsedTypes asserts that a Stream will properly return the set of types
|
||||||
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
||||||
func TestParsedTypes(t *testing.T) {
|
func TestParsedTypes(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
knownType = 1
|
firstReqType = 0
|
||||||
unknownType = 3
|
knownType = 1
|
||||||
|
unknownType = 3
|
||||||
|
secondReqType = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
// Construct a stream that will encode two types, one that will be known
|
tests := []parsedTypeTest{
|
||||||
// to the decoder and another that will be unknown.
|
{
|
||||||
encStream := tlv.MustNewStream(
|
name: "known optional and unknown optional",
|
||||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
encode: []tlv.Type{knownType, unknownType},
|
||||||
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
|
decode: []tlv.Type{knownType},
|
||||||
)
|
},
|
||||||
|
{
|
||||||
|
name: "unknown required and known optional",
|
||||||
|
encode: []tlv.Type{firstReqType, knownType},
|
||||||
|
decode: []tlv.Type{knownType},
|
||||||
|
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown required and unknown optional",
|
||||||
|
encode: []tlv.Type{unknownType, secondReqType},
|
||||||
|
expErr: tlv.ErrUnknownRequiredType(secondReqType),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown required and known required",
|
||||||
|
encode: []tlv.Type{firstReqType, secondReqType},
|
||||||
|
decode: []tlv.Type{secondReqType},
|
||||||
|
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two unknown required",
|
||||||
|
encode: []tlv.Type{firstReqType, secondReqType},
|
||||||
|
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
testParsedTypes(t, test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testParsedTypes(t *testing.T, test parsedTypeTest) {
|
||||||
|
encRecords := make([]tlv.Record, 0, len(test.encode))
|
||||||
|
for _, typ := range test.encode {
|
||||||
|
encRecords = append(
|
||||||
|
encRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
decRecords := make([]tlv.Record, 0, len(test.decode))
|
||||||
|
for _, typ := range test.decode {
|
||||||
|
decRecords = append(
|
||||||
|
decRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct a stream that will encode the test's set of types.
|
||||||
|
encStream := tlv.MustNewStream(encRecords...)
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := encStream.Encode(&b); err != nil {
|
if err := encStream.Encode(&b); err != nil {
|
||||||
t.Fatalf("unable to encode stream: %v", err)
|
t.Fatalf("unable to encode stream: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a stream that will parse only the known type.
|
// Create a stream that will parse a subset of the test's types.
|
||||||
decStream := tlv.MustNewStream(
|
decStream := tlv.MustNewStream(decRecords...)
|
||||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
|
||||||
)
|
|
||||||
|
|
||||||
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
||||||
bytes.NewReader(b.Bytes()),
|
bytes.NewReader(b.Bytes()),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if !reflect.DeepEqual(err, test.expErr) {
|
||||||
t.Fatalf("unable to decode stream: %v", err)
|
t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert that both the known and unknown types are included in the set
|
// Assert that all encoded types are included in the set of parsed
|
||||||
// of parsed types.
|
// types.
|
||||||
if _, ok := parsedTypes[knownType]; !ok {
|
for _, typ := range test.encode {
|
||||||
t.Fatalf("known type %d should be in parsed types", knownType)
|
if _, ok := parsedTypes[typ]; !ok {
|
||||||
}
|
t.Fatalf("encoded type %d should be in parsed types",
|
||||||
if _, ok := parsedTypes[unknownType]; !ok {
|
typ)
|
||||||
t.Fatalf("unknown type %d should be in parsed types",
|
}
|
||||||
unknownType)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user