From 01ab551b221c142dfc2a639e7f0acf8e3a85040d Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Mon, 11 May 2020 15:23:43 -0700 Subject: [PATCH] wtclient: refactor existing candidate session filtering into method --- watchtower/wtclient/client.go | 95 +++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 76aa2a4b..8a37abe5 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -38,6 +38,14 @@ const ( DefaultForceQuitDelay = 10 * time.Second ) +var ( + // activeSessionFilter is a filter that ignored any sessions which are + // not active. + activeSessionFilter = func(s *wtdb.ClientSession) bool { + return s.Status == wtdb.CSessionActive + } +) + // RegisteredTower encompasses information about a registered watchtower with // the client. type RegisteredTower struct { @@ -268,49 +276,18 @@ func New(config *Config) (*TowerClient, error) { // the client. We will use any of these session if their policies match // the current policy of the client, otherwise they will be ignored and // new sessions will be requested. - sessions, err := cfg.DB.ListClientSessions(nil) + candidateSessions, err := getClientSessions( + cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter, + ) if err != nil { return nil, err } - candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) - sessionTowers := make(map[wtdb.TowerID]*wtdb.Tower) - for _, s := range sessions { - // Candidate sessions must be in an active state. - if s.Status != wtdb.CSessionActive { - continue - } - - // Reload the tower from disk using the tower ID contained in - // each candidate session. We will also rederive any session - // keys needed to be able to communicate with the towers and - // authenticate session requests. This prevents us from having - // to store the private keys on disk. - tower, ok := sessionTowers[s.TowerID] - if !ok { - var err error - tower, err = cfg.DB.LoadTowerByID(s.TowerID) - if err != nil { - return nil, err - } - } - s.Tower = tower - - sessionKey, err := DeriveSessionKey(cfg.SecretKeyRing, s.KeyIndex) - if err != nil { - return nil, err - } - s.SessionPrivKey = sessionKey - - candidateSessions[s.ID] = s - sessionTowers[tower.ID] = tower - } - var candidateTowers []*wtdb.Tower - for _, tower := range sessionTowers { + for _, s := range candidateSessions { log.Infof("Using private watchtower %s, offering policy %s", - tower, cfg.Policy) - candidateTowers = append(candidateTowers, tower) + s.Tower, cfg.Policy) + candidateTowers = append(candidateTowers, s.Tower) } // Load the sweep pkscripts that have been generated for all previously @@ -353,6 +330,50 @@ func New(config *Config) (*TowerClient, error) { return c, nil } +// getClientSessions retrieves the client sessions for a particular tower if +// specified, otherwise all client sessions for all towers are retrieved. An +// optional filter can be provided to filter out any undesired client sessions. +// +// NOTE: This method should only be used when deserialization of a +// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the +// existing ListClientSessions method should be used. +func getClientSessions(db DB, keyRing SecretKeyRing, forTower *wtdb.TowerID, + passesFilter func(*wtdb.ClientSession) bool) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { + + sessions, err := db.ListClientSessions(forTower) + if err != nil { + return nil, err + } + + // Reload the tower from disk using the tower ID contained in each + // candidate session. We will also rederive any session keys needed to + // be able to communicate with the towers and authenticate session + // requests. This prevents us from having to store the private keys on + // disk. + for _, s := range sessions { + tower, err := db.LoadTowerByID(s.TowerID) + if err != nil { + return nil, err + } + s.Tower = tower + + sessionKey, err := DeriveSessionKey(keyRing, s.KeyIndex) + if err != nil { + return nil, err + } + s.SessionPrivKey = sessionKey + + // If an optional filter was provided, use it to filter out any + // undesired sessions. + if passesFilter != nil && !passesFilter(s) { + delete(sessions, s.ID) + } + } + + return sessions, nil +} + // buildHighestCommitHeights inspects the full set of candidate client sessions // loaded from disk, and determines the highest known commit height for each // channel. This allows the client to reject backups that it has already