From cbf71b5452fa1d3036a43309e490787c5f7f08dc Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 May 2020 15:37:39 -0700 Subject: [PATCH] channeldb+invoices: use payment addr as primary index --- channeldb/db.go | 9 +++ channeldb/error.go | 8 ++ channeldb/invoice_test.go | 155 ++++++++++++++++++++++++++++++++---- channeldb/invoices.go | 108 ++++++++++++++++++++++--- htlcswitch/test_utils.go | 24 ++++-- invoices/test_utils_test.go | 12 ++- invoices/update.go | 4 + 7 files changed, 286 insertions(+), 34 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index 9df3bccf..fe2dc149 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcwallet/walletdb" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb/kvdb" + mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" @@ -136,6 +137,13 @@ var ( number: 13, migration: migration13.MigrateMPP, }, + { + // Initialize payment address index and begin using it + // as the default index, falling back to payment hash + // index. + number: 14, + migration: mig.CreateTLB(payAddrIndexBucket), + }, } // Big endian is the preferred byte order, due to cursor scans over @@ -248,6 +256,7 @@ var topLevelBuckets = [][]byte{ forwardingLogBucket, fwdPackagesKey, invoiceBucket, + payAddrIndexBucket, nodeInfoBucket, nodeBucket, edgeBucket, diff --git a/channeldb/error.go b/channeldb/error.go index b1364fb4..97e06a14 100644 --- a/channeldb/error.go +++ b/channeldb/error.go @@ -43,6 +43,14 @@ var ( // payment hash already exists. ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists") + // ErrDuplicatePayAddr is returned when an invoice with the target + // payment addr already exists. + ErrDuplicatePayAddr = fmt.Errorf("invoice with payemnt addr already exists") + + // ErrInvRefEquivocation is returned when an InvoiceRef targets + // multiple, distinct invoices. + ErrInvRefEquivocation = errors.New("inv ref matches multiple invoices") + // ErrNoPaymentsCreated is returned when bucket of payments hasn't been // created. ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments") diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index bd1e6a76..626a039b 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -20,16 +20,20 @@ var ( ) func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { - var pre [32]byte + var pre, payAddr [32]byte if _, err := rand.Read(pre[:]); err != nil { return nil, err } + if _, err := rand.Read(payAddr[:]); err != nil { + return nil, err + } i := &Invoice{ CreationDate: testNow, Terms: ContractTerm{ Expiry: 4000, PaymentPreimage: pre, + PaymentAddr: payAddr, Value: value, Features: emptyFeatures, }, @@ -91,9 +95,45 @@ func TestInvoiceIsPending(t *testing.T) { } } +type invWorkflowTest struct { + name string + queryPayHash bool + queryPayAddr bool +} + +var invWorkflowTests = []invWorkflowTest{ + { + name: "unknown", + queryPayHash: false, + queryPayAddr: false, + }, + { + name: "only payhash known", + queryPayHash: true, + queryPayAddr: false, + }, + { + name: "payaddr and payhash known", + queryPayHash: true, + queryPayAddr: true, + }, +} + +// TestInvoiceWorkflow asserts the basic process of inserting, fetching, and +// updating an invoice. We assert that the flow is successful using when +// querying with various combinations of payment hash and payment address. func TestInvoiceWorkflow(t *testing.T) { t.Parallel() + for _, test := range invWorkflowTests { + test := test + t.Run(test.name, func(t *testing.T) { + testInvoiceWorkflow(t, test) + }) + } +} + +func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { db, cleanUp, err := makeTestDB() defer cleanUp() if err != nil { @@ -102,23 +142,33 @@ func TestInvoiceWorkflow(t *testing.T) { // Create a fake invoice which we'll use several times in the tests // below. - fakeInvoice := &Invoice{ - CreationDate: testNow, - Htlcs: map[CircuitKey]*InvoiceHTLC{}, + fakeInvoice, err := randInvoice(10000) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) } - fakeInvoice.Memo = []byte("memo") - fakeInvoice.PaymentRequest = []byte("") - copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) - fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) - fakeInvoice.Terms.Features = emptyFeatures + invPayHash := fakeInvoice.Terms.PaymentPreimage.Hash() - paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() - ref := InvoiceRefByHash(paymentHash) + // Select the payment hash and payment address we will use to lookup or + // update the invoice for the remainder of the test. + var ( + payHash lntypes.Hash + payAddr *[32]byte + ref InvoiceRef + ) + switch { + case test.queryPayHash && test.queryPayAddr: + payHash = invPayHash + payAddr = &fakeInvoice.Terms.PaymentAddr + ref = InvoiceRefByHashAndAddr(payHash, *payAddr) + case test.queryPayHash: + payHash = invPayHash + ref = InvoiceRefByHash(payHash) + } // Add the invoice to the database, this should succeed as there aren't // any existing invoices within the database with the same payment // hash. - if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil { + if _, err := db.AddInvoice(fakeInvoice, invPayHash); err != nil { t.Fatalf("unable to find invoice: %v", err) } @@ -126,8 +176,11 @@ func TestInvoiceWorkflow(t *testing.T) { // database. It should be found, and the invoice returned should be // identical to the one created above. dbInvoice, err := db.LookupInvoice(ref) - if err != nil { - t.Fatalf("unable to find invoice: %v", err) + if !test.queryPayAddr && !test.queryPayHash { + if err != ErrInvoiceNotFound { + t.Fatalf("invoice should not exist: %v", err) + } + return } if !reflect.DeepEqual(*fakeInvoice, dbInvoice) { t.Fatalf("invoice fetched from db doesn't match original %v vs %v", @@ -174,7 +227,7 @@ func TestInvoiceWorkflow(t *testing.T) { // Attempt to insert generated above again, this should fail as // duplicates are rejected by the processing logic. - if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice { + if _, err := db.AddInvoice(fakeInvoice, payHash); err != ErrDuplicateInvoice { t.Fatalf("invoice insertion should fail due to duplication, "+ "instead %v", err) } @@ -232,6 +285,70 @@ func TestInvoiceWorkflow(t *testing.T) { } } +// TestAddDuplicatePayAddr asserts that the payment addresses of inserted +// invoices are unique. +func TestAddDuplicatePayAddr(t *testing.T) { + db, cleanUp, err := makeTestDB() + defer cleanUp() + assert.Nil(t, err) + + // Create two invoices with the same payment addr. + invoice1, err := randInvoice(1000) + assert.Nil(t, err) + + invoice2, err := randInvoice(20000) + assert.Nil(t, err) + invoice2.Terms.PaymentAddr = invoice1.Terms.PaymentAddr + + // First insert should succeed. + inv1Hash := invoice1.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice1, inv1Hash) + assert.Nil(t, err) + + // Second insert should fail with duplicate payment addr. + inv2Hash := invoice2.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice2, inv2Hash) + assert.Equal(t, ErrDuplicatePayAddr, err) +} + +// TestInvRefEquivocation asserts that retrieving or updating an invoice using +// an equivocating InvoiceRef results in ErrInvRefEquivocation. +func TestInvRefEquivocation(t *testing.T) { + db, cleanUp, err := makeTestDB() + defer cleanUp() + assert.Nil(t, err) + + // Add two random invoices. + invoice1, err := randInvoice(1000) + assert.Nil(t, err) + + inv1Hash := invoice1.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice1, inv1Hash) + assert.Nil(t, err) + + invoice2, err := randInvoice(2000) + assert.Nil(t, err) + + inv2Hash := invoice2.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice2, inv2Hash) + assert.Nil(t, err) + + // Now, query using invoice 1's payment address, but invoice 2's payment + // hash. We expect an error since the invref points to multiple + // invoices. + ref := InvoiceRefByHashAndAddr(inv2Hash, invoice1.Terms.PaymentAddr) + _, err = db.LookupInvoice(ref) + assert.Equal(t, ErrInvRefEquivocation, err) + + // The same error should be returned when updating an equivocating + // reference. + nop := func(_ *Invoice) (*InvoiceUpdateDesc, error) { + return nil, nil + } + _, err = db.UpdateInvoice(ref, nop) + assert.Equal(t, ErrInvRefEquivocation, err) +} + // TestInvoiceCancelSingleHtlc tests that a single htlc can be canceled on the // invoice. func TestInvoiceCancelSingleHtlc(t *testing.T) { @@ -991,9 +1108,17 @@ func TestCustomRecords(t *testing.T) { // InvoiceRef depending on the constructor used. func TestInvoiceRef(t *testing.T) { payHash := lntypes.Hash{0x01} + payAddr := [32]byte{0x02} // An InvoiceRef by hash should return the provided hash and a nil // payment addr. refByHash := InvoiceRefByHash(payHash) assert.Equal(t, payHash, refByHash.PayHash()) + assert.Equal(t, (*[32]byte)(nil), refByHash.PayAddr()) + + // An InvoiceRef by hash and addr should return the payment hash and + // payment addr passed to the constructor. + refByHashAndAddr := InvoiceRefByHashAndAddr(payHash, payAddr) + assert.Equal(t, payHash, refByHashAndAddr.PayHash()) + assert.Equal(t, &payAddr, refByHashAndAddr.PayAddr()) } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 83c8de13..3bb005f0 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -37,6 +37,16 @@ var ( // maps: payHash => invoiceKey invoiceIndexBucket = []byte("paymenthashes") + // payAddrIndexBucket is the name of the top-level bucket that maps + // payment addresses to their invoice number. This can be used + // to efficiently query or update non-legacy invoices. Note that legacy + // invoices will not be included in this index since they all have the + // same, all-zero payment address, however all newly generated invoices + // will end up in this index. + // + // maps: payAddr => invoiceKey + payAddrIndexBucket = []byte("pay-addr-index") + // numInvoicesKey is the name of key which houses the auto-incrementing // invoice ID which is essentially used as a primary key. With each // invoice inserted, the primary key is incremented by one. This key is @@ -142,12 +152,23 @@ const ( amtPaidType tlv.Type = 13 ) -// InvoiceRef is an identifier for invoices supporting queries by payment hash. +// InvoiceRef is a composite identifier for invoices. Invoices can be referenced +// by various combinations of payment hash and payment addr, in certain contexts +// only some of these are known. An InvoiceRef and its constructors thus +// encapsulate the valid combinations of query parameters that can be supplied +// to LookupInvoice and UpdateInvoice. type InvoiceRef struct { // payHash is the payment hash of the target invoice. All invoices are // currently indexed by payment hash. This value will be used as a // fallback when no payment address is known. payHash lntypes.Hash + + // payAddr is the payment addr of the target invoice. Newer invoices + // (0.11 and up) are indexed by payment address in addition to payment + // hash, but pre 0.8 invoices do not have one at all. When this value is + // known it will be used as the primary identifier, falling back to + // payHash if no value is known. + payAddr *[32]byte } // InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by @@ -158,13 +179,39 @@ func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef { } } +// InvoiceRefByHashAndAddr creates an InvoiceRef that first queries for an +// invoice by the provided payment address, falling back to the payment hash if +// the payment address is unknown. +func InvoiceRefByHashAndAddr(payHash lntypes.Hash, + payAddr [32]byte) InvoiceRef { + + return InvoiceRef{ + payHash: payHash, + payAddr: &payAddr, + } +} + // PayHash returns the target invoice's payment hash. func (r InvoiceRef) PayHash() lntypes.Hash { return r.payHash } +// PayAddr returns the optional payment address of the target invoice. +// +// NOTE: This value may be nil. +func (r InvoiceRef) PayAddr() *[32]byte { + if r.payAddr != nil { + addr := *r.payAddr + return &addr + } + return nil +} + // String returns a human-readable representation of an InvoiceRef. func (r InvoiceRef) String() string { + if r.payAddr != nil { + return fmt.Sprintf("(pay_hash=%v, pay_addr=%x)", r.payHash, *r.payAddr) + } return fmt.Sprintf("(pay_hash=%v)", r.payHash) } @@ -458,6 +505,11 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( return ErrDuplicateInvoice } + payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket) + if payAddrIndex.Get(newInvoice.Terms.PaymentAddr[:]) != nil { + return ErrDuplicatePayAddr + } + // If the current running payment ID counter hasn't yet been // created, then create it now. var invoiceNum uint32 @@ -474,8 +526,8 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( } newIndex, err := putInvoice( - invoices, invoiceIndex, addIndex, newInvoice, invoiceNum, - paymentHash, + invoices, invoiceIndex, payAddrIndex, addIndex, + newInvoice, invoiceNum, paymentHash, ) if err != nil { return err @@ -575,11 +627,12 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { if invoiceIndex == nil { return ErrNoInvoicesCreated } + payAddrIndex := tx.ReadBucket(payAddrIndexBucket) // Retrieve the invoice number for this invoice using the // provided invoice reference. invoiceNum, err := fetchInvoiceNumByRef( - invoiceIndex, ref, + invoiceIndex, payAddrIndex, ref, ) if err != nil { return err @@ -603,18 +656,44 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { } // fetchInvoiceNumByRef retrieve the invoice number for the provided invoice -// reference. -func fetchInvoiceNumByRef(invoiceIndex kvdb.ReadBucket, +// reference. The payment address will be treated as the primary key, falling +// back to the payment hash if nothing is found for the payment address. An +// error is returned if the invoice is not found. +func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.ReadBucket, ref InvoiceRef) ([]byte, error) { payHash := ref.PayHash() + payAddr := ref.PayAddr() - invoiceNum := invoiceIndex.Get(payHash[:]) - if invoiceNum == nil { - return nil, ErrInvoiceNotFound + var ( + invoiceNumByHash = invoiceIndex.Get(payHash[:]) + invoiceNumByAddr []byte + ) + if payAddr != nil { + invoiceNumByAddr = payAddrIndex.Get(payAddr[:]) } - return invoiceNum, nil + switch { + + // If payment address and payment hash both reference an existing + // invoice, ensure they reference the _same_ invoice. + case invoiceNumByAddr != nil && invoiceNumByHash != nil: + if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) { + return nil, ErrInvRefEquivocation + } + + return invoiceNumByAddr, nil + + // If we were only able to reference the invoice by hash, return the + // corresponding invoice number. This can happen when no payment address + // was provided, or if it didn't match anything in our records. + case invoiceNumByHash != nil: + return invoiceNumByHash, nil + + // Otherwise we don't know of the target invoice. + default: + return nil, ErrInvoiceNotFound + } } // InvoiceWithPaymentHash is used to store an invoice and its corresponding @@ -888,11 +967,12 @@ func (d *DB) UpdateInvoice(ref InvoiceRef, if err != nil { return err } + payAddrIndex := tx.ReadBucket(payAddrIndexBucket) // Retrieve the invoice number for this invoice using the // provided invoice reference. invoiceNum, err := fetchInvoiceNumByRef( - invoiceIndex, ref, + invoiceIndex, payAddrIndex, ref, ) if err != nil { return err @@ -971,7 +1051,7 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { return settledInvoices, nil } -func putInvoice(invoices, invoiceIndex, addIndex kvdb.RwBucket, +func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket, i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) ( uint64, error) { @@ -996,6 +1076,10 @@ func putInvoice(invoices, invoiceIndex, addIndex kvdb.RwBucket, if err != nil { return 0, err } + err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:]) + if err != nil { + return 0, err + } // Next, we'll obtain the next add invoice index (sequence // number), so we can properly place this invoice within this diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 2da6e18b..429963d2 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -2,7 +2,7 @@ package htlcswitch import ( "bytes" - "crypto/rand" + crand "crypto/rand" "crypto/sha256" "encoding/binary" "fmt" @@ -137,7 +137,7 @@ func generateRandomBytes(n int) ([]byte, error) { // TODO(roasbeef): should use counter in tests (atomic) rather than // this - _, err := rand.Read(b[:]) + _, err := crand.Read(b) // Note that Err == nil only if we read len(b) bytes. if err != nil { return nil, err @@ -547,7 +547,7 @@ func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) { // invoice which should be added by destination peer. func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, blob [lnwire.OnionPacketSize]byte, - preimage, rhash [32]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, + preimage, rhash, payAddr [32]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, uint64, error) { // Create the db invoice. Normally the payment requests needs to be set, @@ -562,6 +562,7 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, FinalCltvDelta: testInvoiceCltvExpiry, Value: invoiceAmt, PaymentPreimage: preimage, + PaymentAddr: payAddr, Features: lnwire.NewFeatureVector( nil, lnwire.Features, ), @@ -598,8 +599,16 @@ func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, copy(preimage[:], r) rhash := sha256.Sum256(preimage[:]) + + var payAddr [sha256.Size]byte + r, err = generateRandomBytes(sha256.Size) + if err != nil { + return nil, nil, 0, err + } + copy(payAddr[:], r) + return generatePaymentWithPreimage( - invoiceAmt, htlcAmt, timelock, blob, preimage, rhash, + invoiceAmt, htlcAmt, timelock, blob, preimage, rhash, payAddr, ) } @@ -1328,10 +1337,15 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, rhash := preimage.Hash() + var payAddr [32]byte + if _, err := crand.Read(payAddr[:]); err != nil { + panic(err) + } + // Generate payment: invoice and htlc. invoice, htlc, pid, err := generatePaymentWithPreimage( invoiceAmt, htlcAmt, timelock, blob, - channeldb.UnknownPreimage, rhash, + channeldb.UnknownPreimage, rhash, payAddr, ) if err != nil { paymentErr <- err diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index cf0f14ea..8d98b132 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -1,6 +1,7 @@ package invoices import ( + "crypto/rand" "encoding/binary" "encoding/hex" "fmt" @@ -198,14 +199,20 @@ func newTestInvoice(t *testing.T, preimage lntypes.Preimage, expiry = time.Hour } + var payAddr [32]byte + if _, err := rand.Read(payAddr[:]); err != nil { + t.Fatalf("unable to generate payment addr: %v", err) + } + rawInvoice, err := zpay32.NewInvoice( testNetParams, preimage.Hash(), timestamp, zpay32.Amount(testInvoiceAmount), zpay32.Description(testInvoiceDescription), - zpay32.Expiry(expiry)) - + zpay32.Expiry(expiry), + zpay32.PaymentAddr(payAddr), + ) if err != nil { t.Fatalf("Error while creating new invoice: %v", err) } @@ -219,6 +226,7 @@ func newTestInvoice(t *testing.T, preimage lntypes.Preimage, return &channeldb.Invoice{ Terms: channeldb.ContractTerm{ PaymentPreimage: preimage, + PaymentAddr: payAddr, Value: testInvoiceAmount, Expiry: expiry, Features: testFeatures, diff --git a/invoices/update.go b/invoices/update.go index 4680b3cd..62522378 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -25,6 +25,10 @@ type invoiceUpdateCtx struct { // invoiceRef returns an identifier that can be used to lookup or update the // invoice this HTLC is targeting. func (i *invoiceUpdateCtx) invoiceRef() channeldb.InvoiceRef { + if i.mpp != nil { + payAddr := i.mpp.PaymentAddr() + return channeldb.InvoiceRefByHashAndAddr(i.hash, payAddr) + } return channeldb.InvoiceRefByHash(i.hash) }