diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 8be4fc0e..e4127cf4 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/kvdb" @@ -597,21 +598,63 @@ func (c *ChainArbitrator) Start() error { close(watcherErrs) }() + // stopAndLog is a helper function which shuts down the chain arb and + // logs errors if they occur. + stopAndLog := func() { + if err := c.Stop(); err != nil { + log.Errorf("ChainArbitrator could not shutdown: %v", err) + } + } + // Handle all errors returned from spawning our chain watchers. If any // of them failed, we will stop the chain arb to shutdown any active // goroutines. for err := range watcherErrs { if err != nil { - c.Stop() + stopAndLog() return err } } + // Before we start all of our arbitrators, we do a preliminary state + // lookup so that we can combine all of these lookups in a single db + // transaction. + var startStates map[wire.OutPoint]*chanArbStartState + + err = kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error { + for _, arbitrator := range c.activeChannels { + startState, err := arbitrator.getStartState(tx) + if err != nil { + return err + } + + startStates[arbitrator.cfg.ChanPoint] = startState + } + + return nil + }, func() { + startStates = make( + map[wire.OutPoint]*chanArbStartState, + len(c.activeChannels), + ) + }) + if err != nil { + stopAndLog() + return err + } + // Launch all the goroutines for each arbitrator so they can carry out // their duties. for _, arbitrator := range c.activeChannels { - if err := arbitrator.Start(); err != nil { - c.Stop() + startState, ok := startStates[arbitrator.cfg.ChanPoint] + if !ok { + stopAndLog() + return fmt.Errorf("arbitrator: %v has no start state", + arbitrator.cfg.ChanPoint) + } + + if err := arbitrator.Start(startState); err != nil { + stopAndLog() return err } } @@ -1060,7 +1103,7 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error // arbitrators, then launch it. c.activeChannels[chanPoint] = channelArb - if err := channelArb.Start(); err != nil { + if err := channelArb.Start(nil); err != nil { return err } diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 8b4b3df7..6d008051 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -379,16 +379,58 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, } } +// chanArbStartState contains the information from disk that we need to start +// up a channel arbitrator. +type chanArbStartState struct { + currentState ArbitratorState + commitSet *CommitSet +} + +// getStartState retrieves the information from disk that our channel arbitrator +// requires to start. +func (c *ChannelArbitrator) getStartState(tx kvdb.RTx) (*chanArbStartState, + error) { + + // First, we'll read our last state from disk, so our internal state + // machine can act accordingly. + state, err := c.log.CurrentState(tx) + if err != nil { + return nil, err + } + + // Next we'll fetch our confirmed commitment set. This will only exist + // if the channel has been closed out on chain for modern nodes. For + // older nodes, this won't be found at all, and will rely on the + // existing written chain actions. Additionally, if this channel hasn't + // logged any actions in the log, then this field won't be present. + commitSet, err := c.log.FetchConfirmedCommitSet(tx) + if err != nil && err != errNoCommitSet && err != errScopeBucketNoExist { + return nil, err + } + + return &chanArbStartState{ + currentState: state, + commitSet: commitSet, + }, nil +} + // Start starts all the goroutines that the ChannelArbitrator needs to operate. -func (c *ChannelArbitrator) Start() error { +// If takes a start state, which will be looked up on disk if it is not +// provided. +func (c *ChannelArbitrator) Start(state *chanArbStartState) error { if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { return nil } c.startTimestamp = c.cfg.Clock.Now() - var ( - err error - ) + // If the state passed in is nil, we look it up now. + if state == nil { + var err error + state, err = c.getStartState(nil) + if err != nil { + return err + } + } log.Debugf("Starting ChannelArbitrator(%v), htlc_set=%v", c.cfg.ChanPoint, newLogClosure(func() string { @@ -396,12 +438,8 @@ func (c *ChannelArbitrator) Start() error { }), ) - // First, we'll read our last state from disk, so our internal state - // machine can act accordingly. - c.state, err = c.log.CurrentState(nil) - if err != nil { - return err - } + // Set our state from our starting state. + c.state = state.currentState _, bestHeight, err := c.cfg.ChainIO.GetBestBlock() if err != nil { @@ -449,21 +487,11 @@ func (c *ChannelArbitrator) Start() error { "triggerHeight=%v", c.cfg.ChanPoint, c.state, trigger, triggerHeight) - // Next we'll fetch our confirmed commitment set. This will only exist - // if the channel has been closed out on chain for modern nodes. For - // older nodes, this won't be found at all, and will rely on the - // existing written chain actions. Additionally, if this channel hasn't - // logged any actions in the log, then this field won't be present. - commitSet, err := c.log.FetchConfirmedCommitSet(nil) - if err != nil && err != errNoCommitSet && err != errScopeBucketNoExist { - return err - } - // We'll now attempt to advance our state forward based on the current // on-chain state, and our set of active contracts. startingState := c.state nextState, _, err := c.advanceState( - triggerHeight, trigger, commitSet, + triggerHeight, trigger, state.commitSet, ) if err != nil { switch err { @@ -500,7 +528,8 @@ func (c *ChannelArbitrator) Start() error { // receive a chain event from the chain watcher than the // commitment has been confirmed on chain, and before we // advance our state step, we call InsertConfirmedCommitSet. - if err := c.relaunchResolvers(commitSet, triggerHeight); err != nil { + err := c.relaunchResolvers(state.commitSet, triggerHeight) + if err != nil { return err } } diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 3371998f..a4e9d003 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -275,7 +275,7 @@ func (c *chanArbTestCtx) Restart(restartClosure func(*chanArbTestCtx)) (*chanArb restartClosure(newCtx) } - if err := newCtx.chanArb.Start(); err != nil { + if err := newCtx.chanArb.Start(nil); err != nil { return nil, err } @@ -444,7 +444,7 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - if err := chanArbCtx.chanArb.Start(); err != nil { + if err := chanArbCtx.chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer func() { @@ -506,7 +506,7 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -561,7 +561,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -667,7 +667,7 @@ func TestChannelArbitratorBreachClose(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer func() { @@ -712,7 +712,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.Registry = &mockRegistry{} - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -984,7 +984,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1093,7 +1093,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1201,7 +1201,7 @@ func TestChannelArbitratorPersistence(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1325,7 +1325,7 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1487,7 +1487,7 @@ func TestChannelArbitratorCommitFailure(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1572,7 +1572,7 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) { chanArb.cfg.ClosingHeight = 100 chanArb.cfg.CloseType = channeldb.RemoteForceClose - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1604,7 +1604,7 @@ func TestChannelArbitratorAlreadyForceClosed(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1702,7 +1702,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1893,7 +1893,7 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { return false } - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer func() { @@ -2050,7 +2050,7 @@ func TestRemoteCloseInitiator(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start "+ "ChannelArbitrator: %v", err) } @@ -2120,7 +2120,7 @@ func TestChannelArbitratorAnchors(t *testing.T) { {}, {}, } - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer func() {