diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index 4c7f8f6e..14257581 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -12,9 +12,17 @@ import ( "github.com/lightningnetwork/lnd/zpay32" ) -// invoiceExpiry holds and invoice's payment hash and its expiry. This -// is used to order invoices by their expiry for cancellation. -type invoiceExpiry struct { +// invoiceExpiry is a vanity interface for different invoice expiry types +// which implement the priority queue item interface, used to improve code +// readability. +type invoiceExpiry queue.PriorityQueueItem + +// Compile time assertion that invoiceExpiryTs implements invoiceExpiry. +var _ invoiceExpiry = (*invoiceExpiryTs)(nil) + +// invoiceExpiryTs holds and invoice's payment hash and its expiry. This +// is used to order invoices by their expiry time for cancellation. +type invoiceExpiryTs struct { PaymentHash lntypes.Hash Expiry time.Time Keysend bool @@ -22,8 +30,8 @@ type invoiceExpiry struct { // Less implements PriorityQueueItem.Less such that the top item in the // priorty queue will be the one that expires next. -func (e invoiceExpiry) Less(other queue.PriorityQueueItem) bool { - return e.Expiry.Before(other.(*invoiceExpiry).Expiry) +func (e invoiceExpiryTs) Less(other queue.PriorityQueueItem) bool { + return e.Expiry.Before(other.(*invoiceExpiryTs).Expiry) } // InvoiceExpiryWatcher handles automatic invoice cancellation of expried @@ -44,13 +52,13 @@ type InvoiceExpiryWatcher struct { // cancelInvoice is a template method that cancels an expired invoice. cancelInvoice func(lntypes.Hash, bool) error - // expiryQueue holds invoiceExpiry items and is used to find the next - // invoice to expire. - expiryQueue queue.PriorityQueue + // timestampExpiryQueue holds invoiceExpiry items and is used to find + // the next invoice to expire. + timestampExpiryQueue queue.PriorityQueue // newInvoices channel is used to wake up the main loop when a new // invoices is added. - newInvoices chan []*invoiceExpiry + newInvoices chan []invoiceExpiry wg sync.WaitGroup @@ -62,7 +70,7 @@ type InvoiceExpiryWatcher struct { func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher { return &InvoiceExpiryWatcher{ clock: clock, - newInvoices: make(chan []*invoiceExpiry), + newInvoices: make(chan []invoiceExpiry), quit: make(chan struct{}), } } @@ -104,14 +112,29 @@ func (ew *InvoiceExpiryWatcher) Stop() { } // makeInvoiceExpiry checks if the passed invoice may be canceled and calculates -// the expiry time and creates a slimmer invoiceExpiry object with the hash and -// expiry time. +// the expiry time and creates a slimmer invoiceExpiry implementation. func makeInvoiceExpiry(paymentHash lntypes.Hash, - invoice *channeldb.Invoice) *invoiceExpiry { + invoice *channeldb.Invoice) invoiceExpiry { - if invoice.State != channeldb.ContractOpen { + switch invoice.State { + // If we have an open invoice with no htlcs, we want to expire the + // invoice based on timestamp + case channeldb.ContractOpen: + return makeTimestampExpiry(paymentHash, invoice) + + default: log.Debugf("Invoice not added to expiry watcher: %v", paymentHash) + + return nil + } +} + +// makeTimestampExpiry creates a timestamp-based expiry entry. +func makeTimestampExpiry(paymentHash lntypes.Hash, + invoice *channeldb.Invoice) *invoiceExpiryTs { + + if invoice.State != channeldb.ContractOpen { return nil } @@ -121,7 +144,7 @@ func makeInvoiceExpiry(paymentHash lntypes.Hash, } expiry := invoice.CreationDate.Add(realExpiry) - return &invoiceExpiry{ + return &invoiceExpiryTs{ PaymentHash: paymentHash, Expiry: expiry, Keysend: len(invoice.PaymentRequest) == 0, @@ -129,7 +152,7 @@ func makeInvoiceExpiry(paymentHash lntypes.Hash, } // AddInvoices adds invoices to the InvoiceExpiryWatcher. -func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...*invoiceExpiry) { +func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...invoiceExpiry) { if len(invoices) > 0 { select { case ew.newInvoices <- invoices: @@ -143,11 +166,12 @@ func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...*invoiceExpiry) { } } -// nextExpiry returns a Time chan to wait on until the next invoice expires. -// If there are no active invoices, then it'll simply wait indefinitely. -func (ew *InvoiceExpiryWatcher) nextExpiry() <-chan time.Time { - if !ew.expiryQueue.Empty() { - top := ew.expiryQueue.Top().(*invoiceExpiry) +// nextTimestampExpiry returns a Time chan to wait on until the next invoice +// expires. If there are no active invoices, then it'll simply wait +// indefinitely. +func (ew *InvoiceExpiryWatcher) nextTimestampExpiry() <-chan time.Time { + if !ew.timestampExpiryQueue.Empty() { + top := ew.timestampExpiryQueue.Top().(*invoiceExpiryTs) return ew.clock.TickAfter(top.Expiry.Sub(ew.clock.Now())) } @@ -157,8 +181,8 @@ func (ew *InvoiceExpiryWatcher) nextExpiry() <-chan time.Time { // cancelNextExpiredInvoice will cancel the next expired invoice and removes // it from the expiry queue. func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() { - if !ew.expiryQueue.Empty() { - top := ew.expiryQueue.Top().(*invoiceExpiry) + if !ew.timestampExpiryQueue.Empty() { + top := ew.timestampExpiryQueue.Top().(*invoiceExpiryTs) if !top.Expiry.Before(ew.clock.Now()) { return } @@ -169,15 +193,42 @@ func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() { // field would never be used. Enabling cancellation for accepted // keysend invoices creates a safety mechanism that can prevents // channel force-closes. - err := ew.cancelInvoice(top.PaymentHash, top.Keysend) - if err != nil && err != channeldb.ErrInvoiceAlreadySettled && - err != channeldb.ErrInvoiceAlreadyCanceled { + ew.expireInvoice(top.PaymentHash, top.Keysend) + ew.timestampExpiryQueue.Pop() + } +} - log.Errorf("Unable to cancel invoice: %v", - top.PaymentHash) +// expireInvoice attempts to expire an invoice and logs an error if we get an +// unexpected error. +func (ew *InvoiceExpiryWatcher) expireInvoice(hash lntypes.Hash, force bool) { + err := ew.cancelInvoice(hash, force) + switch err { + case nil: + + case channeldb.ErrInvoiceAlreadyCanceled: + + case channeldb.ErrInvoiceAlreadySettled: + + default: + log.Errorf("Unable to cancel invoice: %v: %v", hash, err) + } +} + +// pushInvoices adds invoices to be expired to their relevant queue. +func (ew *InvoiceExpiryWatcher) pushInvoices(invoices []invoiceExpiry) { + for _, inv := range invoices { + // Switch on the type of entry we have. We need to check nil + // on the implementation of the interface because the interface + // itself is non-nil. + switch expiry := inv.(type) { + case *invoiceExpiryTs: + if expiry != nil { + ew.timestampExpiryQueue.Push(expiry) + } + + default: + log.Errorf("unexpected queue item: %T", inv) } - - ew.expiryQueue.Pop() } } @@ -190,32 +241,23 @@ func (ew *InvoiceExpiryWatcher) mainLoop() { // Cancel any invoices that may have expired. ew.cancelNextExpiredInvoice() - pushInvoices := func(invoicesWithExpiry []*invoiceExpiry) { - for _, invoiceWithExpiry := range invoicesWithExpiry { - // Avoid pushing nil object to the heap. - if invoiceWithExpiry != nil { - ew.expiryQueue.Push(invoiceWithExpiry) - } - } - } - select { - case invoicesWithExpiry := <-ew.newInvoices: + case newInvoices := <-ew.newInvoices: // Take newly forwarded invoices with higher priority // in order to not block the newInvoices channel. - pushInvoices(invoicesWithExpiry) + ew.pushInvoices(newInvoices) continue default: select { - case <-ew.nextExpiry(): + case <-ew.nextTimestampExpiry(): // Wait until the next invoice expires. continue - case invoicesWithExpiry := <-ew.newInvoices: - pushInvoices(invoicesWithExpiry) + case newInvoices := <-ew.newInvoices: + ew.pushInvoices(newInvoices) case <-ew.quit: return diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index a06bde53..e2c7ea82 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -157,7 +157,7 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { t.Parallel() test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) - var invoices []*invoiceExpiry + var invoices []invoiceExpiry for hash, invoice := range test.testData.expiredInvoices { invoices = append(invoices, makeInvoiceExpiry(hash, invoice)) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index bb336be7..c457acd5 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -160,7 +160,7 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, // invoices. func (i *InvoiceRegistry) scanInvoicesOnStart() error { var ( - pending []*invoiceExpiry + pending []invoiceExpiry removable []channeldb.InvoiceDeleteRef ) @@ -1176,6 +1176,20 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { return i.cancelInvoiceImpl(payHash, true) } +// shouldCancel examines the state of an invoice and whether we want to +// cancel already accepted invoices, taking our force cancel boolean into +// account. This is pulled out into its own function so that tests that mock +// cancelInvoiceImpl can reuse this logic. +func shouldCancel(state channeldb.ContractState, cancelAccepted bool) bool { + if state != channeldb.ContractAccepted { + return true + } + + // If the invoice is accepted, we should only cancel if we want to + // force cancelation of accepted invoices. + return cancelAccepted +} + // cancelInvoice attempts to cancel the invoice corresponding to the passed // payment hash. Accepted invoices will only be canceled if explicitly // requested to do so. It notifies subscribing links and resolvers that @@ -1192,9 +1206,7 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, updateInvoice := func(invoice *channeldb.Invoice) ( *channeldb.InvoiceUpdateDesc, error) { - // Only cancel the invoice in ContractAccepted state if explicitly - // requested to do so. - if invoice.State == channeldb.ContractAccepted && !cancelAccepted { + if !shouldCancel(invoice.State, cancelAccepted) { return nil, nil } diff --git a/lntest/itest/lnd_hold_invoice_force_test.go b/lntest/itest/lnd_hold_invoice_force_test.go new file mode 100644 index 00000000..00831000 --- /dev/null +++ b/lntest/itest/lnd_hold_invoice_force_test.go @@ -0,0 +1,131 @@ +package itest + +import ( + "context" + "fmt" + + "github.com/lightningnetwork/lnd/lncfg" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lntest" + "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/stretchr/testify/require" +) + +// testHoldInvoiceForceClose demonstrates that recipients of hold invoices +// will not release active htlcs for their own invoices when they expire, +// resulting in a force close of their channel. +func testHoldInvoiceForceClose(net *lntest.NetworkHarness, t *harnessTest) { + ctxb, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Open a channel between alice and bob. + chanReq := lntest.OpenChannelParams{ + Amt: 300000, + } + + ctxt, _ := context.WithTimeout(ctxb, channelOpenTimeout) + chanPoint := openChannelAndAssert(ctxt, t, net, net.Alice, net.Bob, chanReq) + + // Create a non-dust hold invoice for bob. + var ( + preimage = lntypes.Preimage{1, 2, 3} + payHash = preimage.Hash() + ) + invoiceReq := &invoicesrpc.AddHoldInvoiceRequest{ + Value: 30000, + CltvExpiry: 40, + Hash: payHash[:], + } + + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + bobInvoice, err := net.Bob.AddHoldInvoice(ctxt, invoiceReq) + require.NoError(t.t, err) + + // Pay this invoice from Alice -> Bob, we should achieve this with a + // single htlc. + _, err = net.Alice.RouterClient.SendPaymentV2( + ctxb, &routerrpc.SendPaymentRequest{ + PaymentRequest: bobInvoice.PaymentRequest, + TimeoutSeconds: 60, + FeeLimitMsat: noFeeLimitMsat, + }, + ) + require.NoError(t.t, err) + + waitForInvoiceAccepted(t, net.Bob, payHash) + + // Once the HTLC has cleared, alice and bob should both have a single + // htlc locked in. + nodes := []*lntest.HarnessNode{net.Alice, net.Bob} + err = wait.NoError(func() error { + return assertActiveHtlcs(nodes, payHash[:]) + }, defaultTimeout) + require.NoError(t.t, err) + + // Get our htlc expiry height and current block height so that we + // can mine the exact number of blocks required to expire the htlc. + chans, err := net.Alice.ListChannels(ctxb, &lnrpc.ListChannelsRequest{}) + require.NoError(t.t, err) + require.Len(t.t, chans.Channels, 1) + require.Len(t.t, chans.Channels[0].PendingHtlcs, 1) + activeHtlc := chans.Channels[0].PendingHtlcs[0] + + info, err := net.Alice.GetInfo(ctxb, &lnrpc.GetInfoRequest{}) + require.NoError(t.t, err) + + // Now we will mine blocks until the htlc expires, and wait for each + // node to sync to our latest height. Sanity check that we won't + // underflow. + require.Greater(t.t, activeHtlc.ExpirationHeight, info.BlockHeight, + "expected expiry after current height") + blocksTillExpiry := activeHtlc.ExpirationHeight - info.BlockHeight + + // Alice will go to chain with some delta, sanity check that we won't + // underflow and subtract this from our mined blocks. + require.Greater(t.t, blocksTillExpiry, + uint32(lncfg.DefaultOutgoingBroadcastDelta)) + blocksTillForce := blocksTillExpiry - lncfg.DefaultOutgoingBroadcastDelta + + mineBlocks(t, net, blocksTillForce, 0) + + require.NoError(t.t, net.Alice.WaitForBlockchainSync(ctxb)) + require.NoError(t.t, net.Bob.WaitForBlockchainSync(ctxb)) + + // Alice should have a waiting-close channel because she has force + // closed to time out the htlc. + assertNumPendingChannels(t, net.Alice, 1, 0) + + // We should have our force close tx in the mempool. + mineBlocks(t, net, 1, 1) + + // Ensure alice and bob are synced to chain after we've mined our force + // close. + require.NoError(t.t, net.Alice.WaitForBlockchainSync(ctxb)) + require.NoError(t.t, net.Bob.WaitForBlockchainSync(ctxb)) + + // At this point, Bob's channel should be resolved because his htlc is + // expired, so no further action is required. Alice will still have a + // pending force close channel because she needs to resolve the htlc. + assertNumPendingChannels(t, net.Alice, 0, 1) + assertNumPendingChannels(t, net.Bob, 0, 0) + + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + err = waitForNumChannelPendingForceClose(ctxt, net.Alice, 1, + func(channel *lnrpcForceCloseChannel) error { + numHtlcs := len(channel.PendingHtlcs) + if numHtlcs != 1 { + return fmt.Errorf("expected 1 htlc, got: "+ + "%v", numHtlcs) + } + + return nil + }, + ) + require.NoError(t.t, err) + + // Cleanup Alice's force close. + cleanupForceClose(t, net, net.Alice, chanPoint) +} diff --git a/lntest/itest/lnd_multi-hop_htlc_aggregation_test.go b/lntest/itest/lnd_multi-hop_htlc_aggregation_test.go index b1261c15..3c91eed1 100644 --- a/lntest/itest/lnd_multi-hop_htlc_aggregation_test.go +++ b/lntest/itest/lnd_multi-hop_htlc_aggregation_test.go @@ -169,6 +169,12 @@ func testMultiHopHtlcAggregation(net *lntest.NetworkHarness, t *harnessTest, // be cpfp'ed. net.SetFeeEstimate(30000) + // We want Carol's htlcs to expire off-chain to demonstrate bob's force + // close. However, Carol will cancel her invoices to prevent force + // closes, so we shut her down for now. + restartCarol, err := net.SuspendNode(carol) + require.NoError(t.t, err) + // We'll now mine enough blocks to trigger Bob's broadcast of his // commitment transaction due to the fact that the Carol's HTLCs are // about to timeout. With the default outgoing broadcast delta of zero, @@ -225,6 +231,9 @@ func testMultiHopHtlcAggregation(net *lntest.NetworkHarness, t *harnessTest, } } + // Once bob has force closed, we can restart carol. + require.NoError(t.t, restartCarol()) + // Mine a block to confirm the closing transaction. mineBlocks(t, net, 1, expectedTxes) diff --git a/lntest/itest/lnd_multi-hop_remote_force_close_on_chain_htlc_timeout_test.go b/lntest/itest/lnd_multi-hop_remote_force_close_on_chain_htlc_timeout_test.go index e79cf7c3..b8dee822 100644 --- a/lntest/itest/lnd_multi-hop_remote_force_close_on_chain_htlc_timeout_test.go +++ b/lntest/itest/lnd_multi-hop_remote_force_close_on_chain_htlc_timeout_test.go @@ -5,9 +5,12 @@ import ( "fmt" "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/lightningnetwork/lnd/lntest" "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" ) @@ -43,14 +46,21 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, defer cancel() // We'll now send a single HTLC across our multi-hop network. - carolPubKey := carol.PubKey[:] - payHash := makeFakePayHash(t) - _, err := alice.RouterClient.SendPaymentV2( + preimage := lntypes.Preimage{1, 2, 3} + payHash := preimage.Hash() + invoiceReq := &invoicesrpc.AddHoldInvoiceRequest{ + Value: int64(htlcAmt), + CltvExpiry: 40, + Hash: payHash[:], + } + + ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) + carolInvoice, err := carol.AddHoldInvoice(ctxt, invoiceReq) + require.NoError(t.t, err) + + _, err = alice.RouterClient.SendPaymentV2( ctx, &routerrpc.SendPaymentRequest{ - Dest: carolPubKey, - Amt: int64(htlcAmt), - PaymentHash: payHash, - FinalCltvDelta: finalCltvDelta, + PaymentRequest: carolInvoice.PaymentRequest, TimeoutSeconds: 60, FeeLimitMsat: noFeeLimitMsat, }, @@ -61,7 +71,7 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, // show that the HTLC has been locked in. nodes := []*lntest.HarnessNode{alice, bob, carol} err = wait.NoError(func() error { - return assertActiveHtlcs(nodes, payHash) + return assertActiveHtlcs(nodes, payHash[:]) }, defaultTimeout) require.NoError(t.t, err) @@ -73,7 +83,7 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, // transaction. This will let us exercise that Bob is able to sweep the // expired HTLC on Carol's version of the commitment transaction. If // Carol has an anchor, it will be swept too. - ctxt, _ := context.WithTimeout(ctxb, channelCloseTimeout) + ctxt, _ = context.WithTimeout(ctxb, channelCloseTimeout) closeChannelAndAssertType( ctxt, t, net, carol, bobChanPoint, c == commitTypeAnchors, true, @@ -168,6 +178,10 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, err = waitForNumChannelPendingForceClose(ctxt, bob, 0, nil) require.NoError(t.t, err) + // While we're here, we demonstrate some bugs in our handling of + // invoices that timeout on chain. + assertOnChainInvoiceState(ctxb, t, carol, preimage) + // We'll close out the test by closing the channel from Alice to Bob, // and then shutting down the new node we created as its no longer // needed. Coop close, no anchors. @@ -176,3 +190,39 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, ctxt, t, net, alice, aliceChanPoint, false, false, ) } + +// assertOnChainInvoiceState asserts that we have some bugs with how we handle +// hold invoices that are expired on-chain. +// - htlcs accepted: despite being timed out, our htlcs are still in accepted +// state +// - can settle: our invoice that has expired on-chain can still be settled +// even though we don't claim any htlcs. +func assertOnChainInvoiceState(ctx context.Context, t *harnessTest, + node *lntest.HarnessNode, preimage lntypes.Preimage) { + + hash := preimage.Hash() + inv, err := node.LookupInvoice(ctx, &lnrpc.PaymentHash{ + RHash: hash[:], + }) + require.NoError(t.t, err) + + for _, htlc := range inv.Htlcs { + require.Equal(t.t, lnrpc.InvoiceHTLCState_ACCEPTED, htlc.State) + } + + _, err = node.SettleInvoice(ctx, &invoicesrpc.SettleInvoiceMsg{ + Preimage: preimage[:], + }) + require.NoError(t.t, err, "expected erroneous invoice settle") + + inv, err = node.LookupInvoice(ctx, &lnrpc.PaymentHash{ + RHash: hash[:], + }) + require.NoError(t.t, err) + + require.True(t.t, inv.Settled, "expected erroneously settled invoice") + for _, htlc := range inv.Htlcs { + require.Equal(t.t, lnrpc.InvoiceHTLCState_SETTLED, htlc.State, + "expected htlcs to be erroneously settled") + } +} diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index 666b6641..bacc4bea 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -440,8 +440,9 @@ func waitForNumChannelPendingForceClose(ctx context.Context, forceCloseChans := resp.PendingForceClosingChannels if len(forceCloseChans) != expectedNum { - return fmt.Errorf("bob should have %d pending "+ - "force close channels but has %d", expectedNum, + return fmt.Errorf("%v should have %d pending "+ + "force close channels but has %d", + node.Cfg.Name, expectedNum, len(forceCloseChans)) } diff --git a/lntest/itest/lnd_test_list_on_test.go b/lntest/itest/lnd_test_list_on_test.go index 7f630c0d..679c572f 100644 --- a/lntest/itest/lnd_test_list_on_test.go +++ b/lntest/itest/lnd_test_list_on_test.go @@ -230,6 +230,10 @@ var allTestCases = []*testCase{ name: "hold invoice sender persistence", test: testHoldInvoicePersistence, }, + { + name: "hold invoice force close", + test: testHoldInvoiceForceClose, + }, { name: "cpfp", test: testCPFP,