Merge pull request #4590 from hsjoberg/signal-proper-shutdown-glob

signal: handle shutdown properly
This commit is contained in:
Johan T. Halseth 2021-03-18 19:09:31 +01:00 committed by GitHub
commit b1309277b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 170 additions and 129 deletions

@ -2,20 +2,21 @@ package build
import ( import (
"github.com/btcsuite/btclog" "github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/signal"
) )
// ShutdownLogger wraps an existing logger with a shutdown function which will // ShutdownLogger wraps an existing logger with a shutdown function which will
// be called on Critical/Criticalf to prompt shutdown. // be called on Critical/Criticalf to prompt shutdown.
type ShutdownLogger struct { type ShutdownLogger struct {
btclog.Logger btclog.Logger
shutdown func()
} }
// NewShutdownLogger creates a shutdown logger for the log provided which will // NewShutdownLogger creates a shutdown logger for the log provided which will
// use the signal package to request shutdown on critical errors. // use the signal package to request shutdown on critical errors.
func NewShutdownLogger(logger btclog.Logger) *ShutdownLogger { func NewShutdownLogger(logger btclog.Logger, shutdown func()) *ShutdownLogger {
return &ShutdownLogger{ return &ShutdownLogger{
Logger: logger, Logger: logger,
shutdown: shutdown,
} }
} }
@ -26,6 +27,7 @@ func NewShutdownLogger(logger btclog.Logger) *ShutdownLogger {
// Note: it is part of the btclog.Logger interface. // Note: it is part of the btclog.Logger interface.
func (s *ShutdownLogger) Criticalf(format string, params ...interface{}) { func (s *ShutdownLogger) Criticalf(format string, params ...interface{}) {
s.Logger.Criticalf(format, params...) s.Logger.Criticalf(format, params...)
s.Logger.Info("Sending request for shutdown")
s.shutdown() s.shutdown()
} }
@ -36,18 +38,6 @@ func (s *ShutdownLogger) Criticalf(format string, params ...interface{}) {
// Note: it is part of the btclog.Logger interface. // Note: it is part of the btclog.Logger interface.
func (s *ShutdownLogger) Critical(v ...interface{}) { func (s *ShutdownLogger) Critical(v ...interface{}) {
s.Logger.Critical(v) s.Logger.Critical(v)
s.Logger.Info("Sending request for shutdown")
s.shutdown() s.shutdown()
} }
// shutdown checks whether we are listening for interrupts, since a shutdown
// request to the signal package will block if it is not running, and requests
// shutdown if possible.
func (s *ShutdownLogger) shutdown() {
if !signal.Listening() {
s.Logger.Info("Request for shutdown ignored")
return
}
s.Logger.Info("Sending request for shutdown")
signal.RequestShutdown()
}

@ -14,10 +14,6 @@ import (
// RotatingLogWriter is a wrapper around the LogWriter that supports log file // RotatingLogWriter is a wrapper around the LogWriter that supports log file
// rotation. // rotation.
type RotatingLogWriter struct { type RotatingLogWriter struct {
// GenSubLogger is a function that returns a new logger for a subsystem
// belonging to the current RotatingLogWriter.
GenSubLogger func(string) btclog.Logger
logWriter *LogWriter logWriter *LogWriter
backendLog *btclog.Backend backendLog *btclog.Backend
@ -39,16 +35,19 @@ func NewRotatingLogWriter() *RotatingLogWriter {
logWriter := &LogWriter{} logWriter := &LogWriter{}
backendLog := btclog.NewBackend(logWriter) backendLog := btclog.NewBackend(logWriter)
return &RotatingLogWriter{ return &RotatingLogWriter{
GenSubLogger: func(tag string) btclog.Logger {
logger := backendLog.Logger(tag)
return NewShutdownLogger(logger)
},
logWriter: logWriter, logWriter: logWriter,
backendLog: backendLog, backendLog: backendLog,
subsystemLoggers: SubLoggers{}, subsystemLoggers: SubLoggers{},
} }
} }
// GenSubLogger creates a new sublogger. A shutdown callback function
// is provided to be able to shutdown in case of a critical error.
func (r *RotatingLogWriter) GenSubLogger(tag string, shutdown func()) btclog.Logger {
logger := r.backendLog.Logger(tag)
return NewShutdownLogger(logger, shutdown)
}
// RegisterSubLogger registers a new subsystem logger. // RegisterSubLogger registers a new subsystem logger.
func (r *RotatingLogWriter) RegisterSubLogger(subsystem string, func (r *RotatingLogWriter) RegisterSubLogger(subsystem string,
logger btclog.Logger) { logger btclog.Logger) {

@ -18,7 +18,6 @@ import (
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwallet/chanfunding" "github.com/lightningnetwork/lnd/lnwallet/chanfunding"
"github.com/lightningnetwork/lnd/signal"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@ -467,7 +466,7 @@ func openChannelPsbt(rpcCtx context.Context, ctx *cli.Context,
// the server. // the server.
go func() { go func() {
select { select {
case <-signal.ShutdownChannel(): case <-rpcCtx.Done():
fmt.Printf("\nInterrupt signal received.\n") fmt.Printf("\nInterrupt signal received.\n")
close(quit) close(quit)

@ -41,14 +41,15 @@ const (
) )
func getContext() context.Context { func getContext() context.Context {
if err := signal.Intercept(); err != nil { shutdownInterceptor, err := signal.Intercept()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err) _, _ = fmt.Fprintln(os.Stderr, err)
os.Exit(1) os.Exit(1)
} }
ctxc, cancel := context.WithCancel(context.Background()) ctxc, cancel := context.WithCancel(context.Background())
go func() { go func() {
<-signal.ShutdownChannel() <-shutdownInterceptor.ShutdownChannel()
cancel() cancel()
}() }()
return ctxc return ctxc

@ -10,9 +10,16 @@ import (
) )
func main() { func main() {
// Hook interceptor for os signals.
shutdownInterceptor, err := signal.Intercept()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
// Load the configuration, and parse any command line options. This // Load the configuration, and parse any command line options. This
// function will also set up logging properly. // function will also set up logging properly.
loadedConfig, err := lnd.LoadConfig() loadedConfig, err := lnd.LoadConfig(shutdownInterceptor)
if err != nil { if err != nil {
if e, ok := err.(*flags.Error); !ok || e.Type != flags.ErrHelp { if e, ok := err.(*flags.Error); !ok || e.Type != flags.ErrHelp {
// Print error if not due to help request. // Print error if not due to help request.
@ -24,16 +31,10 @@ func main() {
os.Exit(0) os.Exit(0)
} }
// Hook interceptor for os signals.
if err := signal.Intercept(); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
// Call the "real" main in a nested manner so the defers will properly // Call the "real" main in a nested manner so the defers will properly
// be executed in the case of a graceful shutdown. // be executed in the case of a graceful shutdown.
if err := lnd.Main( if err = lnd.Main(
loadedConfig, lnd.ListenerCfg{}, signal.ShutdownChannel(), loadedConfig, lnd.ListenerCfg{}, shutdownInterceptor,
); err != nil { ); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err) _, _ = fmt.Fprintln(os.Stderr, err)
os.Exit(1) os.Exit(1)

@ -36,6 +36,7 @@ import (
"github.com/lightningnetwork/lnd/lnrpc/signrpc" "github.com/lightningnetwork/lnd/lnrpc/signrpc"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing"
"github.com/lightningnetwork/lnd/signal"
"github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/tor"
) )
@ -510,7 +511,7 @@ func DefaultConfig() Config {
// 2) Pre-parse the command line to check for an alternative config file // 2) Pre-parse the command line to check for an alternative config file
// 3) Load configuration file overwriting defaults with any specified options // 3) Load configuration file overwriting defaults with any specified options
// 4) Parse CLI options and overwrite/add any specified options // 4) Parse CLI options and overwrite/add any specified options
func LoadConfig() (*Config, error) { func LoadConfig(interceptor signal.Interceptor) (*Config, error) {
// Pre-parse the command line options to pick up an alternative config // Pre-parse the command line options to pick up an alternative config
// file. // file.
preCfg := DefaultConfig() preCfg := DefaultConfig()
@ -563,7 +564,7 @@ func LoadConfig() (*Config, error) {
} }
// Make sure everything we just loaded makes sense. // Make sure everything we just loaded makes sense.
cleanCfg, err := ValidateConfig(cfg, usageMessage) cleanCfg, err := ValidateConfig(cfg, usageMessage, interceptor)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -581,7 +582,8 @@ func LoadConfig() (*Config, error) {
// ValidateConfig check the given configuration to be sane. This makes sure no // ValidateConfig check the given configuration to be sane. This makes sure no
// illegal values or combination of values are set. All file system paths are // illegal values or combination of values are set. All file system paths are
// normalized. The cleaned up config is returned on success. // normalized. The cleaned up config is returned on success.
func ValidateConfig(cfg Config, usageMessage string) (*Config, error) { func ValidateConfig(cfg Config, usageMessage string,
interceptor signal.Interceptor) (*Config, error) {
// If the provided lnd directory is not the default, we'll modify the // If the provided lnd directory is not the default, we'll modify the
// path to all of the files and directories that will live within it. // path to all of the files and directories that will live within it.
lndDir := CleanAndExpandPath(cfg.LndDir) lndDir := CleanAndExpandPath(cfg.LndDir)
@ -1151,7 +1153,7 @@ func ValidateConfig(cfg Config, usageMessage string) (*Config, error) {
} }
// Initialize logging at the default logging level. // Initialize logging at the default logging level.
SetupLoggers(cfg.LogWriter) SetupLoggers(cfg.LogWriter, interceptor)
err = cfg.LogWriter.InitLogRotator( err = cfg.LogWriter.InitLogRotator(
filepath.Join(cfg.LogDir, defaultLogFilename), filepath.Join(cfg.LogDir, defaultLogFilename),
cfg.MaxLogFileSize, cfg.MaxLogFiles, cfg.MaxLogFileSize, cfg.MaxLogFiles,

15
lnd.go

@ -190,7 +190,7 @@ type ListenerCfg struct {
// validated main configuration struct and an optional listener config struct. // validated main configuration struct and an optional listener config struct.
// This function starts all main system components then blocks until a signal // This function starts all main system components then blocks until a signal
// is received on the shutdownChan at which point everything is shut down again. // is received on the shutdownChan at which point everything is shut down again.
func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error { func Main(cfg *Config, lisCfg ListenerCfg, interceptor signal.Interceptor) error {
defer func() { defer func() {
ltndLog.Info("Shutdown complete\n") ltndLog.Info("Shutdown complete\n")
err := cfg.LogWriter.Close() err := cfg.LogWriter.Close()
@ -378,6 +378,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error {
rpcServer := newRPCServer( rpcServer := newRPCServer(
cfg, interceptorChain, lisCfg.ExternalRPCSubserverCfg, cfg, interceptorChain, lisCfg.ExternalRPCSubserverCfg,
lisCfg.ExternalRestRegistrar, lisCfg.ExternalRestRegistrar,
interceptor,
) )
err = rpcServer.RegisterWithGrpcServer(grpcServer) err = rpcServer.RegisterWithGrpcServer(grpcServer)
@ -408,7 +409,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error {
// started with the --noseedbackup flag, we use the default password // started with the --noseedbackup flag, we use the default password
// for wallet encryption. // for wallet encryption.
if !cfg.NoSeedBackup { if !cfg.NoSeedBackup {
params, err := waitForWalletPassword(cfg, pwService) params, err := waitForWalletPassword(cfg, pwService, interceptor.ShutdownChannel())
if err != nil { if err != nil {
err := fmt.Errorf("unable to set up wallet password "+ err := fmt.Errorf("unable to set up wallet password "+
"listeners: %v", err) "listeners: %v", err)
@ -793,7 +794,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error {
"start_height=%v", bestHeight) "start_height=%v", bestHeight)
for { for {
if !signal.Alive() { if !interceptor.Alive() {
return nil return nil
} }
@ -856,7 +857,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error {
// Wait for shutdown signal from either a graceful server stop or from // Wait for shutdown signal from either a graceful server stop or from
// the interrupt handler. // the interrupt handler.
<-shutdownChan <-interceptor.ShutdownChannel()
return nil return nil
} }
@ -1328,7 +1329,8 @@ func startRestProxy(cfg *Config, rpcServer *rpcServer, restDialOpts []grpc.DialO
// waitForWalletPassword blocks until a password is provided by the user to // waitForWalletPassword blocks until a password is provided by the user to
// this RPC server. // this RPC server.
func waitForWalletPassword(cfg *Config, func waitForWalletPassword(cfg *Config,
pwService *walletunlocker.UnlockerService) (*WalletUnlockParams, error) { pwService *walletunlocker.UnlockerService,
shutdownChan <-chan struct{}) (*WalletUnlockParams, error) {
chainConfig := cfg.Bitcoin chainConfig := cfg.Bitcoin
if cfg.registeredChains.PrimaryChain() == chainreg.LitecoinChain { if cfg.registeredChains.PrimaryChain() == chainreg.LitecoinChain {
@ -1430,7 +1432,8 @@ func waitForWalletPassword(cfg *Config,
MacResponseChan: pwService.MacResponseChan, MacResponseChan: pwService.MacResponseChan,
}, nil }, nil
case <-signal.ShutdownChannel(): // If we got a shutdown signal we just return with an error immediately
case <-shutdownChan:
return nil, fmt.Errorf("shutting down") return nil, fmt.Errorf("shutting down")
} }
} }

104
log.go

@ -84,12 +84,36 @@ var (
atplLog = addLndPkgLogger("ATPL") atplLog = addLndPkgLogger("ATPL")
) )
// genSubLogger creates a logger for a subsystem. We provide an instance of
// a signal.Interceptor to be able to shutdown in the case of a critical error.
func genSubLogger(root *build.RotatingLogWriter,
interceptor signal.Interceptor) func(string) btclog.Logger {
// Create a shutdown function which will request shutdown from our
// interceptor if it is listening.
shutdown := func() {
if !interceptor.Listening() {
return
}
interceptor.RequestShutdown()
}
// Return a function which will create a sublogger from our root
// logger without shutdown fn.
return func(tag string) btclog.Logger {
return root.GenSubLogger(tag, shutdown)
}
}
// SetupLoggers initializes all package-global logger variables. // SetupLoggers initializes all package-global logger variables.
func SetupLoggers(root *build.RotatingLogWriter) { func SetupLoggers(root *build.RotatingLogWriter, interceptor signal.Interceptor) {
genLogger := genSubLogger(root, interceptor)
// Now that we have the proper root logger, we can replace the // Now that we have the proper root logger, we can replace the
// placeholder lnd package loggers. // placeholder lnd package loggers.
for _, l := range lndPkgLoggers { for _, l := range lndPkgLoggers {
l.Logger = build.NewSubLogger(l.subsystem, root.GenSubLogger) l.Logger = build.NewSubLogger(l.subsystem, genLogger)
SetSubLogger(root, l.subsystem, l.Logger) SetSubLogger(root, l.subsystem, l.Logger)
} }
@ -98,51 +122,55 @@ func SetupLoggers(root *build.RotatingLogWriter) {
signal.UseLogger(ltndLog) signal.UseLogger(ltndLog)
autopilot.UseLogger(atplLog) autopilot.UseLogger(atplLog)
AddSubLogger(root, "LNWL", lnwallet.UseLogger) AddSubLogger(root, "LNWL", interceptor, lnwallet.UseLogger)
AddSubLogger(root, "DISC", discovery.UseLogger) AddSubLogger(root, "DISC", interceptor, discovery.UseLogger)
AddSubLogger(root, "NTFN", chainntnfs.UseLogger) AddSubLogger(root, "NTFN", interceptor, chainntnfs.UseLogger)
AddSubLogger(root, "CHDB", channeldb.UseLogger) AddSubLogger(root, "CHDB", interceptor, channeldb.UseLogger)
AddSubLogger(root, "HSWC", htlcswitch.UseLogger) AddSubLogger(root, "HSWC", interceptor, htlcswitch.UseLogger)
AddSubLogger(root, "CMGR", connmgr.UseLogger) AddSubLogger(root, "CMGR", interceptor, connmgr.UseLogger)
AddSubLogger(root, "BTCN", neutrino.UseLogger) AddSubLogger(root, "BTCN", interceptor, neutrino.UseLogger)
AddSubLogger(root, "CNCT", contractcourt.UseLogger) AddSubLogger(root, "CNCT", interceptor, contractcourt.UseLogger)
AddSubLogger(root, "SPHX", sphinx.UseLogger) AddSubLogger(root, "SPHX", interceptor, sphinx.UseLogger)
AddSubLogger(root, "SWPR", sweep.UseLogger) AddSubLogger(root, "SWPR", interceptor, sweep.UseLogger)
AddSubLogger(root, "SGNR", signrpc.UseLogger) AddSubLogger(root, "SGNR", interceptor, signrpc.UseLogger)
AddSubLogger(root, "WLKT", walletrpc.UseLogger) AddSubLogger(root, "WLKT", interceptor, walletrpc.UseLogger)
AddSubLogger(root, "ARPC", autopilotrpc.UseLogger) AddSubLogger(root, "ARPC", interceptor, autopilotrpc.UseLogger)
AddSubLogger(root, "INVC", invoices.UseLogger) AddSubLogger(root, "INVC", interceptor, invoices.UseLogger)
AddSubLogger(root, "NANN", netann.UseLogger) AddSubLogger(root, "NANN", interceptor, netann.UseLogger)
AddSubLogger(root, "WTWR", watchtower.UseLogger) AddSubLogger(root, "WTWR", interceptor, watchtower.UseLogger)
AddSubLogger(root, "NTFR", chainrpc.UseLogger) AddSubLogger(root, "NTFR", interceptor, chainrpc.UseLogger)
AddSubLogger(root, "IRPC", invoicesrpc.UseLogger) AddSubLogger(root, "IRPC", interceptor, invoicesrpc.UseLogger)
AddSubLogger(root, "CHNF", channelnotifier.UseLogger) AddSubLogger(root, "CHNF", interceptor, channelnotifier.UseLogger)
AddSubLogger(root, "CHBU", chanbackup.UseLogger) AddSubLogger(root, "CHBU", interceptor, chanbackup.UseLogger)
AddSubLogger(root, "PROM", monitoring.UseLogger) AddSubLogger(root, "PROM", interceptor, monitoring.UseLogger)
AddSubLogger(root, "WTCL", wtclient.UseLogger) AddSubLogger(root, "WTCL", interceptor, wtclient.UseLogger)
AddSubLogger(root, "PRNF", peernotifier.UseLogger) AddSubLogger(root, "PRNF", interceptor, peernotifier.UseLogger)
AddSubLogger(root, "CHFD", chanfunding.UseLogger) AddSubLogger(root, "CHFD", interceptor, chanfunding.UseLogger)
AddSubLogger(root, "PEER", peer.UseLogger) AddSubLogger(root, "PEER", interceptor, peer.UseLogger)
AddSubLogger(root, "CHCL", chancloser.UseLogger) AddSubLogger(root, "CHCL", interceptor, chancloser.UseLogger)
AddSubLogger(root, routing.Subsystem, routing.UseLogger, localchans.UseLogger) AddSubLogger(root, routing.Subsystem, interceptor, routing.UseLogger, localchans.UseLogger)
AddSubLogger(root, routerrpc.Subsystem, routerrpc.UseLogger) AddSubLogger(root, routerrpc.Subsystem, interceptor, routerrpc.UseLogger)
AddSubLogger(root, chanfitness.Subsystem, chanfitness.UseLogger) AddSubLogger(root, chanfitness.Subsystem, interceptor, chanfitness.UseLogger)
AddSubLogger(root, verrpc.Subsystem, verrpc.UseLogger) AddSubLogger(root, verrpc.Subsystem, interceptor, verrpc.UseLogger)
AddSubLogger(root, healthcheck.Subsystem, healthcheck.UseLogger) AddSubLogger(root, healthcheck.Subsystem, interceptor, healthcheck.UseLogger)
AddSubLogger(root, chainreg.Subsystem, chainreg.UseLogger) AddSubLogger(root, chainreg.Subsystem, interceptor, chainreg.UseLogger)
AddSubLogger(root, chanacceptor.Subsystem, chanacceptor.UseLogger) AddSubLogger(root, chanacceptor.Subsystem, interceptor, chanacceptor.UseLogger)
AddSubLogger(root, funding.Subsystem, funding.UseLogger) AddSubLogger(root, funding.Subsystem, interceptor, funding.UseLogger)
} }
// AddSubLogger is a helper method to conveniently create and register the // AddSubLogger is a helper method to conveniently create and register the
// logger of one or more sub systems. // logger of one or more sub systems.
func AddSubLogger(root *build.RotatingLogWriter, subsystem string, func AddSubLogger(root *build.RotatingLogWriter, subsystem string,
useLoggers ...func(btclog.Logger)) { interceptor signal.Interceptor, useLoggers ...func(btclog.Logger)) {
// genSubLogger will return a callback for creating a logger instance,
// which we will give to the root logger.
genLogger := genSubLogger(root, interceptor)
// Create and register just a single logger to prevent them from // Create and register just a single logger to prevent them from
// overwriting each other internally. // overwriting each other internally.
logger := build.NewSubLogger(subsystem, root.GenSubLogger) logger := build.NewSubLogger(subsystem, genLogger)
SetSubLogger(root, subsystem, logger, useLoggers...) SetSubLogger(root, subsystem, logger, useLoggers...)
} }

@ -45,17 +45,18 @@ func Start(extraArgs string, unlockerReady, rpcReady Callback) {
// LoadConfig below. // LoadConfig below.
os.Args = append(os.Args, splitArgs...) os.Args = append(os.Args, splitArgs...)
// Load the configuration, and parse the extra arguments as command // Hook interceptor for os signals.
// line options. This function will also set up logging properly. shutdownInterceptor, err := signal.Intercept()
loadedConfig, err := lnd.LoadConfig()
if err != nil { if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err) _, _ = fmt.Fprintln(os.Stderr, err)
rpcReady.OnError(err) rpcReady.OnError(err)
return return
} }
// Hook interceptor for os signals. // Load the configuration, and parse the extra arguments as command
if err := signal.Intercept(); err != nil { // line options. This function will also set up logging properly.
loadedConfig, err := lnd.LoadConfig(shutdownInterceptor)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err) _, _ = fmt.Fprintln(os.Stderr, err)
rpcReady.OnError(err) rpcReady.OnError(err)
return return
@ -85,7 +86,7 @@ func Start(extraArgs string, unlockerReady, rpcReady Callback) {
// be executed in the case of a graceful shutdown. // be executed in the case of a graceful shutdown.
go func() { go func() {
if err := lnd.Main( if err := lnd.Main(
loadedConfig, cfg, signal.ShutdownChannel(), loadedConfig, cfg, shutdownInterceptor,
); err != nil { ); err != nil {
if e, ok := err.(*flags.Error); ok && if e, ok := err.(*flags.Error); ok &&
e.Type == flags.ErrHelp { e.Type == flags.ErrHelp {

@ -520,6 +520,9 @@ type rpcServer struct {
// extRestRegistrar is optional and specifies the registration // extRestRegistrar is optional and specifies the registration
// callback to register external REST subservers. // callback to register external REST subservers.
extRestRegistrar RestRegistrar extRestRegistrar RestRegistrar
// interceptor is used to be able to request a shutdown
interceptor signal.Interceptor
} }
// A compile time check to ensure that rpcServer fully implements the // A compile time check to ensure that rpcServer fully implements the
@ -531,7 +534,8 @@ var _ lnrpc.LightningServer = (*rpcServer)(nil)
// be used to register the LightningService with the gRPC server. // be used to register the LightningService with the gRPC server.
func newRPCServer(cfg *Config, interceptorChain *rpcperms.InterceptorChain, func newRPCServer(cfg *Config, interceptorChain *rpcperms.InterceptorChain,
extSubserverCfg *RPCSubserverConfig, extSubserverCfg *RPCSubserverConfig,
extRestRegistrar RestRegistrar) *rpcServer { extRestRegistrar RestRegistrar,
interceptor signal.Interceptor) *rpcServer {
// We go trhough the list of registered sub-servers, and create a gRPC // We go trhough the list of registered sub-servers, and create a gRPC
// handler for each. These are used to register with the gRPC server // handler for each. These are used to register with the gRPC server
@ -552,6 +556,7 @@ func newRPCServer(cfg *Config, interceptorChain *rpcperms.InterceptorChain,
extSubserverCfg: extSubserverCfg, extSubserverCfg: extSubserverCfg,
extRestRegistrar: extRestRegistrar, extRestRegistrar: extRestRegistrar,
quit: make(chan struct{}, 1), quit: make(chan struct{}, 1),
interceptor: interceptor,
} }
} }
@ -5383,8 +5388,7 @@ func (r *rpcServer) GetNetworkInfo(ctx context.Context,
// a graceful shutdown of the daemon. // a graceful shutdown of the daemon.
func (r *rpcServer) StopDaemon(ctx context.Context, func (r *rpcServer) StopDaemon(ctx context.Context,
_ *lnrpc.StopRequest) (*lnrpc.StopResponse, error) { _ *lnrpc.StopRequest) (*lnrpc.StopResponse, error) {
r.interceptor.RequestShutdown()
signal.RequestShutdown()
return &lnrpc.StopResponse{}, nil return &lnrpc.StopResponse{}, nil
} }

@ -14,29 +14,40 @@ import (
) )
var ( var (
// interruptChannel is used to receive SIGINT (Ctrl+C) signals.
interruptChannel = make(chan os.Signal, 1)
// shutdownRequestChannel is used to request the daemon to shutdown
// gracefully, similar to when receiving SIGINT.
shutdownRequestChannel = make(chan struct{})
// started indicates whether we have started our main interrupt handler. // started indicates whether we have started our main interrupt handler.
// This field should be used atomically. // This field should be used atomically.
started int32 started int32
// quit is closed when instructing the main interrupt handler to exit.
quit = make(chan struct{})
// shutdownChannel is closed once the main interrupt handler exits.
shutdownChannel = make(chan struct{})
) )
// Intercept starts the interception of interrupt signals. Note that this // Interceptor contains channels and methods regarding application shutdown
// function can only be called once. // and interrupt signals
func Intercept() error { type Interceptor struct {
// interruptChannel is used to receive SIGINT (Ctrl+C) signals.
interruptChannel chan os.Signal
// shutdownChannel is closed once the main interrupt handler exits.
shutdownChannel chan struct{}
// shutdownRequestChannel is used to request the daemon to shutdown
// gracefully, similar to when receiving SIGINT.
shutdownRequestChannel chan struct{}
// quit is closed when instructing the main interrupt handler to exit.
quit chan struct{}
}
// Intercept starts the interception of interrupt signals and returns an `Interceptor` instance.
// Note that any previous active interceptor must be stopped before a new one can be created
func Intercept() (Interceptor, error) {
if !atomic.CompareAndSwapInt32(&started, 0, 1) { if !atomic.CompareAndSwapInt32(&started, 0, 1) {
return errors.New("intercept already started") return Interceptor{}, errors.New("intercept already started")
}
channels := Interceptor{
interruptChannel: make(chan os.Signal, 1),
shutdownChannel: make(chan struct{}),
shutdownRequestChannel: make(chan struct{}),
quit: make(chan struct{}),
} }
signalsToCatch := []os.Signal{ signalsToCatch := []os.Signal{
@ -45,10 +56,10 @@ func Intercept() error {
syscall.SIGTERM, syscall.SIGTERM,
syscall.SIGQUIT, syscall.SIGQUIT,
} }
signal.Notify(interruptChannel, signalsToCatch...) signal.Notify(channels.interruptChannel, signalsToCatch...)
go mainInterruptHandler() go channels.mainInterruptHandler()
return nil return channels, nil
} }
// mainInterruptHandler listens for SIGINT (Ctrl+C) signals on the // mainInterruptHandler listens for SIGINT (Ctrl+C) signals on the
@ -56,7 +67,8 @@ func Intercept() error {
// invokes the registered interruptCallbacks accordingly. It also listens for // invokes the registered interruptCallbacks accordingly. It also listens for
// callback registration. // callback registration.
// It must be run as a goroutine. // It must be run as a goroutine.
func mainInterruptHandler() { func (c *Interceptor) mainInterruptHandler() {
defer atomic.StoreInt32(&started, 0)
// isShutdown is a flag which is used to indicate whether or not // isShutdown is a flag which is used to indicate whether or not
// the shutdown signal has already been received and hence any future // the shutdown signal has already been received and hence any future
// attempts to add a new interrupt handler should invoke them // attempts to add a new interrupt handler should invoke them
@ -76,22 +88,23 @@ func mainInterruptHandler() {
// Signal the main interrupt handler to exit, and stop accept // Signal the main interrupt handler to exit, and stop accept
// post-facto requests. // post-facto requests.
close(quit) close(c.quit)
} }
for { for {
select { select {
case signal := <-interruptChannel: case signal := <-c.interruptChannel:
log.Infof("Received %v", signal) log.Infof("Received %v", signal)
shutdown() shutdown()
case <-shutdownRequestChannel: case <-c.shutdownRequestChannel:
log.Infof("Received shutdown request.") log.Infof("Received shutdown request.")
shutdown() shutdown()
case <-quit: case <-c.quit:
log.Infof("Gracefully shutting down.") log.Infof("Gracefully shutting down.")
close(shutdownChannel) close(c.shutdownChannel)
signal.Stop(c.interruptChannel)
return return
} }
} }
@ -99,7 +112,7 @@ func mainInterruptHandler() {
// Listening returns true if the main interrupt handler has been started, and // Listening returns true if the main interrupt handler has been started, and
// has not been killed. // has not been killed.
func Listening() bool { func (c *Interceptor) Listening() bool {
// If our started field is not set, we are not yet listening for // If our started field is not set, we are not yet listening for
// interrupts. // interrupts.
if atomic.LoadInt32(&started) != 1 { if atomic.LoadInt32(&started) != 1 {
@ -108,13 +121,13 @@ func Listening() bool {
// If we have started our main goroutine, we check whether we have // If we have started our main goroutine, we check whether we have
// stopped it yet. // stopped it yet.
return Alive() return c.Alive()
} }
// Alive returns true if the main interrupt handler has not been killed. // Alive returns true if the main interrupt handler has not been killed.
func Alive() bool { func (c *Interceptor) Alive() bool {
select { select {
case <-quit: case <-c.quit:
return false return false
default: default:
return true return true
@ -122,15 +135,15 @@ func Alive() bool {
} }
// RequestShutdown initiates a graceful shutdown from the application. // RequestShutdown initiates a graceful shutdown from the application.
func RequestShutdown() { func (c *Interceptor) RequestShutdown() {
select { select {
case shutdownRequestChannel <- struct{}{}: case c.shutdownRequestChannel <- struct{}{}:
case <-quit: case <-c.quit:
} }
} }
// ShutdownChannel returns the channel that will be closed once the main // ShutdownChannel returns the channel that will be closed once the main
// interrupt handler has exited. // interrupt handler has exited.
func ShutdownChannel() <-chan struct{} { func (c *Interceptor) ShutdownChannel() <-chan struct{} {
return shutdownChannel return c.shutdownChannel
} }