diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index ab068683..d29e1f5f 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -39,6 +39,12 @@ var ( // created because session key index differs from the reserved key // index. ErrIncorrectKeyIndex = errors.New("incorrect key index") + + // ErrClientSessionAlreadyExists signals an attempt to reinsert + // a client session that has already been created. + ErrClientSessionAlreadyExists = errors.New( + "client session already exists", + ) ) // ClientSession encapsulates a SessionInfo returned from a successful diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 32898e4e..b903e78a 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -65,7 +65,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { m.towerIndex[towerPubKey] = towerID m.towers[towerID] = tower - return tower, nil + return copyTower(tower), nil } // LoadTower retrieves a tower by its tower ID. @@ -74,7 +74,7 @@ func (m *ClientDB) LoadTower(towerID wtdb.TowerID) (*wtdb.Tower, error) { defer m.mu.Unlock() if tower, ok := m.towers[towerID]; ok { - return tower, nil + return copyTower(tower), nil } return nil, wtdb.ErrTowerNotFound @@ -106,6 +106,11 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { m.mu.Lock() defer m.mu.Unlock() + // Ensure that we aren't overwriting an existing session. + if _, ok := m.activeSessions[session.ID]; ok { + return wtdb.ErrClientSessionAlreadyExists + } + // Ensure that a session key index has been reserved for this tower. keyIndex, ok := m.indexes[session.TowerID] if !ok { @@ -151,11 +156,10 @@ func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) { return index, nil } + m.nextIndex++ index := m.nextIndex m.indexes[towerID] = index - m.nextIndex++ - return index, nil } @@ -286,3 +290,14 @@ func cloneBytes(b []byte) []byte { return bb } + +func copyTower(tower *wtdb.Tower) *wtdb.Tower { + t := &wtdb.Tower{ + ID: tower.ID, + IdentityKey: tower.IdentityKey, + Addresses: make([]net.Addr, len(tower.Addresses)), + } + copy(t.Addresses, tower.Addresses) + + return t +}