You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
3.9 KiB
165 lines
3.9 KiB
package tlv_test |
|
|
|
import ( |
|
"bytes" |
|
"io" |
|
"io/ioutil" |
|
"testing" |
|
|
|
"github.com/lightningnetwork/lnd/lnwallet/chainfee" |
|
"github.com/lightningnetwork/lnd/tlv" |
|
"github.com/lightningnetwork/lnd/watchtower/blob" |
|
"github.com/lightningnetwork/lnd/watchtower/wtwire" |
|
"github.com/stretchr/testify/require" |
|
) |
|
|
|
// CreateSessionTLV mirrors the wtwire.CreateSession message, but uses TLV for |
|
// encoding/decoding. |
|
type CreateSessionTLV struct { |
|
BlobType blob.Type |
|
MaxUpdates uint16 |
|
RewardBase uint32 |
|
RewardRate uint32 |
|
SweepFeeRate chainfee.SatPerKWeight |
|
|
|
tlvStream *tlv.Stream |
|
} |
|
|
|
// EBlobType is an encoder for blob.Type. |
|
func EBlobType(w io.Writer, val interface{}, buf *[8]byte) error { |
|
if t, ok := val.(*blob.Type); ok { |
|
return tlv.EUint16T(w, uint16(*t), buf) |
|
} |
|
return tlv.NewTypeForEncodingErr(val, "blob.Type") |
|
} |
|
|
|
// EBlobType is an decoder for blob.Type. |
|
func DBlobType(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { |
|
if typ, ok := val.(*blob.Type); ok { |
|
var t uint16 |
|
err := tlv.DUint16(r, &t, buf, l) |
|
if err != nil { |
|
return err |
|
} |
|
*typ = blob.Type(t) |
|
return nil |
|
} |
|
return tlv.NewTypeForDecodingErr(val, "blob.Type", l, 2) |
|
} |
|
|
|
// ESatPerKW is an encoder for lnwallet.SatPerKWeight. |
|
func ESatPerKW(w io.Writer, val interface{}, buf *[8]byte) error { |
|
if v, ok := val.(*chainfee.SatPerKWeight); ok { |
|
v64 := uint64(*v) |
|
return tlv.EUint64(w, &v64, buf) |
|
} |
|
return tlv.NewTypeForEncodingErr(val, "chainfee.SatPerKWeight") |
|
} |
|
|
|
// DSatPerKW is an decoder for lnwallet.SatPerKWeight. |
|
func DSatPerKW(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { |
|
if v, ok := val.(*chainfee.SatPerKWeight); ok { |
|
var sat uint64 |
|
err := tlv.DUint64(r, &sat, buf, l) |
|
if err != nil { |
|
return err |
|
} |
|
*v = chainfee.SatPerKWeight(sat) |
|
return nil |
|
} |
|
return tlv.NewTypeForDecodingErr(val, "chainfee.SatPerKWeight", l, 8) |
|
} |
|
|
|
// NewCreateSessionTLV initializes a new CreateSessionTLV message. |
|
func NewCreateSessionTLV() *CreateSessionTLV { |
|
m := &CreateSessionTLV{} |
|
m.tlvStream = tlv.MustNewStream( |
|
tlv.MakeStaticRecord(0, &m.BlobType, 2, EBlobType, DBlobType), |
|
tlv.MakePrimitiveRecord(1, &m.MaxUpdates), |
|
tlv.MakePrimitiveRecord(2, &m.RewardBase), |
|
tlv.MakePrimitiveRecord(3, &m.RewardRate), |
|
tlv.MakeStaticRecord(4, &m.SweepFeeRate, 8, ESatPerKW, DSatPerKW), |
|
) |
|
|
|
return m |
|
} |
|
|
|
// Encode writes the CreateSessionTLV to the passed io.Writer. |
|
func (c *CreateSessionTLV) Encode(w io.Writer) error { |
|
return c.tlvStream.Encode(w) |
|
} |
|
|
|
// Decode reads the CreateSessionTLV from the passed io.Reader. |
|
func (c *CreateSessionTLV) Decode(r io.Reader) error { |
|
return c.tlvStream.Decode(r) |
|
} |
|
|
|
// BenchmarkEncodeCreateSession benchmarks encoding of the non-TLV |
|
// CreateSession. |
|
func BenchmarkEncodeCreateSession(t *testing.B) { |
|
m := &wtwire.CreateSession{} |
|
|
|
t.ReportAllocs() |
|
t.ResetTimer() |
|
|
|
var err error |
|
for i := 0; i < t.N; i++ { |
|
err = m.Encode(ioutil.Discard, 0) |
|
} |
|
require.NoError(t, err) |
|
} |
|
|
|
// BenchmarkEncodeCreateSessionTLV benchmarks encoding of the TLV CreateSession. |
|
func BenchmarkEncodeCreateSessionTLV(t *testing.B) { |
|
m := NewCreateSessionTLV() |
|
|
|
t.ReportAllocs() |
|
t.ResetTimer() |
|
|
|
var err error |
|
for i := 0; i < t.N; i++ { |
|
err = m.Encode(ioutil.Discard) |
|
} |
|
require.NoError(t, err) |
|
} |
|
|
|
// BenchmarkDecodeCreateSession benchmarks encoding of the non-TLV |
|
// CreateSession. |
|
func BenchmarkDecodeCreateSession(t *testing.B) { |
|
m := &wtwire.CreateSession{} |
|
|
|
var b bytes.Buffer |
|
err := m.Encode(&b, 0) |
|
require.NoError(t, err) |
|
|
|
r := bytes.NewReader(b.Bytes()) |
|
|
|
t.ReportAllocs() |
|
t.ResetTimer() |
|
|
|
for i := 0; i < t.N; i++ { |
|
r.Seek(0, 0) |
|
err = m.Decode(r, 0) |
|
} |
|
require.NoError(t, err) |
|
} |
|
|
|
// BenchmarkDecodeCreateSessionTLV benchmarks decoding of the TLV CreateSession. |
|
func BenchmarkDecodeCreateSessionTLV(t *testing.B) { |
|
m := NewCreateSessionTLV() |
|
|
|
var b bytes.Buffer |
|
err := m.Encode(&b) |
|
require.NoError(t, err) |
|
|
|
r := bytes.NewReader(b.Bytes()) |
|
|
|
t.ReportAllocs() |
|
t.ResetTimer() |
|
|
|
for i := 0; i < t.N; i++ { |
|
r.Seek(0, 0) |
|
err = m.Decode(r) |
|
} |
|
require.NoError(t, err) |
|
}
|
|
|