diff --git a/lntest/harness.go b/lntest/harness.go index 8ba49559..2a8f9668 100644 --- a/lntest/harness.go +++ b/lntest/harness.go @@ -1548,9 +1548,9 @@ func FileExists(path string) bool { return true } -// CopyAll copies all files and directories from srcDir to dstDir recursively. +// copyAll copies all files and directories from srcDir to dstDir recursively. // Note that this function does not support links. -func CopyAll(dstDir, srcDir string) error { +func copyAll(dstDir, srcDir string) error { entries, err := ioutil.ReadDir(srcDir) if err != nil { return err @@ -1571,7 +1571,7 @@ func CopyAll(dstDir, srcDir string) error { return err } - err = CopyAll(dstPath, srcPath) + err = copyAll(dstPath, srcPath) if err != nil { return err } @@ -1582,3 +1582,43 @@ func CopyAll(dstDir, srcDir string) error { return nil } + +// BackupDb creates a backup of the current database. +func (n *NetworkHarness) BackupDb(hn *HarnessNode) error { + if hn.backupDbDir != "" { + return errors.New("backup already created") + } + + // Backup files. + tempDir, err := ioutil.TempDir("", "past-state") + if err != nil { + return fmt.Errorf("unable to create temp db folder: %v", err) + } + + if err := copyAll(tempDir, hn.DBDir()); err != nil { + return fmt.Errorf("unable to copy database files: %v", err) + } + + hn.backupDbDir = tempDir + + return nil +} + +// RestoreDb restores a database backup. +func (n *NetworkHarness) RestoreDb(hn *HarnessNode) error { + if hn.backupDbDir == "" { + return errors.New("no database backup created") + } + + // Restore files. + if err := copyAll(hn.DBDir(), hn.backupDbDir); err != nil { + return fmt.Errorf("unable to copy database files: %v", err) + } + + if err := os.RemoveAll(hn.backupDbDir); err != nil { + return fmt.Errorf("unable to remove backup dir: %v", err) + } + hn.backupDbDir = "" + + return nil +} diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index 5abf8230..9244f21e 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -6629,18 +6629,10 @@ func testRevokedCloseRetribution(net *lntest.NetworkHarness, t *harnessTest) { // broadcast this soon to be revoked state. bobStateNumPreCopy := bobChan.NumUpdates - // Create a temporary file to house Bob's database state at this - // particular point in history. - bobTempDbPath, err := ioutil.TempDir("", "bob-past-state") - if err != nil { - t.Fatalf("unable to create temp db folder: %v", err) - } - defer os.Remove(bobTempDbPath) - // With the temporary file created, copy Bob's current state into the // temporary file we created above. Later after more updates, we'll // restore this state. - if err := lntest.CopyAll(bobTempDbPath, net.Bob.DBDir()); err != nil { + if err := net.BackupDb(net.Bob); err != nil { t.Fatalf("unable to copy database files: %v", err) } @@ -6666,7 +6658,7 @@ func testRevokedCloseRetribution(net *lntest.NetworkHarness, t *harnessTest) { // state. With this, we essentially force Bob to travel back in time // within the channel's history. if err = net.RestartNode(net.Bob, func() error { - return lntest.CopyAll(net.Bob.DBDir(), bobTempDbPath) + return net.RestoreDb(net.Bob) }); err != nil { t.Fatalf("unable to restart node: %v", err) } @@ -6872,18 +6864,10 @@ func testRevokedCloseRetributionZeroValueRemoteOutput(net *lntest.NetworkHarness // broadcast this soon to be revoked state. carolStateNumPreCopy := carolChan.NumUpdates - // Create a temporary file to house Carol's database state at this - // particular point in history. - carolTempDbPath, err := ioutil.TempDir("", "carol-past-state") - if err != nil { - t.Fatalf("unable to create temp db folder: %v", err) - } - defer os.Remove(carolTempDbPath) - // With the temporary file created, copy Carol's current state into the // temporary file we created above. Later after more updates, we'll // restore this state. - if err := lntest.CopyAll(carolTempDbPath, carol.DBDir()); err != nil { + if err := net.BackupDb(carol); err != nil { t.Fatalf("unable to copy database files: %v", err) } @@ -6908,7 +6892,7 @@ func testRevokedCloseRetributionZeroValueRemoteOutput(net *lntest.NetworkHarness // state. With this, we essentially force Carol to travel back in time // within the channel's history. if err = net.RestartNode(carol, func() error { - return lntest.CopyAll(carol.DBDir(), carolTempDbPath) + return net.RestoreDb(carol) }); err != nil { t.Fatalf("unable to restart node: %v", err) } @@ -7185,18 +7169,10 @@ func testRevokedCloseRetributionRemoteHodl(net *lntest.NetworkHarness, // to her channel. checkCarolNumUpdatesAtLeast(1) - // Create a temporary file to house Carol's database state at this - // particular point in history. - carolTempDbPath, err := ioutil.TempDir("", "carol-past-state") - if err != nil { - t.Fatalf("unable to create temp db folder: %v", err) - } - defer os.Remove(carolTempDbPath) - // With the temporary file created, copy Carol's current state into the // temporary file we created above. Later after more updates, we'll // restore this state. - if err := lntest.CopyAll(carolTempDbPath, carol.DBDir()); err != nil { + if err := net.BackupDb(carol); err != nil { t.Fatalf("unable to copy database files: %v", err) } @@ -7230,7 +7206,7 @@ func testRevokedCloseRetributionRemoteHodl(net *lntest.NetworkHarness, // state. With this, we essentially force Carol to travel back in time // within the channel's history. if err = net.RestartNode(carol, func() error { - return lntest.CopyAll(carol.DBDir(), carolTempDbPath) + return net.RestoreDb(carol) }); err != nil { t.Fatalf("unable to restart node: %v", err) } @@ -7618,18 +7594,10 @@ func testRevokedCloseRetributionAltruistWatchtowerCase( // broadcast this soon to be revoked state. carolStateNumPreCopy := carolChan.NumUpdates - // Create a temporary file to house Carol's database state at this - // particular point in history. - carolTempDbPath, err := ioutil.TempDir("", "carol-past-state") - if err != nil { - t.Fatalf("unable to create temp db folder: %v", err) - } - defer os.Remove(carolTempDbPath) - // With the temporary file created, copy Carol's current state into the // temporary file we created above. Later after more updates, we'll // restore this state. - if err := lntest.CopyAll(carolTempDbPath, carol.DBDir()); err != nil { + if err := net.BackupDb(carol); err != nil { t.Fatalf("unable to copy database files: %v", err) } @@ -7688,7 +7656,7 @@ func testRevokedCloseRetributionAltruistWatchtowerCase( // state. With this, we essentially force Carol to travel back in time // within the channel's history. if err = net.RestartNode(carol, func() error { - return lntest.CopyAll(carol.DBDir(), carolTempDbPath) + return net.RestoreDb(carol) }); err != nil { t.Fatalf("unable to restart node: %v", err) } @@ -8158,18 +8126,10 @@ func testDataLossProtection(net *lntest.NetworkHarness, t *harnessTest) { // revoke this state. stateNumPreCopy := nodeChan.NumUpdates - // Create a temporary file to house the database state at this - // particular point in history. - tempDbPath, err := ioutil.TempDir("", node.Name()+"-past-state") - if err != nil { - t.Fatalf("unable to create temp db folder: %v", err) - } - defer os.Remove(tempDbPath) - // With the temporary file created, copy the current state into // the temporary file we created above. Later after more // updates, we'll restore this state. - if err := lntest.CopyAll(tempDbPath, node.DBDir()); err != nil { + if err := net.BackupDb(node); err != nil { t.Fatalf("unable to copy database files: %v", err) } @@ -8196,7 +8156,7 @@ func testDataLossProtection(net *lntest.NetworkHarness, t *harnessTest) { // force the node to travel back in time within the channel's // history. if err = net.RestartNode(node, func() error { - return lntest.CopyAll(node.DBDir(), tempDbPath) + return net.RestoreDb(node) }); err != nil { t.Fatalf("unable to restart node: %v", err) } diff --git a/lntest/node.go b/lntest/node.go index 3af26bae..d9902326 100644 --- a/lntest/node.go +++ b/lntest/node.go @@ -389,6 +389,9 @@ type HarnessNode struct { WalletKitClient walletrpc.WalletKitClient Watchtower watchtowerrpc.WatchtowerClient WatchtowerClient wtclientrpc.WatchtowerClientClient + + // backupDbDir is the path where a database backup is stored, if any. + backupDbDir string } // Assert *HarnessNode implements the lnrpc.LightningClient interface. @@ -1098,6 +1101,13 @@ func (hn *HarnessNode) SetExtraArgs(extraArgs []string) { // cleanup cleans up all the temporary files created by the node's process. func (hn *HarnessNode) cleanup() error { + if hn.backupDbDir != "" { + err := os.RemoveAll(hn.backupDbDir) + if err != nil { + return fmt.Errorf("unable to remove backup dir: %v", err) + } + } + return os.RemoveAll(hn.Cfg.BaseDir) }