diff --git a/config.go b/config.go index 7e3f4dee..232c90be 100644 --- a/config.go +++ b/config.go @@ -347,6 +347,8 @@ type Config struct { GcCanceledInvoicesOnTheFly bool `long:"gc-canceled-invoices-on-the-fly" description:"If true, we'll delete newly canceled invoices on the fly."` + Invoices *lncfg.Invoices `group:"invoices" namespace:"invoices"` + Routing *lncfg.Routing `group:"routing" namespace:"routing"` Gossip *lncfg.Gossip `group:"gossip" namespace:"gossip"` @@ -532,6 +534,9 @@ func DefaultConfig() Config { MaxChannelUpdateBurst: discovery.DefaultMaxChannelUpdateBurst, ChannelUpdateInterval: discovery.DefaultChannelUpdateInterval, }, + Invoices: &lncfg.Invoices{ + HoldExpiryDelta: lncfg.DefaultHoldInvoiceExpiryDelta, + }, MaxOutgoingCltvExpiry: htlcswitch.DefaultMaxOutgoingCltvExpiry, MaxChannelFeeAllocation: htlcswitch.DefaultMaxLinkFeeAllocation, MaxCommitFeeRateAnchors: lnwallet.DefaultAnchorsCommitMaxFeeRateSatPerVByte, @@ -1389,6 +1394,18 @@ func ValidateConfig(cfg Config, usageMessage string, return nil, err } + // Log a warning if our expiry delta is not greater than our incoming + // broadcast delta. We do not fail here because this value may be set + // to zero to intentionally keep lnd's behavior unchanged from when we + // didn't auto-cancel these invoices. + if cfg.Invoices.HoldExpiryDelta <= lncfg.DefaultIncomingBroadcastDelta { + ltndLog.Warnf("Invoice hold expiry delta: %v <= incoming "+ + "delta: %v, accepted hold invoices will force close "+ + "channels if they are not canceled manually", + cfg.Invoices.HoldExpiryDelta, + lncfg.DefaultIncomingBroadcastDelta) + } + // Validate the subconfigs for workers, caches, and the tower client. err = lncfg.Validate( cfg.Workers, diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index d3e9167e..13872a45 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -797,6 +797,20 @@ type mockInvoiceRegistry struct { cleanup func() } +type mockChainNotifier struct { + chainntnfs.ChainNotifier +} + +// RegisterBlockEpochNtfn mocks a successful call to register block +// notifications. +func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( + *chainntnfs.BlockEpochEvent, error) { + + return &chainntnfs.BlockEpochEvent{ + Cancel: func() {}, + }, nil +} + func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { cdb, cleanup, err := newDB() if err != nil { @@ -805,7 +819,10 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { registry := invoices.NewRegistry( cdb, - invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()), + invoices.NewInvoiceExpiryWatcher( + clock.NewDefaultClock(), 0, 0, nil, + &mockChainNotifier{}, + ), &invoices.RegistryConfig{ FinalCltvRejectDelta: 5, }, diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go index 14257581..70d73608 100644 --- a/invoices/invoice_expiry_watcher.go +++ b/invoices/invoice_expiry_watcher.go @@ -5,6 +5,8 @@ import ( "sync" "time" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" @@ -34,6 +36,28 @@ func (e invoiceExpiryTs) Less(other queue.PriorityQueueItem) bool { return e.Expiry.Before(other.(*invoiceExpiryTs).Expiry) } +// Compile time assertion that invoiceExpiryHeight implements invoiceExpiry. +var _ invoiceExpiry = (*invoiceExpiryHeight)(nil) + +// invoiceExpiryHeight holds information about an invoice which can be used to +// cancel it based on its expiry height. +type invoiceExpiryHeight struct { + paymentHash lntypes.Hash + expiryHeight uint32 +} + +// Less implements PriorityQueueItem.Less such that the top item in the +// priority queue is the lowest block height. +func (b invoiceExpiryHeight) Less(other queue.PriorityQueueItem) bool { + return b.expiryHeight < other.(*invoiceExpiryHeight).expiryHeight +} + +// expired returns a boolean that indicates whether this entry has expired, +// taking our expiry delta into account. +func (b invoiceExpiryHeight) expired(currentHeight, delta uint32) bool { + return currentHeight+delta >= b.expiryHeight +} + // InvoiceExpiryWatcher handles automatic invoice cancellation of expried // invoices. Upon start InvoiceExpiryWatcher will retrieve all pending (not yet // settled or canceled) invoices invoices to its watcing queue. When a new @@ -49,6 +73,21 @@ type InvoiceExpiryWatcher struct { // It is useful for testing. clock clock.Clock + // notifier provides us with block height updates. + notifier chainntnfs.ChainNotifier + + // blockExpiryDelta is the number of blocks before a htlc's expiry that + // we expire the invoice based on expiry height. We use a delta because + // we will go to some delta before our expiry, so we want to cancel + // before this to prevent force closes. + blockExpiryDelta uint32 + + // currentHeight is the current block height. + currentHeight uint32 + + // currentHash is the block hash for our current height. + currentHash *chainhash.Hash + // cancelInvoice is a template method that cancels an expired invoice. cancelInvoice func(lntypes.Hash, bool) error @@ -56,6 +95,15 @@ type InvoiceExpiryWatcher struct { // the next invoice to expire. timestampExpiryQueue queue.PriorityQueue + // blockExpiryQueue holds blockExpiry items and is used to find the + // next invoice to expire based on block height. Only hold invoices + // with active htlcs are added to this queue, because they require + // manual cancellation when the hltc is going to time out. Items in + // this queue may already be in the timestampExpiryQueue, this is ok + // because they will not be expired based on timestamp if they have + // active htlcs. + blockExpiryQueue queue.PriorityQueue + // newInvoices channel is used to wake up the main loop when a new // invoices is added. newInvoices chan []invoiceExpiry @@ -67,11 +115,18 @@ type InvoiceExpiryWatcher struct { } // NewInvoiceExpiryWatcher creates a new InvoiceExpiryWatcher instance. -func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher { +func NewInvoiceExpiryWatcher(clock clock.Clock, + expiryDelta, startHeight uint32, startHash *chainhash.Hash, + notifier chainntnfs.ChainNotifier) *InvoiceExpiryWatcher { + return &InvoiceExpiryWatcher{ - clock: clock, - newInvoices: make(chan []invoiceExpiry), - quit: make(chan struct{}), + clock: clock, + notifier: notifier, + blockExpiryDelta: expiryDelta, + currentHeight: startHeight, + currentHash: startHash, + newInvoices: make(chan []invoiceExpiry), + quit: make(chan struct{}), } } @@ -91,8 +146,17 @@ func (ew *InvoiceExpiryWatcher) Start( ew.started = true ew.cancelInvoice = cancelInvoice + + ntfn, err := ew.notifier.RegisterBlockEpochNtfn(&chainntnfs.BlockEpoch{ + Height: int32(ew.currentHeight), + Hash: ew.currentHash, + }) + if err != nil { + return err + } + ew.wg.Add(1) - go ew.mainLoop() + go ew.mainLoop(ntfn) return nil } @@ -122,6 +186,32 @@ func makeInvoiceExpiry(paymentHash lntypes.Hash, case channeldb.ContractOpen: return makeTimestampExpiry(paymentHash, invoice) + // If an invoice has active htlcs, we want to expire it based on block + // height. We only do this for hodl invoices, since regular invoices + // should resolve themselves automatically. + case channeldb.ContractAccepted: + if !invoice.HodlInvoice { + log.Debugf("Invoice in accepted state not added to "+ + "expiry watcher: %v", paymentHash) + + return nil + } + + var minHeight uint32 + for _, htlc := range invoice.Htlcs { + // We only care about accepted htlcs, since they will + // trigger force-closes. + if htlc.State != channeldb.HtlcStateAccepted { + continue + } + + if minHeight == 0 || htlc.Expiry < minHeight { + minHeight = htlc.Expiry + } + } + + return makeHeightExpiry(paymentHash, minHeight) + default: log.Debugf("Invoice not added to expiry watcher: %v", paymentHash) @@ -151,18 +241,36 @@ func makeTimestampExpiry(paymentHash lntypes.Hash, } } +// makeHeightExpiry creates height-based expiry for an invoice based on its +// lowest htlc expiry height. +func makeHeightExpiry(paymentHash lntypes.Hash, + minHeight uint32) *invoiceExpiryHeight { + + if minHeight == 0 { + log.Warnf("make height expiry called with 0 height") + return nil + } + + return &invoiceExpiryHeight{ + paymentHash: paymentHash, + expiryHeight: minHeight, + } +} + // AddInvoices adds invoices to the InvoiceExpiryWatcher. func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...invoiceExpiry) { - if len(invoices) > 0 { - select { - case ew.newInvoices <- invoices: - log.Debugf("Added %d invoices to the expiry watcher", - len(invoices)) + if len(invoices) == 0 { + return + } - // Select on quit too so that callers won't get blocked in case - // of concurrent shutdown. - case <-ew.quit: - } + select { + case ew.newInvoices <- invoices: + log.Debugf("Added %d invoices to the expiry watcher", + len(invoices)) + + // Select on quit too so that callers won't get blocked in case + // of concurrent shutdown. + case <-ew.quit: } } @@ -178,6 +286,23 @@ func (ew *InvoiceExpiryWatcher) nextTimestampExpiry() <-chan time.Time { return nil } +// nextHeightExpiry returns a channel that will immediately be read from if +// the top item on our queue has expired. +func (ew *InvoiceExpiryWatcher) nextHeightExpiry() <-chan uint32 { + if ew.blockExpiryQueue.Empty() { + return nil + } + + top := ew.blockExpiryQueue.Top().(*invoiceExpiryHeight) + if !top.expired(ew.currentHeight, ew.blockExpiryDelta) { + return nil + } + + blockChan := make(chan uint32, 1) + blockChan <- top.expiryHeight + return blockChan +} + // cancelNextExpiredInvoice will cancel the next expired invoice and removes // it from the expiry queue. func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() { @@ -198,6 +323,25 @@ func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() { } } +// cancelNextHeightExpiredInvoice looks at our height based queue and expires +// the next invoice if we have reached its expiry block. +func (ew *InvoiceExpiryWatcher) cancelNextHeightExpiredInvoice() { + if ew.blockExpiryQueue.Empty() { + return + } + + top := ew.blockExpiryQueue.Top().(*invoiceExpiryHeight) + if !top.expired(ew.currentHeight, ew.blockExpiryDelta) { + return + } + + // We always force-cancel block-based expiry so that we can + // cancel invoices that have been accepted but not yet resolved. + // This helps us avoid force closes. + ew.expireInvoice(top.paymentHash, true) + ew.blockExpiryQueue.Pop() +} + // expireInvoice attempts to expire an invoice and logs an error if we get an // unexpected error. func (ew *InvoiceExpiryWatcher) expireInvoice(hash lntypes.Hash, force bool) { @@ -226,6 +370,11 @@ func (ew *InvoiceExpiryWatcher) pushInvoices(invoices []invoiceExpiry) { ew.timestampExpiryQueue.Push(expiry) } + case *invoiceExpiryHeight: + if expiry != nil { + ew.blockExpiryQueue.Push(expiry) + } + default: log.Errorf("unexpected queue item: %T", inv) } @@ -234,12 +383,20 @@ func (ew *InvoiceExpiryWatcher) pushInvoices(invoices []invoiceExpiry) { // mainLoop is a goroutine that receives new invoices and handles cancellation // of expired invoices. -func (ew *InvoiceExpiryWatcher) mainLoop() { - defer ew.wg.Done() +func (ew *InvoiceExpiryWatcher) mainLoop(blockNtfns *chainntnfs.BlockEpochEvent) { + defer func() { + blockNtfns.Cancel() + ew.wg.Done() + }() + + // We have two different queues, so we use a different cancel method + // depending on which expiry condition we have hit. Starting with time + // based expiry is an arbitrary choice to start off. + cancelNext := ew.cancelNextExpiredInvoice for { // Cancel any invoices that may have expired. - ew.cancelNextExpiredInvoice() + cancelNext() select { @@ -252,13 +409,29 @@ func (ew *InvoiceExpiryWatcher) mainLoop() { default: select { + // Wait until the next invoice expires. case <-ew.nextTimestampExpiry(): - // Wait until the next invoice expires. + cancelNext = ew.cancelNextExpiredInvoice + continue + + case <-ew.nextHeightExpiry(): + cancelNext = ew.cancelNextHeightExpiredInvoice continue case newInvoices := <-ew.newInvoices: ew.pushInvoices(newInvoices) + // Consume new blocks. + case block, ok := <-blockNtfns.Epochs: + if !ok { + log.Debugf("block notifications " + + "canceled") + return + } + + ew.currentHeight = uint32(block.Height) + ew.currentHash = block.Hash + case <-ew.quit: return } diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go index e2c7ea82..63ddfb92 100644 --- a/invoices/invoice_expiry_watcher_test.go +++ b/invoices/invoice_expiry_watcher_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" ) @@ -19,13 +21,40 @@ type invoiceExpiryWatcherTest struct { canceledInvoices []lntypes.Hash } +type mockChainNotifier struct { + chainntnfs.ChainNotifier + + blockChan chan *chainntnfs.BlockEpoch +} + +func newMockNotifier() *mockChainNotifier { + return &mockChainNotifier{ + blockChan: make(chan *chainntnfs.BlockEpoch), + } +} + +// RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's +// block channel to deliver blocks to the client. +func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( + *chainntnfs.BlockEpochEvent, error) { + + return &chainntnfs.BlockEpochEvent{ + Epochs: m.blockChan, + Cancel: func() {}, + }, nil +} + // newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture // and sets up the test environment. func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time, numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest { + mockNotifier := newMockNotifier() test := &invoiceExpiryWatcherTest{ - watcher: NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)), + watcher: NewInvoiceExpiryWatcher( + clock.NewTestClock(testTime), 0, + uint32(testCurrentHeight), nil, mockNotifier, + ), testData: generateInvoiceExpiryTestData( t, now, 0, numExpiredInvoices, numPendingInvoices, ), @@ -84,7 +113,10 @@ func (t *invoiceExpiryWatcherTest) checkExpectations() { // Tests that InvoiceExpiryWatcher can be started and stopped. func TestInvoiceExpiryWatcherStartStop(t *testing.T) { - watcher := NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)) + watcher := NewInvoiceExpiryWatcher( + clock.NewTestClock(testTime), 0, uint32(testCurrentHeight), nil, + newMockNotifier(), + ) cancel := func(lntypes.Hash, bool) error { t.Fatalf("unexpected call") return nil @@ -172,3 +204,115 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) { test.watcher.Stop() test.checkExpectations() } + +// TestExpiredHodlInv tests expiration of an already-expired hodl invoice +// which has no htlcs. +func TestExpiredHodlInv(t *testing.T) { + t.Parallel() + + creationDate := testTime.Add(time.Hour * -24) + expiry := time.Hour + + test := setupHodlExpiry( + t, creationDate, expiry, 0, channeldb.ContractOpen, nil, + ) + + test.assertCanceled(t, test.hash) + test.watcher.Stop() +} + +// TestAcceptedHodlNotExpired tests that hodl invoices which are in an accepted +// state are not expired once their time-based expiry elapses, using a regular +// invoice that expires at the same time as a control to ensure that invoices +// with that timestamp would otherwise be expired. +func TestAcceptedHodlNotExpired(t *testing.T) { + t.Parallel() + + creationDate := testTime + expiry := time.Hour + + test := setupHodlExpiry( + t, creationDate, expiry, 0, channeldb.ContractAccepted, nil, + ) + defer test.watcher.Stop() + + // Add another invoice that will expire at our expiry time as a control + // value. + tsExpires := &invoiceExpiryTs{ + PaymentHash: lntypes.Hash{1, 2, 3}, + Expiry: creationDate.Add(expiry), + Keysend: true, + } + test.watcher.AddInvoices(tsExpires) + + test.mockClock.SetTime(creationDate.Add(expiry + 1)) + + // Assert that only the ts expiry invoice is expired. + test.assertCanceled(t, tsExpires.PaymentHash) +} + +// TestHeightAlreadyExpired tests the case where we add an invoice with htlcs +// that have already expired to the expiry watcher. +func TestHeightAlreadyExpired(t *testing.T) { + t.Parallel() + + expiredHtlc := []*channeldb.InvoiceHTLC{ + { + State: channeldb.HtlcStateAccepted, + Expiry: uint32(testCurrentHeight), + }, + } + + test := setupHodlExpiry( + t, testTime, time.Hour, 0, channeldb.ContractAccepted, + expiredHtlc, + ) + defer test.watcher.Stop() + + test.assertCanceled(t, test.hash) +} + +// TestExpiryHeightArrives tests the case where we add a hodl invoice to the +// expiry watcher when it has no htlcs, htlcs are added and then they finally +// expire. We use a non-zero delta for this test to check that we expire with +// sufficient buffer. +func TestExpiryHeightArrives(t *testing.T) { + var ( + creationDate = testTime + expiry = time.Hour * 2 + delta uint32 = 1 + ) + + // Start out with a hodl invoice that is open, and has no htlcs. + test := setupHodlExpiry( + t, creationDate, expiry, delta, channeldb.ContractOpen, nil, + ) + defer test.watcher.Stop() + + htlc1 := uint32(testCurrentHeight + 10) + expiry1 := makeHeightExpiry(test.hash, htlc1) + + // Add htlcs to our invoice and progress its state to accepted. + test.watcher.AddInvoices(expiry1) + test.setState(channeldb.ContractAccepted) + + // Progress time so that our expiry has elapsed. We no longer expect + // this invoice to be canceled because it has been accepted. + test.mockClock.SetTime(creationDate.Add(expiry)) + + // Tick our mock block subscription with the next block, we don't + // expect anything to happen. + currentHeight := uint32(testCurrentHeight + 1) + test.announceBlock(t, currentHeight) + + // Now, we add another htlc to the invoice. This one has a lower expiry + // height than our current ones. + htlc2 := currentHeight + 5 + expiry2 := makeHeightExpiry(test.hash, htlc2) + test.watcher.AddInvoices(expiry2) + + // Announce our lowest htlc expiry block minus our delta, the invoice + // should be expired now. + test.announceBlock(t, htlc2-delta) + test.assertCanceled(t, test.hash) +} diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index da4c2136..5af1d05b 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -1118,6 +1118,16 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( } + // If we have fully accepted the set of htlcs for this invoice, + // we can now add it to our invoice expiry watcher. We do not + // add invoices before they are fully accepted, because it is + // possible that we MppTimeout the htlcs, and then our relevant + // expiry height could change. + if res.outcome == resultAccepted { + expiry := makeInvoiceExpiry(ctx.hash, invoice) + i.expiryWatcher.AddInvoices(expiry) + } + i.hodlSubscribe(hodlChan, ctx.circuitKey) default: diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index c049bb53..a7573778 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/lightningnetwork/lnd/amp" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" @@ -352,7 +353,11 @@ func TestSettleHoldInvoice(t *testing.T) { FinalCltvRejectDelta: testFinalCltvRejectDelta, Clock: clock.NewTestClock(testTime), } - registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg) + + expiryWatcher := NewInvoiceExpiryWatcher( + cfg.Clock, 0, uint32(testCurrentHeight), nil, newMockNotifier(), + ) + registry := NewRegistry(cdb, expiryWatcher, &cfg) err = registry.Start() if err != nil { @@ -521,7 +526,10 @@ func TestCancelHoldInvoice(t *testing.T) { FinalCltvRejectDelta: testFinalCltvRejectDelta, Clock: clock.NewTestClock(testTime), } - registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg) + expiryWatcher := NewInvoiceExpiryWatcher( + cfg.Clock, 0, uint32(testCurrentHeight), nil, newMockNotifier(), + ) + registry := NewRegistry(cdb, expiryWatcher, &cfg) err = registry.Start() if err != nil { @@ -946,7 +954,9 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { Clock: testClock, } - expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock) + expiryWatcher := NewInvoiceExpiryWatcher( + cfg.Clock, 0, uint32(testCurrentHeight), nil, newMockNotifier(), + ) registry := NewRegistry(cdb, expiryWatcher, &cfg) // First prefill the Channel DB with some pre-existing invoices, @@ -1049,7 +1059,9 @@ func TestOldInvoiceRemovalOnStart(t *testing.T) { GcCanceledInvoicesOnStartup: true, } - expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock) + expiryWatcher := NewInvoiceExpiryWatcher( + cfg.Clock, 0, uint32(testCurrentHeight), nil, newMockNotifier(), + ) registry := NewRegistry(cdb, expiryWatcher, &cfg) // First prefill the Channel DB with some pre-existing expired invoices. @@ -1107,6 +1119,222 @@ func TestOldInvoiceRemovalOnStart(t *testing.T) { require.Equal(t, expected, response.Invoices) } +// TestHeightExpiryWithRegistry tests our height-based invoice expiry for +// invoices paid with single and multiple htlcs, testing the case where the +// invoice is settled before expiry (and thus not canceled), and the case +// where the invoice is expired. +func TestHeightExpiryWithRegistry(t *testing.T) { + t.Run("single shot settled before expiry", func(t *testing.T) { + testHeightExpiryWithRegistry(t, 1, true) + }) + + t.Run("single shot expires", func(t *testing.T) { + testHeightExpiryWithRegistry(t, 1, false) + }) + + t.Run("mpp settled before expiry", func(t *testing.T) { + testHeightExpiryWithRegistry(t, 2, true) + }) + + t.Run("mpp expires", func(t *testing.T) { + testHeightExpiryWithRegistry(t, 2, false) + }) +} + +func testHeightExpiryWithRegistry(t *testing.T, numParts int, settle bool) { + t.Parallel() + defer timeout()() + + ctx := newTestContext(t) + defer ctx.cleanup() + + require.Greater(t, numParts, 0, "test requires at least one part") + + // Add a hold invoice, we set a non-nil payment request so that this + // invoice is not considered a keysend by the expiry watcher. + invoice := *testInvoice + invoice.HodlInvoice = true + invoice.PaymentRequest = []byte{1, 2, 3} + + _, err := ctx.registry.AddInvoice(&invoice, testInvoicePaymentHash) + require.NoError(t, err) + + payLoad := testPayload + if numParts > 1 { + payLoad = &mockPayload{ + mpp: record.NewMPP(testInvoiceAmt, [32]byte{}), + } + } + + htlcAmt := invoice.Terms.Value / lnwire.MilliSatoshi(numParts) + hodlChan := make(chan interface{}, numParts) + for i := 0; i < numParts; i++ { + // We bump our expiry height for each htlc so that we can test + // that the lowest expiry height is used. + expiry := testHtlcExpiry + uint32(i) + + resolution, err := ctx.registry.NotifyExitHopHtlc( + testInvoicePaymentHash, htlcAmt, expiry, + testCurrentHeight, getCircuitKey(uint64(i)), hodlChan, + payLoad, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + } + + require.Eventually(t, func() bool { + inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) + require.NoError(t, err) + + return inv.State == channeldb.ContractAccepted + }, time.Second, time.Millisecond*100) + + // Now that we've added our htlc(s), we tick our test clock to our + // invoice expiry time. We don't expect the invoice to be canceled + // based on its expiry time now that we have active htlcs. + ctx.clock.SetTime(invoice.CreationDate.Add(invoice.Terms.Expiry + 1)) + + // The expiry watcher loop takes some time to process the new clock + // time. We mine the block before our expiry height, our mock will block + // until the expiry watcher consumes this height, so we can be sure + // that the expiry loop has run at least once after this block is + // consumed. + ctx.notifier.blockChan <- &chainntnfs.BlockEpoch{ + Height: int32(testHtlcExpiry - 1), + } + + // If we want to settle our invoice in this test, we do so now. + if settle { + err = ctx.registry.SettleHodlInvoice(testInvoicePreimage) + require.NoError(t, err) + + for i := 0; i < numParts; i++ { + htlcResolution := (<-hodlChan).(HtlcResolution) + require.NotNil(t, htlcResolution) + settleResolution := checkSettleResolution( + t, htlcResolution, testInvoicePreimage, + ) + require.Equal(t, ResultSettled, settleResolution.Outcome) + } + } + + // Now we mine our htlc's expiry height. + ctx.notifier.blockChan <- &chainntnfs.BlockEpoch{ + Height: int32(testHtlcExpiry), + } + + // If we did not settle the invoice before its expiry, we now expect + // a cancelation. + expectedState := channeldb.ContractSettled + if !settle { + expectedState = channeldb.ContractCanceled + + htlcResolution := (<-hodlChan).(HtlcResolution) + require.NotNil(t, htlcResolution) + checkFailResolution( + t, htlcResolution, ResultCanceled, + ) + } + + // Finally, lookup the invoice and assert that we have the state we + // expect. + inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) + require.NoError(t, err) + require.Equal(t, expectedState, inv.State, "expected "+ + "hold invoice: %v, got: %v", expectedState, inv.State) +} + +// TestMultipleSetHeightExpiry pays a hold invoice with two mpp sets, testing +// that the invoice expiry watcher only uses the expiry height of the second, +// successful set to cancel the invoice, and does not cancel early using the +// expiry height of the first set that was canceled back due to mpp timeout. +func TestMultipleSetHeightExpiry(t *testing.T) { + t.Parallel() + defer timeout()() + + ctx := newTestContext(t) + defer ctx.cleanup() + + // Add a hold invoice. + invoice := *testInvoice + invoice.HodlInvoice = true + + _, err := ctx.registry.AddInvoice(&invoice, testInvoicePaymentHash) + require.NoError(t, err) + + mppPayload := &mockPayload{ + mpp: record.NewMPP(testInvoiceAmt, [32]byte{}), + } + + // Send htlc 1. + hodlChan1 := make(chan interface{}, 1) + resolution, err := ctx.registry.NotifyExitHopHtlc( + testInvoicePaymentHash, invoice.Terms.Value/2, + testHtlcExpiry, + testCurrentHeight, getCircuitKey(10), hodlChan1, mppPayload, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + + // Simulate mpp timeout releasing htlc 1. + ctx.clock.SetTime(testTime.Add(30 * time.Second)) + + htlcResolution := (<-hodlChan1).(HtlcResolution) + failResolution, ok := htlcResolution.(*HtlcFailResolution) + require.True(t, ok, "expected fail resolution, got: %T", resolution) + require.Equal(t, ResultMppTimeout, failResolution.Outcome, + "expected MPP Timeout, got: %v", failResolution.Outcome) + + // Notify the expiry height for our first htlc. We don't expect the + // invoice to be expired based on block height because the htlc set + // was never completed. + ctx.notifier.blockChan <- &chainntnfs.BlockEpoch{ + Height: int32(testHtlcExpiry), + } + + // Now we will send a full set of htlcs for the invoice with a higher + // expiry height. We expect the invoice to move into the accepted state. + expiry := testHtlcExpiry + 5 + + // Send htlc 2. + hodlChan2 := make(chan interface{}, 1) + resolution, err = ctx.registry.NotifyExitHopHtlc( + testInvoicePaymentHash, invoice.Terms.Value/2, expiry, + testCurrentHeight, getCircuitKey(11), hodlChan2, mppPayload, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + + // Send htlc 3. + hodlChan3 := make(chan interface{}, 1) + resolution, err = ctx.registry.NotifyExitHopHtlc( + testInvoicePaymentHash, invoice.Terms.Value/2, expiry, + testCurrentHeight, getCircuitKey(12), hodlChan3, mppPayload, + ) + require.NoError(t, err) + require.Nil(t, resolution, "did not expect direct resolution") + + // Assert that we've reached an accepted state because the invoice has + // been paid with a complete set. + inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) + require.NoError(t, err) + require.Equal(t, channeldb.ContractAccepted, inv.State, "expected "+ + "hold invoice accepted") + + // Now we will notify the expiry height for the new set of htlcs. We + // expect the invoice to be canceled by the expiry watcher. + ctx.notifier.blockChan <- &chainntnfs.BlockEpoch{ + Height: int32(expiry), + } + + require.Eventuallyf(t, func() bool { + inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) + require.NoError(t, err) + + return inv.State == channeldb.ContractCanceled + }, testTimeout, time.Millisecond*100, "invoice not canceled") +} + // TestSettleInvoicePaymentAddrRequired tests that if an incoming payment has // an invoice that requires the payment addr bit to be set, and the incoming // payment doesn't include an mpp payload, then the payment is rejected. diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index 3e49a957..f013f0d5 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -8,11 +8,13 @@ import ( "io/ioutil" "os" "runtime/pprof" + "sync" "testing" "time" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" @@ -179,6 +181,7 @@ func newTestChannelDB(clock clock.Clock) (*channeldb.DB, func(), error) { type testContext struct { cdb *channeldb.DB registry *InvoiceRegistry + notifier *mockChainNotifier clock *clock.TestClock cleanup func() @@ -193,7 +196,11 @@ func newTestContext(t *testing.T) *testContext { t.Fatal(err) } - expiryWatcher := NewInvoiceExpiryWatcher(clock) + notifier := newMockNotifier() + + expiryWatcher := NewInvoiceExpiryWatcher( + clock, 0, uint32(testCurrentHeight), nil, notifier, + ) // Instantiate and start the invoice ctx.registry. cfg := RegistryConfig{ @@ -212,6 +219,7 @@ func newTestContext(t *testing.T) *testContext { ctx := testContext{ cdb: cdb, registry: registry, + notifier: notifier, clock: clock, t: t, cleanup: func() { @@ -365,3 +373,111 @@ func checkFailResolution(t *testing.T, res HtlcResolution, return failResolution } + +type hodlExpiryTest struct { + hash lntypes.Hash + state channeldb.ContractState + stateLock sync.Mutex + mockNotifier *mockChainNotifier + mockClock *clock.TestClock + cancelChan chan lntypes.Hash + watcher *InvoiceExpiryWatcher +} + +func (h *hodlExpiryTest) setState(state channeldb.ContractState) { + h.stateLock.Lock() + defer h.stateLock.Unlock() + + h.state = state +} + +func (h *hodlExpiryTest) announceBlock(t *testing.T, height uint32) { + select { + case h.mockNotifier.blockChan <- &chainntnfs.BlockEpoch{ + Height: int32(height), + }: + + case <-time.After(testTimeout): + t.Fatalf("block %v not consumed", height) + } +} + +func (h *hodlExpiryTest) assertCanceled(t *testing.T, expected lntypes.Hash) { + select { + case actual := <-h.cancelChan: + require.Equal(t, expected, actual) + + case <-time.After(testTimeout): + t.Fatalf("invoice: %v not canceled", h.hash) + } +} + +// setupHodlExpiry creates a hodl invoice in our expiry watcher and runs an +// arbitrary update function which advances the invoices's state. +func setupHodlExpiry(t *testing.T, creationDate time.Time, + expiry time.Duration, heightDelta uint32, + startState channeldb.ContractState, + startHtlcs []*channeldb.InvoiceHTLC) *hodlExpiryTest { + + mockNotifier := newMockNotifier() + mockClock := clock.NewTestClock(testTime) + + test := &hodlExpiryTest{ + state: startState, + watcher: NewInvoiceExpiryWatcher( + mockClock, heightDelta, uint32(testCurrentHeight), nil, + mockNotifier, + ), + cancelChan: make(chan lntypes.Hash), + mockNotifier: mockNotifier, + mockClock: mockClock, + } + + // Use an unbuffered channel to block on cancel calls so that the test + // does not exit before we've processed all the invoices we expect. + cancelImpl := func(paymentHash lntypes.Hash, force bool) error { + test.stateLock.Lock() + currentState := test.state + test.stateLock.Unlock() + + if currentState != channeldb.ContractOpen && !force { + return nil + } + + select { + case test.cancelChan <- paymentHash: + case <-time.After(testTimeout): + } + + return nil + } + + require.NoError(t, test.watcher.Start(cancelImpl)) + + // We set preimage and hash so that we can use our existing test + // helpers. In practice we would only have the hash, but this does not + // affect what we're testing at all. + preimage := lntypes.Preimage{1} + test.hash = preimage.Hash() + + invoice := newTestInvoice(t, preimage, creationDate, expiry) + invoice.State = startState + invoice.HodlInvoice = true + invoice.Htlcs = make(map[channeldb.CircuitKey]*channeldb.InvoiceHTLC) + + // If we have any htlcs, add them with unique circult keys. + for i, htlc := range startHtlcs { + key := channeldb.CircuitKey{ + HtlcID: uint64(i), + } + + invoice.Htlcs[key] = htlc + } + + // Create an expiry entry for our invoice in its starting state. This + // mimics adding invoices to the watcher on start. + entry := makeInvoiceExpiry(test.hash, invoice) + test.watcher.AddInvoices(entry) + + return test +} diff --git a/lncfg/invoices.go b/lncfg/invoices.go new file mode 100644 index 00000000..16a52d88 --- /dev/null +++ b/lncfg/invoices.go @@ -0,0 +1,12 @@ +package lncfg + +// DefaultHoldInvoiceExpiryDelta defines the number of blocks before the expiry +// height of a hold invoice's htlc that lnd will automatically cancel the +// invoice to prevent the channel from force closing. This value *must* be +// greater than DefaultIncomingBroadcastDelta to prevent force closes. +const DefaultHoldInvoiceExpiryDelta = DefaultIncomingBroadcastDelta + 2 + +// Invoices holds the configuration options for invoices. +type Invoices struct { + HoldExpiryDelta uint32 `long:"holdexpirydelta" description:"The number of blocks before a hold invoice's htlc expires that the invoice should be canceled to prevent a force close. Force closes will not be prevented if this value is not greater than DefaultIncomingBroadcastDelta."` +} diff --git a/lntest/itest/lnd_hold_invoice_force_test.go b/lntest/itest/lnd_hold_invoice_force_test.go index 00831000..7a71ac74 100644 --- a/lntest/itest/lnd_hold_invoice_force_test.go +++ b/lntest/itest/lnd_hold_invoice_force_test.go @@ -14,9 +14,8 @@ import ( "github.com/stretchr/testify/require" ) -// testHoldInvoiceForceClose demonstrates that recipients of hold invoices -// will not release active htlcs for their own invoices when they expire, -// resulting in a force close of their channel. +// testHoldInvoiceForceClose tests cancelation of accepted hold invoices which +// would otherwise trigger force closes when they expire. func testHoldInvoiceForceClose(net *lntest.NetworkHarness, t *harnessTest) { ctxb, cancel := context.WithCancel(context.Background()) defer cancel() @@ -94,38 +93,43 @@ func testHoldInvoiceForceClose(net *lntest.NetworkHarness, t *harnessTest) { require.NoError(t.t, net.Alice.WaitForBlockchainSync(ctxb)) require.NoError(t.t, net.Bob.WaitForBlockchainSync(ctxb)) - // Alice should have a waiting-close channel because she has force - // closed to time out the htlc. - assertNumPendingChannels(t, net.Alice, 1, 0) - - // We should have our force close tx in the mempool. - mineBlocks(t, net, 1, 1) - - // Ensure alice and bob are synced to chain after we've mined our force - // close. - require.NoError(t.t, net.Alice.WaitForBlockchainSync(ctxb)) - require.NoError(t.t, net.Bob.WaitForBlockchainSync(ctxb)) - - // At this point, Bob's channel should be resolved because his htlc is - // expired, so no further action is required. Alice will still have a - // pending force close channel because she needs to resolve the htlc. - assertNumPendingChannels(t, net.Alice, 0, 1) - assertNumPendingChannels(t, net.Bob, 0, 0) - + // Our channel should not have been force closed, instead we expect our + // channel to still be open and our invoice to have been canceled before + // expiry. ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) - err = waitForNumChannelPendingForceClose(ctxt, net.Alice, 1, - func(channel *lnrpcForceCloseChannel) error { - numHtlcs := len(channel.PendingHtlcs) - if numHtlcs != 1 { - return fmt.Errorf("expected 1 htlc, got: "+ - "%v", numHtlcs) - } - - return nil - }, - ) + chanInfo, err := getChanInfo(ctxt, net.Alice) require.NoError(t.t, err) - // Cleanup Alice's force close. - cleanupForceClose(t, net, net.Alice, chanPoint) + fundingTxID, err := lnrpc.GetChanPointFundingTxid(chanPoint) + require.NoError(t.t, err) + chanStr := fmt.Sprintf("%v:%v", fundingTxID, chanPoint.OutputIndex) + require.Equal(t.t, chanStr, chanInfo.ChannelPoint) + + err = wait.NoError(func() error { + inv, err := net.Bob.LookupInvoice(ctxt, &lnrpc.PaymentHash{ + RHash: payHash[:], + }) + if err != nil { + return err + } + + if inv.State != lnrpc.Invoice_CANCELED { + return fmt.Errorf("expected canceled invoice, got: %v", + inv.State) + } + + for _, htlc := range inv.Htlcs { + if htlc.State != lnrpc.InvoiceHTLCState_CANCELED { + return fmt.Errorf("expected htlc canceled, "+ + "got: %v", htlc.State) + } + } + + return nil + }, defaultTimeout) + require.NoError(t.t, err, "expected canceled invoice") + + // Clean up the channel. + ctxt, _ = context.WithTimeout(ctxb, channelCloseTimeout) + closeChannelAndAssert(ctxt, t, net, net.Alice, chanPoint, false) } diff --git a/lntest/itest/lnd_multi-hop_remote_force_close_on_chain_htlc_timeout_test.go b/lntest/itest/lnd_multi-hop_remote_force_close_on_chain_htlc_timeout_test.go index 68547328..cf237c89 100644 --- a/lntest/itest/lnd_multi-hop_remote_force_close_on_chain_htlc_timeout_test.go +++ b/lntest/itest/lnd_multi-hop_remote_force_close_on_chain_htlc_timeout_test.go @@ -178,8 +178,8 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, err = waitForNumChannelPendingForceClose(ctxt, bob, 0, nil) require.NoError(t.t, err) - // While we're here, we demonstrate some bugs in our handling of - // invoices that timeout on chain. + // While we're here, we assert that our expired invoice's state is + // correctly updated, and can no longer be settled. assertOnChainInvoiceState(ctxb, t, carol, preimage) // We'll close out the test by closing the channel from Alice to Bob, @@ -191,12 +191,8 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, ) } -// assertOnChainInvoiceState asserts that we have some bugs with how we handle -// hold invoices that are expired on-chain. -// - htlcs accepted: despite being timed out, our htlcs are still in accepted -// state -// - can settle: our invoice that has expired on-chain can still be settled -// even though we don't claim any htlcs. +// assertOnChainInvoiceState asserts that we have the correct state for a hold +// invoice that has expired on chain, and that it can't be settled. func assertOnChainInvoiceState(ctx context.Context, t *harnessTest, node *lntest.HarnessNode, preimage lntypes.Preimage) { @@ -207,22 +203,12 @@ func assertOnChainInvoiceState(ctx context.Context, t *harnessTest, require.NoError(t.t, err) for _, htlc := range inv.Htlcs { - require.Equal(t.t, lnrpc.InvoiceHTLCState_ACCEPTED, htlc.State) + require.Equal(t.t, lnrpc.InvoiceHTLCState_CANCELED, htlc.State) } + require.Equal(t.t, lnrpc.Invoice_CANCELED, inv.State) _, err = node.SettleInvoice(ctx, &invoicesrpc.SettleInvoiceMsg{ Preimage: preimage[:], }) - require.NoError(t.t, err, "expected erroneous invoice settle") - - inv, err = node.LookupInvoice(ctx, &lnrpc.PaymentHash{ - RHash: hash[:], - }) - require.NoError(t.t, err) - - require.True(t.t, inv.Settled, "expected erroneously settled invoice") // nolint:staticcheck - for _, htlc := range inv.Htlcs { - require.Equal(t.t, lnrpc.InvoiceHTLCState_SETTLED, htlc.State, - "expected htlcs to be erroneously settled") - } + require.Error(t.t, err, "should not be able to settle invoice") } diff --git a/lntest/itest/log_error_whitelist.txt b/lntest/itest/log_error_whitelist.txt index 665b3266..31610f3d 100644 --- a/lntest/itest/log_error_whitelist.txt +++ b/lntest/itest/log_error_whitelist.txt @@ -279,3 +279,5 @@