diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index 99a99dfc..63ddfb92 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" ) @@ -203,3 +204,115 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { test.watcher.Stop() test.checkExpectations() } + +// TestExpiredHodlInv tests expiration of an already-expired hodl invoice +// which has no htlcs. +func TestExpiredHodlInv(t *testing.T) { + t.Parallel() + + creationDate := testTime.Add(time.Hour * -24) + expiry := time.Hour + + test := setupHodlExpiry( + t, creationDate, expiry, 0, channeldb.ContractOpen, nil, + ) + + test.assertCanceled(t, test.hash) + test.watcher.Stop() +} + +// TestAcceptedHodlNotExpired tests that hodl invoices which are in an accepted +// state are not expired once their time-based expiry elapses, using a regular +// invoice that expires at the same time as a control to ensure that invoices +// with that timestamp would otherwise be expired. +func TestAcceptedHodlNotExpired(t *testing.T) { + t.Parallel() + + creationDate := testTime + expiry := time.Hour + + test := setupHodlExpiry( + t, creationDate, expiry, 0, channeldb.ContractAccepted, nil, + ) + defer test.watcher.Stop() + + // Add another invoice that will expire at our expiry time as a control + // value. + tsExpires := &invoiceExpiryTs{ + PaymentHash: lntypes.Hash{1, 2, 3}, + Expiry: creationDate.Add(expiry), + Keysend: true, + } + test.watcher.AddInvoices(tsExpires) + + test.mockClock.SetTime(creationDate.Add(expiry + 1)) + + // Assert that only the ts expiry invoice is expired. + test.assertCanceled(t, tsExpires.PaymentHash) +} + +// TestHeightAlreadyExpired tests the case where we add an invoice with htlcs +// that have already expired to the expiry watcher. +func TestHeightAlreadyExpired(t *testing.T) { + t.Parallel() + + expiredHtlc := []*channeldb.InvoiceHTLC{ + { + State: channeldb.HtlcStateAccepted, + Expiry: uint32(testCurrentHeight), + }, + } + + test := setupHodlExpiry( + t, testTime, time.Hour, 0, channeldb.ContractAccepted, + expiredHtlc, + ) + defer test.watcher.Stop() + + test.assertCanceled(t, test.hash) +} + +// TestExpiryHeightArrives tests the case where we add a hodl invoice to the +// expiry watcher when it has no htlcs, htlcs are added and then they finally +// expire. We use a non-zero delta for this test to check that we expire with +// sufficient buffer. +func TestExpiryHeightArrives(t *testing.T) { + var ( + creationDate = testTime + expiry = time.Hour * 2 + delta uint32 = 1 + ) + + // Start out with a hodl invoice that is open, and has no htlcs. + test := setupHodlExpiry( + t, creationDate, expiry, delta, channeldb.ContractOpen, nil, + ) + defer test.watcher.Stop() + + htlc1 := uint32(testCurrentHeight + 10) + expiry1 := makeHeightExpiry(test.hash, htlc1) + + // Add htlcs to our invoice and progress its state to accepted. + test.watcher.AddInvoices(expiry1) + test.setState(channeldb.ContractAccepted) + + // Progress time so that our expiry has elapsed. We no longer expect + // this invoice to be canceled because it has been accepted. + test.mockClock.SetTime(creationDate.Add(expiry)) + + // Tick our mock block subscription with the next block, we don't + // expect anything to happen. + currentHeight := uint32(testCurrentHeight + 1) + test.announceBlock(t, currentHeight) + + // Now, we add another htlc to the invoice. This one has a lower expiry + // height than our current ones. + htlc2 := currentHeight + 5 + expiry2 := makeHeightExpiry(test.hash, htlc2) + test.watcher.AddInvoices(expiry2) + + // Announce our lowest htlc expiry block minus our delta, the invoice + // should be expired now. + test.announceBlock(t, htlc2-delta) + test.assertCanceled(t, test.hash) +} diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index b78c06aa..51f41fc9 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -8,11 +8,13 @@ import ( "io/ioutil" "os" "runtime/pprof" + "sync" "testing" "time" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" @@ -367,3 +369,111 @@ func checkFailResolution(t *testing.T, res HtlcResolution, return failResolution } + +type hodlExpiryTest struct { + hash lntypes.Hash + state channeldb.ContractState + stateLock sync.Mutex + mockNotifier *mockChainNotifier + mockClock *clock.TestClock + cancelChan chan lntypes.Hash + watcher *InvoiceExpiryWatcher +} + +func (h *hodlExpiryTest) setState(state channeldb.ContractState) { + h.stateLock.Lock() + defer h.stateLock.Unlock() + + h.state = state +} + +func (h *hodlExpiryTest) announceBlock(t *testing.T, height uint32) { + select { + case h.mockNotifier.blockChan <- &chainntnfs.BlockEpoch{ + Height: int32(height), + }: + + case <-time.After(testTimeout): + t.Fatalf("block %v not consumed", height) + } +} + +func (h *hodlExpiryTest) assertCanceled(t *testing.T, expected lntypes.Hash) { + select { + case actual := <-h.cancelChan: + require.Equal(t, expected, actual) + + case <-time.After(testTimeout): + t.Fatalf("invoice: %v not canceled", h.hash) + } +} + +// setupHodlExpiry creates a hodl invoice in our expiry watcher and runs an +// arbitrary update function which advances the invoices's state. +func setupHodlExpiry(t *testing.T, creationDate time.Time, + expiry time.Duration, heightDelta uint32, + startState channeldb.ContractState, + startHtlcs []*channeldb.InvoiceHTLC) *hodlExpiryTest { + + mockNotifier := newMockNotifier() + mockClock := clock.NewTestClock(testTime) + + test := &hodlExpiryTest{ + state: startState, + watcher: NewInvoiceExpiryWatcher( + mockClock, heightDelta, uint32(testCurrentHeight), nil, + mockNotifier, + ), + cancelChan: make(chan lntypes.Hash), + mockNotifier: mockNotifier, + mockClock: mockClock, + } + + // Use an unbuffered channel to block on cancel calls so that the test + // does not exit before we've processed all the invoices we expect. + cancelImpl := func(paymentHash lntypes.Hash, force bool) error { + test.stateLock.Lock() + currentState := test.state + test.stateLock.Unlock() + + if currentState != channeldb.ContractOpen && !force { + return nil + } + + select { + case test.cancelChan <- paymentHash: + case <-time.After(testTimeout): + } + + return nil + } + + require.NoError(t, test.watcher.Start(cancelImpl)) + + // We set preimage and hash so that we can use our existing test + // helpers. In practice we would only have the hash, but this does not + // affect what we're testing at all. + preimage := lntypes.Preimage{1} + test.hash = preimage.Hash() + + invoice := newTestInvoice(t, preimage, creationDate, expiry) + invoice.State = startState + invoice.HodlInvoice = true + invoice.Htlcs = make(map[channeldb.CircuitKey]*channeldb.InvoiceHTLC) + + // If we have any htlcs, add them with unique circult keys. + for i, htlc := range startHtlcs { + key := channeldb.CircuitKey{ + HtlcID: uint64(i), + } + + invoice.Htlcs[key] = htlc + } + + // Create an expiry entry for our invoice in its starting state. This + // mimics adding invoices to the watcher on start. + entry := makeInvoiceExpiry(test.hash, invoice) + test.watcher.AddInvoices(entry) + + return test +}