channeldb+invoices: track invoices by InvoiceRef

This commit is contained in:
Conner Fromknecht 2020-05-21 15:37:10 -07:00
parent 2799202fd9
commit 3522f09a08
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
5 changed files with 151 additions and 76 deletions

@ -11,6 +11,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/stretchr/testify/assert"
) )
var ( var (
@ -112,6 +113,7 @@ func TestInvoiceWorkflow(t *testing.T) {
fakeInvoice.Terms.Features = emptyFeatures fakeInvoice.Terms.Features = emptyFeatures
paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash()
ref := InvoiceRefByHash(paymentHash)
// Add the invoice to the database, this should succeed as there aren't // Add the invoice to the database, this should succeed as there aren't
// any existing invoices within the database with the same payment // any existing invoices within the database with the same payment
@ -123,7 +125,7 @@ func TestInvoiceWorkflow(t *testing.T) {
// Attempt to retrieve the invoice which was just added to the // Attempt to retrieve the invoice which was just added to the
// database. It should be found, and the invoice returned should be // database. It should be found, and the invoice returned should be
// identical to the one created above. // identical to the one created above.
dbInvoice, err := db.LookupInvoice(paymentHash) dbInvoice, err := db.LookupInvoice(ref)
if err != nil { if err != nil {
t.Fatalf("unable to find invoice: %v", err) t.Fatalf("unable to find invoice: %v", err)
} }
@ -144,11 +146,11 @@ func TestInvoiceWorkflow(t *testing.T) {
// now have the settled bit toggle to true and a non-default // now have the settled bit toggle to true and a non-default
// SettledDate // SettledDate
payAmt := fakeInvoice.Terms.Value * 2 payAmt := fakeInvoice.Terms.Value * 2
_, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt)) _, err = db.UpdateInvoice(ref, getUpdateInvoice(payAmt))
if err != nil { if err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
} }
dbInvoice2, err := db.LookupInvoice(paymentHash) dbInvoice2, err := db.LookupInvoice(ref)
if err != nil { if err != nil {
t.Fatalf("unable to fetch invoice: %v", err) t.Fatalf("unable to fetch invoice: %v", err)
} }
@ -180,7 +182,9 @@ func TestInvoiceWorkflow(t *testing.T) {
// Attempt to look up a non-existent invoice, this should also fail but // Attempt to look up a non-existent invoice, this should also fail but
// with a "not found" error. // with a "not found" error.
var fakeHash [32]byte var fakeHash [32]byte
if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound { fakeRef := InvoiceRefByHash(fakeHash)
_, err = db.LookupInvoice(fakeRef)
if err != ErrInvoiceNotFound {
t.Fatalf("lookup should have failed, instead %v", err) t.Fatalf("lookup should have failed, instead %v", err)
} }
@ -256,7 +260,9 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) {
Amt: 500, Amt: 500,
CustomRecords: make(record.CustomSet), CustomRecords: make(record.CustomSet),
} }
invoice, err := db.UpdateInvoice(paymentHash,
ref := InvoiceRefByHash(paymentHash)
invoice, err := db.UpdateInvoice(ref,
func(invoice *Invoice) (*InvoiceUpdateDesc, error) { func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
return &InvoiceUpdateDesc{ return &InvoiceUpdateDesc{
AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{ AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{
@ -275,13 +281,14 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) {
} }
// Cancel the htlc again. // Cancel the htlc again.
invoice, err = db.UpdateInvoice(paymentHash, func(invoice *Invoice) (*InvoiceUpdateDesc, error) { invoice, err = db.UpdateInvoice(ref,
return &InvoiceUpdateDesc{ func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
CancelHtlcs: map[CircuitKey]struct{}{ return &InvoiceUpdateDesc{
key: {}, CancelHtlcs: map[CircuitKey]struct{}{
}, key: {},
}, nil },
}) }, nil
})
if err != nil { if err != nil {
t.Fatalf("unable to cancel htlc: %v", err) t.Fatalf("unable to cancel htlc: %v", err)
} }
@ -380,8 +387,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
paymentHash := invoice.Terms.PaymentPreimage.Hash() paymentHash := invoice.Terms.PaymentPreimage.Hash()
ref := InvoiceRefByHash(paymentHash)
_, err := db.UpdateInvoice( _, err := db.UpdateInvoice(
paymentHash, getUpdateInvoice(invoice.Terms.Value), ref, getUpdateInvoice(invoice.Terms.Value),
) )
if err != nil { if err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
@ -570,9 +578,8 @@ func TestDuplicateSettleInvoice(t *testing.T) {
} }
// With the invoice in the DB, we'll now attempt to settle the invoice. // With the invoice in the DB, we'll now attempt to settle the invoice.
dbInvoice, err := db.UpdateInvoice( ref := InvoiceRefByHash(payHash)
payHash, getUpdateInvoice(amt), dbInvoice, err := db.UpdateInvoice(ref, getUpdateInvoice(amt))
)
if err != nil { if err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
} }
@ -601,9 +608,7 @@ func TestDuplicateSettleInvoice(t *testing.T) {
// If we try to settle the invoice again, then we should get the very // If we try to settle the invoice again, then we should get the very
// same invoice back, but with an error this time. // same invoice back, but with an error this time.
dbInvoice, err = db.UpdateInvoice( dbInvoice, err = db.UpdateInvoice(ref, getUpdateInvoice(amt))
payHash, getUpdateInvoice(amt),
)
if err != ErrInvoiceAlreadySettled { if err != ErrInvoiceAlreadySettled {
t.Fatalf("expected ErrInvoiceAlreadySettled") t.Fatalf("expected ErrInvoiceAlreadySettled")
} }
@ -653,9 +658,8 @@ func TestQueryInvoices(t *testing.T) {
// We'll only settle half of all invoices created. // We'll only settle half of all invoices created.
if i%2 == 0 { if i%2 == 0 {
_, err := db.UpdateInvoice( ref := InvoiceRefByHash(paymentHash)
paymentHash, getUpdateInvoice(amt), _, err := db.UpdateInvoice(ref, getUpdateInvoice(amt))
)
if err != nil { if err != nil {
t.Fatalf("unable to settle invoice: %v", err) t.Fatalf("unable to settle invoice: %v", err)
} }
@ -951,7 +955,8 @@ func TestCustomRecords(t *testing.T) {
100001: []byte{1, 2}, 100001: []byte{1, 2},
} }
_, err = db.UpdateInvoice(paymentHash, ref := InvoiceRefByHash(paymentHash)
_, err = db.UpdateInvoice(ref,
func(invoice *Invoice) (*InvoiceUpdateDesc, error) { func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
return &InvoiceUpdateDesc{ return &InvoiceUpdateDesc{
AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{ AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{
@ -969,7 +974,7 @@ func TestCustomRecords(t *testing.T) {
// Retrieve the invoice from that database and verify that the custom // Retrieve the invoice from that database and verify that the custom
// records are present. // records are present.
dbInvoice, err := db.LookupInvoice(paymentHash) dbInvoice, err := db.LookupInvoice(ref)
if err != nil { if err != nil {
t.Fatalf("unable to lookup invoice: %v", err) t.Fatalf("unable to lookup invoice: %v", err)
} }
@ -981,3 +986,14 @@ func TestCustomRecords(t *testing.T) {
t.Fatalf("invalid custom records") t.Fatalf("invalid custom records")
} }
} }
// TestInvoiceRef asserts that the proper identifiers are returned from an
// InvoiceRef depending on the constructor used.
func TestInvoiceRef(t *testing.T) {
payHash := lntypes.Hash{0x01}
// An InvoiceRef by hash should return the provided hash and a nil
// payment addr.
refByHash := InvoiceRefByHash(payHash)
assert.Equal(t, payHash, refByHash.PayHash())
}

@ -142,6 +142,32 @@ const (
amtPaidType tlv.Type = 13 amtPaidType tlv.Type = 13
) )
// InvoiceRef is an identifier for invoices supporting queries by payment hash.
type InvoiceRef struct {
// payHash is the payment hash of the target invoice. All invoices are
// currently indexed by payment hash. This value will be used as a
// fallback when no payment address is known.
payHash lntypes.Hash
}
// InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by
// its payment hash.
func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef {
return InvoiceRef{
payHash: payHash,
}
}
// PayHash returns the target invoice's payment hash.
func (r InvoiceRef) PayHash() lntypes.Hash {
return r.payHash
}
// String returns a human-readable representation of an InvoiceRef.
func (r InvoiceRef) String() string {
return fmt.Sprintf("(pay_hash=%v)", r.payHash)
}
// ContractState describes the state the invoice is in. // ContractState describes the state the invoice is in.
type ContractState uint8 type ContractState uint8
@ -538,7 +564,7 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) {
// full invoice is returned. Before setting the incoming HTLC, the values // full invoice is returned. Before setting the incoming HTLC, the values
// SHOULD be checked to ensure the payer meets the agreed upon contractual // SHOULD be checked to ensure the payer meets the agreed upon contractual
// terms of the payment. // terms of the payment.
func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) {
var invoice Invoice var invoice Invoice
err := kvdb.View(d, func(tx kvdb.ReadTx) error { err := kvdb.View(d, func(tx kvdb.ReadTx) error {
invoices := tx.ReadBucket(invoiceBucket) invoices := tx.ReadBucket(invoiceBucket)
@ -550,15 +576,17 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
return ErrNoInvoicesCreated return ErrNoInvoicesCreated
} }
// Check the invoice index to see if an invoice paying to this // Retrieve the invoice number for this invoice using the
// hash exists within the DB. // provided invoice reference.
invoiceNum := invoiceIndex.Get(paymentHash[:]) invoiceNum, err := fetchInvoiceNumByRef(
if invoiceNum == nil { invoiceIndex, ref,
return ErrInvoiceNotFound )
if err != nil {
return err
} }
// An invoice matching the payment hash has been found, so // An invoice was found, retrieve the remainder of the invoice
// retrieve the record of the invoice itself. // body.
i, err := fetchInvoice(invoiceNum, invoices) i, err := fetchInvoice(invoiceNum, invoices)
if err != nil { if err != nil {
return err return err
@ -574,6 +602,21 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
return invoice, nil return invoice, nil
} }
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
// reference.
func fetchInvoiceNumByRef(invoiceIndex kvdb.ReadBucket,
ref InvoiceRef) ([]byte, error) {
payHash := ref.PayHash()
invoiceNum := invoiceIndex.Get(payHash[:])
if invoiceNum == nil {
return nil, ErrInvoiceNotFound
}
return invoiceNum, nil
}
// InvoiceWithPaymentHash is used to store an invoice and its corresponding // InvoiceWithPaymentHash is used to store an invoice and its corresponding
// payment hash. This struct is only used to store results of // payment hash. This struct is only used to store results of
// ChannelDB.FetchAllInvoicesWithPaymentHash() call. // ChannelDB.FetchAllInvoicesWithPaymentHash() call.
@ -824,7 +867,7 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
// The update is performed inside the same database transaction that fetches the // The update is performed inside the same database transaction that fetches the
// invoice and is therefore atomic. The fields to update are controlled by the // invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback. // supplied callback.
func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, func (d *DB) UpdateInvoice(ref InvoiceRef,
callback InvoiceUpdateCallback) (*Invoice, error) { callback InvoiceUpdateCallback) (*Invoice, error) {
var updatedInvoice *Invoice var updatedInvoice *Invoice
@ -846,15 +889,18 @@ func (d *DB) UpdateInvoice(paymentHash lntypes.Hash,
return err return err
} }
// Check the invoice index to see if an invoice paying to this // Retrieve the invoice number for this invoice using the
// hash exists within the DB. // provided invoice reference.
invoiceNum := invoiceIndex.Get(paymentHash[:]) invoiceNum, err := fetchInvoiceNumByRef(
if invoiceNum == nil { invoiceIndex, ref,
return ErrInvoiceNotFound )
} if err != nil {
return err
}
payHash := ref.PayHash()
updatedInvoice, err = d.updateInvoice( updatedInvoice, err = d.updateInvoice(
paymentHash, invoices, settleIndex, invoiceNum, payHash, invoices, settleIndex, invoiceNum,
callback, callback,
) )

@ -61,8 +61,8 @@ type RegistryConfig struct {
// htlcReleaseEvent describes an htlc auto-release event. It is used to release // htlcReleaseEvent describes an htlc auto-release event. It is used to release
// mpp htlcs for which the complete set didn't arrive in time. // mpp htlcs for which the complete set didn't arrive in time.
type htlcReleaseEvent struct { type htlcReleaseEvent struct {
// hash is the payment hash of the htlc to release. // invoiceRef identifiers the invoice this htlc belongs to.
hash lntypes.Hash invoiceRef channeldb.InvoiceRef
// key is the circuit key of the htlc to release. // key is the circuit key of the htlc to release.
key channeldb.CircuitKey key channeldb.CircuitKey
@ -289,7 +289,8 @@ func (i *InvoiceRegistry) invoiceEventLoop() {
// the subscriber. // the subscriber.
case *SingleInvoiceSubscription: case *SingleInvoiceSubscription:
log.Infof("New single invoice subscription "+ log.Infof("New single invoice subscription "+
"client: id=%v, hash=%v", e.id, e.hash) "client: id=%v, ref=%v", e.id,
e.invoiceRef)
i.singleNotificationClients[e.id] = e i.singleNotificationClients[e.id] = e
} }
@ -297,8 +298,8 @@ func (i *InvoiceRegistry) invoiceEventLoop() {
// A new htlc came in for auto-release. // A new htlc came in for auto-release.
case event := <-i.htlcAutoReleaseChan: case event := <-i.htlcAutoReleaseChan:
log.Debugf("Scheduling auto-release for htlc: "+ log.Debugf("Scheduling auto-release for htlc: "+
"hash=%v, key=%v at %v", "ref=%v, key=%v at %v",
event.hash, event.key, event.releaseTime) event.invoiceRef, event.key, event.releaseTime)
// We use an independent timer for every htlc rather // We use an independent timer for every htlc rather
// than a set timer that is reset with every htlc coming // than a set timer that is reset with every htlc coming
@ -311,7 +312,7 @@ func (i *InvoiceRegistry) invoiceEventLoop() {
case <-nextReleaseTick: case <-nextReleaseTick:
event := autoReleaseHeap.Pop().(*htlcReleaseEvent) event := autoReleaseHeap.Pop().(*htlcReleaseEvent)
err := i.cancelSingleHtlc( err := i.cancelSingleHtlc(
event.hash, event.key, ResultMppTimeout, event.invoiceRef, event.key, ResultMppTimeout,
) )
if err != nil { if err != nil {
log.Errorf("HTLC timer: %v", err) log.Errorf("HTLC timer: %v", err)
@ -328,7 +329,7 @@ func (i *InvoiceRegistry) invoiceEventLoop() {
func (i *InvoiceRegistry) dispatchToSingleClients(event *invoiceEvent) { func (i *InvoiceRegistry) dispatchToSingleClients(event *invoiceEvent) {
// Dispatch to single invoice subscribers. // Dispatch to single invoice subscribers.
for _, client := range i.singleNotificationClients { for _, client := range i.singleNotificationClients {
if client.hash != event.hash { if client.invoiceRef.PayHash() != event.hash {
continue continue
} }
@ -465,7 +466,7 @@ func (i *InvoiceRegistry) deliverBacklogEvents(client *InvoiceSubscription) erro
func (i *InvoiceRegistry) deliverSingleBacklogEvents( func (i *InvoiceRegistry) deliverSingleBacklogEvents(
client *SingleInvoiceSubscription) error { client *SingleInvoiceSubscription) error {
invoice, err := i.cdb.LookupInvoice(client.hash) invoice, err := i.cdb.LookupInvoice(client.invoiceRef)
// It is possible that the invoice does not exist yet, but the client is // It is possible that the invoice does not exist yet, but the client is
// already watching it in anticipation. // already watching it in anticipation.
@ -479,7 +480,7 @@ func (i *InvoiceRegistry) deliverSingleBacklogEvents(
} }
err = client.notify(&invoiceEvent{ err = client.notify(&invoiceEvent{
hash: client.hash, hash: client.invoiceRef.PayHash(),
invoice: &invoice, invoice: &invoice,
}) })
if err != nil { if err != nil {
@ -502,8 +503,8 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice,
i.Lock() i.Lock()
log.Debugf("Invoice(%v): added with terms %v", paymentHash, ref := channeldb.InvoiceRefByHash(paymentHash)
invoice.Terms) log.Debugf("Invoice%v: added with terms %v", ref, invoice.Terms)
addIndex, err := i.cdb.AddInvoice(invoice, paymentHash) addIndex, err := i.cdb.AddInvoice(invoice, paymentHash)
if err != nil { if err != nil {
@ -533,17 +534,18 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice,
// We'll check the database to see if there's an existing matching // We'll check the database to see if there's an existing matching
// invoice. // invoice.
return i.cdb.LookupInvoice(rHash) ref := channeldb.InvoiceRefByHash(rHash)
return i.cdb.LookupInvoice(ref)
} }
// startHtlcTimer starts a new timer via the invoice registry main loop that // startHtlcTimer starts a new timer via the invoice registry main loop that
// cancels a single htlc on an invoice when the htlc hold duration has passed. // cancels a single htlc on an invoice when the htlc hold duration has passed.
func (i *InvoiceRegistry) startHtlcTimer(hash lntypes.Hash, func (i *InvoiceRegistry) startHtlcTimer(invoiceRef channeldb.InvoiceRef,
key channeldb.CircuitKey, acceptTime time.Time) error { key channeldb.CircuitKey, acceptTime time.Time) error {
releaseTime := acceptTime.Add(i.cfg.HtlcHoldDuration) releaseTime := acceptTime.Add(i.cfg.HtlcHoldDuration)
event := &htlcReleaseEvent{ event := &htlcReleaseEvent{
hash: hash, invoiceRef: invoiceRef,
key: key, key: key,
releaseTime: releaseTime, releaseTime: releaseTime,
} }
@ -560,7 +562,7 @@ func (i *InvoiceRegistry) startHtlcTimer(hash lntypes.Hash,
// cancelSingleHtlc cancels a single accepted htlc on an invoice. It takes // cancelSingleHtlc cancels a single accepted htlc on an invoice. It takes
// a resolution result which will be used to notify subscribed links and // a resolution result which will be used to notify subscribed links and
// resolvers of the details of the htlc cancellation. // resolvers of the details of the htlc cancellation.
func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash, func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef,
key channeldb.CircuitKey, result FailResolutionResult) error { key channeldb.CircuitKey, result FailResolutionResult) error {
i.Lock() i.Lock()
@ -572,7 +574,7 @@ func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash,
// Only allow individual htlc cancelation on open invoices. // Only allow individual htlc cancelation on open invoices.
if invoice.State != channeldb.ContractOpen { if invoice.State != channeldb.ContractOpen {
log.Debugf("cancelSingleHtlc: invoice %v no longer "+ log.Debugf("cancelSingleHtlc: invoice %v no longer "+
"open", hash) "open", invoiceRef)
return nil, nil return nil, nil
} }
@ -587,13 +589,13 @@ func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash,
// resolved. // resolved.
if htlc.State != channeldb.HtlcStateAccepted { if htlc.State != channeldb.HtlcStateAccepted {
log.Debugf("cancelSingleHtlc: htlc %v on invoice %v "+ log.Debugf("cancelSingleHtlc: htlc %v on invoice %v "+
"is already resolved", key, hash) "is already resolved", key, invoiceRef)
return nil, nil return nil, nil
} }
log.Debugf("cancelSingleHtlc: cancelling htlc %v on invoice %v", log.Debugf("cancelSingleHtlc: cancelling htlc %v on invoice %v",
key, hash) key, invoiceRef)
// Return an update descriptor that cancels htlc and keeps // Return an update descriptor that cancels htlc and keeps
// invoice open. // invoice open.
@ -610,7 +612,7 @@ func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash,
// Intercept the update descriptor to set the local updated variable. If // Intercept the update descriptor to set the local updated variable. If
// no invoice update is performed, we can return early. // no invoice update is performed, we can return early.
var updated bool var updated bool
invoice, err := i.cdb.UpdateInvoice(hash, invoice, err := i.cdb.UpdateInvoice(invoiceRef,
func(invoice *channeldb.Invoice) ( func(invoice *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, error) { *channeldb.InvoiceUpdateDesc, error) {
@ -774,7 +776,9 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
// main event loop. // main event loop.
case *htlcAcceptResolution: case *htlcAcceptResolution:
if r.autoRelease { if r.autoRelease {
err := i.startHtlcTimer(rHash, circuitKey, r.acceptTime) err := i.startHtlcTimer(
ctx.invoiceRef(), circuitKey, r.acceptTime,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -808,7 +812,7 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked(
updateSubscribers bool updateSubscribers bool
) )
invoice, err := i.cdb.UpdateInvoice( invoice, err := i.cdb.UpdateInvoice(
ctx.hash, ctx.invoiceRef(),
func(inv *channeldb.Invoice) ( func(inv *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, error) { *channeldb.InvoiceUpdateDesc, error) {
@ -962,7 +966,8 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error {
} }
hash := preimage.Hash() hash := preimage.Hash()
invoice, err := i.cdb.UpdateInvoice(hash, updateInvoice) invoiceRef := channeldb.InvoiceRefByHash(hash)
invoice, err := i.cdb.UpdateInvoice(invoiceRef, updateInvoice)
if err != nil { if err != nil {
log.Errorf("SettleHodlInvoice with preimage %v: %v", log.Errorf("SettleHodlInvoice with preimage %v: %v",
preimage, err) preimage, err)
@ -970,7 +975,7 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error {
return err return err
} }
log.Debugf("Invoice(%v): settled with preimage %v", hash, log.Debugf("Invoice%v: settled with preimage %v", invoiceRef,
invoice.Terms.PaymentPreimage) invoice.Terms.PaymentPreimage)
// In the callback, we marked the invoice as settled. UpdateInvoice will // In the callback, we marked the invoice as settled. UpdateInvoice will
@ -1011,7 +1016,8 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash,
i.Lock() i.Lock()
defer i.Unlock() defer i.Unlock()
log.Debugf("Invoice(%v): canceling invoice", payHash) ref := channeldb.InvoiceRefByHash(payHash)
log.Debugf("Invoice%v: canceling invoice", ref)
updateInvoice := func(invoice *channeldb.Invoice) ( updateInvoice := func(invoice *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, error) { *channeldb.InvoiceUpdateDesc, error) {
@ -1032,12 +1038,13 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash,
}, nil }, nil
} }
invoice, err := i.cdb.UpdateInvoice(payHash, updateInvoice) invoiceRef := channeldb.InvoiceRefByHash(payHash)
invoice, err := i.cdb.UpdateInvoice(invoiceRef, updateInvoice)
// Implement idempotency by returning success if the invoice was already // Implement idempotency by returning success if the invoice was already
// canceled. // canceled.
if err == channeldb.ErrInvoiceAlreadyCanceled { if err == channeldb.ErrInvoiceAlreadyCanceled {
log.Debugf("Invoice(%v): already canceled", payHash) log.Debugf("Invoice%v: already canceled", ref)
return nil return nil
} }
if err != nil { if err != nil {
@ -1046,12 +1053,12 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash,
// Return without cancellation if the invoice state is ContractAccepted. // Return without cancellation if the invoice state is ContractAccepted.
if invoice.State == channeldb.ContractAccepted { if invoice.State == channeldb.ContractAccepted {
log.Debugf("Invoice(%v): remains accepted as cancel wasn't"+ log.Debugf("Invoice%v: remains accepted as cancel wasn't"+
"explicitly requested.", payHash) "explicitly requested.", ref)
return nil return nil
} }
log.Debugf("Invoice(%v): canceled", payHash) log.Debugf("Invoice%v: canceled", ref)
// In the callback, some htlcs may have been moved to the canceled // In the callback, some htlcs may have been moved to the canceled
// state. We now go through all of these and notify links and resolvers // state. We now go through all of these and notify links and resolvers
@ -1140,7 +1147,7 @@ type InvoiceSubscription struct {
type SingleInvoiceSubscription struct { type SingleInvoiceSubscription struct {
invoiceSubscriptionKit invoiceSubscriptionKit
hash lntypes.Hash invoiceRef channeldb.InvoiceRef
// Updates is a channel that we'll use to send all invoice events for // Updates is a channel that we'll use to send all invoice events for
// the invoice that is subscribed to. // the invoice that is subscribed to.
@ -1269,7 +1276,7 @@ func (i *InvoiceRegistry) SubscribeSingleInvoice(
ntfnQueue: queue.NewConcurrentQueue(20), ntfnQueue: queue.NewConcurrentQueue(20),
cancelChan: make(chan struct{}), cancelChan: make(chan struct{}),
}, },
hash: hash, invoiceRef: channeldb.InvoiceRefByHash(hash),
} }
client.ntfnQueue.Start() client.ntfnQueue.Start()

@ -26,7 +26,7 @@ func TestSettleInvoice(t *testing.T) {
} }
defer subscription.Cancel() defer subscription.Cancel()
if subscription.hash != testInvoicePaymentHash { if subscription.invoiceRef.PayHash() != testInvoicePaymentHash {
t.Fatalf("expected subscription for provided hash") t.Fatalf("expected subscription for provided hash")
} }
@ -237,7 +237,7 @@ func TestCancelInvoice(t *testing.T) {
} }
defer subscription.Cancel() defer subscription.Cancel()
if subscription.hash != testInvoicePaymentHash { if subscription.invoiceRef.PayHash() != testInvoicePaymentHash {
t.Fatalf("expected subscription for provided hash") t.Fatalf("expected subscription for provided hash")
} }
@ -362,7 +362,7 @@ func TestSettleHoldInvoice(t *testing.T) {
} }
defer subscription.Cancel() defer subscription.Cancel()
if subscription.hash != testInvoicePaymentHash { if subscription.invoiceRef.PayHash() != testInvoicePaymentHash {
t.Fatalf("expected subscription for provided hash") t.Fatalf("expected subscription for provided hash")
} }

@ -22,10 +22,16 @@ type invoiceUpdateCtx struct {
mpp *record.MPP mpp *record.MPP
} }
// invoiceRef returns an identifier that can be used to lookup or update the
// invoice this HTLC is targeting.
func (i *invoiceUpdateCtx) invoiceRef() channeldb.InvoiceRef {
return channeldb.InvoiceRefByHash(i.hash)
}
// log logs a message specific to this update context. // log logs a message specific to this update context.
func (i *invoiceUpdateCtx) log(s string) { func (i *invoiceUpdateCtx) log(s string) {
log.Debugf("Invoice(%x): %v, amt=%v, expiry=%v, circuit=%v, mpp=%v", log.Debugf("Invoice%v: %v, amt=%v, expiry=%v, circuit=%v, mpp=%v",
i.hash[:], s, i.amtPaid, i.expiry, i.circuitKey, i.mpp) i.invoiceRef, s, i.amtPaid, i.expiry, i.circuitKey, i.mpp)
} }
// failRes is a helper function which creates a failure resolution with // failRes is a helper function which creates a failure resolution with