channeldb+invoices: add ScanInvoices and integrate with InvoiceRegistry

This commit adds channeldb.ScanInvoices to scan through all invoices in
the database. The new call will also replace the already existing
channeldb.FetchAllInvoicesWithPaymentHash call in preparation to collect
invoices we'd like to delete and watch for expiry in one scan in later
commits.
This commit is contained in:
Andras Banki-Horvath 2020-07-28 21:22:23 +02:00
parent ba3c65bfd6
commit 92f3b0a30c
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8
5 changed files with 90 additions and 148 deletions

View File

@ -622,9 +622,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
}
}
// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their
// corresponding payment hashes.
func TestFetchAllInvoicesWithPaymentHash(t *testing.T) {
// TestScanInvoices tests that ScanInvoices scans trough all stored invoices
// correctly.
func TestScanInvoices(t *testing.T) {
t.Parallel()
db, cleanup, err := MakeTestDB()
@ -633,97 +633,54 @@ func TestFetchAllInvoicesWithPaymentHash(t *testing.T) {
t.Fatalf("unable to make test db: %v", err)
}
// With an empty DB we expect to return no error and an empty list.
empty, err := db.FetchAllInvoicesWithPaymentHash(false)
if err != nil {
t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v",
err)
var invoices map[lntypes.Hash]*Invoice
callCount := 0
resetCount := 0
// reset is used to reset/initialize results and is called once
// upon calling ScanInvoices and when the underlying transaction is
// retried.
reset := func() {
invoices = make(map[lntypes.Hash]*Invoice)
callCount = 0
resetCount++
}
if len(empty) != 0 {
t.Fatalf("expected empty list as a result, got: %v", empty)
scanFunc := func(paymentHash lntypes.Hash, invoice *Invoice) error {
invoices[paymentHash] = invoice
callCount++
return nil
}
states := []ContractState{
ContractOpen, ContractSettled, ContractCanceled, ContractAccepted,
}
// With an empty DB we expect to not scan any invoices.
require.NoError(t, db.ScanInvoices(scanFunc, reset))
require.Equal(t, 0, len(invoices))
require.Equal(t, 0, callCount)
require.Equal(t, 1, resetCount)
numInvoices := len(states) * 2
testPendingInvoices := make(map[lntypes.Hash]*Invoice)
testAllInvoices := make(map[lntypes.Hash]*Invoice)
numInvoices := 5
testInvoices := make(map[lntypes.Hash]*Invoice)
// Now populate the DB and check if we can get all invoices with their
// payment hashes as expected.
for i := 1; i <= numInvoices; i++ {
invoice, err := randInvoice(lnwire.MilliSatoshi(i))
if err != nil {
t.Fatalf("unable to create invoice: %v", err)
}
require.NoError(t, err)
// Set the contract state of the next invoice such that there's an equal
// number for all possbile states.
invoice.State = states[i%len(states)]
paymentHash := invoice.Terms.PaymentPreimage.Hash()
testInvoices[paymentHash] = invoice
if invoice.IsPending() {
testPendingInvoices[paymentHash] = invoice
}
testAllInvoices[paymentHash] = invoice
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
}
pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true)
if err != nil {
t.Fatalf("can't fetch invoices with payment hash: %v", err)
}
if len(testPendingInvoices) != len(pendingInvoices) {
t.Fatalf("expected %v pending invoices, got: %v",
len(testPendingInvoices), len(pendingInvoices))
}
allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false)
if err != nil {
t.Fatalf("can't fetch invoices with payment hash: %v", err)
}
if len(testAllInvoices) != len(allInvoices) {
t.Fatalf("expected %v invoices, got: %v",
len(testAllInvoices), len(allInvoices))
}
for i := range pendingInvoices {
expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash]
if !ok {
t.Fatalf("coulnd't find invoice with hash: %v",
pendingInvoices[i].PaymentHash)
}
// Zero out add index to not confuse require.Equal.
pendingInvoices[i].Invoice.AddIndex = 0
expected.AddIndex = 0
require.Equal(t, *expected, pendingInvoices[i].Invoice)
}
for i := range allInvoices {
expected, ok := testAllInvoices[allInvoices[i].PaymentHash]
if !ok {
t.Fatalf("coulnd't find invoice with hash: %v",
allInvoices[i].PaymentHash)
}
// Zero out add index to not confuse require.Equal.
allInvoices[i].Invoice.AddIndex = 0
expected.AddIndex = 0
require.Equal(t, *expected, allInvoices[i].Invoice)
_, err = db.AddInvoice(invoice, paymentHash)
require.NoError(t, err)
}
resetCount = 0
require.NoError(t, db.ScanInvoices(scanFunc, reset))
require.Equal(t, numInvoices, callCount)
require.Equal(t, testInvoices, invoices)
require.Equal(t, 1, resetCount)
}
// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it

View File

@ -723,28 +723,21 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.RBucket,
}
}
// InvoiceWithPaymentHash is used to store an invoice and its corresponding
// payment hash. This struct is only used to store results of
// ChannelDB.FetchAllInvoicesWithPaymentHash() call.
type InvoiceWithPaymentHash struct {
// Invoice holds the invoice as selected from the invoices bucket.
Invoice Invoice
// ScanInvoices scans trough all invoices and calls the passed scanFunc for
// for each invoice with its respective payment hash. Additionally a reset()
// closure is passed which is used to reset/initialize partial results and also
// to signal if the kvdb.View transaction has been retried.
func (d *DB) ScanInvoices(
scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error {
// PaymentHash is the payment hash for the Invoice.
PaymentHash lntypes.Hash
}
return kvdb.View(d, func(tx kvdb.RTx) error {
// Reset partial results. As transaction commit success is not
// guaranteed when using etcd, we need to be prepared to redo
// the whole view transaction. In order to be able to do that
// we need a way to reset existing results. This is also done
// upon first run for initialization.
reset()
// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes
// currently stored within the database. If the pendingOnly param is true, then
// only open or accepted invoices and their payment hashes will be returned,
// skipping all invoices that are fully settled or canceled. Note that the
// returned array is not ordered by add index.
func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) (
[]InvoiceWithPaymentHash, error) {
var result []InvoiceWithPaymentHash
err := kvdb.View(d, func(tx kvdb.RTx) error {
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
@ -775,26 +768,12 @@ func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) (
return err
}
if pendingOnly && !invoice.IsPending() {
return nil
}
var paymentHash lntypes.Hash
copy(paymentHash[:], k)
invoiceWithPaymentHash := InvoiceWithPaymentHash{
Invoice: invoice,
}
copy(invoiceWithPaymentHash.PaymentHash[:], k)
result = append(result, invoiceWithPaymentHash)
return nil
return scanFunc(paymentHash, &invoice)
})
})
if err != nil {
return nil, err
}
return result, nil
}
// InvoiceQuery represents a query to the invoice database. The query allows a

View File

@ -129,14 +129,11 @@ func (ew *InvoiceExpiryWatcher) prepareInvoice(
// AddInvoices adds multiple invoices to the InvoiceExpiryWatcher.
func (ew *InvoiceExpiryWatcher) AddInvoices(
invoices []channeldb.InvoiceWithPaymentHash) {
invoices map[lntypes.Hash]*channeldb.Invoice) {
invoicesWithExpiry := make([]*invoiceExpiry, 0, len(invoices))
for _, invoiceWithPaymentHash := range invoices {
newInvoiceExpiry := ew.prepareInvoice(
invoiceWithPaymentHash.PaymentHash,
&invoiceWithPaymentHash.Invoice,
)
for paymentHash, invoice := range invoices {
newInvoiceExpiry := ew.prepareInvoice(paymentHash, invoice)
if newInvoiceExpiry != nil {
invoicesWithExpiry = append(
invoicesWithExpiry, newInvoiceExpiry,

View File

@ -158,24 +158,14 @@ func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) {
t.Parallel()
test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
var invoices []channeldb.InvoiceWithPaymentHash
invoices := make(map[lntypes.Hash]*channeldb.Invoice)
for hash, invoice := range test.testData.expiredInvoices {
invoices = append(invoices,
channeldb.InvoiceWithPaymentHash{
Invoice: *invoice,
PaymentHash: hash,
},
)
invoices[hash] = invoice
}
for hash, invoice := range test.testData.pendingInvoices {
invoices = append(invoices,
channeldb.InvoiceWithPaymentHash{
Invoice: *invoice,
PaymentHash: hash,
},
)
invoices[hash] = invoice
}
test.watcher.AddInvoices(invoices)

View File

@ -147,21 +147,39 @@ func NewRegistry(cdb *channeldb.DB, expiryWatcher *InvoiceExpiryWatcher,
}
}
// populateExpiryWatcher fetches all active invoices and their corresponding
// payment hashes from ChannelDB and adds them to the expiry watcher.
func (i *InvoiceRegistry) populateExpiryWatcher() error {
pendingOnly := true
pendingInvoices, err := i.cdb.FetchAllInvoicesWithPaymentHash(pendingOnly)
if err != nil && err != channeldb.ErrNoInvoicesCreated {
log.Errorf(
"Error while prefetching active invoices from the database: %v", err,
)
// scanInvoicesOnStart will scan all invoices on start and add active invoices
// to the invoice expiry watcher.
func (i *InvoiceRegistry) scanInvoicesOnStart() error {
var pending map[lntypes.Hash]*channeldb.Invoice
reset := func() {
// Zero out our results on start and if the scan is ever run
// more than once. This latter case can happen if the kvdb
// layer needs to retry the View transaction underneath (eg.
// using the etcd driver, where all transactions are allowed
// to retry for serializability).
pending = make(map[lntypes.Hash]*channeldb.Invoice)
}
scanFunc := func(
paymentHash lntypes.Hash, invoice *channeldb.Invoice) error {
if invoice.IsPending() {
pending[paymentHash] = invoice
}
return nil
}
err := i.cdb.ScanInvoices(scanFunc, reset)
if err != nil {
return err
}
log.Debugf("Adding %d pending invoices to the expiry watcher",
len(pendingInvoices))
i.expiryWatcher.AddInvoices(pendingInvoices)
len(pending))
i.expiryWatcher.AddInvoices(pending)
return nil
}
@ -178,8 +196,9 @@ func (i *InvoiceRegistry) Start() error {
i.wg.Add(1)
go i.invoiceEventLoop()
// Now prefetch all pending invoices to the expiry watcher.
err = i.populateExpiryWatcher()
// Now scan all pending and removable invoices to the expiry watcher or
// delete them.
err = i.scanInvoicesOnStart()
if err != nil {
i.Stop()
return err