wtclient: refactor existing candidate session filtering into method
This commit is contained in:
parent
8b09ac07d3
commit
01ab551b22
@ -38,6 +38,14 @@ const (
|
|||||||
DefaultForceQuitDelay = 10 * time.Second
|
DefaultForceQuitDelay = 10 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// activeSessionFilter is a filter that ignored any sessions which are
|
||||||
|
// not active.
|
||||||
|
activeSessionFilter = func(s *wtdb.ClientSession) bool {
|
||||||
|
return s.Status == wtdb.CSessionActive
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
// RegisteredTower encompasses information about a registered watchtower with
|
// RegisteredTower encompasses information about a registered watchtower with
|
||||||
// the client.
|
// the client.
|
||||||
type RegisteredTower struct {
|
type RegisteredTower struct {
|
||||||
@ -268,49 +276,18 @@ func New(config *Config) (*TowerClient, error) {
|
|||||||
// the client. We will use any of these session if their policies match
|
// the client. We will use any of these session if their policies match
|
||||||
// the current policy of the client, otherwise they will be ignored and
|
// the current policy of the client, otherwise they will be ignored and
|
||||||
// new sessions will be requested.
|
// new sessions will be requested.
|
||||||
sessions, err := cfg.DB.ListClientSessions(nil)
|
candidateSessions, err := getClientSessions(
|
||||||
|
cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
|
||||||
sessionTowers := make(map[wtdb.TowerID]*wtdb.Tower)
|
|
||||||
for _, s := range sessions {
|
|
||||||
// Candidate sessions must be in an active state.
|
|
||||||
if s.Status != wtdb.CSessionActive {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload the tower from disk using the tower ID contained in
|
|
||||||
// each candidate session. We will also rederive any session
|
|
||||||
// keys needed to be able to communicate with the towers and
|
|
||||||
// authenticate session requests. This prevents us from having
|
|
||||||
// to store the private keys on disk.
|
|
||||||
tower, ok := sessionTowers[s.TowerID]
|
|
||||||
if !ok {
|
|
||||||
var err error
|
|
||||||
tower, err = cfg.DB.LoadTowerByID(s.TowerID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.Tower = tower
|
|
||||||
|
|
||||||
sessionKey, err := DeriveSessionKey(cfg.SecretKeyRing, s.KeyIndex)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
s.SessionPrivKey = sessionKey
|
|
||||||
|
|
||||||
candidateSessions[s.ID] = s
|
|
||||||
sessionTowers[tower.ID] = tower
|
|
||||||
}
|
|
||||||
|
|
||||||
var candidateTowers []*wtdb.Tower
|
var candidateTowers []*wtdb.Tower
|
||||||
for _, tower := range sessionTowers {
|
for _, s := range candidateSessions {
|
||||||
log.Infof("Using private watchtower %s, offering policy %s",
|
log.Infof("Using private watchtower %s, offering policy %s",
|
||||||
tower, cfg.Policy)
|
s.Tower, cfg.Policy)
|
||||||
candidateTowers = append(candidateTowers, tower)
|
candidateTowers = append(candidateTowers, s.Tower)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the sweep pkscripts that have been generated for all previously
|
// Load the sweep pkscripts that have been generated for all previously
|
||||||
@ -353,6 +330,50 @@ func New(config *Config) (*TowerClient, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getClientSessions retrieves the client sessions for a particular tower if
|
||||||
|
// specified, otherwise all client sessions for all towers are retrieved. An
|
||||||
|
// optional filter can be provided to filter out any undesired client sessions.
|
||||||
|
//
|
||||||
|
// NOTE: This method should only be used when deserialization of a
|
||||||
|
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
|
||||||
|
// existing ListClientSessions method should be used.
|
||||||
|
func getClientSessions(db DB, keyRing SecretKeyRing, forTower *wtdb.TowerID,
|
||||||
|
passesFilter func(*wtdb.ClientSession) bool) (
|
||||||
|
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||||
|
|
||||||
|
sessions, err := db.ListClientSessions(forTower)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload the tower from disk using the tower ID contained in each
|
||||||
|
// candidate session. We will also rederive any session keys needed to
|
||||||
|
// be able to communicate with the towers and authenticate session
|
||||||
|
// requests. This prevents us from having to store the private keys on
|
||||||
|
// disk.
|
||||||
|
for _, s := range sessions {
|
||||||
|
tower, err := db.LoadTowerByID(s.TowerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.Tower = tower
|
||||||
|
|
||||||
|
sessionKey, err := DeriveSessionKey(keyRing, s.KeyIndex)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.SessionPrivKey = sessionKey
|
||||||
|
|
||||||
|
// If an optional filter was provided, use it to filter out any
|
||||||
|
// undesired sessions.
|
||||||
|
if passesFilter != nil && !passesFilter(s) {
|
||||||
|
delete(sessions, s.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessions, nil
|
||||||
|
}
|
||||||
|
|
||||||
// buildHighestCommitHeights inspects the full set of candidate client sessions
|
// buildHighestCommitHeights inspects the full set of candidate client sessions
|
||||||
// loaded from disk, and determines the highest known commit height for each
|
// loaded from disk, and determines the highest known commit height for each
|
||||||
// channel. This allows the client to reject backups that it has already
|
// channel. This allows the client to reject backups that it has already
|
||||||
|
Loading…
Reference in New Issue
Block a user