diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 66043ff0..e53d7da8 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -148,9 +148,13 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, } // scanInvoicesOnStart will scan all invoices on start and add active invoices -// to the invoice expiry watcher. +// to the invoice expirt watcher while also attempting to delete all canceled +// invoices. func (i *InvoiceRegistry) scanInvoicesOnStart() error { - var pending map[lntypes.Hash]*channeldb.Invoice + var ( + pending map[lntypes.Hash]*channeldb.Invoice + removable []channeldb.InvoiceDeleteRef + ) reset := func() { // Zero out our results on start and if the scan is ever run @@ -159,6 +163,7 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { // using the etcd driver, where all transactions are allowed // to retry for serializability). pending = make(map[lntypes.Hash]*channeldb.Invoice) + removable = make([]channeldb.InvoiceDeleteRef, 0) } scanFunc := func( @@ -166,8 +171,23 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { if invoice.IsPending() { pending[paymentHash] = invoice - } + } else if invoice.State == channeldb.ContractCanceled { + // Consider invoice for removal if it is already + // canceled. Invoices that are expired but not yet + // canceled, will be queued up for cancellation after + // startup and will be deleted afterwards. + ref := channeldb.InvoiceDeleteRef{ + PayHash: paymentHash, + AddIndex: invoice.AddIndex, + SettleIndex: invoice.SettleIndex, + } + if invoice.Terms.PaymentAddr != channeldb.BlankPayAddr { + ref.PayAddr = &invoice.Terms.PaymentAddr + } + + removable = append(removable, ref) + } return nil } @@ -180,6 +200,10 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { len(pending)) i.expiryWatcher.AddInvoices(pending) + if err := i.cdb.DeleteInvoice(removable); err != nil { + log.Warnf("Deleting old invoices failed: %v", err) + } + return nil } diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index c77b38ed..0da260a2 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -1,6 +1,7 @@ package invoices import ( + "math" "testing" "time" @@ -1077,3 +1078,77 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { } } } + +// TestOldInvoiceRemovalOnStart tests that we'll attempt to remove old canceled +// invoices upon start while keeping all settled ones. +func TestOldInvoiceRemovalOnStart(t *testing.T) { + t.Parallel() + + testClock := clock.NewTestClock(testTime) + cdb, cleanup, err := newTestChannelDB(testClock) + defer cleanup() + + require.NoError(t, err) + + cfg := RegistryConfig{ + FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: testClock, + } + + expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock) + registry := NewRegistry(cdb, expiryWatcher, &cfg) + + // First prefill the Channel DB with some pre-existing expired invoices. + const numExpired = 5 + const numPending = 0 + existingInvoices := generateInvoiceExpiryTestData( + t, testTime, 0, numExpired, numPending, + ) + + i := 0 + for paymentHash, invoice := range existingInvoices.expiredInvoices { + // Mark half of the invoices as settled, the other hald as + // canceled. + if i%2 == 0 { + invoice.State = channeldb.ContractSettled + } else { + invoice.State = channeldb.ContractCanceled + } + + _, err := cdb.AddInvoice(invoice, paymentHash) + require.NoError(t, err) + i++ + } + + // Collect all settled invoices for our expectation set. + var expected []channeldb.Invoice + + // Perform a scan query to collect all invoices. + query := channeldb.InvoiceQuery{ + IndexOffset: 0, + NumMaxInvoices: math.MaxUint64, + } + + response, err := cdb.QueryInvoices(query) + require.NoError(t, err) + + // Save all settled invoices for our expectation set. + for _, invoice := range response.Invoices { + if invoice.State == channeldb.ContractSettled { + expected = append(expected, invoice) + } + } + + // Start the registry which should collect and delete all canceled + // invoices upon start. + err = registry.Start() + require.NoError(t, err, "cannot start the registry") + + // Perform a scan query to collect all invoices. + response, err = cdb.QueryInvoices(query) + require.NoError(t, err) + + // Check that we really only kept the settled invoices after the + // registry start. + require.Equal(t, expected, response.Invoices) +}