watchtower: convert to use new kvdb abstraction

This commit is contained in:
Olaoluwa Osuntokun 2020-01-09 18:45:04 -08:00
parent 28bbaa2a94
commit 557b930c5f
No known key found for this signature in database
GPG Key ID: BC13F65E2DC84465
4 changed files with 115 additions and 120 deletions

@ -8,7 +8,7 @@ import (
"net" "net"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -113,7 +113,7 @@ var (
// ClientDB is single database providing a persistent storage engine for the // ClientDB is single database providing a persistent storage engine for the
// wtclient. // wtclient.
type ClientDB struct { type ClientDB struct {
db *bbolt.DB db kvdb.Backend
dbPath string dbPath string
} }
@ -146,7 +146,7 @@ func OpenClientDB(dbPath string) (*ClientDB, error) {
// initialized. This allows us to assume their presence throughout all // initialized. This allows us to assume their presence throughout all
// operations. If an known top-level bucket is expected to exist but is // operations. If an known top-level bucket is expected to exist but is
// missing, this will trigger a ErrUninitializedDB error. // missing, this will trigger a ErrUninitializedDB error.
err = clientDB.db.Update(initClientDBBuckets) err = kvdb.Update(clientDB.db, initClientDBBuckets)
if err != nil { if err != nil {
bdb.Close() bdb.Close()
return nil, err return nil, err
@ -157,7 +157,7 @@ func OpenClientDB(dbPath string) (*ClientDB, error) {
// initClientDBBuckets creates all top-level buckets required to handle database // initClientDBBuckets creates all top-level buckets required to handle database
// operations required by the latest version. // operations required by the latest version.
func initClientDBBuckets(tx *bbolt.Tx) error { func initClientDBBuckets(tx kvdb.RwTx) error {
buckets := [][]byte{ buckets := [][]byte{
cSessionKeyIndexBkt, cSessionKeyIndexBkt,
cChanSummaryBkt, cChanSummaryBkt,
@ -167,7 +167,7 @@ func initClientDBBuckets(tx *bbolt.Tx) error {
} }
for _, bucket := range buckets { for _, bucket := range buckets {
_, err := tx.CreateBucketIfNotExists(bucket) _, err := tx.CreateTopLevelBucket(bucket)
if err != nil { if err != nil {
return err return err
} }
@ -179,7 +179,7 @@ func initClientDBBuckets(tx *bbolt.Tx) error {
// bdb returns the backing bbolt.DB instance. // bdb returns the backing bbolt.DB instance.
// //
// NOTE: Part of the versionedDB interface. // NOTE: Part of the versionedDB interface.
func (c *ClientDB) bdb() *bbolt.DB { func (c *ClientDB) bdb() kvdb.Backend {
return c.db return c.db
} }
@ -188,7 +188,7 @@ func (c *ClientDB) bdb() *bbolt.DB {
// NOTE: Part of the versionedDB interface. // NOTE: Part of the versionedDB interface.
func (c *ClientDB) Version() (uint32, error) { func (c *ClientDB) Version() (uint32, error) {
var version uint32 var version uint32
err := c.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
var err error var err error
version, err = getDBVersion(tx) version, err = getDBVersion(tx)
return err return err
@ -215,13 +215,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed()) copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
var tower *Tower var tower *Tower
err := c.db.Update(func(tx *bbolt.Tx) error { err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
towerIndex := tx.Bucket(cTowerIndexBkt) towerIndex := tx.ReadWriteBucket(cTowerIndexBkt)
if towerIndex == nil { if towerIndex == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
towers := tx.Bucket(cTowerBkt) towers := tx.ReadWriteBucket(cTowerBkt)
if towers == nil { if towers == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -248,7 +248,7 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
// //
// TODO(wilmer): with an index of tower -> sessions we // TODO(wilmer): with an index of tower -> sessions we
// can avoid the linear lookup. // can avoid the linear lookup.
sessions := tx.Bucket(cSessionBkt) sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -308,12 +308,12 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
// //
// NOTE: An error is not returned if the tower doesn't exist. // NOTE: An error is not returned if the tower doesn't exist.
func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return c.db.Update(func(tx *bbolt.Tx) error { return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
towers := tx.Bucket(cTowerBkt) towers := tx.ReadWriteBucket(cTowerBkt)
if towers == nil { if towers == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
towerIndex := tx.Bucket(cTowerIndexBkt) towerIndex := tx.ReadWriteBucket(cTowerIndexBkt)
if towerIndex == nil { if towerIndex == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -342,7 +342,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
// //
// TODO(wilmer): with an index of tower -> sessions we can avoid // TODO(wilmer): with an index of tower -> sessions we can avoid
// the linear lookup. // the linear lookup.
sessions := tx.Bucket(cSessionBkt) sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -383,8 +383,8 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
// LoadTowerByID retrieves a tower by its tower ID. // LoadTowerByID retrieves a tower by its tower ID.
func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) { func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) {
var tower *Tower var tower *Tower
err := c.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
towers := tx.Bucket(cTowerBkt) towers := tx.ReadBucket(cTowerBkt)
if towers == nil { if towers == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -403,12 +403,12 @@ func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) {
// LoadTower retrieves a tower by its public key. // LoadTower retrieves a tower by its public key.
func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) { func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
var tower *Tower var tower *Tower
err := c.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
towers := tx.Bucket(cTowerBkt) towers := tx.ReadBucket(cTowerBkt)
if towers == nil { if towers == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
towerIndex := tx.Bucket(cTowerIndexBkt) towerIndex := tx.ReadBucket(cTowerIndexBkt)
if towerIndex == nil { if towerIndex == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -432,8 +432,8 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
// ListTowers retrieves the list of towers available within the database. // ListTowers retrieves the list of towers available within the database.
func (c *ClientDB) ListTowers() ([]*Tower, error) { func (c *ClientDB) ListTowers() ([]*Tower, error) {
var towers []*Tower var towers []*Tower
err := c.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
towerBucket := tx.Bucket(cTowerBkt) towerBucket := tx.ReadBucket(cTowerBkt)
if towerBucket == nil { if towerBucket == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -461,8 +461,8 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) {
// CreateClientSession is invoked should return the same index. // CreateClientSession is invoked should return the same index.
func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) { func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
var index uint32 var index uint32
err := c.db.Update(func(tx *bbolt.Tx) error { err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
keyIndex := tx.Bucket(cSessionKeyIndexBkt) keyIndex := tx.ReadWriteBucket(cSessionKeyIndexBkt)
if keyIndex == nil { if keyIndex == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -509,20 +509,20 @@ func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
// CreateClientSession records a newly negotiated client session in the set of // CreateClientSession records a newly negotiated client session in the set of
// active sessions. The session can be identified by its SessionID. // active sessions. The session can be identified by its SessionID.
func (c *ClientDB) CreateClientSession(session *ClientSession) error { func (c *ClientDB) CreateClientSession(session *ClientSession) error {
return c.db.Update(func(tx *bbolt.Tx) error { return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
keyIndexes := tx.Bucket(cSessionKeyIndexBkt) keyIndexes := tx.ReadWriteBucket(cSessionKeyIndexBkt)
if keyIndexes == nil { if keyIndexes == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
sessions := tx.Bucket(cSessionBkt) sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
// Check that client session with this session id doesn't // Check that client session with this session id doesn't
// already exist. // already exist.
existingSessionBytes := sessions.Bucket(session.ID[:]) existingSessionBytes := sessions.NestedReadWriteBucket(session.ID[:])
if existingSessionBytes != nil { if existingSessionBytes != nil {
return ErrClientSessionAlreadyExists return ErrClientSessionAlreadyExists
} }
@ -558,8 +558,8 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
// response that do not correspond to this tower. // response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession, error) { func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession var clientSessions map[SessionID]*ClientSession
err := c.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
sessions := tx.Bucket(cSessionBkt) sessions := tx.ReadBucket(cSessionBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -577,7 +577,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession
// listClientSessions returns the set of all client sessions known to the db. An // listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower. // response that do not correspond to this tower.
func listClientSessions(sessions *bbolt.Bucket, func listClientSessions(sessions kvdb.ReadBucket,
id *TowerID) (map[SessionID]*ClientSession, error) { id *TowerID) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession) clientSessions := make(map[SessionID]*ClientSession)
@ -612,8 +612,8 @@ func listClientSessions(sessions *bbolt.Bucket,
// channel summaries. // channel summaries.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
summaries := make(map[lnwire.ChannelID]ClientChanSummary) summaries := make(map[lnwire.ChannelID]ClientChanSummary)
err := c.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
chanSummaries := tx.Bucket(cChanSummaryBkt) chanSummaries := tx.ReadBucket(cChanSummaryBkt)
if chanSummaries == nil { if chanSummaries == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -648,8 +648,8 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
func (c *ClientDB) RegisterChannel(chanID lnwire.ChannelID, func (c *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
sweepPkScript []byte) error { sweepPkScript []byte) error {
return c.db.Update(func(tx *bbolt.Tx) error { return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
chanSummaries := tx.Bucket(cChanSummaryBkt) chanSummaries := tx.ReadWriteBucket(cChanSummaryBkt)
if chanSummaries == nil { if chanSummaries == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -692,8 +692,8 @@ func (c *ClientDB) CommitUpdate(id *SessionID,
update *CommittedUpdate) (uint16, error) { update *CommittedUpdate) (uint16, error) {
var lastApplied uint16 var lastApplied uint16
err := c.db.Update(func(tx *bbolt.Tx) error { err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
sessions := tx.Bucket(cSessionBkt) sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -708,7 +708,7 @@ func (c *ClientDB) CommitUpdate(id *SessionID,
} }
// Can't fail if the above didn't fail. // Can't fail if the above didn't fail.
sessionBkt := sessions.Bucket(id[:]) sessionBkt := sessions.NestedReadWriteBucket(id[:])
// Ensure the session commits sub-bucket is initialized. // Ensure the session commits sub-bucket is initialized.
sessionCommits, err := sessionBkt.CreateBucketIfNotExists( sessionCommits, err := sessionBkt.CreateBucketIfNotExists(
@ -796,8 +796,8 @@ func (c *ClientDB) CommitUpdate(id *SessionID,
func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
lastApplied uint16) error { lastApplied uint16) error {
return c.db.Update(func(tx *bbolt.Tx) error { return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
sessions := tx.Bucket(cSessionBkt) sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -835,11 +835,11 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
} }
// Can't fail because of getClientSession succeeded. // Can't fail because of getClientSession succeeded.
sessionBkt := sessions.Bucket(id[:]) sessionBkt := sessions.NestedReadWriteBucket(id[:])
// If the commits sub-bucket doesn't exist, there can't possibly // If the commits sub-bucket doesn't exist, there can't possibly
// be a corresponding committed update to remove. // be a corresponding committed update to remove.
sessionCommits := sessionBkt.Bucket(cSessionCommits) sessionCommits := sessionBkt.NestedReadWriteBucket(cSessionCommits)
if sessionCommits == nil { if sessionCommits == nil {
return ErrCommittedUpdateNotFound return ErrCommittedUpdateNotFound
} }
@ -894,10 +894,10 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
// bucket corresponding to the serialized session id. This does not deserialize // bucket corresponding to the serialized session id. This does not deserialize
// the CommittedUpdates or AckUpdates associated with the session. If the caller // the CommittedUpdates or AckUpdates associated with the session. If the caller
// requires this info, use getClientSession. // requires this info, use getClientSession.
func getClientSessionBody(sessions *bbolt.Bucket, func getClientSessionBody(sessions kvdb.ReadBucket,
idBytes []byte) (*ClientSession, error) { idBytes []byte) (*ClientSession, error) {
sessionBkt := sessions.Bucket(idBytes) sessionBkt := sessions.NestedReadBucket(idBytes)
if sessionBkt == nil { if sessionBkt == nil {
return nil, ErrClientSessionNotFound return nil, ErrClientSessionNotFound
} }
@ -922,7 +922,7 @@ func getClientSessionBody(sessions *bbolt.Bucket,
// getClientSession loads the full ClientSession associated with the serialized // getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates and AckUpdates in // session id. This method populates the CommittedUpdates and AckUpdates in
// addition to the ClientSession's body. // addition to the ClientSession's body.
func getClientSession(sessions *bbolt.Bucket, func getClientSession(sessions kvdb.ReadBucket,
idBytes []byte) (*ClientSession, error) { idBytes []byte) (*ClientSession, error) {
session, err := getClientSessionBody(sessions, idBytes) session, err := getClientSessionBody(sessions, idBytes)
@ -950,17 +950,17 @@ func getClientSession(sessions *bbolt.Bucket,
// getClientSessionCommits retrieves all committed updates for the session // getClientSessionCommits retrieves all committed updates for the session
// identified by the serialized session id. // identified by the serialized session id.
func getClientSessionCommits(sessions *bbolt.Bucket, func getClientSessionCommits(sessions kvdb.ReadBucket,
idBytes []byte) ([]CommittedUpdate, error) { idBytes []byte) ([]CommittedUpdate, error) {
// Can't fail because client session body has already been read. // Can't fail because client session body has already been read.
sessionBkt := sessions.Bucket(idBytes) sessionBkt := sessions.NestedReadBucket(idBytes)
// Initialize commitedUpdates so that we can return an initialized map // Initialize commitedUpdates so that we can return an initialized map
// if no committed updates exist. // if no committed updates exist.
committedUpdates := make([]CommittedUpdate, 0) committedUpdates := make([]CommittedUpdate, 0)
sessionCommits := sessionBkt.Bucket(cSessionCommits) sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits)
if sessionCommits == nil { if sessionCommits == nil {
return committedUpdates, nil return committedUpdates, nil
} }
@ -986,17 +986,17 @@ func getClientSessionCommits(sessions *bbolt.Bucket,
// getClientSessionAcks retrieves all acked updates for the session identified // getClientSessionAcks retrieves all acked updates for the session identified
// by the serialized session id. // by the serialized session id.
func getClientSessionAcks(sessions *bbolt.Bucket, func getClientSessionAcks(sessions kvdb.ReadBucket,
idBytes []byte) (map[uint16]BackupID, error) { idBytes []byte) (map[uint16]BackupID, error) {
// Can't fail because client session body has already been read. // Can't fail because client session body has already been read.
sessionBkt := sessions.Bucket(idBytes) sessionBkt := sessions.NestedReadBucket(idBytes)
// Initialize ackedUpdates so that we can return an initialized map if // Initialize ackedUpdates so that we can return an initialized map if
// no acked updates exist. // no acked updates exist.
ackedUpdates := make(map[uint16]BackupID) ackedUpdates := make(map[uint16]BackupID)
sessionAcks := sessionBkt.Bucket(cSessionAcks) sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks)
if sessionAcks == nil { if sessionAcks == nil {
return ackedUpdates, nil return ackedUpdates, nil
} }
@ -1023,7 +1023,7 @@ func getClientSessionAcks(sessions *bbolt.Bucket,
// putClientSessionBody stores the body of the ClientSession (everything but the // putClientSessionBody stores the body of the ClientSession (everything but the
// CommittedUpdates and AckedUpdates). // CommittedUpdates and AckedUpdates).
func putClientSessionBody(sessions *bbolt.Bucket, func putClientSessionBody(sessions kvdb.RwBucket,
session *ClientSession) error { session *ClientSession) error {
sessionBkt, err := sessions.CreateBucketIfNotExists(session.ID[:]) sessionBkt, err := sessions.CreateBucketIfNotExists(session.ID[:])
@ -1042,7 +1042,7 @@ func putClientSessionBody(sessions *bbolt.Bucket,
// markSessionStatus updates the persisted state of the session to the new // markSessionStatus updates the persisted state of the session to the new
// status. // status.
func markSessionStatus(sessions *bbolt.Bucket, session *ClientSession, func markSessionStatus(sessions kvdb.RwBucket, session *ClientSession,
status CSessionStatus) error { status CSessionStatus) error {
session.Status = status session.Status = status
@ -1050,7 +1050,7 @@ func markSessionStatus(sessions *bbolt.Bucket, session *ClientSession,
} }
// getChanSummary loads a ClientChanSummary for the passed chanID. // getChanSummary loads a ClientChanSummary for the passed chanID.
func getChanSummary(chanSummaries *bbolt.Bucket, func getChanSummary(chanSummaries kvdb.ReadBucket,
chanID lnwire.ChannelID) (*ClientChanSummary, error) { chanID lnwire.ChannelID) (*ClientChanSummary, error) {
chanSummaryBytes := chanSummaries.Get(chanID[:]) chanSummaryBytes := chanSummaries.Get(chanID[:])
@ -1068,7 +1068,7 @@ func getChanSummary(chanSummaries *bbolt.Bucket,
} }
// putChanSummary stores a ClientChanSummary for the passed chanID. // putChanSummary stores a ClientChanSummary for the passed chanID.
func putChanSummary(chanSummaries *bbolt.Bucket, chanID lnwire.ChannelID, func putChanSummary(chanSummaries kvdb.RwBucket, chanID lnwire.ChannelID,
summary *ClientChanSummary) error { summary *ClientChanSummary) error {
var b bytes.Buffer var b bytes.Buffer
@ -1081,7 +1081,7 @@ func putChanSummary(chanSummaries *bbolt.Bucket, chanID lnwire.ChannelID,
} }
// getTower loads a Tower identified by its serialized tower id. // getTower loads a Tower identified by its serialized tower id.
func getTower(towers *bbolt.Bucket, id []byte) (*Tower, error) { func getTower(towers kvdb.ReadBucket, id []byte) (*Tower, error) {
towerBytes := towers.Get(id) towerBytes := towers.Get(id)
if towerBytes == nil { if towerBytes == nil {
return nil, ErrTowerNotFound return nil, ErrTowerNotFound
@ -1099,7 +1099,7 @@ func getTower(towers *bbolt.Bucket, id []byte) (*Tower, error) {
} }
// putTower stores a Tower identified by its serialized tower id. // putTower stores a Tower identified by its serialized tower id.
func putTower(towers *bbolt.Bucket, tower *Tower) error { func putTower(towers kvdb.RwBucket, tower *Tower) error {
var b bytes.Buffer var b bytes.Buffer
err := tower.Encode(&b) err := tower.Encode(&b)
if err != nil { if err != nil {

@ -6,7 +6,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/channeldb/kvdb"
) )
const ( const (
@ -49,7 +49,7 @@ func fileExists(path string) bool {
// one doesn't exist. The boolean returned indicates if the database did not // one doesn't exist. The boolean returned indicates if the database did not
// exist before, or if it has been created but no version metadata exists within // exist before, or if it has been created but no version metadata exists within
// it. // it.
func createDBIfNotExist(dbPath, name string) (*bbolt.DB, bool, error) { func createDBIfNotExist(dbPath, name string) (kvdb.Backend, bool, error) {
path := filepath.Join(dbPath, name) path := filepath.Join(dbPath, name)
// If the database file doesn't exist, this indicates we much initialize // If the database file doesn't exist, this indicates we much initialize
@ -65,12 +65,7 @@ func createDBIfNotExist(dbPath, name string) (*bbolt.DB, bool, error) {
// Specify bbolt freelist options to reduce heap pressure in case the // Specify bbolt freelist options to reduce heap pressure in case the
// freelist grows to be very large. // freelist grows to be very large.
options := &bbolt.Options{ bdb, err := kvdb.Create(kvdb.BoltBackendName, path, true)
NoFreelistSync: true,
FreelistType: bbolt.FreelistMapType,
}
bdb, err := bbolt.Open(path, dbFilePermission, options)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
@ -82,8 +77,8 @@ func createDBIfNotExist(dbPath, name string) (*bbolt.DB, bool, error) {
// set firstInit to true so that we can treat is initialize the bucket. // set firstInit to true so that we can treat is initialize the bucket.
if !firstInit { if !firstInit {
var metadataExists bool var metadataExists bool
err = bdb.View(func(tx *bbolt.Tx) error { err = kvdb.View(bdb, func(tx kvdb.ReadTx) error {
metadataExists = tx.Bucket(metadataBkt) != nil metadataExists = tx.ReadBucket(metadataBkt) != nil
return nil return nil
}) })
if err != nil { if err != nil {

@ -5,8 +5,8 @@ import (
"errors" "errors"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/blob"
) )
@ -55,7 +55,7 @@ var (
// TowerDB is single database providing a persistent storage engine for the // TowerDB is single database providing a persistent storage engine for the
// wtserver and lookout subsystems. // wtserver and lookout subsystems.
type TowerDB struct { type TowerDB struct {
db *bbolt.DB db kvdb.Backend
dbPath string dbPath string
} }
@ -88,7 +88,7 @@ func OpenTowerDB(dbPath string) (*TowerDB, error) {
// initialized. This allows us to assume their presence throughout all // initialized. This allows us to assume their presence throughout all
// operations. If an known top-level bucket is expected to exist but is // operations. If an known top-level bucket is expected to exist but is
// missing, this will trigger a ErrUninitializedDB error. // missing, this will trigger a ErrUninitializedDB error.
err = towerDB.db.Update(initTowerDBBuckets) err = kvdb.Update(towerDB.db, initTowerDBBuckets)
if err != nil { if err != nil {
bdb.Close() bdb.Close()
return nil, err return nil, err
@ -99,7 +99,7 @@ func OpenTowerDB(dbPath string) (*TowerDB, error) {
// initTowerDBBuckets creates all top-level buckets required to handle database // initTowerDBBuckets creates all top-level buckets required to handle database
// operations required by the latest version. // operations required by the latest version.
func initTowerDBBuckets(tx *bbolt.Tx) error { func initTowerDBBuckets(tx kvdb.RwTx) error {
buckets := [][]byte{ buckets := [][]byte{
sessionsBkt, sessionsBkt,
updateIndexBkt, updateIndexBkt,
@ -108,7 +108,7 @@ func initTowerDBBuckets(tx *bbolt.Tx) error {
} }
for _, bucket := range buckets { for _, bucket := range buckets {
_, err := tx.CreateBucketIfNotExists(bucket) _, err := tx.CreateTopLevelBucket(bucket)
if err != nil { if err != nil {
return err return err
} }
@ -120,7 +120,7 @@ func initTowerDBBuckets(tx *bbolt.Tx) error {
// bdb returns the backing bbolt.DB instance. // bdb returns the backing bbolt.DB instance.
// //
// NOTE: Part of the versionedDB interface. // NOTE: Part of the versionedDB interface.
func (t *TowerDB) bdb() *bbolt.DB { func (t *TowerDB) bdb() kvdb.Backend {
return t.db return t.db
} }
@ -129,7 +129,7 @@ func (t *TowerDB) bdb() *bbolt.DB {
// NOTE: Part of the versionedDB interface. // NOTE: Part of the versionedDB interface.
func (t *TowerDB) Version() (uint32, error) { func (t *TowerDB) Version() (uint32, error) {
var version uint32 var version uint32
err := t.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(t.db, func(tx kvdb.ReadTx) error {
var err error var err error
version, err = getDBVersion(tx) version, err = getDBVersion(tx)
return err return err
@ -150,8 +150,8 @@ func (t *TowerDB) Close() error {
// returned if the session could not be found. // returned if the session could not be found.
func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) { func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
var session *SessionInfo var session *SessionInfo
err := t.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(t.db, func(tx kvdb.ReadTx) error {
sessions := tx.Bucket(sessionsBkt) sessions := tx.ReadBucket(sessionsBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -170,13 +170,13 @@ func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
// InsertSessionInfo records a negotiated session in the tower database. An // InsertSessionInfo records a negotiated session in the tower database. An
// error is returned if the session already exists. // error is returned if the session already exists.
func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error { func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error {
return t.db.Update(func(tx *bbolt.Tx) error { return kvdb.Update(t.db, func(tx kvdb.RwTx) error {
sessions := tx.Bucket(sessionsBkt) sessions := tx.ReadWriteBucket(sessionsBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
updateIndex := tx.Bucket(updateIndexBkt) updateIndex := tx.ReadWriteBucket(updateIndexBkt)
if updateIndex == nil { if updateIndex == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -219,18 +219,18 @@ func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error {
// properly and the last applied values echoed by the client are sane. // properly and the last applied values echoed by the client are sane.
func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) { func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) {
var lastApplied uint16 var lastApplied uint16
err := t.db.Update(func(tx *bbolt.Tx) error { err := kvdb.Update(t.db, func(tx kvdb.RwTx) error {
sessions := tx.Bucket(sessionsBkt) sessions := tx.ReadWriteBucket(sessionsBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
updates := tx.Bucket(updatesBkt) updates := tx.ReadWriteBucket(updatesBkt)
if updates == nil { if updates == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
updateIndex := tx.Bucket(updateIndexBkt) updateIndex := tx.ReadWriteBucket(updateIndexBkt)
if updateIndex == nil { if updateIndex == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -303,18 +303,18 @@ func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error)
// DeleteSession removes all data associated with a particular session id from // DeleteSession removes all data associated with a particular session id from
// the tower's database. // the tower's database.
func (t *TowerDB) DeleteSession(target SessionID) error { func (t *TowerDB) DeleteSession(target SessionID) error {
return t.db.Update(func(tx *bbolt.Tx) error { return kvdb.Update(t.db, func(tx kvdb.RwTx) error {
sessions := tx.Bucket(sessionsBkt) sessions := tx.ReadWriteBucket(sessionsBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
updates := tx.Bucket(updatesBkt) updates := tx.ReadWriteBucket(updatesBkt)
if updates == nil { if updates == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
updateIndex := tx.Bucket(updateIndexBkt) updateIndex := tx.ReadWriteBucket(updateIndexBkt)
if updateIndex == nil { if updateIndex == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -341,7 +341,7 @@ func (t *TowerDB) DeleteSession(target SessionID) error {
for _, hint := range hints { for _, hint := range hints {
// Remove the state updates for any blobs stored under // Remove the state updates for any blobs stored under
// the target session identifier. // the target session identifier.
updatesForHint := updates.Bucket(hint[:]) updatesForHint := updates.NestedReadWriteBucket(hint[:])
if updatesForHint == nil { if updatesForHint == nil {
continue continue
} }
@ -371,7 +371,7 @@ func (t *TowerDB) DeleteSession(target SessionID) error {
// No more updates for this hint, prune hint bucket. // No more updates for this hint, prune hint bucket.
default: default:
err = updates.DeleteBucket(hint[:]) err = updates.DeleteNestedBucket(hint[:])
if err != nil { if err != nil {
return err return err
} }
@ -389,13 +389,13 @@ func (t *TowerDB) DeleteSession(target SessionID) error {
// they exist in the database. // they exist in the database.
func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) { func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) {
var matches []Match var matches []Match
err := t.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(t.db, func(tx kvdb.ReadTx) error {
sessions := tx.Bucket(sessionsBkt) sessions := tx.ReadBucket(sessionsBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
updates := tx.Bucket(updatesBkt) updates := tx.ReadBucket(updatesBkt)
if updates == nil { if updates == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -405,7 +405,7 @@ func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) {
for _, hint := range breachHints { for _, hint := range breachHints {
// If a bucket does not exist for this hint, no matches // If a bucket does not exist for this hint, no matches
// are known. // are known.
updatesForHint := updates.Bucket(hint[:]) updatesForHint := updates.NestedReadBucket(hint[:])
if updatesForHint == nil { if updatesForHint == nil {
continue continue
} }
@ -471,8 +471,8 @@ func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) {
// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in // SetLookoutTip stores the provided epoch as the latest lookout tip epoch in
// the tower database. // the tower database.
func (t *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error { func (t *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
return t.db.Update(func(tx *bbolt.Tx) error { return kvdb.Update(t.db, func(tx kvdb.RwTx) error {
lookoutTip := tx.Bucket(lookoutTipBkt) lookoutTip := tx.ReadWriteBucket(lookoutTipBkt)
if lookoutTip == nil { if lookoutTip == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -485,8 +485,8 @@ func (t *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
// database. // database.
func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
var epoch *chainntnfs.BlockEpoch var epoch *chainntnfs.BlockEpoch
err := t.db.View(func(tx *bbolt.Tx) error { err := kvdb.View(t.db, func(tx kvdb.ReadTx) error {
lookoutTip := tx.Bucket(lookoutTipBkt) lookoutTip := tx.ReadBucket(lookoutTipBkt)
if lookoutTip == nil { if lookoutTip == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -505,7 +505,7 @@ func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
// getSession retrieves the session info from the sessions bucket identified by // getSession retrieves the session info from the sessions bucket identified by
// its session id. An error is returned if the session is not found or a // its session id. An error is returned if the session is not found or a
// deserialization error occurs. // deserialization error occurs.
func getSession(sessions *bbolt.Bucket, id []byte) (*SessionInfo, error) { func getSession(sessions kvdb.ReadBucket, id []byte) (*SessionInfo, error) {
sessionBytes := sessions.Get(id) sessionBytes := sessions.Get(id)
if sessionBytes == nil { if sessionBytes == nil {
return nil, ErrSessionNotFound return nil, ErrSessionNotFound
@ -522,7 +522,7 @@ func getSession(sessions *bbolt.Bucket, id []byte) (*SessionInfo, error) {
// putSession stores the session info in the sessions bucket identified by its // putSession stores the session info in the sessions bucket identified by its
// session id. An error is returned if a serialization error occurs. // session id. An error is returned if a serialization error occurs.
func putSession(sessions *bbolt.Bucket, session *SessionInfo) error { func putSession(sessions kvdb.RwBucket, session *SessionInfo) error {
var b bytes.Buffer var b bytes.Buffer
err := session.Encode(&b) err := session.Encode(&b)
if err != nil { if err != nil {
@ -536,7 +536,7 @@ func putSession(sessions *bbolt.Bucket, session *SessionInfo) error {
// session id. This ensures that future calls to getHintsForSession or // session id. This ensures that future calls to getHintsForSession or
// putHintForSession can rely on the bucket already being created, and fail if // putHintForSession can rely on the bucket already being created, and fail if
// index has not been initialized as this points to improper usage. // index has not been initialized as this points to improper usage.
func touchSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error { func touchSessionHintBkt(updateIndex kvdb.RwBucket, id *SessionID) error {
_, err := updateIndex.CreateBucketIfNotExists(id[:]) _, err := updateIndex.CreateBucketIfNotExists(id[:])
return err return err
} }
@ -544,17 +544,17 @@ func touchSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error {
// removeSessionHintBkt prunes the session-hint bucket for the given session id // removeSessionHintBkt prunes the session-hint bucket for the given session id
// and all of the hints contained inside. This should be used to clean up the // and all of the hints contained inside. This should be used to clean up the
// index upon session deletion. // index upon session deletion.
func removeSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error { func removeSessionHintBkt(updateIndex kvdb.RwBucket, id *SessionID) error {
return updateIndex.DeleteBucket(id[:]) return updateIndex.DeleteNestedBucket(id[:])
} }
// getHintsForSession returns all known hints belonging to the given session id. // getHintsForSession returns all known hints belonging to the given session id.
// If the index for the session has not been initialized, this method returns // If the index for the session has not been initialized, this method returns
// ErrNoSessionHintIndex. // ErrNoSessionHintIndex.
func getHintsForSession(updateIndex *bbolt.Bucket, func getHintsForSession(updateIndex kvdb.ReadBucket,
id *SessionID) ([]blob.BreachHint, error) { id *SessionID) ([]blob.BreachHint, error) {
sessionHints := updateIndex.Bucket(id[:]) sessionHints := updateIndex.NestedReadBucket(id[:])
if sessionHints == nil { if sessionHints == nil {
return nil, ErrNoSessionHintIndex return nil, ErrNoSessionHintIndex
} }
@ -582,10 +582,10 @@ func getHintsForSession(updateIndex *bbolt.Bucket,
// session id, and used to perform efficient removal of updates. If the index // session id, and used to perform efficient removal of updates. If the index
// for the session has not been initialized, this method returns // for the session has not been initialized, this method returns
// ErrNoSessionHintIndex. // ErrNoSessionHintIndex.
func putHintForSession(updateIndex *bbolt.Bucket, id *SessionID, func putHintForSession(updateIndex kvdb.RwBucket, id *SessionID,
hint blob.BreachHint) error { hint blob.BreachHint) error {
sessionHints := updateIndex.Bucket(id[:]) sessionHints := updateIndex.NestedReadWriteBucket(id[:])
if sessionHints == nil { if sessionHints == nil {
return ErrNoSessionHintIndex return ErrNoSessionHintIndex
} }
@ -594,7 +594,7 @@ func putHintForSession(updateIndex *bbolt.Bucket, id *SessionID,
} }
// putLookoutEpoch stores the given lookout tip block epoch in provided bucket. // putLookoutEpoch stores the given lookout tip block epoch in provided bucket.
func putLookoutEpoch(bkt *bbolt.Bucket, epoch *chainntnfs.BlockEpoch) error { func putLookoutEpoch(bkt kvdb.RwBucket, epoch *chainntnfs.BlockEpoch) error {
epochBytes := make([]byte, 36) epochBytes := make([]byte, 36)
copy(epochBytes, epoch.Hash[:]) copy(epochBytes, epoch.Hash[:])
byteOrder.PutUint32(epochBytes[32:], uint32(epoch.Height)) byteOrder.PutUint32(epochBytes[32:], uint32(epoch.Height))
@ -604,7 +604,7 @@ func putLookoutEpoch(bkt *bbolt.Bucket, epoch *chainntnfs.BlockEpoch) error {
// getLookoutEpoch retrieves the lookout tip block epoch from the given bucket. // getLookoutEpoch retrieves the lookout tip block epoch from the given bucket.
// A nil epoch is returned if no update exists. // A nil epoch is returned if no update exists.
func getLookoutEpoch(bkt *bbolt.Bucket) *chainntnfs.BlockEpoch { func getLookoutEpoch(bkt kvdb.ReadBucket) *chainntnfs.BlockEpoch {
epochBytes := bkt.Get(lookoutTipKey) epochBytes := bkt.Get(lookoutTipKey)
if len(epochBytes) != 36 { if len(epochBytes) != 36 {
return nil return nil
@ -625,7 +625,7 @@ func getLookoutEpoch(bkt *bbolt.Bucket) *chainntnfs.BlockEpoch {
var errBucketNotEmpty = errors.New("bucket not empty") var errBucketNotEmpty = errors.New("bucket not empty")
// isBucketEmpty returns errBucketNotEmpty if the bucket is not empty. // isBucketEmpty returns errBucketNotEmpty if the bucket is not empty.
func isBucketEmpty(bkt *bbolt.Bucket) error { func isBucketEmpty(bkt kvdb.ReadBucket) error {
return bkt.ForEach(func(_, _ []byte) error { return bkt.ForEach(func(_, _ []byte) error {
return errBucketNotEmpty return errBucketNotEmpty
}) })

@ -1,14 +1,14 @@
package wtdb package wtdb
import ( import (
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
) )
// migration is a function which takes a prior outdated version of the database // migration is a function which takes a prior outdated version of the database
// instances and mutates the key/bucket structure to arrive at a more // instances and mutates the key/bucket structure to arrive at a more
// up-to-date version of the database. // up-to-date version of the database.
type migration func(tx *bbolt.Tx) error type migration func(tx kvdb.RwTx) error
// version pairs a version number with the migration that would need to be // version pairs a version number with the migration that would need to be
// applied from the prior version to upgrade. // applied from the prior version to upgrade.
@ -46,8 +46,8 @@ func getMigrations(versions []version, curVersion uint32) []version {
// getDBVersion retrieves the current database version from the metadata bucket // getDBVersion retrieves the current database version from the metadata bucket
// using the dbVersionKey. // using the dbVersionKey.
func getDBVersion(tx *bbolt.Tx) (uint32, error) { func getDBVersion(tx kvdb.ReadTx) (uint32, error) {
metadata := tx.Bucket(metadataBkt) metadata := tx.ReadBucket(metadataBkt)
if metadata == nil { if metadata == nil {
return 0, ErrUninitializedDB return 0, ErrUninitializedDB
} }
@ -62,8 +62,8 @@ func getDBVersion(tx *bbolt.Tx) (uint32, error) {
// initDBVersion initializes the top-level metadata bucket and writes the passed // initDBVersion initializes the top-level metadata bucket and writes the passed
// version number as the current version. // version number as the current version.
func initDBVersion(tx *bbolt.Tx, version uint32) error { func initDBVersion(tx kvdb.RwTx, version uint32) error {
_, err := tx.CreateBucketIfNotExists(metadataBkt) _, err := tx.CreateTopLevelBucket(metadataBkt)
if err != nil { if err != nil {
return err return err
} }
@ -73,8 +73,8 @@ func initDBVersion(tx *bbolt.Tx, version uint32) error {
// putDBVersion stores the passed database version in the metadata bucket under // putDBVersion stores the passed database version in the metadata bucket under
// the dbVersionKey. // the dbVersionKey.
func putDBVersion(tx *bbolt.Tx, version uint32) error { func putDBVersion(tx kvdb.RwTx, version uint32) error {
metadata := tx.Bucket(metadataBkt) metadata := tx.ReadWriteBucket(metadataBkt)
if metadata == nil { if metadata == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
@ -89,7 +89,7 @@ func putDBVersion(tx *bbolt.Tx, version uint32) error {
// on either. // on either.
type versionedDB interface { type versionedDB interface {
// bdb returns the underlying bbolt database. // bdb returns the underlying bbolt database.
bdb() *bbolt.DB bdb() kvdb.Backend
// Version returns the current version stored in the database. // Version returns the current version stored in the database.
Version() (uint32, error) Version() (uint32, error)
@ -105,7 +105,7 @@ func initOrSyncVersions(db versionedDB, init bool, versions []version) error {
// If the database has not yet been created, we'll initialize the // If the database has not yet been created, we'll initialize the
// database version with the latest known version. // database version with the latest known version.
if init { if init {
return db.bdb().Update(func(tx *bbolt.Tx) error { return kvdb.Update(db.bdb(), func(tx kvdb.RwTx) error {
return initDBVersion(tx, getLatestDBVersion(versions)) return initDBVersion(tx, getLatestDBVersion(versions))
}) })
} }
@ -141,7 +141,7 @@ func syncVersions(db versionedDB, versions []version) error {
// Otherwise, apply any migrations in order to bring the database // Otherwise, apply any migrations in order to bring the database
// version up to the highest known version. // version up to the highest known version.
updates := getMigrations(versions, curVersion) updates := getMigrations(versions, curVersion)
return db.bdb().Update(func(tx *bbolt.Tx) error { return kvdb.Update(db.bdb(), func(tx kvdb.RwTx) error {
for i, update := range updates { for i, update := range updates {
if update.migration == nil { if update.migration == nil {
continue continue