lnd.xprv/invoices/invoice_expiry_watcher_test.go

181 lines
4.5 KiB
Go
Raw Normal View History

package invoices
import (
"sync"
"testing"
"time"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
)
// invoiceExpiryWatcherTest holds a test fixture and implements checks
// for InvoiceExpiryWatcher tests.
type invoiceExpiryWatcherTest struct {
t *testing.T
wg sync.WaitGroup
watcher *InvoiceExpiryWatcher
testData invoiceExpiryTestData
canceledInvoices []lntypes.Hash
}
// newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture
// and sets up the test environment.
func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time,
numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest {
test := &invoiceExpiryWatcherTest{
watcher: NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)),
testData: generateInvoiceExpiryTestData(
t, now, 0, numExpiredInvoices, numPendingInvoices,
),
}
test.wg.Add(numExpiredInvoices)
err := test.watcher.Start(func(paymentHash lntypes.Hash) error {
test.canceledInvoices = append(test.canceledInvoices, paymentHash)
test.wg.Done()
return nil
})
if err != nil {
t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err)
}
return test
}
func (t *invoiceExpiryWatcherTest) waitForFinish(timeout time.Duration) {
done := make(chan struct{})
// Wait for all cancels.
go func() {
t.wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(timeout):
t.t.Fatalf("test timeout")
}
}
func (t *invoiceExpiryWatcherTest) checkExpectations() {
// Check that invoices that got canceled during the test are the ones
// that expired.
if len(t.canceledInvoices) != len(t.testData.expiredInvoices) {
t.t.Fatalf("expected %v cancellations, got %v",
len(t.testData.expiredInvoices), len(t.canceledInvoices))
}
for i := range t.canceledInvoices {
if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok {
t.t.Fatalf("wrong invoice canceled")
}
}
}
// Tests that InvoiceExpiryWatcher can be started and stopped.
func TestInvoiceExpiryWatcherStartStop(t *testing.T) {
watcher := NewInvoiceExpiryWatcher(clock.NewTestClock(testTime))
cancel := func(lntypes.Hash) error {
t.Fatalf("unexpected call")
return nil
}
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
if err := watcher.Start(cancel); err == nil {
t.Fatalf("expected error upon second start")
}
watcher.Stop()
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
}
// Tests that no invoices will expire from an empty InvoiceExpiryWatcher.
func TestInvoiceExpiryWithNoInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0)
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if all add invoices are expired, then all invoices
// will be canceled.
func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5)
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoice(paymentHash, invoice)
}
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if some invoices are expired, then those invoices
// will be canceled.
func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
for paymentHash, invoice := range test.testData.expiredInvoices {
test.watcher.AddInvoice(paymentHash, invoice)
}
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoice(paymentHash, invoice)
}
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests adding multiple invoices at once.
func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
var invoices []channeldb.InvoiceWithPaymentHash
for hash, invoice := range test.testData.expiredInvoices {
invoices = append(invoices,
channeldb.InvoiceWithPaymentHash{
Invoice: *invoice,
PaymentHash: hash,
},
)
}
for hash, invoice := range test.testData.pendingInvoices {
invoices = append(invoices,
channeldb.InvoiceWithPaymentHash{
Invoice: *invoice,
PaymentHash: hash,
},
)
}
test.watcher.AddInvoices(invoices)
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}