diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 64e2dbe6..bb118f71 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -622,9 +622,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } -// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their -// corresponding payment hashes. -func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { +// TestScanInvoices tests that ScanInvoices scans trough all stored invoices +// correctly. +func TestScanInvoices(t *testing.T) { t.Parallel() db, cleanup, err := MakeTestDB() @@ -633,97 +633,54 @@ func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { t.Fatalf("unable to make test db: %v", err) } - // With an empty DB we expect to return no error and an empty list. - empty, err := db.FetchAllInvoicesWithPaymentHash(false) - if err != nil { - t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v", - err) + var invoices map[lntypes.Hash]*Invoice + callCount := 0 + resetCount := 0 + + // reset is used to reset/initialize results and is called once + // upon calling ScanInvoices and when the underlying transaction is + // retried. + reset := func() { + invoices = make(map[lntypes.Hash]*Invoice) + callCount = 0 + resetCount++ + } - if len(empty) != 0 { - t.Fatalf("expected empty list as a result, got: %v", empty) + scanFunc := func(paymentHash lntypes.Hash, invoice *Invoice) error { + invoices[paymentHash] = invoice + callCount++ + + return nil } - states := []ContractState{ - ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, - } + // With an empty DB we expect to not scan any invoices. + require.NoError(t, db.ScanInvoices(scanFunc, reset)) + require.Equal(t, 0, len(invoices)) + require.Equal(t, 0, callCount) + require.Equal(t, 1, resetCount) - numInvoices := len(states) * 2 - testPendingInvoices := make(map[lntypes.Hash]*Invoice) - testAllInvoices := make(map[lntypes.Hash]*Invoice) + numInvoices := 5 + testInvoices := make(map[lntypes.Hash]*Invoice) // Now populate the DB and check if we can get all invoices with their // payment hashes as expected. for i := 1; i <= numInvoices; i++ { invoice, err := randInvoice(lnwire.MilliSatoshi(i)) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } + require.NoError(t, err) - // Set the contract state of the next invoice such that there's an equal - // number for all possbile states. - invoice.State = states[i%len(states)] paymentHash := invoice.Terms.PaymentPreimage.Hash() + testInvoices[paymentHash] = invoice - if invoice.IsPending() { - testPendingInvoices[paymentHash] = invoice - } - - testAllInvoices[paymentHash] = invoice - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice: %v", err) - } - } - - pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true) - if err != nil { - t.Fatalf("can't fetch invoices with payment hash: %v", err) - } - - if len(testPendingInvoices) != len(pendingInvoices) { - t.Fatalf("expected %v pending invoices, got: %v", - len(testPendingInvoices), len(pendingInvoices)) - } - - allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false) - if err != nil { - t.Fatalf("can't fetch invoices with payment hash: %v", err) - } - - if len(testAllInvoices) != len(allInvoices) { - t.Fatalf("expected %v invoices, got: %v", - len(testAllInvoices), len(allInvoices)) - } - - for i := range pendingInvoices { - expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash] - if !ok { - t.Fatalf("coulnd't find invoice with hash: %v", - pendingInvoices[i].PaymentHash) - } - - // Zero out add index to not confuse require.Equal. - pendingInvoices[i].Invoice.AddIndex = 0 - expected.AddIndex = 0 - - require.Equal(t, *expected, pendingInvoices[i].Invoice) - } - - for i := range allInvoices { - expected, ok := testAllInvoices[allInvoices[i].PaymentHash] - if !ok { - t.Fatalf("coulnd't find invoice with hash: %v", - allInvoices[i].PaymentHash) - } - - // Zero out add index to not confuse require.Equal. - allInvoices[i].Invoice.AddIndex = 0 - expected.AddIndex = 0 - - require.Equal(t, *expected, allInvoices[i].Invoice) + _, err = db.AddInvoice(invoice, paymentHash) + require.NoError(t, err) } + resetCount = 0 + require.NoError(t, db.ScanInvoices(scanFunc, reset)) + require.Equal(t, numInvoices, callCount) + require.Equal(t, testInvoices, invoices) + require.Equal(t, 1, resetCount) } // TestDuplicateSettleInvoice tests that if we add a new invoice and settle it @@ -1194,3 +1151,96 @@ func TestInvoiceRef(t *testing.T) { require.Equal(t, payHash, refByHashAndAddr.PayHash()) require.Equal(t, &payAddr, refByHashAndAddr.PayAddr()) } + +// TestDeleteInvoices tests that deleting a list of invoices will succeed +// if all delete references are valid, or will fail otherwise. +func TestDeleteInvoices(t *testing.T) { + t.Parallel() + + db, cleanup, err := MakeTestDB() + defer cleanup() + require.NoError(t, err, "unable to make test db") + + // Add some invoices to the test db. + numInvoices := 3 + invoicesToDelete := make([]InvoiceDeleteRef, numInvoices) + + for i := 0; i < numInvoices; i++ { + invoice, err := randInvoice(lnwire.MilliSatoshi(i + 1)) + require.NoError(t, err) + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + addIndex, err := db.AddInvoice(invoice, paymentHash) + require.NoError(t, err) + + // Settle the second invoice. + if i == 1 { + invoice, err = db.UpdateInvoice( + InvoiceRefByHash(paymentHash), + getUpdateInvoice(invoice.Terms.Value), + ) + require.NoError(t, err, "unable to settle invoice") + } + + // store the delete ref for later. + invoicesToDelete[i] = InvoiceDeleteRef{ + PayHash: paymentHash, + PayAddr: &invoice.Terms.PaymentAddr, + AddIndex: addIndex, + SettleIndex: invoice.SettleIndex, + } + } + + // assertInvoiceCount asserts that the number of invoices equals + // to the passed count. + assertInvoiceCount := func(count int) { + // Query to collect all invoices. + query := InvoiceQuery{ + IndexOffset: 0, + NumMaxInvoices: math.MaxUint64, + } + + // Check that we really have 3 invoices. + response, err := db.QueryInvoices(query) + require.NoError(t, err) + require.Equal(t, count, len(response.Invoices)) + } + + // XOR one byte of one of the references' hash and attempt to delete. + invoicesToDelete[0].PayHash[2] ^= 3 + require.Error(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(3) + + // Restore the hash. + invoicesToDelete[0].PayHash[2] ^= 3 + + // XOR one byte of one of the references' payment address and attempt + // to delete. + invoicesToDelete[1].PayAddr[5] ^= 7 + require.Error(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(3) + + // Restore the payment address. + invoicesToDelete[1].PayAddr[5] ^= 7 + + // XOR the second invoice's payment settle index as it is settled, and + // attempt to delete. + invoicesToDelete[1].SettleIndex ^= 11 + require.Error(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(3) + + // Restore the settle index. + invoicesToDelete[1].SettleIndex ^= 11 + + // XOR the add index for one of the references and attempt to delete. + invoicesToDelete[2].AddIndex ^= 13 + require.Error(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(3) + + // Restore the add index. + invoicesToDelete[2].AddIndex ^= 13 + + // Delete should succeed with all the valid references. + require.NoError(t, db.DeleteInvoice(invoicesToDelete)) + assertInvoiceCount(0) +} diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 436f194e..5f7b6462 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -723,28 +723,21 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.RBucket, } } -// InvoiceWithPaymentHash is used to store an invoice and its corresponding -// payment hash. This struct is only used to store results of -// ChannelDB.FetchAllInvoicesWithPaymentHash() call. -type InvoiceWithPaymentHash struct { - // Invoice holds the invoice as selected from the invoices bucket. - Invoice Invoice +// ScanInvoices scans trough all invoices and calls the passed scanFunc for +// for each invoice with its respective payment hash. Additionally a reset() +// closure is passed which is used to reset/initialize partial results and also +// to signal if the kvdb.View transaction has been retried. +func (d *DB) ScanInvoices( + scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error { - // PaymentHash is the payment hash for the Invoice. - PaymentHash lntypes.Hash -} + return kvdb.View(d, func(tx kvdb.RTx) error { + // Reset partial results. As transaction commit success is not + // guaranteed when using etcd, we need to be prepared to redo + // the whole view transaction. In order to be able to do that + // we need a way to reset existing results. This is also done + // upon first run for initialization. + reset() -// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes -// currently stored within the database. If the pendingOnly param is true, then -// only open or accepted invoices and their payment hashes will be returned, -// skipping all invoices that are fully settled or canceled. Note that the -// returned array is not ordered by add index. -func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( - []InvoiceWithPaymentHash, error) { - - var result []InvoiceWithPaymentHash - - err := kvdb.View(d, func(tx kvdb.RTx) error { invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { return ErrNoInvoicesCreated @@ -775,26 +768,12 @@ func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( return err } - if pendingOnly && !invoice.IsPending() { - return nil - } + var paymentHash lntypes.Hash + copy(paymentHash[:], k) - invoiceWithPaymentHash := InvoiceWithPaymentHash{ - Invoice: invoice, - } - - copy(invoiceWithPaymentHash.PaymentHash[:], k) - result = append(result, invoiceWithPaymentHash) - - return nil + return scanFunc(paymentHash, &invoice) }) }) - - if err != nil { - return nil, err - } - - return result, nil } // InvoiceQuery represents a query to the invoice database. The query allows a @@ -1761,3 +1740,134 @@ func setSettleMetaFields(settleIndex kvdb.RwBucket, invoiceNum []byte, return nil } + +// InvoiceDeleteRef holds a refererence to an invoice to be deleted. +type InvoiceDeleteRef struct { + // PayHash is the payment hash of the target invoice. All invoices are + // currently indexed by payment hash. + PayHash lntypes.Hash + + // PayAddr is the payment addr of the target invoice. Newer invoices + // (0.11 and up) are indexed by payment address in addition to payment + // hash, but pre 0.8 invoices do not have one at all. + PayAddr *[32]byte + + // AddIndex is the add index of the invoice. + AddIndex uint64 + + // SettleIndex is the settle index of the invoice. + SettleIndex uint64 +} + +// DeleteInvoice attempts to delete the passed invoices from the database in +// one transaction. The passed delete references hold all keys required to +// delete the invoices without also needing to deserialze them. +func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error { + err := kvdb.Update(d, func(tx kvdb.RwTx) error { + invoices := tx.ReadWriteBucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + + invoiceIndex := invoices.NestedReadWriteBucket( + invoiceIndexBucket, + ) + if invoiceIndex == nil { + return ErrNoInvoicesCreated + } + + invoiceAddIndex := invoices.NestedReadWriteBucket( + addIndexBucket, + ) + if invoiceAddIndex == nil { + return ErrNoInvoicesCreated + } + // settleIndex can be nil, as the bucket is created lazily + // when the first invoice is settled. + settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket) + + payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket) + + for _, ref := range invoicesToDelete { + // Fetch the invoice key for using it to check for + // consistency and also to delete from the invoice index. + invoiceKey := invoiceIndex.Get(ref.PayHash[:]) + if invoiceKey == nil { + return ErrInvoiceNotFound + } + + err := invoiceIndex.Delete(ref.PayHash[:]) + if err != nil { + return err + } + + // Delete payment address index reference if there's a + // valid payment address passed. + if ref.PayAddr != nil { + // To ensure consistency check that the already + // fetched invoice key matches the one in the + // payment address index. + key := payAddrIndex.Get(ref.PayAddr[:]) + if !bytes.Equal(key, invoiceKey) { + return fmt.Errorf("unknown invoice") + } + + // Delete from the payment address index. + err := payAddrIndex.Delete(ref.PayAddr[:]) + if err != nil { + return err + } + } + + var addIndexKey [8]byte + byteOrder.PutUint64(addIndexKey[:], ref.AddIndex) + + // To ensure consistency check that the key stored in + // the add index also matches the previously fetched + // invoice key. + key := invoiceAddIndex.Get(addIndexKey[:]) + if !bytes.Equal(key, invoiceKey) { + return fmt.Errorf("unknown invoice") + } + + // Remove from the add index. + err = invoiceAddIndex.Delete(addIndexKey[:]) + if err != nil { + return err + } + + // Remove from the settle index if available and + // if the invoice is settled. + if settleIndex != nil && ref.SettleIndex > 0 { + var settleIndexKey [8]byte + byteOrder.PutUint64( + settleIndexKey[:], ref.SettleIndex, + ) + + // To ensure consistency check that the already + // fetched invoice key matches the one in the + // settle index + key := settleIndex.Get(settleIndexKey[:]) + if !bytes.Equal(key, invoiceKey) { + return fmt.Errorf("unknown invoice") + } + + err = settleIndex.Delete(settleIndexKey[:]) + if err != nil { + return err + } + } + + // Finally remove the serialized invoice from the + // invoice bucket. + err = invoices.Delete(invoiceKey) + if err != nil { + return err + } + } + + return nil + }) + + return err +} diff --git a/config.go b/config.go index 78f76589..06f74979 100644 --- a/config.go +++ b/config.go @@ -245,6 +245,10 @@ type Config struct { KeysendHoldTime time.Duration `long:"keysend-hold-time" description:"If non-zero, keysend payments are accepted but not immediately settled. If the payment isn't settled manually after the specified time, it is canceled automatically. [experimental]"` + GcCanceledInvoicesOnStartup bool `long:"gc-canceled-invoices-on-startup" description:"If true, we'll attempt to garbage collect canceled invoices upon start."` + + GcCanceledInvoicesOnTheFly bool `long:"gc-canceled-invoices-on-the-fly" description:"If true, we'll delete newly canceled invoices on the fly."` + Routing *routing.Conf `group:"routing" namespace:"routing"` Workers *lncfg.Workers `group:"workers" namespace:"workers"` diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index f0db08d1..a46f27f5 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -48,8 +48,8 @@ type InvoiceExpiryWatcher struct { // invoice to expire. expiryQueue queue.PriorityQueue - // newInvoices channel is used to wake up the main loop when a new invoices - // is added. + // newInvoices channel is used to wake up the main loop when a new + // invoices is added. newInvoices chan []*invoiceExpiry wg sync.WaitGroup @@ -109,7 +109,8 @@ func (ew *InvoiceExpiryWatcher) prepareInvoice( paymentHash lntypes.Hash, invoice *channeldb.Invoice) *invoiceExpiry { if invoice.State != channeldb.ContractOpen { - log.Debugf("Invoice not added to expiry watcher: %v", paymentHash) + log.Debugf("Invoice not added to expiry watcher: %v", + paymentHash) return nil } @@ -128,15 +129,15 @@ func (ew *InvoiceExpiryWatcher) prepareInvoice( // AddInvoices adds multiple invoices to the InvoiceExpiryWatcher. func (ew *InvoiceExpiryWatcher) AddInvoices( - invoices []channeldb.InvoiceWithPaymentHash) { + invoices map[lntypes.Hash]*channeldb.Invoice) { invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices)) - for _, invoiceWithPaymentHash := range invoices { - newInvoiceExpiry := ew.prepareInvoice( - invoiceWithPaymentHash.PaymentHash, &invoiceWithPaymentHash.Invoice, - ) + for paymentHash, invoice := range invoices { + newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice) if newInvoiceExpiry != nil { - invoicesWithExpiry = append(invoicesWithExpiry, newInvoiceExpiry) + invoicesWithExpiry = append( + invoicesWithExpiry, newInvoiceExpiry, + ) } } @@ -160,8 +161,8 @@ func (ew *InvoiceExpiryWatcher) AddInvoice( newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice) if newInvoiceExpiry != nil { - log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v", - paymentHash, newInvoiceExpiry.Expiry) + log.Debugf("Adding invoice '%v' to expiry watcher,"+ + "expiration: %v", paymentHash, newInvoiceExpiry.Expiry) select { case ew.newInvoices <- []*invoiceExpiry{newInvoiceExpiry}: @@ -202,7 +203,8 @@ func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() { if err != nil && err != channeldb.ErrInvoiceAlreadySettled && err != channeldb.ErrInvoiceAlreadyCanceled { - log.Errorf("Unable to cancel invoice: %v", top.PaymentHash) + log.Errorf("Unable to cancel invoice: %v", + top.PaymentHash) } ew.expiryQueue.Pop() @@ -236,8 +238,8 @@ func (ew *InvoiceExpiryWatcher) mainLoop() { continue case invoicesWithExpiry := <-ew.newInvoices: - for _, invoiceWithExpiry := range invoicesWithExpiry { - ew.expiryQueue.Push(invoiceWithExpiry) + for _, invoice := range invoicesWithExpiry { + ew.expiryQueue.Push(invoice) } case <-ew.quit: diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 2aa0f87b..67ea2525 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -37,7 +37,9 @@ func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time, err := test.watcher.Start(func(paymentHash lntypes.Hash, force bool) error { - test.canceledInvoices = append(test.canceledInvoices, paymentHash) + test.canceledInvoices = append( + test.canceledInvoices, paymentHash, + ) test.wg.Done() return nil }) @@ -70,7 +72,8 @@ func (t *invoiceExpiryWatcherTest) checkExpectations() { // that expired. if len(t.canceledInvoices) != len(t.testData.expiredInvoices) { t.t.Fatalf("expected %v cancellations, got %v", - len(t.testData.expiredInvoices), len(t.canceledInvoices)) + len(t.testData.expiredInvoices), + len(t.canceledInvoices)) } for i := range t.canceledInvoices { @@ -155,24 +158,14 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) - var invoices []channeldb.InvoiceWithPaymentHash + invoices := make(map[lntypes.Hash]*channeldb.Invoice) for hash, invoice := range test.testData.expiredInvoices { - invoices = append(invoices, - channeldb.InvoiceWithPaymentHash{ - Invoice: *invoice, - PaymentHash: hash, - }, - ) + invoices[hash] = invoice } for hash, invoice := range test.testData.pendingInvoices { - invoices = append(invoices, - channeldb.InvoiceWithPaymentHash{ - Invoice: *invoice, - PaymentHash: hash, - }, - ) + invoices[hash] = invoice } test.watcher.AddInvoices(invoices) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 84d64617..c827f144 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -57,6 +57,14 @@ type RegistryConfig struct { // send payments. AcceptKeySend bool + // GcCanceledInvoicesOnStartup if set, we'll attempt to garbage collect + // all canceled invoices upon start. + GcCanceledInvoicesOnStartup bool + + // GcCanceledInvoicesOnTheFly if set, we'll garbage collect all newly + // canceled invoices on the fly. + GcCanceledInvoicesOnTheFly bool + // KeysendHoldTime indicates for how long we want to accept and hold // spontaneous keysend payments. KeysendHoldTime time.Duration @@ -147,21 +155,65 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, } } -// populateExpiryWatcher fetches all active invoices and their corresponding -// payment hashes from ChannelDB and adds them to the expiry watcher. -func (i *InvoiceRegistry) populateExpiryWatcher() error { - pendingOnly := true - pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly) - if err != nil && err != channeldb.ErrNoInvoicesCreated { - log.Errorf( - "Error while prefetching active invoices from the database: %v", err, - ) +// scanInvoicesOnStart will scan all invoices on start and add active invoices +// to the invoice expirt watcher while also attempting to delete all canceled +// invoices. +func (i *InvoiceRegistry) scanInvoicesOnStart() error { + var ( + pending map[lntypes.Hash]*channeldb.Invoice + removable []channeldb.InvoiceDeleteRef + ) + + reset := func() { + // Zero out our results on start and if the scan is ever run + // more than once. This latter case can happen if the kvdb + // layer needs to retry the View transaction underneath (eg. + // using the etcd driver, where all transactions are allowed + // to retry for serializability). + pending = make(map[lntypes.Hash]*channeldb.Invoice) + removable = make([]channeldb.InvoiceDeleteRef, 0) + } + + scanFunc := func( + paymentHash lntypes.Hash, invoice *channeldb.Invoice) error { + + if invoice.IsPending() { + pending[paymentHash] = invoice + } else if i.cfg.GcCanceledInvoicesOnStartup && + invoice.State == channeldb.ContractCanceled { + + // Consider invoice for removal if it is already + // canceled. Invoices that are expired but not yet + // canceled, will be queued up for cancellation after + // startup and will be deleted afterwards. + ref := channeldb.InvoiceDeleteRef{ + PayHash: paymentHash, + AddIndex: invoice.AddIndex, + SettleIndex: invoice.SettleIndex, + } + + if invoice.Terms.PaymentAddr != channeldb.BlankPayAddr { + ref.PayAddr = &invoice.Terms.PaymentAddr + } + + removable = append(removable, ref) + } + return nil + } + + err := i.cdb.ScanInvoices(scanFunc, reset) + if err != nil { return err } log.Debugf("Adding %d pending invoices to the expiry watcher", - len(pendingInvoices)) - i.expiryWatcher.AddInvoices(pendingInvoices) + len(pending)) + i.expiryWatcher.AddInvoices(pending) + + if err := i.cdb.DeleteInvoice(removable); err != nil { + log.Warnf("Deleting old invoices failed: %v", err) + } + return nil } @@ -178,8 +230,9 @@ func (i *InvoiceRegistry) Start() error { i.wg.Add(1) go i.invoiceEventLoop() - // Now prefetch all pending invoices to the expiry watcher. - err = i.populateExpiryWatcher() + // Now scan all pending and removable invoices to the expiry watcher or + // delete them. + err = i.scanInvoicesOnStart() if err != nil { i.Stop() return err @@ -1075,6 +1128,32 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, } i.notifyClients(payHash, invoice, channeldb.ContractCanceled) + // Attempt to also delete the invoice if requested through the registry + // config. + if i.cfg.GcCanceledInvoicesOnTheFly { + // Assemble the delete reference and attempt to delete through + // the invocice from the DB. + deleteRef := channeldb.InvoiceDeleteRef{ + PayHash: payHash, + AddIndex: invoice.AddIndex, + SettleIndex: invoice.SettleIndex, + } + if invoice.Terms.PaymentAddr != channeldb.BlankPayAddr { + deleteRef.PayAddr = &invoice.Terms.PaymentAddr + } + + err = i.cdb.DeleteInvoice( + []channeldb.InvoiceDeleteRef{deleteRef}, + ) + // If by any chance deletion failed, then log it instead of + // returning the error, as the invoice itsels has already been + // canceled. + if err != nil { + log.Warnf("Invoice%v could not be deleted: %v", + ref, err) + } + } + return nil } diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index c77b38ed..cb916aea 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -1,6 +1,7 @@ package invoices import ( + "math" "testing" "time" @@ -219,11 +220,14 @@ func TestSettleInvoice(t *testing.T) { } } -// TestCancelInvoice tests cancelation of an invoice and related notifications. -func TestCancelInvoice(t *testing.T) { +func testCancelInvoice(t *testing.T, gc bool) { ctx := newTestContext(t) defer ctx.cleanup() + // If set to true, then also delete the invoice from the DB after + // cancellation. + ctx.registry.cfg.GcCanceledInvoicesOnTheFly = gc + allSubscriptions, err := ctx.registry.SubscribeNotifications(0, 0) assert.Nil(t, err) defer allSubscriptions.Cancel() @@ -298,13 +302,26 @@ func TestCancelInvoice(t *testing.T) { t.Fatal("no update received") } + if gc { + // Check that the invoice has been deleted from the db. + _, err = ctx.cdb.LookupInvoice( + channeldb.InvoiceRefByHash(testInvoicePaymentHash), + ) + require.Error(t, err) + } + // We expect no cancel notification to be sent to all invoice // subscribers (backwards compatibility). - // Try to cancel again. + // Try to cancel again. Expect that we report ErrInvoiceNotFound if the + // invoice has been garbage collected (since the invoice has been + // deleted when it was canceled), and no error otherwise. err = ctx.registry.CancelInvoice(testInvoicePaymentHash) - if err != nil { - t.Fatal("expected cancelation of a canceled invoice to succeed") + + if gc { + require.Error(t, err, channeldb.ErrInvoiceNotFound) + } else { + require.NoError(t, err) } // Notify arrival of a new htlc paying to this invoice. This should @@ -326,12 +343,33 @@ func TestCancelInvoice(t *testing.T) { t.Fatalf("expected acceptHeight %v, but got %v", testCurrentHeight, failResolution.AcceptHeight) } - if failResolution.Outcome != ResultInvoiceAlreadyCanceled { - t.Fatalf("expected expiry too soon, got: %v", - failResolution.Outcome) + + // If the invoice has been deleted (or not present) then we expect the + // outcome to be ResultInvoiceNotFound instead of when the invoice is + // in our database in which case we expect ResultInvoiceAlreadyCanceled. + if gc { + require.Equal(t, failResolution.Outcome, ResultInvoiceNotFound) + } else { + require.Equal(t, + failResolution.Outcome, + ResultInvoiceAlreadyCanceled, + ) } } +// TestCancelInvoice tests cancelation of an invoice and related notifications. +func TestCancelInvoice(t *testing.T) { + // Test cancellation both with garbage collection (meaning that canceled + // invoice will be deleted) and without (meain it'll be kept). + t.Run("garbage collect", func(t *testing.T) { + testCancelInvoice(t, true) + }) + + t.Run("no garbage collect", func(t *testing.T) { + testCancelInvoice(t, false) + }) +} + // TestSettleHoldInvoice tests settling of a hold invoice and related // notifications. func TestSettleHoldInvoice(t *testing.T) { @@ -1077,3 +1115,78 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { } } } + +// TestOldInvoiceRemovalOnStart tests that we'll attempt to remove old canceled +// invoices upon start while keeping all settled ones. +func TestOldInvoiceRemovalOnStart(t *testing.T) { + t.Parallel() + + testClock := clock.NewTestClock(testTime) + cdb, cleanup, err := newTestChannelDB(testClock) + defer cleanup() + + require.NoError(t, err) + + cfg := RegistryConfig{ + FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: testClock, + GcCanceledInvoicesOnStartup: true, + } + + expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock) + registry := NewRegistry(cdb, expiryWatcher, &cfg) + + // First prefill the Channel DB with some pre-existing expired invoices. + const numExpired = 5 + const numPending = 0 + existingInvoices := generateInvoiceExpiryTestData( + t, testTime, 0, numExpired, numPending, + ) + + i := 0 + for paymentHash, invoice := range existingInvoices.expiredInvoices { + // Mark half of the invoices as settled, the other hald as + // canceled. + if i%2 == 0 { + invoice.State = channeldb.ContractSettled + } else { + invoice.State = channeldb.ContractCanceled + } + + _, err := cdb.AddInvoice(invoice, paymentHash) + require.NoError(t, err) + i++ + } + + // Collect all settled invoices for our expectation set. + var expected []channeldb.Invoice + + // Perform a scan query to collect all invoices. + query := channeldb.InvoiceQuery{ + IndexOffset: 0, + NumMaxInvoices: math.MaxUint64, + } + + response, err := cdb.QueryInvoices(query) + require.NoError(t, err) + + // Save all settled invoices for our expectation set. + for _, invoice := range response.Invoices { + if invoice.State == channeldb.ContractSettled { + expected = append(expected, invoice) + } + } + + // Start the registry which should collect and delete all canceled + // invoices upon start. + err = registry.Start() + require.NoError(t, err, "cannot start the registry") + + // Perform a scan query to collect all invoices. + response, err = cdb.QueryInvoices(query) + require.NoError(t, err) + + // Check that we really only kept the settled invoices after the + // registry start. + require.Equal(t, expected, response.Invoices) +} diff --git a/server.go b/server.go index b8c54ef2..5db6e7be 100644 --- a/server.go +++ b/server.go @@ -402,11 +402,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } registryConfig := invoices.RegistryConfig{ - FinalCltvRejectDelta: lncfg.DefaultFinalCltvRejectDelta, - HtlcHoldDuration: invoices.DefaultHtlcHoldDuration, - Clock: clock.NewDefaultClock(), - AcceptKeySend: cfg.AcceptKeySend, - KeysendHoldTime: cfg.KeysendHoldTime, + FinalCltvRejectDelta: lncfg.DefaultFinalCltvRejectDelta, + HtlcHoldDuration: invoices.DefaultHtlcHoldDuration, + Clock: clock.NewDefaultClock(), + AcceptKeySend: cfg.AcceptKeySend, + GcCanceledInvoicesOnStartup: cfg.GcCanceledInvoicesOnStartup, + GcCanceledInvoicesOnTheFly: cfg.GcCanceledInvoicesOnTheFly, + KeysendHoldTime: cfg.KeysendHoldTime, } s := &server{