diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index e06f8994..9d67af91 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -95,7 +95,7 @@ func TestOpenChannelEncodeDecode(t *testing.T) { // Next, create channeldb for the first time, also setting a mock // EncryptorDecryptor implementation for testing purposes. - cdb, err := Create(tempDirName) + cdb, err := Open(tempDirName) if err != nil { t.Fatalf("unable to create channeldb: %v", err) } @@ -154,7 +154,7 @@ func TestOpenChannelEncodeDecode(t *testing.T) { TotalSatoshisReceived: 2, TotalNetFees: 9, CreationTime: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), - db: cdb, + Db: cdb, } if err := state.FullSync(); err != nil { diff --git a/channeldb/db.go b/channeldb/db.go index c947ebba..a962cecc 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -44,22 +44,15 @@ type DB struct { // sensitive data encrypted by the passed EncryptorDecryptor implementation. // TODO(roasbeef): versioning? func Open(dbPath string) (*DB, error) { - if !fileExists(dbPath) { - return nil, ErrNoExists - } - path := filepath.Join(dbPath, dbName) - bdb, err := bolt.Open(path, 0600, nil) - if err != nil { - return nil, err + + if !fileExists(path) { + if err := createChannelDB(dbPath); err != nil { + return nil, err + } } - return &DB{store: bdb}, nil -} - -// Create... -func Create(dbPath string) (*DB, error) { - bdb, err := createChannelDB(dbPath) + bdb, err := bolt.Open(path, 0600, nil) if err != nil { return nil, err } @@ -86,17 +79,17 @@ func (d *DB) Close() error { } // createChannelDB... -func createChannelDB(dbPath string) (*bolt.DB, error) { +func createChannelDB(dbPath string) error { if !fileExists(dbPath) { if err := os.MkdirAll(dbPath, 0700); err != nil { - return nil, err + return err } } path := filepath.Join(dbPath, dbName) bdb, err := bolt.Open(path, 0600, nil) if err != nil { - return nil, err + return err } err = bdb.Update(func(tx *bolt.Tx) error { @@ -115,10 +108,10 @@ func createChannelDB(dbPath string) (*bolt.DB, error) { return nil }) if err != nil { - return nil, fmt.Errorf("unable to create new channeldb") + return fmt.Errorf("unable to create new channeldb") } - return bdb, nil + return bdb.Close() } // fileExists... diff --git a/channeldb/db_test.go b/channeldb/db_test.go index f2053856..69c2d3b5 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -3,16 +3,11 @@ package channeldb import ( "io/ioutil" "os" + "path/filepath" "testing" ) -func TestOpenNotCreated(t *testing.T) { - if _, err := Open("path doesn't exist"); err != ErrNoExists { - t.Fatalf("channeldb Open should fail due to non-existant dir") - } -} - -func TestCreateThenOpen(t *testing.T) { +func TestOpenWithCreate(t *testing.T) { // First, create a temporary directory to be used for the duration of // this test. tempDirName, err := ioutil.TempDir("", "channeldb") @@ -21,8 +16,9 @@ func TestCreateThenOpen(t *testing.T) { } defer os.RemoveAll(tempDirName) - // Next, create channeldb for the first time. - cdb, err := Create(tempDirName) + // Next, open thereby creating channeldb for the first time. + dbPath := filepath.Join(tempDirName, "cdb") + cdb, err := Open(dbPath) if err != nil { t.Fatalf("unable to create channeldb: %v", err) } @@ -30,9 +26,8 @@ func TestCreateThenOpen(t *testing.T) { t.Fatalf("unable to close channeldb: %v", err) } - // Open should now succeed as the cdb was created above. - cdb, err = Open(tempDirName) - if err != nil { - t.Fatalf("unable to open channeldb: %v", err) + // The path should have been succesfully created. + if !fileExists(dbPath) { + t.Fatalf("channeldb failed to create data directory") } }