invoices: adding InvoiceExpryWatcher to cancel expired invoices

This commit adds InvoiceExpryWatcher which is a separate class that
receives new invoices (and existing ones upon restart) from InvoiceRegistry
and actively watches their expiry. When an invoice is expired
InvoiceExpiryWatcher will call into InvoiceRegistry to cancel the
invoice and by that notify all subscribers about the state change.
This commit is contained in:
Andras Banki-Horvath 2019-12-09 17:50:11 +01:00
parent 7024f36a76
commit 44f13d1d60
9 changed files with 643 additions and 13 deletions

@ -565,6 +565,83 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
return invoice, nil
}
// InvoiceWithPaymentHash is used to store an invoice and its corresponding
// payment hash. This struct is only used to store results of
// ChannelDB.FetchAllInvoicesWithPaymentHash() call.
type InvoiceWithPaymentHash struct {
// Invoice holds the invoice as selected from the invoices bucket.
Invoice Invoice
// PaymentHash is the payment hash for the Invoice.
PaymentHash lntypes.Hash
}
// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes
// currently stored within the database. If the pendingOnly param is true, then
// only unsettled invoices and their payment hashes will be returned, skipping
// all invoices that are fully settled or canceled. Note that the returned
// array is not ordered by add index.
func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) (
[]InvoiceWithPaymentHash, error) {
var result []InvoiceWithPaymentHash
err := d.View(func(tx *bbolt.Tx) error {
invoices := tx.Bucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
}
invoiceIndex := invoices.Bucket(invoiceIndexBucket)
if invoiceIndex == nil {
// Mask the error if there's no invoice
// index as that simply means there are no
// invoices added yet to the DB. In this case
// we simply return an empty list.
return nil
}
return invoiceIndex.ForEach(func(k, v []byte) error {
// Skip the special numInvoicesKey as that does not
// point to a valid invoice.
if bytes.Equal(k, numInvoicesKey) {
return nil
}
if v == nil {
return nil
}
invoice, err := fetchInvoice(v, invoices)
if err != nil {
return err
}
if pendingOnly &&
(invoice.State == ContractSettled ||
invoice.State == ContractCanceled) {
return nil
}
invoiceWithPaymentHash := InvoiceWithPaymentHash{
Invoice: invoice,
}
copy(invoiceWithPaymentHash.PaymentHash[:], k)
result = append(result, invoiceWithPaymentHash)
return nil
})
})
if err != nil {
return nil, err
}
return result, nil
}
// FetchAllInvoices returns all invoices currently stored within the database.
// If the pendingOnly param is true, then only unsettled invoices will be
// returned, skipping all invoices that are fully settled.

@ -22,6 +22,7 @@ import (
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
@ -792,6 +793,7 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
registry := invoices.NewRegistry(
cdb,
invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()),
&invoices.RegistryConfig{
FinalCltvRejectDelta: 5,
},

@ -0,0 +1,191 @@
package invoices
import (
"fmt"
"sync"
"time"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/zpay32"
)
// invoiceExpiry holds and invoice's payment hash and its expiry. This
// is used to order invoices by their expiry for cancellation.
type invoiceExpiry struct {
PaymentHash lntypes.Hash
Expiry time.Time
}
// Less implements PriorityQueueItem.Less such that the top item in the
// priorty queue will be the one that expires next.
func (e invoiceExpiry) Less(other queue.PriorityQueueItem) bool {
return e.Expiry.Before(other.(*invoiceExpiry).Expiry)
}
// InvoiceExpiryWatcher handles automatic invoice cancellation of expried
// invoices. Upon start InvoiceExpiryWatcher will retrieve all pending (not yet
// settled or canceled) invoices invoices to its watcing queue. When a new
// invoice is added to the InvoiceRegistry, it'll be forarded to the
// InvoiceExpiryWatcher and will end up in the watching queue as well.
// If any of the watched invoices expire, they'll be removed from the watching
// queue and will be cancelled through InvoiceRegistry.CancelInvoice().
type InvoiceExpiryWatcher struct {
sync.Mutex
started bool
// clock is the clock implementation that InvoiceExpiryWatcher uses.
// It is useful for testing.
clock clock.Clock
// cancelInvoice is a template method that cancels an expired invoice.
cancelInvoice func(lntypes.Hash) error
// expiryQueue holds invoiceExpiry items and is used to find the next
// invoice to expire.
expiryQueue queue.PriorityQueue
// newInvoices channel is used to wake up the main loop when a new invoices
// is added.
newInvoices chan *invoiceExpiry
wg sync.WaitGroup
// quit signals InvoiceExpiryWatcher to stop.
quit chan struct{}
}
// NewInvoiceExpiryWatcher creates a new InvoiceExpiryWatcher instance.
func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher {
return &InvoiceExpiryWatcher{
clock: clock,
newInvoices: make(chan *invoiceExpiry),
quit: make(chan struct{}),
}
}
// Start starts the the subscription handler and the main loop. Start() will
// return with error if InvoiceExpiryWatcher is already started. Start()
// expects a cancellation function passed that will be use to cancel expired
// invoices by their payment hash.
func (ew *InvoiceExpiryWatcher) Start(
cancelInvoice func(lntypes.Hash) error) error {
ew.Lock()
defer ew.Unlock()
if ew.started {
return fmt.Errorf("InvoiceExpiryWatcher already started")
}
ew.started = true
ew.cancelInvoice = cancelInvoice
ew.wg.Add(1)
go ew.mainLoop()
return nil
}
// Stop quits the expiry handler loop and waits for InvoiceExpiryWatcher to
// fully stop.
func (ew *InvoiceExpiryWatcher) Stop() {
ew.Lock()
defer ew.Unlock()
if ew.started {
// Signal subscriptionHandler to quit and wait for it to return.
close(ew.quit)
ew.wg.Wait()
ew.started = false
}
}
// AddInvoice adds a new invoice to the InvoiceExpiryWatcher. This won't check
// if the invoice is already added and will only add invoices with ContractOpen
// state.
func (ew *InvoiceExpiryWatcher) AddInvoice(
paymentHash lntypes.Hash, invoice *channeldb.Invoice) {
if invoice.State != channeldb.ContractOpen {
log.Debugf("Invoice not added to expiry watcher: %v", invoice)
return
}
realExpiry := invoice.Terms.Expiry
if realExpiry == 0 {
realExpiry = zpay32.DefaultInvoiceExpiry
}
expiry := invoice.CreationDate.Add(realExpiry)
log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v",
paymentHash, expiry)
newInvoiceExpiry := &invoiceExpiry{
PaymentHash: paymentHash,
Expiry: expiry,
}
select {
case ew.newInvoices <- newInvoiceExpiry:
case <-ew.quit:
// Select on quit too so that callers won't get blocked in case
// of concurrent shutdown.
}
}
// nextExpiry returns a Time chan to wait on until the next invoice expires.
// If there are no active invoices, then it'll simply wait indefinitely.
func (ew *InvoiceExpiryWatcher) nextExpiry() <-chan time.Time {
if !ew.expiryQueue.Empty() {
top := ew.expiryQueue.Top().(*invoiceExpiry)
return ew.clock.TickAfter(top.Expiry.Sub(ew.clock.Now()))
}
return nil
}
// cancelExpiredInvoices will cancel all expired invoices and removes them from
// the expiry queue.
func (ew *InvoiceExpiryWatcher) cancelExpiredInvoices() {
for !ew.expiryQueue.Empty() {
top := ew.expiryQueue.Top().(*invoiceExpiry)
if !top.Expiry.Before(ew.clock.Now()) {
break
}
err := ew.cancelInvoice(top.PaymentHash)
if err != nil && err != channeldb.ErrInvoiceAlreadySettled &&
err != channeldb.ErrInvoiceAlreadyCanceled {
log.Errorf("Unable to cancel invoice: %v", top.PaymentHash)
}
ew.expiryQueue.Pop()
}
}
// mainLoop is a goroutine that receives new invoices and handles cancellation
// of expired invoices.
func (ew *InvoiceExpiryWatcher) mainLoop() {
defer ew.wg.Done()
for {
// Cancel any invoices that may have expired.
ew.cancelExpiredInvoices()
select {
case <-ew.nextExpiry():
// Wait until the next invoice expires, then cancel expired invoices.
continue
case newInvoiceExpiry := <-ew.newInvoices:
ew.expiryQueue.Push(newInvoiceExpiry)
case <-ew.quit:
return
}
}
}

@ -0,0 +1,125 @@
package invoices
import (
"testing"
"time"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
)
// invoiceExpiryWatcherTest holds a test fixture and implements checks
// for InvoiceExpiryWatcher tests.
type invoiceExpiryWatcherTest struct {
t *testing.T
watcher *InvoiceExpiryWatcher
testData invoiceExpiryTestData
canceledInvoices []lntypes.Hash
}
// newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture
// and sets up the test environment.
func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time,
numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest {
test := &invoiceExpiryWatcherTest{
watcher: NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)),
testData: generateInvoiceExpiryTestData(
t, now, 0, numExpiredInvoices, numPendingInvoices,
),
}
err := test.watcher.Start(func(paymentHash lntypes.Hash) error {
test.canceledInvoices = append(test.canceledInvoices, paymentHash)
return nil
})
if err != nil {
t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err)
}
return test
}
func (t *invoiceExpiryWatcherTest) checkExpectations() {
// Check that invoices that got canceled during the test are the ones
// that expired.
if len(t.canceledInvoices) != len(t.testData.expiredInvoices) {
t.t.Fatalf("expected %v cancellations, got %v",
len(t.testData.expiredInvoices), len(t.canceledInvoices))
}
for i := range t.canceledInvoices {
if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok {
t.t.Fatalf("wrong invoice canceled")
}
}
}
// Tests that InvoiceExpiryWatcher can be started and stopped.
func TestInvoiceExpiryWatcherStartStop(t *testing.T) {
watcher := NewInvoiceExpiryWatcher(clock.NewTestClock(testTime))
cancel := func(lntypes.Hash) error {
t.Fatalf("unexpected call")
return nil
}
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
if err := watcher.Start(cancel); err == nil {
t.Fatalf("expected error upon second start")
}
watcher.Stop()
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
}
// Tests that no invoices will expire from an empty InvoiceExpiryWatcher.
func TestInvoiceExpiryWithNoInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0)
time.Sleep(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if all add invoices are expired, then all invoices
// will be canceled.
func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5)
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoice(paymentHash, invoice)
}
time.Sleep(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if some invoices are expired, then those invoices
// will be canceled.
func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
for paymentHash, invoice := range test.testData.expiredInvoices {
test.watcher.AddInvoice(paymentHash, invoice)
}
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoice(paymentHash, invoice)
}
time.Sleep(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}

@ -125,6 +125,8 @@ type InvoiceRegistry struct {
// auto-released.
htlcAutoReleaseChan chan *htlcReleaseEvent
expiryWatcher *InvoiceExpiryWatcher
wg sync.WaitGroup
quit chan struct{}
}
@ -133,7 +135,9 @@ type InvoiceRegistry struct {
// wraps the persistent on-disk invoice storage with an additional in-memory
// layer. The in-memory layer is in place such that debug invoices can be added
// which are volatile yet available system wide within the daemon.
func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry {
func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher,
cfg *RegistryConfig) *InvoiceRegistry {
return &InvoiceRegistry{
cdb: cdb,
notificationClients: make(map[uint32]*InvoiceSubscription),
@ -145,21 +149,62 @@ func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry {
hodlReverseSubscriptions: make(map[chan<- interface{}]map[channeldb.CircuitKey]struct{}),
cfg: cfg,
htlcAutoReleaseChan: make(chan *htlcReleaseEvent),
expiryWatcher: expiryWatcher,
quit: make(chan struct{}),
}
}
// populateExpiryWatcher fetches all active invoices and their corresponding
// payment hashes from ChannelDB and adds them to the expiry watcher.
func (i *InvoiceRegistry) populateExpiryWatcher() error {
pendingOnly := true
pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly)
if err != nil && err != channeldb.ErrNoInvoicesCreated {
log.Errorf(
"Error while prefetching active invoices from the database: %v", err,
)
return err
}
for idx := range pendingInvoices {
i.expiryWatcher.AddInvoice(
pendingInvoices[idx].PaymentHash, &pendingInvoices[idx].Invoice,
)
}
return nil
}
// Start starts the registry and all goroutines it needs to carry out its task.
func (i *InvoiceRegistry) Start() error {
i.wg.Add(1)
// Start InvoiceExpiryWatcher and prepopulate it with existing active
// invoices.
err := i.expiryWatcher.Start(func(paymentHash lntypes.Hash) error {
cancelIfAccepted := false
return i.cancelInvoiceImpl(paymentHash, cancelIfAccepted)
})
if err != nil {
return err
}
i.wg.Add(1)
go i.invoiceEventLoop()
// Now prefetch all pending invoices to the expiry watcher.
err = i.populateExpiryWatcher()
if err != nil {
i.Stop()
return err
}
return nil
}
// Stop signals the registry for a graceful shutdown.
func (i *InvoiceRegistry) Stop() {
i.expiryWatcher.Stop()
close(i.quit)
i.wg.Wait()
@ -470,7 +515,6 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice,
paymentHash lntypes.Hash) (uint64, error) {
i.Lock()
defer i.Unlock()
log.Debugf("Invoice(%v): added %v", paymentHash,
newLogClosure(func() string {
@ -480,12 +524,19 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice,
addIndex, err := i.cdb.AddInvoice(invoice, paymentHash)
if err != nil {
i.Unlock()
return 0, err
}
// Now that we've added the invoice, we'll send dispatch a message to
// notify the clients of this new invoice.
i.notifyClients(paymentHash, invoice, channeldb.ContractOpen)
i.Unlock()
// InvoiceExpiryWatcher.AddInvoice must not be locked by InvoiceRegistry
// to avoid deadlock when a new invoice is added while an other is being
// canceled.
i.expiryWatcher.AddInvoice(paymentHash, invoice)
return addIndex, nil
}
@ -817,6 +868,15 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error {
// CancelInvoice attempts to cancel the invoice corresponding to the passed
// payment hash.
func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
return i.cancelInvoiceImpl(payHash, true)
}
// cancelInvoice attempts to cancel the invoice corresponding to the passed
// payment hash. Accepted invoices will only be canceled if explicitly
// requested to do so.
func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash,
cancelAccepted bool) error {
i.Lock()
defer i.Unlock()
@ -825,6 +885,12 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
updateInvoice := func(invoice *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, error) {
// Only cancel the invoice in ContractAccepted state if explicitly
// requested to do so.
if invoice.State == channeldb.ContractAccepted && !cancelAccepted {
return nil, nil
}
// Move invoice to the canceled state. Rely on validation in
// channeldb to return an error if the invoice is already
// settled or canceled.
@ -847,6 +913,13 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
return err
}
// Return without cancellation if the invoice state is ContractAccepted.
if invoice.State == channeldb.ContractAccepted {
log.Debugf("Invoice(%v): remains accepted as cancel wasn't"+
"explicitly requested.", payHash)
return nil
}
log.Debugf("Invoice(%v): canceled", payHash)
// In the callback, some htlcs may have been moved to the canceled

@ -5,6 +5,8 @@ import (
"time"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
)
@ -301,8 +303,9 @@ func TestSettleHoldInvoice(t *testing.T) {
// Instantiate and start the invoice ctx.registry.
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
Clock: clock.NewTestClock(testTime),
}
registry := NewRegistry(cdb, &cfg)
registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg)
err = registry.Start()
if err != nil {
@ -461,7 +464,7 @@ func TestSettleHoldInvoice(t *testing.T) {
// TestCancelHoldInvoice tests canceling of a hold invoice and related
// notifications.
func TestCancelHoldInvoice(t *testing.T) {
defer timeout()
defer timeout()()
cdb, cleanup, err := newTestChannelDB()
if err != nil {
@ -472,8 +475,9 @@ func TestCancelHoldInvoice(t *testing.T) {
// Instantiate and start the invoice ctx.registry.
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
Clock: clock.NewTestClock(testTime),
}
registry := NewRegistry(cdb, &cfg)
registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg)
err = registry.Start()
if err != nil {
@ -557,7 +561,7 @@ func TestUnknownInvoice(t *testing.T) {
// TestSettleMpp tests settling of an invoice with multiple partial payments.
func TestSettleMpp(t *testing.T) {
defer timeout()
defer timeout()()
ctx := newTestContext(t)
defer ctx.cleanup()
@ -636,3 +640,105 @@ func TestSettleMpp(t *testing.T) {
testInvoice.Terms.Value, inv.AmtPaid)
}
}
// Tests that invoices are canceled after expiration.
func TestInvoiceExpiryWithRegistry(t *testing.T) {
t.Parallel()
cdb, cleanup, err := newTestChannelDB()
defer cleanup()
if err != nil {
t.Fatal(err)
}
testClock := clock.NewTestClock(testTime)
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
Clock: testClock,
}
expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock)
registry := NewRegistry(cdb, expiryWatcher, &cfg)
// First prefill the Channel DB with some pre-existing invoices,
// half of them still pending, half of them expired.
const numExpired = 5
const numPending = 5
existingInvoices := generateInvoiceExpiryTestData(
t, testTime, 0, numExpired, numPending,
)
var expectedCancellations []lntypes.Hash
for paymentHash, expiredInvoice := range existingInvoices.expiredInvoices {
if _, err := cdb.AddInvoice(expiredInvoice, paymentHash); err != nil {
t.Fatalf("cannot add invoice to channel db: %v", err)
}
expectedCancellations = append(expectedCancellations, paymentHash)
}
for paymentHash, pendingInvoice := range existingInvoices.pendingInvoices {
if _, err := cdb.AddInvoice(pendingInvoice, paymentHash); err != nil {
t.Fatalf("cannot add invoice to channel db: %v", err)
}
}
if err = registry.Start(); err != nil {
t.Fatalf("cannot start registry: %v", err)
}
// Now generate pending and invoices and add them to the registry while
// it is up and running. We'll manipulate the clock to let them expire.
newInvoices := generateInvoiceExpiryTestData(
t, testTime, numExpired+numPending, 0, numPending,
)
var invoicesThatWillCancel []lntypes.Hash
for paymentHash, pendingInvoice := range newInvoices.pendingInvoices {
_, err := registry.AddInvoice(pendingInvoice, paymentHash)
invoicesThatWillCancel = append(invoicesThatWillCancel, paymentHash)
if err != nil {
t.Fatal(err)
}
}
// Check that they are really not canceled until before the clock is
// advanced.
for i := range invoicesThatWillCancel {
invoice, err := registry.LookupInvoice(invoicesThatWillCancel[i])
if err != nil {
t.Fatalf("cannot find invoice: %v", err)
}
if invoice.State == channeldb.ContractCanceled {
t.Fatalf("expected pending invoice, got canceled")
}
}
// Fwd time 1 day.
testClock.SetTime(testTime.Add(24 * time.Hour))
// Give some time to the watcher to cancel everything.
time.Sleep(testTimeout)
registry.Stop()
// Create the expected cancellation set before the final check.
expectedCancellations = append(
expectedCancellations, invoicesThatWillCancel...,
)
// Retrospectively check that all invoices that were expected to be canceled
// are indeed canceled.
for i := range expectedCancellations {
invoice, err := registry.LookupInvoice(expectedCancellations[i])
if err != nil {
t.Fatalf("cannot find invoice: %v", err)
}
if invoice.State != channeldb.ContractCanceled {
t.Fatalf("expected canceled invoice, got: %v", invoice.State)
}
}
}

@ -1,6 +1,7 @@
package invoices
import (
"encoding/binary"
"encoding/hex"
"fmt"
"io/ioutil"
@ -51,7 +52,7 @@ var (
testPrivKeyBytes, _ = hex.DecodeString(
"e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734")
testPrivKey, testPubKey = btcec.PrivKeyFromBytes(
testPrivKey, _ = btcec.PrivKeyFromBytes(
btcec.S256(), testPrivKeyBytes)
testInvoiceDescription = "coffee"
@ -75,6 +76,8 @@ var (
)
testPayload = &mockPayload{}
testInvoiceCreationDate = testTime
)
var (
@ -83,16 +86,20 @@ var (
Terms: channeldb.ContractTerm{
PaymentPreimage: testInvoicePreimage,
Value: testInvoiceAmt,
Expiry: time.Hour,
Features: testFeatures,
},
CreationDate: testInvoiceCreationDate,
}
testHodlInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: channeldb.UnknownPreimage,
Value: testInvoiceAmt,
Expiry: time.Hour,
Features: testFeatures,
},
CreationDate: testInvoiceCreationDate,
}
)
@ -120,6 +127,7 @@ func newTestChannelDB() (*channeldb.DB, func(), error) {
}
type testContext struct {
cdb *channeldb.DB
registry *InvoiceRegistry
clock *clock.TestClock
@ -136,13 +144,15 @@ func newTestContext(t *testing.T) *testContext {
}
cdb.Now = clock.Now
expiryWatcher := NewInvoiceExpiryWatcher(clock)
// Instantiate and start the invoice ctx.registry.
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
HtlcHoldDuration: 30 * time.Second,
Clock: clock,
}
registry := NewRegistry(cdb, &cfg)
registry := NewRegistry(cdb, expiryWatcher, &cfg)
err = registry.Start()
if err != nil {
@ -151,6 +161,7 @@ func newTestContext(t *testing.T) *testContext {
}
ctx := testContext{
cdb: cdb,
registry: registry,
clock: clock,
t: t,
@ -172,7 +183,7 @@ func getCircuitKey(htlcID uint64) channeldb.CircuitKey {
}
}
func newTestInvoice(t *testing.T,
func newTestInvoice(t *testing.T, preimage lntypes.Preimage,
timestamp time.Time, expiry time.Duration) *channeldb.Invoice {
if expiry == 0 {
@ -181,7 +192,7 @@ func newTestInvoice(t *testing.T,
rawInvoice, err := zpay32.NewInvoice(
testNetParams,
testInvoicePaymentHash,
preimage.Hash(),
timestamp,
zpay32.Amount(testInvoiceAmount),
zpay32.Description(testInvoiceDescription),
@ -199,7 +210,7 @@ func newTestInvoice(t *testing.T,
return &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: testInvoicePreimage,
PaymentPreimage: preimage,
Value: testInvoiceAmount,
Expiry: expiry,
Features: testFeatures,
@ -229,3 +240,41 @@ func timeout() func() {
close(done)
}
}
// invoiceExpiryTestData simply holds generated expired and pending invoices.
type invoiceExpiryTestData struct {
expiredInvoices map[lntypes.Hash]*channeldb.Invoice
pendingInvoices map[lntypes.Hash]*channeldb.Invoice
}
// generateInvoiceExpiryTestData generates the specified number of fake expired
// and pending invoices anchored to the passed now timestamp.
func generateInvoiceExpiryTestData(
t *testing.T, now time.Time,
offset, numExpired, numPending int) invoiceExpiryTestData {
var testData invoiceExpiryTestData
testData.expiredInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
testData.pendingInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
expiredCreationDate := now.Add(-24 * time.Hour)
for i := 1; i <= numExpired; i++ {
var preimage lntypes.Preimage
binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i))
expiry := time.Duration((i+offset)%24) * time.Hour
invoice := newTestInvoice(t, preimage, expiredCreationDate, expiry)
testData.expiredInvoices[preimage.Hash()] = invoice
}
for i := 1; i <= numPending; i++ {
var preimage lntypes.Preimage
binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i))
expiry := time.Duration((i+offset)%24) * time.Hour
invoice := newTestInvoice(t, preimage, now, expiry)
testData.pendingInvoices[preimage.Hash()] = invoice
}
return testData
}

@ -393,7 +393,10 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
readPool: readPool,
chansToRestore: chansToRestore,
invoices: invoices.NewRegistry(chanDB, &registryConfig),
invoices: invoices.NewRegistry(
chanDB, invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()),
&registryConfig,
),
channelNotifier: channelnotifier.New(chanDB),

@ -82,6 +82,10 @@ const (
// This is chosen to be the maximum number of bytes that can fit into a
// single QR code: https://en.wikipedia.org/wiki/QR_code#Storage
maxInvoiceLength = 7089
// DefaultInvoiceExpiry is the default expiry duration from the creation
// timestamp if expiry is set to zero.
DefaultInvoiceExpiry = time.Hour
)
var (