channeldb+invoices: move invoice accept or settle logic into registry
As the logic around invoice mutations gets more complex, the friction caused by having this logic split between invoice registry and channeldb becomes more apparent. This commit brings a clearer separation of concerns by centralizing the accept/settle logic in the invoice registry. The original AcceptOrSettle method is renamed to UpdateInvoice because the update to perform is controlled by the callback.
This commit is contained in:
parent
c1345a4117
commit
ad3522f1a6
@ -123,9 +123,7 @@ func TestInvoiceWorkflow(t *testing.T) {
|
|||||||
// now have the settled bit toggle to true and a non-default
|
// now have the settled bit toggle to true and a non-default
|
||||||
// SettledDate
|
// SettledDate
|
||||||
payAmt := fakeInvoice.Terms.Value * 2
|
payAmt := fakeInvoice.Terms.Value * 2
|
||||||
_, err = db.AcceptOrSettleInvoice(
|
_, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt))
|
||||||
paymentHash, payAmt, checkHtlcParameters,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to settle invoice: %v", err)
|
t.Fatalf("unable to settle invoice: %v", err)
|
||||||
}
|
}
|
||||||
@ -288,8 +286,8 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
|
|||||||
|
|
||||||
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
||||||
|
|
||||||
_, err := db.AcceptOrSettleInvoice(
|
_, err := db.UpdateInvoice(
|
||||||
paymentHash, 0, checkHtlcParameters,
|
paymentHash, getUpdateInvoice(0),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to settle invoice: %v", err)
|
t.Fatalf("unable to settle invoice: %v", err)
|
||||||
@ -371,8 +369,8 @@ func TestDuplicateSettleInvoice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// With the invoice in the DB, we'll now attempt to settle the invoice.
|
// With the invoice in the DB, we'll now attempt to settle the invoice.
|
||||||
dbInvoice, err := db.AcceptOrSettleInvoice(
|
dbInvoice, err := db.UpdateInvoice(
|
||||||
payHash, amt, checkHtlcParameters,
|
payHash, getUpdateInvoice(amt),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to settle invoice: %v", err)
|
t.Fatalf("unable to settle invoice: %v", err)
|
||||||
@ -393,8 +391,8 @@ func TestDuplicateSettleInvoice(t *testing.T) {
|
|||||||
|
|
||||||
// If we try to settle the invoice again, then we should get the very
|
// If we try to settle the invoice again, then we should get the very
|
||||||
// same invoice back, but with an error this time.
|
// same invoice back, but with an error this time.
|
||||||
dbInvoice, err = db.AcceptOrSettleInvoice(
|
dbInvoice, err = db.UpdateInvoice(
|
||||||
payHash, amt, checkHtlcParameters,
|
payHash, getUpdateInvoice(amt),
|
||||||
)
|
)
|
||||||
if err != ErrInvoiceAlreadySettled {
|
if err != ErrInvoiceAlreadySettled {
|
||||||
t.Fatalf("expected ErrInvoiceAlreadySettled")
|
t.Fatalf("expected ErrInvoiceAlreadySettled")
|
||||||
@ -440,8 +438,8 @@ func TestQueryInvoices(t *testing.T) {
|
|||||||
|
|
||||||
// We'll only settle half of all invoices created.
|
// We'll only settle half of all invoices created.
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err := db.AcceptOrSettleInvoice(
|
_, err := db.UpdateInvoice(
|
||||||
paymentHash, i, checkHtlcParameters,
|
paymentHash, getUpdateInvoice(i),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to settle invoice: %v", err)
|
t.Fatalf("unable to settle invoice: %v", err)
|
||||||
@ -685,10 +683,19 @@ func TestQueryInvoices(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkHtlcParameters(invoice *Invoice) error {
|
// getUpdateInvoice returns an invoice update callback that, when called,
|
||||||
if invoice.Terms.State == ContractSettled {
|
// settles the invoice with the given amount.
|
||||||
return ErrInvoiceAlreadySettled
|
func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
|
||||||
}
|
return func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
|
||||||
|
if invoice.Terms.State == ContractSettled {
|
||||||
|
return nil, ErrInvoiceAlreadySettled
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
update := &InvoiceUpdateDesc{
|
||||||
|
State: ContractSettled,
|
||||||
|
AmtPaid: amt,
|
||||||
|
}
|
||||||
|
|
||||||
|
return update, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -276,6 +276,20 @@ type InvoiceHTLC struct {
|
|||||||
State HtlcState
|
State HtlcState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InvoiceUpdateDesc describes the changes that should be applied to the
|
||||||
|
// invoice.
|
||||||
|
type InvoiceUpdateDesc struct {
|
||||||
|
// State is the new state that this invoice should progress to.
|
||||||
|
State ContractState
|
||||||
|
|
||||||
|
// AmtPaid is the updated amount that has been paid to this invoice.
|
||||||
|
AmtPaid lnwire.MilliSatoshi
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvoiceUpdateCallback is a callback used in the db transaction to update the
|
||||||
|
// invoice.
|
||||||
|
type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error)
|
||||||
|
|
||||||
func validateInvoice(i *Invoice) error {
|
func validateInvoice(i *Invoice) error {
|
||||||
if len(i.Memo) > MaxMemoSize {
|
if len(i.Memo) > MaxMemoSize {
|
||||||
return fmt.Errorf("max length a memo is %v, and invoice "+
|
return fmt.Errorf("max length a memo is %v, and invoice "+
|
||||||
@ -689,21 +703,17 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptOrSettleInvoice attempts to mark an invoice corresponding to the passed
|
// UpdateInvoice attempts to update an invoice corresponding to the passed
|
||||||
// payment hash as settled. If an invoice matching the passed payment hash
|
// payment hash. If an invoice matching the passed payment hash doesn't exist
|
||||||
// doesn't existing within the database, then the action will fail with a "not
|
// within the database, then the action will fail with a "not found" error.
|
||||||
// found" error.
|
|
||||||
//
|
//
|
||||||
// When the preimage for the invoice is unknown (hold invoice), the invoice is
|
// The update is performed inside the same database transaction that fetches the
|
||||||
// marked as accepted.
|
// invoice and is therefore atomic. The fields to update are controlled by the
|
||||||
//
|
// supplied callback.
|
||||||
// TODO: Store invoice cltv as separate field in database so that it doesn't
|
func (d *DB) UpdateInvoice(paymentHash lntypes.Hash,
|
||||||
// need to be decoded from the payment request.
|
callback InvoiceUpdateCallback) (*Invoice, error) {
|
||||||
func (d *DB) AcceptOrSettleInvoice(paymentHash [32]byte,
|
|
||||||
amtPaid lnwire.MilliSatoshi,
|
|
||||||
checkHtlcParameters func(invoice *Invoice) error) (*Invoice, error) {
|
|
||||||
|
|
||||||
var settledInvoice *Invoice
|
var updatedInvoice *Invoice
|
||||||
err := d.Update(func(tx *bbolt.Tx) error {
|
err := d.Update(func(tx *bbolt.Tx) error {
|
||||||
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
|
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -729,15 +739,14 @@ func (d *DB) AcceptOrSettleInvoice(paymentHash [32]byte,
|
|||||||
return ErrInvoiceNotFound
|
return ErrInvoiceNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
settledInvoice, err = acceptOrSettleInvoice(
|
updatedInvoice, err = updateInvoice(
|
||||||
invoices, settleIndex, invoiceNum, amtPaid,
|
invoices, settleIndex, invoiceNum, callback,
|
||||||
checkHtlcParameters,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
return settledInvoice, err
|
return updatedInvoice, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SettleHoldInvoice sets the preimage of a hodl invoice and marks the invoice
|
// SettleHoldInvoice sets the preimage of a hodl invoice and marks the invoice
|
||||||
@ -1200,35 +1209,75 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
|
|||||||
return htlcs, nil
|
return htlcs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func acceptOrSettleInvoice(invoices, settleIndex *bbolt.Bucket,
|
// copySlice allocates a new slice and copies the source into it.
|
||||||
invoiceNum []byte, amtPaid lnwire.MilliSatoshi,
|
func copySlice(src []byte) []byte {
|
||||||
checkHtlcParameters func(invoice *Invoice) error) (
|
dest := make([]byte, len(src))
|
||||||
*Invoice, error) {
|
copy(dest, src)
|
||||||
|
return dest
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyInvoice makes a deep copy of the supplied invoice.
|
||||||
|
func copyInvoice(src *Invoice) *Invoice {
|
||||||
|
dest := Invoice{
|
||||||
|
Memo: copySlice(src.Memo),
|
||||||
|
Receipt: copySlice(src.Receipt),
|
||||||
|
PaymentRequest: copySlice(src.PaymentRequest),
|
||||||
|
FinalCltvDelta: src.FinalCltvDelta,
|
||||||
|
CreationDate: src.CreationDate,
|
||||||
|
SettleDate: src.SettleDate,
|
||||||
|
Terms: src.Terms,
|
||||||
|
AddIndex: src.AddIndex,
|
||||||
|
SettleIndex: src.SettleIndex,
|
||||||
|
AmtPaid: src.AmtPaid,
|
||||||
|
Htlcs: make(
|
||||||
|
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range src.Htlcs {
|
||||||
|
dest.Htlcs[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dest
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateInvoice fetches the invoice, obtains the update descriptor from the
|
||||||
|
// callback and applies the updates in a single db transaction.
|
||||||
|
func updateInvoice(invoices, settleIndex *bbolt.Bucket, invoiceNum []byte,
|
||||||
|
callback InvoiceUpdateCallback) (*Invoice, error) {
|
||||||
|
|
||||||
invoice, err := fetchInvoice(invoiceNum, invoices)
|
invoice, err := fetchInvoice(invoiceNum, invoices)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the invoice is still open, check the htlc parameters.
|
preUpdateState := invoice.Terms.State
|
||||||
if err := checkHtlcParameters(&invoice); err != nil {
|
|
||||||
|
// Create deep copy to prevent any accidental modification in the
|
||||||
|
// callback.
|
||||||
|
copy := copyInvoice(&invoice)
|
||||||
|
|
||||||
|
// Call the callback and obtain the update descriptor.
|
||||||
|
update, err := callback(copy)
|
||||||
|
if err != nil {
|
||||||
return &invoice, err
|
return &invoice, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check to see if we can settle or this is an hold invoice and we need
|
// Update invoice state and amount.
|
||||||
// to wait for the preimage.
|
invoice.Terms.State = update.State
|
||||||
holdInvoice := invoice.Terms.PaymentPreimage == UnknownPreimage
|
invoice.AmtPaid = update.AmtPaid
|
||||||
if holdInvoice {
|
|
||||||
invoice.Terms.State = ContractAccepted
|
// If invoice moved to the settled state, update settle index and settle
|
||||||
} else {
|
// time.
|
||||||
|
if preUpdateState != invoice.Terms.State &&
|
||||||
|
invoice.Terms.State == ContractSettled {
|
||||||
|
|
||||||
err := setSettleFields(settleIndex, invoiceNum, &invoice)
|
err := setSettleFields(settleIndex, invoiceNum, &invoice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
invoice.AmtPaid = amtPaid
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
if err := serializeInvoice(&buf, &invoice); err != nil {
|
if err := serializeInvoice(&buf, &invoice); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -409,22 +409,23 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice,
|
|||||||
return i.cdb.LookupInvoice(rHash)
|
return i.cdb.LookupInvoice(rHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkHtlcParameters is a callback used inside invoice db transactions to
|
// updateInvoice is a callback used inside invoice db transactions to
|
||||||
// atomically check-and-update an invoice.
|
// atomically check-and-update an invoice.
|
||||||
func (i *InvoiceRegistry) checkHtlcParameters(invoice *channeldb.Invoice,
|
func (i *InvoiceRegistry) updateInvoice(invoice *channeldb.Invoice,
|
||||||
amtPaid lnwire.MilliSatoshi, htlcExpiry uint32, currentHeight int32) error {
|
amtPaid lnwire.MilliSatoshi, htlcExpiry uint32, currentHeight int32) (
|
||||||
|
*channeldb.InvoiceUpdateDesc, error) {
|
||||||
|
|
||||||
// If the invoice is already canceled, there is no further checking to
|
// If the invoice is already canceled, there is no further checking to
|
||||||
// do.
|
// do.
|
||||||
if invoice.Terms.State == channeldb.ContractCanceled {
|
if invoice.Terms.State == channeldb.ContractCanceled {
|
||||||
return channeldb.ErrInvoiceAlreadyCanceled
|
return nil, channeldb.ErrInvoiceAlreadyCanceled
|
||||||
}
|
}
|
||||||
|
|
||||||
// If an invoice amount is specified, check that enough is paid. Also
|
// If an invoice amount is specified, check that enough is paid. Also
|
||||||
// check this for duplicate payments if the invoice is already settled
|
// check this for duplicate payments if the invoice is already settled
|
||||||
// or accepted.
|
// or accepted.
|
||||||
if invoice.Terms.Value > 0 && amtPaid < invoice.Terms.Value {
|
if invoice.Terms.Value > 0 && amtPaid < invoice.Terms.Value {
|
||||||
return ErrInvoiceAmountTooLow
|
return nil, ErrInvoiceAmountTooLow
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return early in case the invoice was already accepted or settled. We
|
// Return early in case the invoice was already accepted or settled. We
|
||||||
@ -432,20 +433,32 @@ func (i *InvoiceRegistry) checkHtlcParameters(invoice *channeldb.Invoice,
|
|||||||
// just restarting.
|
// just restarting.
|
||||||
switch invoice.Terms.State {
|
switch invoice.Terms.State {
|
||||||
case channeldb.ContractAccepted:
|
case channeldb.ContractAccepted:
|
||||||
return channeldb.ErrInvoiceAlreadyAccepted
|
return nil, channeldb.ErrInvoiceAlreadyAccepted
|
||||||
case channeldb.ContractSettled:
|
case channeldb.ContractSettled:
|
||||||
return channeldb.ErrInvoiceAlreadySettled
|
return nil, channeldb.ErrInvoiceAlreadySettled
|
||||||
}
|
}
|
||||||
|
|
||||||
if htlcExpiry < uint32(currentHeight+i.finalCltvRejectDelta) {
|
if htlcExpiry < uint32(currentHeight+i.finalCltvRejectDelta) {
|
||||||
return ErrInvoiceExpiryTooSoon
|
return nil, ErrInvoiceExpiryTooSoon
|
||||||
}
|
}
|
||||||
|
|
||||||
if htlcExpiry < uint32(currentHeight+invoice.FinalCltvDelta) {
|
if htlcExpiry < uint32(currentHeight+invoice.FinalCltvDelta) {
|
||||||
return ErrInvoiceExpiryTooSoon
|
return nil, ErrInvoiceExpiryTooSoon
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
update := channeldb.InvoiceUpdateDesc{
|
||||||
|
AmtPaid: amtPaid,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check to see if we can settle or this is an hold invoice and we need
|
||||||
|
// to wait for the preimage.
|
||||||
|
holdInvoice := invoice.Terms.PaymentPreimage == channeldb.UnknownPreimage
|
||||||
|
if holdInvoice {
|
||||||
|
update.State = channeldb.ContractAccepted
|
||||||
|
} else {
|
||||||
|
update.State = channeldb.ContractSettled
|
||||||
|
}
|
||||||
|
return &update, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotifyExitHopHtlc attempts to mark an invoice as settled. If the invoice is a
|
// NotifyExitHopHtlc attempts to mark an invoice as settled. If the invoice is a
|
||||||
@ -474,10 +487,12 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
|
|||||||
|
|
||||||
// If this isn't a debug invoice, then we'll attempt to settle an
|
// If this isn't a debug invoice, then we'll attempt to settle an
|
||||||
// invoice matching this rHash on disk (if one exists).
|
// invoice matching this rHash on disk (if one exists).
|
||||||
invoice, err := i.cdb.AcceptOrSettleInvoice(
|
invoice, err := i.cdb.UpdateInvoice(
|
||||||
rHash, amtPaid,
|
rHash,
|
||||||
func(inv *channeldb.Invoice) error {
|
func(inv *channeldb.Invoice) (*channeldb.InvoiceUpdateDesc,
|
||||||
return i.checkHtlcParameters(
|
error) {
|
||||||
|
|
||||||
|
return i.updateInvoice(
|
||||||
inv, amtPaid, expiry, currentHeight,
|
inv, amtPaid, expiry, currentHeight,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user