tlv: return parsed types from DecodeWithParsedTypes
This commit adds an additional return value to Stream.DecodeWithParsedTypes, which returns the set of types that were encountered during decoding. The set will contain all known types that were decoded, as well as unknown odd types that were ignored. The rationale for the return value (rather than an internal member) is so that the stream remains stateless. This return value can be used by callers during decoding to make assertions as to whether specific types were included in the stream. This is need, for example, when parsing onion payloads where certain fields must be included/omitted depending on the hop type. The original Decode method would incur the additional performance hit of needing to track the parsed types, so we can selectively enable this functionality when a decoder requires it by using a helper which conditionally tracks the parsed types.
This commit is contained in:
parent
fb565bcd5d
commit
aefec9b10f
@ -53,7 +53,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tlvStream.Decode(r)
|
_, err = tlvStream.DecodeWithParsedTypes(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,9 @@ import (
|
|||||||
// Type is an 64-bit identifier for a TLV Record.
|
// Type is an 64-bit identifier for a TLV Record.
|
||||||
type Type uint64
|
type Type uint64
|
||||||
|
|
||||||
|
// TypeSet is an unordered set of Types.
|
||||||
|
type TypeSet map[Type]struct{}
|
||||||
|
|
||||||
// Encoder is a signature for methods that can encode TLV values. An error
|
// Encoder is a signature for methods that can encode TLV values. An error
|
||||||
// should be returned if the Encoder cannot support the underlying type of val.
|
// should be returned if the Encoder cannot support the underlying type of val.
|
||||||
// The provided scratch buffer must be non-nil.
|
// The provided scratch buffer must be non-nil.
|
||||||
|
@ -144,6 +144,21 @@ func (s *Stream) Encode(w io.Writer) error {
|
|||||||
// the last record was read cleanly and we should stop parsing. All other io.EOF
|
// the last record was read cleanly and we should stop parsing. All other io.EOF
|
||||||
// or io.ErrUnexpectedEOF errors are returned.
|
// or io.ErrUnexpectedEOF errors are returned.
|
||||||
func (s *Stream) Decode(r io.Reader) error {
|
func (s *Stream) Decode(r io.Reader) error {
|
||||||
|
_, err := s.decode(r, nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeWithParsedTypes is identical to Decode, but if successful, returns a
|
||||||
|
// TypeSet containing the types of all records that were decoded or ignored from
|
||||||
|
// the stream.
|
||||||
|
func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) {
|
||||||
|
return s.decode(r, make(TypeSet))
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode is a helper function that performs the basis of stream decoding. If
|
||||||
|
// the caller needs the set of parsed types, it must provide an initialized
|
||||||
|
// parsedTypes, otherwise the returned TypeSet will be nil.
|
||||||
|
func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||||
var (
|
var (
|
||||||
typ Type
|
typ Type
|
||||||
min Type
|
min Type
|
||||||
@ -161,11 +176,11 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// We'll silence an EOF when zero bytes remain, meaning the
|
// We'll silence an EOF when zero bytes remain, meaning the
|
||||||
// stream was cleanly encoded.
|
// stream was cleanly encoded.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return nil
|
return parsedTypes, nil
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
typ = Type(t)
|
typ = Type(t)
|
||||||
@ -176,7 +191,7 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// encodings that have duplicate records or from accepting an
|
// encodings that have duplicate records or from accepting an
|
||||||
// unsorted series.
|
// unsorted series.
|
||||||
if overflow || typ < min {
|
if overflow || typ < min {
|
||||||
return ErrStreamNotCanonical
|
return nil, ErrStreamNotCanonical
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the varint length.
|
// Read the varint length.
|
||||||
@ -186,11 +201,11 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||||
// results in an invalid record.
|
// results in an invalid record.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Place a soft limit on the size of a sane record, which
|
// Place a soft limit on the size of a sane record, which
|
||||||
@ -198,7 +213,7 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// unbounded amount of memory when decoding variable-sized
|
// unbounded amount of memory when decoding variable-sized
|
||||||
// fields.
|
// fields.
|
||||||
if length > MaxRecordSize {
|
if length > MaxRecordSize {
|
||||||
return ErrRecordTooLarge
|
return nil, ErrRecordTooLarge
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search the records known to the stream for this type. We'll
|
// Search the records known to the stream for this type. We'll
|
||||||
@ -218,17 +233,17 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||||
// results in an invalid record.
|
// results in an invalid record.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// This record type is unknown to the stream, fail if the type
|
// This record type is unknown to the stream, fail if the type
|
||||||
// is even meaning that we are required to understand it.
|
// is even meaning that we are required to understand it.
|
||||||
case typ%2 == 0:
|
case typ%2 == 0:
|
||||||
return ErrUnknownRequiredType(typ)
|
return nil, ErrUnknownRequiredType(typ)
|
||||||
|
|
||||||
// Otherwise, the record type is unknown and is odd, discard the
|
// Otherwise, the record type is unknown and is odd, discard the
|
||||||
// number of bytes specified by length.
|
// number of bytes specified by length.
|
||||||
@ -239,14 +254,20 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||||
// results in an invalid record.
|
// results in an invalid record.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record the successfully decoded or ignored type if the
|
||||||
|
// caller provided an initialized TypeSet.
|
||||||
|
if parsedTypes != nil {
|
||||||
|
parsedTypes[typ] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
// Update our record index so that we can begin our next search
|
// Update our record index so that we can begin our next search
|
||||||
// from where we left off.
|
// from where we left off.
|
||||||
recordIdx = newIdx
|
recordIdx = newIdx
|
||||||
|
51
tlv/stream_test.go
Normal file
51
tlv/stream_test.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package tlv_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParsedTypes asserts that a Stream will properly return the set of types
|
||||||
|
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
||||||
|
func TestParsedTypes(t *testing.T) {
|
||||||
|
const (
|
||||||
|
knownType = 1
|
||||||
|
unknownType = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Construct a stream that will encode two types, one that will be known
|
||||||
|
// to the decoder and another that will be unknown.
|
||||||
|
encStream := tlv.MustNewStream(
|
||||||
|
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||||
|
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
|
||||||
|
)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := encStream.Encode(&b); err != nil {
|
||||||
|
t.Fatalf("unable to encode stream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a stream that will parse only the known type.
|
||||||
|
decStream := tlv.MustNewStream(
|
||||||
|
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||||
|
)
|
||||||
|
|
||||||
|
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
||||||
|
bytes.NewReader(b.Bytes()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to decode stream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that both the known and unknown types are included in the set
|
||||||
|
// of parsed types.
|
||||||
|
if _, ok := parsedTypes[knownType]; !ok {
|
||||||
|
t.Fatalf("known type %d should be in parsed types", knownType)
|
||||||
|
}
|
||||||
|
if _, ok := parsedTypes[unknownType]; !ok {
|
||||||
|
t.Fatalf("unknown type %d should be in parsed types",
|
||||||
|
unknownType)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user