watchtower: extend client db to filter sessions for a specific tower

This currently takes O(N) time as there does not exist an index of
active client sessions for each watchtower within the client's database.
This index is likely to be added in the future.
This commit is contained in:
Wilmer Paulino 2019-06-07 17:44:55 -07:00
parent 4abadc82f3
commit 56d66c80a1
No known key found for this signature in database
GPG Key ID: 6DF57B9F9514972F
5 changed files with 111 additions and 29 deletions

@ -206,7 +206,7 @@ func New(config *Config) (*TowerClient, error) {
// use any of these session if their policies match the current policy // use any of these session if their policies match the current policy
// of the client, otherwise they will be ignored and new sessions will // of the client, otherwise they will be ignored and new sessions will
// be requested. // be requested.
sessions, err := cfg.DB.ListClientSessions() sessions, err := cfg.DB.ListClientSessions(nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -38,8 +38,10 @@ type DB interface {
// ListClientSessions returns all sessions that have not yet been // ListClientSessions returns all sessions that have not yet been
// exhausted. This is used on startup to find any sessions which may // exhausted. This is used on startup to find any sessions which may
// still be able to accept state updates. // still be able to accept state updates. An optional tower ID can be
ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) // used to filter out any client sessions in the response that do not
// correspond to this tower.
ListClientSessions(*wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error)
// FetchChanSummaries loads a mapping from all registered channels to // FetchChanSummaries loads a mapping from all registered channels to
// their channel summaries. // their channel summaries.

@ -384,30 +384,54 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
}) })
} }
// ListClientSessions returns the set of all client sessions known to the db. // ListClientSessions returns the set of all client sessions known to the db. An
func (c *ClientDB) ListClientSessions() (map[SessionID]*ClientSession, error) { // optional tower ID can be used to filter out any client sessions in the
clientSessions := make(map[SessionID]*ClientSession) // response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession
err := c.db.View(func(tx *bbolt.Tx) error { err := c.db.View(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(cSessionBkt) sessions := tx.Bucket(cSessionBkt)
if sessions == nil { if sessions == nil {
return ErrUninitializedDB return ErrUninitializedDB
} }
var err error
clientSessions, err = listClientSessions(sessions, id)
return err
})
if err != nil {
return nil, err
}
return sessions.ForEach(func(k, _ []byte) error { return clientSessions, nil
// We'll load the full client session since the client }
// will need the CommittedUpdates and AckedUpdates on
// startup to resume committed updates and compute the // listClientSessions returns the set of all client sessions known to the db. An
// highest known commit height for each channel. // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func listClientSessions(sessions *bbolt.Bucket,
id *TowerID) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
err := sessions.ForEach(func(k, _ []byte) error {
// We'll load the full client session since the client will need
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := getClientSession(sessions, k) session, err := getClientSession(sessions, k)
if err != nil { if err != nil {
return err return err
} }
// Filter out any sessions that don't correspond to the given
// tower if one was set.
if id != nil && session.TowerID != *id {
return nil
}
clientSessions[session.ID] = session clientSessions[session.ID] = session
return nil return nil
}) })
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -48,10 +48,10 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr erro
} }
} }
func (h *clientDBHarness) listSessions() map[wtdb.SessionID]*wtdb.ClientSession { func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession {
h.t.Helper() h.t.Helper()
sessions, err := h.db.ListClientSessions() sessions, err := h.db.ListClientSessions(id)
if err != nil { if err != nil {
h.t.Fatalf("unable to list client sessions: %v", err) h.t.Fatalf("unable to list client sessions: %v", err)
} }
@ -172,7 +172,7 @@ func testCreateClientSession(h *clientDBHarness) {
// First, assert that this session is not already present in the // First, assert that this session is not already present in the
// database. // database.
if _, ok := h.listSessions()[session.ID]; ok { if _, ok := h.listSessions(nil)[session.ID]; ok {
h.t.Fatalf("session for id %x should not exist yet", session.ID) h.t.Fatalf("session for id %x should not exist yet", session.ID)
} }
@ -202,7 +202,7 @@ func testCreateClientSession(h *clientDBHarness) {
h.insertSession(session, nil) h.insertSession(session, nil)
// Verify that the session now exists in the database. // Verify that the session now exists in the database.
if _, ok := h.listSessions()[session.ID]; !ok { if _, ok := h.listSessions(nil)[session.ID]; !ok {
h.t.Fatalf("session for id %x should exist now", session.ID) h.t.Fatalf("session for id %x should exist now", session.ID)
} }
@ -218,6 +218,51 @@ func testCreateClientSession(h *clientDBHarness) {
} }
} }
// testFilterClientSessions asserts that we can correctly filter client sessions
// for a specific tower.
func testFilterClientSessions(h *clientDBHarness) {
// We'll create three client sessions, the first two belonging to one
// tower, and the last belonging to another one.
const numSessions = 3
towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID)
for i := 0; i < numSessions; i++ {
towerID := wtdb.TowerID(1)
if i == numSessions-1 {
towerID = wtdb.TowerID(2)
}
keyIndex := h.nextKeyIndex(towerID, nil)
sessionID := wtdb.SessionID([33]byte{byte(i)})
h.insertSession(&wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: towerID,
Policy: wtpolicy.Policy{
MaxUpdates: 100,
},
RewardPkScript: []byte{0x01, 0x02, 0x03},
KeyIndex: keyIndex,
},
ID: sessionID,
}, nil)
towerSessions[towerID] = append(towerSessions[towerID], sessionID)
}
// We should see the expected sessions for each tower when filtering
// them.
for towerID, expectedSessions := range towerSessions {
sessions := h.listSessions(&towerID)
if len(sessions) != len(expectedSessions) {
h.t.Fatalf("expected %v sessions for tower %v, got %v",
len(expectedSessions), towerID, len(sessions))
}
for _, expectedSession := range expectedSessions {
if _, ok := sessions[expectedSession]; !ok {
h.t.Fatalf("expected session %v for tower %v",
expectedSession, towerID)
}
}
}
}
// testCreateTower asserts the behavior of creating new Tower objects within the // testCreateTower asserts the behavior of creating new Tower objects within the
// database, and that the latest address is always prepended to the list of // database, and that the latest address is always prepended to the list of
// known addresses for the tower. // known addresses for the tower.
@ -357,7 +402,7 @@ func testCommitUpdate(h *clientDBHarness) {
// Assert that the committed update appears in the client session's // Assert that the committed update appears in the client session's
// CommittedUpdates map when loaded from disk and that there are no // CommittedUpdates map when loaded from disk and that there are no
// AckedUpdates. // AckedUpdates.
dbSession := h.listSessions()[session.ID] dbSession := h.listSessions(nil)[session.ID]
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1, *update1,
}) })
@ -374,7 +419,7 @@ func testCommitUpdate(h *clientDBHarness) {
} }
// Assert that the loaded ClientSession is the same as before. // Assert that the loaded ClientSession is the same as before.
dbSession = h.listSessions()[session.ID] dbSession = h.listSessions(nil)[session.ID]
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1, *update1,
}) })
@ -396,7 +441,7 @@ func testCommitUpdate(h *clientDBHarness) {
// Check that both updates now appear as committed on the ClientSession // Check that both updates now appear as committed on the ClientSession
// loaded from disk. // loaded from disk.
dbSession = h.listSessions()[session.ID] dbSession = h.listSessions(nil)[session.ID]
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1, *update1,
*update2, *update2,
@ -410,7 +455,7 @@ func testCommitUpdate(h *clientDBHarness) {
h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate) h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate)
// Assert that the ClientSession loaded from disk remains unchanged. // Assert that the ClientSession loaded from disk remains unchanged.
dbSession = h.listSessions()[session.ID] dbSession = h.listSessions(nil)[session.ID]
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1, *update1,
*update2, *update2,
@ -467,7 +512,7 @@ func testAckUpdate(h *clientDBHarness) {
// Assert that the ClientSession loaded from disk has one update in it's // Assert that the ClientSession loaded from disk has one update in it's
// AckedUpdates map, and that the committed update has been removed. // AckedUpdates map, and that the committed update has been removed.
dbSession := h.listSessions()[session.ID] dbSession := h.listSessions(nil)[session.ID]
checkCommittedUpdates(h.t, dbSession, nil) checkCommittedUpdates(h.t, dbSession, nil)
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
1: update1.BackupID, 1: update1.BackupID,
@ -487,7 +532,7 @@ func testAckUpdate(h *clientDBHarness) {
h.ackUpdate(&session.ID, 2, 2, nil) h.ackUpdate(&session.ID, 2, 2, nil)
// Assert that both updates exist as AckedUpdates when loaded from disk. // Assert that both updates exist as AckedUpdates when loaded from disk.
dbSession = h.listSessions()[session.ID] dbSession = h.listSessions(nil)[session.ID]
checkCommittedUpdates(h.t, dbSession, nil) checkCommittedUpdates(h.t, dbSession, nil)
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
1: update1.BackupID, 1: update1.BackupID,
@ -620,6 +665,10 @@ func TestClientDB(t *testing.T) {
name: "create client session", name: "create client session",
run: testCreateClientSession, run: testCreateClientSession,
}, },
{
name: "filter client sessions",
run: testFilterClientSessions,
},
{ {
name: "create tower", name: "create tower",
run: testCreateTower, run: testCreateTower,

@ -86,13 +86,20 @@ func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight ui
return nil return nil
} }
// ListClientSessions returns the set of all client sessions known to the db. // ListClientSessions returns the set of all client sessions known to the db. An
func (m *ClientDB) ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) { // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func (m *ClientDB) ListClientSessions(
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, session := range m.activeSessions { for _, session := range m.activeSessions {
if tower != nil && *tower != session.TowerID {
continue
}
sessions[session.ID] = session sessions[session.ID] = session
} }