diff --git a/lnrpc/wtclientrpc/config.go b/lnrpc/wtclientrpc/config.go index a008ca0d..8796bd05 100644 --- a/lnrpc/wtclientrpc/config.go +++ b/lnrpc/wtclientrpc/config.go @@ -1,6 +1,7 @@ package wtclientrpc import ( + "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/watchtower/wtclient" ) @@ -22,4 +23,7 @@ type Config struct { // addresses to ensure we don't leak any information when running over // non-clear networks, e.g. Tor, etc. Resolver lncfg.TCPResolver + + // Log is the logger instance we should log output to. + Log btclog.Logger } diff --git a/lnrpc/wtclientrpc/log.go b/lnrpc/wtclientrpc/log.go deleted file mode 100644 index eaa847d6..00000000 --- a/lnrpc/wtclientrpc/log.go +++ /dev/null @@ -1,48 +0,0 @@ -package wtclientrpc - -import ( - "github.com/btcsuite/btclog" - "github.com/lightningnetwork/lnd/build" -) - -// Subsystem defines the logging code for this subsystem. -const Subsystem = "WTCL" - -// log is a logger that is initialized with no output filters. This means the -// package will not perform any logging by default until the caller requests -// it. -var log btclog.Logger - -// The default amount of logging is none. -func init() { - UseLogger(build.NewSubLogger(Subsystem, nil)) -} - -// DisableLog disables all library log output. Logging output is disabled by -// by default until UseLogger is called. -func DisableLog() { - UseLogger(btclog.Disabled) -} - -// UseLogger uses a specified Logger to output package logging info. This -// should be used in preference to SetLogWriter if the caller is also using -// btclog. -func UseLogger(logger btclog.Logger) { - log = logger -} - -// logClosure is used to provide a closure over expensive logging operations so -// don't have to be performed when the logging level doesn't warrant it. -type logClosure func() string // nolint:unused - -// String invokes the underlying function and returns the result. -func (c logClosure) String() string { - return c() -} - -// newLogClosure returns a new closure over a function that returns a string -// which itself provides a Stringer interface so that it can be used with the -// logging system. -func newLogClosure(c func() string) logClosure { // nolint:unused - return logClosure(c) -} diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 8cbdc9bf..b6f1fa2a 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -115,8 +115,8 @@ func (c *WatchtowerClient) RegisterWithRootServer(grpcServer *grpc.Server) error // all our methods are routed properly. RegisterWatchtowerClientServer(grpcServer, c) - log.Debugf("WatchtowerClient RPC server successfully registered with " + - "root gRPC server") + c.cfg.Log.Debugf("WatchtowerClient RPC server successfully registered " + + "with root gRPC server") return nil } diff --git a/log.go b/log.go index 51754cca..abd16726 100644 --- a/log.go +++ b/log.go @@ -25,7 +25,6 @@ import ( "github.com/lightningnetwork/lnd/lnrpc/signrpc" "github.com/lightningnetwork/lnd/lnrpc/verrpc" "github.com/lightningnetwork/lnd/lnrpc/walletrpc" - "github.com/lightningnetwork/lnd/lnrpc/wtclientrpc" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chanfunding" "github.com/lightningnetwork/lnd/monitoring" @@ -102,7 +101,6 @@ func init() { addSubLogger(routing.Subsystem, routing.UseLogger, localchans.UseLogger) addSubLogger(routerrpc.Subsystem, routerrpc.UseLogger) - addSubLogger(wtclientrpc.Subsystem, wtclientrpc.UseLogger) addSubLogger(chanfitness.Subsystem, chanfitness.UseLogger) addSubLogger(verrpc.Subsystem, verrpc.UseLogger) } diff --git a/rpcserver.go b/rpcserver.go index c725b67b..d014daf4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -589,6 +589,7 @@ func newRPCServer(s *server, macService *macaroons.Service, s.htlcSwitch, activeNetParams.Params, s.chanRouter, routerBackend, s.nodeSigner, s.chanDB, s.sweeper, tower, s.towerClient, cfg.net.ResolveTCPAddr, genInvoiceFeatures, + rpcsLog, ) if err != nil { return nil, err diff --git a/subrpcserver_config.go b/subrpcserver_config.go index 31632028..5c87f3ab 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -5,6 +5,7 @@ import ( "reflect" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" @@ -94,7 +95,8 @@ func (s *subRPCServerConfigs) PopulateDependencies(cc *chainControl, tower *watchtower.Standalone, towerClient wtclient.Client, tcpResolver lncfg.TCPResolver, - genInvoiceFeatures func() *lnwire.FeatureVector) error { + genInvoiceFeatures func() *lnwire.FeatureVector, + rpcLogger btclog.Logger) error { // First, we'll use reflect to obtain a version of the config struct // that allows us to programmatically inspect its fields. @@ -244,6 +246,9 @@ func (s *subRPCServerConfigs) PopulateDependencies(cc *chainControl, subCfgValue.FieldByName("Resolver").Set( reflect.ValueOf(tcpResolver), ) + subCfgValue.FieldByName("Log").Set( + reflect.ValueOf(rpcLogger), + ) default: return fmt.Errorf("unknown field: %v, %T", fieldName, diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 76aa2a4b..1f3c1c3c 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -38,6 +38,14 @@ const ( DefaultForceQuitDelay = 10 * time.Second ) +var ( + // activeSessionFilter is a filter that ignored any sessions which are + // not active. + activeSessionFilter = func(s *wtdb.ClientSession) bool { + return s.Status == wtdb.CSessionActive + } +) + // RegisteredTower encompasses information about a registered watchtower with // the client. type RegisteredTower struct { @@ -268,49 +276,18 @@ func New(config *Config) (*TowerClient, error) { // the client. We will use any of these session if their policies match // the current policy of the client, otherwise they will be ignored and // new sessions will be requested. - sessions, err := cfg.DB.ListClientSessions(nil) + candidateSessions, err := getClientSessions( + cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter, + ) if err != nil { return nil, err } - candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) - sessionTowers := make(map[wtdb.TowerID]*wtdb.Tower) - for _, s := range sessions { - // Candidate sessions must be in an active state. - if s.Status != wtdb.CSessionActive { - continue - } - - // Reload the tower from disk using the tower ID contained in - // each candidate session. We will also rederive any session - // keys needed to be able to communicate with the towers and - // authenticate session requests. This prevents us from having - // to store the private keys on disk. - tower, ok := sessionTowers[s.TowerID] - if !ok { - var err error - tower, err = cfg.DB.LoadTowerByID(s.TowerID) - if err != nil { - return nil, err - } - } - s.Tower = tower - - sessionKey, err := DeriveSessionKey(cfg.SecretKeyRing, s.KeyIndex) - if err != nil { - return nil, err - } - s.SessionPrivKey = sessionKey - - candidateSessions[s.ID] = s - sessionTowers[tower.ID] = tower - } - var candidateTowers []*wtdb.Tower - for _, tower := range sessionTowers { + for _, s := range candidateSessions { log.Infof("Using private watchtower %s, offering policy %s", - tower, cfg.Policy) - candidateTowers = append(candidateTowers, tower) + s.Tower, cfg.Policy) + candidateTowers = append(candidateTowers, s.Tower) } // Load the sweep pkscripts that have been generated for all previously @@ -353,6 +330,50 @@ func New(config *Config) (*TowerClient, error) { return c, nil } +// getClientSessions retrieves the client sessions for a particular tower if +// specified, otherwise all client sessions for all towers are retrieved. An +// optional filter can be provided to filter out any undesired client sessions. +// +// NOTE: This method should only be used when deserialization of a +// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the +// existing ListClientSessions method should be used. +func getClientSessions(db DB, keyRing SecretKeyRing, forTower *wtdb.TowerID, + passesFilter func(*wtdb.ClientSession) bool) ( + map[wtdb.SessionID]*wtdb.ClientSession, error) { + + sessions, err := db.ListClientSessions(forTower) + if err != nil { + return nil, err + } + + // Reload the tower from disk using the tower ID contained in each + // candidate session. We will also rederive any session keys needed to + // be able to communicate with the towers and authenticate session + // requests. This prevents us from having to store the private keys on + // disk. + for _, s := range sessions { + tower, err := db.LoadTowerByID(s.TowerID) + if err != nil { + return nil, err + } + s.Tower = tower + + sessionKey, err := DeriveSessionKey(keyRing, s.KeyIndex) + if err != nil { + return nil, err + } + s.SessionPrivKey = sessionKey + + // If an optional filter was provided, use it to filter out any + // undesired sessions. + if passesFilter != nil && !passesFilter(s) { + delete(sessions, s.ID) + } + } + + return sessions, nil +} + // buildHighestCommitHeights inspects the full set of candidate client sessions // loaded from disk, and determines the highest known commit height for each // channel. This allows the client to reject backups that it has already @@ -1039,7 +1060,9 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { c.candidateTowers.AddCandidate(tower) // Include all of its corresponding sessions to our set of candidates. - sessions, err := c.cfg.DB.ListClientSessions(&tower.ID) + sessions, err := getClientSessions( + c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID, activeSessionFilter, + ) if err != nil { return fmt.Errorf("unable to determine sessions for tower %x: "+ "%v", tower.IdentityKey.SerializeCompressed(), err) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index ea681d57..577d33f9 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -366,17 +366,18 @@ func (c *mockChannel) getState(i uint64) (*wire.MsgTx, *lnwallet.BreachRetributi } type testHarness struct { - t *testing.T - cfg harnessCfg - signer *wtmock.MockSigner - capacity lnwire.MilliSatoshi - clientDB *wtmock.ClientDB - clientCfg *wtclient.Config - client wtclient.Client - serverDB *wtmock.TowerDB - serverCfg *wtserver.Config - server *wtserver.Server - net *mockNet + t *testing.T + cfg harnessCfg + signer *wtmock.MockSigner + capacity lnwire.MilliSatoshi + clientDB *wtmock.ClientDB + clientCfg *wtclient.Config + client wtclient.Client + serverAddr *lnwire.NetAddress + serverDB *wtmock.TowerDB + serverCfg *wtserver.Config + server *wtserver.Server + net *mockNet mu sync.Mutex channels map[lnwire.ChannelID]*mockChannel @@ -467,18 +468,19 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { } h := &testHarness{ - t: t, - cfg: cfg, - signer: signer, - capacity: cfg.localBalance + cfg.remoteBalance, - clientDB: clientDB, - clientCfg: clientCfg, - client: client, - serverDB: serverDB, - serverCfg: serverCfg, - server: server, - net: mockNet, - channels: make(map[lnwire.ChannelID]*mockChannel), + t: t, + cfg: cfg, + signer: signer, + capacity: cfg.localBalance + cfg.remoteBalance, + clientDB: clientDB, + clientCfg: clientCfg, + client: client, + serverAddr: towerAddr, + serverDB: serverDB, + serverCfg: serverCfg, + server: server, + net: mockNet, + channels: make(map[lnwire.ChannelID]*mockChannel), } h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) @@ -782,6 +784,25 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, } } +// addTower adds a tower found at `addr` to the client. +func (h *testHarness) addTower(addr *lnwire.NetAddress) { + h.t.Helper() + + if err := h.client.AddTower(addr); err != nil { + h.t.Fatalf("unable to add tower: %v", err) + } +} + +// removeTower removes a tower from the client. If `addr` is specified, then the +// only said address is removed from the tower. +func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { + h.t.Helper() + + if err := h.client.RemoveTower(pubKey, addr); err != nil { + h.t.Fatalf("unable to remove tower: %v", err) + } +} + const ( localBalance = lnwire.MilliSatoshi(100000000) remoteBalance = lnwire.MilliSatoshi(200000000) @@ -1396,6 +1417,60 @@ var clientTests = []clientTest{ h.waitServerUpdates(hints, 5*time.Second) }, }, + { + // Asserts that the client can continue making backups to a + // tower that's been re-added after it's been removed. + name: "re-add removed tower", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const ( + chanID = 0 + numUpdates = 4 + ) + + // Create four channel updates and only back up the + // first two. + hints := h.advanceChannelN(chanID, numUpdates) + h.backupStates(chanID, 0, numUpdates/2, nil) + h.waitServerUpdates(hints[:numUpdates/2], 5*time.Second) + + // Fully remove the tower, causing its existing sessions + // to be marked inactive. + h.removeTower(h.serverAddr.IdentityKey, nil) + + // Back up the remaining states. Since the tower has + // been removed, it shouldn't receive any updates. + h.backupStates(chanID, numUpdates/2, numUpdates, nil) + h.waitServerUpdates(nil, time.Second) + + // Re-add the tower. We prevent the tower from acking + // session creation to ensure the inactive sessions are + // not used. + h.server.Stop() + h.serverCfg.NoAckCreateSession = true + h.startServer() + h.addTower(h.serverAddr) + h.waitServerUpdates(nil, time.Second) + + // Finally, allow the tower to ack session creation, + // allowing the state updates to be sent through the new + // session. + h.server.Stop() + h.serverCfg.NoAckCreateSession = false + h.startServer() + h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second) + }, + }, } // TestClient executes the client test suite, asserting the ability to backup diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 395f16a4..1f66e245 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -19,7 +19,7 @@ type ClientDB struct { mu sync.Mutex summaries map[lnwire.ChannelID]wtdb.ClientChanSummary - activeSessions map[wtdb.SessionID]*wtdb.ClientSession + activeSessions map[wtdb.SessionID]wtdb.ClientSession towerIndex map[towerPK]wtdb.TowerID towers map[wtdb.TowerID]*wtdb.Tower @@ -31,7 +31,7 @@ type ClientDB struct { func NewClientDB() *ClientDB { return &ClientDB{ summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), - activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), + activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), towerIndex: make(map[towerPK]wtdb.TowerID), towers: make(map[wtdb.TowerID]*wtdb.Tower), indexes: make(map[wtdb.TowerID]uint32), @@ -62,7 +62,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { } for id, session := range towerSessions { session.Status = wtdb.CSessionActive - m.activeSessions[id] = session + m.activeSessions[id] = *session } } else { towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) @@ -122,7 +122,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { return wtdb.ErrTowerUnackedUpdates } session.Status = wtdb.CSessionInactive - m.activeSessions[id] = session + m.activeSessions[id] = *session } return nil @@ -205,10 +205,11 @@ func (m *ClientDB) listClientSessions( sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) for _, session := range m.activeSessions { + session := session if tower != nil && *tower != session.TowerID { continue } - sessions[session.ID] = session + sessions[session.ID] = &session } return sessions, nil @@ -240,7 +241,7 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { // permits us to create another session with this tower. delete(m.indexes, session.TowerID) - m.activeSessions[session.ID] = &wtdb.ClientSession{ + m.activeSessions[session.ID] = wtdb.ClientSession{ ID: session.ID, ClientSessionBody: wtdb.ClientSessionBody{ SeqNum: session.SeqNum, @@ -313,6 +314,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, // Save the update and increment the sequence number. session.CommittedUpdates = append(session.CommittedUpdates, *update) session.SeqNum++ + m.activeSessions[*id] = session return session.TowerLastApplied, nil } @@ -360,6 +362,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err session.AckedUpdates[seqNum] = update.BackupID session.TowerLastApplied = lastApplied + m.activeSessions[*id] = session return nil }