Merge pull request #4274 from wpaulino/wtclient-add-tower-filter

wtclient: load missing info into client sessions upon new tower
This commit is contained in:
Olaoluwa Osuntokun 2020-05-15 16:13:22 -07:00 committed by GitHub
commit c6ae06242d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 181 additions and 120 deletions

@ -1,6 +1,7 @@
package wtclientrpc package wtclientrpc
import ( import (
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lncfg"
"github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtclient"
) )
@ -22,4 +23,7 @@ type Config struct {
// addresses to ensure we don't leak any information when running over // addresses to ensure we don't leak any information when running over
// non-clear networks, e.g. Tor, etc. // non-clear networks, e.g. Tor, etc.
Resolver lncfg.TCPResolver Resolver lncfg.TCPResolver
// Log is the logger instance we should log output to.
Log btclog.Logger
} }

@ -1,48 +0,0 @@
package wtclientrpc
import (
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build"
)
// Subsystem defines the logging code for this subsystem.
const Subsystem = "WTCL"
// log is a logger that is initialized with no output filters. This means the
// package will not perform any logging by default until the caller requests
// it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger(Subsystem, nil))
}
// DisableLog disables all library log output. Logging output is disabled by
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info. This
// should be used in preference to SetLogWriter if the caller is also using
// btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
// logClosure is used to provide a closure over expensive logging operations so
// don't have to be performed when the logging level doesn't warrant it.
type logClosure func() string // nolint:unused
// String invokes the underlying function and returns the result.
func (c logClosure) String() string {
return c()
}
// newLogClosure returns a new closure over a function that returns a string
// which itself provides a Stringer interface so that it can be used with the
// logging system.
func newLogClosure(c func() string) logClosure { // nolint:unused
return logClosure(c)
}

@ -115,8 +115,8 @@ func (c *WatchtowerClient) RegisterWithRootServer(grpcServer *grpc.Server) error
// all our methods are routed properly. // all our methods are routed properly.
RegisterWatchtowerClientServer(grpcServer, c) RegisterWatchtowerClientServer(grpcServer, c)
log.Debugf("WatchtowerClient RPC server successfully registered with " + c.cfg.Log.Debugf("WatchtowerClient RPC server successfully registered " +
"root gRPC server") "with root gRPC server")
return nil return nil
} }

2
log.go

@ -25,7 +25,6 @@ import (
"github.com/lightningnetwork/lnd/lnrpc/signrpc" "github.com/lightningnetwork/lnd/lnrpc/signrpc"
"github.com/lightningnetwork/lnd/lnrpc/verrpc" "github.com/lightningnetwork/lnd/lnrpc/verrpc"
"github.com/lightningnetwork/lnd/lnrpc/walletrpc" "github.com/lightningnetwork/lnd/lnrpc/walletrpc"
"github.com/lightningnetwork/lnd/lnrpc/wtclientrpc"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chanfunding" "github.com/lightningnetwork/lnd/lnwallet/chanfunding"
"github.com/lightningnetwork/lnd/monitoring" "github.com/lightningnetwork/lnd/monitoring"
@ -102,7 +101,6 @@ func init() {
addSubLogger(routing.Subsystem, routing.UseLogger, localchans.UseLogger) addSubLogger(routing.Subsystem, routing.UseLogger, localchans.UseLogger)
addSubLogger(routerrpc.Subsystem, routerrpc.UseLogger) addSubLogger(routerrpc.Subsystem, routerrpc.UseLogger)
addSubLogger(wtclientrpc.Subsystem, wtclientrpc.UseLogger)
addSubLogger(chanfitness.Subsystem, chanfitness.UseLogger) addSubLogger(chanfitness.Subsystem, chanfitness.UseLogger)
addSubLogger(verrpc.Subsystem, verrpc.UseLogger) addSubLogger(verrpc.Subsystem, verrpc.UseLogger)
} }

@ -589,6 +589,7 @@ func newRPCServer(s *server, macService *macaroons.Service,
s.htlcSwitch, activeNetParams.Params, s.chanRouter, s.htlcSwitch, activeNetParams.Params, s.chanRouter,
routerBackend, s.nodeSigner, s.chanDB, s.sweeper, tower, routerBackend, s.nodeSigner, s.chanDB, s.sweeper, tower,
s.towerClient, cfg.net.ResolveTCPAddr, genInvoiceFeatures, s.towerClient, cfg.net.ResolveTCPAddr, genInvoiceFeatures,
rpcsLog,
) )
if err != nil { if err != nil {
return nil, err return nil, err

@ -5,6 +5,7 @@ import (
"reflect" "reflect"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/autopilot"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
@ -94,7 +95,8 @@ func (s *subRPCServerConfigs) PopulateDependencies(cc *chainControl,
tower *watchtower.Standalone, tower *watchtower.Standalone,
towerClient wtclient.Client, towerClient wtclient.Client,
tcpResolver lncfg.TCPResolver, tcpResolver lncfg.TCPResolver,
genInvoiceFeatures func() *lnwire.FeatureVector) error { genInvoiceFeatures func() *lnwire.FeatureVector,
rpcLogger btclog.Logger) error {
// First, we'll use reflect to obtain a version of the config struct // First, we'll use reflect to obtain a version of the config struct
// that allows us to programmatically inspect its fields. // that allows us to programmatically inspect its fields.
@ -244,6 +246,9 @@ func (s *subRPCServerConfigs) PopulateDependencies(cc *chainControl,
subCfgValue.FieldByName("Resolver").Set( subCfgValue.FieldByName("Resolver").Set(
reflect.ValueOf(tcpResolver), reflect.ValueOf(tcpResolver),
) )
subCfgValue.FieldByName("Log").Set(
reflect.ValueOf(rpcLogger),
)
default: default:
return fmt.Errorf("unknown field: %v, %T", fieldName, return fmt.Errorf("unknown field: %v, %T", fieldName,

@ -38,6 +38,14 @@ const (
DefaultForceQuitDelay = 10 * time.Second DefaultForceQuitDelay = 10 * time.Second
) )
var (
// activeSessionFilter is a filter that ignored any sessions which are
// not active.
activeSessionFilter = func(s *wtdb.ClientSession) bool {
return s.Status == wtdb.CSessionActive
}
)
// RegisteredTower encompasses information about a registered watchtower with // RegisteredTower encompasses information about a registered watchtower with
// the client. // the client.
type RegisteredTower struct { type RegisteredTower struct {
@ -268,49 +276,18 @@ func New(config *Config) (*TowerClient, error) {
// the client. We will use any of these session if their policies match // the client. We will use any of these session if their policies match
// the current policy of the client, otherwise they will be ignored and // the current policy of the client, otherwise they will be ignored and
// new sessions will be requested. // new sessions will be requested.
sessions, err := cfg.DB.ListClientSessions(nil) candidateSessions, err := getClientSessions(
cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
sessionTowers := make(map[wtdb.TowerID]*wtdb.Tower)
for _, s := range sessions {
// Candidate sessions must be in an active state.
if s.Status != wtdb.CSessionActive {
continue
}
// Reload the tower from disk using the tower ID contained in
// each candidate session. We will also rederive any session
// keys needed to be able to communicate with the towers and
// authenticate session requests. This prevents us from having
// to store the private keys on disk.
tower, ok := sessionTowers[s.TowerID]
if !ok {
var err error
tower, err = cfg.DB.LoadTowerByID(s.TowerID)
if err != nil {
return nil, err
}
}
s.Tower = tower
sessionKey, err := DeriveSessionKey(cfg.SecretKeyRing, s.KeyIndex)
if err != nil {
return nil, err
}
s.SessionPrivKey = sessionKey
candidateSessions[s.ID] = s
sessionTowers[tower.ID] = tower
}
var candidateTowers []*wtdb.Tower var candidateTowers []*wtdb.Tower
for _, tower := range sessionTowers { for _, s := range candidateSessions {
log.Infof("Using private watchtower %s, offering policy %s", log.Infof("Using private watchtower %s, offering policy %s",
tower, cfg.Policy) s.Tower, cfg.Policy)
candidateTowers = append(candidateTowers, tower) candidateTowers = append(candidateTowers, s.Tower)
} }
// Load the sweep pkscripts that have been generated for all previously // Load the sweep pkscripts that have been generated for all previously
@ -353,6 +330,50 @@ func New(config *Config) (*TowerClient, error) {
return c, nil return c, nil
} }
// getClientSessions retrieves the client sessions for a particular tower if
// specified, otherwise all client sessions for all towers are retrieved. An
// optional filter can be provided to filter out any undesired client sessions.
//
// NOTE: This method should only be used when deserialization of a
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
// existing ListClientSessions method should be used.
func getClientSessions(db DB, keyRing SecretKeyRing, forTower *wtdb.TowerID,
passesFilter func(*wtdb.ClientSession) bool) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
sessions, err := db.ListClientSessions(forTower)
if err != nil {
return nil, err
}
// Reload the tower from disk using the tower ID contained in each
// candidate session. We will also rederive any session keys needed to
// be able to communicate with the towers and authenticate session
// requests. This prevents us from having to store the private keys on
// disk.
for _, s := range sessions {
tower, err := db.LoadTowerByID(s.TowerID)
if err != nil {
return nil, err
}
s.Tower = tower
sessionKey, err := DeriveSessionKey(keyRing, s.KeyIndex)
if err != nil {
return nil, err
}
s.SessionPrivKey = sessionKey
// If an optional filter was provided, use it to filter out any
// undesired sessions.
if passesFilter != nil && !passesFilter(s) {
delete(sessions, s.ID)
}
}
return sessions, nil
}
// buildHighestCommitHeights inspects the full set of candidate client sessions // buildHighestCommitHeights inspects the full set of candidate client sessions
// loaded from disk, and determines the highest known commit height for each // loaded from disk, and determines the highest known commit height for each
// channel. This allows the client to reject backups that it has already // channel. This allows the client to reject backups that it has already
@ -1039,7 +1060,9 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error {
c.candidateTowers.AddCandidate(tower) c.candidateTowers.AddCandidate(tower)
// Include all of its corresponding sessions to our set of candidates. // Include all of its corresponding sessions to our set of candidates.
sessions, err := c.cfg.DB.ListClientSessions(&tower.ID) sessions, err := getClientSessions(
c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID, activeSessionFilter,
)
if err != nil { if err != nil {
return fmt.Errorf("unable to determine sessions for tower %x: "+ return fmt.Errorf("unable to determine sessions for tower %x: "+
"%v", tower.IdentityKey.SerializeCompressed(), err) "%v", tower.IdentityKey.SerializeCompressed(), err)

@ -366,17 +366,18 @@ func (c *mockChannel) getState(i uint64) (*wire.MsgTx, *lnwallet.BreachRetributi
} }
type testHarness struct { type testHarness struct {
t *testing.T t *testing.T
cfg harnessCfg cfg harnessCfg
signer *wtmock.MockSigner signer *wtmock.MockSigner
capacity lnwire.MilliSatoshi capacity lnwire.MilliSatoshi
clientDB *wtmock.ClientDB clientDB *wtmock.ClientDB
clientCfg *wtclient.Config clientCfg *wtclient.Config
client wtclient.Client client wtclient.Client
serverDB *wtmock.TowerDB serverAddr *lnwire.NetAddress
serverCfg *wtserver.Config serverDB *wtmock.TowerDB
server *wtserver.Server serverCfg *wtserver.Config
net *mockNet server *wtserver.Server
net *mockNet
mu sync.Mutex mu sync.Mutex
channels map[lnwire.ChannelID]*mockChannel channels map[lnwire.ChannelID]*mockChannel
@ -467,18 +468,19 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
} }
h := &testHarness{ h := &testHarness{
t: t, t: t,
cfg: cfg, cfg: cfg,
signer: signer, signer: signer,
capacity: cfg.localBalance + cfg.remoteBalance, capacity: cfg.localBalance + cfg.remoteBalance,
clientDB: clientDB, clientDB: clientDB,
clientCfg: clientCfg, clientCfg: clientCfg,
client: client, client: client,
serverDB: serverDB, serverAddr: towerAddr,
serverCfg: serverCfg, serverDB: serverDB,
server: server, serverCfg: serverCfg,
net: mockNet, server: server,
channels: make(map[lnwire.ChannelID]*mockChannel), net: mockNet,
channels: make(map[lnwire.ChannelID]*mockChannel),
} }
h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
@ -782,6 +784,25 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
} }
} }
// addTower adds a tower found at `addr` to the client.
func (h *testHarness) addTower(addr *lnwire.NetAddress) {
h.t.Helper()
if err := h.client.AddTower(addr); err != nil {
h.t.Fatalf("unable to add tower: %v", err)
}
}
// removeTower removes a tower from the client. If `addr` is specified, then the
// only said address is removed from the tower.
func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) {
h.t.Helper()
if err := h.client.RemoveTower(pubKey, addr); err != nil {
h.t.Fatalf("unable to remove tower: %v", err)
}
}
const ( const (
localBalance = lnwire.MilliSatoshi(100000000) localBalance = lnwire.MilliSatoshi(100000000)
remoteBalance = lnwire.MilliSatoshi(200000000) remoteBalance = lnwire.MilliSatoshi(200000000)
@ -1396,6 +1417,60 @@ var clientTests = []clientTest{
h.waitServerUpdates(hints, 5*time.Second) h.waitServerUpdates(hints, 5*time.Second)
}, },
}, },
{
// Asserts that the client can continue making backups to a
// tower that's been re-added after it's been removed.
name: "re-add removed tower",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 5,
},
},
fn: func(h *testHarness) {
const (
chanID = 0
numUpdates = 4
)
// Create four channel updates and only back up the
// first two.
hints := h.advanceChannelN(chanID, numUpdates)
h.backupStates(chanID, 0, numUpdates/2, nil)
h.waitServerUpdates(hints[:numUpdates/2], 5*time.Second)
// Fully remove the tower, causing its existing sessions
// to be marked inactive.
h.removeTower(h.serverAddr.IdentityKey, nil)
// Back up the remaining states. Since the tower has
// been removed, it shouldn't receive any updates.
h.backupStates(chanID, numUpdates/2, numUpdates, nil)
h.waitServerUpdates(nil, time.Second)
// Re-add the tower. We prevent the tower from acking
// session creation to ensure the inactive sessions are
// not used.
h.server.Stop()
h.serverCfg.NoAckCreateSession = true
h.startServer()
h.addTower(h.serverAddr)
h.waitServerUpdates(nil, time.Second)
// Finally, allow the tower to ack session creation,
// allowing the state updates to be sent through the new
// session.
h.server.Stop()
h.serverCfg.NoAckCreateSession = false
h.startServer()
h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second)
},
},
} }
// TestClient executes the client test suite, asserting the ability to backup // TestClient executes the client test suite, asserting the ability to backup

@ -19,7 +19,7 @@ type ClientDB struct {
mu sync.Mutex mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]*wtdb.ClientSession activeSessions map[wtdb.SessionID]wtdb.ClientSession
towerIndex map[towerPK]wtdb.TowerID towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower towers map[wtdb.TowerID]*wtdb.Tower
@ -31,7 +31,7 @@ type ClientDB struct {
func NewClientDB() *ClientDB { func NewClientDB() *ClientDB {
return &ClientDB{ return &ClientDB{
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
towerIndex: make(map[towerPK]wtdb.TowerID), towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower), towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[wtdb.TowerID]uint32), indexes: make(map[wtdb.TowerID]uint32),
@ -62,7 +62,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
} }
for id, session := range towerSessions { for id, session := range towerSessions {
session.Status = wtdb.CSessionActive session.Status = wtdb.CSessionActive
m.activeSessions[id] = session m.activeSessions[id] = *session
} }
} else { } else {
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
@ -122,7 +122,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return wtdb.ErrTowerUnackedUpdates return wtdb.ErrTowerUnackedUpdates
} }
session.Status = wtdb.CSessionInactive session.Status = wtdb.CSessionInactive
m.activeSessions[id] = session m.activeSessions[id] = *session
} }
return nil return nil
@ -205,10 +205,11 @@ func (m *ClientDB) listClientSessions(
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, session := range m.activeSessions { for _, session := range m.activeSessions {
session := session
if tower != nil && *tower != session.TowerID { if tower != nil && *tower != session.TowerID {
continue continue
} }
sessions[session.ID] = session sessions[session.ID] = &session
} }
return sessions, nil return sessions, nil
@ -240,7 +241,7 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
// permits us to create another session with this tower. // permits us to create another session with this tower.
delete(m.indexes, session.TowerID) delete(m.indexes, session.TowerID)
m.activeSessions[session.ID] = &wtdb.ClientSession{ m.activeSessions[session.ID] = wtdb.ClientSession{
ID: session.ID, ID: session.ID,
ClientSessionBody: wtdb.ClientSessionBody{ ClientSessionBody: wtdb.ClientSessionBody{
SeqNum: session.SeqNum, SeqNum: session.SeqNum,
@ -313,6 +314,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
// Save the update and increment the sequence number. // Save the update and increment the sequence number.
session.CommittedUpdates = append(session.CommittedUpdates, *update) session.CommittedUpdates = append(session.CommittedUpdates, *update)
session.SeqNum++ session.SeqNum++
m.activeSessions[*id] = session
return session.TowerLastApplied, nil return session.TowerLastApplied, nil
} }
@ -360,6 +362,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
session.AckedUpdates[seqNum] = update.BackupID session.AckedUpdates[seqNum] = update.BackupID
session.TowerLastApplied = lastApplied session.TowerLastApplied = lastApplied
m.activeSessions[*id] = session
return nil return nil
} }