Merge pull request #3694 from bhandras/i3448

invoices+channeldb: reject payments to expired invoices
This commit is contained in:
Olaoluwa Osuntokun 2019-12-13 19:53:53 -08:00 committed by GitHub
commit e34bc3d645
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1147 additions and 239 deletions

@ -2,6 +2,7 @@ package channeldb
import ( import (
"crypto/rand" "crypto/rand"
mrand "math/rand"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -393,6 +394,114 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
} }
} }
// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their
// corresponding payment hashes.
func TestFetchAllInvoicesWithPaymentHash(t *testing.T) {
t.Parallel()
db, cleanup, err := makeTestDB()
defer cleanup()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
// With an empty DB we expect to return no error and an empty list.
empty, err := db.FetchAllInvoicesWithPaymentHash(false)
if err != nil {
t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v",
err)
}
if len(empty) != 0 {
t.Fatalf("expected empty list as a result, got: %v", empty)
}
// Now populate the DB and check if we can get all invoices with their
// payment hashes as expected.
const numInvoices = 20
testPendingInvoices := make(map[lntypes.Hash]*Invoice)
testAllInvoices := make(map[lntypes.Hash]*Invoice)
states := []ContractState{
ContractOpen, ContractSettled, ContractCanceled, ContractAccepted,
}
for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ {
invoice, err := randInvoice(i)
if err != nil {
t.Fatalf("unable to create invoice: %v", err)
}
invoice.State = states[mrand.Intn(len(states))]
paymentHash := invoice.Terms.PaymentPreimage.Hash()
if invoice.State != ContractSettled && invoice.State != ContractCanceled {
testPendingInvoices[paymentHash] = invoice
}
testAllInvoices[paymentHash] = invoice
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
}
pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true)
if err != nil {
t.Fatalf("can't fetch invoices with payment hash: %v", err)
}
if len(testPendingInvoices) != len(pendingInvoices) {
t.Fatalf("expected %v pending invoices, got: %v",
len(testPendingInvoices), len(pendingInvoices))
}
allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false)
if err != nil {
t.Fatalf("can't fetch invoices with payment hash: %v", err)
}
if len(testAllInvoices) != len(allInvoices) {
t.Fatalf("expected %v invoices, got: %v",
len(testAllInvoices), len(allInvoices))
}
for i := range pendingInvoices {
expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash]
if !ok {
t.Fatalf("coulnd't find invoice with hash: %v",
pendingInvoices[i].PaymentHash)
}
// Zero out add index to not confuse DeepEqual.
pendingInvoices[i].Invoice.AddIndex = 0
expected.AddIndex = 0
if !reflect.DeepEqual(*expected, pendingInvoices[i].Invoice) {
t.Fatalf("expected: %v, got: %v",
spew.Sdump(expected), spew.Sdump(pendingInvoices[i].Invoice))
}
}
for i := range allInvoices {
expected, ok := testAllInvoices[allInvoices[i].PaymentHash]
if !ok {
t.Fatalf("coulnd't find invoice with hash: %v",
allInvoices[i].PaymentHash)
}
// Zero out add index to not confuse DeepEqual.
allInvoices[i].Invoice.AddIndex = 0
expected.AddIndex = 0
if !reflect.DeepEqual(*expected, allInvoices[i].Invoice) {
t.Fatalf("expected: %v, got: %v",
spew.Sdump(expected), spew.Sdump(allInvoices[i].Invoice))
}
}
}
// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it // TestDuplicateSettleInvoice tests that if we add a new invoice and settle it
// twice, then the second time we also receive the invoice that we settled as a // twice, then the second time we also receive the invoice that we settled as a
// return argument. // return argument.

@ -565,6 +565,83 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
return invoice, nil return invoice, nil
} }
// InvoiceWithPaymentHash is used to store an invoice and its corresponding
// payment hash. This struct is only used to store results of
// ChannelDB.FetchAllInvoicesWithPaymentHash() call.
type InvoiceWithPaymentHash struct {
// Invoice holds the invoice as selected from the invoices bucket.
Invoice Invoice
// PaymentHash is the payment hash for the Invoice.
PaymentHash lntypes.Hash
}
// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes
// currently stored within the database. If the pendingOnly param is true, then
// only unsettled invoices and their payment hashes will be returned, skipping
// all invoices that are fully settled or canceled. Note that the returned
// array is not ordered by add index.
func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) (
[]InvoiceWithPaymentHash, error) {
var result []InvoiceWithPaymentHash
err := d.View(func(tx *bbolt.Tx) error {
invoices := tx.Bucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
}
invoiceIndex := invoices.Bucket(invoiceIndexBucket)
if invoiceIndex == nil {
// Mask the error if there's no invoice
// index as that simply means there are no
// invoices added yet to the DB. In this case
// we simply return an empty list.
return nil
}
return invoiceIndex.ForEach(func(k, v []byte) error {
// Skip the special numInvoicesKey as that does not
// point to a valid invoice.
if bytes.Equal(k, numInvoicesKey) {
return nil
}
if v == nil {
return nil
}
invoice, err := fetchInvoice(v, invoices)
if err != nil {
return err
}
if pendingOnly &&
(invoice.State == ContractSettled ||
invoice.State == ContractCanceled) {
return nil
}
invoiceWithPaymentHash := InvoiceWithPaymentHash{
Invoice: invoice,
}
copy(invoiceWithPaymentHash.PaymentHash[:], k)
result = append(result, invoiceWithPaymentHash)
return nil
})
})
if err != nil {
return nil, err
}
return result, nil
}
// FetchAllInvoices returns all invoices currently stored within the database. // FetchAllInvoices returns all invoices currently stored within the database.
// If the pendingOnly param is true, then only unsettled invoices will be // If the pendingOnly param is true, then only unsettled invoices will be
// returned, skipping all invoices that are fully settled. // returned, skipping all invoices that are fully settled.

24
clock/default_clock.go Normal file

@ -0,0 +1,24 @@
package clock
import (
"time"
)
// DefaultClock implements Clock interface by simply calling the appropriate
// time functions.
type DefaultClock struct{}
// NewDefaultClock constructs a new DefaultClock.
func NewDefaultClock() Clock {
return &DefaultClock{}
}
// Now simply returns time.Now().
func (DefaultClock) Now() time.Time {
return time.Now()
}
// TickAfter simply wraps time.After().
func (DefaultClock) TickAfter(duration time.Duration) <-chan time.Time {
return time.After(duration)
}

16
clock/interface.go Normal file

@ -0,0 +1,16 @@
package clock
import (
"time"
)
// Clock is an interface that provides a time functions for LND packages.
// This is useful during testing when a concrete time reference is needed.
type Clock interface {
// Now returns the current local time (as defined by the Clock).
Now() time.Time
// TickAfter returns a channel that will receive a tick after the specified
// duration has passed.
TickAfter(duration time.Duration) <-chan time.Time
}

@ -1,42 +1,40 @@
package invoices package clock
import ( import (
"sync" "sync"
"time" "time"
) )
// testClock can be used in tests to mock time. // TestClock can be used in tests to mock time.
type testClock struct { type TestClock struct {
currentTime time.Time currentTime time.Time
timeChanMap map[time.Time][]chan time.Time timeChanMap map[time.Time][]chan time.Time
timeLock sync.Mutex timeLock sync.Mutex
} }
// newTestClock returns a new test clock. // NewTestClock returns a new test clock.
func newTestClock(startTime time.Time) *testClock { func NewTestClock(startTime time.Time) *TestClock {
return &testClock{ return &TestClock{
currentTime: startTime, currentTime: startTime,
timeChanMap: make(map[time.Time][]chan time.Time), timeChanMap: make(map[time.Time][]chan time.Time),
} }
} }
// now returns the current (test) time. // Now returns the current (test) time.
func (c *testClock) now() time.Time { func (c *TestClock) Now() time.Time {
c.timeLock.Lock() c.timeLock.Lock()
defer c.timeLock.Unlock() defer c.timeLock.Unlock()
return c.currentTime return c.currentTime
} }
// tickAfter returns a channel that will receive a tick at the specified time. // TickAfter returns a channel that will receive a tick after the specified
func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time { // duration has passed passed by the user set test time.
func (c *TestClock) TickAfter(duration time.Duration) <-chan time.Time {
c.timeLock.Lock() c.timeLock.Lock()
defer c.timeLock.Unlock() defer c.timeLock.Unlock()
triggerTime := c.currentTime.Add(duration) triggerTime := c.currentTime.Add(duration)
log.Debugf("tickAfter called: duration=%v, trigger_time=%v",
duration, triggerTime)
ch := make(chan time.Time, 1) ch := make(chan time.Time, 1)
// If already expired, tick immediately. // If already expired, tick immediately.
@ -53,8 +51,8 @@ func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time {
return ch return ch
} }
// setTime sets the (test) time and triggers tick channels when they expire. // SetTime sets the (test) time and triggers tick channels when they expire.
func (c *testClock) setTime(now time.Time) { func (c *TestClock) SetTime(now time.Time) {
c.timeLock.Lock() c.timeLock.Lock()
defer c.timeLock.Unlock() defer c.timeLock.Unlock()

63
clock/test_clock_test.go Normal file

@ -0,0 +1,63 @@
package clock
import (
"testing"
"time"
)
var (
testTime = time.Date(2009, time.January, 3, 12, 0, 0, 0, time.UTC)
)
func TestNow(t *testing.T) {
c := NewTestClock(testTime)
now := c.Now()
if now != testTime {
t.Fatalf("expected: %v, got: %v", testTime, now)
}
now = now.Add(time.Hour)
c.SetTime(now)
if c.Now() != now {
t.Fatalf("epected: %v, got: %v", now, c.Now())
}
}
func TestTickAfter(t *testing.T) {
c := NewTestClock(testTime)
// Should be ticking immediately.
ticker0 := c.TickAfter(0)
// Both should be ticking after SetTime
ticker1 := c.TickAfter(time.Hour)
ticker2 := c.TickAfter(time.Hour)
// We don't expect this one to tick.
ticker3 := c.TickAfter(2 * time.Hour)
tickOrTimeOut := func(ticker <-chan time.Time, expectTick bool) {
tick := false
select {
case <-ticker:
tick = true
case <-time.After(time.Millisecond):
}
if tick != expectTick {
t.Fatalf("expected tick: %v, ticked: %v", expectTick, tick)
}
}
tickOrTimeOut(ticker0, true)
tickOrTimeOut(ticker1, false)
tickOrTimeOut(ticker2, false)
tickOrTimeOut(ticker3, false)
c.SetTime(c.Now().Add(time.Hour))
tickOrTimeOut(ticker1, true)
tickOrTimeOut(ticker2, true)
tickOrTimeOut(ticker3, false)
}

@ -22,6 +22,7 @@ import (
sphinx "github.com/lightningnetwork/lightning-onion" sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
@ -792,6 +793,7 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
registry := invoices.NewRegistry( registry := invoices.NewRegistry(
cdb, cdb,
invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()),
&invoices.RegistryConfig{ &invoices.RegistryConfig{
FinalCltvRejectDelta: 5, FinalCltvRejectDelta: 5,
}, },

@ -0,0 +1,191 @@
package invoices
import (
"fmt"
"sync"
"time"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/zpay32"
)
// invoiceExpiry holds and invoice's payment hash and its expiry. This
// is used to order invoices by their expiry for cancellation.
type invoiceExpiry struct {
PaymentHash lntypes.Hash
Expiry time.Time
}
// Less implements PriorityQueueItem.Less such that the top item in the
// priorty queue will be the one that expires next.
func (e invoiceExpiry) Less(other queue.PriorityQueueItem) bool {
return e.Expiry.Before(other.(*invoiceExpiry).Expiry)
}
// InvoiceExpiryWatcher handles automatic invoice cancellation of expried
// invoices. Upon start InvoiceExpiryWatcher will retrieve all pending (not yet
// settled or canceled) invoices invoices to its watcing queue. When a new
// invoice is added to the InvoiceRegistry, it'll be forarded to the
// InvoiceExpiryWatcher and will end up in the watching queue as well.
// If any of the watched invoices expire, they'll be removed from the watching
// queue and will be cancelled through InvoiceRegistry.CancelInvoice().
type InvoiceExpiryWatcher struct {
sync.Mutex
started bool
// clock is the clock implementation that InvoiceExpiryWatcher uses.
// It is useful for testing.
clock clock.Clock
// cancelInvoice is a template method that cancels an expired invoice.
cancelInvoice func(lntypes.Hash) error
// expiryQueue holds invoiceExpiry items and is used to find the next
// invoice to expire.
expiryQueue queue.PriorityQueue
// newInvoices channel is used to wake up the main loop when a new invoices
// is added.
newInvoices chan *invoiceExpiry
wg sync.WaitGroup
// quit signals InvoiceExpiryWatcher to stop.
quit chan struct{}
}
// NewInvoiceExpiryWatcher creates a new InvoiceExpiryWatcher instance.
func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher {
return &InvoiceExpiryWatcher{
clock: clock,
newInvoices: make(chan *invoiceExpiry),
quit: make(chan struct{}),
}
}
// Start starts the the subscription handler and the main loop. Start() will
// return with error if InvoiceExpiryWatcher is already started. Start()
// expects a cancellation function passed that will be use to cancel expired
// invoices by their payment hash.
func (ew *InvoiceExpiryWatcher) Start(
cancelInvoice func(lntypes.Hash) error) error {
ew.Lock()
defer ew.Unlock()
if ew.started {
return fmt.Errorf("InvoiceExpiryWatcher already started")
}
ew.started = true
ew.cancelInvoice = cancelInvoice
ew.wg.Add(1)
go ew.mainLoop()
return nil
}
// Stop quits the expiry handler loop and waits for InvoiceExpiryWatcher to
// fully stop.
func (ew *InvoiceExpiryWatcher) Stop() {
ew.Lock()
defer ew.Unlock()
if ew.started {
// Signal subscriptionHandler to quit and wait for it to return.
close(ew.quit)
ew.wg.Wait()
ew.started = false
}
}
// AddInvoice adds a new invoice to the InvoiceExpiryWatcher. This won't check
// if the invoice is already added and will only add invoices with ContractOpen
// state.
func (ew *InvoiceExpiryWatcher) AddInvoice(
paymentHash lntypes.Hash, invoice *channeldb.Invoice) {
if invoice.State != channeldb.ContractOpen {
log.Debugf("Invoice not added to expiry watcher: %v", invoice)
return
}
realExpiry := invoice.Terms.Expiry
if realExpiry == 0 {
realExpiry = zpay32.DefaultInvoiceExpiry
}
expiry := invoice.CreationDate.Add(realExpiry)
log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v",
paymentHash, expiry)
newInvoiceExpiry := &invoiceExpiry{
PaymentHash: paymentHash,
Expiry: expiry,
}
select {
case ew.newInvoices <- newInvoiceExpiry:
case <-ew.quit:
// Select on quit too so that callers won't get blocked in case
// of concurrent shutdown.
}
}
// nextExpiry returns a Time chan to wait on until the next invoice expires.
// If there are no active invoices, then it'll simply wait indefinitely.
func (ew *InvoiceExpiryWatcher) nextExpiry() <-chan time.Time {
if !ew.expiryQueue.Empty() {
top := ew.expiryQueue.Top().(*invoiceExpiry)
return ew.clock.TickAfter(top.Expiry.Sub(ew.clock.Now()))
}
return nil
}
// cancelExpiredInvoices will cancel all expired invoices and removes them from
// the expiry queue.
func (ew *InvoiceExpiryWatcher) cancelExpiredInvoices() {
for !ew.expiryQueue.Empty() {
top := ew.expiryQueue.Top().(*invoiceExpiry)
if !top.Expiry.Before(ew.clock.Now()) {
break
}
err := ew.cancelInvoice(top.PaymentHash)
if err != nil && err != channeldb.ErrInvoiceAlreadySettled &&
err != channeldb.ErrInvoiceAlreadyCanceled {
log.Errorf("Unable to cancel invoice: %v", top.PaymentHash)
}
ew.expiryQueue.Pop()
}
}
// mainLoop is a goroutine that receives new invoices and handles cancellation
// of expired invoices.
func (ew *InvoiceExpiryWatcher) mainLoop() {
defer ew.wg.Done()
for {
// Cancel any invoices that may have expired.
ew.cancelExpiredInvoices()
select {
case <-ew.nextExpiry():
// Wait until the next invoice expires, then cancel expired invoices.
continue
case newInvoiceExpiry := <-ew.newInvoices:
ew.expiryQueue.Push(newInvoiceExpiry)
case <-ew.quit:
return
}
}
}

@ -0,0 +1,125 @@
package invoices
import (
"testing"
"time"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
)
// invoiceExpiryWatcherTest holds a test fixture and implements checks
// for InvoiceExpiryWatcher tests.
type invoiceExpiryWatcherTest struct {
t *testing.T
watcher *InvoiceExpiryWatcher
testData invoiceExpiryTestData
canceledInvoices []lntypes.Hash
}
// newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture
// and sets up the test environment.
func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time,
numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest {
test := &invoiceExpiryWatcherTest{
watcher: NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)),
testData: generateInvoiceExpiryTestData(
t, now, 0, numExpiredInvoices, numPendingInvoices,
),
}
err := test.watcher.Start(func(paymentHash lntypes.Hash) error {
test.canceledInvoices = append(test.canceledInvoices, paymentHash)
return nil
})
if err != nil {
t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err)
}
return test
}
func (t *invoiceExpiryWatcherTest) checkExpectations() {
// Check that invoices that got canceled during the test are the ones
// that expired.
if len(t.canceledInvoices) != len(t.testData.expiredInvoices) {
t.t.Fatalf("expected %v cancellations, got %v",
len(t.testData.expiredInvoices), len(t.canceledInvoices))
}
for i := range t.canceledInvoices {
if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok {
t.t.Fatalf("wrong invoice canceled")
}
}
}
// Tests that InvoiceExpiryWatcher can be started and stopped.
func TestInvoiceExpiryWatcherStartStop(t *testing.T) {
watcher := NewInvoiceExpiryWatcher(clock.NewTestClock(testTime))
cancel := func(lntypes.Hash) error {
t.Fatalf("unexpected call")
return nil
}
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
if err := watcher.Start(cancel); err == nil {
t.Fatalf("expected error upon second start")
}
watcher.Stop()
if err := watcher.Start(cancel); err != nil {
t.Fatalf("unexpected error upon start: %v", err)
}
}
// Tests that no invoices will expire from an empty InvoiceExpiryWatcher.
func TestInvoiceExpiryWithNoInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0)
time.Sleep(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if all add invoices are expired, then all invoices
// will be canceled.
func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5)
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoice(paymentHash, invoice)
}
time.Sleep(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}
// Tests that if some invoices are expired, then those invoices
// will be canceled.
func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
for paymentHash, invoice := range test.testData.expiredInvoices {
test.watcher.AddInvoice(paymentHash, invoice)
}
for paymentHash, invoice := range test.testData.pendingInvoices {
test.watcher.AddInvoice(paymentHash, invoice)
}
time.Sleep(testTimeout)
test.watcher.Stop()
test.checkExpectations()
}

@ -9,6 +9,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/queue" "github.com/lightningnetwork/lnd/queue"
@ -62,12 +63,10 @@ type RegistryConfig struct {
// waiting for the other set members to arrive. // waiting for the other set members to arrive.
HtlcHoldDuration time.Duration HtlcHoldDuration time.Duration
// Now returns the current time. // Clock holds the clock implementation that is used to provide
Now func() time.Time // Now() and TickAfter() and is useful to stub out the clock functions
// during testing.
// TickAfter returns a channel that is sent on after the specified Clock clock.Clock
// duration as passed.
TickAfter func(duration time.Duration) <-chan time.Time
} }
// htlcReleaseEvent describes an htlc auto-release event. It is used to release // htlcReleaseEvent describes an htlc auto-release event. It is used to release
@ -126,6 +125,8 @@ type InvoiceRegistry struct {
// auto-released. // auto-released.
htlcAutoReleaseChan chan *htlcReleaseEvent htlcAutoReleaseChan chan *htlcReleaseEvent
expiryWatcher *InvoiceExpiryWatcher
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
} }
@ -134,7 +135,9 @@ type InvoiceRegistry struct {
// wraps the persistent on-disk invoice storage with an additional in-memory // wraps the persistent on-disk invoice storage with an additional in-memory
// layer. The in-memory layer is in place such that debug invoices can be added // layer. The in-memory layer is in place such that debug invoices can be added
// which are volatile yet available system wide within the daemon. // which are volatile yet available system wide within the daemon.
func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry { func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher,
cfg *RegistryConfig) *InvoiceRegistry {
return &InvoiceRegistry{ return &InvoiceRegistry{
cdb: cdb, cdb: cdb,
notificationClients: make(map[uint32]*InvoiceSubscription), notificationClients: make(map[uint32]*InvoiceSubscription),
@ -146,21 +149,62 @@ func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry {
hodlReverseSubscriptions: make(map[chan<- interface{}]map[channeldb.CircuitKey]struct{}), hodlReverseSubscriptions: make(map[chan<- interface{}]map[channeldb.CircuitKey]struct{}),
cfg: cfg, cfg: cfg,
htlcAutoReleaseChan: make(chan *htlcReleaseEvent), htlcAutoReleaseChan: make(chan *htlcReleaseEvent),
expiryWatcher: expiryWatcher,
quit: make(chan struct{}), quit: make(chan struct{}),
} }
} }
// populateExpiryWatcher fetches all active invoices and their corresponding
// payment hashes from ChannelDB and adds them to the expiry watcher.
func (i *InvoiceRegistry) populateExpiryWatcher() error {
pendingOnly := true
pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly)
if err != nil && err != channeldb.ErrNoInvoicesCreated {
log.Errorf(
"Error while prefetching active invoices from the database: %v", err,
)
return err
}
for idx := range pendingInvoices {
i.expiryWatcher.AddInvoice(
pendingInvoices[idx].PaymentHash, &pendingInvoices[idx].Invoice,
)
}
return nil
}
// Start starts the registry and all goroutines it needs to carry out its task. // Start starts the registry and all goroutines it needs to carry out its task.
func (i *InvoiceRegistry) Start() error { func (i *InvoiceRegistry) Start() error {
i.wg.Add(1) // Start InvoiceExpiryWatcher and prepopulate it with existing active
// invoices.
err := i.expiryWatcher.Start(func(paymentHash lntypes.Hash) error {
cancelIfAccepted := false
return i.cancelInvoiceImpl(paymentHash, cancelIfAccepted)
})
if err != nil {
return err
}
i.wg.Add(1)
go i.invoiceEventLoop() go i.invoiceEventLoop()
// Now prefetch all pending invoices to the expiry watcher.
err = i.populateExpiryWatcher()
if err != nil {
i.Stop()
return err
}
return nil return nil
} }
// Stop signals the registry for a graceful shutdown. // Stop signals the registry for a graceful shutdown.
func (i *InvoiceRegistry) Stop() { func (i *InvoiceRegistry) Stop() {
i.expiryWatcher.Stop()
close(i.quit) close(i.quit)
i.wg.Wait() i.wg.Wait()
@ -177,8 +221,8 @@ type invoiceEvent struct {
// tickAt returns a channel that ticks at the specified time. If the time has // tickAt returns a channel that ticks at the specified time. If the time has
// already passed, it will tick immediately. // already passed, it will tick immediately.
func (i *InvoiceRegistry) tickAt(t time.Time) <-chan time.Time { func (i *InvoiceRegistry) tickAt(t time.Time) <-chan time.Time {
now := i.cfg.Now() now := i.cfg.Clock.Now()
return i.cfg.TickAfter(t.Sub(now)) return i.cfg.Clock.TickAfter(t.Sub(now))
} }
// invoiceEventLoop is the dedicated goroutine responsible for accepting // invoiceEventLoop is the dedicated goroutine responsible for accepting
@ -471,7 +515,6 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice,
paymentHash lntypes.Hash) (uint64, error) { paymentHash lntypes.Hash) (uint64, error) {
i.Lock() i.Lock()
defer i.Unlock()
log.Debugf("Invoice(%v): added %v", paymentHash, log.Debugf("Invoice(%v): added %v", paymentHash,
newLogClosure(func() string { newLogClosure(func() string {
@ -481,12 +524,19 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice,
addIndex, err := i.cdb.AddInvoice(invoice, paymentHash) addIndex, err := i.cdb.AddInvoice(invoice, paymentHash)
if err != nil { if err != nil {
i.Unlock()
return 0, err return 0, err
} }
// Now that we've added the invoice, we'll send dispatch a message to // Now that we've added the invoice, we'll send dispatch a message to
// notify the clients of this new invoice. // notify the clients of this new invoice.
i.notifyClients(paymentHash, invoice, channeldb.ContractOpen) i.notifyClients(paymentHash, invoice, channeldb.ContractOpen)
i.Unlock()
// InvoiceExpiryWatcher.AddInvoice must not be locked by InvoiceRegistry
// to avoid deadlock when a new invoice is added while an other is being
// canceled.
i.expiryWatcher.AddInvoice(paymentHash, invoice)
return addIndex, nil return addIndex, nil
} }
@ -818,6 +868,15 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error {
// CancelInvoice attempts to cancel the invoice corresponding to the passed // CancelInvoice attempts to cancel the invoice corresponding to the passed
// payment hash. // payment hash.
func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
return i.cancelInvoiceImpl(payHash, true)
}
// cancelInvoice attempts to cancel the invoice corresponding to the passed
// payment hash. Accepted invoices will only be canceled if explicitly
// requested to do so.
func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash,
cancelAccepted bool) error {
i.Lock() i.Lock()
defer i.Unlock() defer i.Unlock()
@ -826,6 +885,12 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
updateInvoice := func(invoice *channeldb.Invoice) ( updateInvoice := func(invoice *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, error) { *channeldb.InvoiceUpdateDesc, error) {
// Only cancel the invoice in ContractAccepted state if explicitly
// requested to do so.
if invoice.State == channeldb.ContractAccepted && !cancelAccepted {
return nil, nil
}
// Move invoice to the canceled state. Rely on validation in // Move invoice to the canceled state. Rely on validation in
// channeldb to return an error if the invoice is already // channeldb to return an error if the invoice is already
// settled or canceled. // settled or canceled.
@ -848,6 +913,13 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
return err return err
} }
// Return without cancellation if the invoice state is ContractAccepted.
if invoice.State == channeldb.ContractAccepted {
log.Debugf("Invoice(%v): remains accepted as cancel wasn't"+
"explicitly requested.", payHash)
return nil
}
log.Debugf("Invoice(%v): canceled", payHash) log.Debugf("Invoice(%v): canceled", payHash)
// In the callback, some htlcs may have been moved to the canceled // In the callback, some htlcs may have been moved to the canceled

@ -1,117 +1,16 @@
package invoices package invoices
import ( import (
"io/ioutil"
"os"
"testing" "testing"
"time" "time"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
) )
var (
testTimeout = 5 * time.Second
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
preimage = lntypes.Preimage{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
}
hash = preimage.Hash()
testHtlcExpiry = uint32(5)
testInvoiceCltvDelta = uint32(4)
testFinalCltvRejectDelta = int32(4)
testCurrentHeight = int32(1)
testFeatures = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
testPayload = &mockPayload{}
)
var (
testInvoiceAmt = lnwire.MilliSatoshi(100000)
testInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: preimage,
Value: lnwire.MilliSatoshi(100000),
Features: testFeatures,
},
}
testHodlInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: channeldb.UnknownPreimage,
Value: testInvoiceAmt,
Features: testFeatures,
},
}
)
type testContext struct {
registry *InvoiceRegistry
clock *testClock
cleanup func()
t *testing.T
}
func newTestContext(t *testing.T) *testContext {
clock := newTestClock(testTime)
cdb, cleanup, err := newDB()
if err != nil {
t.Fatal(err)
}
cdb.Now = clock.now
// Instantiate and start the invoice ctx.registry.
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
HtlcHoldDuration: 30 * time.Second,
Now: clock.now,
TickAfter: clock.tickAfter,
}
registry := NewRegistry(cdb, &cfg)
err = registry.Start()
if err != nil {
cleanup()
t.Fatal(err)
}
ctx := testContext{
registry: registry,
clock: clock,
t: t,
cleanup: func() {
registry.Stop()
cleanup()
},
}
return &ctx
}
func getCircuitKey(htlcID uint64) channeldb.CircuitKey {
return channeldb.CircuitKey{
ChanID: lnwire.ShortChannelID{
BlockHeight: 1, TxIndex: 2, TxPosition: 3,
},
HtlcID: htlcID,
}
}
// TestSettleInvoice tests settling of an invoice and related notifications. // TestSettleInvoice tests settling of an invoice and related notifications.
func TestSettleInvoice(t *testing.T) { func TestSettleInvoice(t *testing.T) {
ctx := newTestContext(t) ctx := newTestContext(t)
@ -121,18 +20,18 @@ func TestSettleInvoice(t *testing.T) {
defer allSubscriptions.Cancel() defer allSubscriptions.Cancel()
// Subscribe to the not yet existing invoice. // Subscribe to the not yet existing invoice.
subscription, err := ctx.registry.SubscribeSingleInvoice(hash) subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer subscription.Cancel() defer subscription.Cancel()
if subscription.hash != hash { if subscription.hash != testInvoicePaymentHash {
t.Fatalf("expected subscription for provided hash") t.Fatalf("expected subscription for provided hash")
} }
// Add the invoice. // Add the invoice.
addIdx, err := ctx.registry.AddInvoice(testInvoice, hash) addIdx, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -168,7 +67,7 @@ func TestSettleInvoice(t *testing.T) {
// Try to settle invoice with an htlc that expires too soon. // Try to settle invoice with an htlc that expires too soon.
event, err := ctx.registry.NotifyExitHopHtlc( event, err := ctx.registry.NotifyExitHopHtlc(
hash, testInvoice.Terms.Value, testInvoicePaymentHash, testInvoice.Terms.Value,
uint32(testCurrentHeight)+testInvoiceCltvDelta-1, uint32(testCurrentHeight)+testInvoiceCltvDelta-1,
testCurrentHeight, getCircuitKey(10), hodlChan, testPayload, testCurrentHeight, getCircuitKey(10), hodlChan, testPayload,
) )
@ -186,7 +85,7 @@ func TestSettleInvoice(t *testing.T) {
// Settle invoice with a slightly higher amount. // Settle invoice with a slightly higher amount.
amtPaid := lnwire.MilliSatoshi(100500) amtPaid := lnwire.MilliSatoshi(100500)
_, err = ctx.registry.NotifyExitHopHtlc( _, err = ctx.registry.NotifyExitHopHtlc(
hash, amtPaid, testHtlcExpiry, testCurrentHeight, testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
getCircuitKey(0), hodlChan, testPayload, getCircuitKey(0), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -222,7 +121,7 @@ func TestSettleInvoice(t *testing.T) {
// Try to settle again with the same htlc id. We need this idempotent // Try to settle again with the same htlc id. We need this idempotent
// behaviour after a restart. // behaviour after a restart.
event, err = ctx.registry.NotifyExitHopHtlc( event, err = ctx.registry.NotifyExitHopHtlc(
hash, amtPaid, testHtlcExpiry, testCurrentHeight, testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
getCircuitKey(0), hodlChan, testPayload, getCircuitKey(0), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -236,7 +135,7 @@ func TestSettleInvoice(t *testing.T) {
// should also be accepted, to prevent any change in behaviour for a // should also be accepted, to prevent any change in behaviour for a
// paid invoice that may open up a probe vector. // paid invoice that may open up a probe vector.
event, err = ctx.registry.NotifyExitHopHtlc( event, err = ctx.registry.NotifyExitHopHtlc(
hash, amtPaid+600, testHtlcExpiry, testCurrentHeight, testInvoicePaymentHash, amtPaid+600, testHtlcExpiry, testCurrentHeight,
getCircuitKey(1), hodlChan, testPayload, getCircuitKey(1), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -249,7 +148,7 @@ func TestSettleInvoice(t *testing.T) {
// Try to settle again with a lower amount. This should fail just as it // Try to settle again with a lower amount. This should fail just as it
// would have failed if it were the first payment. // would have failed if it were the first payment.
event, err = ctx.registry.NotifyExitHopHtlc( event, err = ctx.registry.NotifyExitHopHtlc(
hash, amtPaid-600, testHtlcExpiry, testCurrentHeight, testInvoicePaymentHash, amtPaid-600, testHtlcExpiry, testCurrentHeight,
getCircuitKey(2), hodlChan, testPayload, getCircuitKey(2), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -261,7 +160,7 @@ func TestSettleInvoice(t *testing.T) {
// Check that settled amount is equal to the sum of values of the htlcs // Check that settled amount is equal to the sum of values of the htlcs
// 0 and 1. // 0 and 1.
inv, err := ctx.registry.LookupInvoice(hash) inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -270,7 +169,7 @@ func TestSettleInvoice(t *testing.T) {
} }
// Try to cancel. // Try to cancel.
err = ctx.registry.CancelInvoice(hash) err = ctx.registry.CancelInvoice(testInvoicePaymentHash)
if err != channeldb.ErrInvoiceAlreadySettled { if err != channeldb.ErrInvoiceAlreadySettled {
t.Fatal("expected cancelation of a settled invoice to fail") t.Fatal("expected cancelation of a settled invoice to fail")
} }
@ -292,25 +191,25 @@ func TestCancelInvoice(t *testing.T) {
defer allSubscriptions.Cancel() defer allSubscriptions.Cancel()
// Try to cancel the not yet existing invoice. This should fail. // Try to cancel the not yet existing invoice. This should fail.
err := ctx.registry.CancelInvoice(hash) err := ctx.registry.CancelInvoice(testInvoicePaymentHash)
if err != channeldb.ErrInvoiceNotFound { if err != channeldb.ErrInvoiceNotFound {
t.Fatalf("expected ErrInvoiceNotFound, but got %v", err) t.Fatalf("expected ErrInvoiceNotFound, but got %v", err)
} }
// Subscribe to the not yet existing invoice. // Subscribe to the not yet existing invoice.
subscription, err := ctx.registry.SubscribeSingleInvoice(hash) subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer subscription.Cancel() defer subscription.Cancel()
if subscription.hash != hash { if subscription.hash != testInvoicePaymentHash {
t.Fatalf("expected subscription for provided hash") t.Fatalf("expected subscription for provided hash")
} }
// Add the invoice. // Add the invoice.
amt := lnwire.MilliSatoshi(100000) amt := lnwire.MilliSatoshi(100000)
_, err = ctx.registry.AddInvoice(testInvoice, hash) _, err = ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -342,7 +241,7 @@ func TestCancelInvoice(t *testing.T) {
} }
// Cancel invoice. // Cancel invoice.
err = ctx.registry.CancelInvoice(hash) err = ctx.registry.CancelInvoice(testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -365,7 +264,7 @@ func TestCancelInvoice(t *testing.T) {
// subscribers (backwards compatibility). // subscribers (backwards compatibility).
// Try to cancel again. // Try to cancel again.
err = ctx.registry.CancelInvoice(hash) err = ctx.registry.CancelInvoice(testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal("expected cancelation of a canceled invoice to succeed") t.Fatal("expected cancelation of a canceled invoice to succeed")
} }
@ -374,7 +273,7 @@ func TestCancelInvoice(t *testing.T) {
// result in a cancel event. // result in a cancel event.
hodlChan := make(chan interface{}) hodlChan := make(chan interface{})
event, err := ctx.registry.NotifyExitHopHtlc( event, err := ctx.registry.NotifyExitHopHtlc(
hash, amt, testHtlcExpiry, testCurrentHeight, testInvoicePaymentHash, amt, testHtlcExpiry, testCurrentHeight,
getCircuitKey(0), hodlChan, testPayload, getCircuitKey(0), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -393,9 +292,9 @@ func TestCancelInvoice(t *testing.T) {
// TestSettleHoldInvoice tests settling of a hold invoice and related // TestSettleHoldInvoice tests settling of a hold invoice and related
// notifications. // notifications.
func TestSettleHoldInvoice(t *testing.T) { func TestSettleHoldInvoice(t *testing.T) {
defer timeout(t)() defer timeout()()
cdb, cleanup, err := newDB() cdb, cleanup, err := newTestChannelDB()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -404,8 +303,9 @@ func TestSettleHoldInvoice(t *testing.T) {
// Instantiate and start the invoice ctx.registry. // Instantiate and start the invoice ctx.registry.
cfg := RegistryConfig{ cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta, FinalCltvRejectDelta: testFinalCltvRejectDelta,
Clock: clock.NewTestClock(testTime),
} }
registry := NewRegistry(cdb, &cfg) registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg)
err = registry.Start() err = registry.Start()
if err != nil { if err != nil {
@ -417,18 +317,18 @@ func TestSettleHoldInvoice(t *testing.T) {
defer allSubscriptions.Cancel() defer allSubscriptions.Cancel()
// Subscribe to the not yet existing invoice. // Subscribe to the not yet existing invoice.
subscription, err := registry.SubscribeSingleInvoice(hash) subscription, err := registry.SubscribeSingleInvoice(testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer subscription.Cancel() defer subscription.Cancel()
if subscription.hash != hash { if subscription.hash != testInvoicePaymentHash {
t.Fatalf("expected subscription for provided hash") t.Fatalf("expected subscription for provided hash")
} }
// Add the invoice. // Add the invoice.
_, err = registry.AddInvoice(testHodlInvoice, hash) _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -455,7 +355,7 @@ func TestSettleHoldInvoice(t *testing.T) {
// NotifyExitHopHtlc without a preimage present in the invoice registry // NotifyExitHopHtlc without a preimage present in the invoice registry
// should be possible. // should be possible.
event, err := registry.NotifyExitHopHtlc( event, err := registry.NotifyExitHopHtlc(
hash, amtPaid, testHtlcExpiry, testCurrentHeight, testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
getCircuitKey(0), hodlChan, testPayload, getCircuitKey(0), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -467,7 +367,7 @@ func TestSettleHoldInvoice(t *testing.T) {
// Test idempotency. // Test idempotency.
event, err = registry.NotifyExitHopHtlc( event, err = registry.NotifyExitHopHtlc(
hash, amtPaid, testHtlcExpiry, testCurrentHeight, testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
getCircuitKey(0), hodlChan, testPayload, getCircuitKey(0), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -480,7 +380,7 @@ func TestSettleHoldInvoice(t *testing.T) {
// Test replay at a higher height. We expect the same result because it // Test replay at a higher height. We expect the same result because it
// is a replay. // is a replay.
event, err = registry.NotifyExitHopHtlc( event, err = registry.NotifyExitHopHtlc(
hash, amtPaid, testHtlcExpiry, testCurrentHeight+10, testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+10,
getCircuitKey(0), hodlChan, testPayload, getCircuitKey(0), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -493,7 +393,7 @@ func TestSettleHoldInvoice(t *testing.T) {
// Test a new htlc coming in that doesn't meet the final cltv delta // Test a new htlc coming in that doesn't meet the final cltv delta
// requirement. It should be rejected. // requirement. It should be rejected.
event, err = registry.NotifyExitHopHtlc( event, err = registry.NotifyExitHopHtlc(
hash, amtPaid, 1, testCurrentHeight, testInvoicePaymentHash, amtPaid, 1, testCurrentHeight,
getCircuitKey(1), hodlChan, testPayload, getCircuitKey(1), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -516,13 +416,13 @@ func TestSettleHoldInvoice(t *testing.T) {
} }
// Settling with preimage should succeed. // Settling with preimage should succeed.
err = registry.SettleHodlInvoice(preimage) err = registry.SettleHodlInvoice(testInvoicePreimage)
if err != nil { if err != nil {
t.Fatal("expected set preimage to succeed") t.Fatal("expected set preimage to succeed")
} }
hodlEvent := (<-hodlChan).(HodlEvent) hodlEvent := (<-hodlChan).(HodlEvent)
if *hodlEvent.Preimage != preimage { if *hodlEvent.Preimage != testInvoicePreimage {
t.Fatal("unexpected preimage in hodl event") t.Fatal("unexpected preimage in hodl event")
} }
if hodlEvent.AcceptHeight != testCurrentHeight { if hodlEvent.AcceptHeight != testCurrentHeight {
@ -549,13 +449,13 @@ func TestSettleHoldInvoice(t *testing.T) {
} }
// Idempotency. // Idempotency.
err = registry.SettleHodlInvoice(preimage) err = registry.SettleHodlInvoice(testInvoicePreimage)
if err != channeldb.ErrInvoiceAlreadySettled { if err != channeldb.ErrInvoiceAlreadySettled {
t.Fatalf("expected ErrInvoiceAlreadySettled but got %v", err) t.Fatalf("expected ErrInvoiceAlreadySettled but got %v", err)
} }
// Try to cancel. // Try to cancel.
err = registry.CancelInvoice(hash) err = registry.CancelInvoice(testInvoicePaymentHash)
if err == nil { if err == nil {
t.Fatal("expected cancelation of a settled invoice to fail") t.Fatal("expected cancelation of a settled invoice to fail")
} }
@ -564,9 +464,9 @@ func TestSettleHoldInvoice(t *testing.T) {
// TestCancelHoldInvoice tests canceling of a hold invoice and related // TestCancelHoldInvoice tests canceling of a hold invoice and related
// notifications. // notifications.
func TestCancelHoldInvoice(t *testing.T) { func TestCancelHoldInvoice(t *testing.T) {
defer timeout(t)() defer timeout()()
cdb, cleanup, err := newDB() cdb, cleanup, err := newTestChannelDB()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -575,8 +475,9 @@ func TestCancelHoldInvoice(t *testing.T) {
// Instantiate and start the invoice ctx.registry. // Instantiate and start the invoice ctx.registry.
cfg := RegistryConfig{ cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta, FinalCltvRejectDelta: testFinalCltvRejectDelta,
Clock: clock.NewTestClock(testTime),
} }
registry := NewRegistry(cdb, &cfg) registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg)
err = registry.Start() err = registry.Start()
if err != nil { if err != nil {
@ -585,7 +486,7 @@ func TestCancelHoldInvoice(t *testing.T) {
defer registry.Stop() defer registry.Stop()
// Add the invoice. // Add the invoice.
_, err = registry.AddInvoice(testHodlInvoice, hash) _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -596,7 +497,7 @@ func TestCancelHoldInvoice(t *testing.T) {
// NotifyExitHopHtlc without a preimage present in the invoice registry // NotifyExitHopHtlc without a preimage present in the invoice registry
// should be possible. // should be possible.
event, err := registry.NotifyExitHopHtlc( event, err := registry.NotifyExitHopHtlc(
hash, amtPaid, testHtlcExpiry, testCurrentHeight, testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
getCircuitKey(0), hodlChan, testPayload, getCircuitKey(0), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -607,7 +508,7 @@ func TestCancelHoldInvoice(t *testing.T) {
} }
// Cancel invoice. // Cancel invoice.
err = registry.CancelInvoice(hash) err = registry.CancelInvoice(testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal("cancel invoice failed") t.Fatal("cancel invoice failed")
} }
@ -621,7 +522,7 @@ func TestCancelHoldInvoice(t *testing.T) {
// in a rejection. The accept height is expected to be the original // in a rejection. The accept height is expected to be the original
// accept height. // accept height.
event, err = registry.NotifyExitHopHtlc( event, err = registry.NotifyExitHopHtlc(
hash, amtPaid, testHtlcExpiry, testCurrentHeight+1, testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+1,
getCircuitKey(0), hodlChan, testPayload, getCircuitKey(0), hodlChan, testPayload,
) )
if err != nil { if err != nil {
@ -636,29 +537,6 @@ func TestCancelHoldInvoice(t *testing.T) {
} }
} }
func newDB() (*channeldb.DB, func(), error) {
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName, err := ioutil.TempDir("", "channeldb")
if err != nil {
return nil, nil, err
}
// Next, create channeldb for the first time.
cdb, err := channeldb.Open(tempDirName)
if err != nil {
os.RemoveAll(tempDirName)
return nil, nil, err
}
cleanUp := func() {
cdb.Close()
os.RemoveAll(tempDirName)
}
return cdb, cleanUp, nil
}
// TestUnknownInvoice tests that invoice registry returns an error when the // TestUnknownInvoice tests that invoice registry returns an error when the
// invoice is unknown. This is to guard against returning a cancel hodl event // invoice is unknown. This is to guard against returning a cancel hodl event
// for forwarded htlcs. In the link, NotifyExitHopHtlc is only called if we are // for forwarded htlcs. In the link, NotifyExitHopHtlc is only called if we are
@ -673,7 +551,7 @@ func TestUnknownInvoice(t *testing.T) {
hodlChan := make(chan interface{}) hodlChan := make(chan interface{})
amt := lnwire.MilliSatoshi(100000) amt := lnwire.MilliSatoshi(100000)
_, err := ctx.registry.NotifyExitHopHtlc( _, err := ctx.registry.NotifyExitHopHtlc(
hash, amt, testHtlcExpiry, testCurrentHeight, testInvoicePaymentHash, amt, testHtlcExpiry, testCurrentHeight,
getCircuitKey(0), hodlChan, testPayload, getCircuitKey(0), hodlChan, testPayload,
) )
if err != channeldb.ErrInvoiceNotFound { if err != channeldb.ErrInvoiceNotFound {
@ -681,27 +559,15 @@ func TestUnknownInvoice(t *testing.T) {
} }
} }
type mockPayload struct {
mpp *record.MPP
}
func (p *mockPayload) MultiPath() *record.MPP {
return p.mpp
}
func (p *mockPayload) CustomRecords() record.CustomSet {
return make(record.CustomSet)
}
// TestSettleMpp tests settling of an invoice with multiple partial payments. // TestSettleMpp tests settling of an invoice with multiple partial payments.
func TestSettleMpp(t *testing.T) { func TestSettleMpp(t *testing.T) {
defer timeout(t)() defer timeout()()
ctx := newTestContext(t) ctx := newTestContext(t)
defer ctx.cleanup() defer ctx.cleanup()
// Add the invoice. // Add the invoice.
_, err := ctx.registry.AddInvoice(testInvoice, hash) _, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -713,7 +579,7 @@ func TestSettleMpp(t *testing.T) {
// Send htlc 1. // Send htlc 1.
hodlChan1 := make(chan interface{}, 1) hodlChan1 := make(chan interface{}, 1)
event, err := ctx.registry.NotifyExitHopHtlc( event, err := ctx.registry.NotifyExitHopHtlc(
hash, testInvoice.Terms.Value/2, testInvoicePaymentHash, testInvoice.Terms.Value/2,
testHtlcExpiry, testHtlcExpiry,
testCurrentHeight, getCircuitKey(10), hodlChan1, mppPayload, testCurrentHeight, getCircuitKey(10), hodlChan1, mppPayload,
) )
@ -725,7 +591,7 @@ func TestSettleMpp(t *testing.T) {
} }
// Simulate mpp timeout releasing htlc 1. // Simulate mpp timeout releasing htlc 1.
ctx.clock.setTime(testTime.Add(30 * time.Second)) ctx.clock.SetTime(testTime.Add(30 * time.Second))
hodlEvent := (<-hodlChan1).(HodlEvent) hodlEvent := (<-hodlChan1).(HodlEvent)
if hodlEvent.Preimage != nil { if hodlEvent.Preimage != nil {
@ -735,7 +601,7 @@ func TestSettleMpp(t *testing.T) {
// Send htlc 2. // Send htlc 2.
hodlChan2 := make(chan interface{}, 1) hodlChan2 := make(chan interface{}, 1)
event, err = ctx.registry.NotifyExitHopHtlc( event, err = ctx.registry.NotifyExitHopHtlc(
hash, testInvoice.Terms.Value/2, testInvoicePaymentHash, testInvoice.Terms.Value/2,
testHtlcExpiry, testHtlcExpiry,
testCurrentHeight, getCircuitKey(11), hodlChan2, mppPayload, testCurrentHeight, getCircuitKey(11), hodlChan2, mppPayload,
) )
@ -749,7 +615,7 @@ func TestSettleMpp(t *testing.T) {
// Send htlc 3. // Send htlc 3.
hodlChan3 := make(chan interface{}, 1) hodlChan3 := make(chan interface{}, 1)
event, err = ctx.registry.NotifyExitHopHtlc( event, err = ctx.registry.NotifyExitHopHtlc(
hash, testInvoice.Terms.Value/2, testInvoicePaymentHash, testInvoice.Terms.Value/2,
testHtlcExpiry, testHtlcExpiry,
testCurrentHeight, getCircuitKey(12), hodlChan3, mppPayload, testCurrentHeight, getCircuitKey(12), hodlChan3, mppPayload,
) )
@ -762,7 +628,7 @@ func TestSettleMpp(t *testing.T) {
// Check that settled amount is equal to the sum of values of the htlcs // Check that settled amount is equal to the sum of values of the htlcs
// 0 and 1. // 0 and 1.
inv, err := ctx.registry.LookupInvoice(hash) inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -774,3 +640,105 @@ func TestSettleMpp(t *testing.T) {
testInvoice.Terms.Value, inv.AmtPaid) testInvoice.Terms.Value, inv.AmtPaid)
} }
} }
// Tests that invoices are canceled after expiration.
func TestInvoiceExpiryWithRegistry(t *testing.T) {
t.Parallel()
cdb, cleanup, err := newTestChannelDB()
defer cleanup()
if err != nil {
t.Fatal(err)
}
testClock := clock.NewTestClock(testTime)
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
Clock: testClock,
}
expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock)
registry := NewRegistry(cdb, expiryWatcher, &cfg)
// First prefill the Channel DB with some pre-existing invoices,
// half of them still pending, half of them expired.
const numExpired = 5
const numPending = 5
existingInvoices := generateInvoiceExpiryTestData(
t, testTime, 0, numExpired, numPending,
)
var expectedCancellations []lntypes.Hash
for paymentHash, expiredInvoice := range existingInvoices.expiredInvoices {
if _, err := cdb.AddInvoice(expiredInvoice, paymentHash); err != nil {
t.Fatalf("cannot add invoice to channel db: %v", err)
}
expectedCancellations = append(expectedCancellations, paymentHash)
}
for paymentHash, pendingInvoice := range existingInvoices.pendingInvoices {
if _, err := cdb.AddInvoice(pendingInvoice, paymentHash); err != nil {
t.Fatalf("cannot add invoice to channel db: %v", err)
}
}
if err = registry.Start(); err != nil {
t.Fatalf("cannot start registry: %v", err)
}
// Now generate pending and invoices and add them to the registry while
// it is up and running. We'll manipulate the clock to let them expire.
newInvoices := generateInvoiceExpiryTestData(
t, testTime, numExpired+numPending, 0, numPending,
)
var invoicesThatWillCancel []lntypes.Hash
for paymentHash, pendingInvoice := range newInvoices.pendingInvoices {
_, err := registry.AddInvoice(pendingInvoice, paymentHash)
invoicesThatWillCancel = append(invoicesThatWillCancel, paymentHash)
if err != nil {
t.Fatal(err)
}
}
// Check that they are really not canceled until before the clock is
// advanced.
for i := range invoicesThatWillCancel {
invoice, err := registry.LookupInvoice(invoicesThatWillCancel[i])
if err != nil {
t.Fatalf("cannot find invoice: %v", err)
}
if invoice.State == channeldb.ContractCanceled {
t.Fatalf("expected pending invoice, got canceled")
}
}
// Fwd time 1 day.
testClock.SetTime(testTime.Add(24 * time.Hour))
// Give some time to the watcher to cancel everything.
time.Sleep(testTimeout)
registry.Stop()
// Create the expected cancellation set before the final check.
expectedCancellations = append(
expectedCancellations, invoicesThatWillCancel...,
)
// Retrospectively check that all invoices that were expected to be canceled
// are indeed canceled.
for i := range expectedCancellations {
invoice, err := registry.LookupInvoice(expectedCancellations[i])
if err != nil {
t.Fatalf("cannot find invoice: %v", err)
}
if invoice.State != channeldb.ContractCanceled {
t.Fatalf("expected canceled invoice, got: %v", invoice.State)
}
}
}

280
invoices/test_utils_test.go Normal file

@ -0,0 +1,280 @@
package invoices
import (
"encoding/binary"
"encoding/hex"
"fmt"
"io/ioutil"
"os"
"runtime/pprof"
"testing"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/zpay32"
)
type mockPayload struct {
mpp *record.MPP
}
func (p *mockPayload) MultiPath() *record.MPP {
return p.mpp
}
func (p *mockPayload) CustomRecords() record.CustomSet {
return make(record.CustomSet)
}
var (
testTimeout = 5 * time.Second
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
testInvoicePreimage = lntypes.Preimage{1}
testInvoicePaymentHash = testInvoicePreimage.Hash()
testHtlcExpiry = uint32(5)
testInvoiceCltvDelta = uint32(4)
testFinalCltvRejectDelta = int32(4)
testCurrentHeight = int32(1)
testPrivKeyBytes, _ = hex.DecodeString(
"e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734")
testPrivKey, _ = btcec.PrivKeyFromBytes(
btcec.S256(), testPrivKeyBytes)
testInvoiceDescription = "coffee"
testInvoiceAmount = lnwire.MilliSatoshi(100000)
testNetParams = &chaincfg.MainNetParams
testMessageSigner = zpay32.MessageSigner{
SignCompact: func(hash []byte) ([]byte, error) {
sig, err := btcec.SignCompact(btcec.S256(), testPrivKey, hash, true)
if err != nil {
return nil, fmt.Errorf("can't sign the message: %v", err)
}
return sig, nil
},
}
testFeatures = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
testPayload = &mockPayload{}
testInvoiceCreationDate = testTime
)
var (
testInvoiceAmt = lnwire.MilliSatoshi(100000)
testInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: testInvoicePreimage,
Value: testInvoiceAmt,
Expiry: time.Hour,
Features: testFeatures,
},
CreationDate: testInvoiceCreationDate,
}
testHodlInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: channeldb.UnknownPreimage,
Value: testInvoiceAmt,
Expiry: time.Hour,
Features: testFeatures,
},
CreationDate: testInvoiceCreationDate,
}
)
func newTestChannelDB() (*channeldb.DB, func(), error) {
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName, err := ioutil.TempDir("", "channeldb")
if err != nil {
return nil, nil, err
}
// Next, create channeldb for the first time.
cdb, err := channeldb.Open(tempDirName)
if err != nil {
os.RemoveAll(tempDirName)
return nil, nil, err
}
cleanUp := func() {
cdb.Close()
os.RemoveAll(tempDirName)
}
return cdb, cleanUp, nil
}
type testContext struct {
cdb *channeldb.DB
registry *InvoiceRegistry
clock *clock.TestClock
cleanup func()
t *testing.T
}
func newTestContext(t *testing.T) *testContext {
clock := clock.NewTestClock(testTime)
cdb, cleanup, err := newTestChannelDB()
if err != nil {
t.Fatal(err)
}
cdb.Now = clock.Now
expiryWatcher := NewInvoiceExpiryWatcher(clock)
// Instantiate and start the invoice ctx.registry.
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
HtlcHoldDuration: 30 * time.Second,
Clock: clock,
}
registry := NewRegistry(cdb, expiryWatcher, &cfg)
err = registry.Start()
if err != nil {
cleanup()
t.Fatal(err)
}
ctx := testContext{
cdb: cdb,
registry: registry,
clock: clock,
t: t,
cleanup: func() {
registry.Stop()
cleanup()
},
}
return &ctx
}
func getCircuitKey(htlcID uint64) channeldb.CircuitKey {
return channeldb.CircuitKey{
ChanID: lnwire.ShortChannelID{
BlockHeight: 1, TxIndex: 2, TxPosition: 3,
},
HtlcID: htlcID,
}
}
func newTestInvoice(t *testing.T, preimage lntypes.Preimage,
timestamp time.Time, expiry time.Duration) *channeldb.Invoice {
if expiry == 0 {
expiry = time.Hour
}
rawInvoice, err := zpay32.NewInvoice(
testNetParams,
preimage.Hash(),
timestamp,
zpay32.Amount(testInvoiceAmount),
zpay32.Description(testInvoiceDescription),
zpay32.Expiry(expiry))
if err != nil {
t.Fatalf("Error while creating new invoice: %v", err)
}
paymentRequest, err := rawInvoice.Encode(testMessageSigner)
if err != nil {
t.Fatalf("Error while encoding payment request: %v", err)
}
return &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: preimage,
Value: testInvoiceAmount,
Expiry: expiry,
Features: testFeatures,
},
PaymentRequest: []byte(paymentRequest),
CreationDate: timestamp,
}
}
// timeout implements a test level timeout.
func timeout() func() {
done := make(chan struct{})
go func() {
select {
case <-time.After(5 * time.Second):
err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
if err != nil {
panic(fmt.Sprintf("error writing to std out after timeout: %v", err))
}
panic("timeout")
case <-done:
}
}()
return func() {
close(done)
}
}
// invoiceExpiryTestData simply holds generated expired and pending invoices.
type invoiceExpiryTestData struct {
expiredInvoices map[lntypes.Hash]*channeldb.Invoice
pendingInvoices map[lntypes.Hash]*channeldb.Invoice
}
// generateInvoiceExpiryTestData generates the specified number of fake expired
// and pending invoices anchored to the passed now timestamp.
func generateInvoiceExpiryTestData(
t *testing.T, now time.Time,
offset, numExpired, numPending int) invoiceExpiryTestData {
var testData invoiceExpiryTestData
testData.expiredInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
testData.pendingInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
expiredCreationDate := now.Add(-24 * time.Hour)
for i := 1; i <= numExpired; i++ {
var preimage lntypes.Preimage
binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i))
expiry := time.Duration((i+offset)%24) * time.Hour
invoice := newTestInvoice(t, preimage, expiredCreationDate, expiry)
testData.expiredInvoices[preimage.Hash()] = invoice
}
for i := 1; i <= numPending; i++ {
var preimage lntypes.Preimage
binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i))
expiry := time.Duration((i+offset)%24) * time.Hour
invoice := newTestInvoice(t, preimage, now, expiry)
testData.pendingInvoices[preimage.Hash()] = invoice
}
return testData
}

@ -1,26 +0,0 @@
package invoices
import (
"os"
"runtime/pprof"
"testing"
"time"
)
// timeout implements a test level timeout.
func timeout(t *testing.T) func() {
done := make(chan struct{})
go func() {
select {
case <-time.After(5 * time.Second):
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
panic("test timeout")
case <-done:
}
}()
return func() {
close(done)
}
}

@ -9,6 +9,8 @@ import (
// PriorityQueue will be able to use that to build and restore an underlying // PriorityQueue will be able to use that to build and restore an underlying
// heap. // heap.
type PriorityQueueItem interface { type PriorityQueueItem interface {
// Less must return true if this item is ordered before other and false
// otherwise.
Less(other PriorityQueueItem) bool Less(other PriorityQueueItem) bool
} }
@ -43,7 +45,7 @@ func (pq *priorityQueue) Pop() interface{} {
return item return item
} }
// Priority wrap a standard heap in a more object-oriented structure. // PriorityQueue wraps a standard heap into a self contained class.
type PriorityQueue struct { type PriorityQueue struct {
queue priorityQueue queue priorityQueue
} }

@ -33,6 +33,7 @@ import (
"github.com/lightningnetwork/lnd/chanfitness" "github.com/lightningnetwork/lnd/chanfitness"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/channelnotifier"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/feature"
@ -381,8 +382,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
registryConfig := invoices.RegistryConfig{ registryConfig := invoices.RegistryConfig{
FinalCltvRejectDelta: defaultFinalCltvRejectDelta, FinalCltvRejectDelta: defaultFinalCltvRejectDelta,
HtlcHoldDuration: invoices.DefaultHtlcHoldDuration, HtlcHoldDuration: invoices.DefaultHtlcHoldDuration,
Now: time.Now, Clock: clock.NewDefaultClock(),
TickAfter: time.After,
} }
s := &server{ s := &server{
@ -393,7 +393,10 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
readPool: readPool, readPool: readPool,
chansToRestore: chansToRestore, chansToRestore: chansToRestore,
invoices: invoices.NewRegistry(chanDB, &registryConfig), invoices: invoices.NewRegistry(
chanDB, invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()),
&registryConfig,
),
channelNotifier: channelnotifier.New(chanDB), channelNotifier: channelnotifier.New(chanDB),

@ -82,6 +82,10 @@ const (
// This is chosen to be the maximum number of bytes that can fit into a // This is chosen to be the maximum number of bytes that can fit into a
// single QR code: https://en.wikipedia.org/wiki/QR_code#Storage // single QR code: https://en.wikipedia.org/wiki/QR_code#Storage
maxInvoiceLength = 7089 maxInvoiceLength = 7089
// DefaultInvoiceExpiry is the default expiry duration from the creation
// timestamp if expiry is set to zero.
DefaultInvoiceExpiry = time.Hour
) )
var ( var (