diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index 13c09369..63be23a2 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -8,7 +8,6 @@ import ( "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/routing/route" ) var ( @@ -192,16 +191,17 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, // duplicate payments to the same payment hash. The provided preimage is // atomically saved to the DB for record keeping. func (p *PaymentControl) Success(paymentHash lntypes.Hash, - preimage lntypes.Preimage) (*route.Route, error) { + preimage lntypes.Preimage) (*MPPayment, error) { var ( updateErr error - route *route.Route + payment *MPPayment ) err := p.db.Batch(func(tx *bbolt.Tx) error { // Reset the update error, to avoid carrying over an error // from a previous execution of the batched db transaction. updateErr = nil + payment = nil bucket, err := fetchPaymentBucket(tx, paymentHash) if err == ErrPaymentNotInitiated { @@ -225,20 +225,14 @@ func (p *PaymentControl) Success(paymentHash lntypes.Hash, } // Retrieve attempt info for the notification. - attempt, err := fetchPaymentAttempt(bucket) - if err != nil { - return err - } - - route = &attempt.Route - - return nil + payment, err = fetchPayment(bucket) + return err }) if err != nil { return nil, err } - return route, updateErr + return payment, updateErr } // Fail transitions a payment into the Failed state, and records the reason the @@ -246,16 +240,17 @@ func (p *PaymentControl) Success(paymentHash lntypes.Hash, // its next call for this payment hash, allowing the switch to make a // subsequent payment. func (p *PaymentControl) Fail(paymentHash lntypes.Hash, - reason FailureReason) (*route.Route, error) { + reason FailureReason) (*MPPayment, error) { var ( updateErr error - route *route.Route + payment *MPPayment ) err := p.db.Batch(func(tx *bbolt.Tx) error { // Reset the update error, to avoid carrying over an error // from a previous execution of the batched db transaction. updateErr = nil + payment = nil bucket, err := fetchPaymentBucket(tx, paymentHash) if err == ErrPaymentNotInitiated { @@ -279,28 +274,21 @@ func (p *PaymentControl) Fail(paymentHash lntypes.Hash, } // Retrieve attempt info for the notification, if available. - attempt, err := fetchPaymentAttempt(bucket) - if err != nil && err != errNoAttemptInfo { - return err - } - if err != errNoAttemptInfo { - route = &attempt.Route - } - - return nil + payment, err = fetchPayment(bucket) + return err }) if err != nil { return nil, err } - return route, updateErr + return payment, updateErr } // FetchPayment returns information about a payment from the database. func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) ( - *Payment, error) { + *MPPayment, error) { - var payment *Payment + var payment *MPPayment err := p.db.View(func(tx *bbolt.Tx) error { bucket, err := fetchPaymentBucket(tx, paymentHash) if err != nil { diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 479e6898..7be04c6b 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -14,7 +14,6 @@ import ( "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/routing/route" ) func initDB() (*DB, error) { @@ -132,16 +131,22 @@ func TestPaymentControlSwitchFail(t *testing.T) { ) // Verifies that status was changed to StatusSucceeded. - var route *route.Route - route, err = pControl.Success(info.PaymentHash, preimg) + var payment *MPPayment + payment, err = pControl.Success(info.PaymentHash, preimg) if err != nil { t.Fatalf("error shouldn't have been received, got: %v", err) } - err = assertRouteEqual(route, &attempt.Route) + if len(payment.HTLCs) != 1 { + t.Fatalf("payment should have one htlc, got: %d", + len(payment.HTLCs)) + } + + err = assertRouteEqual(&payment.HTLCs[0].Route, &attempt.Route) if err != nil { t.Fatalf("unexpected route returned: %v vs %v: %v", - spew.Sdump(attempt.Route), spew.Sdump(*route), err) + spew.Sdump(attempt.Route), + spew.Sdump(payment.HTLCs[0].Route), err) } assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) diff --git a/channeldb/payments.go b/channeldb/payments.go index c1ff79e9..7a541543 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -337,7 +337,7 @@ func (db *DB) FetchPayments() ([]*MPPayment, error) { return err } - payments = append(payments, p.ToMPPayment()) + payments = append(payments, p) // For older versions of lnd, duplicate payments to a // payment has was possible. These will be found in a @@ -362,7 +362,7 @@ func (db *DB) FetchPayments() ([]*MPPayment, error) { return err } - payments = append(payments, p.ToMPPayment()) + payments = append(payments, p) return nil }) }) @@ -379,7 +379,7 @@ func (db *DB) FetchPayments() ([]*MPPayment, error) { return payments, nil } -func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) { +func fetchPayment(bucket *bbolt.Bucket) (*MPPayment, error) { var ( err error p = &Payment{} @@ -434,7 +434,7 @@ func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) { p.Failure = &reason } - return p, nil + return p.ToMPPayment(), nil } // DeletePayments deletes all completed and failed payments from the DB. diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 5cf84dbc..3432ecd8 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -584,19 +584,12 @@ func (s *Server) trackPayment(paymentHash lntypes.Hash, case result := <-resultChan: // Marshall result to rpc type. var status PaymentStatus - if result.Success { log.Debugf("Payment %v successfully completed", paymentHash) status.State = PaymentState_SUCCEEDED status.Preimage = result.Preimage[:] - status.Route, err = router.MarshallRoute( - result.Route, - ) - if err != nil { - return err - } } else { state, err := marshallFailureReason( result.FailureReason, @@ -605,15 +598,49 @@ func (s *Server) trackPayment(paymentHash lntypes.Hash, return err } status.State = state - if result.Route != nil { - status.Route, err = router.MarshallRoute( - result.Route, - ) - if err != nil { - return err - } + } + + // Extract the last route from the given list of HTLCs. This + // will populate the legacy route field for backwards + // compatibility. + // + // NOTE: For now there will be at most one HTLC, this code + // should be revisted or the field removed when multiple HTLCs + // are permitted. + var legacyRoute *route.Route + for _, htlc := range result.HTLCs { + switch { + case htlc.Settle != nil: + legacyRoute = &htlc.Route + + // Only display the route for failed payments if we got + // an incorrect payment details error, so that it can be + // used for probing or fee estimation. + case htlc.Failure != nil && result.FailureReason == + channeldb.FailureReasonPaymentDetails: + + legacyRoute = &htlc.Route } } + if legacyRoute != nil { + status.Route, err = router.MarshallRoute(legacyRoute) + if err != nil { + return err + } + } + + // Marshal our list of HTLCs that have been tried for this + // payment. + htlcs := make([]*lnrpc.HTLCAttempt, 0, len(result.HTLCs)) + for _, dbHtlc := range result.HTLCs { + htlc, err := router.MarshalHTLCAttempt(dbHtlc) + if err != nil { + return err + } + + htlcs = append(htlcs, htlc) + } + status.Htlcs = htlcs // Send event to the client. err = stream.Send(&status) diff --git a/routing/control_tower.go b/routing/control_tower.go index ee8d75c8..4cf6839f 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -6,7 +6,6 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/routing/route" ) // ControlTower tracks all outgoing payments made, whose primary purpose is to @@ -52,10 +51,6 @@ type PaymentResult struct { // Success indicates whether the payment was successful. Success bool - // Route is the (last) route attempted to send the HTLC. It is only set - // for successful payments. - Route *route.Route - // Preimage is the preimage of a successful payment. This serves as a // proof of payment. It is only set for successful payments. Preimage lntypes.Preimage @@ -63,6 +58,10 @@ type PaymentResult struct { // FailureReason is a failure reason code indicating the reason the // payment failed. It is only set for failed payments. FailureReason channeldb.FailureReason + + // HTLCs is a list of HTLCs that have been attempted in order to settle + // the payment. + HTLCs []channeldb.HTLCAttempt } // controlTower is persistent implementation of ControlTower to restrict @@ -107,46 +106,46 @@ func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash, func (p *controlTower) Success(paymentHash lntypes.Hash, preimage lntypes.Preimage) error { - route, err := p.db.Success(paymentHash, preimage) + payment, err := p.db.Success(paymentHash, preimage) if err != nil { return err } // Notify subscribers of success event. p.notifyFinalEvent( - paymentHash, createSuccessResult(route, preimage), + paymentHash, createSuccessResult(payment.HTLCs), ) return nil } // createSuccessResult creates a success result to send to subscribers. -func createSuccessResult(rt *route.Route, - preimage lntypes.Preimage) *PaymentResult { +func createSuccessResult(htlcs []channeldb.HTLCAttempt) *PaymentResult { + // Extract any preimage from the list of HTLCs. + var preimage lntypes.Preimage + for _, htlc := range htlcs { + if htlc.Settle != nil { + preimage = htlc.Settle.Preimage + break + } + } return &PaymentResult{ Success: true, Preimage: preimage, - Route: rt, + HTLCs: htlcs, } } // createFailResult creates a failed result to send to subscribers. -func createFailedResult(rt *route.Route, +func createFailedResult(htlcs []channeldb.HTLCAttempt, reason channeldb.FailureReason) *PaymentResult { - result := &PaymentResult{ + return &PaymentResult{ Success: false, FailureReason: reason, + HTLCs: htlcs, } - - // In case of incorrect payment details, set the route. This can be used - // for probing and to extract a fee estimate from the route. - if reason == channeldb.FailureReasonPaymentDetails { - result.Route = rt - } - - return result } // Fail transitions a payment into the Failed state, and records the reason the @@ -156,14 +155,16 @@ func createFailedResult(rt *route.Route, func (p *controlTower) Fail(paymentHash lntypes.Hash, reason channeldb.FailureReason) error { - route, err := p.db.Fail(paymentHash, reason) + payment, err := p.db.Fail(paymentHash, reason) if err != nil { return err } // Notify subscribers of fail event. p.notifyFinalEvent( - paymentHash, createFailedResult(route, reason), + paymentHash, createFailedResult( + payment.HTLCs, reason, + ), ) return nil @@ -213,20 +214,14 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) ( // a subscriber, because we can send the result on the channel // immediately. case channeldb.StatusSucceeded: - event = *createSuccessResult( - &payment.Attempt.Route, *payment.Preimage, - ) + event = *createSuccessResult(payment.HTLCs) // Payment already failed. It is not necessary to register as a // subscriber, because we can send the result on the channel // immediately. case channeldb.StatusFailed: - var route *route.Route - if payment.Attempt != nil { - route = &payment.Attempt.Route - } event = *createFailedResult( - route, *payment.Failure, + payment.HTLCs, *payment.FailureReason, ) default: diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 49cc6d43..e326e98c 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -144,10 +144,13 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { if result.Preimage != preimg { t.Fatal("unexpected preimage") } - - if !reflect.DeepEqual(result.Route, &attempt.Route) { - t.Fatalf("unexpected route: %v vs %v", - spew.Sdump(result.Route), + if len(result.HTLCs) != 1 { + t.Fatalf("expected one htlc, got %d", len(result.HTLCs)) + } + htlc := result.HTLCs[0] + if !reflect.DeepEqual(htlc.Route, attempt.Route) { + t.Fatalf("unexpected htlc route: %v vs %v", + spew.Sdump(htlc.Route), spew.Sdump(attempt.Route)) } @@ -168,6 +171,15 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { func TestPaymentControlSubscribeFail(t *testing.T) { t.Parallel() + t.Run("register attempt", func(t *testing.T) { + testPaymentControlSubscribeFail(t, true) + }) + t.Run("no register attempt", func(t *testing.T) { + testPaymentControlSubscribeFail(t, false) + }) +} + +func testPaymentControlSubscribeFail(t *testing.T, registerAttempt bool) { db, err := initDB() if err != nil { t.Fatalf("unable to init db: %v", err) @@ -176,7 +188,7 @@ func TestPaymentControlSubscribeFail(t *testing.T) { pControl := NewControlTower(channeldb.NewPaymentControl(db)) // Initiate a payment. - info, _, _, err := genInfo() + info, attempt, _, err := genInfo() if err != nil { t.Fatal(err) } @@ -192,6 +204,17 @@ func TestPaymentControlSubscribeFail(t *testing.T) { t.Fatalf("expected subscribe to succeed, but got: %v", err) } + // Conditionally register the attempt based on the test type. This + // allows us to simulate failing after attempting with an htlc or before + // making any attempts at all. + if registerAttempt { + // Register an attempt. + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatal(err) + } + } + // Mark the payment as failed. if err := pControl.Fail(info.PaymentHash, channeldb.FailureReasonTimeout); err != nil { t.Fatal(err) @@ -223,9 +246,28 @@ func TestPaymentControlSubscribeFail(t *testing.T) { if result.Success { t.Fatal("unexpected payment state") } - if result.Route != nil { - t.Fatal("expected no route") + + // There will either be one or zero htlcs depending on whether + // or not the attempt was registered. Assert the correct number + // is present, and the route taken if the attempt was + // registered. + if registerAttempt { + if len(result.HTLCs) != 1 { + t.Fatalf("expected 1 htlc, got: %d", + len(result.HTLCs)) + } + + htlc := result.HTLCs[0] + if !reflect.DeepEqual(htlc.Route, testRoute) { + t.Fatalf("unexpected htlc route: %v vs %v", + spew.Sdump(htlc.Route), + spew.Sdump(testRoute)) + } + } else if len(result.HTLCs) != 0 { + t.Fatalf("expected 0 htlcs, got: %d", + len(result.HTLCs)) } + if result.FailureReason != channeldb.FailureReasonTimeout { t.Fatal("unexpected failure reason") }