diff --git a/sweep/store.go b/sweep/store.go index f1ca8475..287646a7 100644 --- a/sweep/store.go +++ b/sweep/store.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb/kvdb" ) var ( @@ -56,26 +56,28 @@ type SweeperStore interface { } type sweeperStore struct { - db *bbolt.DB + db kvdb.Backend } // NewSweeperStore returns a new store instance. -func NewSweeperStore(db *bbolt.DB, chainHash *chainhash.Hash) ( +func NewSweeperStore(db kvdb.Backend, chainHash *chainhash.Hash) ( SweeperStore, error) { - err := db.Update(func(tx *bbolt.Tx) error { - _, err := tx.CreateBucketIfNotExists( + err := kvdb.Update(db, func(tx kvdb.RwTx) error { + _, err := tx.CreateTopLevelBucket( lastTxBucketKey, ) if err != nil { return err } - if tx.Bucket(txHashesBucketKey) != nil { + if tx.ReadWriteBucket(txHashesBucketKey) != nil { return nil } - txHashesBucket, err := tx.CreateBucket(txHashesBucketKey) + txHashesBucket, err := tx.CreateTopLevelBucket( + txHashesBucketKey, + ) if err != nil { return err } @@ -97,7 +99,7 @@ func NewSweeperStore(db *bbolt.DB, chainHash *chainhash.Hash) ( // migrateTxHashes migrates nursery finalized txes to the tx hashes bucket. This // is not implemented as a database migration, to keep the downgrade path open. -func migrateTxHashes(tx *bbolt.Tx, txHashesBucket *bbolt.Bucket, +func migrateTxHashes(tx kvdb.RwTx, txHashesBucket kvdb.RwBucket, chainHash *chainhash.Hash) error { log.Infof("Migrating UTXO nursery finalized TXIDs") @@ -113,20 +115,20 @@ func migrateTxHashes(tx *bbolt.Tx, txHashesBucket *bbolt.Bucket, } // Get chain bucket if exists. - chainBucket := tx.Bucket(b.Bytes()) + chainBucket := tx.ReadWriteBucket(b.Bytes()) if chainBucket == nil { return nil } // Retrieve the existing height index. - hghtIndex := chainBucket.Bucket(utxnHeightIndexKey) + hghtIndex := chainBucket.NestedReadWriteBucket(utxnHeightIndexKey) if hghtIndex == nil { return nil } // Retrieve all heights. err := hghtIndex.ForEach(func(k, v []byte) error { - heightBucket := hghtIndex.Bucket(k) + heightBucket := hghtIndex.NestedReadWriteBucket(k) if heightBucket == nil { return nil } @@ -163,13 +165,13 @@ func migrateTxHashes(tx *bbolt.Tx, txHashesBucket *bbolt.Bucket, // NotifyPublishTx signals that we are about to publish a tx. func (s *sweeperStore) NotifyPublishTx(sweepTx *wire.MsgTx) error { - return s.db.Update(func(tx *bbolt.Tx) error { - lastTxBucket := tx.Bucket(lastTxBucketKey) + return kvdb.Update(s.db, func(tx kvdb.RwTx) error { + lastTxBucket := tx.ReadWriteBucket(lastTxBucketKey) if lastTxBucket == nil { return errors.New("last tx bucket does not exist") } - txHashesBucket := tx.Bucket(txHashesBucketKey) + txHashesBucket := tx.ReadWriteBucket(txHashesBucketKey) if txHashesBucket == nil { return errors.New("tx hashes bucket does not exist") } @@ -194,8 +196,8 @@ func (s *sweeperStore) NotifyPublishTx(sweepTx *wire.MsgTx) error { func (s *sweeperStore) GetLastPublishedTx() (*wire.MsgTx, error) { var sweepTx *wire.MsgTx - err := s.db.View(func(tx *bbolt.Tx) error { - lastTxBucket := tx.Bucket(lastTxBucketKey) + err := kvdb.View(s.db, func(tx kvdb.ReadTx) error { + lastTxBucket := tx.ReadBucket(lastTxBucketKey) if lastTxBucket == nil { return errors.New("last tx bucket does not exist") } @@ -225,8 +227,8 @@ func (s *sweeperStore) GetLastPublishedTx() (*wire.MsgTx, error) { func (s *sweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) { var ours bool - err := s.db.View(func(tx *bbolt.Tx) error { - txHashesBucket := tx.Bucket(txHashesBucketKey) + err := kvdb.View(s.db, func(tx kvdb.ReadTx) error { + txHashesBucket := tx.ReadBucket(txHashesBucketKey) if txHashesBucket == nil { return errors.New("tx hashes bucket does not exist") } diff --git a/sweep/store_test.go b/sweep/store_test.go index 6853c7bb..23714c78 100644 --- a/sweep/store_test.go +++ b/sweep/store_test.go @@ -53,7 +53,7 @@ func TestStore(t *testing.T) { testStore(t, func() (SweeperStore, error) { var chain chainhash.Hash - return NewSweeperStore(cdb.DB, &chain) + return NewSweeperStore(cdb, &chain) }) }) t.Run("mock", func(t *testing.T) {