diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index b38c5cad..467d7a89 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -2,6 +2,7 @@ package channeldb import ( "crypto/rand" + mrand "math/rand" "reflect" "testing" "time" @@ -393,6 +394,114 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } +// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their +// corresponding payment hashes. +func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { + t.Parallel() + + db, cleanup, err := makeTestDB() + defer cleanup() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + // With an empty DB we expect to return no error and an empty list. + empty, err := db.FetchAllInvoicesWithPaymentHash(false) + if err != nil { + t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v", + err) + } + + if len(empty) != 0 { + t.Fatalf("expected empty list as a result, got: %v", empty) + } + + // Now populate the DB and check if we can get all invoices with their + // payment hashes as expected. + const numInvoices = 20 + testPendingInvoices := make(map[lntypes.Hash]*Invoice) + testAllInvoices := make(map[lntypes.Hash]*Invoice) + + states := []ContractState{ + ContractOpen, ContractSettled, ContractCanceled, ContractAccepted, + } + + for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ { + invoice, err := randInvoice(i) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + invoice.State = states[mrand.Intn(len(states))] + paymentHash := invoice.Terms.PaymentPreimage.Hash() + + if invoice.State != ContractSettled && invoice.State != ContractCanceled { + testPendingInvoices[paymentHash] = invoice + } + + testAllInvoices[paymentHash] = invoice + + if _, err := db.AddInvoice(invoice, paymentHash); err != nil { + t.Fatalf("unable to add invoice: %v", err) + } + } + + pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true) + if err != nil { + t.Fatalf("can't fetch invoices with payment hash: %v", err) + } + + if len(testPendingInvoices) != len(pendingInvoices) { + t.Fatalf("expected %v pending invoices, got: %v", + len(testPendingInvoices), len(pendingInvoices)) + } + + allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false) + if err != nil { + t.Fatalf("can't fetch invoices with payment hash: %v", err) + } + + if len(testAllInvoices) != len(allInvoices) { + t.Fatalf("expected %v invoices, got: %v", + len(testAllInvoices), len(allInvoices)) + } + + for i := range pendingInvoices { + expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash] + if !ok { + t.Fatalf("coulnd't find invoice with hash: %v", + pendingInvoices[i].PaymentHash) + } + + // Zero out add index to not confuse DeepEqual. + pendingInvoices[i].Invoice.AddIndex = 0 + expected.AddIndex = 0 + + if !reflect.DeepEqual(*expected, pendingInvoices[i].Invoice) { + t.Fatalf("expected: %v, got: %v", + spew.Sdump(expected), spew.Sdump(pendingInvoices[i].Invoice)) + } + } + + for i := range allInvoices { + expected, ok := testAllInvoices[allInvoices[i].PaymentHash] + if !ok { + t.Fatalf("coulnd't find invoice with hash: %v", + allInvoices[i].PaymentHash) + } + + // Zero out add index to not confuse DeepEqual. + allInvoices[i].Invoice.AddIndex = 0 + expected.AddIndex = 0 + + if !reflect.DeepEqual(*expected, allInvoices[i].Invoice) { + t.Fatalf("expected: %v, got: %v", + spew.Sdump(expected), spew.Sdump(allInvoices[i].Invoice)) + } + } + +} + // TestDuplicateSettleInvoice tests that if we add a new invoice and settle it // twice, then the second time we also receive the invoice that we settled as a // return argument.