diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index e0ec2191..10148917 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -313,6 +313,48 @@ func TestAddDuplicatePayAddr(t *testing.T) { require.Error(t, err, ErrDuplicatePayAddr) } +// TestAddDuplicateKeysendPayAddr asserts that we permit duplicate payment +// addresses to be inserted if they are blank to support JIT legacy keysend +// invoices. +func TestAddDuplicateKeysendPayAddr(t *testing.T) { + db, cleanUp, err := makeTestDB() + defer cleanUp() + require.NoError(t, err) + + // Create two invoices with the same _blank_ payment addr. + invoice1, err := randInvoice(1000) + require.NoError(t, err) + invoice1.Terms.PaymentAddr = BlankPayAddr + + invoice2, err := randInvoice(20000) + require.NoError(t, err) + invoice2.Terms.PaymentAddr = BlankPayAddr + + // Inserting both should succeed without a duplicate payment address + // failure. + inv1Hash := invoice1.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice1, inv1Hash) + require.NoError(t, err) + + inv2Hash := invoice2.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice2, inv2Hash) + require.NoError(t, err) + + // Querying for each should succeed. Here we use hash+addr refs since + // the lookup will fail if the hash and addr point to different + // invoices, so if both succeed we can be assured they aren't included + // in the payment address index. + ref1 := InvoiceRefByHashAndAddr(inv1Hash, BlankPayAddr) + dbInv1, err := db.LookupInvoice(ref1) + require.NoError(t, err) + require.Equal(t, invoice1, &dbInv1) + + ref2 := InvoiceRefByHashAndAddr(inv2Hash, BlankPayAddr) + dbInv2, err := db.LookupInvoice(ref2) + require.NoError(t, err) + require.Equal(t, invoice2, &dbInv2) +} + // TestInvRefEquivocation asserts that retrieving or updating an invoice using // an equivocating InvoiceRef results in ErrInvRefEquivocation. func TestInvRefEquivocation(t *testing.T) { diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 07de2add..436f194e 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -21,6 +21,12 @@ var ( // preimage for this invoice is not yet known. unknownPreimage lntypes.Preimage + // BlankPayAddr is a sentinel payment address for legacy invoices. + // Invoices with this payment address are special-cased in the insertion + // logic to prevent being indexed in the payment address index, + // otherwise they would cause collisions after the first insertion. + BlankPayAddr [32]byte + // invoiceBucket is the name of the bucket within the database that // stores all data related to invoices no matter their final state. // Within the invoice bucket, each invoice is keyed by its invoice ID @@ -519,9 +525,16 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( return ErrDuplicateInvoice } + // Check that we aren't inserting an invoice with a duplicate + // payment address. The all-zeros payment address is + // special-cased to support legacy keysend invoices which don't + // assign one. This is safe since later we also will avoid + // indexing them and avoid collisions. payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket) - if payAddrIndex.Get(newInvoice.Terms.PaymentAddr[:]) != nil { - return ErrDuplicatePayAddr + if newInvoice.Terms.PaymentAddr != BlankPayAddr { + if payAddrIndex.Get(newInvoice.Terms.PaymentAddr[:]) != nil { + return ErrDuplicatePayAddr + } } // If the current running payment ID counter hasn't yet been @@ -679,7 +692,12 @@ func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.RBucket, invoiceNumByAddr []byte ) if payAddr != nil { - invoiceNumByAddr = payAddrIndex.Get(payAddr[:]) + // Only allow lookups for payment address if it is not a blank + // payment address, which is a special-cased value for legacy + // keysend invoices. + if *payAddr != BlankPayAddr { + invoiceNumByAddr = payAddrIndex.Get(payAddr[:]) + } } switch { @@ -1047,9 +1065,15 @@ func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket, if err != nil { return 0, err } - err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:]) - if err != nil { - return 0, err + // Add the invoice to the payment address index, but only if the invoice + // has a non-zero payment address. The all-zero payment address is still + // in use by legacy keysend, so we special-case here to avoid + // collisions. + if i.Terms.PaymentAddr != BlankPayAddr { + err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:]) + if err != nil { + return 0, err + } } // Next, we'll obtain the next add invoice index (sequence diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index abb9f779..e2af66a9 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -676,6 +676,15 @@ func (i *InvoiceRegistry) processKeySend(ctx invoiceUpdateCtx) error { return errors.New("final expiry too soon") } + // The invoice database indexes all invoices by payment address, however + // legacy keysend payment do not have one. In order to avoid a new + // payment type on-disk wrt. to indexing, we'll continue to insert a + // blank payment address which is special cased in the insertion logic + // to not be indexed. In the future, once AMP is merged, this should be + // replaced by generating a random payment address on the behalf of the + // sender. + payAddr := channeldb.BlankPayAddr + // Create placeholder invoice. invoice := &channeldb.Invoice{ CreationDate: i.cfg.Clock.Now(), @@ -683,6 +692,7 @@ func (i *InvoiceRegistry) processKeySend(ctx invoiceUpdateCtx) error { FinalCltvDelta: finalCltvDelta, Value: amt, PaymentPreimage: &preimage, + PaymentAddr: payAddr, Features: features, }, } diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index fa08d5db..a9131f8a 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -725,29 +725,59 @@ func testKeySend(t *testing.T, keySendEnabled bool) { return } - // Otherwise we expect no error and a settle resolution for the htlc. - settleResolution, ok := resolution.(*HtlcSettleResolution) - if !ok { - t.Fatalf("expected settle resolution, got: %T", - resolution) + checkResolution := func(res HtlcResolution, pimg lntypes.Preimage) { + // Otherwise we expect no error and a settle res for the htlc. + settleResolution, ok := res.(*HtlcSettleResolution) + assert.True(t, ok) + assert.Equal(t, settleResolution.Preimage, pimg) } - if settleResolution.Preimage != preimage { - t.Fatalf("expected settle with matching preimage") + checkSubscription := func() { + // We expect a new invoice notification to be sent out. + newInvoice := <-allSubscriptions.NewInvoices + assert.Equal(t, newInvoice.State, channeldb.ContractOpen) + + // We expect a settled notification to be sent out. + settledInvoice := <-allSubscriptions.SettledInvoices + assert.Equal(t, settledInvoice.State, channeldb.ContractSettled) } - // We expect a new invoice notification to be sent out. - newInvoice := <-allSubscriptions.NewInvoices - if newInvoice.State != channeldb.ContractOpen { - t.Fatalf("expected state ContractOpen, but got %v", - newInvoice.State) + checkResolution(resolution, preimage) + checkSubscription() + + // Replay the same keysend payment. We expect an identical resolution, + // but no event should be generated. + resolution, err = ctx.registry.NotifyExitHopHtlc( + hash, amt, expiry, + testCurrentHeight, getCircuitKey(10), hodlChan, keySendPayload, + ) + assert.Nil(t, err) + checkResolution(resolution, preimage) + + select { + case <-allSubscriptions.NewInvoices: + t.Fatalf("replayed keysend should not generate event") + case <-time.After(time.Second): } - // We expect a settled notification to be sent out. - settledInvoice := <-allSubscriptions.SettledInvoices - if settledInvoice.State != channeldb.ContractSettled { - t.Fatalf("expected state ContractOpen, but got %v", - settledInvoice.State) + // Finally, test that we can properly fulfill a second keysend payment + // with a unique preiamge. + preimage2 := lntypes.Preimage{1, 2, 3, 4} + hash2 := preimage2.Hash() + + keySendPayload2 := &mockPayload{ + customRecords: map[uint64][]byte{ + record.KeySendType: preimage2[:], + }, } + + resolution, err = ctx.registry.NotifyExitHopHtlc( + hash2, amt, expiry, + testCurrentHeight, getCircuitKey(20), hodlChan, keySendPayload2, + ) + assert.Nil(t, err) + + checkResolution(resolution, preimage2) + checkSubscription() } // TestMppPayment tests settling of an invoice with multiple partial payments.