diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index c12796af..fbad1465 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -357,6 +357,9 @@ func (c *ChannelArbitrator) Start() error { case channeldb.CooperativeClose: trigger = coopCloseTrigger + case channeldb.BreachClose: + trigger = breachCloseTrigger + case channeldb.LocalForceClose: trigger = localCloseTrigger diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index b1183260..b785b02a 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -1,6 +1,7 @@ package contractcourt import ( + "errors" "fmt" "sync" "testing" @@ -1107,6 +1108,122 @@ func TestChannelArbitratorPersistence(t *testing.T) { } } +// TestChannelArbitratorForceCloseBreachedChannel tests that the channel +// arbitrator is able to handle a channel in the process of being force closed +// is breached by the remote node. In these cases we expect the +// ChannelArbitrator to gracefully exit, as the breach is handled by other +// subsystems. +func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { + log := &mockArbitratorLog{ + state: StateDefault, + newStates: make(chan ArbitratorState, 5), + } + + chanArb, _, _, _, err := createTestChannelArbitrator(log) + if err != nil { + t.Fatalf("unable to create ChannelArbitrator: %v", err) + } + + if err := chanArb.Start(); err != nil { + t.Fatalf("unable to start ChannelArbitrator: %v", err) + } + + // It should start in StateDefault. + assertState(t, chanArb, StateDefault) + + // We start by attempting a local force close. We'll return an + // unexpected publication error, causing the state machine to halt. + expErr := errors.New("intentional publication error") + stateChan := make(chan ArbitratorState) + chanArb.cfg.PublishTx = func(*wire.MsgTx) error { + // When the force close tx is being broadcasted, check that the + // state is correct at that point. + select { + case stateChan <- chanArb.state: + case <-chanArb.quit: + return fmt.Errorf("exiting") + } + return expErr + } + + errChan := make(chan error, 1) + respChan := make(chan *wire.MsgTx, 1) + + // With the channel found, and the request crafted, we'll send over a + // force close request to the arbitrator that watches this channel. + chanArb.forceCloseReqs <- &forceCloseReq{ + errResp: errChan, + closeTx: respChan, + } + + // It should transition to StateBroadcastCommit. + assertStateTransitions(t, log.newStates, StateBroadcastCommit) + + // We expect it to be in state StateBroadcastCommit when attempting + // the force close. + select { + case state := <-stateChan: + if state != StateBroadcastCommit { + t.Fatalf("state during PublishTx was %v", state) + } + case <-time.After(15 * time.Second): + t.Fatalf("no state update received") + } + + // Make sure we get the expected error. + select { + case err := <-errChan: + if err != expErr { + t.Fatalf("unexpected error force closing channel: %v", + err) + } + case <-time.After(5 * time.Second): + t.Fatalf("no response received") + } + + // Stop the channel abitrator. + if err := chanArb.Stop(); err != nil { + t.Fatal(err) + } + + // We mimic that the channel is breached while the channel arbitrator + // is down. This means that on restart it will be started with a + // pending close channel, of type BreachClose. + chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + if err != nil { + t.Fatalf("unable to create ChannelArbitrator: %v", err) + } + + chanArb.cfg.IsPendingClose = true + chanArb.cfg.ClosingHeight = 100 + chanArb.cfg.CloseType = channeldb.BreachClose + + // Start the channel abitrator again, and make sure it goes straight to + // state fully resolved, as in case of breach there is nothing to + // handle. + if err := chanArb.Start(); err != nil { + t.Fatalf("unable to start ChannelArbitrator: %v", err) + } + defer func() { + if err := chanArb.Stop(); err != nil { + t.Fatal(err) + } + }() + + // Finally it should advance to StateFullyResolved. + assertStateTransitions( + t, log.newStates, StateFullyResolved, + ) + + // It should also mark the channel as resolved. + select { + case <-resolved: + // Expected. + case <-time.After(5 * time.Second): + t.Fatalf("contract was not resolved") + } +} + // TestChannelArbitratorCommitFailure tests that the channel arbitrator is able // to recover from a failed CommitState call at restart. func TestChannelArbitratorCommitFailure(t *testing.T) {