aefec9b10f
This commit adds an additional return value to Stream.DecodeWithParsedTypes, which returns the set of types that were encountered during decoding. The set will contain all known types that were decoded, as well as unknown odd types that were ignored. The rationale for the return value (rather than an internal member) is so that the stream remains stateless. This return value can be used by callers during decoding to make assertions as to whether specific types were included in the stream. This is need, for example, when parsing onion payloads where certain fields must be included/omitted depending on the hop type. The original Decode method would incur the additional performance hit of needing to track the parsed types, so we can selectively enable this functionality when a decoder requires it by using a helper which conditionally tracks the parsed types.
234 lines
5.7 KiB
Go
234 lines
5.7 KiB
Go
package tlv
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"sort"
|
|
|
|
"github.com/btcsuite/btcd/btcec"
|
|
)
|
|
|
|
// Type is an 64-bit identifier for a TLV Record.
|
|
type Type uint64
|
|
|
|
// TypeSet is an unordered set of Types.
|
|
type TypeSet map[Type]struct{}
|
|
|
|
// Encoder is a signature for methods that can encode TLV values. An error
|
|
// should be returned if the Encoder cannot support the underlying type of val.
|
|
// The provided scratch buffer must be non-nil.
|
|
type Encoder func(w io.Writer, val interface{}, buf *[8]byte) error
|
|
|
|
// Decoder is a signature for methods that can decode TLV values. An error
|
|
// should be returned if the Decoder cannot support the underlying type of val.
|
|
// The provided scratch buffer must be non-nil.
|
|
type Decoder func(r io.Reader, val interface{}, buf *[8]byte, l uint64) error
|
|
|
|
// ENOP is an encoder that doesn't modify the io.Writer and never fails.
|
|
func ENOP(io.Writer, interface{}, *[8]byte) error { return nil }
|
|
|
|
// DNOP is an encoder that doesn't modify the io.Reader and never fails.
|
|
func DNOP(io.Reader, interface{}, *[8]byte, uint64) error { return nil }
|
|
|
|
// SizeFunc is a function that can compute the length of a given field. Since
|
|
// the size of the underlying field can change, this allows the size of the
|
|
// field to be evaluated at the time of encoding.
|
|
type SizeFunc func() uint64
|
|
|
|
// SizeVarBytes returns a SizeFunc that can compute the length of a byte slice.
|
|
func SizeVarBytes(e *[]byte) SizeFunc {
|
|
return func() uint64 {
|
|
return uint64(len(*e))
|
|
}
|
|
}
|
|
|
|
// Record holds the required information to encode or decode a TLV record.
|
|
type Record struct {
|
|
value interface{}
|
|
typ Type
|
|
staticSize uint64
|
|
sizeFunc SizeFunc
|
|
encoder Encoder
|
|
decoder Decoder
|
|
}
|
|
|
|
// Size returns the size of the Record's value. If no static size is known, the
|
|
// dynamic size will be evaluated.
|
|
func (f *Record) Size() uint64 {
|
|
if f.sizeFunc == nil {
|
|
return f.staticSize
|
|
}
|
|
|
|
return f.sizeFunc()
|
|
}
|
|
|
|
// Type returns the type of the underlying TLV record.
|
|
func (f *Record) Type() Type {
|
|
return f.typ
|
|
}
|
|
|
|
// Encode writes out the TLV record to the passed writer. This is useful when a
|
|
// caller wants to obtain the raw encoding of a *single* TLV record, outside
|
|
// the context of the Stream struct.
|
|
func (f *Record) Encode(w io.Writer) error {
|
|
var b [8]byte
|
|
|
|
return f.encoder(w, f.value, &b)
|
|
}
|
|
|
|
// MakePrimitiveRecord creates a record for common types.
|
|
func MakePrimitiveRecord(typ Type, val interface{}) Record {
|
|
var (
|
|
staticSize uint64
|
|
sizeFunc SizeFunc
|
|
encoder Encoder
|
|
decoder Decoder
|
|
)
|
|
switch e := val.(type) {
|
|
case *uint8:
|
|
staticSize = 1
|
|
encoder = EUint8
|
|
decoder = DUint8
|
|
|
|
case *uint16:
|
|
staticSize = 2
|
|
encoder = EUint16
|
|
decoder = DUint16
|
|
|
|
case *uint32:
|
|
staticSize = 4
|
|
encoder = EUint32
|
|
decoder = DUint32
|
|
|
|
case *uint64:
|
|
staticSize = 8
|
|
encoder = EUint64
|
|
decoder = DUint64
|
|
|
|
case *[32]byte:
|
|
staticSize = 32
|
|
encoder = EBytes32
|
|
decoder = DBytes32
|
|
|
|
case *[33]byte:
|
|
staticSize = 33
|
|
encoder = EBytes33
|
|
decoder = DBytes33
|
|
|
|
case **btcec.PublicKey:
|
|
staticSize = 33
|
|
encoder = EPubKey
|
|
decoder = DPubKey
|
|
|
|
case *[64]byte:
|
|
staticSize = 64
|
|
encoder = EBytes64
|
|
decoder = DBytes64
|
|
|
|
case *[]byte:
|
|
sizeFunc = SizeVarBytes(e)
|
|
encoder = EVarBytes
|
|
decoder = DVarBytes
|
|
|
|
default:
|
|
panic(fmt.Sprintf("unknown primitive type: %T", val))
|
|
}
|
|
|
|
return Record{
|
|
value: val,
|
|
typ: typ,
|
|
staticSize: staticSize,
|
|
sizeFunc: sizeFunc,
|
|
encoder: encoder,
|
|
decoder: decoder,
|
|
}
|
|
}
|
|
|
|
// MakeStaticRecord creates a record for a field of fixed-size
|
|
func MakeStaticRecord(typ Type, val interface{}, size uint64, encoder Encoder,
|
|
decoder Decoder) Record {
|
|
|
|
return Record{
|
|
value: val,
|
|
typ: typ,
|
|
staticSize: size,
|
|
encoder: encoder,
|
|
decoder: decoder,
|
|
}
|
|
}
|
|
|
|
// MakeDynamicRecord creates a record whose size may vary, and will be
|
|
// determined at the time of encoding via sizeFunc.
|
|
func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc,
|
|
encoder Encoder, decoder Decoder) Record {
|
|
|
|
return Record{
|
|
value: val,
|
|
typ: typ,
|
|
sizeFunc: sizeFunc,
|
|
encoder: encoder,
|
|
decoder: decoder,
|
|
}
|
|
}
|
|
|
|
// RecordsToMap encodes a series of TLV records as raw key-value pairs in the
|
|
// form of a map.
|
|
func RecordsToMap(records []Record) (map[uint64][]byte, error) {
|
|
tlvMap := make(map[uint64][]byte, len(records))
|
|
|
|
for _, record := range records {
|
|
var b bytes.Buffer
|
|
if err := record.Encode(&b); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tlvMap[uint64(record.Type())] = b.Bytes()
|
|
}
|
|
|
|
return tlvMap, nil
|
|
}
|
|
|
|
// StubEncoder is a factory function that makes a stub tlv.Encoder out of a raw
|
|
// value. We can use this to make a record that can be encoded when we don't
|
|
// actually know it's true underlying value, and only it serialization.
|
|
func StubEncoder(v []byte) Encoder {
|
|
return func(w io.Writer, val interface{}, buf *[8]byte) error {
|
|
_, err := w.Write(v)
|
|
return err
|
|
}
|
|
}
|
|
|
|
// MapToRecords encodes the passed TLV map as a series of regular tlv.Record
|
|
// instances. The resulting set of records will be returned in sorted order by
|
|
// their type.
|
|
func MapToRecords(tlvMap map[uint64][]byte) ([]Record, error) {
|
|
records := make([]Record, 0, len(tlvMap))
|
|
for k, v := range tlvMap {
|
|
// We don't pass in a decoder here since we don't actually know
|
|
// the type, and only expect this Record to be used for display
|
|
// and encoding purposes.
|
|
record := MakeStaticRecord(
|
|
Type(k), nil, uint64(len(v)), StubEncoder(v), nil,
|
|
)
|
|
|
|
records = append(records, record)
|
|
}
|
|
|
|
SortRecords(records)
|
|
|
|
return records, nil
|
|
}
|
|
|
|
// SortRecords is a helper function that will sort a slice of records in place
|
|
// according to their type.
|
|
func SortRecords(records []Record) {
|
|
if len(records) == 0 {
|
|
return
|
|
}
|
|
|
|
sort.Slice(records, func(i, j int) bool {
|
|
return records[i].Type() < records[j].Type()
|
|
})
|
|
}
|