diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 64e2dbe6..9d5aba36 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -622,9 +622,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } -// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their -// corresponding payment hashes. -func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { +// TestScanInvoices tests that ScanInvoices scans trough all stored invoices +// correctly. +func TestScanInvoices(t *testing.T) { t.Parallel() db, cleanup, err := MakeTestDB() @@ -633,97 +633,54 @@ func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { t.Fatalf("unable to make test db: %v", err) } - // With an empty DB we expect to return no error and an empty list. - empty, err := db.FetchAllInvoicesWithPaymentHash(false) - if err != nil { - t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v", - err) + var invoices map[lntypes.Hash]*Invoice + callCount := 0 + resetCount := 0 + + // reset is used to reset/initialize results and is called once + // upon calling ScanInvoices and when the underlying transaction is + // retried. + reset := func() { + invoices = make(map[lntypes.Hash]*Invoice) + callCount = 0 + resetCount++ + } - if len(empty) != 0 { - t.Fatalf("expected empty list as a result, got: %v", empty) + scanFunc := func(paymentHash lntypes.Hash, invoice *Invoice) error { + invoices[paymentHash] = invoice + callCount++ + + return nil } - states := []ContractState{ - ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, - } + // With an empty DB we expect to not scan any invoices. + require.NoError(t, db.ScanInvoices(scanFunc, reset)) + require.Equal(t, 0, len(invoices)) + require.Equal(t, 0, callCount) + require.Equal(t, 1, resetCount) - numInvoices := len(states) * 2 - testPendingInvoices := make(map[lntypes.Hash]*Invoice) - testAllInvoices := make(map[lntypes.Hash]*Invoice) + numInvoices := 5 + testInvoices := make(map[lntypes.Hash]*Invoice) // Now populate the DB and check if we can get all invoices with their // payment hashes as expected. for i := 1; i <= numInvoices; i++ { invoice, err := randInvoice(lnwire.MilliSatoshi(i)) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } + require.NoError(t, err) - // Set the contract state of the next invoice such that there's an equal - // number for all possbile states. - invoice.State = states[i%len(states)] paymentHash := invoice.Terms.PaymentPreimage.Hash() + testInvoices[paymentHash] = invoice - if invoice.IsPending() { - testPendingInvoices[paymentHash] = invoice - } - - testAllInvoices[paymentHash] = invoice - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice: %v", err) - } - } - - pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true) - if err != nil { - t.Fatalf("can't fetch invoices with payment hash: %v", err) - } - - if len(testPendingInvoices) != len(pendingInvoices) { - t.Fatalf("expected %v pending invoices, got: %v", - len(testPendingInvoices), len(pendingInvoices)) - } - - allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false) - if err != nil { - t.Fatalf("can't fetch invoices with payment hash: %v", err) - } - - if len(testAllInvoices) != len(allInvoices) { - t.Fatalf("expected %v invoices, got: %v", - len(testAllInvoices), len(allInvoices)) - } - - for i := range pendingInvoices { - expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash] - if !ok { - t.Fatalf("coulnd't find invoice with hash: %v", - pendingInvoices[i].PaymentHash) - } - - // Zero out add index to not confuse require.Equal. - pendingInvoices[i].Invoice.AddIndex = 0 - expected.AddIndex = 0 - - require.Equal(t, *expected, pendingInvoices[i].Invoice) - } - - for i := range allInvoices { - expected, ok := testAllInvoices[allInvoices[i].PaymentHash] - if !ok { - t.Fatalf("coulnd't find invoice with hash: %v", - allInvoices[i].PaymentHash) - } - - // Zero out add index to not confuse require.Equal. - allInvoices[i].Invoice.AddIndex = 0 - expected.AddIndex = 0 - - require.Equal(t, *expected, allInvoices[i].Invoice) + _, err = db.AddInvoice(invoice, paymentHash) + require.NoError(t, err) } + resetCount = 0 + require.NoError(t, db.ScanInvoices(scanFunc, reset)) + require.Equal(t, numInvoices, callCount) + require.Equal(t, testInvoices, invoices) + require.Equal(t, 1, resetCount) } // TestDuplicateSettleInvoice tests that if we add a new invoice and settle it diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 436f194e..a7ece3c3 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -723,28 +723,21 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.RBucket, } } -// InvoiceWithPaymentHash is used to store an invoice and its corresponding -// payment hash. This struct is only used to store results of -// ChannelDB.FetchAllInvoicesWithPaymentHash() call. -type InvoiceWithPaymentHash struct { - // Invoice holds the invoice as selected from the invoices bucket. - Invoice Invoice +// ScanInvoices scans trough all invoices and calls the passed scanFunc for +// for each invoice with its respective payment hash. Additionally a reset() +// closure is passed which is used to reset/initialize partial results and also +// to signal if the kvdb.View transaction has been retried. +func (d *DB) ScanInvoices( + scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error { - // PaymentHash is the payment hash for the Invoice. - PaymentHash lntypes.Hash -} + return kvdb.View(d, func(tx kvdb.RTx) error { + // Reset partial results. As transaction commit success is not + // guaranteed when using etcd, we need to be prepared to redo + // the whole view transaction. In order to be able to do that + // we need a way to reset existing results. This is also done + // upon first run for initialization. + reset() -// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes -// currently stored within the database. If the pendingOnly param is true, then -// only open or accepted invoices and their payment hashes will be returned, -// skipping all invoices that are fully settled or canceled. Note that the -// returned array is not ordered by add index. -func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( - []InvoiceWithPaymentHash, error) { - - var result []InvoiceWithPaymentHash - - err := kvdb.View(d, func(tx kvdb.RTx) error { invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { return ErrNoInvoicesCreated @@ -775,26 +768,12 @@ func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( return err } - if pendingOnly && !invoice.IsPending() { - return nil - } + var paymentHash lntypes.Hash + copy(paymentHash[:], k) - invoiceWithPaymentHash := InvoiceWithPaymentHash{ - Invoice: invoice, - } - - copy(invoiceWithPaymentHash.PaymentHash[:], k) - result = append(result, invoiceWithPaymentHash) - - return nil + return scanFunc(paymentHash, &invoice) }) }) - - if err != nil { - return nil, err - } - - return result, nil } // InvoiceQuery represents a query to the invoice database. The query allows a diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index 9df6ca74..a46f27f5 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -129,14 +129,11 @@ func (ew *InvoiceExpiryWatcher) prepareInvoice( // AddInvoices adds multiple invoices to the InvoiceExpiryWatcher. func (ew *InvoiceExpiryWatcher) AddInvoices( - invoices []channeldb.InvoiceWithPaymentHash) { + invoices map[lntypes.Hash]*channeldb.Invoice) { invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices)) - for _, invoiceWithPaymentHash := range invoices { - newInvoiceExpiry := ew.prepareInvoice( - invoiceWithPaymentHash.PaymentHash, - &invoiceWithPaymentHash.Invoice, - ) + for paymentHash, invoice := range invoices { + newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice) if newInvoiceExpiry != nil { invoicesWithExpiry = append( invoicesWithExpiry, newInvoiceExpiry, diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 58d6e2d8..67ea2525 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -158,24 +158,14 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) - var invoices []channeldb.InvoiceWithPaymentHash + invoices := make(map[lntypes.Hash]*channeldb.Invoice) for hash, invoice := range test.testData.expiredInvoices { - invoices = append(invoices, - channeldb.InvoiceWithPaymentHash{ - Invoice: *invoice, - PaymentHash: hash, - }, - ) + invoices[hash] = invoice } for hash, invoice := range test.testData.pendingInvoices { - invoices = append(invoices, - channeldb.InvoiceWithPaymentHash{ - Invoice: *invoice, - PaymentHash: hash, - }, - ) + invoices[hash] = invoice } test.watcher.AddInvoices(invoices) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 84d64617..66043ff0 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -147,21 +147,39 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, } } -// populateExpiryWatcher fetches all active invoices and their corresponding -// payment hashes from ChannelDB and adds them to the expiry watcher. -func (i *InvoiceRegistry) populateExpiryWatcher() error { - pendingOnly := true - pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly) - if err != nil && err != channeldb.ErrNoInvoicesCreated { - log.Errorf( - "Error while prefetching active invoices from the database: %v", err, - ) +// scanInvoicesOnStart will scan all invoices on start and add active invoices +// to the invoice expiry watcher. +func (i *InvoiceRegistry) scanInvoicesOnStart() error { + var pending map[lntypes.Hash]*channeldb.Invoice + + reset := func() { + // Zero out our results on start and if the scan is ever run + // more than once. This latter case can happen if the kvdb + // layer needs to retry the View transaction underneath (eg. + // using the etcd driver, where all transactions are allowed + // to retry for serializability). + pending = make(map[lntypes.Hash]*channeldb.Invoice) + } + + scanFunc := func( + paymentHash lntypes.Hash, invoice *channeldb.Invoice) error { + + if invoice.IsPending() { + pending[paymentHash] = invoice + } + + return nil + } + + err := i.cdb.ScanInvoices(scanFunc, reset) + if err != nil { return err } log.Debugf("Adding %d pending invoices to the expiry watcher", - len(pendingInvoices)) - i.expiryWatcher.AddInvoices(pendingInvoices) + len(pending)) + i.expiryWatcher.AddInvoices(pending) + return nil } @@ -178,8 +196,9 @@ func (i *InvoiceRegistry) Start() error { i.wg.Add(1) go i.invoiceEventLoop() - // Now prefetch all pending invoices to the expiry watcher. - err = i.populateExpiryWatcher() + // Now scan all pending and removable invoices to the expiry watcher or + // delete them. + err = i.scanInvoicesOnStart() if err != nil { i.Stop() return err