diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index 246876d4..aec44b88 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -231,7 +231,7 @@ type ChainControl struct { // full-node, another backed by a running bitcoind full-node, and the other // backed by a running neutrino light client instance. When running with a // neutrino light client instance, `neutrinoCS` must be non-nil. -func NewChainControl(cfg *Config) (*ChainControl, error) { +func NewChainControl(cfg *Config) (*ChainControl, func(), error) { // Set the RPC config from the "home" chain. Multi-chain isn't yet // active, so we'll restrict usage to a particular chain for now. @@ -269,7 +269,7 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { DefaultLitecoinStaticFeePerKW, 0, ) default: - return nil, fmt.Errorf("default routing policy for chain %v is "+ + return nil, nil, fmt.Errorf("default routing policy for chain %v is "+ "unknown", cfg.PrimaryChain()) } @@ -299,7 +299,7 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { heightHintCacheConfig, cfg.LocalChanDB, ) if err != nil { - return nil, fmt.Errorf("unable to initialize height hint "+ + return nil, nil, fmt.Errorf("unable to initialize height hint "+ "cache: %v", err) } @@ -316,14 +316,14 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { ) cc.ChainView, err = chainview.NewCfFilteredChainView(cfg.NeutrinoCS) if err != nil { - return nil, err + return nil, nil, err } // Map the deprecated neutrino feeurl flag to the general fee // url. if cfg.NeutrinoMode.FeeURL != "" { if cfg.FeeURL != "" { - return nil, errors.New("feeurl and " + + return nil, nil, errors.New("feeurl and " + "neutrino.feeurl are mutually exclusive") } @@ -363,7 +363,7 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { // this back to the btcwallet/bitcoind port. rpcPort, err := strconv.Atoi(cfg.ActiveNetParams.RPCPort) if err != nil { - return nil, err + return nil, nil, err } rpcPort -= 2 bitcoindHost = fmt.Sprintf("%v:%d", @@ -400,11 +400,11 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { PrunedModeMaxPeers: bitcoindMode.PrunedNodeMaxPeers, }) if err != nil { - return nil, err + return nil, nil, err } if err := bitcoindConn.Start(); err != nil { - return nil, fmt.Errorf("unable to connect to bitcoind: "+ + return nil, nil, fmt.Errorf("unable to connect to bitcoind: "+ "%v", err) } @@ -439,7 +439,7 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { fallBackFeeRate.FeePerKWeight(), ) if err != nil { - return nil, err + return nil, nil, err } } else if cfg.Litecoin.Active && !cfg.Litecoin.RegTest { log.Infof("Initializing litecoind backed fee estimator in "+ @@ -455,7 +455,7 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { fallBackFeeRate.FeePerKWeight(), ) if err != nil { - return nil, err + return nil, nil, err } } @@ -464,14 +464,14 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { // connection. chainConn, err := rpcclient.New(rpcConfig, nil) if err != nil { - return nil, err + return nil, nil, err } // The api we will use for our health check depends on the // bitcoind version. cmd, err := getBitcoindHealthCheckCmd(chainConn) if err != nil { - return nil, err + return nil, nil, err } cc.HealthCheck = func() error { @@ -497,19 +497,19 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { if btcdMode.RawRPCCert != "" { rpcCert, err = hex.DecodeString(btcdMode.RawRPCCert) if err != nil { - return nil, err + return nil, nil, err } } else { certFile, err := os.Open(btcdMode.RPCCert) if err != nil { - return nil, err + return nil, nil, err } rpcCert, err = ioutil.ReadAll(certFile) if err != nil { - return nil, err + return nil, nil, err } if err := certFile.Close(); err != nil { - return nil, err + return nil, nil, err } } @@ -541,7 +541,7 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { rpcConfig, cfg.ActiveNetParams.Params, hintCache, hintCache, ) if err != nil { - return nil, err + return nil, nil, err } // Finally, we'll create an instance of the default chain view to be @@ -549,7 +549,7 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { cc.ChainView, err = chainview.NewBtcdFilteredChainView(*rpcConfig) if err != nil { log.Errorf("unable to create chain view: %v", err) - return nil, err + return nil, nil, err } // Create a special websockets rpc client for btcd which will be used @@ -557,7 +557,7 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { chainRPC, err := chain.NewRPCClient(cfg.ActiveNetParams.Params, btcdHost, btcdUser, btcdPass, rpcCert, false, 20) if err != nil { - return nil, err + return nil, nil, err } walletConfig.ChainSource = chainRPC @@ -584,11 +584,11 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { *rpcConfig, fallBackFeeRate.FeePerKWeight(), ) if err != nil { - return nil, err + return nil, nil, err } } default: - return nil, fmt.Errorf("unknown node type: %s", + return nil, nil, fmt.Errorf("unknown node type: %s", homeChainConfig.Node) } @@ -599,7 +599,7 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { case cfg.FeeURL == "" && cfg.Bitcoin.MainNet && homeChainConfig.Node == "neutrino": - return nil, fmt.Errorf("--feeurl parameter required when " + + return nil, nil, fmt.Errorf("--feeurl parameter required when " + "running neutrino on mainnet") // Override default fee estimator if an external service is specified. @@ -619,15 +619,29 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { ) } + ccCleanup := func() { + if cc.Wallet != nil { + if err := cc.Wallet.Shutdown(); err != nil { + log.Errorf("Failed to shutdown wallet: %v", err) + } + } + + if cc.FeeEstimator != nil { + if err := cc.FeeEstimator.Stop(); err != nil { + log.Errorf("Failed to stop feeEstimator: %v", err) + } + } + } + // Start fee estimator. if err := cc.FeeEstimator.Start(); err != nil { - return nil, err + return nil, nil, err } wc, err := btcwallet.New(*walletConfig) if err != nil { fmt.Printf("unable to create wallet controller: %v\n", err) - return nil, err + return nil, ccCleanup, err } cc.MsgSigner = wc @@ -662,18 +676,17 @@ func NewChainControl(cfg *Config) (*ChainControl, error) { lnWallet, err := lnwallet.NewLightningWallet(walletCfg) if err != nil { fmt.Printf("unable to create wallet: %v\n", err) - return nil, err + return nil, ccCleanup, err } if err := lnWallet.Startup(); err != nil { fmt.Printf("unable to start wallet: %v\n", err) - return nil, err + return nil, ccCleanup, err } log.Info("LightningWallet opened") - cc.Wallet = lnWallet - return cc, nil + return cc, ccCleanup, nil } // getBitcoindHealthCheckCmd queries bitcoind for its version to decide which diff --git a/channelnotifier/channelnotifier.go b/channelnotifier/channelnotifier.go index 5d67fd51..f5aa2961 100644 --- a/channelnotifier/channelnotifier.go +++ b/channelnotifier/channelnotifier.go @@ -86,10 +86,12 @@ func (c *ChannelNotifier) Start() error { } // Stop signals the notifier for a graceful shutdown. -func (c *ChannelNotifier) Stop() { +func (c *ChannelNotifier) Stop() error { + var err error c.stopped.Do(func() { - c.ntfnServer.Stop() + err = c.ntfnServer.Stop() }) + return err } // SubscribeChannelEvents returns a subscribe.Client that will receive updates diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 5eb5d689..4d583cc7 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -465,8 +465,9 @@ func (d *AuthenticatedGossiper) start() error { } // Stop signals any active goroutines for a graceful closure. -func (d *AuthenticatedGossiper) Stop() { +func (d *AuthenticatedGossiper) Stop() error { d.stopped.Do(d.stop) + return nil } func (d *AuthenticatedGossiper) stop() { diff --git a/funding/manager.go b/funding/manager.go index a02ae114..f74b5c99 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -675,12 +675,14 @@ func (f *Manager) start() error { // Stop signals all helper goroutines to execute a graceful shutdown. This // method will block until all goroutines have exited. -func (f *Manager) Stop() { +func (f *Manager) Stop() error { f.stopped.Do(func() { log.Info("Funding manager shutting down") close(f.quit) f.wg.Wait() }) + + return nil } // nextPendingChanID returns the next free pending channel ID to be used to diff --git a/funding/manager_test.go b/funding/manager_test.go index 4a995040..fa08d2e0 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -484,7 +484,9 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, func recreateAliceFundingManager(t *testing.T, alice *testNode) { // Stop the old fundingManager before creating a new one. close(alice.shutdownChannel) - alice.fundingMgr.Stop() + if err := alice.fundingMgr.Stop(); err != nil { + t.Fatalf("failed stop funding manager: %v", err) + } aliceMsgChan := make(chan lnwire.Message) aliceAnnounceChan := make(chan lnwire.Message) @@ -622,8 +624,12 @@ func tearDownFundingManagers(t *testing.T, a, b *testNode) { close(a.shutdownChannel) close(b.shutdownChannel) - a.fundingMgr.Stop() - b.fundingMgr.Stop() + if err := a.fundingMgr.Stop(); err != nil { + t.Fatalf("failed stop funding manager: %v", err) + } + if err := b.fundingMgr.Stop(); err != nil { + t.Fatalf("failed stop funding manager: %v", err) + } os.RemoveAll(a.testDir) os.RemoveAll(b.testDir) } @@ -1502,7 +1508,9 @@ func TestFundingManagerRestartBehavior(t *testing.T) { // implementation, and expect it to retry sending the fundingLocked // message. We'll explicitly shut down Alice's funding manager to // prevent a race when overriding the sendMessage implementation. - alice.fundingMgr.Stop() + if err := alice.fundingMgr.Stop(); err != nil { + t.Fatalf("failed stop funding manager: %v", err) + } bob.sendMessage = workingSendMessage recreateAliceFundingManager(t, alice) diff --git a/htlcswitch/htlcnotifier.go b/htlcswitch/htlcnotifier.go index 25953e65..e73ec543 100644 --- a/htlcswitch/htlcnotifier.go +++ b/htlcswitch/htlcnotifier.go @@ -87,12 +87,14 @@ func (h *HtlcNotifier) Start() error { } // Stop signals the notifier for a graceful shutdown. -func (h *HtlcNotifier) Stop() { +func (h *HtlcNotifier) Stop() error { + var err error h.stopped.Do(func() { - if err := h.ntfnServer.Stop(); err != nil { + if err = h.ntfnServer.Stop(); err != nil { log.Warnf("error stopping htlc notifier: %v", err) } }) + return err } // SubscribeHtlcEvents returns a subscribe.Client that will receive updates diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index fac7a9dd..adf18d91 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -2867,19 +2867,31 @@ func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int, if err := aliceNotifier.Start(); err != nil { t.Fatalf("could not start alice notifier") } - defer aliceNotifier.Stop() + defer func() { + if err := aliceNotifier.Stop(); err != nil { + t.Fatalf("failed to stop alice notifier: %v", err) + } + }() bobNotifier := NewHtlcNotifier(mockTime) if err := bobNotifier.Start(); err != nil { t.Fatalf("could not start bob notifier") } - defer bobNotifier.Stop() + defer func() { + if err := bobNotifier.Stop(); err != nil { + t.Fatalf("failed to stop bob notifier: %v", err) + } + }() carolNotifier := NewHtlcNotifier(mockTime) if err := carolNotifier.Start(); err != nil { t.Fatalf("could not start carol notifier") } - defer carolNotifier.Stop() + defer func() { + if err := carolNotifier.Stop(); err != nil { + t.Fatalf("failed to stop carol notifier: %v", err) + } + }() // Create a notifier server option which will set our htlc notifiers // for the three hop network. diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index b73de5e5..bb336be7 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -237,7 +237,7 @@ func (i *InvoiceRegistry) Start() error { // delete them. err = i.scanInvoicesOnStart() if err != nil { - i.Stop() + _ = i.Stop() return err } @@ -245,12 +245,13 @@ func (i *InvoiceRegistry) Start() error { } // Stop signals the registry for a graceful shutdown. -func (i *InvoiceRegistry) Stop() { +func (i *InvoiceRegistry) Stop() error { i.expiryWatcher.Stop() close(i.quit) i.wg.Wait() + return nil } // invoiceEvent represents a new event that has modified on invoice on disk. diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 1d1198eb..799d5395 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -527,7 +527,11 @@ func TestCancelHoldInvoice(t *testing.T) { if err != nil { t.Fatal(err) } - defer registry.Stop() + defer func() { + if err := registry.Stop(); err != nil { + t.Fatalf("failed to stop invoice registry: %v", err) + } + }() // Add the invoice. _, err = registry.AddInvoice(testHodlInvoice, testInvoicePaymentHash) @@ -1005,7 +1009,9 @@ func TestInvoiceExpiryWithRegistry(t *testing.T) { // Give some time to the watcher to cancel everything. time.Sleep(500 * time.Millisecond) - registry.Stop() + if err = registry.Stop(); err != nil { + t.Fatalf("failed to stop invoice registry: %v", err) + } // Create the expected cancellation set before the final check. expectedCancellations = append( diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index 6a454a9e..3e49a957 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -215,7 +215,9 @@ func newTestContext(t *testing.T) *testContext { clock: clock, t: t, cleanup: func() { - registry.Stop() + if err = registry.Stop(); err != nil { + t.Fatalf("failed to stop invoice registry: %v", err) + } cleanup() }, } diff --git a/lnd.go b/lnd.go index 658c82b7..b82d2f7c 100644 --- a/lnd.go +++ b/lnd.go @@ -548,7 +548,10 @@ func Main(cfg *Config, lisCfg ListenerCfg, interceptor signal.Interceptor) error }, } - activeChainControl, err := chainreg.NewChainControl(chainControlCfg) + activeChainControl, cleanup, err := chainreg.NewChainControl(chainControlCfg) + if cleanup != nil { + defer cleanup() + } if err != nil { err := fmt.Errorf("unable to create chain control: %v", err) ltndLog.Error(err) diff --git a/peernotifier/peernotifier.go b/peernotifier/peernotifier.go index 0943c82a..4a92ff26 100644 --- a/peernotifier/peernotifier.go +++ b/peernotifier/peernotifier.go @@ -49,11 +49,13 @@ func (p *PeerNotifier) Start() error { } // Stop signals the notifier for a graceful shutdown. -func (p *PeerNotifier) Stop() { +func (p *PeerNotifier) Stop() error { + var err error p.stopped.Do(func() { log.Info("Stopping PeerNotifier") - p.ntfnServer.Stop() + err = p.ntfnServer.Stop() }) + return err } // SubscribePeerEvents returns a subscribe.Client that will receive updates diff --git a/server.go b/server.go index 688b9cda..9d913be6 100644 --- a/server.go +++ b/server.go @@ -1435,17 +1435,45 @@ func (s *server) Started() bool { return atomic.LoadInt32(&s.active) != 0 } +// cleaner is used to aggregate "cleanup" functions during an operation that +// starts several subsystems. In case one of the subsystem fails to start +// and a proper resource cleanup is required, the "run" method achieves this +// by running all these added "cleanup" functions +type cleaner []func() error + +// add is used to add a cleanup function to be called when +// the run function is executed +func (c cleaner) add(cleanup func() error) cleaner { + return append(c, cleanup) +} + +// run is used to run all the previousely added cleanup functions +func (c cleaner) run() { + for i := len(c) - 1; i >= 0; i-- { + if err := c[i](); err != nil { + srvrLog.Infof("Cleanup failed: %v", err) + } + } +} + // Start starts the main daemon server, all requested listeners, and any helper // goroutines. // NOTE: This function is safe for concurrent access. func (s *server) Start() error { var startErr error + + // If one sub system fails to start, the following code ensures that the + // previous started ones are stopped. It also ensures a proper wallet + // shutdown which is important for releasing its resources (boltdb, etc...) + cleanup := cleaner{} + s.start.Do(func() { if s.torController != nil { if err := s.createNewHiddenService(); err != nil { startErr = err return } + cleanup = cleanup.add(s.torController.Stop) } if s.natTraversal != nil { @@ -1458,6 +1486,7 @@ func (s *server) Start() error { startErr = err return } + cleanup = cleanup.add(s.hostAnn.Stop) } if s.livelinessMonitor != nil { @@ -1465,6 +1494,7 @@ func (s *server) Start() error { startErr = err return } + cleanup = cleanup.add(s.livelinessMonitor.Stop) } // Start the notification server. This is used so channel @@ -1476,91 +1506,134 @@ func (s *server) Start() error { startErr = err return } + cleanup = cleanup.add(s.sigPool.Stop) + if err := s.writePool.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.writePool.Stop) + if err := s.readPool.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.readPool.Stop) + if err := s.cc.ChainNotifier.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.cc.ChainNotifier.Stop) + if err := s.channelNotifier.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.channelNotifier.Stop) + if err := s.peerNotifier.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(func() error { + return s.peerNotifier.Stop() + }) if err := s.htlcNotifier.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.htlcNotifier.Stop) + if err := s.sphinx.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.sphinx.Stop) + if s.towerClient != nil { if err := s.towerClient.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.towerClient.Stop) } if s.anchorTowerClient != nil { if err := s.anchorTowerClient.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.anchorTowerClient.Stop) } + if err := s.htlcSwitch.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.htlcSwitch.Stop) + if err := s.sweeper.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.sweeper.Stop) + if err := s.utxoNursery.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.utxoNursery.Stop) + if err := s.chainArb.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.chainArb.Stop) + if err := s.breachArbiter.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.breachArbiter.Stop) + if err := s.authGossiper.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.authGossiper.Stop) + if err := s.chanRouter.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.chanRouter.Stop) + if err := s.fundingMgr.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.fundingMgr.Stop) + if err := s.invoices.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.invoices.Stop) + if err := s.chanStatusMgr.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(s.chanStatusMgr.Stop) if err := s.chanEventStore.Start(); err != nil { startErr = err return } + cleanup = cleanup.add(func() error { + s.chanEventStore.Stop() + return nil + }) // Before we start the connMgr, we'll check to see if we have // any backups to recover. We do this now as we want to ensure @@ -1598,8 +1671,13 @@ func (s *server) Start() error { startErr = err return } + cleanup = cleanup.add(s.chanSubSwapper.Stop) s.connMgr.Start() + cleanup = cleanup.add(func() error { + s.connMgr.Stop() + return nil + }) // With all the relevant sub-systems started, we'll now attempt // to establish persistent connections to our direct channel @@ -1692,6 +1770,9 @@ func (s *server) Start() error { atomic.StoreInt32(&s.active, 1) }) + if startErr != nil { + cleanup.run() + } return startErr } @@ -1710,30 +1791,49 @@ func (s *server) Stop() error { if err := s.cc.ChainNotifier.Stop(); err != nil { srvrLog.Warnf("Unable to stop ChainNotifier: %v", err) } - s.chanRouter.Stop() - s.htlcSwitch.Stop() - s.sphinx.Stop() - s.utxoNursery.Stop() - s.breachArbiter.Stop() - s.authGossiper.Stop() - s.chainArb.Stop() - s.sweeper.Stop() - s.channelNotifier.Stop() - s.peerNotifier.Stop() - s.htlcNotifier.Stop() - if err := s.cc.Wallet.Shutdown(); err != nil { - srvrLog.Warnf("Unable to stop Wallet: %v", err) + if err := s.chanRouter.Stop(); err != nil { + srvrLog.Warnf("failed to stop chanRouter: %v", err) } - if err := s.cc.ChainView.Stop(); err != nil { - srvrLog.Warnf("Unable to stop ChainView: %v", err) + if err := s.htlcSwitch.Stop(); err != nil { + srvrLog.Warnf("failed to stop htlcSwitch: %v", err) + } + if err := s.sphinx.Stop(); err != nil { + srvrLog.Warnf("failed to stop sphinx: %v", err) + } + if err := s.utxoNursery.Stop(); err != nil { + srvrLog.Warnf("failed to stop utxoNursery: %v", err) + } + if err := s.breachArbiter.Stop(); err != nil { + srvrLog.Warnf("failed to stop breachArbiter: %v", err) + } + if err := s.authGossiper.Stop(); err != nil { + srvrLog.Warnf("failed to stop authGossiper: %v", err) + } + if err := s.chainArb.Stop(); err != nil { + srvrLog.Warnf("failed to stop chainArb: %v", err) + } + if err := s.sweeper.Stop(); err != nil { + srvrLog.Warnf("failed to stop sweeper: %v", err) + } + if err := s.channelNotifier.Stop(); err != nil { + srvrLog.Warnf("failed to stop channelNotifier: %v", err) + } + if err := s.peerNotifier.Stop(); err != nil { + srvrLog.Warnf("failed to stop peerNotifier: %v", err) + } + if err := s.htlcNotifier.Stop(); err != nil { + srvrLog.Warnf("failed to stop htlcNotifier: %v", err) } s.connMgr.Stop() - if err := s.cc.FeeEstimator.Stop(); err != nil { - srvrLog.Warnf("Unable to stop FeeEstimator: %v", err) + if err := s.invoices.Stop(); err != nil { + srvrLog.Warnf("failed to stop invoices: %v", err) + } + if err := s.fundingMgr.Stop(); err != nil { + srvrLog.Warnf("failed to stop fundingMgr: %v", err) + } + if err := s.chanSubSwapper.Stop(); err != nil { + srvrLog.Warnf("failed to stop chanSubSwapper: %v", err) } - s.invoices.Stop() - s.fundingMgr.Stop() - s.chanSubSwapper.Stop() s.chanEventStore.Stop() // Disconnect from each active peers to ensure that