diff --git a/lnd.go b/lnd.go index ad8cf0a9..99a85bb8 100644 --- a/lnd.go +++ b/lnd.go @@ -386,6 +386,11 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error { walletInitParams = *params privateWalletPw = walletInitParams.Password publicWalletPw = walletInitParams.Password + defer func() { + if err := walletInitParams.UnloadWallet(); err != nil { + ltndLog.Errorf("Could not unload wallet: %v", err) + } + }() if walletInitParams.RecoveryWindow > 0 { ltndLog.Infof("Wallet recovery mode enabled with "+ @@ -1002,6 +1007,10 @@ type WalletUnlockParams struct { // ChansToRestore a set of static channel backups that should be // restored before the main server instance starts up. ChansToRestore walletunlocker.ChannelsToRecover + + // UnloadWallet is a function for unloading the wallet, which should + // be called on shutdown. + UnloadWallet func() error } // waitForWalletPassword will spin up gRPC and REST endpoints for the @@ -1163,6 +1172,7 @@ func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr, RecoveryWindow: recoveryWindow, Wallet: newWallet, ChansToRestore: initMsg.ChanBackups, + UnloadWallet: loader.UnloadWallet, }, nil // The wallet has already been created in the past, and is simply being @@ -1173,6 +1183,7 @@ func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr, RecoveryWindow: unlockMsg.RecoveryWindow, Wallet: unlockMsg.Wallet, ChansToRestore: unlockMsg.ChanBackups, + UnloadWallet: unlockMsg.UnloadWallet, }, nil case <-signal.ShutdownChannel(): diff --git a/walletunlocker/service.go b/walletunlocker/service.go index 90e84bae..39a66701 100644 --- a/walletunlocker/service.go +++ b/walletunlocker/service.go @@ -81,6 +81,10 @@ type WalletUnlockMsg struct { // ChanBackups a set of static channel backups that should be received // after the wallet has been unlocked. ChanBackups ChannelsToRecover + + // UnloadWallet is a function for unloading the wallet, which should + // be called on shutdown. + UnloadWallet func() error } // UnlockerService implements the WalletUnlocker service used to provide lnd @@ -346,6 +350,7 @@ func (u *UnlockerService) UnlockWallet(ctx context.Context, Passphrase: password, RecoveryWindow: recoveryWindow, Wallet: unlockedWallet, + UnloadWallet: loader.UnloadWallet, } // Before we return the unlock payload, we'll check if we can extract