Merge pull request #3061 from cfromknecht/wire-tlv
tlv: add library for new message/payload serialization format
This commit is contained in:
commit
ea77ff91c2
161
tlv/bench_test.go
Normal file
161
tlv/bench_test.go
Normal file
@ -0,0 +1,161 @@
|
||||
package tlv_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||
)
|
||||
|
||||
// CreateSessionTLV mirrors the wtwire.CreateSession message, but uses TLV for
|
||||
// encoding/decoding.
|
||||
type CreateSessionTLV struct {
|
||||
BlobType blob.Type
|
||||
MaxUpdates uint16
|
||||
RewardBase uint32
|
||||
RewardRate uint32
|
||||
SweepFeeRate lnwallet.SatPerKWeight
|
||||
|
||||
tlvStream *tlv.Stream
|
||||
}
|
||||
|
||||
// EBlobType is an encoder for blob.Type.
|
||||
func EBlobType(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if t, ok := val.(*blob.Type); ok {
|
||||
return tlv.EUint16T(w, uint16(*t), buf)
|
||||
}
|
||||
return tlv.NewTypeForEncodingErr(val, "blob.Type")
|
||||
}
|
||||
|
||||
// EBlobType is an decoder for blob.Type.
|
||||
func DBlobType(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if typ, ok := val.(*blob.Type); ok {
|
||||
var t uint16
|
||||
err := tlv.DUint16(r, &t, buf, l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*typ = blob.Type(t)
|
||||
return nil
|
||||
}
|
||||
return tlv.NewTypeForDecodingErr(val, "blob.Type", l, 2)
|
||||
}
|
||||
|
||||
// ESatPerKW is an encoder for lnwallet.SatPerKWeight.
|
||||
func ESatPerKW(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if v, ok := val.(*lnwallet.SatPerKWeight); ok {
|
||||
return tlv.EUint64(w, uint64(*v), buf)
|
||||
}
|
||||
return tlv.NewTypeForEncodingErr(val, "lnwallet.SatPerKWeight")
|
||||
}
|
||||
|
||||
// DSatPerKW is an decoder for lnwallet.SatPerKWeight.
|
||||
func DSatPerKW(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if v, ok := val.(*lnwallet.SatPerKWeight); ok {
|
||||
var sat uint64
|
||||
err := tlv.DUint64(r, &sat, buf, l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*v = lnwallet.SatPerKWeight(sat)
|
||||
return nil
|
||||
}
|
||||
return tlv.NewTypeForDecodingErr(val, "lnwallet.SatPerKWeight", l, 8)
|
||||
}
|
||||
|
||||
// NewCreateSessionTLV initializes a new CreateSessionTLV message.
|
||||
func NewCreateSessionTLV() *CreateSessionTLV {
|
||||
m := &CreateSessionTLV{}
|
||||
m.tlvStream = tlv.MustNewStream(
|
||||
tlv.MakeStaticRecord(0, &m.BlobType, 2, EBlobType, DBlobType),
|
||||
tlv.MakePrimitiveRecord(1, &m.MaxUpdates),
|
||||
tlv.MakePrimitiveRecord(2, &m.RewardBase),
|
||||
tlv.MakePrimitiveRecord(3, &m.RewardRate),
|
||||
tlv.MakeStaticRecord(4, &m.SweepFeeRate, 8, ESatPerKW, DSatPerKW),
|
||||
)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Encode writes the CreateSessionTLV to the passed io.Writer.
|
||||
func (c *CreateSessionTLV) Encode(w io.Writer) error {
|
||||
return c.tlvStream.Encode(w)
|
||||
}
|
||||
|
||||
// Decode reads the CreateSessionTLV from the passed io.Reader.
|
||||
func (c *CreateSessionTLV) Decode(r io.Reader) error {
|
||||
return c.tlvStream.Decode(r)
|
||||
}
|
||||
|
||||
// BenchmarkEncodeCreateSession benchmarks encoding of the non-TLV
|
||||
// CreateSession.
|
||||
func BenchmarkEncodeCreateSession(t *testing.B) {
|
||||
m := &wtwire.CreateSession{}
|
||||
|
||||
t.ReportAllocs()
|
||||
t.ResetTimer()
|
||||
|
||||
var err error
|
||||
for i := 0; i < t.N; i++ {
|
||||
err = m.Encode(ioutil.Discard, 0)
|
||||
}
|
||||
_ = err
|
||||
}
|
||||
|
||||
// BenchmarkEncodeCreateSessionTLV benchmarks encoding of the TLV CreateSession.
|
||||
func BenchmarkEncodeCreateSessionTLV(t *testing.B) {
|
||||
m := NewCreateSessionTLV()
|
||||
|
||||
t.ReportAllocs()
|
||||
t.ResetTimer()
|
||||
|
||||
var err error
|
||||
for i := 0; i < t.N; i++ {
|
||||
err = m.Encode(ioutil.Discard)
|
||||
}
|
||||
_ = err
|
||||
}
|
||||
|
||||
// BenchmarkDecodeCreateSession benchmarks encoding of the non-TLV
|
||||
// CreateSession.
|
||||
func BenchmarkDecodeCreateSession(t *testing.B) {
|
||||
m := &wtwire.CreateSession{}
|
||||
|
||||
var b bytes.Buffer
|
||||
m.Encode(&b, 0)
|
||||
r := bytes.NewReader(b.Bytes())
|
||||
|
||||
t.ReportAllocs()
|
||||
t.ResetTimer()
|
||||
|
||||
var err error
|
||||
for i := 0; i < t.N; i++ {
|
||||
r.Seek(0, 0)
|
||||
err = m.Decode(r, 0)
|
||||
}
|
||||
_ = err
|
||||
}
|
||||
|
||||
// BenchmarkDecodeCreateSessionTLV benchmarks decoding of the TLV CreateSession.
|
||||
func BenchmarkDecodeCreateSessionTLV(t *testing.B) {
|
||||
m := NewCreateSessionTLV()
|
||||
|
||||
var b bytes.Buffer
|
||||
var err error
|
||||
m.Encode(&b)
|
||||
r := bytes.NewReader(b.Bytes())
|
||||
|
||||
t.ReportAllocs()
|
||||
t.ResetTimer()
|
||||
|
||||
for i := 0; i < t.N; i++ {
|
||||
r.Seek(0, 0)
|
||||
err = m.Decode(r)
|
||||
}
|
||||
_ = err
|
||||
}
|
309
tlv/primitive.go
Normal file
309
tlv/primitive.go
Normal file
@ -0,0 +1,309 @@
|
||||
package tlv
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
)
|
||||
|
||||
// ErrTypeForEncoding signals that an incorrect type was passed to an Encoder.
|
||||
type ErrTypeForEncoding struct {
|
||||
val interface{}
|
||||
expType string
|
||||
}
|
||||
|
||||
// NewTypeForEncodingErr creates a new ErrTypeForEncoding given the incorrect
|
||||
// val and the expected type.
|
||||
func NewTypeForEncodingErr(val interface{}, expType string) ErrTypeForEncoding {
|
||||
return ErrTypeForEncoding{
|
||||
val: val,
|
||||
expType: expType,
|
||||
}
|
||||
}
|
||||
|
||||
// Error returns a human-readable description of the type mismatch.
|
||||
func (e ErrTypeForEncoding) Error() string {
|
||||
return fmt.Sprintf("ErrTypeForEncoding want (type: *%s), "+
|
||||
"got (type: %T)", e.expType, e.val)
|
||||
}
|
||||
|
||||
// ErrTypeForDecoding signals that an incorrect type was passed to a Decoder or
|
||||
// that the expected length of the encoding is different from that required by
|
||||
// the expected type.
|
||||
type ErrTypeForDecoding struct {
|
||||
val interface{}
|
||||
expType string
|
||||
valLength uint64
|
||||
expLength uint64
|
||||
}
|
||||
|
||||
// NewTypeForDecodingErr creates a new ErrTypeForDecoding given the incorrect
|
||||
// val and expected type, or the mismatch in their expected lengths.
|
||||
func NewTypeForDecodingErr(val interface{}, expType string,
|
||||
valLength, expLength uint64) ErrTypeForDecoding {
|
||||
|
||||
return ErrTypeForDecoding{
|
||||
val: val,
|
||||
expType: expType,
|
||||
valLength: valLength,
|
||||
expLength: expLength,
|
||||
}
|
||||
}
|
||||
|
||||
// Error returns a human-readable description of the type mismatch.
|
||||
func (e ErrTypeForDecoding) Error() string {
|
||||
return fmt.Sprintf("ErrTypeForDecoding want (type: *%s, length: %v), "+
|
||||
"got (type: %T, length: %v)", e.expType, e.expLength, e.val,
|
||||
e.valLength)
|
||||
}
|
||||
|
||||
var (
|
||||
byteOrder = binary.BigEndian
|
||||
)
|
||||
|
||||
// EUint8 is an Encoder for uint8 values. An error is returned if val is not a
|
||||
// *uint8.
|
||||
func EUint8(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if i, ok := val.(*uint8); ok {
|
||||
buf[0] = *i
|
||||
_, err := w.Write(buf[:1])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForEncoding{val, "uint8"}
|
||||
}
|
||||
|
||||
// EUint8T encodes a uint8 val to the provided io.Writer. This method is exposed
|
||||
// so that encodings for custom uint8-like types can be created without
|
||||
// incurring an extra heap allocation.
|
||||
func EUint8T(w io.Writer, val uint8, buf *[8]byte) error {
|
||||
buf[0] = val
|
||||
_, err := w.Write(buf[:1])
|
||||
return err
|
||||
}
|
||||
|
||||
// EUint16 is an Encoder for uint16 values. An error is returned if val is not a
|
||||
// *uint16.
|
||||
func EUint16(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if i, ok := val.(*uint16); ok {
|
||||
byteOrder.PutUint16(buf[:2], *i)
|
||||
_, err := w.Write(buf[:2])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForEncoding{val, "uint16"}
|
||||
}
|
||||
|
||||
// EUint16T encodes a uint16 val to the provided io.Writer. This method is
|
||||
// exposed so that encodings for custom uint16-like types can be created without
|
||||
// incurring an extra heap allocation.
|
||||
func EUint16T(w io.Writer, val uint16, buf *[8]byte) error {
|
||||
byteOrder.PutUint16(buf[:2], val)
|
||||
_, err := w.Write(buf[:2])
|
||||
return err
|
||||
}
|
||||
|
||||
// EUint32 is an Encoder for uint32 values. An error is returned if val is not a
|
||||
// *uint32.
|
||||
func EUint32(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if i, ok := val.(*uint32); ok {
|
||||
byteOrder.PutUint32(buf[:4], *i)
|
||||
_, err := w.Write(buf[:4])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForEncoding{val, "uint32"}
|
||||
}
|
||||
|
||||
// EUint32T encodes a uint32 val to the provided io.Writer. This method is
|
||||
// exposed so that encodings for custom uint32-like types can be created without
|
||||
// incurring an extra heap allocation.
|
||||
func EUint32T(w io.Writer, val uint32, buf *[8]byte) error {
|
||||
byteOrder.PutUint32(buf[:4], val)
|
||||
_, err := w.Write(buf[:4])
|
||||
return err
|
||||
}
|
||||
|
||||
// EUint64 is an Encoder for uint64 values. An error is returned if val is not a
|
||||
// *uint64.
|
||||
func EUint64(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if i, ok := val.(*uint64); ok {
|
||||
byteOrder.PutUint64(buf[:], *i)
|
||||
_, err := w.Write(buf[:])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForEncoding{val, "uint64"}
|
||||
}
|
||||
|
||||
// EUint64T encodes a uint64 val to the provided io.Writer. This method is
|
||||
// exposed so that encodings for custom uint64-like types can be created without
|
||||
// incurring an extra heap allocation.
|
||||
func EUint64T(w io.Writer, val uint64, buf *[8]byte) error {
|
||||
byteOrder.PutUint64(buf[:], val)
|
||||
_, err := w.Write(buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
// DUint8 is a Decoder for uint8 values. An error is returned if val is not a
|
||||
// *uint8.
|
||||
func DUint8(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if i, ok := val.(*uint8); ok && l == 1 {
|
||||
if _, err := io.ReadFull(r, buf[:1]); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = buf[0]
|
||||
return nil
|
||||
}
|
||||
return ErrTypeForDecoding{val, "uint8", l, 1}
|
||||
}
|
||||
|
||||
// DUint16 is a Decoder for uint16 values. An error is returned if val is not a
|
||||
// *uint16.
|
||||
func DUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if i, ok := val.(*uint16); ok && l == 2 {
|
||||
if _, err := io.ReadFull(r, buf[:2]); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = byteOrder.Uint16(buf[:2])
|
||||
return nil
|
||||
}
|
||||
return ErrTypeForDecoding{val, "uint16", l, 2}
|
||||
}
|
||||
|
||||
// DUint32 is a Decoder for uint32 values. An error is returned if val is not a
|
||||
// *uint32.
|
||||
func DUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if i, ok := val.(*uint32); ok && l == 4 {
|
||||
if _, err := io.ReadFull(r, buf[:4]); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = byteOrder.Uint32(buf[:4])
|
||||
return nil
|
||||
}
|
||||
return ErrTypeForDecoding{val, "uint32", l, 4}
|
||||
}
|
||||
|
||||
// DUint64 is a Decoder for uint64 values. An error is returned if val is not a
|
||||
// *uint64.
|
||||
func DUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if i, ok := val.(*uint64); ok && l == 8 {
|
||||
if _, err := io.ReadFull(r, buf[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = byteOrder.Uint64(buf[:])
|
||||
return nil
|
||||
}
|
||||
return ErrTypeForDecoding{val, "uint64", l, 8}
|
||||
}
|
||||
|
||||
// EBytes32 is an Encoder for 32-byte arrays. An error is returned if val is not
|
||||
// a *[32]byte.
|
||||
func EBytes32(w io.Writer, val interface{}, _ *[8]byte) error {
|
||||
if b, ok := val.(*[32]byte); ok {
|
||||
_, err := w.Write(b[:])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForEncoding{val, "[32]byte"}
|
||||
}
|
||||
|
||||
// DBytes32 is a Decoder for 32-byte arrays. An error is returned if val is not
|
||||
// a *[32]byte.
|
||||
func DBytes32(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
|
||||
if b, ok := val.(*[32]byte); ok && l == 32 {
|
||||
_, err := io.ReadFull(r, b[:])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForDecoding{val, "[32]byte", l, 32}
|
||||
}
|
||||
|
||||
// EBytes33 is an Encoder for 33-byte arrays. An error is returned if val is not
|
||||
// a *[33]byte.
|
||||
func EBytes33(w io.Writer, val interface{}, _ *[8]byte) error {
|
||||
if b, ok := val.(*[33]byte); ok {
|
||||
_, err := w.Write(b[:])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForEncoding{val, "[33]byte"}
|
||||
}
|
||||
|
||||
// DBytes33 is a Decoder for 33-byte arrays. An error is returned if val is not
|
||||
// a *[33]byte.
|
||||
func DBytes33(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
|
||||
if b, ok := val.(*[33]byte); ok {
|
||||
_, err := io.ReadFull(r, b[:])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForDecoding{val, "[33]byte", l, 33}
|
||||
}
|
||||
|
||||
// EBytes64 is an Encoder for 64-byte arrays. An error is returned if val is not
|
||||
// a *[64]byte.
|
||||
func EBytes64(w io.Writer, val interface{}, _ *[8]byte) error {
|
||||
if b, ok := val.(*[64]byte); ok {
|
||||
_, err := w.Write(b[:])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForEncoding{val, "[64]byte"}
|
||||
}
|
||||
|
||||
// DBytes64 is an Decoder for 64-byte arrays. An error is returned if val is not
|
||||
// a *[64]byte.
|
||||
func DBytes64(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
|
||||
if b, ok := val.(*[64]byte); ok && l == 64 {
|
||||
_, err := io.ReadFull(r, b[:])
|
||||
return err
|
||||
}
|
||||
return ErrTypeForDecoding{val, "[64]byte", l, 64}
|
||||
}
|
||||
|
||||
// EPubKey is an Encoder for *btcec.PublicKey values. An error is returned if
|
||||
// val is not a **btcec.PublicKey.
|
||||
func EPubKey(w io.Writer, val interface{}, _ *[8]byte) error {
|
||||
if pk, ok := val.(**btcec.PublicKey); ok {
|
||||
_, err := w.Write((*pk).SerializeCompressed())
|
||||
return err
|
||||
}
|
||||
return ErrTypeForEncoding{val, "*btcec.PublicKey"}
|
||||
}
|
||||
|
||||
// DPubKey is a Decoder for *btcec.PublicKey values. An error is returned if val
|
||||
// is not a **btcec.PublicKey.
|
||||
func DPubKey(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
|
||||
if pk, ok := val.(**btcec.PublicKey); ok && l == 33 {
|
||||
var b [33]byte
|
||||
_, err := io.ReadFull(r, b[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p, err := btcec.ParsePubKey(b[:], btcec.S256())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*pk = p
|
||||
|
||||
return nil
|
||||
}
|
||||
return ErrTypeForDecoding{val, "*btcec.PublicKey", l, 33}
|
||||
}
|
||||
|
||||
// EVarBytes is an Encoder for variable byte slices. An error is returned if val
|
||||
// is not *[]byte.
|
||||
func EVarBytes(w io.Writer, val interface{}, _ *[8]byte) error {
|
||||
if b, ok := val.(*[]byte); ok {
|
||||
_, err := w.Write(*b)
|
||||
return err
|
||||
}
|
||||
return ErrTypeForEncoding{val, "[]byte"}
|
||||
}
|
||||
|
||||
// DVarBytes is a Decoder for variable byte slices. An error is returned if val
|
||||
// is not *[]byte.
|
||||
func DVarBytes(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
|
||||
if b, ok := val.(*[]byte); ok {
|
||||
*b = make([]byte, l)
|
||||
_, err := io.ReadFull(r, *b)
|
||||
return err
|
||||
}
|
||||
return ErrTypeForDecoding{val, "[]byte", l, l}
|
||||
}
|
153
tlv/record.go
Normal file
153
tlv/record.go
Normal file
@ -0,0 +1,153 @@
|
||||
package tlv
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
)
|
||||
|
||||
// Type is an 64-bit identifier for a TLV Record.
|
||||
type Type uint64
|
||||
|
||||
// Encoder is a signature for methods that can encode TLV values. An error
|
||||
// should be returned if the Encoder cannot support the underlying type of val.
|
||||
// The provided scratch buffer must be non-nil.
|
||||
type Encoder func(w io.Writer, val interface{}, buf *[8]byte) error
|
||||
|
||||
// Decoder is a signature for methods that can decode TLV values. An error
|
||||
// should be returned if the Decoder cannot support the underlying type of val.
|
||||
// The provided scratch buffer must be non-nil.
|
||||
type Decoder func(r io.Reader, val interface{}, buf *[8]byte, l uint64) error
|
||||
|
||||
// ENOP is an encoder that doesn't modify the io.Writer and never fails.
|
||||
func ENOP(io.Writer, interface{}, *[8]byte) error { return nil }
|
||||
|
||||
// DNOP is an encoder that doesn't modify the io.Reader and never fails.
|
||||
func DNOP(io.Reader, interface{}, *[8]byte, uint64) error { return nil }
|
||||
|
||||
// SizeFunc is a function that can compute the length of a given field. Since
|
||||
// the size of the underlying field can change, this allows the size of the
|
||||
// field to be evaluated at the time of encoding.
|
||||
type SizeFunc func() uint64
|
||||
|
||||
// SizeVarBytes returns a SizeFunc that can compute the length of a byte slice.
|
||||
func SizeVarBytes(e *[]byte) SizeFunc {
|
||||
return func() uint64 {
|
||||
return uint64(len(*e))
|
||||
}
|
||||
}
|
||||
|
||||
// Record holds the required information to encode or decode a TLV record.
|
||||
type Record struct {
|
||||
value interface{}
|
||||
typ Type
|
||||
staticSize uint64
|
||||
sizeFunc SizeFunc
|
||||
encoder Encoder
|
||||
decoder Decoder
|
||||
}
|
||||
|
||||
// Size returns the size of the Record's value. If no static size is known, the
|
||||
// dynamic size will be evaluated.
|
||||
func (f *Record) Size() uint64 {
|
||||
if f.sizeFunc == nil {
|
||||
return f.staticSize
|
||||
}
|
||||
|
||||
return f.sizeFunc()
|
||||
}
|
||||
|
||||
// MakePrimitiveRecord creates a record for common types.
|
||||
func MakePrimitiveRecord(typ Type, val interface{}) Record {
|
||||
var (
|
||||
staticSize uint64
|
||||
sizeFunc SizeFunc
|
||||
encoder Encoder
|
||||
decoder Decoder
|
||||
)
|
||||
switch e := val.(type) {
|
||||
case *uint8:
|
||||
staticSize = 1
|
||||
encoder = EUint8
|
||||
decoder = DUint8
|
||||
|
||||
case *uint16:
|
||||
staticSize = 2
|
||||
encoder = EUint16
|
||||
decoder = DUint16
|
||||
|
||||
case *uint32:
|
||||
staticSize = 4
|
||||
encoder = EUint32
|
||||
decoder = DUint32
|
||||
|
||||
case *uint64:
|
||||
staticSize = 8
|
||||
encoder = EUint64
|
||||
decoder = DUint64
|
||||
|
||||
case *[32]byte:
|
||||
staticSize = 32
|
||||
encoder = EBytes32
|
||||
decoder = DBytes32
|
||||
|
||||
case *[33]byte:
|
||||
staticSize = 33
|
||||
encoder = EBytes33
|
||||
decoder = DBytes33
|
||||
|
||||
case **btcec.PublicKey:
|
||||
staticSize = 33
|
||||
encoder = EPubKey
|
||||
decoder = DPubKey
|
||||
|
||||
case *[64]byte:
|
||||
staticSize = 64
|
||||
encoder = EBytes64
|
||||
decoder = DBytes64
|
||||
|
||||
case *[]byte:
|
||||
sizeFunc = SizeVarBytes(e)
|
||||
encoder = EVarBytes
|
||||
decoder = DVarBytes
|
||||
|
||||
default:
|
||||
panic("unknown primitive type")
|
||||
}
|
||||
|
||||
return Record{
|
||||
value: val,
|
||||
typ: typ,
|
||||
staticSize: staticSize,
|
||||
sizeFunc: sizeFunc,
|
||||
encoder: encoder,
|
||||
decoder: decoder,
|
||||
}
|
||||
}
|
||||
|
||||
// MakeStaticRecord creates a record for a field of fixed-size
|
||||
func MakeStaticRecord(typ Type, val interface{}, size uint64, encoder Encoder,
|
||||
decoder Decoder) Record {
|
||||
|
||||
return Record{
|
||||
value: val,
|
||||
typ: typ,
|
||||
staticSize: size,
|
||||
encoder: encoder,
|
||||
decoder: decoder,
|
||||
}
|
||||
}
|
||||
|
||||
// MakeDynamicRecord creates a record whose size may vary, and will be
|
||||
// determined at the time of encoding via sizeFunc.
|
||||
func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc,
|
||||
encoder Encoder, decoder Decoder) Record {
|
||||
|
||||
return Record{
|
||||
value: val,
|
||||
typ: typ,
|
||||
sizeFunc: sizeFunc,
|
||||
encoder: encoder,
|
||||
decoder: decoder,
|
||||
}
|
||||
}
|
280
tlv/stream.go
Normal file
280
tlv/stream.go
Normal file
@ -0,0 +1,280 @@
|
||||
package tlv
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
)
|
||||
|
||||
// ErrStreamNotCanonical signals that a decoded stream does not contain records
|
||||
// sorting by monotonically-increasing type.
|
||||
var ErrStreamNotCanonical = errors.New("tlv stream is not canonical")
|
||||
|
||||
// ErrUnknownRequiredType is an error returned when decoding an unknown and even
|
||||
// type from a Stream.
|
||||
type ErrUnknownRequiredType Type
|
||||
|
||||
// Error returns a human-readable description of unknown required type.
|
||||
func (t ErrUnknownRequiredType) Error() string {
|
||||
return fmt.Sprintf("unknown required type: %d", t)
|
||||
}
|
||||
|
||||
// Stream defines a TLV stream that can be used for encoding or decoding a set
|
||||
// of TLV Records.
|
||||
type Stream struct {
|
||||
records []Record
|
||||
buf [8]byte
|
||||
}
|
||||
|
||||
// NewStream creates a new TLV Stream given an encoding codec, a decoding codec,
|
||||
// and a set of known records.
|
||||
func NewStream(records ...Record) (*Stream, error) {
|
||||
// Assert that the ordering of the Records is canonical and appear in
|
||||
// ascending order of type.
|
||||
var (
|
||||
min Type
|
||||
overflow bool
|
||||
)
|
||||
for _, record := range records {
|
||||
if overflow || record.typ < min {
|
||||
return nil, ErrStreamNotCanonical
|
||||
}
|
||||
if record.encoder == nil {
|
||||
record.encoder = ENOP
|
||||
}
|
||||
if record.decoder == nil {
|
||||
record.decoder = DNOP
|
||||
}
|
||||
if record.typ == math.MaxUint64 {
|
||||
overflow = true
|
||||
}
|
||||
min = record.typ + 1
|
||||
}
|
||||
|
||||
return &Stream{
|
||||
records: records,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MustNewStream creates a new TLV Stream given an encoding codec, a decoding
|
||||
// codec, and a set of known records. If an error is encountered in creating the
|
||||
// stream, this method will panic instead of returning the error.
|
||||
func MustNewStream(records ...Record) *Stream {
|
||||
stream, err := NewStream(records...)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
return stream
|
||||
}
|
||||
|
||||
// Encode writes a Stream to the passed io.Writer. Each of the Records known to
|
||||
// the Stream is written in ascending order of their type so as to be canonical.
|
||||
//
|
||||
// The stream is constructed by concatenating the individual, serialized Records
|
||||
// where each record has the following format:
|
||||
// [varint: type]
|
||||
// [varint: length]
|
||||
// [length: value]
|
||||
//
|
||||
// An error is returned if the io.Writer fails to accept bytes from the
|
||||
// encoding, and nothing else. The ordering of the Records is asserted upon the
|
||||
// creation of a Stream, and thus the output will be by definition canonical.
|
||||
func (s *Stream) Encode(w io.Writer) error {
|
||||
// Iterate through all known records, if any, serializing each record's
|
||||
// type, length and value.
|
||||
for i := range s.records {
|
||||
rec := &s.records[i]
|
||||
|
||||
// Write the record's type as a varint.
|
||||
err := WriteVarInt(w, uint64(rec.typ), &s.buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the record's length as a varint.
|
||||
err = WriteVarInt(w, rec.Size(), &s.buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encode the current record's value using the stream's codec.
|
||||
err = rec.encoder(w, rec.value, &s.buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode deserializes TLV Stream from the passed io.Reader. The Stream will
|
||||
// inspect each record that is parsed and check to see if it has a corresponding
|
||||
// Record to facilitate deserialization of that field. If the record is unknown,
|
||||
// the Stream will discard the record's bytes and proceed to the subsequent
|
||||
// record.
|
||||
//
|
||||
// Each record has the following format:
|
||||
// [varint: type]
|
||||
// [varint: length]
|
||||
// [length: value]
|
||||
//
|
||||
// A series of (possibly zero) records are concatenated into a stream, this
|
||||
// example contains two records:
|
||||
//
|
||||
// (t: 0x01, l: 0x04, v: 0xff, 0xff, 0xff, 0xff)
|
||||
// (t: 0x02, l: 0x01, v: 0x01)
|
||||
//
|
||||
// This method asserts that the byte stream is canonical, namely that each
|
||||
// record is unique and that all records are sorted in ascending order. An
|
||||
// ErrNotCanonicalStream error is returned if the encoded TLV stream is not.
|
||||
//
|
||||
// We permit an io.EOF error only when reading the type byte which signals that
|
||||
// the last record was read cleanly and we should stop parsing. All other io.EOF
|
||||
// or io.ErrUnexpectedEOF errors are returned.
|
||||
func (s *Stream) Decode(r io.Reader) error {
|
||||
var (
|
||||
typ Type
|
||||
min Type
|
||||
recordIdx int
|
||||
overflow bool
|
||||
)
|
||||
|
||||
// Iterate through all possible type identifiers. As types are read from
|
||||
// the io.Reader, min will skip forward to the last read type.
|
||||
for {
|
||||
// Read the next varint type.
|
||||
t, err := ReadVarInt(r, &s.buf)
|
||||
switch {
|
||||
|
||||
// We'll silence an EOF when zero bytes remain, meaning the
|
||||
// stream was cleanly encoded.
|
||||
case err == io.EOF:
|
||||
return nil
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
typ = Type(t)
|
||||
|
||||
// Assert that this type is greater than any previously read.
|
||||
// If we've already overflowed and we parsed another type, the
|
||||
// stream is not canonical. This check prevents us from accepts
|
||||
// encodings that have duplicate records or from accepting an
|
||||
// unsorted series.
|
||||
if overflow || typ < min {
|
||||
return ErrStreamNotCanonical
|
||||
}
|
||||
|
||||
// Read the varint length.
|
||||
length, err := ReadVarInt(r, &s.buf)
|
||||
switch {
|
||||
|
||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||
// results in an invalid record.
|
||||
case err == io.EOF:
|
||||
return io.ErrUnexpectedEOF
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
// Search the records known to the stream for this type. We'll
|
||||
// begin the search and recordIdx and walk forward until we find
|
||||
// it or the next record's type is larger.
|
||||
rec, newIdx, ok := s.getRecord(typ, recordIdx)
|
||||
switch {
|
||||
|
||||
// We know of this record type, proceed to decode the value.
|
||||
// This method asserts that length bytes are read in the
|
||||
// process, and returns an error if the number of bytes is not
|
||||
// exactly length.
|
||||
case ok:
|
||||
err := rec.decoder(r, rec.value, &s.buf, length)
|
||||
switch {
|
||||
|
||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||
// results in an invalid record.
|
||||
case err == io.EOF:
|
||||
return io.ErrUnexpectedEOF
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
// This record type is unknown to the stream, fail if the type
|
||||
// is even meaning that we are required to understand it.
|
||||
case typ%2 == 0:
|
||||
return ErrUnknownRequiredType(typ)
|
||||
|
||||
// Otherwise, the record type is unknown and is odd, discard the
|
||||
// number of bytes specified by length.
|
||||
default:
|
||||
_, err := io.CopyN(ioutil.Discard, r, int64(length))
|
||||
switch {
|
||||
|
||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||
// results in an invalid record.
|
||||
case err == io.EOF:
|
||||
return io.ErrUnexpectedEOF
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update our record index so that we can begin our next search
|
||||
// from where we left off.
|
||||
recordIdx = newIdx
|
||||
|
||||
// If we've parsed the largest possible type, the next loop will
|
||||
// overflow back to zero. However, we need to attempt parsing
|
||||
// the next type to ensure that the stream is empty.
|
||||
if typ == math.MaxUint64 {
|
||||
overflow = true
|
||||
}
|
||||
|
||||
// Finally, set our lower bound on the next accepted type.
|
||||
min = typ + 1
|
||||
}
|
||||
}
|
||||
|
||||
// getRecord searches for a record matching typ known to the stream. The boolean
|
||||
// return value indicates whether the record is known to the stream. The integer
|
||||
// return value carries the index from where getRecord should be invoked on the
|
||||
// subsequent call. The first call to getRecord should always use an idx of 0.
|
||||
func (s *Stream) getRecord(typ Type, idx int) (Record, int, bool) {
|
||||
for idx < len(s.records) {
|
||||
record := s.records[idx]
|
||||
switch {
|
||||
|
||||
// Found target record, return it to the caller. The next index
|
||||
// returned points to the immediately following record.
|
||||
case record.typ == typ:
|
||||
return record, idx + 1, true
|
||||
|
||||
// This record's type is lower than the target. Advance our
|
||||
// index and continue to the next record which will have a
|
||||
// strictly higher type.
|
||||
case record.typ < typ:
|
||||
idx++
|
||||
continue
|
||||
|
||||
// This record's type is larger than the target, hence we have
|
||||
// no record matching the current type. Return the current index
|
||||
// so that we can start our search from here when processing the
|
||||
// next tlv record.
|
||||
default:
|
||||
return Record{}, idx, false
|
||||
}
|
||||
}
|
||||
|
||||
// All known records are exhausted.
|
||||
return Record{}, idx, false
|
||||
}
|
559
tlv/tlv_test.go
Normal file
559
tlv/tlv_test.go
Normal file
@ -0,0 +1,559 @@
|
||||
package tlv_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
type nodeAmts struct {
|
||||
nodeID *btcec.PublicKey
|
||||
amt1 uint64
|
||||
amt2 uint64
|
||||
}
|
||||
|
||||
func ENodeAmts(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if t, ok := val.(*nodeAmts); ok {
|
||||
if err := tlv.EPubKey(w, &t.nodeID, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tlv.EUint64T(w, t.amt1, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return tlv.EUint64T(w, t.amt2, buf)
|
||||
}
|
||||
return tlv.NewTypeForEncodingErr(val, "nodeAmts")
|
||||
}
|
||||
|
||||
func DNodeAmts(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if t, ok := val.(*nodeAmts); ok && l == 49 {
|
||||
if err := tlv.DPubKey(r, &t.nodeID, buf, 33); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tlv.DUint64(r, &t.amt1, buf, 8); err != nil {
|
||||
return err
|
||||
}
|
||||
return tlv.DUint64(r, &t.amt2, buf, 8)
|
||||
}
|
||||
return tlv.NewTypeForDecodingErr(val, "nodeAmts", l, 49)
|
||||
}
|
||||
|
||||
type N1 struct {
|
||||
amt uint64
|
||||
scid uint64
|
||||
nodeAmts nodeAmts
|
||||
cltvDelta uint16
|
||||
|
||||
stream *tlv.Stream
|
||||
}
|
||||
|
||||
func (n *N1) sizeAmt() uint64 {
|
||||
return tlv.SizeTUint64(n.amt)
|
||||
}
|
||||
|
||||
func NewN1() *N1 {
|
||||
n := new(N1)
|
||||
|
||||
n.stream = tlv.MustNewStream(
|
||||
tlv.MakeDynamicRecord(
|
||||
1, &n.amt, n.sizeAmt, tlv.ETUint64, tlv.DTUint64,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(2, &n.scid),
|
||||
tlv.MakeStaticRecord(3, &n.nodeAmts, 49, ENodeAmts, DNodeAmts),
|
||||
tlv.MakePrimitiveRecord(254, &n.cltvDelta),
|
||||
)
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *N1) Encode(w io.Writer) error {
|
||||
return n.stream.Encode(w)
|
||||
}
|
||||
|
||||
func (n *N1) Decode(r io.Reader) error {
|
||||
return n.stream.Decode(r)
|
||||
}
|
||||
|
||||
type N2 struct {
|
||||
amt uint64
|
||||
cltvExpiry uint32
|
||||
|
||||
stream *tlv.Stream
|
||||
}
|
||||
|
||||
func (n *N2) sizeAmt() uint64 {
|
||||
return tlv.SizeTUint64(n.amt)
|
||||
}
|
||||
|
||||
func (n *N2) sizeCltv() uint64 {
|
||||
return tlv.SizeTUint32(n.cltvExpiry)
|
||||
}
|
||||
|
||||
func NewN2() *N2 {
|
||||
n := new(N2)
|
||||
|
||||
n.stream = tlv.MustNewStream(
|
||||
tlv.MakeDynamicRecord(
|
||||
0, &n.amt, n.sizeAmt, tlv.ETUint64, tlv.DTUint64,
|
||||
),
|
||||
tlv.MakeDynamicRecord(
|
||||
11, &n.cltvExpiry, n.sizeCltv, tlv.ETUint32, tlv.DTUint32,
|
||||
),
|
||||
)
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *N2) Encode(w io.Writer) error {
|
||||
return n.stream.Encode(w)
|
||||
}
|
||||
|
||||
func (n *N2) Decode(r io.Reader) error {
|
||||
return n.stream.Decode(r)
|
||||
}
|
||||
|
||||
var tlvDecodingFailureTests = []struct {
|
||||
name string
|
||||
bytes []byte
|
||||
expErr error
|
||||
|
||||
// skipN2 if true, will cause the test to only be executed on N1.
|
||||
skipN2 bool
|
||||
}{
|
||||
{
|
||||
name: "type truncated",
|
||||
bytes: []byte{0xfd},
|
||||
expErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "type truncated",
|
||||
bytes: []byte{0xfd, 0x01},
|
||||
expErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "not minimally encoded type",
|
||||
bytes: []byte{0xfd, 0x00, 0x01}, // spec has trailing 0x00
|
||||
expErr: tlv.ErrVarIntNotCanonical,
|
||||
},
|
||||
{
|
||||
name: "missing length",
|
||||
bytes: []byte{0xfd, 0x01, 0x01},
|
||||
expErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "length truncated",
|
||||
bytes: []byte{0x0f, 0xfd},
|
||||
expErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "length truncated",
|
||||
bytes: []byte{0x0f, 0xfd, 0x26},
|
||||
expErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "missing value",
|
||||
bytes: []byte{0x0f, 0xfd, 0x26, 0x02},
|
||||
expErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "not minimally encoded length",
|
||||
bytes: []byte{0x0f, 0xfd, 0x00, 0x01}, // spec has trailing 0x00
|
||||
expErr: tlv.ErrVarIntNotCanonical,
|
||||
},
|
||||
{
|
||||
name: "value truncated",
|
||||
bytes: []byte{0x0f, 0xfd, 0x02, 0x01,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
expErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "unknown even type",
|
||||
bytes: []byte{0x12, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x12),
|
||||
},
|
||||
{
|
||||
name: "unknown even type",
|
||||
bytes: []byte{0xfd, 0x01, 0x02, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x102),
|
||||
},
|
||||
{
|
||||
name: "unknown even type",
|
||||
bytes: []byte{0xfe, 0x01, 0x00, 0x00, 0x02, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x01000002),
|
||||
},
|
||||
{
|
||||
name: "unknown even type",
|
||||
bytes: []byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x0100000000000002),
|
||||
},
|
||||
{
|
||||
name: "greater than encoding length for n1's amt",
|
||||
bytes: []byte{0x01, 0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 9, 8),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "encoding for n1's amt is not minimal",
|
||||
bytes: []byte{0x01, 0x01, 0x00},
|
||||
expErr: tlv.ErrTUintNotMinimal,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "encoding for n1's amt is not minimal",
|
||||
bytes: []byte{0x01, 0x02, 0x00, 0x01},
|
||||
expErr: tlv.ErrTUintNotMinimal,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "encoding for n1's amt is not minimal",
|
||||
bytes: []byte{0x01, 0x03, 0x00, 0x01, 0x00},
|
||||
expErr: tlv.ErrTUintNotMinimal,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "encoding for n1's amt is not minimal",
|
||||
bytes: []byte{0x01, 0x04, 0x00, 0x01, 0x00, 0x00},
|
||||
expErr: tlv.ErrTUintNotMinimal,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "encoding for n1's amt is not minimal",
|
||||
bytes: []byte{0x01, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00},
|
||||
expErr: tlv.ErrTUintNotMinimal,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "encoding for n1's amt is not minimal",
|
||||
bytes: []byte{0x01, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
expErr: tlv.ErrTUintNotMinimal,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "encoding for n1's amt is not minimal",
|
||||
bytes: []byte{0x01, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
expErr: tlv.ErrTUintNotMinimal,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "encoding for n1's amt is not minimal",
|
||||
bytes: []byte{0x01, 0x08, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
expErr: tlv.ErrTUintNotMinimal,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "less than encoding length for n1's scid",
|
||||
bytes: []byte{0x02, 0x07, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
|
||||
expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 7, 8),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "less than encoding length for n1's scid",
|
||||
bytes: []byte{0x02, 0x09, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
|
||||
expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 9, 8),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "less than encoding length for n1's nodeAmts",
|
||||
bytes: []byte{0x03, 0x29,
|
||||
0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
|
||||
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
|
||||
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
|
||||
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x01,
|
||||
},
|
||||
expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 41, 49),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "less than encoding length for n1's nodeAmts",
|
||||
bytes: []byte{0x03, 0x30,
|
||||
0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
|
||||
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
|
||||
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
|
||||
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x01,
|
||||
},
|
||||
expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 48, 49),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "n1's node_id is not a valid point",
|
||||
bytes: []byte{0x03, 0x31,
|
||||
0x04, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
|
||||
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
|
||||
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
|
||||
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x02,
|
||||
},
|
||||
expErr: errors.New("invalid magic in compressed pubkey string: 4"),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "greater than encoding length for n1's nodeAmts",
|
||||
bytes: []byte{0x03, 0x32,
|
||||
0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
|
||||
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
|
||||
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
|
||||
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
|
||||
},
|
||||
expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 50, 49),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "unknown required type or n1",
|
||||
bytes: []byte{0x00, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x00),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "less than encoding length for n1's cltvDelta",
|
||||
bytes: []byte{0xfd, 0x00, 0x0fe, 0x00},
|
||||
expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 0, 2),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "less than encoding length for n1's cltvDelta",
|
||||
bytes: []byte{0xfd, 0x00, 0xfe, 0x01, 0x01},
|
||||
expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 1, 2),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "greater than encoding length for n1's cltvDelta",
|
||||
bytes: []byte{0xfd, 0x00, 0xfe, 0x03, 0x01, 0x01, 0x01},
|
||||
expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 3, 2),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "unknown even field for n1's namespace",
|
||||
bytes: []byte{0x0a, 0x00},
|
||||
expErr: tlv.ErrUnknownRequiredType(0x0a),
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "valid records but invalid ordering",
|
||||
bytes: []byte{0x02, 0x08,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x26, 0x01,
|
||||
0x01, 0x2a,
|
||||
},
|
||||
expErr: tlv.ErrStreamNotCanonical,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "duplicate tlv type",
|
||||
bytes: []byte{0x02, 0x08,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x31, 0x02,
|
||||
0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x51,
|
||||
},
|
||||
expErr: tlv.ErrStreamNotCanonical,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "duplicate ignored tlv type",
|
||||
bytes: []byte{0x1f, 0x00, 0x1f, 0x01, 0x2a},
|
||||
expErr: tlv.ErrStreamNotCanonical,
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "type wraparound",
|
||||
bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00},
|
||||
expErr: tlv.ErrStreamNotCanonical,
|
||||
},
|
||||
}
|
||||
|
||||
// TestTLVDecodingSuccess asserts that the TLV parser fails to decode invalid
|
||||
// TLV streams.
|
||||
func TestTLVDecodingFailures(t *testing.T) {
|
||||
for _, test := range tlvDecodingFailureTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
n1 := NewN1()
|
||||
r := bytes.NewReader(test.bytes)
|
||||
|
||||
err := n1.Decode(r)
|
||||
if !reflect.DeepEqual(err, test.expErr) {
|
||||
t.Fatalf("expected N1 decoding failure: %v, "+
|
||||
"got: %v", test.expErr, err)
|
||||
}
|
||||
|
||||
if test.skipN2 {
|
||||
return
|
||||
}
|
||||
|
||||
n2 := NewN2()
|
||||
r = bytes.NewReader(test.bytes)
|
||||
|
||||
err = n2.Decode(r)
|
||||
if !reflect.DeepEqual(err, test.expErr) {
|
||||
t.Fatalf("expected N2 decoding failure: %v, "+
|
||||
"got: %v", test.expErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var tlvDecodingSuccessTests = []struct {
|
||||
name string
|
||||
bytes []byte
|
||||
skipN2 bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
},
|
||||
{
|
||||
name: "unknown odd type",
|
||||
bytes: []byte{0x21, 0x00},
|
||||
},
|
||||
{
|
||||
name: "unknown odd type",
|
||||
bytes: []byte{0xfd, 0x02, 0x01, 0x00},
|
||||
},
|
||||
{
|
||||
name: "unknown odd type",
|
||||
bytes: []byte{0xfd, 0x00, 0xfd, 0x00},
|
||||
},
|
||||
{
|
||||
name: "unknown odd type",
|
||||
bytes: []byte{0xfd, 0x00, 0xff, 0x00},
|
||||
},
|
||||
{
|
||||
name: "unknown odd type",
|
||||
bytes: []byte{0xfe, 0x02, 0x00, 0x00, 0x01, 0x00},
|
||||
},
|
||||
{
|
||||
name: "unknown odd type",
|
||||
bytes: []byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00},
|
||||
},
|
||||
{
|
||||
name: "N1 amt=0",
|
||||
bytes: []byte{0x01, 0x00},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 amt=1",
|
||||
bytes: []byte{0x01, 0x01, 0x01},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 amt=256",
|
||||
bytes: []byte{0x01, 0x02, 0x01, 0x00},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 amt=65536",
|
||||
bytes: []byte{0x01, 0x03, 0x01, 0x00, 0x00},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 amt=16777216",
|
||||
bytes: []byte{0x01, 0x04, 0x01, 0x00, 0x00, 0x00},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 amt=4294967296",
|
||||
bytes: []byte{0x01, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 amt=1099511627776",
|
||||
bytes: []byte{0x01, 0x06, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 amt=281474976710656",
|
||||
bytes: []byte{0x01, 0x07, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 amt=72057594037927936",
|
||||
bytes: []byte{0x01, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 scid=0x0x550",
|
||||
bytes: []byte{0x02, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x26},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 node_id=023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb amount_msat_1=1 amount_msat_2=2",
|
||||
bytes: []byte{0x03, 0x31,
|
||||
0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2,
|
||||
0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47,
|
||||
0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e,
|
||||
0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x02},
|
||||
skipN2: true,
|
||||
},
|
||||
{
|
||||
name: "N1 cltv_delta=550",
|
||||
bytes: []byte{0xfd, 0x00, 0xfe, 0x02, 0x02, 0x26},
|
||||
skipN2: true,
|
||||
},
|
||||
}
|
||||
|
||||
// TestTLVDecodingSuccess asserts that the TLV parser is able to successfully
|
||||
// decode valid TLV streams.
|
||||
func TestTLVDecodingSuccess(t *testing.T) {
|
||||
for _, test := range tlvDecodingSuccessTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
n1 := NewN1()
|
||||
r := bytes.NewReader(test.bytes)
|
||||
|
||||
err := n1.Decode(r)
|
||||
if err != nil {
|
||||
t.Fatalf("expected N1 decoding success, got: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
if test.skipN2 {
|
||||
return
|
||||
}
|
||||
|
||||
n2 := NewN2()
|
||||
r = bytes.NewReader(test.bytes)
|
||||
|
||||
err = n2.Decode(r)
|
||||
if err != nil {
|
||||
t.Fatalf("expected N2 decoding succes, got: %v",
|
||||
err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
180
tlv/truncated.go
Normal file
180
tlv/truncated.go
Normal file
@ -0,0 +1,180 @@
|
||||
package tlv
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ErrTUintNotMinimal signals that decoding a truncated uint failed because the
|
||||
// value was not minimally encoded.
|
||||
var ErrTUintNotMinimal = errors.New("truncated uint not minimally encoded")
|
||||
|
||||
// numLeadingZeroBytes16 computes the number of leading zeros for a uint16.
|
||||
func numLeadingZeroBytes16(v uint16) uint64 {
|
||||
switch {
|
||||
case v == 0:
|
||||
return 2
|
||||
case v&0xff00 == 0:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// SizeTUint16 returns the number of bytes remaining in a uint16 after
|
||||
// truncating the leading zeros.
|
||||
func SizeTUint16(v uint16) uint64 {
|
||||
return 2 - numLeadingZeroBytes16(v)
|
||||
}
|
||||
|
||||
// ETUint16 is an Encoder for truncated uint16 values, where leading zeros will
|
||||
// be omitted. An error is returned if val is not a *uint16.
|
||||
func ETUint16(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if t, ok := val.(*uint16); ok {
|
||||
binary.BigEndian.PutUint16(buf[:2], *t)
|
||||
numZeros := numLeadingZeroBytes16(*t)
|
||||
_, err := w.Write(buf[numZeros:2])
|
||||
return err
|
||||
}
|
||||
return NewTypeForEncodingErr(val, "uint16")
|
||||
}
|
||||
|
||||
// DTUint16 is an Decoder for truncated uint16 values, where leading zeros will
|
||||
// be resurrected. An error is returned if val is not a *uint16.
|
||||
func DTUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if t, ok := val.(*uint16); ok && l <= 2 {
|
||||
_, err := io.ReadFull(r, buf[2-l:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
zero(buf[:2-l])
|
||||
*t = binary.BigEndian.Uint16(buf[:2])
|
||||
if 2-numLeadingZeroBytes16(*t) != l {
|
||||
return ErrTUintNotMinimal
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return NewTypeForDecodingErr(val, "uint16", l, 2)
|
||||
}
|
||||
|
||||
// numLeadingZeroBytes16 computes the number of leading zeros for a uint32.
|
||||
func numLeadingZeroBytes32(v uint32) uint64 {
|
||||
switch {
|
||||
case v == 0:
|
||||
return 4
|
||||
case v&0xffffff00 == 0:
|
||||
return 3
|
||||
case v&0xffff0000 == 0:
|
||||
return 2
|
||||
case v&0xff000000 == 0:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// SizeTUint32 returns the number of bytes remaining in a uint32 after
|
||||
// truncating the leading zeros.
|
||||
func SizeTUint32(v uint32) uint64 {
|
||||
return 4 - numLeadingZeroBytes32(v)
|
||||
}
|
||||
|
||||
// ETUint32 is an Encoder for truncated uint32 values, where leading zeros will
|
||||
// be omitted. An error is returned if val is not a *uint32.
|
||||
func ETUint32(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if t, ok := val.(*uint32); ok {
|
||||
binary.BigEndian.PutUint32(buf[:4], *t)
|
||||
numZeros := numLeadingZeroBytes32(*t)
|
||||
_, err := w.Write(buf[numZeros:4])
|
||||
return err
|
||||
}
|
||||
return NewTypeForEncodingErr(val, "uint32")
|
||||
}
|
||||
|
||||
// DTUint32 is an Decoder for truncated uint32 values, where leading zeros will
|
||||
// be resurrected. An error is returned if val is not a *uint32.
|
||||
func DTUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if t, ok := val.(*uint32); ok && l <= 4 {
|
||||
_, err := io.ReadFull(r, buf[4-l:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
zero(buf[:4-l])
|
||||
*t = binary.BigEndian.Uint32(buf[:4])
|
||||
if 4-numLeadingZeroBytes32(*t) != l {
|
||||
return ErrTUintNotMinimal
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return NewTypeForDecodingErr(val, "uint32", l, 4)
|
||||
}
|
||||
|
||||
// numLeadingZeroBytes64 computes the number of leading zeros for a uint32.
|
||||
//
|
||||
// TODO(conner): optimize using unrolled binary search
|
||||
func numLeadingZeroBytes64(v uint64) uint64 {
|
||||
switch {
|
||||
case v == 0:
|
||||
return 8
|
||||
case v&0xffffffffffffff00 == 0:
|
||||
return 7
|
||||
case v&0xffffffffffff0000 == 0:
|
||||
return 6
|
||||
case v&0xffffffffff000000 == 0:
|
||||
return 5
|
||||
case v&0xffffffff00000000 == 0:
|
||||
return 4
|
||||
case v&0xffffff0000000000 == 0:
|
||||
return 3
|
||||
case v&0xffff000000000000 == 0:
|
||||
return 2
|
||||
case v&0xff00000000000000 == 0:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// SizeTUint64 returns the number of bytes remaining in a uint64 after
|
||||
// truncating the leading zeros.
|
||||
func SizeTUint64(v uint64) uint64 {
|
||||
return 8 - numLeadingZeroBytes64(v)
|
||||
}
|
||||
|
||||
// ETUint64 is an Encoder for truncated uint64 values, where leading zeros will
|
||||
// be omitted. An error is returned if val is not a *uint64.
|
||||
func ETUint64(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if t, ok := val.(*uint64); ok {
|
||||
binary.BigEndian.PutUint64(buf[:], *t)
|
||||
numZeros := numLeadingZeroBytes64(*t)
|
||||
_, err := w.Write(buf[numZeros:])
|
||||
return err
|
||||
}
|
||||
return NewTypeForEncodingErr(val, "uint64")
|
||||
}
|
||||
|
||||
// DTUint64 is an Decoder for truncated uint64 values, where leading zeros will
|
||||
// be resurrected. An error is returned if val is not a *uint64.
|
||||
func DTUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if t, ok := val.(*uint64); ok && l <= 8 {
|
||||
_, err := io.ReadFull(r, buf[8-l:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
zero(buf[:8-l])
|
||||
*t = binary.BigEndian.Uint64(buf[:])
|
||||
if 8-numLeadingZeroBytes64(*t) != l {
|
||||
return ErrTUintNotMinimal
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return NewTypeForDecodingErr(val, "uint64", l, 8)
|
||||
}
|
||||
|
||||
// zero clears the passed byte slice.
|
||||
func zero(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0x00
|
||||
}
|
||||
}
|
109
tlv/varint.go
Normal file
109
tlv/varint.go
Normal file
@ -0,0 +1,109 @@
|
||||
package tlv
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ErrVarIntNotCanonical signals that the decoded varint was not minimally encoded.
|
||||
var ErrVarIntNotCanonical = errors.New("decoded varint is not canonical")
|
||||
|
||||
// ReadVarInt reads a variable length integer from r and returns it as a uint64.
|
||||
func ReadVarInt(r io.Reader, buf *[8]byte) (uint64, error) {
|
||||
_, err := io.ReadFull(r, buf[:1])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
discriminant := buf[0]
|
||||
|
||||
var rv uint64
|
||||
switch {
|
||||
case discriminant < 0xfd:
|
||||
rv = uint64(discriminant)
|
||||
|
||||
case discriminant == 0xfd:
|
||||
_, err := io.ReadFull(r, buf[:2])
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
case err != nil:
|
||||
return 0, err
|
||||
}
|
||||
rv = uint64(binary.BigEndian.Uint16(buf[:2]))
|
||||
|
||||
// The encoding is not canonical if the value could have been
|
||||
// encoded using fewer bytes.
|
||||
if rv < 0xfd {
|
||||
return 0, ErrVarIntNotCanonical
|
||||
}
|
||||
|
||||
case discriminant == 0xfe:
|
||||
_, err := io.ReadFull(r, buf[:4])
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
case err != nil:
|
||||
return 0, err
|
||||
}
|
||||
rv = uint64(binary.BigEndian.Uint32(buf[:4]))
|
||||
|
||||
// The encoding is not canonical if the value could have been
|
||||
// encoded using fewer bytes.
|
||||
if rv <= 0xffff {
|
||||
return 0, ErrVarIntNotCanonical
|
||||
}
|
||||
|
||||
default:
|
||||
_, err := io.ReadFull(r, buf[:])
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
case err != nil:
|
||||
return 0, err
|
||||
}
|
||||
rv = binary.BigEndian.Uint64(buf[:])
|
||||
|
||||
// The encoding is not canonical if the value could have been
|
||||
// encoded using fewer bytes.
|
||||
if rv <= 0xffffffff {
|
||||
return 0, ErrVarIntNotCanonical
|
||||
}
|
||||
}
|
||||
|
||||
return rv, nil
|
||||
}
|
||||
|
||||
// WriteVarInt serializes val to w using a variable number of bytes depending
|
||||
// on its value.
|
||||
func WriteVarInt(w io.Writer, val uint64, buf *[8]byte) error {
|
||||
var length int
|
||||
switch {
|
||||
case val < 0xfd:
|
||||
buf[0] = uint8(val)
|
||||
length = 1
|
||||
|
||||
case val <= 0xffff:
|
||||
buf[0] = uint8(0xfd)
|
||||
binary.BigEndian.PutUint16(buf[1:3], uint16(val))
|
||||
length = 3
|
||||
|
||||
case val <= 0xffffffff:
|
||||
buf[0] = uint8(0xfe)
|
||||
binary.BigEndian.PutUint32(buf[1:5], uint32(val))
|
||||
length = 5
|
||||
|
||||
default:
|
||||
buf[0] = uint8(0xff)
|
||||
_, err := w.Write(buf[:1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint64(buf[:], uint64(val))
|
||||
length = 8
|
||||
}
|
||||
|
||||
_, err := w.Write(buf[:length])
|
||||
return err
|
||||
}
|
217
tlv/varint_test.go
Normal file
217
tlv/varint_test.go
Normal file
@ -0,0 +1,217 @@
|
||||
package tlv_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
type varIntTest struct {
|
||||
Name string
|
||||
Value uint64
|
||||
Bytes []byte
|
||||
ExpErr error
|
||||
}
|
||||
|
||||
var writeVarIntTests = []varIntTest{
|
||||
{
|
||||
Name: "zero",
|
||||
Value: 0x00,
|
||||
Bytes: []byte{0x00},
|
||||
},
|
||||
{
|
||||
Name: "one byte high",
|
||||
Value: 0xfc,
|
||||
Bytes: []byte{0xfc},
|
||||
},
|
||||
{
|
||||
Name: "two byte low",
|
||||
Value: 0xfd,
|
||||
Bytes: []byte{0xfd, 0x00, 0xfd},
|
||||
},
|
||||
{
|
||||
Name: "two byte high",
|
||||
Value: 0xffff,
|
||||
Bytes: []byte{0xfd, 0xff, 0xff},
|
||||
},
|
||||
{
|
||||
Name: "four byte low",
|
||||
Value: 0x10000,
|
||||
Bytes: []byte{0xfe, 0x00, 0x01, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
Name: "four byte high",
|
||||
Value: 0xffffffff,
|
||||
Bytes: []byte{0xfe, 0xff, 0xff, 0xff, 0xff},
|
||||
},
|
||||
{
|
||||
Name: "eight byte low",
|
||||
Value: 0x100000000,
|
||||
Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
Name: "eight byte high",
|
||||
Value: math.MaxUint64,
|
||||
Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
},
|
||||
}
|
||||
|
||||
// TestWriteVarInt asserts the behavior of tlv.WriteVarInt under various
|
||||
// positive and negative test cases.
|
||||
func TestWriteVarInt(t *testing.T) {
|
||||
for _, test := range writeVarIntTests {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
testWriteVarInt(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testWriteVarInt(t *testing.T, test varIntTest) {
|
||||
var (
|
||||
w bytes.Buffer
|
||||
buf [8]byte
|
||||
)
|
||||
err := tlv.WriteVarInt(&w, test.Value, &buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode %d as varint: %v",
|
||||
test.Value, err)
|
||||
}
|
||||
|
||||
if bytes.Compare(w.Bytes(), test.Bytes) != 0 {
|
||||
t.Fatalf("expected bytes: %v, got %v",
|
||||
test.Bytes, w.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
var readVarIntTests = []varIntTest{
|
||||
{
|
||||
Name: "zero",
|
||||
Value: 0x00,
|
||||
Bytes: []byte{0x00},
|
||||
},
|
||||
{
|
||||
Name: "one byte high",
|
||||
Value: 0xfc,
|
||||
Bytes: []byte{0xfc},
|
||||
},
|
||||
{
|
||||
Name: "two byte low",
|
||||
Value: 0xfd,
|
||||
Bytes: []byte{0xfd, 0x00, 0xfd},
|
||||
},
|
||||
{
|
||||
Name: "two byte high",
|
||||
Value: 0xffff,
|
||||
Bytes: []byte{0xfd, 0xff, 0xff},
|
||||
},
|
||||
{
|
||||
Name: "four byte low",
|
||||
Value: 0x10000,
|
||||
Bytes: []byte{0xfe, 0x00, 0x01, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
Name: "four byte high",
|
||||
Value: 0xffffffff,
|
||||
Bytes: []byte{0xfe, 0xff, 0xff, 0xff, 0xff},
|
||||
},
|
||||
{
|
||||
Name: "eight byte low",
|
||||
Value: 0x100000000,
|
||||
Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
Name: "eight byte high",
|
||||
Value: math.MaxUint64,
|
||||
Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
},
|
||||
{
|
||||
Name: "two byte not canonical",
|
||||
Bytes: []byte{0xfd, 0x00, 0xfc},
|
||||
ExpErr: tlv.ErrVarIntNotCanonical,
|
||||
},
|
||||
{
|
||||
Name: "four byte not canonical",
|
||||
Bytes: []byte{0xfe, 0x00, 0x00, 0xff, 0xff},
|
||||
ExpErr: tlv.ErrVarIntNotCanonical,
|
||||
},
|
||||
{
|
||||
Name: "eight byte not canonical",
|
||||
Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff},
|
||||
ExpErr: tlv.ErrVarIntNotCanonical,
|
||||
},
|
||||
{
|
||||
Name: "two byte short read",
|
||||
Bytes: []byte{0xfd, 0x00},
|
||||
ExpErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
Name: "four byte short read",
|
||||
Bytes: []byte{0xfe, 0xff, 0xff},
|
||||
ExpErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
Name: "eight byte short read",
|
||||
Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
ExpErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
Name: "one byte no read",
|
||||
Bytes: []byte{},
|
||||
ExpErr: io.EOF,
|
||||
},
|
||||
// The following cases are the reason for needing to make a custom
|
||||
// version of the varint for the tlv package. For the varint encodings
|
||||
// in btcd's wire package these would return io.EOF, since it is
|
||||
// actually a composite of two calls to io.ReadFull. In TLV, we need to
|
||||
// be able to distinguish whether no bytes were read at all from no
|
||||
// Bytes being read on the second read as the latter is not a proper TLV
|
||||
// stream. We handle this by returning io.ErrUnexpectedEOF if we
|
||||
// encounter io.EOF on any of these secondary reads for larger values.
|
||||
{
|
||||
Name: "two byte no read",
|
||||
Bytes: []byte{0xfd},
|
||||
ExpErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
Name: "four byte no read",
|
||||
Bytes: []byte{0xfe},
|
||||
ExpErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
Name: "eight byte no read",
|
||||
Bytes: []byte{0xff},
|
||||
ExpErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
}
|
||||
|
||||
// TestReadVarInt asserts the behavior of tlv.ReadVarInt under various positive
|
||||
// and negative test cases.
|
||||
func TestReadVarInt(t *testing.T) {
|
||||
for _, test := range readVarIntTests {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
testReadVarInt(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testReadVarInt(t *testing.T, test varIntTest) {
|
||||
var buf [8]byte
|
||||
r := bytes.NewReader(test.Bytes)
|
||||
val, err := tlv.ReadVarInt(r, &buf)
|
||||
if err != nil && err != test.ExpErr {
|
||||
t.Fatalf("expected decoding error: %v, got: %v",
|
||||
test.ExpErr, err)
|
||||
}
|
||||
|
||||
// If we expected a decoding error, there's no point checking the value.
|
||||
if test.ExpErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if val != test.Value {
|
||||
t.Fatalf("expected value: %d, got %d", test.Value, val)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user