wtclient: refactor existing candidate session filtering into method
This commit is contained in:
parent
8b09ac07d3
commit
01ab551b22
@ -38,6 +38,14 @@ const (
|
||||
DefaultForceQuitDelay = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// activeSessionFilter is a filter that ignored any sessions which are
|
||||
// not active.
|
||||
activeSessionFilter = func(s *wtdb.ClientSession) bool {
|
||||
return s.Status == wtdb.CSessionActive
|
||||
}
|
||||
)
|
||||
|
||||
// RegisteredTower encompasses information about a registered watchtower with
|
||||
// the client.
|
||||
type RegisteredTower struct {
|
||||
@ -268,49 +276,18 @@ func New(config *Config) (*TowerClient, error) {
|
||||
// the client. We will use any of these session if their policies match
|
||||
// the current policy of the client, otherwise they will be ignored and
|
||||
// new sessions will be requested.
|
||||
sessions, err := cfg.DB.ListClientSessions(nil)
|
||||
candidateSessions, err := getClientSessions(
|
||||
cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
sessionTowers := make(map[wtdb.TowerID]*wtdb.Tower)
|
||||
for _, s := range sessions {
|
||||
// Candidate sessions must be in an active state.
|
||||
if s.Status != wtdb.CSessionActive {
|
||||
continue
|
||||
}
|
||||
|
||||
// Reload the tower from disk using the tower ID contained in
|
||||
// each candidate session. We will also rederive any session
|
||||
// keys needed to be able to communicate with the towers and
|
||||
// authenticate session requests. This prevents us from having
|
||||
// to store the private keys on disk.
|
||||
tower, ok := sessionTowers[s.TowerID]
|
||||
if !ok {
|
||||
var err error
|
||||
tower, err = cfg.DB.LoadTowerByID(s.TowerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
s.Tower = tower
|
||||
|
||||
sessionKey, err := DeriveSessionKey(cfg.SecretKeyRing, s.KeyIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.SessionPrivKey = sessionKey
|
||||
|
||||
candidateSessions[s.ID] = s
|
||||
sessionTowers[tower.ID] = tower
|
||||
}
|
||||
|
||||
var candidateTowers []*wtdb.Tower
|
||||
for _, tower := range sessionTowers {
|
||||
for _, s := range candidateSessions {
|
||||
log.Infof("Using private watchtower %s, offering policy %s",
|
||||
tower, cfg.Policy)
|
||||
candidateTowers = append(candidateTowers, tower)
|
||||
s.Tower, cfg.Policy)
|
||||
candidateTowers = append(candidateTowers, s.Tower)
|
||||
}
|
||||
|
||||
// Load the sweep pkscripts that have been generated for all previously
|
||||
@ -353,6 +330,50 @@ func New(config *Config) (*TowerClient, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// getClientSessions retrieves the client sessions for a particular tower if
|
||||
// specified, otherwise all client sessions for all towers are retrieved. An
|
||||
// optional filter can be provided to filter out any undesired client sessions.
|
||||
//
|
||||
// NOTE: This method should only be used when deserialization of a
|
||||
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
|
||||
// existing ListClientSessions method should be used.
|
||||
func getClientSessions(db DB, keyRing SecretKeyRing, forTower *wtdb.TowerID,
|
||||
passesFilter func(*wtdb.ClientSession) bool) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
sessions, err := db.ListClientSessions(forTower)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reload the tower from disk using the tower ID contained in each
|
||||
// candidate session. We will also rederive any session keys needed to
|
||||
// be able to communicate with the towers and authenticate session
|
||||
// requests. This prevents us from having to store the private keys on
|
||||
// disk.
|
||||
for _, s := range sessions {
|
||||
tower, err := db.LoadTowerByID(s.TowerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Tower = tower
|
||||
|
||||
sessionKey, err := DeriveSessionKey(keyRing, s.KeyIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.SessionPrivKey = sessionKey
|
||||
|
||||
// If an optional filter was provided, use it to filter out any
|
||||
// undesired sessions.
|
||||
if passesFilter != nil && !passesFilter(s) {
|
||||
delete(sessions, s.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// buildHighestCommitHeights inspects the full set of candidate client sessions
|
||||
// loaded from disk, and determines the highest known commit height for each
|
||||
// channel. This allows the client to reject backups that it has already
|
||||
|
Loading…
Reference in New Issue
Block a user