From ad3522f1a6cb79a9001ebe082b54cca41eabc590 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Fri, 9 Aug 2019 13:40:34 +0200 Subject: [PATCH] channeldb+invoices: move invoice accept or settle logic into registry As the logic around invoice mutations gets more complex, the friction caused by having this logic split between invoice registry and channeldb becomes more apparent. This commit brings a clearer separation of concerns by centralizing the accept/settle logic in the invoice registry. The original AcceptOrSettle method is renamed to UpdateInvoice because the update to perform is controlled by the callback. --- channeldb/invoice_test.go | 39 +++++++------ channeldb/invoices.go | 111 ++++++++++++++++++++++++++---------- invoices/invoiceregistry.go | 43 +++++++++----- 3 files changed, 132 insertions(+), 61 deletions(-) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 580fe3d3..7ea19afd 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -123,9 +123,7 @@ func TestInvoiceWorkflow(t *testing.T) { // now have the settled bit toggle to true and a non-default // SettledDate payAmt := fakeInvoice.Terms.Value * 2 - _, err = db.AcceptOrSettleInvoice( - paymentHash, payAmt, checkHtlcParameters, - ) + _, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt)) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } @@ -288,8 +286,8 @@ func TestInvoiceAddTimeSeries(t *testing.T) { paymentHash := invoice.Terms.PaymentPreimage.Hash() - _, err := db.AcceptOrSettleInvoice( - paymentHash, 0, checkHtlcParameters, + _, err := db.UpdateInvoice( + paymentHash, getUpdateInvoice(0), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) @@ -371,8 +369,8 @@ func TestDuplicateSettleInvoice(t *testing.T) { } // With the invoice in the DB, we'll now attempt to settle the invoice. - dbInvoice, err := db.AcceptOrSettleInvoice( - payHash, amt, checkHtlcParameters, + dbInvoice, err := db.UpdateInvoice( + payHash, getUpdateInvoice(amt), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) @@ -393,8 +391,8 @@ func TestDuplicateSettleInvoice(t *testing.T) { // If we try to settle the invoice again, then we should get the very // same invoice back, but with an error this time. - dbInvoice, err = db.AcceptOrSettleInvoice( - payHash, amt, checkHtlcParameters, + dbInvoice, err = db.UpdateInvoice( + payHash, getUpdateInvoice(amt), ) if err != ErrInvoiceAlreadySettled { t.Fatalf("expected ErrInvoiceAlreadySettled") @@ -440,8 +438,8 @@ func TestQueryInvoices(t *testing.T) { // We'll only settle half of all invoices created. if i%2 == 0 { - _, err := db.AcceptOrSettleInvoice( - paymentHash, i, checkHtlcParameters, + _, err := db.UpdateInvoice( + paymentHash, getUpdateInvoice(i), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) @@ -685,10 +683,19 @@ func TestQueryInvoices(t *testing.T) { } } -func checkHtlcParameters(invoice *Invoice) error { - if invoice.Terms.State == ContractSettled { - return ErrInvoiceAlreadySettled - } +// getUpdateInvoice returns an invoice update callback that, when called, +// settles the invoice with the given amount. +func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { + return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { + if invoice.Terms.State == ContractSettled { + return nil, ErrInvoiceAlreadySettled + } - return nil + update := &InvoiceUpdateDesc{ + State: ContractSettled, + AmtPaid: amt, + } + + return update, nil + } } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 18dbf75f..195d82ff 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -276,6 +276,20 @@ type InvoiceHTLC struct { State HtlcState } +// InvoiceUpdateDesc describes the changes that should be applied to the +// invoice. +type InvoiceUpdateDesc struct { + // State is the new state that this invoice should progress to. + State ContractState + + // AmtPaid is the updated amount that has been paid to this invoice. + AmtPaid lnwire.MilliSatoshi +} + +// InvoiceUpdateCallback is a callback used in the db transaction to update the +// invoice. +type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) + func validateInvoice(i *Invoice) error { if len(i.Memo) > MaxMemoSize { return fmt.Errorf("max length a memo is %v, and invoice "+ @@ -689,21 +703,17 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { return resp, nil } -// AcceptOrSettleInvoice attempts to mark an invoice corresponding to the passed -// payment hash as settled. If an invoice matching the passed payment hash -// doesn't existing within the database, then the action will fail with a "not -// found" error. +// UpdateInvoice attempts to update an invoice corresponding to the passed +// payment hash. If an invoice matching the passed payment hash doesn't exist +// within the database, then the action will fail with a "not found" error. // -// When the preimage for the invoice is unknown (hold invoice), the invoice is -// marked as accepted. -// -// TODO: Store invoice cltv as separate field in database so that it doesn't -// need to be decoded from the payment request. -func (d *DB) AcceptOrSettleInvoice(paymentHash [32]byte, - amtPaid lnwire.MilliSatoshi, - checkHtlcParameters func(invoice *Invoice) error) (*Invoice, error) { +// The update is performed inside the same database transaction that fetches the +// invoice and is therefore atomic. The fields to update are controlled by the +// supplied callback. +func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, + callback InvoiceUpdateCallback) (*Invoice, error) { - var settledInvoice *Invoice + var updatedInvoice *Invoice err := d.Update(func(tx *bbolt.Tx) error { invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) if err != nil { @@ -729,15 +739,14 @@ func (d *DB) AcceptOrSettleInvoice(paymentHash [32]byte, return ErrInvoiceNotFound } - settledInvoice, err = acceptOrSettleInvoice( - invoices, settleIndex, invoiceNum, amtPaid, - checkHtlcParameters, + updatedInvoice, err = updateInvoice( + invoices, settleIndex, invoiceNum, callback, ) return err }) - return settledInvoice, err + return updatedInvoice, err } // SettleHoldInvoice sets the preimage of a hodl invoice and marks the invoice @@ -1200,35 +1209,75 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) { return htlcs, nil } -func acceptOrSettleInvoice(invoices, settleIndex *bbolt.Bucket, - invoiceNum []byte, amtPaid lnwire.MilliSatoshi, - checkHtlcParameters func(invoice *Invoice) error) ( - *Invoice, error) { +// copySlice allocates a new slice and copies the source into it. +func copySlice(src []byte) []byte { + dest := make([]byte, len(src)) + copy(dest, src) + return dest +} + +// copyInvoice makes a deep copy of the supplied invoice. +func copyInvoice(src *Invoice) *Invoice { + dest := Invoice{ + Memo: copySlice(src.Memo), + Receipt: copySlice(src.Receipt), + PaymentRequest: copySlice(src.PaymentRequest), + FinalCltvDelta: src.FinalCltvDelta, + CreationDate: src.CreationDate, + SettleDate: src.SettleDate, + Terms: src.Terms, + AddIndex: src.AddIndex, + SettleIndex: src.SettleIndex, + AmtPaid: src.AmtPaid, + Htlcs: make( + map[CircuitKey]*InvoiceHTLC, len(src.Htlcs), + ), + } + + for k, v := range src.Htlcs { + dest.Htlcs[k] = v + } + + return &dest +} + +// updateInvoice fetches the invoice, obtains the update descriptor from the +// callback and applies the updates in a single db transaction. +func updateInvoice(invoices, settleIndex *bbolt.Bucket, invoiceNum []byte, + callback InvoiceUpdateCallback) (*Invoice, error) { invoice, err := fetchInvoice(invoiceNum, invoices) if err != nil { return nil, err } - // If the invoice is still open, check the htlc parameters. - if err := checkHtlcParameters(&invoice); err != nil { + preUpdateState := invoice.Terms.State + + // Create deep copy to prevent any accidental modification in the + // callback. + copy := copyInvoice(&invoice) + + // Call the callback and obtain the update descriptor. + update, err := callback(copy) + if err != nil { return &invoice, err } - // Check to see if we can settle or this is an hold invoice and we need - // to wait for the preimage. - holdInvoice := invoice.Terms.PaymentPreimage == UnknownPreimage - if holdInvoice { - invoice.Terms.State = ContractAccepted - } else { + // Update invoice state and amount. + invoice.Terms.State = update.State + invoice.AmtPaid = update.AmtPaid + + // If invoice moved to the settled state, update settle index and settle + // time. + if preUpdateState != invoice.Terms.State && + invoice.Terms.State == ContractSettled { + err := setSettleFields(settleIndex, invoiceNum, &invoice) if err != nil { return nil, err } } - invoice.AmtPaid = amtPaid - var buf bytes.Buffer if err := serializeInvoice(&buf, &invoice); err != nil { return nil, err diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index cfe8cb60..d2729d15 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -409,22 +409,23 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice, return i.cdb.LookupInvoice(rHash) } -// checkHtlcParameters is a callback used inside invoice db transactions to +// updateInvoice is a callback used inside invoice db transactions to // atomically check-and-update an invoice. -func (i *InvoiceRegistry) checkHtlcParameters(invoice *channeldb.Invoice, - amtPaid lnwire.MilliSatoshi, htlcExpiry uint32, currentHeight int32) error { +func (i *InvoiceRegistry) updateInvoice(invoice *channeldb.Invoice, + amtPaid lnwire.MilliSatoshi, htlcExpiry uint32, currentHeight int32) ( + *channeldb.InvoiceUpdateDesc, error) { // If the invoice is already canceled, there is no further checking to // do. if invoice.Terms.State == channeldb.ContractCanceled { - return channeldb.ErrInvoiceAlreadyCanceled + return nil, channeldb.ErrInvoiceAlreadyCanceled } // If an invoice amount is specified, check that enough is paid. Also // check this for duplicate payments if the invoice is already settled // or accepted. if invoice.Terms.Value > 0 && amtPaid < invoice.Terms.Value { - return ErrInvoiceAmountTooLow + return nil, ErrInvoiceAmountTooLow } // Return early in case the invoice was already accepted or settled. We @@ -432,20 +433,32 @@ func (i *InvoiceRegistry) checkHtlcParameters(invoice *channeldb.Invoice, // just restarting. switch invoice.Terms.State { case channeldb.ContractAccepted: - return channeldb.ErrInvoiceAlreadyAccepted + return nil, channeldb.ErrInvoiceAlreadyAccepted case channeldb.ContractSettled: - return channeldb.ErrInvoiceAlreadySettled + return nil, channeldb.ErrInvoiceAlreadySettled } if htlcExpiry < uint32(currentHeight+i.finalCltvRejectDelta) { - return ErrInvoiceExpiryTooSoon + return nil, ErrInvoiceExpiryTooSoon } if htlcExpiry < uint32(currentHeight+invoice.FinalCltvDelta) { - return ErrInvoiceExpiryTooSoon + return nil, ErrInvoiceExpiryTooSoon } - return nil + update := channeldb.InvoiceUpdateDesc{ + AmtPaid: amtPaid, + } + + // Check to see if we can settle or this is an hold invoice and we need + // to wait for the preimage. + holdInvoice := invoice.Terms.PaymentPreimage == channeldb.UnknownPreimage + if holdInvoice { + update.State = channeldb.ContractAccepted + } else { + update.State = channeldb.ContractSettled + } + return &update, nil } // NotifyExitHopHtlc attempts to mark an invoice as settled. If the invoice is a @@ -474,10 +487,12 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, // If this isn't a debug invoice, then we'll attempt to settle an // invoice matching this rHash on disk (if one exists). - invoice, err := i.cdb.AcceptOrSettleInvoice( - rHash, amtPaid, - func(inv *channeldb.Invoice) error { - return i.checkHtlcParameters( + invoice, err := i.cdb.UpdateInvoice( + rHash, + func(inv *channeldb.Invoice) (*channeldb.InvoiceUpdateDesc, + error) { + + return i.updateInvoice( inv, amtPaid, expiry, currentHeight, ) },