diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index a46f27f5..4c7f8f6e 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -103,10 +103,11 @@ func (ew *InvoiceExpiryWatcher) Stop() { } } -// prepareInvoice checks if the passed invoice may be canceled and calculates -// the expiry time. -func (ew *InvoiceExpiryWatcher) prepareInvoice( - paymentHash lntypes.Hash, invoice *channeldb.Invoice) *invoiceExpiry { +// makeInvoiceExpiry checks if the passed invoice may be canceled and calculates +// the expiry time and creates a slimmer invoiceExpiry object with the hash and +// expiry time. +func makeInvoiceExpiry(paymentHash lntypes.Hash, + invoice *channeldb.Invoice) *invoiceExpiry { if invoice.State != channeldb.ContractOpen { log.Debugf("Invoice not added to expiry watcher: %v", @@ -127,45 +128,14 @@ func (ew *InvoiceExpiryWatcher) prepareInvoice( } } -// AddInvoices adds multiple invoices to the InvoiceExpiryWatcher. -func (ew *InvoiceExpiryWatcher) AddInvoices( - invoices map[lntypes.Hash]*channeldb.Invoice) { - - invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices)) - for paymentHash, invoice := range invoices { - newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice) - if newInvoiceExpiry != nil { - invoicesWithExpiry = append( - invoicesWithExpiry, newInvoiceExpiry, - ) - } - } - - if len(invoicesWithExpiry) > 0 { - log.Debugf("Added %d invoices to the expiry watcher", - len(invoicesWithExpiry)) +// AddInvoices adds invoices to the InvoiceExpiryWatcher. +func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...*invoiceExpiry) { + if len(invoices) > 0 { select { - case ew.newInvoices <- invoicesWithExpiry: - // Select on quit too so that callers won't get blocked in case - // of concurrent shutdown. - case <-ew.quit: - } - } -} + case ew.newInvoices <- invoices: + log.Debugf("Added %d invoices to the expiry watcher", + len(invoices)) -// AddInvoice adds a new invoice to the InvoiceExpiryWatcher. This won't check -// if the invoice is already added and will only add invoices with ContractOpen -// state. -func (ew *InvoiceExpiryWatcher) AddInvoice( - paymentHash lntypes.Hash, invoice *channeldb.Invoice) { - - newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice) - if newInvoiceExpiry != nil { - log.Debugf("Adding invoice '%v' to expiry watcher,"+ - "expiration: %v", paymentHash, newInvoiceExpiry.Expiry) - - select { - case ew.newInvoices <- []*invoiceExpiry{newInvoiceExpiry}: // Select on quit too so that callers won't get blocked in case // of concurrent shutdown. case <-ew.quit: @@ -220,14 +190,21 @@ func (ew *InvoiceExpiryWatcher) mainLoop() { // Cancel any invoices that may have expired. ew.cancelNextExpiredInvoice() + pushInvoices := func(invoicesWithExpiry []*invoiceExpiry) { + for _, invoiceWithExpiry := range invoicesWithExpiry { + // Avoid pushing nil object to the heap. + if invoiceWithExpiry != nil { + ew.expiryQueue.Push(invoiceWithExpiry) + } + } + } + select { case invoicesWithExpiry := <-ew.newInvoices: // Take newly forwarded invoices with higher priority // in order to not block the newInvoices channel. - for _, invoiceWithExpiry := range invoicesWithExpiry { - ew.expiryQueue.Push(invoiceWithExpiry) - } + pushInvoices(invoicesWithExpiry) continue default: @@ -238,9 +215,7 @@ func (ew *InvoiceExpiryWatcher) mainLoop() { continue case invoicesWithExpiry := <-ew.newInvoices: - for _, invoice := range invoicesWithExpiry { - ew.expiryQueue.Push(invoice) - } + pushInvoices(invoicesWithExpiry) case <-ew.quit: return diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 67ea2525..a06bde53 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" ) @@ -125,7 +124,7 @@ func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) { test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5) for paymentHash, invoice := range test.testData.pendingInvoices { - test.watcher.AddInvoice(paymentHash, invoice) + test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice)) } test.waitForFinish(testTimeout) @@ -141,11 +140,11 @@ func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) { test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) for paymentHash, invoice := range test.testData.expiredInvoices { - test.watcher.AddInvoice(paymentHash, invoice) + test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice)) } for paymentHash, invoice := range test.testData.pendingInvoices { - test.watcher.AddInvoice(paymentHash, invoice) + test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice)) } test.waitForFinish(testTimeout) @@ -158,17 +157,17 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) - invoices := make(map[lntypes.Hash]*channeldb.Invoice) + var invoices []*invoiceExpiry for hash, invoice := range test.testData.expiredInvoices { - invoices[hash] = invoice + invoices = append(invoices, makeInvoiceExpiry(hash, invoice)) } for hash, invoice := range test.testData.pendingInvoices { - invoices[hash] = invoice + invoices = append(invoices, makeInvoiceExpiry(hash, invoice)) } - test.watcher.AddInvoices(invoices) + test.watcher.AddInvoices(invoices...) test.waitForFinish(testTimeout) test.watcher.Stop() test.checkExpectations() diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index c827f144..cd7ebe8c 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -160,7 +160,7 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, // invoices. func (i *InvoiceRegistry) scanInvoicesOnStart() error { var ( - pending map[lntypes.Hash]*channeldb.Invoice + pending []*invoiceExpiry removable []channeldb.InvoiceDeleteRef ) @@ -170,7 +170,7 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { // layer needs to retry the View transaction underneath (eg. // using the etcd driver, where all transactions are allowed // to retry for serializability). - pending = make(map[lntypes.Hash]*channeldb.Invoice) + pending = nil removable = make([]channeldb.InvoiceDeleteRef, 0) } @@ -178,7 +178,10 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { paymentHash lntypes.Hash, invoice *channeldb.Invoice) error { if invoice.IsPending() { - pending[paymentHash] = invoice + expiryRef := makeInvoiceExpiry(paymentHash, invoice) + if expiryRef != nil { + pending = append(pending, expiryRef) + } } else if i.cfg.GcCanceledInvoicesOnStartup && invoice.State == channeldb.ContractCanceled { @@ -208,7 +211,7 @@ func (i *InvoiceRegistry) scanInvoicesOnStart() error { log.Debugf("Adding %d pending invoices to the expiry watcher", len(pending)) - i.expiryWatcher.AddInvoices(pending) + i.expiryWatcher.AddInvoices(pending...) if err := i.cdb.DeleteInvoice(removable); err != nil { log.Warnf("Deleting old invoices failed: %v", err) @@ -562,7 +565,10 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, // InvoiceExpiryWatcher.AddInvoice must not be locked by InvoiceRegistry // to avoid deadlock when a new invoice is added while an other is being // canceled. - i.expiryWatcher.AddInvoice(paymentHash, invoice) + invoiceExpiryRef := makeInvoiceExpiry(paymentHash, invoice) + if invoiceExpiryRef != nil { + i.expiryWatcher.AddInvoices(invoiceExpiryRef) + } return addIndex, nil }