invoices: extract invoice decoding from registry

Previously it was difficult to use the invoice registry in unit tests,
because it used zpay32 to decode the invoice. For that to succeed, a
valid signature is required on the payment request.

This commit injects the decode dependency on a different level so that
it is easier to mock.
This commit is contained in:
Joost Jager 2019-02-20 11:44:47 +01:00
parent c23bb5b3f1
commit 3b5c2f44c6
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
3 changed files with 29 additions and 13 deletions

@ -7,14 +7,12 @@ import (
"sync/atomic"
"time"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/zpay32"
)
var (
@ -52,7 +50,9 @@ type InvoiceRegistry struct {
// that *all* nodes are able to fully settle.
debugInvoices map[lntypes.Hash]*channeldb.Invoice
activeNetParams *chaincfg.Params
// decodeFinalCltvExpiry is a function used to decode the final expiry
// value from the payment request.
decodeFinalCltvExpiry func(invoice string) (uint32, error)
wg sync.WaitGroup
quit chan struct{}
@ -62,8 +62,8 @@ type InvoiceRegistry struct {
// wraps the persistent on-disk invoice storage with an additional in-memory
// layer. The in-memory layer is in place such that debug invoices can be added
// which are volatile yet available system wide within the daemon.
func NewRegistry(cdb *channeldb.DB,
activeNetParams *chaincfg.Params) *InvoiceRegistry {
func NewRegistry(cdb *channeldb.DB, decodeFinalCltvExpiry func(invoice string) (
uint32, error)) *InvoiceRegistry {
return &InvoiceRegistry{
cdb: cdb,
@ -74,7 +74,7 @@ func NewRegistry(cdb *channeldb.DB,
newSingleSubscriptions: make(chan *SingleInvoiceSubscription),
subscriptionCancels: make(chan uint32),
invoiceEvents: make(chan *invoiceEvent, 100),
activeNetParams: activeNetParams,
decodeFinalCltvExpiry: decodeFinalCltvExpiry,
quit: make(chan struct{}),
}
}
@ -430,14 +430,12 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice,
return channeldb.Invoice{}, 0, err
}
payReq, err := zpay32.Decode(
string(invoice.PaymentRequest), i.activeNetParams,
)
expiry, err := i.decodeFinalCltvExpiry(string(invoice.PaymentRequest))
if err != nil {
return channeldb.Invoice{}, 0, err
}
return invoice, uint32(payReq.MinFinalCLTVExpiry()), nil
return invoice, expiry, nil
}
// SettleInvoice attempts to mark an invoice as settled. If the invoice is a

@ -10,6 +10,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/zpay32"
)
var (
@ -28,6 +29,14 @@ var (
testPayReq = "lnbc500u1pwywxzwpp5nd2u9xzq02t0tuf2654as7vma42lwkcjptx4yzfq0umq4swpa7cqdqqcqzysmlpc9ewnydr8rr8dnltyxphdyf6mcqrsd6dml8zajtyhwe6a45d807kxtmzayuf0hh2d9tn478ecxkecdg7c5g85pntupug5kakm7xcpn63zqk"
)
func decodeExpiry(payReq string) (uint32, error) {
invoice, err := zpay32.Decode(payReq, &chaincfg.MainNetParams)
if err != nil {
return 0, err
}
return uint32(invoice.MinFinalCLTVExpiry()), nil
}
// TestSettleInvoice tests settling of an invoice and related notifications.
func TestSettleInvoice(t *testing.T) {
cdb, cleanup, err := newDB()
@ -37,7 +46,7 @@ func TestSettleInvoice(t *testing.T) {
defer cleanup()
// Instantiate and start the invoice registry.
registry := NewRegistry(cdb, &chaincfg.MainNetParams)
registry := NewRegistry(cdb, decodeExpiry)
err = registry.Start()
if err != nil {
@ -167,7 +176,7 @@ func TestCancelInvoice(t *testing.T) {
defer cleanup()
// Instantiate and start the invoice registry.
registry := NewRegistry(cdb, &chaincfg.MainNetParams)
registry := NewRegistry(cdb, decodeExpiry)
err = registry.Start()
if err != nil {

@ -46,6 +46,7 @@ import (
"github.com/lightningnetwork/lnd/sweep"
"github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/zpay32"
)
const (
@ -284,6 +285,14 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
readBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout,
)
decodeFinalCltvExpiry := func(payReq string) (uint32, error) {
invoice, err := zpay32.Decode(payReq, activeNetParams.Params)
if err != nil {
return 0, err
}
return uint32(invoice.MinFinalCLTVExpiry()), nil
}
s := &server{
chanDB: chanDB,
cc: cc,
@ -291,7 +300,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
writePool: writePool,
readPool: readPool,
invoices: invoices.NewRegistry(chanDB, activeNetParams.Params),
invoices: invoices.NewRegistry(chanDB, decodeFinalCltvExpiry),
channelNotifier: channelnotifier.New(chanDB),