From de88a4b17427eb489c57d316d3073e6cbc11597f Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Tue, 28 Jan 2020 06:43:07 -0800 Subject: [PATCH 1/3] record: add AMP record and encode/decode methods --- record/amp.go | 107 ++++++++++++++++++++++++++++++++++++++++++ record/record_test.go | 30 +++++++++++- 2 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 record/amp.go diff --git a/record/amp.go b/record/amp.go new file mode 100644 index 00000000..72b4cbf4 --- /dev/null +++ b/record/amp.go @@ -0,0 +1,107 @@ +package record + +import ( + "fmt" + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +// AMPOnionType is the type used in the onion to reference the AMP fields: +// root_share, set_id, and child_index. +const AMPOnionType tlv.Type = 10 + +// AMP is a record that encodes the fields necessary for atomic multi-path +// payments. +type AMP struct { + rootShare [32]byte + setID [32]byte + childIndex uint16 +} + +// NewAMP generate a new AMP record with the given root_share, set_id, and +// child_index. +func NewAMP(rootShare, setID [32]byte, childIndex uint16) *AMP { + return &{ + rootShare: rootShare, + setID: setID, + childIndex: childIndex, + } +} + +// RootShare returns the root share contained in the AMP record. +func (a *AMP) RootShare() [32]byte { + return a.rootShare +} + +// SetID returns the set id contained in the AMP record. +func (a *AMP) SetID() [32]byte { + return a.setID +} + +// ChildIndex returns the child index contained in the AMP record. +func (a *AMP) ChildIndex() uint16 { + return a.childIndex +} + +// AMPEncoder writes the AMP record to the provided io.Writer. +func AMPEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*AMP); ok { + if err := tlv.EBytes32(w, &v.rootShare, buf); err != nil { + return err + } + + if err := tlv.EBytes32(w, &v.setID, buf); err != nil { + return err + } + + return tlv.ETUint16T(w, v.childIndex, buf) + } + return tlv.NewTypeForEncodingErr(val, "AMP") +} + +const ( + // minAMPLength is the minimum length of a serialized AMP TLV record, + // which occurs when the truncated encoding of child_index takes 0 + // bytes, leaving only the root_share and set_id. + minAMPLength = 64 + + // maxAMPLength is the maximum legnth of a serialized AMP TLV record, + // which occurs when the truncated endoing of a child_index takes 2 + // bytes. + maxAMPLength = 66 +) + +// AMPDecoder reads the AMP record from the provided io.Reader. +func AMPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*AMP); ok && minAMPLength <= l && l <= maxAMPLength { + if err := tlv.DBytes32(r, &v.rootShare, buf, 32); err != nil { + return err + } + + if err := tlv.DBytes32(r, &v.setID, buf, 32); err != nil { + return err + } + + return tlv.DTUint16(r, &v.childIndex, buf, l-64) + } + return tlv.NewTypeForDecodingErr(val, "AMP", l, maxAMPLength) +} + +// Record returns a tlv.Record that can be used to encode or decode this record. +func (a *AMP) Record() tlv.Record { + return tlv.MakeDynamicRecord( + AMPOnionType, a, a.PayloadSize, AMPEncoder, AMPDecoder, + ) +} + +// PayloadSize returns the size this record takes up in encoded form. +func (a *AMP) PayloadSize() uint64 { + return 32 + 32 + tlv.SizeTUint16(a.childIndex) +} + +// String returns a human-readble description of the amp payload fields. +func (a *AMP) String() string { + return fmt.Sprintf("root_share=%x set_id=%x child_index=%d", + a.rootShare, a.setID, a.childIndex) +} diff --git a/record/record_test.go b/record/record_test.go index 052e2f1f..8c39790e 100644 --- a/record/record_test.go +++ b/record/record_test.go @@ -17,8 +17,11 @@ type recordEncDecTest struct { } var ( - testTotal = lnwire.MilliSatoshi(45) - testAddr = [32]byte{0x01, 0x02} + testTotal = lnwire.MilliSatoshi(45) + testAddr = [32]byte{0x01, 0x02} + testShare = [32]byte{0x03, 0x04} + testSetID = [32]byte{0x05, 0x06} + testChildIndex = uint16(17) ) var recordEncDecTests = []recordEncDecTest{ @@ -40,6 +43,29 @@ var recordEncDecTests = []recordEncDecTest{ } }, }, + { + name: "amp", + encRecord: func() tlv.RecordProducer { + return record.NewAMP( + testShare, testSetID, testChildIndex, + ) + }, + decRecord: func() tlv.RecordProducer { + return new(record.AMP) + }, + assert: func(t *testing.T, r interface{}) { + amp := r.(*record.AMP) + if amp.RootShare() != testShare { + t.Fatal("incorrect root share") + } + if amp.SetID() != testSetID { + t.Fatal("incorrect set id") + } + if amp.ChildIndex() != testChildIndex { + t.Fatal("incorrect child index") + } + }, + }, } // TestRecordEncodeDecode is a generic test framework for custom TLV records. It From 0cb27151e50d2d6ea7a6144d962db3223a110fa7 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Tue, 28 Jan 2020 06:43:34 -0800 Subject: [PATCH 2/3] routing/route: add AMP record to payload size calcs --- routing/route/route.go | 25 +++++++++++++++++++++++++ routing/route/route_test.go | 5 +++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/routing/route/route.go b/routing/route/route.go index 90a407e3..f4de728f 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -34,6 +34,10 @@ var ( // record to an intermediate hop, only final hops can receive MPP // records. ErrIntermediateMPPHop = errors.New("cannot send MPP to intermediate") + + // ErrAMPMissingMPP is returned when the caller tries to attach an AMP + // record but no MPP record is presented for the final hop. + ErrAMPMissingMPP = errors.New("cannot send AMP without MPP record") ) // Vertex is a simple alias for the serialization of a compressed Bitcoin @@ -111,6 +115,10 @@ type Hop struct { // only be set for the final hop. MPP *record.MPP + // AMP encapsulates the data required for option_amp. This field should + // only be set for the final hop. + AMP *record.AMP + // CustomRecords if non-nil are a set of additional TLV records that // should be included in the forwarding instructions for this node. CustomRecords record.CustomSet @@ -168,6 +176,18 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error { } } + // If an AMP record is destined for this hop, ensure that we only ever + // attach it if we also have an MPP record. We can infer that this is + // already a final hop if MPP is non-nil otherwise we would have exited + // above. + if h.AMP != nil { + if h.MPP != nil { + records = append(records, h.AMP.Record()) + } else { + return ErrAMPMissingMPP + } + } + // Append any custom types destined for this hop. tlvRecords := tlv.MapToRecords(h.CustomRecords) records = append(records, tlvRecords...) @@ -217,6 +237,11 @@ func (h *Hop) PayloadSize(nextChanID uint64) uint64 { addRecord(record.MPPOnionType, h.MPP.PayloadSize()) } + // Add amp if present. + if h.AMP != nil { + addRecord(record.AMPOnionType, h.AMP.PayloadSize()) + } + // Add custom records. for k, v := range h.CustomRecords { addRecord(tlv.Type(k), uint64(len(v))) diff --git a/routing/route/route_test.go b/routing/route/route_test.go index 6c32d8be..df1b9fc3 100644 --- a/routing/route/route_test.go +++ b/routing/route/route_test.go @@ -71,8 +71,8 @@ var ( testAddr = [32]byte{0x01, 0x02} ) -// TestMPPHop asserts that a Hop will encode a non-nil to final nodes, and fail -// when trying to send to intermediaries. +// TestMPPHop asserts that a Hop will encode a non-nil MPP to final nodes, and +// fail when trying to send to intermediaries. func TestMPPHop(t *testing.T) { t.Parallel() @@ -123,6 +123,7 @@ func TestPayloadSize(t *testing.T) { AmtToForward: 1200, OutgoingTimeLock: 700000, MPP: record.NewMPP(500, [32]byte{}), + AMP: record.NewAMP([32]byte{}, [32]byte{}, 8), CustomRecords: map[uint64][]byte{ 100000: {1, 2, 3}, 1000000: {4, 5}, From 9fc197d8b191de51cb4afc42bf8ffbb2217469b5 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Tue, 28 Jan 2020 06:43:44 -0800 Subject: [PATCH 3/3] routing/route: fix TestMPPHop comment --- routing/route/route_test.go | 41 +++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/routing/route/route_test.go b/routing/route/route_test.go index df1b9fc3..2095430b 100644 --- a/routing/route/route_test.go +++ b/routing/route/route_test.go @@ -101,6 +101,47 @@ func TestMPPHop(t *testing.T) { } } +// TestAMPHop asserts that a Hop will encode a non-nil AMP to final nodes of an +// MPP record is also present, and fail otherwise. +func TestAMPHop(t *testing.T) { + t.Parallel() + + hop := Hop{ + ChannelID: 1, + OutgoingTimeLock: 44, + AmtToForward: testAmt, + LegacyPayload: false, + AMP: record.NewAMP([32]byte{}, [32]byte{}, 3), + } + + // Encoding an AMP record to an intermediate hop w/o an MPP record + // should result in a failure. + var b bytes.Buffer + err := hop.PackHopPayload(&b, 2) + if err != ErrAMPMissingMPP { + t.Fatalf("expected err: %v, got: %v", + ErrAMPMissingMPP, err) + } + + // Encoding an AMP record to a final hop w/o an MPP record should result + // in a failure. + b.Reset() + err = hop.PackHopPayload(&b, 0) + if err != ErrAMPMissingMPP { + t.Fatalf("expected err: %v, got: %v", + ErrAMPMissingMPP, err) + } + + // Encoding an AMP record to a final hop w/ an MPP record should be + // successful. + hop.MPP = record.NewMPP(testAmt, testAddr) + b.Reset() + err = hop.PackHopPayload(&b, 0) + if err != nil { + t.Fatalf("expected err: %v, got: %v", nil, err) + } +} + // TestPayloadSize tests the payload size calculation that is provided by Hop // structs. func TestPayloadSize(t *testing.T) {