diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index aeedf47c..aa77cbaa 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -14,10 +14,12 @@ import ( ) type mockArbitratorLog struct { - state ArbitratorState - newStates chan ArbitratorState - failLog bool - failFetch error + state ArbitratorState + newStates chan ArbitratorState + failLog bool + failFetch error + failCommit bool + failCommitState ArbitratorState } // A compile time check to ensure mockArbitratorLog meets the ArbitratorLog @@ -29,6 +31,10 @@ func (b *mockArbitratorLog) CurrentState() (ArbitratorState, error) { } func (b *mockArbitratorLog) CommitState(s ArbitratorState) error { + if b.failCommit && s == b.failCommitState { + return fmt.Errorf("intentional commit error at state %v", + b.failCommitState) + } b.state = s b.newStates <- s return nil @@ -732,3 +738,89 @@ func TestChannelArbitratorPersistence(t *testing.T) { t.Fatalf("contract was not resolved") } } + +// TestChannelArbitratorCommitFailure tests that the channel arbitrator is able +// to recover from a failed CommitState call at restart. +func TestChannelArbitratorCommitFailure(t *testing.T) { + // Start out with a log that will fail committing to StateContractClosed. + log := &mockArbitratorLog{ + state: StateDefault, + newStates: make(chan ArbitratorState, 5), + failCommit: true, + failCommitState: StateContractClosed, + } + + chanArb, resolved, err := createTestChannelArbitrator(log) + if err != nil { + t.Fatalf("unable to create ChannelArbitrator: %v", err) + } + + if err := chanArb.Start(); err != nil { + t.Fatalf("unable to start ChannelArbitrator: %v", err) + } + + // It should start in StateDefault. + assertState(t, chanArb, StateDefault) + + closed := make(chan struct{}) + chanArb.cfg.MarkChannelClosed = func(*channeldb.ChannelCloseSummary) error { + close(closed) + return nil + } + + // Send a remote force close event. + commitSpend := &chainntnfs.SpendDetail{ + SpenderTxHash: &chainhash.Hash{}, + } + + uniClose := &lnwallet.UnilateralCloseSummary{ + SpendDetail: commitSpend, + HtlcResolutions: &lnwallet.HtlcResolutions{}, + } + chanArb.cfg.ChainEvents.RemoteUnilateralClosure <- uniClose + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatalf("channel was not marked closed") + } + + // Since the channel was marked closed in the database, but the commit + // to the next state failed, the state should still be StateDefault. + time.Sleep(100 * time.Millisecond) + if log.state != StateDefault { + t.Fatalf("expected to stay in StateDefault") + } + chanArb.Stop() + + // Start the arbitrator again, with IsPendingClose reporting the + // channel closed in the database. + chanArb, resolved, err = createTestChannelArbitrator(log) + if err != nil { + t.Fatalf("unable to create ChannelArbitrator: %v", err) + } + + log.failCommit = false + + chanArb.cfg.IsPendingClose = true + chanArb.cfg.ClosingHeight = 100 + chanArb.cfg.CloseType = channeldb.RemoteForceClose + + if err := chanArb.Start(); err != nil { + t.Fatalf("unable to start ChannelArbitrator: %v", err) + } + + // Since the channel is marked closed in the database, it should + // advance to StateContractClosed and StateFullyResolved. + assertStateTransitions( + t, log.newStates, StateContractClosed, StateFullyResolved, + ) + + // It should also mark the channel as resolved. + select { + case <-resolved: + // Expected. + case <-time.After(5 * time.Second): + t.Fatalf("contract was not resolved") + } +}