From 2417f40532c239b6c4f8f2bae6adce389aa29995 Mon Sep 17 00:00:00 2001 From: "Johan T. Halseth" Date: Thu, 23 May 2019 20:05:26 +0200 Subject: [PATCH] channeldb: put payment status in new bucket We move the payment status to a new bucket hierarchy. Old buckets and fetch methods are kept around for migration purposes. --- channeldb/control_tower.go | 70 ++++++++++++++++++++++----------- channeldb/control_tower_test.go | 22 +++++++++-- channeldb/payments.go | 33 +++++----------- channeldb/payments_test.go | 48 ---------------------- 4 files changed, 76 insertions(+), 97 deletions(-) diff --git a/channeldb/control_tower.go b/channeldb/control_tower.go index 9fe93736..82e3c1ea 100644 --- a/channeldb/control_tower.go +++ b/channeldb/control_tower.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" ) @@ -81,14 +82,14 @@ func NewPaymentControl(strict bool, db *DB) ControlTower { func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error { var takeoffErr error err := p.db.Batch(func(tx *bbolt.Tx) error { - // Retrieve current status of payment from local database. - paymentStatus, err := FetchPaymentStatusTx( - tx, htlc.PaymentHash, - ) + bucket, err := fetchPaymentBucket(tx, htlc.PaymentHash) if err != nil { return err } + // Get the existing status of this payment, if any. + paymentStatus := fetchPaymentStatus(bucket) + // Reset the takeoff error, to avoid carrying over an error // from a previous execution of the batched db transaction. takeoffErr = nil @@ -98,11 +99,9 @@ func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error { case StatusGrounded: // It is safe to reattempt a payment if we know that we // haven't left one in flight. Since this one is - // grounded, Transition the payment status to InFlight - // to prevent others. - return UpdatePaymentStatusTx( - tx, htlc.PaymentHash, StatusInFlight, - ) + // grounded or failed, transition the payment status + // to InFlight to prevent others. + return bucket.Put(paymentStatusKey, StatusInFlight.Bytes()) case StatusInFlight: // We already have an InFlight payment on the network. We will @@ -133,13 +132,14 @@ func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error { func (p *paymentControl) Success(paymentHash [32]byte) error { var updateErr error err := p.db.Batch(func(tx *bbolt.Tx) error { - paymentStatus, err := FetchPaymentStatusTx( - tx, paymentHash, - ) + bucket, err := fetchPaymentBucket(tx, paymentHash) if err != nil { return err } + // Get the existing status, if any. + paymentStatus := fetchPaymentStatus(bucket) + // Reset the update error, to avoid carrying over an error // from a previous execution of the batched db transaction. updateErr = nil @@ -162,9 +162,7 @@ func (p *paymentControl) Success(paymentHash [32]byte) error { // A successful response was received for an InFlight // payment, mark it as completed to prevent sending to // this payment hash again. - return UpdatePaymentStatusTx( - tx, paymentHash, StatusCompleted, - ) + return bucket.Put(paymentStatusKey, StatusCompleted.Bytes()) case paymentStatus == StatusCompleted: // The payment was completed previously, alert the @@ -190,13 +188,13 @@ func (p *paymentControl) Success(paymentHash [32]byte) error { func (p *paymentControl) Fail(paymentHash [32]byte) error { var updateErr error err := p.db.Batch(func(tx *bbolt.Tx) error { - paymentStatus, err := FetchPaymentStatusTx( - tx, paymentHash, - ) + bucket, err := fetchPaymentBucket(tx, paymentHash) if err != nil { return err } + paymentStatus := fetchPaymentStatus(bucket) + // Reset the update error, to avoid carrying over an error // from a previous execution of the batched db transaction. updateErr = nil @@ -217,11 +215,9 @@ func (p *paymentControl) Fail(paymentHash [32]byte) error { case paymentStatus == StatusInFlight: // A failed response was received for an InFlight - // payment, mark it as Grounded again to allow - // subsequent attempts. - return UpdatePaymentStatusTx( - tx, paymentHash, StatusGrounded, - ) + // payment, mark it as Failed to allow subsequent + // attempts. + return bucket.Put(paymentStatusKey, StatusGrounded.Bytes()) case paymentStatus == StatusCompleted: // The payment was completed previously, and we are now @@ -242,3 +238,31 @@ func (p *paymentControl) Fail(paymentHash [32]byte) error { return updateErr } + +// fetchPaymentBucket fetches or creates the sub-bucket assigned to this +// payment hash. +func fetchPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( + *bbolt.Bucket, error) { + + payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) + if err != nil { + return nil, err + } + + return payments.CreateBucketIfNotExists(paymentHash[:]) +} + +// fetchPaymentStatus fetches the payment status from the bucket. If the +// status isn't found, it will default to "StatusGrounded". +func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus { + // The default status for all payments that aren't recorded in + // database. + var paymentStatus = StatusGrounded + + paymentStatusBytes := bucket.Get(paymentStatusKey) + if paymentStatusBytes != nil { + paymentStatus.FromBytes(paymentStatusBytes) + } + + return paymentStatus +} diff --git a/channeldb/control_tower_test.go b/channeldb/control_tower_test.go index c4d97711..42df8bf4 100644 --- a/channeldb/control_tower_test.go +++ b/channeldb/control_tower_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/btcsuite/fastsha256" + "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lnwire" ) @@ -363,13 +364,28 @@ func assertPaymentStatus(t *testing.T, db *DB, t.Helper() - pStatus, err := db.FetchPaymentStatus(hash) + var paymentStatus = StatusGrounded + err := db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + bucket := payments.Bucket(hash[:]) + if bucket == nil { + return nil + } + + // Get the existing status of this payment, if any. + paymentStatus = fetchPaymentStatus(bucket) + return nil + }) if err != nil { t.Fatalf("unable to fetch payment status: %v", err) } - if pStatus != expStatus { + if paymentStatus != expStatus { t.Fatalf("payment status mismatch: expected %v, got %v", - expStatus, pStatus) + expStatus, paymentStatus) } } diff --git a/channeldb/payments.go b/channeldb/payments.go index 08c9e02f..89997f45 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -12,6 +12,16 @@ import ( ) var ( + // paymentsRootBucket is the name of the top-level bucket within the + // database that stores all data related to payments. Within this + // bucket, each payment hash its own sub-bucket keyed by its payment + // hash. + paymentsRootBucket = []byte("payments-root-bucket") + + // paymentStatusKey is a key used in the payment's sub-bucket to store + // the status of the payment. + paymentStatusKey = []byte("payment-status-key") + // paymentBucket is the name of the bucket within the database that // stores all data related to payments. // @@ -188,29 +198,6 @@ func (db *DB) DeleteAllPayments() error { }) } -// UpdatePaymentStatus sets the payment status for outgoing/finished payments in -// local database. -func (db *DB) UpdatePaymentStatus(paymentHash [32]byte, status PaymentStatus) error { - return db.Batch(func(tx *bbolt.Tx) error { - return UpdatePaymentStatusTx(tx, paymentHash, status) - }) -} - -// UpdatePaymentStatusTx is a helper method that sets the payment status for -// outgoing/finished payments in the local database. This method accepts a -// boltdb transaction such that the operation can be composed into other -// database transactions. -func UpdatePaymentStatusTx(tx *bbolt.Tx, - paymentHash [32]byte, status PaymentStatus) error { - - paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) - if err != nil { - return err - } - - return paymentStatuses.Put(paymentHash[:], status.Bytes()) -} - // FetchPaymentStatus returns the payment status for outgoing payment. // If status of the payment isn't found, it will default to "StatusGrounded". func (db *DB) FetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 80a61be2..bab94e75 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -228,54 +228,6 @@ func TestOutgoingPaymentWorkflow(t *testing.T) { } } -func TestPaymentStatusWorkflow(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - testCases := []struct { - paymentHash [32]byte - status PaymentStatus - }{ - { - paymentHash: makeFakePaymentHash(), - status: StatusGrounded, - }, - { - paymentHash: makeFakePaymentHash(), - status: StatusInFlight, - }, - { - paymentHash: makeFakePaymentHash(), - status: StatusCompleted, - }, - } - - for _, testCase := range testCases { - err := db.UpdatePaymentStatus(testCase.paymentHash, testCase.status) - if err != nil { - t.Fatalf("unable to put payment in DB: %v", err) - } - - status, err := db.FetchPaymentStatus(testCase.paymentHash) - if err != nil { - t.Fatalf("unable to fetch payments from DB: %v", err) - } - - if status != testCase.status { - t.Fatalf("Wrong payments status after reading from DB."+ - "Got %v, want %v", - spew.Sdump(status), - spew.Sdump(testCase.status), - ) - } - } -} - func TestRouteSerialization(t *testing.T) { t.Parallel()