From aefec9b10faf8ba9ee858ab568d7a17339ae4026 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 4 Sep 2019 11:46:28 -0700 Subject: [PATCH] tlv: return parsed types from DecodeWithParsedTypes This commit adds an additional return value to Stream.DecodeWithParsedTypes, which returns the set of types that were encountered during decoding. The set will contain all known types that were decoded, as well as unknown odd types that were ignored. The rationale for the return value (rather than an internal member) is so that the stream remains stateless. This return value can be used by callers during decoding to make assertions as to whether specific types were included in the stream. This is need, for example, when parsing onion payloads where certain fields must be included/omitted depending on the hop type. The original Decode method would incur the additional performance hit of needing to track the parsed types, so we can selectively enable this functionality when a decoder requires it by using a helper which conditionally tracks the parsed types. --- htlcswitch/hop/payload.go | 2 +- tlv/record.go | 3 +++ tlv/stream.go | 43 ++++++++++++++++++++++++--------- tlv/stream_test.go | 51 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 12 deletions(-) create mode 100644 tlv/stream_test.go diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 1c7e563b..5577d76d 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -53,7 +53,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return nil, err } - err = tlvStream.Decode(r) + _, err = tlvStream.DecodeWithParsedTypes(r) if err != nil { return nil, err } diff --git a/tlv/record.go b/tlv/record.go index 610ab6c1..75647e03 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -12,6 +12,9 @@ import ( // Type is an 64-bit identifier for a TLV Record. type Type uint64 +// TypeSet is an unordered set of Types. +type TypeSet map[Type]struct{} + // Encoder is a signature for methods that can encode TLV values. An error // should be returned if the Encoder cannot support the underlying type of val. // The provided scratch buffer must be non-nil. diff --git a/tlv/stream.go b/tlv/stream.go index 159301fd..60b40919 100644 --- a/tlv/stream.go +++ b/tlv/stream.go @@ -144,6 +144,21 @@ func (s *Stream) Encode(w io.Writer) error { // the last record was read cleanly and we should stop parsing. All other io.EOF // or io.ErrUnexpectedEOF errors are returned. func (s *Stream) Decode(r io.Reader) error { + _, err := s.decode(r, nil) + return err +} + +// DecodeWithParsedTypes is identical to Decode, but if successful, returns a +// TypeSet containing the types of all records that were decoded or ignored from +// the stream. +func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) { + return s.decode(r, make(TypeSet)) +} + +// decode is a helper function that performs the basis of stream decoding. If +// the caller needs the set of parsed types, it must provide an initialized +// parsedTypes, otherwise the returned TypeSet will be nil. +func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { var ( typ Type min Type @@ -161,11 +176,11 @@ func (s *Stream) Decode(r io.Reader) error { // We'll silence an EOF when zero bytes remain, meaning the // stream was cleanly encoded. case err == io.EOF: - return nil + return parsedTypes, nil // Other unexpected errors. case err != nil: - return err + return nil, err } typ = Type(t) @@ -176,7 +191,7 @@ func (s *Stream) Decode(r io.Reader) error { // encodings that have duplicate records or from accepting an // unsorted series. if overflow || typ < min { - return ErrStreamNotCanonical + return nil, ErrStreamNotCanonical } // Read the varint length. @@ -186,11 +201,11 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // Place a soft limit on the size of a sane record, which @@ -198,7 +213,7 @@ func (s *Stream) Decode(r io.Reader) error { // unbounded amount of memory when decoding variable-sized // fields. if length > MaxRecordSize { - return ErrRecordTooLarge + return nil, ErrRecordTooLarge } // Search the records known to the stream for this type. We'll @@ -218,17 +233,17 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // This record type is unknown to the stream, fail if the type // is even meaning that we are required to understand it. case typ%2 == 0: - return ErrUnknownRequiredType(typ) + return nil, ErrUnknownRequiredType(typ) // Otherwise, the record type is unknown and is odd, discard the // number of bytes specified by length. @@ -239,14 +254,20 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } } + // Record the successfully decoded or ignored type if the + // caller provided an initialized TypeSet. + if parsedTypes != nil { + parsedTypes[typ] = struct{}{} + } + // Update our record index so that we can begin our next search // from where we left off. recordIdx = newIdx diff --git a/tlv/stream_test.go b/tlv/stream_test.go new file mode 100644 index 00000000..e9970e74 --- /dev/null +++ b/tlv/stream_test.go @@ -0,0 +1,51 @@ +package tlv_test + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TestParsedTypes asserts that a Stream will properly return the set of types +// that it encounters when the type is known-and-decoded or unknown-and-ignored. +func TestParsedTypes(t *testing.T) { + const ( + knownType = 1 + unknownType = 3 + ) + + // Construct a stream that will encode two types, one that will be known + // to the decoder and another that will be unknown. + encStream := tlv.MustNewStream( + tlv.MakePrimitiveRecord(knownType, new(uint64)), + tlv.MakePrimitiveRecord(unknownType, new(uint64)), + ) + + var b bytes.Buffer + if err := encStream.Encode(&b); err != nil { + t.Fatalf("unable to encode stream: %v", err) + } + + // Create a stream that will parse only the known type. + decStream := tlv.MustNewStream( + tlv.MakePrimitiveRecord(knownType, new(uint64)), + ) + + parsedTypes, err := decStream.DecodeWithParsedTypes( + bytes.NewReader(b.Bytes()), + ) + if err != nil { + t.Fatalf("unable to decode stream: %v", err) + } + + // Assert that both the known and unknown types are included in the set + // of parsed types. + if _, ok := parsedTypes[knownType]; !ok { + t.Fatalf("known type %d should be in parsed types", knownType) + } + if _, ok := parsedTypes[unknownType]; !ok { + t.Fatalf("unknown type %d should be in parsed types", + unknownType) + } +}