From 7d99005dde705281485cbcd9896825fa531f705e Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 27 Mar 2019 16:50:59 -0700 Subject: [PATCH] watchtower/wtclient/interface: add LoadTower and mock impl --- watchtower/wtclient/client.go | 11 +++++++++++ watchtower/wtclient/interface.go | 3 +++ watchtower/wtdb/tower.go | 7 +++++++ watchtower/wtmock/client_db.go | 13 ++++++++++++- 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 9bd8e9e8..51d3c61b 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -221,6 +221,17 @@ func New(config *Config) (*TowerClient, error) { return nil, err } + // Reload any towers from disk using the tower IDs contained in each + // candidate session. + for _, s := range c.candidateSessions { + tower, err := c.cfg.DB.LoadTower(s.TowerID) + if err != nil { + return nil, err + } + + s.Tower = tower + } + // Finally, load the sweep pkscripts that have been generated for all // previously registered channels. c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts() diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 5164acea..b770343e 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -19,6 +19,9 @@ type DB interface { // sessions. CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error) + // LoadTower retrieves a tower by its tower ID. + LoadTower(uint64) (*wtdb.Tower, error) + // CreateClientSession saves a newly negotiated client session to the // client's database. This enables the session to be used across // restarts. diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go index e7213cab..ff7a48df 100644 --- a/watchtower/wtdb/tower.go +++ b/watchtower/wtdb/tower.go @@ -1,6 +1,7 @@ package wtdb import ( + "errors" "net" "sync" @@ -8,6 +9,12 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) +var ( + // ErrTowerNotFound signals that the target tower was not found in the + // database. + ErrTowerNotFound = errors.New("tower not found") +) + // Tower holds the necessary components required to connect to a remote tower. // Communication is handled by brontide, and requires both a public key and an // address. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 54f9a697..ad5ca79d 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -64,6 +64,18 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { return tower, nil } +// LoadTower retrieves a tower by its tower ID. +func (m *ClientDB) LoadTower(towerID uint64) (*wtdb.Tower, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if tower, ok := m.towers[towerID]; ok { + return tower, nil + } + + return nil, wtdb.ErrTowerNotFound +} + // MarkBackupIneligible records that particular commit height is ineligible for // backup. This allows the client to track which updates it should not attempt // to retry after startup. @@ -92,7 +104,6 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { m.activeSessions[session.ID] = &wtdb.ClientSession{ TowerID: session.TowerID, - Tower: session.Tower, SessionKeyDesc: session.SessionKeyDesc, SessionPrivKey: session.SessionPrivKey, ID: session.ID,