lnd.xprv/tlv/record.go

231 lines
5.6 KiB
Go
Raw Normal View History

package tlv
import (
"bytes"
"fmt"
"io"
"sort"
"github.com/btcsuite/btcd/btcec"
)
// Type is an 64-bit identifier for a TLV Record.
type Type uint64
// Encoder is a signature for methods that can encode TLV values. An error
// should be returned if the Encoder cannot support the underlying type of val.
// The provided scratch buffer must be non-nil.
type Encoder func(w io.Writer, val interface{}, buf *[8]byte) error
// Decoder is a signature for methods that can decode TLV values. An error
// should be returned if the Decoder cannot support the underlying type of val.
// The provided scratch buffer must be non-nil.
type Decoder func(r io.Reader, val interface{}, buf *[8]byte, l uint64) error
// ENOP is an encoder that doesn't modify the io.Writer and never fails.
func ENOP(io.Writer, interface{}, *[8]byte) error { return nil }
// DNOP is an encoder that doesn't modify the io.Reader and never fails.
func DNOP(io.Reader, interface{}, *[8]byte, uint64) error { return nil }
// SizeFunc is a function that can compute the length of a given field. Since
// the size of the underlying field can change, this allows the size of the
// field to be evaluated at the time of encoding.
type SizeFunc func() uint64
// SizeVarBytes returns a SizeFunc that can compute the length of a byte slice.
func SizeVarBytes(e *[]byte) SizeFunc {
return func() uint64 {
return uint64(len(*e))
}
}
// Record holds the required information to encode or decode a TLV record.
type Record struct {
value interface{}
typ Type
staticSize uint64
sizeFunc SizeFunc
encoder Encoder
decoder Decoder
}
// Size returns the size of the Record's value. If no static size is known, the
// dynamic size will be evaluated.
func (f *Record) Size() uint64 {
if f.sizeFunc == nil {
return f.staticSize
}
return f.sizeFunc()
}
// Type returns the type of the underlying TLV record.
func (f *Record) Type() Type {
return f.typ
}
// Encode writes out the TLV record to the passed writer. This is useful when a
// caller wants to obtain the raw encoding of a *single* TLV record, outside
// the context of the Stream struct.
func (f *Record) Encode(w io.Writer) error {
var b [8]byte
return f.encoder(w, f.value, &b)
}
// MakePrimitiveRecord creates a record for common types.
func MakePrimitiveRecord(typ Type, val interface{}) Record {
var (
staticSize uint64
sizeFunc SizeFunc
encoder Encoder
decoder Decoder
)
switch e := val.(type) {
case *uint8:
staticSize = 1
encoder = EUint8
decoder = DUint8
case *uint16:
staticSize = 2
encoder = EUint16
decoder = DUint16
case *uint32:
staticSize = 4
encoder = EUint32
decoder = DUint32
case *uint64:
staticSize = 8
encoder = EUint64
decoder = DUint64
case *[32]byte:
staticSize = 32
encoder = EBytes32
decoder = DBytes32
case *[33]byte:
staticSize = 33
encoder = EBytes33
decoder = DBytes33
case **btcec.PublicKey:
staticSize = 33
encoder = EPubKey
decoder = DPubKey
case *[64]byte:
staticSize = 64
encoder = EBytes64
decoder = DBytes64
case *[]byte:
sizeFunc = SizeVarBytes(e)
encoder = EVarBytes
decoder = DVarBytes
default:
panic(fmt.Sprintf("unknown primitive type: %T", val))
}
return Record{
value: val,
typ: typ,
staticSize: staticSize,
sizeFunc: sizeFunc,
encoder: encoder,
decoder: decoder,
}
}
// MakeStaticRecord creates a record for a field of fixed-size
func MakeStaticRecord(typ Type, val interface{}, size uint64, encoder Encoder,
decoder Decoder) Record {
return Record{
value: val,
typ: typ,
staticSize: size,
encoder: encoder,
decoder: decoder,
}
}
// MakeDynamicRecord creates a record whose size may vary, and will be
// determined at the time of encoding via sizeFunc.
func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc,
encoder Encoder, decoder Decoder) Record {
return Record{
value: val,
typ: typ,
sizeFunc: sizeFunc,
encoder: encoder,
decoder: decoder,
}
}
// RecordsToMap encodes a series of TLV records as raw key-value pairs in the
// form of a map.
func RecordsToMap(records []Record) (map[uint64][]byte, error) {
tlvMap := make(map[uint64][]byte, len(records))
for _, record := range records {
var b bytes.Buffer
if err := record.Encode(&b); err != nil {
return nil, err
}
tlvMap[uint64(record.Type())] = b.Bytes()
}
return tlvMap, nil
}
// StubEncoder is a factory function that makes a stub tlv.Encoder out of a raw
// value. We can use this to make a record that can be encoded when we don't
// actually know it's true underlying value, and only it serialization.
func StubEncoder(v []byte) Encoder {
return func(w io.Writer, val interface{}, buf *[8]byte) error {
_, err := w.Write(v)
return err
}
}
// MapToRecords encodes the passed TLV map as a series of regular tlv.Record
// instances. The resulting set of records will be returned in sorted order by
// their type.
func MapToRecords(tlvMap map[uint64][]byte) ([]Record, error) {
records := make([]Record, 0, len(tlvMap))
for k, v := range tlvMap {
// We don't pass in a decoder here since we don't actually know
// the type, and only expect this Record to be used for display
// and encoding purposes.
record := MakeStaticRecord(
Type(k), nil, uint64(len(v)), StubEncoder(v), nil,
)
records = append(records, record)
}
SortRecords(records)
return records, nil
}
// SortRecords is a helper function that will sort a slice of records in place
// according to their type.
func SortRecords(records []Record) {
if len(records) == 0 {
return
}
sort.Slice(records, func(i, j int) bool {
return records[i].Type() < records[j].Type()
})
}