invoices: refactor - add interface for expiry items

In preparation for having more than one expiry type, we
alias the queue.PrioirtyQueueItem interface for readability.
This commit is contained in:
carla 2021-04-23 08:19:54 +02:00
parent 9c6e83b15f
commit 4cd48c52ea
No known key found for this signature in database
GPG Key ID: 4CA7FE54A6213C91
3 changed files with 35 additions and 21 deletions

@ -12,6 +12,14 @@ import (
"github.com/lightningnetwork/lnd/zpay32" "github.com/lightningnetwork/lnd/zpay32"
) )
// invoiceExpiry is a vanity interface for different invoice expiry types
// which implement the priority queue item interface, used to improve code
// readability.
type invoiceExpiry queue.PriorityQueueItem
// Compile time assertion that invoiceExpiryTs implements invoiceExpiry.
var _ invoiceExpiry = (*invoiceExpiryTs)(nil)
// invoiceExpiryTs holds and invoice's payment hash and its expiry. This // invoiceExpiryTs holds and invoice's payment hash and its expiry. This
// is used to order invoices by their expiry time for cancellation. // is used to order invoices by their expiry time for cancellation.
type invoiceExpiryTs struct { type invoiceExpiryTs struct {
@ -50,7 +58,7 @@ type InvoiceExpiryWatcher struct {
// newInvoices channel is used to wake up the main loop when a new // newInvoices channel is used to wake up the main loop when a new
// invoices is added. // invoices is added.
newInvoices chan []*invoiceExpiryTs newInvoices chan []invoiceExpiry
wg sync.WaitGroup wg sync.WaitGroup
@ -62,7 +70,7 @@ type InvoiceExpiryWatcher struct {
func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher { func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher {
return &InvoiceExpiryWatcher{ return &InvoiceExpiryWatcher{
clock: clock, clock: clock,
newInvoices: make(chan []*invoiceExpiryTs), newInvoices: make(chan []invoiceExpiry),
quit: make(chan struct{}), quit: make(chan struct{}),
} }
} }
@ -104,10 +112,9 @@ func (ew *InvoiceExpiryWatcher) Stop() {
} }
// makeInvoiceExpiry checks if the passed invoice may be canceled and calculates // makeInvoiceExpiry checks if the passed invoice may be canceled and calculates
// the expiry time and creates a slimmer invoiceExpiry object with the hash and // the expiry time and creates a slimmer invoiceExpiry implementation.
// expiry time.
func makeInvoiceExpiry(paymentHash lntypes.Hash, func makeInvoiceExpiry(paymentHash lntypes.Hash,
invoice *channeldb.Invoice) *invoiceExpiryTs { invoice *channeldb.Invoice) invoiceExpiry {
if invoice.State != channeldb.ContractOpen { if invoice.State != channeldb.ContractOpen {
log.Debugf("Invoice not added to expiry watcher: %v", log.Debugf("Invoice not added to expiry watcher: %v",
@ -129,7 +136,7 @@ func makeInvoiceExpiry(paymentHash lntypes.Hash,
} }
// AddInvoices adds invoices to the InvoiceExpiryWatcher. // AddInvoices adds invoices to the InvoiceExpiryWatcher.
func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...*invoiceExpiryTs) { func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...invoiceExpiry) {
if len(invoices) > 0 { if len(invoices) > 0 {
select { select {
case ew.newInvoices <- invoices: case ew.newInvoices <- invoices:
@ -181,6 +188,24 @@ func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() {
} }
} }
// pushInvoices adds invoices to be expired to their relevant queue.
func (ew *InvoiceExpiryWatcher) pushInvoices(invoices []invoiceExpiry) {
for _, inv := range invoices {
// Switch on the type of entry we have. We need to check nil
// on the implementation of the interface because the interface
// itself is non-nil.
switch expiry := inv.(type) {
case *invoiceExpiryTs:
if expiry != nil {
ew.timestampExpiryQueue.Push(expiry)
}
default:
log.Errorf("unexpected queue item: %T", inv)
}
}
}
// mainLoop is a goroutine that receives new invoices and handles cancellation // mainLoop is a goroutine that receives new invoices and handles cancellation
// of expired invoices. // of expired invoices.
func (ew *InvoiceExpiryWatcher) mainLoop() { func (ew *InvoiceExpiryWatcher) mainLoop() {
@ -190,23 +215,12 @@ func (ew *InvoiceExpiryWatcher) mainLoop() {
// Cancel any invoices that may have expired. // Cancel any invoices that may have expired.
ew.cancelNextExpiredInvoice() ew.cancelNextExpiredInvoice()
pushInvoices := func(invoicesWithExpiry []*invoiceExpiryTs) {
for _, invoiceWithExpiry := range invoicesWithExpiry {
// Avoid pushing nil object to the heap.
if invoiceWithExpiry != nil {
ew.timestampExpiryQueue.Push(
invoiceWithExpiry,
)
}
}
}
select { select {
case invoicesWithExpiry := <-ew.newInvoices: case invoicesWithExpiry := <-ew.newInvoices:
// Take newly forwarded invoices with higher priority // Take newly forwarded invoices with higher priority
// in order to not block the newInvoices channel. // in order to not block the newInvoices channel.
pushInvoices(invoicesWithExpiry) ew.pushInvoices(invoicesWithExpiry)
continue continue
default: default:
@ -217,7 +231,7 @@ func (ew *InvoiceExpiryWatcher) mainLoop() {
continue continue
case invoicesWithExpiry := <-ew.newInvoices: case invoicesWithExpiry := <-ew.newInvoices:
pushInvoices(invoicesWithExpiry) ew.pushInvoices(invoicesWithExpiry)
case <-ew.quit: case <-ew.quit:
return return

@ -157,7 +157,7 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) {
t.Parallel() t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
var invoices []*invoiceExpiryTs var invoices []invoiceExpiry
for hash, invoice := range test.testData.expiredInvoices { for hash, invoice := range test.testData.expiredInvoices {
invoices = append(invoices, makeInvoiceExpiry(hash, invoice)) invoices = append(invoices, makeInvoiceExpiry(hash, invoice))

@ -160,7 +160,7 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher,
// invoices. // invoices.
func (i *InvoiceRegistry) scanInvoicesOnStart() error { func (i *InvoiceRegistry) scanInvoicesOnStart() error {
var ( var (
pending []*invoiceExpiryTs pending []invoiceExpiry
removable []channeldb.InvoiceDeleteRef removable []channeldb.InvoiceDeleteRef
) )