htlcswitch: face race condition in unit tests by returning invoice
In this commit we modify the primary InvoiceRegistry interface within the package to instead return a direct value for LookupInvoice rather than a pointer. This fixes an existing race condition wherein a caller could modify or read the value of the returned invoice.
This commit is contained in:
parent
010815e280
commit
b6f64932c2
@ -12,7 +12,7 @@ import (
|
|||||||
type InvoiceDatabase interface {
|
type InvoiceDatabase interface {
|
||||||
// LookupInvoice attempts to look up an invoice according to it's 32
|
// LookupInvoice attempts to look up an invoice according to it's 32
|
||||||
// byte payment hash.
|
// byte payment hash.
|
||||||
LookupInvoice(chainhash.Hash) (*channeldb.Invoice, error)
|
LookupInvoice(chainhash.Hash) (channeldb.Invoice, error)
|
||||||
|
|
||||||
// SettleInvoice attempts to mark an invoice corresponding to the
|
// SettleInvoice attempts to mark an invoice corresponding to the
|
||||||
// passed payment hash as fully settled.
|
// passed payment hash as fully settled.
|
||||||
|
@ -978,7 +978,7 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) {
|
|||||||
invoice.Terms.PaymentPreimage[0] ^= byte(255)
|
invoice.Terms.PaymentPreimage[0] ^= byte(255)
|
||||||
|
|
||||||
// Check who is last in the route and add invoice to server registry.
|
// Check who is last in the route and add invoice to server registry.
|
||||||
if err := n.carolServer.registry.AddInvoice(invoice); err != nil {
|
if err := n.carolServer.registry.AddInvoice(*invoice); err != nil {
|
||||||
t.Fatalf("unable to add invoice in carol registry: %v", err)
|
t.Fatalf("unable to add invoice in carol registry: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1955,7 +1955,7 @@ func TestChannelRetransmission(t *testing.T) {
|
|||||||
// TODO(andrew.shvv) Will be removed if we move the notification center
|
// TODO(andrew.shvv) Will be removed if we move the notification center
|
||||||
// to the channel link itself.
|
// to the channel link itself.
|
||||||
|
|
||||||
var invoice *channeldb.Invoice
|
var invoice channeldb.Invoice
|
||||||
for i := 0; i < 20; i++ {
|
for i := 0; i < 20; i++ {
|
||||||
select {
|
select {
|
||||||
case <-time.After(time.Millisecond * 200):
|
case <-time.After(time.Millisecond * 200):
|
||||||
|
@ -397,22 +397,22 @@ var _ ChannelLink = (*mockChannelLink)(nil)
|
|||||||
|
|
||||||
type mockInvoiceRegistry struct {
|
type mockInvoiceRegistry struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
invoices map[chainhash.Hash]*channeldb.Invoice
|
invoices map[chainhash.Hash]channeldb.Invoice
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockRegistry() *mockInvoiceRegistry {
|
func newMockRegistry() *mockInvoiceRegistry {
|
||||||
return &mockInvoiceRegistry{
|
return &mockInvoiceRegistry{
|
||||||
invoices: make(map[chainhash.Hash]*channeldb.Invoice),
|
invoices: make(map[chainhash.Hash]channeldb.Invoice),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (*channeldb.Invoice, error) {
|
func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) {
|
||||||
i.Lock()
|
i.Lock()
|
||||||
defer i.Unlock()
|
defer i.Unlock()
|
||||||
|
|
||||||
invoice, ok := i.invoices[rHash]
|
invoice, ok := i.invoices[rHash]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("can't find mock invoice")
|
return channeldb.Invoice{}, errors.New("can't find mock invoice")
|
||||||
}
|
}
|
||||||
|
|
||||||
return invoice, nil
|
return invoice, nil
|
||||||
@ -428,11 +428,12 @@ func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
invoice.Terms.Settled = true
|
invoice.Terms.Settled = true
|
||||||
|
i.invoices[rhash] = invoice
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *mockInvoiceRegistry) AddInvoice(invoice *channeldb.Invoice) error {
|
func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice) error {
|
||||||
i.Lock()
|
i.Lock()
|
||||||
defer i.Unlock()
|
defer i.Unlock()
|
||||||
|
|
||||||
|
@ -549,7 +549,7 @@ func (n *threeHopNetwork) makePayment(sendingPeer, receivingPeer Peer,
|
|||||||
rhash = fastsha256.Sum256(invoice.Terms.PaymentPreimage[:])
|
rhash = fastsha256.Sum256(invoice.Terms.PaymentPreimage[:])
|
||||||
|
|
||||||
// Check who is last in the route and add invoice to server registry.
|
// Check who is last in the route and add invoice to server registry.
|
||||||
if err := receiver.registry.AddInvoice(invoice); err != nil {
|
if err := receiver.registry.AddInvoice(*invoice); err != nil {
|
||||||
paymentErr <- err
|
paymentErr <- err
|
||||||
return &paymentResponse{
|
return &paymentResponse{
|
||||||
rhash: rhash,
|
rhash: rhash,
|
||||||
|
@ -98,7 +98,7 @@ func (i *invoiceRegistry) AddInvoice(invoice *channeldb.Invoice) error {
|
|||||||
// lookupInvoice looks up an invoice by its payment hash (R-Hash), if found
|
// lookupInvoice looks up an invoice by its payment hash (R-Hash), if found
|
||||||
// then we're able to pull the funds pending within an HTLC.
|
// then we're able to pull the funds pending within an HTLC.
|
||||||
// TODO(roasbeef): ignore if settled?
|
// TODO(roasbeef): ignore if settled?
|
||||||
func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (*channeldb.Invoice, error) {
|
func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) {
|
||||||
// First check the in-memory debug invoice index to see if this is an
|
// First check the in-memory debug invoice index to see if this is an
|
||||||
// existing invoice added for debugging.
|
// existing invoice added for debugging.
|
||||||
i.RLock()
|
i.RLock()
|
||||||
@ -107,12 +107,17 @@ func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (*channeldb.Invoic
|
|||||||
|
|
||||||
// If found, then simply return the invoice directly.
|
// If found, then simply return the invoice directly.
|
||||||
if ok {
|
if ok {
|
||||||
return invoice, nil
|
return *invoice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, we'll check the database to see if there's an existing
|
// Otherwise, we'll check the database to see if there's an existing
|
||||||
// matching invoice.
|
// matching invoice.
|
||||||
return i.cdb.LookupInvoice(rHash)
|
invoice, err := i.cdb.LookupInvoice(rHash)
|
||||||
|
if err != nil {
|
||||||
|
return channeldb.Invoice{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return *invoice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SettleInvoice attempts to mark an invoice as settled. If the invoice is a
|
// SettleInvoice attempts to mark an invoice as settled. If the invoice is a
|
||||||
|
@ -3359,7 +3359,10 @@ func TestChanSyncUnableToSync(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestChanAvailableBandwidth...
|
// TestChanAvailableBandwidth tests the accuracy of the AvailableBalance()
|
||||||
|
// method. The value returned from this message should reflect the value
|
||||||
|
// returned within the commitment state of a channel after the transition is
|
||||||
|
// initiated.
|
||||||
func TestChanAvailableBandwidth(t *testing.T) {
|
func TestChanAvailableBandwidth(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -2029,7 +2029,7 @@ func (r *rpcServer) LookupInvoice(ctx context.Context,
|
|||||||
return spew.Sdump(invoice)
|
return spew.Sdump(invoice)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
rpcInvoice, err := createRPCInvoice(invoice)
|
rpcInvoice, err := createRPCInvoice(&invoice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user