From c8d11285f3350cb93969d61a81046ca1df63e513 Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 10 Jun 2020 12:34:27 +0200 Subject: [PATCH] channeldb: index payments by sequence number Add an entry to a payments index bucket which maps sequence number to payment hash when we initiate payments. This allows for more efficient paginated queries. We create the top level bucket in its own migration so that we do not need to create it on the fly. When we retry payments and provide them with a new sequence number, we delete the index for their existing payment so that we do not have an index that points to a non-existent payment. If we delete a payment, we also delete its index entry. This prevents us from looking up entries from indexes to payments that do not exist. --- channeldb/codec.go | 10 ++ channeldb/db.go | 15 ++ channeldb/log.go | 2 + channeldb/migration16/log.go | 14 ++ channeldb/migration16/migration.go | 191 ++++++++++++++++++++++++ channeldb/migration16/migration_test.go | 144 ++++++++++++++++++ channeldb/payment_control.go | 79 ++++++++++ channeldb/payment_control_test.go | 128 ++++++++++++++-- channeldb/payments.go | 70 ++++++++- channeldb/payments_test.go | 5 + 10 files changed, 647 insertions(+), 11 deletions(-) create mode 100644 channeldb/migration16/log.go create mode 100644 channeldb/migration16/migration.go create mode 100644 channeldb/migration16/migration_test.go diff --git a/channeldb/codec.go b/channeldb/codec.go index 78d61694..f6903175 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -192,6 +192,11 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case paymentIndexType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + case lnwire.FundingFlag: if err := binary.Write(w, byteOrder, e); err != nil { return err @@ -406,6 +411,11 @@ func ReadElement(r io.Reader, element interface{}) error { return err } + case *paymentIndexType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + case *lnwire.FundingFlag: if err := binary.Read(r, byteOrder, e); err != nil { return err diff --git a/channeldb/db.go b/channeldb/db.go index 1347a58f..564dbb7b 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -16,6 +16,7 @@ import ( mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" + "github.com/lightningnetwork/lnd/channeldb/migration16" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" @@ -144,6 +145,19 @@ var ( number: 14, migration: mig.CreateTLB(payAddrIndexBucket), }, + { + // Initialize payment index bucket which will be used + // to index payments by sequence number. This index will + // be used to allow more efficient ListPayments queries. + number: 15, + migration: mig.CreateTLB(paymentsIndexBucket), + }, + { + // Add our existing payments to the index bucket created + // in migration 15. + number: 16, + migration: migration16.MigrateSequenceIndex, + }, } // Big endian is the preferred byte order, due to cursor scans over @@ -257,6 +271,7 @@ var topLevelBuckets = [][]byte{ fwdPackagesKey, invoiceBucket, payAddrIndexBucket, + paymentsIndexBucket, nodeInfoBucket, nodeBucket, edgeBucket, diff --git a/channeldb/log.go b/channeldb/log.go index f59426f0..75ba2a5f 100644 --- a/channeldb/log.go +++ b/channeldb/log.go @@ -6,6 +6,7 @@ import ( mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" + "github.com/lightningnetwork/lnd/channeldb/migration16" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" ) @@ -33,4 +34,5 @@ func UseLogger(logger btclog.Logger) { migration_01_to_11.UseLogger(logger) migration12.UseLogger(logger) migration13.UseLogger(logger) + migration16.UseLogger(logger) } diff --git a/channeldb/migration16/log.go b/channeldb/migration16/log.go new file mode 100644 index 00000000..cb946854 --- /dev/null +++ b/channeldb/migration16/log.go @@ -0,0 +1,14 @@ +package migration16 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/channeldb/migration16/migration.go b/channeldb/migration16/migration.go new file mode 100644 index 00000000..b984f083 --- /dev/null +++ b/channeldb/migration16/migration.go @@ -0,0 +1,191 @@ +package migration16 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb/kvdb" +) + +var ( + paymentsRootBucket = []byte("payments-root-bucket") + + paymentSequenceKey = []byte("payment-sequence-key") + + duplicatePaymentsBucket = []byte("payment-duplicate-bucket") + + paymentsIndexBucket = []byte("payments-index-bucket") + + byteOrder = binary.BigEndian +) + +// paymentIndexType indicates the type of index we have recorded in the payment +// indexes bucket. +type paymentIndexType uint8 + +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +const paymentIndexTypeHash paymentIndexType = 0 + +// paymentIndex stores all the information we require to create an index by +// sequence number for a payment. +type paymentIndex struct { + // paymentHash is the hash of the payment, which is its key in the + // payment root bucket. + paymentHash []byte + + // sequenceNumbers is the set of sequence numbers associated with this + // payment hash. There will be more than one sequence number in the + // case where duplicate payments are present. + sequenceNumbers [][]byte +} + +// MigrateSequenceIndex migrates the payments db to contain a new bucket which +// provides an index from sequence number to payment hash. This is required +// for more efficient sequential lookup of payments, which are keyed by payment +// hash before this migration. +func MigrateSequenceIndex(tx kvdb.RwTx) error { + log.Infof("Migrating payments to add sequence number index") + + // Get a list of indices we need to write. + indexList, err := getPaymentIndexList(tx) + if err != nil { + return err + } + + // Create the top level bucket that we will use to index payments in. + bucket, err := tx.CreateTopLevelBucket(paymentsIndexBucket) + if err != nil { + return err + } + + // Write an index for each of our payments. + for _, index := range indexList { + // Write indexes for each of our sequence numbers. + for _, seqNr := range index.sequenceNumbers { + err := putIndex(bucket, seqNr, index.paymentHash) + if err != nil { + return err + } + } + } + + return nil +} + +// putIndex performs a sanity check that ensures we are not writing duplicate +// indexes to disk then creates the index provided. +func putIndex(bucket kvdb.RwBucket, sequenceNr, paymentHash []byte) error { + // Add a sanity check that we do not already have an entry with + // this sequence number. + existingEntry := bucket.Get(sequenceNr) + if existingEntry != nil { + return fmt.Errorf("sequence number: %x duplicated", + sequenceNr) + } + + bytes, err := serializePaymentIndexEntry(paymentHash) + if err != nil { + return err + } + + return bucket.Put(sequenceNr, bytes) +} + +// serializePaymentIndexEntry serializes a payment hash typed index. The value +// produced contains a payment index type (which can be used in future to +// signal different payment index types) and the payment hash. +func serializePaymentIndexEntry(hash []byte) ([]byte, error) { + var b bytes.Buffer + + err := binary.Write(&b, byteOrder, paymentIndexTypeHash) + if err != nil { + return nil, err + } + + if err := wire.WriteVarBytes(&b, 0, hash); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// getPaymentIndexList gets a list of indices we need to write for our current +// set of payments. +func getPaymentIndexList(tx kvdb.RTx) ([]paymentIndex, error) { + // Iterate over all payments and store their indexing keys. This is + // needed, because no modifications are allowed inside a Bucket.ForEach + // loop. + paymentsBucket := tx.ReadBucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil, nil + } + + var indexList []paymentIndex + err := paymentsBucket.ForEach(func(k, v []byte) error { + // Get the bucket which contains the payment, fail if the key + // does not have a bucket. + bucket := paymentsBucket.NestedReadBucket(k) + if bucket == nil { + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return fmt.Errorf("nil sequence number bytes") + } + + seqNrs, err := fetchSequenceNumbers(bucket) + if err != nil { + return err + } + + // Create an index object with our payment hash and sequence + // numbers and append it to our set of indexes. + index := paymentIndex{ + paymentHash: k, + sequenceNumbers: seqNrs, + } + + indexList = append(indexList, index) + return nil + }) + if err != nil { + return nil, err + } + + return indexList, nil +} + +// fetchSequenceNumbers fetches all the sequence numbers associated with a +// payment, including those belonging to any duplicate payments. +func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) { + seqNum := paymentBucket.Get(paymentSequenceKey) + if seqNum == nil { + return nil, errors.New("expected sequence number") + } + + sequenceNumbers := [][]byte{seqNum} + + // Get the duplicate payments bucket, if it has no duplicates, just + // return early with the payment sequence number. + duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket) + if duplicates == nil { + return sequenceNumbers, nil + } + + // If we do have duplicated, they are keyed by sequence number, so we + // iterate through the duplicates bucket and add them to our set of + // sequence numbers. + if err := duplicates.ForEach(func(k, v []byte) error { + sequenceNumbers = append(sequenceNumbers, k) + return nil + }); err != nil { + return nil, err + } + + return sequenceNumbers, nil +} diff --git a/channeldb/migration16/migration_test.go b/channeldb/migration16/migration_test.go new file mode 100644 index 00000000..626bedcb --- /dev/null +++ b/channeldb/migration16/migration_test.go @@ -0,0 +1,144 @@ +package migration16 + +import ( + "encoding/hex" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" + "github.com/lightningnetwork/lnd/channeldb/migtest" +) + +var ( + hexStr = migtest.Hex + + hash1Str = "02acee76ebd53d00824410cf6adecad4f50334dac702bd5a2d3ba01b91709f0e" + hash1 = hexStr(hash1Str) + paymentID1 = hexStr("0000000000000001") + + hash2Str = "62eb3f0a48f954e495d0c14ac63df04a67cefa59dafdbcd3d5046d1f5647840c" + hash2 = hexStr(hash2Str) + paymentID2 = hexStr("0000000000000002") + + paymentID3 = hexStr("0000000000000003") + + // pre is the data in the payments root bucket in database version 13 format. + pre = map[string]interface{}{ + // A payment without duplicates. + hash1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + }, + + // A payment with a duplicate. + hash2: map[string]interface{}{ + "payment-sequence-key": paymentID2, + "payment-duplicate-bucket": map[string]interface{}{ + paymentID3: map[string]interface{}{ + "payment-sequence-key": paymentID3, + }, + }, + }, + } + + preFails = map[string]interface{}{ + // A payment without duplicates. + hash1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + "payment-duplicate-bucket": map[string]interface{}{ + paymentID1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + }, + }, + }, + } + + // post is the expected data after migration. + post = map[string]interface{}{ + paymentID1: paymentHashIndex(hash1Str), + paymentID2: paymentHashIndex(hash2Str), + paymentID3: paymentHashIndex(hash2Str), + } +) + +// paymentHashIndex produces a string that represents the value we expect for +// our payment indexes from a hex encoded payment hash string. +func paymentHashIndex(hashStr string) string { + hash, err := hex.DecodeString(hashStr) + if err != nil { + panic(err) + } + + bytes, err := serializePaymentIndexEntry(hash) + if err != nil { + panic(err) + } + + return string(bytes) +} + +// MigrateSequenceIndex asserts that the database is properly migrated to +// contain a payments index. +func TestMigrateSequenceIndex(t *testing.T) { + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + post map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: pre, + post: post, + }, + { + name: "duplicate sequence number", + shouldFail: true, + pre: preFails, + post: post, + }, + { + name: "no payments", + shouldFail: false, + pre: nil, + post: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + // Before the migration we have a payments bucket. + before := func(tx kvdb.RwTx) error { + return migtest.RestoreDB( + tx, paymentsRootBucket, test.pre, + ) + } + + // After the migration, we should have an untouched + // payments bucket and a new index bucket. + after := func(tx kvdb.RwTx) error { + if err := migtest.VerifyDB( + tx, paymentsRootBucket, test.pre, + ); err != nil { + return err + } + + // If we expect our migration to fail, we don't + // expect an index bucket. + if test.shouldFail { + return nil + } + + return migtest.VerifyDB( + tx, paymentsIndexBucket, test.post, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateSequenceIndex, + test.shouldFail, + ) + }) + } +} diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index 99d2000c..5a538134 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" @@ -74,6 +75,11 @@ var ( // errNoAttemptInfo is returned when no attempt info is stored yet. errNoAttemptInfo = errors.New("unable to find attempt info for " + "inflight payment") + + // errNoSequenceNrIndex is returned when an attempt to lookup a payment + // index is made for a sequence number that is not indexed. + errNoSequenceNrIndex = errors.New("payment sequence number index " + + "does not exist") ) // PaymentControl implements persistence for payments and payment attempts. @@ -152,6 +158,27 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, return err } + // Before we set our new sequence number, we check whether this + // payment has a previously set sequence number and remove its + // index entry if it exists. This happens in the case where we + // have a previously attempted payment which was left in a state + // where we can retry. + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes != nil { + indexBucket := tx.ReadWriteBucket(paymentsIndexBucket) + if err := indexBucket.Delete(seqBytes); err != nil { + return err + } + } + + // Once we have obtained a sequence number, we add an entry + // to our index bucket which will map the sequence number to + // our payment hash. + err = createPaymentIndexEntry(tx, sequenceNum, info.PaymentHash) + if err != nil { + return err + } + err = bucket.Put(paymentSequenceKey, sequenceNum) if err != nil { return err @@ -183,6 +210,58 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, return updateErr } +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +type paymentIndexType uint8 + +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +const paymentIndexTypeHash paymentIndexType = 0 + +// createPaymentIndexEntry creates a payment hash typed index for a payment. The +// index produced contains a payment index type (which can be used in future to +// signal different payment index types) and the payment hash. +func createPaymentIndexEntry(tx kvdb.RwTx, sequenceNumber []byte, + hash lntypes.Hash) error { + + var b bytes.Buffer + if err := WriteElements(&b, paymentIndexTypeHash, hash[:]); err != nil { + return err + } + + indexes := tx.ReadWriteBucket(paymentsIndexBucket) + return indexes.Put(sequenceNumber, b.Bytes()) +} + +// deserializePaymentIndex deserializes a payment index entry. This function +// currently only supports deserialization of payment hash indexes, and will +// fail for other types. +func deserializePaymentIndex(r io.Reader) (lntypes.Hash, error) { + var ( + indexType paymentIndexType + paymentHash []byte + ) + + if err := ReadElements(r, &indexType, &paymentHash); err != nil { + return lntypes.Hash{}, err + } + + // While we only have on payment index type, we do not need to use our + // index type to deserialize the index. However, we sanity check that + // this type is as expected, since we had to read it out anyway. + if indexType != paymentIndexTypeHash { + return lntypes.Hash{}, fmt.Errorf("unknown payment index "+ + "type: %v", indexType) + } + + hash, err := lntypes.MakeHash(paymentHash) + if err != nil { + return lntypes.Hash{}, err + } + + return hash, nil +} + // RegisterAttempt atomically records the provided HTLCAttemptInfo to the // DB. func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index c470a8f5..147e5452 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -1,6 +1,7 @@ package channeldb import ( + "bytes" "crypto/rand" "crypto/sha256" "fmt" @@ -9,9 +10,13 @@ import ( "testing" "time" + "github.com/btcsuite/btcwallet/walletdb" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/record" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func genPreimage() ([32]byte, error) { @@ -70,6 +75,7 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -88,6 +94,11 @@ func TestPaymentControlSwitchFail(t *testing.T) { t, pControl, info.PaymentHash, info, &failReason, nil, ) + // Lookup the payment so we can get its old sequence number before it is + // overwritten. + payment, err := pControl.FetchPayment(info.PaymentHash) + assert.NoError(t, err) + // Sends the htlc again, which should succeed since the prior payment // failed. err = pControl.InitPayment(info.PaymentHash, info) @@ -95,6 +106,11 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + // Check that our index has been updated, and the old index has been + // removed. + assertPaymentIndex(t, pControl, info.PaymentHash) + assertNoIndex(t, pControl, payment.SequenceNum) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -145,7 +161,6 @@ func TestPaymentControlSwitchFail(t *testing.T) { // Settle the attempt and verify that status was changed to // StatusSucceeded. - var payment *MPPayment payment, err = pControl.SettleAttempt( info.PaymentHash, attempt.AttemptID, &HTLCSettleInfo{ @@ -209,6 +224,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -326,7 +342,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown) } -// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only +// TestPaymentControlDeleteNonInFlight checks that calling DeletePayments only // deletes payments from the database that are not in-flight. func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Parallel() @@ -338,23 +354,37 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Fatalf("unable to init db: %v", err) } + // Create a sequence number for duplicate payments that will not collide + // with the sequence numbers for the payments we create. These values + // start at 1, so 9999 is a safe bet for this test. + var duplicateSeqNr = 9999 + pControl := NewPaymentControl(db) payments := []struct { - failed bool - success bool + failed bool + success bool + hasDuplicate bool }{ { - failed: true, - success: false, + failed: true, + success: false, + hasDuplicate: false, }, { - failed: false, - success: true, + failed: false, + success: true, + hasDuplicate: false, }, { - failed: false, - success: false, + failed: false, + success: false, + hasDuplicate: false, + }, + { + failed: false, + success: true, + hasDuplicate: true, }, } @@ -430,6 +460,16 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { t, pControl, info.PaymentHash, info, nil, htlc, ) } + + // If the payment is intended to have a duplicate payment, we + // add one. + if p.hasDuplicate { + appendDuplicatePayment( + t, pControl.db, info.PaymentHash, + uint64(duplicateSeqNr), + ) + duplicateSeqNr++ + } } // Delete payments. @@ -451,6 +491,21 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { if status != StatusInFlight { t.Fatalf("expected in-fligth status, got %v", status) } + + // Finally, check that we only have a single index left in the payment + // index bucket. + var indexCount int + err = kvdb.View(db, func(tx walletdb.ReadTx) error { + index := tx.ReadBucket(paymentsIndexBucket) + + return index.ForEach(func(k, v []byte) error { + indexCount++ + return nil + }) + }) + require.NoError(t, err) + + require.Equal(t, 1, indexCount) } // TestPaymentControlMultiShard checks the ability of payment control to @@ -495,6 +550,7 @@ func TestPaymentControlMultiShard(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -910,3 +966,55 @@ func assertPaymentInfo(t *testing.T, p *PaymentControl, hash lntypes.Hash, t.Fatal("expected no settle info") } } + +// fetchPaymentIndexEntry gets the payment hash for the sequence number provided +// from our payment indexes bucket. +func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl, + sequenceNumber uint64) (*lntypes.Hash, error) { + + var hash lntypes.Hash + + if err := kvdb.View(p.db, func(tx walletdb.ReadTx) error { + indexBucket := tx.ReadBucket(paymentsIndexBucket) + key := make([]byte, 8) + byteOrder.PutUint64(key, sequenceNumber) + + indexValue := indexBucket.Get(key) + if indexValue == nil { + return errNoSequenceNrIndex + } + + r := bytes.NewReader(indexValue) + + var err error + hash, err = deserializePaymentIndex(r) + return err + + }); err != nil { + return nil, err + } + + return &hash, nil +} + +// assertPaymentIndex looks up the index for a payment in the db and checks +// that its payment hash matches the expected hash passed in. +func assertPaymentIndex(t *testing.T, p *PaymentControl, + expectedHash lntypes.Hash) { + + // Lookup the payment so that we have its sequence number and check + // that is has correctly been indexed in the payment indexes bucket. + pmt, err := p.FetchPayment(expectedHash) + require.NoError(t, err) + + hash, err := fetchPaymentIndexEntry(t, p, pmt.SequenceNum) + require.NoError(t, err) + assert.Equal(t, expectedHash, *hash) +} + +// assertNoIndex checks that an index for the sequence number provided does not +// exist. +func assertNoIndex(t *testing.T, p *PaymentControl, seqNr uint64) { + _, err := fetchPaymentIndexEntry(t, p, seqNr) + require.Equal(t, errNoSequenceNrIndex, err) +} diff --git a/channeldb/payments.go b/channeldb/payments.go index d1ec6070..3229c093 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -3,6 +3,7 @@ package channeldb import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "math" @@ -92,6 +93,15 @@ var ( // paymentFailInfoKey is a key used in the payment's sub-bucket to // store information about the reason a payment failed. paymentFailInfoKey = []byte("payment-fail-info") + + // paymentsIndexBucket is the name of the top-level bucket within the + // database that stores an index of payment sequence numbers to its + // payment hash. + // payments-sequence-index-bucket + // |--: + // |--... + // |--: + paymentsIndexBucket = []byte("payments-index-bucket") ) // FailureReason encodes the reason a payment ultimately failed. @@ -566,7 +576,15 @@ func (db *DB) DeletePayments() error { return nil } - var deleteBuckets [][]byte + var ( + // deleteBuckets is the set of payment buckets we need + // to delete. + deleteBuckets [][]byte + + // deleteIndexes is the set of indexes pointing to these + // payments that need to be deleted. + deleteIndexes [][]byte + ) err := payments.ForEach(func(k, _ []byte) error { bucket := payments.NestedReadWriteBucket(k) if bucket == nil { @@ -589,7 +607,18 @@ func (db *DB) DeletePayments() error { return nil } + // Add the bucket to the set of buckets we can delete. deleteBuckets = append(deleteBuckets, k) + + // Get all the sequence number associated with the + // payment, including duplicates. + seqNrs, err := fetchSequenceNumbers(bucket) + if err != nil { + return err + } + + deleteIndexes = append(deleteIndexes, seqNrs...) + return nil }) if err != nil { @@ -602,10 +631,49 @@ func (db *DB) DeletePayments() error { } } + // Get our index bucket and delete all indexes pointing to the + // payments we are deleting. + indexBucket := tx.ReadWriteBucket(paymentsIndexBucket) + for _, k := range deleteIndexes { + if err := indexBucket.Delete(k); err != nil { + return err + } + } + return nil }) } +// fetchSequenceNumbers fetches all the sequence numbers associated with a +// payment, including those belonging to any duplicate payments. +func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) { + seqNum := paymentBucket.Get(paymentSequenceKey) + if seqNum == nil { + return nil, errors.New("expected sequence number") + } + + sequenceNumbers := [][]byte{seqNum} + + // Get the duplicate payments bucket, if it has no duplicates, just + // return early with the payment sequence number. + duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket) + if duplicates == nil { + return sequenceNumbers, nil + } + + // If we do have duplicated, they are keyed by sequence number, so we + // iterate through the duplicates bucket and add them to our set of + // sequence numbers. + if err := duplicates.ForEach(func(k, v []byte) error { + sequenceNumbers = append(sequenceNumbers, k) + return nil + }); err != nil { + return nil, err + } + + return sequenceNumbers, nil +} + // nolint: dupl func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { var scratch [8]byte diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 118de16d..706060a2 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -485,6 +485,11 @@ func appendDuplicatePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, // sequence numbers we've setup. putDuplicatePayment(t, dup, sequenceKey[:], paymentHash) + // Finally, once we have created our entry we add an index for + // it. + err = createPaymentIndexEntry(tx, sequenceKey[:], paymentHash) + require.NoError(t, err) + return nil }) if err != nil {