diff --git a/kvdb/etcd/config.go b/kvdb/etcd/config.go index 5869438c..4d6e9f85 100644 --- a/kvdb/etcd/config.go +++ b/kvdb/etcd/config.go @@ -25,4 +25,8 @@ type Config struct { InsecureSkipVerify bool `long:"insecure_skip_verify" description:"Whether we intend to skip TLS verification"` CollectStats bool `long:"collect_stats" description:"Whether to collect etcd commit stats."` + + // SingleWriter should be set to true if we intend to only allow a + // single writer to the database at a time. + SingleWriter bool } diff --git a/kvdb/etcd/db.go b/kvdb/etcd/db.go index 0a2955ca..5da44bc8 100644 --- a/kvdb/etcd/db.go +++ b/kvdb/etcd/db.go @@ -125,6 +125,7 @@ type db struct { cli *clientv3.Client commitStatsCollector *commitStatsCollector txQueue *commitQueue + txMutex sync.RWMutex } // Enforce db implements the walletdb.DB interface. @@ -204,9 +205,14 @@ func (db *db) getSTMOptions() []STMOptionFunc { // expect retries of the f closure (depending on the database backend used), the // reset function will be called before each retry respectively. func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error { + if db.cfg.SingleWriter { + db.txMutex.RLock() + defer db.txMutex.RUnlock() + } + apply := func(stm STM) error { reset() - return f(newReadWriteTx(stm, etcdDefaultRootBucketId)) + return f(newReadWriteTx(stm, etcdDefaultRootBucketId, nil)) } return RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...) @@ -220,9 +226,14 @@ func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error { // returned. As callers may expect retries of the f closure, the reset function // will be called before each retry respectively. func (db *db) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error { + if db.cfg.SingleWriter { + db.txMutex.Lock() + defer db.txMutex.Unlock() + } + apply := func(stm STM) error { reset() - return f(newReadWriteTx(stm, etcdDefaultRootBucketId)) + return f(newReadWriteTx(stm, etcdDefaultRootBucketId, nil)) } return RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...) @@ -239,17 +250,29 @@ func (db *db) PrintStats() string { // BeginReadWriteTx opens a database read+write transaction. func (db *db) BeginReadWriteTx() (walletdb.ReadWriteTx, error) { + var locker sync.Locker + if db.cfg.SingleWriter { + db.txMutex.Lock() + locker = &db.txMutex + } + return newReadWriteTx( NewSTM(db.cli, db.txQueue, db.getSTMOptions()...), - etcdDefaultRootBucketId, + etcdDefaultRootBucketId, locker, ), nil } // BeginReadTx opens a database read transaction. func (db *db) BeginReadTx() (walletdb.ReadTx, error) { + var locker sync.Locker + if db.cfg.SingleWriter { + db.txMutex.RLock() + locker = db.txMutex.RLocker() + } + return newReadWriteTx( NewSTM(db.cli, db.txQueue, db.getSTMOptions()...), - etcdDefaultRootBucketId, + etcdDefaultRootBucketId, locker, ), nil } diff --git a/kvdb/etcd/fixture.go b/kvdb/etcd/fixture.go index 01781b7b..aee50d42 100644 --- a/kvdb/etcd/fixture.go +++ b/kvdb/etcd/fixture.go @@ -78,8 +78,13 @@ func NewEtcdTestFixture(t *testing.T) *EtcdTestFixture { } } -func (f *EtcdTestFixture) NewBackend() walletdb.DB { - db, err := newEtcdBackend(context.TODO(), f.BackendConfig()) +func (f *EtcdTestFixture) NewBackend(singleWriter bool) walletdb.DB { + cfg := f.BackendConfig() + if singleWriter { + cfg.SingleWriter = true + } + + db, err := newEtcdBackend(context.TODO(), cfg) require.NoError(f.t, err) return db diff --git a/kvdb/etcd/readwrite_tx.go b/kvdb/etcd/readwrite_tx.go index 7605c6cd..12bc2779 100644 --- a/kvdb/etcd/readwrite_tx.go +++ b/kvdb/etcd/readwrite_tx.go @@ -3,6 +3,8 @@ package etcd import ( + "sync" + "github.com/btcsuite/btcwallet/walletdb" ) @@ -17,13 +19,18 @@ type readWriteTx struct { // active is true if the transaction hasn't been committed yet. active bool + + // lock is passed on for manual txns when the backend is instantiated + // such that we read/write lock transactions to ensure a single writer. + lock sync.Locker } // newReadWriteTx creates an rw transaction with the passed STM. -func newReadWriteTx(stm STM, prefix string) *readWriteTx { +func newReadWriteTx(stm STM, prefix string, lock sync.Locker) *readWriteTx { return &readWriteTx{ stm: stm, active: true, + lock: lock, rootBucketID: makeBucketID([]byte(prefix)), } } @@ -65,6 +72,10 @@ func (tx *readWriteTx) Rollback() error { return walletdb.ErrTxClosed } + if tx.lock != nil { + defer tx.lock.Unlock() + } + // Rollback the STM and set the tx to inactive. tx.stm.Rollback() tx.active = false @@ -99,6 +110,10 @@ func (tx *readWriteTx) Commit() error { return walletdb.ErrTxClosed } + if tx.lock != nil { + defer tx.lock.Unlock() + } + // Try committing the transaction. if err := tx.stm.Commit(); err != nil { return err diff --git a/kvdb/etcd_test.go b/kvdb/etcd_test.go index 6d96d284..7a0cdbdd 100644 --- a/kvdb/etcd_test.go +++ b/kvdb/etcd_test.go @@ -3,6 +3,7 @@ package kvdb import ( + "fmt" "testing" "github.com/btcsuite/btcwallet/walletdb" @@ -143,18 +144,23 @@ func TestEtcd(t *testing.T) { continue } - t.Run(test.name, func(t *testing.T) { - t.Parallel() + rwLock := []bool{false, true} + for _, doRwLock := range rwLock { + name := fmt.Sprintf("%v/RWLock=%v", test.name, doRwLock) - f := etcd.NewEtcdTestFixture(t) - defer f.Cleanup() + t.Run(name, func(t *testing.T) { + t.Parallel() - test.test(t, f.NewBackend()) + f := etcd.NewEtcdTestFixture(t) + defer f.Cleanup() - if test.expectedDb != nil { - dump := f.Dump() - require.Equal(t, test.expectedDb, dump) - } - }) + test.test(t, f.NewBackend(doRwLock)) + + if test.expectedDb != nil { + dump := f.Dump() + require.Equal(t, test.expectedDb, dump) + } + }) + } } }