Merge pull request #3512 from halseth/chanarb-breach-close

[contractcourt] Gracefully advance channel arbitrator state machine on breach
This commit is contained in:
Olaoluwa Osuntokun 2019-09-18 21:42:15 -07:00 committed by GitHub
commit e6cd88e1bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 212 additions and 21 deletions

@ -357,6 +357,9 @@ func (c *ChannelArbitrator) Start() error {
case channeldb.CooperativeClose: case channeldb.CooperativeClose:
trigger = coopCloseTrigger trigger = coopCloseTrigger
case channeldb.BreachClose:
trigger = breachCloseTrigger
case channeldb.LocalForceClose: case channeldb.LocalForceClose:
trigger = localCloseTrigger trigger = localCloseTrigger
@ -418,8 +421,6 @@ func (c *ChannelArbitrator) Start() error {
} }
} }
// TODO(roasbeef): cancel if breached
c.wg.Add(1) c.wg.Add(1)
go c.channelAttendant(bestHeight) go c.channelAttendant(bestHeight)
return nil return nil
@ -649,6 +650,11 @@ const (
// coopCloseTrigger is a transition trigger driven by a cooperative // coopCloseTrigger is a transition trigger driven by a cooperative
// close transaction being confirmed. // close transaction being confirmed.
coopCloseTrigger coopCloseTrigger
// breachCloseTrigger is a transition trigger driven by a remote breach
// being confirmed. In this case the channel arbitrator won't have to
// do anything, so we'll just clean up and exit gracefully.
breachCloseTrigger
) )
// String returns a human readable string describing the passed // String returns a human readable string describing the passed
@ -670,6 +676,9 @@ func (t transitionTrigger) String() string {
case coopCloseTrigger: case coopCloseTrigger:
return "coopCloseTrigger" return "coopCloseTrigger"
case breachCloseTrigger:
return "breachCloseTrigger"
default: default:
return "unknown trigger" return "unknown trigger"
} }
@ -748,8 +757,9 @@ func (c *ChannelArbitrator) stateStep(
// If the trigger is a cooperative close being confirmed, then // If the trigger is a cooperative close being confirmed, then
// we can go straight to StateFullyResolved, as there won't be // we can go straight to StateFullyResolved, as there won't be
// any contracts to resolve. // any contracts to resolve. The same is true in the case of a
case coopCloseTrigger: // breach.
case coopCloseTrigger, breachCloseTrigger:
nextState = StateFullyResolved nextState = StateFullyResolved
// Otherwise, if this state advance was triggered by a // Otherwise, if this state advance was triggered by a
@ -773,7 +783,7 @@ func (c *ChannelArbitrator) stateStep(
// StateBroadcastCommit via a user or chain trigger. On restart, // StateBroadcastCommit via a user or chain trigger. On restart,
// this state may be reexecuted after closing the channel, but // this state may be reexecuted after closing the channel, but
// failing to commit to StateContractClosed or // failing to commit to StateContractClosed or
// StateFullyResolved. In that case, one of the three close // StateFullyResolved. In that case, one of the four close
// triggers will be presented, signifying that we should skip // triggers will be presented, signifying that we should skip
// rebroadcasting, and go straight to resolving the on-chain // rebroadcasting, and go straight to resolving the on-chain
// contract or marking the channel resolved. // contract or marking the channel resolved.
@ -785,7 +795,7 @@ func (c *ChannelArbitrator) stateStep(
c.cfg.ChanPoint, trigger, StateContractClosed) c.cfg.ChanPoint, trigger, StateContractClosed)
return StateContractClosed, closeTx, nil return StateContractClosed, closeTx, nil
case coopCloseTrigger: case coopCloseTrigger, breachCloseTrigger:
log.Infof("ChannelArbitrator(%v): detected %s "+ log.Infof("ChannelArbitrator(%v): detected %s "+
"close after closing channel, fast-forwarding "+ "close after closing channel, fast-forwarding "+
"to %s to resolve contract", "to %s to resolve contract",
@ -861,7 +871,9 @@ func (c *ChannelArbitrator) stateStep(
c.cfg.ChanPoint, trigger) c.cfg.ChanPoint, trigger)
nextState = StateContractClosed nextState = StateContractClosed
case coopCloseTrigger: // If a coop close or breach was confirmed, jump straight to
// the fully resolved state.
case coopCloseTrigger, breachCloseTrigger:
log.Infof("ChannelArbitrator(%v): trigger %v, "+ log.Infof("ChannelArbitrator(%v): trigger %v, "+
" going to StateFullyResolved", " going to StateFullyResolved",
c.cfg.ChanPoint, trigger) c.cfg.ChanPoint, trigger)
@ -2026,7 +2038,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
uint32(bestHeight), chainTrigger, nil, uint32(bestHeight), chainTrigger, nil,
) )
if err != nil { if err != nil {
log.Errorf("unable to advance state: %v", err) log.Errorf("Unable to advance state: %v", err)
} }
// If as a result of this trigger, the contract is // If as a result of this trigger, the contract is
@ -2081,7 +2093,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
closeInfo.ChannelCloseSummary, closeInfo.ChannelCloseSummary,
) )
if err != nil { if err != nil {
log.Errorf("unable to mark channel closed: "+ log.Errorf("Unable to mark channel closed: "+
"%v", err) "%v", err)
return return
} }
@ -2092,7 +2104,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
closeInfo.CloseHeight, coopCloseTrigger, nil, closeInfo.CloseHeight, coopCloseTrigger, nil,
) )
if err != nil { if err != nil {
log.Errorf("unable to advance state: %v", err) log.Errorf("Unable to advance state: %v", err)
return return
} }
@ -2123,7 +2135,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
// actions on restart. // actions on restart.
err := c.log.LogContractResolutions(contractRes) err := c.log.LogContractResolutions(contractRes)
if err != nil { if err != nil {
log.Errorf("unable to write resolutions: %v", log.Errorf("Unable to write resolutions: %v",
err) err)
return return
} }
@ -2131,7 +2143,8 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
&closeInfo.CommitSet, &closeInfo.CommitSet,
) )
if err != nil { if err != nil {
log.Errorf("unable to write commit set: %v", err) log.Errorf("Unable to write commit set: %v",
err)
return return
} }
@ -2149,7 +2162,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
closeInfo.ChannelCloseSummary, closeInfo.ChannelCloseSummary,
) )
if err != nil { if err != nil {
log.Errorf("unable to mark "+ log.Errorf("Unable to mark "+
"channel closed: %v", err) "channel closed: %v", err)
return return
} }
@ -2161,7 +2174,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
localCloseTrigger, &closeInfo.CommitSet, localCloseTrigger, &closeInfo.CommitSet,
) )
if err != nil { if err != nil {
log.Errorf("unable to advance state: %v", err) log.Errorf("Unable to advance state: %v", err)
} }
// The remote party has broadcast the commitment on-chain. // The remote party has broadcast the commitment on-chain.
@ -2188,7 +2201,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
// actions on restart. // actions on restart.
err := c.log.LogContractResolutions(contractRes) err := c.log.LogContractResolutions(contractRes)
if err != nil { if err != nil {
log.Errorf("unable to write resolutions: %v", log.Errorf("Unable to write resolutions: %v",
err) err)
return return
} }
@ -2196,7 +2209,8 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
&uniClosure.CommitSet, &uniClosure.CommitSet,
) )
if err != nil { if err != nil {
log.Errorf("unable to write commit set: %v", err) log.Errorf("Unable to write commit set: %v",
err)
return return
} }
@ -2213,7 +2227,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
closeSummary := &uniClosure.ChannelCloseSummary closeSummary := &uniClosure.ChannelCloseSummary
err = c.cfg.MarkChannelClosed(closeSummary) err = c.cfg.MarkChannelClosed(closeSummary)
if err != nil { if err != nil {
log.Errorf("unable to mark channel closed: %v", log.Errorf("Unable to mark channel closed: %v",
err) err)
return return
} }
@ -2225,7 +2239,24 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
remoteCloseTrigger, &uniClosure.CommitSet, remoteCloseTrigger, &uniClosure.CommitSet,
) )
if err != nil { if err != nil {
log.Errorf("unable to advance state: %v", err) log.Errorf("Unable to advance state: %v", err)
}
// The remote has breached the channel. As this is handled by
// the ChainWatcher and BreachArbiter, we don't have to do
// anything in particular, so just advance our state and
// gracefully exit.
case <-c.cfg.ChainEvents.ContractBreach:
log.Infof("ChannelArbitrator(%v): remote party has "+
"breached channel!", c.cfg.ChanPoint)
// We'll advance our state machine until it reaches a
// terminal state.
_, _, err := c.advanceState(
uint32(bestHeight), breachCloseTrigger, nil,
)
if err != nil {
log.Errorf("Unable to advance state: %v", err)
} }
// A new contract has just been resolved, we'll now check our // A new contract has just been resolved, we'll now check our
@ -2239,7 +2270,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
uint32(bestHeight), chainTrigger, nil, uint32(bestHeight), chainTrigger, nil,
) )
if err != nil { if err != nil {
log.Errorf("unable to advance state: %v", err) log.Errorf("Unable to advance state: %v", err)
} }
// If we don't have anything further to do after // If we don't have anything further to do after
@ -2273,7 +2304,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
uint32(bestHeight), userTrigger, nil, uint32(bestHeight), userTrigger, nil,
) )
if err != nil { if err != nil {
log.Errorf("unable to advance state: %v", err) log.Errorf("Unable to advance state: %v", err)
} }
select { select {

@ -1,6 +1,7 @@
package contractcourt package contractcourt
import ( import (
"errors"
"fmt" "fmt"
"sync" "sync"
"testing" "testing"
@ -354,7 +355,7 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) {
t, log.newStates, StateContractClosed, StateFullyResolved, t, log.newStates, StateContractClosed, StateFullyResolved,
) )
// It should alos mark the channel as resolved. // It should also mark the channel as resolved.
select { select {
case <-resolved: case <-resolved:
// Expected. // Expected.
@ -469,6 +470,49 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
} }
} }
// TestChannelArbitratorBreachClose tests that the ChannelArbitrator goes
// through the expected states in case we notice a breach in the chain, and
// gracefully exits.
func TestChannelArbitratorBreachClose(t *testing.T) {
log := &mockArbitratorLog{
state: StateDefault,
newStates: make(chan ArbitratorState, 5),
}
chanArb, resolved, _, _, err := createTestChannelArbitrator(log)
if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err)
}
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer func() {
if err := chanArb.Stop(); err != nil {
t.Fatal(err)
}
}()
// It should start out in the default state.
assertState(t, chanArb, StateDefault)
// Send a breach close event.
chanArb.cfg.ChainEvents.ContractBreach <- &lnwallet.BreachRetribution{}
// It should transition StateDefault -> StateFullyResolved.
assertStateTransitions(
t, log.newStates, StateFullyResolved,
)
// It should also mark the channel as resolved.
select {
case <-resolved:
// Expected.
case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved")
}
}
// TestChannelArbitratorLocalForceClosePendingHtlc tests that the // TestChannelArbitratorLocalForceClosePendingHtlc tests that the
// ChannelArbitrator goes through the expected states in case we request it to // ChannelArbitrator goes through the expected states in case we request it to
// force close a channel that still has an HTLC pending. // force close a channel that still has an HTLC pending.
@ -1064,6 +1108,122 @@ func TestChannelArbitratorPersistence(t *testing.T) {
} }
} }
// TestChannelArbitratorForceCloseBreachedChannel tests that the channel
// arbitrator is able to handle a channel in the process of being force closed
// is breached by the remote node. In these cases we expect the
// ChannelArbitrator to gracefully exit, as the breach is handled by other
// subsystems.
func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) {
log := &mockArbitratorLog{
state: StateDefault,
newStates: make(chan ArbitratorState, 5),
}
chanArb, _, _, _, err := createTestChannelArbitrator(log)
if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err)
}
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
// It should start in StateDefault.
assertState(t, chanArb, StateDefault)
// We start by attempting a local force close. We'll return an
// unexpected publication error, causing the state machine to halt.
expErr := errors.New("intentional publication error")
stateChan := make(chan ArbitratorState)
chanArb.cfg.PublishTx = func(*wire.MsgTx) error {
// When the force close tx is being broadcasted, check that the
// state is correct at that point.
select {
case stateChan <- chanArb.state:
case <-chanArb.quit:
return fmt.Errorf("exiting")
}
return expErr
}
errChan := make(chan error, 1)
respChan := make(chan *wire.MsgTx, 1)
// With the channel found, and the request crafted, we'll send over a
// force close request to the arbitrator that watches this channel.
chanArb.forceCloseReqs <- &forceCloseReq{
errResp: errChan,
closeTx: respChan,
}
// It should transition to StateBroadcastCommit.
assertStateTransitions(t, log.newStates, StateBroadcastCommit)
// We expect it to be in state StateBroadcastCommit when attempting
// the force close.
select {
case state := <-stateChan:
if state != StateBroadcastCommit {
t.Fatalf("state during PublishTx was %v", state)
}
case <-time.After(15 * time.Second):
t.Fatalf("no state update received")
}
// Make sure we get the expected error.
select {
case err := <-errChan:
if err != expErr {
t.Fatalf("unexpected error force closing channel: %v",
err)
}
case <-time.After(5 * time.Second):
t.Fatalf("no response received")
}
// Stop the channel abitrator.
if err := chanArb.Stop(); err != nil {
t.Fatal(err)
}
// We mimic that the channel is breached while the channel arbitrator
// is down. This means that on restart it will be started with a
// pending close channel, of type BreachClose.
chanArb, resolved, _, _, err := createTestChannelArbitrator(log)
if err != nil {
t.Fatalf("unable to create ChannelArbitrator: %v", err)
}
chanArb.cfg.IsPendingClose = true
chanArb.cfg.ClosingHeight = 100
chanArb.cfg.CloseType = channeldb.BreachClose
// Start the channel abitrator again, and make sure it goes straight to
// state fully resolved, as in case of breach there is nothing to
// handle.
if err := chanArb.Start(); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer func() {
if err := chanArb.Stop(); err != nil {
t.Fatal(err)
}
}()
// Finally it should advance to StateFullyResolved.
assertStateTransitions(
t, log.newStates, StateFullyResolved,
)
// It should also mark the channel as resolved.
select {
case <-resolved:
// Expected.
case <-time.After(5 * time.Second):
t.Fatalf("contract was not resolved")
}
}
// TestChannelArbitratorCommitFailure tests that the channel arbitrator is able // TestChannelArbitratorCommitFailure tests that the channel arbitrator is able
// to recover from a failed CommitState call at restart. // to recover from a failed CommitState call at restart.
func TestChannelArbitratorCommitFailure(t *testing.T) { func TestChannelArbitratorCommitFailure(t *testing.T) {