From 3522f09a087f4e6fbe3fcc40fd055b42cc36c389 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 May 2020 15:37:10 -0700 Subject: [PATCH] channeldb+invoices: track invoices by InvoiceRef --- channeldb/invoice_test.go | 64 ++++++++++++++++---------- channeldb/invoices.go | 78 +++++++++++++++++++++++++------- invoices/invoiceregistry.go | 69 +++++++++++++++------------- invoices/invoiceregistry_test.go | 6 +-- invoices/update.go | 10 +++- 5 files changed, 151 insertions(+), 76 deletions(-) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 26807894..bd1e6a76 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/stretchr/testify/assert" ) var ( @@ -112,6 +113,7 @@ func TestInvoiceWorkflow(t *testing.T) { fakeInvoice.Terms.Features = emptyFeatures paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() + ref := InvoiceRefByHash(paymentHash) // Add the invoice to the database, this should succeed as there aren't // any existing invoices within the database with the same payment @@ -123,7 +125,7 @@ func TestInvoiceWorkflow(t *testing.T) { // Attempt to retrieve the invoice which was just added to the // database. It should be found, and the invoice returned should be // identical to the one created above. - dbInvoice, err := db.LookupInvoice(paymentHash) + dbInvoice, err := db.LookupInvoice(ref) if err != nil { t.Fatalf("unable to find invoice: %v", err) } @@ -144,11 +146,11 @@ func TestInvoiceWorkflow(t *testing.T) { // now have the settled bit toggle to true and a non-default // SettledDate payAmt := fakeInvoice.Terms.Value * 2 - _, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt)) + _, err = db.UpdateInvoice(ref, getUpdateInvoice(payAmt)) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } - dbInvoice2, err := db.LookupInvoice(paymentHash) + dbInvoice2, err := db.LookupInvoice(ref) if err != nil { t.Fatalf("unable to fetch invoice: %v", err) } @@ -180,7 +182,9 @@ func TestInvoiceWorkflow(t *testing.T) { // Attempt to look up a non-existent invoice, this should also fail but // with a "not found" error. var fakeHash [32]byte - if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound { + fakeRef := InvoiceRefByHash(fakeHash) + _, err = db.LookupInvoice(fakeRef) + if err != ErrInvoiceNotFound { t.Fatalf("lookup should have failed, instead %v", err) } @@ -256,7 +260,9 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) { Amt: 500, CustomRecords: make(record.CustomSet), } - invoice, err := db.UpdateInvoice(paymentHash, + + ref := InvoiceRefByHash(paymentHash) + invoice, err := db.UpdateInvoice(ref, func(invoice *Invoice) (*InvoiceUpdateDesc, error) { return &InvoiceUpdateDesc{ AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{ @@ -275,13 +281,14 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) { } // Cancel the htlc again. - invoice, err = db.UpdateInvoice(paymentHash, func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - return &InvoiceUpdateDesc{ - CancelHtlcs: map[CircuitKey]struct{}{ - key: {}, - }, - }, nil - }) + invoice, err = db.UpdateInvoice(ref, + func(invoice *Invoice) (*InvoiceUpdateDesc, error) { + return &InvoiceUpdateDesc{ + CancelHtlcs: map[CircuitKey]struct{}{ + key: {}, + }, + }, nil + }) if err != nil { t.Fatalf("unable to cancel htlc: %v", err) } @@ -380,8 +387,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) { paymentHash := invoice.Terms.PaymentPreimage.Hash() + ref := InvoiceRefByHash(paymentHash) _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(invoice.Terms.Value), + ref, getUpdateInvoice(invoice.Terms.Value), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) @@ -570,9 +578,8 @@ func TestDuplicateSettleInvoice(t *testing.T) { } // With the invoice in the DB, we'll now attempt to settle the invoice. - dbInvoice, err := db.UpdateInvoice( - payHash, getUpdateInvoice(amt), - ) + ref := InvoiceRefByHash(payHash) + dbInvoice, err := db.UpdateInvoice(ref, getUpdateInvoice(amt)) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } @@ -601,9 +608,7 @@ func TestDuplicateSettleInvoice(t *testing.T) { // If we try to settle the invoice again, then we should get the very // same invoice back, but with an error this time. - dbInvoice, err = db.UpdateInvoice( - payHash, getUpdateInvoice(amt), - ) + dbInvoice, err = db.UpdateInvoice(ref, getUpdateInvoice(amt)) if err != ErrInvoiceAlreadySettled { t.Fatalf("expected ErrInvoiceAlreadySettled") } @@ -653,9 +658,8 @@ func TestQueryInvoices(t *testing.T) { // We'll only settle half of all invoices created. if i%2 == 0 { - _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(amt), - ) + ref := InvoiceRefByHash(paymentHash) + _, err := db.UpdateInvoice(ref, getUpdateInvoice(amt)) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } @@ -951,7 +955,8 @@ func TestCustomRecords(t *testing.T) { 100001: []byte{1, 2}, } - _, err = db.UpdateInvoice(paymentHash, + ref := InvoiceRefByHash(paymentHash) + _, err = db.UpdateInvoice(ref, func(invoice *Invoice) (*InvoiceUpdateDesc, error) { return &InvoiceUpdateDesc{ AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{ @@ -969,7 +974,7 @@ func TestCustomRecords(t *testing.T) { // Retrieve the invoice from that database and verify that the custom // records are present. - dbInvoice, err := db.LookupInvoice(paymentHash) + dbInvoice, err := db.LookupInvoice(ref) if err != nil { t.Fatalf("unable to lookup invoice: %v", err) } @@ -981,3 +986,14 @@ func TestCustomRecords(t *testing.T) { t.Fatalf("invalid custom records") } } + +// TestInvoiceRef asserts that the proper identifiers are returned from an +// InvoiceRef depending on the constructor used. +func TestInvoiceRef(t *testing.T) { + payHash := lntypes.Hash{0x01} + + // An InvoiceRef by hash should return the provided hash and a nil + // payment addr. + refByHash := InvoiceRefByHash(payHash) + assert.Equal(t, payHash, refByHash.PayHash()) +} diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 23c10dc6..83c8de13 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -142,6 +142,32 @@ const ( amtPaidType tlv.Type = 13 ) +// InvoiceRef is an identifier for invoices supporting queries by payment hash. +type InvoiceRef struct { + // payHash is the payment hash of the target invoice. All invoices are + // currently indexed by payment hash. This value will be used as a + // fallback when no payment address is known. + payHash lntypes.Hash +} + +// InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by +// its payment hash. +func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef { + return InvoiceRef{ + payHash: payHash, + } +} + +// PayHash returns the target invoice's payment hash. +func (r InvoiceRef) PayHash() lntypes.Hash { + return r.payHash +} + +// String returns a human-readable representation of an InvoiceRef. +func (r InvoiceRef) String() string { + return fmt.Sprintf("(pay_hash=%v)", r.payHash) +} + // ContractState describes the state the invoice is in. type ContractState uint8 @@ -538,7 +564,7 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { // full invoice is returned. Before setting the incoming HTLC, the values // SHOULD be checked to ensure the payer meets the agreed upon contractual // terms of the payment. -func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { +func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { var invoice Invoice err := kvdb.View(d, func(tx kvdb.ReadTx) error { invoices := tx.ReadBucket(invoiceBucket) @@ -550,15 +576,17 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { return ErrNoInvoicesCreated } - // Check the invoice index to see if an invoice paying to this - // hash exists within the DB. - invoiceNum := invoiceIndex.Get(paymentHash[:]) - if invoiceNum == nil { - return ErrInvoiceNotFound + // Retrieve the invoice number for this invoice using the + // provided invoice reference. + invoiceNum, err := fetchInvoiceNumByRef( + invoiceIndex, ref, + ) + if err != nil { + return err } - // An invoice matching the payment hash has been found, so - // retrieve the record of the invoice itself. + // An invoice was found, retrieve the remainder of the invoice + // body. i, err := fetchInvoice(invoiceNum, invoices) if err != nil { return err @@ -574,6 +602,21 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { return invoice, nil } +// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice +// reference. +func fetchInvoiceNumByRef(invoiceIndex kvdb.ReadBucket, + ref InvoiceRef) ([]byte, error) { + + payHash := ref.PayHash() + + invoiceNum := invoiceIndex.Get(payHash[:]) + if invoiceNum == nil { + return nil, ErrInvoiceNotFound + } + + return invoiceNum, nil +} + // InvoiceWithPaymentHash is used to store an invoice and its corresponding // payment hash. This struct is only used to store results of // ChannelDB.FetchAllInvoicesWithPaymentHash() call. @@ -824,7 +867,7 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { // The update is performed inside the same database transaction that fetches the // invoice and is therefore atomic. The fields to update are controlled by the // supplied callback. -func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, +func (d *DB) UpdateInvoice(ref InvoiceRef, callback InvoiceUpdateCallback) (*Invoice, error) { var updatedInvoice *Invoice @@ -846,15 +889,18 @@ func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, return err } - // Check the invoice index to see if an invoice paying to this - // hash exists within the DB. - invoiceNum := invoiceIndex.Get(paymentHash[:]) - if invoiceNum == nil { - return ErrInvoiceNotFound - } + // Retrieve the invoice number for this invoice using the + // provided invoice reference. + invoiceNum, err := fetchInvoiceNumByRef( + invoiceIndex, ref, + ) + if err != nil { + return err + } + payHash := ref.PayHash() updatedInvoice, err = d.updateInvoice( - paymentHash, invoices, settleIndex, invoiceNum, + payHash, invoices, settleIndex, invoiceNum, callback, ) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 4c1d9b2a..61e2e56b 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -61,8 +61,8 @@ type RegistryConfig struct { // htlcReleaseEvent describes an htlc auto-release event. It is used to release // mpp htlcs for which the complete set didn't arrive in time. type htlcReleaseEvent struct { - // hash is the payment hash of the htlc to release. - hash lntypes.Hash + // invoiceRef identifiers the invoice this htlc belongs to. + invoiceRef channeldb.InvoiceRef // key is the circuit key of the htlc to release. key channeldb.CircuitKey @@ -289,7 +289,8 @@ func (i *InvoiceRegistry) invoiceEventLoop() { // the subscriber. case *SingleInvoiceSubscription: log.Infof("New single invoice subscription "+ - "client: id=%v, hash=%v", e.id, e.hash) + "client: id=%v, ref=%v", e.id, + e.invoiceRef) i.singleNotificationClients[e.id] = e } @@ -297,8 +298,8 @@ func (i *InvoiceRegistry) invoiceEventLoop() { // A new htlc came in for auto-release. case event := <-i.htlcAutoReleaseChan: log.Debugf("Scheduling auto-release for htlc: "+ - "hash=%v, key=%v at %v", - event.hash, event.key, event.releaseTime) + "ref=%v, key=%v at %v", + event.invoiceRef, event.key, event.releaseTime) // We use an independent timer for every htlc rather // than a set timer that is reset with every htlc coming @@ -311,7 +312,7 @@ func (i *InvoiceRegistry) invoiceEventLoop() { case <-nextReleaseTick: event := autoReleaseHeap.Pop().(*htlcReleaseEvent) err := i.cancelSingleHtlc( - event.hash, event.key, ResultMppTimeout, + event.invoiceRef, event.key, ResultMppTimeout, ) if err != nil { log.Errorf("HTLC timer: %v", err) @@ -328,7 +329,7 @@ func (i *InvoiceRegistry) invoiceEventLoop() { func (i *InvoiceRegistry) dispatchToSingleClients(event *invoiceEvent) { // Dispatch to single invoice subscribers. for _, client := range i.singleNotificationClients { - if client.hash != event.hash { + if client.invoiceRef.PayHash() != event.hash { continue } @@ -465,7 +466,7 @@ func (i *InvoiceRegistry) deliverBacklogEvents(client *InvoiceSubscription) erro func (i *InvoiceRegistry) deliverSingleBacklogEvents( client *SingleInvoiceSubscription) error { - invoice, err := i.cdb.LookupInvoice(client.hash) + invoice, err := i.cdb.LookupInvoice(client.invoiceRef) // It is possible that the invoice does not exist yet, but the client is // already watching it in anticipation. @@ -479,7 +480,7 @@ func (i *InvoiceRegistry) deliverSingleBacklogEvents( } err = client.notify(&invoiceEvent{ - hash: client.hash, + hash: client.invoiceRef.PayHash(), invoice: &invoice, }) if err != nil { @@ -502,8 +503,8 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, i.Lock() - log.Debugf("Invoice(%v): added with terms %v", paymentHash, - invoice.Terms) + ref := channeldb.InvoiceRefByHash(paymentHash) + log.Debugf("Invoice%v: added with terms %v", ref, invoice.Terms) addIndex, err := i.cdb.AddInvoice(invoice, paymentHash) if err != nil { @@ -533,17 +534,18 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice, // We'll check the database to see if there's an existing matching // invoice. - return i.cdb.LookupInvoice(rHash) + ref := channeldb.InvoiceRefByHash(rHash) + return i.cdb.LookupInvoice(ref) } // startHtlcTimer starts a new timer via the invoice registry main loop that // cancels a single htlc on an invoice when the htlc hold duration has passed. -func (i *InvoiceRegistry) startHtlcTimer(hash lntypes.Hash, +func (i *InvoiceRegistry) startHtlcTimer(invoiceRef channeldb.InvoiceRef, key channeldb.CircuitKey, acceptTime time.Time) error { releaseTime := acceptTime.Add(i.cfg.HtlcHoldDuration) event := &htlcReleaseEvent{ - hash: hash, + invoiceRef: invoiceRef, key: key, releaseTime: releaseTime, } @@ -560,7 +562,7 @@ func (i *InvoiceRegistry) startHtlcTimer(hash lntypes.Hash, // cancelSingleHtlc cancels a single accepted htlc on an invoice. It takes // a resolution result which will be used to notify subscribed links and // resolvers of the details of the htlc cancellation. -func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash, +func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef, key channeldb.CircuitKey, result FailResolutionResult) error { i.Lock() @@ -572,7 +574,7 @@ func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash, // Only allow individual htlc cancelation on open invoices. if invoice.State != channeldb.ContractOpen { log.Debugf("cancelSingleHtlc: invoice %v no longer "+ - "open", hash) + "open", invoiceRef) return nil, nil } @@ -587,13 +589,13 @@ func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash, // resolved. if htlc.State != channeldb.HtlcStateAccepted { log.Debugf("cancelSingleHtlc: htlc %v on invoice %v "+ - "is already resolved", key, hash) + "is already resolved", key, invoiceRef) return nil, nil } log.Debugf("cancelSingleHtlc: cancelling htlc %v on invoice %v", - key, hash) + key, invoiceRef) // Return an update descriptor that cancels htlc and keeps // invoice open. @@ -610,7 +612,7 @@ func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash, // Intercept the update descriptor to set the local updated variable. If // no invoice update is performed, we can return early. var updated bool - invoice, err := i.cdb.UpdateInvoice(hash, + invoice, err := i.cdb.UpdateInvoice(invoiceRef, func(invoice *channeldb.Invoice) ( *channeldb.InvoiceUpdateDesc, error) { @@ -774,7 +776,9 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, // main event loop. case *htlcAcceptResolution: if r.autoRelease { - err := i.startHtlcTimer(rHash, circuitKey, r.acceptTime) + err := i.startHtlcTimer( + ctx.invoiceRef(), circuitKey, r.acceptTime, + ) if err != nil { return nil, err } @@ -808,7 +812,7 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( updateSubscribers bool ) invoice, err := i.cdb.UpdateInvoice( - ctx.hash, + ctx.invoiceRef(), func(inv *channeldb.Invoice) ( *channeldb.InvoiceUpdateDesc, error) { @@ -962,7 +966,8 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error { } hash := preimage.Hash() - invoice, err := i.cdb.UpdateInvoice(hash, updateInvoice) + invoiceRef := channeldb.InvoiceRefByHash(hash) + invoice, err := i.cdb.UpdateInvoice(invoiceRef, updateInvoice) if err != nil { log.Errorf("SettleHodlInvoice with preimage %v: %v", preimage, err) @@ -970,7 +975,7 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error { return err } - log.Debugf("Invoice(%v): settled with preimage %v", hash, + log.Debugf("Invoice%v: settled with preimage %v", invoiceRef, invoice.Terms.PaymentPreimage) // In the callback, we marked the invoice as settled. UpdateInvoice will @@ -1011,7 +1016,8 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, i.Lock() defer i.Unlock() - log.Debugf("Invoice(%v): canceling invoice", payHash) + ref := channeldb.InvoiceRefByHash(payHash) + log.Debugf("Invoice%v: canceling invoice", ref) updateInvoice := func(invoice *channeldb.Invoice) ( *channeldb.InvoiceUpdateDesc, error) { @@ -1032,12 +1038,13 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, }, nil } - invoice, err := i.cdb.UpdateInvoice(payHash, updateInvoice) + invoiceRef := channeldb.InvoiceRefByHash(payHash) + invoice, err := i.cdb.UpdateInvoice(invoiceRef, updateInvoice) // Implement idempotency by returning success if the invoice was already // canceled. if err == channeldb.ErrInvoiceAlreadyCanceled { - log.Debugf("Invoice(%v): already canceled", payHash) + log.Debugf("Invoice%v: already canceled", ref) return nil } if err != nil { @@ -1046,12 +1053,12 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, // Return without cancellation if the invoice state is ContractAccepted. if invoice.State == channeldb.ContractAccepted { - log.Debugf("Invoice(%v): remains accepted as cancel wasn't"+ - "explicitly requested.", payHash) + log.Debugf("Invoice%v: remains accepted as cancel wasn't"+ + "explicitly requested.", ref) return nil } - log.Debugf("Invoice(%v): canceled", payHash) + log.Debugf("Invoice%v: canceled", ref) // In the callback, some htlcs may have been moved to the canceled // state. We now go through all of these and notify links and resolvers @@ -1140,7 +1147,7 @@ type InvoiceSubscription struct { type SingleInvoiceSubscription struct { invoiceSubscriptionKit - hash lntypes.Hash + invoiceRef channeldb.InvoiceRef // Updates is a channel that we'll use to send all invoice events for // the invoice that is subscribed to. @@ -1269,7 +1276,7 @@ func (i *InvoiceRegistry) SubscribeSingleInvoice( ntfnQueue: queue.NewConcurrentQueue(20), cancelChan: make(chan struct{}), }, - hash: hash, + invoiceRef: channeldb.InvoiceRefByHash(hash), } client.ntfnQueue.Start() diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 319c30cf..fa672a36 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -26,7 +26,7 @@ func TestSettleInvoice(t *testing.T) { } defer subscription.Cancel() - if subscription.hash != testInvoicePaymentHash { + if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { t.Fatalf("expected subscription for provided hash") } @@ -237,7 +237,7 @@ func TestCancelInvoice(t *testing.T) { } defer subscription.Cancel() - if subscription.hash != testInvoicePaymentHash { + if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { t.Fatalf("expected subscription for provided hash") } @@ -362,7 +362,7 @@ func TestSettleHoldInvoice(t *testing.T) { } defer subscription.Cancel() - if subscription.hash != testInvoicePaymentHash { + if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { t.Fatalf("expected subscription for provided hash") } diff --git a/invoices/update.go b/invoices/update.go index 3226779c..4680b3cd 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -22,10 +22,16 @@ type invoiceUpdateCtx struct { mpp *record.MPP } +// invoiceRef returns an identifier that can be used to lookup or update the +// invoice this HTLC is targeting. +func (i *invoiceUpdateCtx) invoiceRef() channeldb.InvoiceRef { + return channeldb.InvoiceRefByHash(i.hash) +} + // log logs a message specific to this update context. func (i *invoiceUpdateCtx) log(s string) { - log.Debugf("Invoice(%x): %v, amt=%v, expiry=%v, circuit=%v, mpp=%v", - i.hash[:], s, i.amtPaid, i.expiry, i.circuitKey, i.mpp) + log.Debugf("Invoice%v: %v, amt=%v, expiry=%v, circuit=%v, mpp=%v", + i.invoiceRef, s, i.amtPaid, i.expiry, i.circuitKey, i.mpp) } // failRes is a helper function which creates a failure resolution with