From 174d5775243ff95153b6cdd985ee5d5fdcc22411 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 24 Mar 2021 19:48:32 -0700 Subject: [PATCH] channeldb: make payhash on InvoiceRef optional Currently we support queries by payHash or payHash+payAddr. For handling of AMP HTLCs, we only need to support querying by payAddr. --- channeldb/invoice_test.go | 38 +++++++++++++++++-- channeldb/invoices.go | 64 ++++++++++++++++++++++---------- invoices/invoiceregistry.go | 10 ++++- invoices/invoiceregistry_test.go | 16 +++----- 4 files changed, 93 insertions(+), 35 deletions(-) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 7132ba27..af1d967b 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1208,21 +1208,21 @@ func TestInvoiceRef(t *testing.T) { // An InvoiceRef by hash should return the provided hash and a nil // payment addr. refByHash := InvoiceRefByHash(payHash) - require.Equal(t, payHash, refByHash.PayHash()) + require.Equal(t, &payHash, refByHash.PayHash()) require.Equal(t, (*[32]byte)(nil), refByHash.PayAddr()) require.Equal(t, (*[32]byte)(nil), refByHash.SetID()) // An InvoiceRef by hash and addr should return the payment hash and // payment addr passed to the constructor. refByHashAndAddr := InvoiceRefByHashAndAddr(payHash, payAddr) - require.Equal(t, payHash, refByHashAndAddr.PayHash()) + require.Equal(t, &payHash, refByHashAndAddr.PayHash()) require.Equal(t, &payAddr, refByHashAndAddr.PayAddr()) require.Equal(t, (*[32]byte)(nil), refByHashAndAddr.SetID()) // An InvoiceRef by set id should return an empty pay hash, a nil pay // addr, and a reference to the given set id. refBySetID := InvoiceRefBySetID(setID) - require.Equal(t, lntypes.Hash{}, refBySetID.PayHash()) + require.Equal(t, (*lntypes.Hash)(nil), refBySetID.PayHash()) require.Equal(t, (*[32]byte)(nil), refBySetID.PayAddr()) require.Equal(t, &setID, refBySetID.SetID()) @@ -1533,6 +1533,38 @@ func getUpdateInvoiceAMPSettle(setID *[32]byte) InvoiceUpdateCallback { } } +// TestUnexpectedInvoicePreimage asserts that legacy or MPP invoices cannot be +// settled when referenced by payment address only. Since regular or MPP +// payments do not store the payment hash explicitly (it is stored in the +// index), this enforces that they can only be updated using a InvoiceRefByHash +// or InvoiceRefByHashOrAddr. +func TestUnexpectedInvoicePreimage(t *testing.T) { + t.Parallel() + + db, cleanup, err := MakeTestDB() + defer cleanup() + require.NoError(t, err, "unable to make test db") + + invoice, err := randInvoice(lnwire.MilliSatoshi(100)) + require.NoError(t, err) + + // Add a random invoice indexed by payment hash and payment addr. + paymentHash := invoice.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice, paymentHash) + require.NoError(t, err) + + // Attempt to update the invoice by pay addr only. This will fail since, + // in order to settle an MPP invoice, the InvoiceRef must present a + // payment hash against which to validate the preimage. + _, err = db.UpdateInvoice( + InvoiceRefByAddr(invoice.Terms.PaymentAddr), + getUpdateInvoice(invoice.Terms.Value), + ) + + //Assert that we get ErrUnexpectedInvoicePreimage. + require.Error(t, ErrUnexpectedInvoicePreimage, err) +} + // TestDeleteInvoices tests that deleting a list of invoices will succeed // if all delete references are valid, or will fail otherwise. func TestDeleteInvoices(t *testing.T) { diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 444602ae..613960bf 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -122,6 +122,13 @@ var ( // ErrEmptyHTLCSet is returned when attempting to accept or settle and // HTLC set that has no HTLCs. ErrEmptyHTLCSet = errors.New("cannot settle/accept empty HTLC set") + + // ErrUnexpectedInvoicePreimage is returned when an invoice-level + // preimage is provided when trying to settle an invoice that shouldn't + // have one, e.g. an AMP invoice. + ErrUnexpectedInvoicePreimage = errors.New( + "unexpected invoice preimage provided on settle", + ) ) // ErrDuplicateSetID is an error returned when attempting to adding an AMP HTLC @@ -198,7 +205,7 @@ type InvoiceRef struct { // payHash is the payment hash of the target invoice. All invoices are // currently indexed by payment hash. This value will be used as a // fallback when no payment address is known. - payHash lntypes.Hash + payHash *lntypes.Hash // payAddr is the payment addr of the target invoice. Newer invoices // (0.11 and up) are indexed by payment address in addition to payment @@ -220,7 +227,7 @@ type InvoiceRef struct { // its payment hash. func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef { return InvoiceRef{ - payHash: payHash, + payHash: &payHash, } } @@ -231,7 +238,7 @@ func InvoiceRefByHashAndAddr(payHash lntypes.Hash, payAddr [32]byte) InvoiceRef { return InvoiceRef{ - payHash: payHash, + payHash: &payHash, payAddr: &payAddr, } } @@ -253,9 +260,15 @@ func InvoiceRefBySetID(setID [32]byte) InvoiceRef { } } -// PayHash returns the target invoice's payment hash. -func (r InvoiceRef) PayHash() lntypes.Hash { - return r.payHash +// PayHash returns the optional payment hash of the target invoice. +// +// NOTE: This value may be nil. +func (r InvoiceRef) PayHash() *lntypes.Hash { + if r.payHash != nil { + hash := *r.payHash + return &hash + } + return nil } // PayAddr returns the optional payment address of the target invoice. @@ -887,19 +900,27 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket, payHash := ref.PayHash() payAddr := ref.PayAddr() - var ( - invoiceNumByHash = invoiceIndex.Get(payHash[:]) - invoiceNumByAddr []byte - ) - if payAddr != nil { - // Only allow lookups for payment address if it is not a blank - // payment address, which is a special-cased value for legacy - // keysend invoices. - if *payAddr != BlankPayAddr { - invoiceNumByAddr = payAddrIndex.Get(payAddr[:]) + getInvoiceNumByHash := func() []byte { + if payHash != nil { + return invoiceIndex.Get(payHash[:]) } + return nil } + getInvoiceNumByAddr := func() []byte { + if payAddr != nil { + // Only allow lookups for payment address if it is not a + // blank payment address, which is a special-cased value + // for legacy keysend invoices. + if *payAddr != BlankPayAddr { + return payAddrIndex.Get(payAddr[:]) + } + } + return nil + } + + invoiceNumByHash := getInvoiceNumByHash() + invoiceNumByAddr := getInvoiceNumByAddr() switch { // If payment address and payment hash both reference an existing @@ -1745,7 +1766,7 @@ func copyInvoice(src *Invoice) *Invoice { // updateInvoice fetches the invoice, obtains the update descriptor from the // callback and applies the updates in a single db transaction. -func (d *DB) updateInvoice(hash lntypes.Hash, invoices, +func (d *DB) updateInvoice(hash *lntypes.Hash, invoices, settleIndex, setIDIndex kvdb.RwBucket, invoiceNum []byte, callback InvoiceUpdateCallback) (*Invoice, error) { @@ -1913,7 +1934,7 @@ func (d *DB) updateInvoice(hash lntypes.Hash, invoices, } // updateInvoiceState validates and processes an invoice state update. -func updateInvoiceState(invoice *Invoice, hash lntypes.Hash, +func updateInvoiceState(invoice *Invoice, hash *lntypes.Hash, update InvoiceStateUpdateDesc) error { // Returning to open is never allowed from any state. @@ -1962,9 +1983,14 @@ func updateInvoiceState(invoice *Invoice, hash lntypes.Hash, switch { + // If an invoice-level preimage was supplied, but the InvoiceRef + // doesn't specify a hash (e.g. AMP invoices) we fail. + case update.Preimage != nil && hash == nil: + return ErrUnexpectedInvoicePreimage + // Validate the supplied preimage for non-AMP invoices. case update.Preimage != nil: - if update.Preimage.Hash() != hash { + if update.Preimage.Hash() != *hash { return ErrInvoicePreimageMismatch } invoice.Terms.PaymentPreimage = update.Preimage diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index e6e4093b..4f594ee4 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -377,7 +377,8 @@ func (i *InvoiceRegistry) invoiceEventLoop() { func (i *InvoiceRegistry) dispatchToSingleClients(event *invoiceEvent) { // Dispatch to single invoice subscribers. for _, client := range i.singleNotificationClients { - if client.invoiceRef.PayHash() != event.hash { + payHash := client.invoiceRef.PayHash() + if payHash == nil || *payHash != event.hash { continue } @@ -524,8 +525,13 @@ func (i *InvoiceRegistry) deliverSingleBacklogEvents( return err } + payHash := client.invoiceRef.PayHash() + if payHash == nil { + return nil + } + err = client.notify(&invoiceEvent{ - hash: client.invoiceRef.PayHash(), + hash: *payHash, invoice: &invoice, }) if err != nil { diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 9e465eaa..2056a901 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -29,9 +29,7 @@ func TestSettleInvoice(t *testing.T) { } defer subscription.Cancel() - if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { - t.Fatalf("expected subscription for provided hash") - } + require.Equal(t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash) // Add the invoice. addIdx, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) @@ -244,9 +242,7 @@ func testCancelInvoice(t *testing.T, gc bool) { } defer subscription.Cancel() - if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { - t.Fatalf("expected subscription for provided hash") - } + require.Equal(t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash) // Add the invoice. amt := lnwire.MilliSatoshi(100000) @@ -404,9 +400,7 @@ func TestSettleHoldInvoice(t *testing.T) { } defer subscription.Cancel() - if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { - t.Fatalf("expected subscription for provided hash") - } + require.Equal(t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash) // Add the invoice. _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash) @@ -1211,7 +1205,7 @@ func TestSettleInvoicePaymentAddrRequired(t *testing.T) { defer subscription.Cancel() require.Equal( - t, subscription.invoiceRef.PayHash(), testInvoicePaymentHash, + t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash, ) // Add the invoice, which requires the MPP payload to always be @@ -1287,7 +1281,7 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { defer subscription.Cancel() require.Equal( - t, subscription.invoiceRef.PayHash(), testInvoicePaymentHash, + t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash, ) // Add the invoice, which requires the MPP payload to always be