From d343575104e1dadbe0a0296d4b61f7660fb5ff8a Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 May 2020 15:35:51 -0700 Subject: [PATCH 1/8] channeldb/migtest: log migration failure instead of failing This line was incorrectly moved when the migtest package was created for migration 12. This PR introduces a negative test for CreateTLB which surfaced this. --- channeldb/migtest/migtest.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/channeldb/migtest/migtest.go b/channeldb/migtest/migtest.go index 0b8e14f0..09edc033 100644 --- a/channeldb/migtest/migtest.go +++ b/channeldb/migtest/migtest.go @@ -74,7 +74,7 @@ func ApplyMigration(t *testing.T, // Apply migration. err = kvdb.Update(cdb, migrationFunc) if err != nil { - t.Fatal(err) + t.Logf("migration error: %v", err) } } From 24cce7a6ec3edad7601f51b9923ee20571c1b9f9 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 May 2020 15:36:16 -0700 Subject: [PATCH 2/8] channeldb: consolidate top-level bucket create/wipe --- channeldb/db.go | 99 +++++++++++++------------------------------------ 1 file changed, 26 insertions(+), 73 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index b6c7daf5..9df3bccf 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -242,48 +242,31 @@ func (d *DB) Path() string { return d.dbPath } +var topLevelBuckets = [][]byte{ + openChannelBucket, + closedChannelBucket, + forwardingLogBucket, + fwdPackagesKey, + invoiceBucket, + nodeInfoBucket, + nodeBucket, + edgeBucket, + edgeIndexBucket, + graphMetaBucket, + metaBucket, +} + // Wipe completely deletes all saved state within all used buckets within the // database. The deletion is done in a single transaction, therefore this // operation is fully atomic. func (d *DB) Wipe() error { return kvdb.Update(d, func(tx kvdb.RwTx) error { - err := tx.DeleteTopLevelBucket(openChannelBucket) - if err != nil && err != kvdb.ErrBucketNotFound { - return err + for _, tlb := range topLevelBuckets { + err := tx.DeleteTopLevelBucket(tlb) + if err != nil && err != kvdb.ErrBucketNotFound { + return err + } } - - err = tx.DeleteTopLevelBucket(closedChannelBucket) - if err != nil && err != kvdb.ErrBucketNotFound { - return err - } - - err = tx.DeleteTopLevelBucket(invoiceBucket) - if err != nil && err != kvdb.ErrBucketNotFound { - return err - } - - err = tx.DeleteTopLevelBucket(nodeInfoBucket) - if err != nil && err != kvdb.ErrBucketNotFound { - return err - } - - err = tx.DeleteTopLevelBucket(nodeBucket) - if err != nil && err != kvdb.ErrBucketNotFound { - return err - } - err = tx.DeleteTopLevelBucket(edgeBucket) - if err != nil && err != kvdb.ErrBucketNotFound { - return err - } - err = tx.DeleteTopLevelBucket(edgeIndexBucket) - if err != nil && err != kvdb.ErrBucketNotFound { - return err - } - err = tx.DeleteTopLevelBucket(graphMetaBucket) - if err != nil && err != kvdb.ErrBucketNotFound { - return err - } - return nil }) } @@ -301,33 +284,13 @@ func initChannelDB(db kvdb.Backend) error { return nil } - if _, err := tx.CreateTopLevelBucket(openChannelBucket); err != nil { - return err - } - if _, err := tx.CreateTopLevelBucket(closedChannelBucket); err != nil { - return err + for _, tlb := range topLevelBuckets { + if _, err := tx.CreateTopLevelBucket(tlb); err != nil { + return err + } } - if _, err := tx.CreateTopLevelBucket(forwardingLogBucket); err != nil { - return err - } - - if _, err := tx.CreateTopLevelBucket(fwdPackagesKey); err != nil { - return err - } - - if _, err := tx.CreateTopLevelBucket(invoiceBucket); err != nil { - return err - } - - if _, err := tx.CreateTopLevelBucket(nodeInfoBucket); err != nil { - return err - } - - nodes, err := tx.CreateTopLevelBucket(nodeBucket) - if err != nil { - return err - } + nodes := tx.ReadWriteBucket(nodeBucket) _, err = nodes.CreateBucket(aliasIndexBucket) if err != nil { return err @@ -337,10 +300,7 @@ func initChannelDB(db kvdb.Backend) error { return err } - edges, err := tx.CreateTopLevelBucket(edgeBucket) - if err != nil { - return err - } + edges := tx.ReadWriteBucket(edgeBucket) if _, err := edges.CreateBucket(edgeIndexBucket); err != nil { return err } @@ -354,19 +314,12 @@ func initChannelDB(db kvdb.Backend) error { return err } - graphMeta, err := tx.CreateTopLevelBucket(graphMetaBucket) - if err != nil { - return err - } + graphMeta := tx.ReadWriteBucket(graphMetaBucket) _, err = graphMeta.CreateBucket(pruneLogBucket) if err != nil { return err } - if _, err := tx.CreateTopLevelBucket(metaBucket); err != nil { - return err - } - meta.DbVersionNumber = getLatestDBVersion(dbVersions) return putMeta(meta, tx) }) From e80e21d1a8102c45223ad7a0f6fe6d41176ca49e Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 May 2020 15:36:26 -0700 Subject: [PATCH 3/8] channeldb/migration: add generic CreateTLB migration This commit creates a generic migration for creating top-level buckets. --- channeldb/log.go | 2 + channeldb/migration/create_tlb.go | 27 ++++++++++++ channeldb/migration/create_tlb_test.go | 57 ++++++++++++++++++++++++++ channeldb/migration/log.go | 12 ++++++ 4 files changed, 98 insertions(+) create mode 100644 channeldb/migration/create_tlb.go create mode 100644 channeldb/migration/create_tlb_test.go create mode 100644 channeldb/migration/log.go diff --git a/channeldb/log.go b/channeldb/log.go index 7490c6bf..f59426f0 100644 --- a/channeldb/log.go +++ b/channeldb/log.go @@ -3,6 +3,7 @@ package channeldb import ( "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" + mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" @@ -28,6 +29,7 @@ func DisableLog() { // using btclog. func UseLogger(logger btclog.Logger) { log = logger + mig.UseLogger(logger) migration_01_to_11.UseLogger(logger) migration12.UseLogger(logger) migration13.UseLogger(logger) diff --git a/channeldb/migration/create_tlb.go b/channeldb/migration/create_tlb.go new file mode 100644 index 00000000..7c31ec92 --- /dev/null +++ b/channeldb/migration/create_tlb.go @@ -0,0 +1,27 @@ +package migration + +import ( + "fmt" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" +) + +// CreateTLB creates a new top-level bucket with the passed bucket identifier. +func CreateTLB(bucket []byte) func(kvdb.RwTx) error { + return func(tx kvdb.RwTx) error { + log.Infof("Creating top-level bucket: \"%s\" ...", bucket) + + if tx.ReadBucket(bucket) != nil { + return fmt.Errorf("top-level bucket \"%s\" "+ + "already exists", bucket) + } + + _, err := tx.CreateTopLevelBucket(bucket) + if err != nil { + return err + } + + log.Infof("Created top-level bucket: \"%s\"", bucket) + return nil + } +} diff --git a/channeldb/migration/create_tlb_test.go b/channeldb/migration/create_tlb_test.go new file mode 100644 index 00000000..2f422136 --- /dev/null +++ b/channeldb/migration/create_tlb_test.go @@ -0,0 +1,57 @@ +package migration_test + +import ( + "fmt" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" + "github.com/lightningnetwork/lnd/channeldb/migration" + "github.com/lightningnetwork/lnd/channeldb/migtest" +) + +// TestCreateTLB asserts that a CreateTLB properly initializes a new top-level +// bucket, and that it succeeds even if the bucket already exists. It would +// probably be better if the latter failed, but the kvdb abstraction doesn't +// support this. +func TestCreateTLB(t *testing.T) { + newBucket := []byte("hello") + + tests := []struct { + name string + beforeMigration func(kvdb.RwTx) error + shouldFail bool + }{ + { + name: "already exists", + beforeMigration: func(tx kvdb.RwTx) error { + _, err := tx.CreateTopLevelBucket(newBucket) + return err + }, + shouldFail: true, + }, + { + name: "does not exist", + beforeMigration: func(_ kvdb.RwTx) error { return nil }, + shouldFail: false, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + migtest.ApplyMigration( + t, + test.beforeMigration, + func(tx kvdb.RwTx) error { + if tx.ReadBucket(newBucket) != nil { + return nil + } + return fmt.Errorf("bucket \"%s\" not "+ + "created", newBucket) + }, + migration.CreateTLB(newBucket), + test.shouldFail, + ) + }) + } +} diff --git a/channeldb/migration/log.go b/channeldb/migration/log.go new file mode 100644 index 00000000..5085596d --- /dev/null +++ b/channeldb/migration/log.go @@ -0,0 +1,12 @@ +package migration + +import "github.com/btcsuite/btclog" + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} From 2799202fd968c5eca45591c6b1dda93df95e875d Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 May 2020 15:36:44 -0700 Subject: [PATCH 4/8] invoices/invoiceregistry: rename updateCtx to ctx --- invoices/invoiceregistry.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 662a8a82..4c1d9b2a 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -733,11 +733,9 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, circuitKey channeldb.CircuitKey, hodlChan chan<- interface{}, payload Payload) (HtlcResolution, error) { - mpp := payload.MultiPath() - // Create the update context containing the relevant details of the // incoming htlc. - updateCtx := invoiceUpdateCtx{ + ctx := invoiceUpdateCtx{ hash: rHash, circuitKey: circuitKey, amtPaid: amtPaid, @@ -745,16 +743,16 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, currentHeight: currentHeight, finalCltvRejectDelta: i.cfg.FinalCltvRejectDelta, customRecords: payload.CustomRecords(), - mpp: mpp, + mpp: payload.MultiPath(), } // Process keysend if present. Do this outside of the lock, because // AddInvoice obtains its own lock. This is no problem, because the // operation is idempotent. if i.cfg.AcceptKeySend { - err := i.processKeySend(updateCtx) + err := i.processKeySend(ctx) if err != nil { - updateCtx.log(fmt.Sprintf("keysend error: %v", err)) + ctx.log(fmt.Sprintf("keysend error: %v", err)) return NewFailResolution( circuitKey, currentHeight, ResultKeySendError, @@ -764,7 +762,7 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, // Execute locked notify exit hop logic. i.Lock() - resolution, err := i.notifyExitHopHtlcLocked(&updateCtx, hodlChan) + resolution, err := i.notifyExitHopHtlcLocked(&ctx, hodlChan) i.Unlock() if err != nil { return nil, err From 3522f09a087f4e6fbe3fcc40fd055b42cc36c389 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 May 2020 15:37:10 -0700 Subject: [PATCH 5/8] channeldb+invoices: track invoices by InvoiceRef --- channeldb/invoice_test.go | 64 ++++++++++++++++---------- channeldb/invoices.go | 78 +++++++++++++++++++++++++------- invoices/invoiceregistry.go | 69 +++++++++++++++------------- invoices/invoiceregistry_test.go | 6 +-- invoices/update.go | 10 +++- 5 files changed, 151 insertions(+), 76 deletions(-) diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 26807894..bd1e6a76 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/stretchr/testify/assert" ) var ( @@ -112,6 +113,7 @@ func TestInvoiceWorkflow(t *testing.T) { fakeInvoice.Terms.Features = emptyFeatures paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() + ref := InvoiceRefByHash(paymentHash) // Add the invoice to the database, this should succeed as there aren't // any existing invoices within the database with the same payment @@ -123,7 +125,7 @@ func TestInvoiceWorkflow(t *testing.T) { // Attempt to retrieve the invoice which was just added to the // database. It should be found, and the invoice returned should be // identical to the one created above. - dbInvoice, err := db.LookupInvoice(paymentHash) + dbInvoice, err := db.LookupInvoice(ref) if err != nil { t.Fatalf("unable to find invoice: %v", err) } @@ -144,11 +146,11 @@ func TestInvoiceWorkflow(t *testing.T) { // now have the settled bit toggle to true and a non-default // SettledDate payAmt := fakeInvoice.Terms.Value * 2 - _, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt)) + _, err = db.UpdateInvoice(ref, getUpdateInvoice(payAmt)) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } - dbInvoice2, err := db.LookupInvoice(paymentHash) + dbInvoice2, err := db.LookupInvoice(ref) if err != nil { t.Fatalf("unable to fetch invoice: %v", err) } @@ -180,7 +182,9 @@ func TestInvoiceWorkflow(t *testing.T) { // Attempt to look up a non-existent invoice, this should also fail but // with a "not found" error. var fakeHash [32]byte - if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound { + fakeRef := InvoiceRefByHash(fakeHash) + _, err = db.LookupInvoice(fakeRef) + if err != ErrInvoiceNotFound { t.Fatalf("lookup should have failed, instead %v", err) } @@ -256,7 +260,9 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) { Amt: 500, CustomRecords: make(record.CustomSet), } - invoice, err := db.UpdateInvoice(paymentHash, + + ref := InvoiceRefByHash(paymentHash) + invoice, err := db.UpdateInvoice(ref, func(invoice *Invoice) (*InvoiceUpdateDesc, error) { return &InvoiceUpdateDesc{ AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{ @@ -275,13 +281,14 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) { } // Cancel the htlc again. - invoice, err = db.UpdateInvoice(paymentHash, func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - return &InvoiceUpdateDesc{ - CancelHtlcs: map[CircuitKey]struct{}{ - key: {}, - }, - }, nil - }) + invoice, err = db.UpdateInvoice(ref, + func(invoice *Invoice) (*InvoiceUpdateDesc, error) { + return &InvoiceUpdateDesc{ + CancelHtlcs: map[CircuitKey]struct{}{ + key: {}, + }, + }, nil + }) if err != nil { t.Fatalf("unable to cancel htlc: %v", err) } @@ -380,8 +387,9 @@ func TestInvoiceAddTimeSeries(t *testing.T) { paymentHash := invoice.Terms.PaymentPreimage.Hash() + ref := InvoiceRefByHash(paymentHash) _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(invoice.Terms.Value), + ref, getUpdateInvoice(invoice.Terms.Value), ) if err != nil { t.Fatalf("unable to settle invoice: %v", err) @@ -570,9 +578,8 @@ func TestDuplicateSettleInvoice(t *testing.T) { } // With the invoice in the DB, we'll now attempt to settle the invoice. - dbInvoice, err := db.UpdateInvoice( - payHash, getUpdateInvoice(amt), - ) + ref := InvoiceRefByHash(payHash) + dbInvoice, err := db.UpdateInvoice(ref, getUpdateInvoice(amt)) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } @@ -601,9 +608,7 @@ func TestDuplicateSettleInvoice(t *testing.T) { // If we try to settle the invoice again, then we should get the very // same invoice back, but with an error this time. - dbInvoice, err = db.UpdateInvoice( - payHash, getUpdateInvoice(amt), - ) + dbInvoice, err = db.UpdateInvoice(ref, getUpdateInvoice(amt)) if err != ErrInvoiceAlreadySettled { t.Fatalf("expected ErrInvoiceAlreadySettled") } @@ -653,9 +658,8 @@ func TestQueryInvoices(t *testing.T) { // We'll only settle half of all invoices created. if i%2 == 0 { - _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(amt), - ) + ref := InvoiceRefByHash(paymentHash) + _, err := db.UpdateInvoice(ref, getUpdateInvoice(amt)) if err != nil { t.Fatalf("unable to settle invoice: %v", err) } @@ -951,7 +955,8 @@ func TestCustomRecords(t *testing.T) { 100001: []byte{1, 2}, } - _, err = db.UpdateInvoice(paymentHash, + ref := InvoiceRefByHash(paymentHash) + _, err = db.UpdateInvoice(ref, func(invoice *Invoice) (*InvoiceUpdateDesc, error) { return &InvoiceUpdateDesc{ AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{ @@ -969,7 +974,7 @@ func TestCustomRecords(t *testing.T) { // Retrieve the invoice from that database and verify that the custom // records are present. - dbInvoice, err := db.LookupInvoice(paymentHash) + dbInvoice, err := db.LookupInvoice(ref) if err != nil { t.Fatalf("unable to lookup invoice: %v", err) } @@ -981,3 +986,14 @@ func TestCustomRecords(t *testing.T) { t.Fatalf("invalid custom records") } } + +// TestInvoiceRef asserts that the proper identifiers are returned from an +// InvoiceRef depending on the constructor used. +func TestInvoiceRef(t *testing.T) { + payHash := lntypes.Hash{0x01} + + // An InvoiceRef by hash should return the provided hash and a nil + // payment addr. + refByHash := InvoiceRefByHash(payHash) + assert.Equal(t, payHash, refByHash.PayHash()) +} diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 23c10dc6..83c8de13 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -142,6 +142,32 @@ const ( amtPaidType tlv.Type = 13 ) +// InvoiceRef is an identifier for invoices supporting queries by payment hash. +type InvoiceRef struct { + // payHash is the payment hash of the target invoice. All invoices are + // currently indexed by payment hash. This value will be used as a + // fallback when no payment address is known. + payHash lntypes.Hash +} + +// InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by +// its payment hash. +func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef { + return InvoiceRef{ + payHash: payHash, + } +} + +// PayHash returns the target invoice's payment hash. +func (r InvoiceRef) PayHash() lntypes.Hash { + return r.payHash +} + +// String returns a human-readable representation of an InvoiceRef. +func (r InvoiceRef) String() string { + return fmt.Sprintf("(pay_hash=%v)", r.payHash) +} + // ContractState describes the state the invoice is in. type ContractState uint8 @@ -538,7 +564,7 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { // full invoice is returned. Before setting the incoming HTLC, the values // SHOULD be checked to ensure the payer meets the agreed upon contractual // terms of the payment. -func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { +func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { var invoice Invoice err := kvdb.View(d, func(tx kvdb.ReadTx) error { invoices := tx.ReadBucket(invoiceBucket) @@ -550,15 +576,17 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { return ErrNoInvoicesCreated } - // Check the invoice index to see if an invoice paying to this - // hash exists within the DB. - invoiceNum := invoiceIndex.Get(paymentHash[:]) - if invoiceNum == nil { - return ErrInvoiceNotFound + // Retrieve the invoice number for this invoice using the + // provided invoice reference. + invoiceNum, err := fetchInvoiceNumByRef( + invoiceIndex, ref, + ) + if err != nil { + return err } - // An invoice matching the payment hash has been found, so - // retrieve the record of the invoice itself. + // An invoice was found, retrieve the remainder of the invoice + // body. i, err := fetchInvoice(invoiceNum, invoices) if err != nil { return err @@ -574,6 +602,21 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { return invoice, nil } +// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice +// reference. +func fetchInvoiceNumByRef(invoiceIndex kvdb.ReadBucket, + ref InvoiceRef) ([]byte, error) { + + payHash := ref.PayHash() + + invoiceNum := invoiceIndex.Get(payHash[:]) + if invoiceNum == nil { + return nil, ErrInvoiceNotFound + } + + return invoiceNum, nil +} + // InvoiceWithPaymentHash is used to store an invoice and its corresponding // payment hash. This struct is only used to store results of // ChannelDB.FetchAllInvoicesWithPaymentHash() call. @@ -824,7 +867,7 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { // The update is performed inside the same database transaction that fetches the // invoice and is therefore atomic. The fields to update are controlled by the // supplied callback. -func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, +func (d *DB) UpdateInvoice(ref InvoiceRef, callback InvoiceUpdateCallback) (*Invoice, error) { var updatedInvoice *Invoice @@ -846,15 +889,18 @@ func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, return err } - // Check the invoice index to see if an invoice paying to this - // hash exists within the DB. - invoiceNum := invoiceIndex.Get(paymentHash[:]) - if invoiceNum == nil { - return ErrInvoiceNotFound - } + // Retrieve the invoice number for this invoice using the + // provided invoice reference. + invoiceNum, err := fetchInvoiceNumByRef( + invoiceIndex, ref, + ) + if err != nil { + return err + } + payHash := ref.PayHash() updatedInvoice, err = d.updateInvoice( - paymentHash, invoices, settleIndex, invoiceNum, + payHash, invoices, settleIndex, invoiceNum, callback, ) diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 4c1d9b2a..61e2e56b 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -61,8 +61,8 @@ type RegistryConfig struct { // htlcReleaseEvent describes an htlc auto-release event. It is used to release // mpp htlcs for which the complete set didn't arrive in time. type htlcReleaseEvent struct { - // hash is the payment hash of the htlc to release. - hash lntypes.Hash + // invoiceRef identifiers the invoice this htlc belongs to. + invoiceRef channeldb.InvoiceRef // key is the circuit key of the htlc to release. key channeldb.CircuitKey @@ -289,7 +289,8 @@ func (i *InvoiceRegistry) invoiceEventLoop() { // the subscriber. case *SingleInvoiceSubscription: log.Infof("New single invoice subscription "+ - "client: id=%v, hash=%v", e.id, e.hash) + "client: id=%v, ref=%v", e.id, + e.invoiceRef) i.singleNotificationClients[e.id] = e } @@ -297,8 +298,8 @@ func (i *InvoiceRegistry) invoiceEventLoop() { // A new htlc came in for auto-release. case event := <-i.htlcAutoReleaseChan: log.Debugf("Scheduling auto-release for htlc: "+ - "hash=%v, key=%v at %v", - event.hash, event.key, event.releaseTime) + "ref=%v, key=%v at %v", + event.invoiceRef, event.key, event.releaseTime) // We use an independent timer for every htlc rather // than a set timer that is reset with every htlc coming @@ -311,7 +312,7 @@ func (i *InvoiceRegistry) invoiceEventLoop() { case <-nextReleaseTick: event := autoReleaseHeap.Pop().(*htlcReleaseEvent) err := i.cancelSingleHtlc( - event.hash, event.key, ResultMppTimeout, + event.invoiceRef, event.key, ResultMppTimeout, ) if err != nil { log.Errorf("HTLC timer: %v", err) @@ -328,7 +329,7 @@ func (i *InvoiceRegistry) invoiceEventLoop() { func (i *InvoiceRegistry) dispatchToSingleClients(event *invoiceEvent) { // Dispatch to single invoice subscribers. for _, client := range i.singleNotificationClients { - if client.hash != event.hash { + if client.invoiceRef.PayHash() != event.hash { continue } @@ -465,7 +466,7 @@ func (i *InvoiceRegistry) deliverBacklogEvents(client *InvoiceSubscription) erro func (i *InvoiceRegistry) deliverSingleBacklogEvents( client *SingleInvoiceSubscription) error { - invoice, err := i.cdb.LookupInvoice(client.hash) + invoice, err := i.cdb.LookupInvoice(client.invoiceRef) // It is possible that the invoice does not exist yet, but the client is // already watching it in anticipation. @@ -479,7 +480,7 @@ func (i *InvoiceRegistry) deliverSingleBacklogEvents( } err = client.notify(&invoiceEvent{ - hash: client.hash, + hash: client.invoiceRef.PayHash(), invoice: &invoice, }) if err != nil { @@ -502,8 +503,8 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, i.Lock() - log.Debugf("Invoice(%v): added with terms %v", paymentHash, - invoice.Terms) + ref := channeldb.InvoiceRefByHash(paymentHash) + log.Debugf("Invoice%v: added with terms %v", ref, invoice.Terms) addIndex, err := i.cdb.AddInvoice(invoice, paymentHash) if err != nil { @@ -533,17 +534,18 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice, // We'll check the database to see if there's an existing matching // invoice. - return i.cdb.LookupInvoice(rHash) + ref := channeldb.InvoiceRefByHash(rHash) + return i.cdb.LookupInvoice(ref) } // startHtlcTimer starts a new timer via the invoice registry main loop that // cancels a single htlc on an invoice when the htlc hold duration has passed. -func (i *InvoiceRegistry) startHtlcTimer(hash lntypes.Hash, +func (i *InvoiceRegistry) startHtlcTimer(invoiceRef channeldb.InvoiceRef, key channeldb.CircuitKey, acceptTime time.Time) error { releaseTime := acceptTime.Add(i.cfg.HtlcHoldDuration) event := &htlcReleaseEvent{ - hash: hash, + invoiceRef: invoiceRef, key: key, releaseTime: releaseTime, } @@ -560,7 +562,7 @@ func (i *InvoiceRegistry) startHtlcTimer(hash lntypes.Hash, // cancelSingleHtlc cancels a single accepted htlc on an invoice. It takes // a resolution result which will be used to notify subscribed links and // resolvers of the details of the htlc cancellation. -func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash, +func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef, key channeldb.CircuitKey, result FailResolutionResult) error { i.Lock() @@ -572,7 +574,7 @@ func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash, // Only allow individual htlc cancelation on open invoices. if invoice.State != channeldb.ContractOpen { log.Debugf("cancelSingleHtlc: invoice %v no longer "+ - "open", hash) + "open", invoiceRef) return nil, nil } @@ -587,13 +589,13 @@ func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash, // resolved. if htlc.State != channeldb.HtlcStateAccepted { log.Debugf("cancelSingleHtlc: htlc %v on invoice %v "+ - "is already resolved", key, hash) + "is already resolved", key, invoiceRef) return nil, nil } log.Debugf("cancelSingleHtlc: cancelling htlc %v on invoice %v", - key, hash) + key, invoiceRef) // Return an update descriptor that cancels htlc and keeps // invoice open. @@ -610,7 +612,7 @@ func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash, // Intercept the update descriptor to set the local updated variable. If // no invoice update is performed, we can return early. var updated bool - invoice, err := i.cdb.UpdateInvoice(hash, + invoice, err := i.cdb.UpdateInvoice(invoiceRef, func(invoice *channeldb.Invoice) ( *channeldb.InvoiceUpdateDesc, error) { @@ -774,7 +776,9 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, // main event loop. case *htlcAcceptResolution: if r.autoRelease { - err := i.startHtlcTimer(rHash, circuitKey, r.acceptTime) + err := i.startHtlcTimer( + ctx.invoiceRef(), circuitKey, r.acceptTime, + ) if err != nil { return nil, err } @@ -808,7 +812,7 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( updateSubscribers bool ) invoice, err := i.cdb.UpdateInvoice( - ctx.hash, + ctx.invoiceRef(), func(inv *channeldb.Invoice) ( *channeldb.InvoiceUpdateDesc, error) { @@ -962,7 +966,8 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error { } hash := preimage.Hash() - invoice, err := i.cdb.UpdateInvoice(hash, updateInvoice) + invoiceRef := channeldb.InvoiceRefByHash(hash) + invoice, err := i.cdb.UpdateInvoice(invoiceRef, updateInvoice) if err != nil { log.Errorf("SettleHodlInvoice with preimage %v: %v", preimage, err) @@ -970,7 +975,7 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error { return err } - log.Debugf("Invoice(%v): settled with preimage %v", hash, + log.Debugf("Invoice%v: settled with preimage %v", invoiceRef, invoice.Terms.PaymentPreimage) // In the callback, we marked the invoice as settled. UpdateInvoice will @@ -1011,7 +1016,8 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, i.Lock() defer i.Unlock() - log.Debugf("Invoice(%v): canceling invoice", payHash) + ref := channeldb.InvoiceRefByHash(payHash) + log.Debugf("Invoice%v: canceling invoice", ref) updateInvoice := func(invoice *channeldb.Invoice) ( *channeldb.InvoiceUpdateDesc, error) { @@ -1032,12 +1038,13 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, }, nil } - invoice, err := i.cdb.UpdateInvoice(payHash, updateInvoice) + invoiceRef := channeldb.InvoiceRefByHash(payHash) + invoice, err := i.cdb.UpdateInvoice(invoiceRef, updateInvoice) // Implement idempotency by returning success if the invoice was already // canceled. if err == channeldb.ErrInvoiceAlreadyCanceled { - log.Debugf("Invoice(%v): already canceled", payHash) + log.Debugf("Invoice%v: already canceled", ref) return nil } if err != nil { @@ -1046,12 +1053,12 @@ func (i *InvoiceRegistry) cancelInvoiceImpl(payHash lntypes.Hash, // Return without cancellation if the invoice state is ContractAccepted. if invoice.State == channeldb.ContractAccepted { - log.Debugf("Invoice(%v): remains accepted as cancel wasn't"+ - "explicitly requested.", payHash) + log.Debugf("Invoice%v: remains accepted as cancel wasn't"+ + "explicitly requested.", ref) return nil } - log.Debugf("Invoice(%v): canceled", payHash) + log.Debugf("Invoice%v: canceled", ref) // In the callback, some htlcs may have been moved to the canceled // state. We now go through all of these and notify links and resolvers @@ -1140,7 +1147,7 @@ type InvoiceSubscription struct { type SingleInvoiceSubscription struct { invoiceSubscriptionKit - hash lntypes.Hash + invoiceRef channeldb.InvoiceRef // Updates is a channel that we'll use to send all invoice events for // the invoice that is subscribed to. @@ -1269,7 +1276,7 @@ func (i *InvoiceRegistry) SubscribeSingleInvoice( ntfnQueue: queue.NewConcurrentQueue(20), cancelChan: make(chan struct{}), }, - hash: hash, + invoiceRef: channeldb.InvoiceRefByHash(hash), } client.ntfnQueue.Start() diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 319c30cf..fa672a36 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -26,7 +26,7 @@ func TestSettleInvoice(t *testing.T) { } defer subscription.Cancel() - if subscription.hash != testInvoicePaymentHash { + if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { t.Fatalf("expected subscription for provided hash") } @@ -237,7 +237,7 @@ func TestCancelInvoice(t *testing.T) { } defer subscription.Cancel() - if subscription.hash != testInvoicePaymentHash { + if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { t.Fatalf("expected subscription for provided hash") } @@ -362,7 +362,7 @@ func TestSettleHoldInvoice(t *testing.T) { } defer subscription.Cancel() - if subscription.hash != testInvoicePaymentHash { + if subscription.invoiceRef.PayHash() != testInvoicePaymentHash { t.Fatalf("expected subscription for provided hash") } diff --git a/invoices/update.go b/invoices/update.go index 3226779c..4680b3cd 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -22,10 +22,16 @@ type invoiceUpdateCtx struct { mpp *record.MPP } +// invoiceRef returns an identifier that can be used to lookup or update the +// invoice this HTLC is targeting. +func (i *invoiceUpdateCtx) invoiceRef() channeldb.InvoiceRef { + return channeldb.InvoiceRefByHash(i.hash) +} + // log logs a message specific to this update context. func (i *invoiceUpdateCtx) log(s string) { - log.Debugf("Invoice(%x): %v, amt=%v, expiry=%v, circuit=%v, mpp=%v", - i.hash[:], s, i.amtPaid, i.expiry, i.circuitKey, i.mpp) + log.Debugf("Invoice%v: %v, amt=%v, expiry=%v, circuit=%v, mpp=%v", + i.invoiceRef, s, i.amtPaid, i.expiry, i.circuitKey, i.mpp) } // failRes is a helper function which creates a failure resolution with From cbf71b5452fa1d3036a43309e490787c5f7f08dc Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 May 2020 15:37:39 -0700 Subject: [PATCH 6/8] channeldb+invoices: use payment addr as primary index --- channeldb/db.go | 9 +++ channeldb/error.go | 8 ++ channeldb/invoice_test.go | 155 ++++++++++++++++++++++++++++++++---- channeldb/invoices.go | 108 ++++++++++++++++++++++--- htlcswitch/test_utils.go | 24 ++++-- invoices/test_utils_test.go | 12 ++- invoices/update.go | 4 + 7 files changed, 286 insertions(+), 34 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index 9df3bccf..fe2dc149 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcwallet/walletdb" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb/kvdb" + mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" @@ -136,6 +137,13 @@ var ( number: 13, migration: migration13.MigrateMPP, }, + { + // Initialize payment address index and begin using it + // as the default index, falling back to payment hash + // index. + number: 14, + migration: mig.CreateTLB(payAddrIndexBucket), + }, } // Big endian is the preferred byte order, due to cursor scans over @@ -248,6 +256,7 @@ var topLevelBuckets = [][]byte{ forwardingLogBucket, fwdPackagesKey, invoiceBucket, + payAddrIndexBucket, nodeInfoBucket, nodeBucket, edgeBucket, diff --git a/channeldb/error.go b/channeldb/error.go index b1364fb4..97e06a14 100644 --- a/channeldb/error.go +++ b/channeldb/error.go @@ -43,6 +43,14 @@ var ( // payment hash already exists. ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists") + // ErrDuplicatePayAddr is returned when an invoice with the target + // payment addr already exists. + ErrDuplicatePayAddr = fmt.Errorf("invoice with payemnt addr already exists") + + // ErrInvRefEquivocation is returned when an InvoiceRef targets + // multiple, distinct invoices. + ErrInvRefEquivocation = errors.New("inv ref matches multiple invoices") + // ErrNoPaymentsCreated is returned when bucket of payments hasn't been // created. ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments") diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index bd1e6a76..626a039b 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -20,16 +20,20 @@ var ( ) func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { - var pre [32]byte + var pre, payAddr [32]byte if _, err := rand.Read(pre[:]); err != nil { return nil, err } + if _, err := rand.Read(payAddr[:]); err != nil { + return nil, err + } i := &Invoice{ CreationDate: testNow, Terms: ContractTerm{ Expiry: 4000, PaymentPreimage: pre, + PaymentAddr: payAddr, Value: value, Features: emptyFeatures, }, @@ -91,9 +95,45 @@ func TestInvoiceIsPending(t *testing.T) { } } +type invWorkflowTest struct { + name string + queryPayHash bool + queryPayAddr bool +} + +var invWorkflowTests = []invWorkflowTest{ + { + name: "unknown", + queryPayHash: false, + queryPayAddr: false, + }, + { + name: "only payhash known", + queryPayHash: true, + queryPayAddr: false, + }, + { + name: "payaddr and payhash known", + queryPayHash: true, + queryPayAddr: true, + }, +} + +// TestInvoiceWorkflow asserts the basic process of inserting, fetching, and +// updating an invoice. We assert that the flow is successful using when +// querying with various combinations of payment hash and payment address. func TestInvoiceWorkflow(t *testing.T) { t.Parallel() + for _, test := range invWorkflowTests { + test := test + t.Run(test.name, func(t *testing.T) { + testInvoiceWorkflow(t, test) + }) + } +} + +func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { db, cleanUp, err := makeTestDB() defer cleanUp() if err != nil { @@ -102,23 +142,33 @@ func TestInvoiceWorkflow(t *testing.T) { // Create a fake invoice which we'll use several times in the tests // below. - fakeInvoice := &Invoice{ - CreationDate: testNow, - Htlcs: map[CircuitKey]*InvoiceHTLC{}, + fakeInvoice, err := randInvoice(10000) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) } - fakeInvoice.Memo = []byte("memo") - fakeInvoice.PaymentRequest = []byte("") - copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) - fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) - fakeInvoice.Terms.Features = emptyFeatures + invPayHash := fakeInvoice.Terms.PaymentPreimage.Hash() - paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() - ref := InvoiceRefByHash(paymentHash) + // Select the payment hash and payment address we will use to lookup or + // update the invoice for the remainder of the test. + var ( + payHash lntypes.Hash + payAddr *[32]byte + ref InvoiceRef + ) + switch { + case test.queryPayHash && test.queryPayAddr: + payHash = invPayHash + payAddr = &fakeInvoice.Terms.PaymentAddr + ref = InvoiceRefByHashAndAddr(payHash, *payAddr) + case test.queryPayHash: + payHash = invPayHash + ref = InvoiceRefByHash(payHash) + } // Add the invoice to the database, this should succeed as there aren't // any existing invoices within the database with the same payment // hash. - if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil { + if _, err := db.AddInvoice(fakeInvoice, invPayHash); err != nil { t.Fatalf("unable to find invoice: %v", err) } @@ -126,8 +176,11 @@ func TestInvoiceWorkflow(t *testing.T) { // database. It should be found, and the invoice returned should be // identical to the one created above. dbInvoice, err := db.LookupInvoice(ref) - if err != nil { - t.Fatalf("unable to find invoice: %v", err) + if !test.queryPayAddr && !test.queryPayHash { + if err != ErrInvoiceNotFound { + t.Fatalf("invoice should not exist: %v", err) + } + return } if !reflect.DeepEqual(*fakeInvoice, dbInvoice) { t.Fatalf("invoice fetched from db doesn't match original %v vs %v", @@ -174,7 +227,7 @@ func TestInvoiceWorkflow(t *testing.T) { // Attempt to insert generated above again, this should fail as // duplicates are rejected by the processing logic. - if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice { + if _, err := db.AddInvoice(fakeInvoice, payHash); err != ErrDuplicateInvoice { t.Fatalf("invoice insertion should fail due to duplication, "+ "instead %v", err) } @@ -232,6 +285,70 @@ func TestInvoiceWorkflow(t *testing.T) { } } +// TestAddDuplicatePayAddr asserts that the payment addresses of inserted +// invoices are unique. +func TestAddDuplicatePayAddr(t *testing.T) { + db, cleanUp, err := makeTestDB() + defer cleanUp() + assert.Nil(t, err) + + // Create two invoices with the same payment addr. + invoice1, err := randInvoice(1000) + assert.Nil(t, err) + + invoice2, err := randInvoice(20000) + assert.Nil(t, err) + invoice2.Terms.PaymentAddr = invoice1.Terms.PaymentAddr + + // First insert should succeed. + inv1Hash := invoice1.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice1, inv1Hash) + assert.Nil(t, err) + + // Second insert should fail with duplicate payment addr. + inv2Hash := invoice2.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice2, inv2Hash) + assert.Equal(t, ErrDuplicatePayAddr, err) +} + +// TestInvRefEquivocation asserts that retrieving or updating an invoice using +// an equivocating InvoiceRef results in ErrInvRefEquivocation. +func TestInvRefEquivocation(t *testing.T) { + db, cleanUp, err := makeTestDB() + defer cleanUp() + assert.Nil(t, err) + + // Add two random invoices. + invoice1, err := randInvoice(1000) + assert.Nil(t, err) + + inv1Hash := invoice1.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice1, inv1Hash) + assert.Nil(t, err) + + invoice2, err := randInvoice(2000) + assert.Nil(t, err) + + inv2Hash := invoice2.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice2, inv2Hash) + assert.Nil(t, err) + + // Now, query using invoice 1's payment address, but invoice 2's payment + // hash. We expect an error since the invref points to multiple + // invoices. + ref := InvoiceRefByHashAndAddr(inv2Hash, invoice1.Terms.PaymentAddr) + _, err = db.LookupInvoice(ref) + assert.Equal(t, ErrInvRefEquivocation, err) + + // The same error should be returned when updating an equivocating + // reference. + nop := func(_ *Invoice) (*InvoiceUpdateDesc, error) { + return nil, nil + } + _, err = db.UpdateInvoice(ref, nop) + assert.Equal(t, ErrInvRefEquivocation, err) +} + // TestInvoiceCancelSingleHtlc tests that a single htlc can be canceled on the // invoice. func TestInvoiceCancelSingleHtlc(t *testing.T) { @@ -991,9 +1108,17 @@ func TestCustomRecords(t *testing.T) { // InvoiceRef depending on the constructor used. func TestInvoiceRef(t *testing.T) { payHash := lntypes.Hash{0x01} + payAddr := [32]byte{0x02} // An InvoiceRef by hash should return the provided hash and a nil // payment addr. refByHash := InvoiceRefByHash(payHash) assert.Equal(t, payHash, refByHash.PayHash()) + assert.Equal(t, (*[32]byte)(nil), refByHash.PayAddr()) + + // An InvoiceRef by hash and addr should return the payment hash and + // payment addr passed to the constructor. + refByHashAndAddr := InvoiceRefByHashAndAddr(payHash, payAddr) + assert.Equal(t, payHash, refByHashAndAddr.PayHash()) + assert.Equal(t, &payAddr, refByHashAndAddr.PayAddr()) } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 83c8de13..3bb005f0 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -37,6 +37,16 @@ var ( // maps: payHash => invoiceKey invoiceIndexBucket = []byte("paymenthashes") + // payAddrIndexBucket is the name of the top-level bucket that maps + // payment addresses to their invoice number. This can be used + // to efficiently query or update non-legacy invoices. Note that legacy + // invoices will not be included in this index since they all have the + // same, all-zero payment address, however all newly generated invoices + // will end up in this index. + // + // maps: payAddr => invoiceKey + payAddrIndexBucket = []byte("pay-addr-index") + // numInvoicesKey is the name of key which houses the auto-incrementing // invoice ID which is essentially used as a primary key. With each // invoice inserted, the primary key is incremented by one. This key is @@ -142,12 +152,23 @@ const ( amtPaidType tlv.Type = 13 ) -// InvoiceRef is an identifier for invoices supporting queries by payment hash. +// InvoiceRef is a composite identifier for invoices. Invoices can be referenced +// by various combinations of payment hash and payment addr, in certain contexts +// only some of these are known. An InvoiceRef and its constructors thus +// encapsulate the valid combinations of query parameters that can be supplied +// to LookupInvoice and UpdateInvoice. type InvoiceRef struct { // payHash is the payment hash of the target invoice. All invoices are // currently indexed by payment hash. This value will be used as a // fallback when no payment address is known. payHash lntypes.Hash + + // payAddr is the payment addr of the target invoice. Newer invoices + // (0.11 and up) are indexed by payment address in addition to payment + // hash, but pre 0.8 invoices do not have one at all. When this value is + // known it will be used as the primary identifier, falling back to + // payHash if no value is known. + payAddr *[32]byte } // InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by @@ -158,13 +179,39 @@ func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef { } } +// InvoiceRefByHashAndAddr creates an InvoiceRef that first queries for an +// invoice by the provided payment address, falling back to the payment hash if +// the payment address is unknown. +func InvoiceRefByHashAndAddr(payHash lntypes.Hash, + payAddr [32]byte) InvoiceRef { + + return InvoiceRef{ + payHash: payHash, + payAddr: &payAddr, + } +} + // PayHash returns the target invoice's payment hash. func (r InvoiceRef) PayHash() lntypes.Hash { return r.payHash } +// PayAddr returns the optional payment address of the target invoice. +// +// NOTE: This value may be nil. +func (r InvoiceRef) PayAddr() *[32]byte { + if r.payAddr != nil { + addr := *r.payAddr + return &addr + } + return nil +} + // String returns a human-readable representation of an InvoiceRef. func (r InvoiceRef) String() string { + if r.payAddr != nil { + return fmt.Sprintf("(pay_hash=%v, pay_addr=%x)", r.payHash, *r.payAddr) + } return fmt.Sprintf("(pay_hash=%v)", r.payHash) } @@ -458,6 +505,11 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( return ErrDuplicateInvoice } + payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket) + if payAddrIndex.Get(newInvoice.Terms.PaymentAddr[:]) != nil { + return ErrDuplicatePayAddr + } + // If the current running payment ID counter hasn't yet been // created, then create it now. var invoiceNum uint32 @@ -474,8 +526,8 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( } newIndex, err := putInvoice( - invoices, invoiceIndex, addIndex, newInvoice, invoiceNum, - paymentHash, + invoices, invoiceIndex, payAddrIndex, addIndex, + newInvoice, invoiceNum, paymentHash, ) if err != nil { return err @@ -575,11 +627,12 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { if invoiceIndex == nil { return ErrNoInvoicesCreated } + payAddrIndex := tx.ReadBucket(payAddrIndexBucket) // Retrieve the invoice number for this invoice using the // provided invoice reference. invoiceNum, err := fetchInvoiceNumByRef( - invoiceIndex, ref, + invoiceIndex, payAddrIndex, ref, ) if err != nil { return err @@ -603,18 +656,44 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { } // fetchInvoiceNumByRef retrieve the invoice number for the provided invoice -// reference. -func fetchInvoiceNumByRef(invoiceIndex kvdb.ReadBucket, +// reference. The payment address will be treated as the primary key, falling +// back to the payment hash if nothing is found for the payment address. An +// error is returned if the invoice is not found. +func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.ReadBucket, ref InvoiceRef) ([]byte, error) { payHash := ref.PayHash() + payAddr := ref.PayAddr() - invoiceNum := invoiceIndex.Get(payHash[:]) - if invoiceNum == nil { - return nil, ErrInvoiceNotFound + var ( + invoiceNumByHash = invoiceIndex.Get(payHash[:]) + invoiceNumByAddr []byte + ) + if payAddr != nil { + invoiceNumByAddr = payAddrIndex.Get(payAddr[:]) } - return invoiceNum, nil + switch { + + // If payment address and payment hash both reference an existing + // invoice, ensure they reference the _same_ invoice. + case invoiceNumByAddr != nil && invoiceNumByHash != nil: + if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) { + return nil, ErrInvRefEquivocation + } + + return invoiceNumByAddr, nil + + // If we were only able to reference the invoice by hash, return the + // corresponding invoice number. This can happen when no payment address + // was provided, or if it didn't match anything in our records. + case invoiceNumByHash != nil: + return invoiceNumByHash, nil + + // Otherwise we don't know of the target invoice. + default: + return nil, ErrInvoiceNotFound + } } // InvoiceWithPaymentHash is used to store an invoice and its corresponding @@ -888,11 +967,12 @@ func (d *DB) UpdateInvoice(ref InvoiceRef, if err != nil { return err } + payAddrIndex := tx.ReadBucket(payAddrIndexBucket) // Retrieve the invoice number for this invoice using the // provided invoice reference. invoiceNum, err := fetchInvoiceNumByRef( - invoiceIndex, ref, + invoiceIndex, payAddrIndex, ref, ) if err != nil { return err @@ -971,7 +1051,7 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { return settledInvoices, nil } -func putInvoice(invoices, invoiceIndex, addIndex kvdb.RwBucket, +func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket, i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) ( uint64, error) { @@ -996,6 +1076,10 @@ func putInvoice(invoices, invoiceIndex, addIndex kvdb.RwBucket, if err != nil { return 0, err } + err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:]) + if err != nil { + return 0, err + } // Next, we'll obtain the next add invoice index (sequence // number), so we can properly place this invoice within this diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 2da6e18b..429963d2 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -2,7 +2,7 @@ package htlcswitch import ( "bytes" - "crypto/rand" + crand "crypto/rand" "crypto/sha256" "encoding/binary" "fmt" @@ -137,7 +137,7 @@ func generateRandomBytes(n int) ([]byte, error) { // TODO(roasbeef): should use counter in tests (atomic) rather than // this - _, err := rand.Read(b[:]) + _, err := crand.Read(b) // Note that Err == nil only if we read len(b) bytes. if err != nil { return nil, err @@ -547,7 +547,7 @@ func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) { // invoice which should be added by destination peer. func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, blob [lnwire.OnionPacketSize]byte, - preimage, rhash [32]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, + preimage, rhash, payAddr [32]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, uint64, error) { // Create the db invoice. Normally the payment requests needs to be set, @@ -562,6 +562,7 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, FinalCltvDelta: testInvoiceCltvExpiry, Value: invoiceAmt, PaymentPreimage: preimage, + PaymentAddr: payAddr, Features: lnwire.NewFeatureVector( nil, lnwire.Features, ), @@ -598,8 +599,16 @@ func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, copy(preimage[:], r) rhash := sha256.Sum256(preimage[:]) + + var payAddr [sha256.Size]byte + r, err = generateRandomBytes(sha256.Size) + if err != nil { + return nil, nil, 0, err + } + copy(payAddr[:], r) + return generatePaymentWithPreimage( - invoiceAmt, htlcAmt, timelock, blob, preimage, rhash, + invoiceAmt, htlcAmt, timelock, blob, preimage, rhash, payAddr, ) } @@ -1328,10 +1337,15 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, rhash := preimage.Hash() + var payAddr [32]byte + if _, err := crand.Read(payAddr[:]); err != nil { + panic(err) + } + // Generate payment: invoice and htlc. invoice, htlc, pid, err := generatePaymentWithPreimage( invoiceAmt, htlcAmt, timelock, blob, - channeldb.UnknownPreimage, rhash, + channeldb.UnknownPreimage, rhash, payAddr, ) if err != nil { paymentErr <- err diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index cf0f14ea..8d98b132 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -1,6 +1,7 @@ package invoices import ( + "crypto/rand" "encoding/binary" "encoding/hex" "fmt" @@ -198,14 +199,20 @@ func newTestInvoice(t *testing.T, preimage lntypes.Preimage, expiry = time.Hour } + var payAddr [32]byte + if _, err := rand.Read(payAddr[:]); err != nil { + t.Fatalf("unable to generate payment addr: %v", err) + } + rawInvoice, err := zpay32.NewInvoice( testNetParams, preimage.Hash(), timestamp, zpay32.Amount(testInvoiceAmount), zpay32.Description(testInvoiceDescription), - zpay32.Expiry(expiry)) - + zpay32.Expiry(expiry), + zpay32.PaymentAddr(payAddr), + ) if err != nil { t.Fatalf("Error while creating new invoice: %v", err) } @@ -219,6 +226,7 @@ func newTestInvoice(t *testing.T, preimage lntypes.Preimage, return &channeldb.Invoice{ Terms: channeldb.ContractTerm{ PaymentPreimage: preimage, + PaymentAddr: payAddr, Value: testInvoiceAmount, Expiry: expiry, Features: testFeatures, diff --git a/invoices/update.go b/invoices/update.go index 4680b3cd..62522378 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -25,6 +25,10 @@ type invoiceUpdateCtx struct { // invoiceRef returns an identifier that can be used to lookup or update the // invoice this HTLC is targeting. func (i *invoiceUpdateCtx) invoiceRef() channeldb.InvoiceRef { + if i.mpp != nil { + payAddr := i.mpp.PaymentAddr() + return channeldb.InvoiceRefByHashAndAddr(i.hash, payAddr) + } return channeldb.InvoiceRefByHash(i.hash) } From 5c4ab4b7cfd28586db0106dd47dcf0b3210920bd Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Tue, 26 May 2020 17:56:16 -0700 Subject: [PATCH 7/8] lntest: update error whitelist --- lntest/itest/log_error_whitelist.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lntest/itest/log_error_whitelist.txt b/lntest/itest/log_error_whitelist.txt index a65ce6ce..d51c2142 100644 --- a/lntest/itest/log_error_whitelist.txt +++ b/lntest/itest/log_error_whitelist.txt @@ -181,3 +181,6 @@