diff --git a/breacharbiter.go b/breacharbiter.go index 420ca665..a7ad89a4 100644 --- a/breacharbiter.go +++ b/breacharbiter.go @@ -46,17 +46,19 @@ var ( // ContractBreachEvent is an event the breachArbiter will receive in case a // contract breach is observed on-chain. It contains the necessary information -// to handle the breach, and a ProcessACK channel we will use to ACK the event +// to handle the breach, and a ProcessACK closure we will use to ACK the event // when we have safely stored all the necessary information. type ContractBreachEvent struct { // ChanPoint is the channel point of the breached channel. ChanPoint wire.OutPoint - // ProcessACK is an error channel where a nil error should be sent - // iff the breach retribution info is safely stored in the retribution + // ProcessACK is an closure that should be called with a nil error iff + // the breach retribution info is safely stored in the retribution // store. In case storing the information to the store fails, a non-nil - // error should be sent. - ProcessACK chan error + // error should be used. When this closure returns, it means that the + // contract court has marked the channel pending close in the DB, and + // it is safe for the BreachArbiter to carry on its duty. + ProcessACK func(error) // BreachRetribution is the information needed to act on this contract // breach. @@ -745,10 +747,8 @@ func (b *breachArbiter) handleBreachHandoff(breachEvent *ContractBreachEvent) { b.Unlock() brarLog.Errorf("Unable to check breach info in DB: %v", err) - select { - case breachEvent.ProcessACK <- err: - case <-b.quit: - } + // Notify about the failed lookup and return. + breachEvent.ProcessACK(err) return } @@ -757,11 +757,7 @@ func (b *breachArbiter) handleBreachHandoff(breachEvent *ContractBreachEvent) { // case we can safely ACK the handoff, and return. if breached { b.Unlock() - - select { - case breachEvent.ProcessACK <- nil: - case <-b.quit: - } + breachEvent.ProcessACK(nil) return } @@ -782,14 +778,10 @@ func (b *breachArbiter) handleBreachHandoff(breachEvent *ContractBreachEvent) { // acknowledgment back to the close observer with the error. If // the ack is successful, the close observer will mark the // channel as pending-closed in the channeldb. - select { - case breachEvent.ProcessACK <- err: - // Bail if we failed to persist retribution info. - if err != nil { - return - } + breachEvent.ProcessACK(err) - case <-b.quit: + // Bail if we failed to persist retribution info. + if err != nil { return } diff --git a/breacharbiter_test.go b/breacharbiter_test.go index ce1ae42c..0abdab2e 100644 --- a/breacharbiter_test.go +++ b/breacharbiter_test.go @@ -1059,9 +1059,12 @@ func TestBreachHandoffSuccess(t *testing.T) { // Signal a spend of the funding transaction and wait for the close // observer to exit. + processACK := make(chan error) breach := &ContractBreachEvent{ - ChanPoint: *chanPoint, - ProcessACK: make(chan error, 1), + ChanPoint: *chanPoint, + ProcessACK: func(brarErr error) { + processACK <- brarErr + }, BreachRetribution: &lnwallet.BreachRetribution{ BreachTransaction: bobClose.CloseTx, LocalOutputSignDesc: &input.SignDescriptor{ @@ -1075,7 +1078,7 @@ func TestBreachHandoffSuccess(t *testing.T) { // We'll also wait to consume the ACK back from the breach arbiter. select { - case err := <-breach.ProcessACK: + case err := <-processACK: if err != nil { t.Fatalf("handoff failed: %v", err) } @@ -1092,8 +1095,10 @@ func TestBreachHandoffSuccess(t *testing.T) { // already ACKed, the breach arbiter should immediately ACK and ignore // this event. breach = &ContractBreachEvent{ - ChanPoint: *chanPoint, - ProcessACK: make(chan error, 1), + ChanPoint: *chanPoint, + ProcessACK: func(brarErr error) { + processACK <- brarErr + }, BreachRetribution: &lnwallet.BreachRetribution{ BreachTransaction: bobClose.CloseTx, LocalOutputSignDesc: &input.SignDescriptor{ @@ -1108,7 +1113,7 @@ func TestBreachHandoffSuccess(t *testing.T) { // We'll also wait to consume the ACK back from the breach arbiter. select { - case err := <-breach.ProcessACK: + case err := <-processACK: if err != nil { t.Fatalf("handoff failed: %v", err) } @@ -1140,9 +1145,12 @@ func TestBreachHandoffFail(t *testing.T) { // Signal the notifier to dispatch spend notifications of the funding // transaction using the transaction from bob's closing summary. chanPoint := alice.ChanPoint + processACK := make(chan error) breach := &ContractBreachEvent{ - ChanPoint: *chanPoint, - ProcessACK: make(chan error, 1), + ChanPoint: *chanPoint, + ProcessACK: func(brarErr error) { + processACK <- brarErr + }, BreachRetribution: &lnwallet.BreachRetribution{ BreachTransaction: bobClose.CloseTx, LocalOutputSignDesc: &input.SignDescriptor{ @@ -1156,7 +1164,7 @@ func TestBreachHandoffFail(t *testing.T) { // We'll also wait to consume the ACK back from the breach arbiter. select { - case err := <-breach.ProcessACK: + case err := <-processACK: if err == nil { t.Fatalf("breach write should have failed") } @@ -1181,8 +1189,10 @@ func TestBreachHandoffFail(t *testing.T) { // Signal a spend of the funding transaction and wait for the close // observer to exit. This time we are allowing the handoff to succeed. breach = &ContractBreachEvent{ - ChanPoint: *chanPoint, - ProcessACK: make(chan error, 1), + ChanPoint: *chanPoint, + ProcessACK: func(brarErr error) { + processACK <- brarErr + }, BreachRetribution: &lnwallet.BreachRetribution{ BreachTransaction: bobClose.CloseTx, LocalOutputSignDesc: &input.SignDescriptor{ @@ -1196,7 +1206,7 @@ func TestBreachHandoffFail(t *testing.T) { contractBreaches <- breach select { - case err := <-breach.ProcessACK: + case err := <-processACK: if err != nil { t.Fatalf("handoff failed: %v", err) } @@ -1399,16 +1409,19 @@ func testBreachSpends(t *testing.T, test breachTest) { t.Fatalf("unable to create breach retribution: %v", err) } + processACK := make(chan error) breach := &ContractBreachEvent{ - ChanPoint: *chanPoint, - ProcessACK: make(chan error, 1), + ChanPoint: *chanPoint, + ProcessACK: func(brarErr error) { + processACK <- brarErr + }, BreachRetribution: retribution, } contractBreaches <- breach // We'll also wait to consume the ACK back from the breach arbiter. select { - case err := <-breach.ProcessACK: + case err := <-processACK: if err != nil { t.Fatalf("handoff failed: %v", err) } diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index e4127cf4..3df3fa24 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -100,11 +100,14 @@ type ChainArbitratorConfig struct { MarkLinkInactive func(wire.OutPoint) error // ContractBreach is a function closure that the ChainArbitrator will - // use to notify the breachArbiter about a contract breach. It should - // only return a non-nil error when the breachArbiter has preserved the - // necessary breach info for this channel point, and it is safe to mark - // the channel as pending close in the database. - ContractBreach func(wire.OutPoint, *lnwallet.BreachRetribution) error + // use to notify the breachArbiter about a contract breach. A callback + // should be passed that when called will mark the channel pending + // close in the databae. It should only return a non-nil error when the + // breachArbiter has preserved the necessary breach info for this + // channel point, and the callback has succeeded, meaning it is safe to + // stop watching the channel. + ContractBreach func(wire.OutPoint, *lnwallet.BreachRetribution, + func() error) error // IsOurAddress is a function that returns true if the passed address // is known to the underlying wallet. Otherwise, false should be @@ -488,8 +491,12 @@ func (c *ChainArbitrator) Start() error { notifier: c.cfg.Notifier, signer: c.cfg.Signer, isOurAddr: c.cfg.IsOurAddress, - contractBreach: func(retInfo *lnwallet.BreachRetribution) error { - return c.cfg.ContractBreach(chanPoint, retInfo) + contractBreach: func(retInfo *lnwallet.BreachRetribution, + markClosed func() error) error { + + return c.cfg.ContractBreach( + chanPoint, retInfo, markClosed, + ) }, extractStateNumHint: lnwallet.GetStateNumHint, }, @@ -1078,8 +1085,12 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error notifier: c.cfg.Notifier, signer: c.cfg.Signer, isOurAddr: c.cfg.IsOurAddress, - contractBreach: func(retInfo *lnwallet.BreachRetribution) error { - return c.cfg.ContractBreach(chanPoint, retInfo) + contractBreach: func(retInfo *lnwallet.BreachRetribution, + markClosed func() error) error { + + return c.cfg.ContractBreach( + chanPoint, retInfo, markClosed, + ) }, extractStateNumHint: lnwallet.GetStateNumHint, }, diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 6cd3c9dc..ebe7e5fd 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -150,10 +150,13 @@ type chainWatcherConfig struct { signer input.Signer // contractBreach is a method that will be called by the watcher if it - // detects that a contract breach transaction has been confirmed. Only - // when this method returns with a non-nil error it will be safe to mark - // the channel as pending close in the database. - contractBreach func(*lnwallet.BreachRetribution) error + // detects that a contract breach transaction has been confirmed. A + // callback should be passed that when called will mark the channel + // pending close in the database. It will only return a non-nil error + // when the breachArbiter has preserved the necessary breach info for + // this channel point, and the callback has succeeded, meaning it is + // safe to stop watching the channel. + contractBreach func(*lnwallet.BreachRetribution, func() error) error // isOurAddr is a function that returns true if the passed address is // known to us. @@ -1121,19 +1124,6 @@ func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail return spew.Sdump(retribution) })) - // Hand the retribution info over to the breach arbiter. - if err := c.cfg.contractBreach(retribution); err != nil { - log.Errorf("unable to hand breached contract off to "+ - "breachArbiter: %v", err) - return err - } - - // At this point, we've successfully received an ack for the breach - // close. We now construct and persist the close summary, marking the - // channel as pending force closed. - // - // TODO(roasbeef): instead mark we got all the monies? - // TODO(halseth): move responsibility to breach arbiter? settledBalance := remoteCommit.LocalBalance.ToSatoshis() closeSummary := channeldb.ChannelCloseSummary{ ChanPoint: c.cfg.chanState.FundingOutpoint, @@ -1160,14 +1150,31 @@ func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail closeSummary.LastChanSyncMsg = chanSync } - if err := c.cfg.chanState.CloseChannel( - &closeSummary, channeldb.ChanStatusRemoteCloseInitiator, - ); err != nil { - return err + // We create a function closure that will mark the channel as pending + // close in the database. We pass it to the contracBreach method such + // that it can ensure safe handoff of the breach before we close the + // channel. + markClosed := func() error { + // At this point, we've successfully received an ack for the + // breach close, and we can mark the channel as pending force + // closed. + if err := c.cfg.chanState.CloseChannel( + &closeSummary, channeldb.ChanStatusRemoteCloseInitiator, + ); err != nil { + return err + } + + log.Infof("Breached channel=%v marked pending-closed", + c.cfg.chanState.FundingOutpoint) + return nil } - log.Infof("Breached channel=%v marked pending-closed", - c.cfg.chanState.FundingOutpoint) + // Hand the retribution info over to the breach arbiter. + if err := c.cfg.contractBreach(retribution, markClosed); err != nil { + log.Errorf("unable to hand breached contract off to "+ + "breachArbiter: %v", err) + return err + } // With the event processed and channel closed, we'll now notify all // subscribers of the event. diff --git a/server.go b/server.go index 9d913be6..eddec220 100644 --- a/server.go +++ b/server.go @@ -949,10 +949,26 @@ func newServer(cfg *Config, listenAddrs []net.Addr, }, IsOurAddress: cc.Wallet.IsOurAddress, ContractBreach: func(chanPoint wire.OutPoint, - breachRet *lnwallet.BreachRetribution) error { + breachRet *lnwallet.BreachRetribution, + markClosed func() error) error { + + // processACK will handle the breachArbiter ACKing the + // event. + finalErr := make(chan error, 1) + processACK := func(brarErr error) { + if brarErr != nil { + finalErr <- brarErr + return + } + + // If the breachArbiter successfully handled + // the event, we can mark the channel closed. + finalErr <- markClosed() + } + event := &ContractBreachEvent{ ChanPoint: chanPoint, - ProcessACK: make(chan error, 1), + ProcessACK: processACK, BreachRetribution: breachRet, } @@ -963,9 +979,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return ErrServerShuttingDown } - // Wait for the breachArbiter to ACK the event. + // We'll wait for a final error to be available, either + // from the breachArbiter or from our markClosed + // function closure. select { - case err := <-event.ProcessACK: + case err := <-finalErr: return err case <-s.quit: return ErrServerShuttingDown