Merge pull request #2618 from cfromknecht/wtclient

watchtower/wtclient: reliable, asynchronous pipeline for revoked state backups
This commit is contained in:
Olaoluwa Osuntokun 2019-03-16 14:31:55 -07:00 committed by GitHub
commit ec62104acc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 4155 additions and 56 deletions

@ -34,9 +34,8 @@ import (
// necessary components are stripped out and encrypted before being sent to // necessary components are stripped out and encrypted before being sent to
// the tower in a StateUpdate. // the tower in a StateUpdate.
type backupTask struct { type backupTask struct {
chanID lnwire.ChannelID id wtdb.BackupID
commitHeight uint64 breachInfo *lnwallet.BreachRetribution
breachInfo *lnwallet.BreachRetribution
// state-dependent variables // state-dependent variables
@ -96,8 +95,10 @@ func newBackupTask(chanID *lnwire.ChannelID,
} }
return &backupTask{ return &backupTask{
chanID: *chanID, id: wtdb.BackupID{
commitHeight: breachInfo.RevokedStateNum, ChanID: *chanID,
CommitHeight: breachInfo.RevokedStateNum,
},
breachInfo: breachInfo, breachInfo: breachInfo,
toLocalInput: toLocalInput, toLocalInput: toLocalInput,
toRemoteInput: toRemoteInput, toRemoteInput: toRemoteInput,
@ -125,7 +126,7 @@ func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
// SessionInfo's policy. If no error is returned, the task has been bound to the // SessionInfo's policy. If no error is returned, the task has been bound to the
// session and can be queued to upload to the tower. Otherwise, the bind failed // session and can be queued to upload to the tower. Otherwise, the bind failed
// and should be rescheduled with a different session. // and should be rescheduled with a different session.
func (t *backupTask) bindSession(session *wtdb.SessionInfo) error { func (t *backupTask) bindSession(session *wtdb.ClientSession) error {
// First we'll begin by deriving a weight estimate for the justice // First we'll begin by deriving a weight estimate for the justice
// transaction. The final weight can be different depending on whether // transaction. The final weight can be different depending on whether
@ -154,7 +155,7 @@ func (t *backupTask) bindSession(session *wtdb.SessionInfo) error {
// in the current session's policy. // in the current session's policy.
outputs, err := session.Policy.ComputeJusticeTxOuts( outputs, err := session.Policy.ComputeJusticeTxOuts(
t.totalAmt, int64(weightEstimate.Weight()), t.totalAmt, int64(weightEstimate.Weight()),
t.sweepPkScript, session.RewardAddress, t.sweepPkScript, session.RewardPkScript,
) )
if err != nil { if err != nil {
return err return err

@ -69,7 +69,7 @@ type backupTaskTest struct {
expSweepAmt int64 expSweepAmt int64
expRewardAmt int64 expRewardAmt int64
expRewardScript []byte expRewardScript []byte
session *wtdb.SessionInfo session *wtdb.ClientSession
bindErr error bindErr error
expSweepScript []byte expSweepScript []byte
signer input.Signer signer input.Signer
@ -205,13 +205,13 @@ func genTaskTest(
expSweepAmt: expSweepAmt, expSweepAmt: expSweepAmt,
expRewardAmt: expRewardAmt, expRewardAmt: expRewardAmt,
expRewardScript: rewardScript, expRewardScript: rewardScript,
session: &wtdb.SessionInfo{ session: &wtdb.ClientSession{
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
BlobType: blobType, BlobType: blobType,
SweepFeeRate: sweepFeeRate, SweepFeeRate: sweepFeeRate,
RewardRate: 10000, RewardRate: 10000,
}, },
RewardAddress: rewardScript, RewardPkScript: rewardScript,
}, },
bindErr: bindErr, bindErr: bindErr,
expSweepScript: makeAddrSlice(22), expSweepScript: makeAddrSlice(22),
@ -379,7 +379,7 @@ var backupTaskTests = []backupTaskTest{
} }
// TestBackupTaskBind tests the initialization and binding of a backupTask to a // TestBackupTaskBind tests the initialization and binding of a backupTask to a
// SessionInfo. After a succesfful bind, all parameters of the justice // ClientSession. After a successful bind, all parameters of the justice
// transaction should be solidified, so we assert there correctness. In an // transaction should be solidified, so we assert there correctness. In an
// unsuccessful bind, the session-dependent parameters should be unmodified so // unsuccessful bind, the session-dependent parameters should be unmodified so
// that the backup task can be rescheduled if necessary. Finally, we assert that // that the backup task can be rescheduled if necessary. Finally, we assert that
@ -401,14 +401,14 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
// Assert that all parameters set during initialization are properly // Assert that all parameters set during initialization are properly
// populated. // populated.
if task.chanID != test.chanID { if task.id.ChanID != test.chanID {
t.Fatalf("channel id mismatch, want: %s, got: %s", t.Fatalf("channel id mismatch, want: %s, got: %s",
test.chanID, task.chanID) test.chanID, task.id.ChanID)
} }
if task.commitHeight != test.breachInfo.RevokedStateNum { if task.id.CommitHeight != test.breachInfo.RevokedStateNum {
t.Fatalf("commit height mismatch, want: %d, got: %d", t.Fatalf("commit height mismatch, want: %d, got: %d",
test.breachInfo.RevokedStateNum, task.commitHeight) test.breachInfo.RevokedStateNum, task.id.CommitHeight)
} }
if task.totalAmt != test.expTotalAmt { if task.totalAmt != test.expTotalAmt {

@ -0,0 +1,82 @@
package wtclient
import (
"container/list"
"sync"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
)
// TowerCandidateIterator provides an abstraction for iterating through possible
// watchtower addresses when attempting to create a new session.
type TowerCandidateIterator interface {
// Reset clears any internal iterator state, making previously taken
// candidates available as long as they remain in the set.
Reset() error
// Next returns the next candidate tower. The iterator is not required
// to return results in any particular order. If no more candidates are
// available, ErrTowerCandidatesExhausted is returned.
Next() (*wtdb.Tower, error)
}
// towerListIterator is a linked-list backed TowerCandidateIterator.
type towerListIterator struct {
mu sync.Mutex
candidates *list.List
nextCandidate *list.Element
}
// Compile-time constraint to ensure *towerListIterator implements the
// TowerCandidateIterator interface.
var _ TowerCandidateIterator = (*towerListIterator)(nil)
// newTowerListIterator initializes a new towerListIterator from a variadic list
// of lnwire.NetAddresses.
func newTowerListIterator(candidates ...*wtdb.Tower) *towerListIterator {
iter := &towerListIterator{
candidates: list.New(),
}
for _, candidate := range candidates {
iter.candidates.PushBack(candidate)
}
iter.Reset()
return iter
}
// Reset clears the iterators state, and makes the address at the front of the
// list the next item to be returned..
func (t *towerListIterator) Reset() error {
t.mu.Lock()
defer t.mu.Unlock()
// Reset the next candidate to the front of the linked-list.
t.nextCandidate = t.candidates.Front()
return nil
}
// Next returns the next candidate tower. This iterator will always return
// candidates in the order given when the iterator was instantiated. If no more
// candidates are available, ErrTowerCandidatesExhausted is returned.
func (t *towerListIterator) Next() (*wtdb.Tower, error) {
t.mu.Lock()
defer t.mu.Unlock()
// If the next candidate is nil, we've exhausted the list.
if t.nextCandidate == nil {
return nil, ErrTowerCandidatesExhausted
}
// Propose the tower at the front of the list.
tower := t.nextCandidate.Value.(*wtdb.Tower)
// Set the next candidate to the subsequent element.
t.nextCandidate = t.nextCandidate.Next()
return tower, nil
}
// TODO(conner): implement graph-backed candidate iterator for public towers.

@ -0,0 +1,804 @@
package wtclient
import (
"bytes"
"fmt"
"sync"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
const (
// DefaultReadTimeout specifies the default duration we will wait during
// a read before breaking out of a blocking read.
DefaultReadTimeout = 15 * time.Second
// DefaultWriteTimeout specifies the default duration we will wait during
// a write before breaking out of a blocking write.
DefaultWriteTimeout = 15 * time.Second
// DefaultStatInterval specifies the default interval between logging
// metrics about the client's operation.
DefaultStatInterval = 30 * time.Second
)
// Client is the primary interface used by the daemon to control a client's
// lifecycle and backup revoked states.
type Client interface {
// RegisterChannel persistently initializes any channel-dependent
// parameters within the client. This should be called during link
// startup to ensure that the client is able to support the link during
// operation.
RegisterChannel(lnwire.ChannelID) error
// BackupState initiates a request to back up a particular revoked
// state. If the method returns nil, the backup is guaranteed to be
// successful unless the client is force quit, or the justice
// transaction would create dust outputs when trying to abide by the
// negotiated policy.
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution) error
// Start initializes the watchtower client, allowing it process requests
// to backup revoked channel states.
Start() error
// Stop attempts a graceful shutdown of the watchtower client. In doing
// so, it will attempt to flush the pipeline and deliver any queued
// states to the tower before exiting.
Stop() error
// ForceQuit will forcibly shutdown the watchtower client. Calling this
// may lead to queued states being dropped.
ForceQuit()
}
// Config provides the TowerClient with access to the resources it requires to
// perform its duty. All nillable fields must be non-nil for the tower to be
// initialized properly.
type Config struct {
// Signer provides access to the wallet so that the client can sign
// justice transactions that spend from a remote party's commitment
// transaction.
Signer input.Signer
// NewAddress generates a new on-chain sweep pkscript.
NewAddress func() ([]byte, error)
// SecretKeyRing is used to derive the session keys used to communicate
// with the tower. The client only stores the KeyLocators internally so
// that we never store private keys on disk.
SecretKeyRing keychain.SecretKeyRing
// Dial connects to an addr using the specified net and returns the
// connection object.
Dial Dial
// AuthDialer establishes a brontide connection over an onion or clear
// network.
AuthDial AuthDialer
// DB provides access to the client's stable storage medium.
DB DB
// Policy is the session policy the client will propose when creating
// new sessions with the tower. If the policy differs from any active
// sessions recorded in the database, those sessions will be ignored and
// new sessions will be requested immediately.
Policy wtpolicy.Policy
// PrivateTower is the net address of a private tower. The client will
// try to create all sessions with this tower.
PrivateTower *lnwire.NetAddress
// ChainHash identifies the chain that the client is on and for which
// the tower must be watching to monitor for breaches.
ChainHash chainhash.Hash
// ForceQuitDelay is the duration after attempting to shutdown that the
// client will automatically abort any pending backups if an unclean
// shutdown is detected. If the value is less than or equal to zero, a
// call to Stop may block indefinitely. The client can always be
// ForceQuit externally irrespective of the chosen parameter.
ForceQuitDelay time.Duration
// ReadTimeout is the duration we will wait during a read before
// breaking out of a blocking read. If the value is less than or equal
// to zero, the default will be used instead.
ReadTimeout time.Duration
// WriteTimeout is the duration we will wait during a write before
// breaking out of a blocking write. If the value is less than or equal
// to zero, the default will be used instead.
WriteTimeout time.Duration
// MinBackoff defines the initial backoff applied to connections with
// watchtowers. Subsequent backoff durations will grow exponentially up
// until MaxBackoff.
MinBackoff time.Duration
// MaxBackoff defines the maximum backoff applied to conenctions with
// watchtowers. If the exponential backoff produces a timeout greater
// than this value, the backoff will be clamped to MaxBackoff.
MaxBackoff time.Duration
}
// TowerClient is a concrete implementation of the Client interface, offering a
// non-blocking, reliable subsystem for backing up revoked states to a specified
// private tower.
type TowerClient struct {
started sync.Once
stopped sync.Once
forced sync.Once
cfg *Config
pipeline *taskPipeline
negotiator SessionNegotiator
candidateSessions map[wtdb.SessionID]*wtdb.ClientSession
activeSessions sessionQueueSet
sessionQueue *sessionQueue
prevTask *backupTask
sweepPkScriptMu sync.RWMutex
sweepPkScripts map[lnwire.ChannelID][]byte
statTicker *time.Ticker
stats clientStats
wg sync.WaitGroup
forceQuit chan struct{}
}
// Compile-time constraint to ensure *TowerClient implements the Client
// interface.
var _ Client = (*TowerClient)(nil)
// New initializes a new TowerClient from the provide Config. An error is
// returned if the client could not initialized.
func New(config *Config) (*TowerClient, error) {
// Copy the config to prevent side-effects from modifying both the
// internal and external version of the Config.
cfg := new(Config)
*cfg = *config
// Set the read timeout to the default if none was provided.
if cfg.ReadTimeout <= 0 {
cfg.ReadTimeout = DefaultReadTimeout
}
// Set the write timeout to the default if none was provided.
if cfg.WriteTimeout <= 0 {
cfg.WriteTimeout = DefaultWriteTimeout
}
// Record the tower in our database, also loading any addresses
// previously associated with its public key.
tower, err := cfg.DB.CreateTower(cfg.PrivateTower)
if err != nil {
return nil, err
}
log.Infof("Using private watchtower %s, offering policy %s",
cfg.PrivateTower, cfg.Policy)
c := &TowerClient{
cfg: cfg,
pipeline: newTaskPipeline(),
activeSessions: make(sessionQueueSet),
statTicker: time.NewTicker(DefaultStatInterval),
forceQuit: make(chan struct{}),
}
c.negotiator = newSessionNegotiator(&NegotiatorConfig{
DB: cfg.DB,
Policy: cfg.Policy,
ChainHash: cfg.ChainHash,
SendMessage: c.sendMessage,
ReadMessage: c.readMessage,
Dial: c.dial,
Candidates: newTowerListIterator(tower),
MinBackoff: cfg.MinBackoff,
MaxBackoff: cfg.MaxBackoff,
})
// Next, load all active sessions from the db into the client. We will
// use any of these session if their policies match the current policy
// of the client, otherwise they will be ignored and new sessions will
// be requested.
c.candidateSessions, err = c.cfg.DB.ListClientSessions()
if err != nil {
return nil, err
}
// Finally, load the sweep pkscripts that have been generated for all
// previously registered channels.
c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts()
if err != nil {
return nil, err
}
return c, nil
}
// Start initializes the watchtower client by loading or negotiating an active
// session and then begins processing backup tasks from the request pipeline.
func (c *TowerClient) Start() error {
var err error
c.started.Do(func() {
log.Infof("Starting watchtower client")
// First, restart a session queue for any sessions that have
// committed but unacked state updates. This ensures that these
// sessions will be able to flush the committed updates after a
// restart.
for _, session := range c.candidateSessions {
if len(session.CommittedUpdates) > 0 {
log.Infof("Starting session=%s to process "+
"%d committed backups", session.ID,
len(session.CommittedUpdates))
c.initActiveQueue(session)
}
}
// Now start the session negotiator, which will allow us to
// request new session as soon as the backupDispatcher starts
// up.
err = c.negotiator.Start()
if err != nil {
return
}
// Start the task pipeline to which new backup tasks will be
// submitted from active links.
c.pipeline.Start()
c.wg.Add(1)
go c.backupDispatcher()
log.Infof("Watchtower client started successfully")
})
return err
}
// Stop idempotently initiates a graceful shutdown of the watchtower client.
func (c *TowerClient) Stop() error {
c.stopped.Do(func() {
log.Debugf("Stopping watchtower client")
// 1. Shutdown the backup queue, which will prevent any further
// updates from being accepted. In practice, the links should be
// shutdown before the client has been stopped, so all updates
// would have been added prior.
c.pipeline.Stop()
// 2. To ensure we don't hang forever on shutdown due to
// unintended failures, we'll delay a call to force quit the
// pipeline if a ForceQuitDelay is specified. This will have no
// effect if the pipeline shuts down cleanly before the delay
// fires.
//
// For full safety, this can be set to 0 and wait out
// indefinitely. However for mobile clients which may have a
// limited amount of time to exit before the background process
// is killed, this offers a way to ensure the process
// terminates.
if c.cfg.ForceQuitDelay > 0 {
time.AfterFunc(c.cfg.ForceQuitDelay, c.ForceQuit)
}
// 3. Once the backup queue has shutdown, wait for the main
// dispatcher to exit. The backup queue will signal it's
// completion to the dispatcher, which releases the wait group
// after all tasks have been assigned to session queues.
c.wg.Wait()
// 4. Since all valid tasks have been assigned to session
// queues, we no longer need to negotiate sessions.
c.negotiator.Stop()
log.Debugf("Waiting for active session queues to finish "+
"draining, stats: %s", c.stats)
// 5. Shutdown all active session queues in parallel. These will
// exit once all updates have been acked by the watchtower.
c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() {
return s.Stop
})
// Skip log if force quitting.
select {
case <-c.forceQuit:
return
default:
}
log.Debugf("Client successfully stopped, stats: %s", c.stats)
})
return nil
}
// ForceQuit idempotently initiates an unclean shutdown of the watchtower
// client. This should only be executed if Stop is unable to exit cleanly.
func (c *TowerClient) ForceQuit() {
c.forced.Do(func() {
log.Infof("Force quitting watchtower client")
// Cancel log message from stop.
close(c.forceQuit)
// 1. Shutdown the backup queue, which will prevent any further
// updates from being accepted. In practice, the links should be
// shutdown before the client has been stopped, so all updates
// would have been added prior.
c.pipeline.ForceQuit()
// 2. Once the backup queue has shutdown, wait for the main
// dispatcher to exit. The backup queue will signal it's
// completion to the dispatcher, which releases the wait group
// after all tasks have been assigned to session queues.
c.wg.Wait()
// 3. Since all valid tasks have been assigned to session
// queues, we no longer need to negotiate sessions.
c.negotiator.Stop()
// 4. Force quit all active session queues in parallel. These
// will exit once all updates have been acked by the watchtower.
c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() {
return s.ForceQuit
})
log.Infof("Watchtower client unclean shutdown complete, "+
"stats: %s", c.stats)
})
}
// RegisterChannel persistently initializes any channel-dependent parameters
// within the client. This should be called during link startup to ensure that
// the client is able to support the link during operation.
func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
c.sweepPkScriptMu.Lock()
defer c.sweepPkScriptMu.Unlock()
// If a pkscript for this channel already exists, the channel has been
// previously registered.
if _, ok := c.sweepPkScripts[chanID]; ok {
return nil
}
// Otherwise, generate a new sweep pkscript used to sweep funds for this
// channel.
pkScript, err := c.cfg.NewAddress()
if err != nil {
return err
}
// Persist the sweep pkscript so that restarts will not introduce
// address inflation when the channel is reregistered after a restart.
err = c.cfg.DB.AddChanPkScript(chanID, pkScript)
if err != nil {
return err
}
// Finally, cache the pkscript in our in-memory cache to avoid db
// lookups for the remainder of the daemon's execution.
c.sweepPkScripts[chanID] = pkScript
return nil
}
// BackupState initiates a request to back up a particular revoked state. If the
// method returns nil, the backup is guaranteed to be successful unless the:
// - client is force quit,
// - justice transaction would create dust outputs when trying to abide by the
// negotiated policy, or
// - breached outputs contain too little value to sweep at the target sweep fee
// rate.
func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
breachInfo *lnwallet.BreachRetribution) error {
// Retrieve the cached sweep pkscript used for this channel.
c.sweepPkScriptMu.RLock()
sweepPkScript, ok := c.sweepPkScripts[*chanID]
c.sweepPkScriptMu.RUnlock()
if !ok {
return ErrUnregisteredChannel
}
task := newBackupTask(chanID, breachInfo, sweepPkScript)
return c.pipeline.QueueBackupTask(task)
}
// nextSessionQueue attempts to fetch an active session from our set of
// candidate sessions. Candidate sessions with a differing policy from the
// active client's advertised policy will be ignored, but may be resumed if the
// client is restarted with a matching policy. If no candidates were found, nil
// is returned to signal that we need to request a new policy.
func (c *TowerClient) nextSessionQueue() *sessionQueue {
// Select any candidate session at random, and remove it from the set of
// candidate sessions.
var candidateSession *wtdb.ClientSession
for id, sessionInfo := range c.candidateSessions {
delete(c.candidateSessions, id)
// Skip any sessions with policies that don't match the current
// configuration. These can be used again if the client changes
// their configuration back.
if sessionInfo.Policy != c.cfg.Policy {
continue
}
candidateSession = sessionInfo
break
}
// If none of the sessions could be used or none were found, we'll
// return nil to signal that we need another session to be negotiated.
if candidateSession == nil {
return nil
}
// Initialize the session queue and spin it up so it can begin handling
// updates. If the queue was already made active on startup, this will
// simply return the existing session queue from the set.
return c.getOrInitActiveQueue(candidateSession)
}
// backupDispatcher processes events coming from the taskPipeline and is
// responsible for detecting when the client needs to renegotiate a session to
// fulfill continuing demand. The event loop exits after all tasks have been
// received from the upstream taskPipeline, or the taskPipeline is force quit.
//
// NOTE: This method MUST be run as a goroutine.
func (c *TowerClient) backupDispatcher() {
defer c.wg.Done()
log.Tracef("Starting backup dispatcher")
defer log.Tracef("Stopping backup dispatcher")
for {
switch {
// No active session queue and no additional sessions.
case c.sessionQueue == nil && len(c.candidateSessions) == 0:
log.Infof("Requesting new session.")
// Immediately request a new session.
c.negotiator.RequestSession()
// Wait until we receive the newly negotiated session.
// All backups sent in the meantime are queued in the
// revoke queue, as we cannot process them.
select {
case session := <-c.negotiator.NewSessions():
log.Infof("Acquired new session with id=%s",
session.ID)
c.candidateSessions[session.ID] = session
c.stats.sessionAcquired()
case <-c.statTicker.C:
log.Infof("Client stats: %s", c.stats)
}
// No active session queue but have additional sessions.
case c.sessionQueue == nil && len(c.candidateSessions) > 0:
// We've exhausted the prior session, we'll pop another
// from the remaining sessions and continue processing
// backup tasks.
c.sessionQueue = c.nextSessionQueue()
if c.sessionQueue != nil {
log.Debugf("Loaded next candidate session "+
"queue id=%s", c.sessionQueue.ID())
}
// Have active session queue, process backups.
case c.sessionQueue != nil:
if c.prevTask != nil {
c.processTask(c.prevTask)
// Continue to ensure the sessionQueue is
// properly initialized before attempting to
// process more tasks from the pipeline.
continue
}
// Normal operation where new tasks are read from the
// pipeline.
select {
// If any sessions are negotiated while we have an
// active session queue, queue them for future use.
// This shouldn't happen with the current design, so
// it doesn't hurt to select here just in case. In the
// future, we will likely allow more asynchrony so that
// we can request new sessions before the session is
// fully empty, which this case would handle.
case session := <-c.negotiator.NewSessions():
log.Warnf("Acquired new session with id=%s",
"while processing tasks", session.ID)
c.candidateSessions[session.ID] = session
c.stats.sessionAcquired()
case <-c.statTicker.C:
log.Infof("Client stats: %s", c.stats)
// Process each backup task serially from the queue of
// revoked states.
case task, ok := <-c.pipeline.NewBackupTasks():
// All backups in the pipeline have been
// processed, it is now safe to exit.
if !ok {
return
}
log.Debugf("Processing backup task chanid=%s "+
"commit-height=%d", task.id.ChanID,
task.id.CommitHeight)
c.stats.taskReceived()
c.processTask(task)
}
}
}
}
// processTask attempts to schedule the given backupTask on the active
// sessionQueue. The task will either be accepted or rejected, afterwhich the
// appropriate modifications to the client's state machine will be made. After
// every invocation of processTask, the caller should ensure that the
// sessionQueue hasn't been exhausted before proceeding to the next task. Tasks
// that are rejected because the active sessionQueue is full will be cached as
// the prevTask, and should be reprocessed after obtaining a new sessionQueue.
func (c *TowerClient) processTask(task *backupTask) {
status, accepted := c.sessionQueue.AcceptTask(task)
if accepted {
c.taskAccepted(task, status)
} else {
c.taskRejected(task, status)
}
}
// taskAccepted processes the acceptance of a task by a sessionQueue depending
// on the state the sessionQueue is in *after* the task is added. The client's
// prevTask is always removed as a result of this call. The client's
// sessionQueue will be removed if accepting the task left the sessionQueue in
// an exhausted state.
func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) {
log.Infof("Backup chanid=%s commit-height=%d accepted successfully",
task.id.ChanID, task.id.CommitHeight)
c.stats.taskAccepted()
// If this task was accepted, we discard anything held in the prevTask.
// Either it was nil before, or is the task which was just accepted.
c.prevTask = nil
switch newStatus {
// The sessionQueue still has capacity after accepting this task.
case reserveAvailable:
// The sessionQueue is full after accepting this task, so we will need
// to request a new one before proceeding.
case reserveExhausted:
c.stats.sessionExhausted()
log.Debugf("Session %s exhausted", c.sessionQueue.ID())
// This task left the session exhausted, set it to nil and
// proceed to the next loop so we can consume another
// pre-negotiated session or request another.
c.sessionQueue = nil
}
}
// taskRejected process the rejection of a task by a sessionQueue depending on
// the state the was in *before* the task was rejected. The client's prevTask
// will cache the task if the sessionQueue was exhausted before hand, and nil
// the sessionQueue to find a new session. If the sessionQueue was not
// exhausted, the client marks the task as ineligible, as this implies we
// couldn't construct a valid justice transaction given the session's policy.
func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) {
switch curStatus {
// The sessionQueue has available capacity but the task was rejected,
// this indicates that the task was ineligible for backup.
case reserveAvailable:
c.stats.taskIneligible()
log.Infof("Backup chanid=%s commit-height=%d is ineligible",
task.id.ChanID, task.id.CommitHeight)
err := c.cfg.DB.MarkBackupIneligible(
task.id.ChanID, task.id.CommitHeight,
)
if err != nil {
log.Errorf("Unable to mark task chanid=%s "+
"commit-height=%d ineligible: %v",
task.id.ChanID, task.id.CommitHeight, err)
// It is safe to not handle this error, even if we could
// not persist the result. At worst, this task may be
// reprocessed on a subsequent start up, and will either
// succeed do a change in session parameters or fail in
// the same manner.
}
// If this task was rejected *and* the session had available
// capacity, we discard anything held in the prevTask. Either it
// was nil before, or is the task which was just rejected.
c.prevTask = nil
// The sessionQueue rejected the task because it is full, we will stash
// this task and try to add it to the next available sessionQueue.
case reserveExhausted:
c.stats.sessionExhausted()
log.Debugf("Session %s exhausted, backup chanid=%s "+
"commit-height=%d queued for next session",
c.sessionQueue.ID(), task.id.ChanID,
task.id.CommitHeight)
// Cache the task that we pulled off, so that we can process it
// once a new session queue is available.
c.sessionQueue = nil
c.prevTask = task
}
}
// dial connects the peer at addr using privKey as our secret key for the
// connection. The connection will use the configured Net's resolver to resolve
// the address for either Tor or clear net connections.
func (c *TowerClient) dial(privKey *btcec.PrivateKey,
addr *lnwire.NetAddress) (wtserver.Peer, error) {
return c.cfg.AuthDial(privKey, addr, c.cfg.Dial)
}
// readMessage receives and parses the next message from the given Peer. An
// error is returned if a message is not received before the server's read
// timeout, the read off the wire failed, or the message could not be
// deserialized.
func (c *TowerClient) readMessage(peer wtserver.Peer) (wtwire.Message, error) {
// Set a read timeout to ensure we drop the connection if nothing is
// received in a timely manner.
err := peer.SetReadDeadline(time.Now().Add(c.cfg.ReadTimeout))
if err != nil {
err = fmt.Errorf("unable to set read deadline: %v", err)
log.Errorf("Unable to read msg: %v", err)
return nil, err
}
// Pull the next message off the wire,
rawMsg, err := peer.ReadNextMessage()
if err != nil {
err = fmt.Errorf("unable to read message: %v", err)
log.Errorf("Unable to read msg: %v", err)
return nil, err
}
// Parse the received message according to the watchtower wire
// specification.
msgReader := bytes.NewReader(rawMsg)
msg, err := wtwire.ReadMessage(msgReader, 0)
if err != nil {
err = fmt.Errorf("unable to parse message: %v", err)
log.Errorf("Unable to read msg: %v", err)
return nil, err
}
logMessage(peer, msg, true)
return msg, nil
}
// sendMessage sends a watchtower wire message to the target peer.
func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error {
// Encode the next wire message into the buffer.
// TODO(conner): use buffer pool
var b bytes.Buffer
_, err := wtwire.WriteMessage(&b, msg, 0)
if err != nil {
err = fmt.Errorf("Unable to encode msg: %v", err)
log.Errorf("Unable to send msg: %v", err)
return err
}
// Set the write deadline for the connection, ensuring we drop the
// connection if nothing is sent in a timely manner.
err = peer.SetWriteDeadline(time.Now().Add(c.cfg.WriteTimeout))
if err != nil {
err = fmt.Errorf("unable to set write deadline: %v", err)
log.Errorf("Unable to send msg: %v", err)
return err
}
logMessage(peer, msg, false)
// Write out the full message to the remote peer.
_, err = peer.Write(b.Bytes())
if err != nil {
log.Errorf("Unable to send msg: %v", err)
}
return err
}
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the
// database and supplying it with the resources needed by the client.
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
return newSessionQueue(&sessionQueueConfig{
ClientSession: s,
ChainHash: c.cfg.ChainHash,
Dial: c.dial,
ReadMessage: c.readMessage,
SendMessage: c.sendMessage,
Signer: c.cfg.Signer,
DB: c.cfg.DB,
MinBackoff: c.cfg.MinBackoff,
MaxBackoff: c.cfg.MaxBackoff,
})
}
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
// passed ClientSession. If it exists, the active sessionQueue is returned.
// Otherwise a new sessionQueue is initialized and added to the set.
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue {
if sq, ok := c.activeSessions[s.ID]; ok {
return sq
}
return c.initActiveQueue(s)
}
// initActiveQueue creates a new sessionQueue from the passed ClientSession,
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue
// so that it can deliver any committed updates or begin accepting newly
// assigned tasks.
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue {
// Initialize the session queue, providing it with all of the resources
// it requires from the client instance.
sq := c.newSessionQueue(s)
// Add the session queue as an active session so that we remember to
// stop it on shutdown.
c.activeSessions.Add(sq)
// Start the queue so that it can be active in processing newly assigned
// tasks or to upload previously committed updates.
sq.Start()
return sq
}
// logMessage writes information about a message received from a remote peer,
// using directional prepositions to signal whether the message was sent or
// received.
func logMessage(peer wtserver.Peer, msg wtwire.Message, read bool) {
var action = "Received"
var preposition = "from"
if !read {
action = "Sending"
preposition = "to"
}
summary := wtwire.MessageSummary(msg)
if len(summary) > 0 {
summary = "(" + summary + ")"
}
log.Debugf("%s %s%v %s %x@%s", action, msg.MsgType(), summary,
preposition, peer.RemotePub().SerializeCompressed(),
peer.RemoteAddr())
}

@ -0,0 +1,1118 @@
// +build dev
package wtclient_test
import (
"encoding/binary"
"net"
"sync"
"testing"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtclient"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
)
const csvDelay uint32 = 144
var (
revPrivBytes = []byte{
0x8f, 0x4b, 0x51, 0x83, 0xa9, 0x34, 0xbd, 0x5f,
0x74, 0x6c, 0x9d, 0x5c, 0xae, 0x88, 0x2d, 0x31,
0x06, 0x90, 0xdd, 0x8c, 0x9b, 0x31, 0xbc, 0xd1,
0x78, 0x91, 0x88, 0x2a, 0xf9, 0x74, 0xa0, 0xef,
}
toLocalPrivBytes = []byte{
0xde, 0x17, 0xc1, 0x2f, 0xdc, 0x1b, 0xc0, 0xc6,
0x59, 0x5d, 0xf9, 0xc1, 0x3e, 0x89, 0xbc, 0x6f,
0x01, 0x85, 0x45, 0x76, 0x26, 0xce, 0x9c, 0x55,
0x3b, 0xc9, 0xec, 0x3d, 0xd8, 0x8b, 0xac, 0xa8,
}
toRemotePrivBytes = []byte{
0x28, 0x59, 0x6f, 0x36, 0xb8, 0x9f, 0x19, 0x5d,
0xcb, 0x07, 0x48, 0x8a, 0xe5, 0x89, 0x71, 0x74,
0x70, 0x4c, 0xff, 0x1e, 0x9c, 0x00, 0x93, 0xbe,
0xe2, 0x2e, 0x68, 0x08, 0x4c, 0xb4, 0x0f, 0x4f,
}
// addr is the server's reward address given to watchtower clients.
addr, _ = btcutil.DecodeAddress(
"mrX9vMRYLfVy1BnZbc5gZjuyaqH3ZW2ZHz", &chaincfg.TestNet3Params,
)
addrScript, _ = txscript.PayToAddrScript(addr)
)
// randPrivKey generates a new secp keypair, and returns the public key.
func randPrivKey(t *testing.T) *btcec.PrivateKey {
t.Helper()
sk, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
t.Fatalf("unable to generate pubkey: %v", err)
}
return sk
}
type mockNet struct {
mu sync.RWMutex
connCallback func(wtserver.Peer)
}
func newMockNet(cb func(wtserver.Peer)) *mockNet {
return &mockNet{
connCallback: cb,
}
}
func (m *mockNet) Dial(network string, address string) (net.Conn, error) {
return nil, nil
}
func (m *mockNet) LookupHost(host string) ([]string, error) {
panic("not implemented")
}
func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) {
panic("not implemented")
}
func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) {
panic("not implemented")
}
func (m *mockNet) AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error) {
localPk := localPriv.PubKey()
localAddr := &net.TCPAddr{
IP: net.IP{0x32, 0x31, 0x30, 0x29},
Port: 36723,
}
localPeer, remotePeer := wtmock.NewMockConn(
localPk, netAddr.IdentityKey, localAddr, netAddr.Address, 0,
)
m.mu.RLock()
m.connCallback(remotePeer)
m.mu.RUnlock()
return localPeer, nil
}
func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) {
m.mu.Lock()
defer m.mu.Unlock()
m.connCallback = cb
}
type mockChannel struct {
mu sync.Mutex
commitHeight uint64
retributions map[uint64]*lnwallet.BreachRetribution
localBalance lnwire.MilliSatoshi
remoteBalance lnwire.MilliSatoshi
revSK *btcec.PrivateKey
revPK *btcec.PublicKey
revKeyLoc keychain.KeyLocator
toRemoteSK *btcec.PrivateKey
toRemotePK *btcec.PublicKey
toRemoteKeyLoc keychain.KeyLocator
toLocalPK *btcec.PublicKey // only need to generate to-local script
dustLimit lnwire.MilliSatoshi
csvDelay uint32
}
func newMockChannel(t *testing.T, signer *wtmock.MockSigner,
localAmt, remoteAmt lnwire.MilliSatoshi) *mockChannel {
// Generate the revocation, to-local, and to-remote keypairs.
revSK := randPrivKey(t)
revPK := revSK.PubKey()
toLocalSK := randPrivKey(t)
toLocalPK := toLocalSK.PubKey()
toRemoteSK := randPrivKey(t)
toRemotePK := toRemoteSK.PubKey()
// Register the revocation secret key and the to-remote secret key with
// the signer. We will not need to sign with the to-local key, as this
// is to be known only by the counterparty.
revKeyLoc := signer.AddPrivKey(revSK)
toRemoteKeyLoc := signer.AddPrivKey(toRemoteSK)
c := &mockChannel{
retributions: make(map[uint64]*lnwallet.BreachRetribution),
localBalance: localAmt,
remoteBalance: remoteAmt,
revSK: revSK,
revPK: revPK,
revKeyLoc: revKeyLoc,
toLocalPK: toLocalPK,
toRemoteSK: toRemoteSK,
toRemotePK: toRemotePK,
toRemoteKeyLoc: toRemoteKeyLoc,
dustLimit: 546000,
csvDelay: 144,
}
// Create the initial remote commitment with the initial balances.
c.createRemoteCommitTx(t)
return c
}
func (c *mockChannel) createRemoteCommitTx(t *testing.T) {
t.Helper()
// Construct the to-local witness script.
toLocalScript, err := input.CommitScriptToSelf(
c.csvDelay, c.toLocalPK, c.revPK,
)
if err != nil {
t.Fatalf("unable to create to-local script: %v", err)
}
// Compute the to-local witness script hash.
toLocalScriptHash, err := input.WitnessScriptHash(toLocalScript)
if err != nil {
t.Fatalf("unable to create to-local witness script hash: %v", err)
}
// Compute the to-remote witness script hash.
toRemoteScriptHash, err := input.CommitScriptUnencumbered(c.toRemotePK)
if err != nil {
t.Fatalf("unable to create to-remote script: %v", err)
}
// Construct the remote commitment txn, containing the to-local and
// to-remote outputs. The balances are flipped since the transaction is
// from the PoV of the remote party. We don't need any inputs for this
// test. We increment the version with the commit height to ensure that
// all commitment transactions are unique even if the same distribution
// of funds is used more than once.
commitTxn := &wire.MsgTx{
Version: int32(c.commitHeight + 1),
}
var (
toLocalSignDesc *input.SignDescriptor
toRemoteSignDesc *input.SignDescriptor
)
var outputIndex int
if c.remoteBalance >= c.dustLimit {
commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{
Value: int64(c.remoteBalance.ToSatoshis()),
PkScript: toLocalScriptHash,
})
// Create the sign descriptor used to sign for the to-local
// input.
toLocalSignDesc = &input.SignDescriptor{
KeyDesc: keychain.KeyDescriptor{
KeyLocator: c.revKeyLoc,
PubKey: c.revPK,
},
WitnessScript: toLocalScript,
Output: commitTxn.TxOut[outputIndex],
HashType: txscript.SigHashAll,
}
outputIndex++
}
if c.localBalance >= c.dustLimit {
commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{
Value: int64(c.localBalance.ToSatoshis()),
PkScript: toRemoteScriptHash,
})
// Create the sign descriptor used to sign for the to-remote
// input.
toRemoteSignDesc = &input.SignDescriptor{
KeyDesc: keychain.KeyDescriptor{
KeyLocator: c.toRemoteKeyLoc,
PubKey: c.toRemotePK,
},
WitnessScript: toRemoteScriptHash,
Output: commitTxn.TxOut[outputIndex],
HashType: txscript.SigHashAll,
}
outputIndex++
}
txid := commitTxn.TxHash()
var (
toLocalOutPoint wire.OutPoint
toRemoteOutPoint wire.OutPoint
)
outputIndex = 0
if toLocalSignDesc != nil {
toLocalOutPoint = wire.OutPoint{
Hash: txid,
Index: uint32(outputIndex),
}
outputIndex++
}
if toRemoteSignDesc != nil {
toRemoteOutPoint = wire.OutPoint{
Hash: txid,
Index: uint32(outputIndex),
}
outputIndex++
}
commitKeyRing := &lnwallet.CommitmentKeyRing{
RevocationKey: c.revPK,
NoDelayKey: c.toLocalPK,
DelayKey: c.toRemotePK,
}
retribution := &lnwallet.BreachRetribution{
BreachTransaction: commitTxn,
RevokedStateNum: c.commitHeight,
KeyRing: commitKeyRing,
RemoteDelay: c.csvDelay,
LocalOutpoint: toRemoteOutPoint,
LocalOutputSignDesc: toRemoteSignDesc,
RemoteOutpoint: toLocalOutPoint,
RemoteOutputSignDesc: toLocalSignDesc,
}
c.retributions[c.commitHeight] = retribution
c.commitHeight++
}
// advanceState creates the next channel state and retribution without altering
// channel balances.
func (c *mockChannel) advanceState(t *testing.T) {
c.mu.Lock()
defer c.mu.Unlock()
c.createRemoteCommitTx(t)
}
// sendPayment creates the next channel state and retribution after transferring
// amt to the remote party.
func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliSatoshi) {
t.Helper()
c.mu.Lock()
defer c.mu.Unlock()
if c.localBalance < amt {
t.Fatalf("insufficient funds to send, need: %v, have: %v",
amt, c.localBalance)
}
c.localBalance -= amt
c.remoteBalance += amt
c.createRemoteCommitTx(t)
}
// receivePayment creates the next channel state and retribution after
// transferring amt to the local party.
func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliSatoshi) {
t.Helper()
c.mu.Lock()
defer c.mu.Unlock()
if c.remoteBalance < amt {
t.Fatalf("insufficient funds to recv, need: %v, have: %v",
amt, c.remoteBalance)
}
c.localBalance += amt
c.remoteBalance -= amt
c.createRemoteCommitTx(t)
}
// getState retrieves the channel's commitment and retribution at state i.
func (c *mockChannel) getState(i uint64) (*wire.MsgTx, *lnwallet.BreachRetribution) {
c.mu.Lock()
defer c.mu.Unlock()
retribution := c.retributions[i]
return retribution.BreachTransaction, retribution
}
type testHarness struct {
t *testing.T
cfg harnessCfg
signer *wtmock.MockSigner
capacity lnwire.MilliSatoshi
clientDB *wtmock.ClientDB
clientCfg *wtclient.Config
client wtclient.Client
serverDB *wtdb.MockDB
serverCfg *wtserver.Config
server *wtserver.Server
net *mockNet
mu sync.Mutex
channels map[lnwire.ChannelID]*mockChannel
}
type harnessCfg struct {
localBalance lnwire.MilliSatoshi
remoteBalance lnwire.MilliSatoshi
policy wtpolicy.Policy
noRegisterChan0 bool
}
func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
towerAddrStr := "18.28.243.2:9911"
towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr)
if err != nil {
t.Fatalf("Unable to resolve tower TCP addr: %v", err)
}
privKey, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
t.Fatalf("Unable to generate tower private key: %v", err)
}
towerPubKey := privKey.PubKey()
towerAddr := &lnwire.NetAddress{
IdentityKey: towerPubKey,
Address: towerTCPAddr,
}
const timeout = 200 * time.Millisecond
serverDB := wtdb.NewMockDB()
serverCfg := &wtserver.Config{
DB: serverDB,
ReadTimeout: timeout,
WriteTimeout: timeout,
NewAddress: func() (btcutil.Address, error) {
return addr, nil
},
}
server, err := wtserver.New(serverCfg)
if err != nil {
t.Fatalf("unable to create wtserver: %v", err)
}
signer := wtmock.NewMockSigner()
mockNet := newMockNet(server.InboundPeerConnected)
clientDB := wtmock.NewClientDB()
clientCfg := &wtclient.Config{
Signer: signer,
Dial: func(string, string) (net.Conn, error) {
return nil, nil
},
DB: clientDB,
AuthDial: mockNet.AuthDial,
PrivateTower: towerAddr,
Policy: cfg.policy,
NewAddress: func() ([]byte, error) {
return addrScript, nil
},
ReadTimeout: timeout,
WriteTimeout: timeout,
MinBackoff: time.Millisecond,
MaxBackoff: 10 * time.Millisecond,
}
client, err := wtclient.New(clientCfg)
if err != nil {
t.Fatalf("Unable to create wtclient: %v", err)
}
if err := server.Start(); err != nil {
t.Fatalf("Unable to start wtserver: %v", err)
}
if err = client.Start(); err != nil {
server.Stop()
t.Fatalf("Unable to start wtclient: %v", err)
}
h := &testHarness{
t: t,
cfg: cfg,
signer: signer,
capacity: cfg.localBalance + cfg.remoteBalance,
clientDB: clientDB,
clientCfg: clientCfg,
client: client,
serverDB: serverDB,
serverCfg: serverCfg,
server: server,
net: mockNet,
channels: make(map[lnwire.ChannelID]*mockChannel),
}
h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
if !cfg.noRegisterChan0 {
h.registerChannel(0)
}
return h
}
// startServer creates a new server using the harness's current serverCfg and
// starts it after pointing the mockNet's callback to the new server.
func (h *testHarness) startServer() {
h.t.Helper()
var err error
h.server, err = wtserver.New(h.serverCfg)
if err != nil {
h.t.Fatalf("unable to create wtserver: %v", err)
}
h.net.setConnCallback(h.server.InboundPeerConnected)
if err := h.server.Start(); err != nil {
h.t.Fatalf("unable to start wtserver: %v", err)
}
}
// startClient creates a new server using the harness's current clientCf and
// starts it.
func (h *testHarness) startClient() {
h.t.Helper()
var err error
h.client, err = wtclient.New(h.clientCfg)
if err != nil {
h.t.Fatalf("unable to create wtclient: %v", err)
}
if err := h.client.Start(); err != nil {
h.t.Fatalf("unable to start wtclient: %v", err)
}
}
// chanIDFromInt creates a unique channel id given a unique integral id.
func chanIDFromInt(id uint64) lnwire.ChannelID {
var chanID lnwire.ChannelID
binary.BigEndian.PutUint64(chanID[:8], id)
return chanID
}
// makeChannel creates new channel with id, using the localAmt and remoteAmt as
// the starting balances. The channel will be available by using h.channel(id).
//
// NOTE: The method fails if channel for id already exists.
func (h *testHarness) makeChannel(id uint64,
localAmt, remoteAmt lnwire.MilliSatoshi) {
h.t.Helper()
chanID := chanIDFromInt(id)
c := newMockChannel(h.t, h.signer, localAmt, remoteAmt)
c.mu.Lock()
_, ok := h.channels[chanID]
if !ok {
h.channels[chanID] = c
}
c.mu.Unlock()
if ok {
h.t.Fatalf("channel %d already created", id)
}
}
// channel retrieves the channel corresponding to id.
//
// NOTE: The method fails if a channel for id does not exist.
func (h *testHarness) channel(id uint64) *mockChannel {
h.t.Helper()
h.mu.Lock()
c, ok := h.channels[chanIDFromInt(id)]
h.mu.Unlock()
if !ok {
h.t.Fatalf("unable to fetch channel %d", id)
}
return c
}
// registerChannel registers the channel identified by id with the client.
func (h *testHarness) registerChannel(id uint64) {
h.t.Helper()
chanID := chanIDFromInt(id)
err := h.client.RegisterChannel(chanID)
if err != nil {
h.t.Fatalf("unable to register channel %d: %v", id, err)
}
}
// advanceChannelN calls advanceState on the channel identified by id the number
// of provided times and returns the breach hints corresponding to the new
// states.
func (h *testHarness) advanceChannelN(id uint64, n int) []wtdb.BreachHint {
h.t.Helper()
channel := h.channel(id)
var hints []wtdb.BreachHint
for i := uint64(0); i < uint64(n); i++ {
channel.advanceState(h.t)
commitTx, _ := h.channel(id).getState(i)
breachTxID := commitTx.TxHash()
hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID))
}
return hints
}
// backupStates instructs the channel identified by id to send backups to the
// client for states in the range [to, from).
func (h *testHarness) backupStates(id, from, to uint64, expErr error) {
h.t.Helper()
for i := from; i < to; i++ {
h.backupState(id, i, expErr)
}
}
// backupStates instructs the channel identified by id to send a backup for
// state i.
func (h *testHarness) backupState(id, i uint64, expErr error) {
_, retribution := h.channel(id).getState(i)
chanID := chanIDFromInt(id)
err := h.client.BackupState(&chanID, retribution)
if err != expErr {
h.t.Fatalf("back error mismatch, want: %v, got: %v",
expErr, err)
}
}
// sendPayments instructs the channel identified by id to send amt to the remote
// party for each state in from-to times and returns the breach hints for states
// [from, to).
func (h *testHarness) sendPayments(id, from, to uint64,
amt lnwire.MilliSatoshi) []wtdb.BreachHint {
h.t.Helper()
channel := h.channel(id)
var hints []wtdb.BreachHint
for i := from; i < to; i++ {
h.channel(id).sendPayment(h.t, amt)
commitTx, _ := channel.getState(i)
breachTxID := commitTx.TxHash()
hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID))
}
return hints
}
// receivePayment instructs the channel identified by id to recv amt from the
// remote party for each state in from-to times and returns the breach hints for
// states [from, to).
func (h *testHarness) recvPayments(id, from, to uint64,
amt lnwire.MilliSatoshi) []wtdb.BreachHint {
h.t.Helper()
channel := h.channel(id)
var hints []wtdb.BreachHint
for i := from; i < to; i++ {
channel.receivePayment(h.t, amt)
commitTx, _ := channel.getState(i)
breachTxID := commitTx.TxHash()
hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID))
}
return hints
}
// waitServerUpdates blocks until the breach hints provided all appear in the
// watchtower's database or the timeout expires. This is used to test that the
// client in fact sends the updates to the server, even if it is offline.
func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint,
timeout time.Duration) {
h.t.Helper()
// If no breach hints are provided, we will wait out the full timeout to
// assert that no updates appear.
wantUpdates := len(hints) > 0
hintSet := make(map[wtdb.BreachHint]struct{})
for _, hint := range hints {
hintSet[hint] = struct{}{}
}
if len(hints) != len(hintSet) {
h.t.Fatalf("breach hints are not unique, list-len: %d "+
"set-len: %d", len(hints), len(hintSet))
}
// Closure to assert the server's matches are consistent with the hint
// set.
serverHasHints := func(matches []wtdb.Match) bool {
if len(hintSet) != len(matches) {
return false
}
for _, match := range matches {
if _, ok := hintSet[match.Hint]; ok {
continue
}
h.t.Fatalf("match %v in db is not in hint set",
match.Hint)
}
return true
}
failTimeout := time.After(timeout)
for {
select {
case <-time.After(time.Second):
matches, err := h.serverDB.QueryMatches(hints)
switch {
case err != nil:
h.t.Fatalf("unable to query for hints: %v", err)
case wantUpdates && serverHasHints(matches):
return
case wantUpdates:
h.t.Logf("Received %d/%d\n", len(matches),
len(hints))
}
case <-failTimeout:
matches, err := h.serverDB.QueryMatches(hints)
switch {
case err != nil:
h.t.Fatalf("unable to query for hints: %v", err)
case serverHasHints(matches):
return
default:
h.t.Fatalf("breach hints not received, only "+
"got %d/%d", len(matches), len(hints))
}
}
}
}
const (
localBalance = lnwire.MilliSatoshi(100000000)
remoteBalance = lnwire.MilliSatoshi(200000000)
)
type clientTest struct {
name string
cfg harnessCfg
fn func(*testHarness)
}
var clientTests = []clientTest{
{
// Asserts that client will return the ErrUnregisteredChannel
// error when trying to backup states for a channel that has not
// been registered (and received it's pkscript).
name: "backup unregistered channel",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 20000,
SweepFeeRate: 1,
},
noRegisterChan0: true,
},
fn: func(h *testHarness) {
const (
numUpdates = 5
chanID = 0
)
// Advance the channel and backup the retributions. We
// expect ErrUnregisteredChannel to be returned since
// the channel was not registered during harness
// creation.
h.advanceChannelN(chanID, numUpdates)
h.backupStates(
chanID, 0, numUpdates,
wtclient.ErrUnregisteredChannel,
)
},
},
{
// Asserts that the client returns an ErrClientExiting when
// trying to backup channels after the Stop method has been
// called.
name: "backup after stop",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 20000,
SweepFeeRate: 1,
},
},
fn: func(h *testHarness) {
const (
numUpdates = 5
chanID = 0
)
// Stop the client, subsequent backups should fail.
h.client.Stop()
// Advance the channel and try to back up the states. We
// expect ErrClientExiting to be returned from
// BackupState.
h.advanceChannelN(chanID, numUpdates)
h.backupStates(
chanID, 0, numUpdates,
wtclient.ErrClientExiting,
)
},
},
{
// Asserts that the client will continue to back up all states
// that have previously been enqueued before it finishes
// exiting.
name: "backup reliable flush",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 5,
SweepFeeRate: 1,
},
},
fn: func(h *testHarness) {
const (
numUpdates = 5
chanID = 0
)
// Generate numUpdates retributions and back them up to
// the tower.
hints := h.advanceChannelN(chanID, numUpdates)
h.backupStates(chanID, 0, numUpdates, nil)
// Stop the client in the background, to assert the
// pipeline is always flushed before it exits.
go h.client.Stop()
// Wait for all of the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, time.Second)
},
},
{
// Assert that the client will not send out backups for states
// whose justice transactions are ineligible for backup, e.g.
// creating dust outputs.
name: "backup dust ineligible",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 20000,
SweepFeeRate: 1000000, // high sweep fee creates dust
},
},
fn: func(h *testHarness) {
const (
numUpdates = 5
chanID = 0
)
// Create the retributions and queue them for backup.
h.advanceChannelN(chanID, numUpdates)
h.backupStates(chanID, 0, numUpdates, nil)
// Ensure that no updates are received by the server,
// since they should all be marked as ineligible.
h.waitServerUpdates(nil, time.Second)
},
},
{
// Verifies that the client will properly retransmit a committed
// state update to the watchtower after a restart if the update
// was not acked while the client was active last.
name: "committed update restart",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 20000,
SweepFeeRate: 1,
},
},
fn: func(h *testHarness) {
const (
numUpdates = 5
chanID = 0
)
hints := h.advanceChannelN(0, numUpdates)
var numSent uint64
// Add the first two states to the client's pipeline.
h.backupStates(chanID, 0, 2, nil)
numSent = 2
// Wait for both to be reflected in the server's
// database.
h.waitServerUpdates(hints[:numSent], time.Second)
// Now, restart the server and prevent it from acking
// state updates.
h.server.Stop()
h.serverCfg.NoAckUpdates = true
h.startServer()
defer h.server.Stop()
// Send the next state update to the tower. Since the
// tower isn't acking state updates, we expect this
// update to be committed and sent by the session queue,
// but it will never receive an ack.
h.backupState(chanID, numSent, nil)
numSent++
// Force quit the client to abort the state updates it
// has queued. The sleep ensures that the session queues
// have enough time to commit the state updates before
// the client is killed.
time.Sleep(time.Second)
h.client.ForceQuit()
// Restart the server and allow it to ack the updates
// after the client retransmits the unacked update.
h.server.Stop()
h.serverCfg.NoAckUpdates = false
h.startServer()
defer h.server.Stop()
// Restart the client and allow it to process the
// committed update.
h.startClient()
defer h.client.ForceQuit()
// Wait for the committed update to be accepted by the
// tower.
h.waitServerUpdates(hints[:numSent], time.Second)
// Finally, send the rest of the updates and wait for
// the tower to receive the remaining states.
h.backupStates(chanID, numSent, numUpdates, nil)
// Wait for all of the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, time.Second)
},
},
{
// Asserts that the client will continue to retry sending state
// updates if it doesn't receive an ack from the server. The
// client is expected to flush everything in its in-memory
// pipeline once the server begins sending acks again.
name: "no ack from server",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 5,
SweepFeeRate: 1,
},
},
fn: func(h *testHarness) {
const (
numUpdates = 100
chanID = 0
)
// Generate the retributions that will be backed up.
hints := h.advanceChannelN(chanID, numUpdates)
// Restart the server and prevent it from acking state
// updates.
h.server.Stop()
h.serverCfg.NoAckUpdates = true
h.startServer()
defer h.server.Stop()
// Now, queue the retributions for backup.
h.backupStates(chanID, 0, numUpdates, nil)
// Stop the client in the background, to assert the
// pipeline is always flushed before it exits.
go h.client.Stop()
// Give the client time to saturate a large number of
// session queues for which the server has not acked the
// state updates that it has received.
time.Sleep(time.Second)
// Restart the server and allow it to ack the updates
// after the client retransmits the unacked updates.
h.server.Stop()
h.serverCfg.NoAckUpdates = false
h.startServer()
defer h.server.Stop()
// Wait for all of the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, 5*time.Second)
},
},
{
// Asserts that the client is able to send state updates to the
// tower for a full range of channel values, assuming the sweep
// fee rates permit it. We expect all of these to be successful
// since a sweep transactions spending only from one output is
// less expensive than one that sweeps both.
name: "send and recv",
cfg: harnessCfg{
localBalance: 10000001, // ensure (% amt != 0)
remoteBalance: 20000001, // ensure (% amt != 0)
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 1000,
SweepFeeRate: 1,
},
},
fn: func(h *testHarness) {
var (
capacity = h.cfg.localBalance + h.cfg.remoteBalance
paymentAmt = lnwire.MilliSatoshi(200000)
numSends = uint64(h.cfg.localBalance / paymentAmt)
numRecvs = uint64(capacity / paymentAmt)
numUpdates = numSends + numRecvs // 200 updates
chanID = uint64(0)
)
// Send money to the remote party until all funds are
// depleted.
sendHints := h.sendPayments(chanID, 0, numSends, paymentAmt)
// Now, sequentially receive the entire channel balance
// from the remote party.
recvHints := h.recvPayments(chanID, numSends, numUpdates, paymentAmt)
// Collect the hints generated by both sending and
// receiving.
hints := append(sendHints, recvHints...)
// Backup the channel's states the client.
h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, 3*time.Second)
},
},
{
// Asserts that the client is able to support multiple links.
name: "multiple link backup",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 5,
SweepFeeRate: 1,
},
},
fn: func(h *testHarness) {
const (
numUpdates = 5
numChans = 10
)
// Initialize and register an additional 9 channels.
for id := uint64(1); id < 10; id++ {
h.makeChannel(
id, h.cfg.localBalance,
h.cfg.remoteBalance,
)
h.registerChannel(id)
}
// Generate the retributions for all 10 channels and
// collect the breach hints.
var hints []wtdb.BreachHint
for id := uint64(0); id < 10; id++ {
chanHints := h.advanceChannelN(id, numUpdates)
hints = append(hints, chanHints...)
}
// Provided all retributions to the client from all
// channels.
for id := uint64(0); id < 10; id++ {
h.backupStates(id, 0, numUpdates, nil)
}
// Test reliable flush under multi-client scenario.
go h.client.Stop()
// Wait for all of the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, 10*time.Second)
},
},
}
// TestClient executes the client test suite, asserting the ability to backup
// states in a number of failure cases and it's reliability during shutdown.
func TestClient(t *testing.T) {
for _, test := range clientTests {
tc := test
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
h := newHarness(t, tc.cfg)
defer h.server.Stop()
defer h.client.ForceQuit()
tc.fn(h)
})
}
}

@ -0,0 +1,35 @@
package wtclient
import "errors"
var (
// ErrClientExiting signals that the watchtower client is shutting down.
ErrClientExiting = errors.New("watchtower client shutting down")
// ErrTowerCandidatesExhausted signals that a TowerCandidateIterator has
// cycled through all available candidates.
ErrTowerCandidatesExhausted = errors.New("exhausted all tower " +
"candidates")
// ErrPermanentTowerFailure signals that the tower has reported that it
// has permanently failed or the client believes this has happened based
// on the tower's behavior.
ErrPermanentTowerFailure = errors.New("permanent tower failure")
// ErrNegotiatorExiting signals that the SessionNegotiator is shutting
// down.
ErrNegotiatorExiting = errors.New("negotiator exiting")
// ErrNoTowerAddrs signals that the client could not be created because
// we have no addresses with which we can reach a tower.
ErrNoTowerAddrs = errors.New("no tower addresses")
// ErrFailedNegotiation signals that the session negotiator could not
// acquire a new session as requested.
ErrFailedNegotiation = errors.New("session negotiation unsuccessful")
// ErrUnregisteredChannel signals that the client was unable to backup a
// revoked state becuase the channel had not been previously registered
// with the client.
ErrUnregisteredChannel = errors.New("channel is not registered")
)

@ -0,0 +1,76 @@
package wtclient
import (
"net"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/brontide"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
)
// DB abstracts the required database operations required by the watchtower
// client.
type DB interface {
// CreateTower initialize an address record used to communicate with a
// watchtower. Each Tower is assigned a unique ID, that is used to
// amortize storage costs of the public key when used by multiple
// sessions.
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
// CreateClientSession saves a newly negotiated client session to the
// client's database. This enables the session to be used across
// restarts.
CreateClientSession(*wtdb.ClientSession) error
// ListClientSessions returns all sessions that have not yet been
// exhausted. This is used on startup to find any sessions which may
// still be able to accept state updates.
ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error)
// FetchChanPkScripts returns a map of all sweep pkscripts for
// registered channels. This is used on startup to cache the sweep
// pkscripts of registered channels in memory.
FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error)
// AddChanPkScript inserts a newly generated sweep pkscript for the
// given channel.
AddChanPkScript(lnwire.ChannelID, []byte) error
// MarkBackupIneligible records that the state identified by the
// (channel id, commit height) tuple was ineligible for being backed up
// under the current policy. This state can be retried later under a
// different policy.
MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error
// CommitUpdate writes the next state update for a particular
// session, so that we can be sure to resend it after a restart if it
// hasn't been ACK'd by the tower. The sequence number of the update
// should be exactly one greater than the existing entry, and less that
// or equal to the session's MaxUpdates.
CommitUpdate(id *wtdb.SessionID, seqNum uint16,
update *wtdb.CommittedUpdate) (uint16, error)
// AckUpdate records an acknowledgment from the watchtower that the
// update identified by seqNum was received and saved. The returned
// lastApplied will be recorded.
AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error
}
// Dial connects to an addr using the specified net and returns the connection
// object.
type Dial func(net, addr string) (net.Conn, error)
// AuthDialer connects to a remote node using an authenticated transport, such as
// brontide. The dialer argument is used to specify a resolver, which allows
// this method to be used over Tor or clear net connections.
type AuthDialer func(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error)
// AuthDial is the watchtower client's default method of dialing.
func AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error) {
return brontide.Dial(localPriv, netAddr, dialer)
}

@ -0,0 +1,29 @@
package wtclient
import (
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build"
)
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("WTCL", nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}

@ -0,0 +1,451 @@
package wtclient
import (
"fmt"
"sync"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
// SessionNegotiator is an interface for asynchronously requesting new sessions.
type SessionNegotiator interface {
// RequestSession signals to the session negotiator that the client
// needs another session. Once the session is negotiated, it should be
// returned via NewSessions.
RequestSession()
// NewSessions is a read-only channel where newly negotiated sessions
// will be delivered.
NewSessions() <-chan *wtdb.ClientSession
// Start safely initializes the session negotiator.
Start() error
// Stop safely shuts down the session negotiator.
Stop() error
}
// NegotiatorConfig provides access to the resources required by a
// SessionNegotiator to faithfully carry out its duties. All nil-able field must
// be initialized.
type NegotiatorConfig struct {
// DB provides access to a persistent storage medium used by the tower
// to properly allocate session ephemeral keys and record successfully
// negotiated sessions.
DB DB
// Candidates is an abstract set of tower candidates that the negotiator
// will traverse serially when attempting to negotiate a new session.
Candidates TowerCandidateIterator
// Policy defines the session policy that will be proposed to towers
// when attempting to negotiate a new session. This policy will be used
// across all negotiation proposals for the lifetime of the negotiator.
Policy wtpolicy.Policy
// Dial initiates an outbound brontide connection to the given address
// using a specified private key. The peer is returned in the event of a
// successful connection.
Dial func(*btcec.PrivateKey, *lnwire.NetAddress) (wtserver.Peer, error)
// SendMessage writes a wtwire message to remote peer.
SendMessage func(wtserver.Peer, wtwire.Message) error
// ReadMessage reads a message from a remote peer and returns the
// decoded wtwire message.
ReadMessage func(wtserver.Peer) (wtwire.Message, error)
// ChainHash the genesis hash identifying the chain for any negotiated
// sessions. Any state updates sent to that session should also
// originate from this chain.
ChainHash chainhash.Hash
// MinBackoff defines the initial backoff applied by the session
// negotiator after all tower candidates have been exhausted and
// reattempting negotiation with the same set of candidates. Subsequent
// backoff durations will grow exponentially.
MinBackoff time.Duration
// MaxBackoff defines the maximum backoff applied by the session
// negotiator after all tower candidates have been exhausted and
// reattempting negotation with the same set of candidates. If the
// exponential backoff produces a timeout greater than this value, the
// backoff duration will be clamped to MaxBackoff.
MaxBackoff time.Duration
}
// sessionNegotiator is concrete SessionNegotiator that is able to request new
// sessions from a set of candidate towers asynchronously and return successful
// sessions to the primary client.
type sessionNegotiator struct {
started sync.Once
stopped sync.Once
localInit *wtwire.Init
cfg *NegotiatorConfig
dispatcher chan struct{}
newSessions chan *wtdb.ClientSession
successfulNegotiations chan *wtdb.ClientSession
wg sync.WaitGroup
quit chan struct{}
}
// Compile-time constraint to ensure a *sessionNegotiator implements the
// SessionNegotiator interface.
var _ SessionNegotiator = (*sessionNegotiator)(nil)
// newSessionNegotiator initializes a fresh sessionNegotiator instance.
func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator {
localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
cfg.ChainHash,
)
return &sessionNegotiator{
cfg: cfg,
localInit: localInit,
dispatcher: make(chan struct{}, 1),
newSessions: make(chan *wtdb.ClientSession),
successfulNegotiations: make(chan *wtdb.ClientSession),
quit: make(chan struct{}),
}
}
// Start safely starts up the sessionNegotiator.
func (n *sessionNegotiator) Start() error {
n.started.Do(func() {
log.Debugf("Starting session negotiator")
n.wg.Add(1)
go n.negotiationDispatcher()
})
return nil
}
// Stop safely shutsdown the sessionNegotiator.
func (n *sessionNegotiator) Stop() error {
n.stopped.Do(func() {
log.Debugf("Stopping session negotiator")
close(n.quit)
n.wg.Wait()
})
return nil
}
// NewSessions returns a receive-only channel from which newly negotiated
// sessions will be returned.
func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession {
return n.newSessions
}
// RequestSession sends a request to the sessionNegotiator to begin requesting a
// new session. If one is already in the process of being negotiated, the
// request will be ignored.
func (n *sessionNegotiator) RequestSession() {
select {
case n.dispatcher <- struct{}{}:
default:
}
}
// negotiationDispatcher acts as the primary event loop for the
// sessionNegotiator, coordinating requests for more sessions and dispatching
// attempts to negotiate them from a list of candidates.
func (n *sessionNegotiator) negotiationDispatcher() {
defer n.wg.Done()
var pendingNegotiations int
for {
select {
case <-n.dispatcher:
pendingNegotiations++
if pendingNegotiations > 1 {
log.Debugf("Already negotiating session, " +
"waiting for existing negotiation to " +
"complete")
continue
}
// TODO(conner): consider reusing good towers
log.Debugf("Dispatching session negotiation")
n.wg.Add(1)
go n.negotiate()
case session := <-n.successfulNegotiations:
select {
case n.newSessions <- session:
pendingNegotiations--
case <-n.quit:
return
}
if pendingNegotiations > 0 {
log.Debugf("Dispatching pending session " +
"negotiation")
n.wg.Add(1)
go n.negotiate()
}
case <-n.quit:
return
}
}
}
// negotiate handles the process of iterating through potential tower candidates
// and attempting to negotiate a new session until a successful negotiation
// occurs. If the candidate iterator becomes exhausted because none were
// successful, this method will back off exponentially up to the configured max
// backoff. This method will continue trying until a negotiation is succesful
// before returning the negotiated session to the dispatcher via the succeed
// channel.
//
// NOTE: This method MUST be run as a goroutine.
func (n *sessionNegotiator) negotiate() {
defer n.wg.Done()
// On the first pass, initialize the backoff to our configured min
// backoff.
backoff := n.cfg.MinBackoff
retryWithBackoff:
// If we are retrying, wait out the delay before continuing.
if backoff > 0 {
select {
case <-time.After(backoff):
case <-n.quit:
return
}
}
// Before attempting a bout of session negotiation, reset the candidate
// iterator to ensure the results are fresh.
n.cfg.Candidates.Reset()
for {
// Pull the next candidate from our list of addresses.
tower, err := n.cfg.Candidates.Next()
if err != nil {
// We've run out of addresses, double and clamp backoff.
backoff *= 2
if backoff > n.cfg.MaxBackoff {
backoff = n.cfg.MaxBackoff
}
log.Debugf("Unable to get new tower candidate, "+
"retrying after %v -- reason: %v", backoff, err)
goto retryWithBackoff
}
log.Debugf("Attempting session negotiation with tower=%x",
tower.IdentityKey.SerializeCompressed())
// We'll now attempt the CreateSession dance with the tower to
// get a new session, trying all addresses if necessary.
err = n.createSession(tower)
if err != nil {
log.Debugf("Session negotiation with tower=%x "+
"failed, trying again -- reason: %v",
tower.IdentityKey.SerializeCompressed(), err)
continue
}
// Success.
return
}
}
// createSession takes a tower an attempts to negotiate a session using any of
// its stored addresses. This method returns after the first successful
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If
// the tower has no addresses, ErrNoTowerAddrs is returned.
func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error {
// If the tower has no addresses, there's nothing we can do.
if len(tower.Addresses) == 0 {
return ErrNoTowerAddrs
}
// TODO(conner): create with hdkey at random index
sessionPrivKey, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
return err
}
// TODO(conner): write towerAddr+privkey
for _, lnAddr := range tower.LNAddrs() {
err = n.tryAddress(sessionPrivKey, tower, lnAddr)
switch {
case err == ErrPermanentTowerFailure:
// TODO(conner): report to iterator? can then be reset
// with restart
fallthrough
case err != nil:
log.Debugf("Request for session negotiation with "+
"tower=%s failed, trying again -- reason: "+
"%v", lnAddr, err)
continue
default:
return nil
}
}
return ErrFailedNegotiation
}
// tryAddress executes a single create session dance using the given address.
// The address should belong to the tower's set of addresses. This method only
// returns true if all steps succeed and the new session has been persisted, and
// fails otherwise.
func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
// Connect to the tower address using our generated session key.
conn, err := n.cfg.Dial(privKey, lnAddr)
if err != nil {
return err
}
// Send local Init message.
err = n.cfg.SendMessage(conn, n.localInit)
if err != nil {
return fmt.Errorf("unable to send Init: %v", err)
}
// Receive remote Init message.
remoteMsg, err := n.cfg.ReadMessage(conn)
if err != nil {
return fmt.Errorf("unable to read Init: %v", err)
}
// Check that returned message is wtwire.Init.
remoteInit, ok := remoteMsg.(*wtwire.Init)
if !ok {
return fmt.Errorf("expected Init, got %T in reply", remoteMsg)
}
// Verify the watchtower's remote Init message against our own.
err = n.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
if err != nil {
return err
}
policy := n.cfg.Policy
createSession := &wtwire.CreateSession{
BlobType: policy.BlobType,
MaxUpdates: policy.MaxUpdates,
RewardBase: policy.RewardBase,
RewardRate: policy.RewardRate,
SweepFeeRate: policy.SweepFeeRate,
}
// Send CreateSession message.
err = n.cfg.SendMessage(conn, createSession)
if err != nil {
return fmt.Errorf("unable to send CreateSession: %v", err)
}
// Receive CreateSessionReply message.
remoteMsg, err = n.cfg.ReadMessage(conn)
if err != nil {
return fmt.Errorf("unable to read CreateSessionReply: %v", err)
}
// Check that returned message is wtwire.CreateSessionReply.
createSessionReply, ok := remoteMsg.(*wtwire.CreateSessionReply)
if !ok {
return fmt.Errorf("expected CreateSessionReply, got %T in "+
"reply", remoteMsg)
}
switch createSessionReply.Code {
case wtwire.CodeOK, wtwire.CreateSessionCodeAlreadyExists:
// TODO(conner): add last-applied to create session reply to
// handle case where we lose state, session already exists, and
// we want to possibly resume using the session
// TODO(conner): validate reward address
rewardPkScript := createSessionReply.Data
sessionID := wtdb.NewSessionIDFromPubKey(
privKey.PubKey(),
)
clientSession := &wtdb.ClientSession{
TowerID: tower.ID,
Tower: tower,
SessionPrivKey: privKey, // remove after using HD keys
ID: sessionID,
Policy: n.cfg.Policy,
SeqNum: 0,
RewardPkScript: rewardPkScript,
}
err = n.cfg.DB.CreateClientSession(clientSession)
if err != nil {
return fmt.Errorf("unable to persist ClientSession: %v",
err)
}
log.Debugf("New session negotiated with %s, policy: %s",
lnAddr, clientSession.Policy)
// We have a newly negotiated session, return it to the
// dispatcher so that it can update how many outstanding
// negotiation requests we have.
select {
case n.successfulNegotiations <- clientSession:
return nil
case <-n.quit:
return ErrNegotiatorExiting
}
// TODO(conner): handle error codes properly
case wtwire.CreateSessionCodeRejectBlobType:
return fmt.Errorf("tower rejected blob type: %v",
policy.BlobType)
case wtwire.CreateSessionCodeRejectMaxUpdates:
return fmt.Errorf("tower rejected max updates: %v",
policy.MaxUpdates)
case wtwire.CreateSessionCodeRejectRewardRate:
// The tower rejected the session because of the reward rate. If
// we didn't request a reward session, we'll treat this as a
// permanent tower failure.
if !policy.BlobType.Has(blob.FlagReward) {
return ErrPermanentTowerFailure
}
return fmt.Errorf("tower rejected reward rate: %v",
policy.RewardRate)
case wtwire.CreateSessionCodeRejectSweepFeeRate:
return fmt.Errorf("tower rejected sweep fee rate: %v",
policy.SweepFeeRate)
default:
return fmt.Errorf("received unhandled error code: %v",
createSessionReply.Code)
}
}

@ -0,0 +1,688 @@
package wtclient
import (
"container/list"
"fmt"
"sort"
"sync"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
// retryInterval is the default duration we will wait between attempting to
// connect back out to a tower if the prior state update failed.
const retryInterval = 2 * time.Second
// reserveStatus is an enum that signals how full a particular session is.
type reserveStatus uint8
const (
// reserveAvailable indicates that the session has space for at least
// one more backup.
reserveAvailable reserveStatus = iota
// reserveExhausted indicates that all slots in the session have been
// allocated.
reserveExhausted
)
// sessionQueueConfig bundles the resources required by the sessionQueue to
// perform its duties. All entries MUST be non-nil.
type sessionQueueConfig struct {
// ClientSession provides access to the negotiated session parameters
// and updating its persistent storage.
ClientSession *wtdb.ClientSession
// ChainHash identifies the chain for which the session's justice
// transactions are targeted.
ChainHash chainhash.Hash
// Dial allows the client to dial the tower using it's public key and
// net address.
Dial func(*btcec.PrivateKey,
*lnwire.NetAddress) (wtserver.Peer, error)
// SendMessage encodes, encrypts, and writes a message to the given peer.
SendMessage func(wtserver.Peer, wtwire.Message) error
// ReadMessage receives, decypts, and decodes a message from the given
// peer.
ReadMessage func(wtserver.Peer) (wtwire.Message, error)
// Signer facilitates signing of inputs, used to construct the witnesses
// for justice transaction inputs.
Signer input.Signer
// DB provides access to the client's stable storage.
DB DB
// MinBackoff defines the initial backoff applied by the session
// queue before reconnecting to the tower after a failed or partially
// successful batch is sent. Subsequent backoff durations will grow
// exponentially up until MaxBackoff.
MinBackoff time.Duration
// MaxBackoff defines the maximum backoff applied by the session
// queue before reconnecting to the tower after a failed or partially
// successful batch is sent. If the exponential backoff produces a
// timeout greater than this value, the backoff duration will be clamped
// to MaxBackoff.
MaxBackoff time.Duration
}
// sessionQueue implements a reliable queue that will encrypt and send accepted
// backups to the watchtower specified in the config's ClientSession. Calling
// Quit will attempt to perform a clean shutdown by receiving an ACK from the
// tower for all pending backups before exiting. The clean shutdown can be
// aborted by using ForceQuit, which will attempt to shutdown the queue
// immediately.
type sessionQueue struct {
started sync.Once
stopped sync.Once
forced sync.Once
cfg *sessionQueueConfig
commitQueue *list.List
pendingQueue *list.List
queueMtx sync.Mutex
queueCond *sync.Cond
localInit *wtwire.Init
towerAddr *lnwire.NetAddress
seqNum uint16
retryBackoff time.Duration
quit chan struct{}
forceQuit chan struct{}
shutdown chan struct{}
}
// newSessionQueue intiializes a fresh sessionQueue.
func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
cfg.ChainHash,
)
towerAddr := &lnwire.NetAddress{
IdentityKey: cfg.ClientSession.Tower.IdentityKey,
Address: cfg.ClientSession.Tower.Addresses[0],
}
sq := &sessionQueue{
cfg: cfg,
commitQueue: list.New(),
pendingQueue: list.New(),
localInit: localInit,
towerAddr: towerAddr,
seqNum: cfg.ClientSession.SeqNum,
retryBackoff: cfg.MinBackoff,
quit: make(chan struct{}),
forceQuit: make(chan struct{}),
shutdown: make(chan struct{}),
}
sq.queueCond = sync.NewCond(&sq.queueMtx)
sq.restoreCommittedUpdates()
return sq
}
// Start idempotently starts the sessionQueue so that it can begin accepting
// backups.
func (q *sessionQueue) Start() {
q.started.Do(func() {
// TODO(conner): load prior committed state updates from disk an
// populate in queue.
go q.sessionManager()
})
}
// Stop idempotently stops the sessionQueue by initiating a clean shutdown that
// will clear all pending tasks in the queue before returning to the caller.
func (q *sessionQueue) Stop() {
q.stopped.Do(func() {
log.Debugf("Stopping session queue %s", q.ID())
close(q.quit)
q.signalUntilShutdown()
// Skip log if we also force quit.
select {
case <-q.forceQuit:
return
default:
}
log.Debugf("Session queue %s successfully stopped", q.ID())
})
}
// ForceQuit idempotently aborts any clean shutdown in progress and returns to
// he caller after all lingering goroutines have spun down.
func (q *sessionQueue) ForceQuit() {
q.forced.Do(func() {
log.Infof("Force quitting session queue %s", q.ID())
close(q.forceQuit)
q.signalUntilShutdown()
log.Infof("Session queue %s unclean shutdown complete", q.ID())
})
}
// ID returns the wtdb.SessionID for the queue, which can be used to uniquely
// identify this a particular queue.
func (q *sessionQueue) ID() *wtdb.SessionID {
return &q.cfg.ClientSession.ID
}
// AcceptTask attempts to queue a backupTask for delivery to the sessionQueue's
// tower. The session will only be accepted if the queue is not already
// exhausted and the task is successfully bound to the ClientSession.
func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
q.queueCond.L.Lock()
// Examine the current reserve status of the session queue.
curStatus := q.reserveStatus()
switch curStatus {
// The session queue is exhausted, and cannot accept the task because it
// is full. Reject the task such that it can be tried against a
// different session.
case reserveExhausted:
q.queueCond.L.Unlock()
return curStatus, false
// The session queue is not exhausted. Compute the sweep and reward
// outputs as a function of the session parameters. If the outputs are
// dusty or uneconomical to backup, the task is rejected and will not be
// tried again.
//
// TODO(conner): queue backups and retry with different session params.
case reserveAvailable:
err := task.bindSession(q.cfg.ClientSession)
if err != nil {
q.queueCond.L.Unlock()
log.Debugf("SessionQueue %s rejected backup chanid=%s "+
"commit-height=%d: %v", q.ID(), task.id.ChanID,
task.id.CommitHeight, err)
return curStatus, false
}
}
// The sweep and reward outputs satisfy the session's policy, queue the
// task for final signing and delivery.
q.pendingQueue.PushBack(task)
// Finally, compute the session's *new* reserve status. This will be
// used by the client to determine if it can continue using this session
// queue, or if it should negotiate a new one.
newStatus := q.reserveStatus()
q.queueCond.L.Unlock()
q.queueCond.Signal()
return newStatus, true
}
// updateWithSeqNum stores a CommittedUpdate with its assigned sequence number.
// This allows committed updates to be sorted after a restart, and added to the
// commitQueue in the proper order for delivery.
type updateWithSeqNum struct {
seqNum uint16
update *wtdb.CommittedUpdate
}
// restoreCommittedUpdates processes any CommittedUpdates loaded on startup by
// sorting them in ascending order of sequence numbers and adding them to the
// commitQueue. These will be sent before any pending updates are processed.
func (q *sessionQueue) restoreCommittedUpdates() {
committedUpdates := q.cfg.ClientSession.CommittedUpdates
// Construct and unordered slice of all committed updates with their
// assigned sequence numbers.
sortedUpdates := make([]updateWithSeqNum, 0, len(committedUpdates))
for seqNum, update := range committedUpdates {
sortedUpdates = append(sortedUpdates, updateWithSeqNum{
seqNum: seqNum,
update: update,
})
}
// Sort the resulting slice by increasing sequence number.
sort.Slice(sortedUpdates, func(i, j int) bool {
return sortedUpdates[i].seqNum < sortedUpdates[j].seqNum
})
// Finally, add the sorted, committed updates to he commitQueue. These
// updates will be prioritized before any new tasks are assigned to the
// sessionQueue. The queue will begin uploading any tasks in the
// commitQueue as soon as it is started, e.g. during client
// initialization when detecting that this session has unacked updates.
for _, update := range sortedUpdates {
q.commitQueue.PushBack(update)
}
}
// sessionManager is the primary event loop for the sessionQueue, and is
// responsible for encrypting and sending accepted tasks to the tower.
func (q *sessionQueue) sessionManager() {
defer close(q.shutdown)
for {
q.queueCond.L.Lock()
for q.commitQueue.Len() == 0 &&
q.pendingQueue.Len() == 0 {
q.queueCond.Wait()
select {
case <-q.quit:
if q.commitQueue.Len() == 0 &&
q.pendingQueue.Len() == 0 {
q.queueCond.L.Unlock()
return
}
case <-q.forceQuit:
q.queueCond.L.Unlock()
return
default:
}
}
q.queueCond.L.Unlock()
// Exit immediately if a force quit has been requested. If the
// either of the queues still has state updates to send to the
// tower, we may never exit in the above case if we are unable
// to reach the tower for some reason.
select {
case <-q.forceQuit:
return
default:
}
// Initiate a new connection to the watchtower and attempt to
// drain all pending tasks.
q.drainBackups()
}
}
// drainBackups attempts to send all pending updates in the queue to the tower.
func (q *sessionQueue) drainBackups() {
// First, check that we are able to dial this session's tower.
conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionPrivKey, q.towerAddr)
if err != nil {
log.Errorf("Unable to dial watchtower at %v: %v",
q.towerAddr, err)
q.increaseBackoff()
select {
case <-time.After(q.retryBackoff):
case <-q.forceQuit:
}
return
}
defer conn.Close()
// Begin draining the queue of pending state updates. Before the first
// update is sent, we will precede it with an Init message. If the first
// is successful, subsequent updates can be streamed without sending an
// Init.
for sendInit := true; ; sendInit = false {
// Generate the next state update to upload to the tower. This
// method will first proceed in dequeueing committed updates
// before attempting to dequeue any pending updates.
stateUpdate, isPending, err := q.nextStateUpdate()
if err != nil {
log.Errorf("Unable to get next state update: %v", err)
return
}
// Now, send the state update to the tower and wait for a reply.
err = q.sendStateUpdate(
conn, stateUpdate, q.localInit, sendInit, isPending,
)
if err != nil {
log.Errorf("Unable to send state update: %v", err)
q.increaseBackoff()
select {
case <-time.After(q.retryBackoff):
case <-q.forceQuit:
}
return
}
// If the last task was backed up successfully, we'll exit and
// continue once more tasks are added to the queue. We'll also
// clear any accumulated backoff as this batch was able to be
// sent reliably.
if stateUpdate.IsComplete == 1 {
q.resetBackoff()
return
}
// Always apply a small delay between sends, which makes the
// unit tests more reliable. If we were requested to back off,
// when we will do so.
select {
case <-time.After(time.Millisecond):
case <-q.forceQuit:
return
}
}
}
// nextStateUpdate returns the next wtwire.StateUpdate to upload to the tower.
// If any committed updates are present, this method will reconstruct the state
// update from the committed update using the current last applied value found
// in the database. Otherwise, it will select the next pending update, craft the
// payload, and commit an update before returning the state update to send. The
// boolean value in the response is true if the state update is taken from the
// pending queue, allowing the caller to remove the update from either the
// commit or pending queue if the update is successfully acked.
func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
var (
seqNum uint16
update *wtdb.CommittedUpdate
isLast bool
isPending bool
)
q.queueCond.L.Lock()
switch {
// If the commit queue is non-empty, parse the next committed update.
case q.commitQueue.Len() > 0:
next := q.commitQueue.Front()
updateWithSeq := next.Value.(updateWithSeqNum)
seqNum = updateWithSeq.seqNum
update = updateWithSeq.update
// If this is the last item in the commit queue and no items
// exist in the pending queue, we will use the IsComplete flag
// in the StateUpdate to signal that the tower can release the
// connection after replying to free up resources.
isLast = q.commitQueue.Len() == 1 && q.pendingQueue.Len() == 0
q.queueCond.L.Unlock()
log.Debugf("Reprocessing committed state update for "+
"session=%s seqnum=%d", q.ID(), seqNum)
// Otherwise, craft and commit the next update from the pending queue.
default:
isPending = true
// Determine the current sequence number to apply for this
// pending update.
seqNum = q.seqNum + 1
// Obtain the next task from the queue.
next := q.pendingQueue.Front()
task := next.Value.(*backupTask)
// If this is the last item in the pending queue, we will use
// the IsComplete flag in the StateUpdate to signal that the
// tower can release the connection after replying to free up
// resources.
isLast = q.pendingQueue.Len() == 1
q.queueCond.L.Unlock()
hint, encBlob, err := task.craftSessionPayload(q.cfg.Signer)
if err != nil {
// TODO(conner): mark will not send
return nil, false, fmt.Errorf("unable to craft "+
"session payload: %v", err)
}
// TODO(conner): special case other obscure errors
update = &wtdb.CommittedUpdate{
BackupID: task.id,
Hint: hint,
EncryptedBlob: encBlob,
}
log.Debugf("Committing state update for session=%s seqnum=%d",
q.ID(), seqNum)
}
// Before sending the task to the tower, commit the state update
// to disk using the assigned sequence number. If this task has already
// been committed, the call will succeed and only be used for the
// purpose of obtaining the last applied value to send to the tower.
//
// This step ensures that if we crash before receiving an ack that we
// will retransmit the same update. If the tower successfully received
// the update from before, it will reply with an ACK regardless of what
// we send the next time. This step ensures that if we reliably send the
// same update for a given sequence number, to prevent us from thinking
// we backed up a state when we instead backed up another.
lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), seqNum, update)
if err != nil {
// TODO(conner): mark failed/reschedule
return nil, false, fmt.Errorf("unable to commit state update "+
"for session=%s seqnum=%d: %v", q.ID(), seqNum, err)
}
stateUpdate := &wtwire.StateUpdate{
SeqNum: seqNum,
LastApplied: lastApplied,
Hint: update.Hint,
EncryptedBlob: update.EncryptedBlob,
}
// Set the IsComplete flag if this is the last queued item.
if isLast {
stateUpdate.IsComplete = 1
}
return stateUpdate, isPending, nil
}
// sendStateUpdate sends a wtwire.StateUpdate to the watchtower and processes
// the ACK before returning. If sendInit is true, this method will first send
// the localInit message and verify that the tower supports our required feature
// bits. And error is returned if any part of the send fails. The boolean return
// variable indicates whether or not we should back off before attempting to
// send the next state update.
func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
stateUpdate *wtwire.StateUpdate, localInit *wtwire.Init,
sendInit, isPending bool) error {
// If this is the first message being sent to the tower, we must send an
// Init message to establish that server supports the features we
// require.
if sendInit {
// Send Init to tower.
err := q.cfg.SendMessage(conn, q.localInit)
if err != nil {
return err
}
// Receive Init from tower.
remoteMsg, err := q.cfg.ReadMessage(conn)
if err != nil {
return err
}
remoteInit, ok := remoteMsg.(*wtwire.Init)
if !ok {
return fmt.Errorf("watchtower responded with %T to "+
"Init", remoteMsg)
}
// Validate Init.
err = q.localInit.CheckRemoteInit(
remoteInit, wtwire.FeatureNames,
)
if err != nil {
return err
}
}
// Send StateUpdate to tower.
err := q.cfg.SendMessage(conn, stateUpdate)
if err != nil {
return err
}
// Receive StateUpdate from tower.
remoteMsg, err := q.cfg.ReadMessage(conn)
if err != nil {
return err
}
stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply)
if !ok {
return fmt.Errorf("watchtower responded with %T to StateUpdate",
remoteMsg)
}
// Process the reply from the tower.
switch stateUpdateReply.Code {
// The tower reported a successful update, validate the response and
// record the last applied returned.
case wtwire.CodeOK:
// TODO(conner): handle other error cases properly, ban towers, etc.
default:
err := fmt.Errorf("received error code %s in "+
"StateUpdateReply from tower=%x session=%s",
stateUpdateReply.Code,
conn.RemotePub().SerializeCompressed(), q.ID())
log.Warnf("Unable to upload state update: %v", err)
return err
}
lastApplied := stateUpdateReply.LastApplied
err = q.cfg.DB.AckUpdate(q.ID(), stateUpdate.SeqNum, lastApplied)
switch {
case err == wtdb.ErrUnallocatedLastApplied:
// TODO(conner): borked watchtower
err = fmt.Errorf("unable to ack update=%d session=%s: %v",
stateUpdate.SeqNum, q.ID(), err)
log.Errorf("Failed to ack update: %v", err)
return err
case err == wtdb.ErrLastAppliedReversion:
// TODO(conner): borked watchtower
err = fmt.Errorf("unable to ack update=%d session=%s: %v",
stateUpdate.SeqNum, q.ID(), err)
log.Errorf("Failed to ack update: %v", err)
return err
case err != nil:
err = fmt.Errorf("unable to ack update=%d session=%s: %v",
stateUpdate.SeqNum, q.ID(), err)
log.Errorf("Failed to ack update: %v", err)
return err
}
log.Infof("Removing update session=%s seqnum=%d is_pending=%v "+
"from memory", q.ID(), stateUpdate.SeqNum, isPending)
q.queueCond.L.Lock()
if isPending {
// If a pending update was successfully sent, increment the
// sequence number and remove the item from the queue. This
// ensures the total number of backups in the session remains
// unchanged, which maintains the external view of the session's
// reserve status.
q.seqNum++
q.pendingQueue.Remove(q.pendingQueue.Front())
} else {
// Otherwise, simply remove the update from the committed queue.
// This has no effect on the queues reserve status since the
// update had already been committed.
q.commitQueue.Remove(q.commitQueue.Front())
}
q.queueCond.L.Unlock()
return nil
}
// reserveStatus returns a reserveStatus indicating whether or not the
// sessionQueue can accept another task. reserveAvailable is returned when a
// task can be accepted, and reserveExhausted is returned if the all slots in
// the session have been allocated.
//
// NOTE: This method MUST be called with queueCond's exclusive lock held.
func (q *sessionQueue) reserveStatus() reserveStatus {
numPending := uint32(q.pendingQueue.Len())
maxUpdates := uint32(q.cfg.ClientSession.Policy.MaxUpdates)
log.Debugf("SessionQueue %s reserveStatus seqnum=%d pending=%d "+
"max-updates=%d", q.ID(), q.seqNum, numPending, maxUpdates)
if uint32(q.seqNum)+numPending < maxUpdates {
return reserveAvailable
}
return reserveExhausted
}
// resetBackoff returns the connection backoff the minimum configured backoff.
func (q *sessionQueue) resetBackoff() {
q.retryBackoff = q.cfg.MinBackoff
}
// increaseBackoff doubles the current connection backoff, clamping to the
// configured maximum backoff if it would exceed the limit.
func (q *sessionQueue) increaseBackoff() {
q.retryBackoff *= 2
if q.retryBackoff > q.cfg.MaxBackoff {
q.retryBackoff = q.cfg.MaxBackoff
}
}
// signalUntilShutdown strobes the sessionQueue's condition variable until the
// main event loop exits.
func (q *sessionQueue) signalUntilShutdown() {
for {
select {
case <-time.After(time.Millisecond):
q.queueCond.Signal()
case <-q.shutdown:
return
}
}
}
// sessionQueueSet maintains a mapping of SessionIDs to their corresponding
// sessionQueue.
type sessionQueueSet map[wtdb.SessionID]*sessionQueue
// Add inserts a sessionQueue into the sessionQueueSet.
func (s *sessionQueueSet) Add(sessionQueue *sessionQueue) {
(*s)[*sessionQueue.ID()] = sessionQueue
}
// ApplyAndWait executes the nil-adic function returned from getApply for each
// sessionQueue in the set in parallel, then waits for all of them to finish
// before returning to the caller.
func (s *sessionQueueSet) ApplyAndWait(getApply func(*sessionQueue) func()) {
var wg sync.WaitGroup
for _, sessionq := range *s {
wg.Add(1)
go func(sq *sessionQueue) {
defer wg.Done()
getApply(sq)()
}(sessionq)
}
wg.Wait()
}

@ -0,0 +1,51 @@
package wtclient
import "fmt"
type clientStats struct {
numTasksReceived int
numTasksAccepted int
numTasksIneligible int
numSessionsAcquired int
numSessionsExhausted int
}
// taskReceived increments the number to backup requests the client has received
// from active channels.
func (s *clientStats) taskReceived() {
s.numTasksReceived++
}
// taskAccepted increments the number of tasks that have been assigned to active
// session queues, and are awaiting upload to a tower.
func (s *clientStats) taskAccepted() {
s.numTasksAccepted++
}
// taskIneligible increments the number of tasks that were unable to satisfy the
// active session queue's policy. These can potentially be retried later, but
// typically this means that the balance created dust outputs, so it may not be
// worth backing up at all.
func (s *clientStats) taskIneligible() {
s.numTasksIneligible++
}
// sessionAcquired increments the number of sessions that have been successfully
// negotiated by the client during this execution.
func (s *clientStats) sessionAcquired() {
s.numSessionsAcquired++
}
// sessionExhausted increments the number of session that have become full as a
// result of accepting backup tasks.
func (s *clientStats) sessionExhausted() {
s.numSessionsExhausted++
}
// String returns a human readable summary of the client's metrics.
func (s clientStats) String() string {
return fmt.Sprintf("tasks(received=%d accepted=%d ineligible=%d) "+
"sessions(acquired=%d exhausted=%d)", s.numTasksReceived,
s.numTasksAccepted, s.numTasksIneligible, s.numSessionsAcquired,
s.numSessionsExhausted)
}

@ -0,0 +1,185 @@
package wtclient
import (
"container/list"
"sync"
"time"
)
// taskPipeline implements a reliable, in-order queue that ensures its queue
// fully drained before exiting. Stopping the taskPipeline prevents the pipeline
// from accepting any further tasks, and will cause the pipeline to exit after
// all updates have been delivered to the downstream receiver. If this process
// hangs and is unable to make progress, users can optionally call ForceQuit to
// abandon the reliable draining of the queue in order to permit shutdown.
type taskPipeline struct {
started sync.Once
stopped sync.Once
forced sync.Once
queueMtx sync.Mutex
queueCond *sync.Cond
queue *list.List
newBackupTasks chan *backupTask
quit chan struct{}
forceQuit chan struct{}
shutdown chan struct{}
}
// newTaskPipeline initializes a new taskPipeline.
func newTaskPipeline() *taskPipeline {
rq := &taskPipeline{
queue: list.New(),
newBackupTasks: make(chan *backupTask),
quit: make(chan struct{}),
forceQuit: make(chan struct{}),
shutdown: make(chan struct{}),
}
rq.queueCond = sync.NewCond(&rq.queueMtx)
return rq
}
// Start spins up the taskPipeline, making it eligible to begin receiving backup
// tasks and deliver them to the receiver of NewBackupTasks.
func (q *taskPipeline) Start() {
q.started.Do(func() {
go q.queueManager()
})
}
// Stop begins a graceful shutdown of the taskPipeline. This method returns once
// all backupTasks have been delivered via NewBackupTasks, or a ForceQuit causes
// the delivery of pending tasks to be interrupted.
func (q *taskPipeline) Stop() {
q.stopped.Do(func() {
log.Debugf("Stopping task pipeline")
close(q.quit)
q.signalUntilShutdown()
// Skip log if we also force quit.
select {
case <-q.forceQuit:
default:
log.Debugf("Task pipeline stopped successfully")
}
})
}
// ForceQuit signals the taskPipeline to immediately exit, dropping any
// backupTasks that have not been delivered via NewBackupTasks.
func (q *taskPipeline) ForceQuit() {
q.forced.Do(func() {
log.Infof("Force quitting task pipeline")
close(q.forceQuit)
q.signalUntilShutdown()
log.Infof("Task pipeline unclean shutdown complete")
})
}
// NewBackupTasks returns a read-only channel for enqueue backupTasks. The
// channel will be closed after a call to Stop and all pending tasks have been
// delivered, or if a call to ForceQuit is called before the pending entries
// have been drained.
func (q *taskPipeline) NewBackupTasks() <-chan *backupTask {
return q.newBackupTasks
}
// QueueBackupTask enqueues a backupTask for reliable delivery to the consumer
// of NewBackupTasks. If the taskPipeline is shutting down, ErrClientExiting is
// returned. Otherwise, if QueueBackupTask returns nil it is guaranteed to be
// delivered via NewBackupTasks unless ForceQuit is called before completion.
func (q *taskPipeline) QueueBackupTask(task *backupTask) error {
q.queueCond.L.Lock()
select {
// Reject new tasks after quit has been signaled.
case <-q.quit:
q.queueCond.L.Unlock()
return ErrClientExiting
// Reject new tasks after force quit has been signaled.
case <-q.forceQuit:
q.queueCond.L.Unlock()
return ErrClientExiting
default:
}
// Queue the new task and signal the queue's condition variable to wake up
// the queueManager for processing.
q.queue.PushBack(task)
q.queueCond.L.Unlock()
q.queueCond.Signal()
return nil
}
// queueManager processes all incoming backup requests that get added via
// QueueBackupTask. The manager will exit
//
// NOTE: This method MUST be run as a goroutine.
func (q *taskPipeline) queueManager() {
defer close(q.shutdown)
defer close(q.newBackupTasks)
for {
q.queueCond.L.Lock()
for q.queue.Front() == nil {
q.queueCond.Wait()
select {
case <-q.quit:
// Exit only after the queue has been fully drained.
if q.queue.Len() == 0 {
q.queueCond.L.Unlock()
log.Debugf("Revoked state pipeline flushed.")
return
}
case <-q.forceQuit:
q.queueCond.L.Unlock()
log.Debugf("Revoked state pipeline force quit.")
return
default:
}
}
// Pop the first element from the queue.
e := q.queue.Front()
task := q.queue.Remove(e).(*backupTask)
q.queueCond.L.Unlock()
select {
// Backup task submitted to dispatcher. We don't select on quit to
// ensure that we still drain tasks while shutting down.
case q.newBackupTasks <- task:
// Force quit, return immediately to allow the client to exit.
case <-q.forceQuit:
log.Debugf("Revoked state pipeline force quit.")
return
}
}
}
// signalUntilShutdown strobes the queue's condition variable to ensure the
// queueManager reliably unblocks to check for the exit condition.
func (q *taskPipeline) signalUntilShutdown() {
for {
select {
case <-time.After(time.Millisecond):
q.queueCond.Signal()
case <-q.shutdown:
return
}
}
}

@ -0,0 +1,110 @@
package wtdb
import (
"errors"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
var (
// ErrClientSessionNotFound signals that the requested client session
// was not found in the database.
ErrClientSessionNotFound = errors.New("client session not found")
// ErrUpdateAlreadyCommitted signals that the chosen sequence number has
// already been committed to an update with a different breach hint.
ErrUpdateAlreadyCommitted = errors.New("update already committed")
// ErrCommitUnorderedUpdate signals the client tried to commit a
// sequence number other than the next unallocated sequence number.
ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic")
// ErrCommittedUpdateNotFound signals that the tower tried to ACK a
// sequence number that has not yet been allocated by the client.
ErrCommittedUpdateNotFound = errors.New("committed update not found")
// ErrUnallocatedLastApplied signals that the tower tried to provide a
// LastApplied value greater than any allocated sequence number.
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
"greater than allocated seqnum")
)
// ClientSession encapsulates a SessionInfo returned from a successful
// session negotiation, and also records the tower and ephemeral secret used for
// communicating with the tower.
type ClientSession struct {
// ID is the client's public key used when authenticating with the
// tower.
ID SessionID
// SeqNum is the next unallocated sequence number that can be sent to
// the tower.
SeqNum uint16
// TowerLastApplied the last last-applied the tower has echoed back.
TowerLastApplied uint16
// TowerID is the unique, db-assigned identifier that references the
// Tower with which the session is negotiated.
TowerID uint64
// Tower holds the pubkey and address of the watchtower.
//
// NOTE: This value is not serialized. It is recovered by looking up the
// tower with TowerID.
Tower *Tower
// SessionKeyDesc is the key descriptor used to derive the client's
// session key so that it can authenticate with the tower to update its
// session.
SessionKeyDesc keychain.KeyLocator
// SessionPrivKey is the ephemeral secret key used to connect to the
// watchtower.
// TODO(conner): remove after HD keys
SessionPrivKey *btcec.PrivateKey
// Policy holds the negotiated session parameters.
Policy wtpolicy.Policy
// RewardPkScript is the pkscript that the tower's reward will be
// deposited to if a sweep transaction confirms and the sessions
// specifies a reward output.
RewardPkScript []byte
// CommittedUpdates is a map from allocated sequence numbers to unacked
// updates. These updates can be resent after a restart if the update
// failed to send or receive an acknowledgment.
CommittedUpdates map[uint16]*CommittedUpdate
// AckedUpdates is a map from sequence number to backup id to record
// which revoked states were uploaded via this session.
AckedUpdates map[uint16]BackupID
}
// BackupID identifies a particular revoked, remote commitment by channel id and
// commitment height.
type BackupID struct {
// ChanID is the channel id of the revoked commitment.
ChanID lnwire.ChannelID
// CommitHeight is the commitment height of the revoked commitment.
CommitHeight uint64
}
// CommittedUpdate holds a state update sent by a client along with its
// SessionID.
type CommittedUpdate struct {
BackupID BackupID
// Hint is the 16-byte prefix of the revoked commitment transaction ID.
Hint BreachHint
// EncryptedBlob is a ciphertext containing the sweep information for
// exacting justice if the commitment transaction matching the breach
// hint is braodcast.
EncryptedBlob []byte
}

@ -82,7 +82,7 @@ func (s *SessionInfo) AcceptUpdateSequence(seqNum, lastApplied uint16) error {
return ErrSessionConsumed return ErrSessionConsumed
// Client update does not match our expected next seqnum. // Client update does not match our expected next seqnum.
case seqNum != s.LastApplied+1: case seqNum != s.LastApplied && seqNum != s.LastApplied+1:
return ErrUpdateOutOfOrder return ErrUpdateOutOfOrder
} }

65
watchtower/wtdb/tower.go Normal file

@ -0,0 +1,65 @@
package wtdb
import (
"net"
"sync"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/lnwire"
)
// Tower holds the necessary components required to connect to a remote tower.
// Communication is handled by brontide, and requires both a public key and an
// address.
type Tower struct {
// ID is a unique ID for this record assigned by the database.
ID uint64
// IdentityKey is the public key of the remote node, used to
// authenticate the brontide transport.
IdentityKey *btcec.PublicKey
// Addresses is a list of possible addresses to reach the tower.
Addresses []net.Addr
mu sync.RWMutex
}
// AddAddress adds the given address to the tower's in-memory list of addresses.
// If the address's string is already present, the Tower will be left
// unmodified. Otherwise, the adddress is prepended to the beginning of the
// Tower's addresses, on the assumption that it is fresher than the others.
func (t *Tower) AddAddress(addr net.Addr) {
t.mu.Lock()
defer t.mu.Unlock()
// Ensure we don't add a duplicate address.
addrStr := addr.String()
for _, existingAddr := range t.Addresses {
if existingAddr.String() == addrStr {
return
}
}
// Add this address to the front of the list, on the assumption that it
// is a fresher address and will be tried first.
t.Addresses = append([]net.Addr{addr}, t.Addresses...)
}
// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's
// addresses. This can be used to have a client try multiple addresses for the
// same Tower.
func (t *Tower) LNAddrs() []*lnwire.NetAddress {
t.mu.RLock()
defer t.mu.RUnlock()
addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses))
for _, addr := range t.Addresses {
addrs = append(addrs, &lnwire.NetAddress{
IdentityKey: t.IdentityKey,
Address: addr,
})
}
return addrs
}

@ -0,0 +1,223 @@
package wtmock
import (
"fmt"
"net"
"sync"
"sync/atomic"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
)
type towerPK [33]byte
// ClientDB is a mock, in-memory database or testing the watchtower client
// behavior.
type ClientDB struct {
nextTowerID uint64 // to be used atomically
mu sync.Mutex
sweepPkScripts map[lnwire.ChannelID][]byte
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
towerIndex map[towerPK]uint64
towers map[uint64]*wtdb.Tower
}
// NewClientDB initializes a new mock ClientDB.
func NewClientDB() *ClientDB {
return &ClientDB{
sweepPkScripts: make(map[lnwire.ChannelID][]byte),
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
towerIndex: make(map[towerPK]uint64),
towers: make(map[uint64]*wtdb.Tower),
}
}
// CreateTower initializes a database entry with the given lightning address. If
// the tower exists, the address is append to the list of all addresses used to
// that tower previously.
func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
m.mu.Lock()
defer m.mu.Unlock()
var towerPubKey towerPK
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
var tower *wtdb.Tower
towerID, ok := m.towerIndex[towerPubKey]
if ok {
tower = m.towers[towerID]
tower.AddAddress(lnAddr.Address)
} else {
towerID = atomic.AddUint64(&m.nextTowerID, 1)
tower = &wtdb.Tower{
ID: towerID,
IdentityKey: lnAddr.IdentityKey,
Addresses: []net.Addr{lnAddr.Address},
}
}
m.towerIndex[towerPubKey] = towerID
m.towers[towerID] = tower
return tower, nil
}
// MarkBackupIneligible records that particular commit height is ineligible for
// backup. This allows the client to track which updates it should not attempt
// to retry after startup.
func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error {
return nil
}
// ListClientSessions returns the set of all client sessions known to the db.
func (m *ClientDB) ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) {
m.mu.Lock()
defer m.mu.Unlock()
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, session := range m.activeSessions {
sessions[session.ID] = session
}
return sessions, nil
}
// CreateClientSession records a newly negotiated client session in the set of
// active sessions. The session can be identified by its SessionID.
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
m.mu.Lock()
defer m.mu.Unlock()
m.activeSessions[session.ID] = &wtdb.ClientSession{
TowerID: session.TowerID,
Tower: session.Tower,
SessionKeyDesc: session.SessionKeyDesc,
SessionPrivKey: session.SessionPrivKey,
ID: session.ID,
Policy: session.Policy,
SeqNum: session.SeqNum,
TowerLastApplied: session.TowerLastApplied,
RewardPkScript: session.RewardPkScript,
CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate),
AckedUpdates: make(map[uint16]wtdb.BackupID),
}
return nil
}
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
// seqNum). This allows the client to retransmit this update on startup.
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
update *wtdb.CommittedUpdate) (uint16, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Fail if session doesn't exist.
session, ok := m.activeSessions[*id]
if !ok {
return 0, wtdb.ErrClientSessionNotFound
}
// Check if an update has already been committed for this state.
dbUpdate, ok := session.CommittedUpdates[seqNum]
if ok {
// If the breach hint matches, we'll just return the last
// applied value so the client can retransmit.
if dbUpdate.Hint == update.Hint {
return session.TowerLastApplied, nil
}
// Otherwise, fail since the breach hint doesn't match.
return 0, wtdb.ErrUpdateAlreadyCommitted
}
// Sequence number must increment.
if seqNum != session.SeqNum+1 {
return 0, wtdb.ErrCommitUnorderedUpdate
}
// Save the update and increment the sequence number.
session.CommittedUpdates[seqNum] = update
session.SeqNum++
return session.TowerLastApplied, nil
}
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
// removes the update from the set of committed updates, and validates the
// lastApplied value returned from the tower.
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error {
m.mu.Lock()
defer m.mu.Unlock()
// Fail if session doesn't exist.
session, ok := m.activeSessions[*id]
if !ok {
return wtdb.ErrClientSessionNotFound
}
// Retrieve the committed update, failing if none is found. We should
// only receive acks for state updates that we send.
update, ok := session.CommittedUpdates[seqNum]
if !ok {
return wtdb.ErrCommittedUpdateNotFound
}
// Ensure the returned last applied value does not exceed the highest
// allocated sequence number.
if lastApplied > session.SeqNum {
return wtdb.ErrUnallocatedLastApplied
}
// Ensure the last applied value isn't lower than a previous one sent by
// the tower.
if lastApplied < session.TowerLastApplied {
return wtdb.ErrLastAppliedReversion
}
// Finally, remove the committed update from disk and mark the update as
// acked. The tower last applied value is also recorded to send along
// with the next update.
delete(session.CommittedUpdates, seqNum)
session.AckedUpdates[seqNum] = update.BackupID
session.TowerLastApplied = lastApplied
return nil
}
// FetchChanPkScripts returns the set of sweep pkscripts known for all channels.
// This allows the client to cache them in memory on startup.
func (m *ClientDB) FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
sweepPkScripts := make(map[lnwire.ChannelID][]byte)
for chanID, pkScript := range m.sweepPkScripts {
sweepPkScripts[chanID] = cloneBytes(pkScript)
}
return sweepPkScripts, nil
}
// AddChanPkScript sets a pkscript or sweeping funds from the channel or chanID.
func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.sweepPkScripts[chanID]; ok {
return fmt.Errorf("pkscript for %x already exists", pkScript)
}
m.sweepPkScripts[chanID] = cloneBytes(pkScript)
return nil
}
func cloneBytes(b []byte) []byte {
bb := make([]byte, len(b))
copy(bb, b)
return bb
}

@ -13,6 +13,8 @@ import (
type MockPeer struct { type MockPeer struct {
remotePub *btcec.PublicKey remotePub *btcec.PublicKey
remoteAddr net.Addr remoteAddr net.Addr
localPub *btcec.PublicKey
localAddr net.Addr
IncomingMsgs chan []byte IncomingMsgs chan []byte
OutgoingMsgs chan []byte OutgoingMsgs chan []byte
@ -20,30 +22,74 @@ type MockPeer struct {
writeDeadline <-chan time.Time writeDeadline <-chan time.Time
readDeadline <-chan time.Time readDeadline <-chan time.Time
Quit chan struct{} RemoteQuit chan struct{}
Quit chan struct{}
} }
// NewMockPeer returns a fresh MockPeer. // NewMockPeer returns a fresh MockPeer.
func NewMockPeer(pk *btcec.PublicKey, addr net.Addr, bufferSize int) *MockPeer { func NewMockPeer(lpk, rpk *btcec.PublicKey, addr net.Addr,
bufferSize int) *MockPeer {
return &MockPeer{ return &MockPeer{
remotePub: pk, remotePub: rpk,
remoteAddr: addr, remoteAddr: addr,
localAddr: &net.TCPAddr{
IP: net.IP{0x32, 0x31, 0x30, 0x29},
Port: 36723,
},
localPub: lpk,
IncomingMsgs: make(chan []byte, bufferSize), IncomingMsgs: make(chan []byte, bufferSize),
OutgoingMsgs: make(chan []byte, bufferSize), OutgoingMsgs: make(chan []byte, bufferSize),
Quit: make(chan struct{}), Quit: make(chan struct{}),
} }
} }
// NewMockConn establishes a bidirectional connection between two MockPeers.
func NewMockConn(localPk, remotePk *btcec.PublicKey,
localAddr, remoteAddr net.Addr,
bufferSize int) (*MockPeer, *MockPeer) {
localPeer := &MockPeer{
remotePub: remotePk,
remoteAddr: remoteAddr,
localPub: localPk,
localAddr: localAddr,
IncomingMsgs: make(chan []byte, bufferSize),
OutgoingMsgs: make(chan []byte, bufferSize),
Quit: make(chan struct{}),
}
remotePeer := &MockPeer{
remotePub: localPk,
remoteAddr: localAddr,
localPub: remotePk,
localAddr: remoteAddr,
IncomingMsgs: localPeer.OutgoingMsgs,
OutgoingMsgs: localPeer.IncomingMsgs,
Quit: make(chan struct{}),
}
localPeer.RemoteQuit = remotePeer.Quit
remotePeer.RemoteQuit = localPeer.Quit
return localPeer, remotePeer
}
// Write sends the raw bytes as the next full message read to the remote peer. // Write sends the raw bytes as the next full message read to the remote peer.
// The write will fail if either party closes the connection or the write // The write will fail if either party closes the connection or the write
// deadline expires. The passed bytes slice is copied before sending, thus the // deadline expires. The passed bytes slice is copied before sending, thus the
// bytes may be reused once the method returns. // bytes may be reused once the method returns.
func (p *MockPeer) Write(b []byte) (n int, err error) { func (p *MockPeer) Write(b []byte) (n int, err error) {
bb := make([]byte, len(b))
copy(bb, b)
select { select {
case p.OutgoingMsgs <- b: case p.OutgoingMsgs <- bb:
return len(b), nil return len(b), nil
case <-p.writeDeadline: case <-p.writeDeadline:
return 0, fmt.Errorf("write timeout expired") return 0, fmt.Errorf("write timeout expired")
case <-p.RemoteQuit:
return 0, fmt.Errorf("remote closed connected")
case <-p.Quit: case <-p.Quit:
return 0, fmt.Errorf("connection closed") return 0, fmt.Errorf("connection closed")
} }
@ -69,6 +115,8 @@ func (p *MockPeer) ReadNextMessage() ([]byte, error) {
return b, nil return b, nil
case <-p.readDeadline: case <-p.readDeadline:
return nil, fmt.Errorf("read timeout expired") return nil, fmt.Errorf("read timeout expired")
case <-p.RemoteQuit:
return nil, fmt.Errorf("remote closed connected")
case <-p.Quit: case <-p.Quit:
return nil, fmt.Errorf("connection closed") return nil, fmt.Errorf("connection closed")
} }
@ -112,6 +160,25 @@ func (p *MockPeer) RemoteAddr() net.Addr {
return p.remoteAddr return p.remoteAddr
} }
// LocalAddr returns the local net address of the peer.
func (p *MockPeer) LocalAddr() net.Addr {
return p.localAddr
}
// Read is not implemented.
func (p *MockPeer) Read(dst []byte) (int, error) {
panic("not implemented")
}
// SetDeadline is not implemented.
func (p *MockPeer) SetDeadline(t time.Time) error {
panic("not implemented")
}
// Compile-time constraint ensuring the MockPeer implements the wserver.Peer // Compile-time constraint ensuring the MockPeer implements the wserver.Peer
// interface. // interface.
var _ wtserver.Peer = (*MockPeer)(nil) var _ wtserver.Peer = (*MockPeer)(nil)
// Compile-time constraint ensuring the MockPeer implements the net.Conn
// interface.
var _ net.Conn = (*MockPeer)(nil)

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
@ -55,14 +54,18 @@ type Config struct {
// ChainHash identifies the network that the server is watching. // ChainHash identifies the network that the server is watching.
ChainHash chainhash.Hash ChainHash chainhash.Hash
// NoAckUpdates causes the server to not acknowledge state updates, this
// should only be used for testing.
NoAckUpdates bool
} }
// Server houses the state required to handle watchtower peers. It's primary job // Server houses the state required to handle watchtower peers. It's primary job
// is to accept incoming connections, and dispatch processing of the client // is to accept incoming connections, and dispatch processing of the client
// message streams. // message streams.
type Server struct { type Server struct {
started int32 // atomic started sync.Once
shutdown int32 // atomic stopped sync.Once
cfg *Config cfg *Config
@ -71,6 +74,8 @@ type Server struct {
clientMtx sync.RWMutex clientMtx sync.RWMutex
clients map[wtdb.SessionID]Peer clients map[wtdb.SessionID]Peer
newPeers chan Peer
localInit *wtwire.Init localInit *wtwire.Init
wg sync.WaitGroup wg sync.WaitGroup
@ -89,6 +94,7 @@ func New(cfg *Config) (*Server, error) {
s := &Server{ s := &Server{
cfg: cfg, cfg: cfg,
clients: make(map[wtdb.SessionID]Peer), clients: make(map[wtdb.SessionID]Peer),
newPeers: make(chan Peer),
localInit: localInit, localInit: localInit,
quit: make(chan struct{}), quit: make(chan struct{}),
} }
@ -109,36 +115,31 @@ func New(cfg *Config) (*Server, error) {
// Start begins listening on the server's listeners. // Start begins listening on the server's listeners.
func (s *Server) Start() error { func (s *Server) Start() error {
// Already running? s.started.Do(func() {
if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { log.Infof("Starting watchtower server")
return nil
}
log.Infof("Starting watchtower server") s.wg.Add(1)
go s.peerHandler()
s.connMgr.Start() s.connMgr.Start()
log.Infof("Watchtower server started successfully")
log.Infof("Watchtower server started successfully")
})
return nil return nil
} }
// Stop shutdowns down the server's listeners and any active requests. // Stop shutdowns down the server's listeners and any active requests.
func (s *Server) Stop() error { func (s *Server) Stop() error {
// Bail if we're already shutting down. s.stopped.Do(func() {
if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) { log.Infof("Stopping watchtower server")
return nil
}
log.Infof("Stopping watchtower server") s.connMgr.Stop()
s.connMgr.Stop() close(s.quit)
s.wg.Wait()
close(s.quit)
s.wg.Wait()
log.Infof("Watchtower server stopped successfully")
log.Infof("Watchtower server stopped successfully")
})
return nil return nil
} }
@ -163,8 +164,29 @@ func (s *Server) inboundPeerConnected(c net.Conn) {
// by the client. This method serves also as a public endpoint for locally // by the client. This method serves also as a public endpoint for locally
// registering new clients with the server. // registering new clients with the server.
func (s *Server) InboundPeerConnected(peer Peer) { func (s *Server) InboundPeerConnected(peer Peer) {
s.wg.Add(1) select {
go s.handleClient(peer) case s.newPeers <- peer:
case <-s.quit:
}
}
// peerHandler processes newly accepted peers and spawns a client handler for
// each. The peerHandler is used to ensure that waitgrouped client handlers are
// spawned from a waitgrouped goroutine.
func (s *Server) peerHandler() {
defer s.wg.Done()
defer s.removeAllPeers()
for {
select {
case peer := <-s.newPeers:
s.wg.Add(1)
go s.handleClient(peer)
case <-s.quit:
return
}
}
} }
// handleClient processes a series watchtower messages sent by a client. The // handleClient processes a series watchtower messages sent by a client. The
@ -445,6 +467,13 @@ func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
failCode = wtwire.CodeTemporaryFailure failCode = wtwire.CodeTemporaryFailure
} }
if s.cfg.NoAckUpdates {
return &connFailure{
ID: *id,
Code: uint16(failCode),
}
}
return s.replyStateUpdate( return s.replyStateUpdate(
peer, id, failCode, lastApplied, peer, id, failCode, lastApplied,
) )
@ -614,6 +643,21 @@ func (s *Server) removePeer(id *wtdb.SessionID, addr net.Addr) {
} }
} }
// removeAllPeers iterates through the server's current set of peers and closes
// all open connections.
func (s *Server) removeAllPeers() {
s.clientMtx.Lock()
defer s.clientMtx.Unlock()
for id, peer := range s.clients {
log.Infof("Releasing incoming peer %s@%s", id,
peer.RemoteAddr())
delete(s.clients, id)
peer.Close()
}
}
// logMessage writes information about a message exchanged with a remote peer, // logMessage writes information about a message exchanged with a remote peer,
// using directional prepositions to signal whether the message was sent or // using directional prepositions to signal whether the message was sent or
// received. // received.

@ -87,10 +87,12 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) {
s := initServer(t, nil, timeoutDuration) s := initServer(t, nil, timeoutDuration)
defer s.Stop() defer s.Stop()
localPub := randPubKey(t)
// Create two peers using the same session id. // Create two peers using the same session id.
peerPub := randPubKey(t) peerPub := randPubKey(t)
peer1 := wtmock.NewMockPeer(peerPub, nil, 0) peer1 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
peer2 := wtmock.NewMockPeer(peerPub, nil, 0) peer2 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
// Serialize a Init message to be sent by both peers. // Serialize a Init message to be sent by both peers.
init := wtwire.NewInitMessage( init := wtwire.NewInitMessage(
@ -219,9 +221,11 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
s := initServer(t, nil, timeoutDuration) s := initServer(t, nil, timeoutDuration)
defer s.Stop() defer s.Stop()
localPub := randPubKey(t)
// Create a new client and connect to server. // Create a new client and connect to server.
peerPub := randPubKey(t) peerPub := randPubKey(t)
peer := wtmock.NewMockPeer(peerPub, nil, 0) peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, i, s, peer, test.initMsg, timeoutDuration)
// Send the CreateSession message, and wait for a reply. // Send the CreateSession message, and wait for a reply.
@ -249,7 +253,7 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
// Simulate a peer with the same session id connection to the server // Simulate a peer with the same session id connection to the server
// again. // again.
peer = wtmock.NewMockPeer(peerPub, nil, 0) peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, i, s, peer, test.initMsg, timeoutDuration)
// Send the _same_ CreateSession message as the first attempt. // Send the _same_ CreateSession message as the first attempt.
@ -418,8 +422,8 @@ var stateUpdateTests = []stateUpdateTestCase{
{Code: wtwire.CodeOK, LastApplied: 4}, {Code: wtwire.CodeOK, LastApplied: 4},
}, },
}, },
// Valid update sequence with disconnection, ensure resumes resume. // Valid update sequence with disconnection, resume next update. Client
// Client doesn't echo last applied until last message. // doesn't echo last applied until last message.
{ {
name: "resume after disconnect lagging lastapplied", name: "resume after disconnect lagging lastapplied",
initMsg: wtwire.NewInitMessage( initMsg: wtwire.NewInitMessage(
@ -448,6 +452,38 @@ var stateUpdateTests = []stateUpdateTestCase{
{Code: wtwire.CodeOK, LastApplied: 4}, {Code: wtwire.CodeOK, LastApplied: 4},
}, },
}, },
// Valid update sequence with disconnection, resume last update. Client
// doesn't echo last applied until last message.
{
name: "resume after disconnect lagging lastapplied",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 4,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0},
{SeqNum: 2, LastApplied: 0},
nil, // Wait for read timeout to drop conn, then reconnect.
{SeqNum: 2, LastApplied: 0},
{SeqNum: 3, LastApplied: 0},
{SeqNum: 4, LastApplied: 3},
},
replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1},
{Code: wtwire.CodeOK, LastApplied: 2},
nil,
{Code: wtwire.CodeOK, LastApplied: 2},
{Code: wtwire.CodeOK, LastApplied: 3},
{Code: wtwire.CodeOK, LastApplied: 4},
},
},
// Send update with sequence number that exceeds MaxUpdates. // Send update with sequence number that exceeds MaxUpdates.
{ {
name: "seqnum exceed maxupdates", name: "seqnum exceed maxupdates",
@ -527,9 +563,11 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
s := initServer(t, nil, timeoutDuration) s := initServer(t, nil, timeoutDuration)
defer s.Stop() defer s.Stop()
localPub := randPubKey(t)
// Create a new client and connect to the server. // Create a new client and connect to the server.
peerPub := randPubKey(t) peerPub := randPubKey(t)
peer := wtmock.NewMockPeer(peerPub, nil, 0) peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, i, s, peer, test.initMsg, timeoutDuration)
// Register a session for this client to use in the subsequent tests. // Register a session for this client to use in the subsequent tests.
@ -549,7 +587,7 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
// Now that the original connection has been closed, connect a new // Now that the original connection has been closed, connect a new
// client with the same session id. // client with the same session id.
peer = wtmock.NewMockPeer(peerPub, nil, 0) peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, i, s, peer, test.initMsg, timeoutDuration)
// Send the intended StateUpdate messages in series. // Send the intended StateUpdate messages in series.
@ -560,7 +598,7 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
if update == nil { if update == nil {
assertConnClosed(t, peer, 2*timeoutDuration) assertConnClosed(t, peer, 2*timeoutDuration)
peer = wtmock.NewMockPeer(peerPub, nil, 0) peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, i, s, peer, test.initMsg, timeoutDuration)
continue continue

@ -14,9 +14,9 @@ const (
// reply was never received and/or processed by the client. // reply was never received and/or processed by the client.
CreateSessionCodeAlreadyExists CreateSessionCode = 60 CreateSessionCodeAlreadyExists CreateSessionCode = 60
// CreateSessionCodeRejectRejectMaxUpdates the tower rejected the maximum // CreateSessionCodeRejectMaxUpdates the tower rejected the maximum
// number of state updates proposed by the client. // number of state updates proposed by the client.
CreateSessionCodeRejectRejectMaxUpdates CreateSessionCode = 61 CreateSessionCodeRejectMaxUpdates CreateSessionCode = 61
// CreateSessionCodeRejectRewardRate the tower rejected the reward rate // CreateSessionCodeRejectRewardRate the tower rejected the reward rate
// proposed by the client. // proposed by the client.

@ -1,5 +1,7 @@
package wtwire package wtwire
import "fmt"
// ErrorCode represents a generic error code used when replying to watchtower // ErrorCode represents a generic error code used when replying to watchtower
// clients. Specific reply messages may extend the ErrorCode primitive and add // clients. Specific reply messages may extend the ErrorCode primitive and add
// custom codes, so long as they don't collide with the generic error codes.. // custom codes, so long as they don't collide with the generic error codes..
@ -18,3 +20,33 @@ const (
// permanently failed, and further communication should be avoided. // permanently failed, and further communication should be avoided.
CodePermanentFailure ErrorCode = 50 CodePermanentFailure ErrorCode = 50
) )
// String returns a human-readable description of an ErrorCode.
func (c ErrorCode) String() string {
switch c {
case CodeOK:
return "CodeOK"
case CodeTemporaryFailure:
return "CodeTemporaryFailure"
case CodePermanentFailure:
return "CodePermanentFailure"
case CreateSessionCodeAlreadyExists:
return "CreateSessionCodeAlreadyExists"
case CreateSessionCodeRejectMaxUpdates:
return "CreateSessionCodeRejectMaxUpdates"
case CreateSessionCodeRejectRewardRate:
return "CreateSessionCodeRejectRewardRate"
case CreateSessionCodeRejectSweepFeeRate:
return "CreateSessionCodeRejectSweepFeeRate"
case CreateSessionCodeRejectBlobType:
return "CreateSessionCodeRejectBlobType"
case StateUpdateCodeClientBehind:
return "StateUpdateCodeClientBehind"
case StateUpdateCodeMaxUpdatesExceeded:
return "StateUpdateCodeMaxUpdatesExceeded"
case StateUpdateCodeSeqNumOutOfOrder:
return "StateUpdateCodeSeqNumOutOfOrder"
default:
return fmt.Sprintf("UnknownErrorCode: %d", c)
}
}