diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 602933df..a2d47593 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -2046,7 +2046,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { // We must add the invoice to the registry, such that Alice expects // this payment. - err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(*invoice) + err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice( + *invoice, htlc.PaymentHash, + ) if err != nil { t.Fatalf("unable to add invoice to registry: %v", err) } @@ -2148,7 +2150,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { if err != nil { t.Fatalf("unable to create payment: %v", err) } - err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(*invoice) + err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice( + *invoice, htlc.PaymentHash, + ) if err != nil { t.Fatalf("unable to add invoice to registry: %v", err) } @@ -3804,7 +3808,9 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { if err != nil { t.Fatal(err) } - if err := n.carolServer.registry.AddInvoice(*invoice); err != nil { + + err = n.carolServer.registry.AddInvoice(*invoice, htlc.PaymentHash) + if err != nil { t.Fatalf("unable to add invoice in carol registry: %v", err) } @@ -4182,7 +4188,8 @@ func generateHtlc(t *testing.T, coreLink *channelLink, // We must add the invoice to the registry, such that Alice // expects this payment. err := coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice( - *invoice) + *invoice, htlc.PaymentHash, + ) if err != nil { t.Fatalf("unable to add invoice to registry: %v", err) } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 1d797fb1..3a064bc8 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -8,6 +8,7 @@ import ( "io" "io/ioutil" "net" + "os" "sync" "sync/atomic" "testing" @@ -23,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" @@ -703,82 +705,83 @@ func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { var _ ChannelLink = (*mockChannelLink)(nil) -type mockInvoiceRegistry struct { - sync.Mutex +func newDB() (*channeldb.DB, func(), error) { + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + return nil, nil, err + } - invoices map[lntypes.Hash]channeldb.Invoice - finalDelta uint32 + // Next, create channeldb for the first time. + cdb, err := channeldb.Open(tempDirName) + if err != nil { + os.RemoveAll(tempDirName) + return nil, nil, err + } + + cleanUp := func() { + cdb.Close() + os.RemoveAll(tempDirName) + } + + return cdb, cleanUp, nil +} + +type mockInvoiceRegistry struct { + settleChan chan lntypes.Hash + + registry *invoices.InvoiceRegistry + + cleanup func() } func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { + cdb, cleanup, err := newDB() + if err != nil { + panic(err) + } + + decodeExpiry := func(invoice string) (uint32, error) { + return 3, nil + } + + registry := invoices.NewRegistry(cdb, decodeExpiry) + registry.Start() + return &mockInvoiceRegistry{ - finalDelta: minDelta, - invoices: make(map[lntypes.Hash]channeldb.Invoice), + registry: registry, + cleanup: cleanup, } } func (i *mockInvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice, uint32, error) { - i.Lock() - defer i.Unlock() - - invoice, ok := i.invoices[rHash] - if !ok { - return channeldb.Invoice{}, 0, fmt.Errorf("can't find mock "+ - "invoice: %x", rHash[:]) - } - - return invoice, i.finalDelta, nil + return i.registry.LookupInvoice(rHash) } func (i *mockInvoiceRegistry) SettleInvoice(rhash lntypes.Hash, amt lnwire.MilliSatoshi) error { - i.Lock() - defer i.Unlock() - - invoice, ok := i.invoices[rhash] - if !ok { - return fmt.Errorf("can't find mock invoice: %x", rhash[:]) + err := i.registry.SettleInvoice(rhash, amt) + if err != nil { + return err } - - if invoice.Terms.State == channeldb.ContractSettled { - return nil + if i.settleChan != nil { + i.settleChan <- rhash } - invoice.Terms.State = channeldb.ContractSettled - invoice.AmtPaid = amt - i.invoices[rhash] = invoice - return nil } func (i *mockInvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { - i.Lock() - defer i.Unlock() - - invoice, ok := i.invoices[payHash] - if !ok { - return channeldb.ErrInvoiceNotFound - } - - if invoice.Terms.State == channeldb.ContractCanceled { - return nil - } - - invoice.Terms.State = channeldb.ContractCanceled - i.invoices[payHash] = invoice - - return nil + return i.registry.CancelInvoice(payHash) } -func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice) error { - i.Lock() - defer i.Unlock() +func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice, + paymentHash lntypes.Hash) error { - rhash := invoice.Terms.PaymentPreimage.Hash() - i.invoices[rhash] = invoice - - return nil + _, err := i.registry.AddInvoice(&invoice, paymentHash) + return err } var _ InvoiceDatabase = (*mockInvoiceRegistry)(nil) diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 46e0dac3..97603238 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -742,7 +742,7 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, } // Check who is last in the route and add invoice to server registry. - if err := receiver.registry.AddInvoice(*invoice); err != nil { + if err := receiver.registry.AddInvoice(*invoice, htlc.PaymentHash); err != nil { return nil, nil, err }