diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index 21bb5d08..daac4aed 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -12,6 +12,14 @@ import ( "github.com/lightningnetwork/lnd/zpay32" ) +// invoiceExpiry is a vanity interface for different invoice expiry types +// which implement the priority queue item interface, used to improve code +// readability. +type invoiceExpiry queue.PriorityQueueItem + +// Compile time assertion that invoiceExpiryTs implements invoiceExpiry. +var _ invoiceExpiry = (*invoiceExpiryTs)(nil) + // invoiceExpiryTs holds and invoice's payment hash and its expiry. This // is used to order invoices by their expiry time for cancellation. type invoiceExpiryTs struct { @@ -50,7 +58,7 @@ type InvoiceExpiryWatcher struct { // newInvoices channel is used to wake up the main loop when a new // invoices is added. - newInvoices chan []*invoiceExpiryTs + newInvoices chan []invoiceExpiry wg sync.WaitGroup @@ -62,7 +70,7 @@ type InvoiceExpiryWatcher struct { func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher { return &InvoiceExpiryWatcher{ clock: clock, - newInvoices: make(chan []*invoiceExpiryTs), + newInvoices: make(chan []invoiceExpiry), quit: make(chan struct{}), } } @@ -104,10 +112,9 @@ func (ew *InvoiceExpiryWatcher) Stop() { } // makeInvoiceExpiry checks if the passed invoice may be canceled and calculates -// the expiry time and creates a slimmer invoiceExpiry object with the hash and -// expiry time. +// the expiry time and creates a slimmer invoiceExpiry implementation. func makeInvoiceExpiry(paymentHash lntypes.Hash, - invoice *channeldb.Invoice) *invoiceExpiryTs { + invoice *channeldb.Invoice) invoiceExpiry { if invoice.State != channeldb.ContractOpen { log.Debugf("Invoice not added to expiry watcher: %v", @@ -129,7 +136,7 @@ func makeInvoiceExpiry(paymentHash lntypes.Hash, } // AddInvoices adds invoices to the InvoiceExpiryWatcher. -func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...*invoiceExpiryTs) { +func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...invoiceExpiry) { if len(invoices) > 0 { select { case ew.newInvoices <- invoices: @@ -181,6 +188,24 @@ func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() { } } +// pushInvoices adds invoices to be expired to their relevant queue. +func (ew *InvoiceExpiryWatcher) pushInvoices(invoices []invoiceExpiry) { + for _, inv := range invoices { + // Switch on the type of entry we have. We need to check nil + // on the implementation of the interface because the interface + // itself is non-nil. + switch expiry := inv.(type) { + case *invoiceExpiryTs: + if expiry != nil { + ew.timestampExpiryQueue.Push(expiry) + } + + default: + log.Errorf("unexpected queue item: %T", inv) + } + } +} + // mainLoop is a goroutine that receives new invoices and handles cancellation // of expired invoices. func (ew *InvoiceExpiryWatcher) mainLoop() { @@ -190,23 +215,12 @@ func (ew *InvoiceExpiryWatcher) mainLoop() { // Cancel any invoices that may have expired. ew.cancelNextExpiredInvoice() - pushInvoices := func(invoicesWithExpiry []*invoiceExpiryTs) { - for _, invoiceWithExpiry := range invoicesWithExpiry { - // Avoid pushing nil object to the heap. - if invoiceWithExpiry != nil { - ew.timestampExpiryQueue.Push( - invoiceWithExpiry, - ) - } - } - } - select { case invoicesWithExpiry := <-ew.newInvoices: // Take newly forwarded invoices with higher priority // in order to not block the newInvoices channel. - pushInvoices(invoicesWithExpiry) + ew.pushInvoices(invoicesWithExpiry) continue default: @@ -217,7 +231,7 @@ func (ew *InvoiceExpiryWatcher) mainLoop() { continue case invoicesWithExpiry := <-ew.newInvoices: - pushInvoices(invoicesWithExpiry) + ew.pushInvoices(invoicesWithExpiry) case <-ew.quit: return diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index c5b5e518..e2c7ea82 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -157,7 +157,7 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) - var invoices []*invoiceExpiryTs + var invoices []invoiceExpiry for hash, invoice := range test.testData.expiredInvoices { invoices = append(invoices, makeInvoiceExpiry(hash, invoice)) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 24eb6ef9..3ef83a98 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -160,7 +160,7 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, // invoices. func (i *InvoiceRegistry) scanInvoicesOnStart() error { var ( - pending []*invoiceExpiryTs + pending []invoiceExpiry removable []channeldb.InvoiceDeleteRef )