diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index a0e6a47c..0ba0d575 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -339,7 +339,7 @@ func (c *ChainArbitrator) Start() error { c.activeWatchers[channel.FundingOutpoint] = chainWatcher channelArb, err := newActiveChannelArbitrator( - channel, c, chainWatcher.SubscribeChannelEvents(), + channel, c, chainWatcher.SubscribeChannelEvents(false), ) if err != nil { return err @@ -667,7 +667,7 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error // We'll also create a new channel arbitrator instance using this new // channel, and our internal state. channelArb, err := newActiveChannelArbitrator( - newChan, c, chainWatcher.SubscribeChannelEvents(), + newChan, c, chainWatcher.SubscribeChannelEvents(false), ) if err != nil { return err @@ -687,12 +687,15 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error // SubscribeChannelEvents returns a new active subscription for the set of // possible on-chain events for a particular channel. The struct can be used by // callers to be notified whenever an event that changes the state of the -// channel on-chain occurs. +// channel on-chain occurs. If syncDispatch is true, then the sender of the +// notification will wait until an error is sent over the ProcessACK before +// modifying any database state. This allows callers to request a reliable hand +// off. // // TODO(roasbeef): can be used later to provide RPC hook for all channel // lifetimes func (c *ChainArbitrator) SubscribeChannelEvents( - chanPoint wire.OutPoint) (*ChainEventSubscription, error) { + chanPoint wire.OutPoint, syncDispatch bool) (*ChainEventSubscription, error) { // First, we'll attempt to look up the active watcher for this channel. // If we can't find it, then we'll return an error back to the caller. @@ -704,7 +707,7 @@ func (c *ChainArbitrator) SubscribeChannelEvents( // With the watcher located, we'll request for it to create a new chain // event subscription client. - return watcher.SubscribeChannelEvents(), nil + return watcher.SubscribeChannelEvents(syncDispatch), nil } // BeginCoopChanClose allows the initiator or responder to a cooperative diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 940debbb..0371490c 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -45,7 +45,10 @@ type ChainEventSubscription struct { // synchronize dispatch and processing of the notification with the act // of updating the state of the channel on disk. This ensures that the // event can be reliably handed off. - ProcessACK chan struct{} + // + // NOTE: This channel will only be used if the syncDispatch arg passed + // into the constructor is true. + ProcessACK chan error // Cancel cancels the subscription to the event stream for a particular // channel. This method should be called once the caller no longer needs to @@ -89,7 +92,7 @@ type chainWatcher struct { signer lnwallet.Signer // All the fields below are protected by this mutex. - sync.RWMutex + sync.Mutex // clientID is an ephemeral counter used to keep track of each // individual client subscription. @@ -207,13 +210,17 @@ func (c *chainWatcher) Stop() error { // SubscribeChannelEvents returns a n active subscription to the set of channel // events for the channel watched by this chain watcher. Once clients no longer // require the subscription, they should call the Cancel() method to allow the -// watcher to regain those committed resources. -func (c *chainWatcher) SubscribeChannelEvents() *ChainEventSubscription { - c.Lock() - defer c.Unlock() +// watcher to regain those committed resources. The syncDispatch bool indicates +// if the caller would like a synchronous dispatch of the notification. This +// means that the main chain watcher goroutine won't proceed with +// post-processing after the notification until the ProcessACK channel is sent +// upon. +func (c *chainWatcher) SubscribeChannelEvents(syncDispatch bool) *ChainEventSubscription { + c.Lock() clientID := c.clientID c.clientID++ + c.Unlock() log.Debugf("New ChainEventSubscription(id=%v) for ChannelPoint(%v)", clientID, c.chanState.FundingOutpoint) @@ -231,7 +238,13 @@ func (c *chainWatcher) SubscribeChannelEvents() *ChainEventSubscription { }, } + if syncDispatch { + sub.ProcessACK = make(chan error, 1) + } + + c.Lock() c.clientSubscriptions[clientID] = sub + c.Unlock() return sub } @@ -547,7 +560,6 @@ func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail } var ( - broadcastStateNum = remoteCommit.CommitHeight commitTxBroadcast = spendEvent.SpendingTx spendHeight = uint32(spendEvent.SpendingHeight) ) @@ -578,17 +590,20 @@ func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail } // Wait for the breach arbiter to ACK the handoff before - // marking the channel as pending force closed in channeldb. - select { - case <-sub.ProcessACK: - // Bail if the handoff failed. - if err != nil { - return fmt.Errorf("unable to handoff "+ - "retribution info: %v", err) - } + // marking the channel as pending force closed in channeldb, + // but only if the client requested a sync dispatch. + if sub.ProcessACK != nil { + select { + case err := <-sub.ProcessACK: + // Bail if the handoff failed. + if err != nil { + return fmt.Errorf("unable to handoff "+ + "retribution info: %v", err) + } - case <-c.quit: - return fmt.Errorf("quitting") + case <-c.quit: + return fmt.Errorf("quitting") + } } } c.Unlock() diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index aab4661d..ebd9ecd5 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -293,7 +293,7 @@ func (c *ChannelArbitrator) Stop() error { log.Debugf("Stopping ChannelArbitrator(%v)", c.cfg.ChanPoint) if c.cfg.ChainEvents.Cancel != nil { - c.cfg.ChainEvents.Cancel() + go c.cfg.ChainEvents.Cancel() } for _, activeResolver := range c.activeResolvers { @@ -1319,7 +1319,6 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32, // state, so we'll get the most up to date signals to we can // properly do our job. case signalUpdate := <-c.signalUpdates: - log.Tracef("ChannelArbitrator(%v) got new signal "+ "update!", c.cfg.ChanPoint) diff --git a/peer.go b/peer.go index 7436bc20..e142d426 100644 --- a/peer.go +++ b/peer.go @@ -370,7 +370,9 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error { // Register this new channel link with the HTLC Switch. This is // necessary to properly route multi-hop payments, and forward // new payments triggered by RPC clients. - chainEvents, err := p.server.chainArb.SubscribeChannelEvents(*chanPoint) + chainEvents, err := p.server.chainArb.SubscribeChannelEvents( + *chanPoint, false, + ) if err != nil { return err } @@ -1259,7 +1261,9 @@ out: peerLog.Errorf("unable to get best block: %v", err) continue } - chainEvents, err := p.server.chainArb.SubscribeChannelEvents(*chanPoint) + chainEvents, err := p.server.chainArb.SubscribeChannelEvents( + *chanPoint, false, + ) if err != nil { peerLog.Errorf("unable to subscribe to chain "+ "events: %v", err) diff --git a/server.go b/server.go index 321a9c88..4a4ed899 100644 --- a/server.go +++ b/server.go @@ -417,11 +417,16 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, GenSweepScript: func() ([]byte, error) { return newSweepPkScript(cc.wallet) }, - Notifier: cc.chainNotifier, - PublishTransaction: cc.wallet.PublishTransaction, - SubscribeChannelEvents: s.chainArb.SubscribeChannelEvents, - Signer: cc.wallet.Cfg.Signer, - Store: newRetributionStore(chanDB), + Notifier: cc.chainNotifier, + PublishTransaction: cc.wallet.PublishTransaction, + SubscribeChannelEvents: func(chanPoint wire.OutPoint) (*contractcourt.ChainEventSubscription, error) { + // We'll request a sync dispatch to ensure that the channel + // is only marked as closed *after* we update our internal + // state. + return s.chainArb.SubscribeChannelEvents(chanPoint, true) + }, + Signer: cc.wallet.Cfg.Signer, + Store: newRetributionStore(chanDB), }) // Create the connection manager which will be responsible for