diff --git a/routing/pathfind.go b/routing/pathfind.go index 297392a0..ea5315b6 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -99,11 +99,12 @@ type edgePolicyWithSource struct { // finalHopParams encapsulates various parameters for route construction that // apply to the final hop in a route. These features include basic payment data // such as amounts and cltvs, as well as more complex features like destination -// custom records. +// custom records and payment address. type finalHopParams struct { - amt lnwire.MilliSatoshi - cltvDelta uint16 - records record.CustomSet + amt lnwire.MilliSatoshi + cltvDelta uint16 + records record.CustomSet + paymentAddr *[32]byte } // newRoute constructs a route using the provided path and final hop constraints. @@ -152,6 +153,7 @@ func newRoute(sourceVertex route.Vertex, outgoingTimeLock uint32 tlvPayload bool customRecords record.CustomSet + mpp *record.MPP ) // Define a helper function that checks this edge's feature @@ -191,6 +193,21 @@ func newRoute(sourceVertex route.Vertex, "custom records") } customRecords = finalHop.records + + // If we're attaching a payment addr but the receiver + // doesn't support both TLV and payment addrs, fail. + payAddr := supports(lnwire.PaymentAddrOptional) + if !payAddr && finalHop.paymentAddr != nil { + return nil, errors.New("cannot attach " + + "payment addr") + } + + // Otherwise attach the mpp record if it exists. + if finalHop.paymentAddr != nil { + mpp = record.NewMPP( + finalHop.amt, *finalHop.paymentAddr, + ) + } } else { // The amount that the current hop needs to forward is // equal to the incoming amount of the next hop. @@ -220,6 +237,7 @@ func newRoute(sourceVertex route.Vertex, OutgoingTimeLock: outgoingTimeLock, LegacyPayload: !tlvPayload, CustomRecords: customRecords, + MPP: mpp, } hops = append([]*route.Hop{currentHop}, hops...) diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index f7607166..acff0d2b 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -13,6 +13,7 @@ import ( "math/big" "net" "os" + "reflect" "strings" "testing" "time" @@ -1061,6 +1062,8 @@ func TestNewRoute(t *testing.T) { var sourceKey [33]byte sourceVertex := route.Vertex(sourceKey) + testPaymentAddr := [32]byte{0x01, 0x02, 0x03} + const ( startingHeight = 100 finalHopCLTV = 1 @@ -1099,6 +1102,8 @@ func TestNewRoute(t *testing.T) { // overwrite the final hop's feature vector in the graph. destFeatures *lnwire.FeatureVector + paymentAddr *[32]byte + // expectedFees is a list of fees that every hop is expected // to charge for forwarding. expectedFees []lnwire.MilliSatoshi @@ -1129,6 +1134,8 @@ func TestNewRoute(t *testing.T) { expectedErrorCode errorCode expectedTLVPayload bool + + expectedMPP *record.MPP }{ { // For a single hop payment, no fees are expected to be paid. @@ -1171,6 +1178,26 @@ func TestNewRoute(t *testing.T) { expectedTotalAmount: 100130, expectedTotalTimeLock: 6, expectedTLVPayload: true, + }, { + // For a two hop payment, only the fee for the first hop + // needs to be paid. The destination hop does not require + // a fee to receive the payment. + name: "two hop single shot mpp", + destFeatures: tlvPayAddrFeatures, + paymentAddr: &testPaymentAddr, + paymentAmount: 100000, + hops: []*channeldb.ChannelEdgePolicy{ + createHop(0, 1000, 1000000, 10), + createHop(30, 1000, 1000000, 5), + }, + expectedFees: []lnwire.MilliSatoshi{130, 0}, + expectedTimeLocks: []uint32{1, 1}, + expectedTotalAmount: 100130, + expectedTotalTimeLock: 6, + expectedTLVPayload: true, + expectedMPP: record.NewMPP( + 100000, testPaymentAddr, + ), }, { // A three hop payment where the first and second hop // will both charge 1 msat. The fee for the first hop @@ -1284,20 +1311,29 @@ func TestNewRoute(t *testing.T) { if !finalHop.LegacyPayload != testCase.expectedTLVPayload { - t.Errorf("Expected tlv payload: %t, "+ + t.Errorf("Expected final hop tlv payload: %t, "+ "but got: %t instead", testCase.expectedTLVPayload, !finalHop.LegacyPayload) } + + if !reflect.DeepEqual( + finalHop.MPP, testCase.expectedMPP, + ) { + t.Errorf("Expected final hop mpp field: %v, "+ + " but got: %v instead", + testCase.expectedMPP, finalHop.MPP) + } } t.Run(testCase.name, func(t *testing.T) { route, err := newRoute( sourceVertex, testCase.hops, startingHeight, finalHopParams{ - amt: testCase.paymentAmount, - cltvDelta: finalHopCLTV, - records: nil, + amt: testCase.paymentAmount, + cltvDelta: finalHopCLTV, + records: nil, + paymentAddr: testCase.paymentAddr, }, )