From 4e68914e9d86d514740db51c43eab9b4f3e7e80a Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 9 Jan 2020 18:44:27 -0800 Subject: [PATCH] htlcswitch: convert to use new kvdb abstraction --- htlcswitch/circuit_map.go | 32 +++++++++++++++---------------- htlcswitch/decayedlog.go | 37 +++++++++++++++++++----------------- htlcswitch/link_test.go | 14 +++++++------- htlcswitch/payment_result.go | 14 +++++++------- htlcswitch/sequencer.go | 10 +++++----- htlcswitch/switch.go | 6 +++--- htlcswitch/test_utils.go | 6 +++--- 7 files changed, 61 insertions(+), 58 deletions(-) diff --git a/htlcswitch/circuit_map.go b/htlcswitch/circuit_map.go index fa91bfcd..7711f247 100644 --- a/htlcswitch/circuit_map.go +++ b/htlcswitch/circuit_map.go @@ -5,10 +5,10 @@ import ( "fmt" "sync" - "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lnwire" ) @@ -213,13 +213,13 @@ func NewCircuitMap(cfg *CircuitMapConfig) (CircuitMap, error) { // initBuckets ensures that the primary buckets used by the circuit are // initialized so that we can assume their existence after startup. func (cm *circuitMap) initBuckets() error { - return cm.cfg.DB.Update(func(tx *bbolt.Tx) error { - _, err := tx.CreateBucketIfNotExists(circuitKeystoneKey) + return kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error { + _, err := tx.CreateTopLevelBucket(circuitKeystoneKey) if err != nil { return err } - _, err = tx.CreateBucketIfNotExists(circuitAddKey) + _, err = tx.CreateTopLevelBucket(circuitAddKey) return err }) } @@ -238,10 +238,10 @@ func (cm *circuitMap) restoreMemState() error { pending = make(map[CircuitKey]*PaymentCircuit) ) - if err := cm.cfg.DB.Update(func(tx *bbolt.Tx) error { + if err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error { // Restore any of the circuits persisted in the circuit bucket // back into memory. - circuitBkt := tx.Bucket(circuitAddKey) + circuitBkt := tx.ReadWriteBucket(circuitAddKey) if circuitBkt == nil { return ErrCorruptedCircuitMap } @@ -262,7 +262,7 @@ func (cm *circuitMap) restoreMemState() error { // Furthermore, load the keystone bucket and resurrect the // keystones used in any open circuits. - keystoneBkt := tx.Bucket(circuitKeystoneKey) + keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey) if keystoneBkt == nil { return ErrCorruptedCircuitMap } @@ -463,8 +463,8 @@ func (cm *circuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID, return nil } - return cm.cfg.DB.Update(func(tx *bbolt.Tx) error { - keystoneBkt := tx.Bucket(circuitKeystoneKey) + return kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error { + keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey) if keystoneBkt == nil { return ErrCorruptedCircuitMap } @@ -616,8 +616,8 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) ( // Write the entire batch of circuits to the persistent circuit bucket // using bolt's Batch write. This method must be called from multiple, // distinct goroutines to have any impact on performance. - err := cm.cfg.DB.Batch(func(tx *bbolt.Tx) error { - circuitBkt := tx.Bucket(circuitAddKey) + err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { + circuitBkt := tx.ReadWriteBucket(circuitAddKey) if circuitBkt == nil { return ErrCorruptedCircuitMap } @@ -706,10 +706,10 @@ func (cm *circuitMap) OpenCircuits(keystones ...Keystone) error { } cm.mtx.RUnlock() - err := cm.cfg.DB.Update(func(tx *bbolt.Tx) error { + err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error { // Now, load the circuit bucket to which we will write the // already serialized circuit. - keystoneBkt := tx.Bucket(circuitKeystoneKey) + keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey) if keystoneBkt == nil { return ErrCorruptedCircuitMap } @@ -847,13 +847,13 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error { } cm.mtx.Unlock() - err := cm.cfg.DB.Batch(func(tx *bbolt.Tx) error { + err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { for _, circuit := range removedCircuits { // If this htlc made it to an outgoing link, load the // keystone bucket from which we will remove the // outgoing circuit key. if circuit.HasKeystone() { - keystoneBkt := tx.Bucket(circuitKeystoneKey) + keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey) if keystoneBkt == nil { return ErrCorruptedCircuitMap } @@ -868,7 +868,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error { // Remove the circuit itself based on the incoming // circuit key. - circuitBkt := tx.Bucket(circuitAddKey) + circuitBkt := tx.ReadWriteBucket(circuitAddKey) if circuitBkt == nil { return ErrCorruptedCircuitMap } diff --git a/htlcswitch/decayedlog.go b/htlcswitch/decayedlog.go index 6b3c62b5..3a60e112 100644 --- a/htlcswitch/decayedlog.go +++ b/htlcswitch/decayedlog.go @@ -8,9 +8,9 @@ import ( "sync" "sync/atomic" - "github.com/coreos/bbolt" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb/kvdb" ) const ( @@ -56,7 +56,7 @@ type DecayedLog struct { dbPath string - db *bbolt.DB + db kvdb.Backend notifier chainntnfs.ChainNotifier @@ -92,7 +92,10 @@ func (d *DecayedLog) Start() error { // Open the boltdb for use. var err error - if d.db, err = bbolt.Open(d.dbPath, dbPermissions, nil); err != nil { + d.db, err = kvdb.Create( + kvdb.BoltBackendName, d.dbPath, true, + ) + if err != nil { return fmt.Errorf("Could not open boltdb: %v", err) } @@ -119,13 +122,13 @@ func (d *DecayedLog) Start() error { // initBuckets initializes the primary buckets used by the decayed log, namely // the shared hash bucket, and batch replay func (d *DecayedLog) initBuckets() error { - return d.db.Update(func(tx *bbolt.Tx) error { - _, err := tx.CreateBucketIfNotExists(sharedHashBucket) + return kvdb.Update(d.db, func(tx kvdb.RwTx) error { + _, err := tx.CreateTopLevelBucket(sharedHashBucket) if err != nil { return ErrDecayedLogInit } - _, err = tx.CreateBucketIfNotExists(batchReplayBucket) + _, err = tx.CreateTopLevelBucket(batchReplayBucket) if err != nil { return ErrDecayedLogInit } @@ -196,11 +199,11 @@ func (d *DecayedLog) garbageCollector(epochClient *chainntnfs.BlockEpochEvent) { func (d *DecayedLog) gcExpiredHashes(height uint32) (uint32, error) { var numExpiredHashes uint32 - err := d.db.Batch(func(tx *bbolt.Tx) error { + err := kvdb.Batch(d.db, func(tx kvdb.RwTx) error { numExpiredHashes = 0 // Grab the shared hash bucket - sharedHashes := tx.Bucket(sharedHashBucket) + sharedHashes := tx.ReadWriteBucket(sharedHashBucket) if sharedHashes == nil { return fmt.Errorf("sharedHashBucket " + "is nil") @@ -246,8 +249,8 @@ func (d *DecayedLog) gcExpiredHashes(height uint32) (uint32, error) { // Delete removes a key-pair from the // sharedHashBucket. func (d *DecayedLog) Delete(hash *sphinx.HashPrefix) error { - return d.db.Batch(func(tx *bbolt.Tx) error { - sharedHashes := tx.Bucket(sharedHashBucket) + return kvdb.Batch(d.db, func(tx kvdb.RwTx) error { + sharedHashes := tx.ReadWriteBucket(sharedHashBucket) if sharedHashes == nil { return ErrDecayedLogCorrupted } @@ -261,10 +264,10 @@ func (d *DecayedLog) Delete(hash *sphinx.HashPrefix) error { func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) { var value uint32 - err := d.db.View(func(tx *bbolt.Tx) error { + err := kvdb.View(d.db, func(tx kvdb.ReadTx) error { // Grab the shared hash bucket which stores the mapping from // truncated sha-256 hashes of shared secrets to CLTV's. - sharedHashes := tx.Bucket(sharedHashBucket) + sharedHashes := tx.ReadBucket(sharedHashBucket) if sharedHashes == nil { return fmt.Errorf("sharedHashes is nil, could " + "not retrieve CLTV value") @@ -294,8 +297,8 @@ func (d *DecayedLog) Put(hash *sphinx.HashPrefix, cltv uint32) error { var scratch [4]byte binary.BigEndian.PutUint32(scratch[:], cltv) - return d.db.Batch(func(tx *bbolt.Tx) error { - sharedHashes := tx.Bucket(sharedHashBucket) + return kvdb.Batch(d.db, func(tx kvdb.RwTx) error { + sharedHashes := tx.ReadWriteBucket(sharedHashBucket) if sharedHashes == nil { return ErrDecayedLogCorrupted } @@ -327,8 +330,8 @@ func (d *DecayedLog) PutBatch(b *sphinx.Batch) (*sphinx.ReplaySet, error) { // to generate the complete replay set. If this batch was previously // processed, the replay set will be deserialized from disk. var replays *sphinx.ReplaySet - if err := d.db.Batch(func(tx *bbolt.Tx) error { - sharedHashes := tx.Bucket(sharedHashBucket) + if err := kvdb.Batch(d.db, func(tx kvdb.RwTx) error { + sharedHashes := tx.ReadWriteBucket(sharedHashBucket) if sharedHashes == nil { return ErrDecayedLogCorrupted } @@ -336,7 +339,7 @@ func (d *DecayedLog) PutBatch(b *sphinx.Batch) (*sphinx.ReplaySet, error) { // Load the batch replay bucket, which will be used to either // retrieve the result of previously processing this batch, or // to write the result of this operation. - batchReplayBkt := tx.Bucket(batchReplayBucket) + batchReplayBkt := tx.ReadWriteBucket(batchReplayBucket) if batchReplayBkt == nil { return ErrDecayedLogCorrupted } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 0b162cd0..0906141a 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -19,12 +19,12 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" - "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -5170,32 +5170,32 @@ type mockPackager struct { failLoadFwdPkgs bool } -func (*mockPackager) AddFwdPkg(tx *bbolt.Tx, fwdPkg *channeldb.FwdPkg) error { +func (*mockPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *channeldb.FwdPkg) error { return nil } -func (*mockPackager) SetFwdFilter(tx *bbolt.Tx, height uint64, +func (*mockPackager) SetFwdFilter(tx kvdb.RwTx, height uint64, fwdFilter *channeldb.PkgFilter) error { return nil } -func (*mockPackager) AckAddHtlcs(tx *bbolt.Tx, +func (*mockPackager) AckAddHtlcs(tx kvdb.RwTx, addRefs ...channeldb.AddRef) error { return nil } -func (m *mockPackager) LoadFwdPkgs(tx *bbolt.Tx) ([]*channeldb.FwdPkg, error) { +func (m *mockPackager) LoadFwdPkgs(tx kvdb.ReadTx) ([]*channeldb.FwdPkg, error) { if m.failLoadFwdPkgs { return nil, fmt.Errorf("failing LoadFwdPkgs") } return nil, nil } -func (*mockPackager) RemovePkg(tx *bbolt.Tx, height uint64) error { +func (*mockPackager) RemovePkg(tx kvdb.RwTx, height uint64) error { return nil } -func (*mockPackager) AckSettleFails(tx *bbolt.Tx, +func (*mockPackager) AckSettleFails(tx kvdb.RwTx, settleFailRefs ...channeldb.SettleFailRef) error { return nil } diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index faf15d84..b23dbe0a 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -7,8 +7,8 @@ import ( "io" "sync" - "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/multimutex" ) @@ -137,8 +137,8 @@ func (store *networkResultStore) storeResult(paymentID uint64, var paymentIDBytes [8]byte binary.BigEndian.PutUint64(paymentIDBytes[:], paymentID) - err := store.db.Batch(func(tx *bbolt.Tx) error { - networkResults, err := tx.CreateBucketIfNotExists( + err := kvdb.Batch(store.db.Backend, func(tx kvdb.RwTx) error { + networkResults, err := tx.CreateTopLevelBucket( networkResultStoreBucketKey, ) if err != nil { @@ -180,7 +180,7 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) ( resultChan = make(chan *networkResult, 1) ) - err := store.db.View(func(tx *bbolt.Tx) error { + err := kvdb.View(store.db, func(tx kvdb.ReadTx) error { var err error result, err = fetchResult(tx, paymentID) switch { @@ -226,7 +226,7 @@ func (store *networkResultStore) getResult(pid uint64) ( *networkResult, error) { var result *networkResult - err := store.db.View(func(tx *bbolt.Tx) error { + err := kvdb.View(store.db, func(tx kvdb.ReadTx) error { var err error result, err = fetchResult(tx, pid) return err @@ -238,11 +238,11 @@ func (store *networkResultStore) getResult(pid uint64) ( return result, nil } -func fetchResult(tx *bbolt.Tx, pid uint64) (*networkResult, error) { +func fetchResult(tx kvdb.ReadTx, pid uint64) (*networkResult, error) { var paymentIDBytes [8]byte binary.BigEndian.PutUint64(paymentIDBytes[:], pid) - networkResults := tx.Bucket(networkResultStoreBucketKey) + networkResults := tx.ReadBucket(networkResultStoreBucketKey) if networkResults == nil { return nil, ErrPaymentIDNotFound } diff --git a/htlcswitch/sequencer.go b/htlcswitch/sequencer.go index 3a5247db..5b1526b6 100644 --- a/htlcswitch/sequencer.go +++ b/htlcswitch/sequencer.go @@ -3,9 +3,9 @@ package htlcswitch import ( "sync" - "github.com/coreos/bbolt" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/kvdb" ) // defaultSequenceBatchSize specifies the window of sequence numbers that are @@ -87,8 +87,8 @@ func (s *persistentSequencer) NextID() (uint64, error) { // allocated will start from the last known tip on disk, which is fine // as we only require uniqueness of the allocated numbers. var nextHorizonID uint64 - if err := s.db.Update(func(tx *bbolt.Tx) error { - nextIDBkt := tx.Bucket(nextPaymentIDKey) + if err := kvdb.Update(s.db, func(tx kvdb.RwTx) error { + nextIDBkt := tx.ReadWriteBucket(nextPaymentIDKey) if nextIDBkt == nil { return ErrSequencerCorrupted } @@ -121,8 +121,8 @@ func (s *persistentSequencer) NextID() (uint64, error) { // initDB populates the bucket used to generate payment sequence numbers. func (s *persistentSequencer) initDB() error { - return s.db.Update(func(tx *bbolt.Tx) error { - _, err := tx.CreateBucketIfNotExists(nextPaymentIDKey) + return kvdb.Update(s.db, func(tx kvdb.RwTx) error { + _, err := tx.CreateTopLevelBucket(nextPaymentIDKey) return err }) } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 602dc55b..9d03b1b7 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -10,10 +10,10 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" - "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" @@ -1419,7 +1419,7 @@ func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) { // we're the originator of the payment, so the link stops attempting to // re-broadcast. func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error { - return s.cfg.DB.Batch(func(tx *bbolt.Tx) error { + return kvdb.Batch(s.cfg.DB.Backend, func(tx kvdb.RwTx) error { return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...) }) } @@ -1865,7 +1865,7 @@ func (s *Switch) reforwardResponses() error { func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) { var fwdPkgs []*channeldb.FwdPkg - if err := s.cfg.DB.Update(func(tx *bbolt.Tx) error { + if err := kvdb.Update(s.cfg.DB, func(tx kvdb.RwTx) error { var err error fwdPkgs, err = s.cfg.SwitchPackager.LoadChannelFwdPkgs( tx, source, diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 26b8dad5..f628f280 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -21,10 +21,10 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/btcsuite/fastsha256" - "github.com/coreos/bbolt" "github.com/go-errors/errors" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -420,7 +420,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub) switch err { case nil: - case bbolt.ErrDatabaseNotOpen: + case kvdb.ErrDatabaseNotOpen: dbAlice, err = channeldb.Open(dbAlice.Path()) if err != nil { return nil, errors.Errorf("unable to reopen alice "+ @@ -464,7 +464,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub) switch err { case nil: - case bbolt.ErrDatabaseNotOpen: + case kvdb.ErrDatabaseNotOpen: dbBob, err = channeldb.Open(dbBob.Path()) if err != nil { return nil, errors.Errorf("unable to reopen bob "+