diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 030c1325..c470a8f5 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "fmt" "io" - "io/ioutil" "reflect" "testing" "time" @@ -15,20 +14,6 @@ import ( "github.com/lightningnetwork/lnd/record" ) -func initDB() (*DB, error) { - tempPath, err := ioutil.TempDir("", "switchdb") - if err != nil { - return nil, err - } - - db, err := Open(tempPath) - if err != nil { - return nil, err - } - - return db, err -} - func genPreimage() ([32]byte, error) { var preimage [32]byte if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { @@ -66,7 +51,8 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo, func TestPaymentControlSwitchFail(t *testing.T) { t.Parallel() - db, err := initDB() + db, cleanup, err := makeTestDB() + defer cleanup() if err != nil { t.Fatalf("unable to init db: %v", err) } @@ -202,7 +188,9 @@ func TestPaymentControlSwitchFail(t *testing.T) { func TestPaymentControlSwitchDoubleSend(t *testing.T) { t.Parallel() - db, err := initDB() + db, cleanup, err := makeTestDB() + defer cleanup() + if err != nil { t.Fatalf("unable to init db: %v", err) } @@ -282,7 +270,9 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { t.Parallel() - db, err := initDB() + db, cleanup, err := makeTestDB() + defer cleanup() + if err != nil { t.Fatalf("unable to init db: %v", err) } @@ -313,7 +303,9 @@ func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { func TestPaymentControlFailsWithoutInFlight(t *testing.T) { t.Parallel() - db, err := initDB() + db, cleanup, err := makeTestDB() + defer cleanup() + if err != nil { t.Fatalf("unable to init db: %v", err) } @@ -339,7 +331,9 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Parallel() - db, err := initDB() + db, cleanup, err := makeTestDB() + defer cleanup() + if err != nil { t.Fatalf("unable to init db: %v", err) } @@ -481,7 +475,9 @@ func TestPaymentControlMultiShard(t *testing.T) { } runSubTest := func(t *testing.T, test testCase) { - db, err := initDB() + db, cleanup, err := makeTestDB() + defer cleanup() + if err != nil { t.Fatalf("unable to init db: %v", err) } @@ -728,7 +724,9 @@ func TestPaymentControlMultiShard(t *testing.T) { func TestPaymentControlMPPRecordValidation(t *testing.T) { t.Parallel() - db, err := initDB() + db, cleanup, err := makeTestDB() + defer cleanup() + if err != nil { t.Fatalf("unable to init db: %v", err) } diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index b5228722..2f0d88bc 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -351,7 +351,9 @@ func TestQueryPayments(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - db, err := initDB() + db, cleanup, err := makeTestDB() + defer cleanup() + if err != nil { t.Fatalf("unable to init db: %v", err) }