166 lines
3.9 KiB
Go
166 lines
3.9 KiB
Go
package tlv_test
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"io/ioutil"
|
|
"testing"
|
|
|
|
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
|
|
"github.com/lightningnetwork/lnd/tlv"
|
|
"github.com/lightningnetwork/lnd/watchtower/blob"
|
|
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// CreateSessionTLV mirrors the wtwire.CreateSession message, but uses TLV for
|
|
// encoding/decoding.
|
|
type CreateSessionTLV struct {
|
|
BlobType blob.Type
|
|
MaxUpdates uint16
|
|
RewardBase uint32
|
|
RewardRate uint32
|
|
SweepFeeRate chainfee.SatPerKWeight
|
|
|
|
tlvStream *tlv.Stream
|
|
}
|
|
|
|
// EBlobType is an encoder for blob.Type.
|
|
func EBlobType(w io.Writer, val interface{}, buf *[8]byte) error {
|
|
if t, ok := val.(*blob.Type); ok {
|
|
return tlv.EUint16T(w, uint16(*t), buf)
|
|
}
|
|
return tlv.NewTypeForEncodingErr(val, "blob.Type")
|
|
}
|
|
|
|
// EBlobType is an decoder for blob.Type.
|
|
func DBlobType(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
|
if typ, ok := val.(*blob.Type); ok {
|
|
var t uint16
|
|
err := tlv.DUint16(r, &t, buf, l)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*typ = blob.Type(t)
|
|
return nil
|
|
}
|
|
return tlv.NewTypeForDecodingErr(val, "blob.Type", l, 2)
|
|
}
|
|
|
|
// ESatPerKW is an encoder for lnwallet.SatPerKWeight.
|
|
func ESatPerKW(w io.Writer, val interface{}, buf *[8]byte) error {
|
|
if v, ok := val.(*chainfee.SatPerKWeight); ok {
|
|
v64 := uint64(*v)
|
|
return tlv.EUint64(w, &v64, buf)
|
|
}
|
|
return tlv.NewTypeForEncodingErr(val, "chainfee.SatPerKWeight")
|
|
}
|
|
|
|
// DSatPerKW is an decoder for lnwallet.SatPerKWeight.
|
|
func DSatPerKW(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
|
if v, ok := val.(*chainfee.SatPerKWeight); ok {
|
|
var sat uint64
|
|
err := tlv.DUint64(r, &sat, buf, l)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*v = chainfee.SatPerKWeight(sat)
|
|
return nil
|
|
}
|
|
return tlv.NewTypeForDecodingErr(val, "chainfee.SatPerKWeight", l, 8)
|
|
}
|
|
|
|
// NewCreateSessionTLV initializes a new CreateSessionTLV message.
|
|
func NewCreateSessionTLV() *CreateSessionTLV {
|
|
m := &CreateSessionTLV{}
|
|
m.tlvStream = tlv.MustNewStream(
|
|
tlv.MakeStaticRecord(0, &m.BlobType, 2, EBlobType, DBlobType),
|
|
tlv.MakePrimitiveRecord(1, &m.MaxUpdates),
|
|
tlv.MakePrimitiveRecord(2, &m.RewardBase),
|
|
tlv.MakePrimitiveRecord(3, &m.RewardRate),
|
|
tlv.MakeStaticRecord(4, &m.SweepFeeRate, 8, ESatPerKW, DSatPerKW),
|
|
)
|
|
|
|
return m
|
|
}
|
|
|
|
// Encode writes the CreateSessionTLV to the passed io.Writer.
|
|
func (c *CreateSessionTLV) Encode(w io.Writer) error {
|
|
return c.tlvStream.Encode(w)
|
|
}
|
|
|
|
// Decode reads the CreateSessionTLV from the passed io.Reader.
|
|
func (c *CreateSessionTLV) Decode(r io.Reader) error {
|
|
return c.tlvStream.Decode(r)
|
|
}
|
|
|
|
// BenchmarkEncodeCreateSession benchmarks encoding of the non-TLV
|
|
// CreateSession.
|
|
func BenchmarkEncodeCreateSession(t *testing.B) {
|
|
m := &wtwire.CreateSession{}
|
|
|
|
t.ReportAllocs()
|
|
t.ResetTimer()
|
|
|
|
var err error
|
|
for i := 0; i < t.N; i++ {
|
|
err = m.Encode(ioutil.Discard, 0)
|
|
}
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// BenchmarkEncodeCreateSessionTLV benchmarks encoding of the TLV CreateSession.
|
|
func BenchmarkEncodeCreateSessionTLV(t *testing.B) {
|
|
m := NewCreateSessionTLV()
|
|
|
|
t.ReportAllocs()
|
|
t.ResetTimer()
|
|
|
|
var err error
|
|
for i := 0; i < t.N; i++ {
|
|
err = m.Encode(ioutil.Discard)
|
|
}
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// BenchmarkDecodeCreateSession benchmarks encoding of the non-TLV
|
|
// CreateSession.
|
|
func BenchmarkDecodeCreateSession(t *testing.B) {
|
|
m := &wtwire.CreateSession{}
|
|
|
|
var b bytes.Buffer
|
|
err := m.Encode(&b, 0)
|
|
require.NoError(t, err)
|
|
|
|
r := bytes.NewReader(b.Bytes())
|
|
|
|
t.ReportAllocs()
|
|
t.ResetTimer()
|
|
|
|
for i := 0; i < t.N; i++ {
|
|
r.Seek(0, 0)
|
|
err = m.Decode(r, 0)
|
|
}
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// BenchmarkDecodeCreateSessionTLV benchmarks decoding of the TLV CreateSession.
|
|
func BenchmarkDecodeCreateSessionTLV(t *testing.B) {
|
|
m := NewCreateSessionTLV()
|
|
|
|
var b bytes.Buffer
|
|
err := m.Encode(&b)
|
|
require.NoError(t, err)
|
|
|
|
r := bytes.NewReader(b.Bytes())
|
|
|
|
t.ReportAllocs()
|
|
t.ResetTimer()
|
|
|
|
for i := 0; i < t.N; i++ {
|
|
r.Seek(0, 0)
|
|
err = m.Decode(r)
|
|
}
|
|
require.NoError(t, err)
|
|
}
|