diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 395f16a4..1f66e245 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -19,7 +19,7 @@ type ClientDB struct { mu sync.Mutex summaries map[lnwire.ChannelID]wtdb.ClientChanSummary - activeSessions map[wtdb.SessionID]*wtdb.ClientSession + activeSessions map[wtdb.SessionID]wtdb.ClientSession towerIndex map[towerPK]wtdb.TowerID towers map[wtdb.TowerID]*wtdb.Tower @@ -31,7 +31,7 @@ type ClientDB struct { func NewClientDB() *ClientDB { return &ClientDB{ summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), - activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), + activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), towerIndex: make(map[towerPK]wtdb.TowerID), towers: make(map[wtdb.TowerID]*wtdb.Tower), indexes: make(map[wtdb.TowerID]uint32), @@ -62,7 +62,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { } for id, session := range towerSessions { session.Status = wtdb.CSessionActive - m.activeSessions[id] = session + m.activeSessions[id] = *session } } else { towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) @@ -122,7 +122,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { return wtdb.ErrTowerUnackedUpdates } session.Status = wtdb.CSessionInactive - m.activeSessions[id] = session + m.activeSessions[id] = *session } return nil @@ -205,10 +205,11 @@ func (m *ClientDB) listClientSessions( sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, session := range m.activeSessions { + session := session if tower != nil && *tower != session.TowerID { continue } - sessions[session.ID] = session + sessions[session.ID] = &session } return sessions, nil @@ -240,7 +241,7 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { // permits us to create another session with this tower. delete(m.indexes, session.TowerID) - m.activeSessions[session.ID] = &wtdb.ClientSession{ + m.activeSessions[session.ID] = wtdb.ClientSession{ ID: session.ID, ClientSessionBody: wtdb.ClientSessionBody{ SeqNum: session.SeqNum, @@ -313,6 +314,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, // Save the update and increment the sequence number. session.CommittedUpdates = append(session.CommittedUpdates, *update) session.SeqNum++ + m.activeSessions[*id] = session return session.TowerLastApplied, nil } @@ -360,6 +362,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err session.AckedUpdates[seqNum] = update.BackupID session.TowerLastApplied = lastApplied + m.activeSessions[*id] = session return nil }