lnd: switch to nested a main function, fix defer execution

This commit adds a new nested main function “lndMain”, within the
packages’s normal main function. This nesting is required in order to
properly execute all queued defer statements in the case of a forced
exit.
This commit is contained in:
Olaoluwa Osuntokun 2016-07-12 17:03:29 -07:00
parent 48491a7fee
commit 850feed877
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2

42
lnd.go

@ -12,7 +12,6 @@ import (
"strconv" "strconv"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/grpclog"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
@ -24,16 +23,15 @@ var (
shutdownChannel = make(chan struct{}) shutdownChannel = make(chan struct{})
) )
func main() { // lndMain is the true entry point for lnd. This function is required since
// Use all processor cores. // defers created in the top-level scope of a main method aren't executed if
// TODO(roasbeef): remove this if required version # is > 1.6? // os.Exit() is called.
runtime.GOMAXPROCS(runtime.NumCPU()) func lndMain() error {
// Load the configuration, and parse any command line options. This // Load the configuration, and parse any command line options. This
// function will also set up logging properly. // function will also set up logging properly.
loadedConfig, err := loadConfig() loadedConfig, err := loadConfig()
if err != nil { if err != nil {
os.Exit(1) return err
} }
cfg = loadedConfig cfg = loadedConfig
defer backendLog.Flush() defer backendLog.Flush()
@ -43,7 +41,7 @@ func main() {
if loadedConfig.SPVMode == true { if loadedConfig.SPVMode == true {
shell(loadedConfig.SPVHostAdr, activeNetParams) shell(loadedConfig.SPVHostAdr, activeNetParams)
return return err
} }
// Enable http profiling server if requested. // Enable http profiling server if requested.
@ -62,7 +60,7 @@ func main() {
chanDB, err := channeldb.Open(loadedConfig.DataDir, activeNetParams) chanDB, err := channeldb.Open(loadedConfig.DataDir, activeNetParams)
if err != nil { if err != nil {
fmt.Println("unable to open channeldb: ", err) fmt.Println("unable to open channeldb: ", err)
os.Exit(1) return err
} }
defer chanDB.Close() defer chanDB.Close()
@ -70,12 +68,12 @@ func main() {
f, err := os.Open(loadedConfig.RPCCert) f, err := os.Open(loadedConfig.RPCCert)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
os.Exit(1) return err
} }
cert, err := ioutil.ReadAll(f) cert, err := ioutil.ReadAll(f)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
os.Exit(1) return err
} }
defer f.Close() defer f.Close()
@ -93,11 +91,11 @@ func main() {
wallet, err := lnwallet.NewLightningWallet(config, chanDB) wallet, err := lnwallet.NewLightningWallet(config, chanDB)
if err != nil { if err != nil {
fmt.Printf("unable to create wallet: %v\n", err) fmt.Printf("unable to create wallet: %v\n", err)
os.Exit(1) return err
} }
if err := wallet.Startup(); err != nil { if err := wallet.Startup(); err != nil {
fmt.Printf("unable to start wallet: %v\n", err) fmt.Printf("unable to start wallet: %v\n", err)
os.Exit(1) return err
} }
ltndLog.Info("LightningWallet opened") ltndLog.Info("LightningWallet opened")
@ -112,7 +110,7 @@ func main() {
server, err := newServer(defaultListenAddrs, wallet, chanDB) server, err := newServer(defaultListenAddrs, wallet, chanDB)
if err != nil { if err != nil {
srvrLog.Errorf("unable to create server: %v\n", err) srvrLog.Errorf("unable to create server: %v\n", err)
os.Exit(1) return err
} }
server.Start() server.Start()
@ -130,9 +128,8 @@ func main() {
// Finally, start the grpc server listening for HTTP/2 connections. // Finally, start the grpc server listening for HTTP/2 connections.
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", loadedConfig.RPCPort)) lis, err := net.Listen("tcp", fmt.Sprintf(":%d", loadedConfig.RPCPort))
if err != nil { if err != nil {
grpclog.Fatalf("failed to listen: %v", err)
fmt.Printf("failed to listen: %v", err) fmt.Printf("failed to listen: %v", err)
os.Exit(1) return err
} }
go func() { go func() {
rpcsLog.Infof("RPC server listening on %s", lis.Addr()) rpcsLog.Infof("RPC server listening on %s", lis.Addr())
@ -143,4 +140,17 @@ func main() {
// the interrupt handler. // the interrupt handler.
<-shutdownChannel <-shutdownChannel
ltndLog.Info("Shutdown complete") ltndLog.Info("Shutdown complete")
return nil
}
func main() {
// Use all processor cores.
// TODO(roasbeef): remove this if required version # is > 1.6?
runtime.GOMAXPROCS(runtime.NumCPU())
// Call the "real" main in a nested manner so the defers will properly
// be executed in the case of a graceful shutdown.
if err := lndMain(); err != nil {
os.Exit(1)
}
} }