From 440ae7818ae2630939b8e6565bffb266cd742ef8 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:48:23 -0700 Subject: [PATCH] watchtower/wtmock/client_db: adjust mock clientdb behavior In advance of the upcoming wtdb.ClientDB, we'll modify the behavior of the mockdb to be more like the final bbolt backed one, and assert that all or our tests are still passing. --- watchtower/wtdb/client_session.go | 6 ++++++ watchtower/wtmock/client_db.go | 23 +++++++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index ab068683..d29e1f5f 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -39,6 +39,12 @@ var ( // created because session key index differs from the reserved key // index. ErrIncorrectKeyIndex = errors.New("incorrect key index") + + // ErrClientSessionAlreadyExists signals an attempt to reinsert + // a client session that has already been created. + ErrClientSessionAlreadyExists = errors.New( + "client session already exists", + ) ) // ClientSession encapsulates a SessionInfo returned from a successful diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 32898e4e..b903e78a 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -65,7 +65,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { m.towerIndex[towerPubKey] = towerID m.towers[towerID] = tower - return tower, nil + return copyTower(tower), nil } // LoadTower retrieves a tower by its tower ID. @@ -74,7 +74,7 @@ func (m *ClientDB) LoadTower(towerID wtdb.TowerID) (*wtdb.Tower, error) { defer m.mu.Unlock() if tower, ok := m.towers[towerID]; ok { - return tower, nil + return copyTower(tower), nil } return nil, wtdb.ErrTowerNotFound @@ -106,6 +106,11 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { m.mu.Lock() defer m.mu.Unlock() + // Ensure that we aren't overwriting an existing session. + if _, ok := m.activeSessions[session.ID]; ok { + return wtdb.ErrClientSessionAlreadyExists + } + // Ensure that a session key index has been reserved for this tower. keyIndex, ok := m.indexes[session.TowerID] if !ok { @@ -151,11 +156,10 @@ func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) { return index, nil } + m.nextIndex++ index := m.nextIndex m.indexes[towerID] = index - m.nextIndex++ - return index, nil } @@ -286,3 +290,14 @@ func cloneBytes(b []byte) []byte { return bb } + +func copyTower(tower *wtdb.Tower) *wtdb.Tower { + t := &wtdb.Tower{ + ID: tower.ID, + IdentityKey: tower.IdentityKey, + Addresses: make([]net.Addr, len(tower.Addresses)), + } + copy(t.Addresses, tower.Addresses) + + return t +}