diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index df8f376e..5aef8619 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -23,6 +23,14 @@ type DB interface { // LoadTower retrieves a tower by its tower ID. LoadTower(uint64) (*wtdb.Tower, error) + // NextSessionKeyIndex reserves a new session key derivation index for a + // particular tower id. The index is reserved for that tower until + // CreateClientSession is invoked for that tower and index, at which + // point a new index for that tower can be reserved. Multiple calls to + // this method before CreateClientSession is invoked should return the + // same index. + NextSessionKeyIndex(uint64) (uint32, error) + // CreateClientSession saves a newly negotiated client session to the // client's database. This enables the session to be used across // restarts. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 5853c363..6ce73dbc 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -22,6 +22,9 @@ type ClientDB struct { activeSessions map[wtdb.SessionID]*wtdb.ClientSession towerIndex map[towerPK]uint64 towers map[uint64]*wtdb.Tower + + nextIndex uint32 + indexes map[uint64]uint32 } // NewClientDB initializes a new mock ClientDB. @@ -31,6 +34,7 @@ func NewClientDB() *ClientDB { activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), towerIndex: make(map[towerPK]uint64), towers: make(map[uint64]*wtdb.Tower), + indexes: make(map[uint64]uint32), } } @@ -118,6 +122,27 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { return nil } +// NextSessionKeyIndex reserves a new session key derivation index for a +// particular tower id. The index is reserved for that tower until +// CreateClientSession is invoked for that tower and index, at which point a new +// index for that tower can be reserved. Multiple calls to this method before +// CreateClientSession is invoked should return the same index. +func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if index, ok := m.indexes[towerID]; ok { + return index, nil + } + + index := m.nextIndex + m.indexes[towerID] = index + + m.nextIndex++ + + return index, nil +} + // CommitUpdate persists the CommittedUpdate provided in the slot for (session, // seqNum). This allows the client to retransmit this update on startup. func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,