htlcswitch/test: hodl invoice restart test

This commit adds a test that covers the hodl invoice behaviour after a
link restart.
This commit is contained in:
Joost Jager 2019-04-08 13:10:32 +02:00
parent e5ead599cc
commit 570f9ca57e
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
2 changed files with 143 additions and 37 deletions

@ -4026,16 +4026,10 @@ func (h *persistentLinkHarness) restart(restartSwitch bool,
// First, remove the link from the switch. // First, remove the link from the switch.
h.coreLink.cfg.Switch.RemoveLink(h.link.ChanID()) h.coreLink.cfg.Switch.RemoveLink(h.link.ChanID())
var htlcSwitch *Switch
if restartSwitch { if restartSwitch {
// If a switch restart is requested, we will stop it and // If a switch restart is requested, we will stop it. It will be
// leave htlcSwitch nil, which will trigger the creation // reinstantiated in restartLink.
// of a fresh instance in restartLink.
h.coreLink.cfg.Switch.Stop() h.coreLink.cfg.Switch.Stop()
} else {
// Otherwise, we capture the switch's reference so that
// it can be carried over to the restarted link.
htlcSwitch = h.coreLink.cfg.Switch
} }
// Since our in-memory state may have diverged from our persistent // Since our in-memory state may have diverged from our persistent
@ -4051,8 +4045,8 @@ func (h *persistentLinkHarness) restart(restartSwitch bool,
// adding the link to an existing switch, or creating a new one using // adding the link to an existing switch, or creating a new one using
// the database owned by the link. // the database owned by the link.
var cleanUp func() var cleanUp func()
h.link, h.batchTicker, cleanUp, err = restartLink( h.link, h.batchTicker, cleanUp, err = h.restartLink(
h.channel, htlcSwitch, hodlFlags, h.channel, restartSwitch, hodlFlags,
) )
if err != nil { if err != nil {
h.t.Fatalf("unable to restart alicelink: %v", err) h.t.Fatalf("unable to restart alicelink: %v", err)
@ -4128,8 +4122,10 @@ func (h *persistentLinkHarness) trySignNextCommitment() {
// restartLink creates a new channel link from the given channel state, and adds // restartLink creates a new channel link from the given channel state, and adds
// to an htlcswitch. If none is provided by the caller, a new one will be // to an htlcswitch. If none is provided by the caller, a new one will be
// created using Alice's database. // created using Alice's database.
func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch, func (h *persistentLinkHarness) restartLink(
hodlFlags []hodl.Flag) (ChannelLink, chan time.Time, func(), error) { aliceChannel *lnwallet.LightningChannel, restartSwitch bool,
hodlFlags []hodl.Flag) (
ChannelLink, chan time.Time, func(), error) {
var ( var (
decoder = newMockIteratorDecoder() decoder = newMockIteratorDecoder()
@ -4145,14 +4141,12 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
TimeLockDelta: 6, TimeLockDelta: 6,
} }
invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta)
pCache = newMockPreimageCache() pCache = newMockPreimageCache()
) )
aliceDb := aliceChannel.State().Db aliceDb := aliceChannel.State().Db
aliceSwitch := h.coreLink.cfg.Switch
if aliceSwitch == nil { if restartSwitch {
var err error var err error
aliceSwitch, err = initSwitchWithDB(testStartingHeight, aliceDb) aliceSwitch, err = initSwitchWithDB(testStartingHeight, aliceDb)
if err != nil { if err != nil {
@ -4182,7 +4176,7 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
UpdateContractSignals: func(*contractcourt.ContractSignals) error { UpdateContractSignals: func(*contractcourt.ContractSignals) error {
return nil return nil
}, },
Registry: invoiceRegistry, Registry: h.coreLink.cfg.Registry,
ChainEvents: &contractcourt.ChainEventSubscription{}, ChainEvents: &contractcourt.ChainEventSubscription{},
BatchTicker: bticker, BatchTicker: bticker,
FwdPkgGCTicker: ticker.New(5 * time.Second), FwdPkgGCTicker: ticker.New(5 * time.Second),
@ -4247,12 +4241,13 @@ func generateHtlcAndInvoice(t *testing.T,
t.Helper() t.Helper()
htlcAmt := lnwire.NewMSatFromSatoshis(10000) htlcAmt := lnwire.NewMSatFromSatoshis(10000)
htlcExpiry := testStartingHeight + testInvoiceCltvExpiry
hops := []ForwardingInfo{ hops := []ForwardingInfo{
{ {
Network: BitcoinHop, Network: BitcoinHop,
NextHop: exitHop, NextHop: exitHop,
AmountToForward: htlcAmt, AmountToForward: htlcAmt,
OutgoingCTLV: 144, OutgoingCTLV: uint32(htlcExpiry),
}, },
} }
blob, err := generateRoute(hops...) blob, err := generateRoute(hops...)
@ -4260,8 +4255,9 @@ func generateHtlcAndInvoice(t *testing.T,
t.Fatalf("unable to generate route: %v", err) t.Fatalf("unable to generate route: %v", err)
} }
invoice, htlc, err := generatePayment(htlcAmt, htlcAmt, 144, invoice, htlc, err := generatePayment(
blob) htlcAmt, htlcAmt, uint32(htlcExpiry), blob,
)
if err != nil { if err != nil {
t.Fatalf("unable to create payment: %v", err) t.Fatalf("unable to create payment: %v", err)
} }
@ -5641,6 +5637,8 @@ type hodlInvoiceTestCtx struct {
amount lnwire.MilliSatoshi amount lnwire.MilliSatoshi
errChan chan error errChan chan error
restoreBob func() (*lnwallet.LightningChannel, error)
cleanUp func() cleanUp func()
} }
@ -5720,6 +5718,7 @@ func newHodlInvoiceTestCtx(t *testing.T) (*hodlInvoiceTestCtx, error) {
hash: hash, hash: hash,
amount: amount, amount: amount,
errChan: errChan, errChan: errChan,
restoreBob: bob.restore,
cleanUp: func() { cleanUp: func() {
cleanUp() cleanUp()
@ -5732,6 +5731,8 @@ func newHodlInvoiceTestCtx(t *testing.T) (*hodlInvoiceTestCtx, error) {
func TestChannelLinkHoldInvoiceSettle(t *testing.T) { func TestChannelLinkHoldInvoiceSettle(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)()
ctx, err := newHodlInvoiceTestCtx(t) ctx, err := newHodlInvoiceTestCtx(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -5744,13 +5745,9 @@ func TestChannelLinkHoldInvoiceSettle(t *testing.T) {
} }
// Wait for payment to succeed. // Wait for payment to succeed.
select { err = <-ctx.errChan
case err := <-ctx.errChan: if err != nil {
if err != nil { t.Fatal(err)
t.Fatal(err)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout")
} }
// Wait for Bob to receive the revocation. // Wait for Bob to receive the revocation.
@ -5774,6 +5771,8 @@ func TestChannelLinkHoldInvoiceSettle(t *testing.T) {
func TestChannelLinkHoldInvoiceCancel(t *testing.T) { func TestChannelLinkHoldInvoiceCancel(t *testing.T) {
t.Parallel() t.Parallel()
defer timeout(t)()
ctx, err := newHodlInvoiceTestCtx(t) ctx, err := newHodlInvoiceTestCtx(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -5786,15 +5785,102 @@ func TestChannelLinkHoldInvoiceCancel(t *testing.T) {
} }
// Wait for payment to succeed. // Wait for payment to succeed.
select { err = <-ctx.errChan
case err := <-ctx.errChan: if !strings.Contains(err.Error(),
if !strings.Contains(err.Error(), lnwire.CodeUnknownPaymentHash.String()) {
lnwire.CodeUnknownPaymentHash.String()) {
t.Fatal("expected unknown payment hash") t.Fatal("expected unknown payment hash")
} }
case <-time.After(5 * time.Second): }
t.Fatal("timeout")
// TestChannelLinkHoldInvoiceRestart asserts hodl htlcs are held after blocks
// are mined and the link is restarted. The initial expiry checks should not
// apply to hodl htlcs after restart.
func TestChannelLinkHoldInvoiceRestart(t *testing.T) {
t.Parallel()
defer timeout(t)()
const (
chanAmt = btcutil.SatoshiPerBitcoin * 5
)
// We'll start by creating a new link with our chanAmt (5 BTC). We will
// only be testing Alice's behavior, so the reference to Bob's channel
// state is unnecessary.
aliceLink, bobChannel, _, start, cleanUp, restore, err :=
newSingleLinkTestHarness(chanAmt, 0)
if err != nil {
t.Fatalf("unable to create link: %v", err)
}
defer cleanUp()
alice := newPersistentLinkHarness(
t, aliceLink, nil, restore,
)
if err := start(); err != nil {
t.Fatalf("unable to start test harness: %v", err)
}
var (
coreLink = alice.coreLink
registry = coreLink.cfg.Registry.(*mockInvoiceRegistry)
)
registry.settleChan = make(chan lntypes.Hash)
htlc, invoice := generateHtlcAndInvoice(t, 0)
// Convert into a hodl invoice and save the preimage for later.
preimage := invoice.Terms.PaymentPreimage
invoice.Terms.PaymentPreimage = channeldb.UnknownPreimage
// We must add the invoice to the registry, such that Alice
// expects this payment.
err = registry.AddInvoice(
*invoice, htlc.PaymentHash,
)
if err != nil {
t.Fatalf("unable to add invoice to registry: %v", err)
}
// Lock in htlc paying the hodl invoice.
sendHtlcBobToAlice(t, alice.link, bobChannel, htlc)
sendCommitSigBobToAlice(t, alice.link, bobChannel, 1)
receiveRevAndAckAliceToBob(t, alice.msgs, alice.link, bobChannel)
receiveCommitSigAliceToBob(t, alice.msgs, alice.link, bobChannel, 1)
sendRevAndAckBobToAlice(t, alice.link, bobChannel)
// We expect a call to the invoice registry to notify the arrival of the
// htlc.
<-registry.settleChan
// Increase block height. This height will be retrieved by the link
// after restart.
coreLink.cfg.Switch.bestHeight++
// Restart link.
alice.restart(false)
// Expect htlc to be reprocessed.
<-registry.settleChan
// Settle the invoice with the preimage.
registry.SettleHodlInvoice(preimage)
// Expect alice to send a settle and commitsig message to bob.
receiveSettleAliceToBob(t, alice.msgs, alice.link, bobChannel)
receiveCommitSigAliceToBob(t, alice.msgs, alice.link, bobChannel, 0)
// Stop the link
alice.link.Stop()
// Check that no unexpected messages were sent.
select {
case msg := <-alice.msgs:
t.Fatalf("did not expect message %T", msg)
default:
} }
} }

@ -11,6 +11,7 @@ import (
"net" "net"
"os" "os"
"runtime" "runtime"
"runtime/pprof"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -91,6 +92,8 @@ var (
}, },
LockTime: 5, LockTime: 5,
} }
testBatchTimeout = 50 * time.Millisecond
) )
var idSeqNum uint64 var idSeqNum uint64
@ -1051,7 +1054,6 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer,
decoder *mockIteratorDecoder) (ChannelLink, error) { decoder *mockIteratorDecoder) (ChannelLink, error) {
const ( const (
batchTimeout = 50 * time.Millisecond
fwdPkgTimeout = 15 * time.Second fwdPkgTimeout = 15 * time.Second
minFeeUpdateTimeout = 30 * time.Minute minFeeUpdateTimeout = 30 * time.Minute
maxFeeUpdateTimeout = 40 * time.Minute maxFeeUpdateTimeout = 40 * time.Minute
@ -1079,7 +1081,7 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer,
ChainEvents: &contractcourt.ChainEventSubscription{}, ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true, SyncStates: true,
BatchSize: 10, BatchSize: 10,
BatchTicker: ticker.NewForce(batchTimeout), BatchTicker: ticker.NewForce(testBatchTimeout),
FwdPkgGCTicker: ticker.NewForce(fwdPkgTimeout), FwdPkgGCTicker: ticker.NewForce(fwdPkgTimeout),
MinFeeUpdateTimeout: minFeeUpdateTimeout, MinFeeUpdateTimeout: minFeeUpdateTimeout,
MaxFeeUpdateTimeout: maxFeeUpdateTimeout, MaxFeeUpdateTimeout: maxFeeUpdateTimeout,
@ -1256,3 +1258,21 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
return paymentErr return paymentErr
} }
// timeout implements a test level timeout.
func timeout(t *testing.T) func() {
done := make(chan struct{})
go func() {
select {
case <-time.After(5 * time.Second):
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
panic("test timeout")
case <-done:
}
}()
return func() {
close(done)
}
}