diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index e0ec2191..10148917 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -313,6 +313,48 @@ func TestAddDuplicatePayAddr(t *testing.T) { require.Error(t, err, ErrDuplicatePayAddr) } +// TestAddDuplicateKeysendPayAddr asserts that we permit duplicate payment +// addresses to be inserted if they are blank to support JIT legacy keysend +// invoices. +func TestAddDuplicateKeysendPayAddr(t *testing.T) { + db, cleanUp, err := makeTestDB() + defer cleanUp() + require.NoError(t, err) + + // Create two invoices with the same _blank_ payment addr. + invoice1, err := randInvoice(1000) + require.NoError(t, err) + invoice1.Terms.PaymentAddr = BlankPayAddr + + invoice2, err := randInvoice(20000) + require.NoError(t, err) + invoice2.Terms.PaymentAddr = BlankPayAddr + + // Inserting both should succeed without a duplicate payment address + // failure. + inv1Hash := invoice1.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice1, inv1Hash) + require.NoError(t, err) + + inv2Hash := invoice2.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice2, inv2Hash) + require.NoError(t, err) + + // Querying for each should succeed. Here we use hash+addr refs since + // the lookup will fail if the hash and addr point to different + // invoices, so if both succeed we can be assured they aren't included + // in the payment address index. + ref1 := InvoiceRefByHashAndAddr(inv1Hash, BlankPayAddr) + dbInv1, err := db.LookupInvoice(ref1) + require.NoError(t, err) + require.Equal(t, invoice1, &dbInv1) + + ref2 := InvoiceRefByHashAndAddr(inv2Hash, BlankPayAddr) + dbInv2, err := db.LookupInvoice(ref2) + require.NoError(t, err) + require.Equal(t, invoice2, &dbInv2) +} + // TestInvRefEquivocation asserts that retrieving or updating an invoice using // an equivocating InvoiceRef results in ErrInvRefEquivocation. func TestInvRefEquivocation(t *testing.T) { diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 07de2add..436f194e 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -21,6 +21,12 @@ var ( // preimage for this invoice is not yet known. unknownPreimage lntypes.Preimage + // BlankPayAddr is a sentinel payment address for legacy invoices. + // Invoices with this payment address are special-cased in the insertion + // logic to prevent being indexed in the payment address index, + // otherwise they would cause collisions after the first insertion. + BlankPayAddr [32]byte + // invoiceBucket is the name of the bucket within the database that // stores all data related to invoices no matter their final state. // Within the invoice bucket, each invoice is keyed by its invoice ID @@ -519,9 +525,16 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( return ErrDuplicateInvoice } + // Check that we aren't inserting an invoice with a duplicate + // payment address. The all-zeros payment address is + // special-cased to support legacy keysend invoices which don't + // assign one. This is safe since later we also will avoid + // indexing them and avoid collisions. payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket) - if payAddrIndex.Get(newInvoice.Terms.PaymentAddr[:]) != nil { - return ErrDuplicatePayAddr + if newInvoice.Terms.PaymentAddr != BlankPayAddr { + if payAddrIndex.Get(newInvoice.Terms.PaymentAddr[:]) != nil { + return ErrDuplicatePayAddr + } } // If the current running payment ID counter hasn't yet been @@ -679,7 +692,12 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.RBucket, invoiceNumByAddr []byte ) if payAddr != nil { - invoiceNumByAddr = payAddrIndex.Get(payAddr[:]) + // Only allow lookups for payment address if it is not a blank + // payment address, which is a special-cased value for legacy + // keysend invoices. + if *payAddr != BlankPayAddr { + invoiceNumByAddr = payAddrIndex.Get(payAddr[:]) + } } switch { @@ -1047,9 +1065,15 @@ func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket, if err != nil { return 0, err } - err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:]) - if err != nil { - return 0, err + // Add the invoice to the payment address index, but only if the invoice + // has a non-zero payment address. The all-zero payment address is still + // in use by legacy keysend, so we special-case here to avoid + // collisions. + if i.Terms.PaymentAddr != BlankPayAddr { + err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:]) + if err != nil { + return 0, err + } } // Next, we'll obtain the next add invoice index (sequence