contractcourt: create new channel arb test context struct

In this commit, we create a new channel arb test context struct as the
current `createTestChannelArbitrator` has several return parameters, and
upcoming changes will likely at first glance need to add one or more
additional parameters. Rather than extend the existing set of return
parameters, we opt to instead create this struct that wraps the existing
state.

Along the way we add several new utility methods to this context, and
use them in the existing tests where applicable:
  * `AssertStateTransitions`
  * `AssertState`
  * `Restart`
  * `CleanUp`
This commit is contained in:
Olaoluwa Osuntokun 2019-09-24 19:19:53 -07:00
parent 46e0117a4f
commit c3bf8d2054
No known key found for this signature in database
GPG Key ID: BC13F65E2DC84465

@ -3,12 +3,16 @@ package contractcourt
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os"
"path/filepath"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
@ -127,6 +131,25 @@ func (b *mockArbitratorLog) WipeHistory() error {
return nil return nil
} }
// testArbLog is a wrapper around an existing (ideally fully concrete
// ArbitratorLog) that lets us intercept certain calls like transitioning to a
// new state.
type testArbLog struct {
ArbitratorLog
newStates chan ArbitratorState
}
func (t *testArbLog) CommitState(s ArbitratorState) error {
if err := t.ArbitratorLog.CommitState(s); err != nil {
return err
}
t.newStates <- s
return nil
}
type mockChainIO struct{} type mockChainIO struct{}
var _ lnwallet.BlockChainIO = (*mockChainIO)(nil) var _ lnwallet.BlockChainIO = (*mockChainIO)(nil)
@ -148,9 +171,101 @@ func (*mockChainIO) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error)
return nil, nil return nil, nil
} }
func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator, type chanArbTestCtx struct {
chan struct{}, chan []ResolutionMsg, chan *chainntnfs.BlockEpoch, error) { t *testing.T
chanArb *ChannelArbitrator
cleanUp func()
resolvedChan chan struct{}
blockEpochs chan *chainntnfs.BlockEpoch
incubationRequests chan struct{}
resolutions chan []ResolutionMsg
log ArbitratorLog
}
func (c *chanArbTestCtx) CleanUp() {
if err := c.chanArb.Stop(); err != nil {
c.t.Fatalf("unable to stop chan arb: %v", err)
}
if c.cleanUp != nil {
c.cleanUp()
}
}
// AssertStateTransitions asserts that the state machine steps through the
// passed states in order.
func (c *chanArbTestCtx) AssertStateTransitions(expectedStates ...ArbitratorState) {
c.t.Helper()
var newStatesChan chan ArbitratorState
switch log := c.log.(type) {
case *mockArbitratorLog:
newStatesChan = log.newStates
case *testArbLog:
newStatesChan = log.newStates
default:
c.t.Fatalf("unable to assert state transitions with %T", log)
}
for _, exp := range expectedStates {
var state ArbitratorState
select {
case state = <-newStatesChan:
case <-time.After(5 * time.Second):
c.t.Fatalf("new state not received")
}
if state != exp {
c.t.Fatalf("expected new state %v, got %v", exp, state)
}
}
}
// AssertState checks that the ChannelArbitrator is in the state we expect it
// to be.
func (c *chanArbTestCtx) AssertState(expected ArbitratorState) {
if c.chanArb.state != expected {
c.t.Fatalf("expected state %v, was %v", expected, c.chanArb.state)
}
}
// Restart simulates a clean restart of the channel arbitrator, forcing it to
// walk through it's recovery logic. If this function returns nil, then a
// restart was successful. Note that the restart process keeps the log in
// place, in order to simulate proper persistence of the log. The caller can
// optionally provide a restart closure which will be executed before the
// resolver is started again, but after it is created.
func (c *chanArbTestCtx) Restart(restartClosure func(*chanArbTestCtx)) (*chanArbTestCtx, error) {
if err := c.chanArb.Stop(); err != nil {
return nil, err
}
newCtx, err := createTestChannelArbitrator(c.t, c.log)
if err != nil {
return nil, err
}
if restartClosure != nil {
restartClosure(newCtx)
}
if err := newCtx.chanArb.Start(); err != nil {
return nil, err
}
return newCtx, nil
}
func createTestChannelArbitrator(t *testing.T, log ArbitratorLog) (*chanArbTestCtx, error) {
blockEpochs := make(chan *chainntnfs.BlockEpoch) blockEpochs := make(chan *chainntnfs.BlockEpoch)
blockEpoch := &chainntnfs.BlockEpochEvent{ blockEpoch := &chainntnfs.BlockEpochEvent{
Epochs: blockEpochs, Epochs: blockEpochs,
@ -167,6 +282,7 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator,
} }
resolutionChan := make(chan []ResolutionMsg, 1) resolutionChan := make(chan []ResolutionMsg, 1)
incubateChan := make(chan struct{})
chainIO := &mockChainIO{} chainIO := &mockChainIO{}
chainArbCfg := ChainArbitratorConfig{ chainArbCfg := ChainArbitratorConfig{
@ -188,6 +304,8 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator,
IncubateOutputs: func(wire.OutPoint, *lnwallet.CommitOutputResolution, IncubateOutputs: func(wire.OutPoint, *lnwallet.CommitOutputResolution,
*lnwallet.OutgoingHtlcResolution, *lnwallet.OutgoingHtlcResolution,
*lnwallet.IncomingHtlcResolution, uint32) error { *lnwallet.IncomingHtlcResolution, uint32) error {
incubateChan <- struct{}{}
return nil return nil
}, },
} }
@ -224,17 +342,49 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator,
ChainEvents: chanEvents, ChainEvents: chanEvents,
} }
htlcSets := make(map[HtlcSetKey]htlcSet) var cleanUp func()
return NewChannelArbitrator(arbCfg, htlcSets, log), resolvedChan, if log == nil {
resolutionChan, blockEpochs, nil dbDir, err := ioutil.TempDir("", "chanArb")
} if err != nil {
return nil, err
}
dbPath := filepath.Join(dbDir, "testdb")
db, err := bbolt.Open(dbPath, 0600, nil)
if err != nil {
return nil, err
}
// assertState checks that the ChannelArbitrator is in the state we expect it backingLog, err := newBoltArbitratorLog(
// to be. db, arbCfg, chainhash.Hash{}, chanPoint,
func assertState(t *testing.T, c *ChannelArbitrator, expected ArbitratorState) { )
if c.state != expected { if err != nil {
t.Fatalf("expected state %v, was %v", expected, c.state) return nil, err
}
cleanUp = func() {
db.Close()
os.RemoveAll(dbDir)
}
log = &testArbLog{
ArbitratorLog: backingLog,
newStates: make(chan ArbitratorState),
}
} }
htlcSets := make(map[HtlcSetKey]htlcSet)
chanArb := NewChannelArbitrator(arbCfg, htlcSets, log)
return &chanArbTestCtx{
t: t,
chanArb: chanArb,
cleanUp: cleanUp,
resolvedChan: resolvedChan,
resolutions: resolutionChan,
blockEpochs: blockEpochs,
log: log,
incubationRequests: incubateChan,
}, nil
} }
// TestChannelArbitratorCooperativeClose tests that the ChannelArbitertor // TestChannelArbitratorCooperativeClose tests that the ChannelArbitertor
@ -246,22 +396,26 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
if err := chanArb.Start(); err != nil { if err := chanArbCtx.chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
defer chanArb.Stop() defer func() {
if err := chanArbCtx.chanArb.Stop(); err != nil {
t.Fatalf("unable to stop chan arb: %v", err)
}
}()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// We set up a channel to detect when MarkChannelClosed is called. // We set up a channel to detect when MarkChannelClosed is called.
closeInfos := make(chan *channeldb.ChannelCloseSummary) closeInfos := make(chan *channeldb.ChannelCloseSummary)
chanArb.cfg.MarkChannelClosed = func( chanArbCtx.chanArb.cfg.MarkChannelClosed = func(
closeInfo *channeldb.ChannelCloseSummary) error { closeInfo *channeldb.ChannelCloseSummary) error {
closeInfos <- closeInfo closeInfos <- closeInfo
return nil return nil
@ -272,7 +426,7 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) {
closeInfo := &CooperativeCloseInfo{ closeInfo := &CooperativeCloseInfo{
&channeldb.ChannelCloseSummary{}, &channeldb.ChannelCloseSummary{},
} }
chanArb.cfg.ChainEvents.CooperativeClosure <- closeInfo chanArbCtx.chanArb.cfg.ChainEvents.CooperativeClosure <- closeInfo
select { select {
case c := <-closeInfos: case c := <-closeInfos:
@ -285,31 +439,13 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) {
// It should mark the channel as resolved. // It should mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
} }
} }
func assertStateTransitions(t *testing.T, newStates <-chan ArbitratorState,
expectedStates ...ArbitratorState) {
t.Helper()
for _, exp := range expectedStates {
var state ArbitratorState
select {
case state = <-newStates:
case <-time.After(5 * time.Second):
t.Fatalf("new state not received")
}
if state != exp {
t.Fatalf("expected new state %v, got %v", exp, state)
}
}
}
// TestChannelArbitratorRemoteForceClose checks that the ChannelArbitrator goes // TestChannelArbitratorRemoteForceClose checks that the ChannelArbitrator goes
// through the expected states if a remote force close is observed in the // through the expected states if a remote force close is observed in the
// chain. // chain.
@ -319,10 +455,11 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -330,7 +467,7 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) {
defer chanArb.Stop() defer chanArb.Stop()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Send a remote force close event. // Send a remote force close event.
commitSpend := &chainntnfs.SpendDetail{ commitSpend := &chainntnfs.SpendDetail{
@ -351,13 +488,13 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) {
// It should transition StateDefault -> StateContractClosed -> // It should transition StateDefault -> StateContractClosed ->
// StateFullyResolved. // StateFullyResolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, log.newStates, StateContractClosed, StateFullyResolved, StateContractClosed, StateFullyResolved,
) )
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -373,10 +510,11 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -384,7 +522,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
defer chanArb.Stop() defer chanArb.Stop()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// We create a channel we can use to pause the ChannelArbitrator at the // We create a channel we can use to pause the ChannelArbitrator at the
// point where it broadcasts the close tx, and check its state. // point where it broadcasts the close tx, and check its state.
@ -411,7 +549,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
} }
// It should transition to StateBroadcastCommit. // It should transition to StateBroadcastCommit.
assertStateTransitions(t, log.newStates, StateBroadcastCommit) chanArbCtx.AssertStateTransitions(StateBroadcastCommit)
// When it is broadcasting the force close, its state should be // When it is broadcasting the force close, its state should be
// StateBroadcastCommit. // StateBroadcastCommit.
@ -426,7 +564,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
// After broadcasting, transition should be to // After broadcasting, transition should be to
// StateCommitmentBroadcasted. // StateCommitmentBroadcasted.
assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) chanArbCtx.AssertStateTransitions(StateCommitmentBroadcasted)
select { select {
case <-respChan: case <-respChan:
@ -445,7 +583,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
// After broadcasting the close tx, it should be in state // After broadcasting the close tx, it should be in state
// StateCommitmentBroadcasted. // StateCommitmentBroadcasted.
assertState(t, chanArb, StateCommitmentBroadcasted) chanArbCtx.AssertState(StateCommitmentBroadcasted)
// Now notify about the local force close getting confirmed. // Now notify about the local force close getting confirmed.
chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{
@ -458,12 +596,11 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
} }
// It should transition StateContractClosed -> StateFullyResolved. // It should transition StateContractClosed -> StateFullyResolved.
assertStateTransitions(t, log.newStates, StateContractClosed, chanArbCtx.AssertStateTransitions(StateContractClosed, StateFullyResolved)
StateFullyResolved)
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -479,10 +616,11 @@ func TestChannelArbitratorBreachClose(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -494,19 +632,19 @@ func TestChannelArbitratorBreachClose(t *testing.T) {
}() }()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Send a breach close event. // Send a breach close event.
chanArb.cfg.ChainEvents.ContractBreach <- &lnwallet.BreachRetribution{} chanArb.cfg.ChainEvents.ContractBreach <- &lnwallet.BreachRetribution{}
// It should transition StateDefault -> StateFullyResolved. // It should transition StateDefault -> StateFullyResolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, log.newStates, StateFullyResolved, StateFullyResolved,
) )
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -523,14 +661,15 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
resolvers: make(map[ContractResolver]struct{}), resolvers: make(map[ContractResolver]struct{}),
} }
chanArb, resolved, resolutions, _, err := createTestChannelArbitrator( chanArbCtx, err := createTestChannelArbitrator(
arbLog, t, arbLog,
) )
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
incubateChan := make(chan struct{}) incubateChan := make(chan struct{})
chanArb := chanArbCtx.chanArb
chanArb.cfg.IncubateOutputs = func(_ wire.OutPoint, chanArb.cfg.IncubateOutputs = func(_ wire.OutPoint,
_ *lnwallet.CommitOutputResolution, _ *lnwallet.CommitOutputResolution,
_ *lnwallet.OutgoingHtlcResolution, _ *lnwallet.OutgoingHtlcResolution,
@ -599,8 +738,8 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
// The force close request should trigger broadcast of the commitment // The force close request should trigger broadcast of the commitment
// transaction. // transaction.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, arbLog.newStates, StateBroadcastCommit, StateBroadcastCommit,
StateCommitmentBroadcasted, StateCommitmentBroadcasted,
) )
select { select {
@ -675,15 +814,15 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
}, },
} }
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, arbLog.newStates, StateContractClosed, StateContractClosed,
StateWaitingFullResolution, StateWaitingFullResolution,
) )
// We expect an immediate resolution message for the outgoing dust htlc. // We expect an immediate resolution message for the outgoing dust htlc.
// It is not resolvable on-chain. // It is not resolvable on-chain.
select { select {
case msgs := <-resolutions: case msgs := <-chanArbCtx.resolutions:
if len(msgs) != 1 { if len(msgs) != 1 {
t.Fatalf("expected 1 message, instead got %v", len(msgs)) t.Fatalf("expected 1 message, instead got %v", len(msgs))
} }
@ -723,7 +862,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
// Finally, we should also receive a resolution message instructing the // Finally, we should also receive a resolution message instructing the
// switch to cancel back the HTLC. // switch to cancel back the HTLC.
select { select {
case msgs := <-resolutions: case msgs := <-chanArbCtx.resolutions:
if len(msgs) != 1 { if len(msgs) != 1 {
t.Fatalf("expected 1 message, instead got %v", len(msgs)) t.Fatalf("expected 1 message, instead got %v", len(msgs))
} }
@ -740,7 +879,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
// to the second level. Channel arbitrator should still not be marked // to the second level. Channel arbitrator should still not be marked
// as resolved. // as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
t.Fatalf("channel resolved prematurely") t.Fatalf("channel resolved prematurely")
default: default:
} }
@ -749,9 +888,9 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
notifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} notifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx}
// At this point channel should be marked as resolved. // At this point channel should be marked as resolved.
assertStateTransitions(t, arbLog.newStates, StateFullyResolved) chanArbCtx.AssertStateTransitions(StateFullyResolved)
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
} }
@ -766,10 +905,11 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -777,7 +917,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
defer chanArb.Stop() defer chanArb.Stop()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Create a channel we can use to assert the state when it publishes // Create a channel we can use to assert the state when it publishes
// the close tx. // the close tx.
@ -804,7 +944,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
} }
// It should transition to StateBroadcastCommit. // It should transition to StateBroadcastCommit.
assertStateTransitions(t, log.newStates, StateBroadcastCommit) chanArbCtx.AssertStateTransitions(StateBroadcastCommit)
// We expect it to be in state StateBroadcastCommit when publishing // We expect it to be in state StateBroadcastCommit when publishing
// the force close. // the force close.
@ -819,7 +959,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
// After broadcasting, transition should be to // After broadcasting, transition should be to
// StateCommitmentBroadcasted. // StateCommitmentBroadcasted.
assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) chanArbCtx.AssertStateTransitions(StateCommitmentBroadcasted)
// Wait for a response to the force close. // Wait for a response to the force close.
select { select {
@ -838,7 +978,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
} }
// The state should be StateCommitmentBroadcasted. // The state should be StateCommitmentBroadcasted.
assertState(t, chanArb, StateCommitmentBroadcasted) chanArbCtx.AssertState(StateCommitmentBroadcasted)
// Now notify about the _REMOTE_ commitment getting confirmed. // Now notify about the _REMOTE_ commitment getting confirmed.
commitSpend := &chainntnfs.SpendDetail{ commitSpend := &chainntnfs.SpendDetail{
@ -853,12 +993,11 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
} }
// It should transition StateContractClosed -> StateFullyResolved. // It should transition StateContractClosed -> StateFullyResolved.
assertStateTransitions(t, log.newStates, StateContractClosed, chanArbCtx.AssertStateTransitions(StateContractClosed, StateFullyResolved)
StateFullyResolved)
// It should resolve. // It should resolve.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(15 * time.Second): case <-time.After(15 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -875,10 +1014,11 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -886,7 +1026,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
defer chanArb.Stop() defer chanArb.Stop()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Return ErrDoubleSpend when attempting to publish the tx. // Return ErrDoubleSpend when attempting to publish the tx.
stateChan := make(chan ArbitratorState) stateChan := make(chan ArbitratorState)
@ -912,7 +1052,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
} }
// It should transition to StateBroadcastCommit. // It should transition to StateBroadcastCommit.
assertStateTransitions(t, log.newStates, StateBroadcastCommit) chanArbCtx.AssertStateTransitions(StateBroadcastCommit)
// We expect it to be in state StateBroadcastCommit when publishing // We expect it to be in state StateBroadcastCommit when publishing
// the force close. // the force close.
@ -927,7 +1067,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
// After broadcasting, transition should be to // After broadcasting, transition should be to
// StateCommitmentBroadcasted. // StateCommitmentBroadcasted.
assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) chanArbCtx.AssertStateTransitions(StateCommitmentBroadcasted)
// Wait for a response to the force close. // Wait for a response to the force close.
select { select {
@ -946,7 +1086,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
} }
// The state should be StateCommitmentBroadcasted. // The state should be StateCommitmentBroadcasted.
assertState(t, chanArb, StateCommitmentBroadcasted) chanArbCtx.AssertState(StateCommitmentBroadcasted)
// Now notify about the _REMOTE_ commitment getting confirmed. // Now notify about the _REMOTE_ commitment getting confirmed.
commitSpend := &chainntnfs.SpendDetail{ commitSpend := &chainntnfs.SpendDetail{
@ -961,12 +1101,11 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
} }
// It should transition StateContractClosed -> StateFullyResolved. // It should transition StateContractClosed -> StateFullyResolved.
assertStateTransitions(t, log.newStates, StateContractClosed, chanArbCtx.AssertStateTransitions(StateContractClosed, StateFullyResolved)
StateFullyResolved)
// It should resolve. // It should resolve.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(15 * time.Second): case <-time.After(15 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -983,17 +1122,18 @@ func TestChannelArbitratorPersistence(t *testing.T) {
failLog: true, failLog: true,
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
// It should start in StateDefault. // It should start in StateDefault.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Send a remote force close event. // Send a remote force close event.
commitSpend := &chainntnfs.SpendDetail{ commitSpend := &chainntnfs.SpendDetail{
@ -1014,20 +1154,17 @@ func TestChannelArbitratorPersistence(t *testing.T) {
if log.state != StateDefault { if log.state != StateDefault {
t.Fatalf("expected to stay in StateDefault") t.Fatalf("expected to stay in StateDefault")
} }
chanArb.Stop()
// Create a new arbitrator with the same log. // Restart the channel arb, this'll use the same long and prior
chanArb, resolved, _, _, err = createTestChannelArbitrator(log) // context.
chanArbCtx, err = chanArbCtx.Restart(nil)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to restart channel arb: %v", err)
}
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
chanArb = chanArbCtx.chanArb
// Again, it should start up in the default state. // Again, it should start up in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Now we make the log succeed writing the resolutions, but fail when // Now we make the log succeed writing the resolutions, but fail when
// attempting to close the channel. // attempting to close the channel.
@ -1047,20 +1184,16 @@ func TestChannelArbitratorPersistence(t *testing.T) {
if log.state != StateDefault { if log.state != StateDefault {
t.Fatalf("expected to stay in StateDefault") t.Fatalf("expected to stay in StateDefault")
} }
chanArb.Stop()
// Create yet another arbitrator with the same log. // Restart once again to simulate yet another restart.
chanArb, resolved, _, _, err = createTestChannelArbitrator(log) chanArbCtx, err = chanArbCtx.Restart(nil)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to restart channel arb: %v", err)
}
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
chanArb = chanArbCtx.chanArb
// Starts out in StateDefault. // Starts out in StateDefault.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Now make fetching the resolutions fail. // Now make fetching the resolutions fail.
log.failFetch = fmt.Errorf("intentional fetch failure") log.failFetch = fmt.Errorf("intentional fetch failure")
@ -1070,9 +1203,7 @@ func TestChannelArbitratorPersistence(t *testing.T) {
// Since logging the resolutions and closing the channel now succeeds, // Since logging the resolutions and closing the channel now succeeds,
// it should advance to StateContractClosed. // it should advance to StateContractClosed.
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateContractClosed)
t, log.newStates, StateContractClosed,
)
// It should not advance further, however, as fetching resolutions // It should not advance further, however, as fetching resolutions
// failed. // failed.
@ -1084,24 +1215,18 @@ func TestChannelArbitratorPersistence(t *testing.T) {
// Create a new arbitrator, and now make fetching resolutions succeed. // Create a new arbitrator, and now make fetching resolutions succeed.
log.failFetch = nil log.failFetch = nil
chanArb, resolved, _, _, err = createTestChannelArbitrator(log) chanArbCtx, err = chanArbCtx.Restart(nil)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to restart channel arb: %v", err)
} }
defer chanArbCtx.CleanUp()
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer chanArb.Stop()
// Finally it should advance to StateFullyResolved. // Finally it should advance to StateFullyResolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateFullyResolved)
t, log.newStates, StateFullyResolved,
)
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -1119,17 +1244,18 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, _, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
// It should start in StateDefault. // It should start in StateDefault.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// We start by attempting a local force close. We'll return an // We start by attempting a local force close. We'll return an
// unexpected publication error, causing the state machine to halt. // unexpected publication error, causing the state machine to halt.
@ -1157,7 +1283,7 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) {
} }
// It should transition to StateBroadcastCommit. // It should transition to StateBroadcastCommit.
assertStateTransitions(t, log.newStates, StateBroadcastCommit) chanArbCtx.AssertStateTransitions(StateBroadcastCommit)
// We expect it to be in state StateBroadcastCommit when attempting // We expect it to be in state StateBroadcastCommit when attempting
// the force close. // the force close.
@ -1181,43 +1307,25 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) {
t.Fatalf("no response received") t.Fatalf("no response received")
} }
// Stop the channel abitrator.
if err := chanArb.Stop(); err != nil {
t.Fatal(err)
}
// We mimic that the channel is breached while the channel arbitrator // We mimic that the channel is breached while the channel arbitrator
// is down. This means that on restart it will be started with a // is down. This means that on restart it will be started with a
// pending close channel, of type BreachClose. // pending close channel, of type BreachClose.
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err = chanArbCtx.Restart(func(c *chanArbTestCtx) {
c.chanArb.cfg.IsPendingClose = true
c.chanArb.cfg.ClosingHeight = 100
c.chanArb.cfg.CloseType = channeldb.BreachClose
})
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
defer chanArbCtx.CleanUp()
chanArb.cfg.IsPendingClose = true
chanArb.cfg.ClosingHeight = 100
chanArb.cfg.CloseType = channeldb.BreachClose
// Start the channel abitrator again, and make sure it goes straight to
// state fully resolved, as in case of breach there is nothing to
// handle.
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer func() {
if err := chanArb.Stop(); err != nil {
t.Fatal(err)
}
}()
// Finally it should advance to StateFullyResolved. // Finally it should advance to StateFullyResolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateFullyResolved)
t, log.newStates, StateFullyResolved,
)
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -1286,6 +1394,8 @@ func TestChannelArbitratorCommitFailure(t *testing.T) {
} }
for _, test := range testCases { for _, test := range testCases {
test := test
log := &mockArbitratorLog{ log := &mockArbitratorLog{
state: StateDefault, state: StateDefault,
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
@ -1296,17 +1406,18 @@ func TestChannelArbitratorCommitFailure(t *testing.T) {
failCommitState: test.expectedStates[0], failCommitState: test.expectedStates[0],
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
// It should start in StateDefault. // It should start in StateDefault.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
closed := make(chan struct{}) closed := make(chan struct{})
chanArb.cfg.MarkChannelClosed = func( chanArb.cfg.MarkChannelClosed = func(
@ -1336,30 +1447,23 @@ func TestChannelArbitratorCommitFailure(t *testing.T) {
// Start the arbitrator again, with IsPendingClose reporting // Start the arbitrator again, with IsPendingClose reporting
// the channel closed in the database. // the channel closed in the database.
chanArb, resolved, _, _, err = createTestChannelArbitrator(log) log.failCommit = false
chanArbCtx, err = chanArbCtx.Restart(func(c *chanArbTestCtx) {
c.chanArb.cfg.IsPendingClose = true
c.chanArb.cfg.ClosingHeight = 100
c.chanArb.cfg.CloseType = test.closeType
})
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
log.failCommit = false
chanArb.cfg.IsPendingClose = true
chanArb.cfg.ClosingHeight = 100
chanArb.cfg.CloseType = test.closeType
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
// Since the channel is marked closed in the database, it // Since the channel is marked closed in the database, it
// should advance to the expected states. // should advance to the expected states.
assertStateTransitions( chanArbCtx.AssertStateTransitions(test.expectedStates...)
t, log.newStates, test.expectedStates...,
)
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -1382,11 +1486,12 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) {
failFetch: errNoResolutions, failFetch: errNoResolutions,
} }
chanArb, _, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
chanArb.cfg.IsPendingClose = true chanArb.cfg.IsPendingClose = true
chanArb.cfg.ClosingHeight = 100 chanArb.cfg.ClosingHeight = 100
chanArb.cfg.CloseType = channeldb.RemoteForceClose chanArb.cfg.CloseType = channeldb.RemoteForceClose
@ -1397,9 +1502,7 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) {
// It should not advance its state beyond StateContractClosed, since // It should not advance its state beyond StateContractClosed, since
// fetching resolutions fails. // fetching resolutions fails.
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateContractClosed)
t, log.newStates, StateContractClosed,
)
// It should not advance further, however, as fetching resolutions // It should not advance further, however, as fetching resolutions
// failed. // failed.
@ -1420,10 +1523,11 @@ func TestChannelArbitratorAlreadyForceClosed(t *testing.T) {
log := &mockArbitratorLog{ log := &mockArbitratorLog{
state: StateCommitmentBroadcasted, state: StateCommitmentBroadcasted,
} }
chanArb, _, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
@ -1515,12 +1619,13 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
resolvers: make(map[ContractResolver]struct{}), resolvers: make(map[ContractResolver]struct{}),
} }
chanArb, _, resolutions, blockEpochs, err := createTestChannelArbitrator( chanArbCtx, err := createTestChannelArbitrator(
arbLog, t, arbLog,
) )
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
@ -1568,7 +1673,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// now mine a block (height 5), which is 5 blocks away // now mine a block (height 5), which is 5 blocks away
// (our grace delta) from the expiry of that HTLC. // (our grace delta) from the expiry of that HTLC.
case testCase.htlcExpired: case testCase.htlcExpired:
blockEpochs <- &chainntnfs.BlockEpoch{Height: 5} chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 5}
// Otherwise, we'll just trigger a regular force close // Otherwise, we'll just trigger a regular force close
// request. // request.
@ -1584,8 +1689,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// determined that it needs to go to chain in order to // determined that it needs to go to chain in order to
// block off the redemption path so it can cancel the // block off the redemption path so it can cancel the
// incoming HTLC. // incoming HTLC.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, arbLog.newStates, StateBroadcastCommit, StateBroadcastCommit,
StateCommitmentBroadcasted, StateCommitmentBroadcasted,
) )
@ -1646,15 +1751,15 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// The channel arb should now transition to waiting // The channel arb should now transition to waiting
// until the HTLCs have been fully resolved. // until the HTLCs have been fully resolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, arbLog.newStates, StateContractClosed, StateContractClosed,
StateWaitingFullResolution, StateWaitingFullResolution,
) )
// Now that we've sent this signal, we should have that // Now that we've sent this signal, we should have that
// HTLC be cancelled back immediately. // HTLC be cancelled back immediately.
select { select {
case msgs := <-resolutions: case msgs := <-chanArbCtx.resolutions:
if len(msgs) != 1 { if len(msgs) != 1 {
t.Fatalf("expected 1 message, "+ t.Fatalf("expected 1 message, "+
"instead got %v", len(msgs)) "instead got %v", len(msgs))
@ -1672,10 +1777,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// so instead, we'll mine another block which'll cause // so instead, we'll mine another block which'll cause
// it to re-examine its state and realize there're no // it to re-examine its state and realize there're no
// more HTLCs. // more HTLCs.
blockEpochs <- &chainntnfs.BlockEpoch{Height: 6} chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 6}
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateFullyResolved)
t, arbLog.newStates, StateFullyResolved,
)
}) })
} }
} }