channeldb+invoices: add ScanInvoices and integrate with InvoiceRegistry

This commit adds channeldb.ScanInvoices to scan through all invoices in
the database. The new call will also replace the already existing
channeldb.FetchAllInvoicesWithPaymentHash call in preparation to collect
invoices we'd like to delete and watch for expiry in one scan in later
commits.
This commit is contained in:
Andras Banki-Horvath 2020-07-28 21:22:23 +02:00
parent ba3c65bfd6
commit 92f3b0a30c
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8
5 changed files with 90 additions and 148 deletions

@ -622,9 +622,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
} }
} }
// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their // TestScanInvoices tests that ScanInvoices scans trough all stored invoices
// corresponding payment hashes. // correctly.
func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { func TestScanInvoices(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := MakeTestDB() db, cleanup, err := MakeTestDB()
@ -633,97 +633,54 @@ func TestFetchAllInvoicesWithPaymentHash(t *testing.T) {
t.Fatalf("unable to make test db: %v", err) t.Fatalf("unable to make test db: %v", err)
} }
// With an empty DB we expect to return no error and an empty list. var invoices map[lntypes.Hash]*Invoice
empty, err := db.FetchAllInvoicesWithPaymentHash(false) callCount := 0
if err != nil { resetCount := 0
t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v",
err) // reset is used to reset/initialize results and is called once
// upon calling ScanInvoices and when the underlying transaction is
// retried.
reset := func() {
invoices = make(map[lntypes.Hash]*Invoice)
callCount = 0
resetCount++
} }
if len(empty) != 0 { scanFunc := func(paymentHash lntypes.Hash, invoice *Invoice) error {
t.Fatalf("expected empty list as a result, got: %v", empty) invoices[paymentHash] = invoice
callCount++
return nil
} }
states := []ContractState{ // With an empty DB we expect to not scan any invoices.
ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, require.NoError(t, db.ScanInvoices(scanFunc, reset))
} require.Equal(t, 0, len(invoices))
require.Equal(t, 0, callCount)
require.Equal(t, 1, resetCount)
numInvoices := len(states) * 2 numInvoices := 5
testPendingInvoices := make(map[lntypes.Hash]*Invoice) testInvoices := make(map[lntypes.Hash]*Invoice)
testAllInvoices := make(map[lntypes.Hash]*Invoice)
// Now populate the DB and check if we can get all invoices with their // Now populate the DB and check if we can get all invoices with their
// payment hashes as expected. // payment hashes as expected.
for i := 1; i <= numInvoices; i++ { for i := 1; i <= numInvoices; i++ {
invoice, err := randInvoice(lnwire.MilliSatoshi(i)) invoice, err := randInvoice(lnwire.MilliSatoshi(i))
if err != nil { require.NoError(t, err)
t.Fatalf("unable to create invoice: %v", err)
}
// Set the contract state of the next invoice such that there's an equal
// number for all possbile states.
invoice.State = states[i%len(states)]
paymentHash := invoice.Terms.PaymentPreimage.Hash() paymentHash := invoice.Terms.PaymentPreimage.Hash()
testInvoices[paymentHash] = invoice
if invoice.IsPending() { _, err = db.AddInvoice(invoice, paymentHash)
testPendingInvoices[paymentHash] = invoice require.NoError(t, err)
}
testAllInvoices[paymentHash] = invoice
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
}
pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true)
if err != nil {
t.Fatalf("can't fetch invoices with payment hash: %v", err)
}
if len(testPendingInvoices) != len(pendingInvoices) {
t.Fatalf("expected %v pending invoices, got: %v",
len(testPendingInvoices), len(pendingInvoices))
}
allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false)
if err != nil {
t.Fatalf("can't fetch invoices with payment hash: %v", err)
}
if len(testAllInvoices) != len(allInvoices) {
t.Fatalf("expected %v invoices, got: %v",
len(testAllInvoices), len(allInvoices))
}
for i := range pendingInvoices {
expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash]
if !ok {
t.Fatalf("coulnd't find invoice with hash: %v",
pendingInvoices[i].PaymentHash)
}
// Zero out add index to not confuse require.Equal.
pendingInvoices[i].Invoice.AddIndex = 0
expected.AddIndex = 0
require.Equal(t, *expected, pendingInvoices[i].Invoice)
}
for i := range allInvoices {
expected, ok := testAllInvoices[allInvoices[i].PaymentHash]
if !ok {
t.Fatalf("coulnd't find invoice with hash: %v",
allInvoices[i].PaymentHash)
}
// Zero out add index to not confuse require.Equal.
allInvoices[i].Invoice.AddIndex = 0
expected.AddIndex = 0
require.Equal(t, *expected, allInvoices[i].Invoice)
} }
resetCount = 0
require.NoError(t, db.ScanInvoices(scanFunc, reset))
require.Equal(t, numInvoices, callCount)
require.Equal(t, testInvoices, invoices)
require.Equal(t, 1, resetCount)
} }
// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it // TestDuplicateSettleInvoice tests that if we add a new invoice and settle it

@ -723,28 +723,21 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.RBucket,
} }
} }
// InvoiceWithPaymentHash is used to store an invoice and its corresponding // ScanInvoices scans trough all invoices and calls the passed scanFunc for
// payment hash. This struct is only used to store results of // for each invoice with its respective payment hash. Additionally a reset()
// ChannelDB.FetchAllInvoicesWithPaymentHash() call. // closure is passed which is used to reset/initialize partial results and also
type InvoiceWithPaymentHash struct { // to signal if the kvdb.View transaction has been retried.
// Invoice holds the invoice as selected from the invoices bucket. func (d *DB) ScanInvoices(
Invoice Invoice scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error {
// PaymentHash is the payment hash for the Invoice. return kvdb.View(d, func(tx kvdb.RTx) error {
PaymentHash lntypes.Hash // Reset partial results. As transaction commit success is not
} // guaranteed when using etcd, we need to be prepared to redo
// the whole view transaction. In order to be able to do that
// we need a way to reset existing results. This is also done
// upon first run for initialization.
reset()
// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes
// currently stored within the database. If the pendingOnly param is true, then
// only open or accepted invoices and their payment hashes will be returned,
// skipping all invoices that are fully settled or canceled. Note that the
// returned array is not ordered by add index.
func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) (
[]InvoiceWithPaymentHash, error) {
var result []InvoiceWithPaymentHash
err := kvdb.View(d, func(tx kvdb.RTx) error {
invoices := tx.ReadBucket(invoiceBucket) invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil { if invoices == nil {
return ErrNoInvoicesCreated return ErrNoInvoicesCreated
@ -775,26 +768,12 @@ func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) (
return err return err
} }
if pendingOnly && !invoice.IsPending() { var paymentHash lntypes.Hash
return nil copy(paymentHash[:], k)
}
invoiceWithPaymentHash := InvoiceWithPaymentHash{ return scanFunc(paymentHash, &invoice)
Invoice: invoice,
}
copy(invoiceWithPaymentHash.PaymentHash[:], k)
result = append(result, invoiceWithPaymentHash)
return nil
}) })
}) })
if err != nil {
return nil, err
}
return result, nil
} }
// InvoiceQuery represents a query to the invoice database. The query allows a // InvoiceQuery represents a query to the invoice database. The query allows a

@ -129,14 +129,11 @@ func (ew *InvoiceExpiryWatcher) prepareInvoice(
// AddInvoices adds multiple invoices to the InvoiceExpiryWatcher. // AddInvoices adds multiple invoices to the InvoiceExpiryWatcher.
func (ew *InvoiceExpiryWatcher) AddInvoices( func (ew *InvoiceExpiryWatcher) AddInvoices(
invoices []channeldb.InvoiceWithPaymentHash) { invoices map[lntypes.Hash]*channeldb.Invoice) {
invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices)) invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices))
for _, invoiceWithPaymentHash := range invoices { for paymentHash, invoice := range invoices {
newInvoiceExpiry := ew.prepareInvoice( newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice)
invoiceWithPaymentHash.PaymentHash,
&invoiceWithPaymentHash.Invoice,
)
if newInvoiceExpiry != nil { if newInvoiceExpiry != nil {
invoicesWithExpiry = append( invoicesWithExpiry = append(
invoicesWithExpiry, newInvoiceExpiry, invoicesWithExpiry, newInvoiceExpiry,

@ -158,24 +158,14 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) {
t.Parallel() t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5) test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
var invoices []channeldb.InvoiceWithPaymentHash invoices := make(map[lntypes.Hash]*channeldb.Invoice)
for hash, invoice := range test.testData.expiredInvoices { for hash, invoice := range test.testData.expiredInvoices {
invoices = append(invoices, invoices[hash] = invoice
channeldb.InvoiceWithPaymentHash{
Invoice: *invoice,
PaymentHash: hash,
},
)
} }
for hash, invoice := range test.testData.pendingInvoices { for hash, invoice := range test.testData.pendingInvoices {
invoices = append(invoices, invoices[hash] = invoice
channeldb.InvoiceWithPaymentHash{
Invoice: *invoice,
PaymentHash: hash,
},
)
} }
test.watcher.AddInvoices(invoices) test.watcher.AddInvoices(invoices)

@ -147,21 +147,39 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher,
} }
} }
// populateExpiryWatcher fetches all active invoices and their corresponding // scanInvoicesOnStart will scan all invoices on start and add active invoices
// payment hashes from ChannelDB and adds them to the expiry watcher. // to the invoice expiry watcher.
func (i *InvoiceRegistry) populateExpiryWatcher() error { func (i *InvoiceRegistry) scanInvoicesOnStart() error {
pendingOnly := true var pending map[lntypes.Hash]*channeldb.Invoice
pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly)
if err != nil && err != channeldb.ErrNoInvoicesCreated { reset := func() {
log.Errorf( // Zero out our results on start and if the scan is ever run
"Error while prefetching active invoices from the database: %v", err, // more than once. This latter case can happen if the kvdb
) // layer needs to retry the View transaction underneath (eg.
// using the etcd driver, where all transactions are allowed
// to retry for serializability).
pending = make(map[lntypes.Hash]*channeldb.Invoice)
}
scanFunc := func(
paymentHash lntypes.Hash, invoice *channeldb.Invoice) error {
if invoice.IsPending() {
pending[paymentHash] = invoice
}
return nil
}
err := i.cdb.ScanInvoices(scanFunc, reset)
if err != nil {
return err return err
} }
log.Debugf("Adding %d pending invoices to the expiry watcher", log.Debugf("Adding %d pending invoices to the expiry watcher",
len(pendingInvoices)) len(pending))
i.expiryWatcher.AddInvoices(pendingInvoices) i.expiryWatcher.AddInvoices(pending)
return nil return nil
} }
@ -178,8 +196,9 @@ func (i *InvoiceRegistry) Start() error {
i.wg.Add(1) i.wg.Add(1)
go i.invoiceEventLoop() go i.invoiceEventLoop()
// Now prefetch all pending invoices to the expiry watcher. // Now scan all pending and removable invoices to the expiry watcher or
err = i.populateExpiryWatcher() // delete them.
err = i.scanInvoicesOnStart()
if err != nil { if err != nil {
i.Stop() i.Stop()
return err return err