diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 916e9165..01be53d0 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -206,7 +206,7 @@ func New(config *Config) (*TowerClient, error) { // use any of these session if their policies match the current policy // of the client, otherwise they will be ignored and new sessions will // be requested. - sessions, err := cfg.DB.ListClientSessions() + sessions, err := cfg.DB.ListClientSessions(nil) if err != nil { return nil, err } diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index e8a8b865..faab7439 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -38,8 +38,10 @@ type DB interface { // ListClientSessions returns all sessions that have not yet been // exhausted. This is used on startup to find any sessions which may - // still be able to accept state updates. - ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) + // still be able to accept state updates. An optional tower ID can be + // used to filter out any client sessions in the response that do not + // correspond to this tower. + ListClientSessions(*wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) // FetchChanSummaries loads a mapping from all registered channels to // their channel summaries. diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 92307e99..da979ef6 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -384,29 +384,53 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { }) } -// ListClientSessions returns the set of all client sessions known to the db. -func (c *ClientDB) ListClientSessions() (map[SessionID]*ClientSession, error) { - clientSessions := make(map[SessionID]*ClientSession) +// ListClientSessions returns the set of all client sessions known to the db. An +// optional tower ID can be used to filter out any client sessions in the +// response that do not correspond to this tower. +func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession, error) { + var clientSessions map[SessionID]*ClientSession err := c.db.View(func(tx *bbolt.Tx) error { sessions := tx.Bucket(cSessionBkt) if sessions == nil { return ErrUninitializedDB } + var err error + clientSessions, err = listClientSessions(sessions, id) + return err + }) + if err != nil { + return nil, err + } - return sessions.ForEach(func(k, _ []byte) error { - // We'll load the full client session since the client - // will need the CommittedUpdates and AckedUpdates on - // startup to resume committed updates and compute the - // highest known commit height for each channel. - session, err := getClientSession(sessions, k) - if err != nil { - return err - } + return clientSessions, nil +} - clientSessions[session.ID] = session +// listClientSessions returns the set of all client sessions known to the db. An +// optional tower ID can be used to filter out any client sessions in the +// response that do not correspond to this tower. +func listClientSessions(sessions *bbolt.Bucket, + id *TowerID) (map[SessionID]*ClientSession, error) { + clientSessions := make(map[SessionID]*ClientSession) + err := sessions.ForEach(func(k, _ []byte) error { + // We'll load the full client session since the client will need + // the CommittedUpdates and AckedUpdates on startup to resume + // committed updates and compute the highest known commit height + // for each channel. + session, err := getClientSession(sessions, k) + if err != nil { + return err + } + + // Filter out any sessions that don't correspond to the given + // tower if one was set. + if id != nil && session.TowerID != *id { return nil - }) + } + + clientSessions[session.ID] = session + + return nil }) if err != nil { return nil, err diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 92ebc95b..b400972e 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -48,10 +48,10 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr erro } } -func (h *clientDBHarness) listSessions() map[wtdb.SessionID]*wtdb.ClientSession { +func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { h.t.Helper() - sessions, err := h.db.ListClientSessions() + sessions, err := h.db.ListClientSessions(id) if err != nil { h.t.Fatalf("unable to list client sessions: %v", err) } @@ -172,7 +172,7 @@ func testCreateClientSession(h *clientDBHarness) { // First, assert that this session is not already present in the // database. - if _, ok := h.listSessions()[session.ID]; ok { + if _, ok := h.listSessions(nil)[session.ID]; ok { h.t.Fatalf("session for id %x should not exist yet", session.ID) } @@ -202,7 +202,7 @@ func testCreateClientSession(h *clientDBHarness) { h.insertSession(session, nil) // Verify that the session now exists in the database. - if _, ok := h.listSessions()[session.ID]; !ok { + if _, ok := h.listSessions(nil)[session.ID]; !ok { h.t.Fatalf("session for id %x should exist now", session.ID) } @@ -218,6 +218,51 @@ func testCreateClientSession(h *clientDBHarness) { } } +// testFilterClientSessions asserts that we can correctly filter client sessions +// for a specific tower. +func testFilterClientSessions(h *clientDBHarness) { + // We'll create three client sessions, the first two belonging to one + // tower, and the last belonging to another one. + const numSessions = 3 + towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID) + for i := 0; i < numSessions; i++ { + towerID := wtdb.TowerID(1) + if i == numSessions-1 { + towerID = wtdb.TowerID(2) + } + keyIndex := h.nextKeyIndex(towerID, nil) + sessionID := wtdb.SessionID([33]byte{byte(i)}) + h.insertSession(&wtdb.ClientSession{ + ClientSessionBody: wtdb.ClientSessionBody{ + TowerID: towerID, + Policy: wtpolicy.Policy{ + MaxUpdates: 100, + }, + RewardPkScript: []byte{0x01, 0x02, 0x03}, + KeyIndex: keyIndex, + }, + ID: sessionID, + }, nil) + towerSessions[towerID] = append(towerSessions[towerID], sessionID) + } + + // We should see the expected sessions for each tower when filtering + // them. + for towerID, expectedSessions := range towerSessions { + sessions := h.listSessions(&towerID) + if len(sessions) != len(expectedSessions) { + h.t.Fatalf("expected %v sessions for tower %v, got %v", + len(expectedSessions), towerID, len(sessions)) + } + for _, expectedSession := range expectedSessions { + if _, ok := sessions[expectedSession]; !ok { + h.t.Fatalf("expected session %v for tower %v", + expectedSession, towerID) + } + } + } +} + // testCreateTower asserts the behavior of creating new Tower objects within the // database, and that the latest address is always prepended to the list of // known addresses for the tower. @@ -357,7 +402,7 @@ func testCommitUpdate(h *clientDBHarness) { // Assert that the committed update appears in the client session's // CommittedUpdates map when loaded from disk and that there are no // AckedUpdates. - dbSession := h.listSessions()[session.ID] + dbSession := h.listSessions(nil)[session.ID] checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ *update1, }) @@ -374,7 +419,7 @@ func testCommitUpdate(h *clientDBHarness) { } // Assert that the loaded ClientSession is the same as before. - dbSession = h.listSessions()[session.ID] + dbSession = h.listSessions(nil)[session.ID] checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ *update1, }) @@ -396,7 +441,7 @@ func testCommitUpdate(h *clientDBHarness) { // Check that both updates now appear as committed on the ClientSession // loaded from disk. - dbSession = h.listSessions()[session.ID] + dbSession = h.listSessions(nil)[session.ID] checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ *update1, *update2, @@ -410,7 +455,7 @@ func testCommitUpdate(h *clientDBHarness) { h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate) // Assert that the ClientSession loaded from disk remains unchanged. - dbSession = h.listSessions()[session.ID] + dbSession = h.listSessions(nil)[session.ID] checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ *update1, *update2, @@ -467,7 +512,7 @@ func testAckUpdate(h *clientDBHarness) { // Assert that the ClientSession loaded from disk has one update in it's // AckedUpdates map, and that the committed update has been removed. - dbSession := h.listSessions()[session.ID] + dbSession := h.listSessions(nil)[session.ID] checkCommittedUpdates(h.t, dbSession, nil) checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ 1: update1.BackupID, @@ -487,7 +532,7 @@ func testAckUpdate(h *clientDBHarness) { h.ackUpdate(&session.ID, 2, 2, nil) // Assert that both updates exist as AckedUpdates when loaded from disk. - dbSession = h.listSessions()[session.ID] + dbSession = h.listSessions(nil)[session.ID] checkCommittedUpdates(h.t, dbSession, nil) checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ 1: update1.BackupID, @@ -620,6 +665,10 @@ func TestClientDB(t *testing.T) { name: "create client session", run: testCreateClientSession, }, + { + name: "filter client sessions", + run: testFilterClientSessions, + }, { name: "create tower", run: testCreateTower, diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 88cde50f..ddaca523 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -86,13 +86,20 @@ func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight ui return nil } -// ListClientSessions returns the set of all client sessions known to the db. -func (m *ClientDB) ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) { +// ListClientSessions returns the set of all client sessions known to the db. An +// optional tower ID can be used to filter out any client sessions in the +// response that do not correspond to this tower. +func (m *ClientDB) ListClientSessions( + tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { + m.mu.Lock() defer m.mu.Unlock() sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, session := range m.activeSessions { + if tower != nil && *tower != session.TowerID { + continue + } sessions[session.ID] = session }