diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 8d139aa3..14553398 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -12,6 +12,65 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) +type mockArbitratorLog struct { + state ArbitratorState + newStates chan ArbitratorState +} + +// A compile time check to ensure mockArbitratorLog meets the ArbitratorLog +// interface. +var _ ArbitratorLog = (*mockArbitratorLog)(nil) + +func (b *mockArbitratorLog) CurrentState() (ArbitratorState, error) { + return b.state, nil +} + +func (b *mockArbitratorLog) CommitState(s ArbitratorState) error { + b.state = s + b.newStates <- s + return nil +} + +func (b *mockArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, error) { + var contracts []ContractResolver + return contracts, nil +} + +func (b *mockArbitratorLog) InsertUnresolvedContracts(resolvers ...ContractResolver) error { + return nil +} + +func (b *mockArbitratorLog) SwapContract(oldContract, newContract ContractResolver) error { + return nil +} + +func (b *mockArbitratorLog) ResolveContract(res ContractResolver) error { + return nil +} + +func (b *mockArbitratorLog) LogContractResolutions(c *ContractResolutions) error { + return nil +} + +func (b *mockArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) { + c := &ContractResolutions{} + + return c, nil +} + +func (b *mockArbitratorLog) LogChainActions(actions ChainActionMap) error { + return nil +} + +func (b *mockArbitratorLog) FetchChainActions() (ChainActionMap, error) { + actionsMap := make(ChainActionMap) + return actionsMap, nil +} + +func (b *mockArbitratorLog) WipeHistory() error { + return nil +} + type mockChainIO struct{} func (*mockChainIO) GetBestBlock() (*chainhash.Hash, int32, error) { @@ -31,7 +90,8 @@ func (*mockChainIO) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) return nil, nil } -func createTestChannelArbitrator() (*ChannelArbitrator, chan struct{}, func(), error) { +func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator, + chan struct{}, error) { blockEpoch := &chainntnfs.BlockEpochEvent{ Cancel: func() {}, } @@ -81,16 +141,8 @@ func createTestChannelArbitrator() (*ChannelArbitrator, chan struct{}, func(), e ChainArbitratorConfig: chainArbCfg, ChainEvents: chanEvents, } - testLog, cleanUp, err := newTestBoltArbLog( - testChainHash, testChanPoint1, - ) - if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create test log: %v", - err) - } - return NewChannelArbitrator(arbCfg, nil, testLog), - resolvedChan, cleanUp, nil + return NewChannelArbitrator(arbCfg, nil, log), resolvedChan, nil } // assertState checks that the ChannelArbitrator is in the state we expect it @@ -104,11 +156,15 @@ func assertState(t *testing.T, c *ChannelArbitrator, expected ArbitratorState) { // TestChannelArbitratorCooperativeClose tests that the ChannelArbitertor // correctly does nothing in case a cooperative close is confirmed. func TestChannelArbitratorCooperativeClose(t *testing.T) { - chanArb, _, cleanUp, err := createTestChannelArbitrator() + log := &mockArbitratorLog{ + state: StateDefault, + newStates: make(chan ArbitratorState, 5), + } + + chanArb, _, err := createTestChannelArbitrator(log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - defer cleanUp() if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -124,15 +180,37 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) { assertState(t, chanArb, StateDefault) } -// TestChannelArbitratorRemoteForceClose checks that the ChannelArbitrotor goes +func assertStateTransitions(t *testing.T, newStates <-chan ArbitratorState, + expectedStates ...ArbitratorState) { + t.Helper() + + for _, exp := range expectedStates { + var state ArbitratorState + select { + case state = <-newStates: + case <-time.After(5 * time.Second): + t.Fatalf("new state not received") + } + + if state != exp { + t.Fatalf("expected new state %v, got %v", exp, state) + } + } +} + +// TestChannelArbitratorRemoteForceClose checks that the ChannelArbitrator goes // through the expected states if a remote force close is observed in the // chain. func TestChannelArbitratorRemoteForceClose(t *testing.T) { - chanArb, resolved, cleanUp, err := createTestChannelArbitrator() + log := &mockArbitratorLog{ + state: StateDefault, + newStates: make(chan ArbitratorState, 5), + } + + chanArb, resolved, err := createTestChannelArbitrator(log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - defer cleanUp() if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -153,28 +231,34 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { } chanArb.cfg.ChainEvents.RemoteUnilateralClosure <- uniClose - // It should mark the channel as resolved. + // It should transition StateDefault -> StateContractClosed -> + // StateFullyResolved. + assertStateTransitions( + t, log.newStates, StateContractClosed, StateFullyResolved, + ) + + // It should alos mark the channel as resolved. select { case <-resolved: // Expected. case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") } - - // TODO: intermediate states. - // We expect the ChannelArbitrator to end up in the the resolved state. - assertState(t, chanArb, StateFullyResolved) } // TestChannelArbitratorLocalForceClose tests that the ChannelArbitrator goes // through the expected states in case we request it to force close the channel, // and the local force close event is observed in chain. func TestChannelArbitratorLocalForceClose(t *testing.T) { - chanArb, resolved, cleanUp, err := createTestChannelArbitrator() + log := &mockArbitratorLog{ + state: StateDefault, + newStates: make(chan ArbitratorState, 5), + } + + chanArb, resolved, err := createTestChannelArbitrator(log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - defer cleanUp() if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -208,6 +292,9 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { closeTx: respChan, } + // It should transition to StateBroadcastCommit. + assertStateTransitions(t, log.newStates, StateBroadcastCommit) + // When it is broadcasting the force close, its state should be // StateBroadcastCommit. select { @@ -219,6 +306,10 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { t.Fatalf("did not get state update") } + // After broadcasting, transition should be to + // StateCommitmentBroadcasted. + assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) + select { case <-respChan: case <-time.After(5 * time.Second): @@ -246,28 +337,33 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { HtlcResolutions: &lnwallet.HtlcResolutions{}, }, } - // It should mark the channel as resolved. + + // It should transition StateContractClosed -> StateFullyResolved. + assertStateTransitions(t, log.newStates, StateContractClosed, + StateFullyResolved) + + // It should also mark the channel as resolved. select { case <-resolved: // Expected. case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") } - - // And end up in the StateFullyResolved state. - // TODO: intermediate states as well. - assertState(t, chanArb, StateFullyResolved) } // TestChannelArbitratorLocalForceCloseRemoteConfiremd tests that the // ChannelArbitrator behaves as expected in the case where we request a local // force close, but a remote commitment ends up being confirmed in chain. func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { - chanArb, resolved, cleanUp, err := createTestChannelArbitrator() + log := &mockArbitratorLog{ + state: StateDefault, + newStates: make(chan ArbitratorState, 5), + } + + chanArb, resolved, err := createTestChannelArbitrator(log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - defer cleanUp() if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -301,6 +397,9 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { closeTx: respChan, } + // It should transition to StateBroadcastCommit. + assertStateTransitions(t, log.newStates, StateBroadcastCommit) + // We expect it to be in state StateBroadcastCommit when publishing // the force close. select { @@ -312,6 +411,10 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { t.Fatalf("no state update received") } + // After broadcasting, transition should be to + // StateCommitmentBroadcasted. + assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) + // Wait for a response to the force close. select { case <-respChan: @@ -341,6 +444,10 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { } chanArb.cfg.ChainEvents.RemoteUnilateralClosure <- uniClose + // It should transition StateContractClosed -> StateFullyResolved. + assertStateTransitions(t, log.newStates, StateContractClosed, + StateFullyResolved) + // It should resolve. select { case <-resolved: @@ -348,10 +455,6 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { case <-time.After(15 * time.Second): t.Fatalf("contract was not resolved") } - - // And we expect it to end up in StateFullyResolved. - // TODO: intermediate states as well. - assertState(t, chanArb, StateFullyResolved) } // TestChannelArbitratorLocalForceCloseDoubleSpend tests that the @@ -359,11 +462,15 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { // force close, but we fail broadcasting our commitment because a remote // commitment has already been published. func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { - chanArb, resolved, cleanUp, err := createTestChannelArbitrator() + log := &mockArbitratorLog{ + state: StateDefault, + newStates: make(chan ArbitratorState, 5), + } + + chanArb, resolved, err := createTestChannelArbitrator(log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - defer cleanUp() if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -396,6 +503,9 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { closeTx: respChan, } + // It should transition to StateBroadcastCommit. + assertStateTransitions(t, log.newStates, StateBroadcastCommit) + // We expect it to be in state StateBroadcastCommit when publishing // the force close. select { @@ -407,6 +517,10 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { t.Fatalf("no state update received") } + // After broadcasting, transition should be to + // StateCommitmentBroadcasted. + assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) + // Wait for a response to the force close. select { case <-respChan: @@ -436,6 +550,10 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { } chanArb.cfg.ChainEvents.RemoteUnilateralClosure <- uniClose + // It should transition StateContractClosed -> StateFullyResolved. + assertStateTransitions(t, log.newStates, StateContractClosed, + StateFullyResolved) + // It should resolve. select { case <-resolved: @@ -443,8 +561,4 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { case <-time.After(15 * time.Second): t.Fatalf("contract was not resolved") } - - // And we expect it to end up in StateFullyResolved. - // TODO: intermediate states as well. - assertState(t, chanArb, StateFullyResolved) }