diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index b38c5cad..467d7a89 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -2,6 +2,7 @@ package channeldb import ( "crypto/rand" + mrand "math/rand" "reflect" "testing" "time" @@ -393,6 +394,114 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } +// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their +// corresponding payment hashes. +func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { + t.Parallel() + + db, cleanup, err := makeTestDB() + defer cleanup() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + // With an empty DB we expect to return no error and an empty list. + empty, err := db.FetchAllInvoicesWithPaymentHash(false) + if err != nil { + t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v", + err) + } + + if len(empty) != 0 { + t.Fatalf("expected empty list as a result, got: %v", empty) + } + + // Now populate the DB and check if we can get all invoices with their + // payment hashes as expected. + const numInvoices = 20 + testPendingInvoices := make(map[lntypes.Hash]*Invoice) + testAllInvoices := make(map[lntypes.Hash]*Invoice) + + states := []ContractState{ + ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, + } + + for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ { + invoice, err := randInvoice(i) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + invoice.State = states[mrand.Intn(len(states))] + paymentHash := invoice.Terms.PaymentPreimage.Hash() + + if invoice.State != ContractSettled && invoice.State != ContractCanceled { + testPendingInvoices[paymentHash] = invoice + } + + testAllInvoices[paymentHash] = invoice + + if _, err := db.AddInvoice(invoice, paymentHash); err != nil { + t.Fatalf("unable to add invoice: %v", err) + } + } + + pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true) + if err != nil { + t.Fatalf("can't fetch invoices with payment hash: %v", err) + } + + if len(testPendingInvoices) != len(pendingInvoices) { + t.Fatalf("expected %v pending invoices, got: %v", + len(testPendingInvoices), len(pendingInvoices)) + } + + allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false) + if err != nil { + t.Fatalf("can't fetch invoices with payment hash: %v", err) + } + + if len(testAllInvoices) != len(allInvoices) { + t.Fatalf("expected %v invoices, got: %v", + len(testAllInvoices), len(allInvoices)) + } + + for i := range pendingInvoices { + expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash] + if !ok { + t.Fatalf("coulnd't find invoice with hash: %v", + pendingInvoices[i].PaymentHash) + } + + // Zero out add index to not confuse DeepEqual. + pendingInvoices[i].Invoice.AddIndex = 0 + expected.AddIndex = 0 + + if !reflect.DeepEqual(*expected, pendingInvoices[i].Invoice) { + t.Fatalf("expected: %v, got: %v", + spew.Sdump(expected), spew.Sdump(pendingInvoices[i].Invoice)) + } + } + + for i := range allInvoices { + expected, ok := testAllInvoices[allInvoices[i].PaymentHash] + if !ok { + t.Fatalf("coulnd't find invoice with hash: %v", + allInvoices[i].PaymentHash) + } + + // Zero out add index to not confuse DeepEqual. + allInvoices[i].Invoice.AddIndex = 0 + expected.AddIndex = 0 + + if !reflect.DeepEqual(*expected, allInvoices[i].Invoice) { + t.Fatalf("expected: %v, got: %v", + spew.Sdump(expected), spew.Sdump(allInvoices[i].Invoice)) + } + } + +} + // TestDuplicateSettleInvoice tests that if we add a new invoice and settle it // twice, then the second time we also receive the invoice that we settled as a // return argument. diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 8ded522f..14e6c451 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -565,6 +565,83 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { return invoice, nil } +// InvoiceWithPaymentHash is used to store an invoice and its corresponding +// payment hash. This struct is only used to store results of +// ChannelDB.FetchAllInvoicesWithPaymentHash() call. +type InvoiceWithPaymentHash struct { + // Invoice holds the invoice as selected from the invoices bucket. + Invoice Invoice + + // PaymentHash is the payment hash for the Invoice. + PaymentHash lntypes.Hash +} + +// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes +// currently stored within the database. If the pendingOnly param is true, then +// only unsettled invoices and their payment hashes will be returned, skipping +// all invoices that are fully settled or canceled. Note that the returned +// array is not ordered by add index. +func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( + []InvoiceWithPaymentHash, error) { + + var result []InvoiceWithPaymentHash + + err := d.View(func(tx *bbolt.Tx) error { + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + + invoiceIndex := invoices.Bucket(invoiceIndexBucket) + if invoiceIndex == nil { + // Mask the error if there's no invoice + // index as that simply means there are no + // invoices added yet to the DB. In this case + // we simply return an empty list. + return nil + } + + return invoiceIndex.ForEach(func(k, v []byte) error { + // Skip the special numInvoicesKey as that does not + // point to a valid invoice. + if bytes.Equal(k, numInvoicesKey) { + return nil + } + + if v == nil { + return nil + } + + invoice, err := fetchInvoice(v, invoices) + if err != nil { + return err + } + + if pendingOnly && + (invoice.State == ContractSettled || + invoice.State == ContractCanceled) { + + return nil + } + + invoiceWithPaymentHash := InvoiceWithPaymentHash{ + Invoice: invoice, + } + + copy(invoiceWithPaymentHash.PaymentHash[:], k) + result = append(result, invoiceWithPaymentHash) + + return nil + }) + }) + + if err != nil { + return nil, err + } + + return result, nil +} + // FetchAllInvoices returns all invoices currently stored within the database. // If the pendingOnly param is true, then only unsettled invoices will be // returned, skipping all invoices that are fully settled. diff --git a/clock/default_clock.go b/clock/default_clock.go new file mode 100644 index 00000000..3a4f8df3 --- /dev/null +++ b/clock/default_clock.go @@ -0,0 +1,24 @@ +package clock + +import ( + "time" +) + +// DefaultClock implements Clock interface by simply calling the appropriate +// time functions. +type DefaultClock struct{} + +// NewDefaultClock constructs a new DefaultClock. +func NewDefaultClock() Clock { + return &DefaultClock{} +} + +// Now simply returns time.Now(). +func (DefaultClock) Now() time.Time { + return time.Now() +} + +// TickAfter simply wraps time.After(). +func (DefaultClock) TickAfter(duration time.Duration) <-chan time.Time { + return time.After(duration) +} diff --git a/clock/interface.go b/clock/interface.go new file mode 100644 index 00000000..0450410e --- /dev/null +++ b/clock/interface.go @@ -0,0 +1,16 @@ +package clock + +import ( + "time" +) + +// Clock is an interface that provides a time functions for LND packages. +// This is useful during testing when a concrete time reference is needed. +type Clock interface { + // Now returns the current local time (as defined by the Clock). + Now() time.Time + + // TickAfter returns a channel that will receive a tick after the specified + // duration has passed. + TickAfter(duration time.Duration) <-chan time.Time +} diff --git a/invoices/clock_test.go b/clock/test_clock.go similarity index 65% rename from invoices/clock_test.go rename to clock/test_clock.go index 41dd4991..f4319cee 100644 --- a/invoices/clock_test.go +++ b/clock/test_clock.go @@ -1,42 +1,40 @@ -package invoices +package clock import ( "sync" "time" ) -// testClock can be used in tests to mock time. -type testClock struct { +// TestClock can be used in tests to mock time. +type TestClock struct { currentTime time.Time timeChanMap map[time.Time][]chan time.Time timeLock sync.Mutex } -// newTestClock returns a new test clock. -func newTestClock(startTime time.Time) *testClock { - return &testClock{ +// NewTestClock returns a new test clock. +func NewTestClock(startTime time.Time) *TestClock { + return &TestClock{ currentTime: startTime, timeChanMap: make(map[time.Time][]chan time.Time), } } -// now returns the current (test) time. -func (c *testClock) now() time.Time { +// Now returns the current (test) time. +func (c *TestClock) Now() time.Time { c.timeLock.Lock() defer c.timeLock.Unlock() return c.currentTime } -// tickAfter returns a channel that will receive a tick at the specified time. -func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time { +// TickAfter returns a channel that will receive a tick after the specified +// duration has passed passed by the user set test time. +func (c *TestClock) TickAfter(duration time.Duration) <-chan time.Time { c.timeLock.Lock() defer c.timeLock.Unlock() triggerTime := c.currentTime.Add(duration) - log.Debugf("tickAfter called: duration=%v, trigger_time=%v", - duration, triggerTime) - ch := make(chan time.Time, 1) // If already expired, tick immediately. @@ -53,8 +51,8 @@ func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time { return ch } -// setTime sets the (test) time and triggers tick channels when they expire. -func (c *testClock) setTime(now time.Time) { +// SetTime sets the (test) time and triggers tick channels when they expire. +func (c *TestClock) SetTime(now time.Time) { c.timeLock.Lock() defer c.timeLock.Unlock() diff --git a/clock/test_clock_test.go b/clock/test_clock_test.go new file mode 100644 index 00000000..879cc8fd --- /dev/null +++ b/clock/test_clock_test.go @@ -0,0 +1,63 @@ +package clock + +import ( + "testing" + "time" +) + +var ( + testTime = time.Date(2009, time.January, 3, 12, 0, 0, 0, time.UTC) +) + +func TestNow(t *testing.T) { + c := NewTestClock(testTime) + now := c.Now() + + if now != testTime { + t.Fatalf("expected: %v, got: %v", testTime, now) + } + + now = now.Add(time.Hour) + c.SetTime(now) + if c.Now() != now { + t.Fatalf("epected: %v, got: %v", now, c.Now()) + } +} + +func TestTickAfter(t *testing.T) { + c := NewTestClock(testTime) + + // Should be ticking immediately. + ticker0 := c.TickAfter(0) + + // Both should be ticking after SetTime + ticker1 := c.TickAfter(time.Hour) + ticker2 := c.TickAfter(time.Hour) + + // We don't expect this one to tick. + ticker3 := c.TickAfter(2 * time.Hour) + + tickOrTimeOut := func(ticker <-chan time.Time, expectTick bool) { + tick := false + select { + case <-ticker: + tick = true + case <-time.After(time.Millisecond): + } + + if tick != expectTick { + t.Fatalf("expected tick: %v, ticked: %v", expectTick, tick) + } + } + + tickOrTimeOut(ticker0, true) + tickOrTimeOut(ticker1, false) + tickOrTimeOut(ticker2, false) + tickOrTimeOut(ticker3, false) + + c.SetTime(c.Now().Add(time.Hour)) + + tickOrTimeOut(ticker1, true) + tickOrTimeOut(ticker2, true) + tickOrTimeOut(ticker3, false) +} diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index d3d61ca9..1d0731e8 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -22,6 +22,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -792,6 +793,7 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { registry := invoices.NewRegistry( cdb, + invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()), &invoices.RegistryConfig{ FinalCltvRejectDelta: 5, }, diff --git a/invoices/invoice_expiry_watcher.go b/invoices/invoice_expiry_watcher.go new file mode 100644 index 00000000..db9d9cd5 --- /dev/null +++ b/invoices/invoice_expiry_watcher.go @@ -0,0 +1,191 @@ +package invoices + +import ( + "fmt" + "sync" + "time" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/queue" + "github.com/lightningnetwork/lnd/zpay32" +) + +// invoiceExpiry holds and invoice's payment hash and its expiry. This +// is used to order invoices by their expiry for cancellation. +type invoiceExpiry struct { + PaymentHash lntypes.Hash + Expiry time.Time +} + +// Less implements PriorityQueueItem.Less such that the top item in the +// priorty queue will be the one that expires next. +func (e invoiceExpiry) Less(other queue.PriorityQueueItem) bool { + return e.Expiry.Before(other.(*invoiceExpiry).Expiry) +} + +// InvoiceExpiryWatcher handles automatic invoice cancellation of expried +// invoices. Upon start InvoiceExpiryWatcher will retrieve all pending (not yet +// settled or canceled) invoices invoices to its watcing queue. When a new +// invoice is added to the InvoiceRegistry, it'll be forarded to the +// InvoiceExpiryWatcher and will end up in the watching queue as well. +// If any of the watched invoices expire, they'll be removed from the watching +// queue and will be cancelled through InvoiceRegistry.CancelInvoice(). +type InvoiceExpiryWatcher struct { + sync.Mutex + started bool + + // clock is the clock implementation that InvoiceExpiryWatcher uses. + // It is useful for testing. + clock clock.Clock + + // cancelInvoice is a template method that cancels an expired invoice. + cancelInvoice func(lntypes.Hash) error + + // expiryQueue holds invoiceExpiry items and is used to find the next + // invoice to expire. + expiryQueue queue.PriorityQueue + + // newInvoices channel is used to wake up the main loop when a new invoices + // is added. + newInvoices chan *invoiceExpiry + + wg sync.WaitGroup + + // quit signals InvoiceExpiryWatcher to stop. + quit chan struct{} +} + +// NewInvoiceExpiryWatcher creates a new InvoiceExpiryWatcher instance. +func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher { + return &InvoiceExpiryWatcher{ + clock: clock, + newInvoices: make(chan *invoiceExpiry), + quit: make(chan struct{}), + } +} + +// Start starts the the subscription handler and the main loop. Start() will +// return with error if InvoiceExpiryWatcher is already started. Start() +// expects a cancellation function passed that will be use to cancel expired +// invoices by their payment hash. +func (ew *InvoiceExpiryWatcher) Start( + cancelInvoice func(lntypes.Hash) error) error { + + ew.Lock() + defer ew.Unlock() + + if ew.started { + return fmt.Errorf("InvoiceExpiryWatcher already started") + } + + ew.started = true + ew.cancelInvoice = cancelInvoice + ew.wg.Add(1) + go ew.mainLoop() + + return nil +} + +// Stop quits the expiry handler loop and waits for InvoiceExpiryWatcher to +// fully stop. +func (ew *InvoiceExpiryWatcher) Stop() { + ew.Lock() + defer ew.Unlock() + + if ew.started { + // Signal subscriptionHandler to quit and wait for it to return. + close(ew.quit) + ew.wg.Wait() + ew.started = false + } +} + +// AddInvoice adds a new invoice to the InvoiceExpiryWatcher. This won't check +// if the invoice is already added and will only add invoices with ContractOpen +// state. +func (ew *InvoiceExpiryWatcher) AddInvoice( + paymentHash lntypes.Hash, invoice *channeldb.Invoice) { + + if invoice.State != channeldb.ContractOpen { + log.Debugf("Invoice not added to expiry watcher: %v", invoice) + return + } + + realExpiry := invoice.Terms.Expiry + if realExpiry == 0 { + realExpiry = zpay32.DefaultInvoiceExpiry + } + + expiry := invoice.CreationDate.Add(realExpiry) + + log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v", + paymentHash, expiry) + + newInvoiceExpiry := &invoiceExpiry{ + PaymentHash: paymentHash, + Expiry: expiry, + } + + select { + case ew.newInvoices <- newInvoiceExpiry: + case <-ew.quit: + // Select on quit too so that callers won't get blocked in case + // of concurrent shutdown. + } +} + +// nextExpiry returns a Time chan to wait on until the next invoice expires. +// If there are no active invoices, then it'll simply wait indefinitely. +func (ew *InvoiceExpiryWatcher) nextExpiry() <-chan time.Time { + if !ew.expiryQueue.Empty() { + top := ew.expiryQueue.Top().(*invoiceExpiry) + return ew.clock.TickAfter(top.Expiry.Sub(ew.clock.Now())) + } + + return nil +} + +// cancelExpiredInvoices will cancel all expired invoices and removes them from +// the expiry queue. +func (ew *InvoiceExpiryWatcher) cancelExpiredInvoices() { + for !ew.expiryQueue.Empty() { + top := ew.expiryQueue.Top().(*invoiceExpiry) + if !top.Expiry.Before(ew.clock.Now()) { + break + } + + err := ew.cancelInvoice(top.PaymentHash) + if err != nil && err != channeldb.ErrInvoiceAlreadySettled && + err != channeldb.ErrInvoiceAlreadyCanceled { + + log.Errorf("Unable to cancel invoice: %v", top.PaymentHash) + } + + ew.expiryQueue.Pop() + } +} + +// mainLoop is a goroutine that receives new invoices and handles cancellation +// of expired invoices. +func (ew *InvoiceExpiryWatcher) mainLoop() { + defer ew.wg.Done() + + for { + // Cancel any invoices that may have expired. + ew.cancelExpiredInvoices() + + select { + case <-ew.nextExpiry(): + // Wait until the next invoice expires, then cancel expired invoices. + continue + + case newInvoiceExpiry := <-ew.newInvoices: + ew.expiryQueue.Push(newInvoiceExpiry) + + case <-ew.quit: + return + } + } +} diff --git a/invoices/invoice_expiry_watcher_test.go b/invoices/invoice_expiry_watcher_test.go new file mode 100644 index 00000000..8bfdfd69 --- /dev/null +++ b/invoices/invoice_expiry_watcher_test.go @@ -0,0 +1,125 @@ +package invoices + +import ( + "testing" + "time" + + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" +) + +// invoiceExpiryWatcherTest holds a test fixture and implements checks +// for InvoiceExpiryWatcher tests. +type invoiceExpiryWatcherTest struct { + t *testing.T + watcher *InvoiceExpiryWatcher + testData invoiceExpiryTestData + canceledInvoices []lntypes.Hash +} + +// newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture +// and sets up the test environment. +func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time, + numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest { + + test := &invoiceExpiryWatcherTest{ + watcher: NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)), + testData: generateInvoiceExpiryTestData( + t, now, 0, numExpiredInvoices, numPendingInvoices, + ), + } + + err := test.watcher.Start(func(paymentHash lntypes.Hash) error { + test.canceledInvoices = append(test.canceledInvoices, paymentHash) + return nil + }) + + if err != nil { + t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err) + } + + return test +} + +func (t *invoiceExpiryWatcherTest) checkExpectations() { + // Check that invoices that got canceled during the test are the ones + // that expired. + if len(t.canceledInvoices) != len(t.testData.expiredInvoices) { + t.t.Fatalf("expected %v cancellations, got %v", + len(t.testData.expiredInvoices), len(t.canceledInvoices)) + } + + for i := range t.canceledInvoices { + if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok { + t.t.Fatalf("wrong invoice canceled") + } + } +} + +// Tests that InvoiceExpiryWatcher can be started and stopped. +func TestInvoiceExpiryWatcherStartStop(t *testing.T) { + watcher := NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)) + cancel := func(lntypes.Hash) error { + t.Fatalf("unexpected call") + return nil + } + + if err := watcher.Start(cancel); err != nil { + t.Fatalf("unexpected error upon start: %v", err) + } + + if err := watcher.Start(cancel); err == nil { + t.Fatalf("expected error upon second start") + } + + watcher.Stop() + + if err := watcher.Start(cancel); err != nil { + t.Fatalf("unexpected error upon start: %v", err) + } +} + +// Tests that no invoices will expire from an empty InvoiceExpiryWatcher. +func TestInvoiceExpiryWithNoInvoices(t *testing.T) { + t.Parallel() + test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0) + + time.Sleep(testTimeout) + test.watcher.Stop() + test.checkExpectations() +} + +// Tests that if all add invoices are expired, then all invoices +// will be canceled. +func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) { + t.Parallel() + + test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5) + + for paymentHash, invoice := range test.testData.pendingInvoices { + test.watcher.AddInvoice(paymentHash, invoice) + } + + time.Sleep(testTimeout) + test.watcher.Stop() + test.checkExpectations() +} + +// Tests that if some invoices are expired, then those invoices +// will be canceled. +func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) { + t.Parallel() + test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) + + for paymentHash, invoice := range test.testData.expiredInvoices { + test.watcher.AddInvoice(paymentHash, invoice) + } + + for paymentHash, invoice := range test.testData.pendingInvoices { + test.watcher.AddInvoice(paymentHash, invoice) + } + + time.Sleep(testTimeout) + test.watcher.Stop() + test.checkExpectations() +} diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index bfdc54a6..d4e77ad5 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -9,6 +9,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/queue" @@ -62,12 +63,10 @@ type RegistryConfig struct { // waiting for the other set members to arrive. HtlcHoldDuration time.Duration - // Now returns the current time. - Now func() time.Time - - // TickAfter returns a channel that is sent on after the specified - // duration as passed. - TickAfter func(duration time.Duration) <-chan time.Time + // Clock holds the clock implementation that is used to provide + // Now() and TickAfter() and is useful to stub out the clock functions + // during testing. + Clock clock.Clock } // htlcReleaseEvent describes an htlc auto-release event. It is used to release @@ -126,6 +125,8 @@ type InvoiceRegistry struct { // auto-released. htlcAutoReleaseChan chan *htlcReleaseEvent + expiryWatcher *InvoiceExpiryWatcher + wg sync.WaitGroup quit chan struct{} } @@ -134,7 +135,9 @@ type InvoiceRegistry struct { // wraps the persistent on-disk invoice storage with an additional in-memory // layer. The in-memory layer is in place such that debug invoices can be added // which are volatile yet available system wide within the daemon. -func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry { +func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher, + cfg *RegistryConfig) *InvoiceRegistry { + return &InvoiceRegistry{ cdb: cdb, notificationClients: make(map[uint32]*InvoiceSubscription), @@ -146,21 +149,62 @@ func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry { hodlReverseSubscriptions: make(map[chan<- interface{}]map[channeldb.CircuitKey]struct{}), cfg: cfg, htlcAutoReleaseChan: make(chan *htlcReleaseEvent), + expiryWatcher: expiryWatcher, quit: make(chan struct{}), } } +// populateExpiryWatcher fetches all active invoices and their corresponding +// payment hashes from ChannelDB and adds them to the expiry watcher. +func (i *InvoiceRegistry) populateExpiryWatcher() error { + pendingOnly := true + pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly) + if err != nil && err != channeldb.ErrNoInvoicesCreated { + log.Errorf( + "Error while prefetching active invoices from the database: %v", err, + ) + return err + } + + for idx := range pendingInvoices { + i.expiryWatcher.AddInvoice( + pendingInvoices[idx].PaymentHash, &pendingInvoices[idx].Invoice, + ) + } + + return nil +} + // Start starts the registry and all goroutines it needs to carry out its task. func (i *InvoiceRegistry) Start() error { - i.wg.Add(1) + // Start InvoiceExpiryWatcher and prepopulate it with existing active + // invoices. + err := i.expiryWatcher.Start(func(paymentHash lntypes.Hash) error { + cancelIfAccepted := false + return i.cancelInvoiceImpl(paymentHash, cancelIfAccepted) + }) + if err != nil { + return err + } + + i.wg.Add(1) go i.invoiceEventLoop() + // Now prefetch all pending invoices to the expiry watcher. + err = i.populateExpiryWatcher() + if err != nil { + i.Stop() + return err + } + return nil } // Stop signals the registry for a graceful shutdown. func (i *InvoiceRegistry) Stop() { + i.expiryWatcher.Stop() + close(i.quit) i.wg.Wait() @@ -177,8 +221,8 @@ type invoiceEvent struct { // tickAt returns a channel that ticks at the specified time. If the time has // already passed, it will tick immediately. func (i *InvoiceRegistry) tickAt(t time.Time) <-chan time.Time { - now := i.cfg.Now() - return i.cfg.TickAfter(t.Sub(now)) + now := i.cfg.Clock.Now() + return i.cfg.Clock.TickAfter(t.Sub(now)) } // invoiceEventLoop is the dedicated goroutine responsible for accepting @@ -471,7 +515,6 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, paymentHash lntypes.Hash) (uint64, error) { i.Lock() - defer i.Unlock() log.Debugf("Invoice(%v): added %v", paymentHash, newLogClosure(func() string { @@ -481,12 +524,19 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, addIndex, err := i.cdb.AddInvoice(invoice, paymentHash) if err != nil { + i.Unlock() return 0, err } // Now that we've added the invoice, we'll send dispatch a message to // notify the clients of this new invoice. i.notifyClients(paymentHash, invoice, channeldb.ContractOpen) + i.Unlock() + + // InvoiceExpiryWatcher.AddInvoice must not be locked by InvoiceRegistry + // to avoid deadlock when a new invoice is added while an other is being + // canceled. + i.expiryWatcher.AddInvoice(paymentHash, invoice) return addIndex, nil } @@ -818,6 +868,15 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error { // CancelInvoice attempts to cancel the invoice corresponding to the passed // payment hash. func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { + return i.cancelInvoiceImpl(payHash, true) +} + +// cancelInvoice attempts to cancel the invoice corresponding to the passed +// payment hash. Accepted invoices will only be canceled if explicitly +// requested to do so. +func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, + cancelAccepted bool) error { + i.Lock() defer i.Unlock() @@ -826,6 +885,12 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { updateInvoice := func(invoice *channeldb.Invoice) ( *channeldb.InvoiceUpdateDesc, error) { + // Only cancel the invoice in ContractAccepted state if explicitly + // requested to do so. + if invoice.State == channeldb.ContractAccepted && !cancelAccepted { + return nil, nil + } + // Move invoice to the canceled state. Rely on validation in // channeldb to return an error if the invoice is already // settled or canceled. @@ -848,6 +913,13 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { return err } + // Return without cancellation if the invoice state is ContractAccepted. + if invoice.State == channeldb.ContractAccepted { + log.Debugf("Invoice(%v): remains accepted as cancel wasn't"+ + "explicitly requested.", payHash) + return nil + } + log.Debugf("Invoice(%v): canceled", payHash) // In the callback, some htlcs may have been moved to the canceled diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index eabb2ba2..e47eaa37 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -1,117 +1,16 @@ package invoices import ( - "io/ioutil" - "os" "testing" "time" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" ) -var ( - testTimeout = 5 * time.Second - - testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC) - - preimage = lntypes.Preimage{ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - } - - hash = preimage.Hash() - - testHtlcExpiry = uint32(5) - - testInvoiceCltvDelta = uint32(4) - - testFinalCltvRejectDelta = int32(4) - - testCurrentHeight = int32(1) - - testFeatures = lnwire.NewFeatureVector( - nil, lnwire.Features, - ) - - testPayload = &mockPayload{} -) - -var ( - testInvoiceAmt = lnwire.MilliSatoshi(100000) - testInvoice = &channeldb.Invoice{ - Terms: channeldb.ContractTerm{ - PaymentPreimage: preimage, - Value: lnwire.MilliSatoshi(100000), - Features: testFeatures, - }, - } - - testHodlInvoice = &channeldb.Invoice{ - Terms: channeldb.ContractTerm{ - PaymentPreimage: channeldb.UnknownPreimage, - Value: testInvoiceAmt, - Features: testFeatures, - }, - } -) - -type testContext struct { - registry *InvoiceRegistry - clock *testClock - - cleanup func() - t *testing.T -} - -func newTestContext(t *testing.T) *testContext { - clock := newTestClock(testTime) - - cdb, cleanup, err := newDB() - if err != nil { - t.Fatal(err) - } - cdb.Now = clock.now - - // Instantiate and start the invoice ctx.registry. - cfg := RegistryConfig{ - FinalCltvRejectDelta: testFinalCltvRejectDelta, - HtlcHoldDuration: 30 * time.Second, - Now: clock.now, - TickAfter: clock.tickAfter, - } - registry := NewRegistry(cdb, &cfg) - - err = registry.Start() - if err != nil { - cleanup() - t.Fatal(err) - } - - ctx := testContext{ - registry: registry, - clock: clock, - t: t, - cleanup: func() { - registry.Stop() - cleanup() - }, - } - - return &ctx -} - -func getCircuitKey(htlcID uint64) channeldb.CircuitKey { - return channeldb.CircuitKey{ - ChanID: lnwire.ShortChannelID{ - BlockHeight: 1, TxIndex: 2, TxPosition: 3, - }, - HtlcID: htlcID, - } -} - // TestSettleInvoice tests settling of an invoice and related notifications. func TestSettleInvoice(t *testing.T) { ctx := newTestContext(t) @@ -121,18 +20,18 @@ func TestSettleInvoice(t *testing.T) { defer allSubscriptions.Cancel() // Subscribe to the not yet existing invoice. - subscription, err := ctx.registry.SubscribeSingleInvoice(hash) + subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash) if err != nil { t.Fatal(err) } defer subscription.Cancel() - if subscription.hash != hash { + if subscription.hash != testInvoicePaymentHash { t.Fatalf("expected subscription for provided hash") } // Add the invoice. - addIdx, err := ctx.registry.AddInvoice(testInvoice, hash) + addIdx, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -168,7 +67,7 @@ func TestSettleInvoice(t *testing.T) { // Try to settle invoice with an htlc that expires too soon. event, err := ctx.registry.NotifyExitHopHtlc( - hash, testInvoice.Terms.Value, + testInvoicePaymentHash, testInvoice.Terms.Value, uint32(testCurrentHeight)+testInvoiceCltvDelta-1, testCurrentHeight, getCircuitKey(10), hodlChan, testPayload, ) @@ -186,7 +85,7 @@ func TestSettleInvoice(t *testing.T) { // Settle invoice with a slightly higher amount. amtPaid := lnwire.MilliSatoshi(100500) _, err = ctx.registry.NotifyExitHopHtlc( - hash, amtPaid, testHtlcExpiry, testCurrentHeight, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { @@ -222,7 +121,7 @@ func TestSettleInvoice(t *testing.T) { // Try to settle again with the same htlc id. We need this idempotent // behaviour after a restart. event, err = ctx.registry.NotifyExitHopHtlc( - hash, amtPaid, testHtlcExpiry, testCurrentHeight, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { @@ -236,7 +135,7 @@ func TestSettleInvoice(t *testing.T) { // should also be accepted, to prevent any change in behaviour for a // paid invoice that may open up a probe vector. event, err = ctx.registry.NotifyExitHopHtlc( - hash, amtPaid+600, testHtlcExpiry, testCurrentHeight, + testInvoicePaymentHash, amtPaid+600, testHtlcExpiry, testCurrentHeight, getCircuitKey(1), hodlChan, testPayload, ) if err != nil { @@ -249,7 +148,7 @@ func TestSettleInvoice(t *testing.T) { // Try to settle again with a lower amount. This should fail just as it // would have failed if it were the first payment. event, err = ctx.registry.NotifyExitHopHtlc( - hash, amtPaid-600, testHtlcExpiry, testCurrentHeight, + testInvoicePaymentHash, amtPaid-600, testHtlcExpiry, testCurrentHeight, getCircuitKey(2), hodlChan, testPayload, ) if err != nil { @@ -261,7 +160,7 @@ func TestSettleInvoice(t *testing.T) { // Check that settled amount is equal to the sum of values of the htlcs // 0 and 1. - inv, err := ctx.registry.LookupInvoice(hash) + inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -270,7 +169,7 @@ func TestSettleInvoice(t *testing.T) { } // Try to cancel. - err = ctx.registry.CancelInvoice(hash) + err = ctx.registry.CancelInvoice(testInvoicePaymentHash) if err != channeldb.ErrInvoiceAlreadySettled { t.Fatal("expected cancelation of a settled invoice to fail") } @@ -292,25 +191,25 @@ func TestCancelInvoice(t *testing.T) { defer allSubscriptions.Cancel() // Try to cancel the not yet existing invoice. This should fail. - err := ctx.registry.CancelInvoice(hash) + err := ctx.registry.CancelInvoice(testInvoicePaymentHash) if err != channeldb.ErrInvoiceNotFound { t.Fatalf("expected ErrInvoiceNotFound, but got %v", err) } // Subscribe to the not yet existing invoice. - subscription, err := ctx.registry.SubscribeSingleInvoice(hash) + subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash) if err != nil { t.Fatal(err) } defer subscription.Cancel() - if subscription.hash != hash { + if subscription.hash != testInvoicePaymentHash { t.Fatalf("expected subscription for provided hash") } // Add the invoice. amt := lnwire.MilliSatoshi(100000) - _, err = ctx.registry.AddInvoice(testInvoice, hash) + _, err = ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -342,7 +241,7 @@ func TestCancelInvoice(t *testing.T) { } // Cancel invoice. - err = ctx.registry.CancelInvoice(hash) + err = ctx.registry.CancelInvoice(testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -365,7 +264,7 @@ func TestCancelInvoice(t *testing.T) { // subscribers (backwards compatibility). // Try to cancel again. - err = ctx.registry.CancelInvoice(hash) + err = ctx.registry.CancelInvoice(testInvoicePaymentHash) if err != nil { t.Fatal("expected cancelation of a canceled invoice to succeed") } @@ -374,7 +273,7 @@ func TestCancelInvoice(t *testing.T) { // result in a cancel event. hodlChan := make(chan interface{}) event, err := ctx.registry.NotifyExitHopHtlc( - hash, amt, testHtlcExpiry, testCurrentHeight, + testInvoicePaymentHash, amt, testHtlcExpiry, testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { @@ -393,9 +292,9 @@ func TestCancelInvoice(t *testing.T) { // TestSettleHoldInvoice tests settling of a hold invoice and related // notifications. func TestSettleHoldInvoice(t *testing.T) { - defer timeout(t)() + defer timeout()() - cdb, cleanup, err := newDB() + cdb, cleanup, err := newTestChannelDB() if err != nil { t.Fatal(err) } @@ -404,8 +303,9 @@ func TestSettleHoldInvoice(t *testing.T) { // Instantiate and start the invoice ctx.registry. cfg := RegistryConfig{ FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: clock.NewTestClock(testTime), } - registry := NewRegistry(cdb, &cfg) + registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg) err = registry.Start() if err != nil { @@ -417,18 +317,18 @@ func TestSettleHoldInvoice(t *testing.T) { defer allSubscriptions.Cancel() // Subscribe to the not yet existing invoice. - subscription, err := registry.SubscribeSingleInvoice(hash) + subscription, err := registry.SubscribeSingleInvoice(testInvoicePaymentHash) if err != nil { t.Fatal(err) } defer subscription.Cancel() - if subscription.hash != hash { + if subscription.hash != testInvoicePaymentHash { t.Fatalf("expected subscription for provided hash") } // Add the invoice. - _, err = registry.AddInvoice(testHodlInvoice, hash) + _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -455,7 +355,7 @@ func TestSettleHoldInvoice(t *testing.T) { // NotifyExitHopHtlc without a preimage present in the invoice registry // should be possible. event, err := registry.NotifyExitHopHtlc( - hash, amtPaid, testHtlcExpiry, testCurrentHeight, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { @@ -467,7 +367,7 @@ func TestSettleHoldInvoice(t *testing.T) { // Test idempotency. event, err = registry.NotifyExitHopHtlc( - hash, amtPaid, testHtlcExpiry, testCurrentHeight, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { @@ -480,7 +380,7 @@ func TestSettleHoldInvoice(t *testing.T) { // Test replay at a higher height. We expect the same result because it // is a replay. event, err = registry.NotifyExitHopHtlc( - hash, amtPaid, testHtlcExpiry, testCurrentHeight+10, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+10, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { @@ -493,7 +393,7 @@ func TestSettleHoldInvoice(t *testing.T) { // Test a new htlc coming in that doesn't meet the final cltv delta // requirement. It should be rejected. event, err = registry.NotifyExitHopHtlc( - hash, amtPaid, 1, testCurrentHeight, + testInvoicePaymentHash, amtPaid, 1, testCurrentHeight, getCircuitKey(1), hodlChan, testPayload, ) if err != nil { @@ -516,13 +416,13 @@ func TestSettleHoldInvoice(t *testing.T) { } // Settling with preimage should succeed. - err = registry.SettleHodlInvoice(preimage) + err = registry.SettleHodlInvoice(testInvoicePreimage) if err != nil { t.Fatal("expected set preimage to succeed") } hodlEvent := (<-hodlChan).(HodlEvent) - if *hodlEvent.Preimage != preimage { + if *hodlEvent.Preimage != testInvoicePreimage { t.Fatal("unexpected preimage in hodl event") } if hodlEvent.AcceptHeight != testCurrentHeight { @@ -549,13 +449,13 @@ func TestSettleHoldInvoice(t *testing.T) { } // Idempotency. - err = registry.SettleHodlInvoice(preimage) + err = registry.SettleHodlInvoice(testInvoicePreimage) if err != channeldb.ErrInvoiceAlreadySettled { t.Fatalf("expected ErrInvoiceAlreadySettled but got %v", err) } // Try to cancel. - err = registry.CancelInvoice(hash) + err = registry.CancelInvoice(testInvoicePaymentHash) if err == nil { t.Fatal("expected cancelation of a settled invoice to fail") } @@ -564,9 +464,9 @@ func TestSettleHoldInvoice(t *testing.T) { // TestCancelHoldInvoice tests canceling of a hold invoice and related // notifications. func TestCancelHoldInvoice(t *testing.T) { - defer timeout(t)() + defer timeout()() - cdb, cleanup, err := newDB() + cdb, cleanup, err := newTestChannelDB() if err != nil { t.Fatal(err) } @@ -575,8 +475,9 @@ func TestCancelHoldInvoice(t *testing.T) { // Instantiate and start the invoice ctx.registry. cfg := RegistryConfig{ FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: clock.NewTestClock(testTime), } - registry := NewRegistry(cdb, &cfg) + registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg) err = registry.Start() if err != nil { @@ -585,7 +486,7 @@ func TestCancelHoldInvoice(t *testing.T) { defer registry.Stop() // Add the invoice. - _, err = registry.AddInvoice(testHodlInvoice, hash) + _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -596,7 +497,7 @@ func TestCancelHoldInvoice(t *testing.T) { // NotifyExitHopHtlc without a preimage present in the invoice registry // should be possible. event, err := registry.NotifyExitHopHtlc( - hash, amtPaid, testHtlcExpiry, testCurrentHeight, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { @@ -607,7 +508,7 @@ func TestCancelHoldInvoice(t *testing.T) { } // Cancel invoice. - err = registry.CancelInvoice(hash) + err = registry.CancelInvoice(testInvoicePaymentHash) if err != nil { t.Fatal("cancel invoice failed") } @@ -621,7 +522,7 @@ func TestCancelHoldInvoice(t *testing.T) { // in a rejection. The accept height is expected to be the original // accept height. event, err = registry.NotifyExitHopHtlc( - hash, amtPaid, testHtlcExpiry, testCurrentHeight+1, + testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+1, getCircuitKey(0), hodlChan, testPayload, ) if err != nil { @@ -636,29 +537,6 @@ func TestCancelHoldInvoice(t *testing.T) { } } -func newDB() (*channeldb.DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) - if err != nil { - os.RemoveAll(tempDirName) - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) - } - - return cdb, cleanUp, nil -} - // TestUnknownInvoice tests that invoice registry returns an error when the // invoice is unknown. This is to guard against returning a cancel hodl event // for forwarded htlcs. In the link, NotifyExitHopHtlc is only called if we are @@ -673,7 +551,7 @@ func TestUnknownInvoice(t *testing.T) { hodlChan := make(chan interface{}) amt := lnwire.MilliSatoshi(100000) _, err := ctx.registry.NotifyExitHopHtlc( - hash, amt, testHtlcExpiry, testCurrentHeight, + testInvoicePaymentHash, amt, testHtlcExpiry, testCurrentHeight, getCircuitKey(0), hodlChan, testPayload, ) if err != channeldb.ErrInvoiceNotFound { @@ -681,27 +559,15 @@ func TestUnknownInvoice(t *testing.T) { } } -type mockPayload struct { - mpp *record.MPP -} - -func (p *mockPayload) MultiPath() *record.MPP { - return p.mpp -} - -func (p *mockPayload) CustomRecords() record.CustomSet { - return make(record.CustomSet) -} - // TestSettleMpp tests settling of an invoice with multiple partial payments. func TestSettleMpp(t *testing.T) { - defer timeout(t)() + defer timeout()() ctx := newTestContext(t) defer ctx.cleanup() // Add the invoice. - _, err := ctx.registry.AddInvoice(testInvoice, hash) + _, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -713,7 +579,7 @@ func TestSettleMpp(t *testing.T) { // Send htlc 1. hodlChan1 := make(chan interface{}, 1) event, err := ctx.registry.NotifyExitHopHtlc( - hash, testInvoice.Terms.Value/2, + testInvoicePaymentHash, testInvoice.Terms.Value/2, testHtlcExpiry, testCurrentHeight, getCircuitKey(10), hodlChan1, mppPayload, ) @@ -725,7 +591,7 @@ func TestSettleMpp(t *testing.T) { } // Simulate mpp timeout releasing htlc 1. - ctx.clock.setTime(testTime.Add(30 * time.Second)) + ctx.clock.SetTime(testTime.Add(30 * time.Second)) hodlEvent := (<-hodlChan1).(HodlEvent) if hodlEvent.Preimage != nil { @@ -735,7 +601,7 @@ func TestSettleMpp(t *testing.T) { // Send htlc 2. hodlChan2 := make(chan interface{}, 1) event, err = ctx.registry.NotifyExitHopHtlc( - hash, testInvoice.Terms.Value/2, + testInvoicePaymentHash, testInvoice.Terms.Value/2, testHtlcExpiry, testCurrentHeight, getCircuitKey(11), hodlChan2, mppPayload, ) @@ -749,7 +615,7 @@ func TestSettleMpp(t *testing.T) { // Send htlc 3. hodlChan3 := make(chan interface{}, 1) event, err = ctx.registry.NotifyExitHopHtlc( - hash, testInvoice.Terms.Value/2, + testInvoicePaymentHash, testInvoice.Terms.Value/2, testHtlcExpiry, testCurrentHeight, getCircuitKey(12), hodlChan3, mppPayload, ) @@ -762,7 +628,7 @@ func TestSettleMpp(t *testing.T) { // Check that settled amount is equal to the sum of values of the htlcs // 0 and 1. - inv, err := ctx.registry.LookupInvoice(hash) + inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash) if err != nil { t.Fatal(err) } @@ -774,3 +640,105 @@ func TestSettleMpp(t *testing.T) { testInvoice.Terms.Value, inv.AmtPaid) } } + +// Tests that invoices are canceled after expiration. +func TestInvoiceExpiryWithRegistry(t *testing.T) { + t.Parallel() + + cdb, cleanup, err := newTestChannelDB() + defer cleanup() + + if err != nil { + t.Fatal(err) + } + + testClock := clock.NewTestClock(testTime) + + cfg := RegistryConfig{ + FinalCltvRejectDelta: testFinalCltvRejectDelta, + Clock: testClock, + } + + expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock) + registry := NewRegistry(cdb, expiryWatcher, &cfg) + + // First prefill the Channel DB with some pre-existing invoices, + // half of them still pending, half of them expired. + const numExpired = 5 + const numPending = 5 + existingInvoices := generateInvoiceExpiryTestData( + t, testTime, 0, numExpired, numPending, + ) + + var expectedCancellations []lntypes.Hash + + for paymentHash, expiredInvoice := range existingInvoices.expiredInvoices { + if _, err := cdb.AddInvoice(expiredInvoice, paymentHash); err != nil { + t.Fatalf("cannot add invoice to channel db: %v", err) + } + expectedCancellations = append(expectedCancellations, paymentHash) + } + + for paymentHash, pendingInvoice := range existingInvoices.pendingInvoices { + if _, err := cdb.AddInvoice(pendingInvoice, paymentHash); err != nil { + t.Fatalf("cannot add invoice to channel db: %v", err) + } + } + + if err = registry.Start(); err != nil { + t.Fatalf("cannot start registry: %v", err) + } + + // Now generate pending and invoices and add them to the registry while + // it is up and running. We'll manipulate the clock to let them expire. + newInvoices := generateInvoiceExpiryTestData( + t, testTime, numExpired+numPending, 0, numPending, + ) + + var invoicesThatWillCancel []lntypes.Hash + for paymentHash, pendingInvoice := range newInvoices.pendingInvoices { + _, err := registry.AddInvoice(pendingInvoice, paymentHash) + invoicesThatWillCancel = append(invoicesThatWillCancel, paymentHash) + if err != nil { + t.Fatal(err) + } + } + + // Check that they are really not canceled until before the clock is + // advanced. + for i := range invoicesThatWillCancel { + invoice, err := registry.LookupInvoice(invoicesThatWillCancel[i]) + if err != nil { + t.Fatalf("cannot find invoice: %v", err) + } + + if invoice.State == channeldb.ContractCanceled { + t.Fatalf("expected pending invoice, got canceled") + } + } + + // Fwd time 1 day. + testClock.SetTime(testTime.Add(24 * time.Hour)) + + // Give some time to the watcher to cancel everything. + time.Sleep(testTimeout) + registry.Stop() + + // Create the expected cancellation set before the final check. + expectedCancellations = append( + expectedCancellations, invoicesThatWillCancel..., + ) + + // Retrospectively check that all invoices that were expected to be canceled + // are indeed canceled. + for i := range expectedCancellations { + invoice, err := registry.LookupInvoice(expectedCancellations[i]) + if err != nil { + t.Fatalf("cannot find invoice: %v", err) + } + + if invoice.State != channeldb.ContractCanceled { + t.Fatalf("expected canceled invoice, got: %v", invoice.State) + } + } +} diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go new file mode 100644 index 00000000..d6298543 --- /dev/null +++ b/invoices/test_utils_test.go @@ -0,0 +1,280 @@ +package invoices + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "io/ioutil" + "os" + "runtime/pprof" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/zpay32" +) + +type mockPayload struct { + mpp *record.MPP +} + +func (p *mockPayload) MultiPath() *record.MPP { + return p.mpp +} + +func (p *mockPayload) CustomRecords() record.CustomSet { + return make(record.CustomSet) +} + +var ( + testTimeout = 5 * time.Second + + testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC) + + testInvoicePreimage = lntypes.Preimage{1} + + testInvoicePaymentHash = testInvoicePreimage.Hash() + + testHtlcExpiry = uint32(5) + + testInvoiceCltvDelta = uint32(4) + + testFinalCltvRejectDelta = int32(4) + + testCurrentHeight = int32(1) + + testPrivKeyBytes, _ = hex.DecodeString( + "e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734") + + testPrivKey, _ = btcec.PrivKeyFromBytes( + btcec.S256(), testPrivKeyBytes) + + testInvoiceDescription = "coffee" + + testInvoiceAmount = lnwire.MilliSatoshi(100000) + + testNetParams = &chaincfg.MainNetParams + + testMessageSigner = zpay32.MessageSigner{ + SignCompact: func(hash []byte) ([]byte, error) { + sig, err := btcec.SignCompact(btcec.S256(), testPrivKey, hash, true) + if err != nil { + return nil, fmt.Errorf("can't sign the message: %v", err) + } + return sig, nil + }, + } + + testFeatures = lnwire.NewFeatureVector( + nil, lnwire.Features, + ) + + testPayload = &mockPayload{} + + testInvoiceCreationDate = testTime +) + +var ( + testInvoiceAmt = lnwire.MilliSatoshi(100000) + testInvoice = &channeldb.Invoice{ + Terms: channeldb.ContractTerm{ + PaymentPreimage: testInvoicePreimage, + Value: testInvoiceAmt, + Expiry: time.Hour, + Features: testFeatures, + }, + CreationDate: testInvoiceCreationDate, + } + + testHodlInvoice = &channeldb.Invoice{ + Terms: channeldb.ContractTerm{ + PaymentPreimage: channeldb.UnknownPreimage, + Value: testInvoiceAmt, + Expiry: time.Hour, + Features: testFeatures, + }, + CreationDate: testInvoiceCreationDate, + } +) + +func newTestChannelDB() (*channeldb.DB, func(), error) { + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + return nil, nil, err + } + + // Next, create channeldb for the first time. + cdb, err := channeldb.Open(tempDirName) + if err != nil { + os.RemoveAll(tempDirName) + return nil, nil, err + } + + cleanUp := func() { + cdb.Close() + os.RemoveAll(tempDirName) + } + + return cdb, cleanUp, nil +} + +type testContext struct { + cdb *channeldb.DB + registry *InvoiceRegistry + clock *clock.TestClock + + cleanup func() + t *testing.T +} + +func newTestContext(t *testing.T) *testContext { + clock := clock.NewTestClock(testTime) + + cdb, cleanup, err := newTestChannelDB() + if err != nil { + t.Fatal(err) + } + cdb.Now = clock.Now + + expiryWatcher := NewInvoiceExpiryWatcher(clock) + + // Instantiate and start the invoice ctx.registry. + cfg := RegistryConfig{ + FinalCltvRejectDelta: testFinalCltvRejectDelta, + HtlcHoldDuration: 30 * time.Second, + Clock: clock, + } + registry := NewRegistry(cdb, expiryWatcher, &cfg) + + err = registry.Start() + if err != nil { + cleanup() + t.Fatal(err) + } + + ctx := testContext{ + cdb: cdb, + registry: registry, + clock: clock, + t: t, + cleanup: func() { + registry.Stop() + cleanup() + }, + } + + return &ctx +} + +func getCircuitKey(htlcID uint64) channeldb.CircuitKey { + return channeldb.CircuitKey{ + ChanID: lnwire.ShortChannelID{ + BlockHeight: 1, TxIndex: 2, TxPosition: 3, + }, + HtlcID: htlcID, + } +} + +func newTestInvoice(t *testing.T, preimage lntypes.Preimage, + timestamp time.Time, expiry time.Duration) *channeldb.Invoice { + + if expiry == 0 { + expiry = time.Hour + } + + rawInvoice, err := zpay32.NewInvoice( + testNetParams, + preimage.Hash(), + timestamp, + zpay32.Amount(testInvoiceAmount), + zpay32.Description(testInvoiceDescription), + zpay32.Expiry(expiry)) + + if err != nil { + t.Fatalf("Error while creating new invoice: %v", err) + } + + paymentRequest, err := rawInvoice.Encode(testMessageSigner) + + if err != nil { + t.Fatalf("Error while encoding payment request: %v", err) + } + + return &channeldb.Invoice{ + Terms: channeldb.ContractTerm{ + PaymentPreimage: preimage, + Value: testInvoiceAmount, + Expiry: expiry, + Features: testFeatures, + }, + PaymentRequest: []byte(paymentRequest), + CreationDate: timestamp, + } +} + +// timeout implements a test level timeout. +func timeout() func() { + done := make(chan struct{}) + + go func() { + select { + case <-time.After(5 * time.Second): + err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + if err != nil { + panic(fmt.Sprintf("error writing to std out after timeout: %v", err)) + } + panic("timeout") + case <-done: + } + }() + + return func() { + close(done) + } +} + +// invoiceExpiryTestData simply holds generated expired and pending invoices. +type invoiceExpiryTestData struct { + expiredInvoices map[lntypes.Hash]*channeldb.Invoice + pendingInvoices map[lntypes.Hash]*channeldb.Invoice +} + +// generateInvoiceExpiryTestData generates the specified number of fake expired +// and pending invoices anchored to the passed now timestamp. +func generateInvoiceExpiryTestData( + t *testing.T, now time.Time, + offset, numExpired, numPending int) invoiceExpiryTestData { + + var testData invoiceExpiryTestData + + testData.expiredInvoices = make(map[lntypes.Hash]*channeldb.Invoice) + testData.pendingInvoices = make(map[lntypes.Hash]*channeldb.Invoice) + + expiredCreationDate := now.Add(-24 * time.Hour) + + for i := 1; i <= numExpired; i++ { + var preimage lntypes.Preimage + binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i)) + expiry := time.Duration((i+offset)%24) * time.Hour + invoice := newTestInvoice(t, preimage, expiredCreationDate, expiry) + testData.expiredInvoices[preimage.Hash()] = invoice + } + + for i := 1; i <= numPending; i++ { + var preimage lntypes.Preimage + binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i)) + expiry := time.Duration((i+offset)%24) * time.Hour + invoice := newTestInvoice(t, preimage, now, expiry) + testData.pendingInvoices[preimage.Hash()] = invoice + } + + return testData +} diff --git a/invoices/utils_test.go b/invoices/utils_test.go deleted file mode 100644 index f33a93e8..00000000 --- a/invoices/utils_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package invoices - -import ( - "os" - "runtime/pprof" - "testing" - "time" -) - -// timeout implements a test level timeout. -func timeout(t *testing.T) func() { - done := make(chan struct{}) - go func() { - select { - case <-time.After(5 * time.Second): - pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - - panic("test timeout") - case <-done: - } - }() - - return func() { - close(done) - } -} diff --git a/queue/priority_queue.go b/queue/priority_queue.go index aae7b423..06485e53 100644 --- a/queue/priority_queue.go +++ b/queue/priority_queue.go @@ -9,6 +9,8 @@ import ( // PriorityQueue will be able to use that to build and restore an underlying // heap. type PriorityQueueItem interface { + // Less must return true if this item is ordered before other and false + // otherwise. Less(other PriorityQueueItem) bool } @@ -43,7 +45,7 @@ func (pq *priorityQueue) Pop() interface{} { return item } -// Priority wrap a standard heap in a more object-oriented structure. +// PriorityQueue wraps a standard heap into a self contained class. type PriorityQueue struct { queue priorityQueue } diff --git a/server.go b/server.go index 88dce9de..320c3e68 100644 --- a/server.go +++ b/server.go @@ -33,6 +33,7 @@ import ( "github.com/lightningnetwork/lnd/chanfitness" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/feature" @@ -381,8 +382,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, registryConfig := invoices.RegistryConfig{ FinalCltvRejectDelta: defaultFinalCltvRejectDelta, HtlcHoldDuration: invoices.DefaultHtlcHoldDuration, - Now: time.Now, - TickAfter: time.After, + Clock: clock.NewDefaultClock(), } s := &server{ @@ -393,7 +393,10 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, readPool: readPool, chansToRestore: chansToRestore, - invoices: invoices.NewRegistry(chanDB, ®istryConfig), + invoices: invoices.NewRegistry( + chanDB, invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()), + ®istryConfig, + ), channelNotifier: channelnotifier.New(chanDB), diff --git a/zpay32/invoice.go b/zpay32/invoice.go index ab871125..3afb4cda 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -82,6 +82,10 @@ const ( // This is chosen to be the maximum number of bytes that can fit into a // single QR code: https://en.wikipedia.org/wiki/QR_code#Storage maxInvoiceLength = 7089 + + // DefaultInvoiceExpiry is the default expiry duration from the creation + // timestamp if expiry is set to zero. + DefaultInvoiceExpiry = time.Hour ) var (