diff --git a/channeldb/codec.go b/channeldb/codec.go index 78d61694..f6903175 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -192,6 +192,11 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case paymentIndexType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + case lnwire.FundingFlag: if err := binary.Write(w, byteOrder, e); err != nil { return err @@ -406,6 +411,11 @@ func ReadElement(r io.Reader, element interface{}) error { return err } + case *paymentIndexType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + case *lnwire.FundingFlag: if err := binary.Read(r, byteOrder, e); err != nil { return err diff --git a/channeldb/db.go b/channeldb/db.go index 1347a58f..564dbb7b 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -16,6 +16,7 @@ import ( mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" + "github.com/lightningnetwork/lnd/channeldb/migration16" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" @@ -144,6 +145,19 @@ var ( number: 14, migration: mig.CreateTLB(payAddrIndexBucket), }, + { + // Initialize payment index bucket which will be used + // to index payments by sequence number. This index will + // be used to allow more efficient ListPayments queries. + number: 15, + migration: mig.CreateTLB(paymentsIndexBucket), + }, + { + // Add our existing payments to the index bucket created + // in migration 15. + number: 16, + migration: migration16.MigrateSequenceIndex, + }, } // Big endian is the preferred byte order, due to cursor scans over @@ -257,6 +271,7 @@ var topLevelBuckets = [][]byte{ fwdPackagesKey, invoiceBucket, payAddrIndexBucket, + paymentsIndexBucket, nodeInfoBucket, nodeBucket, edgeBucket, diff --git a/channeldb/log.go b/channeldb/log.go index f59426f0..75ba2a5f 100644 --- a/channeldb/log.go +++ b/channeldb/log.go @@ -6,6 +6,7 @@ import ( mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" + "github.com/lightningnetwork/lnd/channeldb/migration16" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" ) @@ -33,4 +34,5 @@ func UseLogger(logger btclog.Logger) { migration_01_to_11.UseLogger(logger) migration12.UseLogger(logger) migration13.UseLogger(logger) + migration16.UseLogger(logger) } diff --git a/channeldb/migration16/log.go b/channeldb/migration16/log.go new file mode 100644 index 00000000..cb946854 --- /dev/null +++ b/channeldb/migration16/log.go @@ -0,0 +1,14 @@ +package migration16 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/channeldb/migration16/migration.go b/channeldb/migration16/migration.go new file mode 100644 index 00000000..b984f083 --- /dev/null +++ b/channeldb/migration16/migration.go @@ -0,0 +1,191 @@ +package migration16 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb/kvdb" +) + +var ( + paymentsRootBucket = []byte("payments-root-bucket") + + paymentSequenceKey = []byte("payment-sequence-key") + + duplicatePaymentsBucket = []byte("payment-duplicate-bucket") + + paymentsIndexBucket = []byte("payments-index-bucket") + + byteOrder = binary.BigEndian +) + +// paymentIndexType indicates the type of index we have recorded in the payment +// indexes bucket. +type paymentIndexType uint8 + +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +const paymentIndexTypeHash paymentIndexType = 0 + +// paymentIndex stores all the information we require to create an index by +// sequence number for a payment. +type paymentIndex struct { + // paymentHash is the hash of the payment, which is its key in the + // payment root bucket. + paymentHash []byte + + // sequenceNumbers is the set of sequence numbers associated with this + // payment hash. There will be more than one sequence number in the + // case where duplicate payments are present. + sequenceNumbers [][]byte +} + +// MigrateSequenceIndex migrates the payments db to contain a new bucket which +// provides an index from sequence number to payment hash. This is required +// for more efficient sequential lookup of payments, which are keyed by payment +// hash before this migration. +func MigrateSequenceIndex(tx kvdb.RwTx) error { + log.Infof("Migrating payments to add sequence number index") + + // Get a list of indices we need to write. + indexList, err := getPaymentIndexList(tx) + if err != nil { + return err + } + + // Create the top level bucket that we will use to index payments in. + bucket, err := tx.CreateTopLevelBucket(paymentsIndexBucket) + if err != nil { + return err + } + + // Write an index for each of our payments. + for _, index := range indexList { + // Write indexes for each of our sequence numbers. + for _, seqNr := range index.sequenceNumbers { + err := putIndex(bucket, seqNr, index.paymentHash) + if err != nil { + return err + } + } + } + + return nil +} + +// putIndex performs a sanity check that ensures we are not writing duplicate +// indexes to disk then creates the index provided. +func putIndex(bucket kvdb.RwBucket, sequenceNr, paymentHash []byte) error { + // Add a sanity check that we do not already have an entry with + // this sequence number. + existingEntry := bucket.Get(sequenceNr) + if existingEntry != nil { + return fmt.Errorf("sequence number: %x duplicated", + sequenceNr) + } + + bytes, err := serializePaymentIndexEntry(paymentHash) + if err != nil { + return err + } + + return bucket.Put(sequenceNr, bytes) +} + +// serializePaymentIndexEntry serializes a payment hash typed index. The value +// produced contains a payment index type (which can be used in future to +// signal different payment index types) and the payment hash. +func serializePaymentIndexEntry(hash []byte) ([]byte, error) { + var b bytes.Buffer + + err := binary.Write(&b, byteOrder, paymentIndexTypeHash) + if err != nil { + return nil, err + } + + if err := wire.WriteVarBytes(&b, 0, hash); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// getPaymentIndexList gets a list of indices we need to write for our current +// set of payments. +func getPaymentIndexList(tx kvdb.RTx) ([]paymentIndex, error) { + // Iterate over all payments and store their indexing keys. This is + // needed, because no modifications are allowed inside a Bucket.ForEach + // loop. + paymentsBucket := tx.ReadBucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil, nil + } + + var indexList []paymentIndex + err := paymentsBucket.ForEach(func(k, v []byte) error { + // Get the bucket which contains the payment, fail if the key + // does not have a bucket. + bucket := paymentsBucket.NestedReadBucket(k) + if bucket == nil { + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return fmt.Errorf("nil sequence number bytes") + } + + seqNrs, err := fetchSequenceNumbers(bucket) + if err != nil { + return err + } + + // Create an index object with our payment hash and sequence + // numbers and append it to our set of indexes. + index := paymentIndex{ + paymentHash: k, + sequenceNumbers: seqNrs, + } + + indexList = append(indexList, index) + return nil + }) + if err != nil { + return nil, err + } + + return indexList, nil +} + +// fetchSequenceNumbers fetches all the sequence numbers associated with a +// payment, including those belonging to any duplicate payments. +func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) { + seqNum := paymentBucket.Get(paymentSequenceKey) + if seqNum == nil { + return nil, errors.New("expected sequence number") + } + + sequenceNumbers := [][]byte{seqNum} + + // Get the duplicate payments bucket, if it has no duplicates, just + // return early with the payment sequence number. + duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket) + if duplicates == nil { + return sequenceNumbers, nil + } + + // If we do have duplicated, they are keyed by sequence number, so we + // iterate through the duplicates bucket and add them to our set of + // sequence numbers. + if err := duplicates.ForEach(func(k, v []byte) error { + sequenceNumbers = append(sequenceNumbers, k) + return nil + }); err != nil { + return nil, err + } + + return sequenceNumbers, nil +} diff --git a/channeldb/migration16/migration_test.go b/channeldb/migration16/migration_test.go new file mode 100644 index 00000000..626bedcb --- /dev/null +++ b/channeldb/migration16/migration_test.go @@ -0,0 +1,144 @@ +package migration16 + +import ( + "encoding/hex" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" + "github.com/lightningnetwork/lnd/channeldb/migtest" +) + +var ( + hexStr = migtest.Hex + + hash1Str = "02acee76ebd53d00824410cf6adecad4f50334dac702bd5a2d3ba01b91709f0e" + hash1 = hexStr(hash1Str) + paymentID1 = hexStr("0000000000000001") + + hash2Str = "62eb3f0a48f954e495d0c14ac63df04a67cefa59dafdbcd3d5046d1f5647840c" + hash2 = hexStr(hash2Str) + paymentID2 = hexStr("0000000000000002") + + paymentID3 = hexStr("0000000000000003") + + // pre is the data in the payments root bucket in database version 13 format. + pre = map[string]interface{}{ + // A payment without duplicates. + hash1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + }, + + // A payment with a duplicate. + hash2: map[string]interface{}{ + "payment-sequence-key": paymentID2, + "payment-duplicate-bucket": map[string]interface{}{ + paymentID3: map[string]interface{}{ + "payment-sequence-key": paymentID3, + }, + }, + }, + } + + preFails = map[string]interface{}{ + // A payment without duplicates. + hash1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + "payment-duplicate-bucket": map[string]interface{}{ + paymentID1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + }, + }, + }, + } + + // post is the expected data after migration. + post = map[string]interface{}{ + paymentID1: paymentHashIndex(hash1Str), + paymentID2: paymentHashIndex(hash2Str), + paymentID3: paymentHashIndex(hash2Str), + } +) + +// paymentHashIndex produces a string that represents the value we expect for +// our payment indexes from a hex encoded payment hash string. +func paymentHashIndex(hashStr string) string { + hash, err := hex.DecodeString(hashStr) + if err != nil { + panic(err) + } + + bytes, err := serializePaymentIndexEntry(hash) + if err != nil { + panic(err) + } + + return string(bytes) +} + +// MigrateSequenceIndex asserts that the database is properly migrated to +// contain a payments index. +func TestMigrateSequenceIndex(t *testing.T) { + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + post map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: pre, + post: post, + }, + { + name: "duplicate sequence number", + shouldFail: true, + pre: preFails, + post: post, + }, + { + name: "no payments", + shouldFail: false, + pre: nil, + post: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + // Before the migration we have a payments bucket. + before := func(tx kvdb.RwTx) error { + return migtest.RestoreDB( + tx, paymentsRootBucket, test.pre, + ) + } + + // After the migration, we should have an untouched + // payments bucket and a new index bucket. + after := func(tx kvdb.RwTx) error { + if err := migtest.VerifyDB( + tx, paymentsRootBucket, test.pre, + ); err != nil { + return err + } + + // If we expect our migration to fail, we don't + // expect an index bucket. + if test.shouldFail { + return nil + } + + return migtest.VerifyDB( + tx, paymentsIndexBucket, test.post, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateSequenceIndex, + test.shouldFail, + ) + }) + } +} diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index 99d2000c..5a538134 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" @@ -74,6 +75,11 @@ var ( // errNoAttemptInfo is returned when no attempt info is stored yet. errNoAttemptInfo = errors.New("unable to find attempt info for " + "inflight payment") + + // errNoSequenceNrIndex is returned when an attempt to lookup a payment + // index is made for a sequence number that is not indexed. + errNoSequenceNrIndex = errors.New("payment sequence number index " + + "does not exist") ) // PaymentControl implements persistence for payments and payment attempts. @@ -152,6 +158,27 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, return err } + // Before we set our new sequence number, we check whether this + // payment has a previously set sequence number and remove its + // index entry if it exists. This happens in the case where we + // have a previously attempted payment which was left in a state + // where we can retry. + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes != nil { + indexBucket := tx.ReadWriteBucket(paymentsIndexBucket) + if err := indexBucket.Delete(seqBytes); err != nil { + return err + } + } + + // Once we have obtained a sequence number, we add an entry + // to our index bucket which will map the sequence number to + // our payment hash. + err = createPaymentIndexEntry(tx, sequenceNum, info.PaymentHash) + if err != nil { + return err + } + err = bucket.Put(paymentSequenceKey, sequenceNum) if err != nil { return err @@ -183,6 +210,58 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, return updateErr } +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +type paymentIndexType uint8 + +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +const paymentIndexTypeHash paymentIndexType = 0 + +// createPaymentIndexEntry creates a payment hash typed index for a payment. The +// index produced contains a payment index type (which can be used in future to +// signal different payment index types) and the payment hash. +func createPaymentIndexEntry(tx kvdb.RwTx, sequenceNumber []byte, + hash lntypes.Hash) error { + + var b bytes.Buffer + if err := WriteElements(&b, paymentIndexTypeHash, hash[:]); err != nil { + return err + } + + indexes := tx.ReadWriteBucket(paymentsIndexBucket) + return indexes.Put(sequenceNumber, b.Bytes()) +} + +// deserializePaymentIndex deserializes a payment index entry. This function +// currently only supports deserialization of payment hash indexes, and will +// fail for other types. +func deserializePaymentIndex(r io.Reader) (lntypes.Hash, error) { + var ( + indexType paymentIndexType + paymentHash []byte + ) + + if err := ReadElements(r, &indexType, &paymentHash); err != nil { + return lntypes.Hash{}, err + } + + // While we only have on payment index type, we do not need to use our + // index type to deserialize the index. However, we sanity check that + // this type is as expected, since we had to read it out anyway. + if indexType != paymentIndexTypeHash { + return lntypes.Hash{}, fmt.Errorf("unknown payment index "+ + "type: %v", indexType) + } + + hash, err := lntypes.MakeHash(paymentHash) + if err != nil { + return lntypes.Hash{}, err + } + + return hash, nil +} + // RegisterAttempt atomically records the provided HTLCAttemptInfo to the // DB. func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index c470a8f5..147e5452 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -1,6 +1,7 @@ package channeldb import ( + "bytes" "crypto/rand" "crypto/sha256" "fmt" @@ -9,9 +10,13 @@ import ( "testing" "time" + "github.com/btcsuite/btcwallet/walletdb" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/record" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func genPreimage() ([32]byte, error) { @@ -70,6 +75,7 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -88,6 +94,11 @@ func TestPaymentControlSwitchFail(t *testing.T) { t, pControl, info.PaymentHash, info, &failReason, nil, ) + // Lookup the payment so we can get its old sequence number before it is + // overwritten. + payment, err := pControl.FetchPayment(info.PaymentHash) + assert.NoError(t, err) + // Sends the htlc again, which should succeed since the prior payment // failed. err = pControl.InitPayment(info.PaymentHash, info) @@ -95,6 +106,11 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + // Check that our index has been updated, and the old index has been + // removed. + assertPaymentIndex(t, pControl, info.PaymentHash) + assertNoIndex(t, pControl, payment.SequenceNum) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -145,7 +161,6 @@ func TestPaymentControlSwitchFail(t *testing.T) { // Settle the attempt and verify that status was changed to // StatusSucceeded. - var payment *MPPayment payment, err = pControl.SettleAttempt( info.PaymentHash, attempt.AttemptID, &HTLCSettleInfo{ @@ -209,6 +224,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -326,7 +342,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown) } -// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only +// TestPaymentControlDeleteNonInFlight checks that calling DeletePayments only // deletes payments from the database that are not in-flight. func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Parallel() @@ -338,23 +354,37 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Fatalf("unable to init db: %v", err) } + // Create a sequence number for duplicate payments that will not collide + // with the sequence numbers for the payments we create. These values + // start at 1, so 9999 is a safe bet for this test. + var duplicateSeqNr = 9999 + pControl := NewPaymentControl(db) payments := []struct { - failed bool - success bool + failed bool + success bool + hasDuplicate bool }{ { - failed: true, - success: false, + failed: true, + success: false, + hasDuplicate: false, }, { - failed: false, - success: true, + failed: false, + success: true, + hasDuplicate: false, }, { - failed: false, - success: false, + failed: false, + success: false, + hasDuplicate: false, + }, + { + failed: false, + success: true, + hasDuplicate: true, }, } @@ -430,6 +460,16 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { t, pControl, info.PaymentHash, info, nil, htlc, ) } + + // If the payment is intended to have a duplicate payment, we + // add one. + if p.hasDuplicate { + appendDuplicatePayment( + t, pControl.db, info.PaymentHash, + uint64(duplicateSeqNr), + ) + duplicateSeqNr++ + } } // Delete payments. @@ -451,6 +491,21 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { if status != StatusInFlight { t.Fatalf("expected in-fligth status, got %v", status) } + + // Finally, check that we only have a single index left in the payment + // index bucket. + var indexCount int + err = kvdb.View(db, func(tx walletdb.ReadTx) error { + index := tx.ReadBucket(paymentsIndexBucket) + + return index.ForEach(func(k, v []byte) error { + indexCount++ + return nil + }) + }) + require.NoError(t, err) + + require.Equal(t, 1, indexCount) } // TestPaymentControlMultiShard checks the ability of payment control to @@ -495,6 +550,7 @@ func TestPaymentControlMultiShard(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -910,3 +966,55 @@ func assertPaymentInfo(t *testing.T, p *PaymentControl, hash lntypes.Hash, t.Fatal("expected no settle info") } } + +// fetchPaymentIndexEntry gets the payment hash for the sequence number provided +// from our payment indexes bucket. +func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl, + sequenceNumber uint64) (*lntypes.Hash, error) { + + var hash lntypes.Hash + + if err := kvdb.View(p.db, func(tx walletdb.ReadTx) error { + indexBucket := tx.ReadBucket(paymentsIndexBucket) + key := make([]byte, 8) + byteOrder.PutUint64(key, sequenceNumber) + + indexValue := indexBucket.Get(key) + if indexValue == nil { + return errNoSequenceNrIndex + } + + r := bytes.NewReader(indexValue) + + var err error + hash, err = deserializePaymentIndex(r) + return err + + }); err != nil { + return nil, err + } + + return &hash, nil +} + +// assertPaymentIndex looks up the index for a payment in the db and checks +// that its payment hash matches the expected hash passed in. +func assertPaymentIndex(t *testing.T, p *PaymentControl, + expectedHash lntypes.Hash) { + + // Lookup the payment so that we have its sequence number and check + // that is has correctly been indexed in the payment indexes bucket. + pmt, err := p.FetchPayment(expectedHash) + require.NoError(t, err) + + hash, err := fetchPaymentIndexEntry(t, p, pmt.SequenceNum) + require.NoError(t, err) + assert.Equal(t, expectedHash, *hash) +} + +// assertNoIndex checks that an index for the sequence number provided does not +// exist. +func assertNoIndex(t *testing.T, p *PaymentControl, seqNr uint64) { + _, err := fetchPaymentIndexEntry(t, p, seqNr) + require.Equal(t, errNoSequenceNrIndex, err) +} diff --git a/channeldb/payments.go b/channeldb/payments.go index d1ec6070..3229c093 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -3,6 +3,7 @@ package channeldb import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "math" @@ -92,6 +93,15 @@ var ( // paymentFailInfoKey is a key used in the payment's sub-bucket to // store information about the reason a payment failed. paymentFailInfoKey = []byte("payment-fail-info") + + // paymentsIndexBucket is the name of the top-level bucket within the + // database that stores an index of payment sequence numbers to its + // payment hash. + // payments-sequence-index-bucket + // |--: + // |--... + // |--: + paymentsIndexBucket = []byte("payments-index-bucket") ) // FailureReason encodes the reason a payment ultimately failed. @@ -566,7 +576,15 @@ func (db *DB) DeletePayments() error { return nil } - var deleteBuckets [][]byte + var ( + // deleteBuckets is the set of payment buckets we need + // to delete. + deleteBuckets [][]byte + + // deleteIndexes is the set of indexes pointing to these + // payments that need to be deleted. + deleteIndexes [][]byte + ) err := payments.ForEach(func(k, _ []byte) error { bucket := payments.NestedReadWriteBucket(k) if bucket == nil { @@ -589,7 +607,18 @@ func (db *DB) DeletePayments() error { return nil } + // Add the bucket to the set of buckets we can delete. deleteBuckets = append(deleteBuckets, k) + + // Get all the sequence number associated with the + // payment, including duplicates. + seqNrs, err := fetchSequenceNumbers(bucket) + if err != nil { + return err + } + + deleteIndexes = append(deleteIndexes, seqNrs...) + return nil }) if err != nil { @@ -602,10 +631,49 @@ func (db *DB) DeletePayments() error { } } + // Get our index bucket and delete all indexes pointing to the + // payments we are deleting. + indexBucket := tx.ReadWriteBucket(paymentsIndexBucket) + for _, k := range deleteIndexes { + if err := indexBucket.Delete(k); err != nil { + return err + } + } + return nil }) } +// fetchSequenceNumbers fetches all the sequence numbers associated with a +// payment, including those belonging to any duplicate payments. +func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) { + seqNum := paymentBucket.Get(paymentSequenceKey) + if seqNum == nil { + return nil, errors.New("expected sequence number") + } + + sequenceNumbers := [][]byte{seqNum} + + // Get the duplicate payments bucket, if it has no duplicates, just + // return early with the payment sequence number. + duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket) + if duplicates == nil { + return sequenceNumbers, nil + } + + // If we do have duplicated, they are keyed by sequence number, so we + // iterate through the duplicates bucket and add them to our set of + // sequence numbers. + if err := duplicates.ForEach(func(k, v []byte) error { + sequenceNumbers = append(sequenceNumbers, k) + return nil + }); err != nil { + return nil, err + } + + return sequenceNumbers, nil +} + // nolint: dupl func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { var scratch [8]byte diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 118de16d..706060a2 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -485,6 +485,11 @@ func appendDuplicatePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, // sequence numbers we've setup. putDuplicatePayment(t, dup, sequenceKey[:], paymentHash) + // Finally, once we have created our entry we add an index for + // it. + err = createPaymentIndexEntry(tx, sequenceKey[:], paymentHash) + require.NoError(t, err) + return nil }) if err != nil {