channeldb+invoices: move invoice accept or settle logic into registry

As the logic around invoice mutations gets more complex, the friction
caused by having this logic split between invoice registry and channeldb
becomes more apparent. This commit brings a clearer separation of
concerns by centralizing the accept/settle logic in the invoice
registry.

The original AcceptOrSettle method is renamed to UpdateInvoice because
the update to perform is controlled by the callback.
This commit is contained in:
Joost Jager 2019-08-09 13:40:34 +02:00
parent c1345a4117
commit ad3522f1a6
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
3 changed files with 132 additions and 61 deletions

@ -123,9 +123,7 @@ func TestInvoiceWorkflow(t *testing.T) {
// now have the settled bit toggle to true and a non-default // now have the settled bit toggle to true and a non-default
// SettledDate // SettledDate
payAmt := fakeInvoice.Terms.Value * 2 payAmt := fakeInvoice.Terms.Value * 2
_, err = db.AcceptOrSettleInvoice( _, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt))
paymentHash, payAmt, checkHtlcParameters,
)
if err != nil { if err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
} }
@ -288,8 +286,8 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
paymentHash := invoice.Terms.PaymentPreimage.Hash() paymentHash := invoice.Terms.PaymentPreimage.Hash()
_, err := db.AcceptOrSettleInvoice( _, err := db.UpdateInvoice(
paymentHash, 0, checkHtlcParameters, paymentHash, getUpdateInvoice(0),
) )
if err != nil { if err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
@ -371,8 +369,8 @@ func TestDuplicateSettleInvoice(t *testing.T) {
} }
// With the invoice in the DB, we'll now attempt to settle the invoice. // With the invoice in the DB, we'll now attempt to settle the invoice.
dbInvoice, err := db.AcceptOrSettleInvoice( dbInvoice, err := db.UpdateInvoice(
payHash, amt, checkHtlcParameters, payHash, getUpdateInvoice(amt),
) )
if err != nil { if err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
@ -393,8 +391,8 @@ func TestDuplicateSettleInvoice(t *testing.T) {
// If we try to settle the invoice again, then we should get the very // If we try to settle the invoice again, then we should get the very
// same invoice back, but with an error this time. // same invoice back, but with an error this time.
dbInvoice, err = db.AcceptOrSettleInvoice( dbInvoice, err = db.UpdateInvoice(
payHash, amt, checkHtlcParameters, payHash, getUpdateInvoice(amt),
) )
if err != ErrInvoiceAlreadySettled { if err != ErrInvoiceAlreadySettled {
t.Fatalf("expected ErrInvoiceAlreadySettled") t.Fatalf("expected ErrInvoiceAlreadySettled")
@ -440,8 +438,8 @@ func TestQueryInvoices(t *testing.T) {
// We'll only settle half of all invoices created. // We'll only settle half of all invoices created.
if i%2 == 0 { if i%2 == 0 {
_, err := db.AcceptOrSettleInvoice( _, err := db.UpdateInvoice(
paymentHash, i, checkHtlcParameters, paymentHash, getUpdateInvoice(i),
) )
if err != nil { if err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
@ -685,10 +683,19 @@ func TestQueryInvoices(t *testing.T) {
} }
} }
func checkHtlcParameters(invoice *Invoice) error { // getUpdateInvoice returns an invoice update callback that, when called,
// settles the invoice with the given amount.
func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
return func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
if invoice.Terms.State == ContractSettled { if invoice.Terms.State == ContractSettled {
return ErrInvoiceAlreadySettled return nil, ErrInvoiceAlreadySettled
} }
return nil update := &InvoiceUpdateDesc{
State: ContractSettled,
AmtPaid: amt,
}
return update, nil
}
} }

@ -276,6 +276,20 @@ type InvoiceHTLC struct {
State HtlcState State HtlcState
} }
// InvoiceUpdateDesc describes the changes that should be applied to the
// invoice.
type InvoiceUpdateDesc struct {
// State is the new state that this invoice should progress to.
State ContractState
// AmtPaid is the updated amount that has been paid to this invoice.
AmtPaid lnwire.MilliSatoshi
}
// InvoiceUpdateCallback is a callback used in the db transaction to update the
// invoice.
type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error)
func validateInvoice(i *Invoice) error { func validateInvoice(i *Invoice) error {
if len(i.Memo) > MaxMemoSize { if len(i.Memo) > MaxMemoSize {
return fmt.Errorf("max length a memo is %v, and invoice "+ return fmt.Errorf("max length a memo is %v, and invoice "+
@ -689,21 +703,17 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
return resp, nil return resp, nil
} }
// AcceptOrSettleInvoice attempts to mark an invoice corresponding to the passed // UpdateInvoice attempts to update an invoice corresponding to the passed
// payment hash as settled. If an invoice matching the passed payment hash // payment hash. If an invoice matching the passed payment hash doesn't exist
// doesn't existing within the database, then the action will fail with a "not // within the database, then the action will fail with a "not found" error.
// found" error.
// //
// When the preimage for the invoice is unknown (hold invoice), the invoice is // The update is performed inside the same database transaction that fetches the
// marked as accepted. // invoice and is therefore atomic. The fields to update are controlled by the
// // supplied callback.
// TODO: Store invoice cltv as separate field in database so that it doesn't func (d *DB) UpdateInvoice(paymentHash lntypes.Hash,
// need to be decoded from the payment request. callback InvoiceUpdateCallback) (*Invoice, error) {
func (d *DB) AcceptOrSettleInvoice(paymentHash [32]byte,
amtPaid lnwire.MilliSatoshi,
checkHtlcParameters func(invoice *Invoice) error) (*Invoice, error) {
var settledInvoice *Invoice var updatedInvoice *Invoice
err := d.Update(func(tx *bbolt.Tx) error { err := d.Update(func(tx *bbolt.Tx) error {
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
if err != nil { if err != nil {
@ -729,15 +739,14 @@ func (d *DB) AcceptOrSettleInvoice(paymentHash [32]byte,
return ErrInvoiceNotFound return ErrInvoiceNotFound
} }
settledInvoice, err = acceptOrSettleInvoice( updatedInvoice, err = updateInvoice(
invoices, settleIndex, invoiceNum, amtPaid, invoices, settleIndex, invoiceNum, callback,
checkHtlcParameters,
) )
return err return err
}) })
return settledInvoice, err return updatedInvoice, err
} }
// SettleHoldInvoice sets the preimage of a hodl invoice and marks the invoice // SettleHoldInvoice sets the preimage of a hodl invoice and marks the invoice
@ -1200,35 +1209,75 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
return htlcs, nil return htlcs, nil
} }
func acceptOrSettleInvoice(invoices, settleIndex *bbolt.Bucket, // copySlice allocates a new slice and copies the source into it.
invoiceNum []byte, amtPaid lnwire.MilliSatoshi, func copySlice(src []byte) []byte {
checkHtlcParameters func(invoice *Invoice) error) ( dest := make([]byte, len(src))
*Invoice, error) { copy(dest, src)
return dest
}
// copyInvoice makes a deep copy of the supplied invoice.
func copyInvoice(src *Invoice) *Invoice {
dest := Invoice{
Memo: copySlice(src.Memo),
Receipt: copySlice(src.Receipt),
PaymentRequest: copySlice(src.PaymentRequest),
FinalCltvDelta: src.FinalCltvDelta,
CreationDate: src.CreationDate,
SettleDate: src.SettleDate,
Terms: src.Terms,
AddIndex: src.AddIndex,
SettleIndex: src.SettleIndex,
AmtPaid: src.AmtPaid,
Htlcs: make(
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
),
}
for k, v := range src.Htlcs {
dest.Htlcs[k] = v
}
return &dest
}
// updateInvoice fetches the invoice, obtains the update descriptor from the
// callback and applies the updates in a single db transaction.
func updateInvoice(invoices, settleIndex *bbolt.Bucket, invoiceNum []byte,
callback InvoiceUpdateCallback) (*Invoice, error) {
invoice, err := fetchInvoice(invoiceNum, invoices) invoice, err := fetchInvoice(invoiceNum, invoices)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// If the invoice is still open, check the htlc parameters. preUpdateState := invoice.Terms.State
if err := checkHtlcParameters(&invoice); err != nil {
// Create deep copy to prevent any accidental modification in the
// callback.
copy := copyInvoice(&invoice)
// Call the callback and obtain the update descriptor.
update, err := callback(copy)
if err != nil {
return &invoice, err return &invoice, err
} }
// Check to see if we can settle or this is an hold invoice and we need // Update invoice state and amount.
// to wait for the preimage. invoice.Terms.State = update.State
holdInvoice := invoice.Terms.PaymentPreimage == UnknownPreimage invoice.AmtPaid = update.AmtPaid
if holdInvoice {
invoice.Terms.State = ContractAccepted // If invoice moved to the settled state, update settle index and settle
} else { // time.
if preUpdateState != invoice.Terms.State &&
invoice.Terms.State == ContractSettled {
err := setSettleFields(settleIndex, invoiceNum, &invoice) err := setSettleFields(settleIndex, invoiceNum, &invoice)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
invoice.AmtPaid = amtPaid
var buf bytes.Buffer var buf bytes.Buffer
if err := serializeInvoice(&buf, &invoice); err != nil { if err := serializeInvoice(&buf, &invoice); err != nil {
return nil, err return nil, err

@ -409,22 +409,23 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice,
return i.cdb.LookupInvoice(rHash) return i.cdb.LookupInvoice(rHash)
} }
// checkHtlcParameters is a callback used inside invoice db transactions to // updateInvoice is a callback used inside invoice db transactions to
// atomically check-and-update an invoice. // atomically check-and-update an invoice.
func (i *InvoiceRegistry) checkHtlcParameters(invoice *channeldb.Invoice, func (i *InvoiceRegistry) updateInvoice(invoice *channeldb.Invoice,
amtPaid lnwire.MilliSatoshi, htlcExpiry uint32, currentHeight int32) error { amtPaid lnwire.MilliSatoshi, htlcExpiry uint32, currentHeight int32) (
*channeldb.InvoiceUpdateDesc, error) {
// If the invoice is already canceled, there is no further checking to // If the invoice is already canceled, there is no further checking to
// do. // do.
if invoice.Terms.State == channeldb.ContractCanceled { if invoice.Terms.State == channeldb.ContractCanceled {
return channeldb.ErrInvoiceAlreadyCanceled return nil, channeldb.ErrInvoiceAlreadyCanceled
} }
// If an invoice amount is specified, check that enough is paid. Also // If an invoice amount is specified, check that enough is paid. Also
// check this for duplicate payments if the invoice is already settled // check this for duplicate payments if the invoice is already settled
// or accepted. // or accepted.
if invoice.Terms.Value > 0 && amtPaid < invoice.Terms.Value { if invoice.Terms.Value > 0 && amtPaid < invoice.Terms.Value {
return ErrInvoiceAmountTooLow return nil, ErrInvoiceAmountTooLow
} }
// Return early in case the invoice was already accepted or settled. We // Return early in case the invoice was already accepted or settled. We
@ -432,20 +433,32 @@ func (i *InvoiceRegistry) checkHtlcParameters(invoice *channeldb.Invoice,
// just restarting. // just restarting.
switch invoice.Terms.State { switch invoice.Terms.State {
case channeldb.ContractAccepted: case channeldb.ContractAccepted:
return channeldb.ErrInvoiceAlreadyAccepted return nil, channeldb.ErrInvoiceAlreadyAccepted
case channeldb.ContractSettled: case channeldb.ContractSettled:
return channeldb.ErrInvoiceAlreadySettled return nil, channeldb.ErrInvoiceAlreadySettled
} }
if htlcExpiry < uint32(currentHeight+i.finalCltvRejectDelta) { if htlcExpiry < uint32(currentHeight+i.finalCltvRejectDelta) {
return ErrInvoiceExpiryTooSoon return nil, ErrInvoiceExpiryTooSoon
} }
if htlcExpiry < uint32(currentHeight+invoice.FinalCltvDelta) { if htlcExpiry < uint32(currentHeight+invoice.FinalCltvDelta) {
return ErrInvoiceExpiryTooSoon return nil, ErrInvoiceExpiryTooSoon
} }
return nil update := channeldb.InvoiceUpdateDesc{
AmtPaid: amtPaid,
}
// Check to see if we can settle or this is an hold invoice and we need
// to wait for the preimage.
holdInvoice := invoice.Terms.PaymentPreimage == channeldb.UnknownPreimage
if holdInvoice {
update.State = channeldb.ContractAccepted
} else {
update.State = channeldb.ContractSettled
}
return &update, nil
} }
// NotifyExitHopHtlc attempts to mark an invoice as settled. If the invoice is a // NotifyExitHopHtlc attempts to mark an invoice as settled. If the invoice is a
@ -474,10 +487,12 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
// If this isn't a debug invoice, then we'll attempt to settle an // If this isn't a debug invoice, then we'll attempt to settle an
// invoice matching this rHash on disk (if one exists). // invoice matching this rHash on disk (if one exists).
invoice, err := i.cdb.AcceptOrSettleInvoice( invoice, err := i.cdb.UpdateInvoice(
rHash, amtPaid, rHash,
func(inv *channeldb.Invoice) error { func(inv *channeldb.Invoice) (*channeldb.InvoiceUpdateDesc,
return i.checkHtlcParameters( error) {
return i.updateInvoice(
inv, amtPaid, expiry, currentHeight, inv, amtPaid, expiry, currentHeight,
) )
}, },