diff --git a/routing/mock_test.go b/routing/mock_test.go index 6332bea8..64a1e4f6 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -10,7 +10,6 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" - "github.com/lightningnetwork/lnd/zpay32" ) type mockPaymentAttemptDispatcher struct { @@ -78,8 +77,8 @@ type mockPaymentSessionSource struct { var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) -func (m *mockPaymentSessionSource) NewPaymentSession(routeHints [][]zpay32.HopHint, - target route.Vertex) (PaymentSession, error) { +func (m *mockPaymentSessionSource) NewPaymentSession( + _ *LightningPayment) (PaymentSession, error) { return &mockPaymentSession{m.routes}, nil } @@ -123,9 +122,7 @@ type mockPaymentSession struct { var _ PaymentSession = (*mockPaymentSession)(nil) -func (m *mockPaymentSession) RequestRoute(payment *LightningPayment, - height uint32, finalCltvDelta uint16) (*route.Route, error) { - +func (m *mockPaymentSession) RequestRoute(height uint32) (*route.Route, error) { if len(m.routes) == 0 { return nil, fmt.Errorf("no routes") } diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 96441d17..58bafc7d 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -29,15 +29,14 @@ func (e errNoRoute) Error() string { // paymentLifecycle holds all information about the current state of a payment // needed to resume if from any point. type paymentLifecycle struct { - router *ChannelRouter - payment *LightningPayment - paySession PaymentSession - timeoutChan <-chan time.Time - currentHeight int32 - finalCLTVDelta uint16 - attempt *channeldb.HTLCAttemptInfo - circuit *sphinx.Circuit - lastError error + router *ChannelRouter + payment *LightningPayment + paySession PaymentSession + timeoutChan <-chan time.Time + currentHeight int32 + attempt *channeldb.HTLCAttemptInfo + circuit *sphinx.Circuit + lastError error } // resumePayment resumes the paymentLifecycle from the current state. @@ -267,9 +266,7 @@ func (p *paymentLifecycle) createNewPaymentAttempt() (lnwire.ShortChannelID, } // Create a new payment attempt from the given payment session. - rt, err := p.paySession.RequestRoute( - p.payment, uint32(p.currentHeight), p.finalCLTVDelta, - ) + rt, err := p.paySession.RequestRoute(uint32(p.currentHeight)) if err != nil { log.Warnf("Failed to find route for payment %x: %v", p.payment.PaymentHash, err) diff --git a/routing/payment_session.go b/routing/payment_session.go index 47732e01..fda93536 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -24,8 +24,7 @@ var ( type PaymentSession interface { // RequestRoute returns the next route to attempt for routing the // specified HTLC payment to the target node. - RequestRoute(payment *LightningPayment, - height uint32, finalCltvDelta uint16) (*route.Route, error) + RequestRoute(height uint32) (*route.Route, error) } // paymentSession is used during an HTLC routings session to prune the local @@ -43,6 +42,8 @@ type paymentSession struct { sessionSource *SessionSource + payment *LightningPayment + preBuiltRoute *route.Route preBuiltRouteTried bool @@ -58,8 +59,7 @@ type paymentSession struct { // // NOTE: This function is safe for concurrent access. // NOTE: Part of the PaymentSession interface. -func (p *paymentSession) RequestRoute(payment *LightningPayment, - height uint32, finalCltvDelta uint16) (*route.Route, error) { +func (p *paymentSession) RequestRoute(height uint32) (*route.Route, error) { switch { @@ -77,12 +77,13 @@ func (p *paymentSession) RequestRoute(payment *LightningPayment, // Add BlockPadding to the finalCltvDelta so that the receiving node // does not reject the HTLC if some blocks are mined while it's in-flight. + finalCltvDelta := p.payment.FinalCLTVDelta finalCltvDelta += BlockPadding // We need to subtract the final delta before passing it into path // finding. The optimal path is independent of the final cltv delta and // the path finding algorithm is unaware of this value. - cltvLimit := payment.CltvLimit - uint32(finalCltvDelta) + cltvLimit := p.payment.CltvLimit - uint32(finalCltvDelta) // TODO(roasbeef): sync logic amongst dist sys @@ -93,13 +94,13 @@ func (p *paymentSession) RequestRoute(payment *LightningPayment, restrictions := &RestrictParams{ ProbabilitySource: ss.MissionControl.GetProbability, - FeeLimit: payment.FeeLimit, - OutgoingChannelID: payment.OutgoingChannelID, - LastHop: payment.LastHop, + FeeLimit: p.payment.FeeLimit, + OutgoingChannelID: p.payment.OutgoingChannelID, + LastHop: p.payment.LastHop, CltvLimit: cltvLimit, - DestCustomRecords: payment.DestCustomRecords, - DestFeatures: payment.DestFeatures, - PaymentAddr: payment.PaymentAddr, + DestCustomRecords: p.payment.DestCustomRecords, + DestFeatures: p.payment.DestFeatures, + PaymentAddr: p.payment.PaymentAddr, } // We'll also obtain a set of bandwidthHints from the lower layer for @@ -122,8 +123,8 @@ func (p *paymentSession) RequestRoute(payment *LightningPayment, bandwidthHints: bandwidthHints, }, restrictions, &ss.PathFindingConfig, - ss.SelfNode.PubKeyBytes, payment.Target, - payment.Amount, finalHtlcExpiry, + ss.SelfNode.PubKeyBytes, p.payment.Target, + p.payment.Amount, finalHtlcExpiry, ) if err != nil { return nil, err @@ -135,10 +136,10 @@ func (p *paymentSession) RequestRoute(payment *LightningPayment, route, err := newRoute( sourceVertex, path, height, finalHopParams{ - amt: payment.Amount, + amt: p.payment.Amount, cltvDelta: finalCltvDelta, - records: payment.DestCustomRecords, - paymentAddr: payment.PaymentAddr, + records: p.payment.DestCustomRecords, + paymentAddr: p.payment.PaymentAddr, }, ) if err != nil { diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 05295b58..f3fd968e 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -47,10 +47,10 @@ type SessionSource struct { // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the // payment's destination. -func (m *SessionSource) NewPaymentSession(routeHints [][]zpay32.HopHint, - target route.Vertex) (PaymentSession, error) { +func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( + PaymentSession, error) { - edges, err := RouteHintsToEdges(routeHints, target) + edges, err := RouteHintsToEdges(p.RouteHints, p.Target) if err != nil { return nil, err } @@ -70,6 +70,7 @@ func (m *SessionSource) NewPaymentSession(routeHints [][]zpay32.HopHint, additionalEdges: edges, getBandwidthHints: getBandwidthHints, sessionSource: m, + payment: p, pathFinder: findPath, }, nil } diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 55549c44..6d795b89 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -44,16 +44,6 @@ func TestRequestRoute(t *testing.T) { }, } - session := &paymentSession{ - getBandwidthHints: func() (map[uint64]lnwire.MilliSatoshi, - error) { - - return nil, nil - }, - sessionSource: sessionSource, - pathFinder: findPath, - } - cltvLimit := uint32(30) finalCltvDelta := uint16(8) @@ -62,7 +52,18 @@ func TestRequestRoute(t *testing.T) { FinalCLTVDelta: finalCltvDelta, } - route, err := session.RequestRoute(payment, height, finalCltvDelta) + session := &paymentSession{ + getBandwidthHints: func() (map[uint64]lnwire.MilliSatoshi, + error) { + + return nil, nil + }, + sessionSource: sessionSource, + payment: payment, + pathFinder: findPath, + } + + route, err := session.RequestRoute(height) if err != nil { t.Fatal(err) } diff --git a/routing/router.go b/routing/router.go index ca7a20f8..5ef4dcaa 100644 --- a/routing/router.go +++ b/routing/router.go @@ -159,8 +159,7 @@ type PaymentSessionSource interface { // routes to the given target. An optional set of routing hints can be // provided in order to populate additional edges to explore when // finding a path to the payment's destination. - NewPaymentSession(routeHints [][]zpay32.HopHint, - target route.Vertex) (PaymentSession, error) + NewPaymentSession(p *LightningPayment) (PaymentSession, error) // NewPaymentSessionForRoute creates a new paymentSession instance that // is just used for failure reporting to missioncontrol, and will only @@ -1677,9 +1676,7 @@ func (r *ChannelRouter) preparePayment(payment *LightningPayment) ( // Before starting the HTLC routing attempt, we'll create a fresh // payment session which will report our errors back to mission // control. - paySession, err := r.cfg.SessionSource.NewPaymentSession( - payment.RouteHints, payment.Target, - ) + paySession, err := r.cfg.SessionSource.NewPaymentSession(payment) if err != nil { return nil, err } @@ -1813,14 +1810,13 @@ func (r *ChannelRouter) sendPayment( // Now set up a paymentLifecycle struct with these params, such that we // can resume the payment from the current state. p := &paymentLifecycle{ - router: r, - payment: payment, - paySession: paySession, - currentHeight: currentHeight, - finalCLTVDelta: uint16(payment.FinalCLTVDelta), - attempt: existingAttempt, - circuit: nil, - lastError: nil, + router: r, + payment: payment, + paySession: paySession, + currentHeight: currentHeight, + attempt: existingAttempt, + circuit: nil, + lastError: nil, } // If a timeout is specified, create a timeout channel. If no timeout is