diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index c6a25ba6..e233537f 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -119,6 +119,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { amt uint64 cltv uint32 mpp = &record.MPP{} + amp = &record.AMP{} ) tlvStream, err := tlv.NewStream( @@ -126,6 +127,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { record.NewLockTimeRecord(&cltv), record.NewNextHopIDRecord(&cid), mpp.Record(), + amp.Record(), ) if err != nil { return nil, err @@ -160,6 +162,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { mpp = nil } + // If no AMP field was parsed, set the MPP field on the resulting + // payload to nil. + if _, ok := parsedTypes[record.AMPOnionType]; !ok { + amp = nil + } + // Filter out the custom records. customRecords := NewCustomRecords(parsedTypes) @@ -171,6 +179,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { OutgoingCTLV: cltv, }, MPP: mpp, + AMP: amp, customRecords: customRecords, }, nil } @@ -207,6 +216,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, _, hasLockTime := parsedTypes[record.LockTimeOnionType] _, hasNextHop := parsedTypes[record.NextHopOnionType] _, hasMPP := parsedTypes[record.MPPOnionType] + _, hasAMP := parsedTypes[record.AMPOnionType] switch { @@ -243,6 +253,14 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, Violation: IncludedViolation, FinalHop: isFinalHop, } + + // Intermediate nodes should never receive AMP fields. + case !isFinalHop && hasAMP: + return ErrInvalidPayload{ + Type: record.AMPOnionType, + Violation: IncludedViolation, + FinalHop: isFinalHop, + } } return nil diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index bfe2c134..c7abc9fa 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -8,6 +8,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/stretchr/testify/require" ) const testUnknownRequiredType = 0x10 @@ -18,6 +19,7 @@ type decodePayloadTest struct { expErr error expCustomRecords map[uint64][]byte shouldHaveMPP bool + shouldHaveAMP bool } var decodePayloadTests = []decodePayloadTest{ @@ -183,6 +185,37 @@ var decodePayloadTests = []decodePayloadTest{ FinalHop: false, }, }, + { + name: "intermediate hop with amp", + payload: []byte{ + // amount + 0x02, 0x00, + // cltv + 0x04, 0x00, + // next hop id + 0x06, 0x08, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // amp + 0x0e, 0x41, + // amp.root_share + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + // amp.set_id + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + // amp.child_index + 0x09, + }, + expErr: hop.ErrInvalidPayload{ + Type: record.AMPOnionType, + Violation: hop.IncludedViolation, + FinalHop: false, + }, + }, { name: "final hop with mpp", payload: []byte{ @@ -201,6 +234,30 @@ var decodePayloadTests = []decodePayloadTest{ expErr: nil, shouldHaveMPP: true, }, + { + name: "final hop with amp", + payload: []byte{ + // amount + 0x02, 0x00, + // cltv + 0x04, 0x00, + // amp + 0x0e, 0x41, + // amp.root_share + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + // amp.set_id + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + // amp.child_index + 0x09, + }, + shouldHaveAMP: true, + }, } // TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the @@ -223,6 +280,20 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, } + + testRootShare = [32]byte{ + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + } + testSetID = [32]byte{ + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, + } + testChildIndex = uint32(9) ) p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload)) @@ -249,6 +320,17 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { t.Fatalf("unexpected MPP payload") } + if test.shouldHaveAMP { + if p.AMP == nil { + t.Fatalf("payload should have AMP record") + } + require.Equal(t, testRootShare, p.AMP.RootShare()) + require.Equal(t, testSetID, p.AMP.SetID()) + require.Equal(t, testChildIndex, p.AMP.ChildIndex()) + } else if p.AMP != nil { + t.Fatalf("unexpected AMP payload") + } + // Convert expected nil map to empty map, because we always expect an // initiated map from the payload. expCustomRecords := make(record.CustomSet)