lnd.xprv/invoices/test_utils_test.go
Andras Banki-Horvath 7024f36a76 general: adding the Clock interface to aid testing
This commit adds Clock and DefaultClock and moves the private
invoices.testClock under the clock package while adding basic
unit tests for it.
Clock is an interface currently encapsulating Now() and TickAfter().
It can be added as an external dependency to any class. This way
tests can stub out time.Now() or time.After().

The DefaultClock class simply returns the real time.Now() and
time.After().
2019-12-13 16:52:22 +01:00

232 lines
4.7 KiB
Go

package invoices
import (
"encoding/hex"
"fmt"
"io/ioutil"
"os"
"runtime/pprof"
"testing"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/zpay32"
)
type mockPayload struct {
mpp *record.MPP
}
func (p *mockPayload) MultiPath() *record.MPP {
return p.mpp
}
func (p *mockPayload) CustomRecords() record.CustomSet {
return make(record.CustomSet)
}
var (
testTimeout = 5 * time.Second
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
testInvoicePreimage = lntypes.Preimage{1}
testInvoicePaymentHash = testInvoicePreimage.Hash()
testHtlcExpiry = uint32(5)
testInvoiceCltvDelta = uint32(4)
testFinalCltvRejectDelta = int32(4)
testCurrentHeight = int32(1)
testPrivKeyBytes, _ = hex.DecodeString(
"e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734")
testPrivKey, testPubKey = btcec.PrivKeyFromBytes(
btcec.S256(), testPrivKeyBytes)
testInvoiceDescription = "coffee"
testInvoiceAmount = lnwire.MilliSatoshi(100000)
testNetParams = &chaincfg.MainNetParams
testMessageSigner = zpay32.MessageSigner{
SignCompact: func(hash []byte) ([]byte, error) {
sig, err := btcec.SignCompact(btcec.S256(), testPrivKey, hash, true)
if err != nil {
return nil, fmt.Errorf("can't sign the message: %v", err)
}
return sig, nil
},
}
testFeatures = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
testPayload = &mockPayload{}
)
var (
testInvoiceAmt = lnwire.MilliSatoshi(100000)
testInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: testInvoicePreimage,
Value: testInvoiceAmt,
Features: testFeatures,
},
}
testHodlInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: channeldb.UnknownPreimage,
Value: testInvoiceAmt,
Features: testFeatures,
},
}
)
func newTestChannelDB() (*channeldb.DB, func(), error) {
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName, err := ioutil.TempDir("", "channeldb")
if err != nil {
return nil, nil, err
}
// Next, create channeldb for the first time.
cdb, err := channeldb.Open(tempDirName)
if err != nil {
os.RemoveAll(tempDirName)
return nil, nil, err
}
cleanUp := func() {
cdb.Close()
os.RemoveAll(tempDirName)
}
return cdb, cleanUp, nil
}
type testContext struct {
registry *InvoiceRegistry
clock *clock.TestClock
cleanup func()
t *testing.T
}
func newTestContext(t *testing.T) *testContext {
clock := clock.NewTestClock(testTime)
cdb, cleanup, err := newTestChannelDB()
if err != nil {
t.Fatal(err)
}
cdb.Now = clock.Now
// Instantiate and start the invoice ctx.registry.
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
HtlcHoldDuration: 30 * time.Second,
Clock: clock,
}
registry := NewRegistry(cdb, &cfg)
err = registry.Start()
if err != nil {
cleanup()
t.Fatal(err)
}
ctx := testContext{
registry: registry,
clock: clock,
t: t,
cleanup: func() {
registry.Stop()
cleanup()
},
}
return &ctx
}
func getCircuitKey(htlcID uint64) channeldb.CircuitKey {
return channeldb.CircuitKey{
ChanID: lnwire.ShortChannelID{
BlockHeight: 1, TxIndex: 2, TxPosition: 3,
},
HtlcID: htlcID,
}
}
func newTestInvoice(t *testing.T,
timestamp time.Time, expiry time.Duration) *channeldb.Invoice {
if expiry == 0 {
expiry = time.Hour
}
rawInvoice, err := zpay32.NewInvoice(
testNetParams,
testInvoicePaymentHash,
timestamp,
zpay32.Amount(testInvoiceAmount),
zpay32.Description(testInvoiceDescription),
zpay32.Expiry(expiry))
if err != nil {
t.Fatalf("Error while creating new invoice: %v", err)
}
paymentRequest, err := rawInvoice.Encode(testMessageSigner)
if err != nil {
t.Fatalf("Error while encoding payment request: %v", err)
}
return &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: testInvoicePreimage,
Value: testInvoiceAmount,
Expiry: expiry,
Features: testFeatures,
},
PaymentRequest: []byte(paymentRequest),
CreationDate: timestamp,
}
}
// timeout implements a test level timeout.
func timeout() func() {
done := make(chan struct{})
go func() {
select {
case <-time.After(5 * time.Second):
err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
if err != nil {
panic(fmt.Sprintf("error writing to std out after timeout: %v", err))
}
panic("timeout")
case <-done:
}
}()
return func() {
close(done)
}
}