Merge pull request #4274 from wpaulino/wtclient-add-tower-filter
wtclient: load missing info into client sessions upon new tower
This commit is contained in:
commit
c6ae06242d
@ -1,6 +1,7 @@
|
||||
package wtclientrpc
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/lncfg"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||
)
|
||||
@ -22,4 +23,7 @@ type Config struct {
|
||||
// addresses to ensure we don't leak any information when running over
|
||||
// non-clear networks, e.g. Tor, etc.
|
||||
Resolver lncfg.TCPResolver
|
||||
|
||||
// Log is the logger instance we should log output to.
|
||||
Log btclog.Logger
|
||||
}
|
||||
|
@ -1,48 +0,0 @@
|
||||
package wtclientrpc
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
)
|
||||
|
||||
// Subsystem defines the logging code for this subsystem.
|
||||
const Subsystem = "WTCL"
|
||||
|
||||
// log is a logger that is initialized with no output filters. This means the
|
||||
// package will not perform any logging by default until the caller requests
|
||||
// it.
|
||||
var log btclog.Logger
|
||||
|
||||
// The default amount of logging is none.
|
||||
func init() {
|
||||
UseLogger(build.NewSubLogger(Subsystem, nil))
|
||||
}
|
||||
|
||||
// DisableLog disables all library log output. Logging output is disabled by
|
||||
// by default until UseLogger is called.
|
||||
func DisableLog() {
|
||||
UseLogger(btclog.Disabled)
|
||||
}
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info. This
|
||||
// should be used in preference to SetLogWriter if the caller is also using
|
||||
// btclog.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
||||
|
||||
// logClosure is used to provide a closure over expensive logging operations so
|
||||
// don't have to be performed when the logging level doesn't warrant it.
|
||||
type logClosure func() string // nolint:unused
|
||||
|
||||
// String invokes the underlying function and returns the result.
|
||||
func (c logClosure) String() string {
|
||||
return c()
|
||||
}
|
||||
|
||||
// newLogClosure returns a new closure over a function that returns a string
|
||||
// which itself provides a Stringer interface so that it can be used with the
|
||||
// logging system.
|
||||
func newLogClosure(c func() string) logClosure { // nolint:unused
|
||||
return logClosure(c)
|
||||
}
|
@ -115,8 +115,8 @@ func (c *WatchtowerClient) RegisterWithRootServer(grpcServer *grpc.Server) error
|
||||
// all our methods are routed properly.
|
||||
RegisterWatchtowerClientServer(grpcServer, c)
|
||||
|
||||
log.Debugf("WatchtowerClient RPC server successfully registered with " +
|
||||
"root gRPC server")
|
||||
c.cfg.Log.Debugf("WatchtowerClient RPC server successfully registered " +
|
||||
"with root gRPC server")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
2
log.go
2
log.go
@ -25,7 +25,6 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lnrpc/signrpc"
|
||||
"github.com/lightningnetwork/lnd/lnrpc/verrpc"
|
||||
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
|
||||
"github.com/lightningnetwork/lnd/lnrpc/wtclientrpc"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwallet/chanfunding"
|
||||
"github.com/lightningnetwork/lnd/monitoring"
|
||||
@ -102,7 +101,6 @@ func init() {
|
||||
|
||||
addSubLogger(routing.Subsystem, routing.UseLogger, localchans.UseLogger)
|
||||
addSubLogger(routerrpc.Subsystem, routerrpc.UseLogger)
|
||||
addSubLogger(wtclientrpc.Subsystem, wtclientrpc.UseLogger)
|
||||
addSubLogger(chanfitness.Subsystem, chanfitness.UseLogger)
|
||||
addSubLogger(verrpc.Subsystem, verrpc.UseLogger)
|
||||
}
|
||||
|
@ -589,6 +589,7 @@ func newRPCServer(s *server, macService *macaroons.Service,
|
||||
s.htlcSwitch, activeNetParams.Params, s.chanRouter,
|
||||
routerBackend, s.nodeSigner, s.chanDB, s.sweeper, tower,
|
||||
s.towerClient, cfg.net.ResolveTCPAddr, genInvoiceFeatures,
|
||||
rpcsLog,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"reflect"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/autopilot"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch"
|
||||
@ -94,7 +95,8 @@ func (s *subRPCServerConfigs) PopulateDependencies(cc *chainControl,
|
||||
tower *watchtower.Standalone,
|
||||
towerClient wtclient.Client,
|
||||
tcpResolver lncfg.TCPResolver,
|
||||
genInvoiceFeatures func() *lnwire.FeatureVector) error {
|
||||
genInvoiceFeatures func() *lnwire.FeatureVector,
|
||||
rpcLogger btclog.Logger) error {
|
||||
|
||||
// First, we'll use reflect to obtain a version of the config struct
|
||||
// that allows us to programmatically inspect its fields.
|
||||
@ -244,6 +246,9 @@ func (s *subRPCServerConfigs) PopulateDependencies(cc *chainControl,
|
||||
subCfgValue.FieldByName("Resolver").Set(
|
||||
reflect.ValueOf(tcpResolver),
|
||||
)
|
||||
subCfgValue.FieldByName("Log").Set(
|
||||
reflect.ValueOf(rpcLogger),
|
||||
)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown field: %v, %T", fieldName,
|
||||
|
@ -38,6 +38,14 @@ const (
|
||||
DefaultForceQuitDelay = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// activeSessionFilter is a filter that ignored any sessions which are
|
||||
// not active.
|
||||
activeSessionFilter = func(s *wtdb.ClientSession) bool {
|
||||
return s.Status == wtdb.CSessionActive
|
||||
}
|
||||
)
|
||||
|
||||
// RegisteredTower encompasses information about a registered watchtower with
|
||||
// the client.
|
||||
type RegisteredTower struct {
|
||||
@ -268,49 +276,18 @@ func New(config *Config) (*TowerClient, error) {
|
||||
// the client. We will use any of these session if their policies match
|
||||
// the current policy of the client, otherwise they will be ignored and
|
||||
// new sessions will be requested.
|
||||
sessions, err := cfg.DB.ListClientSessions(nil)
|
||||
candidateSessions, err := getClientSessions(
|
||||
cfg.DB, cfg.SecretKeyRing, nil, activeSessionFilter,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
sessionTowers := make(map[wtdb.TowerID]*wtdb.Tower)
|
||||
for _, s := range sessions {
|
||||
// Candidate sessions must be in an active state.
|
||||
if s.Status != wtdb.CSessionActive {
|
||||
continue
|
||||
}
|
||||
|
||||
// Reload the tower from disk using the tower ID contained in
|
||||
// each candidate session. We will also rederive any session
|
||||
// keys needed to be able to communicate with the towers and
|
||||
// authenticate session requests. This prevents us from having
|
||||
// to store the private keys on disk.
|
||||
tower, ok := sessionTowers[s.TowerID]
|
||||
if !ok {
|
||||
var err error
|
||||
tower, err = cfg.DB.LoadTowerByID(s.TowerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
s.Tower = tower
|
||||
|
||||
sessionKey, err := DeriveSessionKey(cfg.SecretKeyRing, s.KeyIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.SessionPrivKey = sessionKey
|
||||
|
||||
candidateSessions[s.ID] = s
|
||||
sessionTowers[tower.ID] = tower
|
||||
}
|
||||
|
||||
var candidateTowers []*wtdb.Tower
|
||||
for _, tower := range sessionTowers {
|
||||
for _, s := range candidateSessions {
|
||||
log.Infof("Using private watchtower %s, offering policy %s",
|
||||
tower, cfg.Policy)
|
||||
candidateTowers = append(candidateTowers, tower)
|
||||
s.Tower, cfg.Policy)
|
||||
candidateTowers = append(candidateTowers, s.Tower)
|
||||
}
|
||||
|
||||
// Load the sweep pkscripts that have been generated for all previously
|
||||
@ -353,6 +330,50 @@ func New(config *Config) (*TowerClient, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// getClientSessions retrieves the client sessions for a particular tower if
|
||||
// specified, otherwise all client sessions for all towers are retrieved. An
|
||||
// optional filter can be provided to filter out any undesired client sessions.
|
||||
//
|
||||
// NOTE: This method should only be used when deserialization of a
|
||||
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
|
||||
// existing ListClientSessions method should be used.
|
||||
func getClientSessions(db DB, keyRing SecretKeyRing, forTower *wtdb.TowerID,
|
||||
passesFilter func(*wtdb.ClientSession) bool) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
sessions, err := db.ListClientSessions(forTower)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reload the tower from disk using the tower ID contained in each
|
||||
// candidate session. We will also rederive any session keys needed to
|
||||
// be able to communicate with the towers and authenticate session
|
||||
// requests. This prevents us from having to store the private keys on
|
||||
// disk.
|
||||
for _, s := range sessions {
|
||||
tower, err := db.LoadTowerByID(s.TowerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Tower = tower
|
||||
|
||||
sessionKey, err := DeriveSessionKey(keyRing, s.KeyIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.SessionPrivKey = sessionKey
|
||||
|
||||
// If an optional filter was provided, use it to filter out any
|
||||
// undesired sessions.
|
||||
if passesFilter != nil && !passesFilter(s) {
|
||||
delete(sessions, s.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// buildHighestCommitHeights inspects the full set of candidate client sessions
|
||||
// loaded from disk, and determines the highest known commit height for each
|
||||
// channel. This allows the client to reject backups that it has already
|
||||
@ -1039,7 +1060,9 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error {
|
||||
c.candidateTowers.AddCandidate(tower)
|
||||
|
||||
// Include all of its corresponding sessions to our set of candidates.
|
||||
sessions, err := c.cfg.DB.ListClientSessions(&tower.ID)
|
||||
sessions, err := getClientSessions(
|
||||
c.cfg.DB, c.cfg.SecretKeyRing, &tower.ID, activeSessionFilter,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to determine sessions for tower %x: "+
|
||||
"%v", tower.IdentityKey.SerializeCompressed(), err)
|
||||
|
@ -366,17 +366,18 @@ func (c *mockChannel) getState(i uint64) (*wire.MsgTx, *lnwallet.BreachRetributi
|
||||
}
|
||||
|
||||
type testHarness struct {
|
||||
t *testing.T
|
||||
cfg harnessCfg
|
||||
signer *wtmock.MockSigner
|
||||
capacity lnwire.MilliSatoshi
|
||||
clientDB *wtmock.ClientDB
|
||||
clientCfg *wtclient.Config
|
||||
client wtclient.Client
|
||||
serverDB *wtmock.TowerDB
|
||||
serverCfg *wtserver.Config
|
||||
server *wtserver.Server
|
||||
net *mockNet
|
||||
t *testing.T
|
||||
cfg harnessCfg
|
||||
signer *wtmock.MockSigner
|
||||
capacity lnwire.MilliSatoshi
|
||||
clientDB *wtmock.ClientDB
|
||||
clientCfg *wtclient.Config
|
||||
client wtclient.Client
|
||||
serverAddr *lnwire.NetAddress
|
||||
serverDB *wtmock.TowerDB
|
||||
serverCfg *wtserver.Config
|
||||
server *wtserver.Server
|
||||
net *mockNet
|
||||
|
||||
mu sync.Mutex
|
||||
channels map[lnwire.ChannelID]*mockChannel
|
||||
@ -467,18 +468,19 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||
}
|
||||
|
||||
h := &testHarness{
|
||||
t: t,
|
||||
cfg: cfg,
|
||||
signer: signer,
|
||||
capacity: cfg.localBalance + cfg.remoteBalance,
|
||||
clientDB: clientDB,
|
||||
clientCfg: clientCfg,
|
||||
client: client,
|
||||
serverDB: serverDB,
|
||||
serverCfg: serverCfg,
|
||||
server: server,
|
||||
net: mockNet,
|
||||
channels: make(map[lnwire.ChannelID]*mockChannel),
|
||||
t: t,
|
||||
cfg: cfg,
|
||||
signer: signer,
|
||||
capacity: cfg.localBalance + cfg.remoteBalance,
|
||||
clientDB: clientDB,
|
||||
clientCfg: clientCfg,
|
||||
client: client,
|
||||
serverAddr: towerAddr,
|
||||
serverDB: serverDB,
|
||||
serverCfg: serverCfg,
|
||||
server: server,
|
||||
net: mockNet,
|
||||
channels: make(map[lnwire.ChannelID]*mockChannel),
|
||||
}
|
||||
|
||||
h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
|
||||
@ -782,6 +784,25 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
|
||||
}
|
||||
}
|
||||
|
||||
// addTower adds a tower found at `addr` to the client.
|
||||
func (h *testHarness) addTower(addr *lnwire.NetAddress) {
|
||||
h.t.Helper()
|
||||
|
||||
if err := h.client.AddTower(addr); err != nil {
|
||||
h.t.Fatalf("unable to add tower: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// removeTower removes a tower from the client. If `addr` is specified, then the
|
||||
// only said address is removed from the tower.
|
||||
func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) {
|
||||
h.t.Helper()
|
||||
|
||||
if err := h.client.RemoveTower(pubKey, addr); err != nil {
|
||||
h.t.Fatalf("unable to remove tower: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
localBalance = lnwire.MilliSatoshi(100000000)
|
||||
remoteBalance = lnwire.MilliSatoshi(200000000)
|
||||
@ -1396,6 +1417,60 @@ var clientTests = []clientTest{
|
||||
h.waitServerUpdates(hints, 5*time.Second)
|
||||
},
|
||||
},
|
||||
{
|
||||
// Asserts that the client can continue making backups to a
|
||||
// tower that's been re-added after it's been removed.
|
||||
name: "re-add removed tower",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: wtpolicy.TxPolicy{
|
||||
BlobType: blob.TypeAltruistCommit,
|
||||
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
|
||||
},
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
chanID = 0
|
||||
numUpdates = 4
|
||||
)
|
||||
|
||||
// Create four channel updates and only back up the
|
||||
// first two.
|
||||
hints := h.advanceChannelN(chanID, numUpdates)
|
||||
h.backupStates(chanID, 0, numUpdates/2, nil)
|
||||
h.waitServerUpdates(hints[:numUpdates/2], 5*time.Second)
|
||||
|
||||
// Fully remove the tower, causing its existing sessions
|
||||
// to be marked inactive.
|
||||
h.removeTower(h.serverAddr.IdentityKey, nil)
|
||||
|
||||
// Back up the remaining states. Since the tower has
|
||||
// been removed, it shouldn't receive any updates.
|
||||
h.backupStates(chanID, numUpdates/2, numUpdates, nil)
|
||||
h.waitServerUpdates(nil, time.Second)
|
||||
|
||||
// Re-add the tower. We prevent the tower from acking
|
||||
// session creation to ensure the inactive sessions are
|
||||
// not used.
|
||||
h.server.Stop()
|
||||
h.serverCfg.NoAckCreateSession = true
|
||||
h.startServer()
|
||||
h.addTower(h.serverAddr)
|
||||
h.waitServerUpdates(nil, time.Second)
|
||||
|
||||
// Finally, allow the tower to ack session creation,
|
||||
// allowing the state updates to be sent through the new
|
||||
// session.
|
||||
h.server.Stop()
|
||||
h.serverCfg.NoAckCreateSession = false
|
||||
h.startServer()
|
||||
h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestClient executes the client test suite, asserting the ability to backup
|
||||
|
@ -19,7 +19,7 @@ type ClientDB struct {
|
||||
|
||||
mu sync.Mutex
|
||||
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
||||
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
|
||||
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
||||
towerIndex map[towerPK]wtdb.TowerID
|
||||
towers map[wtdb.TowerID]*wtdb.Tower
|
||||
|
||||
@ -31,7 +31,7 @@ type ClientDB struct {
|
||||
func NewClientDB() *ClientDB {
|
||||
return &ClientDB{
|
||||
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
|
||||
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
|
||||
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[wtdb.TowerID]uint32),
|
||||
@ -62,7 +62,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||
}
|
||||
for id, session := range towerSessions {
|
||||
session.Status = wtdb.CSessionActive
|
||||
m.activeSessions[id] = session
|
||||
m.activeSessions[id] = *session
|
||||
}
|
||||
} else {
|
||||
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
|
||||
@ -122,7 +122,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
return wtdb.ErrTowerUnackedUpdates
|
||||
}
|
||||
session.Status = wtdb.CSessionInactive
|
||||
m.activeSessions[id] = session
|
||||
m.activeSessions[id] = *session
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -205,10 +205,11 @@ func (m *ClientDB) listClientSessions(
|
||||
|
||||
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
for _, session := range m.activeSessions {
|
||||
session := session
|
||||
if tower != nil && *tower != session.TowerID {
|
||||
continue
|
||||
}
|
||||
sessions[session.ID] = session
|
||||
sessions[session.ID] = &session
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
@ -240,7 +241,7 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
// permits us to create another session with this tower.
|
||||
delete(m.indexes, session.TowerID)
|
||||
|
||||
m.activeSessions[session.ID] = &wtdb.ClientSession{
|
||||
m.activeSessions[session.ID] = wtdb.ClientSession{
|
||||
ID: session.ID,
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
SeqNum: session.SeqNum,
|
||||
@ -313,6 +314,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||
// Save the update and increment the sequence number.
|
||||
session.CommittedUpdates = append(session.CommittedUpdates, *update)
|
||||
session.SeqNum++
|
||||
m.activeSessions[*id] = session
|
||||
|
||||
return session.TowerLastApplied, nil
|
||||
}
|
||||
@ -360,6 +362,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
|
||||
session.AckedUpdates[seqNum] = update.BackupID
|
||||
session.TowerLastApplied = lastApplied
|
||||
|
||||
m.activeSessions[*id] = session
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user