From 6a24a03cec97bd0e32c8da40dad1e7bf22f54f3a Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 18 Feb 2020 19:35:53 +0100 Subject: [PATCH] channeldb+kvdb: walletdb/kvdb interface etcd implementation This commit adds a full interface implementation of the walletdb/kvdb interface with detailed tests. --- channeldb/kvdb/etcd.go | 10 + channeldb/kvdb/etcd/bucket.go | 83 ++++ channeldb/kvdb/etcd/bucket_test.go | 38 ++ channeldb/kvdb/etcd/db.go | 79 +++- channeldb/kvdb/etcd/db_test.go | 42 ++ channeldb/kvdb/etcd/driver.go | 66 +++ channeldb/kvdb/etcd/driver_test.go | 28 ++ channeldb/kvdb/etcd/readwrite_bucket.go | 355 ++++++++++++++++ channeldb/kvdb/etcd/readwrite_bucket_test.go | 402 +++++++++++++++++++ channeldb/kvdb/etcd/readwrite_cursor.go | 143 +++++++ channeldb/kvdb/etcd/readwrite_cursor_test.go | 291 ++++++++++++++ channeldb/kvdb/etcd/readwrite_tx.go | 93 +++++ channeldb/kvdb/etcd/readwrite_tx_test.go | 154 +++++++ 13 files changed, 1783 insertions(+), 1 deletion(-) create mode 100644 channeldb/kvdb/etcd.go create mode 100644 channeldb/kvdb/etcd/bucket.go create mode 100644 channeldb/kvdb/etcd/bucket_test.go create mode 100644 channeldb/kvdb/etcd/db_test.go create mode 100644 channeldb/kvdb/etcd/driver.go create mode 100644 channeldb/kvdb/etcd/driver_test.go create mode 100644 channeldb/kvdb/etcd/readwrite_bucket.go create mode 100644 channeldb/kvdb/etcd/readwrite_bucket_test.go create mode 100644 channeldb/kvdb/etcd/readwrite_cursor.go create mode 100644 channeldb/kvdb/etcd/readwrite_cursor_test.go create mode 100644 channeldb/kvdb/etcd/readwrite_tx.go create mode 100644 channeldb/kvdb/etcd/readwrite_tx_test.go diff --git a/channeldb/kvdb/etcd.go b/channeldb/kvdb/etcd.go new file mode 100644 index 00000000..89ec463b --- /dev/null +++ b/channeldb/kvdb/etcd.go @@ -0,0 +1,10 @@ +package kvdb + +import ( + _ "github.com/lightningnetwork/lnd/channeldb/kvdb/etcd" // Import to register backend. +) + +// EtcdBackendName is the name of the backend that should be passed into +// kvdb.Create to initialize a new instance of kvdb.Backend backed by a live +// instance of etcd. +const EtcdBackendName = "etcd" diff --git a/channeldb/kvdb/etcd/bucket.go b/channeldb/kvdb/etcd/bucket.go new file mode 100644 index 00000000..48a32b9c --- /dev/null +++ b/channeldb/kvdb/etcd/bucket.go @@ -0,0 +1,83 @@ +package etcd + +import ( + "crypto/sha256" +) + +const ( + bucketIDLength = 32 +) + +var ( + bucketPrefix = []byte("b") + valuePrefix = []byte("v") + sequencePrefix = []byte("$") +) + +// rootBucketId returns a zero filled 32 byte array +func rootBucketID() []byte { + var rootID [bucketIDLength]byte + return rootID[:] +} + +// makeBucketID returns a deterministic key for the passed byte slice. +// Currently it returns the sha256 hash of the slice. +func makeBucketID(key []byte) [bucketIDLength]byte { + return sha256.Sum256(key) +} + +// isValidBucketID checks if the passed slice is the required length to be a +// valid bucket id. +func isValidBucketID(s []byte) bool { + return len(s) == bucketIDLength +} + +// makeKey concatenates prefix, parent and key into one byte slice. +// The prefix indicates the use of this key (whether bucket, value or sequence), +// while parentID refers to the parent bucket. +func makeKey(prefix, parent, key []byte) []byte { + keyBuf := make([]byte, len(prefix)+len(parent)+len(key)) + copy(keyBuf, prefix) + copy(keyBuf[len(prefix):], parent) + copy(keyBuf[len(prefix)+len(parent):], key) + + return keyBuf +} + +// makePrefix concatenates prefix with parent into one byte slice. +func makePrefix(prefix []byte, parent []byte) []byte { + prefixBuf := make([]byte, len(prefix)+len(parent)) + copy(prefixBuf, prefix) + copy(prefixBuf[len(prefix):], parent) + + return prefixBuf +} + +// makeBucketKey returns a bucket key from the passed parent bucket id and +// the key. +func makeBucketKey(parent []byte, key []byte) []byte { + return makeKey(bucketPrefix, parent, key) +} + +// makeValueKey returns a value key from the passed parent bucket id and +// the key. +func makeValueKey(parent []byte, key []byte) []byte { + return makeKey(valuePrefix, parent, key) +} + +// makeSequenceKey returns a sequence key of the passed parent bucket id. +func makeSequenceKey(parent []byte) []byte { + return makeKey(sequencePrefix, parent, nil) +} + +// makeBucketPrefix returns the bucket prefix of the passed parent bucket id. +// This prefix is used for all sub buckets. +func makeBucketPrefix(parent []byte) []byte { + return makePrefix(bucketPrefix, parent) +} + +// makeValuePrefix returns the value prefix of the passed parent bucket id. +// This prefix is used for all key/values in the bucket. +func makeValuePrefix(parent []byte) []byte { + return makePrefix(valuePrefix, parent) +} diff --git a/channeldb/kvdb/etcd/bucket_test.go b/channeldb/kvdb/etcd/bucket_test.go new file mode 100644 index 00000000..45d6155d --- /dev/null +++ b/channeldb/kvdb/etcd/bucket_test.go @@ -0,0 +1,38 @@ +package etcd + +// bkey is a helper functon used in tests to create a bucket key from passed +// bucket list. +func bkey(buckets ...string) string { + var bucketKey []byte + + parent := rootBucketID() + + for _, bucketName := range buckets { + bucketKey = makeBucketKey(parent, []byte(bucketName)) + id := makeBucketID(bucketKey) + parent = id[:] + } + + return string(bucketKey) +} + +// bval is a helper function used in tests to create a bucket value (the value +// for a bucket key) from the passed bucket list. +func bval(buckets ...string) string { + id := makeBucketID([]byte(bkey(buckets...))) + return string(id[:]) +} + +// vkey is a helper function used in tests to create a value key from the +// passed key and bucket list. +func vkey(key string, buckets ...string) string { + bucket := rootBucketID() + + for _, bucketName := range buckets { + bucketKey := makeBucketKey(bucket, []byte(bucketName)) + id := makeBucketID(bucketKey) + bucket = id[:] + } + + return string(makeValueKey(bucket, []byte(key))) +} diff --git a/channeldb/kvdb/etcd/db.go b/channeldb/kvdb/etcd/db.go index 4d7c9d0d..a5d844b0 100644 --- a/channeldb/kvdb/etcd/db.go +++ b/channeldb/kvdb/etcd/db.go @@ -1,8 +1,11 @@ package etcd import ( + "context" + "io" "time" + "github.com/btcsuite/btcwallet/walletdb" "github.com/coreos/etcd/clientv3" ) @@ -10,6 +13,9 @@ const ( // etcdConnectionTimeout is the timeout until successful connection to the // etcd instance. etcdConnectionTimeout = 10 * time.Second + + // etcdLongTimeout is a timeout for longer taking etcd operatons. + etcdLongTimeout = 30 * time.Second ) // db holds a reference to the etcd client connection. @@ -17,6 +23,9 @@ type db struct { cli *clientv3.Client } +// Enforce db implements the walletdb.DB interface. +var _ walletdb.DB = (*db)(nil) + // BackendConfig holds and etcd backend config and connection parameters. type BackendConfig struct { // Host holds the peer url of the etcd instance. @@ -49,7 +58,75 @@ func newEtcdBackend(config BackendConfig) (*db, error) { return backend, nil } -// Close closes the db, but closing the underlying etcd client connection. +// View opens a database read transaction and executes the function f with the +// transaction passed as a parameter. After f exits, the transaction is rolled +// back. If f errors, its error is returned, not a rollback error (if any +// occur). +func (db *db) View(f func(tx walletdb.ReadTx) error) error { + apply := func(stm STM) error { + return f(newReadWriteTx(stm)) + } + + return RunSTM(db.cli, apply) +} + +// Update opens a database read/write transaction and executes the function f +// with the transaction passed as a parameter. After f exits, if f did not +// error, the transaction is committed. Otherwise, if f did error, the +// transaction is rolled back. If the rollback fails, the original error +// returned by f is still returned. If the commit fails, the commit error is +// returned. +func (db *db) Update(f func(tx walletdb.ReadWriteTx) error) error { + apply := func(stm STM) error { + return f(newReadWriteTx(stm)) + } + + return RunSTM(db.cli, apply) +} + +// BeginReadTx opens a database read transaction. +func (db *db) BeginReadWriteTx() (walletdb.ReadWriteTx, error) { + return newReadWriteTx(NewSTM(db.cli)), nil +} + +// BeginReadWriteTx opens a database read+write transaction. +func (db *db) BeginReadTx() (walletdb.ReadTx, error) { + return newReadWriteTx(NewSTM(db.cli)), nil +} + +// Copy writes a copy of the database to the provided writer. This call will +// start a read-only transaction to perform all operations. +// This function is part of the walletdb.Db interface implementation. +func (db *db) Copy(w io.Writer) error { + ctx := context.Background() + + ctx, cancel := context.WithTimeout(ctx, etcdLongTimeout) + defer cancel() + + readCloser, err := db.cli.Snapshot(ctx) + if err != nil { + return err + } + + _, err = io.Copy(w, readCloser) + + return err +} + +// Close cleanly shuts down the database and syncs all data. +// This function is part of the walletdb.Db interface implementation. func (db *db) Close() error { return db.cli.Close() } + +// Batch opens a database read/write transaction and executes the function f +// with the transaction passed as a parameter. After f exits, if f did not +// error, the transaction is committed. Otherwise, if f did error, the +// transaction is rolled back. If the rollback fails, the original error +// returned by f is still returned. If the commit fails, the commit error is +// returned. +// +// Batch is only useful when there are multiple goroutines calling it. +func (db *db) Batch(apply func(tx walletdb.ReadWriteTx) error) error { + return db.Update(apply) +} diff --git a/channeldb/kvdb/etcd/db_test.go b/channeldb/kvdb/etcd/db_test.go new file mode 100644 index 00000000..ecf9b06c --- /dev/null +++ b/channeldb/kvdb/etcd/db_test.go @@ -0,0 +1,42 @@ +package etcd + +import ( + "bytes" + "testing" + + "github.com/btcsuite/btcwallet/walletdb" + "github.com/stretchr/testify/assert" +) + +func TestCopy(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + // "apple" + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.NoError(t, err) + assert.NotNil(t, apple) + + assert.NoError(t, apple.Put([]byte("key"), []byte("val"))) + return nil + }) + + // Expect non-zero copy. + var buf bytes.Buffer + + assert.NoError(t, db.Copy(&buf)) + assert.Greater(t, buf.Len(), 0) + assert.Nil(t, err) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + vkey("key", "apple"): "val", + } + assert.Equal(t, expected, f.Dump()) +} diff --git a/channeldb/kvdb/etcd/driver.go b/channeldb/kvdb/etcd/driver.go new file mode 100644 index 00000000..ccdfbf7d --- /dev/null +++ b/channeldb/kvdb/etcd/driver.go @@ -0,0 +1,66 @@ +package etcd + +import ( + "fmt" + + "github.com/btcsuite/btcwallet/walletdb" +) + +const ( + dbType = "etcd" +) + +// parseArgs parses the arguments from the walletdb Open/Create methods. +func parseArgs(funcName string, args ...interface{}) (*BackendConfig, error) { + if len(args) != 1 { + return nil, fmt.Errorf("invalid number of arguments to %s.%s -- "+ + "expected: etcd.BackendConfig", + dbType, funcName, + ) + } + + config, ok := args[0].(BackendConfig) + if !ok { + return nil, fmt.Errorf("argument to %s.%s is invalid -- "+ + "expected: etcd.BackendConfig", + dbType, funcName, + ) + } + + return &config, nil +} + +// createDBDriver is the callback provided during driver registration that +// creates, initializes, and opens a database for use. +func createDBDriver(args ...interface{}) (walletdb.DB, error) { + config, err := parseArgs("Create", args...) + if err != nil { + return nil, err + } + + return newEtcdBackend(*config) +} + +// openDBDriver is the callback provided during driver registration that opens +// an existing database for use. +func openDBDriver(args ...interface{}) (walletdb.DB, error) { + config, err := parseArgs("Open", args...) + if err != nil { + return nil, err + } + + return newEtcdBackend(*config) +} + +func init() { + // Register the driver. + driver := walletdb.Driver{ + DbType: dbType, + Create: createDBDriver, + Open: openDBDriver, + } + if err := walletdb.RegisterDriver(driver); err != nil { + panic(fmt.Sprintf("Failed to regiser database driver '%s': %v", + dbType, err)) + } +} diff --git a/channeldb/kvdb/etcd/driver_test.go b/channeldb/kvdb/etcd/driver_test.go new file mode 100644 index 00000000..ad8578c6 --- /dev/null +++ b/channeldb/kvdb/etcd/driver_test.go @@ -0,0 +1,28 @@ +package etcd + +import ( + "testing" + + "github.com/btcsuite/btcwallet/walletdb" + "github.com/stretchr/testify/assert" +) + +func TestOpenCreateFailure(t *testing.T) { + t.Parallel() + + db, err := walletdb.Open(dbType) + assert.Error(t, err) + assert.Nil(t, db) + + db, err = walletdb.Open(dbType, "wrong") + assert.Error(t, err) + assert.Nil(t, db) + + db, err = walletdb.Create(dbType) + assert.Error(t, err) + assert.Nil(t, db) + + db, err = walletdb.Create(dbType, "wrong") + assert.Error(t, err) + assert.Nil(t, db) +} diff --git a/channeldb/kvdb/etcd/readwrite_bucket.go b/channeldb/kvdb/etcd/readwrite_bucket.go new file mode 100644 index 00000000..94fa5321 --- /dev/null +++ b/channeldb/kvdb/etcd/readwrite_bucket.go @@ -0,0 +1,355 @@ +package etcd + +import ( + "strconv" + + "github.com/btcsuite/btcwallet/walletdb" +) + +// readWriteBucket stores the bucket id and the buckets transaction. +type readWriteBucket struct { + // id is used to identify the bucket and is created by + // hashing the parent id with the bucket key. For each key/value, + // sub-bucket or the bucket sequence the bucket id is used with the + // appropriate prefix to prefix the key. + id []byte + + // tx holds the parent transaction. + tx *readWriteTx +} + +// newReadWriteBucket creates a new rw bucket with the passed transaction +// and bucket id. +func newReadWriteBucket(tx *readWriteTx, id []byte) *readWriteBucket { + return &readWriteBucket{ + id: id, + tx: tx, + } +} + +// NestedReadBucket retrieves a nested read bucket with the given key. +// Returns nil if the bucket does not exist. +func (b *readWriteBucket) NestedReadBucket(key []byte) walletdb.ReadBucket { + return b.NestedReadWriteBucket(key) +} + +// ForEach invokes the passed function with every key/value pair in +// the bucket. This includes nested buckets, in which case the value +// is nil, but it does not include the key/value pairs within those +// nested buckets. +func (b *readWriteBucket) ForEach(cb func(k, v []byte) error) error { + prefix := makeValuePrefix(b.id) + prefixLen := len(prefix) + + // Get the first matching key that is in the bucket. + kv, err := b.tx.stm.First(string(prefix)) + if err != nil { + return err + } + + for kv != nil { + if err := cb([]byte(kv.key[prefixLen:]), []byte(kv.val)); err != nil { + return err + } + + // Step to the next key. + kv, err = b.tx.stm.Next(string(prefix), kv.key) + if err != nil { + return err + } + } + + // Make a bucket prefix. This prefixes all sub buckets. + prefix = makeBucketPrefix(b.id) + prefixLen = len(prefix) + + // Get the first bucket. + kv, err = b.tx.stm.First(string(prefix)) + if err != nil { + return err + } + + for kv != nil { + if err := cb([]byte(kv.key[prefixLen:]), nil); err != nil { + return err + } + + // Step to the next bucket. + kv, err = b.tx.stm.Next(string(prefix), kv.key) + if err != nil { + return err + } + } + + return nil +} + +// Get returns the value for the given key. Returns nil if the key does +// not exist in this bucket. +func (b *readWriteBucket) Get(key []byte) []byte { + // Return nil if the key is empty. + if len(key) == 0 { + return nil + } + + // Fetch the associated value. + val, err := b.tx.stm.Get(string(makeValueKey(b.id, key))) + if err != nil { + // TODO: we should return the error once the + // kvdb inteface is extended. + return nil + } + + if val == nil { + return nil + } + + return val +} + +func (b *readWriteBucket) ReadCursor() walletdb.ReadCursor { + return newReadWriteCursor(b) +} + +// NestedReadWriteBucket retrieves a nested bucket with the given key. +// Returns nil if the bucket does not exist. +func (b *readWriteBucket) NestedReadWriteBucket(key []byte) walletdb.ReadWriteBucket { + if len(key) == 0 { + return nil + } + + // Get the bucket id (and return nil if bucket doesn't exist). + bucketVal, err := b.tx.stm.Get(string(makeBucketKey(b.id, key))) + if err != nil { + // TODO: we should return the error once the + // kvdb inteface is extended. + return nil + } + + if !isValidBucketID(bucketVal) { + return nil + } + + // Return the bucket with the fetched bucket id. + return newReadWriteBucket(b.tx, bucketVal) +} + +// CreateBucket creates and returns a new nested bucket with the given +// key. Returns ErrBucketExists if the bucket already exists, +// ErrBucketNameRequired if the key is empty, or ErrIncompatibleValue +// if the key value is otherwise invalid for the particular database +// implementation. Other errors are possible depending on the +// implementation. +func (b *readWriteBucket) CreateBucket(key []byte) ( + walletdb.ReadWriteBucket, error) { + + if len(key) == 0 { + return nil, walletdb.ErrBucketNameRequired + } + + // Check if the bucket already exists. + bucketKey := makeBucketKey(b.id, key) + + bucketVal, err := b.tx.stm.Get(string(bucketKey)) + if err != nil { + return nil, err + } + + if isValidBucketID(bucketVal) { + return nil, walletdb.ErrBucketExists + } + + // Create a deterministic bucket id from the bucket key. + newID := makeBucketID(bucketKey) + + // Create the bucket. + b.tx.stm.Put(string(bucketKey), string(newID[:])) + + return newReadWriteBucket(b.tx, newID[:]), nil +} + +// CreateBucketIfNotExists creates and returns a new nested bucket with +// the given key if it does not already exist. Returns +// ErrBucketNameRequired if the key is empty or ErrIncompatibleValue +// if the key value is otherwise invalid for the particular database +// backend. Other errors are possible depending on the implementation. +func (b *readWriteBucket) CreateBucketIfNotExists(key []byte) ( + walletdb.ReadWriteBucket, error) { + + if len(key) == 0 { + return nil, walletdb.ErrBucketNameRequired + } + + // Check for the bucket and create if it doesn't exist. + bucketKey := string(makeBucketKey(b.id, key)) + + bucketVal, err := b.tx.stm.Get(bucketKey) + if err != nil { + return nil, err + } + + if !isValidBucketID(bucketVal) { + newID := makeBucketID([]byte(bucketKey)) + b.tx.stm.Put(bucketKey, string(newID[:])) + + return newReadWriteBucket(b.tx, newID[:]), nil + } + + // Otherwise return the bucket with the fetched bucket id. + return newReadWriteBucket(b.tx, bucketVal), nil +} + +// DeleteNestedBucket deletes the nested bucket and its sub-buckets +// pointed to by the passed key. All values in the bucket and sub-buckets +// will be deleted as well. +func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { + // TODO shouldn't empty key return ErrBucketNameRequired ? + if len(key) == 0 { + return walletdb.ErrIncompatibleValue + } + + // Get the bucket first. + bucketKey := string(makeBucketKey(b.id, key)) + + bucketVal, err := b.tx.stm.Get(bucketKey) + if err != nil { + return err + } + + if !isValidBucketID(bucketVal) { + return walletdb.ErrBucketNotFound + } + + // Enqueue the top level bucket id. + queue := [][]byte{bucketVal} + + // Traverse the buckets breadth first. + for len(queue) != 0 { + if !isValidBucketID(queue[0]) { + return walletdb.ErrBucketNotFound + } + + id := queue[0] + queue = queue[1:] + + // Delete values in the current bucket + valuePrefix := string(makeValuePrefix(id)) + + kv, err := b.tx.stm.First(valuePrefix) + if err != nil { + return err + } + + for kv != nil { + b.tx.stm.Del(kv.key) + + kv, err = b.tx.stm.Next(valuePrefix, kv.key) + if err != nil { + return err + } + } + + // Iterate sub buckets + bucketPrefix := string(makeBucketPrefix(id)) + + kv, err = b.tx.stm.First(bucketPrefix) + if err != nil { + return err + } + + for kv != nil { + // Delete sub bucket key. + b.tx.stm.Del(kv.key) + // Queue it for traversal. + queue = append(queue, []byte(kv.val)) + + kv, err = b.tx.stm.Next(bucketPrefix, kv.key) + if err != nil { + return err + } + } + } + + // Delete the top level bucket. + b.tx.stm.Del(bucketKey) + + return nil +} + +// Put updates the value for the passed key. +// Returns ErrKeyRequred if te passed key is empty. +func (b *readWriteBucket) Put(key, value []byte) error { + if len(key) == 0 { + return walletdb.ErrKeyRequired + } + + // Update the transaction with the new value. + b.tx.stm.Put(string(makeValueKey(b.id, key)), string(value)) + + return nil +} + +// Delete deletes the key/value pointed to by the passed key. +// Returns ErrKeyRequred if the passed key is empty. +func (b *readWriteBucket) Delete(key []byte) error { + if len(key) == 0 { + return walletdb.ErrKeyRequired + } + + // Update the transaction to delete the key/value. + b.tx.stm.Del(string(makeValueKey(b.id, key))) + + return nil +} + +// ReadWriteCursor returns a new read-write cursor for this bucket. +func (b *readWriteBucket) ReadWriteCursor() walletdb.ReadWriteCursor { + return newReadWriteCursor(b) +} + +// Tx returns the buckets transaction. +func (b *readWriteBucket) Tx() walletdb.ReadWriteTx { + return b.tx +} + +// NextSequence returns an autoincrementing sequence number for this bucket. +// Note that this is not a thread safe function and as such it must not be used +// for synchronization. +func (b *readWriteBucket) NextSequence() (uint64, error) { + seq := b.Sequence() + 1 + + return seq, b.SetSequence(seq) +} + +// SetSequence updates the sequence number for the bucket. +func (b *readWriteBucket) SetSequence(v uint64) error { + // Convert the number to string. + val := strconv.FormatUint(v, 10) + + // Update the transaction with the new value for the sequence key. + b.tx.stm.Put(string(makeSequenceKey(b.id)), val) + + return nil +} + +// Sequence returns the current sequence number for this bucket without +// incrementing it. +func (b *readWriteBucket) Sequence() uint64 { + val, err := b.tx.stm.Get(string(makeSequenceKey(b.id))) + if err != nil { + // TODO: This update kvdb interface such that error + // may be returned here. + return 0 + } + + if val == nil { + // If the sequence number is not yet + // stored, then take the default value. + return 0 + } + + // Otherwise try to parse a 64 bit unsigned integer from the value. + num, _ := strconv.ParseUint(string(val), 10, 64) + + return num +} diff --git a/channeldb/kvdb/etcd/readwrite_bucket_test.go b/channeldb/kvdb/etcd/readwrite_bucket_test.go new file mode 100644 index 00000000..8e919403 --- /dev/null +++ b/channeldb/kvdb/etcd/readwrite_bucket_test.go @@ -0,0 +1,402 @@ +package etcd + +import ( + "fmt" + "math" + "testing" + + "github.com/btcsuite/btcwallet/walletdb" + "github.com/stretchr/testify/assert" +) + +func TestBucketCreation(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + // empty bucket name + b, err := tx.CreateTopLevelBucket(nil) + assert.Error(t, walletdb.ErrBucketNameRequired, err) + assert.Nil(t, b) + + // empty bucket name + b, err = tx.CreateTopLevelBucket([]byte("")) + assert.Error(t, walletdb.ErrBucketNameRequired, err) + assert.Nil(t, b) + + // "apple" + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.NoError(t, err) + assert.NotNil(t, apple) + + // Check bucket tx. + assert.Equal(t, tx, apple.Tx()) + + // "apple" already created + b, err = tx.CreateTopLevelBucket([]byte("apple")) + assert.NoError(t, err) + assert.NotNil(t, b) + + // "apple/banana" + banana, err := apple.CreateBucket([]byte("banana")) + assert.NoError(t, err) + assert.NotNil(t, banana) + + banana, err = apple.CreateBucketIfNotExists([]byte("banana")) + assert.NoError(t, err) + assert.NotNil(t, banana) + + // Try creating "apple/banana" again + b, err = apple.CreateBucket([]byte("banana")) + assert.Error(t, walletdb.ErrBucketExists, err) + assert.Nil(t, b) + + // "apple/mango" + mango, err := apple.CreateBucket([]byte("mango")) + assert.Nil(t, err) + assert.NotNil(t, mango) + + // "apple/banana/pear" + pear, err := banana.CreateBucket([]byte("pear")) + assert.Nil(t, err) + assert.NotNil(t, pear) + + // empty bucket + assert.Nil(t, apple.NestedReadWriteBucket(nil)) + assert.Nil(t, apple.NestedReadWriteBucket([]byte(""))) + + // "apple/pear" doesn't exist + assert.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) + + // "apple/banana" exits + assert.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) + assert.NotNil(t, apple.NestedReadBucket([]byte("banana"))) + return nil + }) + + assert.Nil(t, err) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + bkey("apple", "banana"): bval("apple", "banana"), + bkey("apple", "mango"): bval("apple", "mango"), + bkey("apple", "banana", "pear"): bval("apple", "banana", "pear"), + } + assert.Equal(t, expected, f.Dump()) +} + +func TestBucketDeletion(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + // "apple" + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.Nil(t, err) + assert.NotNil(t, apple) + + // "apple/banana" + banana, err := apple.CreateBucket([]byte("banana")) + assert.Nil(t, err) + assert.NotNil(t, banana) + + kvs := []KV{{"key1", "val1"}, {"key2", "val2"}, {"key3", "val3"}} + + for _, kv := range kvs { + assert.NoError(t, banana.Put([]byte(kv.key), []byte(kv.val))) + assert.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) + } + + // Delete a k/v from "apple/banana" + assert.NoError(t, banana.Delete([]byte("key2"))) + // Try getting/putting/deleting invalid k/v's. + assert.Nil(t, banana.Get(nil)) + assert.Error(t, walletdb.ErrKeyRequired, banana.Put(nil, []byte("val"))) + assert.Error(t, walletdb.ErrKeyRequired, banana.Delete(nil)) + + // Try deleting a k/v that doesn't exist. + assert.NoError(t, banana.Delete([]byte("nokey"))) + + // "apple/pear" + pear, err := apple.CreateBucket([]byte("pear")) + assert.Nil(t, err) + assert.NotNil(t, pear) + + // Put some values into "apple/pear" + for _, kv := range kvs { + assert.Nil(t, pear.Put([]byte(kv.key), []byte(kv.val))) + assert.Equal(t, []byte(kv.val), pear.Get([]byte(kv.key))) + } + + // Create nested bucket "apple/pear/cherry" + cherry, err := pear.CreateBucket([]byte("cherry")) + assert.Nil(t, err) + assert.NotNil(t, cherry) + + // Put some values into "apple/pear/cherry" + for _, kv := range kvs { + assert.NoError(t, cherry.Put([]byte(kv.key), []byte(kv.val))) + } + + // Read back values in "apple/pear/cherry" trough a read bucket. + cherryReadBucket := pear.NestedReadBucket([]byte("cherry")) + for _, kv := range kvs { + assert.Equal( + t, []byte(kv.val), + cherryReadBucket.Get([]byte(kv.key)), + ) + } + + // Try deleting some invalid buckets. + assert.Error(t, + walletdb.ErrBucketNameRequired, apple.DeleteNestedBucket(nil), + ) + + // Try deleting a non existing bucket. + assert.Error( + t, + walletdb.ErrBucketNotFound, + apple.DeleteNestedBucket([]byte("missing")), + ) + + // Delete "apple/pear" + assert.Nil(t, apple.DeleteNestedBucket([]byte("pear"))) + + // "apple/pear" deleted + assert.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) + + // "apple/pear/cherry" deleted + assert.Nil(t, pear.NestedReadWriteBucket([]byte("cherry"))) + + // Values deleted too. + for _, kv := range kvs { + assert.Nil(t, pear.Get([]byte(kv.key))) + assert.Nil(t, cherry.Get([]byte(kv.key))) + } + + // "aple/banana" exists + assert.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) + return nil + }) + + assert.Nil(t, err) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + bkey("apple", "banana"): bval("apple", "banana"), + vkey("key1", "apple", "banana"): "val1", + vkey("key3", "apple", "banana"): "val3", + } + assert.Equal(t, expected, f.Dump()) +} + +func TestBucketForEach(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + // "apple" + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.Nil(t, err) + assert.NotNil(t, apple) + + // "apple/banana" + banana, err := apple.CreateBucket([]byte("banana")) + assert.Nil(t, err) + assert.NotNil(t, banana) + + kvs := []KV{{"key1", "val1"}, {"key2", "val2"}, {"key3", "val3"}} + + // put some values into "apple" and "apple/banana" too + for _, kv := range kvs { + assert.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) + assert.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) + + assert.Nil(t, banana.Put([]byte(kv.key), []byte(kv.val))) + assert.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) + } + + got := make(map[string]string) + err = apple.ForEach(func(key, val []byte) error { + got[string(key)] = string(val) + return nil + }) + + expected := map[string]string{ + "key1": "val1", + "key2": "val2", + "key3": "val3", + "banana": "", + } + + assert.NoError(t, err) + assert.Equal(t, expected, got) + + got = make(map[string]string) + err = banana.ForEach(func(key, val []byte) error { + got[string(key)] = string(val) + return nil + }) + + assert.NoError(t, err) + // remove the sub-bucket key + delete(expected, "banana") + assert.Equal(t, expected, got) + + return nil + }) + + assert.Nil(t, err) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + bkey("apple", "banana"): bval("apple", "banana"), + vkey("key1", "apple"): "val1", + vkey("key2", "apple"): "val2", + vkey("key3", "apple"): "val3", + vkey("key1", "apple", "banana"): "val1", + vkey("key2", "apple", "banana"): "val2", + vkey("key3", "apple", "banana"): "val3", + } + assert.Equal(t, expected, f.Dump()) +} + +func TestBucketForEachWithError(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + // "apple" + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.Nil(t, err) + assert.NotNil(t, apple) + + // "apple/banana" + banana, err := apple.CreateBucket([]byte("banana")) + assert.Nil(t, err) + assert.NotNil(t, banana) + + // "apple/pear" + pear, err := apple.CreateBucket([]byte("pear")) + assert.Nil(t, err) + assert.NotNil(t, pear) + + kvs := []KV{{"key1", "val1"}, {"key2", "val2"}} + + // Put some values into "apple" and "apple/banana" too. + for _, kv := range kvs { + assert.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) + assert.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) + } + + got := make(map[string]string) + i := 0 + // Error while iterating value keys. + err = apple.ForEach(func(key, val []byte) error { + if i == 1 { + return fmt.Errorf("error") + } + + got[string(key)] = string(val) + i++ + return nil + }) + + expected := map[string]string{ + "key1": "val1", + } + + assert.Equal(t, expected, got) + assert.Error(t, err) + + got = make(map[string]string) + i = 0 + // Erro while iterating buckets. + err = apple.ForEach(func(key, val []byte) error { + if i == 3 { + return fmt.Errorf("error") + } + + got[string(key)] = string(val) + i++ + return nil + }) + + expected = map[string]string{ + "key1": "val1", + "key2": "val2", + "banana": "", + } + + assert.Equal(t, expected, got) + assert.Error(t, err) + return nil + }) + + assert.Nil(t, err) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + bkey("apple", "banana"): bval("apple", "banana"), + bkey("apple", "pear"): bval("apple", "pear"), + vkey("key1", "apple"): "val1", + vkey("key2", "apple"): "val2", + } + assert.Equal(t, expected, f.Dump()) +} + +func TestBucketSequence(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.Nil(t, err) + assert.NotNil(t, apple) + + banana, err := apple.CreateBucket([]byte("banana")) + assert.Nil(t, err) + assert.NotNil(t, banana) + + assert.Equal(t, uint64(0), apple.Sequence()) + assert.Equal(t, uint64(0), banana.Sequence()) + + assert.Nil(t, apple.SetSequence(math.MaxUint64)) + assert.Equal(t, uint64(math.MaxUint64), apple.Sequence()) + + for i := uint64(0); i < uint64(5); i++ { + s, err := apple.NextSequence() + assert.Nil(t, err) + assert.Equal(t, i, s) + } + + return nil + }) + + assert.Nil(t, err) +} diff --git a/channeldb/kvdb/etcd/readwrite_cursor.go b/channeldb/kvdb/etcd/readwrite_cursor.go new file mode 100644 index 00000000..da30cb6f --- /dev/null +++ b/channeldb/kvdb/etcd/readwrite_cursor.go @@ -0,0 +1,143 @@ +package etcd + +// readWriteCursor holds a reference to the cursors bucket, the value +// prefix and the current key used while iterating. +type readWriteCursor struct { + // bucket holds the reference to the parent bucket. + bucket *readWriteBucket + + // prefix holds the value prefix which is in front of each + // value key in the bucket. + prefix string + + // currKey holds the current key of the cursor. + currKey string +} + +func newReadWriteCursor(bucket *readWriteBucket) *readWriteCursor { + return &readWriteCursor{ + bucket: bucket, + prefix: string(makeValuePrefix(bucket.id)), + } +} + +// First positions the cursor at the first key/value pair and returns +// the pair. +func (c *readWriteCursor) First() (key, value []byte) { + // Get the first key with the value prefix. + kv, err := c.bucket.tx.stm.First(c.prefix) + if err != nil { + // TODO: revise this once kvdb interface supports errors + return nil, nil + } + + if kv != nil { + c.currKey = kv.key + // Chop the prefix and return the key/value. + return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + } + + return nil, nil +} + +// Last positions the cursor at the last key/value pair and returns the +// pair. +func (c *readWriteCursor) Last() (key, value []byte) { + kv, err := c.bucket.tx.stm.Last(c.prefix) + if err != nil { + // TODO: revise this once kvdb interface supports errors + return nil, nil + } + + if kv != nil { + c.currKey = kv.key + // Chop the prefix and return the key/value. + return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + } + + return nil, nil +} + +// Next moves the cursor one key/value pair forward and returns the new +// pair. +func (c *readWriteCursor) Next() (key, value []byte) { + kv, err := c.bucket.tx.stm.Next(c.prefix, c.currKey) + if err != nil { + // TODO: revise this once kvdb interface supports errors + return nil, nil + } + + if kv != nil { + c.currKey = kv.key + // Chop the prefix and return the key/value. + return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + } + + return nil, nil +} + +// Prev moves the cursor one key/value pair backward and returns the new +// pair. +func (c *readWriteCursor) Prev() (key, value []byte) { + kv, err := c.bucket.tx.stm.Prev(c.prefix, c.currKey) + if err != nil { + // TODO: revise this once kvdb interface supports errors + return nil, nil + } + + if kv != nil { + c.currKey = kv.key + // Chop the prefix and return the key/value. + return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + } + + return nil, nil +} + +// Seek positions the cursor at the passed seek key. If the key does +// not exist, the cursor is moved to the next key after seek. Returns +// the new pair. +func (c *readWriteCursor) Seek(seek []byte) (key, value []byte) { + // Return nil if trying to seek to an empty key. + if seek == nil { + return nil, nil + } + + // Seek to the first key with prefix + seek. If that key is not present + // STM will seek to the next matching key with prefix. + kv, err := c.bucket.tx.stm.Seek(c.prefix, c.prefix+string(seek)) + if err != nil { + // TODO: revise this once kvdb interface supports errors + return nil, nil + } + + if kv != nil { + c.currKey = kv.key + // Chop the prefix and return the key/value. + return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + } + + return nil, nil +} + +// Delete removes the current key/value pair the cursor is at without +// invalidating the cursor. Returns ErrIncompatibleValue if attempted +// when the cursor points to a nested bucket. +func (c *readWriteCursor) Delete() error { + // Get the next key after the current one. We could do this + // after deletion too but it's one step more efficient here. + nextKey, err := c.bucket.tx.stm.Next(c.prefix, c.currKey) + if err != nil { + return err + } + + // Delete the current key. + c.bucket.tx.stm.Del(c.currKey) + + // Set current key to the next one if possible. + if nextKey != nil { + c.currKey = nextKey.key + } + + return nil +} diff --git a/channeldb/kvdb/etcd/readwrite_cursor_test.go b/channeldb/kvdb/etcd/readwrite_cursor_test.go new file mode 100644 index 00000000..fd8ca01d --- /dev/null +++ b/channeldb/kvdb/etcd/readwrite_cursor_test.go @@ -0,0 +1,291 @@ +package etcd + +import ( + "testing" + + "github.com/btcsuite/btcwallet/walletdb" + "github.com/stretchr/testify/assert" +) + +func TestReadCursorEmptyInterval(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + b, err := tx.CreateTopLevelBucket([]byte("alma")) + assert.NoError(t, err) + assert.NotNil(t, b) + + return nil + }) + assert.NoError(t, err) + + err = db.View(func(tx walletdb.ReadTx) error { + b := tx.ReadBucket([]byte("alma")) + assert.NotNil(t, b) + + cursor := b.ReadCursor() + k, v := cursor.First() + assert.Nil(t, k) + assert.Nil(t, v) + + k, v = cursor.Next() + assert.Nil(t, k) + assert.Nil(t, v) + + k, v = cursor.Last() + assert.Nil(t, k) + assert.Nil(t, v) + + k, v = cursor.Prev() + assert.Nil(t, k) + assert.Nil(t, v) + + return nil + }) + assert.NoError(t, err) +} + +func TestReadCursorNonEmptyInterval(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + testKeyValues := []KV{ + {"b", "1"}, + {"c", "2"}, + {"da", "3"}, + {"e", "4"}, + } + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + b, err := tx.CreateTopLevelBucket([]byte("alma")) + assert.NoError(t, err) + assert.NotNil(t, b) + + for _, kv := range testKeyValues { + assert.NoError(t, b.Put([]byte(kv.key), []byte(kv.val))) + } + return nil + }) + + assert.NoError(t, err) + + err = db.View(func(tx walletdb.ReadTx) error { + b := tx.ReadBucket([]byte("alma")) + assert.NotNil(t, b) + + // Iterate from the front. + var kvs []KV + cursor := b.ReadCursor() + k, v := cursor.First() + + for k != nil && v != nil { + kvs = append(kvs, KV{string(k), string(v)}) + k, v = cursor.Next() + } + assert.Equal(t, testKeyValues, kvs) + + // Iterate from the back. + kvs = []KV{} + k, v = cursor.Last() + + for k != nil && v != nil { + kvs = append(kvs, KV{string(k), string(v)}) + k, v = cursor.Prev() + } + assert.Equal(t, reverseKVs(testKeyValues), kvs) + + // Random access + perm := []int{3, 0, 2, 1} + for _, i := range perm { + k, v := cursor.Seek([]byte(testKeyValues[i].key)) + assert.Equal(t, []byte(testKeyValues[i].key), k) + assert.Equal(t, []byte(testKeyValues[i].val), v) + } + + // Seek to nonexisting key. + k, v = cursor.Seek(nil) + assert.Nil(t, k) + assert.Nil(t, v) + + k, v = cursor.Seek([]byte("x")) + assert.Nil(t, k) + assert.Nil(t, v) + + return nil + }) + + assert.NoError(t, err) +} + +func TestReadWriteCursor(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + testKeyValues := []KV{ + {"b", "1"}, + {"c", "2"}, + {"da", "3"}, + {"e", "4"}, + } + + count := len(testKeyValues) + + // Pre-store the first half of the interval. + assert.NoError(t, db.Update(func(tx walletdb.ReadWriteTx) error { + b, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.NoError(t, err) + assert.NotNil(t, b) + + for i := 0; i < count/2; i++ { + err = b.Put( + []byte(testKeyValues[i].key), + []byte(testKeyValues[i].val), + ) + assert.NoError(t, err) + } + return nil + })) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + b := tx.ReadWriteBucket([]byte("apple")) + assert.NotNil(t, b) + + // Store the second half of the interval. + for i := count / 2; i < count; i++ { + err = b.Put( + []byte(testKeyValues[i].key), + []byte(testKeyValues[i].val), + ) + assert.NoError(t, err) + } + + cursor := b.ReadWriteCursor() + + // First on valid interval. + fk, fv := cursor.First() + assert.Equal(t, []byte("b"), fk) + assert.Equal(t, []byte("1"), fv) + + // Prev(First()) = nil + k, v := cursor.Prev() + assert.Nil(t, k) + assert.Nil(t, v) + + // Last on valid interval. + lk, lv := cursor.Last() + assert.Equal(t, []byte("e"), lk) + assert.Equal(t, []byte("4"), lv) + + // Next(Last()) = nil + k, v = cursor.Next() + assert.Nil(t, k) + assert.Nil(t, v) + + // Delete first item, then add an item before the + // deleted one. Check that First/Next will "jump" + // over the deleted item and return the new first. + _, _ = cursor.First() + assert.NoError(t, cursor.Delete()) + assert.NoError(t, b.Put([]byte("a"), []byte("0"))) + fk, fv = cursor.First() + + assert.Equal(t, []byte("a"), fk) + assert.Equal(t, []byte("0"), fv) + + k, v = cursor.Next() + assert.Equal(t, []byte("c"), k) + assert.Equal(t, []byte("2"), v) + + // Similarly test that a new end is returned if + // the old end is deleted first. + _, _ = cursor.Last() + assert.NoError(t, cursor.Delete()) + assert.NoError(t, b.Put([]byte("f"), []byte("5"))) + + lk, lv = cursor.Last() + assert.Equal(t, []byte("f"), lk) + assert.Equal(t, []byte("5"), lv) + + k, v = cursor.Prev() + assert.Equal(t, []byte("da"), k) + assert.Equal(t, []byte("3"), v) + + // Overwrite k/v in the middle of the interval. + assert.NoError(t, b.Put([]byte("c"), []byte("3"))) + k, v = cursor.Prev() + assert.Equal(t, []byte("c"), k) + assert.Equal(t, []byte("3"), v) + + // Insert new key/values. + assert.NoError(t, b.Put([]byte("cx"), []byte("x"))) + assert.NoError(t, b.Put([]byte("cy"), []byte("y"))) + + k, v = cursor.Next() + assert.Equal(t, []byte("cx"), k) + assert.Equal(t, []byte("x"), v) + + k, v = cursor.Next() + assert.Equal(t, []byte("cy"), k) + assert.Equal(t, []byte("y"), v) + + expected := []KV{ + {"a", "0"}, + {"c", "3"}, + {"cx", "x"}, + {"cy", "y"}, + {"da", "3"}, + {"f", "5"}, + } + + // Iterate from the front. + var kvs []KV + k, v = cursor.First() + + for k != nil && v != nil { + kvs = append(kvs, KV{string(k), string(v)}) + k, v = cursor.Next() + } + assert.Equal(t, expected, kvs) + + // Iterate from the back. + kvs = []KV{} + k, v = cursor.Last() + + for k != nil && v != nil { + kvs = append(kvs, KV{string(k), string(v)}) + k, v = cursor.Prev() + } + assert.Equal(t, reverseKVs(expected), kvs) + + return nil + }) + + assert.NoError(t, err) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + vkey("a", "apple"): "0", + vkey("c", "apple"): "3", + vkey("cx", "apple"): "x", + vkey("cy", "apple"): "y", + vkey("da", "apple"): "3", + vkey("f", "apple"): "5", + } + assert.Equal(t, expected, f.Dump()) +} diff --git a/channeldb/kvdb/etcd/readwrite_tx.go b/channeldb/kvdb/etcd/readwrite_tx.go new file mode 100644 index 00000000..591ff55d --- /dev/null +++ b/channeldb/kvdb/etcd/readwrite_tx.go @@ -0,0 +1,93 @@ +package etcd + +import ( + "github.com/btcsuite/btcwallet/walletdb" +) + +// readWriteTx holds a reference to the STM transaction. +type readWriteTx struct { + // stm is the reference to the parent STM. + stm STM + + // active is true if the transaction hasn't been + // committed yet. + active bool +} + +// newReadWriteTx creates an rw transaction with the passed STM. +func newReadWriteTx(stm STM) *readWriteTx { + return &readWriteTx{ + stm: stm, + active: true, + } +} + +// rooBucket is a helper function to return the always present +// root bucket. +func rootBucket(tx *readWriteTx) *readWriteBucket { + return newReadWriteBucket(tx, rootBucketID()) +} + +// ReadBucket opens the root bucket for read only access. If the bucket +// described by the key does not exist, nil is returned. +func (tx *readWriteTx) ReadBucket(key []byte) walletdb.ReadBucket { + return rootBucket(tx).NestedReadWriteBucket(key) +} + +// Rollback closes the transaction, discarding changes (if any) if the +// database was modified by a write transaction. +func (tx *readWriteTx) Rollback() error { + // If the transaction has been closed roolback will fail. + if !tx.active { + return walletdb.ErrTxClosed + } + + // Rollback the STM and set the tx to inactive. + tx.stm.Rollback() + tx.active = false + + return nil +} + +// ReadWriteBucket opens the root bucket for read/write access. If the +// bucket described by the key does not exist, nil is returned. +func (tx *readWriteTx) ReadWriteBucket(key []byte) walletdb.ReadWriteBucket { + return rootBucket(tx).NestedReadWriteBucket(key) +} + +// CreateTopLevelBucket creates the top level bucket for a key if it +// does not exist. The newly-created bucket it returned. +func (tx *readWriteTx) CreateTopLevelBucket(key []byte) (walletdb.ReadWriteBucket, error) { + return rootBucket(tx).CreateBucketIfNotExists(key) +} + +// DeleteTopLevelBucket deletes the top level bucket for a key. This +// errors if the bucket can not be found or the key keys a single value +// instead of a bucket. +func (tx *readWriteTx) DeleteTopLevelBucket(key []byte) error { + return rootBucket(tx).DeleteNestedBucket(key) +} + +// Commit commits the transaction if not already committed. Will return +// error if the underlying STM fails. +func (tx *readWriteTx) Commit() error { + // Commit will fail if the transaction is already committed. + if !tx.active { + return walletdb.ErrTxClosed + } + + // Try committing the transaction. + if err := tx.stm.Commit(); err != nil { + return err + } + + // Mark the transaction as not active after commit. + tx.active = false + + return nil +} + +// OnCommit sets the commit callback (overriding if already set). +func (tx *readWriteTx) OnCommit(cb func()) { + tx.stm.OnCommit(cb) +} diff --git a/channeldb/kvdb/etcd/readwrite_tx_test.go b/channeldb/kvdb/etcd/readwrite_tx_test.go new file mode 100644 index 00000000..14a904c5 --- /dev/null +++ b/channeldb/kvdb/etcd/readwrite_tx_test.go @@ -0,0 +1,154 @@ +package etcd + +import ( + "testing" + + "github.com/btcsuite/btcwallet/walletdb" + "github.com/stretchr/testify/assert" +) + +func TestTxManualCommit(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + tx, err := db.BeginReadWriteTx() + assert.NoError(t, err) + assert.NotNil(t, tx) + + committed := false + + tx.OnCommit(func() { + committed = true + }) + + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.NoError(t, err) + assert.NotNil(t, apple) + assert.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) + + banana, err := tx.CreateTopLevelBucket([]byte("banana")) + assert.NoError(t, err) + assert.NotNil(t, banana) + assert.NoError(t, banana.Put([]byte("testKey"), []byte("testVal"))) + assert.NoError(t, tx.DeleteTopLevelBucket([]byte("banana"))) + + assert.NoError(t, tx.Commit()) + assert.True(t, committed) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + vkey("testKey", "apple"): "testVal", + } + assert.Equal(t, expected, f.Dump()) +} + +func TestTxRollback(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + tx, err := db.BeginReadWriteTx() + assert.Nil(t, err) + assert.NotNil(t, tx) + + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.Nil(t, err) + assert.NotNil(t, apple) + + assert.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) + + assert.NoError(t, tx.Rollback()) + assert.Error(t, walletdb.ErrTxClosed, tx.Commit()) + assert.Equal(t, map[string]string{}, f.Dump()) +} + +func TestChangeDuringManualTx(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + tx, err := db.BeginReadWriteTx() + assert.Nil(t, err) + assert.NotNil(t, tx) + + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.Nil(t, err) + assert.NotNil(t, apple) + + assert.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) + + // Try overwriting the bucket key. + f.Put(bkey("apple"), "banana") + + // TODO: translate error + assert.NotNil(t, tx.Commit()) + assert.Equal(t, map[string]string{ + bkey("apple"): "banana", + }, f.Dump()) +} + +func TestChangeDuringUpdate(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + assert.NoError(t, err) + + count := 0 + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + assert.NoError(t, err) + assert.NotNil(t, apple) + + assert.NoError(t, apple.Put([]byte("key"), []byte("value"))) + + if count == 0 { + f.Put(vkey("key", "apple"), "new_value") + f.Put(vkey("key2", "apple"), "value2") + } + + cursor := apple.ReadCursor() + k, v := cursor.First() + assert.Equal(t, []byte("key"), k) + assert.Equal(t, []byte("value"), v) + assert.Equal(t, v, apple.Get([]byte("key"))) + + k, v = cursor.Next() + if count == 0 { + assert.Nil(t, k) + assert.Nil(t, v) + } else { + assert.Equal(t, []byte("key2"), k) + assert.Equal(t, []byte("value2"), v) + } + + count++ + return nil + }) + + assert.Nil(t, err) + assert.Equal(t, count, 2) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + vkey("key", "apple"): "value", + vkey("key2", "apple"): "value2", + } + assert.Equal(t, expected, f.Dump()) +}