Merge pull request #3480 from Roasbeef/proper-resolution-supplements

contractcourt: supplement resolvers with confirmed commit set HTLCs
This commit is contained in:
Olaoluwa Osuntokun 2019-09-25 17:08:35 -07:00 committed by GitHub
commit c57bb9d86b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 369 additions and 225 deletions

@ -45,6 +45,16 @@ linters:
# trigger funlen problems that we may not want to solve at that time. # trigger funlen problems that we may not want to solve at that time.
- funlen - funlen
# Disable for now as we haven't yet tuned the sensitivity to our codebase
# yet. Enabling by default for example, would also force new contributors to
# potentially extensively refactor code, when they want to smaller change to
# land.
- gocyclo
# Instances of table driven tests that don't pre-allocate shouldn't trigger
# the linter.
- prealloc
issues: issues:
# Only show newly introduced problems. # Only show newly introduced problems.
new-from-rev: 01f696afce2f9c0d4ed854edefa3846891d01d8a new-from-rev: 01f696afce2f9c0d4ed854edefa3846891d01d8a

@ -415,7 +415,19 @@ func (c *ChannelArbitrator) Start() error {
if startingState == StateWaitingFullResolution && if startingState == StateWaitingFullResolution &&
nextState == StateWaitingFullResolution { nextState == StateWaitingFullResolution {
if err := c.relaunchResolvers(); err != nil { // In order to relaunch the resolvers, we'll need to fetch the
// set of HTLCs that were present in the commitment transaction
// at the time it was confirmed. commitSet.ConfCommitKey can't
// be nil at this point since we're in
// StateWaitingFullResolution. We can only be in
// StateWaitingFullResolution after we've transitioned from
// StateContractClosed which can only be triggered by the local
// or remote close trigger. This trigger is only fired when we
// receive a chain event from the chain watcher than the
// commitment has been confirmed on chain, and before we
// advance our state step, we call InsertConfirmedCommitSet.
confCommitSet := commitSet.HtlcSets[*commitSet.ConfCommitKey]
if err := c.relaunchResolvers(confCommitSet); err != nil {
c.cfg.BlockEpochs.Cancel() c.cfg.BlockEpochs.Cancel()
return err return err
} }
@ -431,7 +443,7 @@ func (c *ChannelArbitrator) Start() error {
// starting the ChannelArbitrator. This information should ideally be stored in // starting the ChannelArbitrator. This information should ideally be stored in
// the database, so this only serves as a intermediate work-around to prevent a // the database, so this only serves as a intermediate work-around to prevent a
// migration. // migration.
func (c *ChannelArbitrator) relaunchResolvers() error { func (c *ChannelArbitrator) relaunchResolvers(confirmedHTLCs []channeldb.HTLC) error {
// We'll now query our log to see if there are any active // We'll now query our log to see if there are any active
// unresolved contracts. If this is the case, then we'll // unresolved contracts. If this is the case, then we'll
// relaunch all contract resolvers. // relaunch all contract resolvers.
@ -456,31 +468,22 @@ func (c *ChannelArbitrator) relaunchResolvers() error {
// to prevent a db migration. We use all available htlc sets here in // to prevent a db migration. We use all available htlc sets here in
// order to ensure we have complete coverage. // order to ensure we have complete coverage.
htlcMap := make(map[wire.OutPoint]*channeldb.HTLC) htlcMap := make(map[wire.OutPoint]*channeldb.HTLC)
for _, htlcs := range c.activeHTLCs { for _, htlc := range confirmedHTLCs {
for _, htlc := range htlcs.incomingHTLCs { htlc := htlc
htlc := htlc outpoint := wire.OutPoint{
outpoint := wire.OutPoint{ Hash: commitHash,
Hash: commitHash, Index: uint32(htlc.OutputIndex),
Index: uint32(htlc.OutputIndex),
}
htlcMap[outpoint] = &htlc
}
for _, htlc := range htlcs.outgoingHTLCs {
htlc := htlc
outpoint := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
}
htlcMap[outpoint] = &htlc
} }
htlcMap[outpoint] = &htlc
} }
log.Infof("ChannelArbitrator(%v): relaunching %v contract "+ log.Infof("ChannelArbitrator(%v): relaunching %v contract "+
"resolvers", c.cfg.ChanPoint, len(unresolvedContracts)) "resolvers", c.cfg.ChanPoint, len(unresolvedContracts))
for _, resolver := range unresolvedContracts { for _, resolver := range unresolvedContracts {
c.supplementResolver(resolver, htlcMap) if err := c.supplementResolver(resolver, htlcMap); err != nil {
return err
}
} }
c.launchResolvers(unresolvedContracts) c.launchResolvers(unresolvedContracts)

@ -3,12 +3,16 @@ package contractcourt
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os"
"path/filepath"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
@ -127,6 +131,25 @@ func (b *mockArbitratorLog) WipeHistory() error {
return nil return nil
} }
// testArbLog is a wrapper around an existing (ideally fully concrete
// ArbitratorLog) that lets us intercept certain calls like transitioning to a
// new state.
type testArbLog struct {
ArbitratorLog
newStates chan ArbitratorState
}
func (t *testArbLog) CommitState(s ArbitratorState) error {
if err := t.ArbitratorLog.CommitState(s); err != nil {
return err
}
t.newStates <- s
return nil
}
type mockChainIO struct{} type mockChainIO struct{}
var _ lnwallet.BlockChainIO = (*mockChainIO)(nil) var _ lnwallet.BlockChainIO = (*mockChainIO)(nil)
@ -148,9 +171,101 @@ func (*mockChainIO) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error)
return nil, nil return nil, nil
} }
func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator, type chanArbTestCtx struct {
chan struct{}, chan []ResolutionMsg, chan *chainntnfs.BlockEpoch, error) { t *testing.T
chanArb *ChannelArbitrator
cleanUp func()
resolvedChan chan struct{}
blockEpochs chan *chainntnfs.BlockEpoch
incubationRequests chan struct{}
resolutions chan []ResolutionMsg
log ArbitratorLog
}
func (c *chanArbTestCtx) CleanUp() {
if err := c.chanArb.Stop(); err != nil {
c.t.Fatalf("unable to stop chan arb: %v", err)
}
if c.cleanUp != nil {
c.cleanUp()
}
}
// AssertStateTransitions asserts that the state machine steps through the
// passed states in order.
func (c *chanArbTestCtx) AssertStateTransitions(expectedStates ...ArbitratorState) {
c.t.Helper()
var newStatesChan chan ArbitratorState
switch log := c.log.(type) {
case *mockArbitratorLog:
newStatesChan = log.newStates
case *testArbLog:
newStatesChan = log.newStates
default:
c.t.Fatalf("unable to assert state transitions with %T", log)
}
for _, exp := range expectedStates {
var state ArbitratorState
select {
case state = <-newStatesChan:
case <-time.After(5 * time.Second):
c.t.Fatalf("new state not received")
}
if state != exp {
c.t.Fatalf("expected new state %v, got %v", exp, state)
}
}
}
// AssertState checks that the ChannelArbitrator is in the state we expect it
// to be.
func (c *chanArbTestCtx) AssertState(expected ArbitratorState) {
if c.chanArb.state != expected {
c.t.Fatalf("expected state %v, was %v", expected, c.chanArb.state)
}
}
// Restart simulates a clean restart of the channel arbitrator, forcing it to
// walk through it's recovery logic. If this function returns nil, then a
// restart was successful. Note that the restart process keeps the log in
// place, in order to simulate proper persistence of the log. The caller can
// optionally provide a restart closure which will be executed before the
// resolver is started again, but after it is created.
func (c *chanArbTestCtx) Restart(restartClosure func(*chanArbTestCtx)) (*chanArbTestCtx, error) {
if err := c.chanArb.Stop(); err != nil {
return nil, err
}
newCtx, err := createTestChannelArbitrator(c.t, c.log)
if err != nil {
return nil, err
}
if restartClosure != nil {
restartClosure(newCtx)
}
if err := newCtx.chanArb.Start(); err != nil {
return nil, err
}
return newCtx, nil
}
func createTestChannelArbitrator(t *testing.T, log ArbitratorLog) (*chanArbTestCtx, error) {
blockEpochs := make(chan *chainntnfs.BlockEpoch) blockEpochs := make(chan *chainntnfs.BlockEpoch)
blockEpoch := &chainntnfs.BlockEpochEvent{ blockEpoch := &chainntnfs.BlockEpochEvent{
Epochs: blockEpochs, Epochs: blockEpochs,
@ -167,6 +282,7 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator,
} }
resolutionChan := make(chan []ResolutionMsg, 1) resolutionChan := make(chan []ResolutionMsg, 1)
incubateChan := make(chan struct{})
chainIO := &mockChainIO{} chainIO := &mockChainIO{}
chainArbCfg := ChainArbitratorConfig{ chainArbCfg := ChainArbitratorConfig{
@ -188,6 +304,8 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator,
IncubateOutputs: func(wire.OutPoint, *lnwallet.CommitOutputResolution, IncubateOutputs: func(wire.OutPoint, *lnwallet.CommitOutputResolution,
*lnwallet.OutgoingHtlcResolution, *lnwallet.OutgoingHtlcResolution,
*lnwallet.IncomingHtlcResolution, uint32) error { *lnwallet.IncomingHtlcResolution, uint32) error {
incubateChan <- struct{}{}
return nil return nil
}, },
} }
@ -224,17 +342,49 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator,
ChainEvents: chanEvents, ChainEvents: chanEvents,
} }
htlcSets := make(map[HtlcSetKey]htlcSet) var cleanUp func()
return NewChannelArbitrator(arbCfg, htlcSets, log), resolvedChan, if log == nil {
resolutionChan, blockEpochs, nil dbDir, err := ioutil.TempDir("", "chanArb")
} if err != nil {
return nil, err
}
dbPath := filepath.Join(dbDir, "testdb")
db, err := bbolt.Open(dbPath, 0600, nil)
if err != nil {
return nil, err
}
// assertState checks that the ChannelArbitrator is in the state we expect it backingLog, err := newBoltArbitratorLog(
// to be. db, arbCfg, chainhash.Hash{}, chanPoint,
func assertState(t *testing.T, c *ChannelArbitrator, expected ArbitratorState) { )
if c.state != expected { if err != nil {
t.Fatalf("expected state %v, was %v", expected, c.state) return nil, err
}
cleanUp = func() {
db.Close()
os.RemoveAll(dbDir)
}
log = &testArbLog{
ArbitratorLog: backingLog,
newStates: make(chan ArbitratorState),
}
} }
htlcSets := make(map[HtlcSetKey]htlcSet)
chanArb := NewChannelArbitrator(arbCfg, htlcSets, log)
return &chanArbTestCtx{
t: t,
chanArb: chanArb,
cleanUp: cleanUp,
resolvedChan: resolvedChan,
resolutions: resolutionChan,
blockEpochs: blockEpochs,
log: log,
incubationRequests: incubateChan,
}, nil
} }
// TestChannelArbitratorCooperativeClose tests that the ChannelArbitertor // TestChannelArbitratorCooperativeClose tests that the ChannelArbitertor
@ -246,22 +396,26 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
if err := chanArb.Start(); err != nil { if err := chanArbCtx.chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
defer chanArb.Stop() defer func() {
if err := chanArbCtx.chanArb.Stop(); err != nil {
t.Fatalf("unable to stop chan arb: %v", err)
}
}()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// We set up a channel to detect when MarkChannelClosed is called. // We set up a channel to detect when MarkChannelClosed is called.
closeInfos := make(chan *channeldb.ChannelCloseSummary) closeInfos := make(chan *channeldb.ChannelCloseSummary)
chanArb.cfg.MarkChannelClosed = func( chanArbCtx.chanArb.cfg.MarkChannelClosed = func(
closeInfo *channeldb.ChannelCloseSummary) error { closeInfo *channeldb.ChannelCloseSummary) error {
closeInfos <- closeInfo closeInfos <- closeInfo
return nil return nil
@ -272,7 +426,7 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) {
closeInfo := &CooperativeCloseInfo{ closeInfo := &CooperativeCloseInfo{
&channeldb.ChannelCloseSummary{}, &channeldb.ChannelCloseSummary{},
} }
chanArb.cfg.ChainEvents.CooperativeClosure <- closeInfo chanArbCtx.chanArb.cfg.ChainEvents.CooperativeClosure <- closeInfo
select { select {
case c := <-closeInfos: case c := <-closeInfos:
@ -285,31 +439,13 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) {
// It should mark the channel as resolved. // It should mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
} }
} }
func assertStateTransitions(t *testing.T, newStates <-chan ArbitratorState,
expectedStates ...ArbitratorState) {
t.Helper()
for _, exp := range expectedStates {
var state ArbitratorState
select {
case state = <-newStates:
case <-time.After(5 * time.Second):
t.Fatalf("new state not received")
}
if state != exp {
t.Fatalf("expected new state %v, got %v", exp, state)
}
}
}
// TestChannelArbitratorRemoteForceClose checks that the ChannelArbitrator goes // TestChannelArbitratorRemoteForceClose checks that the ChannelArbitrator goes
// through the expected states if a remote force close is observed in the // through the expected states if a remote force close is observed in the
// chain. // chain.
@ -319,10 +455,11 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -330,7 +467,7 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) {
defer chanArb.Stop() defer chanArb.Stop()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Send a remote force close event. // Send a remote force close event.
commitSpend := &chainntnfs.SpendDetail{ commitSpend := &chainntnfs.SpendDetail{
@ -351,13 +488,13 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) {
// It should transition StateDefault -> StateContractClosed -> // It should transition StateDefault -> StateContractClosed ->
// StateFullyResolved. // StateFullyResolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, log.newStates, StateContractClosed, StateFullyResolved, StateContractClosed, StateFullyResolved,
) )
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -373,10 +510,11 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -384,7 +522,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
defer chanArb.Stop() defer chanArb.Stop()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// We create a channel we can use to pause the ChannelArbitrator at the // We create a channel we can use to pause the ChannelArbitrator at the
// point where it broadcasts the close tx, and check its state. // point where it broadcasts the close tx, and check its state.
@ -411,7 +549,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
} }
// It should transition to StateBroadcastCommit. // It should transition to StateBroadcastCommit.
assertStateTransitions(t, log.newStates, StateBroadcastCommit) chanArbCtx.AssertStateTransitions(StateBroadcastCommit)
// When it is broadcasting the force close, its state should be // When it is broadcasting the force close, its state should be
// StateBroadcastCommit. // StateBroadcastCommit.
@ -426,7 +564,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
// After broadcasting, transition should be to // After broadcasting, transition should be to
// StateCommitmentBroadcasted. // StateCommitmentBroadcasted.
assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) chanArbCtx.AssertStateTransitions(StateCommitmentBroadcasted)
select { select {
case <-respChan: case <-respChan:
@ -445,7 +583,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
// After broadcasting the close tx, it should be in state // After broadcasting the close tx, it should be in state
// StateCommitmentBroadcasted. // StateCommitmentBroadcasted.
assertState(t, chanArb, StateCommitmentBroadcasted) chanArbCtx.AssertState(StateCommitmentBroadcasted)
// Now notify about the local force close getting confirmed. // Now notify about the local force close getting confirmed.
chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{
@ -458,12 +596,11 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
} }
// It should transition StateContractClosed -> StateFullyResolved. // It should transition StateContractClosed -> StateFullyResolved.
assertStateTransitions(t, log.newStates, StateContractClosed, chanArbCtx.AssertStateTransitions(StateContractClosed, StateFullyResolved)
StateFullyResolved)
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -479,10 +616,11 @@ func TestChannelArbitratorBreachClose(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -494,19 +632,19 @@ func TestChannelArbitratorBreachClose(t *testing.T) {
}() }()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Send a breach close event. // Send a breach close event.
chanArb.cfg.ChainEvents.ContractBreach <- &lnwallet.BreachRetribution{} chanArb.cfg.ChainEvents.ContractBreach <- &lnwallet.BreachRetribution{}
// It should transition StateDefault -> StateFullyResolved. // It should transition StateDefault -> StateFullyResolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, log.newStates, StateFullyResolved, StateFullyResolved,
) )
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -517,29 +655,15 @@ func TestChannelArbitratorBreachClose(t *testing.T) {
// ChannelArbitrator goes through the expected states in case we request it to // ChannelArbitrator goes through the expected states in case we request it to
// force close a channel that still has an HTLC pending. // force close a channel that still has an HTLC pending.
func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
arbLog := &mockArbitratorLog{ // We create a new test context for this channel arb, notice that we
state: StateDefault, // pass in a nil ArbitratorLog which means that a default one backed by
newStates: make(chan ArbitratorState, 5), // a real DB will be created. We need this for our test as we want to
resolvers: make(map[ContractResolver]struct{}), // test proper restart recovery and resolver population.
} chanArbCtx, err := createTestChannelArbitrator(t, nil)
chanArb, resolved, resolutions, _, err := createTestChannelArbitrator(
arbLog,
)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
incubateChan := make(chan struct{})
chanArb.cfg.IncubateOutputs = func(_ wire.OutPoint,
_ *lnwallet.CommitOutputResolution,
_ *lnwallet.OutgoingHtlcResolution,
_ *lnwallet.IncomingHtlcResolution, _ uint32) error {
incubateChan <- struct{}{}
return nil
}
chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.PreimageDB = newMockWitnessBeacon()
chanArb.cfg.Registry = &mockRegistry{} chanArb.cfg.Registry = &mockRegistry{}
@ -558,9 +682,10 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
chanArb.UpdateContractSignals(signals) chanArb.UpdateContractSignals(signals)
// Add HTLC to channel arbitrator. // Add HTLC to channel arbitrator.
htlcAmt := 10000
htlc := channeldb.HTLC{ htlc := channeldb.HTLC{
Incoming: false, Incoming: false,
Amt: 10000, Amt: lnwire.MilliSatoshi(htlcAmt),
HtlcIndex: 99, HtlcIndex: 99,
} }
@ -599,8 +724,8 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
// The force close request should trigger broadcast of the commitment // The force close request should trigger broadcast of the commitment
// transaction. // transaction.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, arbLog.newStates, StateBroadcastCommit, StateBroadcastCommit,
StateCommitmentBroadcasted, StateCommitmentBroadcasted,
) )
select { select {
@ -636,8 +761,8 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
Index: 0, Index: 0,
} }
// Set up the outgoing resolution. Populate SignedTimeoutTx because // Set up the outgoing resolution. Populate SignedTimeoutTx because our
// our commitment transaction got confirmed. // commitment transaction got confirmed.
outgoingRes := lnwallet.OutgoingHtlcResolution{ outgoingRes := lnwallet.OutgoingHtlcResolution{
Expiry: 10, Expiry: 10,
SweepSignDesc: input.SignDescriptor{ SweepSignDesc: input.SignDescriptor{
@ -675,15 +800,15 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
}, },
} }
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, arbLog.newStates, StateContractClosed, StateContractClosed,
StateWaitingFullResolution, StateWaitingFullResolution,
) )
// We expect an immediate resolution message for the outgoing dust htlc. // We expect an immediate resolution message for the outgoing dust htlc.
// It is not resolvable on-chain. // It is not resolvable on-chain.
select { select {
case msgs := <-resolutions: case msgs := <-chanArbCtx.resolutions:
if len(msgs) != 1 { if len(msgs) != 1 {
t.Fatalf("expected 1 message, instead got %v", len(msgs)) t.Fatalf("expected 1 message, instead got %v", len(msgs))
} }
@ -696,34 +821,76 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
t.Fatalf("resolution msgs not sent") t.Fatalf("resolution msgs not sent")
} }
// We'll grab the old notifier here as our resolvers are still holding
// a reference to this instance, and a new one will be created when we
// restart the channel arb below.
oldNotifier := chanArb.cfg.Notifier.(*mockNotifier)
// At this point, in order to simulate a restart, we'll re-create the
// channel arbitrator. We do this to ensure that all information
// required to properly resolve this HTLC are populated.
if err := chanArb.Stop(); err != nil {
t.Fatalf("unable to stop chan arb: %v", err)
}
// We'll no re-create the resolver, notice that we use the existing
// arbLog so it carries over the same on-disk state.
chanArbCtxNew, err := chanArbCtx.Restart(nil)
if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err)
}
chanArb = chanArbCtxNew.chanArb
defer chanArbCtxNew.CleanUp()
// Post restart, it should be the case that our resolver was properly
// supplemented, and we only have a single resolver in the final set.
if len(chanArb.activeResolvers) != 1 {
t.Fatalf("expected single resolver, instead got: %v",
len(chanArb.activeResolvers))
}
// We'll now examine the in-memory state of the active resolvers to
// ensure t hey were populated properly.
resolver := chanArb.activeResolvers[0]
outgoingResolver, ok := resolver.(*htlcOutgoingContestResolver)
if !ok {
t.Fatalf("expected outgoing contest resolver, got %vT",
resolver)
}
// The resolver should have its htlcAmt field populated as it.
if int64(outgoingResolver.htlcAmt) != int64(htlcAmt) {
t.Fatalf("wrong htlc amount: expected %v, got %v,",
htlcAmt, int64(outgoingResolver.htlcAmt))
}
// htlcOutgoingContestResolver is now active and waiting for the HTLC to // htlcOutgoingContestResolver is now active and waiting for the HTLC to
// expire. It should not yet have passed it on for incubation. // expire. It should not yet have passed it on for incubation.
select { select {
case <-incubateChan: case <-chanArbCtx.incubationRequests:
t.Fatalf("contract should not be incubated yet") t.Fatalf("contract should not be incubated yet")
default: default:
} }
// Send a notification that the expiry height has been reached. // Send a notification that the expiry height has been reached.
notifier := chanArb.cfg.Notifier.(*mockNotifier) oldNotifier.epochChan <- &chainntnfs.BlockEpoch{Height: 10}
notifier.epochChan <- &chainntnfs.BlockEpoch{Height: 10}
// htlcOutgoingContestResolver is now transforming into a // htlcOutgoingContestResolver is now transforming into a
// htlcTimeoutResolver and should send the contract off for incubation. // htlcTimeoutResolver and should send the contract off for incubation.
select { select {
case <-incubateChan: case <-chanArbCtx.incubationRequests:
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("no response received") t.Fatalf("no response received")
} }
// Notify resolver that the HTLC output of the commitment has been // Notify resolver that the HTLC output of the commitment has been
// spent. // spent.
notifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} oldNotifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx}
// Finally, we should also receive a resolution message instructing the // Finally, we should also receive a resolution message instructing the
// switch to cancel back the HTLC. // switch to cancel back the HTLC.
select { select {
case msgs := <-resolutions: case msgs := <-chanArbCtx.resolutions:
if len(msgs) != 1 { if len(msgs) != 1 {
t.Fatalf("expected 1 message, instead got %v", len(msgs)) t.Fatalf("expected 1 message, instead got %v", len(msgs))
} }
@ -740,18 +907,18 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
// to the second level. Channel arbitrator should still not be marked // to the second level. Channel arbitrator should still not be marked
// as resolved. // as resolved.
select { select {
case <-resolved: case <-chanArbCtxNew.resolvedChan:
t.Fatalf("channel resolved prematurely") t.Fatalf("channel resolved prematurely")
default: default:
} }
// Notify resolver that the second level transaction is spent. // Notify resolver that the second level transaction is spent.
notifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} oldNotifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx}
// At this point channel should be marked as resolved. // At this point channel should be marked as resolved.
assertStateTransitions(t, arbLog.newStates, StateFullyResolved) chanArbCtxNew.AssertStateTransitions(StateFullyResolved)
select { select {
case <-resolved: case <-chanArbCtxNew.resolvedChan:
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
} }
@ -766,10 +933,11 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -777,7 +945,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
defer chanArb.Stop() defer chanArb.Stop()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Create a channel we can use to assert the state when it publishes // Create a channel we can use to assert the state when it publishes
// the close tx. // the close tx.
@ -804,7 +972,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
} }
// It should transition to StateBroadcastCommit. // It should transition to StateBroadcastCommit.
assertStateTransitions(t, log.newStates, StateBroadcastCommit) chanArbCtx.AssertStateTransitions(StateBroadcastCommit)
// We expect it to be in state StateBroadcastCommit when publishing // We expect it to be in state StateBroadcastCommit when publishing
// the force close. // the force close.
@ -819,7 +987,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
// After broadcasting, transition should be to // After broadcasting, transition should be to
// StateCommitmentBroadcasted. // StateCommitmentBroadcasted.
assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) chanArbCtx.AssertStateTransitions(StateCommitmentBroadcasted)
// Wait for a response to the force close. // Wait for a response to the force close.
select { select {
@ -838,7 +1006,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
} }
// The state should be StateCommitmentBroadcasted. // The state should be StateCommitmentBroadcasted.
assertState(t, chanArb, StateCommitmentBroadcasted) chanArbCtx.AssertState(StateCommitmentBroadcasted)
// Now notify about the _REMOTE_ commitment getting confirmed. // Now notify about the _REMOTE_ commitment getting confirmed.
commitSpend := &chainntnfs.SpendDetail{ commitSpend := &chainntnfs.SpendDetail{
@ -853,12 +1021,11 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
} }
// It should transition StateContractClosed -> StateFullyResolved. // It should transition StateContractClosed -> StateFullyResolved.
assertStateTransitions(t, log.newStates, StateContractClosed, chanArbCtx.AssertStateTransitions(StateContractClosed, StateFullyResolved)
StateFullyResolved)
// It should resolve. // It should resolve.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(15 * time.Second): case <-time.After(15 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -875,10 +1042,11 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
@ -886,7 +1054,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
defer chanArb.Stop() defer chanArb.Stop()
// It should start out in the default state. // It should start out in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Return ErrDoubleSpend when attempting to publish the tx. // Return ErrDoubleSpend when attempting to publish the tx.
stateChan := make(chan ArbitratorState) stateChan := make(chan ArbitratorState)
@ -912,7 +1080,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
} }
// It should transition to StateBroadcastCommit. // It should transition to StateBroadcastCommit.
assertStateTransitions(t, log.newStates, StateBroadcastCommit) chanArbCtx.AssertStateTransitions(StateBroadcastCommit)
// We expect it to be in state StateBroadcastCommit when publishing // We expect it to be in state StateBroadcastCommit when publishing
// the force close. // the force close.
@ -927,7 +1095,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
// After broadcasting, transition should be to // After broadcasting, transition should be to
// StateCommitmentBroadcasted. // StateCommitmentBroadcasted.
assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) chanArbCtx.AssertStateTransitions(StateCommitmentBroadcasted)
// Wait for a response to the force close. // Wait for a response to the force close.
select { select {
@ -946,7 +1114,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
} }
// The state should be StateCommitmentBroadcasted. // The state should be StateCommitmentBroadcasted.
assertState(t, chanArb, StateCommitmentBroadcasted) chanArbCtx.AssertState(StateCommitmentBroadcasted)
// Now notify about the _REMOTE_ commitment getting confirmed. // Now notify about the _REMOTE_ commitment getting confirmed.
commitSpend := &chainntnfs.SpendDetail{ commitSpend := &chainntnfs.SpendDetail{
@ -961,12 +1129,11 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
} }
// It should transition StateContractClosed -> StateFullyResolved. // It should transition StateContractClosed -> StateFullyResolved.
assertStateTransitions(t, log.newStates, StateContractClosed, chanArbCtx.AssertStateTransitions(StateContractClosed, StateFullyResolved)
StateFullyResolved)
// It should resolve. // It should resolve.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(15 * time.Second): case <-time.After(15 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -983,17 +1150,18 @@ func TestChannelArbitratorPersistence(t *testing.T) {
failLog: true, failLog: true,
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
// It should start in StateDefault. // It should start in StateDefault.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Send a remote force close event. // Send a remote force close event.
commitSpend := &chainntnfs.SpendDetail{ commitSpend := &chainntnfs.SpendDetail{
@ -1014,20 +1182,17 @@ func TestChannelArbitratorPersistence(t *testing.T) {
if log.state != StateDefault { if log.state != StateDefault {
t.Fatalf("expected to stay in StateDefault") t.Fatalf("expected to stay in StateDefault")
} }
chanArb.Stop()
// Create a new arbitrator with the same log. // Restart the channel arb, this'll use the same long and prior
chanArb, resolved, _, _, err = createTestChannelArbitrator(log) // context.
chanArbCtx, err = chanArbCtx.Restart(nil)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to restart channel arb: %v", err)
}
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
chanArb = chanArbCtx.chanArb
// Again, it should start up in the default state. // Again, it should start up in the default state.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Now we make the log succeed writing the resolutions, but fail when // Now we make the log succeed writing the resolutions, but fail when
// attempting to close the channel. // attempting to close the channel.
@ -1047,20 +1212,16 @@ func TestChannelArbitratorPersistence(t *testing.T) {
if log.state != StateDefault { if log.state != StateDefault {
t.Fatalf("expected to stay in StateDefault") t.Fatalf("expected to stay in StateDefault")
} }
chanArb.Stop()
// Create yet another arbitrator with the same log. // Restart once again to simulate yet another restart.
chanArb, resolved, _, _, err = createTestChannelArbitrator(log) chanArbCtx, err = chanArbCtx.Restart(nil)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to restart channel arb: %v", err)
}
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
chanArb = chanArbCtx.chanArb
// Starts out in StateDefault. // Starts out in StateDefault.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// Now make fetching the resolutions fail. // Now make fetching the resolutions fail.
log.failFetch = fmt.Errorf("intentional fetch failure") log.failFetch = fmt.Errorf("intentional fetch failure")
@ -1070,9 +1231,7 @@ func TestChannelArbitratorPersistence(t *testing.T) {
// Since logging the resolutions and closing the channel now succeeds, // Since logging the resolutions and closing the channel now succeeds,
// it should advance to StateContractClosed. // it should advance to StateContractClosed.
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateContractClosed)
t, log.newStates, StateContractClosed,
)
// It should not advance further, however, as fetching resolutions // It should not advance further, however, as fetching resolutions
// failed. // failed.
@ -1084,24 +1243,18 @@ func TestChannelArbitratorPersistence(t *testing.T) {
// Create a new arbitrator, and now make fetching resolutions succeed. // Create a new arbitrator, and now make fetching resolutions succeed.
log.failFetch = nil log.failFetch = nil
chanArb, resolved, _, _, err = createTestChannelArbitrator(log) chanArbCtx, err = chanArbCtx.Restart(nil)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to restart channel arb: %v", err)
} }
defer chanArbCtx.CleanUp()
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer chanArb.Stop()
// Finally it should advance to StateFullyResolved. // Finally it should advance to StateFullyResolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateFullyResolved)
t, log.newStates, StateFullyResolved,
)
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -1119,17 +1272,18 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) {
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
} }
chanArb, _, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
// It should start in StateDefault. // It should start in StateDefault.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
// We start by attempting a local force close. We'll return an // We start by attempting a local force close. We'll return an
// unexpected publication error, causing the state machine to halt. // unexpected publication error, causing the state machine to halt.
@ -1157,7 +1311,7 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) {
} }
// It should transition to StateBroadcastCommit. // It should transition to StateBroadcastCommit.
assertStateTransitions(t, log.newStates, StateBroadcastCommit) chanArbCtx.AssertStateTransitions(StateBroadcastCommit)
// We expect it to be in state StateBroadcastCommit when attempting // We expect it to be in state StateBroadcastCommit when attempting
// the force close. // the force close.
@ -1181,43 +1335,25 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) {
t.Fatalf("no response received") t.Fatalf("no response received")
} }
// Stop the channel abitrator.
if err := chanArb.Stop(); err != nil {
t.Fatal(err)
}
// We mimic that the channel is breached while the channel arbitrator // We mimic that the channel is breached while the channel arbitrator
// is down. This means that on restart it will be started with a // is down. This means that on restart it will be started with a
// pending close channel, of type BreachClose. // pending close channel, of type BreachClose.
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err = chanArbCtx.Restart(func(c *chanArbTestCtx) {
c.chanArb.cfg.IsPendingClose = true
c.chanArb.cfg.ClosingHeight = 100
c.chanArb.cfg.CloseType = channeldb.BreachClose
})
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
defer chanArbCtx.CleanUp()
chanArb.cfg.IsPendingClose = true
chanArb.cfg.ClosingHeight = 100
chanArb.cfg.CloseType = channeldb.BreachClose
// Start the channel abitrator again, and make sure it goes straight to
// state fully resolved, as in case of breach there is nothing to
// handle.
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer func() {
if err := chanArb.Stop(); err != nil {
t.Fatal(err)
}
}()
// Finally it should advance to StateFullyResolved. // Finally it should advance to StateFullyResolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateFullyResolved)
t, log.newStates, StateFullyResolved,
)
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -1286,6 +1422,8 @@ func TestChannelArbitratorCommitFailure(t *testing.T) {
} }
for _, test := range testCases { for _, test := range testCases {
test := test
log := &mockArbitratorLog{ log := &mockArbitratorLog{
state: StateDefault, state: StateDefault,
newStates: make(chan ArbitratorState, 5), newStates: make(chan ArbitratorState, 5),
@ -1296,17 +1434,18 @@ func TestChannelArbitratorCommitFailure(t *testing.T) {
failCommitState: test.expectedStates[0], failCommitState: test.expectedStates[0],
} }
chanArb, resolved, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
// It should start in StateDefault. // It should start in StateDefault.
assertState(t, chanArb, StateDefault) chanArbCtx.AssertState(StateDefault)
closed := make(chan struct{}) closed := make(chan struct{})
chanArb.cfg.MarkChannelClosed = func( chanArb.cfg.MarkChannelClosed = func(
@ -1336,30 +1475,23 @@ func TestChannelArbitratorCommitFailure(t *testing.T) {
// Start the arbitrator again, with IsPendingClose reporting // Start the arbitrator again, with IsPendingClose reporting
// the channel closed in the database. // the channel closed in the database.
chanArb, resolved, _, _, err = createTestChannelArbitrator(log) log.failCommit = false
chanArbCtx, err = chanArbCtx.Restart(func(c *chanArbTestCtx) {
c.chanArb.cfg.IsPendingClose = true
c.chanArb.cfg.ClosingHeight = 100
c.chanArb.cfg.CloseType = test.closeType
})
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
log.failCommit = false
chanArb.cfg.IsPendingClose = true
chanArb.cfg.ClosingHeight = 100
chanArb.cfg.CloseType = test.closeType
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
// Since the channel is marked closed in the database, it // Since the channel is marked closed in the database, it
// should advance to the expected states. // should advance to the expected states.
assertStateTransitions( chanArbCtx.AssertStateTransitions(test.expectedStates...)
t, log.newStates, test.expectedStates...,
)
// It should also mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-chanArbCtx.resolvedChan:
// Expected. // Expected.
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved") t.Fatalf("contract was not resolved")
@ -1382,11 +1514,12 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) {
failFetch: errNoResolutions, failFetch: errNoResolutions,
} }
chanArb, _, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
chanArb.cfg.IsPendingClose = true chanArb.cfg.IsPendingClose = true
chanArb.cfg.ClosingHeight = 100 chanArb.cfg.ClosingHeight = 100
chanArb.cfg.CloseType = channeldb.RemoteForceClose chanArb.cfg.CloseType = channeldb.RemoteForceClose
@ -1397,9 +1530,7 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) {
// It should not advance its state beyond StateContractClosed, since // It should not advance its state beyond StateContractClosed, since
// fetching resolutions fails. // fetching resolutions fails.
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateContractClosed)
t, log.newStates, StateContractClosed,
)
// It should not advance further, however, as fetching resolutions // It should not advance further, however, as fetching resolutions
// failed. // failed.
@ -1420,10 +1551,11 @@ func TestChannelArbitratorAlreadyForceClosed(t *testing.T) {
log := &mockArbitratorLog{ log := &mockArbitratorLog{
state: StateCommitmentBroadcasted, state: StateCommitmentBroadcasted,
} }
chanArb, _, _, _, err := createTestChannelArbitrator(log) chanArbCtx, err := createTestChannelArbitrator(t, log)
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
@ -1515,12 +1647,13 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
resolvers: make(map[ContractResolver]struct{}), resolvers: make(map[ContractResolver]struct{}),
} }
chanArb, _, resolutions, blockEpochs, err := createTestChannelArbitrator( chanArbCtx, err := createTestChannelArbitrator(
arbLog, t, arbLog,
) )
if err != nil { if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err) t.Fatalf("unable to create ChannelArbitrator: %v", err)
} }
chanArb := chanArbCtx.chanArb
if err := chanArb.Start(); err != nil { if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err) t.Fatalf("unable to start ChannelArbitrator: %v", err)
} }
@ -1568,7 +1701,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// now mine a block (height 5), which is 5 blocks away // now mine a block (height 5), which is 5 blocks away
// (our grace delta) from the expiry of that HTLC. // (our grace delta) from the expiry of that HTLC.
case testCase.htlcExpired: case testCase.htlcExpired:
blockEpochs <- &chainntnfs.BlockEpoch{Height: 5} chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 5}
// Otherwise, we'll just trigger a regular force close // Otherwise, we'll just trigger a regular force close
// request. // request.
@ -1584,8 +1717,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// determined that it needs to go to chain in order to // determined that it needs to go to chain in order to
// block off the redemption path so it can cancel the // block off the redemption path so it can cancel the
// incoming HTLC. // incoming HTLC.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, arbLog.newStates, StateBroadcastCommit, StateBroadcastCommit,
StateCommitmentBroadcasted, StateCommitmentBroadcasted,
) )
@ -1646,15 +1779,15 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// The channel arb should now transition to waiting // The channel arb should now transition to waiting
// until the HTLCs have been fully resolved. // until the HTLCs have been fully resolved.
assertStateTransitions( chanArbCtx.AssertStateTransitions(
t, arbLog.newStates, StateContractClosed, StateContractClosed,
StateWaitingFullResolution, StateWaitingFullResolution,
) )
// Now that we've sent this signal, we should have that // Now that we've sent this signal, we should have that
// HTLC be cancelled back immediately. // HTLC be cancelled back immediately.
select { select {
case msgs := <-resolutions: case msgs := <-chanArbCtx.resolutions:
if len(msgs) != 1 { if len(msgs) != 1 {
t.Fatalf("expected 1 message, "+ t.Fatalf("expected 1 message, "+
"instead got %v", len(msgs)) "instead got %v", len(msgs))
@ -1672,10 +1805,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// so instead, we'll mine another block which'll cause // so instead, we'll mine another block which'll cause
// it to re-examine its state and realize there're no // it to re-examine its state and realize there're no
// more HTLCs. // more HTLCs.
blockEpochs <- &chainntnfs.BlockEpoch{Height: 6} chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 6}
assertStateTransitions( chanArbCtx.AssertStateTransitions(StateFullyResolved)
t, arbLog.newStates, StateFullyResolved,
)
}) })
} }
} }