Merge pull request #3957 from cfromknecht/amp-record
record+routing/route: add AMP record
This commit is contained in:
commit
07977a2bf0
107
record/amp.go
Normal file
107
record/amp.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
package record
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AMPOnionType is the type used in the onion to reference the AMP fields:
|
||||||
|
// root_share, set_id, and child_index.
|
||||||
|
const AMPOnionType tlv.Type = 10
|
||||||
|
|
||||||
|
// AMP is a record that encodes the fields necessary for atomic multi-path
|
||||||
|
// payments.
|
||||||
|
type AMP struct {
|
||||||
|
rootShare [32]byte
|
||||||
|
setID [32]byte
|
||||||
|
childIndex uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAMP generate a new AMP record with the given root_share, set_id, and
|
||||||
|
// child_index.
|
||||||
|
func NewAMP(rootShare, setID [32]byte, childIndex uint16) *AMP {
|
||||||
|
return &{
|
||||||
|
rootShare: rootShare,
|
||||||
|
setID: setID,
|
||||||
|
childIndex: childIndex,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RootShare returns the root share contained in the AMP record.
|
||||||
|
func (a *AMP) RootShare() [32]byte {
|
||||||
|
return a.rootShare
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetID returns the set id contained in the AMP record.
|
||||||
|
func (a *AMP) SetID() [32]byte {
|
||||||
|
return a.setID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChildIndex returns the child index contained in the AMP record.
|
||||||
|
func (a *AMP) ChildIndex() uint16 {
|
||||||
|
return a.childIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
// AMPEncoder writes the AMP record to the provided io.Writer.
|
||||||
|
func AMPEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||||
|
if v, ok := val.(*AMP); ok {
|
||||||
|
if err := tlv.EBytes32(w, &v.rootShare, buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tlv.EBytes32(w, &v.setID, buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlv.ETUint16T(w, v.childIndex, buf)
|
||||||
|
}
|
||||||
|
return tlv.NewTypeForEncodingErr(val, "AMP")
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// minAMPLength is the minimum length of a serialized AMP TLV record,
|
||||||
|
// which occurs when the truncated encoding of child_index takes 0
|
||||||
|
// bytes, leaving only the root_share and set_id.
|
||||||
|
minAMPLength = 64
|
||||||
|
|
||||||
|
// maxAMPLength is the maximum legnth of a serialized AMP TLV record,
|
||||||
|
// which occurs when the truncated endoing of a child_index takes 2
|
||||||
|
// bytes.
|
||||||
|
maxAMPLength = 66
|
||||||
|
)
|
||||||
|
|
||||||
|
// AMPDecoder reads the AMP record from the provided io.Reader.
|
||||||
|
func AMPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||||
|
if v, ok := val.(*AMP); ok && minAMPLength <= l && l <= maxAMPLength {
|
||||||
|
if err := tlv.DBytes32(r, &v.rootShare, buf, 32); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tlv.DBytes32(r, &v.setID, buf, 32); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlv.DTUint16(r, &v.childIndex, buf, l-64)
|
||||||
|
}
|
||||||
|
return tlv.NewTypeForDecodingErr(val, "AMP", l, maxAMPLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record returns a tlv.Record that can be used to encode or decode this record.
|
||||||
|
func (a *AMP) Record() tlv.Record {
|
||||||
|
return tlv.MakeDynamicRecord(
|
||||||
|
AMPOnionType, a, a.PayloadSize, AMPEncoder, AMPDecoder,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayloadSize returns the size this record takes up in encoded form.
|
||||||
|
func (a *AMP) PayloadSize() uint64 {
|
||||||
|
return 32 + 32 + tlv.SizeTUint16(a.childIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a human-readble description of the amp payload fields.
|
||||||
|
func (a *AMP) String() string {
|
||||||
|
return fmt.Sprintf("root_share=%x set_id=%x child_index=%d",
|
||||||
|
a.rootShare, a.setID, a.childIndex)
|
||||||
|
}
|
@ -17,8 +17,11 @@ type recordEncDecTest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testTotal = lnwire.MilliSatoshi(45)
|
testTotal = lnwire.MilliSatoshi(45)
|
||||||
testAddr = [32]byte{0x01, 0x02}
|
testAddr = [32]byte{0x01, 0x02}
|
||||||
|
testShare = [32]byte{0x03, 0x04}
|
||||||
|
testSetID = [32]byte{0x05, 0x06}
|
||||||
|
testChildIndex = uint16(17)
|
||||||
)
|
)
|
||||||
|
|
||||||
var recordEncDecTests = []recordEncDecTest{
|
var recordEncDecTests = []recordEncDecTest{
|
||||||
@ -40,6 +43,29 @@ var recordEncDecTests = []recordEncDecTest{
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "amp",
|
||||||
|
encRecord: func() tlv.RecordProducer {
|
||||||
|
return record.NewAMP(
|
||||||
|
testShare, testSetID, testChildIndex,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
decRecord: func() tlv.RecordProducer {
|
||||||
|
return new(record.AMP)
|
||||||
|
},
|
||||||
|
assert: func(t *testing.T, r interface{}) {
|
||||||
|
amp := r.(*record.AMP)
|
||||||
|
if amp.RootShare() != testShare {
|
||||||
|
t.Fatal("incorrect root share")
|
||||||
|
}
|
||||||
|
if amp.SetID() != testSetID {
|
||||||
|
t.Fatal("incorrect set id")
|
||||||
|
}
|
||||||
|
if amp.ChildIndex() != testChildIndex {
|
||||||
|
t.Fatal("incorrect child index")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRecordEncodeDecode is a generic test framework for custom TLV records. It
|
// TestRecordEncodeDecode is a generic test framework for custom TLV records. It
|
||||||
|
@ -34,6 +34,10 @@ var (
|
|||||||
// record to an intermediate hop, only final hops can receive MPP
|
// record to an intermediate hop, only final hops can receive MPP
|
||||||
// records.
|
// records.
|
||||||
ErrIntermediateMPPHop = errors.New("cannot send MPP to intermediate")
|
ErrIntermediateMPPHop = errors.New("cannot send MPP to intermediate")
|
||||||
|
|
||||||
|
// ErrAMPMissingMPP is returned when the caller tries to attach an AMP
|
||||||
|
// record but no MPP record is presented for the final hop.
|
||||||
|
ErrAMPMissingMPP = errors.New("cannot send AMP without MPP record")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Vertex is a simple alias for the serialization of a compressed Bitcoin
|
// Vertex is a simple alias for the serialization of a compressed Bitcoin
|
||||||
@ -111,6 +115,10 @@ type Hop struct {
|
|||||||
// only be set for the final hop.
|
// only be set for the final hop.
|
||||||
MPP *record.MPP
|
MPP *record.MPP
|
||||||
|
|
||||||
|
// AMP encapsulates the data required for option_amp. This field should
|
||||||
|
// only be set for the final hop.
|
||||||
|
AMP *record.AMP
|
||||||
|
|
||||||
// CustomRecords if non-nil are a set of additional TLV records that
|
// CustomRecords if non-nil are a set of additional TLV records that
|
||||||
// should be included in the forwarding instructions for this node.
|
// should be included in the forwarding instructions for this node.
|
||||||
CustomRecords record.CustomSet
|
CustomRecords record.CustomSet
|
||||||
@ -168,6 +176,18 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If an AMP record is destined for this hop, ensure that we only ever
|
||||||
|
// attach it if we also have an MPP record. We can infer that this is
|
||||||
|
// already a final hop if MPP is non-nil otherwise we would have exited
|
||||||
|
// above.
|
||||||
|
if h.AMP != nil {
|
||||||
|
if h.MPP != nil {
|
||||||
|
records = append(records, h.AMP.Record())
|
||||||
|
} else {
|
||||||
|
return ErrAMPMissingMPP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Append any custom types destined for this hop.
|
// Append any custom types destined for this hop.
|
||||||
tlvRecords := tlv.MapToRecords(h.CustomRecords)
|
tlvRecords := tlv.MapToRecords(h.CustomRecords)
|
||||||
records = append(records, tlvRecords...)
|
records = append(records, tlvRecords...)
|
||||||
@ -217,6 +237,11 @@ func (h *Hop) PayloadSize(nextChanID uint64) uint64 {
|
|||||||
addRecord(record.MPPOnionType, h.MPP.PayloadSize())
|
addRecord(record.MPPOnionType, h.MPP.PayloadSize())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add amp if present.
|
||||||
|
if h.AMP != nil {
|
||||||
|
addRecord(record.AMPOnionType, h.AMP.PayloadSize())
|
||||||
|
}
|
||||||
|
|
||||||
// Add custom records.
|
// Add custom records.
|
||||||
for k, v := range h.CustomRecords {
|
for k, v := range h.CustomRecords {
|
||||||
addRecord(tlv.Type(k), uint64(len(v)))
|
addRecord(tlv.Type(k), uint64(len(v)))
|
||||||
|
@ -71,8 +71,8 @@ var (
|
|||||||
testAddr = [32]byte{0x01, 0x02}
|
testAddr = [32]byte{0x01, 0x02}
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestMPPHop asserts that a Hop will encode a non-nil to final nodes, and fail
|
// TestMPPHop asserts that a Hop will encode a non-nil MPP to final nodes, and
|
||||||
// when trying to send to intermediaries.
|
// fail when trying to send to intermediaries.
|
||||||
func TestMPPHop(t *testing.T) {
|
func TestMPPHop(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -101,6 +101,47 @@ func TestMPPHop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAMPHop asserts that a Hop will encode a non-nil AMP to final nodes of an
|
||||||
|
// MPP record is also present, and fail otherwise.
|
||||||
|
func TestAMPHop(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
hop := Hop{
|
||||||
|
ChannelID: 1,
|
||||||
|
OutgoingTimeLock: 44,
|
||||||
|
AmtToForward: testAmt,
|
||||||
|
LegacyPayload: false,
|
||||||
|
AMP: record.NewAMP([32]byte{}, [32]byte{}, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encoding an AMP record to an intermediate hop w/o an MPP record
|
||||||
|
// should result in a failure.
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := hop.PackHopPayload(&b, 2)
|
||||||
|
if err != ErrAMPMissingMPP {
|
||||||
|
t.Fatalf("expected err: %v, got: %v",
|
||||||
|
ErrAMPMissingMPP, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encoding an AMP record to a final hop w/o an MPP record should result
|
||||||
|
// in a failure.
|
||||||
|
b.Reset()
|
||||||
|
err = hop.PackHopPayload(&b, 0)
|
||||||
|
if err != ErrAMPMissingMPP {
|
||||||
|
t.Fatalf("expected err: %v, got: %v",
|
||||||
|
ErrAMPMissingMPP, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encoding an AMP record to a final hop w/ an MPP record should be
|
||||||
|
// successful.
|
||||||
|
hop.MPP = record.NewMPP(testAmt, testAddr)
|
||||||
|
b.Reset()
|
||||||
|
err = hop.PackHopPayload(&b, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected err: %v, got: %v", nil, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestPayloadSize tests the payload size calculation that is provided by Hop
|
// TestPayloadSize tests the payload size calculation that is provided by Hop
|
||||||
// structs.
|
// structs.
|
||||||
func TestPayloadSize(t *testing.T) {
|
func TestPayloadSize(t *testing.T) {
|
||||||
@ -123,6 +164,7 @@ func TestPayloadSize(t *testing.T) {
|
|||||||
AmtToForward: 1200,
|
AmtToForward: 1200,
|
||||||
OutgoingTimeLock: 700000,
|
OutgoingTimeLock: 700000,
|
||||||
MPP: record.NewMPP(500, [32]byte{}),
|
MPP: record.NewMPP(500, [32]byte{}),
|
||||||
|
AMP: record.NewAMP([32]byte{}, [32]byte{}, 8),
|
||||||
CustomRecords: map[uint64][]byte{
|
CustomRecords: map[uint64][]byte{
|
||||||
100000: {1, 2, 3},
|
100000: {1, 2, 3},
|
||||||
1000000: {4, 5},
|
1000000: {4, 5},
|
||||||
|
Loading…
Reference in New Issue
Block a user