package invoices import ( "sync" "testing" "time" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" ) // invoiceExpiryWatcherTest holds a test fixture and implements checks // for InvoiceExpiryWatcher tests. type invoiceExpiryWatcherTest struct { t *testing.T wg sync.WaitGroup watcher *InvoiceExpiryWatcher testData invoiceExpiryTestData canceledInvoices []lntypes.Hash } type mockChainNotifier struct { chainntnfs.ChainNotifier blockChan chan *chainntnfs.BlockEpoch } func newMockNotifier() *mockChainNotifier { return &mockChainNotifier{ blockChan: make(chan *chainntnfs.BlockEpoch), } } // RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's // block channel to deliver blocks to the client. func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( *chainntnfs.BlockEpochEvent, error) { return &chainntnfs.BlockEpochEvent{ Epochs: m.blockChan, Cancel: func() {}, }, nil } // newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture // and sets up the test environment. func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time, numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest { mockNotifier := newMockNotifier() test := &invoiceExpiryWatcherTest{ watcher: NewInvoiceExpiryWatcher( clock.NewTestClock(testTime), 0, uint32(testCurrentHeight), nil, mockNotifier, ), testData: generateInvoiceExpiryTestData( t, now, 0, numExpiredInvoices, numPendingInvoices, ), } test.wg.Add(numExpiredInvoices) err := test.watcher.Start(func(paymentHash lntypes.Hash, force bool) error { test.canceledInvoices = append( test.canceledInvoices, paymentHash, ) test.wg.Done() return nil }) if err != nil { t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err) } return test } func (t *invoiceExpiryWatcherTest) waitForFinish(timeout time.Duration) { done := make(chan struct{}) // Wait for all cancels. go func() { t.wg.Wait() close(done) }() select { case <-done: case <-time.After(timeout): t.t.Fatalf("test timeout") } } func (t *invoiceExpiryWatcherTest) checkExpectations() { // Check that invoices that got canceled during the test are the ones // that expired. if len(t.canceledInvoices) != len(t.testData.expiredInvoices) { t.t.Fatalf("expected %v cancellations, got %v", len(t.testData.expiredInvoices), len(t.canceledInvoices)) } for i := range t.canceledInvoices { if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok { t.t.Fatalf("wrong invoice canceled") } } } // Tests that InvoiceExpiryWatcher can be started and stopped. func TestInvoiceExpiryWatcherStartStop(t *testing.T) { watcher := NewInvoiceExpiryWatcher( clock.NewTestClock(testTime), 0, uint32(testCurrentHeight), nil, newMockNotifier(), ) cancel := func(lntypes.Hash, bool) error { t.Fatalf("unexpected call") return nil } if err := watcher.Start(cancel); err != nil { t.Fatalf("unexpected error upon start: %v", err) } if err := watcher.Start(cancel); err == nil { t.Fatalf("expected error upon second start") } watcher.Stop() if err := watcher.Start(cancel); err != nil { t.Fatalf("unexpected error upon start: %v", err) } } // Tests that no invoices will expire from an empty InvoiceExpiryWatcher. func TestInvoiceExpiryWithNoInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0) test.waitForFinish(testTimeout) test.watcher.Stop() test.checkExpectations() } // Tests that if all add invoices are expired, then all invoices // will be canceled. func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5) for paymentHash, invoice := range test.testData.pendingInvoices { test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice)) } test.waitForFinish(testTimeout) test.watcher.Stop() test.checkExpectations() } // Tests that if some invoices are expired, then those invoices // will be canceled. func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) for paymentHash, invoice := range test.testData.expiredInvoices { test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice)) } for paymentHash, invoice := range test.testData.pendingInvoices { test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice)) } test.waitForFinish(testTimeout) test.watcher.Stop() test.checkExpectations() } // Tests adding multiple invoices at once. func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) var invoices []invoiceExpiry for hash, invoice := range test.testData.expiredInvoices { invoices = append(invoices, makeInvoiceExpiry(hash, invoice)) } for hash, invoice := range test.testData.pendingInvoices { invoices = append(invoices, makeInvoiceExpiry(hash, invoice)) } test.watcher.AddInvoices(invoices...) test.waitForFinish(testTimeout) test.watcher.Stop() test.checkExpectations() }