Merge pull request #3694 from bhandras/i3448
invoices+channeldb: reject payments to expired invoices
This commit is contained in:
commit
e34bc3d645
@ -2,6 +2,7 @@ package channeldb
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
mrand "math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
@ -393,6 +394,114 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their
|
||||
// corresponding payment hashes.
|
||||
func TestFetchAllInvoicesWithPaymentHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanup, err := makeTestDB()
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test db: %v", err)
|
||||
}
|
||||
|
||||
// With an empty DB we expect to return no error and an empty list.
|
||||
empty, err := db.FetchAllInvoicesWithPaymentHash(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
if len(empty) != 0 {
|
||||
t.Fatalf("expected empty list as a result, got: %v", empty)
|
||||
}
|
||||
|
||||
// Now populate the DB and check if we can get all invoices with their
|
||||
// payment hashes as expected.
|
||||
const numInvoices = 20
|
||||
testPendingInvoices := make(map[lntypes.Hash]*Invoice)
|
||||
testAllInvoices := make(map[lntypes.Hash]*Invoice)
|
||||
|
||||
states := []ContractState{
|
||||
ContractOpen, ContractSettled, ContractCanceled, ContractAccepted,
|
||||
}
|
||||
|
||||
for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ {
|
||||
invoice, err := randInvoice(i)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create invoice: %v", err)
|
||||
}
|
||||
|
||||
invoice.State = states[mrand.Intn(len(states))]
|
||||
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
||||
|
||||
if invoice.State != ContractSettled && invoice.State != ContractCanceled {
|
||||
testPendingInvoices[paymentHash] = invoice
|
||||
}
|
||||
|
||||
testAllInvoices[paymentHash] = invoice
|
||||
|
||||
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
|
||||
t.Fatalf("unable to add invoice: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true)
|
||||
if err != nil {
|
||||
t.Fatalf("can't fetch invoices with payment hash: %v", err)
|
||||
}
|
||||
|
||||
if len(testPendingInvoices) != len(pendingInvoices) {
|
||||
t.Fatalf("expected %v pending invoices, got: %v",
|
||||
len(testPendingInvoices), len(pendingInvoices))
|
||||
}
|
||||
|
||||
allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false)
|
||||
if err != nil {
|
||||
t.Fatalf("can't fetch invoices with payment hash: %v", err)
|
||||
}
|
||||
|
||||
if len(testAllInvoices) != len(allInvoices) {
|
||||
t.Fatalf("expected %v invoices, got: %v",
|
||||
len(testAllInvoices), len(allInvoices))
|
||||
}
|
||||
|
||||
for i := range pendingInvoices {
|
||||
expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash]
|
||||
if !ok {
|
||||
t.Fatalf("coulnd't find invoice with hash: %v",
|
||||
pendingInvoices[i].PaymentHash)
|
||||
}
|
||||
|
||||
// Zero out add index to not confuse DeepEqual.
|
||||
pendingInvoices[i].Invoice.AddIndex = 0
|
||||
expected.AddIndex = 0
|
||||
|
||||
if !reflect.DeepEqual(*expected, pendingInvoices[i].Invoice) {
|
||||
t.Fatalf("expected: %v, got: %v",
|
||||
spew.Sdump(expected), spew.Sdump(pendingInvoices[i].Invoice))
|
||||
}
|
||||
}
|
||||
|
||||
for i := range allInvoices {
|
||||
expected, ok := testAllInvoices[allInvoices[i].PaymentHash]
|
||||
if !ok {
|
||||
t.Fatalf("coulnd't find invoice with hash: %v",
|
||||
allInvoices[i].PaymentHash)
|
||||
}
|
||||
|
||||
// Zero out add index to not confuse DeepEqual.
|
||||
allInvoices[i].Invoice.AddIndex = 0
|
||||
expected.AddIndex = 0
|
||||
|
||||
if !reflect.DeepEqual(*expected, allInvoices[i].Invoice) {
|
||||
t.Fatalf("expected: %v, got: %v",
|
||||
spew.Sdump(expected), spew.Sdump(allInvoices[i].Invoice))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it
|
||||
// twice, then the second time we also receive the invoice that we settled as a
|
||||
// return argument.
|
||||
|
@ -565,6 +565,83 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
|
||||
return invoice, nil
|
||||
}
|
||||
|
||||
// InvoiceWithPaymentHash is used to store an invoice and its corresponding
|
||||
// payment hash. This struct is only used to store results of
|
||||
// ChannelDB.FetchAllInvoicesWithPaymentHash() call.
|
||||
type InvoiceWithPaymentHash struct {
|
||||
// Invoice holds the invoice as selected from the invoices bucket.
|
||||
Invoice Invoice
|
||||
|
||||
// PaymentHash is the payment hash for the Invoice.
|
||||
PaymentHash lntypes.Hash
|
||||
}
|
||||
|
||||
// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes
|
||||
// currently stored within the database. If the pendingOnly param is true, then
|
||||
// only unsettled invoices and their payment hashes will be returned, skipping
|
||||
// all invoices that are fully settled or canceled. Note that the returned
|
||||
// array is not ordered by add index.
|
||||
func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) (
|
||||
[]InvoiceWithPaymentHash, error) {
|
||||
|
||||
var result []InvoiceWithPaymentHash
|
||||
|
||||
err := d.View(func(tx *bbolt.Tx) error {
|
||||
invoices := tx.Bucket(invoiceBucket)
|
||||
if invoices == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
invoiceIndex := invoices.Bucket(invoiceIndexBucket)
|
||||
if invoiceIndex == nil {
|
||||
// Mask the error if there's no invoice
|
||||
// index as that simply means there are no
|
||||
// invoices added yet to the DB. In this case
|
||||
// we simply return an empty list.
|
||||
return nil
|
||||
}
|
||||
|
||||
return invoiceIndex.ForEach(func(k, v []byte) error {
|
||||
// Skip the special numInvoicesKey as that does not
|
||||
// point to a valid invoice.
|
||||
if bytes.Equal(k, numInvoicesKey) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
invoice, err := fetchInvoice(v, invoices)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pendingOnly &&
|
||||
(invoice.State == ContractSettled ||
|
||||
invoice.State == ContractCanceled) {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
invoiceWithPaymentHash := InvoiceWithPaymentHash{
|
||||
Invoice: invoice,
|
||||
}
|
||||
|
||||
copy(invoiceWithPaymentHash.PaymentHash[:], k)
|
||||
result = append(result, invoiceWithPaymentHash)
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FetchAllInvoices returns all invoices currently stored within the database.
|
||||
// If the pendingOnly param is true, then only unsettled invoices will be
|
||||
// returned, skipping all invoices that are fully settled.
|
||||
|
24
clock/default_clock.go
Normal file
24
clock/default_clock.go
Normal file
@ -0,0 +1,24 @@
|
||||
package clock
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultClock implements Clock interface by simply calling the appropriate
|
||||
// time functions.
|
||||
type DefaultClock struct{}
|
||||
|
||||
// NewDefaultClock constructs a new DefaultClock.
|
||||
func NewDefaultClock() Clock {
|
||||
return &DefaultClock{}
|
||||
}
|
||||
|
||||
// Now simply returns time.Now().
|
||||
func (DefaultClock) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// TickAfter simply wraps time.After().
|
||||
func (DefaultClock) TickAfter(duration time.Duration) <-chan time.Time {
|
||||
return time.After(duration)
|
||||
}
|
16
clock/interface.go
Normal file
16
clock/interface.go
Normal file
@ -0,0 +1,16 @@
|
||||
package clock
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Clock is an interface that provides a time functions for LND packages.
|
||||
// This is useful during testing when a concrete time reference is needed.
|
||||
type Clock interface {
|
||||
// Now returns the current local time (as defined by the Clock).
|
||||
Now() time.Time
|
||||
|
||||
// TickAfter returns a channel that will receive a tick after the specified
|
||||
// duration has passed.
|
||||
TickAfter(duration time.Duration) <-chan time.Time
|
||||
}
|
@ -1,42 +1,40 @@
|
||||
package invoices
|
||||
package clock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// testClock can be used in tests to mock time.
|
||||
type testClock struct {
|
||||
// TestClock can be used in tests to mock time.
|
||||
type TestClock struct {
|
||||
currentTime time.Time
|
||||
timeChanMap map[time.Time][]chan time.Time
|
||||
timeLock sync.Mutex
|
||||
}
|
||||
|
||||
// newTestClock returns a new test clock.
|
||||
func newTestClock(startTime time.Time) *testClock {
|
||||
return &testClock{
|
||||
// NewTestClock returns a new test clock.
|
||||
func NewTestClock(startTime time.Time) *TestClock {
|
||||
return &TestClock{
|
||||
currentTime: startTime,
|
||||
timeChanMap: make(map[time.Time][]chan time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// now returns the current (test) time.
|
||||
func (c *testClock) now() time.Time {
|
||||
// Now returns the current (test) time.
|
||||
func (c *TestClock) Now() time.Time {
|
||||
c.timeLock.Lock()
|
||||
defer c.timeLock.Unlock()
|
||||
|
||||
return c.currentTime
|
||||
}
|
||||
|
||||
// tickAfter returns a channel that will receive a tick at the specified time.
|
||||
func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time {
|
||||
// TickAfter returns a channel that will receive a tick after the specified
|
||||
// duration has passed passed by the user set test time.
|
||||
func (c *TestClock) TickAfter(duration time.Duration) <-chan time.Time {
|
||||
c.timeLock.Lock()
|
||||
defer c.timeLock.Unlock()
|
||||
|
||||
triggerTime := c.currentTime.Add(duration)
|
||||
log.Debugf("tickAfter called: duration=%v, trigger_time=%v",
|
||||
duration, triggerTime)
|
||||
|
||||
ch := make(chan time.Time, 1)
|
||||
|
||||
// If already expired, tick immediately.
|
||||
@ -53,8 +51,8 @@ func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time {
|
||||
return ch
|
||||
}
|
||||
|
||||
// setTime sets the (test) time and triggers tick channels when they expire.
|
||||
func (c *testClock) setTime(now time.Time) {
|
||||
// SetTime sets the (test) time and triggers tick channels when they expire.
|
||||
func (c *TestClock) SetTime(now time.Time) {
|
||||
c.timeLock.Lock()
|
||||
defer c.timeLock.Unlock()
|
||||
|
63
clock/test_clock_test.go
Normal file
63
clock/test_clock_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package clock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
testTime = time.Date(2009, time.January, 3, 12, 0, 0, 0, time.UTC)
|
||||
)
|
||||
|
||||
func TestNow(t *testing.T) {
|
||||
c := NewTestClock(testTime)
|
||||
now := c.Now()
|
||||
|
||||
if now != testTime {
|
||||
t.Fatalf("expected: %v, got: %v", testTime, now)
|
||||
}
|
||||
|
||||
now = now.Add(time.Hour)
|
||||
c.SetTime(now)
|
||||
if c.Now() != now {
|
||||
t.Fatalf("epected: %v, got: %v", now, c.Now())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickAfter(t *testing.T) {
|
||||
c := NewTestClock(testTime)
|
||||
|
||||
// Should be ticking immediately.
|
||||
ticker0 := c.TickAfter(0)
|
||||
|
||||
// Both should be ticking after SetTime
|
||||
ticker1 := c.TickAfter(time.Hour)
|
||||
ticker2 := c.TickAfter(time.Hour)
|
||||
|
||||
// We don't expect this one to tick.
|
||||
ticker3 := c.TickAfter(2 * time.Hour)
|
||||
|
||||
tickOrTimeOut := func(ticker <-chan time.Time, expectTick bool) {
|
||||
tick := false
|
||||
select {
|
||||
case <-ticker:
|
||||
tick = true
|
||||
case <-time.After(time.Millisecond):
|
||||
}
|
||||
|
||||
if tick != expectTick {
|
||||
t.Fatalf("expected tick: %v, ticked: %v", expectTick, tick)
|
||||
}
|
||||
}
|
||||
|
||||
tickOrTimeOut(ticker0, true)
|
||||
tickOrTimeOut(ticker1, false)
|
||||
tickOrTimeOut(ticker2, false)
|
||||
tickOrTimeOut(ticker3, false)
|
||||
|
||||
c.SetTime(c.Now().Add(time.Hour))
|
||||
|
||||
tickOrTimeOut(ticker1, true)
|
||||
tickOrTimeOut(ticker2, true)
|
||||
tickOrTimeOut(ticker3, false)
|
||||
}
|
@ -22,6 +22,7 @@ import (
|
||||
sphinx "github.com/lightningnetwork/lightning-onion"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/contractcourt"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
@ -792,6 +793,7 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
|
||||
|
||||
registry := invoices.NewRegistry(
|
||||
cdb,
|
||||
invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()),
|
||||
&invoices.RegistryConfig{
|
||||
FinalCltvRejectDelta: 5,
|
||||
},
|
||||
|
191
invoices/invoice_expiry_watcher.go
Normal file
191
invoices/invoice_expiry_watcher.go
Normal file
@ -0,0 +1,191 @@
|
||||
package invoices
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/queue"
|
||||
"github.com/lightningnetwork/lnd/zpay32"
|
||||
)
|
||||
|
||||
// invoiceExpiry holds and invoice's payment hash and its expiry. This
|
||||
// is used to order invoices by their expiry for cancellation.
|
||||
type invoiceExpiry struct {
|
||||
PaymentHash lntypes.Hash
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
// Less implements PriorityQueueItem.Less such that the top item in the
|
||||
// priorty queue will be the one that expires next.
|
||||
func (e invoiceExpiry) Less(other queue.PriorityQueueItem) bool {
|
||||
return e.Expiry.Before(other.(*invoiceExpiry).Expiry)
|
||||
}
|
||||
|
||||
// InvoiceExpiryWatcher handles automatic invoice cancellation of expried
|
||||
// invoices. Upon start InvoiceExpiryWatcher will retrieve all pending (not yet
|
||||
// settled or canceled) invoices invoices to its watcing queue. When a new
|
||||
// invoice is added to the InvoiceRegistry, it'll be forarded to the
|
||||
// InvoiceExpiryWatcher and will end up in the watching queue as well.
|
||||
// If any of the watched invoices expire, they'll be removed from the watching
|
||||
// queue and will be cancelled through InvoiceRegistry.CancelInvoice().
|
||||
type InvoiceExpiryWatcher struct {
|
||||
sync.Mutex
|
||||
started bool
|
||||
|
||||
// clock is the clock implementation that InvoiceExpiryWatcher uses.
|
||||
// It is useful for testing.
|
||||
clock clock.Clock
|
||||
|
||||
// cancelInvoice is a template method that cancels an expired invoice.
|
||||
cancelInvoice func(lntypes.Hash) error
|
||||
|
||||
// expiryQueue holds invoiceExpiry items and is used to find the next
|
||||
// invoice to expire.
|
||||
expiryQueue queue.PriorityQueue
|
||||
|
||||
// newInvoices channel is used to wake up the main loop when a new invoices
|
||||
// is added.
|
||||
newInvoices chan *invoiceExpiry
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
// quit signals InvoiceExpiryWatcher to stop.
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// NewInvoiceExpiryWatcher creates a new InvoiceExpiryWatcher instance.
|
||||
func NewInvoiceExpiryWatcher(clock clock.Clock) *InvoiceExpiryWatcher {
|
||||
return &InvoiceExpiryWatcher{
|
||||
clock: clock,
|
||||
newInvoices: make(chan *invoiceExpiry),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the the subscription handler and the main loop. Start() will
|
||||
// return with error if InvoiceExpiryWatcher is already started. Start()
|
||||
// expects a cancellation function passed that will be use to cancel expired
|
||||
// invoices by their payment hash.
|
||||
func (ew *InvoiceExpiryWatcher) Start(
|
||||
cancelInvoice func(lntypes.Hash) error) error {
|
||||
|
||||
ew.Lock()
|
||||
defer ew.Unlock()
|
||||
|
||||
if ew.started {
|
||||
return fmt.Errorf("InvoiceExpiryWatcher already started")
|
||||
}
|
||||
|
||||
ew.started = true
|
||||
ew.cancelInvoice = cancelInvoice
|
||||
ew.wg.Add(1)
|
||||
go ew.mainLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop quits the expiry handler loop and waits for InvoiceExpiryWatcher to
|
||||
// fully stop.
|
||||
func (ew *InvoiceExpiryWatcher) Stop() {
|
||||
ew.Lock()
|
||||
defer ew.Unlock()
|
||||
|
||||
if ew.started {
|
||||
// Signal subscriptionHandler to quit and wait for it to return.
|
||||
close(ew.quit)
|
||||
ew.wg.Wait()
|
||||
ew.started = false
|
||||
}
|
||||
}
|
||||
|
||||
// AddInvoice adds a new invoice to the InvoiceExpiryWatcher. This won't check
|
||||
// if the invoice is already added and will only add invoices with ContractOpen
|
||||
// state.
|
||||
func (ew *InvoiceExpiryWatcher) AddInvoice(
|
||||
paymentHash lntypes.Hash, invoice *channeldb.Invoice) {
|
||||
|
||||
if invoice.State != channeldb.ContractOpen {
|
||||
log.Debugf("Invoice not added to expiry watcher: %v", invoice)
|
||||
return
|
||||
}
|
||||
|
||||
realExpiry := invoice.Terms.Expiry
|
||||
if realExpiry == 0 {
|
||||
realExpiry = zpay32.DefaultInvoiceExpiry
|
||||
}
|
||||
|
||||
expiry := invoice.CreationDate.Add(realExpiry)
|
||||
|
||||
log.Debugf("Adding invoice '%v' to expiry watcher, expiration: %v",
|
||||
paymentHash, expiry)
|
||||
|
||||
newInvoiceExpiry := &invoiceExpiry{
|
||||
PaymentHash: paymentHash,
|
||||
Expiry: expiry,
|
||||
}
|
||||
|
||||
select {
|
||||
case ew.newInvoices <- newInvoiceExpiry:
|
||||
case <-ew.quit:
|
||||
// Select on quit too so that callers won't get blocked in case
|
||||
// of concurrent shutdown.
|
||||
}
|
||||
}
|
||||
|
||||
// nextExpiry returns a Time chan to wait on until the next invoice expires.
|
||||
// If there are no active invoices, then it'll simply wait indefinitely.
|
||||
func (ew *InvoiceExpiryWatcher) nextExpiry() <-chan time.Time {
|
||||
if !ew.expiryQueue.Empty() {
|
||||
top := ew.expiryQueue.Top().(*invoiceExpiry)
|
||||
return ew.clock.TickAfter(top.Expiry.Sub(ew.clock.Now()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cancelExpiredInvoices will cancel all expired invoices and removes them from
|
||||
// the expiry queue.
|
||||
func (ew *InvoiceExpiryWatcher) cancelExpiredInvoices() {
|
||||
for !ew.expiryQueue.Empty() {
|
||||
top := ew.expiryQueue.Top().(*invoiceExpiry)
|
||||
if !top.Expiry.Before(ew.clock.Now()) {
|
||||
break
|
||||
}
|
||||
|
||||
err := ew.cancelInvoice(top.PaymentHash)
|
||||
if err != nil && err != channeldb.ErrInvoiceAlreadySettled &&
|
||||
err != channeldb.ErrInvoiceAlreadyCanceled {
|
||||
|
||||
log.Errorf("Unable to cancel invoice: %v", top.PaymentHash)
|
||||
}
|
||||
|
||||
ew.expiryQueue.Pop()
|
||||
}
|
||||
}
|
||||
|
||||
// mainLoop is a goroutine that receives new invoices and handles cancellation
|
||||
// of expired invoices.
|
||||
func (ew *InvoiceExpiryWatcher) mainLoop() {
|
||||
defer ew.wg.Done()
|
||||
|
||||
for {
|
||||
// Cancel any invoices that may have expired.
|
||||
ew.cancelExpiredInvoices()
|
||||
|
||||
select {
|
||||
case <-ew.nextExpiry():
|
||||
// Wait until the next invoice expires, then cancel expired invoices.
|
||||
continue
|
||||
|
||||
case newInvoiceExpiry := <-ew.newInvoices:
|
||||
ew.expiryQueue.Push(newInvoiceExpiry)
|
||||
|
||||
case <-ew.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
125
invoices/invoice_expiry_watcher_test.go
Normal file
125
invoices/invoice_expiry_watcher_test.go
Normal file
@ -0,0 +1,125 @@
|
||||
package invoices
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
)
|
||||
|
||||
// invoiceExpiryWatcherTest holds a test fixture and implements checks
|
||||
// for InvoiceExpiryWatcher tests.
|
||||
type invoiceExpiryWatcherTest struct {
|
||||
t *testing.T
|
||||
watcher *InvoiceExpiryWatcher
|
||||
testData invoiceExpiryTestData
|
||||
canceledInvoices []lntypes.Hash
|
||||
}
|
||||
|
||||
// newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture
|
||||
// and sets up the test environment.
|
||||
func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time,
|
||||
numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest {
|
||||
|
||||
test := &invoiceExpiryWatcherTest{
|
||||
watcher: NewInvoiceExpiryWatcher(clock.NewTestClock(testTime)),
|
||||
testData: generateInvoiceExpiryTestData(
|
||||
t, now, 0, numExpiredInvoices, numPendingInvoices,
|
||||
),
|
||||
}
|
||||
|
||||
err := test.watcher.Start(func(paymentHash lntypes.Hash) error {
|
||||
test.canceledInvoices = append(test.canceledInvoices, paymentHash)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err)
|
||||
}
|
||||
|
||||
return test
|
||||
}
|
||||
|
||||
func (t *invoiceExpiryWatcherTest) checkExpectations() {
|
||||
// Check that invoices that got canceled during the test are the ones
|
||||
// that expired.
|
||||
if len(t.canceledInvoices) != len(t.testData.expiredInvoices) {
|
||||
t.t.Fatalf("expected %v cancellations, got %v",
|
||||
len(t.testData.expiredInvoices), len(t.canceledInvoices))
|
||||
}
|
||||
|
||||
for i := range t.canceledInvoices {
|
||||
if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok {
|
||||
t.t.Fatalf("wrong invoice canceled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that InvoiceExpiryWatcher can be started and stopped.
|
||||
func TestInvoiceExpiryWatcherStartStop(t *testing.T) {
|
||||
watcher := NewInvoiceExpiryWatcher(clock.NewTestClock(testTime))
|
||||
cancel := func(lntypes.Hash) error {
|
||||
t.Fatalf("unexpected call")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := watcher.Start(cancel); err != nil {
|
||||
t.Fatalf("unexpected error upon start: %v", err)
|
||||
}
|
||||
|
||||
if err := watcher.Start(cancel); err == nil {
|
||||
t.Fatalf("expected error upon second start")
|
||||
}
|
||||
|
||||
watcher.Stop()
|
||||
|
||||
if err := watcher.Start(cancel); err != nil {
|
||||
t.Fatalf("unexpected error upon start: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that no invoices will expire from an empty InvoiceExpiryWatcher.
|
||||
func TestInvoiceExpiryWithNoInvoices(t *testing.T) {
|
||||
t.Parallel()
|
||||
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0)
|
||||
|
||||
time.Sleep(testTimeout)
|
||||
test.watcher.Stop()
|
||||
test.checkExpectations()
|
||||
}
|
||||
|
||||
// Tests that if all add invoices are expired, then all invoices
|
||||
// will be canceled.
|
||||
func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5)
|
||||
|
||||
for paymentHash, invoice := range test.testData.pendingInvoices {
|
||||
test.watcher.AddInvoice(paymentHash, invoice)
|
||||
}
|
||||
|
||||
time.Sleep(testTimeout)
|
||||
test.watcher.Stop()
|
||||
test.checkExpectations()
|
||||
}
|
||||
|
||||
// Tests that if some invoices are expired, then those invoices
|
||||
// will be canceled.
|
||||
func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) {
|
||||
t.Parallel()
|
||||
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
|
||||
|
||||
for paymentHash, invoice := range test.testData.expiredInvoices {
|
||||
test.watcher.AddInvoice(paymentHash, invoice)
|
||||
}
|
||||
|
||||
for paymentHash, invoice := range test.testData.pendingInvoices {
|
||||
test.watcher.AddInvoice(paymentHash, invoice)
|
||||
}
|
||||
|
||||
time.Sleep(testTimeout)
|
||||
test.watcher.Stop()
|
||||
test.checkExpectations()
|
||||
}
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/queue"
|
||||
@ -62,12 +63,10 @@ type RegistryConfig struct {
|
||||
// waiting for the other set members to arrive.
|
||||
HtlcHoldDuration time.Duration
|
||||
|
||||
// Now returns the current time.
|
||||
Now func() time.Time
|
||||
|
||||
// TickAfter returns a channel that is sent on after the specified
|
||||
// duration as passed.
|
||||
TickAfter func(duration time.Duration) <-chan time.Time
|
||||
// Clock holds the clock implementation that is used to provide
|
||||
// Now() and TickAfter() and is useful to stub out the clock functions
|
||||
// during testing.
|
||||
Clock clock.Clock
|
||||
}
|
||||
|
||||
// htlcReleaseEvent describes an htlc auto-release event. It is used to release
|
||||
@ -126,6 +125,8 @@ type InvoiceRegistry struct {
|
||||
// auto-released.
|
||||
htlcAutoReleaseChan chan *htlcReleaseEvent
|
||||
|
||||
expiryWatcher *InvoiceExpiryWatcher
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
}
|
||||
@ -134,7 +135,9 @@ type InvoiceRegistry struct {
|
||||
// wraps the persistent on-disk invoice storage with an additional in-memory
|
||||
// layer. The in-memory layer is in place such that debug invoices can be added
|
||||
// which are volatile yet available system wide within the daemon.
|
||||
func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry {
|
||||
func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher,
|
||||
cfg *RegistryConfig) *InvoiceRegistry {
|
||||
|
||||
return &InvoiceRegistry{
|
||||
cdb: cdb,
|
||||
notificationClients: make(map[uint32]*InvoiceSubscription),
|
||||
@ -146,21 +149,62 @@ func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry {
|
||||
hodlReverseSubscriptions: make(map[chan<- interface{}]map[channeldb.CircuitKey]struct{}),
|
||||
cfg: cfg,
|
||||
htlcAutoReleaseChan: make(chan *htlcReleaseEvent),
|
||||
expiryWatcher: expiryWatcher,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// populateExpiryWatcher fetches all active invoices and their corresponding
|
||||
// payment hashes from ChannelDB and adds them to the expiry watcher.
|
||||
func (i *InvoiceRegistry) populateExpiryWatcher() error {
|
||||
pendingOnly := true
|
||||
pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly)
|
||||
if err != nil && err != channeldb.ErrNoInvoicesCreated {
|
||||
log.Errorf(
|
||||
"Error while prefetching active invoices from the database: %v", err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
for idx := range pendingInvoices {
|
||||
i.expiryWatcher.AddInvoice(
|
||||
pendingInvoices[idx].PaymentHash, &pendingInvoices[idx].Invoice,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start starts the registry and all goroutines it needs to carry out its task.
|
||||
func (i *InvoiceRegistry) Start() error {
|
||||
i.wg.Add(1)
|
||||
// Start InvoiceExpiryWatcher and prepopulate it with existing active
|
||||
// invoices.
|
||||
err := i.expiryWatcher.Start(func(paymentHash lntypes.Hash) error {
|
||||
cancelIfAccepted := false
|
||||
return i.cancelInvoiceImpl(paymentHash, cancelIfAccepted)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.wg.Add(1)
|
||||
go i.invoiceEventLoop()
|
||||
|
||||
// Now prefetch all pending invoices to the expiry watcher.
|
||||
err = i.populateExpiryWatcher()
|
||||
if err != nil {
|
||||
i.Stop()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop signals the registry for a graceful shutdown.
|
||||
func (i *InvoiceRegistry) Stop() {
|
||||
i.expiryWatcher.Stop()
|
||||
|
||||
close(i.quit)
|
||||
|
||||
i.wg.Wait()
|
||||
@ -177,8 +221,8 @@ type invoiceEvent struct {
|
||||
// tickAt returns a channel that ticks at the specified time. If the time has
|
||||
// already passed, it will tick immediately.
|
||||
func (i *InvoiceRegistry) tickAt(t time.Time) <-chan time.Time {
|
||||
now := i.cfg.Now()
|
||||
return i.cfg.TickAfter(t.Sub(now))
|
||||
now := i.cfg.Clock.Now()
|
||||
return i.cfg.Clock.TickAfter(t.Sub(now))
|
||||
}
|
||||
|
||||
// invoiceEventLoop is the dedicated goroutine responsible for accepting
|
||||
@ -471,7 +515,6 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice,
|
||||
paymentHash lntypes.Hash) (uint64, error) {
|
||||
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
log.Debugf("Invoice(%v): added %v", paymentHash,
|
||||
newLogClosure(func() string {
|
||||
@ -481,12 +524,19 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice,
|
||||
|
||||
addIndex, err := i.cdb.AddInvoice(invoice, paymentHash)
|
||||
if err != nil {
|
||||
i.Unlock()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Now that we've added the invoice, we'll send dispatch a message to
|
||||
// notify the clients of this new invoice.
|
||||
i.notifyClients(paymentHash, invoice, channeldb.ContractOpen)
|
||||
i.Unlock()
|
||||
|
||||
// InvoiceExpiryWatcher.AddInvoice must not be locked by InvoiceRegistry
|
||||
// to avoid deadlock when a new invoice is added while an other is being
|
||||
// canceled.
|
||||
i.expiryWatcher.AddInvoice(paymentHash, invoice)
|
||||
|
||||
return addIndex, nil
|
||||
}
|
||||
@ -818,6 +868,15 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error {
|
||||
// CancelInvoice attempts to cancel the invoice corresponding to the passed
|
||||
// payment hash.
|
||||
func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
|
||||
return i.cancelInvoiceImpl(payHash, true)
|
||||
}
|
||||
|
||||
// cancelInvoice attempts to cancel the invoice corresponding to the passed
|
||||
// payment hash. Accepted invoices will only be canceled if explicitly
|
||||
// requested to do so.
|
||||
func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash,
|
||||
cancelAccepted bool) error {
|
||||
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
@ -826,6 +885,12 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
|
||||
updateInvoice := func(invoice *channeldb.Invoice) (
|
||||
*channeldb.InvoiceUpdateDesc, error) {
|
||||
|
||||
// Only cancel the invoice in ContractAccepted state if explicitly
|
||||
// requested to do so.
|
||||
if invoice.State == channeldb.ContractAccepted && !cancelAccepted {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Move invoice to the canceled state. Rely on validation in
|
||||
// channeldb to return an error if the invoice is already
|
||||
// settled or canceled.
|
||||
@ -848,6 +913,13 @@ func (i *InvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return without cancellation if the invoice state is ContractAccepted.
|
||||
if invoice.State == channeldb.ContractAccepted {
|
||||
log.Debugf("Invoice(%v): remains accepted as cancel wasn't"+
|
||||
"explicitly requested.", payHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Invoice(%v): canceled", payHash)
|
||||
|
||||
// In the callback, some htlcs may have been moved to the canceled
|
||||
|
@ -1,117 +1,16 @@
|
||||
package invoices
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
)
|
||||
|
||||
var (
|
||||
testTimeout = 5 * time.Second
|
||||
|
||||
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
|
||||
|
||||
preimage = lntypes.Preimage{
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
|
||||
}
|
||||
|
||||
hash = preimage.Hash()
|
||||
|
||||
testHtlcExpiry = uint32(5)
|
||||
|
||||
testInvoiceCltvDelta = uint32(4)
|
||||
|
||||
testFinalCltvRejectDelta = int32(4)
|
||||
|
||||
testCurrentHeight = int32(1)
|
||||
|
||||
testFeatures = lnwire.NewFeatureVector(
|
||||
nil, lnwire.Features,
|
||||
)
|
||||
|
||||
testPayload = &mockPayload{}
|
||||
)
|
||||
|
||||
var (
|
||||
testInvoiceAmt = lnwire.MilliSatoshi(100000)
|
||||
testInvoice = &channeldb.Invoice{
|
||||
Terms: channeldb.ContractTerm{
|
||||
PaymentPreimage: preimage,
|
||||
Value: lnwire.MilliSatoshi(100000),
|
||||
Features: testFeatures,
|
||||
},
|
||||
}
|
||||
|
||||
testHodlInvoice = &channeldb.Invoice{
|
||||
Terms: channeldb.ContractTerm{
|
||||
PaymentPreimage: channeldb.UnknownPreimage,
|
||||
Value: testInvoiceAmt,
|
||||
Features: testFeatures,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type testContext struct {
|
||||
registry *InvoiceRegistry
|
||||
clock *testClock
|
||||
|
||||
cleanup func()
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func newTestContext(t *testing.T) *testContext {
|
||||
clock := newTestClock(testTime)
|
||||
|
||||
cdb, cleanup, err := newDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cdb.Now = clock.now
|
||||
|
||||
// Instantiate and start the invoice ctx.registry.
|
||||
cfg := RegistryConfig{
|
||||
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
||||
HtlcHoldDuration: 30 * time.Second,
|
||||
Now: clock.now,
|
||||
TickAfter: clock.tickAfter,
|
||||
}
|
||||
registry := NewRegistry(cdb, &cfg)
|
||||
|
||||
err = registry.Start()
|
||||
if err != nil {
|
||||
cleanup()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := testContext{
|
||||
registry: registry,
|
||||
clock: clock,
|
||||
t: t,
|
||||
cleanup: func() {
|
||||
registry.Stop()
|
||||
cleanup()
|
||||
},
|
||||
}
|
||||
|
||||
return &ctx
|
||||
}
|
||||
|
||||
func getCircuitKey(htlcID uint64) channeldb.CircuitKey {
|
||||
return channeldb.CircuitKey{
|
||||
ChanID: lnwire.ShortChannelID{
|
||||
BlockHeight: 1, TxIndex: 2, TxPosition: 3,
|
||||
},
|
||||
HtlcID: htlcID,
|
||||
}
|
||||
}
|
||||
|
||||
// TestSettleInvoice tests settling of an invoice and related notifications.
|
||||
func TestSettleInvoice(t *testing.T) {
|
||||
ctx := newTestContext(t)
|
||||
@ -121,18 +20,18 @@ func TestSettleInvoice(t *testing.T) {
|
||||
defer allSubscriptions.Cancel()
|
||||
|
||||
// Subscribe to the not yet existing invoice.
|
||||
subscription, err := ctx.registry.SubscribeSingleInvoice(hash)
|
||||
subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer subscription.Cancel()
|
||||
|
||||
if subscription.hash != hash {
|
||||
if subscription.hash != testInvoicePaymentHash {
|
||||
t.Fatalf("expected subscription for provided hash")
|
||||
}
|
||||
|
||||
// Add the invoice.
|
||||
addIdx, err := ctx.registry.AddInvoice(testInvoice, hash)
|
||||
addIdx, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -168,7 +67,7 @@ func TestSettleInvoice(t *testing.T) {
|
||||
|
||||
// Try to settle invoice with an htlc that expires too soon.
|
||||
event, err := ctx.registry.NotifyExitHopHtlc(
|
||||
hash, testInvoice.Terms.Value,
|
||||
testInvoicePaymentHash, testInvoice.Terms.Value,
|
||||
uint32(testCurrentHeight)+testInvoiceCltvDelta-1,
|
||||
testCurrentHeight, getCircuitKey(10), hodlChan, testPayload,
|
||||
)
|
||||
@ -186,7 +85,7 @@ func TestSettleInvoice(t *testing.T) {
|
||||
// Settle invoice with a slightly higher amount.
|
||||
amtPaid := lnwire.MilliSatoshi(100500)
|
||||
_, err = ctx.registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
getCircuitKey(0), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -222,7 +121,7 @@ func TestSettleInvoice(t *testing.T) {
|
||||
// Try to settle again with the same htlc id. We need this idempotent
|
||||
// behaviour after a restart.
|
||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
getCircuitKey(0), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -236,7 +135,7 @@ func TestSettleInvoice(t *testing.T) {
|
||||
// should also be accepted, to prevent any change in behaviour for a
|
||||
// paid invoice that may open up a probe vector.
|
||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid+600, testHtlcExpiry, testCurrentHeight,
|
||||
testInvoicePaymentHash, amtPaid+600, testHtlcExpiry, testCurrentHeight,
|
||||
getCircuitKey(1), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -249,7 +148,7 @@ func TestSettleInvoice(t *testing.T) {
|
||||
// Try to settle again with a lower amount. This should fail just as it
|
||||
// would have failed if it were the first payment.
|
||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid-600, testHtlcExpiry, testCurrentHeight,
|
||||
testInvoicePaymentHash, amtPaid-600, testHtlcExpiry, testCurrentHeight,
|
||||
getCircuitKey(2), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -261,7 +160,7 @@ func TestSettleInvoice(t *testing.T) {
|
||||
|
||||
// Check that settled amount is equal to the sum of values of the htlcs
|
||||
// 0 and 1.
|
||||
inv, err := ctx.registry.LookupInvoice(hash)
|
||||
inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -270,7 +169,7 @@ func TestSettleInvoice(t *testing.T) {
|
||||
}
|
||||
|
||||
// Try to cancel.
|
||||
err = ctx.registry.CancelInvoice(hash)
|
||||
err = ctx.registry.CancelInvoice(testInvoicePaymentHash)
|
||||
if err != channeldb.ErrInvoiceAlreadySettled {
|
||||
t.Fatal("expected cancelation of a settled invoice to fail")
|
||||
}
|
||||
@ -292,25 +191,25 @@ func TestCancelInvoice(t *testing.T) {
|
||||
defer allSubscriptions.Cancel()
|
||||
|
||||
// Try to cancel the not yet existing invoice. This should fail.
|
||||
err := ctx.registry.CancelInvoice(hash)
|
||||
err := ctx.registry.CancelInvoice(testInvoicePaymentHash)
|
||||
if err != channeldb.ErrInvoiceNotFound {
|
||||
t.Fatalf("expected ErrInvoiceNotFound, but got %v", err)
|
||||
}
|
||||
|
||||
// Subscribe to the not yet existing invoice.
|
||||
subscription, err := ctx.registry.SubscribeSingleInvoice(hash)
|
||||
subscription, err := ctx.registry.SubscribeSingleInvoice(testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer subscription.Cancel()
|
||||
|
||||
if subscription.hash != hash {
|
||||
if subscription.hash != testInvoicePaymentHash {
|
||||
t.Fatalf("expected subscription for provided hash")
|
||||
}
|
||||
|
||||
// Add the invoice.
|
||||
amt := lnwire.MilliSatoshi(100000)
|
||||
_, err = ctx.registry.AddInvoice(testInvoice, hash)
|
||||
_, err = ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -342,7 +241,7 @@ func TestCancelInvoice(t *testing.T) {
|
||||
}
|
||||
|
||||
// Cancel invoice.
|
||||
err = ctx.registry.CancelInvoice(hash)
|
||||
err = ctx.registry.CancelInvoice(testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -365,7 +264,7 @@ func TestCancelInvoice(t *testing.T) {
|
||||
// subscribers (backwards compatibility).
|
||||
|
||||
// Try to cancel again.
|
||||
err = ctx.registry.CancelInvoice(hash)
|
||||
err = ctx.registry.CancelInvoice(testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal("expected cancelation of a canceled invoice to succeed")
|
||||
}
|
||||
@ -374,7 +273,7 @@ func TestCancelInvoice(t *testing.T) {
|
||||
// result in a cancel event.
|
||||
hodlChan := make(chan interface{})
|
||||
event, err := ctx.registry.NotifyExitHopHtlc(
|
||||
hash, amt, testHtlcExpiry, testCurrentHeight,
|
||||
testInvoicePaymentHash, amt, testHtlcExpiry, testCurrentHeight,
|
||||
getCircuitKey(0), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -393,9 +292,9 @@ func TestCancelInvoice(t *testing.T) {
|
||||
// TestSettleHoldInvoice tests settling of a hold invoice and related
|
||||
// notifications.
|
||||
func TestSettleHoldInvoice(t *testing.T) {
|
||||
defer timeout(t)()
|
||||
defer timeout()()
|
||||
|
||||
cdb, cleanup, err := newDB()
|
||||
cdb, cleanup, err := newTestChannelDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -404,8 +303,9 @@ func TestSettleHoldInvoice(t *testing.T) {
|
||||
// Instantiate and start the invoice ctx.registry.
|
||||
cfg := RegistryConfig{
|
||||
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
||||
Clock: clock.NewTestClock(testTime),
|
||||
}
|
||||
registry := NewRegistry(cdb, &cfg)
|
||||
registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg)
|
||||
|
||||
err = registry.Start()
|
||||
if err != nil {
|
||||
@ -417,18 +317,18 @@ func TestSettleHoldInvoice(t *testing.T) {
|
||||
defer allSubscriptions.Cancel()
|
||||
|
||||
// Subscribe to the not yet existing invoice.
|
||||
subscription, err := registry.SubscribeSingleInvoice(hash)
|
||||
subscription, err := registry.SubscribeSingleInvoice(testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer subscription.Cancel()
|
||||
|
||||
if subscription.hash != hash {
|
||||
if subscription.hash != testInvoicePaymentHash {
|
||||
t.Fatalf("expected subscription for provided hash")
|
||||
}
|
||||
|
||||
// Add the invoice.
|
||||
_, err = registry.AddInvoice(testHodlInvoice, hash)
|
||||
_, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -455,7 +355,7 @@ func TestSettleHoldInvoice(t *testing.T) {
|
||||
// NotifyExitHopHtlc without a preimage present in the invoice registry
|
||||
// should be possible.
|
||||
event, err := registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
getCircuitKey(0), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -467,7 +367,7 @@ func TestSettleHoldInvoice(t *testing.T) {
|
||||
|
||||
// Test idempotency.
|
||||
event, err = registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
getCircuitKey(0), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -480,7 +380,7 @@ func TestSettleHoldInvoice(t *testing.T) {
|
||||
// Test replay at a higher height. We expect the same result because it
|
||||
// is a replay.
|
||||
event, err = registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight+10,
|
||||
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+10,
|
||||
getCircuitKey(0), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -493,7 +393,7 @@ func TestSettleHoldInvoice(t *testing.T) {
|
||||
// Test a new htlc coming in that doesn't meet the final cltv delta
|
||||
// requirement. It should be rejected.
|
||||
event, err = registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid, 1, testCurrentHeight,
|
||||
testInvoicePaymentHash, amtPaid, 1, testCurrentHeight,
|
||||
getCircuitKey(1), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -516,13 +416,13 @@ func TestSettleHoldInvoice(t *testing.T) {
|
||||
}
|
||||
|
||||
// Settling with preimage should succeed.
|
||||
err = registry.SettleHodlInvoice(preimage)
|
||||
err = registry.SettleHodlInvoice(testInvoicePreimage)
|
||||
if err != nil {
|
||||
t.Fatal("expected set preimage to succeed")
|
||||
}
|
||||
|
||||
hodlEvent := (<-hodlChan).(HodlEvent)
|
||||
if *hodlEvent.Preimage != preimage {
|
||||
if *hodlEvent.Preimage != testInvoicePreimage {
|
||||
t.Fatal("unexpected preimage in hodl event")
|
||||
}
|
||||
if hodlEvent.AcceptHeight != testCurrentHeight {
|
||||
@ -549,13 +449,13 @@ func TestSettleHoldInvoice(t *testing.T) {
|
||||
}
|
||||
|
||||
// Idempotency.
|
||||
err = registry.SettleHodlInvoice(preimage)
|
||||
err = registry.SettleHodlInvoice(testInvoicePreimage)
|
||||
if err != channeldb.ErrInvoiceAlreadySettled {
|
||||
t.Fatalf("expected ErrInvoiceAlreadySettled but got %v", err)
|
||||
}
|
||||
|
||||
// Try to cancel.
|
||||
err = registry.CancelInvoice(hash)
|
||||
err = registry.CancelInvoice(testInvoicePaymentHash)
|
||||
if err == nil {
|
||||
t.Fatal("expected cancelation of a settled invoice to fail")
|
||||
}
|
||||
@ -564,9 +464,9 @@ func TestSettleHoldInvoice(t *testing.T) {
|
||||
// TestCancelHoldInvoice tests canceling of a hold invoice and related
|
||||
// notifications.
|
||||
func TestCancelHoldInvoice(t *testing.T) {
|
||||
defer timeout(t)()
|
||||
defer timeout()()
|
||||
|
||||
cdb, cleanup, err := newDB()
|
||||
cdb, cleanup, err := newTestChannelDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -575,8 +475,9 @@ func TestCancelHoldInvoice(t *testing.T) {
|
||||
// Instantiate and start the invoice ctx.registry.
|
||||
cfg := RegistryConfig{
|
||||
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
||||
Clock: clock.NewTestClock(testTime),
|
||||
}
|
||||
registry := NewRegistry(cdb, &cfg)
|
||||
registry := NewRegistry(cdb, NewInvoiceExpiryWatcher(cfg.Clock), &cfg)
|
||||
|
||||
err = registry.Start()
|
||||
if err != nil {
|
||||
@ -585,7 +486,7 @@ func TestCancelHoldInvoice(t *testing.T) {
|
||||
defer registry.Stop()
|
||||
|
||||
// Add the invoice.
|
||||
_, err = registry.AddInvoice(testHodlInvoice, hash)
|
||||
_, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -596,7 +497,7 @@ func TestCancelHoldInvoice(t *testing.T) {
|
||||
// NotifyExitHopHtlc without a preimage present in the invoice registry
|
||||
// should be possible.
|
||||
event, err := registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight,
|
||||
getCircuitKey(0), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -607,7 +508,7 @@ func TestCancelHoldInvoice(t *testing.T) {
|
||||
}
|
||||
|
||||
// Cancel invoice.
|
||||
err = registry.CancelInvoice(hash)
|
||||
err = registry.CancelInvoice(testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal("cancel invoice failed")
|
||||
}
|
||||
@ -621,7 +522,7 @@ func TestCancelHoldInvoice(t *testing.T) {
|
||||
// in a rejection. The accept height is expected to be the original
|
||||
// accept height.
|
||||
event, err = registry.NotifyExitHopHtlc(
|
||||
hash, amtPaid, testHtlcExpiry, testCurrentHeight+1,
|
||||
testInvoicePaymentHash, amtPaid, testHtlcExpiry, testCurrentHeight+1,
|
||||
getCircuitKey(0), hodlChan, testPayload,
|
||||
)
|
||||
if err != nil {
|
||||
@ -636,29 +537,6 @@ func TestCancelHoldInvoice(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newDB() (*channeldb.DB, func(), error) {
|
||||
// First, create a temporary directory to be used for the duration of
|
||||
// this test.
|
||||
tempDirName, err := ioutil.TempDir("", "channeldb")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Next, create channeldb for the first time.
|
||||
cdb, err := channeldb.Open(tempDirName)
|
||||
if err != nil {
|
||||
os.RemoveAll(tempDirName)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanUp := func() {
|
||||
cdb.Close()
|
||||
os.RemoveAll(tempDirName)
|
||||
}
|
||||
|
||||
return cdb, cleanUp, nil
|
||||
}
|
||||
|
||||
// TestUnknownInvoice tests that invoice registry returns an error when the
|
||||
// invoice is unknown. This is to guard against returning a cancel hodl event
|
||||
// for forwarded htlcs. In the link, NotifyExitHopHtlc is only called if we are
|
||||
@ -673,7 +551,7 @@ func TestUnknownInvoice(t *testing.T) {
|
||||
hodlChan := make(chan interface{})
|
||||
amt := lnwire.MilliSatoshi(100000)
|
||||
_, err := ctx.registry.NotifyExitHopHtlc(
|
||||
hash, amt, testHtlcExpiry, testCurrentHeight,
|
||||
testInvoicePaymentHash, amt, testHtlcExpiry, testCurrentHeight,
|
||||
getCircuitKey(0), hodlChan, testPayload,
|
||||
)
|
||||
if err != channeldb.ErrInvoiceNotFound {
|
||||
@ -681,27 +559,15 @@ func TestUnknownInvoice(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockPayload struct {
|
||||
mpp *record.MPP
|
||||
}
|
||||
|
||||
func (p *mockPayload) MultiPath() *record.MPP {
|
||||
return p.mpp
|
||||
}
|
||||
|
||||
func (p *mockPayload) CustomRecords() record.CustomSet {
|
||||
return make(record.CustomSet)
|
||||
}
|
||||
|
||||
// TestSettleMpp tests settling of an invoice with multiple partial payments.
|
||||
func TestSettleMpp(t *testing.T) {
|
||||
defer timeout(t)()
|
||||
defer timeout()()
|
||||
|
||||
ctx := newTestContext(t)
|
||||
defer ctx.cleanup()
|
||||
|
||||
// Add the invoice.
|
||||
_, err := ctx.registry.AddInvoice(testInvoice, hash)
|
||||
_, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -713,7 +579,7 @@ func TestSettleMpp(t *testing.T) {
|
||||
// Send htlc 1.
|
||||
hodlChan1 := make(chan interface{}, 1)
|
||||
event, err := ctx.registry.NotifyExitHopHtlc(
|
||||
hash, testInvoice.Terms.Value/2,
|
||||
testInvoicePaymentHash, testInvoice.Terms.Value/2,
|
||||
testHtlcExpiry,
|
||||
testCurrentHeight, getCircuitKey(10), hodlChan1, mppPayload,
|
||||
)
|
||||
@ -725,7 +591,7 @@ func TestSettleMpp(t *testing.T) {
|
||||
}
|
||||
|
||||
// Simulate mpp timeout releasing htlc 1.
|
||||
ctx.clock.setTime(testTime.Add(30 * time.Second))
|
||||
ctx.clock.SetTime(testTime.Add(30 * time.Second))
|
||||
|
||||
hodlEvent := (<-hodlChan1).(HodlEvent)
|
||||
if hodlEvent.Preimage != nil {
|
||||
@ -735,7 +601,7 @@ func TestSettleMpp(t *testing.T) {
|
||||
// Send htlc 2.
|
||||
hodlChan2 := make(chan interface{}, 1)
|
||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||
hash, testInvoice.Terms.Value/2,
|
||||
testInvoicePaymentHash, testInvoice.Terms.Value/2,
|
||||
testHtlcExpiry,
|
||||
testCurrentHeight, getCircuitKey(11), hodlChan2, mppPayload,
|
||||
)
|
||||
@ -749,7 +615,7 @@ func TestSettleMpp(t *testing.T) {
|
||||
// Send htlc 3.
|
||||
hodlChan3 := make(chan interface{}, 1)
|
||||
event, err = ctx.registry.NotifyExitHopHtlc(
|
||||
hash, testInvoice.Terms.Value/2,
|
||||
testInvoicePaymentHash, testInvoice.Terms.Value/2,
|
||||
testHtlcExpiry,
|
||||
testCurrentHeight, getCircuitKey(12), hodlChan3, mppPayload,
|
||||
)
|
||||
@ -762,7 +628,7 @@ func TestSettleMpp(t *testing.T) {
|
||||
|
||||
// Check that settled amount is equal to the sum of values of the htlcs
|
||||
// 0 and 1.
|
||||
inv, err := ctx.registry.LookupInvoice(hash)
|
||||
inv, err := ctx.registry.LookupInvoice(testInvoicePaymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -774,3 +640,105 @@ func TestSettleMpp(t *testing.T) {
|
||||
testInvoice.Terms.Value, inv.AmtPaid)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that invoices are canceled after expiration.
|
||||
func TestInvoiceExpiryWithRegistry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanup, err := newTestChannelDB()
|
||||
defer cleanup()
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testClock := clock.NewTestClock(testTime)
|
||||
|
||||
cfg := RegistryConfig{
|
||||
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
||||
Clock: testClock,
|
||||
}
|
||||
|
||||
expiryWatcher := NewInvoiceExpiryWatcher(cfg.Clock)
|
||||
registry := NewRegistry(cdb, expiryWatcher, &cfg)
|
||||
|
||||
// First prefill the Channel DB with some pre-existing invoices,
|
||||
// half of them still pending, half of them expired.
|
||||
const numExpired = 5
|
||||
const numPending = 5
|
||||
existingInvoices := generateInvoiceExpiryTestData(
|
||||
t, testTime, 0, numExpired, numPending,
|
||||
)
|
||||
|
||||
var expectedCancellations []lntypes.Hash
|
||||
|
||||
for paymentHash, expiredInvoice := range existingInvoices.expiredInvoices {
|
||||
if _, err := cdb.AddInvoice(expiredInvoice, paymentHash); err != nil {
|
||||
t.Fatalf("cannot add invoice to channel db: %v", err)
|
||||
}
|
||||
expectedCancellations = append(expectedCancellations, paymentHash)
|
||||
}
|
||||
|
||||
for paymentHash, pendingInvoice := range existingInvoices.pendingInvoices {
|
||||
if _, err := cdb.AddInvoice(pendingInvoice, paymentHash); err != nil {
|
||||
t.Fatalf("cannot add invoice to channel db: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = registry.Start(); err != nil {
|
||||
t.Fatalf("cannot start registry: %v", err)
|
||||
}
|
||||
|
||||
// Now generate pending and invoices and add them to the registry while
|
||||
// it is up and running. We'll manipulate the clock to let them expire.
|
||||
newInvoices := generateInvoiceExpiryTestData(
|
||||
t, testTime, numExpired+numPending, 0, numPending,
|
||||
)
|
||||
|
||||
var invoicesThatWillCancel []lntypes.Hash
|
||||
for paymentHash, pendingInvoice := range newInvoices.pendingInvoices {
|
||||
_, err := registry.AddInvoice(pendingInvoice, paymentHash)
|
||||
invoicesThatWillCancel = append(invoicesThatWillCancel, paymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that they are really not canceled until before the clock is
|
||||
// advanced.
|
||||
for i := range invoicesThatWillCancel {
|
||||
invoice, err := registry.LookupInvoice(invoicesThatWillCancel[i])
|
||||
if err != nil {
|
||||
t.Fatalf("cannot find invoice: %v", err)
|
||||
}
|
||||
|
||||
if invoice.State == channeldb.ContractCanceled {
|
||||
t.Fatalf("expected pending invoice, got canceled")
|
||||
}
|
||||
}
|
||||
|
||||
// Fwd time 1 day.
|
||||
testClock.SetTime(testTime.Add(24 * time.Hour))
|
||||
|
||||
// Give some time to the watcher to cancel everything.
|
||||
time.Sleep(testTimeout)
|
||||
registry.Stop()
|
||||
|
||||
// Create the expected cancellation set before the final check.
|
||||
expectedCancellations = append(
|
||||
expectedCancellations, invoicesThatWillCancel...,
|
||||
)
|
||||
|
||||
// Retrospectively check that all invoices that were expected to be canceled
|
||||
// are indeed canceled.
|
||||
for i := range expectedCancellations {
|
||||
invoice, err := registry.LookupInvoice(expectedCancellations[i])
|
||||
if err != nil {
|
||||
t.Fatalf("cannot find invoice: %v", err)
|
||||
}
|
||||
|
||||
if invoice.State != channeldb.ContractCanceled {
|
||||
t.Fatalf("expected canceled invoice, got: %v", invoice.State)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
280
invoices/test_utils_test.go
Normal file
280
invoices/test_utils_test.go
Normal file
@ -0,0 +1,280 @@
|
||||
package invoices
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/zpay32"
|
||||
)
|
||||
|
||||
type mockPayload struct {
|
||||
mpp *record.MPP
|
||||
}
|
||||
|
||||
func (p *mockPayload) MultiPath() *record.MPP {
|
||||
return p.mpp
|
||||
}
|
||||
|
||||
func (p *mockPayload) CustomRecords() record.CustomSet {
|
||||
return make(record.CustomSet)
|
||||
}
|
||||
|
||||
var (
|
||||
testTimeout = 5 * time.Second
|
||||
|
||||
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
|
||||
|
||||
testInvoicePreimage = lntypes.Preimage{1}
|
||||
|
||||
testInvoicePaymentHash = testInvoicePreimage.Hash()
|
||||
|
||||
testHtlcExpiry = uint32(5)
|
||||
|
||||
testInvoiceCltvDelta = uint32(4)
|
||||
|
||||
testFinalCltvRejectDelta = int32(4)
|
||||
|
||||
testCurrentHeight = int32(1)
|
||||
|
||||
testPrivKeyBytes, _ = hex.DecodeString(
|
||||
"e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734")
|
||||
|
||||
testPrivKey, _ = btcec.PrivKeyFromBytes(
|
||||
btcec.S256(), testPrivKeyBytes)
|
||||
|
||||
testInvoiceDescription = "coffee"
|
||||
|
||||
testInvoiceAmount = lnwire.MilliSatoshi(100000)
|
||||
|
||||
testNetParams = &chaincfg.MainNetParams
|
||||
|
||||
testMessageSigner = zpay32.MessageSigner{
|
||||
SignCompact: func(hash []byte) ([]byte, error) {
|
||||
sig, err := btcec.SignCompact(btcec.S256(), testPrivKey, hash, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't sign the message: %v", err)
|
||||
}
|
||||
return sig, nil
|
||||
},
|
||||
}
|
||||
|
||||
testFeatures = lnwire.NewFeatureVector(
|
||||
nil, lnwire.Features,
|
||||
)
|
||||
|
||||
testPayload = &mockPayload{}
|
||||
|
||||
testInvoiceCreationDate = testTime
|
||||
)
|
||||
|
||||
var (
|
||||
testInvoiceAmt = lnwire.MilliSatoshi(100000)
|
||||
testInvoice = &channeldb.Invoice{
|
||||
Terms: channeldb.ContractTerm{
|
||||
PaymentPreimage: testInvoicePreimage,
|
||||
Value: testInvoiceAmt,
|
||||
Expiry: time.Hour,
|
||||
Features: testFeatures,
|
||||
},
|
||||
CreationDate: testInvoiceCreationDate,
|
||||
}
|
||||
|
||||
testHodlInvoice = &channeldb.Invoice{
|
||||
Terms: channeldb.ContractTerm{
|
||||
PaymentPreimage: channeldb.UnknownPreimage,
|
||||
Value: testInvoiceAmt,
|
||||
Expiry: time.Hour,
|
||||
Features: testFeatures,
|
||||
},
|
||||
CreationDate: testInvoiceCreationDate,
|
||||
}
|
||||
)
|
||||
|
||||
func newTestChannelDB() (*channeldb.DB, func(), error) {
|
||||
// First, create a temporary directory to be used for the duration of
|
||||
// this test.
|
||||
tempDirName, err := ioutil.TempDir("", "channeldb")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Next, create channeldb for the first time.
|
||||
cdb, err := channeldb.Open(tempDirName)
|
||||
if err != nil {
|
||||
os.RemoveAll(tempDirName)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanUp := func() {
|
||||
cdb.Close()
|
||||
os.RemoveAll(tempDirName)
|
||||
}
|
||||
|
||||
return cdb, cleanUp, nil
|
||||
}
|
||||
|
||||
type testContext struct {
|
||||
cdb *channeldb.DB
|
||||
registry *InvoiceRegistry
|
||||
clock *clock.TestClock
|
||||
|
||||
cleanup func()
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func newTestContext(t *testing.T) *testContext {
|
||||
clock := clock.NewTestClock(testTime)
|
||||
|
||||
cdb, cleanup, err := newTestChannelDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cdb.Now = clock.Now
|
||||
|
||||
expiryWatcher := NewInvoiceExpiryWatcher(clock)
|
||||
|
||||
// Instantiate and start the invoice ctx.registry.
|
||||
cfg := RegistryConfig{
|
||||
FinalCltvRejectDelta: testFinalCltvRejectDelta,
|
||||
HtlcHoldDuration: 30 * time.Second,
|
||||
Clock: clock,
|
||||
}
|
||||
registry := NewRegistry(cdb, expiryWatcher, &cfg)
|
||||
|
||||
err = registry.Start()
|
||||
if err != nil {
|
||||
cleanup()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := testContext{
|
||||
cdb: cdb,
|
||||
registry: registry,
|
||||
clock: clock,
|
||||
t: t,
|
||||
cleanup: func() {
|
||||
registry.Stop()
|
||||
cleanup()
|
||||
},
|
||||
}
|
||||
|
||||
return &ctx
|
||||
}
|
||||
|
||||
func getCircuitKey(htlcID uint64) channeldb.CircuitKey {
|
||||
return channeldb.CircuitKey{
|
||||
ChanID: lnwire.ShortChannelID{
|
||||
BlockHeight: 1, TxIndex: 2, TxPosition: 3,
|
||||
},
|
||||
HtlcID: htlcID,
|
||||
}
|
||||
}
|
||||
|
||||
func newTestInvoice(t *testing.T, preimage lntypes.Preimage,
|
||||
timestamp time.Time, expiry time.Duration) *channeldb.Invoice {
|
||||
|
||||
if expiry == 0 {
|
||||
expiry = time.Hour
|
||||
}
|
||||
|
||||
rawInvoice, err := zpay32.NewInvoice(
|
||||
testNetParams,
|
||||
preimage.Hash(),
|
||||
timestamp,
|
||||
zpay32.Amount(testInvoiceAmount),
|
||||
zpay32.Description(testInvoiceDescription),
|
||||
zpay32.Expiry(expiry))
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error while creating new invoice: %v", err)
|
||||
}
|
||||
|
||||
paymentRequest, err := rawInvoice.Encode(testMessageSigner)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error while encoding payment request: %v", err)
|
||||
}
|
||||
|
||||
return &channeldb.Invoice{
|
||||
Terms: channeldb.ContractTerm{
|
||||
PaymentPreimage: preimage,
|
||||
Value: testInvoiceAmount,
|
||||
Expiry: expiry,
|
||||
Features: testFeatures,
|
||||
},
|
||||
PaymentRequest: []byte(paymentRequest),
|
||||
CreationDate: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// timeout implements a test level timeout.
|
||||
func timeout() func() {
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error writing to std out after timeout: %v", err))
|
||||
}
|
||||
panic("timeout")
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
return func() {
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
|
||||
// invoiceExpiryTestData simply holds generated expired and pending invoices.
|
||||
type invoiceExpiryTestData struct {
|
||||
expiredInvoices map[lntypes.Hash]*channeldb.Invoice
|
||||
pendingInvoices map[lntypes.Hash]*channeldb.Invoice
|
||||
}
|
||||
|
||||
// generateInvoiceExpiryTestData generates the specified number of fake expired
|
||||
// and pending invoices anchored to the passed now timestamp.
|
||||
func generateInvoiceExpiryTestData(
|
||||
t *testing.T, now time.Time,
|
||||
offset, numExpired, numPending int) invoiceExpiryTestData {
|
||||
|
||||
var testData invoiceExpiryTestData
|
||||
|
||||
testData.expiredInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
|
||||
testData.pendingInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
|
||||
|
||||
expiredCreationDate := now.Add(-24 * time.Hour)
|
||||
|
||||
for i := 1; i <= numExpired; i++ {
|
||||
var preimage lntypes.Preimage
|
||||
binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i))
|
||||
expiry := time.Duration((i+offset)%24) * time.Hour
|
||||
invoice := newTestInvoice(t, preimage, expiredCreationDate, expiry)
|
||||
testData.expiredInvoices[preimage.Hash()] = invoice
|
||||
}
|
||||
|
||||
for i := 1; i <= numPending; i++ {
|
||||
var preimage lntypes.Preimage
|
||||
binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i))
|
||||
expiry := time.Duration((i+offset)%24) * time.Hour
|
||||
invoice := newTestInvoice(t, preimage, now, expiry)
|
||||
testData.pendingInvoices[preimage.Hash()] = invoice
|
||||
}
|
||||
|
||||
return testData
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
package invoices
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// timeout implements a test level timeout.
|
||||
func timeout(t *testing.T) func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
|
||||
panic("test timeout")
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
return func() {
|
||||
close(done)
|
||||
}
|
||||
}
|
@ -9,6 +9,8 @@ import (
|
||||
// PriorityQueue will be able to use that to build and restore an underlying
|
||||
// heap.
|
||||
type PriorityQueueItem interface {
|
||||
// Less must return true if this item is ordered before other and false
|
||||
// otherwise.
|
||||
Less(other PriorityQueueItem) bool
|
||||
}
|
||||
|
||||
@ -43,7 +45,7 @@ func (pq *priorityQueue) Pop() interface{} {
|
||||
return item
|
||||
}
|
||||
|
||||
// Priority wrap a standard heap in a more object-oriented structure.
|
||||
// PriorityQueue wraps a standard heap into a self contained class.
|
||||
type PriorityQueue struct {
|
||||
queue priorityQueue
|
||||
}
|
||||
|
@ -33,6 +33,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/chanfitness"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channelnotifier"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/contractcourt"
|
||||
"github.com/lightningnetwork/lnd/discovery"
|
||||
"github.com/lightningnetwork/lnd/feature"
|
||||
@ -381,8 +382,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
|
||||
registryConfig := invoices.RegistryConfig{
|
||||
FinalCltvRejectDelta: defaultFinalCltvRejectDelta,
|
||||
HtlcHoldDuration: invoices.DefaultHtlcHoldDuration,
|
||||
Now: time.Now,
|
||||
TickAfter: time.After,
|
||||
Clock: clock.NewDefaultClock(),
|
||||
}
|
||||
|
||||
s := &server{
|
||||
@ -393,7 +393,10 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
|
||||
readPool: readPool,
|
||||
chansToRestore: chansToRestore,
|
||||
|
||||
invoices: invoices.NewRegistry(chanDB, ®istryConfig),
|
||||
invoices: invoices.NewRegistry(
|
||||
chanDB, invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()),
|
||||
®istryConfig,
|
||||
),
|
||||
|
||||
channelNotifier: channelnotifier.New(chanDB),
|
||||
|
||||
|
@ -82,6 +82,10 @@ const (
|
||||
// This is chosen to be the maximum number of bytes that can fit into a
|
||||
// single QR code: https://en.wikipedia.org/wiki/QR_code#Storage
|
||||
maxInvoiceLength = 7089
|
||||
|
||||
// DefaultInvoiceExpiry is the default expiry duration from the creation
|
||||
// timestamp if expiry is set to zero.
|
||||
DefaultInvoiceExpiry = time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
|
Loading…
Reference in New Issue
Block a user