diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 98ece2c6..922239e5 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -2,6 +2,7 @@ package channeldb import ( "crypto/rand" + "math" "reflect" "testing" "time" @@ -51,6 +52,21 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { return i, nil } +// settleTestInvoice settles a test invoice. +func settleTestInvoice(invoice *Invoice, settleIndex uint64) { + invoice.SettleDate = testNow + invoice.AmtPaid = invoice.Terms.Value + invoice.State = ContractSettled + invoice.Htlcs[CircuitKey{}] = &InvoiceHTLC{ + Amt: invoice.Terms.Value, + AcceptTime: testNow, + ResolveTime: testNow, + State: HtlcStateSettled, + CustomRecords: make(record.CustomSet), + } + invoice.SettleIndex = settleIndex +} + // Tests that pending invoices are those which are either in ContractOpen or // in ContractAccepted state. func TestInvoiceIsPending(t *testing.T) { @@ -173,7 +189,7 @@ func TestInvoiceWorkflow(t *testing.T) { amt := lnwire.NewMSatFromSatoshis(1000) invoices := make([]*Invoice, numInvoices+1) invoices[0] = &dbInvoice2 - for i := 1; i < len(invoices)-1; i++ { + for i := 1; i < len(invoices); i++ { invoice, err := randInvoice(amt) if err != nil { t.Fatalf("unable to create invoice: %v", err) @@ -188,20 +204,26 @@ func TestInvoiceWorkflow(t *testing.T) { } // Perform a scan to collect all the active invoices. - dbInvoices, err := db.FetchAllInvoices(false) + query := InvoiceQuery{ + IndexOffset: 0, + NumMaxInvoices: math.MaxUint64, + PendingOnly: false, + } + + response, err := db.QueryInvoices(query) if err != nil { - t.Fatalf("unable to fetch all invoices: %v", err) + t.Fatalf("invoice query failed: %v", err) } // The retrieve list of invoices should be identical as since we're // using big endian, the invoices should be retrieved in ascending // order (and the primary key should be incremented with each // insertion). - for i := 0; i < len(invoices)-1; i++ { - if !reflect.DeepEqual(*invoices[i], dbInvoices[i]) { + for i := 0; i < len(invoices); i++ { + if !reflect.DeepEqual(*invoices[i], response.Invoices[i]) { t.Fatalf("retrieved invoices don't match %v vs %v", spew.Sdump(invoices[i]), - spew.Sdump(dbInvoices[i])) + spew.Sdump(response.Invoices[i])) } } } @@ -351,6 +373,8 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } + var settledInvoices []Invoice + var settleIndex uint64 = 1 // We'll now only settle the latter half of each of those invoices. for i := 10; i < len(invoices); i++ { invoice := &invoices[i] @@ -358,21 +382,18 @@ func TestInvoiceAddTimeSeries(t *testing.T) { paymentHash := invoice.Terms.PaymentPreimage.Hash() _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(0), + paymentHash, getUpdateInvoice(invoice.Terms.Value), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } - } - invoices, err = db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to fetch invoices: %v", err) - } + // Create the settled invoice for the expectation set. + settleTestInvoice(invoice, settleIndex) + settleIndex++ - // We'll slice off the first 10 invoices, as we only settled the last - // 10. - invoices = invoices[10:] + settledInvoices = append(settledInvoices, *invoice) + } // We'll now prepare an additional set of queries to ensure the settle // time series has properly been maintained in the database. @@ -397,7 +418,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { // being returned, as we only settled those. { sinceSettleIndex: 1, - resp: invoices[1:], + resp: settledInvoices[1:], }, } @@ -600,69 +621,6 @@ func TestDuplicateSettleInvoice(t *testing.T) { } } -// TestFetchAllInvoices tests that FetchAllInvoices works as expected. -func TestFetchAllInvoices(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - contractStates := []ContractState{ - ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, - } - - numInvoices := len(contractStates) * 2 - - var expectedPendingInvoices []Invoice - var expectedAllInvoices []Invoice - - for i := 1; i <= numInvoices; i++ { - invoice, err := randInvoice(lnwire.MilliSatoshi(i)) - - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - invoice.AddIndex = uint64(i) - // Set the contract state of the next invoice such that there's an equal - // number for all possbile states. - invoice.State = contractStates[i%len(contractStates)] - - paymentHash := invoice.Terms.PaymentPreimage.Hash() - if invoice.IsPending() { - expectedPendingInvoices = append(expectedPendingInvoices, *invoice) - } - expectedAllInvoices = append(expectedAllInvoices, *invoice) - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice: %v", err) - } - } - - pendingInvoices, err := db.FetchAllInvoices(true) - if err != nil { - t.Fatalf("unable to fetch all pending invoices: %v", err) - } - - allInvoices, err := db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to fetch all non pending invoices: %v", err) - } - - if !reflect.DeepEqual(pendingInvoices, expectedPendingInvoices) { - t.Fatalf("pending invoices: %v\n != \n expected einvoices: %v", - spew.Sdump(pendingInvoices), spew.Sdump(expectedPendingInvoices)) - } - - if !reflect.DeepEqual(allInvoices, expectedAllInvoices) { - t.Fatalf("pending + non pending: %v\n != \n expected: %v", - spew.Sdump(allInvoices), spew.Sdump(expectedAllInvoices)) - } -} - // TestQueryInvoices ensures that we can properly query the invoice database for // invoices using different types of queries. func TestQueryInvoices(t *testing.T) { @@ -679,8 +637,13 @@ func TestQueryInvoices(t *testing.T) { // assume that the index of the invoice within the database is the same // as the amount of the invoice itself. const numInvoices = 50 - for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ { - invoice, err := randInvoice(i) + var settleIndex uint64 = 1 + var invoices []Invoice + var pendingInvoices []Invoice + + for i := 1; i <= numInvoices; i++ { + amt := lnwire.MilliSatoshi(i) + invoice, err := randInvoice(amt) if err != nil { t.Fatalf("unable to create invoice: %v", err) } @@ -694,24 +657,20 @@ func TestQueryInvoices(t *testing.T) { // We'll only settle half of all invoices created. if i%2 == 0 { _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(i), + paymentHash, getUpdateInvoice(amt), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } - } - } - // We'll then retrieve the set of all invoices and pending invoices. - // This will serve useful when comparing the expected responses of the - // query with the actual ones. - invoices, err := db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to retrieve invoices: %v", err) - } - pendingInvoices, err := db.FetchAllInvoices(true) - if err != nil { - t.Fatalf("unable to retrieve pending invoices: %v", err) + // Create the settled invoice for the expectation set. + settleTestInvoice(invoice, settleIndex) + settleIndex++ + } else { + pendingInvoices = append(pendingInvoices, *invoice) + } + + invoices = append(invoices, *invoice) } // The test will consist of several queries along with their respective diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 62c767a4..006d5e68 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -642,49 +642,6 @@ func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( return result, nil } -// FetchAllInvoices returns all invoices currently stored within the database. -// If the pendingOnly param is set to true, then only invoices in open or -// accepted state will be returned, skipping all invoices that are fully -// settled or canceled. -func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { - var invoices []Invoice - - err := d.View(func(tx *bbolt.Tx) error { - invoiceB := tx.Bucket(invoiceBucket) - if invoiceB == nil { - return ErrNoInvoicesCreated - } - - // Iterate through the entire key space of the top-level - // invoice bucket. If key with a non-nil value stores the next - // invoice ID which maps to the corresponding invoice. - return invoiceB.ForEach(func(k, v []byte) error { - if v == nil { - return nil - } - - invoiceReader := bytes.NewReader(v) - invoice, err := deserializeInvoice(invoiceReader) - if err != nil { - return err - } - - if pendingOnly && !invoice.IsPending() { - return nil - } - - invoices = append(invoices, invoice) - - return nil - }) - }) - if err != nil { - return nil, err - } - - return invoices, nil -} - // InvoiceQuery represents a query to the invoice database. The query allows a // caller to retrieve all invoices starting from a particular add index and // limit the number of results returned.