htlcswitch/test: use real invoice registry with temp db as mock
In further commits the behaviour of invoice registry becomes more intrinsically connected to the link. This commit prepares for that by allowing link and registry to be tested as a single unit.
This commit is contained in:
parent
e464ed18c7
commit
aeb35d9898
@ -2046,7 +2046,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
|
||||
|
||||
// We must add the invoice to the registry, such that Alice expects
|
||||
// this payment.
|
||||
err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(*invoice)
|
||||
err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(
|
||||
*invoice, htlc.PaymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to add invoice to registry: %v", err)
|
||||
}
|
||||
@ -2148,7 +2150,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create payment: %v", err)
|
||||
}
|
||||
err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(*invoice)
|
||||
err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(
|
||||
*invoice, htlc.PaymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to add invoice to registry: %v", err)
|
||||
}
|
||||
@ -3804,7 +3808,9 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := n.carolServer.registry.AddInvoice(*invoice); err != nil {
|
||||
|
||||
err = n.carolServer.registry.AddInvoice(*invoice, htlc.PaymentHash)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to add invoice in carol registry: %v", err)
|
||||
}
|
||||
|
||||
@ -4182,7 +4188,8 @@ func generateHtlc(t *testing.T, coreLink *channelLink,
|
||||
// We must add the invoice to the registry, such that Alice
|
||||
// expects this payment.
|
||||
err := coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(
|
||||
*invoice)
|
||||
*invoice, htlc.PaymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to add invoice to registry: %v", err)
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@ -23,6 +24,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/contractcourt"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/invoices"
|
||||
"github.com/lightningnetwork/lnd/lnpeer"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
@ -703,82 +705,83 @@ func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) {
|
||||
|
||||
var _ ChannelLink = (*mockChannelLink)(nil)
|
||||
|
||||
type mockInvoiceRegistry struct {
|
||||
sync.Mutex
|
||||
func newDB() (*channeldb.DB, func(), error) {
|
||||
// First, create a temporary directory to be used for the duration of
|
||||
// this test.
|
||||
tempDirName, err := ioutil.TempDir("", "channeldb")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
invoices map[lntypes.Hash]channeldb.Invoice
|
||||
finalDelta uint32
|
||||
// Next, create channeldb for the first time.
|
||||
cdb, err := channeldb.Open(tempDirName)
|
||||
if err != nil {
|
||||
os.RemoveAll(tempDirName)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanUp := func() {
|
||||
cdb.Close()
|
||||
os.RemoveAll(tempDirName)
|
||||
}
|
||||
|
||||
return cdb, cleanUp, nil
|
||||
}
|
||||
|
||||
type mockInvoiceRegistry struct {
|
||||
settleChan chan lntypes.Hash
|
||||
|
||||
registry *invoices.InvoiceRegistry
|
||||
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
|
||||
cdb, cleanup, err := newDB()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
decodeExpiry := func(invoice string) (uint32, error) {
|
||||
return 3, nil
|
||||
}
|
||||
|
||||
registry := invoices.NewRegistry(cdb, decodeExpiry)
|
||||
registry.Start()
|
||||
|
||||
return &mockInvoiceRegistry{
|
||||
finalDelta: minDelta,
|
||||
invoices: make(map[lntypes.Hash]channeldb.Invoice),
|
||||
registry: registry,
|
||||
cleanup: cleanup,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *mockInvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice, uint32, error) {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
invoice, ok := i.invoices[rHash]
|
||||
if !ok {
|
||||
return channeldb.Invoice{}, 0, fmt.Errorf("can't find mock "+
|
||||
"invoice: %x", rHash[:])
|
||||
}
|
||||
|
||||
return invoice, i.finalDelta, nil
|
||||
return i.registry.LookupInvoice(rHash)
|
||||
}
|
||||
|
||||
func (i *mockInvoiceRegistry) SettleInvoice(rhash lntypes.Hash,
|
||||
amt lnwire.MilliSatoshi) error {
|
||||
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
invoice, ok := i.invoices[rhash]
|
||||
if !ok {
|
||||
return fmt.Errorf("can't find mock invoice: %x", rhash[:])
|
||||
err := i.registry.SettleInvoice(rhash, amt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if invoice.Terms.State == channeldb.ContractSettled {
|
||||
return nil
|
||||
if i.settleChan != nil {
|
||||
i.settleChan <- rhash
|
||||
}
|
||||
|
||||
invoice.Terms.State = channeldb.ContractSettled
|
||||
invoice.AmtPaid = amt
|
||||
i.invoices[rhash] = invoice
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *mockInvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
invoice, ok := i.invoices[payHash]
|
||||
if !ok {
|
||||
return channeldb.ErrInvoiceNotFound
|
||||
return i.registry.CancelInvoice(payHash)
|
||||
}
|
||||
|
||||
if invoice.Terms.State == channeldb.ContractCanceled {
|
||||
return nil
|
||||
}
|
||||
func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice,
|
||||
paymentHash lntypes.Hash) error {
|
||||
|
||||
invoice.Terms.State = channeldb.ContractCanceled
|
||||
i.invoices[payHash] = invoice
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice) error {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
rhash := invoice.Terms.PaymentPreimage.Hash()
|
||||
i.invoices[rhash] = invoice
|
||||
|
||||
return nil
|
||||
_, err := i.registry.AddInvoice(&invoice, paymentHash)
|
||||
return err
|
||||
}
|
||||
|
||||
var _ InvoiceDatabase = (*mockInvoiceRegistry)(nil)
|
||||
|
@ -742,7 +742,7 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer,
|
||||
}
|
||||
|
||||
// Check who is last in the route and add invoice to server registry.
|
||||
if err := receiver.registry.AddInvoice(*invoice); err != nil {
|
||||
if err := receiver.registry.AddInvoice(*invoice, htlc.PaymentHash); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user