diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 8ded522f..14e6c451 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -565,6 +565,83 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { return invoice, nil } +// InvoiceWithPaymentHash is used to store an invoice and its corresponding +// payment hash. This struct is only used to store results of +// ChannelDB.FetchAllInvoicesWithPaymentHash() call. +type InvoiceWithPaymentHash struct { + // Invoice holds the invoice as selected from the invoices bucket. + Invoice Invoice + + // PaymentHash is the payment hash for the Invoice. + PaymentHash lntypes.Hash +} + +// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes +// currently stored within the database. If the pendingOnly param is true, then +// only unsettled invoices and their payment hashes will be returned, skipping +// all invoices that are fully settled or canceled. Note that the returned +// array is not ordered by add index. +func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( + []InvoiceWithPaymentHash, error) { + + var result []InvoiceWithPaymentHash + + err := d.View(func(tx *bbolt.Tx) error { + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + + invoiceIndex := invoices.Bucket(invoiceIndexBucket) + if invoiceIndex == nil { + // Mask the error if there's no invoice + // index as that simply means there are no + // invoices added yet to the DB. In this case + // we simply return an empty list. + return nil + } + + return invoiceIndex.ForEach(func(k, v []byte) error { + // Skip the special numInvoicesKey as that does not + // point to a valid invoice. + if bytes.Equal(k, numInvoicesKey) { + return nil + } + + if v == nil { + return nil + } + + invoice, err := fetchInvoice(v, invoices) + if err != nil { + return err + } + + if pendingOnly && + (invoice.State == ContractSettled || + invoice.State == ContractCanceled) { + + return nil + } + + invoiceWithPaymentHash := InvoiceWithPaymentHash{ + Invoice: invoice, + } + + copy(invoiceWithPaymentHash.PaymentHash[:], k) + result = append(result, invoiceWithPaymentHash) + + return nil + }) + }) + + if err != nil { + return nil, err + } + + return result, nil +} + // FetchAllInvoices returns all invoices currently stored within the database. // If the pendingOnly param is true, then only unsettled invoices will be // returned, skipping all invoices that are fully settled. diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index d3d61ca9..1d0731e8 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -22,6 +22,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -792,6 +793,7 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { registry := invoices.NewRegistry( cdb, + invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()), &invoices.RegistryConfig{ FinalCltvRejectDelta: 5, }, diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go new file mode 100644 index 00000000..db9d9cd5 --- /dev/null +++ b/invoices/invoice_expiry_watcher.go @@ -0,0 +1,191 @@ +package invoices + +import ( + "fmt" + "sync" + "time" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/queue" + "github.com/lightningnetwork/lnd/zpay32" +) + +// invoiceExpiry holds and invoice's payment hash and its expiry. This +// is used to order invoices by their expiry for cancellation. +type invoiceExpiry struct { + PaymentHash lntypes.Hash + Expiry time.Time +} + +// Less implements PriorityQueueItem.Less such that the top item in the +// priorty queue will be the one that expires next. +func (e invoiceExpiry) Less(other queue.PriorityQueueItem) bool { + return e.Expiry.Before(other.(*invoiceExpiry).Expiry) +} + +// InvoiceExpiryWatcher handles automatic invoice cancellation of expried +// invoices. Upon start InvoiceExpiryWatcher will retrieve all pending (not yet +// settled or canceled) invoices invoices to its watcing queue. When a new +// invoice is added to the InvoiceRegistry, it'll be forarded to the +// InvoiceExpiryWatcher and will end up in the watching queue as well. +// If any of the watched invoices expire, they'll be removed from the watching +// queue and will be cancelled through InvoiceRegistry.CancelInvoice(). +type InvoiceExpiryWatcher struct { + sync.Mutex + started bool + + // clock is the clock implementation that InvoiceExpiryWatcher uses. + // It is useful for testing. + clock clock.Clock + + // cancelInvoice is a template method that cancels an expired invoice. + cancelInvoice func(lntypes.Hash) error + + // expiryQueue holds invoiceExpiry items and is used to find the next + // invoice to expire. + expiryQueue queue.PriorityQueue + + // newInvoices channel is used to wake up the main loop when a new invoices + // is added. + newInvoices chan *invoiceExpiry + + wg sync.WaitGroup + + // quit signals InvoiceExpiryWatcher to stop. + quit chan struct{} +} + +// NewInvoiceExpiryWatcher creates a new InvoiceExpiryWatcher instance. +func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher { + return &InvoiceExpiryWatcher{ + clock: clock, + newInvoices: make(chan *invoiceExpiry), + quit: make(chan struct{}), + } +} + +// Start starts the the subscription handler and the main loop. Start() will +// return with error if InvoiceExpiryWatcher is already started. Start() +// expects a cancellation function passed that will be use to cancel expired +// invoices by their payment hash. +func (ew *InvoiceExpiryWatcher) Start( + cancelInvoice func(lntypes.Hash) error) error { + + ew.Lock() + defer ew.Unlock() + + if ew.started { + return fmt.Errorf("InvoiceExpiryWatcher already started") + } + + ew.started = true + ew.cancelInvoice = cancelInvoice + ew.wg.Add(1) + go ew.mainLoop() + + return nil +} + +// Stop quits the expiry handler loop and waits for InvoiceExpiryWatcher to +// fully stop. +func (ew *InvoiceExpiryWatcher) Stop() { + ew.Lock() + defer ew.Unlock() + + if ew.started { + // Signal subscriptionHandler to quit and wait for it to return. + close(ew.quit) + ew.wg.Wait() + ew.started = false + } +} + +// AddInvoice adds a new invoice to the InvoiceExpiryWatcher. This won't check +// if the invoice is already added and will only add invoices with ContractOpen +// state. +func (ew *InvoiceExpiryWatcher) AddInvoice( + paymentHash lntypes.Hash, invoice *channeldb.Invoice) { + + if invoice.State != channeldb.ContractOpen { + log.Debugf("Invoice not added to expiry watcher: %v", invoice) + return + } + + realExpiry := invoice.Terms.Expiry + if realExpiry == 0 { + realExpiry = zpay32.DefaultInvoiceExpiry + } + + expiry := invoice.CreationDate.Add(realExpiry) + + log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v", + paymentHash, expiry) + + newInvoiceExpiry := &invoiceExpiry{ + PaymentHash: paymentHash, + Expiry: expiry, + } + + select { + case ew.newInvoices <- newInvoiceExpiry: + case <-ew.quit: + // Select on quit too so that callers won't get blocked in case + // of concurrent shutdown. + } +} + +// nextExpiry returns a Time chan to wait on until the next invoice expires. +// If there are no active invoices, then it'll simply wait indefinitely. +func (ew *InvoiceExpiryWatcher) nextExpiry() <-chan time.Time { + if !ew.expiryQueue.Empty() { + top := ew.expiryQueue.Top().(*invoiceExpiry) + return ew.clock.TickAfter(top.Expiry.Sub(ew.clock.Now())) + } + + return nil +} + +// cancelExpiredInvoices will cancel all expired invoices and removes them from +// the expiry queue. +func (ew *InvoiceExpiryWatcher) cancelExpiredInvoices() { + for !ew.expiryQueue.Empty() { + top := ew.expiryQueue.Top().(*invoiceExpiry) + if !top.Expiry.Before(ew.clock.Now()) { + break + } + + err := ew.cancelInvoice(top.PaymentHash) + if err != nil && err != channeldb.ErrInvoiceAlreadySettled && + err != channeldb.ErrInvoiceAlreadyCanceled { + + log.Errorf("Unable to cancel invoice: %v", top.PaymentHash) + } + + ew.expiryQueue.Pop() + } +} + +// mainLoop is a goroutine that receives new invoices and handles cancellation +// of expired invoices. +func (ew *InvoiceExpiryWatcher) mainLoop() { + defer ew.wg.Done() + + for { + // Cancel any invoices that may have expired. + ew.cancelExpiredInvoices() + + select { + case <-ew.nextExpiry(): + // Wait until the next invoice expires, then cancel expired invoices. + continue + + case newInvoiceExpiry := <-ew.newInvoices: + ew.expiryQueue.Push(newInvoiceExpiry) + + case <-ew.quit: + return + } + } +} diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go new file mode 100644 index 00000000..8bfdfd69 --- /dev/null +++ b/invoices/invoice_expiry_watcher_test.go @@ -0,0 +1,125 @@ +package invoices + +import ( + "testing" + "time" + + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" +) + +// invoiceExpiryWatcherTest holds a test fixture and implements checks +// for InvoiceExpiryWatcher tests. +type invoiceExpiryWatcherTest struct { + t *testing.T + watcher *InvoiceExpiryWatcher + testData invoiceExpiryTestData + canceledInvoices []lntypes.Hash +} + +// newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture +// and sets up the test environment. +func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time, + numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest { + + test := &invoiceExpiryWatcherTest{ + watcher: NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)), + testData: generateInvoiceExpiryTestData( + t, now, 0, numExpiredInvoices, numPendingInvoices, + ), + } + + err := test.watcher.Start(func(paymentHash lntypes.Hash) error { + test.canceledInvoices = append(test.canceledInvoices, paymentHash) + return nil + }) + + if err != nil { + t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err) + } + + return test +} + +func (t *invoiceExpiryWatcherTest) checkExpectations() { + // Check that invoices that got canceled during the test are the ones + // that expired. + if len(t.canceledInvoices) != len(t.testData.expiredInvoices) { + t.t.Fatalf("expected %v cancellations, got %v", + len(t.testData.expiredInvoices), len(t.canceledInvoices)) + } + + for i := range t.canceledInvoices { + if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok { + t.t.Fatalf("wrong invoice canceled") + } + } +} + +// Tests that InvoiceExpiryWatcher can be started and stopped. +func TestInvoiceExpiryWatcherStartStop(t *testing.T) { + watcher := NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)) + cancel := func(lntypes.Hash) error { + t.Fatalf("unexpected call") + return nil + } + + if err := watcher.Start(cancel); err != nil { + t.Fatalf("unexpected error upon start: %v", err) + } + + if err := watcher.Start(cancel); err == nil { + t.Fatalf("expected error upon second start") + } + + watcher.Stop() + + if err := watcher.Start(cancel); err != nil { + t.Fatalf("unexpected error upon start: %v", err) + } +} + +// Tests that no invoices will expire from an empty InvoiceExpiryWatcher. +func TestInvoiceExpiryWithNoInvoices(t *testing.T) { + t.Parallel() + test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0) + + time.Sleep(testTimeout) + test.watcher.Stop() + test.checkExpectations() +} + +// Tests that if all add invoices are expired, then all invoices +// will be canceled. +func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) { + t.Parallel() + + test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5) + + for paymentHash, invoice := range test.testData.pendingInvoices { + test.watcher.AddInvoice(paymentHash, invoice) + } + + time.Sleep(testTimeout) + test.watcher.Stop() + test.checkExpectations() +} + +// Tests that if some invoices are expired, then those invoices +// will be canceled. +func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) { + t.Parallel() + test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) + + for paymentHash, invoice := range test.testData.expiredInvoices { + test.watcher.AddInvoice(paymentHash, invoice) + } + + for paymentHash, invoice := range test.testData.pendingInvoices { + test.watcher.AddInvoice(paymentHash, invoice) + } + + time.Sleep(testTimeout) + test.watcher.Stop() + test.checkExpectations() +} diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 16f4f400..d4e77ad5 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -125,6 +125,8 @@ type InvoiceRegistry struct { // auto-released. htlcAutoReleaseChan chan *htlcReleaseEvent + expiryWatcher *InvoiceExpiryWatcher + wg sync.WaitGroup quit chan struct{} } @@ -133,7 +135,9 @@ type InvoiceRegistry struct { // wraps the persistent on-disk invoice storage with an additional in-memory // layer. The in-memory layer is in place such that debug invoices can be added // which are volatile yet available system wide within the daemon. -func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry { +func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, + cfg *RegistryConfig) *InvoiceRegistry { + return &InvoiceRegistry{ cdb: cdb, notificationClients: make(map[uint32]*InvoiceSubscription), @@ -145,21 +149,62 @@ func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry { hodlReverseSubscriptions: make(map[chan<- interface{}]map[channeldb.CircuitKey]struct{}), cfg: cfg, htlcAutoReleaseChan: make(chan *htlcReleaseEvent), + expiryWatcher: expiryWatcher, quit: make(chan struct{}), } } +// populateExpiryWatcher fetches all active invoices and their corresponding +// payment hashes from ChannelDB and adds them to the expiry watcher. +func (i *InvoiceRegistry) populateExpiryWatcher() error { + pendingOnly := true + pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly) + if err != nil && err != channeldb.ErrNoInvoicesCreated { + log.Errorf( + "Error while prefetching active invoices from the database: %v", err, + ) + return err + } + + for idx := range pendingInvoices { + i.expiryWatcher.AddInvoice( + pendingInvoices[idx].PaymentHash, &pendingInvoices[idx].Invoice, + ) + } + + return nil +} + // Start starts the registry and all goroutines it needs to carry out its task. func (i *InvoiceRegistry) Start() error { - i.wg.Add(1) + // Start InvoiceExpiryWatcher and prepopulate it with existing active + // invoices. + err := i.expiryWatcher.Start(func(paymentHash lntypes.Hash) error { + cancelIfAccepted := false + return i.cancelInvoiceImpl(paymentHash, cancelIfAccepted) + }) + if err != nil { + return err + } + + i.wg.Add(1) go i.invoiceEventLoop() + // Now prefetch all pending invoices to the expiry watcher. + err = i.populateExpiryWatcher() + if err != nil { + i.Stop() + return err + } + return nil } // Stop signals the registry for a graceful shutdown. func (i *InvoiceRegistry) Stop() { + i.expiryWatcher.Stop() + close(i.quit) i.wg.Wait() @@ -470,7 +515,6 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, paymentHash lntypes.Hash) (uint64, error) { i.Lock() - defer i.Unlock() log.Debugf("Invoice(%v): added %v", paymentHash, newLogClosure(func() string { @@ -480,12 +524,19 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, addIndex, err := i.cdb.AddInvoice(invoice, paymentHash) if err != nil { + i.Unlock() return 0, err } // Now that we've added the invoice, we'll send dispatch a message to // notify the clients of this new invoice. i.notifyClients(paymentHash, invoice, channeldb.ContractOpen) + i.Unlock() + + // InvoiceExpiryWatcher.AddInvoice must not be locked by InvoiceRegistry + // to avoid deadlock when a new invoice is added while an other is being + // canceled. + i.expiryWatcher.AddInvoice(paymentHash, invoice) return addIndex, nil } @@ -817,6 +868,15 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error { // CancelInvoice attempts to cancel the invoice corresponding to the passed // payment hash. func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { + return i.cancelInvoiceImpl(payHash, true) +} + +// cancelInvoice attempts to cancel the invoice corresponding to the passed +// payment hash. Accepted invoices will only be canceled if explicitly +// requested to do so. +func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, + cancelAccepted bool) error { + i.Lock() defer i.Unlock() @@ -825,6 +885,12 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { updateInvoice := func(invoice *channeldb.Invoice) ( *channeldb.InvoiceUpdateDesc, error) { + // Only cancel the invoice in ContractAccepted state if explicitly + // requested to do so. + if invoice.State == channeldb.ContractAccepted && !cancelAccepted { + return nil, nil + } + // Move invoice to the canceled state. Rely on validation in // channeldb to return an error if the invoice is already // settled or canceled. @@ -847,6 +913,13 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { return err } + // Return without cancellation if the invoice state is ContractAccepted. + if invoice.State == channeldb.ContractAccepted { + log.Debugf("Invoice(%v): remains accepted as cancel wasn't"+ + "explicitly requested.", payHash) + return nil + } + log.Debugf("Invoice(%v): canceled", payHash) // In the callback, some htlcs may have been moved to the canceled diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index fbd9a260..e47eaa37 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -5,6 +5,8 @@ import ( "time" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" ) @@ -301,8 +303,9 @@ func TestSettleHoldInvoice(t *testing.T) { // Instantiate and start the invoice ctx.registry. cfg := RegistryConfig{ FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: clock.NewTestClock(testTime), } - registry := NewRegistry(cdb, &cfg) + registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg) err = registry.Start() if err != nil { @@ -461,7 +464,7 @@ func TestSettleHoldInvoice(t *testing.T) { // TestCancelHoldInvoice tests canceling of a hold invoice and related // notifications. func TestCancelHoldInvoice(t *testing.T) { - defer timeout() + defer timeout()() cdb, cleanup, err := newTestChannelDB() if err != nil { @@ -472,8 +475,9 @@ func TestCancelHoldInvoice(t *testing.T) { // Instantiate and start the invoice ctx.registry. cfg := RegistryConfig{ FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: clock.NewTestClock(testTime), } - registry := NewRegistry(cdb, &cfg) + registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg) err = registry.Start() if err != nil { @@ -557,7 +561,7 @@ func TestUnknownInvoice(t *testing.T) { // TestSettleMpp tests settling of an invoice with multiple partial payments. func TestSettleMpp(t *testing.T) { - defer timeout() + defer timeout()() ctx := newTestContext(t) defer ctx.cleanup() @@ -636,3 +640,105 @@ func TestSettleMpp(t *testing.T) { testInvoice.Terms.Value, inv.AmtPaid) } } + +// Tests that invoices are canceled after expiration. +func TestInvoiceExpiryWithRegistry(t *testing.T) { + t.Parallel() + + cdb, cleanup, err := newTestChannelDB() + defer cleanup() + + if err != nil { + t.Fatal(err) + } + + testClock := clock.NewTestClock(testTime) + + cfg := RegistryConfig{ + FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: testClock, + } + + expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock) + registry := NewRegistry(cdb, expiryWatcher, &cfg) + + // First prefill the Channel DB with some pre-existing invoices, + // half of them still pending, half of them expired. + const numExpired = 5 + const numPending = 5 + existingInvoices := generateInvoiceExpiryTestData( + t, testTime, 0, numExpired, numPending, + ) + + var expectedCancellations []lntypes.Hash + + for paymentHash, expiredInvoice := range existingInvoices.expiredInvoices { + if _, err := cdb.AddInvoice(expiredInvoice, paymentHash); err != nil { + t.Fatalf("cannot add invoice to channel db: %v", err) + } + expectedCancellations = append(expectedCancellations, paymentHash) + } + + for paymentHash, pendingInvoice := range existingInvoices.pendingInvoices { + if _, err := cdb.AddInvoice(pendingInvoice, paymentHash); err != nil { + t.Fatalf("cannot add invoice to channel db: %v", err) + } + } + + if err = registry.Start(); err != nil { + t.Fatalf("cannot start registry: %v", err) + } + + // Now generate pending and invoices and add them to the registry while + // it is up and running. We'll manipulate the clock to let them expire. + newInvoices := generateInvoiceExpiryTestData( + t, testTime, numExpired+numPending, 0, numPending, + ) + + var invoicesThatWillCancel []lntypes.Hash + for paymentHash, pendingInvoice := range newInvoices.pendingInvoices { + _, err := registry.AddInvoice(pendingInvoice, paymentHash) + invoicesThatWillCancel = append(invoicesThatWillCancel, paymentHash) + if err != nil { + t.Fatal(err) + } + } + + // Check that they are really not canceled until before the clock is + // advanced. + for i := range invoicesThatWillCancel { + invoice, err := registry.LookupInvoice(invoicesThatWillCancel[i]) + if err != nil { + t.Fatalf("cannot find invoice: %v", err) + } + + if invoice.State == channeldb.ContractCanceled { + t.Fatalf("expected pending invoice, got canceled") + } + } + + // Fwd time 1 day. + testClock.SetTime(testTime.Add(24 * time.Hour)) + + // Give some time to the watcher to cancel everything. + time.Sleep(testTimeout) + registry.Stop() + + // Create the expected cancellation set before the final check. + expectedCancellations = append( + expectedCancellations, invoicesThatWillCancel..., + ) + + // Retrospectively check that all invoices that were expected to be canceled + // are indeed canceled. + for i := range expectedCancellations { + invoice, err := registry.LookupInvoice(expectedCancellations[i]) + if err != nil { + t.Fatalf("cannot find invoice: %v", err) + } + + if invoice.State != channeldb.ContractCanceled { + t.Fatalf("expected canceled invoice, got: %v", invoice.State) + } + } +} diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index 52e6044f..d6298543 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -1,6 +1,7 @@ package invoices import ( + "encoding/binary" "encoding/hex" "fmt" "io/ioutil" @@ -51,7 +52,7 @@ var ( testPrivKeyBytes, _ = hex.DecodeString( "e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734") - testPrivKey, testPubKey = btcec.PrivKeyFromBytes( + testPrivKey, _ = btcec.PrivKeyFromBytes( btcec.S256(), testPrivKeyBytes) testInvoiceDescription = "coffee" @@ -75,6 +76,8 @@ var ( ) testPayload = &mockPayload{} + + testInvoiceCreationDate = testTime ) var ( @@ -83,16 +86,20 @@ var ( Terms: channeldb.ContractTerm{ PaymentPreimage: testInvoicePreimage, Value: testInvoiceAmt, + Expiry: time.Hour, Features: testFeatures, }, + CreationDate: testInvoiceCreationDate, } testHodlInvoice = &channeldb.Invoice{ Terms: channeldb.ContractTerm{ PaymentPreimage: channeldb.UnknownPreimage, Value: testInvoiceAmt, + Expiry: time.Hour, Features: testFeatures, }, + CreationDate: testInvoiceCreationDate, } ) @@ -120,6 +127,7 @@ func newTestChannelDB() (*channeldb.DB, func(), error) { } type testContext struct { + cdb *channeldb.DB registry *InvoiceRegistry clock *clock.TestClock @@ -136,13 +144,15 @@ func newTestContext(t *testing.T) *testContext { } cdb.Now = clock.Now + expiryWatcher := NewInvoiceExpiryWatcher(clock) + // Instantiate and start the invoice ctx.registry. cfg := RegistryConfig{ FinalCltvRejectDelta: testFinalCltvRejectDelta, HtlcHoldDuration: 30 * time.Second, Clock: clock, } - registry := NewRegistry(cdb, &cfg) + registry := NewRegistry(cdb, expiryWatcher, &cfg) err = registry.Start() if err != nil { @@ -151,6 +161,7 @@ func newTestContext(t *testing.T) *testContext { } ctx := testContext{ + cdb: cdb, registry: registry, clock: clock, t: t, @@ -172,7 +183,7 @@ func getCircuitKey(htlcID uint64) channeldb.CircuitKey { } } -func newTestInvoice(t *testing.T, +func newTestInvoice(t *testing.T, preimage lntypes.Preimage, timestamp time.Time, expiry time.Duration) *channeldb.Invoice { if expiry == 0 { @@ -181,7 +192,7 @@ func newTestInvoice(t *testing.T, rawInvoice, err := zpay32.NewInvoice( testNetParams, - testInvoicePaymentHash, + preimage.Hash(), timestamp, zpay32.Amount(testInvoiceAmount), zpay32.Description(testInvoiceDescription), @@ -199,7 +210,7 @@ func newTestInvoice(t *testing.T, return &channeldb.Invoice{ Terms: channeldb.ContractTerm{ - PaymentPreimage: testInvoicePreimage, + PaymentPreimage: preimage, Value: testInvoiceAmount, Expiry: expiry, Features: testFeatures, @@ -229,3 +240,41 @@ func timeout() func() { close(done) } } + +// invoiceExpiryTestData simply holds generated expired and pending invoices. +type invoiceExpiryTestData struct { + expiredInvoices map[lntypes.Hash]*channeldb.Invoice + pendingInvoices map[lntypes.Hash]*channeldb.Invoice +} + +// generateInvoiceExpiryTestData generates the specified number of fake expired +// and pending invoices anchored to the passed now timestamp. +func generateInvoiceExpiryTestData( + t *testing.T, now time.Time, + offset, numExpired, numPending int) invoiceExpiryTestData { + + var testData invoiceExpiryTestData + + testData.expiredInvoices = make(map[lntypes.Hash]*channeldb.Invoice) + testData.pendingInvoices = make(map[lntypes.Hash]*channeldb.Invoice) + + expiredCreationDate := now.Add(-24 * time.Hour) + + for i := 1; i <= numExpired; i++ { + var preimage lntypes.Preimage + binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i)) + expiry := time.Duration((i+offset)%24) * time.Hour + invoice := newTestInvoice(t, preimage, expiredCreationDate, expiry) + testData.expiredInvoices[preimage.Hash()] = invoice + } + + for i := 1; i <= numPending; i++ { + var preimage lntypes.Preimage + binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i)) + expiry := time.Duration((i+offset)%24) * time.Hour + invoice := newTestInvoice(t, preimage, now, expiry) + testData.pendingInvoices[preimage.Hash()] = invoice + } + + return testData +} diff --git a/server.go b/server.go index edf90e76..320c3e68 100644 --- a/server.go +++ b/server.go @@ -393,7 +393,10 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, readPool: readPool, chansToRestore: chansToRestore, - invoices: invoices.NewRegistry(chanDB, ®istryConfig), + invoices: invoices.NewRegistry( + chanDB, invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()), + ®istryConfig, + ), channelNotifier: channelnotifier.New(chanDB), diff --git a/zpay32/invoice.go b/zpay32/invoice.go index ab871125..3afb4cda 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -82,6 +82,10 @@ const ( // This is chosen to be the maximum number of bytes that can fit into a // single QR code: https://en.wikipedia.org/wiki/QR_code#Storage maxInvoiceLength = 7089 + + // DefaultInvoiceExpiry is the default expiry duration from the creation + // timestamp if expiry is set to zero. + DefaultInvoiceExpiry = time.Hour ) var (