From 6c4a1f4f99ece18c009a02d5959b5f1d8e18ddd8 Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 10 Jun 2020 12:34:27 +0200 Subject: [PATCH 1/6] channeldb: update TestQueryPayments to cover duplicate payments Update our current tests to include lookup of duplicate payments. We do so in preparation for changing our lookup to be based on a new payments index. We add an append duplicate function which will add a duplicate payment with the minimum information required to successfully read it from disk in tests. --- channeldb/payments_test.go | 112 ++++++++++++++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 2 deletions(-) diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 2f0d88bc..118de16d 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -9,11 +9,13 @@ import ( "time" "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcwallet/walletdb" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" ) var ( @@ -188,6 +190,10 @@ func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) { func TestQueryPayments(t *testing.T) { // Define table driven test for QueryPayments. // Test payments have sequence indices [1, 3, 4, 5, 6, 7]. + // Note that the payment with index 7 has the same payment hash as 6, + // and is stored in a nested bucket within payment 6 rather than being + // its own entry in the payments bucket. We do this to test retrieval + // of legacy payments. tests := []struct { name string query PaymentsQuery @@ -359,10 +365,16 @@ func TestQueryPayments(t *testing.T) { } // Populate the database with a set of test payments. - numberOfPayments := 7 + // We create 6 original payments, deleting the payment + // at index 2 so that we cover the case where sequence + // numbers are missing. We also add a duplicate payment + // to the last payment added to test the legacy case + // where we have duplicates in the nested duplicates + // bucket. + nonDuplicatePayments := 6 pControl := NewPaymentControl(db) - for i := 0; i < numberOfPayments; i++ { + for i := 0; i < nonDuplicatePayments; i++ { // Generate a test payment. info, _, _, err := genInfo() if err != nil { @@ -381,6 +393,22 @@ func TestQueryPayments(t *testing.T) { if i == 1 { deletePayment(t, db, info.PaymentHash) } + + // If we are on the last payment entry, add a + // duplicate payment with sequence number equal + // to the parent payment + 1. + if i == (nonDuplicatePayments - 1) { + pmt, err := pControl.FetchPayment( + info.PaymentHash, + ) + require.NoError(t, err) + + appendDuplicatePayment( + t, pControl.db, + info.PaymentHash, + pmt.SequenceNum+1, + ) + } } // Fetch all payments in the database. @@ -424,3 +452,83 @@ func TestQueryPayments(t *testing.T) { }) } } + +// appendDuplicatePayment adds a duplicate payment to an existing payment. Note +// that this function requires a unique sequence number. +// +// This code is *only* intended to replicate legacy duplicate payments in lnd, +// our current schema does not allow duplicates. +func appendDuplicatePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, + seqNr uint64) { + + err := kvdb.Update(db, func(tx walletdb.ReadWriteTx) error { + bucket, err := fetchPaymentBucketUpdate( + tx, paymentHash, + ) + if err != nil { + return err + } + + // Create the duplicates bucket if it is not + // present. + dup, err := bucket.CreateBucketIfNotExists( + duplicatePaymentsBucket, + ) + if err != nil { + return err + } + + var sequenceKey [8]byte + byteOrder.PutUint64(sequenceKey[:], seqNr) + + // Create duplicate payments for the two dup + // sequence numbers we've setup. + putDuplicatePayment(t, dup, sequenceKey[:], paymentHash) + + return nil + }) + if err != nil { + t.Fatalf("could not create payment: %v", err) + } +} + +// putDuplicatePayment creates a duplicate payment in the duplicates bucket +// provided with the minimal information required for successful reading. +func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, + sequenceKey []byte, paymentHash lntypes.Hash) { + + paymentBucket, err := duplicateBucket.CreateBucketIfNotExists( + sequenceKey, + ) + require.NoError(t, err) + + err = paymentBucket.Put(duplicatePaymentSequenceKey, sequenceKey) + require.NoError(t, err) + + // Generate fake information for the duplicate payment. + info, _, _, err := genInfo() + require.NoError(t, err) + + // Write the payment info to disk under the creation info key. This code + // is copied rather than using serializePaymentCreationInfo to ensure + // we always write in the legacy format used by duplicate payments. + var b bytes.Buffer + var scratch [8]byte + _, err = b.Write(paymentHash[:]) + require.NoError(t, err) + + byteOrder.PutUint64(scratch[:], uint64(info.Value)) + _, err = b.Write(scratch[:]) + require.NoError(t, err) + + err = serializeTime(&b, info.CreationTime) + require.NoError(t, err) + + byteOrder.PutUint32(scratch[:4], 0) + _, err = b.Write(scratch[:4]) + require.NoError(t, err) + + // Get the PaymentCreationInfo. + err = paymentBucket.Put(duplicatePaymentCreationInfoKey, b.Bytes()) + require.NoError(t, err) +} From c8d11285f3350cb93969d61a81046ca1df63e513 Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 10 Jun 2020 12:34:27 +0200 Subject: [PATCH 2/6] channeldb: index payments by sequence number Add an entry to a payments index bucket which maps sequence number to payment hash when we initiate payments. This allows for more efficient paginated queries. We create the top level bucket in its own migration so that we do not need to create it on the fly. When we retry payments and provide them with a new sequence number, we delete the index for their existing payment so that we do not have an index that points to a non-existent payment. If we delete a payment, we also delete its index entry. This prevents us from looking up entries from indexes to payments that do not exist. --- channeldb/codec.go | 10 ++ channeldb/db.go | 15 ++ channeldb/log.go | 2 + channeldb/migration16/log.go | 14 ++ channeldb/migration16/migration.go | 191 ++++++++++++++++++++++++ channeldb/migration16/migration_test.go | 144 ++++++++++++++++++ channeldb/payment_control.go | 79 ++++++++++ channeldb/payment_control_test.go | 128 ++++++++++++++-- channeldb/payments.go | 70 ++++++++- channeldb/payments_test.go | 5 + 10 files changed, 647 insertions(+), 11 deletions(-) create mode 100644 channeldb/migration16/log.go create mode 100644 channeldb/migration16/migration.go create mode 100644 channeldb/migration16/migration_test.go diff --git a/channeldb/codec.go b/channeldb/codec.go index 78d61694..f6903175 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -192,6 +192,11 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case paymentIndexType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + case lnwire.FundingFlag: if err := binary.Write(w, byteOrder, e); err != nil { return err @@ -406,6 +411,11 @@ func ReadElement(r io.Reader, element interface{}) error { return err } + case *paymentIndexType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + case *lnwire.FundingFlag: if err := binary.Read(r, byteOrder, e); err != nil { return err diff --git a/channeldb/db.go b/channeldb/db.go index 1347a58f..564dbb7b 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -16,6 +16,7 @@ import ( mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" + "github.com/lightningnetwork/lnd/channeldb/migration16" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" @@ -144,6 +145,19 @@ var ( number: 14, migration: mig.CreateTLB(payAddrIndexBucket), }, + { + // Initialize payment index bucket which will be used + // to index payments by sequence number. This index will + // be used to allow more efficient ListPayments queries. + number: 15, + migration: mig.CreateTLB(paymentsIndexBucket), + }, + { + // Add our existing payments to the index bucket created + // in migration 15. + number: 16, + migration: migration16.MigrateSequenceIndex, + }, } // Big endian is the preferred byte order, due to cursor scans over @@ -257,6 +271,7 @@ var topLevelBuckets = [][]byte{ fwdPackagesKey, invoiceBucket, payAddrIndexBucket, + paymentsIndexBucket, nodeInfoBucket, nodeBucket, edgeBucket, diff --git a/channeldb/log.go b/channeldb/log.go index f59426f0..75ba2a5f 100644 --- a/channeldb/log.go +++ b/channeldb/log.go @@ -6,6 +6,7 @@ import ( mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" + "github.com/lightningnetwork/lnd/channeldb/migration16" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" ) @@ -33,4 +34,5 @@ func UseLogger(logger btclog.Logger) { migration_01_to_11.UseLogger(logger) migration12.UseLogger(logger) migration13.UseLogger(logger) + migration16.UseLogger(logger) } diff --git a/channeldb/migration16/log.go b/channeldb/migration16/log.go new file mode 100644 index 00000000..cb946854 --- /dev/null +++ b/channeldb/migration16/log.go @@ -0,0 +1,14 @@ +package migration16 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/channeldb/migration16/migration.go b/channeldb/migration16/migration.go new file mode 100644 index 00000000..b984f083 --- /dev/null +++ b/channeldb/migration16/migration.go @@ -0,0 +1,191 @@ +package migration16 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb/kvdb" +) + +var ( + paymentsRootBucket = []byte("payments-root-bucket") + + paymentSequenceKey = []byte("payment-sequence-key") + + duplicatePaymentsBucket = []byte("payment-duplicate-bucket") + + paymentsIndexBucket = []byte("payments-index-bucket") + + byteOrder = binary.BigEndian +) + +// paymentIndexType indicates the type of index we have recorded in the payment +// indexes bucket. +type paymentIndexType uint8 + +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +const paymentIndexTypeHash paymentIndexType = 0 + +// paymentIndex stores all the information we require to create an index by +// sequence number for a payment. +type paymentIndex struct { + // paymentHash is the hash of the payment, which is its key in the + // payment root bucket. + paymentHash []byte + + // sequenceNumbers is the set of sequence numbers associated with this + // payment hash. There will be more than one sequence number in the + // case where duplicate payments are present. + sequenceNumbers [][]byte +} + +// MigrateSequenceIndex migrates the payments db to contain a new bucket which +// provides an index from sequence number to payment hash. This is required +// for more efficient sequential lookup of payments, which are keyed by payment +// hash before this migration. +func MigrateSequenceIndex(tx kvdb.RwTx) error { + log.Infof("Migrating payments to add sequence number index") + + // Get a list of indices we need to write. + indexList, err := getPaymentIndexList(tx) + if err != nil { + return err + } + + // Create the top level bucket that we will use to index payments in. + bucket, err := tx.CreateTopLevelBucket(paymentsIndexBucket) + if err != nil { + return err + } + + // Write an index for each of our payments. + for _, index := range indexList { + // Write indexes for each of our sequence numbers. + for _, seqNr := range index.sequenceNumbers { + err := putIndex(bucket, seqNr, index.paymentHash) + if err != nil { + return err + } + } + } + + return nil +} + +// putIndex performs a sanity check that ensures we are not writing duplicate +// indexes to disk then creates the index provided. +func putIndex(bucket kvdb.RwBucket, sequenceNr, paymentHash []byte) error { + // Add a sanity check that we do not already have an entry with + // this sequence number. + existingEntry := bucket.Get(sequenceNr) + if existingEntry != nil { + return fmt.Errorf("sequence number: %x duplicated", + sequenceNr) + } + + bytes, err := serializePaymentIndexEntry(paymentHash) + if err != nil { + return err + } + + return bucket.Put(sequenceNr, bytes) +} + +// serializePaymentIndexEntry serializes a payment hash typed index. The value +// produced contains a payment index type (which can be used in future to +// signal different payment index types) and the payment hash. +func serializePaymentIndexEntry(hash []byte) ([]byte, error) { + var b bytes.Buffer + + err := binary.Write(&b, byteOrder, paymentIndexTypeHash) + if err != nil { + return nil, err + } + + if err := wire.WriteVarBytes(&b, 0, hash); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// getPaymentIndexList gets a list of indices we need to write for our current +// set of payments. +func getPaymentIndexList(tx kvdb.RTx) ([]paymentIndex, error) { + // Iterate over all payments and store their indexing keys. This is + // needed, because no modifications are allowed inside a Bucket.ForEach + // loop. + paymentsBucket := tx.ReadBucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil, nil + } + + var indexList []paymentIndex + err := paymentsBucket.ForEach(func(k, v []byte) error { + // Get the bucket which contains the payment, fail if the key + // does not have a bucket. + bucket := paymentsBucket.NestedReadBucket(k) + if bucket == nil { + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return fmt.Errorf("nil sequence number bytes") + } + + seqNrs, err := fetchSequenceNumbers(bucket) + if err != nil { + return err + } + + // Create an index object with our payment hash and sequence + // numbers and append it to our set of indexes. + index := paymentIndex{ + paymentHash: k, + sequenceNumbers: seqNrs, + } + + indexList = append(indexList, index) + return nil + }) + if err != nil { + return nil, err + } + + return indexList, nil +} + +// fetchSequenceNumbers fetches all the sequence numbers associated with a +// payment, including those belonging to any duplicate payments. +func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) { + seqNum := paymentBucket.Get(paymentSequenceKey) + if seqNum == nil { + return nil, errors.New("expected sequence number") + } + + sequenceNumbers := [][]byte{seqNum} + + // Get the duplicate payments bucket, if it has no duplicates, just + // return early with the payment sequence number. + duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket) + if duplicates == nil { + return sequenceNumbers, nil + } + + // If we do have duplicated, they are keyed by sequence number, so we + // iterate through the duplicates bucket and add them to our set of + // sequence numbers. + if err := duplicates.ForEach(func(k, v []byte) error { + sequenceNumbers = append(sequenceNumbers, k) + return nil + }); err != nil { + return nil, err + } + + return sequenceNumbers, nil +} diff --git a/channeldb/migration16/migration_test.go b/channeldb/migration16/migration_test.go new file mode 100644 index 00000000..626bedcb --- /dev/null +++ b/channeldb/migration16/migration_test.go @@ -0,0 +1,144 @@ +package migration16 + +import ( + "encoding/hex" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" + "github.com/lightningnetwork/lnd/channeldb/migtest" +) + +var ( + hexStr = migtest.Hex + + hash1Str = "02acee76ebd53d00824410cf6adecad4f50334dac702bd5a2d3ba01b91709f0e" + hash1 = hexStr(hash1Str) + paymentID1 = hexStr("0000000000000001") + + hash2Str = "62eb3f0a48f954e495d0c14ac63df04a67cefa59dafdbcd3d5046d1f5647840c" + hash2 = hexStr(hash2Str) + paymentID2 = hexStr("0000000000000002") + + paymentID3 = hexStr("0000000000000003") + + // pre is the data in the payments root bucket in database version 13 format. + pre = map[string]interface{}{ + // A payment without duplicates. + hash1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + }, + + // A payment with a duplicate. + hash2: map[string]interface{}{ + "payment-sequence-key": paymentID2, + "payment-duplicate-bucket": map[string]interface{}{ + paymentID3: map[string]interface{}{ + "payment-sequence-key": paymentID3, + }, + }, + }, + } + + preFails = map[string]interface{}{ + // A payment without duplicates. + hash1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + "payment-duplicate-bucket": map[string]interface{}{ + paymentID1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + }, + }, + }, + } + + // post is the expected data after migration. + post = map[string]interface{}{ + paymentID1: paymentHashIndex(hash1Str), + paymentID2: paymentHashIndex(hash2Str), + paymentID3: paymentHashIndex(hash2Str), + } +) + +// paymentHashIndex produces a string that represents the value we expect for +// our payment indexes from a hex encoded payment hash string. +func paymentHashIndex(hashStr string) string { + hash, err := hex.DecodeString(hashStr) + if err != nil { + panic(err) + } + + bytes, err := serializePaymentIndexEntry(hash) + if err != nil { + panic(err) + } + + return string(bytes) +} + +// MigrateSequenceIndex asserts that the database is properly migrated to +// contain a payments index. +func TestMigrateSequenceIndex(t *testing.T) { + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + post map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: pre, + post: post, + }, + { + name: "duplicate sequence number", + shouldFail: true, + pre: preFails, + post: post, + }, + { + name: "no payments", + shouldFail: false, + pre: nil, + post: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + // Before the migration we have a payments bucket. + before := func(tx kvdb.RwTx) error { + return migtest.RestoreDB( + tx, paymentsRootBucket, test.pre, + ) + } + + // After the migration, we should have an untouched + // payments bucket and a new index bucket. + after := func(tx kvdb.RwTx) error { + if err := migtest.VerifyDB( + tx, paymentsRootBucket, test.pre, + ); err != nil { + return err + } + + // If we expect our migration to fail, we don't + // expect an index bucket. + if test.shouldFail { + return nil + } + + return migtest.VerifyDB( + tx, paymentsIndexBucket, test.post, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateSequenceIndex, + test.shouldFail, + ) + }) + } +} diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index 99d2000c..5a538134 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" @@ -74,6 +75,11 @@ var ( // errNoAttemptInfo is returned when no attempt info is stored yet. errNoAttemptInfo = errors.New("unable to find attempt info for " + "inflight payment") + + // errNoSequenceNrIndex is returned when an attempt to lookup a payment + // index is made for a sequence number that is not indexed. + errNoSequenceNrIndex = errors.New("payment sequence number index " + + "does not exist") ) // PaymentControl implements persistence for payments and payment attempts. @@ -152,6 +158,27 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, return err } + // Before we set our new sequence number, we check whether this + // payment has a previously set sequence number and remove its + // index entry if it exists. This happens in the case where we + // have a previously attempted payment which was left in a state + // where we can retry. + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes != nil { + indexBucket := tx.ReadWriteBucket(paymentsIndexBucket) + if err := indexBucket.Delete(seqBytes); err != nil { + return err + } + } + + // Once we have obtained a sequence number, we add an entry + // to our index bucket which will map the sequence number to + // our payment hash. + err = createPaymentIndexEntry(tx, sequenceNum, info.PaymentHash) + if err != nil { + return err + } + err = bucket.Put(paymentSequenceKey, sequenceNum) if err != nil { return err @@ -183,6 +210,58 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, return updateErr } +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +type paymentIndexType uint8 + +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +const paymentIndexTypeHash paymentIndexType = 0 + +// createPaymentIndexEntry creates a payment hash typed index for a payment. The +// index produced contains a payment index type (which can be used in future to +// signal different payment index types) and the payment hash. +func createPaymentIndexEntry(tx kvdb.RwTx, sequenceNumber []byte, + hash lntypes.Hash) error { + + var b bytes.Buffer + if err := WriteElements(&b, paymentIndexTypeHash, hash[:]); err != nil { + return err + } + + indexes := tx.ReadWriteBucket(paymentsIndexBucket) + return indexes.Put(sequenceNumber, b.Bytes()) +} + +// deserializePaymentIndex deserializes a payment index entry. This function +// currently only supports deserialization of payment hash indexes, and will +// fail for other types. +func deserializePaymentIndex(r io.Reader) (lntypes.Hash, error) { + var ( + indexType paymentIndexType + paymentHash []byte + ) + + if err := ReadElements(r, &indexType, &paymentHash); err != nil { + return lntypes.Hash{}, err + } + + // While we only have on payment index type, we do not need to use our + // index type to deserialize the index. However, we sanity check that + // this type is as expected, since we had to read it out anyway. + if indexType != paymentIndexTypeHash { + return lntypes.Hash{}, fmt.Errorf("unknown payment index "+ + "type: %v", indexType) + } + + hash, err := lntypes.MakeHash(paymentHash) + if err != nil { + return lntypes.Hash{}, err + } + + return hash, nil +} + // RegisterAttempt atomically records the provided HTLCAttemptInfo to the // DB. func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index c470a8f5..147e5452 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -1,6 +1,7 @@ package channeldb import ( + "bytes" "crypto/rand" "crypto/sha256" "fmt" @@ -9,9 +10,13 @@ import ( "testing" "time" + "github.com/btcsuite/btcwallet/walletdb" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/record" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func genPreimage() ([32]byte, error) { @@ -70,6 +75,7 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -88,6 +94,11 @@ func TestPaymentControlSwitchFail(t *testing.T) { t, pControl, info.PaymentHash, info, &failReason, nil, ) + // Lookup the payment so we can get its old sequence number before it is + // overwritten. + payment, err := pControl.FetchPayment(info.PaymentHash) + assert.NoError(t, err) + // Sends the htlc again, which should succeed since the prior payment // failed. err = pControl.InitPayment(info.PaymentHash, info) @@ -95,6 +106,11 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + // Check that our index has been updated, and the old index has been + // removed. + assertPaymentIndex(t, pControl, info.PaymentHash) + assertNoIndex(t, pControl, payment.SequenceNum) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -145,7 +161,6 @@ func TestPaymentControlSwitchFail(t *testing.T) { // Settle the attempt and verify that status was changed to // StatusSucceeded. - var payment *MPPayment payment, err = pControl.SettleAttempt( info.PaymentHash, attempt.AttemptID, &HTLCSettleInfo{ @@ -209,6 +224,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -326,7 +342,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown) } -// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only +// TestPaymentControlDeleteNonInFlight checks that calling DeletePayments only // deletes payments from the database that are not in-flight. func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Parallel() @@ -338,23 +354,37 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Fatalf("unable to init db: %v", err) } + // Create a sequence number for duplicate payments that will not collide + // with the sequence numbers for the payments we create. These values + // start at 1, so 9999 is a safe bet for this test. + var duplicateSeqNr = 9999 + pControl := NewPaymentControl(db) payments := []struct { - failed bool - success bool + failed bool + success bool + hasDuplicate bool }{ { - failed: true, - success: false, + failed: true, + success: false, + hasDuplicate: false, }, { - failed: false, - success: true, + failed: false, + success: true, + hasDuplicate: false, }, { - failed: false, - success: false, + failed: false, + success: false, + hasDuplicate: false, + }, + { + failed: false, + success: true, + hasDuplicate: true, }, } @@ -430,6 +460,16 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { t, pControl, info.PaymentHash, info, nil, htlc, ) } + + // If the payment is intended to have a duplicate payment, we + // add one. + if p.hasDuplicate { + appendDuplicatePayment( + t, pControl.db, info.PaymentHash, + uint64(duplicateSeqNr), + ) + duplicateSeqNr++ + } } // Delete payments. @@ -451,6 +491,21 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { if status != StatusInFlight { t.Fatalf("expected in-fligth status, got %v", status) } + + // Finally, check that we only have a single index left in the payment + // index bucket. + var indexCount int + err = kvdb.View(db, func(tx walletdb.ReadTx) error { + index := tx.ReadBucket(paymentsIndexBucket) + + return index.ForEach(func(k, v []byte) error { + indexCount++ + return nil + }) + }) + require.NoError(t, err) + + require.Equal(t, 1, indexCount) } // TestPaymentControlMultiShard checks the ability of payment control to @@ -495,6 +550,7 @@ func TestPaymentControlMultiShard(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -910,3 +966,55 @@ func assertPaymentInfo(t *testing.T, p *PaymentControl, hash lntypes.Hash, t.Fatal("expected no settle info") } } + +// fetchPaymentIndexEntry gets the payment hash for the sequence number provided +// from our payment indexes bucket. +func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl, + sequenceNumber uint64) (*lntypes.Hash, error) { + + var hash lntypes.Hash + + if err := kvdb.View(p.db, func(tx walletdb.ReadTx) error { + indexBucket := tx.ReadBucket(paymentsIndexBucket) + key := make([]byte, 8) + byteOrder.PutUint64(key, sequenceNumber) + + indexValue := indexBucket.Get(key) + if indexValue == nil { + return errNoSequenceNrIndex + } + + r := bytes.NewReader(indexValue) + + var err error + hash, err = deserializePaymentIndex(r) + return err + + }); err != nil { + return nil, err + } + + return &hash, nil +} + +// assertPaymentIndex looks up the index for a payment in the db and checks +// that its payment hash matches the expected hash passed in. +func assertPaymentIndex(t *testing.T, p *PaymentControl, + expectedHash lntypes.Hash) { + + // Lookup the payment so that we have its sequence number and check + // that is has correctly been indexed in the payment indexes bucket. + pmt, err := p.FetchPayment(expectedHash) + require.NoError(t, err) + + hash, err := fetchPaymentIndexEntry(t, p, pmt.SequenceNum) + require.NoError(t, err) + assert.Equal(t, expectedHash, *hash) +} + +// assertNoIndex checks that an index for the sequence number provided does not +// exist. +func assertNoIndex(t *testing.T, p *PaymentControl, seqNr uint64) { + _, err := fetchPaymentIndexEntry(t, p, seqNr) + require.Equal(t, errNoSequenceNrIndex, err) +} diff --git a/channeldb/payments.go b/channeldb/payments.go index d1ec6070..3229c093 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -3,6 +3,7 @@ package channeldb import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "math" @@ -92,6 +93,15 @@ var ( // paymentFailInfoKey is a key used in the payment's sub-bucket to // store information about the reason a payment failed. paymentFailInfoKey = []byte("payment-fail-info") + + // paymentsIndexBucket is the name of the top-level bucket within the + // database that stores an index of payment sequence numbers to its + // payment hash. + // payments-sequence-index-bucket + // |--: + // |--... + // |--: + paymentsIndexBucket = []byte("payments-index-bucket") ) // FailureReason encodes the reason a payment ultimately failed. @@ -566,7 +576,15 @@ func (db *DB) DeletePayments() error { return nil } - var deleteBuckets [][]byte + var ( + // deleteBuckets is the set of payment buckets we need + // to delete. + deleteBuckets [][]byte + + // deleteIndexes is the set of indexes pointing to these + // payments that need to be deleted. + deleteIndexes [][]byte + ) err := payments.ForEach(func(k, _ []byte) error { bucket := payments.NestedReadWriteBucket(k) if bucket == nil { @@ -589,7 +607,18 @@ func (db *DB) DeletePayments() error { return nil } + // Add the bucket to the set of buckets we can delete. deleteBuckets = append(deleteBuckets, k) + + // Get all the sequence number associated with the + // payment, including duplicates. + seqNrs, err := fetchSequenceNumbers(bucket) + if err != nil { + return err + } + + deleteIndexes = append(deleteIndexes, seqNrs...) + return nil }) if err != nil { @@ -602,10 +631,49 @@ func (db *DB) DeletePayments() error { } } + // Get our index bucket and delete all indexes pointing to the + // payments we are deleting. + indexBucket := tx.ReadWriteBucket(paymentsIndexBucket) + for _, k := range deleteIndexes { + if err := indexBucket.Delete(k); err != nil { + return err + } + } + return nil }) } +// fetchSequenceNumbers fetches all the sequence numbers associated with a +// payment, including those belonging to any duplicate payments. +func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) { + seqNum := paymentBucket.Get(paymentSequenceKey) + if seqNum == nil { + return nil, errors.New("expected sequence number") + } + + sequenceNumbers := [][]byte{seqNum} + + // Get the duplicate payments bucket, if it has no duplicates, just + // return early with the payment sequence number. + duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket) + if duplicates == nil { + return sequenceNumbers, nil + } + + // If we do have duplicated, they are keyed by sequence number, so we + // iterate through the duplicates bucket and add them to our set of + // sequence numbers. + if err := duplicates.ForEach(func(k, v []byte) error { + sequenceNumbers = append(sequenceNumbers, k) + return nil + }); err != nil { + return nil, err + } + + return sequenceNumbers, nil +} + // nolint: dupl func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { var scratch [8]byte diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 118de16d..706060a2 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -485,6 +485,11 @@ func appendDuplicatePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, // sequence numbers we've setup. putDuplicatePayment(t, dup, sequenceKey[:], paymentHash) + // Finally, once we have created our entry we add an index for + // it. + err = createPaymentIndexEntry(tx, sequenceKey[:], paymentHash) + require.NoError(t, err) + return nil }) if err != nil { From f4933c67fd1cc68fc321f188f894411a865caabb Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 10 Jun 2020 12:34:27 +0200 Subject: [PATCH 3/6] channeldb: add test case for index offset greater than index In our current invoice pagination logic, we would not return any invoices if our offset index was more than 1 off our last index and we were paginating backwards. This commit adds a test case for this behaviour before fixing it in the next commit. --- channeldb/invoice_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index eea9df03..d19c349b 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1007,6 +1007,18 @@ func TestQueryInvoices(t *testing.T) { // still pending. expected: pendingInvoices[len(pendingInvoices)-15:], }, + // Fetch all invoices paginating backwards, with an index offset + // that is beyond our last offset. We currently do not return + // anything if our index is greater than our last index. + { + query: InvoiceQuery{ + IndexOffset: numInvoices * 2, + PendingOnly: false, + Reversed: true, + NumMaxInvoices: numInvoices, + }, + expected: nil, + }, } for i, testCase := range testCases { From eea871b5831b01c77c27d4f37e99f67217a3f2d8 Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 10 Jun 2020 12:34:27 +0200 Subject: [PATCH 4/6] channeldb: add a paginator struct to process generic pagination We now use the same method of pagination for invoices and payments. Rather than duplicate logic across calls, we add a pagnator struct which can have query specific logic plugged into it. This commit also addresses an existing issue where a reverse query for invoices with an offset larger than our last offset would not return any invoices. We update this behaviour to act more like c.Seek and just start from the last entry. This behaviour change is covered by a unit test that previously checked for the lack of invoices. --- channeldb/invoice_test.go | 6 +- channeldb/invoices.go | 90 +++++++----------------- channeldb/paginate.go | 140 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 67 deletions(-) create mode 100644 channeldb/paginate.go diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index d19c349b..e0ec2191 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1008,8 +1008,8 @@ func TestQueryInvoices(t *testing.T) { expected: pendingInvoices[len(pendingInvoices)-15:], }, // Fetch all invoices paginating backwards, with an index offset - // that is beyond our last offset. We currently do not return - // anything if our index is greater than our last index. + // that is beyond our last offset. We expect all invoices to be + // returned. { query: InvoiceQuery{ IndexOffset: numInvoices * 2, @@ -1017,7 +1017,7 @@ func TestQueryInvoices(t *testing.T) { Reversed: true, NumMaxInvoices: numInvoices, }, - expected: nil, + expected: invoices, }, } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index a7ed4324..07de2add 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -839,85 +839,47 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { if invoices == nil { return ErrNoInvoicesCreated } + + // Get the add index bucket which we will use to iterate through + // our indexed invoices. invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket) if invoiceAddIndex == nil { return ErrNoInvoicesCreated } - // keyForIndex is a helper closure that retrieves the invoice - // key for the given add index of an invoice. - keyForIndex := func(c kvdb.RCursor, index uint64) []byte { - var keyIndex [8]byte - byteOrder.PutUint64(keyIndex[:], index) - _, invoiceKey := c.Seek(keyIndex[:]) - return invoiceKey - } + // Create a paginator which reads from our add index bucket with + // the parameters provided by the invoice query. + paginator := newPaginator( + invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset, + q.NumMaxInvoices, + ) - // nextKey is a helper closure to determine what the next - // invoice key is when iterating over the invoice add index. - nextKey := func(c kvdb.RCursor) ([]byte, []byte) { - if q.Reversed { - return c.Prev() - } - return c.Next() - } - - // We'll be using a cursor to seek into the database and return - // a slice of invoices. We'll need to determine where to start - // our cursor depending on the parameters set within the query. - c := invoiceAddIndex.ReadCursor() - invoiceKey := keyForIndex(c, q.IndexOffset+1) - - // If the query is specifying reverse iteration, then we must - // handle a few offset cases. - if q.Reversed { - switch q.IndexOffset { - - // This indicates the default case, where no offset was - // specified. In that case we just start from the last - // invoice. - case 0: - _, invoiceKey = c.Last() - - // This indicates the offset being set to the very - // first invoice. Since there are no invoices before - // this offset, and the direction is reversed, we can - // return without adding any invoices to the response. - case 1: - return nil - - // Otherwise we start iteration at the invoice prior to - // the offset. - default: - invoiceKey = keyForIndex(c, q.IndexOffset-1) - } - } - - // If we know that a set of invoices exists, then we'll begin - // our seek through the bucket in order to satisfy the query. - // We'll continue until either we reach the end of the range, or - // reach our max number of invoices. - for ; invoiceKey != nil; _, invoiceKey = nextKey(c) { - // If our current return payload exceeds the max number - // of invoices, then we'll exit now. - if uint64(len(resp.Invoices)) >= q.NumMaxInvoices { - break - } - - invoice, err := fetchInvoice(invoiceKey, invoices) + // accumulateInvoices looks up an invoice based on the index we + // are given, adds it to our set of invoices if it has the right + // characteristics for our query and returns the number of items + // we have added to our set of invoices. + accumulateInvoices := func(_, indexValue []byte) (bool, error) { + invoice, err := fetchInvoice(indexValue, invoices) if err != nil { - return err + return false, err } - // Skip any settled or canceled invoices if the caller is - // only interested in pending ones. + // Skip any settled or canceled invoices if the caller + // is only interested in pending ones. if q.PendingOnly && !invoice.IsPending() { - continue + return false, nil } // At this point, we've exhausted the offset, so we'll // begin collecting invoices found within the range. resp.Invoices = append(resp.Invoices, invoice) + return true, nil + } + + // Query our paginator using accumulateInvoices to build up a + // set of invoices. + if err := paginator.query(accumulateInvoices); err != nil { + return err } // If we iterated through the add index in reverse order, then diff --git a/channeldb/paginate.go b/channeldb/paginate.go new file mode 100644 index 00000000..22ec4fb4 --- /dev/null +++ b/channeldb/paginate.go @@ -0,0 +1,140 @@ +package channeldb + +import "github.com/lightningnetwork/lnd/channeldb/kvdb" + +type paginator struct { + // cursor is the cursor which we are using to iterate through a bucket. + cursor kvdb.RCursor + + // reversed indicates whether we are paginating forwards or backwards. + reversed bool + + // indexOffset is the index from which we will begin querying. + indexOffset uint64 + + // totalItems is the total number of items we allow in our response. + totalItems uint64 +} + +// newPaginator returns a struct which can be used to query an indexed bucket +// in pages. +func newPaginator(c kvdb.RCursor, reversed bool, + indexOffset, totalItems uint64) paginator { + + return paginator{ + cursor: c, + reversed: reversed, + indexOffset: indexOffset, + totalItems: totalItems, + } +} + +// keyValueForIndex seeks our cursor to a given index and returns the key and +// value at that position. +func (p paginator) keyValueForIndex(index uint64) ([]byte, []byte) { + var keyIndex [8]byte + byteOrder.PutUint64(keyIndex[:], index) + return p.cursor.Seek(keyIndex[:]) +} + +// lastIndex returns the last value in our index, if our index is empty it +// returns 0. +func (p paginator) lastIndex() uint64 { + keyIndex, _ := p.cursor.Last() + if keyIndex == nil { + return 0 + } + + return byteOrder.Uint64(keyIndex) +} + +// nextKey is a helper closure to determine what key we should use next when +// we are iterating, depending on whether we are iterating forwards or in +// reverse. +func (p paginator) nextKey() ([]byte, []byte) { + if p.reversed { + return p.cursor.Prev() + } + return p.cursor.Next() +} + +// cursorStart gets the index key and value for the first item we are looking +// up, taking into account that we may be paginating in reverse. The index +// offset provided is *excusive* so we will start with the item after the offset +// for forwards queries, and the item before the index for backwards queries. +func (p paginator) cursorStart() ([]byte, []byte) { + indexKey, indexValue := p.keyValueForIndex(p.indexOffset + 1) + + // If the query is specifying reverse iteration, then we must + // handle a few offset cases. + if p.reversed { + switch { + + // This indicates the default case, where no offset was + // specified. In that case we just start from the last + // entry. + case p.indexOffset == 0: + indexKey, indexValue = p.cursor.Last() + + // This indicates the offset being set to the very + // first entry. Since there are no entries before + // this offset, and the direction is reversed, we can + // return without adding any invoices to the response. + case p.indexOffset == 1: + return nil, nil + + // If we have been given an index offset that is beyond our last + // index value, we just return the last indexed value in our set + // since we are querying in reverse. We do not cover the case + // where our index offset equals our last index value, because + // index offset is exclusive, so we would want to start at the + // value before our last index. + case p.indexOffset > p.lastIndex(): + return p.cursor.Last() + + // Otherwise we have an index offset which is within our set of + // indexed keys, and we want to start at the item before our + // offset. We seek to our index offset, then return the element + // before it. We do this rather than p.indexOffset-1 to account + // for indexes that have gaps. + default: + p.keyValueForIndex(p.indexOffset) + indexKey, indexValue = p.cursor.Prev() + } + } + + return indexKey, indexValue +} + +// query gets the start point for our index offset and iterates through keys +// in our index until we reach the total number of items required for the query +// or we run out of cursor values. This function takes a fetchAndAppend function +// which is responsible for looking up the entry at that index, adding the entry +// to its set of return items (if desired) and return a boolean which indicates +// whether the item was added. This is required to allow the paginator to +// determine when the response has the maximum number of required items. +func (p paginator) query(fetchAndAppend func(k, v []byte) (bool, error)) error { + indexKey, indexValue := p.cursorStart() + + var totalItems int + for ; indexKey != nil; indexKey, indexValue = p.nextKey() { + // If our current return payload exceeds the max number + // of invoices, then we'll exit now. + if uint64(totalItems) >= p.totalItems { + break + } + + added, err := fetchAndAppend(indexKey, indexValue) + if err != nil { + return err + } + + // If we added an item to our set in the latest fetch and append + // we increment our total count. + if added { + totalItems++ + } + } + + return nil +} From 38624e861279fe3f5e278b2485398cdc4bacdce7 Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 10 Jun 2020 12:34:28 +0200 Subject: [PATCH 5/6] channeldb: add fetchPaymentWithSequenceNumber lookup and test With our new index of sequence number to index, it is possible for more than one sequence number to point to the same hash because legacy lnd allowed duplicate payments under the same hash. We now store these payments in a nested bucket within the payments database. To allow lookup of the correct payment from an index, we require matching of the payment hash and sequence number. --- channeldb/payments.go | 97 +++++++++++++++++++++++++++++ channeldb/payments_test.go | 122 +++++++++++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+) diff --git a/channeldb/payments.go b/channeldb/payments.go index 3229c093..73624902 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -104,6 +104,26 @@ var ( paymentsIndexBucket = []byte("payments-index-bucket") ) +var ( + // ErrNoSequenceNumber is returned if we lookup a payment which does + // not have a sequence number. + ErrNoSequenceNumber = errors.New("sequence number not found") + + // ErrDuplicateNotFound is returned when we lookup a payment by its + // index and cannot find a payment with a matching sequence number. + ErrDuplicateNotFound = errors.New("duplicate payment not found") + + // ErrNoDuplicateBucket is returned when we expect to find duplicates + // when looking up a payment from its index, but the payment does not + // have any. + ErrNoDuplicateBucket = errors.New("expected duplicate bucket") + + // ErrNoDuplicateNestedBucket is returned if we do not find duplicate + // payments in their own sub-bucket. + ErrNoDuplicateNestedBucket = errors.New("nested duplicate bucket not " + + "found") +) + // FailureReason encodes the reason a payment ultimately failed. type FailureReason byte @@ -568,6 +588,83 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { return resp, err } +// fetchPaymentWithSequenceNumber get the payment which matches the payment hash +// *and* sequence number provided from the database. This is required because +// we previously had more than one payment per hash, so we have multiple indexes +// pointing to a single payment; we want to retrieve the correct one. +func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash, + sequenceNumber []byte) (*MPPayment, error) { + + // We can now lookup the payment keyed by its hash in + // the payments root bucket. + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err != nil { + return nil, err + } + + // A single payment hash can have multiple payments associated with it. + // We lookup our sequence number first, to determine whether this is + // the payment we are actually looking for. + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return nil, ErrNoSequenceNumber + } + + // If this top level payment has the sequence number we are looking for, + // return it. + if bytes.Equal(seqBytes, sequenceNumber) { + return fetchPayment(bucket) + } + + // If we were not looking for the top level payment, we are looking for + // one of our duplicate payments. We need to iterate through the seq + // numbers in this bucket to find the correct payments. If we do not + // find a duplicate payments bucket here, something is wrong. + dup := bucket.NestedReadBucket(duplicatePaymentsBucket) + if dup == nil { + return nil, ErrNoDuplicateBucket + } + + var duplicatePayment *MPPayment + err = dup.ForEach(func(k, v []byte) error { + subBucket := dup.NestedReadBucket(k) + if subBucket == nil { + // We one bucket for each duplicate to be found. + return ErrNoDuplicateNestedBucket + } + + seqBytes := subBucket.Get(duplicatePaymentSequenceKey) + if seqBytes == nil { + return err + } + + // If this duplicate payment is not the sequence number we are + // looking for, we can continue. + if !bytes.Equal(seqBytes, sequenceNumber) { + return nil + } + + duplicatePayment, err = fetchDuplicatePayment(subBucket) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err + } + + // If none of the duplicate payments matched our sequence number, we + // failed to find the payment with this sequence number; something is + // wrong. + if duplicatePayment == nil { + return nil, ErrDuplicateNotFound + } + + return duplicatePayment, nil +} + // DeletePayments deletes all completed and failed payments from the DB. func (db *DB) DeletePayments() error { return kvdb.Update(db, func(tx kvdb.RwTx) error { diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 706060a2..532003bb 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -453,6 +453,128 @@ func TestQueryPayments(t *testing.T) { } } +// TestFetchPaymentWithSequenceNumber tests lookup of payments with their +// sequence number. It sets up one payment with no duplicates, and another with +// two duplicates in its duplicates bucket then uses these payments to test the +// case where a specific duplicate is not found and the duplicates bucket is not +// present when we expect it to be. +func TestFetchPaymentWithSequenceNumber(t *testing.T) { + db, cleanup, err := makeTestDB() + require.NoError(t, err) + + defer cleanup() + + pControl := NewPaymentControl(db) + + // Generate a test payment which does not have duplicates. + noDuplicates, _, _, err := genInfo() + require.NoError(t, err) + + // Create a new payment entry in the database. + err = pControl.InitPayment(noDuplicates.PaymentHash, noDuplicates) + require.NoError(t, err) + + // Fetch the payment so we can get its sequence nr. + noDuplicatesPayment, err := pControl.FetchPayment( + noDuplicates.PaymentHash, + ) + require.NoError(t, err) + + // Generate a test payment which we will add duplicates to. + hasDuplicates, _, _, err := genInfo() + require.NoError(t, err) + + // Create a new payment entry in the database. + err = pControl.InitPayment(hasDuplicates.PaymentHash, hasDuplicates) + require.NoError(t, err) + + // Fetch the payment so we can get its sequence nr. + hasDuplicatesPayment, err := pControl.FetchPayment( + hasDuplicates.PaymentHash, + ) + require.NoError(t, err) + + // We declare the sequence numbers used here so that we can reference + // them in tests. + var ( + duplicateOneSeqNr = hasDuplicatesPayment.SequenceNum + 1 + duplicateTwoSeqNr = hasDuplicatesPayment.SequenceNum + 2 + ) + + // Add two duplicates to our second payment. + appendDuplicatePayment( + t, db, hasDuplicates.PaymentHash, duplicateOneSeqNr, + ) + appendDuplicatePayment( + t, db, hasDuplicates.PaymentHash, duplicateTwoSeqNr, + ) + + tests := []struct { + name string + paymentHash lntypes.Hash + sequenceNumber uint64 + expectedErr error + }{ + { + name: "lookup payment without duplicates", + paymentHash: noDuplicates.PaymentHash, + sequenceNumber: noDuplicatesPayment.SequenceNum, + expectedErr: nil, + }, + { + name: "lookup payment with duplicates", + paymentHash: hasDuplicates.PaymentHash, + sequenceNumber: hasDuplicatesPayment.SequenceNum, + expectedErr: nil, + }, + { + name: "lookup first duplicate", + paymentHash: hasDuplicates.PaymentHash, + sequenceNumber: duplicateOneSeqNr, + expectedErr: nil, + }, + { + name: "lookup second duplicate", + paymentHash: hasDuplicates.PaymentHash, + sequenceNumber: duplicateTwoSeqNr, + expectedErr: nil, + }, + { + name: "lookup non-existent duplicate", + paymentHash: hasDuplicates.PaymentHash, + sequenceNumber: 999999, + expectedErr: ErrDuplicateNotFound, + }, + { + name: "lookup duplicate, no duplicates bucket", + paymentHash: noDuplicates.PaymentHash, + sequenceNumber: duplicateTwoSeqNr, + expectedErr: ErrNoDuplicateBucket, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + err := kvdb.Update(db, + func(tx walletdb.ReadWriteTx) error { + + var seqNrBytes [8]byte + byteOrder.PutUint64( + seqNrBytes[:], test.sequenceNumber, + ) + + _, err := fetchPaymentWithSequenceNumber( + tx, test.paymentHash, seqNrBytes[:], + ) + return err + }) + require.Equal(t, test.expectedErr, err) + }) + } +} + // appendDuplicatePayment adds a duplicate payment to an existing payment. Note // that this function requires a unique sequence number. // From ab594ea57b52cb013cff78adc6a213b71c7940e0 Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 10 Jun 2020 12:34:28 +0200 Subject: [PATCH 6/6] channeldb: update QueryPayments to use sequence nr index and paginator Use the new paginatior strcut for payments. Add some tests which will specifically test cases on and around the missing index we force in our test to ensure that we properly handle this case. We also add a sanity check in the test that checks that we can query when we have no payments. --- channeldb/payments.go | 121 ++++++++++++++++++++----------------- channeldb/payments_test.go | 63 +++++++++++++++++-- 2 files changed, 122 insertions(+), 62 deletions(-) diff --git a/channeldb/payments.go b/channeldb/payments.go index 73624902..5c2475bd 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "math" "sort" "time" @@ -511,64 +510,72 @@ type PaymentsResponse struct { func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { var resp PaymentsResponse - allPayments, err := db.FetchPayments() - if err != nil { + if err := kvdb.View(db, func(tx kvdb.RTx) error { + // Get the root payments bucket. + paymentsBucket := tx.ReadBucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil + } + + // Get the index bucket which maps sequence number -> payment + // hash and duplicate bool. If we have a payments bucket, we + // should have an indexes bucket as well. + indexes := tx.ReadBucket(paymentsIndexBucket) + if indexes == nil { + return fmt.Errorf("index bucket does not exist") + } + + // accumulatePayments gets payments with the sequence number + // and hash provided and adds them to our list of payments if + // they meet the criteria of our query. It returns the number + // of payments that were added. + accumulatePayments := func(sequenceKey, hash []byte) (bool, + error) { + + r := bytes.NewReader(hash) + paymentHash, err := deserializePaymentIndex(r) + if err != nil { + return false, err + } + + payment, err := fetchPaymentWithSequenceNumber( + tx, paymentHash, sequenceKey, + ) + if err != nil { + return false, err + } + + // To keep compatibility with the old API, we only + // return non-succeeded payments if requested. + if payment.Status != StatusSucceeded && + !query.IncludeIncomplete { + + return false, err + } + + // At this point, we've exhausted the offset, so we'll + // begin collecting invoices found within the range. + resp.Payments = append(resp.Payments, payment) + return true, nil + } + + // Create a paginator which reads from our sequence index bucket + // with the parameters provided by the payments query. + paginator := newPaginator( + indexes.ReadCursor(), query.Reversed, query.IndexOffset, + query.MaxPayments, + ) + + // Run a paginated query, adding payments to our response. + if err := paginator.query(accumulatePayments); err != nil { + return err + } + + return nil + }); err != nil { return resp, err } - if len(allPayments) == 0 { - return resp, nil - } - - indexExclusiveLimit := query.IndexOffset - // In backward pagination, if the index limit is the default 0 value, - // we set our limit to maxint to include all payments from the highest - // sequence number on. - if query.Reversed && indexExclusiveLimit == 0 { - indexExclusiveLimit = math.MaxInt64 - } - - for i := range allPayments { - var payment *MPPayment - - // If we have the max number of payments we want, exit. - if uint64(len(resp.Payments)) == query.MaxPayments { - break - } - - if query.Reversed { - payment = allPayments[len(allPayments)-1-i] - - // In the reversed direction, skip over all payments - // that have sequence numbers greater than or equal to - // the index offset. We skip payments with equal index - // because the offset is exclusive. - if payment.SequenceNum >= indexExclusiveLimit { - continue - } - } else { - payment = allPayments[i] - - // In the forward direction, skip over all payments that - // have sequence numbers less than or equal to the index - // offset. We skip payments with equal indexes because - // the index offset is exclusive. - if payment.SequenceNum <= indexExclusiveLimit { - continue - } - } - - // To keep compatibility with the old API, we only return - // non-succeeded payments if requested. - if payment.Status != StatusSucceeded && - !query.IncludeIncomplete { - - continue - } - - resp.Payments = append(resp.Payments, payment) - } - // Need to swap the payments slice order if reversed order. if query.Reversed { for l, r := 0, len(resp.Payments)-1; l < r; l, r = l+1, r-1 { @@ -585,7 +592,7 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { resp.Payments[len(resp.Payments)-1].SequenceNum } - return resp, err + return resp, nil } // fetchPaymentWithSequenceNumber get the payment which matches the payment hash diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 532003bb..9e790c3e 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -165,18 +165,24 @@ func TestRouteSerialization(t *testing.T) { } // deletePayment removes a payment with paymentHash from the payments database. -func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) { +func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, seqNr uint64) { t.Helper() err := kvdb.Update(db, func(tx kvdb.RwTx) error { payments := tx.ReadWriteBucket(paymentsRootBucket) + // Delete the payment bucket. err := payments.DeleteNestedBucket(paymentHash[:]) if err != nil { return err } - return nil + key := make([]byte, 8) + byteOrder.PutUint64(key, seqNr) + + // Delete the index that references this payment. + indexes := tx.ReadWriteBucket(paymentsIndexBucket) + return indexes.Delete(key) }) if err != nil { @@ -350,6 +356,42 @@ func TestQueryPayments(t *testing.T) { lastIndex: 7, expectedSeqNrs: []uint64{3, 4, 5, 6, 7}, }, + { + name: "query payments reverse before index gap", + query: PaymentsQuery{ + IndexOffset: 3, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments reverse on index gap", + query: PaymentsQuery{ + IndexOffset: 2, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments forward on index gap", + query: PaymentsQuery{ + IndexOffset: 2, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 4, + expectedSeqNrs: []uint64{3, 4}, + }, } for _, tt := range tests { @@ -358,11 +400,16 @@ func TestQueryPayments(t *testing.T) { t.Parallel() db, cleanup, err := makeTestDB() - defer cleanup() - if err != nil { t.Fatalf("unable to init db: %v", err) } + defer cleanup() + + // Make a preliminary query to make sure it's ok to + // query when we have no payments. + resp, err := db.QueryPayments(tt.query) + require.NoError(t, err) + require.Len(t, resp.Payments, 0) // Populate the database with a set of test payments. // We create 6 original payments, deleting the payment @@ -391,7 +438,13 @@ func TestQueryPayments(t *testing.T) { // Immediately delete the payment with index 2. if i == 1 { - deletePayment(t, db, info.PaymentHash) + pmt, err := pControl.FetchPayment( + info.PaymentHash, + ) + require.NoError(t, err) + + deletePayment(t, db, info.PaymentHash, + pmt.SequenceNum) } // If we are on the last payment entry, add a