From 92f3b0a30c5fba4e9a4e5ee32131f95e20d0684a Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 28 Jul 2020 21:22:23 +0200 Subject: [PATCH] channeldb+invoices: add ScanInvoices and integrate with InvoiceRegistry This commit adds channeldb.ScanInvoices to scan through all invoices in the database. The new call will also replace the already existing channeldb.FetchAllInvoicesWithPaymentHash call in preparation to collect invoices we'd like to delete and watch for expiry in one scan in later commits. --- channeldb/invoice_test.go | 115 ++++++++---------------- channeldb/invoices.go | 53 ++++------- invoices/invoice_expiry_watcher.go | 9 +- invoices/invoice_expiry_watcher_test.go | 16 +--- invoices/invoiceregistry.go | 45 +++++++--- 5 files changed, 90 insertions(+), 148 deletions(-) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 64e2dbe6..9d5aba36 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -622,9 +622,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } -// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their -// corresponding payment hashes. -func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { +// TestScanInvoices tests that ScanInvoices scans trough all stored invoices +// correctly. +func TestScanInvoices(t *testing.T) { t.Parallel() db, cleanup, err := MakeTestDB() @@ -633,97 +633,54 @@ func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { t.Fatalf("unable to make test db: %v", err) } - // With an empty DB we expect to return no error and an empty list. - empty, err := db.FetchAllInvoicesWithPaymentHash(false) - if err != nil { - t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v", - err) + var invoices map[lntypes.Hash]*Invoice + callCount := 0 + resetCount := 0 + + // reset is used to reset/initialize results and is called once + // upon calling ScanInvoices and when the underlying transaction is + // retried. + reset := func() { + invoices = make(map[lntypes.Hash]*Invoice) + callCount = 0 + resetCount++ + } - if len(empty) != 0 { - t.Fatalf("expected empty list as a result, got: %v", empty) + scanFunc := func(paymentHash lntypes.Hash, invoice *Invoice) error { + invoices[paymentHash] = invoice + callCount++ + + return nil } - states := []ContractState{ - ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, - } + // With an empty DB we expect to not scan any invoices. + require.NoError(t, db.ScanInvoices(scanFunc, reset)) + require.Equal(t, 0, len(invoices)) + require.Equal(t, 0, callCount) + require.Equal(t, 1, resetCount) - numInvoices := len(states) * 2 - testPendingInvoices := make(map[lntypes.Hash]*Invoice) - testAllInvoices := make(map[lntypes.Hash]*Invoice) + numInvoices := 5 + testInvoices := make(map[lntypes.Hash]*Invoice) // Now populate the DB and check if we can get all invoices with their // payment hashes as expected. for i := 1; i <= numInvoices; i++ { invoice, err := randInvoice(lnwire.MilliSatoshi(i)) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } + require.NoError(t, err) - // Set the contract state of the next invoice such that there's an equal - // number for all possbile states. - invoice.State = states[i%len(states)] paymentHash := invoice.Terms.PaymentPreimage.Hash() + testInvoices[paymentHash] = invoice - if invoice.IsPending() { - testPendingInvoices[paymentHash] = invoice - } - - testAllInvoices[paymentHash] = invoice - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice: %v", err) - } - } - - pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true) - if err != nil { - t.Fatalf("can't fetch invoices with payment hash: %v", err) - } - - if len(testPendingInvoices) != len(pendingInvoices) { - t.Fatalf("expected %v pending invoices, got: %v", - len(testPendingInvoices), len(pendingInvoices)) - } - - allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false) - if err != nil { - t.Fatalf("can't fetch invoices with payment hash: %v", err) - } - - if len(testAllInvoices) != len(allInvoices) { - t.Fatalf("expected %v invoices, got: %v", - len(testAllInvoices), len(allInvoices)) - } - - for i := range pendingInvoices { - expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash] - if !ok { - t.Fatalf("coulnd't find invoice with hash: %v", - pendingInvoices[i].PaymentHash) - } - - // Zero out add index to not confuse require.Equal. - pendingInvoices[i].Invoice.AddIndex = 0 - expected.AddIndex = 0 - - require.Equal(t, *expected, pendingInvoices[i].Invoice) - } - - for i := range allInvoices { - expected, ok := testAllInvoices[allInvoices[i].PaymentHash] - if !ok { - t.Fatalf("coulnd't find invoice with hash: %v", - allInvoices[i].PaymentHash) - } - - // Zero out add index to not confuse require.Equal. - allInvoices[i].Invoice.AddIndex = 0 - expected.AddIndex = 0 - - require.Equal(t, *expected, allInvoices[i].Invoice) + _, err = db.AddInvoice(invoice, paymentHash) + require.NoError(t, err) } + resetCount = 0 + require.NoError(t, db.ScanInvoices(scanFunc, reset)) + require.Equal(t, numInvoices, callCount) + require.Equal(t, testInvoices, invoices) + require.Equal(t, 1, resetCount) } // TestDuplicateSettleInvoice tests that if we add a new invoice and settle it diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 436f194e..a7ece3c3 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -723,28 +723,21 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.RBucket, } } -// InvoiceWithPaymentHash is used to store an invoice and its corresponding -// payment hash. This struct is only used to store results of -// ChannelDB.FetchAllInvoicesWithPaymentHash() call. -type InvoiceWithPaymentHash struct { - // Invoice holds the invoice as selected from the invoices bucket. - Invoice Invoice +// ScanInvoices scans trough all invoices and calls the passed scanFunc for +// for each invoice with its respective payment hash. Additionally a reset() +// closure is passed which is used to reset/initialize partial results and also +// to signal if the kvdb.View transaction has been retried. +func (d *DB) ScanInvoices( + scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error { - // PaymentHash is the payment hash for the Invoice. - PaymentHash lntypes.Hash -} + return kvdb.View(d, func(tx kvdb.RTx) error { + // Reset partial results. As transaction commit success is not + // guaranteed when using etcd, we need to be prepared to redo + // the whole view transaction. In order to be able to do that + // we need a way to reset existing results. This is also done + // upon first run for initialization. + reset() -// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes -// currently stored within the database. If the pendingOnly param is true, then -// only open or accepted invoices and their payment hashes will be returned, -// skipping all invoices that are fully settled or canceled. Note that the -// returned array is not ordered by add index. -func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( - []InvoiceWithPaymentHash, error) { - - var result []InvoiceWithPaymentHash - - err := kvdb.View(d, func(tx kvdb.RTx) error { invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { return ErrNoInvoicesCreated @@ -775,26 +768,12 @@ func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( return err } - if pendingOnly && !invoice.IsPending() { - return nil - } + var paymentHash lntypes.Hash + copy(paymentHash[:], k) - invoiceWithPaymentHash := InvoiceWithPaymentHash{ - Invoice: invoice, - } - - copy(invoiceWithPaymentHash.PaymentHash[:], k) - result = append(result, invoiceWithPaymentHash) - - return nil + return scanFunc(paymentHash, &invoice) }) }) - - if err != nil { - return nil, err - } - - return result, nil } // InvoiceQuery represents a query to the invoice database. The query allows a diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index 9df6ca74..a46f27f5 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -129,14 +129,11 @@ func (ew *InvoiceExpiryWatcher) prepareInvoice( // AddInvoices adds multiple invoices to the InvoiceExpiryWatcher. func (ew *InvoiceExpiryWatcher) AddInvoices( - invoices []channeldb.InvoiceWithPaymentHash) { + invoices map[lntypes.Hash]*channeldb.Invoice) { invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices)) - for _, invoiceWithPaymentHash := range invoices { - newInvoiceExpiry := ew.prepareInvoice( - invoiceWithPaymentHash.PaymentHash, - &invoiceWithPaymentHash.Invoice, - ) + for paymentHash, invoice := range invoices { + newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice) if newInvoiceExpiry != nil { invoicesWithExpiry = append( invoicesWithExpiry, newInvoiceExpiry, diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 58d6e2d8..67ea2525 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -158,24 +158,14 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) - var invoices []channeldb.InvoiceWithPaymentHash + invoices := make(map[lntypes.Hash]*channeldb.Invoice) for hash, invoice := range test.testData.expiredInvoices { - invoices = append(invoices, - channeldb.InvoiceWithPaymentHash{ - Invoice: *invoice, - PaymentHash: hash, - }, - ) + invoices[hash] = invoice } for hash, invoice := range test.testData.pendingInvoices { - invoices = append(invoices, - channeldb.InvoiceWithPaymentHash{ - Invoice: *invoice, - PaymentHash: hash, - }, - ) + invoices[hash] = invoice } test.watcher.AddInvoices(invoices) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 84d64617..66043ff0 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -147,21 +147,39 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, } } -// populateExpiryWatcher fetches all active invoices and their corresponding -// payment hashes from ChannelDB and adds them to the expiry watcher. -func (i *InvoiceRegistry) populateExpiryWatcher() error { - pendingOnly := true - pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly) - if err != nil && err != channeldb.ErrNoInvoicesCreated { - log.Errorf( - "Error while prefetching active invoices from the database: %v", err, - ) +// scanInvoicesOnStart will scan all invoices on start and add active invoices +// to the invoice expiry watcher. +func (i *InvoiceRegistry) scanInvoicesOnStart() error { + var pending map[lntypes.Hash]*channeldb.Invoice + + reset := func() { + // Zero out our results on start and if the scan is ever run + // more than once. This latter case can happen if the kvdb + // layer needs to retry the View transaction underneath (eg. + // using the etcd driver, where all transactions are allowed + // to retry for serializability). + pending = make(map[lntypes.Hash]*channeldb.Invoice) + } + + scanFunc := func( + paymentHash lntypes.Hash, invoice *channeldb.Invoice) error { + + if invoice.IsPending() { + pending[paymentHash] = invoice + } + + return nil + } + + err := i.cdb.ScanInvoices(scanFunc, reset) + if err != nil { return err } log.Debugf("Adding %d pending invoices to the expiry watcher", - len(pendingInvoices)) - i.expiryWatcher.AddInvoices(pendingInvoices) + len(pending)) + i.expiryWatcher.AddInvoices(pending) + return nil } @@ -178,8 +196,9 @@ func (i *InvoiceRegistry) Start() error { i.wg.Add(1) go i.invoiceEventLoop() - // Now prefetch all pending invoices to the expiry watcher. - err = i.populateExpiryWatcher() + // Now scan all pending and removable invoices to the expiry watcher or + // delete them. + err = i.scanInvoicesOnStart() if err != nil { i.Stop() return err