diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index db9d9cd5..07c21780 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -49,7 +49,7 @@ type InvoiceExpiryWatcher struct { // newInvoices channel is used to wake up the main loop when a new invoices // is added. - newInvoices chan *invoiceExpiry + newInvoices chan []*invoiceExpiry wg sync.WaitGroup @@ -61,7 +61,7 @@ type InvoiceExpiryWatcher struct { func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher { return &InvoiceExpiryWatcher{ clock: clock, - newInvoices: make(chan *invoiceExpiry), + newInvoices: make(chan []*invoiceExpiry), quit: make(chan struct{}), } } @@ -102,15 +102,14 @@ func (ew *InvoiceExpiryWatcher) Stop() { } } -// AddInvoice adds a new invoice to the InvoiceExpiryWatcher. This won't check -// if the invoice is already added and will only add invoices with ContractOpen -// state. -func (ew *InvoiceExpiryWatcher) AddInvoice( - paymentHash lntypes.Hash, invoice *channeldb.Invoice) { +// prepareInvoice checks if the passed invoice may be canceled and calculates +// the expiry time. +func (ew *InvoiceExpiryWatcher) prepareInvoice( + paymentHash lntypes.Hash, invoice *channeldb.Invoice) *invoiceExpiry { if invoice.State != channeldb.ContractOpen { log.Debugf("Invoice not added to expiry watcher: %v", invoice) - return + return nil } realExpiry := invoice.Terms.Expiry @@ -119,20 +118,55 @@ func (ew *InvoiceExpiryWatcher) AddInvoice( } expiry := invoice.CreationDate.Add(realExpiry) - - log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v", - paymentHash, expiry) - - newInvoiceExpiry := &invoiceExpiry{ + return &invoiceExpiry{ PaymentHash: paymentHash, Expiry: expiry, } +} - select { - case ew.newInvoices <- newInvoiceExpiry: - case <-ew.quit: +// AddInvoices adds multiple invoices to the InvoiceExpiryWatcher. +func (ew *InvoiceExpiryWatcher) AddInvoices( + invoices []channeldb.InvoiceWithPaymentHash) { + + invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices)) + for _, invoiceWithPaymentHash := range invoices { + newInvoiceExpiry := ew.prepareInvoice( + invoiceWithPaymentHash.PaymentHash, &invoiceWithPaymentHash.Invoice, + ) + if newInvoiceExpiry != nil { + invoicesWithExpiry = append(invoicesWithExpiry, newInvoiceExpiry) + } + } + + if len(invoicesWithExpiry) > 0 { + log.Debugf("Added %v invoices to the expiry watcher: %v", + len(invoicesWithExpiry)) + select { + case ew.newInvoices <- invoicesWithExpiry: // Select on quit too so that callers won't get blocked in case // of concurrent shutdown. + case <-ew.quit: + } + } +} + +// AddInvoice adds a new invoice to the InvoiceExpiryWatcher. This won't check +// if the invoice is already added and will only add invoices with ContractOpen +// state. +func (ew *InvoiceExpiryWatcher) AddInvoice( + paymentHash lntypes.Hash, invoice *channeldb.Invoice) { + + newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice) + if newInvoiceExpiry != nil { + log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v", + paymentHash, newInvoiceExpiry.Expiry) + + select { + case ew.newInvoices <- []*invoiceExpiry{newInvoiceExpiry}: + // Select on quit too so that callers won't get blocked in case + // of concurrent shutdown. + case <-ew.quit: + } } } @@ -147,13 +181,13 @@ func (ew *InvoiceExpiryWatcher) nextExpiry() <-chan time.Time { return nil } -// cancelExpiredInvoices will cancel all expired invoices and removes them from -// the expiry queue. -func (ew *InvoiceExpiryWatcher) cancelExpiredInvoices() { - for !ew.expiryQueue.Empty() { +// cancelNextExpiredInvoice will cancel the next expired invoice and removes +// it from the expiry queue. +func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() { + if !ew.expiryQueue.Empty() { top := ew.expiryQueue.Top().(*invoiceExpiry) if !top.Expiry.Before(ew.clock.Now()) { - break + return } err := ew.cancelInvoice(top.PaymentHash) @@ -174,18 +208,33 @@ func (ew *InvoiceExpiryWatcher) mainLoop() { for { // Cancel any invoices that may have expired. - ew.cancelExpiredInvoices() + ew.cancelNextExpiredInvoice() select { - case <-ew.nextExpiry(): - // Wait until the next invoice expires, then cancel expired invoices. + + case invoicesWithExpiry := <-ew.newInvoices: + // Take newly forwarded invoices with higher priority + // in order to not block the newInvoices channel. + for _, invoiceWithExpiry := range invoicesWithExpiry { + ew.expiryQueue.Push(invoiceWithExpiry) + } continue - case newInvoiceExpiry := <-ew.newInvoices: - ew.expiryQueue.Push(newInvoiceExpiry) + default: + select { - case <-ew.quit: - return + case <-ew.nextExpiry(): + // Wait until the next invoice expires. + continue + + case invoicesWithExpiry := <-ew.newInvoices: + for _, invoiceWithExpiry := range invoicesWithExpiry { + ew.expiryQueue.Push(invoiceWithExpiry) + } + + case <-ew.quit: + return + } } } } diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 8bfdfd69..8940063b 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" ) @@ -123,3 +124,31 @@ func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) { test.watcher.Stop() test.checkExpectations() } + +// Tests adding multiple invoices at once. +func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { + t.Parallel() + test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) + var invoices []channeldb.InvoiceWithPaymentHash + for hash, invoice := range test.testData.expiredInvoices { + invoices = append(invoices, + channeldb.InvoiceWithPaymentHash{ + Invoice: *invoice, + PaymentHash: hash, + }, + ) + } + for hash, invoice := range test.testData.pendingInvoices { + invoices = append(invoices, + channeldb.InvoiceWithPaymentHash{ + Invoice: *invoice, + PaymentHash: hash, + }, + ) + } + + test.watcher.AddInvoices(invoices) + time.Sleep(testTimeout) + test.watcher.Stop() + test.checkExpectations() +} diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index ea10750b..4b3a1b2d 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -198,12 +198,8 @@ func (i *InvoiceRegistry) populateExpiryWatcher() error { return err } - for idx := range pendingInvoices { - i.expiryWatcher.AddInvoice( - pendingInvoices[idx].PaymentHash, &pendingInvoices[idx].Invoice, - ) - } - + log.Debugf("Adding %v pending invoices to the expiry watcher") + i.expiryWatcher.AddInvoices(pendingInvoices) return nil }