From ec7c16fdc1c9409bdb691c9a702d06d72d6f9fc2 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:47:08 -0700 Subject: [PATCH 01/11] watchtower/wtdb: prepare for addition of client db This commit renames the variables dbName to towerDBName and dbVersions to towerDBVersions, to distinguish between the upcoming clientDBName clientDBVersions. We also move resusable portions of the database initialization and default endianness to its own file so that it can be shared between both tower and client databases. --- watchtower/wtdb/db_common.go | 92 ++++++++++++++++++++++ watchtower/wtdb/tower_db.go | 147 ++++------------------------------- watchtower/wtdb/version.go | 104 +++++++++++++++++++++---- 3 files changed, 195 insertions(+), 148 deletions(-) create mode 100644 watchtower/wtdb/db_common.go diff --git a/watchtower/wtdb/db_common.go b/watchtower/wtdb/db_common.go new file mode 100644 index 00000000..63bca9ae --- /dev/null +++ b/watchtower/wtdb/db_common.go @@ -0,0 +1,92 @@ +package wtdb + +import ( + "encoding/binary" + "errors" + "os" + "path/filepath" + + "github.com/coreos/bbolt" +) + +const ( + // dbFilePermission requests read+write access to the db file. + dbFilePermission = 0600 +) + +var ( + // metadataBkt stores all the meta information concerning the state of + // the database. + metadataBkt = []byte("metadata-bucket") + + // dbVersionKey is a static key used to retrieve the database version + // number from the metadataBkt. + dbVersionKey = []byte("version") + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrNoDBVersion signals that the database contains no version info. + ErrNoDBVersion = errors.New("db has no version") + + // byteOrder is the default endianness used when serializing integers. + byteOrder = binary.BigEndian +) + +// fileExists returns true if the file exists, and false otherwise. +func fileExists(path string) bool { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return false + } + } + + return true +} + +// createDBIfNotExist opens the boltdb database at dbPath/name, creating one if +// one doesn't exist. The boolean returned indicates if the database did not +// exist before, or if it has been created but no version metadata exists within +// it. +func createDBIfNotExist(dbPath, name string) (*bbolt.DB, bool, error) { + path := filepath.Join(dbPath, name) + + // If the database file doesn't exist, this indicates we much initialize + // a fresh database with the latest version. + firstInit := !fileExists(path) + if firstInit { + // Ensure all parent directories are initialized. + err := os.MkdirAll(dbPath, 0700) + if err != nil { + return nil, false, err + } + } + + bdb, err := bbolt.Open(path, dbFilePermission, nil) + if err != nil { + return nil, false, err + } + + // If the file existed previously, we'll now check to see that the + // metadata bucket is properly initialized. It could be the case that + // the database was created, but we failed to actually populate any + // metadata. If the metadata bucket does not actually exist, we'll + // set firstInit to true so that we can treat is initialize the bucket. + if !firstInit { + var metadataExists bool + err = bdb.View(func(tx *bbolt.Tx) error { + metadataExists = tx.Bucket(metadataBkt) != nil + return nil + }) + if err != nil { + return nil, false, err + } + + if !metadataExists { + firstInit = true + } + } + + return bdb, firstInit, nil +} diff --git a/watchtower/wtdb/tower_db.go b/watchtower/wtdb/tower_db.go index 0bcd271c..96edafca 100644 --- a/watchtower/wtdb/tower_db.go +++ b/watchtower/wtdb/tower_db.go @@ -2,23 +2,16 @@ package wtdb import ( "bytes" - "encoding/binary" "errors" - "os" - "path/filepath" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" ) const ( - // dbName is the filename of tower database. - dbName = "watchtower.db" - - // dbFilePermission requests read+write access to the db file. - dbFilePermission = 0600 + // towerDBName is the filename of tower database. + towerDBName = "watchtower.db" ) var ( @@ -49,26 +42,9 @@ var ( // epoch from the lookoutTipBkt. lookoutTipKey = []byte("lookout-tip") - // metadataBkt stores all the meta information concerning the state of - // the database. - metadataBkt = []byte("metadata-bucket") - - // dbVersionKey is a static key used to retrieve the database version - // number from the metadataBkt. - dbVersionKey = []byte("version") - - // ErrUninitializedDB signals that top-level buckets for the database - // have not been initialized. - ErrUninitializedDB = errors.New("tower db not initialized") - - // ErrNoDBVersion signals that the database contains no version info. - ErrNoDBVersion = errors.New("tower db has no version") - // ErrNoSessionHintIndex signals that an active session does not have an // initialized index for tracking its own state updates. ErrNoSessionHintIndex = errors.New("session hint index missing") - - byteOrder = binary.BigEndian ) // TowerDB is single database providing a persistent storage engine for the @@ -86,67 +62,20 @@ type TowerDB struct { // with a version number higher that the latest version will fail to prevent // accidental reversion. func OpenTowerDB(dbPath string) (*TowerDB, error) { - path := filepath.Join(dbPath, dbName) - - // If the database file doesn't exist, this indicates we much initialize - // a fresh database with the latest version. - firstInit := !fileExists(path) - if firstInit { - // Ensure all parent directories are initialized. - err := os.MkdirAll(dbPath, 0700) - if err != nil { - return nil, err - } - } - - bdb, err := bbolt.Open(path, dbFilePermission, nil) + bdb, firstInit, err := createDBIfNotExist(dbPath, towerDBName) if err != nil { return nil, err } - // If the file existed previously, we'll now check to see that the - // metadata bucket is properly initialized. It could be the case that - // the database was created, but we failed to actually populate any - // metadata. If the metadata bucket does not actually exist, we'll - // set firstInit to true so that we can treat is initialize the bucket. - if !firstInit { - var metadataExists bool - err = bdb.View(func(tx *bbolt.Tx) error { - metadataExists = tx.Bucket(metadataBkt) != nil - return nil - }) - if err != nil { - return nil, err - } - - if !metadataExists { - firstInit = true - } - } - towerDB := &TowerDB{ db: bdb, dbPath: dbPath, } - if firstInit { - // If the database has not yet been created, we'll initialize - // the database version with the latest known version. - err = towerDB.db.Update(func(tx *bbolt.Tx) error { - return initDBVersion(tx, getLatestDBVersion(dbVersions)) - }) - if err != nil { - bdb.Close() - return nil, err - } - } else { - // Otherwise, ensure that any migrations are applied to ensure - // the data is in the format expected by the latest version. - err = towerDB.syncVersions(dbVersions) - if err != nil { - bdb.Close() - return nil, err - } + err = initOrSyncVersions(towerDB, firstInit, towerDBVersions) + if err != nil { + bdb.Close() + return nil, err } // Now that the database version fully consistent with our latest known @@ -163,17 +92,6 @@ func OpenTowerDB(dbPath string) (*TowerDB, error) { return towerDB, nil } -// fileExists returns true if the file exists, and false otherwise. -func fileExists(path string) bool { - if _, err := os.Stat(path); err != nil { - if os.IsNotExist(err) { - return false - } - } - - return true -} - // initTowerDBBuckets creates all top-level buckets required to handle database // operations required by the latest version. func initTowerDBBuckets(tx *bbolt.Tx) error { @@ -194,53 +112,16 @@ func initTowerDBBuckets(tx *bbolt.Tx) error { return nil } -// syncVersions ensures the database version is consistent with the highest -// known database version, applying any migrations that have not been made. If -// the highest known version number is lower than the database's version, this -// method will fail to prevent accidental reversions. -func (t *TowerDB) syncVersions(versions []version) error { - curVersion, err := t.Version() - if err != nil { - return err - } - - latestVersion := getLatestDBVersion(versions) - switch { - - // Current version is higher than any known version, fail to prevent - // reversion. - case curVersion > latestVersion: - return channeldb.ErrDBReversion - - // Current version matches highest known version, nothing to do. - case curVersion == latestVersion: - return nil - } - - // Otherwise, apply any migrations in order to bring the database - // version up to the highest known version. - updates := getMigrations(versions, curVersion) - return t.db.Update(func(tx *bbolt.Tx) error { - for _, update := range updates { - if update.migration == nil { - continue - } - - log.Infof("Applying migration #%d", update.number) - - err := update.migration(tx) - if err != nil { - log.Errorf("Unable to apply migration #%d: %v", - err) - return err - } - } - - return putDBVersion(tx, latestVersion) - }) +// bdb returns the backing bbolt.DB instance. +// +// NOTE: Part of the versionedDB interface. +func (t *TowerDB) bdb() *bbolt.DB { + return t.db } // Version returns the database's current version number. +// +// NOTE: Part of the versionedDB interface. func (t *TowerDB) Version() (uint32, error) { var version uint32 err := t.db.View(func(tx *bbolt.Tx) error { diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index fd7481af..974f25b0 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -1,6 +1,9 @@ package wtdb -import "github.com/coreos/bbolt" +import ( + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb" +) // migration is a function which takes a prior outdated version of the database // instances and mutates the key/bucket structure to arrive at a more @@ -10,32 +13,25 @@ type migration func(tx *bbolt.Tx) error // version pairs a version number with the migration that would need to be // applied from the prior version to upgrade. type version struct { - number uint32 migration migration } -// dbVersions stores all versions and migrations of the database. This list will -// be used when opening the database to determine if any migrations must be -// applied. -var dbVersions = []version{ - { - // Initial version requires no migration. - number: 0, - migration: nil, - }, -} +// towerDBVersions stores all versions and migrations of the tower database. +// This list will be used when opening the database to determine if any +// migrations must be applied. +var towerDBVersions = []version{} // getLatestDBVersion returns the last known database version. func getLatestDBVersion(versions []version) uint32 { - return versions[len(versions)-1].number + return uint32(len(versions)) } // getMigrations returns a slice of all updates with a greater number that // curVersion that need to be applied to sync up with the latest version. func getMigrations(versions []version, curVersion uint32) []version { var updates []version - for _, v := range versions { - if v.number > curVersion { + for i, v := range versions { + if uint32(i)+1 > curVersion { updates = append(updates, v) } } @@ -82,3 +78,81 @@ func putDBVersion(tx *bbolt.Tx, version uint32) error { byteOrder.PutUint32(versionBytes, version) return metadata.Put(dbVersionKey, versionBytes) } + +// versionedDB is a private interface implemented by both the tower and client +// databases, permitting all versioning operations to be performed generically +// on either. +type versionedDB interface { + // bdb returns the underlying bbolt database. + bdb() *bbolt.DB + + // Version returns the current version stored in the database. + Version() (uint32, error) +} + +// initOrSyncVersions ensures that the database version is properly set before +// opening the database up for regular use. When the database is being +// initialized for the first time, the caller should set init to true, which +// will simply write the latest version to the database. Otherwise, passing init +// as false will cause the database to apply any needed migrations to ensure its +// version matches the latest version in the provided versions list. +func initOrSyncVersions(db versionedDB, init bool, versions []version) error { + // If the database has not yet been created, we'll initialize the + // database version with the latest known version. + if init { + return db.bdb().Update(func(tx *bbolt.Tx) error { + return initDBVersion(tx, getLatestDBVersion(versions)) + }) + } + + // Otherwise, ensure that any migrations are applied to ensure the data + // is in the format expected by the latest version. + return syncVersions(db, versions) +} + +// syncVersions ensures the database version is consistent with the highest +// known database version, applying any migrations that have not been made. If +// the highest known version number is lower than the database's version, this +// method will fail to prevent accidental reversions. +func syncVersions(db versionedDB, versions []version) error { + curVersion, err := db.Version() + if err != nil { + return err + } + + latestVersion := getLatestDBVersion(versions) + switch { + + // Current version is higher than any known version, fail to prevent + // reversion. + case curVersion > latestVersion: + return channeldb.ErrDBReversion + + // Current version matches highest known version, nothing to do. + case curVersion == latestVersion: + return nil + } + + // Otherwise, apply any migrations in order to bring the database + // version up to the highest known version. + updates := getMigrations(versions, curVersion) + return db.bdb().Update(func(tx *bbolt.Tx) error { + for i, update := range updates { + if update.migration == nil { + continue + } + + version := curVersion + uint32(i) + 1 + log.Infof("Applying migration #%d", version) + + err := update.migration(tx) + if err != nil { + log.Errorf("Unable to apply migration #%d: %v", + version, err) + return err + } + } + + return putDBVersion(tx, latestVersion) + }) +} From 3509c0c991eab897aa5fcead81530d8541b6f6e2 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:47:22 -0700 Subject: [PATCH 02/11] watchtower/multi: use proper TowerID type This allows serialization methods to be added with TowerID method receivers. --- watchtower/wtclient/interface.go | 4 ++-- watchtower/wtdb/client_session.go | 2 +- watchtower/wtdb/tower.go | 21 ++++++++++++++++++++- watchtower/wtmock/client_db.go | 20 ++++++++++---------- 4 files changed, 33 insertions(+), 14 deletions(-) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 5aef8619..1d091d50 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -21,7 +21,7 @@ type DB interface { CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error) // LoadTower retrieves a tower by its tower ID. - LoadTower(uint64) (*wtdb.Tower, error) + LoadTower(wtdb.TowerID) (*wtdb.Tower, error) // NextSessionKeyIndex reserves a new session key derivation index for a // particular tower id. The index is reserved for that tower until @@ -29,7 +29,7 @@ type DB interface { // point a new index for that tower can be reserved. Multiple calls to // this method before CreateClientSession is invoked should return the // same index. - NextSessionKeyIndex(uint64) (uint32, error) + NextSessionKeyIndex(wtdb.TowerID) (uint32, error) // CreateClientSession saves a newly negotiated client session to the // client's database. This enables the session to be used across diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index 43df8bb9..ded54d8a 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -57,7 +57,7 @@ type ClientSession struct { // TowerID is the unique, db-assigned identifier that references the // Tower with which the session is negotiated. - TowerID uint64 + TowerID TowerID // Tower holds the pubkey and address of the watchtower. // diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go index ff7a48df..e4f28781 100644 --- a/watchtower/wtdb/tower.go +++ b/watchtower/wtdb/tower.go @@ -15,12 +15,31 @@ var ( ErrTowerNotFound = errors.New("tower not found") ) +// TowerID is a unique 64-bit identifier allocated to each unique watchtower. +// This allows the client to conserve on-disk space by not needing to always +// reference towers by their pubkey. +type TowerID uint64 + +// TowerIDFromBytes constructs a TowerID from the provided byte slice. The +// argument must have at least 8 bytes, and should contain the TowerID in +// big-endian byte order. +func TowerIDFromBytes(towerIDBytes []byte) TowerID { + return TowerID(byteOrder.Uint64(towerIDBytes)) +} + +// Bytes encodes a TowerID into an 8-byte slice in big-endian byte order. +func (id TowerID) Bytes() []byte { + var buf [8]byte + byteOrder.PutUint64(buf[:], uint64(id)) + return buf[:] +} + // Tower holds the necessary components required to connect to a remote tower. // Communication is handled by brontide, and requires both a public key and an // address. type Tower struct { // ID is a unique ID for this record assigned by the database. - ID uint64 + ID TowerID // IdentityKey is the public key of the remote node, used to // authenticate the brontide transport. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index a075e7d9..267a290e 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -20,11 +20,11 @@ type ClientDB struct { mu sync.Mutex sweepPkScripts map[lnwire.ChannelID][]byte activeSessions map[wtdb.SessionID]*wtdb.ClientSession - towerIndex map[towerPK]uint64 - towers map[uint64]*wtdb.Tower + towerIndex map[towerPK]wtdb.TowerID + towers map[wtdb.TowerID]*wtdb.Tower nextIndex uint32 - indexes map[uint64]uint32 + indexes map[wtdb.TowerID]uint32 } // NewClientDB initializes a new mock ClientDB. @@ -32,9 +32,9 @@ func NewClientDB() *ClientDB { return &ClientDB{ sweepPkScripts: make(map[lnwire.ChannelID][]byte), activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), - towerIndex: make(map[towerPK]uint64), - towers: make(map[uint64]*wtdb.Tower), - indexes: make(map[uint64]uint32), + towerIndex: make(map[towerPK]wtdb.TowerID), + towers: make(map[wtdb.TowerID]*wtdb.Tower), + indexes: make(map[wtdb.TowerID]uint32), } } @@ -54,9 +54,9 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { tower = m.towers[towerID] tower.AddAddress(lnAddr.Address) } else { - towerID = atomic.AddUint64(&m.nextTowerID, 1) + towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) tower = &wtdb.Tower{ - ID: towerID, + ID: wtdb.TowerID(towerID), IdentityKey: lnAddr.IdentityKey, Addresses: []net.Addr{lnAddr.Address}, } @@ -69,7 +69,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { } // LoadTower retrieves a tower by its tower ID. -func (m *ClientDB) LoadTower(towerID uint64) (*wtdb.Tower, error) { +func (m *ClientDB) LoadTower(towerID wtdb.TowerID) (*wtdb.Tower, error) { m.mu.Lock() defer m.mu.Unlock() @@ -141,7 +141,7 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { // CreateClientSession is invoked for that tower and index, at which point a new // index for that tower can be reserved. Multiple calls to this method before // CreateClientSession is invoked should return the same index. -func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) { +func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) { m.mu.Lock() defer m.mu.Unlock() From 5ad9530502862679d6af1dd7d25a3bedd48308bb Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:47:36 -0700 Subject: [PATCH 03/11] watchtower/wtdb: return sorted ClientSession.CommittedUpdates This commit replaces the map-based CommittedUpdates field with a slice. When reading from disk, these will already be sorted by bbolt, so the client restore the updates as presented without needing to sort them first. Since the key in the map variant was the sequence number, we refactor the CommittedUpdate struct to have a sequence number and an embedded CommittedUpdateBody (which is equivalent to the old CommittedUpdate). The database is then expected to populate the sequence number from the key on disk. Since the sequence number is now directly integrated in the CommittedUpdate struct, this allow allows us to remove the now redundant seqNum argument from CommitUpdate. --- watchtower/wtclient/interface.go | 2 +- watchtower/wtclient/session_queue.go | 68 +++++++--------------------- watchtower/wtdb/client_session.go | 23 ++++++++-- watchtower/wtmock/client_db.go | 62 ++++++++++++++----------- 4 files changed, 71 insertions(+), 84 deletions(-) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 1d091d50..4de81acb 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -61,7 +61,7 @@ type DB interface { // hasn't been ACK'd by the tower. The sequence number of the update // should be exactly one greater than the existing entry, and less that // or equal to the session's MaxUpdates. - CommitUpdate(id *wtdb.SessionID, seqNum uint16, + CommitUpdate(id *wtdb.SessionID, update *wtdb.CommittedUpdate) (uint16, error) // AckUpdate records an acknowledgment from the watchtower that the diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index baf260aa..258e274a 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -3,7 +3,6 @@ package wtclient import ( "container/list" "fmt" - "sort" "sync" "time" @@ -133,7 +132,11 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { } sq.queueCond = sync.NewCond(&sq.queueMtx) - sq.restoreCommittedUpdates() + // The database should return them in sorted order, and session queue's + // sequence number will be equal to that of the last committed update. + for _, update := range sq.cfg.ClientSession.CommittedUpdates { + sq.commitQueue.PushBack(update) + } return sq } @@ -237,45 +240,6 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) { return newStatus, true } -// updateWithSeqNum stores a CommittedUpdate with its assigned sequence number. -// This allows committed updates to be sorted after a restart, and added to the -// commitQueue in the proper order for delivery. -type updateWithSeqNum struct { - seqNum uint16 - update *wtdb.CommittedUpdate -} - -// restoreCommittedUpdates processes any CommittedUpdates loaded on startup by -// sorting them in ascending order of sequence numbers and adding them to the -// commitQueue. These will be sent before any pending updates are processed. -func (q *sessionQueue) restoreCommittedUpdates() { - committedUpdates := q.cfg.ClientSession.CommittedUpdates - - // Construct and unordered slice of all committed updates with their - // assigned sequence numbers. - sortedUpdates := make([]updateWithSeqNum, 0, len(committedUpdates)) - for seqNum, update := range committedUpdates { - sortedUpdates = append(sortedUpdates, updateWithSeqNum{ - seqNum: seqNum, - update: update, - }) - } - - // Sort the resulting slice by increasing sequence number. - sort.Slice(sortedUpdates, func(i, j int) bool { - return sortedUpdates[i].seqNum < sortedUpdates[j].seqNum - }) - - // Finally, add the sorted, committed updates to he commitQueue. These - // updates will be prioritized before any new tasks are assigned to the - // sessionQueue. The queue will begin uploading any tasks in the - // commitQueue as soon as it is started, e.g. during client - // initialization when detecting that this session has unacked updates. - for _, update := range sortedUpdates { - q.commitQueue.PushBack(update) - } -} - // sessionManager is the primary event loop for the sessionQueue, and is // responsible for encrypting and sending accepted tasks to the tower. func (q *sessionQueue) sessionManager() { @@ -396,7 +360,7 @@ func (q *sessionQueue) drainBackups() { func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { var ( seqNum uint16 - update *wtdb.CommittedUpdate + update wtdb.CommittedUpdate isLast bool isPending bool ) @@ -407,10 +371,9 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { // If the commit queue is non-empty, parse the next committed update. case q.commitQueue.Len() > 0: next := q.commitQueue.Front() - updateWithSeq := next.Value.(updateWithSeqNum) - seqNum = updateWithSeq.seqNum - update = updateWithSeq.update + update = next.Value.(wtdb.CommittedUpdate) + seqNum = update.SeqNum // If this is the last item in the commit queue and no items // exist in the pending queue, we will use the IsComplete flag @@ -449,10 +412,13 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { } // TODO(conner): special case other obscure errors - update = &wtdb.CommittedUpdate{ - BackupID: task.id, - Hint: hint, - EncryptedBlob: encBlob, + update = wtdb.CommittedUpdate{ + SeqNum: seqNum, + CommittedUpdateBody: wtdb.CommittedUpdateBody{ + BackupID: task.id, + Hint: hint, + EncryptedBlob: encBlob, + }, } log.Debugf("Committing state update for session=%s seqnum=%d", @@ -470,7 +436,7 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { // we send the next time. This step ensures that if we reliably send the // same update for a given sequence number, to prevent us from thinking // we backed up a state when we instead backed up another. - lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), seqNum, update) + lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), &update) if err != nil { // TODO(conner): mark failed/reschedule return nil, false, fmt.Errorf("unable to commit state update "+ @@ -478,7 +444,7 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { } stateUpdate := &wtwire.StateUpdate{ - SeqNum: seqNum, + SeqNum: update.SeqNum, LastApplied: lastApplied, Hint: update.Hint, EncryptedBlob: update.EncryptedBlob, diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index ded54d8a..9eed54f4 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -86,10 +86,10 @@ type ClientSession struct { // specifies a reward output. RewardPkScript []byte - // CommittedUpdates is a map from allocated sequence numbers to unacked - // updates. These updates can be resent after a restart if the update - // failed to send or receive an acknowledgment. - CommittedUpdates map[uint16]*CommittedUpdate + // CommittedUpdates is a sorted list of unacked updates. These updates + // can be resent after a restart if the updates failed to send or + // receive an acknowledgment. + CommittedUpdates []CommittedUpdate // AckedUpdates is a map from sequence number to backup id to record // which revoked states were uploaded via this session. @@ -107,8 +107,21 @@ type BackupID struct { } // CommittedUpdate holds a state update sent by a client along with its -// SessionID. +// allocated sequence number and the exact remote commitment the encrypted +// justice transaction can rectify. type CommittedUpdate struct { + // SeqNum is the unique sequence number allocated by the session to this + // update. + SeqNum uint16 + + CommittedUpdateBody +} + +// CommittedUpdateBody represents the primary components of a CommittedUpdate. +// On disk, this is stored under the sequence number, which acts as its key. +type CommittedUpdateBody struct { + // BackupID identifies the breached commitment that the encrypted blob + // can spend from. BackupID BackupID // Hint is the 16-byte prefix of the revoked commitment transaction ID. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 267a290e..e4e13c83 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -129,7 +129,7 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { SeqNum: session.SeqNum, TowerLastApplied: session.TowerLastApplied, RewardPkScript: cloneBytes(session.RewardPkScript), - CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate), + CommittedUpdates: make([]wtdb.CommittedUpdate, 0), AckedUpdates: make(map[uint16]wtdb.BackupID), } @@ -159,7 +159,7 @@ func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) { // CommitUpdate persists the CommittedUpdate provided in the slot for (session, // seqNum). This allows the client to retransmit this update on startup. -func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16, +func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, update *wtdb.CommittedUpdate) (uint16, error) { m.mu.Lock() @@ -172,25 +172,26 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16, } // Check if an update has already been committed for this state. - dbUpdate, ok := session.CommittedUpdates[seqNum] - if ok { - // If the breach hint matches, we'll just return the last - // applied value so the client can retransmit. - if dbUpdate.Hint == update.Hint { - return session.TowerLastApplied, nil - } + for _, dbUpdate := range session.CommittedUpdates { + if dbUpdate.SeqNum == update.SeqNum { + // If the breach hint matches, we'll just return the + // last applied value so the client can retransmit. + if dbUpdate.Hint == update.Hint { + return session.TowerLastApplied, nil + } - // Otherwise, fail since the breach hint doesn't match. - return 0, wtdb.ErrUpdateAlreadyCommitted + // Otherwise, fail since the breach hint doesn't match. + return 0, wtdb.ErrUpdateAlreadyCommitted + } } // Sequence number must increment. - if seqNum != session.SeqNum+1 { + if update.SeqNum != session.SeqNum+1 { return 0, wtdb.ErrCommitUnorderedUpdate } // Save the update and increment the sequence number. - session.CommittedUpdates[seqNum] = update + session.CommittedUpdates = append(session.CommittedUpdates, *update) session.SeqNum++ return session.TowerLastApplied, nil @@ -209,13 +210,6 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err return wtdb.ErrClientSessionNotFound } - // Retrieve the committed update, failing if none is found. We should - // only receive acks for state updates that we send. - update, ok := session.CommittedUpdates[seqNum] - if !ok { - return wtdb.ErrCommittedUpdateNotFound - } - // Ensure the returned last applied value does not exceed the highest // allocated sequence number. if lastApplied > session.SeqNum { @@ -228,14 +222,28 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err return wtdb.ErrLastAppliedReversion } - // Finally, remove the committed update from disk and mark the update as - // acked. The tower last applied value is also recorded to send along - // with the next update. - delete(session.CommittedUpdates, seqNum) - session.AckedUpdates[seqNum] = update.BackupID - session.TowerLastApplied = lastApplied + // Retrieve the committed update, failing if none is found. We should + // only receive acks for state updates that we send. + updates := session.CommittedUpdates + for i, update := range updates { + if update.SeqNum != seqNum { + continue + } - return nil + // Remove the committed update from disk and mark the update as + // acked. The tower last applied value is also recorded to send + // along with the next update. + copy(updates[:i], updates[i+1:]) + updates[len(updates)-1] = wtdb.CommittedUpdate{} + session.CommittedUpdates = updates[:len(updates)-1] + + session.AckedUpdates[seqNum] = update.BackupID + session.TowerLastApplied = lastApplied + + return nil + } + + return wtdb.ErrCommittedUpdateNotFound } // FetchChanPkScripts returns the set of sweep pkscripts known for all channels. From 1db9bf2fd4b1e8e3e25aeb7fd594c18af412d04b Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:47:49 -0700 Subject: [PATCH 04/11] watchtower/wtdb: create embedded ClientSessionBody This commit splits out the portions of the ClientSession into an embedded ClientSessionBody, since these fields will be serialized together on-disk. --- watchtower/wtclient/backup_task.go | 3 +- .../wtclient/backup_task_internal_test.go | 4 +- watchtower/wtclient/session_negotiator.go | 11 ++-- watchtower/wtclient/session_queue.go | 2 +- watchtower/wtdb/client_session.go | 60 ++++++++++++------- watchtower/wtmock/client_db.go | 16 ++--- 6 files changed, 59 insertions(+), 37 deletions(-) diff --git a/watchtower/wtclient/backup_task.go b/watchtower/wtclient/backup_task.go index c88bfd0f..72e14934 100644 --- a/watchtower/wtclient/backup_task.go +++ b/watchtower/wtclient/backup_task.go @@ -126,8 +126,7 @@ func (t *backupTask) inputs() map[wire.OutPoint]input.Input { // SessionInfo's policy. If no error is returned, the task has been bound to the // session and can be queued to upload to the tower. Otherwise, the bind failed // and should be rescheduled with a different session. -func (t *backupTask) bindSession(session *wtdb.ClientSession) error { - +func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // First we'll begin by deriving a weight estimate for the justice // transaction. The final weight can be different depending on whether // the watchtower is taking a reward. diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 2c25c9a0..869c4042 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -69,7 +69,7 @@ type backupTaskTest struct { expSweepAmt int64 expRewardAmt int64 expRewardScript []byte - session *wtdb.ClientSession + session *wtdb.ClientSessionBody bindErr error expSweepScript []byte signer input.Signer @@ -205,7 +205,7 @@ func genTaskTest( expSweepAmt: expSweepAmt, expRewardAmt: expRewardAmt, expRewardScript: rewardScript, - session: &wtdb.ClientSession{ + session: &wtdb.ClientSessionBody{ Policy: wtpolicy.Policy{ BlobType: blobType, SweepFeeRate: sweepFeeRate, diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index 355701bd..e3d58c1f 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -417,14 +417,15 @@ func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey, privKey.PubKey(), ) clientSession := &wtdb.ClientSession{ - TowerID: tower.ID, + ClientSessionBody: wtdb.ClientSessionBody{ + TowerID: tower.ID, + KeyIndex: keyIndex, + Policy: n.cfg.Policy, + RewardPkScript: rewardPkScript, + }, Tower: tower, - KeyIndex: keyIndex, SessionPrivKey: privKey, ID: sessionID, - Policy: n.cfg.Policy, - SeqNum: 0, - RewardPkScript: rewardPkScript, } err = n.cfg.DB.CreateClientSession(clientSession) diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index 258e274a..39ab0a43 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -215,7 +215,7 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) { // // TODO(conner): queue backups and retry with different session params. case reserveAvailable: - err := task.bindSession(q.cfg.ClientSession) + err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody) if err != nil { q.queueCond.L.Unlock() log.Debugf("SessionQueue %s rejected backup chanid=%s "+ diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index 9eed54f4..5b2d39d7 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -46,8 +46,48 @@ var ( type ClientSession struct { // ID is the client's public key used when authenticating with the // tower. + // + // NOTE: This value is not serialized with the body of the struct, it + // should be set and recovered as the ClientSession's key. ID SessionID + ClientSessionBody + + // CommittedUpdates is a sorted list of unacked updates. These updates + // can be resent after a restart if the updates failed to send or + // receive an acknowledgment. + // + // NOTE: This list is serialized in it's own bucket, separate from the + // body of the ClientSession. The representation on disk is a key value + // map from sequence number to CommittedUpdateBody to allow efficient + // insertion and retrieval. + CommittedUpdates []CommittedUpdate + + // AckedUpdates is a map from sequence number to backup id to record + // which revoked states were uploaded via this session. + // + // NOTE: This map is serialized in it's own bucket, separate from the + // body of the ClientSession. + AckedUpdates map[uint16]BackupID + + // Tower holds the pubkey and address of the watchtower. + // + // NOTE: This value is not serialized. It is recovered by looking up the + // tower with TowerID. + Tower *Tower + + // SessionPrivKey is the ephemeral secret key used to connect to the + // watchtower. + // + // NOTE: This value is not serialized. It is derived using the KeyIndex + // on startup to avoid storing private keys on disk. + SessionPrivKey *btcec.PrivateKey +} + +// ClientSessionBody represents the primary components of a ClientSession that +// are serialized together within the database. The CommittedUpdates and +// AckedUpdates are serialized in buckets separate from the body. +type ClientSessionBody struct { // SeqNum is the next unallocated sequence number that can be sent to // the tower. SeqNum uint16 @@ -59,25 +99,12 @@ type ClientSession struct { // Tower with which the session is negotiated. TowerID TowerID - // Tower holds the pubkey and address of the watchtower. - // - // NOTE: This value is not serialized. It is recovered by looking up the - // tower with TowerID. - Tower *Tower - // KeyIndex is the index of key locator used to derive the client's // session key so that it can authenticate with the tower to update its // session. In order to rederive the private key, the key locator should // use the keychain.KeyFamilyTowerSession key family. KeyIndex uint32 - // SessionPrivKey is the ephemeral secret key used to connect to the - // watchtower. - // - // NOTE: This value is not serialized. It is derived using the KeyIndex - // on startup to avoid storing private keys on disk. - SessionPrivKey *btcec.PrivateKey - // Policy holds the negotiated session parameters. Policy wtpolicy.Policy @@ -86,14 +113,7 @@ type ClientSession struct { // specifies a reward output. RewardPkScript []byte - // CommittedUpdates is a sorted list of unacked updates. These updates - // can be resent after a restart if the updates failed to send or - // receive an acknowledgment. - CommittedUpdates []CommittedUpdate - // AckedUpdates is a map from sequence number to backup id to record - // which revoked states were uploaded via this session. - AckedUpdates map[uint16]BackupID } // BackupID identifies a particular revoked, remote commitment by channel id and diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index e4e13c83..32898e4e 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -122,13 +122,15 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { delete(m.indexes, session.TowerID) m.activeSessions[session.ID] = &wtdb.ClientSession{ - TowerID: session.TowerID, - KeyIndex: session.KeyIndex, - ID: session.ID, - Policy: session.Policy, - SeqNum: session.SeqNum, - TowerLastApplied: session.TowerLastApplied, - RewardPkScript: cloneBytes(session.RewardPkScript), + ID: session.ID, + ClientSessionBody: wtdb.ClientSessionBody{ + SeqNum: session.SeqNum, + TowerLastApplied: session.TowerLastApplied, + TowerID: session.TowerID, + KeyIndex: session.KeyIndex, + Policy: session.Policy, + RewardPkScript: cloneBytes(session.RewardPkScript), + }, CommittedUpdates: make([]wtdb.CommittedUpdate, 0), AckedUpdates: make(map[uint16]wtdb.BackupID), } From 2a904cb69f7ac7803a0834602301eec2bc765aa5 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:48:08 -0700 Subject: [PATCH 05/11] watchtower/wtdb: add Encode/Decode methods to wtclient structs --- channeldb/codec.go | 10 ++ watchtower/wtdb/client_session.go | 73 ++++++++++++ watchtower/wtdb/codec_test.go | 181 +++++++++++++++++++++++++++++- watchtower/wtdb/tower.go | 32 ++++-- 4 files changed, 284 insertions(+), 12 deletions(-) diff --git a/channeldb/codec.go b/channeldb/codec.go index 1da362dd..ec6e165b 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -103,6 +103,11 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case lnwire.ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + case uint64: if err := binary.Write(w, byteOrder, e); err != nil { return err @@ -259,6 +264,11 @@ func ReadElement(r io.Reader, element interface{}) error { } *e = lnwire.NewShortChanIDFromInt(a) + case *lnwire.ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + case *uint64: if err := binary.Read(r, byteOrder, e); err != nil { return err diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index 5b2d39d7..ab068683 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -2,6 +2,7 @@ package wtdb import ( "errors" + "io" "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/lnwire" @@ -112,8 +113,38 @@ type ClientSessionBody struct { // deposited to if a sweep transaction confirms and the sessions // specifies a reward output. RewardPkScript []byte +} +// Encode writes a ClientSessionBody to the passed io.Writer. +func (s *ClientSessionBody) Encode(w io.Writer) error { + return WriteElements(w, + s.SeqNum, + s.TowerLastApplied, + uint64(s.TowerID), + s.KeyIndex, + s.Policy, + s.RewardPkScript, + ) +} +// Decode reads a ClientSessionBody from the passed io.Reader. +func (s *ClientSessionBody) Decode(r io.Reader) error { + var towerID uint64 + err := ReadElements(r, + &s.SeqNum, + &s.TowerLastApplied, + &towerID, + &s.KeyIndex, + &s.Policy, + &s.RewardPkScript, + ) + if err != nil { + return err + } + + s.TowerID = TowerID(towerID) + + return nil } // BackupID identifies a particular revoked, remote commitment by channel id and @@ -126,6 +157,22 @@ type BackupID struct { CommitHeight uint64 } +// Encode writes the BackupID from the passed io.Writer. +func (b *BackupID) Encode(w io.Writer) error { + return WriteElements(w, + b.ChanID, + b.CommitHeight, + ) +} + +// Decode reads a BackupID from the passed io.Reader. +func (b *BackupID) Decode(r io.Reader) error { + return ReadElements(r, + &b.ChanID, + &b.CommitHeight, + ) +} + // CommittedUpdate holds a state update sent by a client along with its // allocated sequence number and the exact remote commitment the encrypted // justice transaction can rectify. @@ -152,3 +199,29 @@ type CommittedUpdateBody struct { // hint is broadcast. EncryptedBlob []byte } + +// Encode writes the CommittedUpdateBody to the passed io.Writer. +func (u *CommittedUpdateBody) Encode(w io.Writer) error { + err := u.BackupID.Encode(w) + if err != nil { + return err + } + + return WriteElements(w, + u.Hint, + u.EncryptedBlob, + ) +} + +// Decode reads a CommittedUpdateBody from the passed io.Reader. +func (u *CommittedUpdateBody) Decode(r io.Reader) error { + err := u.BackupID.Decode(r) + if err != nil { + return err + } + + return ReadElements(r, + &u.Hint, + &u.EncryptedBlob, + ) +} diff --git a/watchtower/wtdb/codec_test.go b/watchtower/wtdb/codec_test.go index 948ec4ee..21e11c6f 100644 --- a/watchtower/wtdb/codec_test.go +++ b/watchtower/wtdb/codec_test.go @@ -2,14 +2,122 @@ package wtdb_test import ( "bytes" + "encoding/binary" "io" + "math/rand" + "net" "reflect" "testing" "testing/quick" + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/wtdb" ) +func randPubKey() (*btcec.PublicKey, error) { + priv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return nil, err + } + + return priv.PubKey(), nil +} + +func randTCP4Addr(r *rand.Rand) (*net.TCPAddr, error) { + var ip [4]byte + if _, err := r.Read(ip[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + addrIP := net.IP(ip[:]) + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil +} + +func randTCP6Addr(r *rand.Rand) (*net.TCPAddr, error) { + var ip [16]byte + if _, err := r.Read(ip[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + addrIP := net.IP(ip[:]) + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil +} + +func randV2OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { + var serviceID [tor.V2DecodedLen]byte + if _, err := r.Read(serviceID[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) + onionService += tor.OnionSuffix + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil +} + +func randV3OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { + var serviceID [tor.V3DecodedLen]byte + if _, err := r.Read(serviceID[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) + onionService += tor.OnionSuffix + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil +} + +func randAddrs(r *rand.Rand) ([]net.Addr, error) { + tcp4Addr, err := randTCP4Addr(r) + if err != nil { + return nil, err + } + + tcp6Addr, err := randTCP6Addr(r) + if err != nil { + return nil, err + } + + v2OnionAddr, err := randV2OnionAddr(r) + if err != nil { + return nil, err + } + + v3OnionAddr, err := randV3OnionAddr(r) + if err != nil { + return nil, err + } + + return []net.Addr{tcp4Addr, tcp6Addr, v2OnionAddr, v3OnionAddr}, nil +} + // dbObject is abstract object support encoding and decoding. type dbObject interface { Encode(io.Writer) error @@ -19,7 +127,9 @@ type dbObject interface { // TestCodec serializes and deserializes wtdb objects in order to test that that // the codec understands all of the required field types. The test also asserts // that decoding an object into another results in an equivalent object. -func TestCodec(t *testing.T) { +func TestCodec(tt *testing.T) { + + var t *testing.T mainScenario := func(obj dbObject) bool { // Ensure encoding the object succeeds. var b bytes.Buffer @@ -35,6 +145,14 @@ func TestCodec(t *testing.T) { obj2 = &wtdb.SessionInfo{} case *wtdb.SessionStateUpdate: obj2 = &wtdb.SessionStateUpdate{} + case *wtdb.ClientSessionBody: + obj2 = &wtdb.ClientSessionBody{} + case *wtdb.CommittedUpdateBody: + obj2 = &wtdb.CommittedUpdateBody{} + case *wtdb.BackupID: + obj2 = &wtdb.BackupID{} + case *wtdb.Tower: + obj2 = &wtdb.Tower{} default: t.Fatalf("unknown type: %T", obj) return false @@ -57,6 +175,29 @@ func TestCodec(t *testing.T) { return true } + customTypeGen := map[string]func([]reflect.Value, *rand.Rand){ + "Tower": func(v []reflect.Value, r *rand.Rand) { + pk, err := randPubKey() + if err != nil { + t.Fatalf("unable to generate pubkey: %v", err) + return + } + + addrs, err := randAddrs(r) + if err != nil { + t.Fatalf("unable to generate addrs: %v", err) + return + } + + obj := wtdb.Tower{ + IdentityKey: pk, + Addresses: addrs, + } + + v[0] = reflect.ValueOf(obj) + }, + } + tests := []struct { name string scenario interface{} @@ -73,11 +214,45 @@ func TestCodec(t *testing.T) { return mainScenario(&obj) }, }, + { + name: "ClientSessionBody", + scenario: func(obj wtdb.ClientSessionBody) bool { + return mainScenario(&obj) + }, + }, + { + name: "CommittedUpdateBody", + scenario: func(obj wtdb.CommittedUpdateBody) bool { + return mainScenario(&obj) + }, + }, + { + name: "BackupID", + scenario: func(obj wtdb.BackupID) bool { + return mainScenario(&obj) + }, + }, + { + name: "Tower", + scenario: func(obj wtdb.Tower) bool { + return mainScenario(&obj) + }, + }, } for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if err := quick.Check(test.scenario, nil); err != nil { + tt.Run(test.name, func(h *testing.T) { + t = h + + var config *quick.Config + if valueGen, ok := customTypeGen[test.name]; ok { + config = &quick.Config{ + Values: valueGen, + } + } + + err := quick.Check(test.scenario, config) + if err != nil { t.Fatalf("fuzz checks for msg=%s failed: %v", test.name, err) } diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go index e4f28781..518da750 100644 --- a/watchtower/wtdb/tower.go +++ b/watchtower/wtdb/tower.go @@ -2,8 +2,8 @@ package wtdb import ( "errors" + "io" "net" - "sync" "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/lnwire" @@ -47,18 +47,15 @@ type Tower struct { // Addresses is a list of possible addresses to reach the tower. Addresses []net.Addr - - mu sync.RWMutex } // AddAddress adds the given address to the tower's in-memory list of addresses. // If the address's string is already present, the Tower will be left // unmodified. Otherwise, the adddress is prepended to the beginning of the // Tower's addresses, on the assumption that it is fresher than the others. +// +// NOTE: This method is NOT safe for concurrent use. func (t *Tower) AddAddress(addr net.Addr) { - t.mu.Lock() - defer t.mu.Unlock() - // Ensure we don't add a duplicate address. addrStr := addr.String() for _, existingAddr := range t.Addresses { @@ -75,10 +72,9 @@ func (t *Tower) AddAddress(addr net.Addr) { // LNAddrs generates a list of lnwire.NetAddress from a Tower instance's // addresses. This can be used to have a client try multiple addresses for the // same Tower. +// +// NOTE: This method is NOT safe for concurrent use. func (t *Tower) LNAddrs() []*lnwire.NetAddress { - t.mu.RLock() - defer t.mu.RUnlock() - addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses)) for _, addr := range t.Addresses { addrs = append(addrs, &lnwire.NetAddress{ @@ -89,3 +85,21 @@ func (t *Tower) LNAddrs() []*lnwire.NetAddress { return addrs } + +// Encode writes the Tower to the passed io.Writer. The TowerID is not +// serialized, since it acts as the key. +func (t *Tower) Encode(w io.Writer) error { + return WriteElements(w, + t.IdentityKey, + t.Addresses, + ) +} + +// Decode reads a Tower from the passed io.Reader. The TowerID is meant to be +// decoded from the key. +func (t *Tower) Decode(r io.Reader) error { + return ReadElements(r, + &t.IdentityKey, + &t.Addresses, + ) +} From 440ae7818ae2630939b8e6565bffb266cd742ef8 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:48:23 -0700 Subject: [PATCH 06/11] watchtower/wtmock/client_db: adjust mock clientdb behavior In advance of the upcoming wtdb.ClientDB, we'll modify the behavior of the mockdb to be more like the final bbolt backed one, and assert that all or our tests are still passing. --- watchtower/wtdb/client_session.go | 6 ++++++ watchtower/wtmock/client_db.go | 23 +++++++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index ab068683..d29e1f5f 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -39,6 +39,12 @@ var ( // created because session key index differs from the reserved key // index. ErrIncorrectKeyIndex = errors.New("incorrect key index") + + // ErrClientSessionAlreadyExists signals an attempt to reinsert + // a client session that has already been created. + ErrClientSessionAlreadyExists = errors.New( + "client session already exists", + ) ) // ClientSession encapsulates a SessionInfo returned from a successful diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 32898e4e..b903e78a 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -65,7 +65,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { m.towerIndex[towerPubKey] = towerID m.towers[towerID] = tower - return tower, nil + return copyTower(tower), nil } // LoadTower retrieves a tower by its tower ID. @@ -74,7 +74,7 @@ func (m *ClientDB) LoadTower(towerID wtdb.TowerID) (*wtdb.Tower, error) { defer m.mu.Unlock() if tower, ok := m.towers[towerID]; ok { - return tower, nil + return copyTower(tower), nil } return nil, wtdb.ErrTowerNotFound @@ -106,6 +106,11 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { m.mu.Lock() defer m.mu.Unlock() + // Ensure that we aren't overwriting an existing session. + if _, ok := m.activeSessions[session.ID]; ok { + return wtdb.ErrClientSessionAlreadyExists + } + // Ensure that a session key index has been reserved for this tower. keyIndex, ok := m.indexes[session.TowerID] if !ok { @@ -151,11 +156,10 @@ func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) { return index, nil } + m.nextIndex++ index := m.nextIndex m.indexes[towerID] = index - m.nextIndex++ - return index, nil } @@ -286,3 +290,14 @@ func cloneBytes(b []byte) []byte { return bb } + +func copyTower(tower *wtdb.Tower) *wtdb.Tower { + t := &wtdb.Tower{ + ID: tower.ID, + IdentityKey: tower.IdentityKey, + Addresses: make([]net.Addr, len(tower.Addresses)), + } + copy(t.Addresses, tower.Addresses) + + return t +} From 25fc464a6e91a3d7a0c263206fe7c076065558f8 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:48:36 -0700 Subject: [PATCH 07/11] watchtower/wtdb/client_chan_summary: add ClientChanSummary A ClientChanSummary will be inserted for each channel registered with the client, which for now will just track the sweep pkscript to use. In the future, this will be extended with additional information to enable the client to efficiently compute which historical states need to be backed up under a given policy. --- watchtower/wtdb/client_chan_summary.go | 32 ++++++++++++++++++++++++++ watchtower/wtdb/codec_test.go | 8 +++++++ 2 files changed, 40 insertions(+) create mode 100644 watchtower/wtdb/client_chan_summary.go diff --git a/watchtower/wtdb/client_chan_summary.go b/watchtower/wtdb/client_chan_summary.go new file mode 100644 index 00000000..d4b3c3c3 --- /dev/null +++ b/watchtower/wtdb/client_chan_summary.go @@ -0,0 +1,32 @@ +package wtdb + +import ( + "io" + + "github.com/lightningnetwork/lnd/lnwire" +) + +// ChannelSummaries is a map for a given channel id to it's ClientChanSummary. +type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary + +// ClientChanSummary tracks channel-specific information. A new +// ClientChanSummary is inserted in the database the first time the client +// encounters a particular channel. +type ClientChanSummary struct { + // SweepPkScript is the pkscript to which all justice transactions will + // deposit recovered funds for this particular channel. + SweepPkScript []byte + + // TODO(conner): later extend with info about initial commit height, + // ineligible states, etc. +} + +// Encode writes the ClientChanSummary to the passed io.Writer. +func (s *ClientChanSummary) Encode(w io.Writer) error { + return WriteElement(w, s.SweepPkScript) +} + +// Decode reads a ClientChanSummary form the passed io.Reader. +func (s *ClientChanSummary) Decode(r io.Reader) error { + return ReadElement(r, &s.SweepPkScript) +} diff --git a/watchtower/wtdb/codec_test.go b/watchtower/wtdb/codec_test.go index 21e11c6f..69c7b059 100644 --- a/watchtower/wtdb/codec_test.go +++ b/watchtower/wtdb/codec_test.go @@ -153,6 +153,8 @@ func TestCodec(tt *testing.T) { obj2 = &wtdb.BackupID{} case *wtdb.Tower: obj2 = &wtdb.Tower{} + case *wtdb.ClientChanSummary: + obj2 = &wtdb.ClientChanSummary{} default: t.Fatalf("unknown type: %T", obj) return false @@ -238,6 +240,12 @@ func TestCodec(tt *testing.T) { return mainScenario(&obj) }, }, + { + name: "ClientChanSummary", + scenario: func(obj wtdb.ClientChanSummary) bool { + return mainScenario(&obj) + }, + }, } for _, test := range tests { From b35a5b8892c6c4fd6824b70e252f98fe1caf104e Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:48:50 -0700 Subject: [PATCH 08/11] watchtower/wtclient: integrate ClientChannelSummaries In this commit, we utilize the more generic ClientChanSummary instead of exposing methods that only allow us to set and fetch sweep pkscripts. --- watchtower/wtclient/client.go | 26 +++++++++-------- watchtower/wtclient/client_test.go | 2 ++ watchtower/wtclient/interface.go | 17 ++++++----- watchtower/wtdb/client_chan_summary.go | 7 +++++ watchtower/wtmock/client_db.go | 39 ++++++++++++++++---------- 5 files changed, 57 insertions(+), 34 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 1e614cc2..6a037570 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -150,8 +150,8 @@ type TowerClient struct { sessionQueue *sessionQueue prevTask *backupTask - sweepPkScriptMu sync.RWMutex - sweepPkScripts map[lnwire.ChannelID][]byte + summaryMu sync.RWMutex + summaries wtdb.ChannelSummaries statTicker *time.Ticker stats clientStats @@ -245,7 +245,7 @@ func New(config *Config) (*TowerClient, error) { // Finally, load the sweep pkscripts that have been generated for all // previously registered channels. - c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts() + c.summaries, err = c.cfg.DB.FetchChanSummaries() if err != nil { return nil, err } @@ -388,12 +388,12 @@ func (c *TowerClient) ForceQuit() { // within the client. This should be called during link startup to ensure that // the client is able to support the link during operation. func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { - c.sweepPkScriptMu.Lock() - defer c.sweepPkScriptMu.Unlock() + c.summaryMu.Lock() + defer c.summaryMu.Unlock() // If a pkscript for this channel already exists, the channel has been // previously registered. - if _, ok := c.sweepPkScripts[chanID]; ok { + if _, ok := c.summaries[chanID]; ok { return nil } @@ -406,14 +406,16 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // Persist the sweep pkscript so that restarts will not introduce // address inflation when the channel is reregistered after a restart. - err = c.cfg.DB.AddChanPkScript(chanID, pkScript) + err = c.cfg.DB.RegisterChannel(chanID, pkScript) if err != nil { return err } // Finally, cache the pkscript in our in-memory cache to avoid db // lookups for the remainder of the daemon's execution. - c.sweepPkScripts[chanID] = pkScript + c.summaries[chanID] = wtdb.ClientChanSummary{ + SweepPkScript: pkScript, + } return nil } @@ -429,14 +431,14 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, breachInfo *lnwallet.BreachRetribution) error { // Retrieve the cached sweep pkscript used for this channel. - c.sweepPkScriptMu.RLock() - sweepPkScript, ok := c.sweepPkScripts[*chanID] - c.sweepPkScriptMu.RUnlock() + c.summaryMu.RLock() + summary, ok := c.summaries[*chanID] + c.summaryMu.RUnlock() if !ok { return ErrUnregisteredChannel } - task := newBackupTask(chanID, breachInfo, sweepPkScript) + task := newBackupTask(chanID, breachInfo, summary.SweepPkScript) return c.pipeline.QueueBackupTask(task) } diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 86811bf0..ac4ebf2d 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -605,6 +605,8 @@ func (h *testHarness) backupStates(id, from, to uint64, expErr error) { // backupStates instructs the channel identified by id to send a backup for // state i. func (h *testHarness) backupState(id, i uint64, expErr error) { + h.t.Helper() + _, retribution := h.channel(id).getState(i) chanID := chanIDFromInt(id) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 4de81acb..e8a8b865 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -41,14 +41,17 @@ type DB interface { // still be able to accept state updates. ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) - // FetchChanPkScripts returns a map of all sweep pkscripts for - // registered channels. This is used on startup to cache the sweep - // pkscripts of registered channels in memory. - FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) + // FetchChanSummaries loads a mapping from all registered channels to + // their channel summaries. + FetchChanSummaries() (wtdb.ChannelSummaries, error) - // AddChanPkScript inserts a newly generated sweep pkscript for the - // given channel. - AddChanPkScript(lnwire.ChannelID, []byte) error + // RegisterChannel registers a channel for use within the client + // database. For now, all that is stored in the channel summary is the + // sweep pkscript that we'd like any tower sweeps to pay into. In the + // future, this will be extended to contain more info to allow the + // client efficiently request historical states to be backed up under + // the client's active policy. + RegisterChannel(lnwire.ChannelID, []byte) error // MarkBackupIneligible records that the state identified by the // (channel id, commit height) tuple was ineligible for being backed up diff --git a/watchtower/wtdb/client_chan_summary.go b/watchtower/wtdb/client_chan_summary.go index d4b3c3c3..0925150a 100644 --- a/watchtower/wtdb/client_chan_summary.go +++ b/watchtower/wtdb/client_chan_summary.go @@ -1,11 +1,18 @@ package wtdb import ( + "errors" "io" "github.com/lightningnetwork/lnd/lnwire" ) +var ( + // ErrChannelAlreadyRegistered signals a duplicate attempt to + // register a channel with the client database. + ErrChannelAlreadyRegistered = errors.New("channel already registered") +) + // ChannelSummaries is a map for a given channel id to it's ClientChanSummary. type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index b903e78a..88cde50f 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -1,7 +1,6 @@ package wtmock import ( - "fmt" "net" "sync" "sync/atomic" @@ -18,7 +17,7 @@ type ClientDB struct { nextTowerID uint64 // to be used atomically mu sync.Mutex - sweepPkScripts map[lnwire.ChannelID][]byte + summaries map[lnwire.ChannelID]wtdb.ClientChanSummary activeSessions map[wtdb.SessionID]*wtdb.ClientSession towerIndex map[towerPK]wtdb.TowerID towers map[wtdb.TowerID]*wtdb.Tower @@ -30,7 +29,7 @@ type ClientDB struct { // NewClientDB initializes a new mock ClientDB. func NewClientDB() *ClientDB { return &ClientDB{ - sweepPkScripts: make(map[lnwire.ChannelID][]byte), + summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), towerIndex: make(map[towerPK]wtdb.TowerID), towers: make(map[wtdb.TowerID]*wtdb.Tower), @@ -252,30 +251,40 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err return wtdb.ErrCommittedUpdateNotFound } -// FetchChanPkScripts returns the set of sweep pkscripts known for all channels. -// This allows the client to cache them in memory on startup. -func (m *ClientDB) FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) { +// FetchChanSummaries loads a mapping from all registered channels to their +// channel summaries. +func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { m.mu.Lock() defer m.mu.Unlock() - sweepPkScripts := make(map[lnwire.ChannelID][]byte) - for chanID, pkScript := range m.sweepPkScripts { - sweepPkScripts[chanID] = cloneBytes(pkScript) + summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary) + for chanID, summary := range m.summaries { + summaries[chanID] = wtdb.ClientChanSummary{ + SweepPkScript: cloneBytes(summary.SweepPkScript), + } } - return sweepPkScripts, nil + return summaries, nil } -// AddChanPkScript sets a pkscript or sweeping funds from the channel or chanID. -func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) error { +// RegisterChannel registers a channel for use within the client database. For +// now, all that is stored in the channel summary is the sweep pkscript that +// we'd like any tower sweeps to pay into. In the future, this will be extended +// to contain more info to allow the client efficiently request historical +// states to be backed up under the client's active policy. +func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID, + sweepPkScript []byte) error { + m.mu.Lock() defer m.mu.Unlock() - if _, ok := m.sweepPkScripts[chanID]; ok { - return fmt.Errorf("pkscript for %x already exists", pkScript) + if _, ok := m.summaries[chanID]; ok { + return wtdb.ErrChannelAlreadyRegistered } - m.sweepPkScripts[chanID] = cloneBytes(pkScript) + m.summaries[chanID] = wtdb.ClientChanSummary{ + SweepPkScript: cloneBytes(sweepPkScript), + } return nil } From 3be651b0b3e9b4c55eb589e7e9a7b004cc9d53f9 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:49:04 -0700 Subject: [PATCH 09/11] watchtower/wtdb: add ClientDB This commit adds the full bbolt-backed client database as well as a set of unit tests to assert that it exactly implements the same behavior as the mock ClientDB. --- watchtower/wtdb/client_chan_summary.go | 7 - watchtower/wtdb/client_db.go | 908 +++++++++++++++++++++++++ watchtower/wtdb/client_db_test.go | 688 +++++++++++++++++++ watchtower/wtdb/client_session.go | 39 -- watchtower/wtdb/tower.go | 7 - watchtower/wtdb/version.go | 5 + 6 files changed, 1601 insertions(+), 53 deletions(-) create mode 100644 watchtower/wtdb/client_db.go create mode 100644 watchtower/wtdb/client_db_test.go diff --git a/watchtower/wtdb/client_chan_summary.go b/watchtower/wtdb/client_chan_summary.go index 0925150a..d4b3c3c3 100644 --- a/watchtower/wtdb/client_chan_summary.go +++ b/watchtower/wtdb/client_chan_summary.go @@ -1,18 +1,11 @@ package wtdb import ( - "errors" "io" "github.com/lightningnetwork/lnd/lnwire" ) -var ( - // ErrChannelAlreadyRegistered signals a duplicate attempt to - // register a channel with the client database. - ErrChannelAlreadyRegistered = errors.New("channel already registered") -) - // ChannelSummaries is a map for a given channel id to it's ClientChanSummary. type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go new file mode 100644 index 00000000..92307e99 --- /dev/null +++ b/watchtower/wtdb/client_db.go @@ -0,0 +1,908 @@ +package wtdb + +import ( + "bytes" + "errors" + "fmt" + "math" + "net" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" +) + +const ( + // clientDBName is the filename of client database. + clientDBName = "wtclient.db" +) + +var ( + // cSessionKeyIndexBkt is a top-level bucket storing: + // tower-id -> reserved-session-key-index (uint32). + cSessionKeyIndexBkt = []byte("client-session-key-index-bucket") + + // cChanSummaryBkt is a top-level bucket storing: + // channel-id -> encoded ClientChanSummary. + cChanSummaryBkt = []byte("client-channel-summary-bucket") + + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAcks => seqnum -> encoded BackupID + cSessionBkt = []byte("client-session-bucket") + + // cSessionBody is a sub-bucket of cSessionBkt storing only the body of + // the ClientSession. + cSessionBody = []byte("client-session-body") + + // cSessionBody is a sub-bucket of cSessionBkt storing: + // seqnum -> encoded CommittedUpdate. + cSessionCommits = []byte("client-session-commits") + + // cSessionAcks is a sub-bucket of cSessionBkt storing: + // seqnum -> encoded BackupID. + cSessionAcks = []byte("client-session-acks") + + // cTowerBkt is a top-level bucket storing: + // tower-id -> encoded Tower. + cTowerBkt = []byte("client-tower-bucket") + + // cTowerIndexBkt is a top-level bucket storing: + // tower-pubkey -> tower-id. + cTowerIndexBkt = []byte("client-tower-index-bucket") + + // ErrTowerNotFound signals that the target tower was not found in the + // database. + ErrTowerNotFound = errors.New("tower not found") + + // ErrCorruptClientSession signals that the client session's on-disk + // structure deviates from what is expected. + ErrCorruptClientSession = errors.New("client session corrupted") + + // ErrClientSessionAlreadyExists signals an attempt to reinsert a client + // session that has already been created. + ErrClientSessionAlreadyExists = errors.New( + "client session already exists", + ) + + // ErrChannelAlreadyRegistered signals a duplicate attempt to register a + // channel with the client database. + ErrChannelAlreadyRegistered = errors.New("channel already registered") + + // ErrChannelNotRegistered signals a channel has not yet been registered + // in the client database. + ErrChannelNotRegistered = errors.New("channel not registered") + + // ErrClientSessionNotFound signals that the requested client session + // was not found in the database. + ErrClientSessionNotFound = errors.New("client session not found") + + // ErrUpdateAlreadyCommitted signals that the chosen sequence number has + // already been committed to an update with a different breach hint. + ErrUpdateAlreadyCommitted = errors.New("update already committed") + + // ErrCommitUnorderedUpdate signals the client tried to commit a + // sequence number other than the next unallocated sequence number. + ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic") + + // ErrCommittedUpdateNotFound signals that the tower tried to ACK a + // sequence number that has not yet been allocated by the client. + ErrCommittedUpdateNotFound = errors.New("committed update not found") + + // ErrUnallocatedLastApplied signals that the tower tried to provide a + // LastApplied value greater than any allocated sequence number. + ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " + + "greater than allocated seqnum") + + // ErrNoReservedKeyIndex signals that a client session could not be + // created because no session key index was reserved. + ErrNoReservedKeyIndex = errors.New("key index not reserved") + + // ErrIncorrectKeyIndex signals that the client session could not be + // created because session key index differs from the reserved key + // index. + ErrIncorrectKeyIndex = errors.New("incorrect key index") +) + +// ClientDB is single database providing a persistent storage engine for the +// wtclient. +type ClientDB struct { + db *bbolt.DB + dbPath string +} + +// OpenClientDB opens the client database given the path to the database's +// directory. If no such database exists, this method will initialize a fresh +// one using the latest version number and bucket structure. If a database +// exists but has a lower version number than the current version, any necessary +// migrations will be applied before returning. Any attempt to open a database +// with a version number higher that the latest version will fail to prevent +// accidental reversion. +func OpenClientDB(dbPath string) (*ClientDB, error) { + bdb, firstInit, err := createDBIfNotExist(dbPath, clientDBName) + if err != nil { + return nil, err + } + + clientDB := &ClientDB{ + db: bdb, + dbPath: dbPath, + } + + err = initOrSyncVersions(clientDB, firstInit, clientDBVersions) + if err != nil { + bdb.Close() + return nil, err + } + + // Now that the database version fully consistent with our latest known + // version, ensure that all top-level buckets known to this version are + // initialized. This allows us to assume their presence throughout all + // operations. If an known top-level bucket is expected to exist but is + // missing, this will trigger a ErrUninitializedDB error. + err = clientDB.db.Update(initClientDBBuckets) + if err != nil { + bdb.Close() + return nil, err + } + + return clientDB, nil +} + +// initClientDBBuckets creates all top-level buckets required to handle database +// operations required by the latest version. +func initClientDBBuckets(tx *bbolt.Tx) error { + buckets := [][]byte{ + cSessionKeyIndexBkt, + cChanSummaryBkt, + cSessionBkt, + cTowerBkt, + cTowerIndexBkt, + } + + for _, bucket := range buckets { + _, err := tx.CreateBucketIfNotExists(bucket) + if err != nil { + return err + } + } + + return nil +} + +// bdb returns the backing bbolt.DB instance. +// +// NOTE: Part of the versionedDB interface. +func (c *ClientDB) bdb() *bbolt.DB { + return c.db +} + +// Version returns the database's current version number. +// +// NOTE: Part of the versionedDB interface. +func (c *ClientDB) Version() (uint32, error) { + var version uint32 + err := c.db.View(func(tx *bbolt.Tx) error { + var err error + version, err = getDBVersion(tx) + return err + }) + if err != nil { + return 0, err + } + + return version, nil +} + +// Close closes the underlying database. +func (c *ClientDB) Close() error { + return c.db.Close() +} + +// CreateTower initializes a database entry with the given lightning address. If +// the tower exists, the address is append to the list of all addresses used to +// that tower previously. +func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { + var towerPubKey [33]byte + copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed()) + + var tower *Tower + err := c.db.Update(func(tx *bbolt.Tx) error { + towerIndex := tx.Bucket(cTowerIndexBkt) + if towerIndex == nil { + return ErrUninitializedDB + } + + towers := tx.Bucket(cTowerBkt) + if towers == nil { + return ErrUninitializedDB + } + + // Check if the tower index already knows of this pubkey. + towerIDBytes := towerIndex.Get(towerPubKey[:]) + if len(towerIDBytes) == 8 { + // The tower already exists, deserialize the existing + // record. + var err error + tower, err = getTower(towers, towerIDBytes) + if err != nil { + return err + } + + // Add the new address to the existing tower. If the + // address is a duplicate, this will result in no + // change. + tower.AddAddress(lnAddr.Address) + } else { + // No such tower exists, create a new tower id for our + // new tower. The error is unhandled since NextSequence + // never fails in an Update. + towerID, _ := towerIndex.NextSequence() + + tower = &Tower{ + ID: TowerID(towerID), + IdentityKey: lnAddr.IdentityKey, + Addresses: []net.Addr{lnAddr.Address}, + } + + towerIDBytes = tower.ID.Bytes() + + // Since this tower is new, record the mapping from + // tower pubkey to tower id in the tower index. + err := towerIndex.Put(towerPubKey[:], towerIDBytes) + if err != nil { + return err + } + } + + // Store the new or updated tower under its tower id. + return putTower(towers, tower) + }) + if err != nil { + return nil, err + } + + return tower, nil +} + +// LoadTower retrieves a tower by its tower ID. +func (c *ClientDB) LoadTower(towerID TowerID) (*Tower, error) { + var tower *Tower + err := c.db.View(func(tx *bbolt.Tx) error { + towers := tx.Bucket(cTowerBkt) + if towers == nil { + return ErrUninitializedDB + } + + var err error + tower, err = getTower(towers, towerID.Bytes()) + return err + }) + if err != nil { + return nil, err + } + + return tower, nil +} + +// NextSessionKeyIndex reserves a new session key derivation index for a +// particular tower id. The index is reserved for that tower until +// CreateClientSession is invoked for that tower and index, at which point a new +// index for that tower can be reserved. Multiple calls to this method before +// CreateClientSession is invoked should return the same index. +func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) { + var index uint32 + err := c.db.Update(func(tx *bbolt.Tx) error { + keyIndex := tx.Bucket(cSessionKeyIndexBkt) + if keyIndex == nil { + return ErrUninitializedDB + } + + // Check the session key index to see if a key has already been + // reserved for this tower. If so, we'll deserialize and return + // the index directly. + towerIDBytes := towerID.Bytes() + indexBytes := keyIndex.Get(towerIDBytes) + if len(indexBytes) == 4 { + index = byteOrder.Uint32(indexBytes) + return nil + } + + // Otherwise, generate a new session key index since the node + // doesn't already have reserved index. The error is ignored + // since NextSequence can't fail inside Update. + index64, _ := keyIndex.NextSequence() + + // As a sanity check, assert that the index is still in the + // valid range of unhardened pubkeys. In the future, we should + // move to only using hardened keys, and this will prevent any + // overlap from occurring until then. This also prevents us from + // overflowing uint32s. + if index64 > math.MaxInt32 { + return fmt.Errorf("exhausted session key indexes") + } + + index = uint32(index64) + + var indexBuf [4]byte + byteOrder.PutUint32(indexBuf[:], index) + + // Record the reserved session key index under this tower's id. + return keyIndex.Put(towerIDBytes, indexBuf[:]) + }) + if err != nil { + return 0, err + } + + return index, nil +} + +// CreateClientSession records a newly negotiated client session in the set of +// active sessions. The session can be identified by its SessionID. +func (c *ClientDB) CreateClientSession(session *ClientSession) error { + return c.db.Update(func(tx *bbolt.Tx) error { + keyIndexes := tx.Bucket(cSessionKeyIndexBkt) + if keyIndexes == nil { + return ErrUninitializedDB + } + + sessions := tx.Bucket(cSessionBkt) + if sessions == nil { + return ErrUninitializedDB + } + + // Check that client session with this session id doesn't + // already exist. + existingSessionBytes := sessions.Bucket(session.ID[:]) + if existingSessionBytes != nil { + return ErrClientSessionAlreadyExists + } + + // Check that this tower has a reserved key index. + towerIDBytes := session.TowerID.Bytes() + keyIndexBytes := keyIndexes.Get(towerIDBytes) + if len(keyIndexBytes) != 4 { + return ErrNoReservedKeyIndex + } + + // Assert that the key index of the inserted session matches the + // reserved session key index. + index := byteOrder.Uint32(keyIndexBytes) + if index != session.KeyIndex { + return ErrIncorrectKeyIndex + } + + // Remove the key index reservation. + err := keyIndexes.Delete(towerIDBytes) + if err != nil { + return err + } + + // Finally, write the client session's body in the sessions + // bucket. + return putClientSessionBody(sessions, session) + }) +} + +// ListClientSessions returns the set of all client sessions known to the db. +func (c *ClientDB) ListClientSessions() (map[SessionID]*ClientSession, error) { + clientSessions := make(map[SessionID]*ClientSession) + err := c.db.View(func(tx *bbolt.Tx) error { + sessions := tx.Bucket(cSessionBkt) + if sessions == nil { + return ErrUninitializedDB + } + + return sessions.ForEach(func(k, _ []byte) error { + // We'll load the full client session since the client + // will need the CommittedUpdates and AckedUpdates on + // startup to resume committed updates and compute the + // highest known commit height for each channel. + session, err := getClientSession(sessions, k) + if err != nil { + return err + } + + clientSessions[session.ID] = session + + return nil + }) + }) + if err != nil { + return nil, err + } + + return clientSessions, nil +} + +// FetchChanSummaries loads a mapping from all registered channels to their +// channel summaries. +func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { + summaries := make(map[lnwire.ChannelID]ClientChanSummary) + err := c.db.View(func(tx *bbolt.Tx) error { + chanSummaries := tx.Bucket(cChanSummaryBkt) + if chanSummaries == nil { + return ErrUninitializedDB + } + + return chanSummaries.ForEach(func(k, v []byte) error { + var chanID lnwire.ChannelID + copy(chanID[:], k) + + var summary ClientChanSummary + err := summary.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + + summaries[chanID] = summary + + return nil + }) + }) + if err != nil { + return nil, err + } + + return summaries, nil +} + +// RegisterChannel registers a channel for use within the client database. For +// now, all that is stored in the channel summary is the sweep pkscript that +// we'd like any tower sweeps to pay into. In the future, this will be extended +// to contain more info to allow the client efficiently request historical +// states to be backed up under the client's active policy. +func (c *ClientDB) RegisterChannel(chanID lnwire.ChannelID, + sweepPkScript []byte) error { + + return c.db.Update(func(tx *bbolt.Tx) error { + chanSummaries := tx.Bucket(cChanSummaryBkt) + if chanSummaries == nil { + return ErrUninitializedDB + } + + _, err := getChanSummary(chanSummaries, chanID) + switch { + + // Summary already exists. + case err == nil: + return ErrChannelAlreadyRegistered + + // Channel is not registered, proceed with registration. + case err == ErrChannelNotRegistered: + + // Unexpected error. + case err != nil: + return err + } + + summary := ClientChanSummary{ + SweepPkScript: sweepPkScript, + } + + return putChanSummary(chanSummaries, chanID, &summary) + }) +} + +// MarkBackupIneligible records that the state identified by the (channel id, +// commit height) tuple was ineligible for being backed up under the current +// policy. This state can be retried later under a different policy. +func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, + commitHeight uint64) error { + + return nil +} + +// CommitUpdate persists the CommittedUpdate provided in the slot for (session, +// seqNum). This allows the client to retransmit this update on startup. +func (c *ClientDB) CommitUpdate(id *SessionID, + update *CommittedUpdate) (uint16, error) { + + var lastApplied uint16 + err := c.db.Update(func(tx *bbolt.Tx) error { + sessions := tx.Bucket(cSessionBkt) + if sessions == nil { + return ErrUninitializedDB + } + + // We'll only load the ClientSession body for performance, since + // we primarily need to inspect its SeqNum and TowerLastApplied + // fields. The CommittedUpdates will be modified on disk + // directly. + session, err := getClientSessionBody(sessions, id[:]) + if err != nil { + return err + } + + // Can't fail if the above didn't fail. + sessionBkt := sessions.Bucket(id[:]) + + // Ensure the session commits sub-bucket is initialized. + sessionCommits, err := sessionBkt.CreateBucketIfNotExists( + cSessionCommits, + ) + if err != nil { + return err + } + + var seqNumBuf [2]byte + byteOrder.PutUint16(seqNumBuf[:], update.SeqNum) + + // Check to see if a committed update already exists for this + // sequence number. + committedUpdateBytes := sessionCommits.Get(seqNumBuf[:]) + if committedUpdateBytes != nil { + var dbUpdate CommittedUpdate + err := dbUpdate.Decode( + bytes.NewReader(committedUpdateBytes), + ) + if err != nil { + return err + } + + // If an existing committed update has a different hint, + // we'll reject this newer update. + if dbUpdate.Hint != update.Hint { + return ErrUpdateAlreadyCommitted + } + + // Otherwise, capture the last applied value and + // succeed. + lastApplied = session.TowerLastApplied + return nil + } + + // There's no committed update for this sequence number, ensure + // that we are committing the next unallocated one. + if update.SeqNum != session.SeqNum+1 { + return ErrCommitUnorderedUpdate + } + + // Increment the session's sequence number and store the updated + // client session. + // + // TODO(conner): split out seqnum and last applied own bucket to + // eliminate serialization of full struct during CommitUpdate? + // Can also read/write directly to byes [:2] without migration. + session.SeqNum++ + err = putClientSessionBody(sessions, session) + if err != nil { + return err + } + + // Encode and store the committed update in the sessionCommits + // sub-bucket under the requested sequence number. + var b bytes.Buffer + err = update.Encode(&b) + if err != nil { + return err + } + + err = sessionCommits.Put(seqNumBuf[:], b.Bytes()) + if err != nil { + return err + } + + // Finally, capture the session's last applied value so it can + // be sent in the next state update to the tower. + lastApplied = session.TowerLastApplied + + return nil + + }) + if err != nil { + return 0, err + } + + return lastApplied, nil +} + +// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This +// removes the update from the set of committed updates, and validates the +// lastApplied value returned from the tower. +func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, + lastApplied uint16) error { + + return c.db.Update(func(tx *bbolt.Tx) error { + sessions := tx.Bucket(cSessionBkt) + if sessions == nil { + return ErrUninitializedDB + } + + // We'll only load the ClientSession body for performance, since + // we primarily need to inspect its SeqNum and TowerLastApplied + // fields. The CommittedUpdates and AckedUpdates will be + // modified on disk directly. + session, err := getClientSessionBody(sessions, id[:]) + if err != nil { + return err + } + + // If the tower has acked a sequence number beyond our highest + // sequence number, fail. + if lastApplied > session.SeqNum { + return ErrUnallocatedLastApplied + } + + // If the tower acked with a lower sequence number than it gave + // us prior, fail. + if lastApplied < session.TowerLastApplied { + return ErrLastAppliedReversion + } + + // TODO(conner): split out seqnum and last applied own bucket to + // eliminate serialization of full struct during AckUpdate? Can + // also read/write directly to byes [2:4] without migration. + session.TowerLastApplied = lastApplied + + // Write the client session with the updated last applied value. + err = putClientSessionBody(sessions, session) + if err != nil { + return err + } + + // Can't fail because of getClientSession succeeded. + sessionBkt := sessions.Bucket(id[:]) + + // If the commits sub-bucket doesn't exist, there can't possibly + // be a corresponding committed update to remove. + sessionCommits := sessionBkt.Bucket(cSessionCommits) + if sessionCommits == nil { + return ErrCommittedUpdateNotFound + } + + var seqNumBuf [2]byte + byteOrder.PutUint16(seqNumBuf[:], seqNum) + + // Assert that a committed update exists for this sequence + // number. + committedUpdateBytes := sessionCommits.Get(seqNumBuf[:]) + if committedUpdateBytes == nil { + return ErrCommittedUpdateNotFound + } + + var committedUpdate CommittedUpdate + err = committedUpdate.Decode( + bytes.NewReader(committedUpdateBytes), + ) + if err != nil { + return err + } + + // Remove the corresponding committed update. + err = sessionCommits.Delete(seqNumBuf[:]) + if err != nil { + return err + } + + // Ensure that the session acks sub-bucket is initialized so we + // can insert an entry. + sessionAcks, err := sessionBkt.CreateBucketIfNotExists( + cSessionAcks, + ) + if err != nil { + return err + } + + // The session acks only need to track the backup id of the + // update, so we can discard the blob and hint. + var b bytes.Buffer + err = committedUpdate.BackupID.Encode(&b) + if err != nil { + return err + } + + // Finally, insert the ack into the sessionAcks sub-bucket. + return sessionAcks.Put(seqNumBuf[:], b.Bytes()) + }) +} + +// getClientSessionBody loads the body of a ClientSession from the sessions +// bucket corresponding to the serialized session id. This does not deserialize +// the CommittedUpdates or AckUpdates associated with the session. If the caller +// requires this info, use getClientSession. +func getClientSessionBody(sessions *bbolt.Bucket, + idBytes []byte) (*ClientSession, error) { + + sessionBkt := sessions.Bucket(idBytes) + if sessionBkt == nil { + return nil, ErrClientSessionNotFound + } + + // Should never have a sessionBkt without also having its body. + sessionBody := sessionBkt.Get(cSessionBody) + if sessionBody == nil { + return nil, ErrCorruptClientSession + } + + var session ClientSession + copy(session.ID[:], idBytes) + + err := session.Decode(bytes.NewReader(sessionBody)) + if err != nil { + return nil, err + } + + return &session, nil +} + +// getClientSession loads the full ClientSession associated with the serialized +// session id. This method populates the CommittedUpdates and AckUpdates in +// addition to the ClientSession's body. +func getClientSession(sessions *bbolt.Bucket, + idBytes []byte) (*ClientSession, error) { + + session, err := getClientSessionBody(sessions, idBytes) + if err != nil { + return nil, err + } + + // Fetch the committed updates for this session. + commitedUpdates, err := getClientSessionCommits(sessions, idBytes) + if err != nil { + return nil, err + } + + // Fetch the acked updates for this session. + ackedUpdates, err := getClientSessionAcks(sessions, idBytes) + if err != nil { + return nil, err + } + + session.CommittedUpdates = commitedUpdates + session.AckedUpdates = ackedUpdates + + return session, nil +} + +// getClientSessionCommits retrieves all committed updates for the session +// identified by the serialized session id. +func getClientSessionCommits(sessions *bbolt.Bucket, + idBytes []byte) ([]CommittedUpdate, error) { + + // Can't fail because client session body has already been read. + sessionBkt := sessions.Bucket(idBytes) + + // Initialize commitedUpdates so that we can return an initialized map + // if no committed updates exist. + committedUpdates := make([]CommittedUpdate, 0) + + sessionCommits := sessionBkt.Bucket(cSessionCommits) + if sessionCommits == nil { + return committedUpdates, nil + } + + err := sessionCommits.ForEach(func(k, v []byte) error { + var committedUpdate CommittedUpdate + err := committedUpdate.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + committedUpdate.SeqNum = byteOrder.Uint16(k) + + committedUpdates = append(committedUpdates, committedUpdate) + + return nil + }) + if err != nil { + return nil, err + } + + return committedUpdates, nil +} + +// getClientSessionAcks retrieves all acked updates for the session identified +// by the serialized session id. +func getClientSessionAcks(sessions *bbolt.Bucket, + idBytes []byte) (map[uint16]BackupID, error) { + + // Can't fail because client session body has already been read. + sessionBkt := sessions.Bucket(idBytes) + + // Initialize ackedUpdates so that we can return an initialized map if + // no acked updates exist. + ackedUpdates := make(map[uint16]BackupID) + + sessionAcks := sessionBkt.Bucket(cSessionAcks) + if sessionAcks == nil { + return ackedUpdates, nil + } + + err := sessionAcks.ForEach(func(k, v []byte) error { + seqNum := byteOrder.Uint16(k) + + var backupID BackupID + err := backupID.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + + ackedUpdates[seqNum] = backupID + + return nil + }) + if err != nil { + return nil, err + } + + return ackedUpdates, nil +} + +// putClientSessionBody stores the body of the ClientSession (everything but the +// CommittedUpdates and AckedUpdates). +func putClientSessionBody(sessions *bbolt.Bucket, + session *ClientSession) error { + + sessionBkt, err := sessions.CreateBucketIfNotExists(session.ID[:]) + if err != nil { + return err + } + + var b bytes.Buffer + err = session.Encode(&b) + if err != nil { + return err + } + + return sessionBkt.Put(cSessionBody, b.Bytes()) +} + +// getChanSummary loads a ClientChanSummary for the passed chanID. +func getChanSummary(chanSummaries *bbolt.Bucket, + chanID lnwire.ChannelID) (*ClientChanSummary, error) { + + chanSummaryBytes := chanSummaries.Get(chanID[:]) + if chanSummaryBytes == nil { + return nil, ErrChannelNotRegistered + } + + var summary ClientChanSummary + err := summary.Decode(bytes.NewReader(chanSummaryBytes)) + if err != nil { + return nil, err + } + + return &summary, nil +} + +// putChanSummary stores a ClientChanSummary for the passed chanID. +func putChanSummary(chanSummaries *bbolt.Bucket, chanID lnwire.ChannelID, + summary *ClientChanSummary) error { + + var b bytes.Buffer + err := summary.Encode(&b) + if err != nil { + return err + } + + return chanSummaries.Put(chanID[:], b.Bytes()) +} + +// getTower loads a Tower identified by its serialized tower id. +func getTower(towers *bbolt.Bucket, id []byte) (*Tower, error) { + towerBytes := towers.Get(id) + if towerBytes == nil { + return nil, ErrTowerNotFound + } + + var tower Tower + err := tower.Decode(bytes.NewReader(towerBytes)) + if err != nil { + return nil, err + } + + tower.ID = TowerIDFromBytes(id) + + return &tower, nil +} + +// putTower stores a Tower identified by its serialized tower id. +func putTower(towers *bbolt.Bucket, tower *Tower) error { + var b bytes.Buffer + err := tower.Encode(&b) + if err != nil { + return err + } + + return towers.Put(tower.ID.Bytes(), b.Bytes()) +} diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go new file mode 100644 index 00000000..66fd8a4e --- /dev/null +++ b/watchtower/wtdb/client_db_test.go @@ -0,0 +1,688 @@ +package wtdb_test + +import ( + "bytes" + crand "crypto/rand" + "io" + "io/ioutil" + "net" + "os" + "reflect" + "testing" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtclient" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtmock" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" +) + +// clientDBInit is a closure used to initialize a wtclient.DB instance its +// cleanup function. +type clientDBInit func(t *testing.T) (wtclient.DB, func()) + +type clientDBHarness struct { + t *testing.T + db wtclient.DB +} + +func newClientDBHarness(t *testing.T, init clientDBInit) (*clientDBHarness, func()) { + db, cleanup := init(t) + + h := &clientDBHarness{ + t: t, + db: db, + } + + return h, cleanup +} + +func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) { + h.t.Helper() + + err := h.db.CreateClientSession(session) + if err != expErr { + h.t.Fatalf("expected create client session error: %v, got: %v", + expErr, err) + } +} + +func (h *clientDBHarness) listSessions() map[wtdb.SessionID]*wtdb.ClientSession { + h.t.Helper() + + sessions, err := h.db.ListClientSessions() + if err != nil { + h.t.Fatalf("unable to list client sessions: %v", err) + } + + return sessions +} + +func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, expErr error) uint32 { + h.t.Helper() + + index, err := h.db.NextSessionKeyIndex(id) + if err != expErr { + h.t.Fatalf("expected next session key index error: %v, got: %v", + expErr, err) + } + + if index == 0 { + h.t.Fatalf("next key index should never be 0") + } + + return index +} + +func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress, + expErr error) *wtdb.Tower { + + h.t.Helper() + + tower, err := h.db.CreateTower(lnAddr) + if err != expErr { + h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err) + } + + if tower.ID == 0 { + h.t.Fatalf("tower id should never be 0") + } + + return tower +} + +func (h *clientDBHarness) loadTower(id wtdb.TowerID, expErr error) *wtdb.Tower { + h.t.Helper() + + tower, err := h.db.LoadTower(id) + if err != expErr { + h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err) + } + + return tower +} + +func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary { + h.t.Helper() + + summaries, err := h.db.FetchChanSummaries() + if err != nil { + h.t.Fatalf("unable to fetch chan summaries: %v", err) + } + + return summaries +} + +func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID, + sweepPkScript []byte, expErr error) { + + h.t.Helper() + + err := h.db.RegisterChannel(chanID, sweepPkScript) + if err != expErr { + h.t.Fatalf("expected register channel error: %v, got: %v", + expErr, err) + } +} + +func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID, + update *wtdb.CommittedUpdate, expErr error) uint16 { + + h.t.Helper() + + lastApplied, err := h.db.CommitUpdate(id, update) + if err != expErr { + h.t.Fatalf("expected commit update error: %v, got: %v", + expErr, err) + } + + return lastApplied +} + +func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, + lastApplied uint16, expErr error) { + + h.t.Helper() + + err := h.db.AckUpdate(id, seqNum, lastApplied) + if err != expErr { + h.t.Fatalf("expected commit update error: %v, got: %v", + expErr, err) + } +} + +// testCreateClientSession asserts various conditions regarding the creation of +// a new ClientSession. The test asserts: +// - client sessions can only be created if a session key index is reserved. +// - client sessions cannot be created with an incorrect session key index . +// - inserting duplicate sessions fails. +func testCreateClientSession(h *clientDBHarness) { + // Create a test client session to insert. + session := &wtdb.ClientSession{ + ClientSessionBody: wtdb.ClientSessionBody{ + TowerID: wtdb.TowerID(3), + Policy: wtpolicy.Policy{ + MaxUpdates: 100, + }, + RewardPkScript: []byte{0x01, 0x02, 0x03}, + }, + ID: wtdb.SessionID([33]byte{0x01}), + } + + // First, assert that this session is not already present in the + // database. + if _, ok := h.listSessions()[session.ID]; ok { + h.t.Fatalf("session for id %x should not exist yet", session.ID) + } + + // Attempting to insert the client session without reserving a session + // key index should fail. + h.insertSession(session, wtdb.ErrNoReservedKeyIndex) + + // Now, reserve a session key for this tower. + keyIndex := h.nextKeyIndex(session.TowerID, nil) + + // The client session hasn't been updated with the reserved key index + // (since it's still zero). Inserting should fail due to the mismatch. + h.insertSession(session, wtdb.ErrIncorrectKeyIndex) + + // Reserve another key for the same index. Since no session has been + // successfully created, it should return the same index to maintain + // idempotency across restarts. + keyIndex2 := h.nextKeyIndex(session.TowerID, nil) + if keyIndex != keyIndex2 { + h.t.Fatalf("next key index should be idempotent: want: %v, "+ + "got %v", keyIndex, keyIndex2) + } + + // Now, set the client session's key index so that it is proper and + // insert it. This should succeed. + session.KeyIndex = keyIndex + h.insertSession(session, nil) + + // Verify that the session now exists in the database. + if _, ok := h.listSessions()[session.ID]; !ok { + h.t.Fatalf("session for id %x should exist now", session.ID) + } + + // Attempt to insert the session again, which should fail due to the + // session already existing. + h.insertSession(session, wtdb.ErrClientSessionAlreadyExists) + + // Finally, assert that reserving another key index succeeds with a + // different key index, now that the first one has been finalized. + keyIndex3 := h.nextKeyIndex(session.TowerID, nil) + if keyIndex == keyIndex3 { + h.t.Fatalf("key index still reserved after creating session") + } +} + +// testCreateTower asserts the behavior of creating new Tower objects within the +// database, and that the latest address is always prepended to the list of +// known addresses for the tower. +func testCreateTower(h *clientDBHarness) { + // Test that loading a tower with an arbitrary tower id fails. + h.loadTower(20, wtdb.ErrTowerNotFound) + + pk, err := randPubKey() + if err != nil { + h.t.Fatalf("unable to generate pubkey: %v", err) + } + + addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} + lnAddr := &lnwire.NetAddress{ + IdentityKey: pk, + Address: addr1, + } + + // Insert a random tower into the database. + tower := h.createTower(lnAddr, nil) + + // Load the tower from the database and assert that it matches the tower + // we created. + tower2 := h.loadTower(tower.ID, nil) + if !reflect.DeepEqual(tower, tower2) { + h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", + tower, tower2) + } + + // Insert the address again into the database. Since the address is the + // same, this should result in an unmodified tower record. + towerDupAddr := h.createTower(lnAddr, nil) + if len(towerDupAddr.Addresses) != 1 { + h.t.Fatalf("duplicate address should be deduped") + } + if !reflect.DeepEqual(tower, towerDupAddr) { + h.t.Fatalf("mismatch towers, want: %v, got: %v", + tower, towerDupAddr) + } + + // Generate a new address for this tower. + addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911} + + lnAddr2 := &lnwire.NetAddress{ + IdentityKey: pk, + Address: addr2, + } + + // Insert the updated address, which should produce a tower with a new + // address. + towerNewAddr := h.createTower(lnAddr2, nil) + + // Load the tower from the database, and assert that it matches the + // tower returned from creation. + towerNewAddr2 := h.loadTower(tower.ID, nil) + if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) { + h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", + towerNewAddr, towerNewAddr2) + } + + // Assert that there are now two addresses on the tower object. + if len(towerNewAddr.Addresses) != 2 { + h.t.Fatalf("new address should be added") + } + + // Finally, assert that the new address was prepended since it is deemed + // fresher. + if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) { + h.t.Fatalf("new address should be prepended") + } +} + +// testChanSummaries tests the process of a registering a channel and its +// associated sweep pkscript. +func testChanSummaries(h *clientDBHarness) { + // First, assert that this channel is not already registered. + var chanID lnwire.ChannelID + if _, ok := h.fetchChanSummaries()[chanID]; ok { + h.t.Fatalf("pkscript for channel %x should not exist yet", + chanID) + } + + // Generate a random sweep pkscript and register it for this channel. + expPkScript := make([]byte, 22) + if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil { + h.t.Fatalf("unable to generate pkscript: %v", err) + } + h.registerChan(chanID, expPkScript, nil) + + // Assert that the channel exists and that its sweep pkscript matches + // the one we registered. + summary, ok := h.fetchChanSummaries()[chanID] + if !ok { + h.t.Fatalf("pkscript for channel %x should not exist yet", + chanID) + } else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 { + h.t.Fatalf("pkscript mismatch, want: %x, got: %x", + expPkScript, summary.SweepPkScript) + } + + // Finally, assert that re-registering the same channel produces a + // failure. + h.registerChan(chanID, expPkScript, wtdb.ErrChannelAlreadyRegistered) +} + +// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can +func testCommitUpdate(h *clientDBHarness) { + session := &wtdb.ClientSession{ + ClientSessionBody: wtdb.ClientSessionBody{ + TowerID: wtdb.TowerID(3), + Policy: wtpolicy.Policy{ + MaxUpdates: 100, + }, + RewardPkScript: []byte{0x01, 0x02, 0x03}, + }, + ID: wtdb.SessionID([33]byte{0x02}), + } + + // Generate a random update and try to commit before inserting the + // session, which should fail. + update1 := randCommittedUpdate(h.t, 1) + h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound) + + // Reserve a session key index and insert the session. + session.KeyIndex = h.nextKeyIndex(session.TowerID, nil) + h.insertSession(session, nil) + + // Now, try to commit the update that failed initially which should + // succeed. The lastApplied value should be 0 since we have not received + // an ack from the tower. + lastApplied := h.commitUpdate(&session.ID, update1, nil) + if lastApplied != 0 { + h.t.Fatalf("last applied mismatch, want: 0, got: %v", + lastApplied) + } + + // Assert that the committed update appears in the client session's + // CommittedUpdates map when loaded from disk and that there are no + // AckedUpdates. + dbSession := h.listSessions()[session.ID] + checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ + *update1, + }) + checkAckedUpdates(h.t, dbSession, nil) + + // Try to commit the same update, which should succeed due to + // idempotency (which is preserved when the breach hint is identical to + // the on-disk update's hint). The lastApplied value should remain + // unchanged. + lastApplied2 := h.commitUpdate(&session.ID, update1, nil) + if lastApplied2 != lastApplied { + h.t.Fatalf("last applied should not have changed, got %v", + lastApplied2) + } + + // Assert that the loaded ClientSession is the same as before. + dbSession = h.listSessions()[session.ID] + checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ + *update1, + }) + checkAckedUpdates(h.t, dbSession, nil) + + // Generate another random update and try to commit it at the identical + // sequence number. Since the breach hint has changed, this should fail. + update2 := randCommittedUpdate(h.t, 1) + h.commitUpdate(&session.ID, update2, wtdb.ErrUpdateAlreadyCommitted) + + // Next, insert the new update at the next unallocated sequence number + // which should succeed. + update2.SeqNum = 2 + lastApplied3 := h.commitUpdate(&session.ID, update2, nil) + if lastApplied3 != lastApplied { + h.t.Fatalf("last applied should not have changed, got %v", + lastApplied3) + } + + // Check that both updates now appear as committed on the ClientSession + // loaded from disk. + dbSession = h.listSessions()[session.ID] + checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ + *update1, + *update2, + }) + checkAckedUpdates(h.t, dbSession, nil) + + // Finally, create one more random update and try to commit it at index + // 4, which should be rejected since 3 is the next slot the database + // expects. + update4 := randCommittedUpdate(h.t, 4) + h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate) + + // Assert that the ClientSession loaded from disk remains unchanged. + dbSession = h.listSessions()[session.ID] + checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ + *update1, + *update2, + }) + checkAckedUpdates(h.t, dbSession, nil) +} + +// testAckUpdate asserts the behavior of AckUpdate. +func testAckUpdate(h *clientDBHarness) { + // Create a new session that the updates in this will be tied to. + session := &wtdb.ClientSession{ + ClientSessionBody: wtdb.ClientSessionBody{ + TowerID: wtdb.TowerID(3), + Policy: wtpolicy.Policy{ + MaxUpdates: 100, + }, + RewardPkScript: []byte{0x01, 0x02, 0x03}, + }, + ID: wtdb.SessionID([33]byte{0x03}), + } + + // Try to ack an update before inserting the client session, which + // should fail. + h.ackUpdate(&session.ID, 1, 0, wtdb.ErrClientSessionNotFound) + + // Reserve a session key and insert the client session. + session.KeyIndex = h.nextKeyIndex(session.TowerID, nil) + h.insertSession(session, nil) + + // Now, try to ack update 1. This should fail since update 1 was never + // committed. + h.ackUpdate(&session.ID, 1, 0, wtdb.ErrCommittedUpdateNotFound) + + // Commit to a random update at seqnum 1. + update1 := randCommittedUpdate(h.t, 1) + lastApplied := h.commitUpdate(&session.ID, update1, nil) + if lastApplied != 0 { + h.t.Fatalf("last applied mismatch, want: 0, got: %v", + lastApplied) + } + + // Acking seqnum 1 should succeed. + h.ackUpdate(&session.ID, 1, 1, nil) + + // Acking seqnum 1 again should fail. + h.ackUpdate(&session.ID, 1, 1, wtdb.ErrCommittedUpdateNotFound) + + // Acking a valid seqnum with a reverted last applied value should fail. + h.ackUpdate(&session.ID, 1, 0, wtdb.ErrLastAppliedReversion) + + // Acking with a last applied greater than any allocated seqnum should + // fail. + h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied) + + // Assert that the ClientSession loaded from disk has one update in it's + // AckedUpdates map, and that the committed update has been removed. + dbSession := h.listSessions()[session.ID] + checkCommittedUpdates(h.t, dbSession, nil) + checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ + 1: update1.BackupID, + }) + + // Commit to another random update, and assert that the last applied + // value is 1, since this was what was provided in the last successful + // ack. + update2 := randCommittedUpdate(h.t, 2) + lastApplied = h.commitUpdate(&session.ID, update2, nil) + if lastApplied != 1 { + h.t.Fatalf("last applied mismatch, want: 1, got: %v", + lastApplied) + } + + // Ack seqnum 2. + h.ackUpdate(&session.ID, 2, 2, nil) + + // Assert that both updates exist as AckedUpdates when loaded from disk. + dbSession = h.listSessions()[session.ID] + checkCommittedUpdates(h.t, dbSession, nil) + checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ + 1: update1.BackupID, + 2: update2.BackupID, + }) + + // Acking again with a lower last applied should fail. + h.ackUpdate(&session.ID, 2, 1, wtdb.ErrLastAppliedReversion) + + // Acking an unallocated seqnum should fail. + h.ackUpdate(&session.ID, 4, 2, wtdb.ErrCommittedUpdateNotFound) + + // Acking with a last applied greater than any allocated seqnum should + // fail. + h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied) +} + +// checkCommittedUpdates asserts that the CommittedUpdates on session match the +// expUpdates provided. +func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, + expUpdates []wtdb.CommittedUpdate) { + + t.Helper() + + // We promote nil expUpdates to an initialized slice since the database + // should never return a nil slice. This promotion is done purely out of + // convenience for the testing framework. + if expUpdates == nil { + expUpdates = make([]wtdb.CommittedUpdate, 0) + } + + if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) { + t.Fatalf("committed updates mismatch, want: %v, got: %v", + expUpdates, session.CommittedUpdates) + } +} + +// checkAckedUpdates asserts that the AckedUpdates on a sessio match the +// expUpdates provided. +func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, + expUpdates map[uint16]wtdb.BackupID) { + + // We promote nil expUpdates to an initialized map since the database + // should never return a nil map. This promotion is done purely out of + // convenience for the testing framework. + if expUpdates == nil { + expUpdates = make(map[uint16]wtdb.BackupID) + } + + if !reflect.DeepEqual(session.AckedUpdates, expUpdates) { + t.Fatalf("acked updates mismatch, want: %v, got: %v", + expUpdates, session.AckedUpdates) + } +} + +// TestClientDB asserts the behavior of a fresh client db, a reopened client db, +// and the mock implementation. This ensures that all databases function +// identically, especially in the negative paths. +func TestClientDB(t *testing.T) { + dbs := []struct { + name string + init clientDBInit + }{ + { + name: "fresh clientdb", + init: func(t *testing.T) (wtclient.DB, func()) { + path, err := ioutil.TempDir("", "clientdb") + if err != nil { + t.Fatalf("unable to make temp dir: %v", + err) + } + + db, err := wtdb.OpenClientDB(path) + if err != nil { + os.RemoveAll(path) + t.Fatalf("unable to open db: %v", err) + } + + cleanup := func() { + db.Close() + os.RemoveAll(path) + } + + return db, cleanup + }, + }, + { + name: "reopened clientdb", + init: func(t *testing.T) (wtclient.DB, func()) { + path, err := ioutil.TempDir("", "clientdb") + if err != nil { + t.Fatalf("unable to make temp dir: %v", + err) + } + + db, err := wtdb.OpenClientDB(path) + if err != nil { + os.RemoveAll(path) + t.Fatalf("unable to open db: %v", err) + } + db.Close() + + db, err = wtdb.OpenClientDB(path) + if err != nil { + os.RemoveAll(path) + t.Fatalf("unable to reopen db: %v", err) + } + + cleanup := func() { + db.Close() + os.RemoveAll(path) + } + + return db, cleanup + }, + }, + { + name: "mock", + init: func(t *testing.T) (wtclient.DB, func()) { + return wtmock.NewClientDB(), func() {} + }, + }, + } + + tests := []struct { + name string + run func(*clientDBHarness) + }{ + { + name: "create client session", + run: testCreateClientSession, + }, + { + name: "create tower", + run: testCreateTower, + }, + { + name: "chan summaries", + run: testChanSummaries, + }, + { + name: "commit update", + run: testCommitUpdate, + }, + { + name: "ack update", + run: testAckUpdate, + }, + } + + for _, database := range dbs { + db := database + t.Run(db.name, func(t *testing.T) { + t.Parallel() + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + h, cleanup := newClientDBHarness( + t, db.init, + ) + defer cleanup() + + test.run(h) + }) + } + }) + } +} + +// randCommittedUpdate generates a random committed update. +func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { + var chanID lnwire.ChannelID + if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil { + t.Fatalf("unable to generate chan id: %v", err) + } + + var hint wtdb.BreachHint + if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil { + t.Fatalf("unable to generate breach hint: %v", err) + } + + encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) + if _, err := io.ReadFull(crand.Reader, encBlob); err != nil { + t.Fatalf("unable to generate encrypted blob: %v", err) + } + + return &wtdb.CommittedUpdate{ + SeqNum: seqNum, + CommittedUpdateBody: wtdb.CommittedUpdateBody{ + BackupID: wtdb.BackupID{ + ChanID: chanID, + CommitHeight: 666, + }, + Hint: hint, + EncryptedBlob: encBlob, + }, + } +} diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index d29e1f5f..34e2168b 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -1,7 +1,6 @@ package wtdb import ( - "errors" "io" "github.com/btcsuite/btcd/btcec" @@ -9,44 +8,6 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) -var ( - // ErrClientSessionNotFound signals that the requested client session - // was not found in the database. - ErrClientSessionNotFound = errors.New("client session not found") - - // ErrUpdateAlreadyCommitted signals that the chosen sequence number has - // already been committed to an update with a different breach hint. - ErrUpdateAlreadyCommitted = errors.New("update already committed") - - // ErrCommitUnorderedUpdate signals the client tried to commit a - // sequence number other than the next unallocated sequence number. - ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic") - - // ErrCommittedUpdateNotFound signals that the tower tried to ACK a - // sequence number that has not yet been allocated by the client. - ErrCommittedUpdateNotFound = errors.New("committed update not found") - - // ErrUnallocatedLastApplied signals that the tower tried to provide a - // LastApplied value greater than any allocated sequence number. - ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " + - "greater than allocated seqnum") - - // ErrNoReservedKeyIndex signals that a client session could not be - // created because no session key index was reserved. - ErrNoReservedKeyIndex = errors.New("key index not reserved") - - // ErrIncorrectKeyIndex signals that the client session could not be - // created because session key index differs from the reserved key - // index. - ErrIncorrectKeyIndex = errors.New("incorrect key index") - - // ErrClientSessionAlreadyExists signals an attempt to reinsert - // a client session that has already been created. - ErrClientSessionAlreadyExists = errors.New( - "client session already exists", - ) -) - // ClientSession encapsulates a SessionInfo returned from a successful // session negotiation, and also records the tower and ephemeral secret used for // communicating with the tower. diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go index 518da750..426a6e83 100644 --- a/watchtower/wtdb/tower.go +++ b/watchtower/wtdb/tower.go @@ -1,7 +1,6 @@ package wtdb import ( - "errors" "io" "net" @@ -9,12 +8,6 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) -var ( - // ErrTowerNotFound signals that the target tower was not found in the - // database. - ErrTowerNotFound = errors.New("tower not found") -) - // TowerID is a unique 64-bit identifier allocated to each unique watchtower. // This allows the client to conserve on-disk space by not needing to always // reference towers by their pubkey. diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index 974f25b0..b8aa2b7e 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -21,6 +21,11 @@ type version struct { // migrations must be applied. var towerDBVersions = []version{} +// clientDBVersions stores all versions and migrations of the client database. +// This list will be used when opening the database to determine if any +// migrations must be applied. +var clientDBVersions = []version{} + // getLatestDBVersion returns the last known database version. func getLatestDBVersion(versions []version) uint32 { return uint32(len(versions)) From 9157c88f9328f5a2a5b76aed6b5156a84b9ff5d1 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:49:18 -0700 Subject: [PATCH 10/11] watchtower/wtclient: dedup backups across restarts Now that the committed and acked updates are persisted across restarts, we will use them to filter out duplicate commit heights presented by the client. --- watchtower/wtclient/client.go | 70 +++++++++++++++++++++++++++--- watchtower/wtclient/client_test.go | 49 +++++++++++++++++++++ 2 files changed, 113 insertions(+), 6 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 6a037570..8f0cbc9f 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -150,8 +150,9 @@ type TowerClient struct { sessionQueue *sessionQueue prevTask *backupTask - summaryMu sync.RWMutex - summaries wtdb.ChannelSummaries + backupMu sync.Mutex + summaries wtdb.ChannelSummaries + chanCommitHeights map[lnwire.ChannelID]uint64 statTicker *time.Ticker stats clientStats @@ -243,6 +244,10 @@ func New(config *Config) (*TowerClient, error) { s.SessionPrivKey = sessionPriv } + // Reconstruct the highest commit height processed for each channel + // under the client's current policy. + c.buildHighestCommitHeights() + // Finally, load the sweep pkscripts that have been generated for all // previously registered channels. c.summaries, err = c.cfg.DB.FetchChanSummaries() @@ -253,6 +258,44 @@ func New(config *Config) (*TowerClient, error) { return c, nil } +// buildHighestCommitHeights inspects the full set of candidate client sessions +// loaded from disk, and determines the highest known commit height for each +// channel. This allows the client to reject backups that it has already +// processed for it's active policy. +func (c *TowerClient) buildHighestCommitHeights() { + chanCommitHeights := make(map[lnwire.ChannelID]uint64) + for _, s := range c.candidateSessions { + // We only want to consider accepted updates that have been + // accepted under an identical policy to the client's current + // policy. + if s.Policy != c.cfg.Policy { + continue + } + + // Take the highest commit height found in the session's + // committed updates. + for _, committedUpdate := range s.CommittedUpdates { + bid := committedUpdate.BackupID + + height, ok := chanCommitHeights[bid.ChanID] + if !ok || bid.CommitHeight > height { + chanCommitHeights[bid.ChanID] = bid.CommitHeight + } + } + + // Take the heights commit height found in the session's acked + // updates. + for _, bid := range s.AckedUpdates { + height, ok := chanCommitHeights[bid.ChanID] + if !ok || bid.CommitHeight > height { + chanCommitHeights[bid.ChanID] = bid.CommitHeight + } + } + } + + c.chanCommitHeights = chanCommitHeights +} + // Start initializes the watchtower client by loading or negotiating an active // session and then begins processing backup tasks from the request pipeline. func (c *TowerClient) Start() error { @@ -388,8 +431,8 @@ func (c *TowerClient) ForceQuit() { // within the client. This should be called during link startup to ensure that // the client is able to support the link during operation. func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { - c.summaryMu.Lock() - defer c.summaryMu.Unlock() + c.backupMu.Lock() + defer c.backupMu.Unlock() // If a pkscript for this channel already exists, the channel has been // previously registered. @@ -431,13 +474,28 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, breachInfo *lnwallet.BreachRetribution) error { // Retrieve the cached sweep pkscript used for this channel. - c.summaryMu.RLock() + c.backupMu.Lock() summary, ok := c.summaries[*chanID] - c.summaryMu.RUnlock() if !ok { + c.backupMu.Unlock() return ErrUnregisteredChannel } + // Ignore backups that have already been presented to the client. + height, ok := c.chanCommitHeights[*chanID] + if ok && breachInfo.RevokedStateNum <= height { + c.backupMu.Unlock() + log.Debugf("Ignoring duplicate backup for chanid=%v at height=%d", + chanID, breachInfo.RevokedStateNum) + return nil + } + + // This backup has a higher commit height than any known backup for this + // channel. We'll update our tip so that we won't accept it again if the + // link flaps. + c.chanCommitHeights[*chanID] = breachInfo.RevokedStateNum + c.backupMu.Unlock() + task := newBackupTask(chanID, breachInfo, summary.SweepPkScript) return c.pipeline.QueueBackupTask(task) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index ac4ebf2d..b5a9bbbd 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -1246,6 +1246,55 @@ var clientTests = []clientTest{ h.assertUpdatesForPolicy(hints, h.clientCfg.Policy) }, }, + { + // Asserts that the client will deduplicate backups presented by + // a channel both in memory and after a restart. The client + // should only accept backups with a commit height greater than + // any processed already processed for a given policy. + name: "dedup backups", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 5, + SweepFeeRate: 1, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 10 + chanID = 0 + ) + + // Generate the retributions that will be backed up. + hints := h.advanceChannelN(chanID, numUpdates) + + // Queue the first half of the retributions twice, the + // second batch should be entirely deduped by the + // client's in-memory tracking. + h.backupStates(chanID, 0, numUpdates/2, nil) + h.backupStates(chanID, 0, numUpdates/2, nil) + + // Wait for the first half of the updates to be + // populated in the server's database. + h.waitServerUpdates(hints[:len(hints)/2], 5*time.Second) + + // Restart the client, so we can ensure the deduping is + // maintained across restarts. + h.client.Stop() + h.startClient() + defer h.client.ForceQuit() + + // Try to back up the full range of retributions. Only + // the second half should actually be sent. + h.backupStates(chanID, 0, numUpdates, nil) + + // Wait for all of the updates to be populated in the + // server's database. + h.waitServerUpdates(hints, 5*time.Second) + }, + }, } // TestClient executes the client test suite, asserting the ability to backup From 28bf49807e0722549e45c82eeca72d58ea418871 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 24 May 2019 17:04:07 -0700 Subject: [PATCH 11/11] watchtower/wtdb: add CSessionStatus field to ClientSession This commit adds persisted status bit-field to ClientSessions, that can be used to modify behavior of their handling in the client. Currently, only a default CSessionActive status is defined. However, the intention is that this could later be used to signal that a session is abandoned without needing to perform a db migration to add the field. As we move forward with testing, this will likely be useful if a session gets borked and we need a simple method of the client to temporarily ignore certain sessions. The field may be useful in signaling other types of status changes, though this was the primary motivation that warranted the addition. --- channeldb/codec.go | 10 ++++++++++ watchtower/wtdb/client_session.go | 21 ++++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/channeldb/codec.go b/channeldb/codec.go index ec6e165b..ca5cfeed 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -128,6 +128,11 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case uint8: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + case bool: if err := binary.Write(w, byteOrder, e); err != nil { return err @@ -289,6 +294,11 @@ func ReadElement(r io.Reader, element interface{}) error { return err } + case *uint8: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + case *bool: if err := binary.Read(r, byteOrder, e); err != nil { return err diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index 34e2168b..cb59ca57 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -8,6 +8,16 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) +// CSessionStatus is a bit-field representing the possible statuses of +// ClientSessions. +type CSessionStatus uint8 + +const ( + // CSessionActive indicates that the ClientSession is active and can be + // used for backups. + CSessionActive CSessionStatus = 0 +) + // ClientSession encapsulates a SessionInfo returned from a successful // session negotiation, and also records the tower and ephemeral secret used for // communicating with the tower. @@ -76,6 +86,9 @@ type ClientSessionBody struct { // Policy holds the negotiated session parameters. Policy wtpolicy.Policy + // Status indicates the current state of the ClientSession. + Status CSessionStatus + // RewardPkScript is the pkscript that the tower's reward will be // deposited to if a sweep transaction confirms and the sessions // specifies a reward output. @@ -89,6 +102,7 @@ func (s *ClientSessionBody) Encode(w io.Writer) error { s.TowerLastApplied, uint64(s.TowerID), s.KeyIndex, + uint8(s.Status), s.Policy, s.RewardPkScript, ) @@ -96,12 +110,16 @@ func (s *ClientSessionBody) Encode(w io.Writer) error { // Decode reads a ClientSessionBody from the passed io.Reader. func (s *ClientSessionBody) Decode(r io.Reader) error { - var towerID uint64 + var ( + towerID uint64 + status uint8 + ) err := ReadElements(r, &s.SeqNum, &s.TowerLastApplied, &towerID, &s.KeyIndex, + &status, &s.Policy, &s.RewardPkScript, ) @@ -110,6 +128,7 @@ func (s *ClientSessionBody) Decode(r io.Reader) error { } s.TowerID = TowerID(towerID) + s.Status = CSessionStatus(status) return nil }