diff --git a/record/mpp.go b/record/mpp.go new file mode 100644 index 00000000..b28d1085 --- /dev/null +++ b/record/mpp.go @@ -0,0 +1,98 @@ +package record + +import ( + "io" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// MPPOnionType is the type used in the onion to reference the MPP fields: +// total_amt and payment_addr. +const MPPOnionType tlv.Type = 8 + +// MPP is a record that encodes the fields necessary for multi-path payments. +type MPP struct { + // paymentAddr is a random, receiver-generated value used to avoid + // collisions with concurrent payers. + paymentAddr [32]byte + + // totalMsat is the total value of the payment, potentially spread + // across more than one HTLC. + totalMsat lnwire.MilliSatoshi +} + +// NewMPP generates a new MPP record with the given total and payment address. +func NewMPP(total lnwire.MilliSatoshi, addr [32]byte) *MPP { + return &MPP{ + paymentAddr: addr, + totalMsat: total, + } +} + +// PaymentAddr returns the payment address contained in the MPP record. +func (r *MPP) PaymentAddr() [32]byte { + return r.paymentAddr +} + +// TotalMsat returns the total value of an MPP payment in msats. +func (r *MPP) TotalMsat() lnwire.MilliSatoshi { + return r.totalMsat +} + +// MPPEncoder writes the MPP record to the provided io.Writer. +func MPPEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*MPP); ok { + err := tlv.EBytes32(w, &v.paymentAddr, buf) + if err != nil { + return err + } + + return tlv.ETUint64T(w, uint64(v.totalMsat), buf) + } + return tlv.NewTypeForEncodingErr(val, "MPP") +} + +const ( + // minMPPLength is the minimum length of a serialized MPP TLV record, + // which occurs when the truncated encoding of total_amt_msat takes 0 + // bytes, leaving only the payment_addr. + minMPPLength = 32 + + // maxMPPLength is the maximum length of a serialized MPP TLV record, + // which occurs when the truncated encoding of total_amt_msat takes 8 + // bytes. + maxMPPLength = 40 +) + +// MPPDecoder reads the MPP record to the provided io.Reader. +func MPPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*MPP); ok && minMPPLength <= l && l <= maxMPPLength { + if err := tlv.DBytes32(r, &v.paymentAddr, buf, 32); err != nil { + return err + } + + var total uint64 + if err := tlv.DTUint64(r, &total, buf, l-32); err != nil { + return err + } + v.totalMsat = lnwire.MilliSatoshi(total) + + return nil + + } + return tlv.NewTypeForDecodingErr(val, "MPP", l, maxMPPLength) +} + +// Record returns a tlv.Record that can be used to encode or decode this record. +func (r *MPP) Record() tlv.Record { + // Fixed-size, 32 byte payment address followed by truncated 64-bit + // total msat. + size := func() uint64 { + return 32 + tlv.SizeTUint64(uint64(r.totalMsat)) + } + + return tlv.MakeDynamicRecord( + MPPOnionType, r, size, MPPEncoder, MPPDecoder, + ) +} diff --git a/record/record_test.go b/record/record_test.go new file mode 100644 index 00000000..052e2f1f --- /dev/null +++ b/record/record_test.go @@ -0,0 +1,73 @@ +package record_test + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/tlv" +) + +type recordEncDecTest struct { + name string + encRecord func() tlv.RecordProducer + decRecord func() tlv.RecordProducer + assert func(*testing.T, interface{}) +} + +var ( + testTotal = lnwire.MilliSatoshi(45) + testAddr = [32]byte{0x01, 0x02} +) + +var recordEncDecTests = []recordEncDecTest{ + { + name: "mpp", + encRecord: func() tlv.RecordProducer { + return record.NewMPP(testTotal, testAddr) + }, + decRecord: func() tlv.RecordProducer { + return new(record.MPP) + }, + assert: func(t *testing.T, r interface{}) { + mpp := r.(*record.MPP) + if mpp.TotalMsat() != testTotal { + t.Fatal("incorrect total msat") + } + if mpp.PaymentAddr() != testAddr { + t.Fatal("incorrect payment addr") + } + }, + }, +} + +// TestRecordEncodeDecode is a generic test framework for custom TLV records. It +// asserts that records can encode and decode themselves, and that the value of +// the original record matches the decoded record. +func TestRecordEncodeDecode(t *testing.T) { + for _, test := range recordEncDecTests { + test := test + t.Run(test.name, func(t *testing.T) { + r := test.encRecord() + r2 := test.decRecord() + encStream := tlv.MustNewStream(r.Record()) + decStream := tlv.MustNewStream(r2.Record()) + + test.assert(t, r) + + var b bytes.Buffer + err := encStream.Encode(&b) + if err != nil { + t.Fatalf("unable to encode record: %v", err) + } + + err = decStream.Decode(bytes.NewReader(b.Bytes())) + if err != nil { + t.Fatalf("unable to decode record: %v", err) + } + + test.assert(t, r2) + }) + } +} diff --git a/tlv/record.go b/tlv/record.go index 75647e03..66ae8f1c 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -43,6 +43,14 @@ func SizeVarBytes(e *[]byte) SizeFunc { } } +// RecorderProducer is an interface for objects that can produce a Record object +// capable of encoding and/or decoding the RecordProducer as a Record. +type RecordProducer interface { + // Record returns a Record that can be used to encode or decode the + // backing object. + Record() Record +} + // Record holds the required information to encode or decode a TLV record. type Record struct { value interface{}