diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index e4127cf4..3df3fa24 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -100,11 +100,14 @@ type ChainArbitratorConfig struct { MarkLinkInactive func(wire.OutPoint) error // ContractBreach is a function closure that the ChainArbitrator will - // use to notify the breachArbiter about a contract breach. It should - // only return a non-nil error when the breachArbiter has preserved the - // necessary breach info for this channel point, and it is safe to mark - // the channel as pending close in the database. - ContractBreach func(wire.OutPoint, *lnwallet.BreachRetribution) error + // use to notify the breachArbiter about a contract breach. A callback + // should be passed that when called will mark the channel pending + // close in the databae. It should only return a non-nil error when the + // breachArbiter has preserved the necessary breach info for this + // channel point, and the callback has succeeded, meaning it is safe to + // stop watching the channel. + ContractBreach func(wire.OutPoint, *lnwallet.BreachRetribution, + func() error) error // IsOurAddress is a function that returns true if the passed address // is known to the underlying wallet. Otherwise, false should be @@ -488,8 +491,12 @@ func (c *ChainArbitrator) Start() error { notifier: c.cfg.Notifier, signer: c.cfg.Signer, isOurAddr: c.cfg.IsOurAddress, - contractBreach: func(retInfo *lnwallet.BreachRetribution) error { - return c.cfg.ContractBreach(chanPoint, retInfo) + contractBreach: func(retInfo *lnwallet.BreachRetribution, + markClosed func() error) error { + + return c.cfg.ContractBreach( + chanPoint, retInfo, markClosed, + ) }, extractStateNumHint: lnwallet.GetStateNumHint, }, @@ -1078,8 +1085,12 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error notifier: c.cfg.Notifier, signer: c.cfg.Signer, isOurAddr: c.cfg.IsOurAddress, - contractBreach: func(retInfo *lnwallet.BreachRetribution) error { - return c.cfg.ContractBreach(chanPoint, retInfo) + contractBreach: func(retInfo *lnwallet.BreachRetribution, + markClosed func() error) error { + + return c.cfg.ContractBreach( + chanPoint, retInfo, markClosed, + ) }, extractStateNumHint: lnwallet.GetStateNumHint, }, diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 6cd3c9dc..ebe7e5fd 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -150,10 +150,13 @@ type chainWatcherConfig struct { signer input.Signer // contractBreach is a method that will be called by the watcher if it - // detects that a contract breach transaction has been confirmed. Only - // when this method returns with a non-nil error it will be safe to mark - // the channel as pending close in the database. - contractBreach func(*lnwallet.BreachRetribution) error + // detects that a contract breach transaction has been confirmed. A + // callback should be passed that when called will mark the channel + // pending close in the database. It will only return a non-nil error + // when the breachArbiter has preserved the necessary breach info for + // this channel point, and the callback has succeeded, meaning it is + // safe to stop watching the channel. + contractBreach func(*lnwallet.BreachRetribution, func() error) error // isOurAddr is a function that returns true if the passed address is // known to us. @@ -1121,19 +1124,6 @@ func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail return spew.Sdump(retribution) })) - // Hand the retribution info over to the breach arbiter. - if err := c.cfg.contractBreach(retribution); err != nil { - log.Errorf("unable to hand breached contract off to "+ - "breachArbiter: %v", err) - return err - } - - // At this point, we've successfully received an ack for the breach - // close. We now construct and persist the close summary, marking the - // channel as pending force closed. - // - // TODO(roasbeef): instead mark we got all the monies? - // TODO(halseth): move responsibility to breach arbiter? settledBalance := remoteCommit.LocalBalance.ToSatoshis() closeSummary := channeldb.ChannelCloseSummary{ ChanPoint: c.cfg.chanState.FundingOutpoint, @@ -1160,14 +1150,31 @@ func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail closeSummary.LastChanSyncMsg = chanSync } - if err := c.cfg.chanState.CloseChannel( - &closeSummary, channeldb.ChanStatusRemoteCloseInitiator, - ); err != nil { - return err + // We create a function closure that will mark the channel as pending + // close in the database. We pass it to the contracBreach method such + // that it can ensure safe handoff of the breach before we close the + // channel. + markClosed := func() error { + // At this point, we've successfully received an ack for the + // breach close, and we can mark the channel as pending force + // closed. + if err := c.cfg.chanState.CloseChannel( + &closeSummary, channeldb.ChanStatusRemoteCloseInitiator, + ); err != nil { + return err + } + + log.Infof("Breached channel=%v marked pending-closed", + c.cfg.chanState.FundingOutpoint) + return nil } - log.Infof("Breached channel=%v marked pending-closed", - c.cfg.chanState.FundingOutpoint) + // Hand the retribution info over to the breach arbiter. + if err := c.cfg.contractBreach(retribution, markClosed); err != nil { + log.Errorf("unable to hand breached contract off to "+ + "breachArbiter: %v", err) + return err + } // With the event processed and channel closed, we'll now notify all // subscribers of the event. diff --git a/server.go b/server.go index 844e16d6..eddec220 100644 --- a/server.go +++ b/server.go @@ -949,7 +949,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, }, IsOurAddress: cc.Wallet.IsOurAddress, ContractBreach: func(chanPoint wire.OutPoint, - breachRet *lnwallet.BreachRetribution) error { + breachRet *lnwallet.BreachRetribution, + markClosed func() error) error { // processACK will handle the breachArbiter ACKing the // event. @@ -960,7 +961,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return } - finalErr <- nil + // If the breachArbiter successfully handled + // the event, we can mark the channel closed. + finalErr <- markClosed() } event := &ContractBreachEvent{ @@ -976,7 +979,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return ErrServerShuttingDown } - // Wait for the breachArbiter to ACK the event. + // We'll wait for a final error to be available, either + // from the breachArbiter or from our markClosed + // function closure. select { case err := <-finalErr: return err