diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index aca845d5..4f3ee5dd 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -70,7 +70,7 @@ func TestInvoiceWorkflow(t *testing.T) { // Add the invoice to the database, this should succeed as there aren't // any existing invoices within the database with the same payment // hash. - if err := db.AddInvoice(fakeInvoice); err != nil { + if _, err := db.AddInvoice(fakeInvoice); err != nil { t.Fatalf("unable to find invoice: %v", err) } @@ -126,7 +126,7 @@ func TestInvoiceWorkflow(t *testing.T) { // Attempt to insert generated above again, this should fail as // duplicates are rejected by the processing logic. - if err := db.AddInvoice(fakeInvoice); err != ErrDuplicateInvoice { + if _, err := db.AddInvoice(fakeInvoice); err != ErrDuplicateInvoice { t.Fatalf("invoice insertion should fail due to duplication, "+ "instead %v", err) } @@ -149,7 +149,7 @@ func TestInvoiceWorkflow(t *testing.T) { t.Fatalf("unable to create invoice: %v", err) } - if err := db.AddInvoice(invoice); err != nil { + if _, err := db.AddInvoice(invoice); err != nil { t.Fatalf("unable to add invoice %v", err) } @@ -198,7 +198,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { t.Fatalf("unable to create invoice: %v", err) } - if err := db.AddInvoice(invoice); err != nil { + if _, err := db.AddInvoice(invoice); err != nil { t.Fatalf("unable to add invoice %v", err) } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index a6fe0df9..a05f57df 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -179,11 +179,12 @@ func validateInvoice(i *Invoice) error { // has *any* payment hashes which already exists within the database, then the // insertion will be aborted and rejected due to the strict policy banning any // duplicate payment hashes. -func (d *DB) AddInvoice(newInvoice *Invoice) error { +func (d *DB) AddInvoice(newInvoice *Invoice) (uint64, error) { if err := validateInvoice(newInvoice); err != nil { - return err + return 0, err } + var invoiceAddIndex uint64 err := d.Update(func(tx *bolt.Tx) error { invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) if err != nil { @@ -227,15 +228,21 @@ func (d *DB) AddInvoice(newInvoice *Invoice) error { invoiceNum = byteOrder.Uint32(invoiceCounter) } - return putInvoice( + newIndex, err := putInvoice( invoices, invoiceIndex, addIndex, newInvoice, invoiceNum, ) + if err != nil { + return err + } + + invoiceAddIndex = newIndex + return nil }) if err != nil { - return err + return 0, err } - return err + return invoiceAddIndex, err } // InvoicesAddedSince can be used by callers to seek into the event time series @@ -501,7 +508,7 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { } func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket, - i *Invoice, invoiceNum uint32) error { + i *Invoice, invoiceNum uint32) (uint64, error) { // Create the invoice key which is just the big-endian representation // of the invoice number. @@ -514,7 +521,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket, invoiceCounter := invoiceNum + 1 byteOrder.PutUint32(scratch[:], invoiceCounter) if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil { - return err + return 0, err } // Add the payment hash to the invoice index. This will let us quickly @@ -523,7 +530,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket, paymentHash := sha256.Sum256(i.Terms.PaymentPreimage[:]) err := invoiceIndex.Put(paymentHash[:], invoiceKey[:]) if err != nil { - return err + return 0, err } // Next, we'll obtain the next add invoice index (sequence @@ -531,7 +538,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket, // event stream. nextAddSeqNo, err := addIndex.NextSequence() if err != nil { - return err + return 0, err } // With the next sequence obtained, we'll updating the event series in @@ -540,7 +547,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket, var seqNoBytes [8]byte byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo) if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil { - return err + return 0, err } i.AddIndex = nextAddSeqNo @@ -548,10 +555,14 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket, // Finally, serialize the invoice itself to be written to the disk. var buf bytes.Buffer if err := serializeInvoice(&buf, i); err != nil { - return nil + return 0, nil } - return invoices.Put(invoiceKey[:], buf.Bytes()) + if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil { + return 0, err + } + + return nextAddSeqNo, nil } func serializeInvoice(w io.Writer, i *Invoice) error {