channeldb: make payhash on InvoiceRef optional

Currently we support queries by payHash or payHash+payAddr. For handling
of AMP HTLCs, we only need to support querying by payAddr.
This commit is contained in:
Conner Fromknecht 2021-03-24 19:48:32 -07:00
parent be6698447e
commit 174d577524
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
4 changed files with 93 additions and 35 deletions

@ -1208,21 +1208,21 @@ func TestInvoiceRef(t *testing.T) {
// An InvoiceRef by hash should return the provided hash and a nil // An InvoiceRef by hash should return the provided hash and a nil
// payment addr. // payment addr.
refByHash := InvoiceRefByHash(payHash) refByHash := InvoiceRefByHash(payHash)
require.Equal(t, payHash, refByHash.PayHash()) require.Equal(t, &payHash, refByHash.PayHash())
require.Equal(t, (*[32]byte)(nil), refByHash.PayAddr()) require.Equal(t, (*[32]byte)(nil), refByHash.PayAddr())
require.Equal(t, (*[32]byte)(nil), refByHash.SetID()) require.Equal(t, (*[32]byte)(nil), refByHash.SetID())
// An InvoiceRef by hash and addr should return the payment hash and // An InvoiceRef by hash and addr should return the payment hash and
// payment addr passed to the constructor. // payment addr passed to the constructor.
refByHashAndAddr := InvoiceRefByHashAndAddr(payHash, payAddr) refByHashAndAddr := InvoiceRefByHashAndAddr(payHash, payAddr)
require.Equal(t, payHash, refByHashAndAddr.PayHash()) require.Equal(t, &payHash, refByHashAndAddr.PayHash())
require.Equal(t, &payAddr, refByHashAndAddr.PayAddr()) require.Equal(t, &payAddr, refByHashAndAddr.PayAddr())
require.Equal(t, (*[32]byte)(nil), refByHashAndAddr.SetID()) require.Equal(t, (*[32]byte)(nil), refByHashAndAddr.SetID())
// An InvoiceRef by set id should return an empty pay hash, a nil pay // An InvoiceRef by set id should return an empty pay hash, a nil pay
// addr, and a reference to the given set id. // addr, and a reference to the given set id.
refBySetID := InvoiceRefBySetID(setID) refBySetID := InvoiceRefBySetID(setID)
require.Equal(t, lntypes.Hash{}, refBySetID.PayHash()) require.Equal(t, (*lntypes.Hash)(nil), refBySetID.PayHash())
require.Equal(t, (*[32]byte)(nil), refBySetID.PayAddr()) require.Equal(t, (*[32]byte)(nil), refBySetID.PayAddr())
require.Equal(t, &setID, refBySetID.SetID()) require.Equal(t, &setID, refBySetID.SetID())
@ -1533,6 +1533,38 @@ func getUpdateInvoiceAMPSettle(setID *[32]byte) InvoiceUpdateCallback {
} }
} }
// TestUnexpectedInvoicePreimage asserts that legacy or MPP invoices cannot be
// settled when referenced by payment address only. Since regular or MPP
// payments do not store the payment hash explicitly (it is stored in the
// index), this enforces that they can only be updated using a InvoiceRefByHash
// or InvoiceRefByHashOrAddr.
func TestUnexpectedInvoicePreimage(t *testing.T) {
t.Parallel()
db, cleanup, err := MakeTestDB()
defer cleanup()
require.NoError(t, err, "unable to make test db")
invoice, err := randInvoice(lnwire.MilliSatoshi(100))
require.NoError(t, err)
// Add a random invoice indexed by payment hash and payment addr.
paymentHash := invoice.Terms.PaymentPreimage.Hash()
_, err = db.AddInvoice(invoice, paymentHash)
require.NoError(t, err)
// Attempt to update the invoice by pay addr only. This will fail since,
// in order to settle an MPP invoice, the InvoiceRef must present a
// payment hash against which to validate the preimage.
_, err = db.UpdateInvoice(
InvoiceRefByAddr(invoice.Terms.PaymentAddr),
getUpdateInvoice(invoice.Terms.Value),
)
//Assert that we get ErrUnexpectedInvoicePreimage.
require.Error(t, ErrUnexpectedInvoicePreimage, err)
}
// TestDeleteInvoices tests that deleting a list of invoices will succeed // TestDeleteInvoices tests that deleting a list of invoices will succeed
// if all delete references are valid, or will fail otherwise. // if all delete references are valid, or will fail otherwise.
func TestDeleteInvoices(t *testing.T) { func TestDeleteInvoices(t *testing.T) {

@ -122,6 +122,13 @@ var (
// ErrEmptyHTLCSet is returned when attempting to accept or settle and // ErrEmptyHTLCSet is returned when attempting to accept or settle and
// HTLC set that has no HTLCs. // HTLC set that has no HTLCs.
ErrEmptyHTLCSet = errors.New("cannot settle/accept empty HTLC set") ErrEmptyHTLCSet = errors.New("cannot settle/accept empty HTLC set")
// ErrUnexpectedInvoicePreimage is returned when an invoice-level
// preimage is provided when trying to settle an invoice that shouldn't
// have one, e.g. an AMP invoice.
ErrUnexpectedInvoicePreimage = errors.New(
"unexpected invoice preimage provided on settle",
)
) )
// ErrDuplicateSetID is an error returned when attempting to adding an AMP HTLC // ErrDuplicateSetID is an error returned when attempting to adding an AMP HTLC
@ -198,7 +205,7 @@ type InvoiceRef struct {
// payHash is the payment hash of the target invoice. All invoices are // payHash is the payment hash of the target invoice. All invoices are
// currently indexed by payment hash. This value will be used as a // currently indexed by payment hash. This value will be used as a
// fallback when no payment address is known. // fallback when no payment address is known.
payHash lntypes.Hash payHash *lntypes.Hash
// payAddr is the payment addr of the target invoice. Newer invoices // payAddr is the payment addr of the target invoice. Newer invoices
// (0.11 and up) are indexed by payment address in addition to payment // (0.11 and up) are indexed by payment address in addition to payment
@ -220,7 +227,7 @@ type InvoiceRef struct {
// its payment hash. // its payment hash.
func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef { func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef {
return InvoiceRef{ return InvoiceRef{
payHash: payHash, payHash: &payHash,
} }
} }
@ -231,7 +238,7 @@ func InvoiceRefByHashAndAddr(payHash lntypes.Hash,
payAddr [32]byte) InvoiceRef { payAddr [32]byte) InvoiceRef {
return InvoiceRef{ return InvoiceRef{
payHash: payHash, payHash: &payHash,
payAddr: &payAddr, payAddr: &payAddr,
} }
} }
@ -253,9 +260,15 @@ func InvoiceRefBySetID(setID [32]byte) InvoiceRef {
} }
} }
// PayHash returns the target invoice's payment hash. // PayHash returns the optional payment hash of the target invoice.
func (r InvoiceRef) PayHash() lntypes.Hash { //
return r.payHash // NOTE: This value may be nil.
func (r InvoiceRef) PayHash() *lntypes.Hash {
if r.payHash != nil {
hash := *r.payHash
return &hash
}
return nil
} }
// PayAddr returns the optional payment address of the target invoice. // PayAddr returns the optional payment address of the target invoice.
@ -887,19 +900,27 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
payHash := ref.PayHash() payHash := ref.PayHash()
payAddr := ref.PayAddr() payAddr := ref.PayAddr()
var ( getInvoiceNumByHash := func() []byte {
invoiceNumByHash = invoiceIndex.Get(payHash[:]) if payHash != nil {
invoiceNumByAddr []byte return invoiceIndex.Get(payHash[:])
)
if payAddr != nil {
// Only allow lookups for payment address if it is not a blank
// payment address, which is a special-cased value for legacy
// keysend invoices.
if *payAddr != BlankPayAddr {
invoiceNumByAddr = payAddrIndex.Get(payAddr[:])
} }
return nil
} }
getInvoiceNumByAddr := func() []byte {
if payAddr != nil {
// Only allow lookups for payment address if it is not a
// blank payment address, which is a special-cased value
// for legacy keysend invoices.
if *payAddr != BlankPayAddr {
return payAddrIndex.Get(payAddr[:])
}
}
return nil
}
invoiceNumByHash := getInvoiceNumByHash()
invoiceNumByAddr := getInvoiceNumByAddr()
switch { switch {
// If payment address and payment hash both reference an existing // If payment address and payment hash both reference an existing
@ -1745,7 +1766,7 @@ func copyInvoice(src *Invoice) *Invoice {
// updateInvoice fetches the invoice, obtains the update descriptor from the // updateInvoice fetches the invoice, obtains the update descriptor from the
// callback and applies the updates in a single db transaction. // callback and applies the updates in a single db transaction.
func (d *DB) updateInvoice(hash lntypes.Hash, invoices, func (d *DB) updateInvoice(hash *lntypes.Hash, invoices,
settleIndex, setIDIndex kvdb.RwBucket, invoiceNum []byte, settleIndex, setIDIndex kvdb.RwBucket, invoiceNum []byte,
callback InvoiceUpdateCallback) (*Invoice, error) { callback InvoiceUpdateCallback) (*Invoice, error) {
@ -1913,7 +1934,7 @@ func (d *DB) updateInvoice(hash lntypes.Hash, invoices,
} }
// updateInvoiceState validates and processes an invoice state update. // updateInvoiceState validates and processes an invoice state update.
func updateInvoiceState(invoice *Invoice, hash lntypes.Hash, func updateInvoiceState(invoice *Invoice, hash *lntypes.Hash,
update InvoiceStateUpdateDesc) error { update InvoiceStateUpdateDesc) error {
// Returning to open is never allowed from any state. // Returning to open is never allowed from any state.
@ -1962,9 +1983,14 @@ func updateInvoiceState(invoice *Invoice, hash lntypes.Hash,
switch { switch {
// If an invoice-level preimage was supplied, but the InvoiceRef
// doesn't specify a hash (e.g. AMP invoices) we fail.
case update.Preimage != nil && hash == nil:
return ErrUnexpectedInvoicePreimage
// Validate the supplied preimage for non-AMP invoices. // Validate the supplied preimage for non-AMP invoices.
case update.Preimage != nil: case update.Preimage != nil:
if update.Preimage.Hash() != hash { if update.Preimage.Hash() != *hash {
return ErrInvoicePreimageMismatch return ErrInvoicePreimageMismatch
} }
invoice.Terms.PaymentPreimage = update.Preimage invoice.Terms.PaymentPreimage = update.Preimage

@ -377,7 +377,8 @@ func (i *InvoiceRegistry) invoiceEventLoop() {
func (i *InvoiceRegistry) dispatchToSingleClients(event *invoiceEvent) { func (i *InvoiceRegistry) dispatchToSingleClients(event *invoiceEvent) {
// Dispatch to single invoice subscribers. // Dispatch to single invoice subscribers.
for _, client := range i.singleNotificationClients { for _, client := range i.singleNotificationClients {
if client.invoiceRef.PayHash() != event.hash { payHash := client.invoiceRef.PayHash()
if payHash == nil || *payHash != event.hash {
continue continue
} }
@ -524,8 +525,13 @@ func (i *InvoiceRegistry) deliverSingleBacklogEvents(
return err return err
} }
payHash := client.invoiceRef.PayHash()
if payHash == nil {
return nil
}
err = client.notify(&invoiceEvent{ err = client.notify(&invoiceEvent{
hash: client.invoiceRef.PayHash(), hash: *payHash,
invoice: &invoice, invoice: &invoice,
}) })
if err != nil { if err != nil {

@ -29,9 +29,7 @@ func TestSettleInvoice(t *testing.T) {
} }
defer subscription.Cancel() defer subscription.Cancel()
if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { require.Equal(t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash)
t.Fatalf("expected subscription for provided hash")
}
// Add the invoice. // Add the invoice.
addIdx, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash) addIdx, err := ctx.registry.AddInvoice(testInvoice, testInvoicePaymentHash)
@ -244,9 +242,7 @@ func testCancelInvoice(t *testing.T, gc bool) {
} }
defer subscription.Cancel() defer subscription.Cancel()
if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { require.Equal(t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash)
t.Fatalf("expected subscription for provided hash")
}
// Add the invoice. // Add the invoice.
amt := lnwire.MilliSatoshi(100000) amt := lnwire.MilliSatoshi(100000)
@ -404,9 +400,7 @@ func TestSettleHoldInvoice(t *testing.T) {
} }
defer subscription.Cancel() defer subscription.Cancel()
if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { require.Equal(t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash)
t.Fatalf("expected subscription for provided hash")
}
// Add the invoice. // Add the invoice.
_, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash) _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash)
@ -1211,7 +1205,7 @@ func TestSettleInvoicePaymentAddrRequired(t *testing.T) {
defer subscription.Cancel() defer subscription.Cancel()
require.Equal( require.Equal(
t, subscription.invoiceRef.PayHash(), testInvoicePaymentHash, t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash,
) )
// Add the invoice, which requires the MPP payload to always be // Add the invoice, which requires the MPP payload to always be
@ -1287,7 +1281,7 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) {
defer subscription.Cancel() defer subscription.Cancel()
require.Equal( require.Equal(
t, subscription.invoiceRef.PayHash(), testInvoicePaymentHash, t, subscription.invoiceRef.PayHash(), &testInvoicePaymentHash,
) )
// Add the invoice, which requires the MPP payload to always be // Add the invoice, which requires the MPP payload to always be