From ec6c0689ef6d61c060f912781b1d1423f7ba6e7f Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Fri, 3 Jan 2020 15:53:24 +0100 Subject: [PATCH 1/3] channeldb: fix channeldb.InvoiceHTLC deep copy This commit fixes deep copy of chaneldb.InvoiceHTLC, where previously the map holding the custom record set wasn't properly copied. --- channeldb/invoices.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/channeldb/invoices.go b/channeldb/invoices.go index ff916a0a..62c767a4 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -1327,6 +1327,19 @@ func copySlice(src []byte) []byte { return dest } +// copyInvoiceHTLC makes a deep copy of the supplied invoice HTLC. +func copyInvoiceHTLC(src *InvoiceHTLC) *InvoiceHTLC { + result := *src + + // Make a copy of the CustomSet map. + result.CustomRecords = make(record.CustomSet) + for k, v := range src.CustomRecords { + result.CustomRecords[k] = v + } + + return &result +} + // copyInvoice makes a deep copy of the supplied invoice. func copyInvoice(src *Invoice) *Invoice { dest := Invoice{ @@ -1347,7 +1360,7 @@ func copyInvoice(src *Invoice) *Invoice { dest.Terms.Features = src.Terms.Features.Clone() for k, v := range src.Htlcs { - dest.Htlcs[k] = v + dest.Htlcs[k] = copyInvoiceHTLC(v) } return &dest From 4136b18e3d737b188f2f91a25a12d75c034ee387 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Fri, 3 Jan 2020 16:19:49 +0100 Subject: [PATCH 2/3] channeldb: remove time.Now() from tests --- channeldb/invoice_test.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 00de1326..98ece2c6 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -14,6 +14,7 @@ import ( var ( emptyFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) + testNow = time.Unix(1, 0) ) func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { @@ -23,9 +24,7 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { } i := &Invoice{ - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), + CreationDate: testNow, Terms: ContractTerm{ Expiry: 4000, PaymentPreimage: pre, @@ -87,9 +86,7 @@ func TestInvoiceWorkflow(t *testing.T) { // Create a fake invoice which we'll use several times in the tests // below. fakeInvoice := &Invoice{ - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), + CreationDate: testNow, Htlcs: map[CircuitKey]*InvoiceHTLC{}, } fakeInvoice.Memo = []byte("memo") @@ -285,6 +282,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { if err != nil { t.Fatalf("unable to make test db: %v", err) } + db.Now = func() time.Time { return testNow } // We'll start off by creating 20 random invoices, and inserting them // into the database. @@ -537,7 +535,7 @@ func TestDuplicateSettleInvoice(t *testing.T) { if err != nil { t.Fatalf("unable to make test db: %v", err) } - db.Now = func() time.Time { return time.Unix(1, 0) } + db.Now = func() time.Time { return testNow } // We'll start out by creating an invoice and writing it to the DB. amt := lnwire.NewMSatFromSatoshis(1000) @@ -675,6 +673,7 @@ func TestQueryInvoices(t *testing.T) { if err != nil { t.Fatalf("unable to make test db: %v", err) } + db.Now = func() time.Time { return testNow } // To begin the test, we'll add 50 invoices to the database. We'll // assume that the index of the invoice within the database is the same From 8af5d8bc73fbaa0ea57c137fa2644cde7f605be4 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Fri, 3 Jan 2020 16:21:40 +0100 Subject: [PATCH 3/3] channeldb: remove unused, test only FetchAllInvoices function This commit removes channeldb.FetchAllInvoices and changes tests such that expectation sets are prepared in the test case instead of selected from the DB. --- channeldb/invoice_test.go | 147 ++++++++++++++------------------------ channeldb/invoices.go | 43 ----------- 2 files changed, 53 insertions(+), 137 deletions(-) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 98ece2c6..922239e5 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -2,6 +2,7 @@ package channeldb import ( "crypto/rand" + "math" "reflect" "testing" "time" @@ -51,6 +52,21 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { return i, nil } +// settleTestInvoice settles a test invoice. +func settleTestInvoice(invoice *Invoice, settleIndex uint64) { + invoice.SettleDate = testNow + invoice.AmtPaid = invoice.Terms.Value + invoice.State = ContractSettled + invoice.Htlcs[CircuitKey{}] = &InvoiceHTLC{ + Amt: invoice.Terms.Value, + AcceptTime: testNow, + ResolveTime: testNow, + State: HtlcStateSettled, + CustomRecords: make(record.CustomSet), + } + invoice.SettleIndex = settleIndex +} + // Tests that pending invoices are those which are either in ContractOpen or // in ContractAccepted state. func TestInvoiceIsPending(t *testing.T) { @@ -173,7 +189,7 @@ func TestInvoiceWorkflow(t *testing.T) { amt := lnwire.NewMSatFromSatoshis(1000) invoices := make([]*Invoice, numInvoices+1) invoices[0] = &dbInvoice2 - for i := 1; i < len(invoices)-1; i++ { + for i := 1; i < len(invoices); i++ { invoice, err := randInvoice(amt) if err != nil { t.Fatalf("unable to create invoice: %v", err) @@ -188,20 +204,26 @@ func TestInvoiceWorkflow(t *testing.T) { } // Perform a scan to collect all the active invoices. - dbInvoices, err := db.FetchAllInvoices(false) + query := InvoiceQuery{ + IndexOffset: 0, + NumMaxInvoices: math.MaxUint64, + PendingOnly: false, + } + + response, err := db.QueryInvoices(query) if err != nil { - t.Fatalf("unable to fetch all invoices: %v", err) + t.Fatalf("invoice query failed: %v", err) } // The retrieve list of invoices should be identical as since we're // using big endian, the invoices should be retrieved in ascending // order (and the primary key should be incremented with each // insertion). - for i := 0; i < len(invoices)-1; i++ { - if !reflect.DeepEqual(*invoices[i], dbInvoices[i]) { + for i := 0; i < len(invoices); i++ { + if !reflect.DeepEqual(*invoices[i], response.Invoices[i]) { t.Fatalf("retrieved invoices don't match %v vs %v", spew.Sdump(invoices[i]), - spew.Sdump(dbInvoices[i])) + spew.Sdump(response.Invoices[i])) } } } @@ -351,6 +373,8 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } + var settledInvoices []Invoice + var settleIndex uint64 = 1 // We'll now only settle the latter half of each of those invoices. for i := 10; i < len(invoices); i++ { invoice := &invoices[i] @@ -358,21 +382,18 @@ func TestInvoiceAddTimeSeries(t *testing.T) { paymentHash := invoice.Terms.PaymentPreimage.Hash() _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(0), + paymentHash, getUpdateInvoice(invoice.Terms.Value), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } - } - invoices, err = db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to fetch invoices: %v", err) - } + // Create the settled invoice for the expectation set. + settleTestInvoice(invoice, settleIndex) + settleIndex++ - // We'll slice off the first 10 invoices, as we only settled the last - // 10. - invoices = invoices[10:] + settledInvoices = append(settledInvoices, *invoice) + } // We'll now prepare an additional set of queries to ensure the settle // time series has properly been maintained in the database. @@ -397,7 +418,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { // being returned, as we only settled those. { sinceSettleIndex: 1, - resp: invoices[1:], + resp: settledInvoices[1:], }, } @@ -600,69 +621,6 @@ func TestDuplicateSettleInvoice(t *testing.T) { } } -// TestFetchAllInvoices tests that FetchAllInvoices works as expected. -func TestFetchAllInvoices(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - contractStates := []ContractState{ - ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, - } - - numInvoices := len(contractStates) * 2 - - var expectedPendingInvoices []Invoice - var expectedAllInvoices []Invoice - - for i := 1; i <= numInvoices; i++ { - invoice, err := randInvoice(lnwire.MilliSatoshi(i)) - - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - invoice.AddIndex = uint64(i) - // Set the contract state of the next invoice such that there's an equal - // number for all possbile states. - invoice.State = contractStates[i%len(contractStates)] - - paymentHash := invoice.Terms.PaymentPreimage.Hash() - if invoice.IsPending() { - expectedPendingInvoices = append(expectedPendingInvoices, *invoice) - } - expectedAllInvoices = append(expectedAllInvoices, *invoice) - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice: %v", err) - } - } - - pendingInvoices, err := db.FetchAllInvoices(true) - if err != nil { - t.Fatalf("unable to fetch all pending invoices: %v", err) - } - - allInvoices, err := db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to fetch all non pending invoices: %v", err) - } - - if !reflect.DeepEqual(pendingInvoices, expectedPendingInvoices) { - t.Fatalf("pending invoices: %v\n != \n expected einvoices: %v", - spew.Sdump(pendingInvoices), spew.Sdump(expectedPendingInvoices)) - } - - if !reflect.DeepEqual(allInvoices, expectedAllInvoices) { - t.Fatalf("pending + non pending: %v\n != \n expected: %v", - spew.Sdump(allInvoices), spew.Sdump(expectedAllInvoices)) - } -} - // TestQueryInvoices ensures that we can properly query the invoice database for // invoices using different types of queries. func TestQueryInvoices(t *testing.T) { @@ -679,8 +637,13 @@ func TestQueryInvoices(t *testing.T) { // assume that the index of the invoice within the database is the same // as the amount of the invoice itself. const numInvoices = 50 - for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ { - invoice, err := randInvoice(i) + var settleIndex uint64 = 1 + var invoices []Invoice + var pendingInvoices []Invoice + + for i := 1; i <= numInvoices; i++ { + amt := lnwire.MilliSatoshi(i) + invoice, err := randInvoice(amt) if err != nil { t.Fatalf("unable to create invoice: %v", err) } @@ -694,24 +657,20 @@ func TestQueryInvoices(t *testing.T) { // We'll only settle half of all invoices created. if i%2 == 0 { _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(i), + paymentHash, getUpdateInvoice(amt), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } - } - } - // We'll then retrieve the set of all invoices and pending invoices. - // This will serve useful when comparing the expected responses of the - // query with the actual ones. - invoices, err := db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to retrieve invoices: %v", err) - } - pendingInvoices, err := db.FetchAllInvoices(true) - if err != nil { - t.Fatalf("unable to retrieve pending invoices: %v", err) + // Create the settled invoice for the expectation set. + settleTestInvoice(invoice, settleIndex) + settleIndex++ + } else { + pendingInvoices = append(pendingInvoices, *invoice) + } + + invoices = append(invoices, *invoice) } // The test will consist of several queries along with their respective diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 62c767a4..006d5e68 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -642,49 +642,6 @@ func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) ( return result, nil } -// FetchAllInvoices returns all invoices currently stored within the database. -// If the pendingOnly param is set to true, then only invoices in open or -// accepted state will be returned, skipping all invoices that are fully -// settled or canceled. -func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { - var invoices []Invoice - - err := d.View(func(tx *bbolt.Tx) error { - invoiceB := tx.Bucket(invoiceBucket) - if invoiceB == nil { - return ErrNoInvoicesCreated - } - - // Iterate through the entire key space of the top-level - // invoice bucket. If key with a non-nil value stores the next - // invoice ID which maps to the corresponding invoice. - return invoiceB.ForEach(func(k, v []byte) error { - if v == nil { - return nil - } - - invoiceReader := bytes.NewReader(v) - invoice, err := deserializeInvoice(invoiceReader) - if err != nil { - return err - } - - if pendingOnly && !invoice.IsPending() { - return nil - } - - invoices = append(invoices, invoice) - - return nil - }) - }) - if err != nil { - return nil, err - } - - return invoices, nil -} - // InvoiceQuery represents a query to the invoice database. The query allows a // caller to retrieve all invoices starting from a particular add index and // limit the number of results returned.