channeldb: AddInvoice now returns the addIndex of the new invoice

This commit is contained in:
Olaoluwa Osuntokun 2018-06-29 18:05:51 -07:00
parent 2dcc2d63a6
commit 7aeed0b58f
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21
2 changed files with 27 additions and 16 deletions

@ -70,7 +70,7 @@ func TestInvoiceWorkflow(t *testing.T) {
// Add the invoice to the database, this should succeed as there aren't // Add the invoice to the database, this should succeed as there aren't
// any existing invoices within the database with the same payment // any existing invoices within the database with the same payment
// hash. // hash.
if err := db.AddInvoice(fakeInvoice); err != nil { if _, err := db.AddInvoice(fakeInvoice); err != nil {
t.Fatalf("unable to find invoice: %v", err) t.Fatalf("unable to find invoice: %v", err)
} }
@ -126,7 +126,7 @@ func TestInvoiceWorkflow(t *testing.T) {
// Attempt to insert generated above again, this should fail as // Attempt to insert generated above again, this should fail as
// duplicates are rejected by the processing logic. // duplicates are rejected by the processing logic.
if err := db.AddInvoice(fakeInvoice); err != ErrDuplicateInvoice { if _, err := db.AddInvoice(fakeInvoice); err != ErrDuplicateInvoice {
t.Fatalf("invoice insertion should fail due to duplication, "+ t.Fatalf("invoice insertion should fail due to duplication, "+
"instead %v", err) "instead %v", err)
} }
@ -149,7 +149,7 @@ func TestInvoiceWorkflow(t *testing.T) {
t.Fatalf("unable to create invoice: %v", err) t.Fatalf("unable to create invoice: %v", err)
} }
if err := db.AddInvoice(invoice); err != nil { if _, err := db.AddInvoice(invoice); err != nil {
t.Fatalf("unable to add invoice %v", err) t.Fatalf("unable to add invoice %v", err)
} }
@ -198,7 +198,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
t.Fatalf("unable to create invoice: %v", err) t.Fatalf("unable to create invoice: %v", err)
} }
if err := db.AddInvoice(invoice); err != nil { if _, err := db.AddInvoice(invoice); err != nil {
t.Fatalf("unable to add invoice %v", err) t.Fatalf("unable to add invoice %v", err)
} }

@ -179,11 +179,12 @@ func validateInvoice(i *Invoice) error {
// has *any* payment hashes which already exists within the database, then the // has *any* payment hashes which already exists within the database, then the
// insertion will be aborted and rejected due to the strict policy banning any // insertion will be aborted and rejected due to the strict policy banning any
// duplicate payment hashes. // duplicate payment hashes.
func (d *DB) AddInvoice(newInvoice *Invoice) error { func (d *DB) AddInvoice(newInvoice *Invoice) (uint64, error) {
if err := validateInvoice(newInvoice); err != nil { if err := validateInvoice(newInvoice); err != nil {
return err return 0, err
} }
var invoiceAddIndex uint64
err := d.Update(func(tx *bolt.Tx) error { err := d.Update(func(tx *bolt.Tx) error {
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
if err != nil { if err != nil {
@ -227,15 +228,21 @@ func (d *DB) AddInvoice(newInvoice *Invoice) error {
invoiceNum = byteOrder.Uint32(invoiceCounter) invoiceNum = byteOrder.Uint32(invoiceCounter)
} }
return putInvoice( newIndex, err := putInvoice(
invoices, invoiceIndex, addIndex, newInvoice, invoiceNum, invoices, invoiceIndex, addIndex, newInvoice, invoiceNum,
) )
if err != nil {
return err
}
invoiceAddIndex = newIndex
return nil
}) })
if err != nil { if err != nil {
return err return 0, err
} }
return err return invoiceAddIndex, err
} }
// InvoicesAddedSince can be used by callers to seek into the event time series // InvoicesAddedSince can be used by callers to seek into the event time series
@ -501,7 +508,7 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
} }
func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket, func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
i *Invoice, invoiceNum uint32) error { i *Invoice, invoiceNum uint32) (uint64, error) {
// Create the invoice key which is just the big-endian representation // Create the invoice key which is just the big-endian representation
// of the invoice number. // of the invoice number.
@ -514,7 +521,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
invoiceCounter := invoiceNum + 1 invoiceCounter := invoiceNum + 1
byteOrder.PutUint32(scratch[:], invoiceCounter) byteOrder.PutUint32(scratch[:], invoiceCounter)
if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil { if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
return err return 0, err
} }
// Add the payment hash to the invoice index. This will let us quickly // Add the payment hash to the invoice index. This will let us quickly
@ -523,7 +530,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
paymentHash := sha256.Sum256(i.Terms.PaymentPreimage[:]) paymentHash := sha256.Sum256(i.Terms.PaymentPreimage[:])
err := invoiceIndex.Put(paymentHash[:], invoiceKey[:]) err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
if err != nil { if err != nil {
return err return 0, err
} }
// Next, we'll obtain the next add invoice index (sequence // Next, we'll obtain the next add invoice index (sequence
@ -531,7 +538,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
// event stream. // event stream.
nextAddSeqNo, err := addIndex.NextSequence() nextAddSeqNo, err := addIndex.NextSequence()
if err != nil { if err != nil {
return err return 0, err
} }
// With the next sequence obtained, we'll updating the event series in // With the next sequence obtained, we'll updating the event series in
@ -540,7 +547,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
var seqNoBytes [8]byte var seqNoBytes [8]byte
byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo) byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil { if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
return err return 0, err
} }
i.AddIndex = nextAddSeqNo i.AddIndex = nextAddSeqNo
@ -548,10 +555,14 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
// Finally, serialize the invoice itself to be written to the disk. // Finally, serialize the invoice itself to be written to the disk.
var buf bytes.Buffer var buf bytes.Buffer
if err := serializeInvoice(&buf, i); err != nil { if err := serializeInvoice(&buf, i); err != nil {
return nil return 0, nil
} }
return invoices.Put(invoiceKey[:], buf.Bytes()) if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
return 0, err
}
return nextAddSeqNo, nil
} }
func serializeInvoice(w io.Writer, i *Invoice) error { func serializeInvoice(w io.Writer, i *Invoice) error {