diff --git a/lntest/itest/lnd_multi-hop-error-propagation_test.go b/lntest/itest/lnd_multi-hop-error-propagation_test.go index e0a02280..6fc62921 100644 --- a/lntest/itest/lnd_multi-hop-error-propagation_test.go +++ b/lntest/itest/lnd_multi-hop-error-propagation_test.go @@ -204,6 +204,7 @@ out: FinalCltvDelta: int32(carolPayReq.CltvExpiry), TimeoutSeconds: 60, FeeLimitMsat: noFeeLimitMsat, + MaxParts: 1, } sendAndAssertFailure( t, net.Alice, @@ -240,6 +241,7 @@ out: FinalCltvDelta: int32(carolPayReq.CltvExpiry), TimeoutSeconds: 60, FeeLimitMsat: noFeeLimitMsat, + MaxParts: 1, } sendAndAssertFailure( t, net.Alice, @@ -300,6 +302,7 @@ out: PaymentRequest: carolInvoice2.PaymentRequest, TimeoutSeconds: 60, FeeLimitMsat: noFeeLimitMsat, + MaxParts: 1, }, ) @@ -332,6 +335,7 @@ out: PaymentRequest: carolInvoice3.PaymentRequest, TimeoutSeconds: 60, FeeLimitMsat: noFeeLimitMsat, + MaxParts: 1, } sendAndAssertFailure( t, net.Alice, @@ -381,6 +385,7 @@ out: PaymentRequest: carolInvoice.PaymentRequest, TimeoutSeconds: 60, FeeLimitMsat: noFeeLimitMsat, + MaxParts: 1, }, lnrpc.PaymentFailureReason_FAILURE_REASON_NO_ROUTE, ) diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index ddbc1036..c2ebae48 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -28,6 +28,7 @@ type integratedRoutingContext struct { target *mockNode amt lnwire.MilliSatoshi + maxShardAmt *lnwire.MilliSatoshi finalExpiry int32 mcCfg MissionControlConfig @@ -151,6 +152,10 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, MaxParts: maxParts, } + if c.maxShardAmt != nil { + payment.MaxShardAmt = c.maxShardAmt + } + session, err := newPaymentSession( &payment, getBandwidthHints, func() (routingGraph, func(), error) { diff --git a/routing/integrated_routing_test.go b/routing/integrated_routing_test.go index 6764da5c..80fefaf9 100644 --- a/routing/integrated_routing_test.go +++ b/routing/integrated_routing_test.go @@ -89,6 +89,7 @@ type mppSendTestCase struct { graph func(g *mockGraph) expectedFailure bool maxParts uint32 + maxShardSize btcutil.Amount } const ( @@ -208,6 +209,33 @@ var mppTestCases = []mppSendTestCase{ expectedFailure: true, maxParts: 10, }, + + // Test that if maxShardSize is set, then all attempts are below the + // max shard size, yet still sum up to the total payment amount. A + // payment of 30k satoshis with a max shard size of 10k satoshis should + // produce 3 payments of 10k sats each. + { + name: "max shard size clamping", + graph: onePathGraph, + amt: 30_000, + expectedAttempts: 3, + expectedSuccesses: []expectedHtlcSuccess{ + { + amt: 10_000, + chans: []uint64{chanSourceIm1, chanIm1Target}, + }, + { + amt: 10_000, + chans: []uint64{chanSourceIm1, chanIm1Target}, + }, + { + amt: 10_000, + chans: []uint64{chanSourceIm1, chanIm1Target}, + }, + }, + maxParts: 1000, + maxShardSize: 10_000, + }, } // TestMppSend tests that a payment can be completed using multiple shards. @@ -229,6 +257,11 @@ func testMppSend(t *testing.T, testCase *mppSendTestCase) { ctx.amt = lnwire.NewMSatFromSatoshis(testCase.amt) + if testCase.maxShardSize != 0 { + shardAmt := lnwire.NewMSatFromSatoshis(testCase.maxShardSize) + ctx.maxShardAmt = &shardAmt + } + attempts, err := ctx.testPayment(testCase.maxParts) switch { case err == nil && testCase.expectedFailure: diff --git a/routing/payment_session.go b/routing/payment_session.go index 4b766d4c..58666583 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -230,6 +230,18 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, finalHtlcExpiry := int32(height) + int32(finalCltvDelta) + // Before we enter the loop below, we'll make sure to respect the max + // payment shard size (if it's set), which is effectively our + // client-side MTU that we'll attempt to respect at all times. + maxShardActive := p.payment.MaxShardAmt != nil + if maxShardActive && maxAmt > *p.payment.MaxShardAmt { + p.log.Debug("Clamping payment attempt from %v to %v due to "+ + "max shard size of %v", maxAmt, + *p.payment.MaxShardAmt, maxAmt) + + maxAmt = *p.payment.MaxShardAmt + } + for { // We'll also obtain a set of bandwidthHints from the lower // layer for each of our outbound channels. This will allow the @@ -279,7 +291,8 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, } if !p.payment.DestFeatures.HasFeature(lnwire.MPPOptional) { - p.log.Debug("not splitting because destination doesn't declare MPP") + p.log.Debug("not splitting because " + + "destination doesn't declare MPP") return nil, errNoPathFound }