From 01ab551b221c142dfc2a639e7f0acf8e3a85040d Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Mon, 11 May 2020 15:23:43 -0700 Subject: [PATCH 1/6] wtclient: refactor existing candidate session filtering into method --- watchtower/wtclient/client.go | 95 +++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 76aa2a4b..8a37abe5 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -38,6 +38,14 @@ const ( DefaultForceQuitDelay = 10 * time.Second ) +var ( + // activeSessionFilter is a filter that ignored any sessions which are + // not active. + activeSessionFilter = func(s *wtdb.ClientSession) bool { + return s.Status == wtdb.CSessionActive + } +) + // RegisteredTower encompasses information about a registered watchtower with // the client. type RegisteredTower struct { @@ -268,49 +276,18 @@ func New(config *Config) (*TowerClient, error) { // the client. We will use any of these session if their policies match // the current policy of the client, otherwise they will be ignored and // new sessions will be requested. - sessions, err := cfg.DB.ListClientSessions(nil) + candidateSessions, err := getClientSessions( + cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter, + ) if err != nil { return nil, err } - candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) - sessionTowers := make(map[wtdb.TowerID]*wtdb.Tower) - for _, s := range sessions { - // Candidate sessions must be in an active state. - if s.Status != wtdb.CSessionActive { - continue - } - - // Reload the tower from disk using the tower ID contained in - // each candidate session. We will also rederive any session - // keys needed to be able to communicate with the towers and - // authenticate session requests. This prevents us from having - // to store the private keys on disk. - tower, ok := sessionTowers[s.TowerID] - if !ok { - var err error - tower, err = cfg.DB.LoadTowerByID(s.TowerID) - if err != nil { - return nil, err - } - } - s.Tower = tower - - sessionKey, err := DeriveSessionKey(cfg.SecretKeyRing, s.KeyIndex) - if err != nil { - return nil, err - } - s.SessionPrivKey = sessionKey - - candidateSessions[s.ID] = s - sessionTowers[tower.ID] = tower - } - var candidateTowers []*wtdb.Tower - for _, tower := range sessionTowers { + for _, s := range candidateSessions { log.Infof("Using private watchtower %s, offering policy %s", - tower, cfg.Policy) - candidateTowers = append(candidateTowers, tower) + s.Tower, cfg.Policy) + candidateTowers = append(candidateTowers, s.Tower) } // Load the sweep pkscripts that have been generated for all previously @@ -353,6 +330,50 @@ func New(config *Config) (*TowerClient, error) { return c, nil } +// getClientSessions retrieves the client sessions for a particular tower if +// specified, otherwise all client sessions for all towers are retrieved. An +// optional filter can be provided to filter out any undesired client sessions. +// +// NOTE: This method should only be used when deserialization of a +// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the +// existing ListClientSessions method should be used. +func getClientSessions(db DB, keyRing SecretKeyRing, forTower *wtdb.TowerID, + passesFilter func(*wtdb.ClientSession) bool) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { + + sessions, err := db.ListClientSessions(forTower) + if err != nil { + return nil, err + } + + // Reload the tower from disk using the tower ID contained in each + // candidate session. We will also rederive any session keys needed to + // be able to communicate with the towers and authenticate session + // requests. This prevents us from having to store the private keys on + // disk. + for _, s := range sessions { + tower, err := db.LoadTowerByID(s.TowerID) + if err != nil { + return nil, err + } + s.Tower = tower + + sessionKey, err := DeriveSessionKey(keyRing, s.KeyIndex) + if err != nil { + return nil, err + } + s.SessionPrivKey = sessionKey + + // If an optional filter was provided, use it to filter out any + // undesired sessions. + if passesFilter != nil && !passesFilter(s) { + delete(sessions, s.ID) + } + } + + return sessions, nil +} + // buildHighestCommitHeights inspects the full set of candidate client sessions // loaded from disk, and determines the highest known commit height for each // channel. This allows the client to reject backups that it has already From 75c2ebd79475e793c76daa6d20468e55a00127a4 Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Mon, 11 May 2020 15:26:12 -0700 Subject: [PATCH 2/6] wtclient: load missing info into client sessions upon new tower This addresses a potential panic in where we relied on this missing info being populated. --- watchtower/wtclient/client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 8a37abe5..6827cdca 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -1060,7 +1060,9 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { c.candidateTowers.AddCandidate(tower) // Include all of its corresponding sessions to our set of candidates. - sessions, err := c.cfg.DB.ListClientSessions(&tower.ID) + sessions, err := getClientSessions( + c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID, nil, + ) if err != nil { return fmt.Errorf("unable to determine sessions for tower %x: "+ "%v", tower.IdentityKey.SerializeCompressed(), err) From ec5c941512372edaa60d3cc73165b06de0737d66 Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Mon, 11 May 2020 15:26:36 -0700 Subject: [PATCH 3/6] wtclient: filter out inactive sessions upon adding existing/new tower --- watchtower/wtclient/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 6827cdca..1f3c1c3c 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -1061,7 +1061,7 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { // Include all of its corresponding sessions to our set of candidates. sessions, err := getClientSessions( - c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID, nil, + c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID, activeSessionFilter, ) if err != nil { return fmt.Errorf("unable to determine sessions for tower %x: "+ From b195d39ad75c2106bcf1fbfc436f2b072df630b3 Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Mon, 11 May 2020 16:05:04 -0700 Subject: [PATCH 4/6] rpc: use existing rpc logger for wtclientrpc The logger string used to identify the wtclient and wtclientrpc loggers was the same, leading to being unable to modify the log level of the wtclient logger as it would be overwritten with the wtclientrpc's one. To simplify things, we decide to use the existing RPC logger for wtclientrpc. --- lnrpc/wtclientrpc/config.go | 4 +++ lnrpc/wtclientrpc/log.go | 48 ----------------------------------- lnrpc/wtclientrpc/wtclient.go | 4 +-- log.go | 2 -- rpcserver.go | 1 + subrpcserver_config.go | 7 ++++- 6 files changed, 13 insertions(+), 53 deletions(-) delete mode 100644 lnrpc/wtclientrpc/log.go diff --git a/lnrpc/wtclientrpc/config.go b/lnrpc/wtclientrpc/config.go index a008ca0d..8796bd05 100644 --- a/lnrpc/wtclientrpc/config.go +++ b/lnrpc/wtclientrpc/config.go @@ -1,6 +1,7 @@ package wtclientrpc import ( + "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/watchtower/wtclient" ) @@ -22,4 +23,7 @@ type Config struct { // addresses to ensure we don't leak any information when running over // non-clear networks, e.g. Tor, etc. Resolver lncfg.TCPResolver + + // Log is the logger instance we should log output to. + Log btclog.Logger } diff --git a/lnrpc/wtclientrpc/log.go b/lnrpc/wtclientrpc/log.go deleted file mode 100644 index eaa847d6..00000000 --- a/lnrpc/wtclientrpc/log.go +++ /dev/null @@ -1,48 +0,0 @@ -package wtclientrpc - -import ( - "github.com/btcsuite/btclog" - "github.com/lightningnetwork/lnd/build" -) - -// Subsystem defines the logging code for this subsystem. -const Subsystem = "WTCL" - -// log is a logger that is initialized with no output filters. This means the -// package will not perform any logging by default until the caller requests -// it. -var log btclog.Logger - -// The default amount of logging is none. -func init() { - UseLogger(build.NewSubLogger(Subsystem, nil)) -} - -// DisableLog disables all library log output. Logging output is disabled by -// by default until UseLogger is called. -func DisableLog() { - UseLogger(btclog.Disabled) -} - -// UseLogger uses a specified Logger to output package logging info. This -// should be used in preference to SetLogWriter if the caller is also using -// btclog. -func UseLogger(logger btclog.Logger) { - log = logger -} - -// logClosure is used to provide a closure over expensive logging operations so -// don't have to be performed when the logging level doesn't warrant it. -type logClosure func() string // nolint:unused - -// String invokes the underlying function and returns the result. -func (c logClosure) String() string { - return c() -} - -// newLogClosure returns a new closure over a function that returns a string -// which itself provides a Stringer interface so that it can be used with the -// logging system. -func newLogClosure(c func() string) logClosure { // nolint:unused - return logClosure(c) -} diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 8cbdc9bf..b6f1fa2a 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -115,8 +115,8 @@ func (c *WatchtowerClient) RegisterWithRootServer(grpcServer *grpc.Server) error // all our methods are routed properly. RegisterWatchtowerClientServer(grpcServer, c) - log.Debugf("WatchtowerClient RPC server successfully registered with " + - "root gRPC server") + c.cfg.Log.Debugf("WatchtowerClient RPC server successfully registered " + + "with root gRPC server") return nil } diff --git a/log.go b/log.go index 51754cca..abd16726 100644 --- a/log.go +++ b/log.go @@ -25,7 +25,6 @@ import ( "github.com/lightningnetwork/lnd/lnrpc/signrpc" "github.com/lightningnetwork/lnd/lnrpc/verrpc" "github.com/lightningnetwork/lnd/lnrpc/walletrpc" - "github.com/lightningnetwork/lnd/lnrpc/wtclientrpc" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chanfunding" "github.com/lightningnetwork/lnd/monitoring" @@ -102,7 +101,6 @@ func init() { addSubLogger(routing.Subsystem, routing.UseLogger, localchans.UseLogger) addSubLogger(routerrpc.Subsystem, routerrpc.UseLogger) - addSubLogger(wtclientrpc.Subsystem, wtclientrpc.UseLogger) addSubLogger(chanfitness.Subsystem, chanfitness.UseLogger) addSubLogger(verrpc.Subsystem, verrpc.UseLogger) } diff --git a/rpcserver.go b/rpcserver.go index c725b67b..d014daf4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -589,6 +589,7 @@ func newRPCServer(s *server, macService *macaroons.Service, s.htlcSwitch, activeNetParams.Params, s.chanRouter, routerBackend, s.nodeSigner, s.chanDB, s.sweeper, tower, s.towerClient, cfg.net.ResolveTCPAddr, genInvoiceFeatures, + rpcsLog, ) if err != nil { return nil, err diff --git a/subrpcserver_config.go b/subrpcserver_config.go index 31632028..5c87f3ab 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -5,6 +5,7 @@ import ( "reflect" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" @@ -94,7 +95,8 @@ func (s *subRPCServerConfigs) PopulateDependencies(cc *chainControl, tower *watchtower.Standalone, towerClient wtclient.Client, tcpResolver lncfg.TCPResolver, - genInvoiceFeatures func() *lnwire.FeatureVector) error { + genInvoiceFeatures func() *lnwire.FeatureVector, + rpcLogger btclog.Logger) error { // First, we'll use reflect to obtain a version of the config struct // that allows us to programmatically inspect its fields. @@ -244,6 +246,9 @@ func (s *subRPCServerConfigs) PopulateDependencies(cc *chainControl, subCfgValue.FieldByName("Resolver").Set( reflect.ValueOf(tcpResolver), ) + subCfgValue.FieldByName("Log").Set( + reflect.ValueOf(rpcLogger), + ) default: return fmt.Errorf("unknown field: %v, %T", fieldName, From f6f0d3819f5974e4dd65cf13c02a96f14c4a5d8c Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Thu, 14 May 2020 11:48:06 -0700 Subject: [PATCH 5/6] wtclient: test case re-add removed tower --- watchtower/wtclient/client_test.go | 121 +++++++++++++++++++++++------ 1 file changed, 98 insertions(+), 23 deletions(-) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index ea681d57..577d33f9 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -366,17 +366,18 @@ func (c *mockChannel) getState(i uint64) (*wire.MsgTx, *lnwallet.BreachRetributi } type testHarness struct { - t *testing.T - cfg harnessCfg - signer *wtmock.MockSigner - capacity lnwire.MilliSatoshi - clientDB *wtmock.ClientDB - clientCfg *wtclient.Config - client wtclient.Client - serverDB *wtmock.TowerDB - serverCfg *wtserver.Config - server *wtserver.Server - net *mockNet + t *testing.T + cfg harnessCfg + signer *wtmock.MockSigner + capacity lnwire.MilliSatoshi + clientDB *wtmock.ClientDB + clientCfg *wtclient.Config + client wtclient.Client + serverAddr *lnwire.NetAddress + serverDB *wtmock.TowerDB + serverCfg *wtserver.Config + server *wtserver.Server + net *mockNet mu sync.Mutex channels map[lnwire.ChannelID]*mockChannel @@ -467,18 +468,19 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { } h := &testHarness{ - t: t, - cfg: cfg, - signer: signer, - capacity: cfg.localBalance + cfg.remoteBalance, - clientDB: clientDB, - clientCfg: clientCfg, - client: client, - serverDB: serverDB, - serverCfg: serverCfg, - server: server, - net: mockNet, - channels: make(map[lnwire.ChannelID]*mockChannel), + t: t, + cfg: cfg, + signer: signer, + capacity: cfg.localBalance + cfg.remoteBalance, + clientDB: clientDB, + clientCfg: clientCfg, + client: client, + serverAddr: towerAddr, + serverDB: serverDB, + serverCfg: serverCfg, + server: server, + net: mockNet, + channels: make(map[lnwire.ChannelID]*mockChannel), } h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) @@ -782,6 +784,25 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, } } +// addTower adds a tower found at `addr` to the client. +func (h *testHarness) addTower(addr *lnwire.NetAddress) { + h.t.Helper() + + if err := h.client.AddTower(addr); err != nil { + h.t.Fatalf("unable to add tower: %v", err) + } +} + +// removeTower removes a tower from the client. If `addr` is specified, then the +// only said address is removed from the tower. +func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { + h.t.Helper() + + if err := h.client.RemoveTower(pubKey, addr); err != nil { + h.t.Fatalf("unable to remove tower: %v", err) + } +} + const ( localBalance = lnwire.MilliSatoshi(100000000) remoteBalance = lnwire.MilliSatoshi(200000000) @@ -1396,6 +1417,60 @@ var clientTests = []clientTest{ h.waitServerUpdates(hints, 5*time.Second) }, }, + { + // Asserts that the client can continue making backups to a + // tower that's been re-added after it's been removed. + name: "re-add removed tower", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const ( + chanID = 0 + numUpdates = 4 + ) + + // Create four channel updates and only back up the + // first two. + hints := h.advanceChannelN(chanID, numUpdates) + h.backupStates(chanID, 0, numUpdates/2, nil) + h.waitServerUpdates(hints[:numUpdates/2], 5*time.Second) + + // Fully remove the tower, causing its existing sessions + // to be marked inactive. + h.removeTower(h.serverAddr.IdentityKey, nil) + + // Back up the remaining states. Since the tower has + // been removed, it shouldn't receive any updates. + h.backupStates(chanID, numUpdates/2, numUpdates, nil) + h.waitServerUpdates(nil, time.Second) + + // Re-add the tower. We prevent the tower from acking + // session creation to ensure the inactive sessions are + // not used. + h.server.Stop() + h.serverCfg.NoAckCreateSession = true + h.startServer() + h.addTower(h.serverAddr) + h.waitServerUpdates(nil, time.Second) + + // Finally, allow the tower to ack session creation, + // allowing the state updates to be sent through the new + // session. + h.server.Stop() + h.serverCfg.NoAckCreateSession = false + h.startServer() + h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second) + }, + }, } // TestClient executes the client test suite, asserting the ability to backup From c76070054530d568278aa4c59fc80aae6789fcfd Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Thu, 14 May 2020 15:28:49 -0700 Subject: [PATCH 6/6] wtmock: prevent race conditions by not using ClientSession pointers These race conditions originate from the mock database storing and returning pointers, rather than returning a copy. Observed on Travis: WARNING: DATA RACE Read at 0x00c0003222b8 by goroutine 149: github.com/lightningnetwork/lnd/watchtower/wtclient.(*sessionQueue).drainBackups() /home/runner/work/lnd/lnd/watchtower/wtclient/session_queue.go:288 +0xed github.com/lightningnetwork/lnd/watchtower/wtclient.(*sessionQueue).sessionManager() /home/runner/work/lnd/lnd/watchtower/wtclient/session_queue.go:281 +0x450 Previous write at 0x00c0003222b8 by goroutine 93: github.com/lightningnetwork/lnd/watchtower/wtclient.getClientSessions() /home/runner/work/lnd/lnd/watchtower/wtclient/client.go:365 +0x24f github.com/lightningnetwork/lnd/watchtower/wtclient.(*TowerClient).handleNewTower() /home/runner/work/lnd/lnd/watchtower/wtclient/client.go:1063 +0x23e github.com/lightningnetwork/lnd/watchtower/wtclient.(*TowerClient).backupDispatcher() /home/runner/work/lnd/lnd/watchtower/wtclient/client.go:784 +0x10b9 --- watchtower/wtmock/client_db.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 395f16a4..1f66e245 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -19,7 +19,7 @@ type ClientDB struct { mu sync.Mutex summaries map[lnwire.ChannelID]wtdb.ClientChanSummary - activeSessions map[wtdb.SessionID]*wtdb.ClientSession + activeSessions map[wtdb.SessionID]wtdb.ClientSession towerIndex map[towerPK]wtdb.TowerID towers map[wtdb.TowerID]*wtdb.Tower @@ -31,7 +31,7 @@ type ClientDB struct { func NewClientDB() *ClientDB { return &ClientDB{ summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), - activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), + activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), towerIndex: make(map[towerPK]wtdb.TowerID), towers: make(map[wtdb.TowerID]*wtdb.Tower), indexes: make(map[wtdb.TowerID]uint32), @@ -62,7 +62,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { } for id, session := range towerSessions { session.Status = wtdb.CSessionActive - m.activeSessions[id] = session + m.activeSessions[id] = *session } } else { towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) @@ -122,7 +122,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { return wtdb.ErrTowerUnackedUpdates } session.Status = wtdb.CSessionInactive - m.activeSessions[id] = session + m.activeSessions[id] = *session } return nil @@ -205,10 +205,11 @@ func (m *ClientDB) listClientSessions( sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, session := range m.activeSessions { + session := session if tower != nil && *tower != session.TowerID { continue } - sessions[session.ID] = session + sessions[session.ID] = &session } return sessions, nil @@ -240,7 +241,7 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { // permits us to create another session with this tower. delete(m.indexes, session.TowerID) - m.activeSessions[session.ID] = &wtdb.ClientSession{ + m.activeSessions[session.ID] = wtdb.ClientSession{ ID: session.ID, ClientSessionBody: wtdb.ClientSessionBody{ SeqNum: session.SeqNum, @@ -313,6 +314,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, // Save the update and increment the sequence number. session.CommittedUpdates = append(session.CommittedUpdates, *update) session.SeqNum++ + m.activeSessions[*id] = session return session.TowerLastApplied, nil } @@ -360,6 +362,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err session.AckedUpdates[seqNum] = update.BackupID session.TowerLastApplied = lastApplied + m.activeSessions[*id] = session return nil }