diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 1c7e563b..5577d76d 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -53,7 +53,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return nil, err } - err = tlvStream.Decode(r) + _, err = tlvStream.DecodeWithParsedTypes(r) if err != nil { return nil, err } diff --git a/tlv/record.go b/tlv/record.go index 610ab6c1..75647e03 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -12,6 +12,9 @@ import ( // Type is an 64-bit identifier for a TLV Record. type Type uint64 +// TypeSet is an unordered set of Types. +type TypeSet map[Type]struct{} + // Encoder is a signature for methods that can encode TLV values. An error // should be returned if the Encoder cannot support the underlying type of val. // The provided scratch buffer must be non-nil. diff --git a/tlv/stream.go b/tlv/stream.go index 159301fd..60b40919 100644 --- a/tlv/stream.go +++ b/tlv/stream.go @@ -144,6 +144,21 @@ func (s *Stream) Encode(w io.Writer) error { // the last record was read cleanly and we should stop parsing. All other io.EOF // or io.ErrUnexpectedEOF errors are returned. func (s *Stream) Decode(r io.Reader) error { + _, err := s.decode(r, nil) + return err +} + +// DecodeWithParsedTypes is identical to Decode, but if successful, returns a +// TypeSet containing the types of all records that were decoded or ignored from +// the stream. +func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) { + return s.decode(r, make(TypeSet)) +} + +// decode is a helper function that performs the basis of stream decoding. If +// the caller needs the set of parsed types, it must provide an initialized +// parsedTypes, otherwise the returned TypeSet will be nil. +func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { var ( typ Type min Type @@ -161,11 +176,11 @@ func (s *Stream) Decode(r io.Reader) error { // We'll silence an EOF when zero bytes remain, meaning the // stream was cleanly encoded. case err == io.EOF: - return nil + return parsedTypes, nil // Other unexpected errors. case err != nil: - return err + return nil, err } typ = Type(t) @@ -176,7 +191,7 @@ func (s *Stream) Decode(r io.Reader) error { // encodings that have duplicate records or from accepting an // unsorted series. if overflow || typ < min { - return ErrStreamNotCanonical + return nil, ErrStreamNotCanonical } // Read the varint length. @@ -186,11 +201,11 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // Place a soft limit on the size of a sane record, which @@ -198,7 +213,7 @@ func (s *Stream) Decode(r io.Reader) error { // unbounded amount of memory when decoding variable-sized // fields. if length > MaxRecordSize { - return ErrRecordTooLarge + return nil, ErrRecordTooLarge } // Search the records known to the stream for this type. We'll @@ -218,17 +233,17 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // This record type is unknown to the stream, fail if the type // is even meaning that we are required to understand it. case typ%2 == 0: - return ErrUnknownRequiredType(typ) + return nil, ErrUnknownRequiredType(typ) // Otherwise, the record type is unknown and is odd, discard the // number of bytes specified by length. @@ -239,14 +254,20 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } } + // Record the successfully decoded or ignored type if the + // caller provided an initialized TypeSet. + if parsedTypes != nil { + parsedTypes[typ] = struct{}{} + } + // Update our record index so that we can begin our next search // from where we left off. recordIdx = newIdx diff --git a/tlv/stream_test.go b/tlv/stream_test.go new file mode 100644 index 00000000..e9970e74 --- /dev/null +++ b/tlv/stream_test.go @@ -0,0 +1,51 @@ +package tlv_test + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TestParsedTypes asserts that a Stream will properly return the set of types +// that it encounters when the type is known-and-decoded or unknown-and-ignored. +func TestParsedTypes(t *testing.T) { + const ( + knownType = 1 + unknownType = 3 + ) + + // Construct a stream that will encode two types, one that will be known + // to the decoder and another that will be unknown. + encStream := tlv.MustNewStream( + tlv.MakePrimitiveRecord(knownType, new(uint64)), + tlv.MakePrimitiveRecord(unknownType, new(uint64)), + ) + + var b bytes.Buffer + if err := encStream.Encode(&b); err != nil { + t.Fatalf("unable to encode stream: %v", err) + } + + // Create a stream that will parse only the known type. + decStream := tlv.MustNewStream( + tlv.MakePrimitiveRecord(knownType, new(uint64)), + ) + + parsedTypes, err := decStream.DecodeWithParsedTypes( + bytes.NewReader(b.Bytes()), + ) + if err != nil { + t.Fatalf("unable to decode stream: %v", err) + } + + // Assert that both the known and unknown types are included in the set + // of parsed types. + if _, ok := parsedTypes[knownType]; !ok { + t.Fatalf("known type %d should be in parsed types", knownType) + } + if _, ok := parsedTypes[unknownType]; !ok { + t.Fatalf("unknown type %d should be in parsed types", + unknownType) + } +}