lnd.xprv/invoices/invoice_expiry_watcher_test.go
carla 34de5922ed
multi: add height-based invoice expiry
This commit adds height-based invoice expiry for hodl invoices
that have active htlcs. This allows us to cancel our intentionally
held htlcs before channels are force closed. We only add this for
hodl invoices because we expect regular invoices to automatically
be resolved.

We still keep hodl invoices in the time-based expiry queue,
because we want to expire open invoices that reach their timeout
before any htlcs are added. Since htlcs are added after the
invoice is created, we add new htlcs as they arrive in the
invoice registry. In this commit, we allow adding of duplicate
entries for an invoice to be added to the expiry queue as each
htlc arrives to keep implementation simple. Our cancellation
logic can already handle the case where an entry is already
canceled, so this is ok.
2021-05-11 08:45:29 +02:00

206 lines
5.1 KiB
Go

package invoices
import (
"sync"
"testing"
"time"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
)
// invoiceExpiryWatcherTest holds a test fixture and implements checks
// for InvoiceExpiryWatcher tests.
type invoiceExpiryWatcherTest struct {
t *testing.T
wg sync.WaitGroup
watcher *InvoiceExpiryWatcher
testData invoiceExpiryTestData
canceledInvoices []lntypes.Hash
}
type mockChainNotifier struct {
chainntnfs.ChainNotifier
blockChan chan *chainntnfs.BlockEpoch
}
func newMockNotifier() *mockChainNotifier {
return &mockChainNotifier{
blockChan: make(chan *chainntnfs.BlockEpoch),
}
}
// RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's
// block channel to deliver blocks to the client.
func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) (
*chainntnfs.BlockEpochEvent, error) {
return &chainntnfs.BlockEpochEvent{
Epochs: m.blockChan,
Cancel: func() {},
}, nil
}
// newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture
// and sets up the test environment.
func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time,
numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest {
mockNotifier := newMockNotifier()
test := &invoiceExpiryWatcherTest{
watcher: NewInvoiceExpiryWatcher(
clock.NewTestClock(testTime), 0,
uint32(testCurrentHeight), nil, mockNotifier,
),
testData: generateInvoiceExpiryTestData(
t, now, 0, numExpiredInvoices, numPendingInvoices,
),
}
test.wg.Add(numExpiredInvoices)
err := test.watcher.Start(func(paymentHash lntypes.Hash,
force bool) error {
test.canceledInvoices = append(
test.canceledInvoices, paymentHash,
)
test.wg.Done()
return nil
})
if err != nil {
t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err)
}
return test
}
func (t *invoiceExpiryWatcherTest) waitForFinish(timeout time.Duration) {
done := make(chan struct{})
// Wait for all cancels.
go func() {
t.wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(timeout):
t.t.Fatalf("test timeout")
}
}
func (t *invoiceExpiryWatcherTest) checkExpectations() {
// Check that invoices that got canceled during the test are the ones
// that expired.
if len(t.canceledInvoices) != len(t.testData.expiredInvoices) {
t.t.Fatalf("expected %v cancellations, got %v",
len(t.testData.expiredInvoices),
len(t.canceledInvoices))
}
for i := range t.canceledInvoices {
if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok {
t.t.Fatalf("wrong invoice canceled")
}
}
}
// Tests that InvoiceExpiryWatcher can be started and stopped.
func TestInvoiceExpiryWatcherStartStop(t *testing.T) {
watcher := NewInvoiceExpiryWatcher(
clock.NewTestClock(testTime), 0, uint32(testCurrentHeight), nil,
newMockNotifier(),
)
cancel := func(lntypes.Hash, bool) error {
t.Fatalf("unexpected call")
return nil
}
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
if err := watcher.Start(cancel); err == nil {
t.Fatalf("expected error upon second start")
}
watcher.Stop()
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
}
// Tests that no invoices will expire from an empty InvoiceExpiryWatcher.
func TestInvoiceExpiryWithNoInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0)
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if all add invoices are expired, then all invoices
// will be canceled.
func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5)
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice))
}
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if some invoices are expired, then those invoices
// will be canceled.
func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
for paymentHash, invoice := range test.testData.expiredInvoices {
test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice))
}
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice))
}
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests adding multiple invoices at once.
func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
var invoices []invoiceExpiry
for hash, invoice := range test.testData.expiredInvoices {
invoices = append(invoices, makeInvoiceExpiry(hash, invoice))
}
for hash, invoice := range test.testData.pendingInvoices {
invoices = append(invoices, makeInvoiceExpiry(hash, invoice))
}
test.watcher.AddInvoices(invoices...)
test.waitForFinish(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}