lnd.xprv/tlv/stream.go
Conner Fromknecht aefec9b10f
tlv: return parsed types from DecodeWithParsedTypes
This commit adds an additional return value to
Stream.DecodeWithParsedTypes, which returns the set of types that were
encountered during decoding. The set will contain all known types that
were decoded, as well as unknown odd types that were ignored.

The rationale for the return value (rather than an internal member) is
so that the stream remains stateless.

This return value can be used by callers during decoding to make
assertions as to whether specific types were included in the stream.
This is need, for example, when parsing onion payloads where certain
fields must be included/omitted depending on the hop type.

The original Decode method would incur the additional performance hit of
needing to track the parsed types, so we can selectively enable this
functionality when a decoder requires it by using a helper which
conditionally tracks the parsed types.
2019-09-05 09:30:37 -07:00

320 lines
9.6 KiB
Go

package tlv
import (
"errors"
"fmt"
"io"
"io/ioutil"
"math"
)
// MaxRecordSize is the maximum size of a particular record that will be parsed
// by a stream decoder. This value is currently chosen to the be equal to the
// maximum message size permitted by BOLT 1, as no record should be bigger than
// an entire message.
const MaxRecordSize = 65535 // 65KB
// ErrStreamNotCanonical signals that a decoded stream does not contain records
// sorting by monotonically-increasing type.
var ErrStreamNotCanonical = errors.New("tlv stream is not canonical")
// ErrRecordTooLarge signals that a decoded record has a length that is too
// long to parse.
var ErrRecordTooLarge = errors.New("record is too large")
// ErrUnknownRequiredType is an error returned when decoding an unknown and even
// type from a Stream.
type ErrUnknownRequiredType Type
// Error returns a human-readable description of unknown required type.
func (t ErrUnknownRequiredType) Error() string {
return fmt.Sprintf("unknown required type: %d", t)
}
// Stream defines a TLV stream that can be used for encoding or decoding a set
// of TLV Records.
type Stream struct {
records []Record
buf [8]byte
}
// NewStream creates a new TLV Stream given an encoding codec, a decoding codec,
// and a set of known records.
func NewStream(records ...Record) (*Stream, error) {
// Assert that the ordering of the Records is canonical and appear in
// ascending order of type.
var (
min Type
overflow bool
)
for _, record := range records {
if overflow || record.typ < min {
return nil, ErrStreamNotCanonical
}
if record.encoder == nil {
record.encoder = ENOP
}
if record.decoder == nil {
record.decoder = DNOP
}
if record.typ == math.MaxUint64 {
overflow = true
}
min = record.typ + 1
}
return &Stream{
records: records,
}, nil
}
// MustNewStream creates a new TLV Stream given an encoding codec, a decoding
// codec, and a set of known records. If an error is encountered in creating the
// stream, this method will panic instead of returning the error.
func MustNewStream(records ...Record) *Stream {
stream, err := NewStream(records...)
if err != nil {
panic(err.Error())
}
return stream
}
// Encode writes a Stream to the passed io.Writer. Each of the Records known to
// the Stream is written in ascending order of their type so as to be canonical.
//
// The stream is constructed by concatenating the individual, serialized Records
// where each record has the following format:
// [varint: type]
// [varint: length]
// [length: value]
//
// An error is returned if the io.Writer fails to accept bytes from the
// encoding, and nothing else. The ordering of the Records is asserted upon the
// creation of a Stream, and thus the output will be by definition canonical.
func (s *Stream) Encode(w io.Writer) error {
// Iterate through all known records, if any, serializing each record's
// type, length and value.
for i := range s.records {
rec := &s.records[i]
// Write the record's type as a varint.
err := WriteVarInt(w, uint64(rec.typ), &s.buf)
if err != nil {
return err
}
// Write the record's length as a varint.
err = WriteVarInt(w, rec.Size(), &s.buf)
if err != nil {
return err
}
// Encode the current record's value using the stream's codec.
err = rec.encoder(w, rec.value, &s.buf)
if err != nil {
return err
}
}
return nil
}
// Decode deserializes TLV Stream from the passed io.Reader. The Stream will
// inspect each record that is parsed and check to see if it has a corresponding
// Record to facilitate deserialization of that field. If the record is unknown,
// the Stream will discard the record's bytes and proceed to the subsequent
// record.
//
// Each record has the following format:
// [varint: type]
// [varint: length]
// [length: value]
//
// A series of (possibly zero) records are concatenated into a stream, this
// example contains two records:
//
// (t: 0x01, l: 0x04, v: 0xff, 0xff, 0xff, 0xff)
// (t: 0x02, l: 0x01, v: 0x01)
//
// This method asserts that the byte stream is canonical, namely that each
// record is unique and that all records are sorted in ascending order. An
// ErrNotCanonicalStream error is returned if the encoded TLV stream is not.
//
// We permit an io.EOF error only when reading the type byte which signals that
// the last record was read cleanly and we should stop parsing. All other io.EOF
// or io.ErrUnexpectedEOF errors are returned.
func (s *Stream) Decode(r io.Reader) error {
_, err := s.decode(r, nil)
return err
}
// DecodeWithParsedTypes is identical to Decode, but if successful, returns a
// TypeSet containing the types of all records that were decoded or ignored from
// the stream.
func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) {
return s.decode(r, make(TypeSet))
}
// decode is a helper function that performs the basis of stream decoding. If
// the caller needs the set of parsed types, it must provide an initialized
// parsedTypes, otherwise the returned TypeSet will be nil.
func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
var (
typ Type
min Type
recordIdx int
overflow bool
)
// Iterate through all possible type identifiers. As types are read from
// the io.Reader, min will skip forward to the last read type.
for {
// Read the next varint type.
t, err := ReadVarInt(r, &s.buf)
switch {
// We'll silence an EOF when zero bytes remain, meaning the
// stream was cleanly encoded.
case err == io.EOF:
return parsedTypes, nil
// Other unexpected errors.
case err != nil:
return nil, err
}
typ = Type(t)
// Assert that this type is greater than any previously read.
// If we've already overflowed and we parsed another type, the
// stream is not canonical. This check prevents us from accepts
// encodings that have duplicate records or from accepting an
// unsorted series.
if overflow || typ < min {
return nil, ErrStreamNotCanonical
}
// Read the varint length.
length, err := ReadVarInt(r, &s.buf)
switch {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return nil, err
}
// Place a soft limit on the size of a sane record, which
// prevents malicious encoders from causing us to allocate an
// unbounded amount of memory when decoding variable-sized
// fields.
if length > MaxRecordSize {
return nil, ErrRecordTooLarge
}
// Search the records known to the stream for this type. We'll
// begin the search and recordIdx and walk forward until we find
// it or the next record's type is larger.
rec, newIdx, ok := s.getRecord(typ, recordIdx)
switch {
// We know of this record type, proceed to decode the value.
// This method asserts that length bytes are read in the
// process, and returns an error if the number of bytes is not
// exactly length.
case ok:
err := rec.decoder(r, rec.value, &s.buf, length)
switch {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return nil, err
}
// This record type is unknown to the stream, fail if the type
// is even meaning that we are required to understand it.
case typ%2 == 0:
return nil, ErrUnknownRequiredType(typ)
// Otherwise, the record type is unknown and is odd, discard the
// number of bytes specified by length.
default:
_, err := io.CopyN(ioutil.Discard, r, int64(length))
switch {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return nil, err
}
}
// Record the successfully decoded or ignored type if the
// caller provided an initialized TypeSet.
if parsedTypes != nil {
parsedTypes[typ] = struct{}{}
}
// Update our record index so that we can begin our next search
// from where we left off.
recordIdx = newIdx
// If we've parsed the largest possible type, the next loop will
// overflow back to zero. However, we need to attempt parsing
// the next type to ensure that the stream is empty.
if typ == math.MaxUint64 {
overflow = true
}
// Finally, set our lower bound on the next accepted type.
min = typ + 1
}
}
// getRecord searches for a record matching typ known to the stream. The boolean
// return value indicates whether the record is known to the stream. The integer
// return value carries the index from where getRecord should be invoked on the
// subsequent call. The first call to getRecord should always use an idx of 0.
func (s *Stream) getRecord(typ Type, idx int) (Record, int, bool) {
for idx < len(s.records) {
record := s.records[idx]
switch {
// Found target record, return it to the caller. The next index
// returned points to the immediately following record.
case record.typ == typ:
return record, idx + 1, true
// This record's type is lower than the target. Advance our
// index and continue to the next record which will have a
// strictly higher type.
case record.typ < typ:
idx++
continue
// This record's type is larger than the target, hence we have
// no record matching the current type. Return the current index
// so that we can start our search from here when processing the
// next tlv record.
default:
return Record{}, idx, false
}
}
// All known records are exhausted.
return Record{}, idx, false
}