diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 0cf0e0e6..9398668b 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1,6 +1,7 @@ package channeldb import ( + "crypto/rand" "reflect" "testing" "time" @@ -10,6 +11,26 @@ import ( "github.com/roasbeef/btcutil" ) +func randInvoice(value btcutil.Amount) (*Invoice, error) { + + var pre [32]byte + if _, err := rand.Read(pre[:]); err != nil { + return nil, err + } + + i := &Invoice{ + CreationDate: time.Now(), + Terms: ContractTerm{ + PaymentPreimage: pre, + Value: value, + }, + } + copy(i.Memo[:], []byte("memo")) + copy(i.Receipt[:], []byte("recipt")) + + return i, nil +} + func TestInvoiceWorkflow(t *testing.T) { db, cleanUp, err := makeTestDB() if err != nil { @@ -73,4 +94,40 @@ func TestInvoiceWorkflow(t *testing.T) { if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound { t.Fatalf("lookup should have failed, instead %v", err) } + + // Add 100 random invoices. + const numInvoices = 10 + amt := btcutil.Amount(1000) + invoices := make([]*Invoice, numInvoices+1) + invoices[0] = dbInvoice2 + for i := 1; i < len(invoices)-1; i++ { + invoice, err := randInvoice(amt) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + if err := db.AddInvoice(invoice); err != nil { + t.Fatalf("unable to add invoice %v", err) + } + + invoices[i] = invoice + } + + // Perform a scan to collect all the active invoices. + dbInvoices, err := db.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to fetch all invoices: %v", err) + } + + // The retrieve list of invoices should be identical as since we're + // using big endian, the invoices should be retrieved in asecending + // order (and the primary key should be incremented with each + // insertion). + for i := 0; i < len(invoices)-1; i++ { + if !reflect.DeepEqual(invoices[i], dbInvoices[i]) { + t.Fatalf("retrived invoices don't match %v vs %v", + spew.Sdump(invoices[i]), + spew.Sdump(dbInvoices[i])) + } + } } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 073568a8..d27ca23f 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -180,6 +180,48 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (*Invoice, error) { return invoice, nil } +// FetchAllInvoices returns all invoices currently stored within the database. +// If the pendingOnly param is true, then only unsettled invoices will be +// returned, skipping all invoices that are fully settled. +func (d *DB) FetchAllInvoices(pendingOnly bool) ([]*Invoice, error) { + var invoices []*Invoice + + err := d.store.View(func(tx *bolt.Tx) error { + invoiceB := tx.Bucket(invoiceBucket) + if invoiceB == nil { + return ErrNoInvoicesCreated + } + + // Iterate through the entire key space of the top-level + // invoice bucket. If key with a non-nil value stores the next + // invoice ID which maps to the corresponding invoice. + return invoiceB.ForEach(func(k, v []byte) error { + if v == nil { + return nil + } + + invoiceReader := bytes.NewReader(v) + invoice, err := deserializeInvoice(invoiceReader) + if err != nil { + return err + } + + if pendingOnly && invoice.Terms.Settled { + return nil + } + + invoices = append(invoices, invoice) + + return nil + }) + }) + if err != nil { + return nil, err + } + + return invoices, nil +} + // SettleInvoice attempts to mark an invoice corresponding to the passed // payment hash as fully settled. If an invoice matching the passed payment // hash doesn't existing within the database, then the action will fail with a