watchtower/wtclient: integrate ClientChannelSummaries

In this commit, we utilize the more generic ClientChanSummary instead of
exposing methods that only allow us to set and fetch sweep pkscripts.
This commit is contained in:
Conner Fromknecht 2019-05-23 20:48:50 -07:00
parent 25fc464a6e
commit b35a5b8892
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
5 changed files with 57 additions and 34 deletions

@ -150,8 +150,8 @@ type TowerClient struct {
sessionQueue *sessionQueue sessionQueue *sessionQueue
prevTask *backupTask prevTask *backupTask
sweepPkScriptMu sync.RWMutex summaryMu sync.RWMutex
sweepPkScripts map[lnwire.ChannelID][]byte summaries wtdb.ChannelSummaries
statTicker *time.Ticker statTicker *time.Ticker
stats clientStats stats clientStats
@ -245,7 +245,7 @@ func New(config *Config) (*TowerClient, error) {
// Finally, load the sweep pkscripts that have been generated for all // Finally, load the sweep pkscripts that have been generated for all
// previously registered channels. // previously registered channels.
c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts() c.summaries, err = c.cfg.DB.FetchChanSummaries()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -388,12 +388,12 @@ func (c *TowerClient) ForceQuit() {
// within the client. This should be called during link startup to ensure that // within the client. This should be called during link startup to ensure that
// the client is able to support the link during operation. // the client is able to support the link during operation.
func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
c.sweepPkScriptMu.Lock() c.summaryMu.Lock()
defer c.sweepPkScriptMu.Unlock() defer c.summaryMu.Unlock()
// If a pkscript for this channel already exists, the channel has been // If a pkscript for this channel already exists, the channel has been
// previously registered. // previously registered.
if _, ok := c.sweepPkScripts[chanID]; ok { if _, ok := c.summaries[chanID]; ok {
return nil return nil
} }
@ -406,14 +406,16 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
// Persist the sweep pkscript so that restarts will not introduce // Persist the sweep pkscript so that restarts will not introduce
// address inflation when the channel is reregistered after a restart. // address inflation when the channel is reregistered after a restart.
err = c.cfg.DB.AddChanPkScript(chanID, pkScript) err = c.cfg.DB.RegisterChannel(chanID, pkScript)
if err != nil { if err != nil {
return err return err
} }
// Finally, cache the pkscript in our in-memory cache to avoid db // Finally, cache the pkscript in our in-memory cache to avoid db
// lookups for the remainder of the daemon's execution. // lookups for the remainder of the daemon's execution.
c.sweepPkScripts[chanID] = pkScript c.summaries[chanID] = wtdb.ClientChanSummary{
SweepPkScript: pkScript,
}
return nil return nil
} }
@ -429,14 +431,14 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
breachInfo *lnwallet.BreachRetribution) error { breachInfo *lnwallet.BreachRetribution) error {
// Retrieve the cached sweep pkscript used for this channel. // Retrieve the cached sweep pkscript used for this channel.
c.sweepPkScriptMu.RLock() c.summaryMu.RLock()
sweepPkScript, ok := c.sweepPkScripts[*chanID] summary, ok := c.summaries[*chanID]
c.sweepPkScriptMu.RUnlock() c.summaryMu.RUnlock()
if !ok { if !ok {
return ErrUnregisteredChannel return ErrUnregisteredChannel
} }
task := newBackupTask(chanID, breachInfo, sweepPkScript) task := newBackupTask(chanID, breachInfo, summary.SweepPkScript)
return c.pipeline.QueueBackupTask(task) return c.pipeline.QueueBackupTask(task)
} }

@ -605,6 +605,8 @@ func (h *testHarness) backupStates(id, from, to uint64, expErr error) {
// backupStates instructs the channel identified by id to send a backup for // backupStates instructs the channel identified by id to send a backup for
// state i. // state i.
func (h *testHarness) backupState(id, i uint64, expErr error) { func (h *testHarness) backupState(id, i uint64, expErr error) {
h.t.Helper()
_, retribution := h.channel(id).getState(i) _, retribution := h.channel(id).getState(i)
chanID := chanIDFromInt(id) chanID := chanIDFromInt(id)

@ -41,14 +41,17 @@ type DB interface {
// still be able to accept state updates. // still be able to accept state updates.
ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error)
// FetchChanPkScripts returns a map of all sweep pkscripts for // FetchChanSummaries loads a mapping from all registered channels to
// registered channels. This is used on startup to cache the sweep // their channel summaries.
// pkscripts of registered channels in memory. FetchChanSummaries() (wtdb.ChannelSummaries, error)
FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error)
// AddChanPkScript inserts a newly generated sweep pkscript for the // RegisterChannel registers a channel for use within the client
// given channel. // database. For now, all that is stored in the channel summary is the
AddChanPkScript(lnwire.ChannelID, []byte) error // sweep pkscript that we'd like any tower sweeps to pay into. In the
// future, this will be extended to contain more info to allow the
// client efficiently request historical states to be backed up under
// the client's active policy.
RegisterChannel(lnwire.ChannelID, []byte) error
// MarkBackupIneligible records that the state identified by the // MarkBackupIneligible records that the state identified by the
// (channel id, commit height) tuple was ineligible for being backed up // (channel id, commit height) tuple was ineligible for being backed up

@ -1,11 +1,18 @@
package wtdb package wtdb
import ( import (
"errors"
"io" "io"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
var (
// ErrChannelAlreadyRegistered signals a duplicate attempt to
// register a channel with the client database.
ErrChannelAlreadyRegistered = errors.New("channel already registered")
)
// ChannelSummaries is a map for a given channel id to it's ClientChanSummary. // ChannelSummaries is a map for a given channel id to it's ClientChanSummary.
type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary

@ -1,7 +1,6 @@
package wtmock package wtmock
import ( import (
"fmt"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -18,7 +17,7 @@ type ClientDB struct {
nextTowerID uint64 // to be used atomically nextTowerID uint64 // to be used atomically
mu sync.Mutex mu sync.Mutex
sweepPkScripts map[lnwire.ChannelID][]byte summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]*wtdb.ClientSession activeSessions map[wtdb.SessionID]*wtdb.ClientSession
towerIndex map[towerPK]wtdb.TowerID towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower towers map[wtdb.TowerID]*wtdb.Tower
@ -30,7 +29,7 @@ type ClientDB struct {
// NewClientDB initializes a new mock ClientDB. // NewClientDB initializes a new mock ClientDB.
func NewClientDB() *ClientDB { func NewClientDB() *ClientDB {
return &ClientDB{ return &ClientDB{
sweepPkScripts: make(map[lnwire.ChannelID][]byte), summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
towerIndex: make(map[towerPK]wtdb.TowerID), towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower), towers: make(map[wtdb.TowerID]*wtdb.Tower),
@ -252,30 +251,40 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
return wtdb.ErrCommittedUpdateNotFound return wtdb.ErrCommittedUpdateNotFound
} }
// FetchChanPkScripts returns the set of sweep pkscripts known for all channels. // FetchChanSummaries loads a mapping from all registered channels to their
// This allows the client to cache them in memory on startup. // channel summaries.
func (m *ClientDB) FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) { func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
sweepPkScripts := make(map[lnwire.ChannelID][]byte) summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
for chanID, pkScript := range m.sweepPkScripts { for chanID, summary := range m.summaries {
sweepPkScripts[chanID] = cloneBytes(pkScript) summaries[chanID] = wtdb.ClientChanSummary{
SweepPkScript: cloneBytes(summary.SweepPkScript),
}
} }
return sweepPkScripts, nil return summaries, nil
} }
// AddChanPkScript sets a pkscript or sweeping funds from the channel or chanID. // RegisterChannel registers a channel for use within the client database. For
func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) error { // now, all that is stored in the channel summary is the sweep pkscript that
// we'd like any tower sweeps to pay into. In the future, this will be extended
// to contain more info to allow the client efficiently request historical
// states to be backed up under the client's active policy.
func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
sweepPkScript []byte) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if _, ok := m.sweepPkScripts[chanID]; ok { if _, ok := m.summaries[chanID]; ok {
return fmt.Errorf("pkscript for %x already exists", pkScript) return wtdb.ErrChannelAlreadyRegistered
} }
m.sweepPkScripts[chanID] = cloneBytes(pkScript) m.summaries[chanID] = wtdb.ClientChanSummary{
SweepPkScript: cloneBytes(sweepPkScript),
}
return nil return nil
} }