From 89fe21b79aca40f29fe58d4de3fbc05a04c20467 Mon Sep 17 00:00:00 2001 From: carla Date: Thu, 12 Nov 2020 15:23:24 +0200 Subject: [PATCH 1/3] contractcourt: use single block subscription for block epochs --- contractcourt/chain_arbitrator.go | 134 +++++++++++++++++++---- contractcourt/channel_arbitrator.go | 26 ++--- contractcourt/channel_arbitrator_test.go | 18 +-- 3 files changed, 128 insertions(+), 50 deletions(-) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 79c722cc..8be4fc0e 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -312,18 +312,8 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel, log.Tracef("Creating ChannelArbitrator for ChannelPoint(%v)", channel.FundingOutpoint) - // We'll start by registering for a block epoch notifications so this - // channel can keep track of the current state of the main chain. - // // TODO(roasbeef): fetch best height (or pass in) so can ensure block // epoch delivers all the notifications to - // - // TODO(roasbeef): instead 1 block epoch that multi-plexes to the rest? - // * reduces the number of goroutines - blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return nil, err - } chanPoint := channel.FundingOutpoint @@ -333,7 +323,6 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel, ChanPoint: chanPoint, Channel: c.getArbChannel(channel), ShortChanID: channel.ShortChanID(), - BlockEpochs: blockEpoch, MarkCommitmentBroadcasted: channel.MarkCommitmentBroadcasted, MarkChannelClosed: func(summary *channeldb.ChannelCloseSummary, @@ -369,7 +358,6 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel, c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint, ) if err != nil { - blockEpoch.Cancel() return nil, err } @@ -385,7 +373,6 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel, pendingRemoteCommitment, err := channel.RemoteCommitChainTip() if err != nil && err != channeldb.ErrNoPendingCommit { - blockEpoch.Cancel() return nil, err } if pendingRemoteCommitment != nil { @@ -545,18 +532,12 @@ func (c *ChainArbitrator) Start() error { // the chain any longer, only resolve the contracts on the confirmed // commitment. for _, closeChanInfo := range closingChannels { - blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - // We can leave off the CloseContract and ForceCloseChan // methods as the channel is already closed at this point. chanPoint := closeChanInfo.ChanPoint arbCfg := ChannelArbitratorConfig{ ChanPoint: chanPoint, ShortChanID: closeChanInfo.ShortChanID, - BlockEpochs: blockEpoch, ChainArbitratorConfig: c.cfg, ChainEvents: &ChainEventSubscription{}, IsPendingClose: true, @@ -574,7 +555,6 @@ func (c *ChainArbitrator) Start() error { c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint, ) if err != nil { - blockEpoch.Cancel() return err } arbCfg.MarkChannelResolved = func() error { @@ -627,8 +607,8 @@ func (c *ChainArbitrator) Start() error { } } - // Finally, we'll launch all the goroutines for each arbitrator so they - // can carry out their duties. + // Launch all the goroutines for each arbitrator so they can carry out + // their duties. for _, arbitrator := range c.activeChannels { if err := arbitrator.Start(); err != nil { c.Stop() @@ -636,11 +616,121 @@ func (c *ChainArbitrator) Start() error { } } + // Subscribe to a single stream of block epoch notifications that we + // will dispatch to all active arbitrators. + blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return err + } + + // Start our goroutine which will dispatch blocks to each arbitrator. + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.dispatchBlocks(blockEpoch) + }() + // TODO(roasbeef): eventually move all breach watching here return nil } +// blockRecipient contains the information we need to dispatch a block to a +// channel arbitrator. +type blockRecipient struct { + // chanPoint is the funding outpoint of the channel. + chanPoint wire.OutPoint + + // blocks is the channel that new block heights are sent into. This + // channel should be sufficiently buffered as to not block the sender. + blocks chan<- int32 + + // quit is closed if the receiving entity is shutting down. + quit chan struct{} +} + +// dispatchBlocks consumes a block epoch notification stream and dispatches +// blocks to each of the chain arb's active channel arbitrators. This function +// must be run in a goroutine. +func (c *ChainArbitrator) dispatchBlocks( + blockEpoch *chainntnfs.BlockEpochEvent) { + + // getRecipients is a helper function which acquires the chain arb + // lock and returns a set of block recipients which can be used to + // dispatch blocks. + getRecipients := func() []blockRecipient { + c.Lock() + blocks := make([]blockRecipient, 0, len(c.activeChannels)) + for _, channel := range c.activeChannels { + blocks = append(blocks, blockRecipient{ + chanPoint: channel.cfg.ChanPoint, + blocks: channel.blocks, + quit: channel.quit, + }) + } + c.Unlock() + + return blocks + } + + // On exit, cancel our blocks subscription and close each block channel + // so that the arbitrators know they will no longer be receiving blocks. + defer func() { + blockEpoch.Cancel() + + recipients := getRecipients() + for _, recipient := range recipients { + close(recipient.blocks) + } + }() + + // Consume block epochs until we receive the instruction to shutdown. + for { + select { + // Consume block epochs, exiting if our subscription is + // terminated. + case block, ok := <-blockEpoch.Epochs: + if !ok { + log.Trace("dispatchBlocks block epoch " + + "cancelled") + return + } + + // Get the set of currently active channels block + // subscription channels and dispatch the block to + // each. + for _, recipient := range getRecipients() { + select { + // Deliver the block to the arbitrator. + case recipient.blocks <- block.Height: + + // If the recipient is shutting down, exit + // without delivering the block. This may be + // the case when two blocks are mined in quick + // succession, and the arbitrator resolves + // after the first block, and does not need to + // consume the second block. + case <-recipient.quit: + log.Debugf("channel: %v exit without "+ + "receiving block: %v", + recipient.chanPoint, + block.Height) + + // If the chain arb is shutting down, we don't + // need to deliver any more blocks (everything + // will be shutting down). + case <-c.quit: + return + } + } + + // Exit if the chain arbitrator is shutting down. + case <-c.quit: + return + } + } +} + // publishClosingTxs will load any stored cooperative or unilater closing // transactions and republish them. This helps ensure propagation of the // transactions in the event that prior publications failed. diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 2d50ca1d..86ddd87d 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -12,7 +12,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/input" @@ -34,6 +33,10 @@ const ( // anchorSweepConfTarget is the conf target used when sweeping // commitment anchors. anchorSweepConfTarget = 6 + + // arbitratorBlockBufferSize is the size of the buffer we give to each + // channel arbitrator. + arbitratorBlockBufferSize = 20 ) // WitnessSubscription represents an intent to be notified once new witnesses @@ -108,12 +111,6 @@ type ChannelArbitratorConfig struct { // to the switch during contract resolution. ShortChanID lnwire.ShortChannelID - // BlockEpochs is an active block epoch event stream backed by an - // active ChainNotifier instance. We will use new block notifications - // sent over this channel to decide when we should go on chain to - // reclaim/redeem the funds in an HTLC sent to/from us. - BlockEpochs *chainntnfs.BlockEpochEvent - // ChainEvents is an active subscription to the chain watcher for this // channel to be notified of any on-chain activity related to this // channel. @@ -325,6 +322,11 @@ type ChannelArbitrator struct { // to do its duty. cfg ChannelArbitratorConfig + // blocks is a channel that the arbitrator will receive new blocks on. + // This channel should be buffered by so that it does not block the + // sender. + blocks chan int32 + // signalUpdates is a channel that any new live signals for the channel // we're watching over will be sent. signalUpdates chan *signalUpdateMsg @@ -366,6 +368,7 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, return &ChannelArbitrator{ log: log, + blocks: make(chan int32, arbitratorBlockBufferSize), signalUpdates: make(chan *signalUpdateMsg), htlcUpdates: make(<-chan *ContractUpdate), resolutionSignal: make(chan struct{}), @@ -397,13 +400,11 @@ func (c *ChannelArbitrator) Start() error { // machine can act accordingly. c.state, err = c.log.CurrentState() if err != nil { - c.cfg.BlockEpochs.Cancel() return err } _, bestHeight, err := c.cfg.ChainIO.GetBestBlock() if err != nil { - c.cfg.BlockEpochs.Cancel() return err } @@ -479,7 +480,6 @@ func (c *ChannelArbitrator) Start() error { c.cfg.ChanPoint) default: - c.cfg.BlockEpochs.Cancel() return err } } @@ -501,7 +501,6 @@ func (c *ChannelArbitrator) Start() error { // commitment has been confirmed on chain, and before we // advance our state step, we call InsertConfirmedCommitSet. if err := c.relaunchResolvers(commitSet, triggerHeight); err != nil { - c.cfg.BlockEpochs.Cancel() return err } } @@ -2111,7 +2110,6 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // TODO(roasbeef): tell top chain arb we're done defer func() { - c.cfg.BlockEpochs.Cancel() c.wg.Done() }() @@ -2121,11 +2119,11 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // A new block has arrived, we'll examine all the active HTLC's // to see if any of them have expired, and also update our // track of the best current height. - case blockEpoch, ok := <-c.cfg.BlockEpochs.Epochs: + case blockHeight, ok := <-c.blocks: if !ok { return } - bestHeight = blockEpoch.Height + bestHeight = blockHeight // If we're not in the default state, then we can // ignore this signal as we're waiting for contract diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index d3c85f26..38970b6b 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -197,8 +197,6 @@ type chanArbTestCtx struct { resolvedChan chan struct{} - blockEpochs chan *chainntnfs.BlockEpoch - incubationRequests chan struct{} resolutions chan []ResolutionMsg @@ -304,12 +302,6 @@ func withMarkClosed(markClosed func(*channeldb.ChannelCloseSummary, func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, opts ...testChanArbOption) (*chanArbTestCtx, error) { - blockEpochs := make(chan *chainntnfs.BlockEpoch) - blockEpoch := &chainntnfs.BlockEpochEvent{ - Epochs: blockEpochs, - Cancel: func() {}, - } - chanPoint := wire.OutPoint{} shortChanID := lnwire.ShortChannelID{} chanEvents := &ChainEventSubscription{ @@ -366,7 +358,6 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, arbCfg := &ChannelArbitratorConfig{ ChanPoint: chanPoint, ShortChanID: shortChanID, - BlockEpochs: blockEpoch, MarkChannelResolved: func() error { resolvedChan <- struct{}{} return nil @@ -433,7 +424,6 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, cleanUp: cleanUp, resolvedChan: resolvedChan, resolutions: resolutionChan, - blockEpochs: blockEpochs, log: log, incubationRequests: incubateChan, sweeper: mockSweeper, @@ -1759,7 +1749,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // now mine a block (height 5), which is 5 blocks away // (our grace delta) from the expiry of that HTLC. case testCase.htlcExpired: - chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 5} + chanArbCtx.chanArb.blocks <- 5 // Otherwise, we'll just trigger a regular force close // request. @@ -1863,7 +1853,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // so instead, we'll mine another block which'll cause // it to re-examine its state and realize there're no // more HTLCs. - chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 6} + chanArbCtx.chanArb.blocks <- 6 chanArbCtx.AssertStateTransitions(StateFullyResolved) }) } @@ -1940,13 +1930,13 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { // We will advance the uptime to 10 seconds which should be still within // the grace period and should not trigger going to chain. testClock.SetTime(startTime.Add(time.Second * 10)) - chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 5} + chanArbCtx.chanArb.blocks <- 5 chanArbCtx.AssertState(StateDefault) // We will advance the uptime to 16 seconds which should trigger going // to chain. testClock.SetTime(startTime.Add(time.Second * 16)) - chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 6} + chanArbCtx.chanArb.blocks <- 6 chanArbCtx.AssertStateTransitions( StateBroadcastCommit, StateCommitmentBroadcasted, From f1404af4755e256c26061c104babda4e29ef9e00 Mon Sep 17 00:00:00 2001 From: carla Date: Thu, 12 Nov 2020 15:23:24 +0200 Subject: [PATCH 2/3] contractcourt: pass in optional txns to lookups required for arb startup To allow us to grab all of the information we need for our channel arbs in a more efficient way on startup, we add an optional tx to our lookup functions required on start. --- contractcourt/briefcase.go | 106 ++++++++++++++--------- contractcourt/briefcase_test.go | 12 +-- contractcourt/channel_arbitrator.go | 4 +- contractcourt/channel_arbitrator_test.go | 4 +- 4 files changed, 76 insertions(+), 50 deletions(-) diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go index 6b377eeb..eb1489d5 100644 --- a/contractcourt/briefcase.go +++ b/contractcourt/briefcase.go @@ -54,8 +54,10 @@ type ArbitratorLog interface { // TODO(roasbeef): document on interface the errors expected to be // returned - // CurrentState returns the current state of the ChannelArbitrator. - CurrentState() (ArbitratorState, error) + // CurrentState returns the current state of the ChannelArbitrator. It + // takes an optional database transaction, which will be used if it is + // non-nil, otherwise the lookup will be done in its own transaction. + CurrentState(tx kvdb.RTx) (ArbitratorState, error) // CommitState persists, the current state of the chain attendant. CommitState(ArbitratorState) error @@ -96,8 +98,10 @@ type ArbitratorLog interface { InsertConfirmedCommitSet(c *CommitSet) error // FetchConfirmedCommitSet fetches the known confirmed active HTLC set - // from the database. - FetchConfirmedCommitSet() (*CommitSet, error) + // from the database. It takes an optional database transaction, which + // will be used if it is non-nil, otherwise the lookup will be done in + // its own transaction. + FetchConfirmedCommitSet(tx kvdb.RTx) (*CommitSet, error) // FetchChainActions attempts to fetch the set of previously stored // chain actions. We'll use this upon restart to properly advance our @@ -412,27 +416,28 @@ func (b *boltArbitratorLog) writeResolver(contractBucket kvdb.RwBucket, return contractBucket.Put(resKey, buf.Bytes()) } -// CurrentState returns the current state of the ChannelArbitrator. +// CurrentState returns the current state of the ChannelArbitrator. It takes an +// optional database transaction, which will be used if it is non-nil, otherwise +// the lookup will be done in its own transaction. // // NOTE: Part of the ContractResolver interface. -func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) { - var s ArbitratorState - err := kvdb.View(b.db, func(tx kvdb.RTx) error { - scopeBucket := tx.ReadBucket(b.scopeKey[:]) - if scopeBucket == nil { - return errScopeBucketNoExist - } +func (b *boltArbitratorLog) CurrentState(tx kvdb.RTx) (ArbitratorState, error) { + var ( + s ArbitratorState + err error + ) - stateBytes := scopeBucket.Get(stateKey) - if stateBytes == nil { - return nil - } + if tx != nil { + s, err = b.currentState(tx) + } else { + err = kvdb.View(b.db, func(tx kvdb.RTx) error { + s, err = b.currentState(tx) + return err + }, func() { + s = 0 + }) + } - s = ArbitratorState(stateBytes[0]) - return nil - }, func() { - s = 0 - }) if err != nil && err != errScopeBucketNoExist { return s, err } @@ -440,6 +445,20 @@ func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) { return s, nil } +func (b *boltArbitratorLog) currentState(tx kvdb.RTx) (ArbitratorState, error) { + scopeBucket := tx.ReadBucket(b.scopeKey[:]) + if scopeBucket == nil { + return 0, errScopeBucketNoExist + } + + stateBytes := scopeBucket.Get(stateKey) + if stateBytes == nil { + return 0, nil + } + + return ArbitratorState(stateBytes[0]), nil +} + // CommitState persists, the current state of the chain attendant. // // NOTE: Part of the ContractResolver interface. @@ -851,29 +870,20 @@ func (b *boltArbitratorLog) InsertConfirmedCommitSet(c *CommitSet) error { } // FetchConfirmedCommitSet fetches the known confirmed active HTLC set from the -// database. +// database. It takes an optional database transaction, which will be used if it +// is non-nil, otherwise the lookup will be done in its own transaction. // // NOTE: Part of the ContractResolver interface. -func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) { +func (b *boltArbitratorLog) FetchConfirmedCommitSet(tx kvdb.RTx) (*CommitSet, error) { + if tx != nil { + return b.fetchConfirmedCommitSet(tx) + } + var c *CommitSet err := kvdb.View(b.db, func(tx kvdb.RTx) error { - scopeBucket := tx.ReadBucket(b.scopeKey[:]) - if scopeBucket == nil { - return errScopeBucketNoExist - } - - commitSetBytes := scopeBucket.Get(commitSetKey) - if commitSetBytes == nil { - return errNoCommitSet - } - - commitSet, err := decodeCommitSet(bytes.NewReader(commitSetBytes)) - if err != nil { - return err - } - - c = commitSet - return nil + var err error + c, err = b.fetchConfirmedCommitSet(tx) + return err }, func() { c = nil }) @@ -884,6 +894,22 @@ func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) { return c, nil } +func (b *boltArbitratorLog) fetchConfirmedCommitSet(tx kvdb.RTx) (*CommitSet, + error) { + + scopeBucket := tx.ReadBucket(b.scopeKey[:]) + if scopeBucket == nil { + return nil, errScopeBucketNoExist + } + + commitSetBytes := scopeBucket.Get(commitSetKey) + if commitSetBytes == nil { + return nil, errNoCommitSet + } + + return decodeCommitSet(bytes.NewReader(commitSetBytes)) +} + // WipeHistory is to be called ONLY once *all* contracts have been fully // resolved, and the channel closure if finalized. This method will delete all // on-disk state within the persistent log. diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go index 1e88f607..6c2936b4 100644 --- a/contractcourt/briefcase_test.go +++ b/contractcourt/briefcase_test.go @@ -611,7 +611,7 @@ func TestStateMutation(t *testing.T) { defer cleanUp() // The default state of an arbitrator should be StateDefault. - arbState, err := testLog.CurrentState() + arbState, err := testLog.CurrentState(nil) if err != nil { t.Fatalf("unable to read arb state: %v", err) } @@ -625,7 +625,7 @@ func TestStateMutation(t *testing.T) { if err := testLog.CommitState(StateFullyResolved); err != nil { t.Fatalf("unable to write state: %v", err) } - arbState, err = testLog.CurrentState() + arbState, err = testLog.CurrentState(nil) if err != nil { t.Fatalf("unable to read arb state: %v", err) } @@ -643,7 +643,7 @@ func TestStateMutation(t *testing.T) { // If we try to query for the state again, we should get the default // state again. - arbState, err = testLog.CurrentState() + arbState, err = testLog.CurrentState(nil) if err != nil { t.Fatalf("unable to query current state: %v", err) } @@ -687,11 +687,11 @@ func TestScopeIsolation(t *testing.T) { // Querying each log, the states should be the prior one we set, and be // disjoint. - log1State, err := testLog1.CurrentState() + log1State, err := testLog1.CurrentState(nil) if err != nil { t.Fatalf("unable to read arb state: %v", err) } - log2State, err := testLog2.CurrentState() + log2State, err := testLog2.CurrentState(nil) if err != nil { t.Fatalf("unable to read arb state: %v", err) } @@ -752,7 +752,7 @@ func TestCommitSetStorage(t *testing.T) { t.Fatalf("unable to write commit set: %v", err) } - diskCommitSet, err := testLog.FetchConfirmedCommitSet() + diskCommitSet, err := testLog.FetchConfirmedCommitSet(nil) if err != nil { t.Fatalf("unable to read commit set: %v", err) } diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 86ddd87d..8b4b3df7 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -398,7 +398,7 @@ func (c *ChannelArbitrator) Start() error { // First, we'll read our last state from disk, so our internal state // machine can act accordingly. - c.state, err = c.log.CurrentState() + c.state, err = c.log.CurrentState(nil) if err != nil { return err } @@ -454,7 +454,7 @@ func (c *ChannelArbitrator) Start() error { // older nodes, this won't be found at all, and will rely on the // existing written chain actions. Additionally, if this channel hasn't // logged any actions in the log, then this field won't be present. - commitSet, err := c.log.FetchConfirmedCommitSet() + commitSet, err := c.log.FetchConfirmedCommitSet(nil) if err != nil && err != errNoCommitSet && err != errScopeBucketNoExist { return err } diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 38970b6b..3371998f 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -51,7 +51,7 @@ type mockArbitratorLog struct { // interface. var _ ArbitratorLog = (*mockArbitratorLog)(nil) -func (b *mockArbitratorLog) CurrentState() (ArbitratorState, error) { +func (b *mockArbitratorLog) CurrentState(kvdb.RTx) (ArbitratorState, error) { return b.state, nil } @@ -140,7 +140,7 @@ func (b *mockArbitratorLog) InsertConfirmedCommitSet(c *CommitSet) error { return nil } -func (b *mockArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) { +func (b *mockArbitratorLog) FetchConfirmedCommitSet(kvdb.RTx) (*CommitSet, error) { return b.commitSet, nil } From 697dbf7f3a1908f32b9674a6ec087e59c13715e2 Mon Sep 17 00:00:00 2001 From: carla Date: Thu, 12 Nov 2020 15:23:25 +0200 Subject: [PATCH 3/3] contractcourt: get arbitrator state before we start each arbitrator --- contractcourt/chain_arbitrator.go | 51 +++++++++++++++-- contractcourt/channel_arbitrator.go | 73 +++++++++++++++++------- contractcourt/channel_arbitrator_test.go | 34 +++++------ 3 files changed, 115 insertions(+), 43 deletions(-) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 8be4fc0e..e4127cf4 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/kvdb" @@ -597,21 +598,63 @@ func (c *ChainArbitrator) Start() error { close(watcherErrs) }() + // stopAndLog is a helper function which shuts down the chain arb and + // logs errors if they occur. + stopAndLog := func() { + if err := c.Stop(); err != nil { + log.Errorf("ChainArbitrator could not shutdown: %v", err) + } + } + // Handle all errors returned from spawning our chain watchers. If any // of them failed, we will stop the chain arb to shutdown any active // goroutines. for err := range watcherErrs { if err != nil { - c.Stop() + stopAndLog() return err } } + // Before we start all of our arbitrators, we do a preliminary state + // lookup so that we can combine all of these lookups in a single db + // transaction. + var startStates map[wire.OutPoint]*chanArbStartState + + err = kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error { + for _, arbitrator := range c.activeChannels { + startState, err := arbitrator.getStartState(tx) + if err != nil { + return err + } + + startStates[arbitrator.cfg.ChanPoint] = startState + } + + return nil + }, func() { + startStates = make( + map[wire.OutPoint]*chanArbStartState, + len(c.activeChannels), + ) + }) + if err != nil { + stopAndLog() + return err + } + // Launch all the goroutines for each arbitrator so they can carry out // their duties. for _, arbitrator := range c.activeChannels { - if err := arbitrator.Start(); err != nil { - c.Stop() + startState, ok := startStates[arbitrator.cfg.ChanPoint] + if !ok { + stopAndLog() + return fmt.Errorf("arbitrator: %v has no start state", + arbitrator.cfg.ChanPoint) + } + + if err := arbitrator.Start(startState); err != nil { + stopAndLog() return err } } @@ -1060,7 +1103,7 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error // arbitrators, then launch it. c.activeChannels[chanPoint] = channelArb - if err := channelArb.Start(); err != nil { + if err := channelArb.Start(nil); err != nil { return err } diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 8b4b3df7..6d008051 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -379,16 +379,58 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, } } +// chanArbStartState contains the information from disk that we need to start +// up a channel arbitrator. +type chanArbStartState struct { + currentState ArbitratorState + commitSet *CommitSet +} + +// getStartState retrieves the information from disk that our channel arbitrator +// requires to start. +func (c *ChannelArbitrator) getStartState(tx kvdb.RTx) (*chanArbStartState, + error) { + + // First, we'll read our last state from disk, so our internal state + // machine can act accordingly. + state, err := c.log.CurrentState(tx) + if err != nil { + return nil, err + } + + // Next we'll fetch our confirmed commitment set. This will only exist + // if the channel has been closed out on chain for modern nodes. For + // older nodes, this won't be found at all, and will rely on the + // existing written chain actions. Additionally, if this channel hasn't + // logged any actions in the log, then this field won't be present. + commitSet, err := c.log.FetchConfirmedCommitSet(tx) + if err != nil && err != errNoCommitSet && err != errScopeBucketNoExist { + return nil, err + } + + return &chanArbStartState{ + currentState: state, + commitSet: commitSet, + }, nil +} + // Start starts all the goroutines that the ChannelArbitrator needs to operate. -func (c *ChannelArbitrator) Start() error { +// If takes a start state, which will be looked up on disk if it is not +// provided. +func (c *ChannelArbitrator) Start(state *chanArbStartState) error { if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { return nil } c.startTimestamp = c.cfg.Clock.Now() - var ( - err error - ) + // If the state passed in is nil, we look it up now. + if state == nil { + var err error + state, err = c.getStartState(nil) + if err != nil { + return err + } + } log.Debugf("Starting ChannelArbitrator(%v), htlc_set=%v", c.cfg.ChanPoint, newLogClosure(func() string { @@ -396,12 +438,8 @@ func (c *ChannelArbitrator) Start() error { }), ) - // First, we'll read our last state from disk, so our internal state - // machine can act accordingly. - c.state, err = c.log.CurrentState(nil) - if err != nil { - return err - } + // Set our state from our starting state. + c.state = state.currentState _, bestHeight, err := c.cfg.ChainIO.GetBestBlock() if err != nil { @@ -449,21 +487,11 @@ func (c *ChannelArbitrator) Start() error { "triggerHeight=%v", c.cfg.ChanPoint, c.state, trigger, triggerHeight) - // Next we'll fetch our confirmed commitment set. This will only exist - // if the channel has been closed out on chain for modern nodes. For - // older nodes, this won't be found at all, and will rely on the - // existing written chain actions. Additionally, if this channel hasn't - // logged any actions in the log, then this field won't be present. - commitSet, err := c.log.FetchConfirmedCommitSet(nil) - if err != nil && err != errNoCommitSet && err != errScopeBucketNoExist { - return err - } - // We'll now attempt to advance our state forward based on the current // on-chain state, and our set of active contracts. startingState := c.state nextState, _, err := c.advanceState( - triggerHeight, trigger, commitSet, + triggerHeight, trigger, state.commitSet, ) if err != nil { switch err { @@ -500,7 +528,8 @@ func (c *ChannelArbitrator) Start() error { // receive a chain event from the chain watcher than the // commitment has been confirmed on chain, and before we // advance our state step, we call InsertConfirmedCommitSet. - if err := c.relaunchResolvers(commitSet, triggerHeight); err != nil { + err := c.relaunchResolvers(state.commitSet, triggerHeight) + if err != nil { return err } } diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 3371998f..a4e9d003 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -275,7 +275,7 @@ func (c *chanArbTestCtx) Restart(restartClosure func(*chanArbTestCtx)) (*chanArb restartClosure(newCtx) } - if err := newCtx.chanArb.Start(); err != nil { + if err := newCtx.chanArb.Start(nil); err != nil { return nil, err } @@ -444,7 +444,7 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - if err := chanArbCtx.chanArb.Start(); err != nil { + if err := chanArbCtx.chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer func() { @@ -506,7 +506,7 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -561,7 +561,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -667,7 +667,7 @@ func TestChannelArbitratorBreachClose(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer func() { @@ -712,7 +712,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.Registry = &mockRegistry{} - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -984,7 +984,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1093,7 +1093,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1201,7 +1201,7 @@ func TestChannelArbitratorPersistence(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1325,7 +1325,7 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1487,7 +1487,7 @@ func TestChannelArbitratorCommitFailure(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1572,7 +1572,7 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) { chanArb.cfg.ClosingHeight = 100 chanArb.cfg.CloseType = channeldb.RemoteForceClose - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1604,7 +1604,7 @@ func TestChannelArbitratorAlreadyForceClosed(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1702,7 +1702,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1893,7 +1893,7 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { return false } - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer func() { @@ -2050,7 +2050,7 @@ func TestRemoteCloseInitiator(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start "+ "ChannelArbitrator: %v", err) } @@ -2120,7 +2120,7 @@ func TestChannelArbitratorAnchors(t *testing.T) { {}, {}, } - if err := chanArb.Start(); err != nil { + if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer func() {