tlv: return parsed types from DecodeWithParsedTypes

This commit adds an additional return value to
Stream.DecodeWithParsedTypes, which returns the set of types that were
encountered during decoding. The set will contain all known types that
were decoded, as well as unknown odd types that were ignored.

The rationale for the return value (rather than an internal member) is
so that the stream remains stateless.

This return value can be used by callers during decoding to make
assertions as to whether specific types were included in the stream.
This is need, for example, when parsing onion payloads where certain
fields must be included/omitted depending on the hop type.

The original Decode method would incur the additional performance hit of
needing to track the parsed types, so we can selectively enable this
functionality when a decoder requires it by using a helper which
conditionally tracks the parsed types.
This commit is contained in:
Conner Fromknecht 2019-09-04 11:46:28 -07:00
parent fb565bcd5d
commit aefec9b10f
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
4 changed files with 87 additions and 12 deletions

View File

@ -53,7 +53,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
return nil, err
}
err = tlvStream.Decode(r)
_, err = tlvStream.DecodeWithParsedTypes(r)
if err != nil {
return nil, err
}

View File

@ -12,6 +12,9 @@ import (
// Type is an 64-bit identifier for a TLV Record.
type Type uint64
// TypeSet is an unordered set of Types.
type TypeSet map[Type]struct{}
// Encoder is a signature for methods that can encode TLV values. An error
// should be returned if the Encoder cannot support the underlying type of val.
// The provided scratch buffer must be non-nil.

View File

@ -144,6 +144,21 @@ func (s *Stream) Encode(w io.Writer) error {
// the last record was read cleanly and we should stop parsing. All other io.EOF
// or io.ErrUnexpectedEOF errors are returned.
func (s *Stream) Decode(r io.Reader) error {
_, err := s.decode(r, nil)
return err
}
// DecodeWithParsedTypes is identical to Decode, but if successful, returns a
// TypeSet containing the types of all records that were decoded or ignored from
// the stream.
func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) {
return s.decode(r, make(TypeSet))
}
// decode is a helper function that performs the basis of stream decoding. If
// the caller needs the set of parsed types, it must provide an initialized
// parsedTypes, otherwise the returned TypeSet will be nil.
func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
var (
typ Type
min Type
@ -161,11 +176,11 @@ func (s *Stream) Decode(r io.Reader) error {
// We'll silence an EOF when zero bytes remain, meaning the
// stream was cleanly encoded.
case err == io.EOF:
return nil
return parsedTypes, nil
// Other unexpected errors.
case err != nil:
return err
return nil, err
}
typ = Type(t)
@ -176,7 +191,7 @@ func (s *Stream) Decode(r io.Reader) error {
// encodings that have duplicate records or from accepting an
// unsorted series.
if overflow || typ < min {
return ErrStreamNotCanonical
return nil, ErrStreamNotCanonical
}
// Read the varint length.
@ -186,11 +201,11 @@ func (s *Stream) Decode(r io.Reader) error {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return io.ErrUnexpectedEOF
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return err
return nil, err
}
// Place a soft limit on the size of a sane record, which
@ -198,7 +213,7 @@ func (s *Stream) Decode(r io.Reader) error {
// unbounded amount of memory when decoding variable-sized
// fields.
if length > MaxRecordSize {
return ErrRecordTooLarge
return nil, ErrRecordTooLarge
}
// Search the records known to the stream for this type. We'll
@ -218,17 +233,17 @@ func (s *Stream) Decode(r io.Reader) error {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return io.ErrUnexpectedEOF
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return err
return nil, err
}
// This record type is unknown to the stream, fail if the type
// is even meaning that we are required to understand it.
case typ%2 == 0:
return ErrUnknownRequiredType(typ)
return nil, ErrUnknownRequiredType(typ)
// Otherwise, the record type is unknown and is odd, discard the
// number of bytes specified by length.
@ -239,14 +254,20 @@ func (s *Stream) Decode(r io.Reader) error {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return io.ErrUnexpectedEOF
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return err
return nil, err
}
}
// Record the successfully decoded or ignored type if the
// caller provided an initialized TypeSet.
if parsedTypes != nil {
parsedTypes[typ] = struct{}{}
}
// Update our record index so that we can begin our next search
// from where we left off.
recordIdx = newIdx

51
tlv/stream_test.go Normal file
View File

@ -0,0 +1,51 @@
package tlv_test
import (
"bytes"
"testing"
"github.com/lightningnetwork/lnd/tlv"
)
// TestParsedTypes asserts that a Stream will properly return the set of types
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
func TestParsedTypes(t *testing.T) {
const (
knownType = 1
unknownType = 3
)
// Construct a stream that will encode two types, one that will be known
// to the decoder and another that will be unknown.
encStream := tlv.MustNewStream(
tlv.MakePrimitiveRecord(knownType, new(uint64)),
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
)
var b bytes.Buffer
if err := encStream.Encode(&b); err != nil {
t.Fatalf("unable to encode stream: %v", err)
}
// Create a stream that will parse only the known type.
decStream := tlv.MustNewStream(
tlv.MakePrimitiveRecord(knownType, new(uint64)),
)
parsedTypes, err := decStream.DecodeWithParsedTypes(
bytes.NewReader(b.Bytes()),
)
if err != nil {
t.Fatalf("unable to decode stream: %v", err)
}
// Assert that both the known and unknown types are included in the set
// of parsed types.
if _, ok := parsedTypes[knownType]; !ok {
t.Fatalf("known type %d should be in parsed types", knownType)
}
if _, ok := parsedTypes[unknownType]; !ok {
t.Fatalf("unknown type %d should be in parsed types",
unknownType)
}
}