Merge pull request #3694 from bhandras/i3448
invoices+channeldb: reject payments to expired invoices
This commit is contained in:
commit
e34bc3d645
@ -2,6 +2,7 @@ package channeldb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
mrand "math/rand"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -393,6 +394,114 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their
|
||||||
|
// corresponding payment hashes.
|
||||||
|
func TestFetchAllInvoicesWithPaymentHash(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, cleanup, err := makeTestDB()
|
||||||
|
defer cleanup()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to make test db: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With an empty DB we expect to return no error and an empty list.
|
||||||
|
empty, err := db.FetchAllInvoicesWithPaymentHash(false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(empty) != 0 {
|
||||||
|
t.Fatalf("expected empty list as a result, got: %v", empty)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now populate the DB and check if we can get all invoices with their
|
||||||
|
// payment hashes as expected.
|
||||||
|
const numInvoices = 20
|
||||||
|
testPendingInvoices := make(map[lntypes.Hash]*Invoice)
|
||||||
|
testAllInvoices := make(map[lntypes.Hash]*Invoice)
|
||||||
|
|
||||||
|
states := []ContractState{
|
||||||
|
ContractOpen, ContractSettled, ContractCanceled, ContractAccepted,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ {
|
||||||
|
invoice, err := randInvoice(i)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create invoice: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
invoice.State = states[mrand.Intn(len(states))]
|
||||||
|
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
||||||
|
|
||||||
|
if invoice.State != ContractSettled && invoice.State != ContractCanceled {
|
||||||
|
testPendingInvoices[paymentHash] = invoice
|
||||||
|
}
|
||||||
|
|
||||||
|
testAllInvoices[paymentHash] = invoice
|
||||||
|
|
||||||
|
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
|
||||||
|
t.Fatalf("unable to add invoice: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't fetch invoices with payment hash: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(testPendingInvoices) != len(pendingInvoices) {
|
||||||
|
t.Fatalf("expected %v pending invoices, got: %v",
|
||||||
|
len(testPendingInvoices), len(pendingInvoices))
|
||||||
|
}
|
||||||
|
|
||||||
|
allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't fetch invoices with payment hash: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(testAllInvoices) != len(allInvoices) {
|
||||||
|
t.Fatalf("expected %v invoices, got: %v",
|
||||||
|
len(testAllInvoices), len(allInvoices))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range pendingInvoices {
|
||||||
|
expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("coulnd't find invoice with hash: %v",
|
||||||
|
pendingInvoices[i].PaymentHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero out add index to not confuse DeepEqual.
|
||||||
|
pendingInvoices[i].Invoice.AddIndex = 0
|
||||||
|
expected.AddIndex = 0
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(*expected, pendingInvoices[i].Invoice) {
|
||||||
|
t.Fatalf("expected: %v, got: %v",
|
||||||
|
spew.Sdump(expected), spew.Sdump(pendingInvoices[i].Invoice))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range allInvoices {
|
||||||
|
expected, ok := testAllInvoices[allInvoices[i].PaymentHash]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("coulnd't find invoice with hash: %v",
|
||||||
|
allInvoices[i].PaymentHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero out add index to not confuse DeepEqual.
|
||||||
|
allInvoices[i].Invoice.AddIndex = 0
|
||||||
|
expected.AddIndex = 0
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(*expected, allInvoices[i].Invoice) {
|
||||||
|
t.Fatalf("expected: %v, got: %v",
|
||||||
|
spew.Sdump(expected), spew.Sdump(allInvoices[i].Invoice))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it
|
// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it
|
||||||
// twice, then the second time we also receive the invoice that we settled as a
|
// twice, then the second time we also receive the invoice that we settled as a
|
||||||
// return argument.
|
// return argument.
|
||||||
|
@ -565,6 +565,83 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
|
|||||||
return invoice, nil
|
return invoice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InvoiceWithPaymentHash is used to store an invoice and its corresponding
|
||||||
|
// payment hash. This struct is only used to store results of
|
||||||
|
// ChannelDB.FetchAllInvoicesWithPaymentHash() call.
|
||||||
|
type InvoiceWithPaymentHash struct {
|
||||||
|
// Invoice holds the invoice as selected from the invoices bucket.
|
||||||
|
Invoice Invoice
|
||||||
|
|
||||||
|
// PaymentHash is the payment hash for the Invoice.
|
||||||
|
PaymentHash lntypes.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes
|
||||||
|
// currently stored within the database. If the pendingOnly param is true, then
|
||||||
|
// only unsettled invoices and their payment hashes will be returned, skipping
|
||||||
|
// all invoices that are fully settled or canceled. Note that the returned
|
||||||
|
// array is not ordered by add index.
|
||||||
|
func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) (
|
||||||
|
[]InvoiceWithPaymentHash, error) {
|
||||||
|
|
||||||
|
var result []InvoiceWithPaymentHash
|
||||||
|
|
||||||
|
err := d.View(func(tx *bbolt.Tx) error {
|
||||||
|
invoices := tx.Bucket(invoiceBucket)
|
||||||
|
if invoices == nil {
|
||||||
|
return ErrNoInvoicesCreated
|
||||||
|
}
|
||||||
|
|
||||||
|
invoiceIndex := invoices.Bucket(invoiceIndexBucket)
|
||||||
|
if invoiceIndex == nil {
|
||||||
|
// Mask the error if there's no invoice
|
||||||
|
// index as that simply means there are no
|
||||||
|
// invoices added yet to the DB. In this case
|
||||||
|
// we simply return an empty list.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return invoiceIndex.ForEach(func(k, v []byte) error {
|
||||||
|
// Skip the special numInvoicesKey as that does not
|
||||||
|
// point to a valid invoice.
|
||||||
|
if bytes.Equal(k, numInvoicesKey) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
invoice, err := fetchInvoice(v, invoices)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pendingOnly &&
|
||||||
|
(invoice.State == ContractSettled ||
|
||||||
|
invoice.State == ContractCanceled) {
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
invoiceWithPaymentHash := InvoiceWithPaymentHash{
|
||||||
|
Invoice: invoice,
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(invoiceWithPaymentHash.PaymentHash[:], k)
|
||||||
|
result = append(result, invoiceWithPaymentHash)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// FetchAllInvoices returns all invoices currently stored within the database.
|
// FetchAllInvoices returns all invoices currently stored within the database.
|
||||||
// If the pendingOnly param is true, then only unsettled invoices will be
|
// If the pendingOnly param is true, then only unsettled invoices will be
|
||||||
// returned, skipping all invoices that are fully settled.
|
// returned, skipping all invoices that are fully settled.
|
||||||
|
24
clock/default_clock.go
Normal file
24
clock/default_clock.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package clock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultClock implements Clock interface by simply calling the appropriate
|
||||||
|
// time functions.
|
||||||
|
type DefaultClock struct{}
|
||||||
|
|
||||||
|
// NewDefaultClock constructs a new DefaultClock.
|
||||||
|
func NewDefaultClock() Clock {
|
||||||
|
return &DefaultClock{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now simply returns time.Now().
|
||||||
|
func (DefaultClock) Now() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TickAfter simply wraps time.After().
|
||||||
|
func (DefaultClock) TickAfter(duration time.Duration) <-chan time.Time {
|
||||||
|
return time.After(duration)
|
||||||
|
}
|
16
clock/interface.go
Normal file
16
clock/interface.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package clock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Clock is an interface that provides a time functions for LND packages.
|
||||||
|
// This is useful during testing when a concrete time reference is needed.
|
||||||
|
type Clock interface {
|
||||||
|
// Now returns the current local time (as defined by the Clock).
|
||||||
|
Now() time.Time
|
||||||
|
|
||||||
|
// TickAfter returns a channel that will receive a tick after the specified
|
||||||
|
// duration has passed.
|
||||||
|
TickAfter(duration time.Duration) <-chan time.Time
|
||||||
|
}
|
@ -1,42 +1,40 @@
|
|||||||
package invoices
|
package clock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testClock can be used in tests to mock time.
|
// TestClock can be used in tests to mock time.
|
||||||
type testClock struct {
|
type TestClock struct {
|
||||||
currentTime time.Time
|
currentTime time.Time
|
||||||
timeChanMap map[time.Time][]chan time.Time
|
timeChanMap map[time.Time][]chan time.Time
|
||||||
timeLock sync.Mutex
|
timeLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTestClock returns a new test clock.
|
// NewTestClock returns a new test clock.
|
||||||
func newTestClock(startTime time.Time) *testClock {
|
func NewTestClock(startTime time.Time) *TestClock {
|
||||||
return &testClock{
|
return &TestClock{
|
||||||
currentTime: startTime,
|
currentTime: startTime,
|
||||||
timeChanMap: make(map[time.Time][]chan time.Time),
|
timeChanMap: make(map[time.Time][]chan time.Time),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// now returns the current (test) time.
|
// Now returns the current (test) time.
|
||||||
func (c *testClock) now() time.Time {
|
func (c *TestClock) Now() time.Time {
|
||||||
c.timeLock.Lock()
|
c.timeLock.Lock()
|
||||||
defer c.timeLock.Unlock()
|
defer c.timeLock.Unlock()
|
||||||
|
|
||||||
return c.currentTime
|
return c.currentTime
|
||||||
}
|
}
|
||||||
|
|
||||||
// tickAfter returns a channel that will receive a tick at the specified time.
|
// TickAfter returns a channel that will receive a tick after the specified
|
||||||
func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time {
|
// duration has passed passed by the user set test time.
|
||||||
|
func (c *TestClock) TickAfter(duration time.Duration) <-chan time.Time {
|
||||||
c.timeLock.Lock()
|
c.timeLock.Lock()
|
||||||
defer c.timeLock.Unlock()
|
defer c.timeLock.Unlock()
|
||||||
|
|
||||||
triggerTime := c.currentTime.Add(duration)
|
triggerTime := c.currentTime.Add(duration)
|
||||||
log.Debugf("tickAfter called: duration=%v, trigger_time=%v",
|
|
||||||
duration, triggerTime)
|
|
||||||
|
|
||||||
ch := make(chan time.Time, 1)
|
ch := make(chan time.Time, 1)
|
||||||
|
|
||||||
// If already expired, tick immediately.
|
// If already expired, tick immediately.
|
||||||
@ -53,8 +51,8 @@ func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time {
|
|||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
// setTime sets the (test) time and triggers tick channels when they expire.
|
// SetTime sets the (test) time and triggers tick channels when they expire.
|
||||||
func (c *testClock) setTime(now time.Time) {
|
func (c *TestClock) SetTime(now time.Time) {
|
||||||
c.timeLock.Lock()
|
c.timeLock.Lock()
|
||||||
defer c.timeLock.Unlock()
|
defer c.timeLock.Unlock()
|
||||||
|
|
63
clock/test_clock_test.go
Normal file
63
clock/test_clock_test.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package clock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testTime = time.Date(2009, time.January, 3, 12, 0, 0, 0, time.UTC)
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNow(t *testing.T) {
|
||||||
|
c := NewTestClock(testTime)
|
||||||
|
now := c.Now()
|
||||||
|
|
||||||
|
if now != testTime {
|
||||||
|
t.Fatalf("expected: %v, got: %v", testTime, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
now = now.Add(time.Hour)
|
||||||
|
c.SetTime(now)
|
||||||
|
if c.Now() != now {
|
||||||
|
t.Fatalf("epected: %v, got: %v", now, c.Now())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTickAfter(t *testing.T) {
|
||||||
|
c := NewTestClock(testTime)
|
||||||
|
|
||||||
|
// Should be ticking immediately.
|
||||||
|
ticker0 := c.TickAfter(0)
|
||||||
|
|
||||||
|
// Both should be ticking after SetTime
|
||||||
|
ticker1 := c.TickAfter(time.Hour)
|
||||||
|
ticker2 := c.TickAfter(time.Hour)
|
||||||
|
|
||||||
|
// We don't expect this one to tick.
|
||||||
|
ticker3 := c.TickAfter(2 * time.Hour)
|
||||||
|
|
||||||
|
tickOrTimeOut := func(ticker <-chan time.Time, expectTick bool) {
|
||||||
|
tick := false
|
||||||
|
select {
|
||||||
|
case <-ticker:
|
||||||
|
tick = true
|
||||||
|
case <-time.After(time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
if tick != expectTick {
|
||||||
|
t.Fatalf("expected tick: %v, ticked: %v", expectTick, tick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tickOrTimeOut(ticker0, true)
|
||||||
|
tickOrTimeOut(ticker1, false)
|
||||||
|
tickOrTimeOut(ticker2, false)
|
||||||
|
tickOrTimeOut(ticker3, false)
|
||||||
|
|
||||||
|
c.SetTime(c.Now().Add(time.Hour))
|
||||||
|
|
||||||
|
tickOrTimeOut(ticker1, true)
|
||||||
|
tickOrTimeOut(ticker2, true)
|
||||||
|
tickOrTimeOut(ticker3, false)
|
||||||
|
}
|
@ -22,6 +22,7 @@ import (
|
|||||||
sphinx "github.com/lightningnetwork/lightning-onion"
|
sphinx "github.com/lightningnetwork/lightning-onion"
|
||||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
"github.com/lightningnetwork/lnd/contractcourt"
|
"github.com/lightningnetwork/lnd/contractcourt"
|
||||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||||
"github.com/lightningnetwork/lnd/input"
|
"github.com/lightningnetwork/lnd/input"
|
||||||
@ -792,6 +793,7 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
|
|||||||
|
|
||||||
registry := invoices.NewRegistry(
|
registry := invoices.NewRegistry(
|
||||||
cdb,
|
cdb,
|
||||||
|
invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()),
|
||||||
&invoices.RegistryConfig{
|
&invoices.RegistryConfig{
|
||||||
FinalCltvRejectDelta: 5,
|
FinalCltvRejectDelta: 5,
|
||||||
},
|
},
|
||||||
|
191
invoices/invoice_expiry_watcher.go
Normal file
191
invoices/invoice_expiry_watcher.go
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
package invoices
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
"github.com/lightningnetwork/lnd/queue"
|
||||||
|
"github.com/lightningnetwork/lnd/zpay32"
|
||||||
|
)
|
||||||
|
|
||||||
|
// invoiceExpiry holds and invoice's payment hash and its expiry. This
|
||||||
|
// is used to order invoices by their expiry for cancellation.
|
||||||
|
type invoiceExpiry struct {
|
||||||
|
PaymentHash lntypes.Hash
|
||||||
|
Expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Less implements PriorityQueueItem.Less such that the top item in the
|
||||||
|
// priorty queue will be the one that expires next.
|
||||||
|
func (e invoiceExpiry) Less(other queue.PriorityQueueItem) bool {
|
||||||
|
return e.Expiry.Before(other.(*invoiceExpiry).Expiry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvoiceExpiryWatcher handles automatic invoice cancellation of expried
|
||||||
|
// invoices. Upon start InvoiceExpiryWatcher will retrieve all pending (not yet
|
||||||
|
// settled or canceled) invoices invoices to its watcing queue. When a new
|
||||||
|
// invoice is added to the InvoiceRegistry, it'll be forarded to the
|
||||||
|
// InvoiceExpiryWatcher and will end up in the watching queue as well.
|
||||||
|
// If any of the watched invoices expire, they'll be removed from the watching
|
||||||
|
// queue and will be cancelled through InvoiceRegistry.CancelInvoice().
|
||||||
|
type InvoiceExpiryWatcher struct {
|
||||||
|
sync.Mutex
|
||||||
|
started bool
|
||||||
|
|
||||||
|
// clock is the clock implementation that InvoiceExpiryWatcher uses.
|
||||||
|
// It is useful for testing.
|
||||||
|
clock clock.Clock
|
||||||
|
|
||||||
|
// cancelInvoice is a template method that cancels an expired invoice.
|
||||||
|
cancelInvoice func(lntypes.Hash) error
|
||||||
|
|
||||||
|
// expiryQueue holds invoiceExpiry items and is used to find the next
|
||||||
|
// invoice to expire.
|
||||||
|
expiryQueue queue.PriorityQueue
|
||||||
|
|
||||||
|
// newInvoices channel is used to wake up the main loop when a new invoices
|
||||||
|
// is added.
|
||||||
|
newInvoices chan *invoiceExpiry
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// quit signals InvoiceExpiryWatcher to stop.
|
||||||
|
quit chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInvoiceExpiryWatcher creates a new InvoiceExpiryWatcher instance.
|
||||||
|
func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher {
|
||||||
|
return &InvoiceExpiryWatcher{
|
||||||
|
clock: clock,
|
||||||
|
newInvoices: make(chan *invoiceExpiry),
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the the subscription handler and the main loop. Start() will
|
||||||
|
// return with error if InvoiceExpiryWatcher is already started. Start()
|
||||||
|
// expects a cancellation function passed that will be use to cancel expired
|
||||||
|
// invoices by their payment hash.
|
||||||
|
func (ew *InvoiceExpiryWatcher) Start(
|
||||||
|
cancelInvoice func(lntypes.Hash) error) error {
|
||||||
|
|
||||||
|
ew.Lock()
|
||||||
|
defer ew.Unlock()
|
||||||
|
|
||||||
|
if ew.started {
|
||||||
|
return fmt.Errorf("InvoiceExpiryWatcher already started")
|
||||||
|
}
|
||||||
|
|
||||||
|
ew.started = true
|
||||||
|
ew.cancelInvoice = cancelInvoice
|
||||||
|
ew.wg.Add(1)
|
||||||
|
go ew.mainLoop()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop quits the expiry handler loop and waits for InvoiceExpiryWatcher to
|
||||||
|
// fully stop.
|
||||||
|
func (ew *InvoiceExpiryWatcher) Stop() {
|
||||||
|
ew.Lock()
|
||||||
|
defer ew.Unlock()
|
||||||
|
|
||||||
|
if ew.started {
|
||||||
|
// Signal subscriptionHandler to quit and wait for it to return.
|
||||||
|
close(ew.quit)
|
||||||
|
ew.wg.Wait()
|
||||||
|
ew.started = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddInvoice adds a new invoice to the InvoiceExpiryWatcher. This won't check
|
||||||
|
// if the invoice is already added and will only add invoices with ContractOpen
|
||||||
|
// state.
|
||||||
|
func (ew *InvoiceExpiryWatcher) AddInvoice(
|
||||||
|
paymentHash lntypes.Hash, invoice *channeldb.Invoice) {
|
||||||
|
|
||||||
|
if invoice.State != channeldb.ContractOpen {
|
||||||
|
log.Debugf("Invoice not added to expiry watcher: %v", invoice)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
realExpiry := invoice.Terms.Expiry
|
||||||
|
if realExpiry == 0 {
|
||||||
|
realExpiry = zpay32.DefaultInvoiceExpiry
|
||||||
|
}
|
||||||
|
|
||||||
|
expiry := invoice.CreationDate.Add(realExpiry)
|
||||||
|
|
||||||
|
log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v",
|
||||||
|
paymentHash, expiry)
|
||||||
|
|
||||||
|
newInvoiceExpiry := &invoiceExpiry{
|
||||||
|
PaymentHash: paymentHash,
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case ew.newInvoices <- newInvoiceExpiry:
|
||||||
|
case <-ew.quit:
|
||||||
|
// Select on quit too so that callers won't get blocked in case
|
||||||
|
// of concurrent shutdown.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextExpiry returns a Time chan to wait on until the next invoice expires.
|
||||||
|
// If there are no active invoices, then it'll simply wait indefinitely.
|
||||||
|
func (ew *InvoiceExpiryWatcher) nextExpiry() <-chan time.Time {
|
||||||
|
if !ew.expiryQueue.Empty() {
|
||||||
|
top := ew.expiryQueue.Top().(*invoiceExpiry)
|
||||||
|
return ew.clock.TickAfter(top.Expiry.Sub(ew.clock.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancelExpiredInvoices will cancel all expired invoices and removes them from
|
||||||
|
// the expiry queue.
|
||||||
|
func (ew *InvoiceExpiryWatcher) cancelExpiredInvoices() {
|
||||||
|
for !ew.expiryQueue.Empty() {
|
||||||
|
top := ew.expiryQueue.Top().(*invoiceExpiry)
|
||||||
|
if !top.Expiry.Before(ew.clock.Now()) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ew.cancelInvoice(top.PaymentHash)
|
||||||
|
if err != nil && err != channeldb.ErrInvoiceAlreadySettled &&
|
||||||
|
err != channeldb.ErrInvoiceAlreadyCanceled {
|
||||||
|
|
||||||
|
log.Errorf("Unable to cancel invoice: %v", top.PaymentHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
ew.expiryQueue.Pop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mainLoop is a goroutine that receives new invoices and handles cancellation
|
||||||
|
// of expired invoices.
|
||||||
|
func (ew *InvoiceExpiryWatcher) mainLoop() {
|
||||||
|
defer ew.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Cancel any invoices that may have expired.
|
||||||
|
ew.cancelExpiredInvoices()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ew.nextExpiry():
|
||||||
|
// Wait until the next invoice expires, then cancel expired invoices.
|
||||||
|
continue
|
||||||
|
|
||||||
|
case newInvoiceExpiry := <-ew.newInvoices:
|
||||||
|
ew.expiryQueue.Push(newInvoiceExpiry)
|
||||||
|
|
||||||
|
case <-ew.quit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
125
invoices/invoice_expiry_watcher_test.go
Normal file
125
invoices/invoice_expiry_watcher_test.go
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
package invoices
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// invoiceExpiryWatcherTest holds a test fixture and implements checks
|
||||||
|
// for InvoiceExpiryWatcher tests.
|
||||||
|
type invoiceExpiryWatcherTest struct {
|
||||||
|
t *testing.T
|
||||||
|
watcher *InvoiceExpiryWatcher
|
||||||
|
testData invoiceExpiryTestData
|
||||||
|
canceledInvoices []lntypes.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture
|
||||||
|
// and sets up the test environment.
|
||||||
|
func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time,
|
||||||
|
numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest {
|
||||||
|
|
||||||
|
test := &invoiceExpiryWatcherTest{
|
||||||
|
watcher: NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)),
|
||||||
|
testData: generateInvoiceExpiryTestData(
|
||||||
|
t, now, 0, numExpiredInvoices, numPendingInvoices,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := test.watcher.Start(func(paymentHash lntypes.Hash) error {
|
||||||
|
test.canceledInvoices = append(test.canceledInvoices, paymentHash)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return test
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *invoiceExpiryWatcherTest) checkExpectations() {
|
||||||
|
// Check that invoices that got canceled during the test are the ones
|
||||||
|
// that expired.
|
||||||
|
if len(t.canceledInvoices) != len(t.testData.expiredInvoices) {
|
||||||
|
t.t.Fatalf("expected %v cancellations, got %v",
|
||||||
|
len(t.testData.expiredInvoices), len(t.canceledInvoices))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range t.canceledInvoices {
|
||||||
|
if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok {
|
||||||
|
t.t.Fatalf("wrong invoice canceled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that InvoiceExpiryWatcher can be started and stopped.
|
||||||
|
func TestInvoiceExpiryWatcherStartStop(t *testing.T) {
|
||||||
|
watcher := NewInvoiceExpiryWatcher(clock.NewTestClock(testTime))
|
||||||
|
cancel := func(lntypes.Hash) error {
|
||||||
|
t.Fatalf("unexpected call")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := watcher.Start(cancel); err != nil {
|
||||||
|
t.Fatalf("unexpected error upon start: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := watcher.Start(cancel); err == nil {
|
||||||
|
t.Fatalf("expected error upon second start")
|
||||||
|
}
|
||||||
|
|
||||||
|
watcher.Stop()
|
||||||
|
|
||||||
|
if err := watcher.Start(cancel); err != nil {
|
||||||
|
t.Fatalf("unexpected error upon start: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that no invoices will expire from an empty InvoiceExpiryWatcher.
|
||||||
|
func TestInvoiceExpiryWithNoInvoices(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0)
|
||||||
|
|
||||||
|
time.Sleep(testTimeout)
|
||||||
|
test.watcher.Stop()
|
||||||
|
test.checkExpectations()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that if all add invoices are expired, then all invoices
|
||||||
|
// will be canceled.
|
||||||
|
func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5)
|
||||||
|
|
||||||
|
for paymentHash, invoice := range test.testData.pendingInvoices {
|
||||||
|
test.watcher.AddInvoice(paymentHash, invoice)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(testTimeout)
|
||||||
|
test.watcher.Stop()
|
||||||
|
test.checkExpectations()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that if some invoices are expired, then those invoices
|
||||||
|
// will be canceled.
|
||||||
|
func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
|
||||||
|
|
||||||
|
for paymentHash, invoice := range test.testData.expiredInvoices {
|
||||||
|
test.watcher.AddInvoice(paymentHash, invoice)
|
||||||
|
}
|
||||||
|
|
||||||
|
for paymentHash, invoice := range test.testData.pendingInvoices {
|
||||||
|
test.watcher.AddInvoice(paymentHash, invoice)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(testTimeout)
|
||||||
|
test.watcher.Stop()
|
||||||
|
test.checkExpectations()
|
||||||
|
}
|
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/queue"
|
"github.com/lightningnetwork/lnd/queue"
|
||||||
@ -62,12 +63,10 @@ type RegistryConfig struct {
|
|||||||
// waiting for the other set members to arrive.
|
// waiting for the other set members to arrive.
|
||||||
HtlcHoldDuration time.Duration
|
HtlcHoldDuration time.Duration
|
||||||
|
|
||||||
// Now returns the current time.
|
// Clock holds the clock implementation that is used to provide
|
||||||
Now func() time.Time
|
// Now() and TickAfter() and is useful to stub out the clock functions
|
||||||
|
// during testing.
|
||||||
// TickAfter returns a channel that is sent on after the specified
|
Clock clock.Clock
|
||||||
// duration as passed.
|
|
||||||
TickAfter func(duration time.Duration) <-chan time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// htlcReleaseEvent describes an htlc auto-release event. It is used to release
|
// htlcReleaseEvent describes an htlc auto-release event. It is used to release
|
||||||
@ -126,6 +125,8 @@ type InvoiceRegistry struct {
|
|||||||
// auto-released.
|
// auto-released.
|
||||||
htlcAutoReleaseChan chan *htlcReleaseEvent
|
htlcAutoReleaseChan chan *htlcReleaseEvent
|
||||||
|
|
||||||
|
expiryWatcher *InvoiceExpiryWatcher
|
||||||
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
}
|
}
|
||||||
@ -134,7 +135,9 @@ type InvoiceRegistry struct {
|
|||||||
// wraps the persistent on-disk invoice storage with an additional in-memory
|
// wraps the persistent on-disk invoice storage with an additional in-memory
|
||||||
// layer. The in-memory layer is in place such that debug invoices can be added
|
// layer. The in-memory layer is in place such that debug invoices can be added
|
||||||
// which are volatile yet available system wide within the daemon.
|
// which are volatile yet available system wide within the daemon.
|
||||||
func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry {
|
func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher,
|
||||||
|
cfg *RegistryConfig) *InvoiceRegistry {
|
||||||
|
|
||||||
return &InvoiceRegistry{
|
return &InvoiceRegistry{
|
||||||
cdb: cdb,
|
cdb: cdb,
|
||||||
notificationClients: make(map[uint32]*InvoiceSubscription),
|
notificationClients: make(map[uint32]*InvoiceSubscription),
|
||||||
@ -146,21 +149,62 @@ func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry {
|
|||||||
hodlReverseSubscriptions: make(map[chan<- interface{}]map[channeldb.CircuitKey]struct{}),
|
hodlReverseSubscriptions: make(map[chan<- interface{}]map[channeldb.CircuitKey]struct{}),
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
htlcAutoReleaseChan: make(chan *htlcReleaseEvent),
|
htlcAutoReleaseChan: make(chan *htlcReleaseEvent),
|
||||||
|
expiryWatcher: expiryWatcher,
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// populateExpiryWatcher fetches all active invoices and their corresponding
|
||||||
|
// payment hashes from ChannelDB and adds them to the expiry watcher.
|
||||||
|
func (i *InvoiceRegistry) populateExpiryWatcher() error {
|
||||||
|
pendingOnly := true
|
||||||
|
pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly)
|
||||||
|
if err != nil && err != channeldb.ErrNoInvoicesCreated {
|
||||||
|
log.Errorf(
|
||||||
|
"Error while prefetching active invoices from the database: %v", err,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx := range pendingInvoices {
|
||||||
|
i.expiryWatcher.AddInvoice(
|
||||||
|
pendingInvoices[idx].PaymentHash, &pendingInvoices[idx].Invoice,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the registry and all goroutines it needs to carry out its task.
|
// Start starts the registry and all goroutines it needs to carry out its task.
|
||||||
func (i *InvoiceRegistry) Start() error {
|
func (i *InvoiceRegistry) Start() error {
|
||||||
i.wg.Add(1)
|
// Start InvoiceExpiryWatcher and prepopulate it with existing active
|
||||||
|
// invoices.
|
||||||
|
err := i.expiryWatcher.Start(func(paymentHash lntypes.Hash) error {
|
||||||
|
cancelIfAccepted := false
|
||||||
|
return i.cancelInvoiceImpl(paymentHash, cancelIfAccepted)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
i.wg.Add(1)
|
||||||
go i.invoiceEventLoop()
|
go i.invoiceEventLoop()
|
||||||
|
|
||||||
|
// Now prefetch all pending invoices to the expiry watcher.
|
||||||
|
err = i.populateExpiryWatcher()
|
||||||
|
if err != nil {
|
||||||
|
i.Stop()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop signals the registry for a graceful shutdown.
|
// Stop signals the registry for a graceful shutdown.
|
||||||
func (i *InvoiceRegistry) Stop() {
|
func (i *InvoiceRegistry) Stop() {
|
||||||
|
i.expiryWatcher.Stop()
|
||||||
|
|
||||||
close(i.quit)
|
close(i.quit)
|
||||||
|
|
||||||
i.wg.Wait()
|
i.wg.Wait()
|
||||||
@ -177,8 +221,8 @@ type invoiceEvent struct {
|
|||||||
// tickAt returns a channel that ticks at the specified time. If the time has
|
// tickAt returns a channel that ticks at the specified time. If the time has
|
||||||
// already passed, it will tick immediately.
|
// already passed, it will tick immediately.
|
||||||
func (i *InvoiceRegistry) tickAt(t time.Time) <-chan time.Time {
|
func (i *InvoiceRegistry) tickAt(t time.Time) <-chan time.Time {
|
||||||
now := i.cfg.Now()
|
now := i.cfg.Clock.Now()
|
||||||
return i.cfg.TickAfter(t.Sub(now))
|
return i.cfg.Clock.TickAfter(t.Sub(now))
|
||||||
}
|
}
|
||||||
|
|
||||||
// invoiceEventLoop is the dedicated goroutine responsible for accepting
|
// invoiceEventLoop is the dedicated goroutine responsible for accepting
|
||||||
@ -471,7 +515,6 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice,
|
|||||||
paymentHash lntypes.Hash) (uint64, error) {
|
paymentHash lntypes.Hash) (uint64, error) {
|
||||||
|
|
||||||
i.Lock()
|
i.Lock()
|
||||||
defer i.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("Invoice(%v): added %v", paymentHash,
|
log.Debugf("Invoice(%v): added %v", paymentHash,
|
||||||
newLogClosure(func() string {
|
newLogClosure(func() string {
|
||||||
@ -481,12 +524,19 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice,
|
|||||||
|
|
||||||
addIndex, err := i.cdb.AddInvoice(invoice, paymentHash)
|
addIndex, err := i.cdb.AddInvoice(invoice, paymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
i.Unlock()
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now that we've added the invoice, we'll send dispatch a message to
|
// Now that we've added the invoice, we'll send dispatch a message to
|
||||||
// notify the clients of this new invoice.
|
// notify the clients of this new invoice.
|
||||||
i.notifyClients(paymentHash, invoice, channeldb.ContractOpen)
|
i.notifyClients(paymentHash, invoice, channeldb.ContractOpen)
|
||||||
|
i.Unlock()
|
||||||
|
|
||||||
|
// InvoiceExpiryWatcher.AddInvoice must not be locked by InvoiceRegistry
|
||||||
|
// to avoid deadlock when a new invoice is added while an other is being
|
||||||
|
// canceled.
|
||||||
|
i.expiryWatcher.AddInvoice(paymentHash, invoice)
|
||||||
|
|
||||||
return addIndex, nil
|
return addIndex, nil
|
||||||
}
|
}
|
||||||
@ -818,6 +868,15 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error {
|
|||||||
// CancelInvoice attempts to cancel the invoice corresponding to the passed
|
// CancelInvoice attempts to cancel the invoice corresponding to the passed
|
||||||
// payment hash.
|
// payment hash.
|
||||||
func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
|
func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
|
||||||
|
return i.cancelInvoiceImpl(payHash, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancelInvoice attempts to cancel the invoice corresponding to the passed
|
||||||
|
// payment hash. Accepted invoices will only be canceled if explicitly
|
||||||
|
// requested to do so.
|
||||||
|
func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash,
|
||||||
|
cancelAccepted bool) error {
|
||||||
|
|
||||||
i.Lock()
|
i.Lock()
|
||||||
defer i.Unlock()
|
defer i.Unlock()
|
||||||
|
|
||||||
@ -826,6 +885,12 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
|
|||||||
updateInvoice := func(invoice *channeldb.Invoice) (
|
updateInvoice := func(invoice *channeldb.Invoice) (
|
||||||
*channeldb.InvoiceUpdateDesc, error) {
|
*channeldb.InvoiceUpdateDesc, error) {
|
||||||
|
|
||||||
|
// Only cancel the invoice in ContractAccepted state if explicitly
|
||||||
|
// requested to do so.
|
||||||
|
if invoice.State == channeldb.ContractAccepted && !cancelAccepted {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Move invoice to the canceled state. Rely on validation in
|
// Move invoice to the canceled state. Rely on validation in
|
||||||
// channeldb to return an error if the invoice is already
|
// channeldb to return an error if the invoice is already
|
||||||
// settled or canceled.
|
// settled or canceled.
|
||||||
@ -848,6 +913,13 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return without cancellation if the invoice state is ContractAccepted.
|
||||||
|
if invoice.State == channeldb.ContractAccepted {
|
||||||
|
log.Debugf("Invoice(%v): remains accepted as cancel wasn't"+
|
||||||
|
"explicitly requested.", payHash)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("Invoice(%v): canceled", payHash)
|
log.Debugf("Invoice(%v): canceled", payHash)
|
||||||
|
|
||||||
// In the callback, some htlcs may have been moved to the canceled
|
// In the callback, some htlcs may have been moved to the canceled
|
||||||
|
@ -1,117 +1,16 @@
|
|||||||
package invoices
|
package invoices
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/record"
|
"github.com/lightningnetwork/lnd/record"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
testTimeout = 5 * time.Second
|
|
||||||
|
|
||||||
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
|
|
||||||
|
|
||||||
preimage = lntypes.Preimage{
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
hash = preimage.Hash()
|
|
||||||
|
|
||||||
testHtlcExpiry = uint32(5)
|
|
||||||
|
|
||||||
testInvoiceCltvDelta = uint32(4)
|
|
||||||
|
|
||||||
testFinalCltvRejectDelta = int32(4)
|
|
||||||
|
|
||||||
testCurrentHeight = int32(1)
|
|
||||||
|
|
||||||
testFeatures = lnwire.NewFeatureVector(
|
|
||||||
nil, lnwire.Features,
|
|
||||||
)
|
|
||||||
|
|
||||||
testPayload = &mockPayload{}
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
testInvoiceAmt = lnwire.MilliSatoshi(100000)
|
|
||||||
testInvoice = &channeldb.Invoice{
|
|
||||||
Terms: channeldb.ContractTerm{
|
|
||||||
PaymentPreimage: preimage,
|
|
||||||
Value: lnwire.MilliSatoshi(100000),
|
|
||||||
Features: testFeatures,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
testHodlInvoice = &channeldb.Invoice{
|
|
||||||
Terms: channeldb.ContractTerm{
|
|
||||||
PaymentPreimage: channeldb.UnknownPreimage,
|
|
||||||
Value: testInvoiceAmt,
|
|
||||||
Features: testFeatures,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type testContext struct {
|
|
||||||
registry *InvoiceRegistry
|
|
||||||
clock *testClock
|
|
||||||
|
|
||||||
cleanup func()
|
|
||||||
t *testing.T
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestContext(t *testing.T) *testContext {
|
|
||||||
clock := newTestClock(testTime)
|
|
||||||
|
|
||||||
cdb, cleanup, err := newDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
cdb.Now = clock.now
|
|
||||||
|
|
||||||
// Instantiate and start the invoice ctx.registry.
|
|
||||||
cfg := RegistryConfig{
|
|
||||||
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
|
||||||
HtlcHoldDuration: 30 * time.Second,
|
|
||||||
Now: clock.now,
|
|
||||||
TickAfter: clock.tickAfter,
|
|
||||||
}
|
|
||||||
registry := NewRegistry(cdb, &cfg)
|
|
||||||
|
|
||||||
err = registry.Start()
|
|
||||||
if err != nil {
|
|
||||||
cleanup()
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := testContext{
|
|
||||||
registry: registry,
|
|
||||||
clock: clock,
|
|
||||||
t: t,
|
|
||||||
cleanup: func() {
|
|
||||||
registry.Stop()
|
|
||||||
cleanup()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCircuitKey(htlcID uint64) channeldb.CircuitKey {
|
|
||||||
return channeldb.CircuitKey{
|
|
||||||
ChanID: lnwire.ShortChannelID{
|
|
||||||
BlockHeight: 1, TxIndex: 2, TxPosition: 3,
|
|
||||||
},
|
|
||||||
HtlcID: htlcID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSettleInvoice tests settling of an invoice and related notifications.
|
// TestSettleInvoice tests settling of an invoice and related notifications.
|
||||||
func TestSettleInvoice(t *testing.T) {
|
func TestSettleInvoice(t *testing.T) {
|
||||||
ctx := newTestContext(t)
|
ctx := newTestContext(t)
|
||||||
@ -121,18 +20,18 @@ func TestSettleInvoice(t *testing.T) {
|
|||||||
defer allSubscriptions.Cancel()
|
defer allSubscriptions.Cancel()
|
||||||
|
|
||||||
// Subscribe to the not yet existing invoice.
|
// Subscribe to the not yet existing invoice.
|
||||||
subscription, err := ctx.registry.SubscribeSingleInvoice(hash)
|
subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer subscription.Cancel()
|
defer subscription.Cancel()
|
||||||
|
|
||||||
if subscription.hash != hash {
|
if subscription.hash != testInvoicePaymentHash {
|
||||||
t.Fatalf("expected subscription for provided hash")
|
t.Fatalf("expected subscription for provided hash")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the invoice.
|
// Add the invoice.
|
||||||
addIdx, err := ctx.registry.AddInvoice(testInvoice, hash)
|
addIdx, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -168,7 +67,7 @@ func TestSettleInvoice(t *testing.T) {
|
|||||||
|
|
||||||
// Try to settle invoice with an htlc that expires too soon.
|
// Try to settle invoice with an htlc that expires too soon.
|
||||||
event, err := ctx.registry.NotifyExitHopHtlc(
|
event, err := ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, testInvoice.Terms.Value,
|
testInvoicePaymentHash, testInvoice.Terms.Value,
|
||||||
uint32(testCurrentHeight)+testInvoiceCltvDelta-1,
|
uint32(testCurrentHeight)+testInvoiceCltvDelta-1,
|
||||||
testCurrentHeight, getCircuitKey(10), hodlChan, testPayload,
|
testCurrentHeight, getCircuitKey(10), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
@ -186,7 +85,7 @@ func TestSettleInvoice(t *testing.T) {
|
|||||||
// Settle invoice with a slightly higher amount.
|
// Settle invoice with a slightly higher amount.
|
||||||
amtPaid := lnwire.MilliSatoshi(100500)
|
amtPaid := lnwire.MilliSatoshi(100500)
|
||||||
_, err = ctx.registry.NotifyExitHopHtlc(
|
_, err = ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||||
getCircuitKey(0), hodlChan, testPayload,
|
getCircuitKey(0), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -222,7 +121,7 @@ func TestSettleInvoice(t *testing.T) {
|
|||||||
// Try to settle again with the same htlc id. We need this idempotent
|
// Try to settle again with the same htlc id. We need this idempotent
|
||||||
// behaviour after a restart.
|
// behaviour after a restart.
|
||||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||||
getCircuitKey(0), hodlChan, testPayload,
|
getCircuitKey(0), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -236,7 +135,7 @@ func TestSettleInvoice(t *testing.T) {
|
|||||||
// should also be accepted, to prevent any change in behaviour for a
|
// should also be accepted, to prevent any change in behaviour for a
|
||||||
// paid invoice that may open up a probe vector.
|
// paid invoice that may open up a probe vector.
|
||||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid+600, testHtlcExpiry, testCurrentHeight,
|
testInvoicePaymentHash, amtPaid+600, testHtlcExpiry, testCurrentHeight,
|
||||||
getCircuitKey(1), hodlChan, testPayload,
|
getCircuitKey(1), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -249,7 +148,7 @@ func TestSettleInvoice(t *testing.T) {
|
|||||||
// Try to settle again with a lower amount. This should fail just as it
|
// Try to settle again with a lower amount. This should fail just as it
|
||||||
// would have failed if it were the first payment.
|
// would have failed if it were the first payment.
|
||||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid-600, testHtlcExpiry, testCurrentHeight,
|
testInvoicePaymentHash, amtPaid-600, testHtlcExpiry, testCurrentHeight,
|
||||||
getCircuitKey(2), hodlChan, testPayload,
|
getCircuitKey(2), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -261,7 +160,7 @@ func TestSettleInvoice(t *testing.T) {
|
|||||||
|
|
||||||
// Check that settled amount is equal to the sum of values of the htlcs
|
// Check that settled amount is equal to the sum of values of the htlcs
|
||||||
// 0 and 1.
|
// 0 and 1.
|
||||||
inv, err := ctx.registry.LookupInvoice(hash)
|
inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -270,7 +169,7 @@ func TestSettleInvoice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try to cancel.
|
// Try to cancel.
|
||||||
err = ctx.registry.CancelInvoice(hash)
|
err = ctx.registry.CancelInvoice(testInvoicePaymentHash)
|
||||||
if err != channeldb.ErrInvoiceAlreadySettled {
|
if err != channeldb.ErrInvoiceAlreadySettled {
|
||||||
t.Fatal("expected cancelation of a settled invoice to fail")
|
t.Fatal("expected cancelation of a settled invoice to fail")
|
||||||
}
|
}
|
||||||
@ -292,25 +191,25 @@ func TestCancelInvoice(t *testing.T) {
|
|||||||
defer allSubscriptions.Cancel()
|
defer allSubscriptions.Cancel()
|
||||||
|
|
||||||
// Try to cancel the not yet existing invoice. This should fail.
|
// Try to cancel the not yet existing invoice. This should fail.
|
||||||
err := ctx.registry.CancelInvoice(hash)
|
err := ctx.registry.CancelInvoice(testInvoicePaymentHash)
|
||||||
if err != channeldb.ErrInvoiceNotFound {
|
if err != channeldb.ErrInvoiceNotFound {
|
||||||
t.Fatalf("expected ErrInvoiceNotFound, but got %v", err)
|
t.Fatalf("expected ErrInvoiceNotFound, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe to the not yet existing invoice.
|
// Subscribe to the not yet existing invoice.
|
||||||
subscription, err := ctx.registry.SubscribeSingleInvoice(hash)
|
subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer subscription.Cancel()
|
defer subscription.Cancel()
|
||||||
|
|
||||||
if subscription.hash != hash {
|
if subscription.hash != testInvoicePaymentHash {
|
||||||
t.Fatalf("expected subscription for provided hash")
|
t.Fatalf("expected subscription for provided hash")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the invoice.
|
// Add the invoice.
|
||||||
amt := lnwire.MilliSatoshi(100000)
|
amt := lnwire.MilliSatoshi(100000)
|
||||||
_, err = ctx.registry.AddInvoice(testInvoice, hash)
|
_, err = ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -342,7 +241,7 @@ func TestCancelInvoice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cancel invoice.
|
// Cancel invoice.
|
||||||
err = ctx.registry.CancelInvoice(hash)
|
err = ctx.registry.CancelInvoice(testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -365,7 +264,7 @@ func TestCancelInvoice(t *testing.T) {
|
|||||||
// subscribers (backwards compatibility).
|
// subscribers (backwards compatibility).
|
||||||
|
|
||||||
// Try to cancel again.
|
// Try to cancel again.
|
||||||
err = ctx.registry.CancelInvoice(hash)
|
err = ctx.registry.CancelInvoice(testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("expected cancelation of a canceled invoice to succeed")
|
t.Fatal("expected cancelation of a canceled invoice to succeed")
|
||||||
}
|
}
|
||||||
@ -374,7 +273,7 @@ func TestCancelInvoice(t *testing.T) {
|
|||||||
// result in a cancel event.
|
// result in a cancel event.
|
||||||
hodlChan := make(chan interface{})
|
hodlChan := make(chan interface{})
|
||||||
event, err := ctx.registry.NotifyExitHopHtlc(
|
event, err := ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, amt, testHtlcExpiry, testCurrentHeight,
|
testInvoicePaymentHash, amt, testHtlcExpiry, testCurrentHeight,
|
||||||
getCircuitKey(0), hodlChan, testPayload,
|
getCircuitKey(0), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -393,9 +292,9 @@ func TestCancelInvoice(t *testing.T) {
|
|||||||
// TestSettleHoldInvoice tests settling of a hold invoice and related
|
// TestSettleHoldInvoice tests settling of a hold invoice and related
|
||||||
// notifications.
|
// notifications.
|
||||||
func TestSettleHoldInvoice(t *testing.T) {
|
func TestSettleHoldInvoice(t *testing.T) {
|
||||||
defer timeout(t)()
|
defer timeout()()
|
||||||
|
|
||||||
cdb, cleanup, err := newDB()
|
cdb, cleanup, err := newTestChannelDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -404,8 +303,9 @@ func TestSettleHoldInvoice(t *testing.T) {
|
|||||||
// Instantiate and start the invoice ctx.registry.
|
// Instantiate and start the invoice ctx.registry.
|
||||||
cfg := RegistryConfig{
|
cfg := RegistryConfig{
|
||||||
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
||||||
|
Clock: clock.NewTestClock(testTime),
|
||||||
}
|
}
|
||||||
registry := NewRegistry(cdb, &cfg)
|
registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg)
|
||||||
|
|
||||||
err = registry.Start()
|
err = registry.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -417,18 +317,18 @@ func TestSettleHoldInvoice(t *testing.T) {
|
|||||||
defer allSubscriptions.Cancel()
|
defer allSubscriptions.Cancel()
|
||||||
|
|
||||||
// Subscribe to the not yet existing invoice.
|
// Subscribe to the not yet existing invoice.
|
||||||
subscription, err := registry.SubscribeSingleInvoice(hash)
|
subscription, err := registry.SubscribeSingleInvoice(testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer subscription.Cancel()
|
defer subscription.Cancel()
|
||||||
|
|
||||||
if subscription.hash != hash {
|
if subscription.hash != testInvoicePaymentHash {
|
||||||
t.Fatalf("expected subscription for provided hash")
|
t.Fatalf("expected subscription for provided hash")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the invoice.
|
// Add the invoice.
|
||||||
_, err = registry.AddInvoice(testHodlInvoice, hash)
|
_, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -455,7 +355,7 @@ func TestSettleHoldInvoice(t *testing.T) {
|
|||||||
// NotifyExitHopHtlc without a preimage present in the invoice registry
|
// NotifyExitHopHtlc without a preimage present in the invoice registry
|
||||||
// should be possible.
|
// should be possible.
|
||||||
event, err := registry.NotifyExitHopHtlc(
|
event, err := registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||||
getCircuitKey(0), hodlChan, testPayload,
|
getCircuitKey(0), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -467,7 +367,7 @@ func TestSettleHoldInvoice(t *testing.T) {
|
|||||||
|
|
||||||
// Test idempotency.
|
// Test idempotency.
|
||||||
event, err = registry.NotifyExitHopHtlc(
|
event, err = registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||||
getCircuitKey(0), hodlChan, testPayload,
|
getCircuitKey(0), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -480,7 +380,7 @@ func TestSettleHoldInvoice(t *testing.T) {
|
|||||||
// Test replay at a higher height. We expect the same result because it
|
// Test replay at a higher height. We expect the same result because it
|
||||||
// is a replay.
|
// is a replay.
|
||||||
event, err = registry.NotifyExitHopHtlc(
|
event, err = registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight+10,
|
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+10,
|
||||||
getCircuitKey(0), hodlChan, testPayload,
|
getCircuitKey(0), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -493,7 +393,7 @@ func TestSettleHoldInvoice(t *testing.T) {
|
|||||||
// Test a new htlc coming in that doesn't meet the final cltv delta
|
// Test a new htlc coming in that doesn't meet the final cltv delta
|
||||||
// requirement. It should be rejected.
|
// requirement. It should be rejected.
|
||||||
event, err = registry.NotifyExitHopHtlc(
|
event, err = registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid, 1, testCurrentHeight,
|
testInvoicePaymentHash, amtPaid, 1, testCurrentHeight,
|
||||||
getCircuitKey(1), hodlChan, testPayload,
|
getCircuitKey(1), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -516,13 +416,13 @@ func TestSettleHoldInvoice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Settling with preimage should succeed.
|
// Settling with preimage should succeed.
|
||||||
err = registry.SettleHodlInvoice(preimage)
|
err = registry.SettleHodlInvoice(testInvoicePreimage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("expected set preimage to succeed")
|
t.Fatal("expected set preimage to succeed")
|
||||||
}
|
}
|
||||||
|
|
||||||
hodlEvent := (<-hodlChan).(HodlEvent)
|
hodlEvent := (<-hodlChan).(HodlEvent)
|
||||||
if *hodlEvent.Preimage != preimage {
|
if *hodlEvent.Preimage != testInvoicePreimage {
|
||||||
t.Fatal("unexpected preimage in hodl event")
|
t.Fatal("unexpected preimage in hodl event")
|
||||||
}
|
}
|
||||||
if hodlEvent.AcceptHeight != testCurrentHeight {
|
if hodlEvent.AcceptHeight != testCurrentHeight {
|
||||||
@ -549,13 +449,13 @@ func TestSettleHoldInvoice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Idempotency.
|
// Idempotency.
|
||||||
err = registry.SettleHodlInvoice(preimage)
|
err = registry.SettleHodlInvoice(testInvoicePreimage)
|
||||||
if err != channeldb.ErrInvoiceAlreadySettled {
|
if err != channeldb.ErrInvoiceAlreadySettled {
|
||||||
t.Fatalf("expected ErrInvoiceAlreadySettled but got %v", err)
|
t.Fatalf("expected ErrInvoiceAlreadySettled but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to cancel.
|
// Try to cancel.
|
||||||
err = registry.CancelInvoice(hash)
|
err = registry.CancelInvoice(testInvoicePaymentHash)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected cancelation of a settled invoice to fail")
|
t.Fatal("expected cancelation of a settled invoice to fail")
|
||||||
}
|
}
|
||||||
@ -564,9 +464,9 @@ func TestSettleHoldInvoice(t *testing.T) {
|
|||||||
// TestCancelHoldInvoice tests canceling of a hold invoice and related
|
// TestCancelHoldInvoice tests canceling of a hold invoice and related
|
||||||
// notifications.
|
// notifications.
|
||||||
func TestCancelHoldInvoice(t *testing.T) {
|
func TestCancelHoldInvoice(t *testing.T) {
|
||||||
defer timeout(t)()
|
defer timeout()()
|
||||||
|
|
||||||
cdb, cleanup, err := newDB()
|
cdb, cleanup, err := newTestChannelDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -575,8 +475,9 @@ func TestCancelHoldInvoice(t *testing.T) {
|
|||||||
// Instantiate and start the invoice ctx.registry.
|
// Instantiate and start the invoice ctx.registry.
|
||||||
cfg := RegistryConfig{
|
cfg := RegistryConfig{
|
||||||
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
||||||
|
Clock: clock.NewTestClock(testTime),
|
||||||
}
|
}
|
||||||
registry := NewRegistry(cdb, &cfg)
|
registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg)
|
||||||
|
|
||||||
err = registry.Start()
|
err = registry.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -585,7 +486,7 @@ func TestCancelHoldInvoice(t *testing.T) {
|
|||||||
defer registry.Stop()
|
defer registry.Stop()
|
||||||
|
|
||||||
// Add the invoice.
|
// Add the invoice.
|
||||||
_, err = registry.AddInvoice(testHodlInvoice, hash)
|
_, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -596,7 +497,7 @@ func TestCancelHoldInvoice(t *testing.T) {
|
|||||||
// NotifyExitHopHtlc without a preimage present in the invoice registry
|
// NotifyExitHopHtlc without a preimage present in the invoice registry
|
||||||
// should be possible.
|
// should be possible.
|
||||||
event, err := registry.NotifyExitHopHtlc(
|
event, err := registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||||
getCircuitKey(0), hodlChan, testPayload,
|
getCircuitKey(0), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -607,7 +508,7 @@ func TestCancelHoldInvoice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cancel invoice.
|
// Cancel invoice.
|
||||||
err = registry.CancelInvoice(hash)
|
err = registry.CancelInvoice(testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("cancel invoice failed")
|
t.Fatal("cancel invoice failed")
|
||||||
}
|
}
|
||||||
@ -621,7 +522,7 @@ func TestCancelHoldInvoice(t *testing.T) {
|
|||||||
// in a rejection. The accept height is expected to be the original
|
// in a rejection. The accept height is expected to be the original
|
||||||
// accept height.
|
// accept height.
|
||||||
event, err = registry.NotifyExitHopHtlc(
|
event, err = registry.NotifyExitHopHtlc(
|
||||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight+1,
|
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+1,
|
||||||
getCircuitKey(0), hodlChan, testPayload,
|
getCircuitKey(0), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -636,29 +537,6 @@ func TestCancelHoldInvoice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDB() (*channeldb.DB, func(), error) {
|
|
||||||
// First, create a temporary directory to be used for the duration of
|
|
||||||
// this test.
|
|
||||||
tempDirName, err := ioutil.TempDir("", "channeldb")
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, create channeldb for the first time.
|
|
||||||
cdb, err := channeldb.Open(tempDirName)
|
|
||||||
if err != nil {
|
|
||||||
os.RemoveAll(tempDirName)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cleanUp := func() {
|
|
||||||
cdb.Close()
|
|
||||||
os.RemoveAll(tempDirName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cdb, cleanUp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestUnknownInvoice tests that invoice registry returns an error when the
|
// TestUnknownInvoice tests that invoice registry returns an error when the
|
||||||
// invoice is unknown. This is to guard against returning a cancel hodl event
|
// invoice is unknown. This is to guard against returning a cancel hodl event
|
||||||
// for forwarded htlcs. In the link, NotifyExitHopHtlc is only called if we are
|
// for forwarded htlcs. In the link, NotifyExitHopHtlc is only called if we are
|
||||||
@ -673,7 +551,7 @@ func TestUnknownInvoice(t *testing.T) {
|
|||||||
hodlChan := make(chan interface{})
|
hodlChan := make(chan interface{})
|
||||||
amt := lnwire.MilliSatoshi(100000)
|
amt := lnwire.MilliSatoshi(100000)
|
||||||
_, err := ctx.registry.NotifyExitHopHtlc(
|
_, err := ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, amt, testHtlcExpiry, testCurrentHeight,
|
testInvoicePaymentHash, amt, testHtlcExpiry, testCurrentHeight,
|
||||||
getCircuitKey(0), hodlChan, testPayload,
|
getCircuitKey(0), hodlChan, testPayload,
|
||||||
)
|
)
|
||||||
if err != channeldb.ErrInvoiceNotFound {
|
if err != channeldb.ErrInvoiceNotFound {
|
||||||
@ -681,27 +559,15 @@ func TestUnknownInvoice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockPayload struct {
|
|
||||||
mpp *record.MPP
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mockPayload) MultiPath() *record.MPP {
|
|
||||||
return p.mpp
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mockPayload) CustomRecords() record.CustomSet {
|
|
||||||
return make(record.CustomSet)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSettleMpp tests settling of an invoice with multiple partial payments.
|
// TestSettleMpp tests settling of an invoice with multiple partial payments.
|
||||||
func TestSettleMpp(t *testing.T) {
|
func TestSettleMpp(t *testing.T) {
|
||||||
defer timeout(t)()
|
defer timeout()()
|
||||||
|
|
||||||
ctx := newTestContext(t)
|
ctx := newTestContext(t)
|
||||||
defer ctx.cleanup()
|
defer ctx.cleanup()
|
||||||
|
|
||||||
// Add the invoice.
|
// Add the invoice.
|
||||||
_, err := ctx.registry.AddInvoice(testInvoice, hash)
|
_, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -713,7 +579,7 @@ func TestSettleMpp(t *testing.T) {
|
|||||||
// Send htlc 1.
|
// Send htlc 1.
|
||||||
hodlChan1 := make(chan interface{}, 1)
|
hodlChan1 := make(chan interface{}, 1)
|
||||||
event, err := ctx.registry.NotifyExitHopHtlc(
|
event, err := ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, testInvoice.Terms.Value/2,
|
testInvoicePaymentHash, testInvoice.Terms.Value/2,
|
||||||
testHtlcExpiry,
|
testHtlcExpiry,
|
||||||
testCurrentHeight, getCircuitKey(10), hodlChan1, mppPayload,
|
testCurrentHeight, getCircuitKey(10), hodlChan1, mppPayload,
|
||||||
)
|
)
|
||||||
@ -725,7 +591,7 @@ func TestSettleMpp(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Simulate mpp timeout releasing htlc 1.
|
// Simulate mpp timeout releasing htlc 1.
|
||||||
ctx.clock.setTime(testTime.Add(30 * time.Second))
|
ctx.clock.SetTime(testTime.Add(30 * time.Second))
|
||||||
|
|
||||||
hodlEvent := (<-hodlChan1).(HodlEvent)
|
hodlEvent := (<-hodlChan1).(HodlEvent)
|
||||||
if hodlEvent.Preimage != nil {
|
if hodlEvent.Preimage != nil {
|
||||||
@ -735,7 +601,7 @@ func TestSettleMpp(t *testing.T) {
|
|||||||
// Send htlc 2.
|
// Send htlc 2.
|
||||||
hodlChan2 := make(chan interface{}, 1)
|
hodlChan2 := make(chan interface{}, 1)
|
||||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, testInvoice.Terms.Value/2,
|
testInvoicePaymentHash, testInvoice.Terms.Value/2,
|
||||||
testHtlcExpiry,
|
testHtlcExpiry,
|
||||||
testCurrentHeight, getCircuitKey(11), hodlChan2, mppPayload,
|
testCurrentHeight, getCircuitKey(11), hodlChan2, mppPayload,
|
||||||
)
|
)
|
||||||
@ -749,7 +615,7 @@ func TestSettleMpp(t *testing.T) {
|
|||||||
// Send htlc 3.
|
// Send htlc 3.
|
||||||
hodlChan3 := make(chan interface{}, 1)
|
hodlChan3 := make(chan interface{}, 1)
|
||||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||||
hash, testInvoice.Terms.Value/2,
|
testInvoicePaymentHash, testInvoice.Terms.Value/2,
|
||||||
testHtlcExpiry,
|
testHtlcExpiry,
|
||||||
testCurrentHeight, getCircuitKey(12), hodlChan3, mppPayload,
|
testCurrentHeight, getCircuitKey(12), hodlChan3, mppPayload,
|
||||||
)
|
)
|
||||||
@ -762,7 +628,7 @@ func TestSettleMpp(t *testing.T) {
|
|||||||
|
|
||||||
// Check that settled amount is equal to the sum of values of the htlcs
|
// Check that settled amount is equal to the sum of values of the htlcs
|
||||||
// 0 and 1.
|
// 0 and 1.
|
||||||
inv, err := ctx.registry.LookupInvoice(hash)
|
inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -774,3 +640,105 @@ func TestSettleMpp(t *testing.T) {
|
|||||||
testInvoice.Terms.Value, inv.AmtPaid)
|
testInvoice.Terms.Value, inv.AmtPaid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that invoices are canceled after expiration.
|
||||||
|
func TestInvoiceExpiryWithRegistry(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cdb, cleanup, err := newTestChannelDB()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testClock := clock.NewTestClock(testTime)
|
||||||
|
|
||||||
|
cfg := RegistryConfig{
|
||||||
|
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
||||||
|
Clock: testClock,
|
||||||
|
}
|
||||||
|
|
||||||
|
expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock)
|
||||||
|
registry := NewRegistry(cdb, expiryWatcher, &cfg)
|
||||||
|
|
||||||
|
// First prefill the Channel DB with some pre-existing invoices,
|
||||||
|
// half of them still pending, half of them expired.
|
||||||
|
const numExpired = 5
|
||||||
|
const numPending = 5
|
||||||
|
existingInvoices := generateInvoiceExpiryTestData(
|
||||||
|
t, testTime, 0, numExpired, numPending,
|
||||||
|
)
|
||||||
|
|
||||||
|
var expectedCancellations []lntypes.Hash
|
||||||
|
|
||||||
|
for paymentHash, expiredInvoice := range existingInvoices.expiredInvoices {
|
||||||
|
if _, err := cdb.AddInvoice(expiredInvoice, paymentHash); err != nil {
|
||||||
|
t.Fatalf("cannot add invoice to channel db: %v", err)
|
||||||
|
}
|
||||||
|
expectedCancellations = append(expectedCancellations, paymentHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
for paymentHash, pendingInvoice := range existingInvoices.pendingInvoices {
|
||||||
|
if _, err := cdb.AddInvoice(pendingInvoice, paymentHash); err != nil {
|
||||||
|
t.Fatalf("cannot add invoice to channel db: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = registry.Start(); err != nil {
|
||||||
|
t.Fatalf("cannot start registry: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now generate pending and invoices and add them to the registry while
|
||||||
|
// it is up and running. We'll manipulate the clock to let them expire.
|
||||||
|
newInvoices := generateInvoiceExpiryTestData(
|
||||||
|
t, testTime, numExpired+numPending, 0, numPending,
|
||||||
|
)
|
||||||
|
|
||||||
|
var invoicesThatWillCancel []lntypes.Hash
|
||||||
|
for paymentHash, pendingInvoice := range newInvoices.pendingInvoices {
|
||||||
|
_, err := registry.AddInvoice(pendingInvoice, paymentHash)
|
||||||
|
invoicesThatWillCancel = append(invoicesThatWillCancel, paymentHash)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that they are really not canceled until before the clock is
|
||||||
|
// advanced.
|
||||||
|
for i := range invoicesThatWillCancel {
|
||||||
|
invoice, err := registry.LookupInvoice(invoicesThatWillCancel[i])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot find invoice: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if invoice.State == channeldb.ContractCanceled {
|
||||||
|
t.Fatalf("expected pending invoice, got canceled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fwd time 1 day.
|
||||||
|
testClock.SetTime(testTime.Add(24 * time.Hour))
|
||||||
|
|
||||||
|
// Give some time to the watcher to cancel everything.
|
||||||
|
time.Sleep(testTimeout)
|
||||||
|
registry.Stop()
|
||||||
|
|
||||||
|
// Create the expected cancellation set before the final check.
|
||||||
|
expectedCancellations = append(
|
||||||
|
expectedCancellations, invoicesThatWillCancel...,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Retrospectively check that all invoices that were expected to be canceled
|
||||||
|
// are indeed canceled.
|
||||||
|
for i := range expectedCancellations {
|
||||||
|
invoice, err := registry.LookupInvoice(expectedCancellations[i])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot find invoice: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if invoice.State != channeldb.ContractCanceled {
|
||||||
|
t.Fatalf("expected canceled invoice, got: %v", invoice.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
280
invoices/test_utils_test.go
Normal file
280
invoices/test_utils_test.go
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
package invoices
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"runtime/pprof"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/btcsuite/btcd/chaincfg"
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/record"
|
||||||
|
"github.com/lightningnetwork/lnd/zpay32"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockPayload struct {
|
||||||
|
mpp *record.MPP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mockPayload) MultiPath() *record.MPP {
|
||||||
|
return p.mpp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mockPayload) CustomRecords() record.CustomSet {
|
||||||
|
return make(record.CustomSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
testTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
testInvoicePreimage = lntypes.Preimage{1}
|
||||||
|
|
||||||
|
testInvoicePaymentHash = testInvoicePreimage.Hash()
|
||||||
|
|
||||||
|
testHtlcExpiry = uint32(5)
|
||||||
|
|
||||||
|
testInvoiceCltvDelta = uint32(4)
|
||||||
|
|
||||||
|
testFinalCltvRejectDelta = int32(4)
|
||||||
|
|
||||||
|
testCurrentHeight = int32(1)
|
||||||
|
|
||||||
|
testPrivKeyBytes, _ = hex.DecodeString(
|
||||||
|
"e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734")
|
||||||
|
|
||||||
|
testPrivKey, _ = btcec.PrivKeyFromBytes(
|
||||||
|
btcec.S256(), testPrivKeyBytes)
|
||||||
|
|
||||||
|
testInvoiceDescription = "coffee"
|
||||||
|
|
||||||
|
testInvoiceAmount = lnwire.MilliSatoshi(100000)
|
||||||
|
|
||||||
|
testNetParams = &chaincfg.MainNetParams
|
||||||
|
|
||||||
|
testMessageSigner = zpay32.MessageSigner{
|
||||||
|
SignCompact: func(hash []byte) ([]byte, error) {
|
||||||
|
sig, err := btcec.SignCompact(btcec.S256(), testPrivKey, hash, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("can't sign the message: %v", err)
|
||||||
|
}
|
||||||
|
return sig, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testFeatures = lnwire.NewFeatureVector(
|
||||||
|
nil, lnwire.Features,
|
||||||
|
)
|
||||||
|
|
||||||
|
testPayload = &mockPayload{}
|
||||||
|
|
||||||
|
testInvoiceCreationDate = testTime
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testInvoiceAmt = lnwire.MilliSatoshi(100000)
|
||||||
|
testInvoice = &channeldb.Invoice{
|
||||||
|
Terms: channeldb.ContractTerm{
|
||||||
|
PaymentPreimage: testInvoicePreimage,
|
||||||
|
Value: testInvoiceAmt,
|
||||||
|
Expiry: time.Hour,
|
||||||
|
Features: testFeatures,
|
||||||
|
},
|
||||||
|
CreationDate: testInvoiceCreationDate,
|
||||||
|
}
|
||||||
|
|
||||||
|
testHodlInvoice = &channeldb.Invoice{
|
||||||
|
Terms: channeldb.ContractTerm{
|
||||||
|
PaymentPreimage: channeldb.UnknownPreimage,
|
||||||
|
Value: testInvoiceAmt,
|
||||||
|
Expiry: time.Hour,
|
||||||
|
Features: testFeatures,
|
||||||
|
},
|
||||||
|
CreationDate: testInvoiceCreationDate,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestChannelDB() (*channeldb.DB, func(), error) {
|
||||||
|
// First, create a temporary directory to be used for the duration of
|
||||||
|
// this test.
|
||||||
|
tempDirName, err := ioutil.TempDir("", "channeldb")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, create channeldb for the first time.
|
||||||
|
cdb, err := channeldb.Open(tempDirName)
|
||||||
|
if err != nil {
|
||||||
|
os.RemoveAll(tempDirName)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanUp := func() {
|
||||||
|
cdb.Close()
|
||||||
|
os.RemoveAll(tempDirName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cdb, cleanUp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testContext struct {
|
||||||
|
cdb *channeldb.DB
|
||||||
|
registry *InvoiceRegistry
|
||||||
|
clock *clock.TestClock
|
||||||
|
|
||||||
|
cleanup func()
|
||||||
|
t *testing.T
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestContext(t *testing.T) *testContext {
|
||||||
|
clock := clock.NewTestClock(testTime)
|
||||||
|
|
||||||
|
cdb, cleanup, err := newTestChannelDB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cdb.Now = clock.Now
|
||||||
|
|
||||||
|
expiryWatcher := NewInvoiceExpiryWatcher(clock)
|
||||||
|
|
||||||
|
// Instantiate and start the invoice ctx.registry.
|
||||||
|
cfg := RegistryConfig{
|
||||||
|
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
||||||
|
HtlcHoldDuration: 30 * time.Second,
|
||||||
|
Clock: clock,
|
||||||
|
}
|
||||||
|
registry := NewRegistry(cdb, expiryWatcher, &cfg)
|
||||||
|
|
||||||
|
err = registry.Start()
|
||||||
|
if err != nil {
|
||||||
|
cleanup()
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testContext{
|
||||||
|
cdb: cdb,
|
||||||
|
registry: registry,
|
||||||
|
clock: clock,
|
||||||
|
t: t,
|
||||||
|
cleanup: func() {
|
||||||
|
registry.Stop()
|
||||||
|
cleanup()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCircuitKey(htlcID uint64) channeldb.CircuitKey {
|
||||||
|
return channeldb.CircuitKey{
|
||||||
|
ChanID: lnwire.ShortChannelID{
|
||||||
|
BlockHeight: 1, TxIndex: 2, TxPosition: 3,
|
||||||
|
},
|
||||||
|
HtlcID: htlcID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestInvoice(t *testing.T, preimage lntypes.Preimage,
|
||||||
|
timestamp time.Time, expiry time.Duration) *channeldb.Invoice {
|
||||||
|
|
||||||
|
if expiry == 0 {
|
||||||
|
expiry = time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
rawInvoice, err := zpay32.NewInvoice(
|
||||||
|
testNetParams,
|
||||||
|
preimage.Hash(),
|
||||||
|
timestamp,
|
||||||
|
zpay32.Amount(testInvoiceAmount),
|
||||||
|
zpay32.Description(testInvoiceDescription),
|
||||||
|
zpay32.Expiry(expiry))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error while creating new invoice: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
paymentRequest, err := rawInvoice.Encode(testMessageSigner)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error while encoding payment request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &channeldb.Invoice{
|
||||||
|
Terms: channeldb.ContractTerm{
|
||||||
|
PaymentPreimage: preimage,
|
||||||
|
Value: testInvoiceAmount,
|
||||||
|
Expiry: expiry,
|
||||||
|
Features: testFeatures,
|
||||||
|
},
|
||||||
|
PaymentRequest: []byte(paymentRequest),
|
||||||
|
CreationDate: timestamp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeout implements a test level timeout.
|
||||||
|
func timeout() func() {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("error writing to std out after timeout: %v", err))
|
||||||
|
}
|
||||||
|
panic("timeout")
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
close(done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// invoiceExpiryTestData simply holds generated expired and pending invoices.
|
||||||
|
type invoiceExpiryTestData struct {
|
||||||
|
expiredInvoices map[lntypes.Hash]*channeldb.Invoice
|
||||||
|
pendingInvoices map[lntypes.Hash]*channeldb.Invoice
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateInvoiceExpiryTestData generates the specified number of fake expired
|
||||||
|
// and pending invoices anchored to the passed now timestamp.
|
||||||
|
func generateInvoiceExpiryTestData(
|
||||||
|
t *testing.T, now time.Time,
|
||||||
|
offset, numExpired, numPending int) invoiceExpiryTestData {
|
||||||
|
|
||||||
|
var testData invoiceExpiryTestData
|
||||||
|
|
||||||
|
testData.expiredInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
|
||||||
|
testData.pendingInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
|
||||||
|
|
||||||
|
expiredCreationDate := now.Add(-24 * time.Hour)
|
||||||
|
|
||||||
|
for i := 1; i <= numExpired; i++ {
|
||||||
|
var preimage lntypes.Preimage
|
||||||
|
binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i))
|
||||||
|
expiry := time.Duration((i+offset)%24) * time.Hour
|
||||||
|
invoice := newTestInvoice(t, preimage, expiredCreationDate, expiry)
|
||||||
|
testData.expiredInvoices[preimage.Hash()] = invoice
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= numPending; i++ {
|
||||||
|
var preimage lntypes.Preimage
|
||||||
|
binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i))
|
||||||
|
expiry := time.Duration((i+offset)%24) * time.Hour
|
||||||
|
invoice := newTestInvoice(t, preimage, now, expiry)
|
||||||
|
testData.pendingInvoices[preimage.Hash()] = invoice
|
||||||
|
}
|
||||||
|
|
||||||
|
return testData
|
||||||
|
}
|
@ -1,26 +0,0 @@
|
|||||||
package invoices
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"runtime/pprof"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// timeout implements a test level timeout.
|
|
||||||
func timeout(t *testing.T) func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
|
||||||
|
|
||||||
panic("test timeout")
|
|
||||||
case <-done:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return func() {
|
|
||||||
close(done)
|
|
||||||
}
|
|
||||||
}
|
|
@ -9,6 +9,8 @@ import (
|
|||||||
// PriorityQueue will be able to use that to build and restore an underlying
|
// PriorityQueue will be able to use that to build and restore an underlying
|
||||||
// heap.
|
// heap.
|
||||||
type PriorityQueueItem interface {
|
type PriorityQueueItem interface {
|
||||||
|
// Less must return true if this item is ordered before other and false
|
||||||
|
// otherwise.
|
||||||
Less(other PriorityQueueItem) bool
|
Less(other PriorityQueueItem) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,7 +45,7 @@ func (pq *priorityQueue) Pop() interface{} {
|
|||||||
return item
|
return item
|
||||||
}
|
}
|
||||||
|
|
||||||
// Priority wrap a standard heap in a more object-oriented structure.
|
// PriorityQueue wraps a standard heap into a self contained class.
|
||||||
type PriorityQueue struct {
|
type PriorityQueue struct {
|
||||||
queue priorityQueue
|
queue priorityQueue
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/chanfitness"
|
"github.com/lightningnetwork/lnd/chanfitness"
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
"github.com/lightningnetwork/lnd/channelnotifier"
|
"github.com/lightningnetwork/lnd/channelnotifier"
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
"github.com/lightningnetwork/lnd/contractcourt"
|
"github.com/lightningnetwork/lnd/contractcourt"
|
||||||
"github.com/lightningnetwork/lnd/discovery"
|
"github.com/lightningnetwork/lnd/discovery"
|
||||||
"github.com/lightningnetwork/lnd/feature"
|
"github.com/lightningnetwork/lnd/feature"
|
||||||
@ -381,8 +382,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
|
|||||||
registryConfig := invoices.RegistryConfig{
|
registryConfig := invoices.RegistryConfig{
|
||||||
FinalCltvRejectDelta: defaultFinalCltvRejectDelta,
|
FinalCltvRejectDelta: defaultFinalCltvRejectDelta,
|
||||||
HtlcHoldDuration: invoices.DefaultHtlcHoldDuration,
|
HtlcHoldDuration: invoices.DefaultHtlcHoldDuration,
|
||||||
Now: time.Now,
|
Clock: clock.NewDefaultClock(),
|
||||||
TickAfter: time.After,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &server{
|
s := &server{
|
||||||
@ -393,7 +393,10 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
|
|||||||
readPool: readPool,
|
readPool: readPool,
|
||||||
chansToRestore: chansToRestore,
|
chansToRestore: chansToRestore,
|
||||||
|
|
||||||
invoices: invoices.NewRegistry(chanDB, ®istryConfig),
|
invoices: invoices.NewRegistry(
|
||||||
|
chanDB, invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()),
|
||||||
|
®istryConfig,
|
||||||
|
),
|
||||||
|
|
||||||
channelNotifier: channelnotifier.New(chanDB),
|
channelNotifier: channelnotifier.New(chanDB),
|
||||||
|
|
||||||
|
@ -82,6 +82,10 @@ const (
|
|||||||
// This is chosen to be the maximum number of bytes that can fit into a
|
// This is chosen to be the maximum number of bytes that can fit into a
|
||||||
// single QR code: https://en.wikipedia.org/wiki/QR_code#Storage
|
// single QR code: https://en.wikipedia.org/wiki/QR_code#Storage
|
||||||
maxInvoiceLength = 7089
|
maxInvoiceLength = 7089
|
||||||
|
|
||||||
|
// DefaultInvoiceExpiry is the default expiry duration from the creation
|
||||||
|
// timestamp if expiry is set to zero.
|
||||||
|
DefaultInvoiceExpiry = time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
Loading…
Reference in New Issue
Block a user