Merge pull request #3061 from cfromknecht/wire-tlv

tlv: add library for new message/payload serialization format
This commit is contained in:
Olaoluwa Osuntokun 2019-08-07 15:51:26 -07:00 committed by GitHub
commit ea77ff91c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1968 additions and 0 deletions

161
tlv/bench_test.go Normal file

@ -0,0 +1,161 @@
package tlv_test
import (
"bytes"
"io"
"io/ioutil"
"testing"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
// CreateSessionTLV mirrors the wtwire.CreateSession message, but uses TLV for
// encoding/decoding.
type CreateSessionTLV struct {
BlobType blob.Type
MaxUpdates uint16
RewardBase uint32
RewardRate uint32
SweepFeeRate lnwallet.SatPerKWeight
tlvStream *tlv.Stream
}
// EBlobType is an encoder for blob.Type.
func EBlobType(w io.Writer, val interface{}, buf *[8]byte) error {
if t, ok := val.(*blob.Type); ok {
return tlv.EUint16T(w, uint16(*t), buf)
}
return tlv.NewTypeForEncodingErr(val, "blob.Type")
}
// EBlobType is an decoder for blob.Type.
func DBlobType(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if typ, ok := val.(*blob.Type); ok {
var t uint16
err := tlv.DUint16(r, &t, buf, l)
if err != nil {
return err
}
*typ = blob.Type(t)
return nil
}
return tlv.NewTypeForDecodingErr(val, "blob.Type", l, 2)
}
// ESatPerKW is an encoder for lnwallet.SatPerKWeight.
func ESatPerKW(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*lnwallet.SatPerKWeight); ok {
return tlv.EUint64(w, uint64(*v), buf)
}
return tlv.NewTypeForEncodingErr(val, "lnwallet.SatPerKWeight")
}
// DSatPerKW is an decoder for lnwallet.SatPerKWeight.
func DSatPerKW(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*lnwallet.SatPerKWeight); ok {
var sat uint64
err := tlv.DUint64(r, &sat, buf, l)
if err != nil {
return err
}
*v = lnwallet.SatPerKWeight(sat)
return nil
}
return tlv.NewTypeForDecodingErr(val, "lnwallet.SatPerKWeight", l, 8)
}
// NewCreateSessionTLV initializes a new CreateSessionTLV message.
func NewCreateSessionTLV() *CreateSessionTLV {
m := &CreateSessionTLV{}
m.tlvStream = tlv.MustNewStream(
tlv.MakeStaticRecord(0, &m.BlobType, 2, EBlobType, DBlobType),
tlv.MakePrimitiveRecord(1, &m.MaxUpdates),
tlv.MakePrimitiveRecord(2, &m.RewardBase),
tlv.MakePrimitiveRecord(3, &m.RewardRate),
tlv.MakeStaticRecord(4, &m.SweepFeeRate, 8, ESatPerKW, DSatPerKW),
)
return m
}
// Encode writes the CreateSessionTLV to the passed io.Writer.
func (c *CreateSessionTLV) Encode(w io.Writer) error {
return c.tlvStream.Encode(w)
}
// Decode reads the CreateSessionTLV from the passed io.Reader.
func (c *CreateSessionTLV) Decode(r io.Reader) error {
return c.tlvStream.Decode(r)
}
// BenchmarkEncodeCreateSession benchmarks encoding of the non-TLV
// CreateSession.
func BenchmarkEncodeCreateSession(t *testing.B) {
m := &wtwire.CreateSession{}
t.ReportAllocs()
t.ResetTimer()
var err error
for i := 0; i < t.N; i++ {
err = m.Encode(ioutil.Discard, 0)
}
_ = err
}
// BenchmarkEncodeCreateSessionTLV benchmarks encoding of the TLV CreateSession.
func BenchmarkEncodeCreateSessionTLV(t *testing.B) {
m := NewCreateSessionTLV()
t.ReportAllocs()
t.ResetTimer()
var err error
for i := 0; i < t.N; i++ {
err = m.Encode(ioutil.Discard)
}
_ = err
}
// BenchmarkDecodeCreateSession benchmarks encoding of the non-TLV
// CreateSession.
func BenchmarkDecodeCreateSession(t *testing.B) {
m := &wtwire.CreateSession{}
var b bytes.Buffer
m.Encode(&b, 0)
r := bytes.NewReader(b.Bytes())
t.ReportAllocs()
t.ResetTimer()
var err error
for i := 0; i < t.N; i++ {
r.Seek(0, 0)
err = m.Decode(r, 0)
}
_ = err
}
// BenchmarkDecodeCreateSessionTLV benchmarks decoding of the TLV CreateSession.
func BenchmarkDecodeCreateSessionTLV(t *testing.B) {
m := NewCreateSessionTLV()
var b bytes.Buffer
var err error
m.Encode(&b)
r := bytes.NewReader(b.Bytes())
t.ReportAllocs()
t.ResetTimer()
for i := 0; i < t.N; i++ {
r.Seek(0, 0)
err = m.Decode(r)
}
_ = err
}

309
tlv/primitive.go Normal file

@ -0,0 +1,309 @@
package tlv
import (
"encoding/binary"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec"
)
// ErrTypeForEncoding signals that an incorrect type was passed to an Encoder.
type ErrTypeForEncoding struct {
val interface{}
expType string
}
// NewTypeForEncodingErr creates a new ErrTypeForEncoding given the incorrect
// val and the expected type.
func NewTypeForEncodingErr(val interface{}, expType string) ErrTypeForEncoding {
return ErrTypeForEncoding{
val: val,
expType: expType,
}
}
// Error returns a human-readable description of the type mismatch.
func (e ErrTypeForEncoding) Error() string {
return fmt.Sprintf("ErrTypeForEncoding want (type: *%s), "+
"got (type: %T)", e.expType, e.val)
}
// ErrTypeForDecoding signals that an incorrect type was passed to a Decoder or
// that the expected length of the encoding is different from that required by
// the expected type.
type ErrTypeForDecoding struct {
val interface{}
expType string
valLength uint64
expLength uint64
}
// NewTypeForDecodingErr creates a new ErrTypeForDecoding given the incorrect
// val and expected type, or the mismatch in their expected lengths.
func NewTypeForDecodingErr(val interface{}, expType string,
valLength, expLength uint64) ErrTypeForDecoding {
return ErrTypeForDecoding{
val: val,
expType: expType,
valLength: valLength,
expLength: expLength,
}
}
// Error returns a human-readable description of the type mismatch.
func (e ErrTypeForDecoding) Error() string {
return fmt.Sprintf("ErrTypeForDecoding want (type: *%s, length: %v), "+
"got (type: %T, length: %v)", e.expType, e.expLength, e.val,
e.valLength)
}
var (
byteOrder = binary.BigEndian
)
// EUint8 is an Encoder for uint8 values. An error is returned if val is not a
// *uint8.
func EUint8(w io.Writer, val interface{}, buf *[8]byte) error {
if i, ok := val.(*uint8); ok {
buf[0] = *i
_, err := w.Write(buf[:1])
return err
}
return ErrTypeForEncoding{val, "uint8"}
}
// EUint8T encodes a uint8 val to the provided io.Writer. This method is exposed
// so that encodings for custom uint8-like types can be created without
// incurring an extra heap allocation.
func EUint8T(w io.Writer, val uint8, buf *[8]byte) error {
buf[0] = val
_, err := w.Write(buf[:1])
return err
}
// EUint16 is an Encoder for uint16 values. An error is returned if val is not a
// *uint16.
func EUint16(w io.Writer, val interface{}, buf *[8]byte) error {
if i, ok := val.(*uint16); ok {
byteOrder.PutUint16(buf[:2], *i)
_, err := w.Write(buf[:2])
return err
}
return ErrTypeForEncoding{val, "uint16"}
}
// EUint16T encodes a uint16 val to the provided io.Writer. This method is
// exposed so that encodings for custom uint16-like types can be created without
// incurring an extra heap allocation.
func EUint16T(w io.Writer, val uint16, buf *[8]byte) error {
byteOrder.PutUint16(buf[:2], val)
_, err := w.Write(buf[:2])
return err
}
// EUint32 is an Encoder for uint32 values. An error is returned if val is not a
// *uint32.
func EUint32(w io.Writer, val interface{}, buf *[8]byte) error {
if i, ok := val.(*uint32); ok {
byteOrder.PutUint32(buf[:4], *i)
_, err := w.Write(buf[:4])
return err
}
return ErrTypeForEncoding{val, "uint32"}
}
// EUint32T encodes a uint32 val to the provided io.Writer. This method is
// exposed so that encodings for custom uint32-like types can be created without
// incurring an extra heap allocation.
func EUint32T(w io.Writer, val uint32, buf *[8]byte) error {
byteOrder.PutUint32(buf[:4], val)
_, err := w.Write(buf[:4])
return err
}
// EUint64 is an Encoder for uint64 values. An error is returned if val is not a
// *uint64.
func EUint64(w io.Writer, val interface{}, buf *[8]byte) error {
if i, ok := val.(*uint64); ok {
byteOrder.PutUint64(buf[:], *i)
_, err := w.Write(buf[:])
return err
}
return ErrTypeForEncoding{val, "uint64"}
}
// EUint64T encodes a uint64 val to the provided io.Writer. This method is
// exposed so that encodings for custom uint64-like types can be created without
// incurring an extra heap allocation.
func EUint64T(w io.Writer, val uint64, buf *[8]byte) error {
byteOrder.PutUint64(buf[:], val)
_, err := w.Write(buf[:])
return err
}
// DUint8 is a Decoder for uint8 values. An error is returned if val is not a
// *uint8.
func DUint8(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if i, ok := val.(*uint8); ok && l == 1 {
if _, err := io.ReadFull(r, buf[:1]); err != nil {
return err
}
*i = buf[0]
return nil
}
return ErrTypeForDecoding{val, "uint8", l, 1}
}
// DUint16 is a Decoder for uint16 values. An error is returned if val is not a
// *uint16.
func DUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if i, ok := val.(*uint16); ok && l == 2 {
if _, err := io.ReadFull(r, buf[:2]); err != nil {
return err
}
*i = byteOrder.Uint16(buf[:2])
return nil
}
return ErrTypeForDecoding{val, "uint16", l, 2}
}
// DUint32 is a Decoder for uint32 values. An error is returned if val is not a
// *uint32.
func DUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if i, ok := val.(*uint32); ok && l == 4 {
if _, err := io.ReadFull(r, buf[:4]); err != nil {
return err
}
*i = byteOrder.Uint32(buf[:4])
return nil
}
return ErrTypeForDecoding{val, "uint32", l, 4}
}
// DUint64 is a Decoder for uint64 values. An error is returned if val is not a
// *uint64.
func DUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if i, ok := val.(*uint64); ok && l == 8 {
if _, err := io.ReadFull(r, buf[:]); err != nil {
return err
}
*i = byteOrder.Uint64(buf[:])
return nil
}
return ErrTypeForDecoding{val, "uint64", l, 8}
}
// EBytes32 is an Encoder for 32-byte arrays. An error is returned if val is not
// a *[32]byte.
func EBytes32(w io.Writer, val interface{}, _ *[8]byte) error {
if b, ok := val.(*[32]byte); ok {
_, err := w.Write(b[:])
return err
}
return ErrTypeForEncoding{val, "[32]byte"}
}
// DBytes32 is a Decoder for 32-byte arrays. An error is returned if val is not
// a *[32]byte.
func DBytes32(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*[32]byte); ok && l == 32 {
_, err := io.ReadFull(r, b[:])
return err
}
return ErrTypeForDecoding{val, "[32]byte", l, 32}
}
// EBytes33 is an Encoder for 33-byte arrays. An error is returned if val is not
// a *[33]byte.
func EBytes33(w io.Writer, val interface{}, _ *[8]byte) error {
if b, ok := val.(*[33]byte); ok {
_, err := w.Write(b[:])
return err
}
return ErrTypeForEncoding{val, "[33]byte"}
}
// DBytes33 is a Decoder for 33-byte arrays. An error is returned if val is not
// a *[33]byte.
func DBytes33(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*[33]byte); ok {
_, err := io.ReadFull(r, b[:])
return err
}
return ErrTypeForDecoding{val, "[33]byte", l, 33}
}
// EBytes64 is an Encoder for 64-byte arrays. An error is returned if val is not
// a *[64]byte.
func EBytes64(w io.Writer, val interface{}, _ *[8]byte) error {
if b, ok := val.(*[64]byte); ok {
_, err := w.Write(b[:])
return err
}
return ErrTypeForEncoding{val, "[64]byte"}
}
// DBytes64 is an Decoder for 64-byte arrays. An error is returned if val is not
// a *[64]byte.
func DBytes64(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*[64]byte); ok && l == 64 {
_, err := io.ReadFull(r, b[:])
return err
}
return ErrTypeForDecoding{val, "[64]byte", l, 64}
}
// EPubKey is an Encoder for *btcec.PublicKey values. An error is returned if
// val is not a **btcec.PublicKey.
func EPubKey(w io.Writer, val interface{}, _ *[8]byte) error {
if pk, ok := val.(**btcec.PublicKey); ok {
_, err := w.Write((*pk).SerializeCompressed())
return err
}
return ErrTypeForEncoding{val, "*btcec.PublicKey"}
}
// DPubKey is a Decoder for *btcec.PublicKey values. An error is returned if val
// is not a **btcec.PublicKey.
func DPubKey(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if pk, ok := val.(**btcec.PublicKey); ok && l == 33 {
var b [33]byte
_, err := io.ReadFull(r, b[:])
if err != nil {
return err
}
p, err := btcec.ParsePubKey(b[:], btcec.S256())
if err != nil {
return err
}
*pk = p
return nil
}
return ErrTypeForDecoding{val, "*btcec.PublicKey", l, 33}
}
// EVarBytes is an Encoder for variable byte slices. An error is returned if val
// is not *[]byte.
func EVarBytes(w io.Writer, val interface{}, _ *[8]byte) error {
if b, ok := val.(*[]byte); ok {
_, err := w.Write(*b)
return err
}
return ErrTypeForEncoding{val, "[]byte"}
}
// DVarBytes is a Decoder for variable byte slices. An error is returned if val
// is not *[]byte.
func DVarBytes(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*[]byte); ok {
*b = make([]byte, l)
_, err := io.ReadFull(r, *b)
return err
}
return ErrTypeForDecoding{val, "[]byte", l, l}
}

153
tlv/record.go Normal file

@ -0,0 +1,153 @@
package tlv
import (
"io"
"github.com/btcsuite/btcd/btcec"
)
// Type is an 64-bit identifier for a TLV Record.
type Type uint64
// Encoder is a signature for methods that can encode TLV values. An error
// should be returned if the Encoder cannot support the underlying type of val.
// The provided scratch buffer must be non-nil.
type Encoder func(w io.Writer, val interface{}, buf *[8]byte) error
// Decoder is a signature for methods that can decode TLV values. An error
// should be returned if the Decoder cannot support the underlying type of val.
// The provided scratch buffer must be non-nil.
type Decoder func(r io.Reader, val interface{}, buf *[8]byte, l uint64) error
// ENOP is an encoder that doesn't modify the io.Writer and never fails.
func ENOP(io.Writer, interface{}, *[8]byte) error { return nil }
// DNOP is an encoder that doesn't modify the io.Reader and never fails.
func DNOP(io.Reader, interface{}, *[8]byte, uint64) error { return nil }
// SizeFunc is a function that can compute the length of a given field. Since
// the size of the underlying field can change, this allows the size of the
// field to be evaluated at the time of encoding.
type SizeFunc func() uint64
// SizeVarBytes returns a SizeFunc that can compute the length of a byte slice.
func SizeVarBytes(e *[]byte) SizeFunc {
return func() uint64 {
return uint64(len(*e))
}
}
// Record holds the required information to encode or decode a TLV record.
type Record struct {
value interface{}
typ Type
staticSize uint64
sizeFunc SizeFunc
encoder Encoder
decoder Decoder
}
// Size returns the size of the Record's value. If no static size is known, the
// dynamic size will be evaluated.
func (f *Record) Size() uint64 {
if f.sizeFunc == nil {
return f.staticSize
}
return f.sizeFunc()
}
// MakePrimitiveRecord creates a record for common types.
func MakePrimitiveRecord(typ Type, val interface{}) Record {
var (
staticSize uint64
sizeFunc SizeFunc
encoder Encoder
decoder Decoder
)
switch e := val.(type) {
case *uint8:
staticSize = 1
encoder = EUint8
decoder = DUint8
case *uint16:
staticSize = 2
encoder = EUint16
decoder = DUint16
case *uint32:
staticSize = 4
encoder = EUint32
decoder = DUint32
case *uint64:
staticSize = 8
encoder = EUint64
decoder = DUint64
case *[32]byte:
staticSize = 32
encoder = EBytes32
decoder = DBytes32
case *[33]byte:
staticSize = 33
encoder = EBytes33
decoder = DBytes33
case **btcec.PublicKey:
staticSize = 33
encoder = EPubKey
decoder = DPubKey
case *[64]byte:
staticSize = 64
encoder = EBytes64
decoder = DBytes64
case *[]byte:
sizeFunc = SizeVarBytes(e)
encoder = EVarBytes
decoder = DVarBytes
default:
panic("unknown primitive type")
}
return Record{
value: val,
typ: typ,
staticSize: staticSize,
sizeFunc: sizeFunc,
encoder: encoder,
decoder: decoder,
}
}
// MakeStaticRecord creates a record for a field of fixed-size
func MakeStaticRecord(typ Type, val interface{}, size uint64, encoder Encoder,
decoder Decoder) Record {
return Record{
value: val,
typ: typ,
staticSize: size,
encoder: encoder,
decoder: decoder,
}
}
// MakeDynamicRecord creates a record whose size may vary, and will be
// determined at the time of encoding via sizeFunc.
func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc,
encoder Encoder, decoder Decoder) Record {
return Record{
value: val,
typ: typ,
sizeFunc: sizeFunc,
encoder: encoder,
decoder: decoder,
}
}

280
tlv/stream.go Normal file

@ -0,0 +1,280 @@
package tlv
import (
"errors"
"fmt"
"io"
"io/ioutil"
"math"
)
// ErrStreamNotCanonical signals that a decoded stream does not contain records
// sorting by monotonically-increasing type.
var ErrStreamNotCanonical = errors.New("tlv stream is not canonical")
// ErrUnknownRequiredType is an error returned when decoding an unknown and even
// type from a Stream.
type ErrUnknownRequiredType Type
// Error returns a human-readable description of unknown required type.
func (t ErrUnknownRequiredType) Error() string {
return fmt.Sprintf("unknown required type: %d", t)
}
// Stream defines a TLV stream that can be used for encoding or decoding a set
// of TLV Records.
type Stream struct {
records []Record
buf [8]byte
}
// NewStream creates a new TLV Stream given an encoding codec, a decoding codec,
// and a set of known records.
func NewStream(records ...Record) (*Stream, error) {
// Assert that the ordering of the Records is canonical and appear in
// ascending order of type.
var (
min Type
overflow bool
)
for _, record := range records {
if overflow || record.typ < min {
return nil, ErrStreamNotCanonical
}
if record.encoder == nil {
record.encoder = ENOP
}
if record.decoder == nil {
record.decoder = DNOP
}
if record.typ == math.MaxUint64 {
overflow = true
}
min = record.typ + 1
}
return &Stream{
records: records,
}, nil
}
// MustNewStream creates a new TLV Stream given an encoding codec, a decoding
// codec, and a set of known records. If an error is encountered in creating the
// stream, this method will panic instead of returning the error.
func MustNewStream(records ...Record) *Stream {
stream, err := NewStream(records...)
if err != nil {
panic(err.Error())
}
return stream
}
// Encode writes a Stream to the passed io.Writer. Each of the Records known to
// the Stream is written in ascending order of their type so as to be canonical.
//
// The stream is constructed by concatenating the individual, serialized Records
// where each record has the following format:
// [varint: type]
// [varint: length]
// [length: value]
//
// An error is returned if the io.Writer fails to accept bytes from the
// encoding, and nothing else. The ordering of the Records is asserted upon the
// creation of a Stream, and thus the output will be by definition canonical.
func (s *Stream) Encode(w io.Writer) error {
// Iterate through all known records, if any, serializing each record's
// type, length and value.
for i := range s.records {
rec := &s.records[i]
// Write the record's type as a varint.
err := WriteVarInt(w, uint64(rec.typ), &s.buf)
if err != nil {
return err
}
// Write the record's length as a varint.
err = WriteVarInt(w, rec.Size(), &s.buf)
if err != nil {
return err
}
// Encode the current record's value using the stream's codec.
err = rec.encoder(w, rec.value, &s.buf)
if err != nil {
return err
}
}
return nil
}
// Decode deserializes TLV Stream from the passed io.Reader. The Stream will
// inspect each record that is parsed and check to see if it has a corresponding
// Record to facilitate deserialization of that field. If the record is unknown,
// the Stream will discard the record's bytes and proceed to the subsequent
// record.
//
// Each record has the following format:
// [varint: type]
// [varint: length]
// [length: value]
//
// A series of (possibly zero) records are concatenated into a stream, this
// example contains two records:
//
// (t: 0x01, l: 0x04, v: 0xff, 0xff, 0xff, 0xff)
// (t: 0x02, l: 0x01, v: 0x01)
//
// This method asserts that the byte stream is canonical, namely that each
// record is unique and that all records are sorted in ascending order. An
// ErrNotCanonicalStream error is returned if the encoded TLV stream is not.
//
// We permit an io.EOF error only when reading the type byte which signals that
// the last record was read cleanly and we should stop parsing. All other io.EOF
// or io.ErrUnexpectedEOF errors are returned.
func (s *Stream) Decode(r io.Reader) error {
var (
typ Type
min Type
recordIdx int
overflow bool
)
// Iterate through all possible type identifiers. As types are read from
// the io.Reader, min will skip forward to the last read type.
for {
// Read the next varint type.
t, err := ReadVarInt(r, &s.buf)
switch {
// We'll silence an EOF when zero bytes remain, meaning the
// stream was cleanly encoded.
case err == io.EOF:
return nil
// Other unexpected errors.
case err != nil:
return err
}
typ = Type(t)
// Assert that this type is greater than any previously read.
// If we've already overflowed and we parsed another type, the
// stream is not canonical. This check prevents us from accepts
// encodings that have duplicate records or from accepting an
// unsorted series.
if overflow || typ < min {
return ErrStreamNotCanonical
}
// Read the varint length.
length, err := ReadVarInt(r, &s.buf)
switch {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return err
}
// Search the records known to the stream for this type. We'll
// begin the search and recordIdx and walk forward until we find
// it or the next record's type is larger.
rec, newIdx, ok := s.getRecord(typ, recordIdx)
switch {
// We know of this record type, proceed to decode the value.
// This method asserts that length bytes are read in the
// process, and returns an error if the number of bytes is not
// exactly length.
case ok:
err := rec.decoder(r, rec.value, &s.buf, length)
switch {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return err
}
// This record type is unknown to the stream, fail if the type
// is even meaning that we are required to understand it.
case typ%2 == 0:
return ErrUnknownRequiredType(typ)
// Otherwise, the record type is unknown and is odd, discard the
// number of bytes specified by length.
default:
_, err := io.CopyN(ioutil.Discard, r, int64(length))
switch {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return err
}
}
// Update our record index so that we can begin our next search
// from where we left off.
recordIdx = newIdx
// If we've parsed the largest possible type, the next loop will
// overflow back to zero. However, we need to attempt parsing
// the next type to ensure that the stream is empty.
if typ == math.MaxUint64 {
overflow = true
}
// Finally, set our lower bound on the next accepted type.
min = typ + 1
}
}
// getRecord searches for a record matching typ known to the stream. The boolean
// return value indicates whether the record is known to the stream. The integer
// return value carries the index from where getRecord should be invoked on the
// subsequent call. The first call to getRecord should always use an idx of 0.
func (s *Stream) getRecord(typ Type, idx int) (Record, int, bool) {
for idx < len(s.records) {
record := s.records[idx]
switch {
// Found target record, return it to the caller. The next index
// returned points to the immediately following record.
case record.typ == typ:
return record, idx + 1, true
// This record's type is lower than the target. Advance our
// index and continue to the next record which will have a
// strictly higher type.
case record.typ < typ:
idx++
continue
// This record's type is larger than the target, hence we have
// no record matching the current type. Return the current index
// so that we can start our search from here when processing the
// next tlv record.
default:
return Record{}, idx, false
}
}
// All known records are exhausted.
return Record{}, idx, false
}

559
tlv/tlv_test.go Normal file

@ -0,0 +1,559 @@
package tlv_test
import (
"bytes"
"errors"
"io"
"reflect"
"testing"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/tlv"
)
type nodeAmts struct {
nodeID *btcec.PublicKey
amt1 uint64
amt2 uint64
}
func ENodeAmts(w io.Writer, val interface{}, buf *[8]byte) error {
if t, ok := val.(*nodeAmts); ok {
if err := tlv.EPubKey(w, &t.nodeID, buf); err != nil {
return err
}
if err := tlv.EUint64T(w, t.amt1, buf); err != nil {
return err
}
return tlv.EUint64T(w, t.amt2, buf)
}
return tlv.NewTypeForEncodingErr(val, "nodeAmts")
}
func DNodeAmts(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if t, ok := val.(*nodeAmts); ok && l == 49 {
if err := tlv.DPubKey(r, &t.nodeID, buf, 33); err != nil {
return err
}
if err := tlv.DUint64(r, &t.amt1, buf, 8); err != nil {
return err
}
return tlv.DUint64(r, &t.amt2, buf, 8)
}
return tlv.NewTypeForDecodingErr(val, "nodeAmts", l, 49)
}
type N1 struct {
amt uint64
scid uint64
nodeAmts nodeAmts
cltvDelta uint16
stream *tlv.Stream
}
func (n *N1) sizeAmt() uint64 {
return tlv.SizeTUint64(n.amt)
}
func NewN1() *N1 {
n := new(N1)
n.stream = tlv.MustNewStream(
tlv.MakeDynamicRecord(
1, &n.amt, n.sizeAmt, tlv.ETUint64, tlv.DTUint64,
),
tlv.MakePrimitiveRecord(2, &n.scid),
tlv.MakeStaticRecord(3, &n.nodeAmts, 49, ENodeAmts, DNodeAmts),
tlv.MakePrimitiveRecord(254, &n.cltvDelta),
)
return n
}
func (n *N1) Encode(w io.Writer) error {
return n.stream.Encode(w)
}
func (n *N1) Decode(r io.Reader) error {
return n.stream.Decode(r)
}
type N2 struct {
amt uint64
cltvExpiry uint32
stream *tlv.Stream
}
func (n *N2) sizeAmt() uint64 {
return tlv.SizeTUint64(n.amt)
}
func (n *N2) sizeCltv() uint64 {
return tlv.SizeTUint32(n.cltvExpiry)
}
func NewN2() *N2 {
n := new(N2)
n.stream = tlv.MustNewStream(
tlv.MakeDynamicRecord(
0, &n.amt, n.sizeAmt, tlv.ETUint64, tlv.DTUint64,
),
tlv.MakeDynamicRecord(
11, &n.cltvExpiry, n.sizeCltv, tlv.ETUint32, tlv.DTUint32,
),
)
return n
}
func (n *N2) Encode(w io.Writer) error {
return n.stream.Encode(w)
}
func (n *N2) Decode(r io.Reader) error {
return n.stream.Decode(r)
}
var tlvDecodingFailureTests = []struct {
name string
bytes []byte
expErr error
// skipN2 if true, will cause the test to only be executed on N1.
skipN2 bool
}{
{
name: "type truncated",
bytes: []byte{0xfd},
expErr: io.ErrUnexpectedEOF,
},
{
name: "type truncated",
bytes: []byte{0xfd, 0x01},
expErr: io.ErrUnexpectedEOF,
},
{
name: "not minimally encoded type",
bytes: []byte{0xfd, 0x00, 0x01}, // spec has trailing 0x00
expErr: tlv.ErrVarIntNotCanonical,
},
{
name: "missing length",
bytes: []byte{0xfd, 0x01, 0x01},
expErr: io.ErrUnexpectedEOF,
},
{
name: "length truncated",
bytes: []byte{0x0f, 0xfd},
expErr: io.ErrUnexpectedEOF,
},
{
name: "length truncated",
bytes: []byte{0x0f, 0xfd, 0x26},
expErr: io.ErrUnexpectedEOF,
},
{
name: "missing value",
bytes: []byte{0x0f, 0xfd, 0x26, 0x02},
expErr: io.ErrUnexpectedEOF,
},
{
name: "not minimally encoded length",
bytes: []byte{0x0f, 0xfd, 0x00, 0x01}, // spec has trailing 0x00
expErr: tlv.ErrVarIntNotCanonical,
},
{
name: "value truncated",
bytes: []byte{0x0f, 0xfd, 0x02, 0x01,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
expErr: io.ErrUnexpectedEOF,
},
{
name: "unknown even type",
bytes: []byte{0x12, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x12),
},
{
name: "unknown even type",
bytes: []byte{0xfd, 0x01, 0x02, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x102),
},
{
name: "unknown even type",
bytes: []byte{0xfe, 0x01, 0x00, 0x00, 0x02, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x01000002),
},
{
name: "unknown even type",
bytes: []byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x0100000000000002),
},
{
name: "greater than encoding length for n1's amt",
bytes: []byte{0x01, 0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 9, 8),
skipN2: true,
},
{
name: "encoding for n1's amt is not minimal",
bytes: []byte{0x01, 0x01, 0x00},
expErr: tlv.ErrTUintNotMinimal,
skipN2: true,
},
{
name: "encoding for n1's amt is not minimal",
bytes: []byte{0x01, 0x02, 0x00, 0x01},
expErr: tlv.ErrTUintNotMinimal,
skipN2: true,
},
{
name: "encoding for n1's amt is not minimal",
bytes: []byte{0x01, 0x03, 0x00, 0x01, 0x00},
expErr: tlv.ErrTUintNotMinimal,
skipN2: true,
},
{
name: "encoding for n1's amt is not minimal",
bytes: []byte{0x01, 0x04, 0x00, 0x01, 0x00, 0x00},
expErr: tlv.ErrTUintNotMinimal,
skipN2: true,
},
{
name: "encoding for n1's amt is not minimal",
bytes: []byte{0x01, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00},
expErr: tlv.ErrTUintNotMinimal,
skipN2: true,
},
{
name: "encoding for n1's amt is not minimal",
bytes: []byte{0x01, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
expErr: tlv.ErrTUintNotMinimal,
skipN2: true,
},
{
name: "encoding for n1's amt is not minimal",
bytes: []byte{0x01, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00},
expErr: tlv.ErrTUintNotMinimal,
skipN2: true,
},
{
name: "encoding for n1's amt is not minimal",
bytes: []byte{0x01, 0x08, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
expErr: tlv.ErrTUintNotMinimal,
skipN2: true,
},
{
name: "less than encoding length for n1's scid",
bytes: []byte{0x02, 0x07, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 7, 8),
skipN2: true,
},
{
name: "less than encoding length for n1's scid",
bytes: []byte{0x02, 0x09, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 9, 8),
skipN2: true,
},
{
name: "less than encoding length for n1's nodeAmts",
bytes: []byte{0x03, 0x29,
0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01,
},
expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 41, 49),
skipN2: true,
},
{
name: "less than encoding length for n1's nodeAmts",
bytes: []byte{0x03, 0x30,
0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x01,
},
expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 48, 49),
skipN2: true,
},
{
name: "n1's node_id is not a valid point",
bytes: []byte{0x03, 0x31,
0x04, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x02,
},
expErr: errors.New("invalid magic in compressed pubkey string: 4"),
skipN2: true,
},
{
name: "greater than encoding length for n1's nodeAmts",
bytes: []byte{0x03, 0x32,
0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
},
expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 50, 49),
skipN2: true,
},
{
name: "unknown required type or n1",
bytes: []byte{0x00, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x00),
skipN2: true,
},
{
name: "less than encoding length for n1's cltvDelta",
bytes: []byte{0xfd, 0x00, 0x0fe, 0x00},
expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 0, 2),
skipN2: true,
},
{
name: "less than encoding length for n1's cltvDelta",
bytes: []byte{0xfd, 0x00, 0xfe, 0x01, 0x01},
expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 1, 2),
skipN2: true,
},
{
name: "greater than encoding length for n1's cltvDelta",
bytes: []byte{0xfd, 0x00, 0xfe, 0x03, 0x01, 0x01, 0x01},
expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 3, 2),
skipN2: true,
},
{
name: "unknown even field for n1's namespace",
bytes: []byte{0x0a, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x0a),
skipN2: true,
},
{
name: "valid records but invalid ordering",
bytes: []byte{0x02, 0x08,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x26, 0x01,
0x01, 0x2a,
},
expErr: tlv.ErrStreamNotCanonical,
skipN2: true,
},
{
name: "duplicate tlv type",
bytes: []byte{0x02, 0x08,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x31, 0x02,
0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x51,
},
expErr: tlv.ErrStreamNotCanonical,
skipN2: true,
},
{
name: "duplicate ignored tlv type",
bytes: []byte{0x1f, 0x00, 0x1f, 0x01, 0x2a},
expErr: tlv.ErrStreamNotCanonical,
skipN2: true,
},
{
name: "type wraparound",
bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00},
expErr: tlv.ErrStreamNotCanonical,
},
}
// TestTLVDecodingSuccess asserts that the TLV parser fails to decode invalid
// TLV streams.
func TestTLVDecodingFailures(t *testing.T) {
for _, test := range tlvDecodingFailureTests {
t.Run(test.name, func(t *testing.T) {
n1 := NewN1()
r := bytes.NewReader(test.bytes)
err := n1.Decode(r)
if !reflect.DeepEqual(err, test.expErr) {
t.Fatalf("expected N1 decoding failure: %v, "+
"got: %v", test.expErr, err)
}
if test.skipN2 {
return
}
n2 := NewN2()
r = bytes.NewReader(test.bytes)
err = n2.Decode(r)
if !reflect.DeepEqual(err, test.expErr) {
t.Fatalf("expected N2 decoding failure: %v, "+
"got: %v", test.expErr, err)
}
})
}
}
var tlvDecodingSuccessTests = []struct {
name string
bytes []byte
skipN2 bool
}{
{
name: "empty",
},
{
name: "unknown odd type",
bytes: []byte{0x21, 0x00},
},
{
name: "unknown odd type",
bytes: []byte{0xfd, 0x02, 0x01, 0x00},
},
{
name: "unknown odd type",
bytes: []byte{0xfd, 0x00, 0xfd, 0x00},
},
{
name: "unknown odd type",
bytes: []byte{0xfd, 0x00, 0xff, 0x00},
},
{
name: "unknown odd type",
bytes: []byte{0xfe, 0x02, 0x00, 0x00, 0x01, 0x00},
},
{
name: "unknown odd type",
bytes: []byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00},
},
{
name: "N1 amt=0",
bytes: []byte{0x01, 0x00},
skipN2: true,
},
{
name: "N1 amt=1",
bytes: []byte{0x01, 0x01, 0x01},
skipN2: true,
},
{
name: "N1 amt=256",
bytes: []byte{0x01, 0x02, 0x01, 0x00},
skipN2: true,
},
{
name: "N1 amt=65536",
bytes: []byte{0x01, 0x03, 0x01, 0x00, 0x00},
skipN2: true,
},
{
name: "N1 amt=16777216",
bytes: []byte{0x01, 0x04, 0x01, 0x00, 0x00, 0x00},
skipN2: true,
},
{
name: "N1 amt=4294967296",
bytes: []byte{0x01, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00},
skipN2: true,
},
{
name: "N1 amt=1099511627776",
bytes: []byte{0x01, 0x06, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00},
skipN2: true,
},
{
name: "N1 amt=281474976710656",
bytes: []byte{0x01, 0x07, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
skipN2: true,
},
{
name: "N1 amt=72057594037927936",
bytes: []byte{0x01, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
skipN2: true,
},
{
name: "N1 scid=0x0x550",
bytes: []byte{0x02, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x26},
skipN2: true,
},
{
name: "N1 node_id=023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb amount_msat_1=1 amount_msat_2=2",
bytes: []byte{0x03, 0x31,
0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x02},
skipN2: true,
},
{
name: "N1 cltv_delta=550",
bytes: []byte{0xfd, 0x00, 0xfe, 0x02, 0x02, 0x26},
skipN2: true,
},
}
// TestTLVDecodingSuccess asserts that the TLV parser is able to successfully
// decode valid TLV streams.
func TestTLVDecodingSuccess(t *testing.T) {
for _, test := range tlvDecodingSuccessTests {
t.Run(test.name, func(t *testing.T) {
n1 := NewN1()
r := bytes.NewReader(test.bytes)
err := n1.Decode(r)
if err != nil {
t.Fatalf("expected N1 decoding success, got: %v",
err)
}
if test.skipN2 {
return
}
n2 := NewN2()
r = bytes.NewReader(test.bytes)
err = n2.Decode(r)
if err != nil {
t.Fatalf("expected N2 decoding succes, got: %v",
err)
}
})
}
}

180
tlv/truncated.go Normal file

@ -0,0 +1,180 @@
package tlv
import (
"encoding/binary"
"errors"
"io"
)
// ErrTUintNotMinimal signals that decoding a truncated uint failed because the
// value was not minimally encoded.
var ErrTUintNotMinimal = errors.New("truncated uint not minimally encoded")
// numLeadingZeroBytes16 computes the number of leading zeros for a uint16.
func numLeadingZeroBytes16(v uint16) uint64 {
switch {
case v == 0:
return 2
case v&0xff00 == 0:
return 1
default:
return 0
}
}
// SizeTUint16 returns the number of bytes remaining in a uint16 after
// truncating the leading zeros.
func SizeTUint16(v uint16) uint64 {
return 2 - numLeadingZeroBytes16(v)
}
// ETUint16 is an Encoder for truncated uint16 values, where leading zeros will
// be omitted. An error is returned if val is not a *uint16.
func ETUint16(w io.Writer, val interface{}, buf *[8]byte) error {
if t, ok := val.(*uint16); ok {
binary.BigEndian.PutUint16(buf[:2], *t)
numZeros := numLeadingZeroBytes16(*t)
_, err := w.Write(buf[numZeros:2])
return err
}
return NewTypeForEncodingErr(val, "uint16")
}
// DTUint16 is an Decoder for truncated uint16 values, where leading zeros will
// be resurrected. An error is returned if val is not a *uint16.
func DTUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if t, ok := val.(*uint16); ok && l <= 2 {
_, err := io.ReadFull(r, buf[2-l:])
if err != nil {
return err
}
zero(buf[:2-l])
*t = binary.BigEndian.Uint16(buf[:2])
if 2-numLeadingZeroBytes16(*t) != l {
return ErrTUintNotMinimal
}
return nil
}
return NewTypeForDecodingErr(val, "uint16", l, 2)
}
// numLeadingZeroBytes16 computes the number of leading zeros for a uint32.
func numLeadingZeroBytes32(v uint32) uint64 {
switch {
case v == 0:
return 4
case v&0xffffff00 == 0:
return 3
case v&0xffff0000 == 0:
return 2
case v&0xff000000 == 0:
return 1
default:
return 0
}
}
// SizeTUint32 returns the number of bytes remaining in a uint32 after
// truncating the leading zeros.
func SizeTUint32(v uint32) uint64 {
return 4 - numLeadingZeroBytes32(v)
}
// ETUint32 is an Encoder for truncated uint32 values, where leading zeros will
// be omitted. An error is returned if val is not a *uint32.
func ETUint32(w io.Writer, val interface{}, buf *[8]byte) error {
if t, ok := val.(*uint32); ok {
binary.BigEndian.PutUint32(buf[:4], *t)
numZeros := numLeadingZeroBytes32(*t)
_, err := w.Write(buf[numZeros:4])
return err
}
return NewTypeForEncodingErr(val, "uint32")
}
// DTUint32 is an Decoder for truncated uint32 values, where leading zeros will
// be resurrected. An error is returned if val is not a *uint32.
func DTUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if t, ok := val.(*uint32); ok && l <= 4 {
_, err := io.ReadFull(r, buf[4-l:])
if err != nil {
return err
}
zero(buf[:4-l])
*t = binary.BigEndian.Uint32(buf[:4])
if 4-numLeadingZeroBytes32(*t) != l {
return ErrTUintNotMinimal
}
return nil
}
return NewTypeForDecodingErr(val, "uint32", l, 4)
}
// numLeadingZeroBytes64 computes the number of leading zeros for a uint32.
//
// TODO(conner): optimize using unrolled binary search
func numLeadingZeroBytes64(v uint64) uint64 {
switch {
case v == 0:
return 8
case v&0xffffffffffffff00 == 0:
return 7
case v&0xffffffffffff0000 == 0:
return 6
case v&0xffffffffff000000 == 0:
return 5
case v&0xffffffff00000000 == 0:
return 4
case v&0xffffff0000000000 == 0:
return 3
case v&0xffff000000000000 == 0:
return 2
case v&0xff00000000000000 == 0:
return 1
default:
return 0
}
}
// SizeTUint64 returns the number of bytes remaining in a uint64 after
// truncating the leading zeros.
func SizeTUint64(v uint64) uint64 {
return 8 - numLeadingZeroBytes64(v)
}
// ETUint64 is an Encoder for truncated uint64 values, where leading zeros will
// be omitted. An error is returned if val is not a *uint64.
func ETUint64(w io.Writer, val interface{}, buf *[8]byte) error {
if t, ok := val.(*uint64); ok {
binary.BigEndian.PutUint64(buf[:], *t)
numZeros := numLeadingZeroBytes64(*t)
_, err := w.Write(buf[numZeros:])
return err
}
return NewTypeForEncodingErr(val, "uint64")
}
// DTUint64 is an Decoder for truncated uint64 values, where leading zeros will
// be resurrected. An error is returned if val is not a *uint64.
func DTUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if t, ok := val.(*uint64); ok && l <= 8 {
_, err := io.ReadFull(r, buf[8-l:])
if err != nil {
return err
}
zero(buf[:8-l])
*t = binary.BigEndian.Uint64(buf[:])
if 8-numLeadingZeroBytes64(*t) != l {
return ErrTUintNotMinimal
}
return nil
}
return NewTypeForDecodingErr(val, "uint64", l, 8)
}
// zero clears the passed byte slice.
func zero(b []byte) {
for i := range b {
b[i] = 0x00
}
}

109
tlv/varint.go Normal file

@ -0,0 +1,109 @@
package tlv
import (
"encoding/binary"
"errors"
"io"
)
// ErrVarIntNotCanonical signals that the decoded varint was not minimally encoded.
var ErrVarIntNotCanonical = errors.New("decoded varint is not canonical")
// ReadVarInt reads a variable length integer from r and returns it as a uint64.
func ReadVarInt(r io.Reader, buf *[8]byte) (uint64, error) {
_, err := io.ReadFull(r, buf[:1])
if err != nil {
return 0, err
}
discriminant := buf[0]
var rv uint64
switch {
case discriminant < 0xfd:
rv = uint64(discriminant)
case discriminant == 0xfd:
_, err := io.ReadFull(r, buf[:2])
switch {
case err == io.EOF:
return 0, io.ErrUnexpectedEOF
case err != nil:
return 0, err
}
rv = uint64(binary.BigEndian.Uint16(buf[:2]))
// The encoding is not canonical if the value could have been
// encoded using fewer bytes.
if rv < 0xfd {
return 0, ErrVarIntNotCanonical
}
case discriminant == 0xfe:
_, err := io.ReadFull(r, buf[:4])
switch {
case err == io.EOF:
return 0, io.ErrUnexpectedEOF
case err != nil:
return 0, err
}
rv = uint64(binary.BigEndian.Uint32(buf[:4]))
// The encoding is not canonical if the value could have been
// encoded using fewer bytes.
if rv <= 0xffff {
return 0, ErrVarIntNotCanonical
}
default:
_, err := io.ReadFull(r, buf[:])
switch {
case err == io.EOF:
return 0, io.ErrUnexpectedEOF
case err != nil:
return 0, err
}
rv = binary.BigEndian.Uint64(buf[:])
// The encoding is not canonical if the value could have been
// encoded using fewer bytes.
if rv <= 0xffffffff {
return 0, ErrVarIntNotCanonical
}
}
return rv, nil
}
// WriteVarInt serializes val to w using a variable number of bytes depending
// on its value.
func WriteVarInt(w io.Writer, val uint64, buf *[8]byte) error {
var length int
switch {
case val < 0xfd:
buf[0] = uint8(val)
length = 1
case val <= 0xffff:
buf[0] = uint8(0xfd)
binary.BigEndian.PutUint16(buf[1:3], uint16(val))
length = 3
case val <= 0xffffffff:
buf[0] = uint8(0xfe)
binary.BigEndian.PutUint32(buf[1:5], uint32(val))
length = 5
default:
buf[0] = uint8(0xff)
_, err := w.Write(buf[:1])
if err != nil {
return err
}
binary.BigEndian.PutUint64(buf[:], uint64(val))
length = 8
}
_, err := w.Write(buf[:length])
return err
}

217
tlv/varint_test.go Normal file

@ -0,0 +1,217 @@
package tlv_test
import (
"bytes"
"io"
"math"
"testing"
"github.com/lightningnetwork/lnd/tlv"
)
type varIntTest struct {
Name string
Value uint64
Bytes []byte
ExpErr error
}
var writeVarIntTests = []varIntTest{
{
Name: "zero",
Value: 0x00,
Bytes: []byte{0x00},
},
{
Name: "one byte high",
Value: 0xfc,
Bytes: []byte{0xfc},
},
{
Name: "two byte low",
Value: 0xfd,
Bytes: []byte{0xfd, 0x00, 0xfd},
},
{
Name: "two byte high",
Value: 0xffff,
Bytes: []byte{0xfd, 0xff, 0xff},
},
{
Name: "four byte low",
Value: 0x10000,
Bytes: []byte{0xfe, 0x00, 0x01, 0x00, 0x00},
},
{
Name: "four byte high",
Value: 0xffffffff,
Bytes: []byte{0xfe, 0xff, 0xff, 0xff, 0xff},
},
{
Name: "eight byte low",
Value: 0x100000000,
Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
},
{
Name: "eight byte high",
Value: math.MaxUint64,
Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
},
}
// TestWriteVarInt asserts the behavior of tlv.WriteVarInt under various
// positive and negative test cases.
func TestWriteVarInt(t *testing.T) {
for _, test := range writeVarIntTests {
t.Run(test.Name, func(t *testing.T) {
testWriteVarInt(t, test)
})
}
}
func testWriteVarInt(t *testing.T, test varIntTest) {
var (
w bytes.Buffer
buf [8]byte
)
err := tlv.WriteVarInt(&w, test.Value, &buf)
if err != nil {
t.Fatalf("unable to encode %d as varint: %v",
test.Value, err)
}
if bytes.Compare(w.Bytes(), test.Bytes) != 0 {
t.Fatalf("expected bytes: %v, got %v",
test.Bytes, w.Bytes())
}
}
var readVarIntTests = []varIntTest{
{
Name: "zero",
Value: 0x00,
Bytes: []byte{0x00},
},
{
Name: "one byte high",
Value: 0xfc,
Bytes: []byte{0xfc},
},
{
Name: "two byte low",
Value: 0xfd,
Bytes: []byte{0xfd, 0x00, 0xfd},
},
{
Name: "two byte high",
Value: 0xffff,
Bytes: []byte{0xfd, 0xff, 0xff},
},
{
Name: "four byte low",
Value: 0x10000,
Bytes: []byte{0xfe, 0x00, 0x01, 0x00, 0x00},
},
{
Name: "four byte high",
Value: 0xffffffff,
Bytes: []byte{0xfe, 0xff, 0xff, 0xff, 0xff},
},
{
Name: "eight byte low",
Value: 0x100000000,
Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
},
{
Name: "eight byte high",
Value: math.MaxUint64,
Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
},
{
Name: "two byte not canonical",
Bytes: []byte{0xfd, 0x00, 0xfc},
ExpErr: tlv.ErrVarIntNotCanonical,
},
{
Name: "four byte not canonical",
Bytes: []byte{0xfe, 0x00, 0x00, 0xff, 0xff},
ExpErr: tlv.ErrVarIntNotCanonical,
},
{
Name: "eight byte not canonical",
Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff},
ExpErr: tlv.ErrVarIntNotCanonical,
},
{
Name: "two byte short read",
Bytes: []byte{0xfd, 0x00},
ExpErr: io.ErrUnexpectedEOF,
},
{
Name: "four byte short read",
Bytes: []byte{0xfe, 0xff, 0xff},
ExpErr: io.ErrUnexpectedEOF,
},
{
Name: "eight byte short read",
Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff},
ExpErr: io.ErrUnexpectedEOF,
},
{
Name: "one byte no read",
Bytes: []byte{},
ExpErr: io.EOF,
},
// The following cases are the reason for needing to make a custom
// version of the varint for the tlv package. For the varint encodings
// in btcd's wire package these would return io.EOF, since it is
// actually a composite of two calls to io.ReadFull. In TLV, we need to
// be able to distinguish whether no bytes were read at all from no
// Bytes being read on the second read as the latter is not a proper TLV
// stream. We handle this by returning io.ErrUnexpectedEOF if we
// encounter io.EOF on any of these secondary reads for larger values.
{
Name: "two byte no read",
Bytes: []byte{0xfd},
ExpErr: io.ErrUnexpectedEOF,
},
{
Name: "four byte no read",
Bytes: []byte{0xfe},
ExpErr: io.ErrUnexpectedEOF,
},
{
Name: "eight byte no read",
Bytes: []byte{0xff},
ExpErr: io.ErrUnexpectedEOF,
},
}
// TestReadVarInt asserts the behavior of tlv.ReadVarInt under various positive
// and negative test cases.
func TestReadVarInt(t *testing.T) {
for _, test := range readVarIntTests {
t.Run(test.Name, func(t *testing.T) {
testReadVarInt(t, test)
})
}
}
func testReadVarInt(t *testing.T, test varIntTest) {
var buf [8]byte
r := bytes.NewReader(test.Bytes)
val, err := tlv.ReadVarInt(r, &buf)
if err != nil && err != test.ExpErr {
t.Fatalf("expected decoding error: %v, got: %v",
test.ExpErr, err)
}
// If we expected a decoding error, there's no point checking the value.
if test.ExpErr != nil {
return
}
if val != test.Value {
t.Fatalf("expected value: %d, got %d", test.Value, val)
}
}