htlcswitch: convert to use new kvdb abstraction

This commit is contained in:
Olaoluwa Osuntokun 2020-01-09 18:44:27 -08:00
parent 320101d054
commit 4e68914e9d
No known key found for this signature in database
GPG Key ID: BC13F65E2DC84465
7 changed files with 61 additions and 58 deletions

@ -5,10 +5,10 @@ import (
"fmt"
"sync"
"github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnwire"
)
@ -213,13 +213,13 @@ func NewCircuitMap(cfg *CircuitMapConfig) (CircuitMap, error) {
// initBuckets ensures that the primary buckets used by the circuit are
// initialized so that we can assume their existence after startup.
func (cm *circuitMap) initBuckets() error {
return cm.cfg.DB.Update(func(tx *bbolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(circuitKeystoneKey)
return kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(circuitKeystoneKey)
if err != nil {
return err
}
_, err = tx.CreateBucketIfNotExists(circuitAddKey)
_, err = tx.CreateTopLevelBucket(circuitAddKey)
return err
})
}
@ -238,10 +238,10 @@ func (cm *circuitMap) restoreMemState() error {
pending = make(map[CircuitKey]*PaymentCircuit)
)
if err := cm.cfg.DB.Update(func(tx *bbolt.Tx) error {
if err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
// Restore any of the circuits persisted in the circuit bucket
// back into memory.
circuitBkt := tx.Bucket(circuitAddKey)
circuitBkt := tx.ReadWriteBucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
@ -262,7 +262,7 @@ func (cm *circuitMap) restoreMemState() error {
// Furthermore, load the keystone bucket and resurrect the
// keystones used in any open circuits.
keystoneBkt := tx.Bucket(circuitKeystoneKey)
keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
@ -463,8 +463,8 @@ func (cm *circuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
return nil
}
return cm.cfg.DB.Update(func(tx *bbolt.Tx) error {
keystoneBkt := tx.Bucket(circuitKeystoneKey)
return kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
@ -616,8 +616,8 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) (
// Write the entire batch of circuits to the persistent circuit bucket
// using bolt's Batch write. This method must be called from multiple,
// distinct goroutines to have any impact on performance.
err := cm.cfg.DB.Batch(func(tx *bbolt.Tx) error {
circuitBkt := tx.Bucket(circuitAddKey)
err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error {
circuitBkt := tx.ReadWriteBucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
@ -706,10 +706,10 @@ func (cm *circuitMap) OpenCircuits(keystones ...Keystone) error {
}
cm.mtx.RUnlock()
err := cm.cfg.DB.Update(func(tx *bbolt.Tx) error {
err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
// Now, load the circuit bucket to which we will write the
// already serialized circuit.
keystoneBkt := tx.Bucket(circuitKeystoneKey)
keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
@ -847,13 +847,13 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
}
cm.mtx.Unlock()
err := cm.cfg.DB.Batch(func(tx *bbolt.Tx) error {
err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error {
for _, circuit := range removedCircuits {
// If this htlc made it to an outgoing link, load the
// keystone bucket from which we will remove the
// outgoing circuit key.
if circuit.HasKeystone() {
keystoneBkt := tx.Bucket(circuitKeystoneKey)
keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
@ -868,7 +868,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
// Remove the circuit itself based on the incoming
// circuit key.
circuitBkt := tx.Bucket(circuitAddKey)
circuitBkt := tx.ReadWriteBucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}

@ -8,9 +8,9 @@ import (
"sync"
"sync/atomic"
"github.com/coreos/bbolt"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
)
const (
@ -56,7 +56,7 @@ type DecayedLog struct {
dbPath string
db *bbolt.DB
db kvdb.Backend
notifier chainntnfs.ChainNotifier
@ -92,7 +92,10 @@ func (d *DecayedLog) Start() error {
// Open the boltdb for use.
var err error
if d.db, err = bbolt.Open(d.dbPath, dbPermissions, nil); err != nil {
d.db, err = kvdb.Create(
kvdb.BoltBackendName, d.dbPath, true,
)
if err != nil {
return fmt.Errorf("Could not open boltdb: %v", err)
}
@ -119,13 +122,13 @@ func (d *DecayedLog) Start() error {
// initBuckets initializes the primary buckets used by the decayed log, namely
// the shared hash bucket, and batch replay
func (d *DecayedLog) initBuckets() error {
return d.db.Update(func(tx *bbolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(sharedHashBucket)
return kvdb.Update(d.db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(sharedHashBucket)
if err != nil {
return ErrDecayedLogInit
}
_, err = tx.CreateBucketIfNotExists(batchReplayBucket)
_, err = tx.CreateTopLevelBucket(batchReplayBucket)
if err != nil {
return ErrDecayedLogInit
}
@ -196,11 +199,11 @@ func (d *DecayedLog) garbageCollector(epochClient *chainntnfs.BlockEpochEvent) {
func (d *DecayedLog) gcExpiredHashes(height uint32) (uint32, error) {
var numExpiredHashes uint32
err := d.db.Batch(func(tx *bbolt.Tx) error {
err := kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
numExpiredHashes = 0
// Grab the shared hash bucket
sharedHashes := tx.Bucket(sharedHashBucket)
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return fmt.Errorf("sharedHashBucket " +
"is nil")
@ -246,8 +249,8 @@ func (d *DecayedLog) gcExpiredHashes(height uint32) (uint32, error) {
// Delete removes a <shared secret hash, CLTV> key-pair from the
// sharedHashBucket.
func (d *DecayedLog) Delete(hash *sphinx.HashPrefix) error {
return d.db.Batch(func(tx *bbolt.Tx) error {
sharedHashes := tx.Bucket(sharedHashBucket)
return kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return ErrDecayedLogCorrupted
}
@ -261,10 +264,10 @@ func (d *DecayedLog) Delete(hash *sphinx.HashPrefix) error {
func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) {
var value uint32
err := d.db.View(func(tx *bbolt.Tx) error {
err := kvdb.View(d.db, func(tx kvdb.ReadTx) error {
// Grab the shared hash bucket which stores the mapping from
// truncated sha-256 hashes of shared secrets to CLTV's.
sharedHashes := tx.Bucket(sharedHashBucket)
sharedHashes := tx.ReadBucket(sharedHashBucket)
if sharedHashes == nil {
return fmt.Errorf("sharedHashes is nil, could " +
"not retrieve CLTV value")
@ -294,8 +297,8 @@ func (d *DecayedLog) Put(hash *sphinx.HashPrefix, cltv uint32) error {
var scratch [4]byte
binary.BigEndian.PutUint32(scratch[:], cltv)
return d.db.Batch(func(tx *bbolt.Tx) error {
sharedHashes := tx.Bucket(sharedHashBucket)
return kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return ErrDecayedLogCorrupted
}
@ -327,8 +330,8 @@ func (d *DecayedLog) PutBatch(b *sphinx.Batch) (*sphinx.ReplaySet, error) {
// to generate the complete replay set. If this batch was previously
// processed, the replay set will be deserialized from disk.
var replays *sphinx.ReplaySet
if err := d.db.Batch(func(tx *bbolt.Tx) error {
sharedHashes := tx.Bucket(sharedHashBucket)
if err := kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return ErrDecayedLogCorrupted
}
@ -336,7 +339,7 @@ func (d *DecayedLog) PutBatch(b *sphinx.Batch) (*sphinx.ReplaySet, error) {
// Load the batch replay bucket, which will be used to either
// retrieve the result of previously processing this batch, or
// to write the result of this operation.
batchReplayBkt := tx.Bucket(batchReplayBucket)
batchReplayBkt := tx.ReadWriteBucket(batchReplayBucket)
if batchReplayBkt == nil {
return ErrDecayedLogCorrupted
}

@ -19,12 +19,12 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
@ -5170,32 +5170,32 @@ type mockPackager struct {
failLoadFwdPkgs bool
}
func (*mockPackager) AddFwdPkg(tx *bbolt.Tx, fwdPkg *channeldb.FwdPkg) error {
func (*mockPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *channeldb.FwdPkg) error {
return nil
}
func (*mockPackager) SetFwdFilter(tx *bbolt.Tx, height uint64,
func (*mockPackager) SetFwdFilter(tx kvdb.RwTx, height uint64,
fwdFilter *channeldb.PkgFilter) error {
return nil
}
func (*mockPackager) AckAddHtlcs(tx *bbolt.Tx,
func (*mockPackager) AckAddHtlcs(tx kvdb.RwTx,
addRefs ...channeldb.AddRef) error {
return nil
}
func (m *mockPackager) LoadFwdPkgs(tx *bbolt.Tx) ([]*channeldb.FwdPkg, error) {
func (m *mockPackager) LoadFwdPkgs(tx kvdb.ReadTx) ([]*channeldb.FwdPkg, error) {
if m.failLoadFwdPkgs {
return nil, fmt.Errorf("failing LoadFwdPkgs")
}
return nil, nil
}
func (*mockPackager) RemovePkg(tx *bbolt.Tx, height uint64) error {
func (*mockPackager) RemovePkg(tx kvdb.RwTx, height uint64) error {
return nil
}
func (*mockPackager) AckSettleFails(tx *bbolt.Tx,
func (*mockPackager) AckSettleFails(tx kvdb.RwTx,
settleFailRefs ...channeldb.SettleFailRef) error {
return nil
}

@ -7,8 +7,8 @@ import (
"io"
"sync"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/multimutex"
)
@ -137,8 +137,8 @@ func (store *networkResultStore) storeResult(paymentID uint64,
var paymentIDBytes [8]byte
binary.BigEndian.PutUint64(paymentIDBytes[:], paymentID)
err := store.db.Batch(func(tx *bbolt.Tx) error {
networkResults, err := tx.CreateBucketIfNotExists(
err := kvdb.Batch(store.db.Backend, func(tx kvdb.RwTx) error {
networkResults, err := tx.CreateTopLevelBucket(
networkResultStoreBucketKey,
)
if err != nil {
@ -180,7 +180,7 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) (
resultChan = make(chan *networkResult, 1)
)
err := store.db.View(func(tx *bbolt.Tx) error {
err := kvdb.View(store.db, func(tx kvdb.ReadTx) error {
var err error
result, err = fetchResult(tx, paymentID)
switch {
@ -226,7 +226,7 @@ func (store *networkResultStore) getResult(pid uint64) (
*networkResult, error) {
var result *networkResult
err := store.db.View(func(tx *bbolt.Tx) error {
err := kvdb.View(store.db, func(tx kvdb.ReadTx) error {
var err error
result, err = fetchResult(tx, pid)
return err
@ -238,11 +238,11 @@ func (store *networkResultStore) getResult(pid uint64) (
return result, nil
}
func fetchResult(tx *bbolt.Tx, pid uint64) (*networkResult, error) {
func fetchResult(tx kvdb.ReadTx, pid uint64) (*networkResult, error) {
var paymentIDBytes [8]byte
binary.BigEndian.PutUint64(paymentIDBytes[:], pid)
networkResults := tx.Bucket(networkResultStoreBucketKey)
networkResults := tx.ReadBucket(networkResultStoreBucketKey)
if networkResults == nil {
return nil, ErrPaymentIDNotFound
}

@ -3,9 +3,9 @@ package htlcswitch
import (
"sync"
"github.com/coreos/bbolt"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
)
// defaultSequenceBatchSize specifies the window of sequence numbers that are
@ -87,8 +87,8 @@ func (s *persistentSequencer) NextID() (uint64, error) {
// allocated will start from the last known tip on disk, which is fine
// as we only require uniqueness of the allocated numbers.
var nextHorizonID uint64
if err := s.db.Update(func(tx *bbolt.Tx) error {
nextIDBkt := tx.Bucket(nextPaymentIDKey)
if err := kvdb.Update(s.db, func(tx kvdb.RwTx) error {
nextIDBkt := tx.ReadWriteBucket(nextPaymentIDKey)
if nextIDBkt == nil {
return ErrSequencerCorrupted
}
@ -121,8 +121,8 @@ func (s *persistentSequencer) NextID() (uint64, error) {
// initDB populates the bucket used to generate payment sequence numbers.
func (s *persistentSequencer) initDB() error {
return s.db.Update(func(tx *bbolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(nextPaymentIDKey)
return kvdb.Update(s.db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(nextPaymentIDKey)
return err
})
}

@ -10,10 +10,10 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes"
@ -1419,7 +1419,7 @@ func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) {
// we're the originator of the payment, so the link stops attempting to
// re-broadcast.
func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error {
return s.cfg.DB.Batch(func(tx *bbolt.Tx) error {
return kvdb.Batch(s.cfg.DB.Backend, func(tx kvdb.RwTx) error {
return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...)
})
}
@ -1865,7 +1865,7 @@ func (s *Switch) reforwardResponses() error {
func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
var fwdPkgs []*channeldb.FwdPkg
if err := s.cfg.DB.Update(func(tx *bbolt.Tx) error {
if err := kvdb.Update(s.cfg.DB, func(tx kvdb.RwTx) error {
var err error
fwdPkgs, err = s.cfg.SwitchPackager.LoadChannelFwdPkgs(
tx, source,

@ -21,10 +21,10 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/fastsha256"
"github.com/coreos/bbolt"
"github.com/go-errors/errors"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
@ -420,7 +420,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub)
switch err {
case nil:
case bbolt.ErrDatabaseNotOpen:
case kvdb.ErrDatabaseNotOpen:
dbAlice, err = channeldb.Open(dbAlice.Path())
if err != nil {
return nil, errors.Errorf("unable to reopen alice "+
@ -464,7 +464,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub)
switch err {
case nil:
case bbolt.ErrDatabaseNotOpen:
case kvdb.ErrDatabaseNotOpen:
dbBob, err = channeldb.Open(dbBob.Path())
if err != nil {
return nil, errors.Errorf("unable to reopen bob "+