Merge pull request #2618 from cfromknecht/wtclient
watchtower/wtclient: reliable, asynchronous pipeline for revoked state backups
This commit is contained in:
commit
ec62104acc
@ -34,8 +34,7 @@ import (
|
||||
// necessary components are stripped out and encrypted before being sent to
|
||||
// the tower in a StateUpdate.
|
||||
type backupTask struct {
|
||||
chanID lnwire.ChannelID
|
||||
commitHeight uint64
|
||||
id wtdb.BackupID
|
||||
breachInfo *lnwallet.BreachRetribution
|
||||
|
||||
// state-dependent variables
|
||||
@ -96,8 +95,10 @@ func newBackupTask(chanID *lnwire.ChannelID,
|
||||
}
|
||||
|
||||
return &backupTask{
|
||||
chanID: *chanID,
|
||||
commitHeight: breachInfo.RevokedStateNum,
|
||||
id: wtdb.BackupID{
|
||||
ChanID: *chanID,
|
||||
CommitHeight: breachInfo.RevokedStateNum,
|
||||
},
|
||||
breachInfo: breachInfo,
|
||||
toLocalInput: toLocalInput,
|
||||
toRemoteInput: toRemoteInput,
|
||||
@ -125,7 +126,7 @@ func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
|
||||
// SessionInfo's policy. If no error is returned, the task has been bound to the
|
||||
// session and can be queued to upload to the tower. Otherwise, the bind failed
|
||||
// and should be rescheduled with a different session.
|
||||
func (t *backupTask) bindSession(session *wtdb.SessionInfo) error {
|
||||
func (t *backupTask) bindSession(session *wtdb.ClientSession) error {
|
||||
|
||||
// First we'll begin by deriving a weight estimate for the justice
|
||||
// transaction. The final weight can be different depending on whether
|
||||
@ -154,7 +155,7 @@ func (t *backupTask) bindSession(session *wtdb.SessionInfo) error {
|
||||
// in the current session's policy.
|
||||
outputs, err := session.Policy.ComputeJusticeTxOuts(
|
||||
t.totalAmt, int64(weightEstimate.Weight()),
|
||||
t.sweepPkScript, session.RewardAddress,
|
||||
t.sweepPkScript, session.RewardPkScript,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -69,7 +69,7 @@ type backupTaskTest struct {
|
||||
expSweepAmt int64
|
||||
expRewardAmt int64
|
||||
expRewardScript []byte
|
||||
session *wtdb.SessionInfo
|
||||
session *wtdb.ClientSession
|
||||
bindErr error
|
||||
expSweepScript []byte
|
||||
signer input.Signer
|
||||
@ -205,13 +205,13 @@ func genTaskTest(
|
||||
expSweepAmt: expSweepAmt,
|
||||
expRewardAmt: expRewardAmt,
|
||||
expRewardScript: rewardScript,
|
||||
session: &wtdb.SessionInfo{
|
||||
session: &wtdb.ClientSession{
|
||||
Policy: wtpolicy.Policy{
|
||||
BlobType: blobType,
|
||||
SweepFeeRate: sweepFeeRate,
|
||||
RewardRate: 10000,
|
||||
},
|
||||
RewardAddress: rewardScript,
|
||||
RewardPkScript: rewardScript,
|
||||
},
|
||||
bindErr: bindErr,
|
||||
expSweepScript: makeAddrSlice(22),
|
||||
@ -379,7 +379,7 @@ var backupTaskTests = []backupTaskTest{
|
||||
}
|
||||
|
||||
// TestBackupTaskBind tests the initialization and binding of a backupTask to a
|
||||
// SessionInfo. After a succesfful bind, all parameters of the justice
|
||||
// ClientSession. After a successful bind, all parameters of the justice
|
||||
// transaction should be solidified, so we assert there correctness. In an
|
||||
// unsuccessful bind, the session-dependent parameters should be unmodified so
|
||||
// that the backup task can be rescheduled if necessary. Finally, we assert that
|
||||
@ -401,14 +401,14 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||
|
||||
// Assert that all parameters set during initialization are properly
|
||||
// populated.
|
||||
if task.chanID != test.chanID {
|
||||
if task.id.ChanID != test.chanID {
|
||||
t.Fatalf("channel id mismatch, want: %s, got: %s",
|
||||
test.chanID, task.chanID)
|
||||
test.chanID, task.id.ChanID)
|
||||
}
|
||||
|
||||
if task.commitHeight != test.breachInfo.RevokedStateNum {
|
||||
if task.id.CommitHeight != test.breachInfo.RevokedStateNum {
|
||||
t.Fatalf("commit height mismatch, want: %d, got: %d",
|
||||
test.breachInfo.RevokedStateNum, task.commitHeight)
|
||||
test.breachInfo.RevokedStateNum, task.id.CommitHeight)
|
||||
}
|
||||
|
||||
if task.totalAmt != test.expTotalAmt {
|
||||
|
82
watchtower/wtclient/candidate_iterator.go
Normal file
82
watchtower/wtclient/candidate_iterator.go
Normal file
@ -0,0 +1,82 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
// TowerCandidateIterator provides an abstraction for iterating through possible
|
||||
// watchtower addresses when attempting to create a new session.
|
||||
type TowerCandidateIterator interface {
|
||||
// Reset clears any internal iterator state, making previously taken
|
||||
// candidates available as long as they remain in the set.
|
||||
Reset() error
|
||||
|
||||
// Next returns the next candidate tower. The iterator is not required
|
||||
// to return results in any particular order. If no more candidates are
|
||||
// available, ErrTowerCandidatesExhausted is returned.
|
||||
Next() (*wtdb.Tower, error)
|
||||
}
|
||||
|
||||
// towerListIterator is a linked-list backed TowerCandidateIterator.
|
||||
type towerListIterator struct {
|
||||
mu sync.Mutex
|
||||
candidates *list.List
|
||||
nextCandidate *list.Element
|
||||
}
|
||||
|
||||
// Compile-time constraint to ensure *towerListIterator implements the
|
||||
// TowerCandidateIterator interface.
|
||||
var _ TowerCandidateIterator = (*towerListIterator)(nil)
|
||||
|
||||
// newTowerListIterator initializes a new towerListIterator from a variadic list
|
||||
// of lnwire.NetAddresses.
|
||||
func newTowerListIterator(candidates ...*wtdb.Tower) *towerListIterator {
|
||||
iter := &towerListIterator{
|
||||
candidates: list.New(),
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
iter.candidates.PushBack(candidate)
|
||||
}
|
||||
iter.Reset()
|
||||
|
||||
return iter
|
||||
}
|
||||
|
||||
// Reset clears the iterators state, and makes the address at the front of the
|
||||
// list the next item to be returned..
|
||||
func (t *towerListIterator) Reset() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Reset the next candidate to the front of the linked-list.
|
||||
t.nextCandidate = t.candidates.Front()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Next returns the next candidate tower. This iterator will always return
|
||||
// candidates in the order given when the iterator was instantiated. If no more
|
||||
// candidates are available, ErrTowerCandidatesExhausted is returned.
|
||||
func (t *towerListIterator) Next() (*wtdb.Tower, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// If the next candidate is nil, we've exhausted the list.
|
||||
if t.nextCandidate == nil {
|
||||
return nil, ErrTowerCandidatesExhausted
|
||||
}
|
||||
|
||||
// Propose the tower at the front of the list.
|
||||
tower := t.nextCandidate.Value.(*wtdb.Tower)
|
||||
|
||||
// Set the next candidate to the subsequent element.
|
||||
t.nextCandidate = t.nextCandidate.Next()
|
||||
|
||||
return tower, nil
|
||||
}
|
||||
|
||||
// TODO(conner): implement graph-backed candidate iterator for public towers.
|
804
watchtower/wtclient/client.go
Normal file
804
watchtower/wtclient/client.go
Normal file
@ -0,0 +1,804 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultReadTimeout specifies the default duration we will wait during
|
||||
// a read before breaking out of a blocking read.
|
||||
DefaultReadTimeout = 15 * time.Second
|
||||
|
||||
// DefaultWriteTimeout specifies the default duration we will wait during
|
||||
// a write before breaking out of a blocking write.
|
||||
DefaultWriteTimeout = 15 * time.Second
|
||||
|
||||
// DefaultStatInterval specifies the default interval between logging
|
||||
// metrics about the client's operation.
|
||||
DefaultStatInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
// Client is the primary interface used by the daemon to control a client's
|
||||
// lifecycle and backup revoked states.
|
||||
type Client interface {
|
||||
// RegisterChannel persistently initializes any channel-dependent
|
||||
// parameters within the client. This should be called during link
|
||||
// startup to ensure that the client is able to support the link during
|
||||
// operation.
|
||||
RegisterChannel(lnwire.ChannelID) error
|
||||
|
||||
// BackupState initiates a request to back up a particular revoked
|
||||
// state. If the method returns nil, the backup is guaranteed to be
|
||||
// successful unless the client is force quit, or the justice
|
||||
// transaction would create dust outputs when trying to abide by the
|
||||
// negotiated policy.
|
||||
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution) error
|
||||
|
||||
// Start initializes the watchtower client, allowing it process requests
|
||||
// to backup revoked channel states.
|
||||
Start() error
|
||||
|
||||
// Stop attempts a graceful shutdown of the watchtower client. In doing
|
||||
// so, it will attempt to flush the pipeline and deliver any queued
|
||||
// states to the tower before exiting.
|
||||
Stop() error
|
||||
|
||||
// ForceQuit will forcibly shutdown the watchtower client. Calling this
|
||||
// may lead to queued states being dropped.
|
||||
ForceQuit()
|
||||
}
|
||||
|
||||
// Config provides the TowerClient with access to the resources it requires to
|
||||
// perform its duty. All nillable fields must be non-nil for the tower to be
|
||||
// initialized properly.
|
||||
type Config struct {
|
||||
// Signer provides access to the wallet so that the client can sign
|
||||
// justice transactions that spend from a remote party's commitment
|
||||
// transaction.
|
||||
Signer input.Signer
|
||||
|
||||
// NewAddress generates a new on-chain sweep pkscript.
|
||||
NewAddress func() ([]byte, error)
|
||||
|
||||
// SecretKeyRing is used to derive the session keys used to communicate
|
||||
// with the tower. The client only stores the KeyLocators internally so
|
||||
// that we never store private keys on disk.
|
||||
SecretKeyRing keychain.SecretKeyRing
|
||||
|
||||
// Dial connects to an addr using the specified net and returns the
|
||||
// connection object.
|
||||
Dial Dial
|
||||
|
||||
// AuthDialer establishes a brontide connection over an onion or clear
|
||||
// network.
|
||||
AuthDial AuthDialer
|
||||
|
||||
// DB provides access to the client's stable storage medium.
|
||||
DB DB
|
||||
|
||||
// Policy is the session policy the client will propose when creating
|
||||
// new sessions with the tower. If the policy differs from any active
|
||||
// sessions recorded in the database, those sessions will be ignored and
|
||||
// new sessions will be requested immediately.
|
||||
Policy wtpolicy.Policy
|
||||
|
||||
// PrivateTower is the net address of a private tower. The client will
|
||||
// try to create all sessions with this tower.
|
||||
PrivateTower *lnwire.NetAddress
|
||||
|
||||
// ChainHash identifies the chain that the client is on and for which
|
||||
// the tower must be watching to monitor for breaches.
|
||||
ChainHash chainhash.Hash
|
||||
|
||||
// ForceQuitDelay is the duration after attempting to shutdown that the
|
||||
// client will automatically abort any pending backups if an unclean
|
||||
// shutdown is detected. If the value is less than or equal to zero, a
|
||||
// call to Stop may block indefinitely. The client can always be
|
||||
// ForceQuit externally irrespective of the chosen parameter.
|
||||
ForceQuitDelay time.Duration
|
||||
|
||||
// ReadTimeout is the duration we will wait during a read before
|
||||
// breaking out of a blocking read. If the value is less than or equal
|
||||
// to zero, the default will be used instead.
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// WriteTimeout is the duration we will wait during a write before
|
||||
// breaking out of a blocking write. If the value is less than or equal
|
||||
// to zero, the default will be used instead.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// MinBackoff defines the initial backoff applied to connections with
|
||||
// watchtowers. Subsequent backoff durations will grow exponentially up
|
||||
// until MaxBackoff.
|
||||
MinBackoff time.Duration
|
||||
|
||||
// MaxBackoff defines the maximum backoff applied to conenctions with
|
||||
// watchtowers. If the exponential backoff produces a timeout greater
|
||||
// than this value, the backoff will be clamped to MaxBackoff.
|
||||
MaxBackoff time.Duration
|
||||
}
|
||||
|
||||
// TowerClient is a concrete implementation of the Client interface, offering a
|
||||
// non-blocking, reliable subsystem for backing up revoked states to a specified
|
||||
// private tower.
|
||||
type TowerClient struct {
|
||||
started sync.Once
|
||||
stopped sync.Once
|
||||
forced sync.Once
|
||||
|
||||
cfg *Config
|
||||
|
||||
pipeline *taskPipeline
|
||||
|
||||
negotiator SessionNegotiator
|
||||
candidateSessions map[wtdb.SessionID]*wtdb.ClientSession
|
||||
activeSessions sessionQueueSet
|
||||
|
||||
sessionQueue *sessionQueue
|
||||
prevTask *backupTask
|
||||
|
||||
sweepPkScriptMu sync.RWMutex
|
||||
sweepPkScripts map[lnwire.ChannelID][]byte
|
||||
|
||||
statTicker *time.Ticker
|
||||
stats clientStats
|
||||
|
||||
wg sync.WaitGroup
|
||||
forceQuit chan struct{}
|
||||
}
|
||||
|
||||
// Compile-time constraint to ensure *TowerClient implements the Client
|
||||
// interface.
|
||||
var _ Client = (*TowerClient)(nil)
|
||||
|
||||
// New initializes a new TowerClient from the provide Config. An error is
|
||||
// returned if the client could not initialized.
|
||||
func New(config *Config) (*TowerClient, error) {
|
||||
// Copy the config to prevent side-effects from modifying both the
|
||||
// internal and external version of the Config.
|
||||
cfg := new(Config)
|
||||
*cfg = *config
|
||||
|
||||
// Set the read timeout to the default if none was provided.
|
||||
if cfg.ReadTimeout <= 0 {
|
||||
cfg.ReadTimeout = DefaultReadTimeout
|
||||
}
|
||||
|
||||
// Set the write timeout to the default if none was provided.
|
||||
if cfg.WriteTimeout <= 0 {
|
||||
cfg.WriteTimeout = DefaultWriteTimeout
|
||||
}
|
||||
|
||||
// Record the tower in our database, also loading any addresses
|
||||
// previously associated with its public key.
|
||||
tower, err := cfg.DB.CreateTower(cfg.PrivateTower)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("Using private watchtower %s, offering policy %s",
|
||||
cfg.PrivateTower, cfg.Policy)
|
||||
|
||||
c := &TowerClient{
|
||||
cfg: cfg,
|
||||
pipeline: newTaskPipeline(),
|
||||
activeSessions: make(sessionQueueSet),
|
||||
statTicker: time.NewTicker(DefaultStatInterval),
|
||||
forceQuit: make(chan struct{}),
|
||||
}
|
||||
c.negotiator = newSessionNegotiator(&NegotiatorConfig{
|
||||
DB: cfg.DB,
|
||||
Policy: cfg.Policy,
|
||||
ChainHash: cfg.ChainHash,
|
||||
SendMessage: c.sendMessage,
|
||||
ReadMessage: c.readMessage,
|
||||
Dial: c.dial,
|
||||
Candidates: newTowerListIterator(tower),
|
||||
MinBackoff: cfg.MinBackoff,
|
||||
MaxBackoff: cfg.MaxBackoff,
|
||||
})
|
||||
|
||||
// Next, load all active sessions from the db into the client. We will
|
||||
// use any of these session if their policies match the current policy
|
||||
// of the client, otherwise they will be ignored and new sessions will
|
||||
// be requested.
|
||||
c.candidateSessions, err = c.cfg.DB.ListClientSessions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Finally, load the sweep pkscripts that have been generated for all
|
||||
// previously registered channels.
|
||||
c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Start initializes the watchtower client by loading or negotiating an active
|
||||
// session and then begins processing backup tasks from the request pipeline.
|
||||
func (c *TowerClient) Start() error {
|
||||
var err error
|
||||
c.started.Do(func() {
|
||||
log.Infof("Starting watchtower client")
|
||||
|
||||
// First, restart a session queue for any sessions that have
|
||||
// committed but unacked state updates. This ensures that these
|
||||
// sessions will be able to flush the committed updates after a
|
||||
// restart.
|
||||
for _, session := range c.candidateSessions {
|
||||
if len(session.CommittedUpdates) > 0 {
|
||||
log.Infof("Starting session=%s to process "+
|
||||
"%d committed backups", session.ID,
|
||||
len(session.CommittedUpdates))
|
||||
c.initActiveQueue(session)
|
||||
}
|
||||
}
|
||||
|
||||
// Now start the session negotiator, which will allow us to
|
||||
// request new session as soon as the backupDispatcher starts
|
||||
// up.
|
||||
err = c.negotiator.Start()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Start the task pipeline to which new backup tasks will be
|
||||
// submitted from active links.
|
||||
c.pipeline.Start()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.backupDispatcher()
|
||||
|
||||
log.Infof("Watchtower client started successfully")
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Stop idempotently initiates a graceful shutdown of the watchtower client.
|
||||
func (c *TowerClient) Stop() error {
|
||||
c.stopped.Do(func() {
|
||||
log.Debugf("Stopping watchtower client")
|
||||
|
||||
// 1. Shutdown the backup queue, which will prevent any further
|
||||
// updates from being accepted. In practice, the links should be
|
||||
// shutdown before the client has been stopped, so all updates
|
||||
// would have been added prior.
|
||||
c.pipeline.Stop()
|
||||
|
||||
// 2. To ensure we don't hang forever on shutdown due to
|
||||
// unintended failures, we'll delay a call to force quit the
|
||||
// pipeline if a ForceQuitDelay is specified. This will have no
|
||||
// effect if the pipeline shuts down cleanly before the delay
|
||||
// fires.
|
||||
//
|
||||
// For full safety, this can be set to 0 and wait out
|
||||
// indefinitely. However for mobile clients which may have a
|
||||
// limited amount of time to exit before the background process
|
||||
// is killed, this offers a way to ensure the process
|
||||
// terminates.
|
||||
if c.cfg.ForceQuitDelay > 0 {
|
||||
time.AfterFunc(c.cfg.ForceQuitDelay, c.ForceQuit)
|
||||
}
|
||||
|
||||
// 3. Once the backup queue has shutdown, wait for the main
|
||||
// dispatcher to exit. The backup queue will signal it's
|
||||
// completion to the dispatcher, which releases the wait group
|
||||
// after all tasks have been assigned to session queues.
|
||||
c.wg.Wait()
|
||||
|
||||
// 4. Since all valid tasks have been assigned to session
|
||||
// queues, we no longer need to negotiate sessions.
|
||||
c.negotiator.Stop()
|
||||
|
||||
log.Debugf("Waiting for active session queues to finish "+
|
||||
"draining, stats: %s", c.stats)
|
||||
|
||||
// 5. Shutdown all active session queues in parallel. These will
|
||||
// exit once all updates have been acked by the watchtower.
|
||||
c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() {
|
||||
return s.Stop
|
||||
})
|
||||
|
||||
// Skip log if force quitting.
|
||||
select {
|
||||
case <-c.forceQuit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
log.Debugf("Client successfully stopped, stats: %s", c.stats)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceQuit idempotently initiates an unclean shutdown of the watchtower
|
||||
// client. This should only be executed if Stop is unable to exit cleanly.
|
||||
func (c *TowerClient) ForceQuit() {
|
||||
c.forced.Do(func() {
|
||||
log.Infof("Force quitting watchtower client")
|
||||
|
||||
// Cancel log message from stop.
|
||||
close(c.forceQuit)
|
||||
|
||||
// 1. Shutdown the backup queue, which will prevent any further
|
||||
// updates from being accepted. In practice, the links should be
|
||||
// shutdown before the client has been stopped, so all updates
|
||||
// would have been added prior.
|
||||
c.pipeline.ForceQuit()
|
||||
|
||||
// 2. Once the backup queue has shutdown, wait for the main
|
||||
// dispatcher to exit. The backup queue will signal it's
|
||||
// completion to the dispatcher, which releases the wait group
|
||||
// after all tasks have been assigned to session queues.
|
||||
c.wg.Wait()
|
||||
|
||||
// 3. Since all valid tasks have been assigned to session
|
||||
// queues, we no longer need to negotiate sessions.
|
||||
c.negotiator.Stop()
|
||||
|
||||
// 4. Force quit all active session queues in parallel. These
|
||||
// will exit once all updates have been acked by the watchtower.
|
||||
c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() {
|
||||
return s.ForceQuit
|
||||
})
|
||||
|
||||
log.Infof("Watchtower client unclean shutdown complete, "+
|
||||
"stats: %s", c.stats)
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterChannel persistently initializes any channel-dependent parameters
|
||||
// within the client. This should be called during link startup to ensure that
|
||||
// the client is able to support the link during operation.
|
||||
func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
|
||||
c.sweepPkScriptMu.Lock()
|
||||
defer c.sweepPkScriptMu.Unlock()
|
||||
|
||||
// If a pkscript for this channel already exists, the channel has been
|
||||
// previously registered.
|
||||
if _, ok := c.sweepPkScripts[chanID]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, generate a new sweep pkscript used to sweep funds for this
|
||||
// channel.
|
||||
pkScript, err := c.cfg.NewAddress()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Persist the sweep pkscript so that restarts will not introduce
|
||||
// address inflation when the channel is reregistered after a restart.
|
||||
err = c.cfg.DB.AddChanPkScript(chanID, pkScript)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, cache the pkscript in our in-memory cache to avoid db
|
||||
// lookups for the remainder of the daemon's execution.
|
||||
c.sweepPkScripts[chanID] = pkScript
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackupState initiates a request to back up a particular revoked state. If the
|
||||
// method returns nil, the backup is guaranteed to be successful unless the:
|
||||
// - client is force quit,
|
||||
// - justice transaction would create dust outputs when trying to abide by the
|
||||
// negotiated policy, or
|
||||
// - breached outputs contain too little value to sweep at the target sweep fee
|
||||
// rate.
|
||||
func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||
breachInfo *lnwallet.BreachRetribution) error {
|
||||
|
||||
// Retrieve the cached sweep pkscript used for this channel.
|
||||
c.sweepPkScriptMu.RLock()
|
||||
sweepPkScript, ok := c.sweepPkScripts[*chanID]
|
||||
c.sweepPkScriptMu.RUnlock()
|
||||
if !ok {
|
||||
return ErrUnregisteredChannel
|
||||
}
|
||||
|
||||
task := newBackupTask(chanID, breachInfo, sweepPkScript)
|
||||
|
||||
return c.pipeline.QueueBackupTask(task)
|
||||
}
|
||||
|
||||
// nextSessionQueue attempts to fetch an active session from our set of
|
||||
// candidate sessions. Candidate sessions with a differing policy from the
|
||||
// active client's advertised policy will be ignored, but may be resumed if the
|
||||
// client is restarted with a matching policy. If no candidates were found, nil
|
||||
// is returned to signal that we need to request a new policy.
|
||||
func (c *TowerClient) nextSessionQueue() *sessionQueue {
|
||||
// Select any candidate session at random, and remove it from the set of
|
||||
// candidate sessions.
|
||||
var candidateSession *wtdb.ClientSession
|
||||
for id, sessionInfo := range c.candidateSessions {
|
||||
delete(c.candidateSessions, id)
|
||||
|
||||
// Skip any sessions with policies that don't match the current
|
||||
// configuration. These can be used again if the client changes
|
||||
// their configuration back.
|
||||
if sessionInfo.Policy != c.cfg.Policy {
|
||||
continue
|
||||
}
|
||||
|
||||
candidateSession = sessionInfo
|
||||
break
|
||||
}
|
||||
|
||||
// If none of the sessions could be used or none were found, we'll
|
||||
// return nil to signal that we need another session to be negotiated.
|
||||
if candidateSession == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize the session queue and spin it up so it can begin handling
|
||||
// updates. If the queue was already made active on startup, this will
|
||||
// simply return the existing session queue from the set.
|
||||
return c.getOrInitActiveQueue(candidateSession)
|
||||
}
|
||||
|
||||
// backupDispatcher processes events coming from the taskPipeline and is
|
||||
// responsible for detecting when the client needs to renegotiate a session to
|
||||
// fulfill continuing demand. The event loop exits after all tasks have been
|
||||
// received from the upstream taskPipeline, or the taskPipeline is force quit.
|
||||
//
|
||||
// NOTE: This method MUST be run as a goroutine.
|
||||
func (c *TowerClient) backupDispatcher() {
|
||||
defer c.wg.Done()
|
||||
|
||||
log.Tracef("Starting backup dispatcher")
|
||||
defer log.Tracef("Stopping backup dispatcher")
|
||||
|
||||
for {
|
||||
switch {
|
||||
|
||||
// No active session queue and no additional sessions.
|
||||
case c.sessionQueue == nil && len(c.candidateSessions) == 0:
|
||||
log.Infof("Requesting new session.")
|
||||
|
||||
// Immediately request a new session.
|
||||
c.negotiator.RequestSession()
|
||||
|
||||
// Wait until we receive the newly negotiated session.
|
||||
// All backups sent in the meantime are queued in the
|
||||
// revoke queue, as we cannot process them.
|
||||
select {
|
||||
case session := <-c.negotiator.NewSessions():
|
||||
log.Infof("Acquired new session with id=%s",
|
||||
session.ID)
|
||||
c.candidateSessions[session.ID] = session
|
||||
c.stats.sessionAcquired()
|
||||
|
||||
case <-c.statTicker.C:
|
||||
log.Infof("Client stats: %s", c.stats)
|
||||
}
|
||||
|
||||
// No active session queue but have additional sessions.
|
||||
case c.sessionQueue == nil && len(c.candidateSessions) > 0:
|
||||
// We've exhausted the prior session, we'll pop another
|
||||
// from the remaining sessions and continue processing
|
||||
// backup tasks.
|
||||
c.sessionQueue = c.nextSessionQueue()
|
||||
if c.sessionQueue != nil {
|
||||
log.Debugf("Loaded next candidate session "+
|
||||
"queue id=%s", c.sessionQueue.ID())
|
||||
}
|
||||
|
||||
// Have active session queue, process backups.
|
||||
case c.sessionQueue != nil:
|
||||
if c.prevTask != nil {
|
||||
c.processTask(c.prevTask)
|
||||
|
||||
// Continue to ensure the sessionQueue is
|
||||
// properly initialized before attempting to
|
||||
// process more tasks from the pipeline.
|
||||
continue
|
||||
}
|
||||
|
||||
// Normal operation where new tasks are read from the
|
||||
// pipeline.
|
||||
select {
|
||||
|
||||
// If any sessions are negotiated while we have an
|
||||
// active session queue, queue them for future use.
|
||||
// This shouldn't happen with the current design, so
|
||||
// it doesn't hurt to select here just in case. In the
|
||||
// future, we will likely allow more asynchrony so that
|
||||
// we can request new sessions before the session is
|
||||
// fully empty, which this case would handle.
|
||||
case session := <-c.negotiator.NewSessions():
|
||||
log.Warnf("Acquired new session with id=%s",
|
||||
"while processing tasks", session.ID)
|
||||
c.candidateSessions[session.ID] = session
|
||||
c.stats.sessionAcquired()
|
||||
|
||||
case <-c.statTicker.C:
|
||||
log.Infof("Client stats: %s", c.stats)
|
||||
|
||||
// Process each backup task serially from the queue of
|
||||
// revoked states.
|
||||
case task, ok := <-c.pipeline.NewBackupTasks():
|
||||
// All backups in the pipeline have been
|
||||
// processed, it is now safe to exit.
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Processing backup task chanid=%s "+
|
||||
"commit-height=%d", task.id.ChanID,
|
||||
task.id.CommitHeight)
|
||||
|
||||
c.stats.taskReceived()
|
||||
c.processTask(task)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processTask attempts to schedule the given backupTask on the active
|
||||
// sessionQueue. The task will either be accepted or rejected, afterwhich the
|
||||
// appropriate modifications to the client's state machine will be made. After
|
||||
// every invocation of processTask, the caller should ensure that the
|
||||
// sessionQueue hasn't been exhausted before proceeding to the next task. Tasks
|
||||
// that are rejected because the active sessionQueue is full will be cached as
|
||||
// the prevTask, and should be reprocessed after obtaining a new sessionQueue.
|
||||
func (c *TowerClient) processTask(task *backupTask) {
|
||||
status, accepted := c.sessionQueue.AcceptTask(task)
|
||||
if accepted {
|
||||
c.taskAccepted(task, status)
|
||||
} else {
|
||||
c.taskRejected(task, status)
|
||||
}
|
||||
}
|
||||
|
||||
// taskAccepted processes the acceptance of a task by a sessionQueue depending
|
||||
// on the state the sessionQueue is in *after* the task is added. The client's
|
||||
// prevTask is always removed as a result of this call. The client's
|
||||
// sessionQueue will be removed if accepting the task left the sessionQueue in
|
||||
// an exhausted state.
|
||||
func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) {
|
||||
log.Infof("Backup chanid=%s commit-height=%d accepted successfully",
|
||||
task.id.ChanID, task.id.CommitHeight)
|
||||
|
||||
c.stats.taskAccepted()
|
||||
|
||||
// If this task was accepted, we discard anything held in the prevTask.
|
||||
// Either it was nil before, or is the task which was just accepted.
|
||||
c.prevTask = nil
|
||||
|
||||
switch newStatus {
|
||||
|
||||
// The sessionQueue still has capacity after accepting this task.
|
||||
case reserveAvailable:
|
||||
|
||||
// The sessionQueue is full after accepting this task, so we will need
|
||||
// to request a new one before proceeding.
|
||||
case reserveExhausted:
|
||||
c.stats.sessionExhausted()
|
||||
|
||||
log.Debugf("Session %s exhausted", c.sessionQueue.ID())
|
||||
|
||||
// This task left the session exhausted, set it to nil and
|
||||
// proceed to the next loop so we can consume another
|
||||
// pre-negotiated session or request another.
|
||||
c.sessionQueue = nil
|
||||
}
|
||||
}
|
||||
|
||||
// taskRejected process the rejection of a task by a sessionQueue depending on
|
||||
// the state the was in *before* the task was rejected. The client's prevTask
|
||||
// will cache the task if the sessionQueue was exhausted before hand, and nil
|
||||
// the sessionQueue to find a new session. If the sessionQueue was not
|
||||
// exhausted, the client marks the task as ineligible, as this implies we
|
||||
// couldn't construct a valid justice transaction given the session's policy.
|
||||
func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) {
|
||||
switch curStatus {
|
||||
|
||||
// The sessionQueue has available capacity but the task was rejected,
|
||||
// this indicates that the task was ineligible for backup.
|
||||
case reserveAvailable:
|
||||
c.stats.taskIneligible()
|
||||
|
||||
log.Infof("Backup chanid=%s commit-height=%d is ineligible",
|
||||
task.id.ChanID, task.id.CommitHeight)
|
||||
|
||||
err := c.cfg.DB.MarkBackupIneligible(
|
||||
task.id.ChanID, task.id.CommitHeight,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to mark task chanid=%s "+
|
||||
"commit-height=%d ineligible: %v",
|
||||
task.id.ChanID, task.id.CommitHeight, err)
|
||||
|
||||
// It is safe to not handle this error, even if we could
|
||||
// not persist the result. At worst, this task may be
|
||||
// reprocessed on a subsequent start up, and will either
|
||||
// succeed do a change in session parameters or fail in
|
||||
// the same manner.
|
||||
}
|
||||
|
||||
// If this task was rejected *and* the session had available
|
||||
// capacity, we discard anything held in the prevTask. Either it
|
||||
// was nil before, or is the task which was just rejected.
|
||||
c.prevTask = nil
|
||||
|
||||
// The sessionQueue rejected the task because it is full, we will stash
|
||||
// this task and try to add it to the next available sessionQueue.
|
||||
case reserveExhausted:
|
||||
c.stats.sessionExhausted()
|
||||
|
||||
log.Debugf("Session %s exhausted, backup chanid=%s "+
|
||||
"commit-height=%d queued for next session",
|
||||
c.sessionQueue.ID(), task.id.ChanID,
|
||||
task.id.CommitHeight)
|
||||
|
||||
// Cache the task that we pulled off, so that we can process it
|
||||
// once a new session queue is available.
|
||||
c.sessionQueue = nil
|
||||
c.prevTask = task
|
||||
}
|
||||
}
|
||||
|
||||
// dial connects the peer at addr using privKey as our secret key for the
|
||||
// connection. The connection will use the configured Net's resolver to resolve
|
||||
// the address for either Tor or clear net connections.
|
||||
func (c *TowerClient) dial(privKey *btcec.PrivateKey,
|
||||
addr *lnwire.NetAddress) (wtserver.Peer, error) {
|
||||
|
||||
return c.cfg.AuthDial(privKey, addr, c.cfg.Dial)
|
||||
}
|
||||
|
||||
// readMessage receives and parses the next message from the given Peer. An
|
||||
// error is returned if a message is not received before the server's read
|
||||
// timeout, the read off the wire failed, or the message could not be
|
||||
// deserialized.
|
||||
func (c *TowerClient) readMessage(peer wtserver.Peer) (wtwire.Message, error) {
|
||||
// Set a read timeout to ensure we drop the connection if nothing is
|
||||
// received in a timely manner.
|
||||
err := peer.SetReadDeadline(time.Now().Add(c.cfg.ReadTimeout))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to set read deadline: %v", err)
|
||||
log.Errorf("Unable to read msg: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Pull the next message off the wire,
|
||||
rawMsg, err := peer.ReadNextMessage()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to read message: %v", err)
|
||||
log.Errorf("Unable to read msg: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the received message according to the watchtower wire
|
||||
// specification.
|
||||
msgReader := bytes.NewReader(rawMsg)
|
||||
msg, err := wtwire.ReadMessage(msgReader, 0)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to parse message: %v", err)
|
||||
log.Errorf("Unable to read msg: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logMessage(peer, msg, true)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// sendMessage sends a watchtower wire message to the target peer.
|
||||
func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error {
|
||||
// Encode the next wire message into the buffer.
|
||||
// TODO(conner): use buffer pool
|
||||
var b bytes.Buffer
|
||||
_, err := wtwire.WriteMessage(&b, msg, 0)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Unable to encode msg: %v", err)
|
||||
log.Errorf("Unable to send msg: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the write deadline for the connection, ensuring we drop the
|
||||
// connection if nothing is sent in a timely manner.
|
||||
err = peer.SetWriteDeadline(time.Now().Add(c.cfg.WriteTimeout))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to set write deadline: %v", err)
|
||||
log.Errorf("Unable to send msg: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logMessage(peer, msg, false)
|
||||
|
||||
// Write out the full message to the remote peer.
|
||||
_, err = peer.Write(b.Bytes())
|
||||
if err != nil {
|
||||
log.Errorf("Unable to send msg: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the
|
||||
// database and supplying it with the resources needed by the client.
|
||||
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
return newSessionQueue(&sessionQueueConfig{
|
||||
ClientSession: s,
|
||||
ChainHash: c.cfg.ChainHash,
|
||||
Dial: c.dial,
|
||||
ReadMessage: c.readMessage,
|
||||
SendMessage: c.sendMessage,
|
||||
Signer: c.cfg.Signer,
|
||||
DB: c.cfg.DB,
|
||||
MinBackoff: c.cfg.MinBackoff,
|
||||
MaxBackoff: c.cfg.MaxBackoff,
|
||||
})
|
||||
}
|
||||
|
||||
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
|
||||
// passed ClientSession. If it exists, the active sessionQueue is returned.
|
||||
// Otherwise a new sessionQueue is initialized and added to the set.
|
||||
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
if sq, ok := c.activeSessions[s.ID]; ok {
|
||||
return sq
|
||||
}
|
||||
|
||||
return c.initActiveQueue(s)
|
||||
}
|
||||
|
||||
// initActiveQueue creates a new sessionQueue from the passed ClientSession,
|
||||
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue
|
||||
// so that it can deliver any committed updates or begin accepting newly
|
||||
// assigned tasks.
|
||||
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
// Initialize the session queue, providing it with all of the resources
|
||||
// it requires from the client instance.
|
||||
sq := c.newSessionQueue(s)
|
||||
|
||||
// Add the session queue as an active session so that we remember to
|
||||
// stop it on shutdown.
|
||||
c.activeSessions.Add(sq)
|
||||
|
||||
// Start the queue so that it can be active in processing newly assigned
|
||||
// tasks or to upload previously committed updates.
|
||||
sq.Start()
|
||||
|
||||
return sq
|
||||
}
|
||||
|
||||
// logMessage writes information about a message received from a remote peer,
|
||||
// using directional prepositions to signal whether the message was sent or
|
||||
// received.
|
||||
func logMessage(peer wtserver.Peer, msg wtwire.Message, read bool) {
|
||||
var action = "Received"
|
||||
var preposition = "from"
|
||||
if !read {
|
||||
action = "Sending"
|
||||
preposition = "to"
|
||||
}
|
||||
|
||||
summary := wtwire.MessageSummary(msg)
|
||||
if len(summary) > 0 {
|
||||
summary = "(" + summary + ")"
|
||||
}
|
||||
|
||||
log.Debugf("%s %s%v %s %x@%s", action, msg.MsgType(), summary,
|
||||
preposition, peer.RemotePub().SerializeCompressed(),
|
||||
peer.RemoteAddr())
|
||||
}
|
1118
watchtower/wtclient/client_test.go
Normal file
1118
watchtower/wtclient/client_test.go
Normal file
@ -0,0 +1,1118 @@
|
||||
// +build dev
|
||||
|
||||
package wtclient_test
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||
)
|
||||
|
||||
const csvDelay uint32 = 144
|
||||
|
||||
var (
|
||||
revPrivBytes = []byte{
|
||||
0x8f, 0x4b, 0x51, 0x83, 0xa9, 0x34, 0xbd, 0x5f,
|
||||
0x74, 0x6c, 0x9d, 0x5c, 0xae, 0x88, 0x2d, 0x31,
|
||||
0x06, 0x90, 0xdd, 0x8c, 0x9b, 0x31, 0xbc, 0xd1,
|
||||
0x78, 0x91, 0x88, 0x2a, 0xf9, 0x74, 0xa0, 0xef,
|
||||
}
|
||||
|
||||
toLocalPrivBytes = []byte{
|
||||
0xde, 0x17, 0xc1, 0x2f, 0xdc, 0x1b, 0xc0, 0xc6,
|
||||
0x59, 0x5d, 0xf9, 0xc1, 0x3e, 0x89, 0xbc, 0x6f,
|
||||
0x01, 0x85, 0x45, 0x76, 0x26, 0xce, 0x9c, 0x55,
|
||||
0x3b, 0xc9, 0xec, 0x3d, 0xd8, 0x8b, 0xac, 0xa8,
|
||||
}
|
||||
|
||||
toRemotePrivBytes = []byte{
|
||||
0x28, 0x59, 0x6f, 0x36, 0xb8, 0x9f, 0x19, 0x5d,
|
||||
0xcb, 0x07, 0x48, 0x8a, 0xe5, 0x89, 0x71, 0x74,
|
||||
0x70, 0x4c, 0xff, 0x1e, 0x9c, 0x00, 0x93, 0xbe,
|
||||
0xe2, 0x2e, 0x68, 0x08, 0x4c, 0xb4, 0x0f, 0x4f,
|
||||
}
|
||||
|
||||
// addr is the server's reward address given to watchtower clients.
|
||||
addr, _ = btcutil.DecodeAddress(
|
||||
"mrX9vMRYLfVy1BnZbc5gZjuyaqH3ZW2ZHz", &chaincfg.TestNet3Params,
|
||||
)
|
||||
|
||||
addrScript, _ = txscript.PayToAddrScript(addr)
|
||||
)
|
||||
|
||||
// randPrivKey generates a new secp keypair, and returns the public key.
|
||||
func randPrivKey(t *testing.T) *btcec.PrivateKey {
|
||||
t.Helper()
|
||||
|
||||
sk, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate pubkey: %v", err)
|
||||
}
|
||||
|
||||
return sk
|
||||
}
|
||||
|
||||
type mockNet struct {
|
||||
mu sync.RWMutex
|
||||
connCallback func(wtserver.Peer)
|
||||
}
|
||||
|
||||
func newMockNet(cb func(wtserver.Peer)) *mockNet {
|
||||
return &mockNet{
|
||||
connCallback: cb,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockNet) Dial(network string, address string) (net.Conn, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockNet) LookupHost(host string) ([]string, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockNet) AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
|
||||
dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error) {
|
||||
|
||||
localPk := localPriv.PubKey()
|
||||
localAddr := &net.TCPAddr{
|
||||
IP: net.IP{0x32, 0x31, 0x30, 0x29},
|
||||
Port: 36723,
|
||||
}
|
||||
|
||||
localPeer, remotePeer := wtmock.NewMockConn(
|
||||
localPk, netAddr.IdentityKey, localAddr, netAddr.Address, 0,
|
||||
)
|
||||
|
||||
m.mu.RLock()
|
||||
m.connCallback(remotePeer)
|
||||
m.mu.RUnlock()
|
||||
|
||||
return localPeer, nil
|
||||
}
|
||||
|
||||
func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.connCallback = cb
|
||||
}
|
||||
|
||||
type mockChannel struct {
|
||||
mu sync.Mutex
|
||||
commitHeight uint64
|
||||
retributions map[uint64]*lnwallet.BreachRetribution
|
||||
localBalance lnwire.MilliSatoshi
|
||||
remoteBalance lnwire.MilliSatoshi
|
||||
|
||||
revSK *btcec.PrivateKey
|
||||
revPK *btcec.PublicKey
|
||||
revKeyLoc keychain.KeyLocator
|
||||
|
||||
toRemoteSK *btcec.PrivateKey
|
||||
toRemotePK *btcec.PublicKey
|
||||
toRemoteKeyLoc keychain.KeyLocator
|
||||
|
||||
toLocalPK *btcec.PublicKey // only need to generate to-local script
|
||||
|
||||
dustLimit lnwire.MilliSatoshi
|
||||
csvDelay uint32
|
||||
}
|
||||
|
||||
func newMockChannel(t *testing.T, signer *wtmock.MockSigner,
|
||||
localAmt, remoteAmt lnwire.MilliSatoshi) *mockChannel {
|
||||
|
||||
// Generate the revocation, to-local, and to-remote keypairs.
|
||||
revSK := randPrivKey(t)
|
||||
revPK := revSK.PubKey()
|
||||
|
||||
toLocalSK := randPrivKey(t)
|
||||
toLocalPK := toLocalSK.PubKey()
|
||||
|
||||
toRemoteSK := randPrivKey(t)
|
||||
toRemotePK := toRemoteSK.PubKey()
|
||||
|
||||
// Register the revocation secret key and the to-remote secret key with
|
||||
// the signer. We will not need to sign with the to-local key, as this
|
||||
// is to be known only by the counterparty.
|
||||
revKeyLoc := signer.AddPrivKey(revSK)
|
||||
toRemoteKeyLoc := signer.AddPrivKey(toRemoteSK)
|
||||
|
||||
c := &mockChannel{
|
||||
retributions: make(map[uint64]*lnwallet.BreachRetribution),
|
||||
localBalance: localAmt,
|
||||
remoteBalance: remoteAmt,
|
||||
revSK: revSK,
|
||||
revPK: revPK,
|
||||
revKeyLoc: revKeyLoc,
|
||||
toLocalPK: toLocalPK,
|
||||
toRemoteSK: toRemoteSK,
|
||||
toRemotePK: toRemotePK,
|
||||
toRemoteKeyLoc: toRemoteKeyLoc,
|
||||
dustLimit: 546000,
|
||||
csvDelay: 144,
|
||||
}
|
||||
|
||||
// Create the initial remote commitment with the initial balances.
|
||||
c.createRemoteCommitTx(t)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *mockChannel) createRemoteCommitTx(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// Construct the to-local witness script.
|
||||
toLocalScript, err := input.CommitScriptToSelf(
|
||||
c.csvDelay, c.toLocalPK, c.revPK,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create to-local script: %v", err)
|
||||
}
|
||||
|
||||
// Compute the to-local witness script hash.
|
||||
toLocalScriptHash, err := input.WitnessScriptHash(toLocalScript)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create to-local witness script hash: %v", err)
|
||||
}
|
||||
|
||||
// Compute the to-remote witness script hash.
|
||||
toRemoteScriptHash, err := input.CommitScriptUnencumbered(c.toRemotePK)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create to-remote script: %v", err)
|
||||
}
|
||||
|
||||
// Construct the remote commitment txn, containing the to-local and
|
||||
// to-remote outputs. The balances are flipped since the transaction is
|
||||
// from the PoV of the remote party. We don't need any inputs for this
|
||||
// test. We increment the version with the commit height to ensure that
|
||||
// all commitment transactions are unique even if the same distribution
|
||||
// of funds is used more than once.
|
||||
commitTxn := &wire.MsgTx{
|
||||
Version: int32(c.commitHeight + 1),
|
||||
}
|
||||
|
||||
var (
|
||||
toLocalSignDesc *input.SignDescriptor
|
||||
toRemoteSignDesc *input.SignDescriptor
|
||||
)
|
||||
|
||||
var outputIndex int
|
||||
if c.remoteBalance >= c.dustLimit {
|
||||
commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{
|
||||
Value: int64(c.remoteBalance.ToSatoshis()),
|
||||
PkScript: toLocalScriptHash,
|
||||
})
|
||||
|
||||
// Create the sign descriptor used to sign for the to-local
|
||||
// input.
|
||||
toLocalSignDesc = &input.SignDescriptor{
|
||||
KeyDesc: keychain.KeyDescriptor{
|
||||
KeyLocator: c.revKeyLoc,
|
||||
PubKey: c.revPK,
|
||||
},
|
||||
WitnessScript: toLocalScript,
|
||||
Output: commitTxn.TxOut[outputIndex],
|
||||
HashType: txscript.SigHashAll,
|
||||
}
|
||||
outputIndex++
|
||||
}
|
||||
if c.localBalance >= c.dustLimit {
|
||||
commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{
|
||||
Value: int64(c.localBalance.ToSatoshis()),
|
||||
PkScript: toRemoteScriptHash,
|
||||
})
|
||||
|
||||
// Create the sign descriptor used to sign for the to-remote
|
||||
// input.
|
||||
toRemoteSignDesc = &input.SignDescriptor{
|
||||
KeyDesc: keychain.KeyDescriptor{
|
||||
KeyLocator: c.toRemoteKeyLoc,
|
||||
PubKey: c.toRemotePK,
|
||||
},
|
||||
WitnessScript: toRemoteScriptHash,
|
||||
Output: commitTxn.TxOut[outputIndex],
|
||||
HashType: txscript.SigHashAll,
|
||||
}
|
||||
outputIndex++
|
||||
}
|
||||
|
||||
txid := commitTxn.TxHash()
|
||||
|
||||
var (
|
||||
toLocalOutPoint wire.OutPoint
|
||||
toRemoteOutPoint wire.OutPoint
|
||||
)
|
||||
|
||||
outputIndex = 0
|
||||
if toLocalSignDesc != nil {
|
||||
toLocalOutPoint = wire.OutPoint{
|
||||
Hash: txid,
|
||||
Index: uint32(outputIndex),
|
||||
}
|
||||
outputIndex++
|
||||
}
|
||||
if toRemoteSignDesc != nil {
|
||||
toRemoteOutPoint = wire.OutPoint{
|
||||
Hash: txid,
|
||||
Index: uint32(outputIndex),
|
||||
}
|
||||
outputIndex++
|
||||
}
|
||||
|
||||
commitKeyRing := &lnwallet.CommitmentKeyRing{
|
||||
RevocationKey: c.revPK,
|
||||
NoDelayKey: c.toLocalPK,
|
||||
DelayKey: c.toRemotePK,
|
||||
}
|
||||
|
||||
retribution := &lnwallet.BreachRetribution{
|
||||
BreachTransaction: commitTxn,
|
||||
RevokedStateNum: c.commitHeight,
|
||||
KeyRing: commitKeyRing,
|
||||
RemoteDelay: c.csvDelay,
|
||||
LocalOutpoint: toRemoteOutPoint,
|
||||
LocalOutputSignDesc: toRemoteSignDesc,
|
||||
RemoteOutpoint: toLocalOutPoint,
|
||||
RemoteOutputSignDesc: toLocalSignDesc,
|
||||
}
|
||||
|
||||
c.retributions[c.commitHeight] = retribution
|
||||
c.commitHeight++
|
||||
}
|
||||
|
||||
// advanceState creates the next channel state and retribution without altering
|
||||
// channel balances.
|
||||
func (c *mockChannel) advanceState(t *testing.T) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.createRemoteCommitTx(t)
|
||||
}
|
||||
|
||||
// sendPayment creates the next channel state and retribution after transferring
|
||||
// amt to the remote party.
|
||||
func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliSatoshi) {
|
||||
t.Helper()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.localBalance < amt {
|
||||
t.Fatalf("insufficient funds to send, need: %v, have: %v",
|
||||
amt, c.localBalance)
|
||||
}
|
||||
|
||||
c.localBalance -= amt
|
||||
c.remoteBalance += amt
|
||||
c.createRemoteCommitTx(t)
|
||||
}
|
||||
|
||||
// receivePayment creates the next channel state and retribution after
|
||||
// transferring amt to the local party.
|
||||
func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliSatoshi) {
|
||||
t.Helper()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.remoteBalance < amt {
|
||||
t.Fatalf("insufficient funds to recv, need: %v, have: %v",
|
||||
amt, c.remoteBalance)
|
||||
}
|
||||
|
||||
c.localBalance += amt
|
||||
c.remoteBalance -= amt
|
||||
c.createRemoteCommitTx(t)
|
||||
}
|
||||
|
||||
// getState retrieves the channel's commitment and retribution at state i.
|
||||
func (c *mockChannel) getState(i uint64) (*wire.MsgTx, *lnwallet.BreachRetribution) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
retribution := c.retributions[i]
|
||||
|
||||
return retribution.BreachTransaction, retribution
|
||||
}
|
||||
|
||||
type testHarness struct {
|
||||
t *testing.T
|
||||
cfg harnessCfg
|
||||
signer *wtmock.MockSigner
|
||||
capacity lnwire.MilliSatoshi
|
||||
clientDB *wtmock.ClientDB
|
||||
clientCfg *wtclient.Config
|
||||
client wtclient.Client
|
||||
serverDB *wtdb.MockDB
|
||||
serverCfg *wtserver.Config
|
||||
server *wtserver.Server
|
||||
net *mockNet
|
||||
|
||||
mu sync.Mutex
|
||||
channels map[lnwire.ChannelID]*mockChannel
|
||||
}
|
||||
|
||||
type harnessCfg struct {
|
||||
localBalance lnwire.MilliSatoshi
|
||||
remoteBalance lnwire.MilliSatoshi
|
||||
policy wtpolicy.Policy
|
||||
noRegisterChan0 bool
|
||||
}
|
||||
|
||||
func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||
towerAddrStr := "18.28.243.2:9911"
|
||||
towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to resolve tower TCP addr: %v", err)
|
||||
}
|
||||
|
||||
privKey, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate tower private key: %v", err)
|
||||
}
|
||||
|
||||
towerPubKey := privKey.PubKey()
|
||||
|
||||
towerAddr := &lnwire.NetAddress{
|
||||
IdentityKey: towerPubKey,
|
||||
Address: towerTCPAddr,
|
||||
}
|
||||
|
||||
const timeout = 200 * time.Millisecond
|
||||
serverDB := wtdb.NewMockDB()
|
||||
|
||||
serverCfg := &wtserver.Config{
|
||||
DB: serverDB,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
NewAddress: func() (btcutil.Address, error) {
|
||||
return addr, nil
|
||||
},
|
||||
}
|
||||
|
||||
server, err := wtserver.New(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create wtserver: %v", err)
|
||||
}
|
||||
|
||||
signer := wtmock.NewMockSigner()
|
||||
mockNet := newMockNet(server.InboundPeerConnected)
|
||||
clientDB := wtmock.NewClientDB()
|
||||
|
||||
clientCfg := &wtclient.Config{
|
||||
Signer: signer,
|
||||
Dial: func(string, string) (net.Conn, error) {
|
||||
return nil, nil
|
||||
},
|
||||
DB: clientDB,
|
||||
AuthDial: mockNet.AuthDial,
|
||||
PrivateTower: towerAddr,
|
||||
Policy: cfg.policy,
|
||||
NewAddress: func() ([]byte, error) {
|
||||
return addrScript, nil
|
||||
},
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
MinBackoff: time.Millisecond,
|
||||
MaxBackoff: 10 * time.Millisecond,
|
||||
}
|
||||
client, err := wtclient.New(clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create wtclient: %v", err)
|
||||
}
|
||||
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatalf("Unable to start wtserver: %v", err)
|
||||
}
|
||||
|
||||
if err = client.Start(); err != nil {
|
||||
server.Stop()
|
||||
t.Fatalf("Unable to start wtclient: %v", err)
|
||||
}
|
||||
|
||||
h := &testHarness{
|
||||
t: t,
|
||||
cfg: cfg,
|
||||
signer: signer,
|
||||
capacity: cfg.localBalance + cfg.remoteBalance,
|
||||
clientDB: clientDB,
|
||||
clientCfg: clientCfg,
|
||||
client: client,
|
||||
serverDB: serverDB,
|
||||
serverCfg: serverCfg,
|
||||
server: server,
|
||||
net: mockNet,
|
||||
channels: make(map[lnwire.ChannelID]*mockChannel),
|
||||
}
|
||||
|
||||
h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
|
||||
if !cfg.noRegisterChan0 {
|
||||
h.registerChannel(0)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// startServer creates a new server using the harness's current serverCfg and
|
||||
// starts it after pointing the mockNet's callback to the new server.
|
||||
func (h *testHarness) startServer() {
|
||||
h.t.Helper()
|
||||
|
||||
var err error
|
||||
h.server, err = wtserver.New(h.serverCfg)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to create wtserver: %v", err)
|
||||
}
|
||||
|
||||
h.net.setConnCallback(h.server.InboundPeerConnected)
|
||||
|
||||
if err := h.server.Start(); err != nil {
|
||||
h.t.Fatalf("unable to start wtserver: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// startClient creates a new server using the harness's current clientCf and
|
||||
// starts it.
|
||||
func (h *testHarness) startClient() {
|
||||
h.t.Helper()
|
||||
|
||||
var err error
|
||||
h.client, err = wtclient.New(h.clientCfg)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to create wtclient: %v", err)
|
||||
}
|
||||
if err := h.client.Start(); err != nil {
|
||||
h.t.Fatalf("unable to start wtclient: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// chanIDFromInt creates a unique channel id given a unique integral id.
|
||||
func chanIDFromInt(id uint64) lnwire.ChannelID {
|
||||
var chanID lnwire.ChannelID
|
||||
binary.BigEndian.PutUint64(chanID[:8], id)
|
||||
return chanID
|
||||
}
|
||||
|
||||
// makeChannel creates new channel with id, using the localAmt and remoteAmt as
|
||||
// the starting balances. The channel will be available by using h.channel(id).
|
||||
//
|
||||
// NOTE: The method fails if channel for id already exists.
|
||||
func (h *testHarness) makeChannel(id uint64,
|
||||
localAmt, remoteAmt lnwire.MilliSatoshi) {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
chanID := chanIDFromInt(id)
|
||||
c := newMockChannel(h.t, h.signer, localAmt, remoteAmt)
|
||||
|
||||
c.mu.Lock()
|
||||
_, ok := h.channels[chanID]
|
||||
if !ok {
|
||||
h.channels[chanID] = c
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
h.t.Fatalf("channel %d already created", id)
|
||||
}
|
||||
}
|
||||
|
||||
// channel retrieves the channel corresponding to id.
|
||||
//
|
||||
// NOTE: The method fails if a channel for id does not exist.
|
||||
func (h *testHarness) channel(id uint64) *mockChannel {
|
||||
h.t.Helper()
|
||||
|
||||
h.mu.Lock()
|
||||
c, ok := h.channels[chanIDFromInt(id)]
|
||||
h.mu.Unlock()
|
||||
if !ok {
|
||||
h.t.Fatalf("unable to fetch channel %d", id)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// registerChannel registers the channel identified by id with the client.
|
||||
func (h *testHarness) registerChannel(id uint64) {
|
||||
h.t.Helper()
|
||||
|
||||
chanID := chanIDFromInt(id)
|
||||
err := h.client.RegisterChannel(chanID)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to register channel %d: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// advanceChannelN calls advanceState on the channel identified by id the number
|
||||
// of provided times and returns the breach hints corresponding to the new
|
||||
// states.
|
||||
func (h *testHarness) advanceChannelN(id uint64, n int) []wtdb.BreachHint {
|
||||
h.t.Helper()
|
||||
|
||||
channel := h.channel(id)
|
||||
|
||||
var hints []wtdb.BreachHint
|
||||
for i := uint64(0); i < uint64(n); i++ {
|
||||
channel.advanceState(h.t)
|
||||
commitTx, _ := h.channel(id).getState(i)
|
||||
breachTxID := commitTx.TxHash()
|
||||
hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID))
|
||||
}
|
||||
|
||||
return hints
|
||||
}
|
||||
|
||||
// backupStates instructs the channel identified by id to send backups to the
|
||||
// client for states in the range [to, from).
|
||||
func (h *testHarness) backupStates(id, from, to uint64, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
for i := from; i < to; i++ {
|
||||
h.backupState(id, i, expErr)
|
||||
}
|
||||
}
|
||||
|
||||
// backupStates instructs the channel identified by id to send a backup for
|
||||
// state i.
|
||||
func (h *testHarness) backupState(id, i uint64, expErr error) {
|
||||
_, retribution := h.channel(id).getState(i)
|
||||
|
||||
chanID := chanIDFromInt(id)
|
||||
err := h.client.BackupState(&chanID, retribution)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("back error mismatch, want: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// sendPayments instructs the channel identified by id to send amt to the remote
|
||||
// party for each state in from-to times and returns the breach hints for states
|
||||
// [from, to).
|
||||
func (h *testHarness) sendPayments(id, from, to uint64,
|
||||
amt lnwire.MilliSatoshi) []wtdb.BreachHint {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
channel := h.channel(id)
|
||||
|
||||
var hints []wtdb.BreachHint
|
||||
for i := from; i < to; i++ {
|
||||
h.channel(id).sendPayment(h.t, amt)
|
||||
commitTx, _ := channel.getState(i)
|
||||
breachTxID := commitTx.TxHash()
|
||||
hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID))
|
||||
}
|
||||
|
||||
return hints
|
||||
}
|
||||
|
||||
// receivePayment instructs the channel identified by id to recv amt from the
|
||||
// remote party for each state in from-to times and returns the breach hints for
|
||||
// states [from, to).
|
||||
func (h *testHarness) recvPayments(id, from, to uint64,
|
||||
amt lnwire.MilliSatoshi) []wtdb.BreachHint {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
channel := h.channel(id)
|
||||
|
||||
var hints []wtdb.BreachHint
|
||||
for i := from; i < to; i++ {
|
||||
channel.receivePayment(h.t, amt)
|
||||
commitTx, _ := channel.getState(i)
|
||||
breachTxID := commitTx.TxHash()
|
||||
hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID))
|
||||
}
|
||||
|
||||
return hints
|
||||
}
|
||||
|
||||
// waitServerUpdates blocks until the breach hints provided all appear in the
|
||||
// watchtower's database or the timeout expires. This is used to test that the
|
||||
// client in fact sends the updates to the server, even if it is offline.
|
||||
func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint,
|
||||
timeout time.Duration) {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
// If no breach hints are provided, we will wait out the full timeout to
|
||||
// assert that no updates appear.
|
||||
wantUpdates := len(hints) > 0
|
||||
|
||||
hintSet := make(map[wtdb.BreachHint]struct{})
|
||||
for _, hint := range hints {
|
||||
hintSet[hint] = struct{}{}
|
||||
}
|
||||
|
||||
if len(hints) != len(hintSet) {
|
||||
h.t.Fatalf("breach hints are not unique, list-len: %d "+
|
||||
"set-len: %d", len(hints), len(hintSet))
|
||||
}
|
||||
|
||||
// Closure to assert the server's matches are consistent with the hint
|
||||
// set.
|
||||
serverHasHints := func(matches []wtdb.Match) bool {
|
||||
if len(hintSet) != len(matches) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
if _, ok := hintSet[match.Hint]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
h.t.Fatalf("match %v in db is not in hint set",
|
||||
match.Hint)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
failTimeout := time.After(timeout)
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
matches, err := h.serverDB.QueryMatches(hints)
|
||||
switch {
|
||||
case err != nil:
|
||||
h.t.Fatalf("unable to query for hints: %v", err)
|
||||
|
||||
case wantUpdates && serverHasHints(matches):
|
||||
return
|
||||
|
||||
case wantUpdates:
|
||||
h.t.Logf("Received %d/%d\n", len(matches),
|
||||
len(hints))
|
||||
}
|
||||
|
||||
case <-failTimeout:
|
||||
matches, err := h.serverDB.QueryMatches(hints)
|
||||
switch {
|
||||
case err != nil:
|
||||
h.t.Fatalf("unable to query for hints: %v", err)
|
||||
|
||||
case serverHasHints(matches):
|
||||
return
|
||||
|
||||
default:
|
||||
h.t.Fatalf("breach hints not received, only "+
|
||||
"got %d/%d", len(matches), len(hints))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
localBalance = lnwire.MilliSatoshi(100000000)
|
||||
remoteBalance = lnwire.MilliSatoshi(200000000)
|
||||
)
|
||||
|
||||
type clientTest struct {
|
||||
name string
|
||||
cfg harnessCfg
|
||||
fn func(*testHarness)
|
||||
}
|
||||
|
||||
var clientTests = []clientTest{
|
||||
{
|
||||
// Asserts that client will return the ErrUnregisteredChannel
|
||||
// error when trying to backup states for a channel that has not
|
||||
// been registered (and received it's pkscript).
|
||||
name: "backup unregistered channel",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 20000,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
noRegisterChan0: true,
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
numUpdates = 5
|
||||
chanID = 0
|
||||
)
|
||||
|
||||
// Advance the channel and backup the retributions. We
|
||||
// expect ErrUnregisteredChannel to be returned since
|
||||
// the channel was not registered during harness
|
||||
// creation.
|
||||
h.advanceChannelN(chanID, numUpdates)
|
||||
h.backupStates(
|
||||
chanID, 0, numUpdates,
|
||||
wtclient.ErrUnregisteredChannel,
|
||||
)
|
||||
},
|
||||
},
|
||||
{
|
||||
// Asserts that the client returns an ErrClientExiting when
|
||||
// trying to backup channels after the Stop method has been
|
||||
// called.
|
||||
name: "backup after stop",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 20000,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
numUpdates = 5
|
||||
chanID = 0
|
||||
)
|
||||
|
||||
// Stop the client, subsequent backups should fail.
|
||||
h.client.Stop()
|
||||
|
||||
// Advance the channel and try to back up the states. We
|
||||
// expect ErrClientExiting to be returned from
|
||||
// BackupState.
|
||||
h.advanceChannelN(chanID, numUpdates)
|
||||
h.backupStates(
|
||||
chanID, 0, numUpdates,
|
||||
wtclient.ErrClientExiting,
|
||||
)
|
||||
},
|
||||
},
|
||||
{
|
||||
// Asserts that the client will continue to back up all states
|
||||
// that have previously been enqueued before it finishes
|
||||
// exiting.
|
||||
name: "backup reliable flush",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 5,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
numUpdates = 5
|
||||
chanID = 0
|
||||
)
|
||||
|
||||
// Generate numUpdates retributions and back them up to
|
||||
// the tower.
|
||||
hints := h.advanceChannelN(chanID, numUpdates)
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Stop the client in the background, to assert the
|
||||
// pipeline is always flushed before it exits.
|
||||
go h.client.Stop()
|
||||
|
||||
// Wait for all of the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, time.Second)
|
||||
},
|
||||
},
|
||||
{
|
||||
// Assert that the client will not send out backups for states
|
||||
// whose justice transactions are ineligible for backup, e.g.
|
||||
// creating dust outputs.
|
||||
name: "backup dust ineligible",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 20000,
|
||||
SweepFeeRate: 1000000, // high sweep fee creates dust
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
numUpdates = 5
|
||||
chanID = 0
|
||||
)
|
||||
|
||||
// Create the retributions and queue them for backup.
|
||||
h.advanceChannelN(chanID, numUpdates)
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Ensure that no updates are received by the server,
|
||||
// since they should all be marked as ineligible.
|
||||
h.waitServerUpdates(nil, time.Second)
|
||||
},
|
||||
},
|
||||
{
|
||||
// Verifies that the client will properly retransmit a committed
|
||||
// state update to the watchtower after a restart if the update
|
||||
// was not acked while the client was active last.
|
||||
name: "committed update restart",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 20000,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
numUpdates = 5
|
||||
chanID = 0
|
||||
)
|
||||
|
||||
hints := h.advanceChannelN(0, numUpdates)
|
||||
|
||||
var numSent uint64
|
||||
|
||||
// Add the first two states to the client's pipeline.
|
||||
h.backupStates(chanID, 0, 2, nil)
|
||||
numSent = 2
|
||||
|
||||
// Wait for both to be reflected in the server's
|
||||
// database.
|
||||
h.waitServerUpdates(hints[:numSent], time.Second)
|
||||
|
||||
// Now, restart the server and prevent it from acking
|
||||
// state updates.
|
||||
h.server.Stop()
|
||||
h.serverCfg.NoAckUpdates = true
|
||||
h.startServer()
|
||||
defer h.server.Stop()
|
||||
|
||||
// Send the next state update to the tower. Since the
|
||||
// tower isn't acking state updates, we expect this
|
||||
// update to be committed and sent by the session queue,
|
||||
// but it will never receive an ack.
|
||||
h.backupState(chanID, numSent, nil)
|
||||
numSent++
|
||||
|
||||
// Force quit the client to abort the state updates it
|
||||
// has queued. The sleep ensures that the session queues
|
||||
// have enough time to commit the state updates before
|
||||
// the client is killed.
|
||||
time.Sleep(time.Second)
|
||||
h.client.ForceQuit()
|
||||
|
||||
// Restart the server and allow it to ack the updates
|
||||
// after the client retransmits the unacked update.
|
||||
h.server.Stop()
|
||||
h.serverCfg.NoAckUpdates = false
|
||||
h.startServer()
|
||||
defer h.server.Stop()
|
||||
|
||||
// Restart the client and allow it to process the
|
||||
// committed update.
|
||||
h.startClient()
|
||||
defer h.client.ForceQuit()
|
||||
|
||||
// Wait for the committed update to be accepted by the
|
||||
// tower.
|
||||
h.waitServerUpdates(hints[:numSent], time.Second)
|
||||
|
||||
// Finally, send the rest of the updates and wait for
|
||||
// the tower to receive the remaining states.
|
||||
h.backupStates(chanID, numSent, numUpdates, nil)
|
||||
|
||||
// Wait for all of the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, time.Second)
|
||||
|
||||
},
|
||||
},
|
||||
{
|
||||
// Asserts that the client will continue to retry sending state
|
||||
// updates if it doesn't receive an ack from the server. The
|
||||
// client is expected to flush everything in its in-memory
|
||||
// pipeline once the server begins sending acks again.
|
||||
name: "no ack from server",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 5,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
numUpdates = 100
|
||||
chanID = 0
|
||||
)
|
||||
|
||||
// Generate the retributions that will be backed up.
|
||||
hints := h.advanceChannelN(chanID, numUpdates)
|
||||
|
||||
// Restart the server and prevent it from acking state
|
||||
// updates.
|
||||
h.server.Stop()
|
||||
h.serverCfg.NoAckUpdates = true
|
||||
h.startServer()
|
||||
defer h.server.Stop()
|
||||
|
||||
// Now, queue the retributions for backup.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Stop the client in the background, to assert the
|
||||
// pipeline is always flushed before it exits.
|
||||
go h.client.Stop()
|
||||
|
||||
// Give the client time to saturate a large number of
|
||||
// session queues for which the server has not acked the
|
||||
// state updates that it has received.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Restart the server and allow it to ack the updates
|
||||
// after the client retransmits the unacked updates.
|
||||
h.server.Stop()
|
||||
h.serverCfg.NoAckUpdates = false
|
||||
h.startServer()
|
||||
defer h.server.Stop()
|
||||
|
||||
// Wait for all of the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, 5*time.Second)
|
||||
},
|
||||
},
|
||||
{
|
||||
// Asserts that the client is able to send state updates to the
|
||||
// tower for a full range of channel values, assuming the sweep
|
||||
// fee rates permit it. We expect all of these to be successful
|
||||
// since a sweep transactions spending only from one output is
|
||||
// less expensive than one that sweeps both.
|
||||
name: "send and recv",
|
||||
cfg: harnessCfg{
|
||||
localBalance: 10000001, // ensure (% amt != 0)
|
||||
remoteBalance: 20000001, // ensure (% amt != 0)
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 1000,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
var (
|
||||
capacity = h.cfg.localBalance + h.cfg.remoteBalance
|
||||
paymentAmt = lnwire.MilliSatoshi(200000)
|
||||
numSends = uint64(h.cfg.localBalance / paymentAmt)
|
||||
numRecvs = uint64(capacity / paymentAmt)
|
||||
numUpdates = numSends + numRecvs // 200 updates
|
||||
chanID = uint64(0)
|
||||
)
|
||||
|
||||
// Send money to the remote party until all funds are
|
||||
// depleted.
|
||||
sendHints := h.sendPayments(chanID, 0, numSends, paymentAmt)
|
||||
|
||||
// Now, sequentially receive the entire channel balance
|
||||
// from the remote party.
|
||||
recvHints := h.recvPayments(chanID, numSends, numUpdates, paymentAmt)
|
||||
|
||||
// Collect the hints generated by both sending and
|
||||
// receiving.
|
||||
hints := append(sendHints, recvHints...)
|
||||
|
||||
// Backup the channel's states the client.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Wait for all of the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, 3*time.Second)
|
||||
},
|
||||
},
|
||||
{
|
||||
// Asserts that the client is able to support multiple links.
|
||||
name: "multiple link backup",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 5,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
numUpdates = 5
|
||||
numChans = 10
|
||||
)
|
||||
|
||||
// Initialize and register an additional 9 channels.
|
||||
for id := uint64(1); id < 10; id++ {
|
||||
h.makeChannel(
|
||||
id, h.cfg.localBalance,
|
||||
h.cfg.remoteBalance,
|
||||
)
|
||||
h.registerChannel(id)
|
||||
}
|
||||
|
||||
// Generate the retributions for all 10 channels and
|
||||
// collect the breach hints.
|
||||
var hints []wtdb.BreachHint
|
||||
for id := uint64(0); id < 10; id++ {
|
||||
chanHints := h.advanceChannelN(id, numUpdates)
|
||||
hints = append(hints, chanHints...)
|
||||
}
|
||||
|
||||
// Provided all retributions to the client from all
|
||||
// channels.
|
||||
for id := uint64(0); id < 10; id++ {
|
||||
h.backupStates(id, 0, numUpdates, nil)
|
||||
}
|
||||
|
||||
// Test reliable flush under multi-client scenario.
|
||||
go h.client.Stop()
|
||||
|
||||
// Wait for all of the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, 10*time.Second)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestClient executes the client test suite, asserting the ability to backup
|
||||
// states in a number of failure cases and it's reliability during shutdown.
|
||||
func TestClient(t *testing.T) {
|
||||
for _, test := range clientTests {
|
||||
tc := test
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := newHarness(t, tc.cfg)
|
||||
defer h.server.Stop()
|
||||
defer h.client.ForceQuit()
|
||||
|
||||
tc.fn(h)
|
||||
})
|
||||
}
|
||||
}
|
35
watchtower/wtclient/errors.go
Normal file
35
watchtower/wtclient/errors.go
Normal file
@ -0,0 +1,35 @@
|
||||
package wtclient
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrClientExiting signals that the watchtower client is shutting down.
|
||||
ErrClientExiting = errors.New("watchtower client shutting down")
|
||||
|
||||
// ErrTowerCandidatesExhausted signals that a TowerCandidateIterator has
|
||||
// cycled through all available candidates.
|
||||
ErrTowerCandidatesExhausted = errors.New("exhausted all tower " +
|
||||
"candidates")
|
||||
|
||||
// ErrPermanentTowerFailure signals that the tower has reported that it
|
||||
// has permanently failed or the client believes this has happened based
|
||||
// on the tower's behavior.
|
||||
ErrPermanentTowerFailure = errors.New("permanent tower failure")
|
||||
|
||||
// ErrNegotiatorExiting signals that the SessionNegotiator is shutting
|
||||
// down.
|
||||
ErrNegotiatorExiting = errors.New("negotiator exiting")
|
||||
|
||||
// ErrNoTowerAddrs signals that the client could not be created because
|
||||
// we have no addresses with which we can reach a tower.
|
||||
ErrNoTowerAddrs = errors.New("no tower addresses")
|
||||
|
||||
// ErrFailedNegotiation signals that the session negotiator could not
|
||||
// acquire a new session as requested.
|
||||
ErrFailedNegotiation = errors.New("session negotiation unsuccessful")
|
||||
|
||||
// ErrUnregisteredChannel signals that the client was unable to backup a
|
||||
// revoked state becuase the channel had not been previously registered
|
||||
// with the client.
|
||||
ErrUnregisteredChannel = errors.New("channel is not registered")
|
||||
)
|
76
watchtower/wtclient/interface.go
Normal file
76
watchtower/wtclient/interface.go
Normal file
@ -0,0 +1,76 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/brontide"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||
)
|
||||
|
||||
// DB abstracts the required database operations required by the watchtower
|
||||
// client.
|
||||
type DB interface {
|
||||
// CreateTower initialize an address record used to communicate with a
|
||||
// watchtower. Each Tower is assigned a unique ID, that is used to
|
||||
// amortize storage costs of the public key when used by multiple
|
||||
// sessions.
|
||||
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
|
||||
|
||||
// CreateClientSession saves a newly negotiated client session to the
|
||||
// client's database. This enables the session to be used across
|
||||
// restarts.
|
||||
CreateClientSession(*wtdb.ClientSession) error
|
||||
|
||||
// ListClientSessions returns all sessions that have not yet been
|
||||
// exhausted. This is used on startup to find any sessions which may
|
||||
// still be able to accept state updates.
|
||||
ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||
|
||||
// FetchChanPkScripts returns a map of all sweep pkscripts for
|
||||
// registered channels. This is used on startup to cache the sweep
|
||||
// pkscripts of registered channels in memory.
|
||||
FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error)
|
||||
|
||||
// AddChanPkScript inserts a newly generated sweep pkscript for the
|
||||
// given channel.
|
||||
AddChanPkScript(lnwire.ChannelID, []byte) error
|
||||
|
||||
// MarkBackupIneligible records that the state identified by the
|
||||
// (channel id, commit height) tuple was ineligible for being backed up
|
||||
// under the current policy. This state can be retried later under a
|
||||
// different policy.
|
||||
MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error
|
||||
|
||||
// CommitUpdate writes the next state update for a particular
|
||||
// session, so that we can be sure to resend it after a restart if it
|
||||
// hasn't been ACK'd by the tower. The sequence number of the update
|
||||
// should be exactly one greater than the existing entry, and less that
|
||||
// or equal to the session's MaxUpdates.
|
||||
CommitUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
update *wtdb.CommittedUpdate) (uint16, error)
|
||||
|
||||
// AckUpdate records an acknowledgment from the watchtower that the
|
||||
// update identified by seqNum was received and saved. The returned
|
||||
// lastApplied will be recorded.
|
||||
AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error
|
||||
}
|
||||
|
||||
// Dial connects to an addr using the specified net and returns the connection
|
||||
// object.
|
||||
type Dial func(net, addr string) (net.Conn, error)
|
||||
|
||||
// AuthDialer connects to a remote node using an authenticated transport, such as
|
||||
// brontide. The dialer argument is used to specify a resolver, which allows
|
||||
// this method to be used over Tor or clear net connections.
|
||||
type AuthDialer func(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
|
||||
dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error)
|
||||
|
||||
// AuthDial is the watchtower client's default method of dialing.
|
||||
func AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
|
||||
dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error) {
|
||||
|
||||
return brontide.Dial(localPriv, netAddr, dialer)
|
||||
}
|
29
watchtower/wtclient/log.go
Normal file
29
watchtower/wtclient/log.go
Normal file
@ -0,0 +1,29 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized with no output filters. This
|
||||
// means the package will not perform any logging by default until the caller
|
||||
// requests it.
|
||||
var log btclog.Logger
|
||||
|
||||
// The default amount of logging is none.
|
||||
func init() {
|
||||
UseLogger(build.NewSubLogger("WTCL", nil))
|
||||
}
|
||||
|
||||
// DisableLog disables all library log output. Logging output is disabled
|
||||
// by default until UseLogger is called.
|
||||
func DisableLog() {
|
||||
UseLogger(btclog.Disabled)
|
||||
}
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
// This should be used in preference to SetLogWriter if the caller is also
|
||||
// using btclog.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
451
watchtower/wtclient/session_negotiator.go
Normal file
451
watchtower/wtclient/session_negotiator.go
Normal file
@ -0,0 +1,451 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||
)
|
||||
|
||||
// SessionNegotiator is an interface for asynchronously requesting new sessions.
|
||||
type SessionNegotiator interface {
|
||||
// RequestSession signals to the session negotiator that the client
|
||||
// needs another session. Once the session is negotiated, it should be
|
||||
// returned via NewSessions.
|
||||
RequestSession()
|
||||
|
||||
// NewSessions is a read-only channel where newly negotiated sessions
|
||||
// will be delivered.
|
||||
NewSessions() <-chan *wtdb.ClientSession
|
||||
|
||||
// Start safely initializes the session negotiator.
|
||||
Start() error
|
||||
|
||||
// Stop safely shuts down the session negotiator.
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// NegotiatorConfig provides access to the resources required by a
|
||||
// SessionNegotiator to faithfully carry out its duties. All nil-able field must
|
||||
// be initialized.
|
||||
type NegotiatorConfig struct {
|
||||
// DB provides access to a persistent storage medium used by the tower
|
||||
// to properly allocate session ephemeral keys and record successfully
|
||||
// negotiated sessions.
|
||||
DB DB
|
||||
|
||||
// Candidates is an abstract set of tower candidates that the negotiator
|
||||
// will traverse serially when attempting to negotiate a new session.
|
||||
Candidates TowerCandidateIterator
|
||||
|
||||
// Policy defines the session policy that will be proposed to towers
|
||||
// when attempting to negotiate a new session. This policy will be used
|
||||
// across all negotiation proposals for the lifetime of the negotiator.
|
||||
Policy wtpolicy.Policy
|
||||
|
||||
// Dial initiates an outbound brontide connection to the given address
|
||||
// using a specified private key. The peer is returned in the event of a
|
||||
// successful connection.
|
||||
Dial func(*btcec.PrivateKey, *lnwire.NetAddress) (wtserver.Peer, error)
|
||||
|
||||
// SendMessage writes a wtwire message to remote peer.
|
||||
SendMessage func(wtserver.Peer, wtwire.Message) error
|
||||
|
||||
// ReadMessage reads a message from a remote peer and returns the
|
||||
// decoded wtwire message.
|
||||
ReadMessage func(wtserver.Peer) (wtwire.Message, error)
|
||||
|
||||
// ChainHash the genesis hash identifying the chain for any negotiated
|
||||
// sessions. Any state updates sent to that session should also
|
||||
// originate from this chain.
|
||||
ChainHash chainhash.Hash
|
||||
|
||||
// MinBackoff defines the initial backoff applied by the session
|
||||
// negotiator after all tower candidates have been exhausted and
|
||||
// reattempting negotiation with the same set of candidates. Subsequent
|
||||
// backoff durations will grow exponentially.
|
||||
MinBackoff time.Duration
|
||||
|
||||
// MaxBackoff defines the maximum backoff applied by the session
|
||||
// negotiator after all tower candidates have been exhausted and
|
||||
// reattempting negotation with the same set of candidates. If the
|
||||
// exponential backoff produces a timeout greater than this value, the
|
||||
// backoff duration will be clamped to MaxBackoff.
|
||||
MaxBackoff time.Duration
|
||||
}
|
||||
|
||||
// sessionNegotiator is concrete SessionNegotiator that is able to request new
|
||||
// sessions from a set of candidate towers asynchronously and return successful
|
||||
// sessions to the primary client.
|
||||
type sessionNegotiator struct {
|
||||
started sync.Once
|
||||
stopped sync.Once
|
||||
|
||||
localInit *wtwire.Init
|
||||
|
||||
cfg *NegotiatorConfig
|
||||
|
||||
dispatcher chan struct{}
|
||||
newSessions chan *wtdb.ClientSession
|
||||
successfulNegotiations chan *wtdb.ClientSession
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// Compile-time constraint to ensure a *sessionNegotiator implements the
|
||||
// SessionNegotiator interface.
|
||||
var _ SessionNegotiator = (*sessionNegotiator)(nil)
|
||||
|
||||
// newSessionNegotiator initializes a fresh sessionNegotiator instance.
|
||||
func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator {
|
||||
localInit := wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||
cfg.ChainHash,
|
||||
)
|
||||
|
||||
return &sessionNegotiator{
|
||||
cfg: cfg,
|
||||
localInit: localInit,
|
||||
dispatcher: make(chan struct{}, 1),
|
||||
newSessions: make(chan *wtdb.ClientSession),
|
||||
successfulNegotiations: make(chan *wtdb.ClientSession),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start safely starts up the sessionNegotiator.
|
||||
func (n *sessionNegotiator) Start() error {
|
||||
n.started.Do(func() {
|
||||
log.Debugf("Starting session negotiator")
|
||||
|
||||
n.wg.Add(1)
|
||||
go n.negotiationDispatcher()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop safely shutsdown the sessionNegotiator.
|
||||
func (n *sessionNegotiator) Stop() error {
|
||||
n.stopped.Do(func() {
|
||||
log.Debugf("Stopping session negotiator")
|
||||
|
||||
close(n.quit)
|
||||
n.wg.Wait()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewSessions returns a receive-only channel from which newly negotiated
|
||||
// sessions will be returned.
|
||||
func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession {
|
||||
return n.newSessions
|
||||
}
|
||||
|
||||
// RequestSession sends a request to the sessionNegotiator to begin requesting a
|
||||
// new session. If one is already in the process of being negotiated, the
|
||||
// request will be ignored.
|
||||
func (n *sessionNegotiator) RequestSession() {
|
||||
select {
|
||||
case n.dispatcher <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// negotiationDispatcher acts as the primary event loop for the
|
||||
// sessionNegotiator, coordinating requests for more sessions and dispatching
|
||||
// attempts to negotiate them from a list of candidates.
|
||||
func (n *sessionNegotiator) negotiationDispatcher() {
|
||||
defer n.wg.Done()
|
||||
|
||||
var pendingNegotiations int
|
||||
for {
|
||||
select {
|
||||
case <-n.dispatcher:
|
||||
pendingNegotiations++
|
||||
|
||||
if pendingNegotiations > 1 {
|
||||
log.Debugf("Already negotiating session, " +
|
||||
"waiting for existing negotiation to " +
|
||||
"complete")
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(conner): consider reusing good towers
|
||||
|
||||
log.Debugf("Dispatching session negotiation")
|
||||
|
||||
n.wg.Add(1)
|
||||
go n.negotiate()
|
||||
|
||||
case session := <-n.successfulNegotiations:
|
||||
select {
|
||||
case n.newSessions <- session:
|
||||
pendingNegotiations--
|
||||
case <-n.quit:
|
||||
return
|
||||
}
|
||||
|
||||
if pendingNegotiations > 0 {
|
||||
log.Debugf("Dispatching pending session " +
|
||||
"negotiation")
|
||||
|
||||
n.wg.Add(1)
|
||||
go n.negotiate()
|
||||
}
|
||||
|
||||
case <-n.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// negotiate handles the process of iterating through potential tower candidates
|
||||
// and attempting to negotiate a new session until a successful negotiation
|
||||
// occurs. If the candidate iterator becomes exhausted because none were
|
||||
// successful, this method will back off exponentially up to the configured max
|
||||
// backoff. This method will continue trying until a negotiation is succesful
|
||||
// before returning the negotiated session to the dispatcher via the succeed
|
||||
// channel.
|
||||
//
|
||||
// NOTE: This method MUST be run as a goroutine.
|
||||
func (n *sessionNegotiator) negotiate() {
|
||||
defer n.wg.Done()
|
||||
|
||||
// On the first pass, initialize the backoff to our configured min
|
||||
// backoff.
|
||||
backoff := n.cfg.MinBackoff
|
||||
|
||||
retryWithBackoff:
|
||||
// If we are retrying, wait out the delay before continuing.
|
||||
if backoff > 0 {
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-n.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Before attempting a bout of session negotiation, reset the candidate
|
||||
// iterator to ensure the results are fresh.
|
||||
n.cfg.Candidates.Reset()
|
||||
for {
|
||||
// Pull the next candidate from our list of addresses.
|
||||
tower, err := n.cfg.Candidates.Next()
|
||||
if err != nil {
|
||||
// We've run out of addresses, double and clamp backoff.
|
||||
backoff *= 2
|
||||
if backoff > n.cfg.MaxBackoff {
|
||||
backoff = n.cfg.MaxBackoff
|
||||
}
|
||||
|
||||
log.Debugf("Unable to get new tower candidate, "+
|
||||
"retrying after %v -- reason: %v", backoff, err)
|
||||
|
||||
goto retryWithBackoff
|
||||
}
|
||||
|
||||
log.Debugf("Attempting session negotiation with tower=%x",
|
||||
tower.IdentityKey.SerializeCompressed())
|
||||
|
||||
// We'll now attempt the CreateSession dance with the tower to
|
||||
// get a new session, trying all addresses if necessary.
|
||||
err = n.createSession(tower)
|
||||
if err != nil {
|
||||
log.Debugf("Session negotiation with tower=%x "+
|
||||
"failed, trying again -- reason: %v",
|
||||
tower.IdentityKey.SerializeCompressed(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Success.
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// createSession takes a tower an attempts to negotiate a session using any of
|
||||
// its stored addresses. This method returns after the first successful
|
||||
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If
|
||||
// the tower has no addresses, ErrNoTowerAddrs is returned.
|
||||
func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error {
|
||||
// If the tower has no addresses, there's nothing we can do.
|
||||
if len(tower.Addresses) == 0 {
|
||||
return ErrNoTowerAddrs
|
||||
}
|
||||
|
||||
// TODO(conner): create with hdkey at random index
|
||||
sessionPrivKey, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(conner): write towerAddr+privkey
|
||||
|
||||
for _, lnAddr := range tower.LNAddrs() {
|
||||
err = n.tryAddress(sessionPrivKey, tower, lnAddr)
|
||||
switch {
|
||||
case err == ErrPermanentTowerFailure:
|
||||
// TODO(conner): report to iterator? can then be reset
|
||||
// with restart
|
||||
fallthrough
|
||||
|
||||
case err != nil:
|
||||
log.Debugf("Request for session negotiation with "+
|
||||
"tower=%s failed, trying again -- reason: "+
|
||||
"%v", lnAddr, err)
|
||||
continue
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrFailedNegotiation
|
||||
}
|
||||
|
||||
// tryAddress executes a single create session dance using the given address.
|
||||
// The address should belong to the tower's set of addresses. This method only
|
||||
// returns true if all steps succeed and the new session has been persisted, and
|
||||
// fails otherwise.
|
||||
func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
|
||||
tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
|
||||
|
||||
// Connect to the tower address using our generated session key.
|
||||
conn, err := n.cfg.Dial(privKey, lnAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send local Init message.
|
||||
err = n.cfg.SendMessage(conn, n.localInit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to send Init: %v", err)
|
||||
}
|
||||
|
||||
// Receive remote Init message.
|
||||
remoteMsg, err := n.cfg.ReadMessage(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read Init: %v", err)
|
||||
}
|
||||
|
||||
// Check that returned message is wtwire.Init.
|
||||
remoteInit, ok := remoteMsg.(*wtwire.Init)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected Init, got %T in reply", remoteMsg)
|
||||
}
|
||||
|
||||
// Verify the watchtower's remote Init message against our own.
|
||||
err = n.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy := n.cfg.Policy
|
||||
createSession := &wtwire.CreateSession{
|
||||
BlobType: policy.BlobType,
|
||||
MaxUpdates: policy.MaxUpdates,
|
||||
RewardBase: policy.RewardBase,
|
||||
RewardRate: policy.RewardRate,
|
||||
SweepFeeRate: policy.SweepFeeRate,
|
||||
}
|
||||
|
||||
// Send CreateSession message.
|
||||
err = n.cfg.SendMessage(conn, createSession)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to send CreateSession: %v", err)
|
||||
}
|
||||
|
||||
// Receive CreateSessionReply message.
|
||||
remoteMsg, err = n.cfg.ReadMessage(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read CreateSessionReply: %v", err)
|
||||
}
|
||||
|
||||
// Check that returned message is wtwire.CreateSessionReply.
|
||||
createSessionReply, ok := remoteMsg.(*wtwire.CreateSessionReply)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected CreateSessionReply, got %T in "+
|
||||
"reply", remoteMsg)
|
||||
}
|
||||
|
||||
switch createSessionReply.Code {
|
||||
case wtwire.CodeOK, wtwire.CreateSessionCodeAlreadyExists:
|
||||
|
||||
// TODO(conner): add last-applied to create session reply to
|
||||
// handle case where we lose state, session already exists, and
|
||||
// we want to possibly resume using the session
|
||||
|
||||
// TODO(conner): validate reward address
|
||||
rewardPkScript := createSessionReply.Data
|
||||
|
||||
sessionID := wtdb.NewSessionIDFromPubKey(
|
||||
privKey.PubKey(),
|
||||
)
|
||||
clientSession := &wtdb.ClientSession{
|
||||
TowerID: tower.ID,
|
||||
Tower: tower,
|
||||
SessionPrivKey: privKey, // remove after using HD keys
|
||||
ID: sessionID,
|
||||
Policy: n.cfg.Policy,
|
||||
SeqNum: 0,
|
||||
RewardPkScript: rewardPkScript,
|
||||
}
|
||||
|
||||
err = n.cfg.DB.CreateClientSession(clientSession)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to persist ClientSession: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
log.Debugf("New session negotiated with %s, policy: %s",
|
||||
lnAddr, clientSession.Policy)
|
||||
|
||||
// We have a newly negotiated session, return it to the
|
||||
// dispatcher so that it can update how many outstanding
|
||||
// negotiation requests we have.
|
||||
select {
|
||||
case n.successfulNegotiations <- clientSession:
|
||||
return nil
|
||||
case <-n.quit:
|
||||
return ErrNegotiatorExiting
|
||||
}
|
||||
|
||||
// TODO(conner): handle error codes properly
|
||||
case wtwire.CreateSessionCodeRejectBlobType:
|
||||
return fmt.Errorf("tower rejected blob type: %v",
|
||||
policy.BlobType)
|
||||
|
||||
case wtwire.CreateSessionCodeRejectMaxUpdates:
|
||||
return fmt.Errorf("tower rejected max updates: %v",
|
||||
policy.MaxUpdates)
|
||||
|
||||
case wtwire.CreateSessionCodeRejectRewardRate:
|
||||
// The tower rejected the session because of the reward rate. If
|
||||
// we didn't request a reward session, we'll treat this as a
|
||||
// permanent tower failure.
|
||||
if !policy.BlobType.Has(blob.FlagReward) {
|
||||
return ErrPermanentTowerFailure
|
||||
}
|
||||
|
||||
return fmt.Errorf("tower rejected reward rate: %v",
|
||||
policy.RewardRate)
|
||||
|
||||
case wtwire.CreateSessionCodeRejectSweepFeeRate:
|
||||
return fmt.Errorf("tower rejected sweep fee rate: %v",
|
||||
policy.SweepFeeRate)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("received unhandled error code: %v",
|
||||
createSessionReply.Code)
|
||||
}
|
||||
}
|
688
watchtower/wtclient/session_queue.go
Normal file
688
watchtower/wtclient/session_queue.go
Normal file
@ -0,0 +1,688 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||
)
|
||||
|
||||
// retryInterval is the default duration we will wait between attempting to
|
||||
// connect back out to a tower if the prior state update failed.
|
||||
const retryInterval = 2 * time.Second
|
||||
|
||||
// reserveStatus is an enum that signals how full a particular session is.
|
||||
type reserveStatus uint8
|
||||
|
||||
const (
|
||||
// reserveAvailable indicates that the session has space for at least
|
||||
// one more backup.
|
||||
reserveAvailable reserveStatus = iota
|
||||
|
||||
// reserveExhausted indicates that all slots in the session have been
|
||||
// allocated.
|
||||
reserveExhausted
|
||||
)
|
||||
|
||||
// sessionQueueConfig bundles the resources required by the sessionQueue to
|
||||
// perform its duties. All entries MUST be non-nil.
|
||||
type sessionQueueConfig struct {
|
||||
// ClientSession provides access to the negotiated session parameters
|
||||
// and updating its persistent storage.
|
||||
ClientSession *wtdb.ClientSession
|
||||
|
||||
// ChainHash identifies the chain for which the session's justice
|
||||
// transactions are targeted.
|
||||
ChainHash chainhash.Hash
|
||||
|
||||
// Dial allows the client to dial the tower using it's public key and
|
||||
// net address.
|
||||
Dial func(*btcec.PrivateKey,
|
||||
*lnwire.NetAddress) (wtserver.Peer, error)
|
||||
|
||||
// SendMessage encodes, encrypts, and writes a message to the given peer.
|
||||
SendMessage func(wtserver.Peer, wtwire.Message) error
|
||||
|
||||
// ReadMessage receives, decypts, and decodes a message from the given
|
||||
// peer.
|
||||
ReadMessage func(wtserver.Peer) (wtwire.Message, error)
|
||||
|
||||
// Signer facilitates signing of inputs, used to construct the witnesses
|
||||
// for justice transaction inputs.
|
||||
Signer input.Signer
|
||||
|
||||
// DB provides access to the client's stable storage.
|
||||
DB DB
|
||||
|
||||
// MinBackoff defines the initial backoff applied by the session
|
||||
// queue before reconnecting to the tower after a failed or partially
|
||||
// successful batch is sent. Subsequent backoff durations will grow
|
||||
// exponentially up until MaxBackoff.
|
||||
MinBackoff time.Duration
|
||||
|
||||
// MaxBackoff defines the maximum backoff applied by the session
|
||||
// queue before reconnecting to the tower after a failed or partially
|
||||
// successful batch is sent. If the exponential backoff produces a
|
||||
// timeout greater than this value, the backoff duration will be clamped
|
||||
// to MaxBackoff.
|
||||
MaxBackoff time.Duration
|
||||
}
|
||||
|
||||
// sessionQueue implements a reliable queue that will encrypt and send accepted
|
||||
// backups to the watchtower specified in the config's ClientSession. Calling
|
||||
// Quit will attempt to perform a clean shutdown by receiving an ACK from the
|
||||
// tower for all pending backups before exiting. The clean shutdown can be
|
||||
// aborted by using ForceQuit, which will attempt to shutdown the queue
|
||||
// immediately.
|
||||
type sessionQueue struct {
|
||||
started sync.Once
|
||||
stopped sync.Once
|
||||
forced sync.Once
|
||||
|
||||
cfg *sessionQueueConfig
|
||||
|
||||
commitQueue *list.List
|
||||
pendingQueue *list.List
|
||||
queueMtx sync.Mutex
|
||||
queueCond *sync.Cond
|
||||
|
||||
localInit *wtwire.Init
|
||||
towerAddr *lnwire.NetAddress
|
||||
|
||||
seqNum uint16
|
||||
|
||||
retryBackoff time.Duration
|
||||
|
||||
quit chan struct{}
|
||||
forceQuit chan struct{}
|
||||
shutdown chan struct{}
|
||||
}
|
||||
|
||||
// newSessionQueue intiializes a fresh sessionQueue.
|
||||
func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
|
||||
localInit := wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||
cfg.ChainHash,
|
||||
)
|
||||
|
||||
towerAddr := &lnwire.NetAddress{
|
||||
IdentityKey: cfg.ClientSession.Tower.IdentityKey,
|
||||
Address: cfg.ClientSession.Tower.Addresses[0],
|
||||
}
|
||||
|
||||
sq := &sessionQueue{
|
||||
cfg: cfg,
|
||||
commitQueue: list.New(),
|
||||
pendingQueue: list.New(),
|
||||
localInit: localInit,
|
||||
towerAddr: towerAddr,
|
||||
seqNum: cfg.ClientSession.SeqNum,
|
||||
retryBackoff: cfg.MinBackoff,
|
||||
quit: make(chan struct{}),
|
||||
forceQuit: make(chan struct{}),
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
sq.queueCond = sync.NewCond(&sq.queueMtx)
|
||||
|
||||
sq.restoreCommittedUpdates()
|
||||
|
||||
return sq
|
||||
}
|
||||
|
||||
// Start idempotently starts the sessionQueue so that it can begin accepting
|
||||
// backups.
|
||||
func (q *sessionQueue) Start() {
|
||||
q.started.Do(func() {
|
||||
// TODO(conner): load prior committed state updates from disk an
|
||||
// populate in queue.
|
||||
|
||||
go q.sessionManager()
|
||||
})
|
||||
}
|
||||
|
||||
// Stop idempotently stops the sessionQueue by initiating a clean shutdown that
|
||||
// will clear all pending tasks in the queue before returning to the caller.
|
||||
func (q *sessionQueue) Stop() {
|
||||
q.stopped.Do(func() {
|
||||
log.Debugf("Stopping session queue %s", q.ID())
|
||||
|
||||
close(q.quit)
|
||||
q.signalUntilShutdown()
|
||||
|
||||
// Skip log if we also force quit.
|
||||
select {
|
||||
case <-q.forceQuit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
log.Debugf("Session queue %s successfully stopped", q.ID())
|
||||
})
|
||||
}
|
||||
|
||||
// ForceQuit idempotently aborts any clean shutdown in progress and returns to
|
||||
// he caller after all lingering goroutines have spun down.
|
||||
func (q *sessionQueue) ForceQuit() {
|
||||
q.forced.Do(func() {
|
||||
log.Infof("Force quitting session queue %s", q.ID())
|
||||
|
||||
close(q.forceQuit)
|
||||
q.signalUntilShutdown()
|
||||
|
||||
log.Infof("Session queue %s unclean shutdown complete", q.ID())
|
||||
})
|
||||
}
|
||||
|
||||
// ID returns the wtdb.SessionID for the queue, which can be used to uniquely
|
||||
// identify this a particular queue.
|
||||
func (q *sessionQueue) ID() *wtdb.SessionID {
|
||||
return &q.cfg.ClientSession.ID
|
||||
}
|
||||
|
||||
// AcceptTask attempts to queue a backupTask for delivery to the sessionQueue's
|
||||
// tower. The session will only be accepted if the queue is not already
|
||||
// exhausted and the task is successfully bound to the ClientSession.
|
||||
func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
|
||||
q.queueCond.L.Lock()
|
||||
|
||||
// Examine the current reserve status of the session queue.
|
||||
curStatus := q.reserveStatus()
|
||||
switch curStatus {
|
||||
|
||||
// The session queue is exhausted, and cannot accept the task because it
|
||||
// is full. Reject the task such that it can be tried against a
|
||||
// different session.
|
||||
case reserveExhausted:
|
||||
q.queueCond.L.Unlock()
|
||||
return curStatus, false
|
||||
|
||||
// The session queue is not exhausted. Compute the sweep and reward
|
||||
// outputs as a function of the session parameters. If the outputs are
|
||||
// dusty or uneconomical to backup, the task is rejected and will not be
|
||||
// tried again.
|
||||
//
|
||||
// TODO(conner): queue backups and retry with different session params.
|
||||
case reserveAvailable:
|
||||
err := task.bindSession(q.cfg.ClientSession)
|
||||
if err != nil {
|
||||
q.queueCond.L.Unlock()
|
||||
log.Debugf("SessionQueue %s rejected backup chanid=%s "+
|
||||
"commit-height=%d: %v", q.ID(), task.id.ChanID,
|
||||
task.id.CommitHeight, err)
|
||||
return curStatus, false
|
||||
}
|
||||
}
|
||||
|
||||
// The sweep and reward outputs satisfy the session's policy, queue the
|
||||
// task for final signing and delivery.
|
||||
q.pendingQueue.PushBack(task)
|
||||
|
||||
// Finally, compute the session's *new* reserve status. This will be
|
||||
// used by the client to determine if it can continue using this session
|
||||
// queue, or if it should negotiate a new one.
|
||||
newStatus := q.reserveStatus()
|
||||
q.queueCond.L.Unlock()
|
||||
|
||||
q.queueCond.Signal()
|
||||
|
||||
return newStatus, true
|
||||
}
|
||||
|
||||
// updateWithSeqNum stores a CommittedUpdate with its assigned sequence number.
|
||||
// This allows committed updates to be sorted after a restart, and added to the
|
||||
// commitQueue in the proper order for delivery.
|
||||
type updateWithSeqNum struct {
|
||||
seqNum uint16
|
||||
update *wtdb.CommittedUpdate
|
||||
}
|
||||
|
||||
// restoreCommittedUpdates processes any CommittedUpdates loaded on startup by
|
||||
// sorting them in ascending order of sequence numbers and adding them to the
|
||||
// commitQueue. These will be sent before any pending updates are processed.
|
||||
func (q *sessionQueue) restoreCommittedUpdates() {
|
||||
committedUpdates := q.cfg.ClientSession.CommittedUpdates
|
||||
|
||||
// Construct and unordered slice of all committed updates with their
|
||||
// assigned sequence numbers.
|
||||
sortedUpdates := make([]updateWithSeqNum, 0, len(committedUpdates))
|
||||
for seqNum, update := range committedUpdates {
|
||||
sortedUpdates = append(sortedUpdates, updateWithSeqNum{
|
||||
seqNum: seqNum,
|
||||
update: update,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort the resulting slice by increasing sequence number.
|
||||
sort.Slice(sortedUpdates, func(i, j int) bool {
|
||||
return sortedUpdates[i].seqNum < sortedUpdates[j].seqNum
|
||||
})
|
||||
|
||||
// Finally, add the sorted, committed updates to he commitQueue. These
|
||||
// updates will be prioritized before any new tasks are assigned to the
|
||||
// sessionQueue. The queue will begin uploading any tasks in the
|
||||
// commitQueue as soon as it is started, e.g. during client
|
||||
// initialization when detecting that this session has unacked updates.
|
||||
for _, update := range sortedUpdates {
|
||||
q.commitQueue.PushBack(update)
|
||||
}
|
||||
}
|
||||
|
||||
// sessionManager is the primary event loop for the sessionQueue, and is
|
||||
// responsible for encrypting and sending accepted tasks to the tower.
|
||||
func (q *sessionQueue) sessionManager() {
|
||||
defer close(q.shutdown)
|
||||
|
||||
for {
|
||||
q.queueCond.L.Lock()
|
||||
for q.commitQueue.Len() == 0 &&
|
||||
q.pendingQueue.Len() == 0 {
|
||||
|
||||
q.queueCond.Wait()
|
||||
|
||||
select {
|
||||
case <-q.quit:
|
||||
if q.commitQueue.Len() == 0 &&
|
||||
q.pendingQueue.Len() == 0 {
|
||||
q.queueCond.L.Unlock()
|
||||
return
|
||||
}
|
||||
case <-q.forceQuit:
|
||||
q.queueCond.L.Unlock()
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
q.queueCond.L.Unlock()
|
||||
|
||||
// Exit immediately if a force quit has been requested. If the
|
||||
// either of the queues still has state updates to send to the
|
||||
// tower, we may never exit in the above case if we are unable
|
||||
// to reach the tower for some reason.
|
||||
select {
|
||||
case <-q.forceQuit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Initiate a new connection to the watchtower and attempt to
|
||||
// drain all pending tasks.
|
||||
q.drainBackups()
|
||||
}
|
||||
}
|
||||
|
||||
// drainBackups attempts to send all pending updates in the queue to the tower.
|
||||
func (q *sessionQueue) drainBackups() {
|
||||
// First, check that we are able to dial this session's tower.
|
||||
conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionPrivKey, q.towerAddr)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to dial watchtower at %v: %v",
|
||||
q.towerAddr, err)
|
||||
|
||||
q.increaseBackoff()
|
||||
select {
|
||||
case <-time.After(q.retryBackoff):
|
||||
case <-q.forceQuit:
|
||||
}
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Begin draining the queue of pending state updates. Before the first
|
||||
// update is sent, we will precede it with an Init message. If the first
|
||||
// is successful, subsequent updates can be streamed without sending an
|
||||
// Init.
|
||||
for sendInit := true; ; sendInit = false {
|
||||
// Generate the next state update to upload to the tower. This
|
||||
// method will first proceed in dequeueing committed updates
|
||||
// before attempting to dequeue any pending updates.
|
||||
stateUpdate, isPending, err := q.nextStateUpdate()
|
||||
if err != nil {
|
||||
log.Errorf("Unable to get next state update: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Now, send the state update to the tower and wait for a reply.
|
||||
err = q.sendStateUpdate(
|
||||
conn, stateUpdate, q.localInit, sendInit, isPending,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to send state update: %v", err)
|
||||
|
||||
q.increaseBackoff()
|
||||
select {
|
||||
case <-time.After(q.retryBackoff):
|
||||
case <-q.forceQuit:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If the last task was backed up successfully, we'll exit and
|
||||
// continue once more tasks are added to the queue. We'll also
|
||||
// clear any accumulated backoff as this batch was able to be
|
||||
// sent reliably.
|
||||
if stateUpdate.IsComplete == 1 {
|
||||
q.resetBackoff()
|
||||
return
|
||||
}
|
||||
|
||||
// Always apply a small delay between sends, which makes the
|
||||
// unit tests more reliable. If we were requested to back off,
|
||||
// when we will do so.
|
||||
select {
|
||||
case <-time.After(time.Millisecond):
|
||||
case <-q.forceQuit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nextStateUpdate returns the next wtwire.StateUpdate to upload to the tower.
|
||||
// If any committed updates are present, this method will reconstruct the state
|
||||
// update from the committed update using the current last applied value found
|
||||
// in the database. Otherwise, it will select the next pending update, craft the
|
||||
// payload, and commit an update before returning the state update to send. The
|
||||
// boolean value in the response is true if the state update is taken from the
|
||||
// pending queue, allowing the caller to remove the update from either the
|
||||
// commit or pending queue if the update is successfully acked.
|
||||
func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
|
||||
var (
|
||||
seqNum uint16
|
||||
update *wtdb.CommittedUpdate
|
||||
isLast bool
|
||||
isPending bool
|
||||
)
|
||||
|
||||
q.queueCond.L.Lock()
|
||||
switch {
|
||||
|
||||
// If the commit queue is non-empty, parse the next committed update.
|
||||
case q.commitQueue.Len() > 0:
|
||||
next := q.commitQueue.Front()
|
||||
updateWithSeq := next.Value.(updateWithSeqNum)
|
||||
|
||||
seqNum = updateWithSeq.seqNum
|
||||
update = updateWithSeq.update
|
||||
|
||||
// If this is the last item in the commit queue and no items
|
||||
// exist in the pending queue, we will use the IsComplete flag
|
||||
// in the StateUpdate to signal that the tower can release the
|
||||
// connection after replying to free up resources.
|
||||
isLast = q.commitQueue.Len() == 1 && q.pendingQueue.Len() == 0
|
||||
q.queueCond.L.Unlock()
|
||||
|
||||
log.Debugf("Reprocessing committed state update for "+
|
||||
"session=%s seqnum=%d", q.ID(), seqNum)
|
||||
|
||||
// Otherwise, craft and commit the next update from the pending queue.
|
||||
default:
|
||||
isPending = true
|
||||
|
||||
// Determine the current sequence number to apply for this
|
||||
// pending update.
|
||||
seqNum = q.seqNum + 1
|
||||
|
||||
// Obtain the next task from the queue.
|
||||
next := q.pendingQueue.Front()
|
||||
task := next.Value.(*backupTask)
|
||||
|
||||
// If this is the last item in the pending queue, we will use
|
||||
// the IsComplete flag in the StateUpdate to signal that the
|
||||
// tower can release the connection after replying to free up
|
||||
// resources.
|
||||
isLast = q.pendingQueue.Len() == 1
|
||||
q.queueCond.L.Unlock()
|
||||
|
||||
hint, encBlob, err := task.craftSessionPayload(q.cfg.Signer)
|
||||
if err != nil {
|
||||
// TODO(conner): mark will not send
|
||||
return nil, false, fmt.Errorf("unable to craft "+
|
||||
"session payload: %v", err)
|
||||
}
|
||||
// TODO(conner): special case other obscure errors
|
||||
|
||||
update = &wtdb.CommittedUpdate{
|
||||
BackupID: task.id,
|
||||
Hint: hint,
|
||||
EncryptedBlob: encBlob,
|
||||
}
|
||||
|
||||
log.Debugf("Committing state update for session=%s seqnum=%d",
|
||||
q.ID(), seqNum)
|
||||
}
|
||||
|
||||
// Before sending the task to the tower, commit the state update
|
||||
// to disk using the assigned sequence number. If this task has already
|
||||
// been committed, the call will succeed and only be used for the
|
||||
// purpose of obtaining the last applied value to send to the tower.
|
||||
//
|
||||
// This step ensures that if we crash before receiving an ack that we
|
||||
// will retransmit the same update. If the tower successfully received
|
||||
// the update from before, it will reply with an ACK regardless of what
|
||||
// we send the next time. This step ensures that if we reliably send the
|
||||
// same update for a given sequence number, to prevent us from thinking
|
||||
// we backed up a state when we instead backed up another.
|
||||
lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), seqNum, update)
|
||||
if err != nil {
|
||||
// TODO(conner): mark failed/reschedule
|
||||
return nil, false, fmt.Errorf("unable to commit state update "+
|
||||
"for session=%s seqnum=%d: %v", q.ID(), seqNum, err)
|
||||
}
|
||||
|
||||
stateUpdate := &wtwire.StateUpdate{
|
||||
SeqNum: seqNum,
|
||||
LastApplied: lastApplied,
|
||||
Hint: update.Hint,
|
||||
EncryptedBlob: update.EncryptedBlob,
|
||||
}
|
||||
|
||||
// Set the IsComplete flag if this is the last queued item.
|
||||
if isLast {
|
||||
stateUpdate.IsComplete = 1
|
||||
}
|
||||
|
||||
return stateUpdate, isPending, nil
|
||||
}
|
||||
|
||||
// sendStateUpdate sends a wtwire.StateUpdate to the watchtower and processes
|
||||
// the ACK before returning. If sendInit is true, this method will first send
|
||||
// the localInit message and verify that the tower supports our required feature
|
||||
// bits. And error is returned if any part of the send fails. The boolean return
|
||||
// variable indicates whether or not we should back off before attempting to
|
||||
// send the next state update.
|
||||
func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
|
||||
stateUpdate *wtwire.StateUpdate, localInit *wtwire.Init,
|
||||
sendInit, isPending bool) error {
|
||||
|
||||
// If this is the first message being sent to the tower, we must send an
|
||||
// Init message to establish that server supports the features we
|
||||
// require.
|
||||
if sendInit {
|
||||
// Send Init to tower.
|
||||
err := q.cfg.SendMessage(conn, q.localInit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive Init from tower.
|
||||
remoteMsg, err := q.cfg.ReadMessage(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
remoteInit, ok := remoteMsg.(*wtwire.Init)
|
||||
if !ok {
|
||||
return fmt.Errorf("watchtower responded with %T to "+
|
||||
"Init", remoteMsg)
|
||||
}
|
||||
|
||||
// Validate Init.
|
||||
err = q.localInit.CheckRemoteInit(
|
||||
remoteInit, wtwire.FeatureNames,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Send StateUpdate to tower.
|
||||
err := q.cfg.SendMessage(conn, stateUpdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive StateUpdate from tower.
|
||||
remoteMsg, err := q.cfg.ReadMessage(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply)
|
||||
if !ok {
|
||||
return fmt.Errorf("watchtower responded with %T to StateUpdate",
|
||||
remoteMsg)
|
||||
}
|
||||
|
||||
// Process the reply from the tower.
|
||||
switch stateUpdateReply.Code {
|
||||
|
||||
// The tower reported a successful update, validate the response and
|
||||
// record the last applied returned.
|
||||
case wtwire.CodeOK:
|
||||
|
||||
// TODO(conner): handle other error cases properly, ban towers, etc.
|
||||
default:
|
||||
err := fmt.Errorf("received error code %s in "+
|
||||
"StateUpdateReply from tower=%x session=%s",
|
||||
stateUpdateReply.Code,
|
||||
conn.RemotePub().SerializeCompressed(), q.ID())
|
||||
log.Warnf("Unable to upload state update: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
lastApplied := stateUpdateReply.LastApplied
|
||||
err = q.cfg.DB.AckUpdate(q.ID(), stateUpdate.SeqNum, lastApplied)
|
||||
switch {
|
||||
case err == wtdb.ErrUnallocatedLastApplied:
|
||||
// TODO(conner): borked watchtower
|
||||
err = fmt.Errorf("unable to ack update=%d session=%s: %v",
|
||||
stateUpdate.SeqNum, q.ID(), err)
|
||||
log.Errorf("Failed to ack update: %v", err)
|
||||
return err
|
||||
|
||||
case err == wtdb.ErrLastAppliedReversion:
|
||||
// TODO(conner): borked watchtower
|
||||
err = fmt.Errorf("unable to ack update=%d session=%s: %v",
|
||||
stateUpdate.SeqNum, q.ID(), err)
|
||||
log.Errorf("Failed to ack update: %v", err)
|
||||
return err
|
||||
|
||||
case err != nil:
|
||||
err = fmt.Errorf("unable to ack update=%d session=%s: %v",
|
||||
stateUpdate.SeqNum, q.ID(), err)
|
||||
log.Errorf("Failed to ack update: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Removing update session=%s seqnum=%d is_pending=%v "+
|
||||
"from memory", q.ID(), stateUpdate.SeqNum, isPending)
|
||||
|
||||
q.queueCond.L.Lock()
|
||||
if isPending {
|
||||
// If a pending update was successfully sent, increment the
|
||||
// sequence number and remove the item from the queue. This
|
||||
// ensures the total number of backups in the session remains
|
||||
// unchanged, which maintains the external view of the session's
|
||||
// reserve status.
|
||||
q.seqNum++
|
||||
q.pendingQueue.Remove(q.pendingQueue.Front())
|
||||
} else {
|
||||
// Otherwise, simply remove the update from the committed queue.
|
||||
// This has no effect on the queues reserve status since the
|
||||
// update had already been committed.
|
||||
q.commitQueue.Remove(q.commitQueue.Front())
|
||||
}
|
||||
q.queueCond.L.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reserveStatus returns a reserveStatus indicating whether or not the
|
||||
// sessionQueue can accept another task. reserveAvailable is returned when a
|
||||
// task can be accepted, and reserveExhausted is returned if the all slots in
|
||||
// the session have been allocated.
|
||||
//
|
||||
// NOTE: This method MUST be called with queueCond's exclusive lock held.
|
||||
func (q *sessionQueue) reserveStatus() reserveStatus {
|
||||
numPending := uint32(q.pendingQueue.Len())
|
||||
maxUpdates := uint32(q.cfg.ClientSession.Policy.MaxUpdates)
|
||||
|
||||
log.Debugf("SessionQueue %s reserveStatus seqnum=%d pending=%d "+
|
||||
"max-updates=%d", q.ID(), q.seqNum, numPending, maxUpdates)
|
||||
|
||||
if uint32(q.seqNum)+numPending < maxUpdates {
|
||||
return reserveAvailable
|
||||
}
|
||||
|
||||
return reserveExhausted
|
||||
|
||||
}
|
||||
|
||||
// resetBackoff returns the connection backoff the minimum configured backoff.
|
||||
func (q *sessionQueue) resetBackoff() {
|
||||
q.retryBackoff = q.cfg.MinBackoff
|
||||
}
|
||||
|
||||
// increaseBackoff doubles the current connection backoff, clamping to the
|
||||
// configured maximum backoff if it would exceed the limit.
|
||||
func (q *sessionQueue) increaseBackoff() {
|
||||
q.retryBackoff *= 2
|
||||
if q.retryBackoff > q.cfg.MaxBackoff {
|
||||
q.retryBackoff = q.cfg.MaxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
// signalUntilShutdown strobes the sessionQueue's condition variable until the
|
||||
// main event loop exits.
|
||||
func (q *sessionQueue) signalUntilShutdown() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Millisecond):
|
||||
q.queueCond.Signal()
|
||||
case <-q.shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sessionQueueSet maintains a mapping of SessionIDs to their corresponding
|
||||
// sessionQueue.
|
||||
type sessionQueueSet map[wtdb.SessionID]*sessionQueue
|
||||
|
||||
// Add inserts a sessionQueue into the sessionQueueSet.
|
||||
func (s *sessionQueueSet) Add(sessionQueue *sessionQueue) {
|
||||
(*s)[*sessionQueue.ID()] = sessionQueue
|
||||
}
|
||||
|
||||
// ApplyAndWait executes the nil-adic function returned from getApply for each
|
||||
// sessionQueue in the set in parallel, then waits for all of them to finish
|
||||
// before returning to the caller.
|
||||
func (s *sessionQueueSet) ApplyAndWait(getApply func(*sessionQueue) func()) {
|
||||
var wg sync.WaitGroup
|
||||
for _, sessionq := range *s {
|
||||
wg.Add(1)
|
||||
go func(sq *sessionQueue) {
|
||||
defer wg.Done()
|
||||
getApply(sq)()
|
||||
}(sessionq)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
51
watchtower/wtclient/stats.go
Normal file
51
watchtower/wtclient/stats.go
Normal file
@ -0,0 +1,51 @@
|
||||
package wtclient
|
||||
|
||||
import "fmt"
|
||||
|
||||
type clientStats struct {
|
||||
numTasksReceived int
|
||||
numTasksAccepted int
|
||||
numTasksIneligible int
|
||||
numSessionsAcquired int
|
||||
numSessionsExhausted int
|
||||
}
|
||||
|
||||
// taskReceived increments the number to backup requests the client has received
|
||||
// from active channels.
|
||||
func (s *clientStats) taskReceived() {
|
||||
s.numTasksReceived++
|
||||
}
|
||||
|
||||
// taskAccepted increments the number of tasks that have been assigned to active
|
||||
// session queues, and are awaiting upload to a tower.
|
||||
func (s *clientStats) taskAccepted() {
|
||||
s.numTasksAccepted++
|
||||
}
|
||||
|
||||
// taskIneligible increments the number of tasks that were unable to satisfy the
|
||||
// active session queue's policy. These can potentially be retried later, but
|
||||
// typically this means that the balance created dust outputs, so it may not be
|
||||
// worth backing up at all.
|
||||
func (s *clientStats) taskIneligible() {
|
||||
s.numTasksIneligible++
|
||||
}
|
||||
|
||||
// sessionAcquired increments the number of sessions that have been successfully
|
||||
// negotiated by the client during this execution.
|
||||
func (s *clientStats) sessionAcquired() {
|
||||
s.numSessionsAcquired++
|
||||
}
|
||||
|
||||
// sessionExhausted increments the number of session that have become full as a
|
||||
// result of accepting backup tasks.
|
||||
func (s *clientStats) sessionExhausted() {
|
||||
s.numSessionsExhausted++
|
||||
}
|
||||
|
||||
// String returns a human readable summary of the client's metrics.
|
||||
func (s clientStats) String() string {
|
||||
return fmt.Sprintf("tasks(received=%d accepted=%d ineligible=%d) "+
|
||||
"sessions(acquired=%d exhausted=%d)", s.numTasksReceived,
|
||||
s.numTasksAccepted, s.numTasksIneligible, s.numSessionsAcquired,
|
||||
s.numSessionsExhausted)
|
||||
}
|
185
watchtower/wtclient/task_pipeline.go
Normal file
185
watchtower/wtclient/task_pipeline.go
Normal file
@ -0,0 +1,185 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// taskPipeline implements a reliable, in-order queue that ensures its queue
|
||||
// fully drained before exiting. Stopping the taskPipeline prevents the pipeline
|
||||
// from accepting any further tasks, and will cause the pipeline to exit after
|
||||
// all updates have been delivered to the downstream receiver. If this process
|
||||
// hangs and is unable to make progress, users can optionally call ForceQuit to
|
||||
// abandon the reliable draining of the queue in order to permit shutdown.
|
||||
type taskPipeline struct {
|
||||
started sync.Once
|
||||
stopped sync.Once
|
||||
forced sync.Once
|
||||
|
||||
queueMtx sync.Mutex
|
||||
queueCond *sync.Cond
|
||||
queue *list.List
|
||||
|
||||
newBackupTasks chan *backupTask
|
||||
|
||||
quit chan struct{}
|
||||
forceQuit chan struct{}
|
||||
shutdown chan struct{}
|
||||
}
|
||||
|
||||
// newTaskPipeline initializes a new taskPipeline.
|
||||
func newTaskPipeline() *taskPipeline {
|
||||
rq := &taskPipeline{
|
||||
queue: list.New(),
|
||||
newBackupTasks: make(chan *backupTask),
|
||||
quit: make(chan struct{}),
|
||||
forceQuit: make(chan struct{}),
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
rq.queueCond = sync.NewCond(&rq.queueMtx)
|
||||
|
||||
return rq
|
||||
}
|
||||
|
||||
// Start spins up the taskPipeline, making it eligible to begin receiving backup
|
||||
// tasks and deliver them to the receiver of NewBackupTasks.
|
||||
func (q *taskPipeline) Start() {
|
||||
q.started.Do(func() {
|
||||
go q.queueManager()
|
||||
})
|
||||
}
|
||||
|
||||
// Stop begins a graceful shutdown of the taskPipeline. This method returns once
|
||||
// all backupTasks have been delivered via NewBackupTasks, or a ForceQuit causes
|
||||
// the delivery of pending tasks to be interrupted.
|
||||
func (q *taskPipeline) Stop() {
|
||||
q.stopped.Do(func() {
|
||||
log.Debugf("Stopping task pipeline")
|
||||
|
||||
close(q.quit)
|
||||
q.signalUntilShutdown()
|
||||
|
||||
// Skip log if we also force quit.
|
||||
select {
|
||||
case <-q.forceQuit:
|
||||
default:
|
||||
log.Debugf("Task pipeline stopped successfully")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ForceQuit signals the taskPipeline to immediately exit, dropping any
|
||||
// backupTasks that have not been delivered via NewBackupTasks.
|
||||
func (q *taskPipeline) ForceQuit() {
|
||||
q.forced.Do(func() {
|
||||
log.Infof("Force quitting task pipeline")
|
||||
|
||||
close(q.forceQuit)
|
||||
q.signalUntilShutdown()
|
||||
|
||||
log.Infof("Task pipeline unclean shutdown complete")
|
||||
})
|
||||
}
|
||||
|
||||
// NewBackupTasks returns a read-only channel for enqueue backupTasks. The
|
||||
// channel will be closed after a call to Stop and all pending tasks have been
|
||||
// delivered, or if a call to ForceQuit is called before the pending entries
|
||||
// have been drained.
|
||||
func (q *taskPipeline) NewBackupTasks() <-chan *backupTask {
|
||||
return q.newBackupTasks
|
||||
}
|
||||
|
||||
// QueueBackupTask enqueues a backupTask for reliable delivery to the consumer
|
||||
// of NewBackupTasks. If the taskPipeline is shutting down, ErrClientExiting is
|
||||
// returned. Otherwise, if QueueBackupTask returns nil it is guaranteed to be
|
||||
// delivered via NewBackupTasks unless ForceQuit is called before completion.
|
||||
func (q *taskPipeline) QueueBackupTask(task *backupTask) error {
|
||||
q.queueCond.L.Lock()
|
||||
select {
|
||||
|
||||
// Reject new tasks after quit has been signaled.
|
||||
case <-q.quit:
|
||||
q.queueCond.L.Unlock()
|
||||
return ErrClientExiting
|
||||
|
||||
// Reject new tasks after force quit has been signaled.
|
||||
case <-q.forceQuit:
|
||||
q.queueCond.L.Unlock()
|
||||
return ErrClientExiting
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
// Queue the new task and signal the queue's condition variable to wake up
|
||||
// the queueManager for processing.
|
||||
q.queue.PushBack(task)
|
||||
q.queueCond.L.Unlock()
|
||||
|
||||
q.queueCond.Signal()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// queueManager processes all incoming backup requests that get added via
|
||||
// QueueBackupTask. The manager will exit
|
||||
//
|
||||
// NOTE: This method MUST be run as a goroutine.
|
||||
func (q *taskPipeline) queueManager() {
|
||||
defer close(q.shutdown)
|
||||
defer close(q.newBackupTasks)
|
||||
|
||||
for {
|
||||
q.queueCond.L.Lock()
|
||||
for q.queue.Front() == nil {
|
||||
q.queueCond.Wait()
|
||||
|
||||
select {
|
||||
case <-q.quit:
|
||||
// Exit only after the queue has been fully drained.
|
||||
if q.queue.Len() == 0 {
|
||||
q.queueCond.L.Unlock()
|
||||
log.Debugf("Revoked state pipeline flushed.")
|
||||
return
|
||||
}
|
||||
|
||||
case <-q.forceQuit:
|
||||
q.queueCond.L.Unlock()
|
||||
log.Debugf("Revoked state pipeline force quit.")
|
||||
return
|
||||
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Pop the first element from the queue.
|
||||
e := q.queue.Front()
|
||||
task := q.queue.Remove(e).(*backupTask)
|
||||
q.queueCond.L.Unlock()
|
||||
|
||||
select {
|
||||
|
||||
// Backup task submitted to dispatcher. We don't select on quit to
|
||||
// ensure that we still drain tasks while shutting down.
|
||||
case q.newBackupTasks <- task:
|
||||
|
||||
// Force quit, return immediately to allow the client to exit.
|
||||
case <-q.forceQuit:
|
||||
log.Debugf("Revoked state pipeline force quit.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// signalUntilShutdown strobes the queue's condition variable to ensure the
|
||||
// queueManager reliably unblocks to check for the exit condition.
|
||||
func (q *taskPipeline) signalUntilShutdown() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Millisecond):
|
||||
q.queueCond.Signal()
|
||||
case <-q.shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
110
watchtower/wtdb/client_session.go
Normal file
110
watchtower/wtdb/client_session.go
Normal file
@ -0,0 +1,110 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrClientSessionNotFound signals that the requested client session
|
||||
// was not found in the database.
|
||||
ErrClientSessionNotFound = errors.New("client session not found")
|
||||
|
||||
// ErrUpdateAlreadyCommitted signals that the chosen sequence number has
|
||||
// already been committed to an update with a different breach hint.
|
||||
ErrUpdateAlreadyCommitted = errors.New("update already committed")
|
||||
|
||||
// ErrCommitUnorderedUpdate signals the client tried to commit a
|
||||
// sequence number other than the next unallocated sequence number.
|
||||
ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic")
|
||||
|
||||
// ErrCommittedUpdateNotFound signals that the tower tried to ACK a
|
||||
// sequence number that has not yet been allocated by the client.
|
||||
ErrCommittedUpdateNotFound = errors.New("committed update not found")
|
||||
|
||||
// ErrUnallocatedLastApplied signals that the tower tried to provide a
|
||||
// LastApplied value greater than any allocated sequence number.
|
||||
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
|
||||
"greater than allocated seqnum")
|
||||
)
|
||||
|
||||
// ClientSession encapsulates a SessionInfo returned from a successful
|
||||
// session negotiation, and also records the tower and ephemeral secret used for
|
||||
// communicating with the tower.
|
||||
type ClientSession struct {
|
||||
// ID is the client's public key used when authenticating with the
|
||||
// tower.
|
||||
ID SessionID
|
||||
|
||||
// SeqNum is the next unallocated sequence number that can be sent to
|
||||
// the tower.
|
||||
SeqNum uint16
|
||||
|
||||
// TowerLastApplied the last last-applied the tower has echoed back.
|
||||
TowerLastApplied uint16
|
||||
|
||||
// TowerID is the unique, db-assigned identifier that references the
|
||||
// Tower with which the session is negotiated.
|
||||
TowerID uint64
|
||||
|
||||
// Tower holds the pubkey and address of the watchtower.
|
||||
//
|
||||
// NOTE: This value is not serialized. It is recovered by looking up the
|
||||
// tower with TowerID.
|
||||
Tower *Tower
|
||||
|
||||
// SessionKeyDesc is the key descriptor used to derive the client's
|
||||
// session key so that it can authenticate with the tower to update its
|
||||
// session.
|
||||
SessionKeyDesc keychain.KeyLocator
|
||||
|
||||
// SessionPrivKey is the ephemeral secret key used to connect to the
|
||||
// watchtower.
|
||||
// TODO(conner): remove after HD keys
|
||||
SessionPrivKey *btcec.PrivateKey
|
||||
|
||||
// Policy holds the negotiated session parameters.
|
||||
Policy wtpolicy.Policy
|
||||
|
||||
// RewardPkScript is the pkscript that the tower's reward will be
|
||||
// deposited to if a sweep transaction confirms and the sessions
|
||||
// specifies a reward output.
|
||||
RewardPkScript []byte
|
||||
|
||||
// CommittedUpdates is a map from allocated sequence numbers to unacked
|
||||
// updates. These updates can be resent after a restart if the update
|
||||
// failed to send or receive an acknowledgment.
|
||||
CommittedUpdates map[uint16]*CommittedUpdate
|
||||
|
||||
// AckedUpdates is a map from sequence number to backup id to record
|
||||
// which revoked states were uploaded via this session.
|
||||
AckedUpdates map[uint16]BackupID
|
||||
}
|
||||
|
||||
// BackupID identifies a particular revoked, remote commitment by channel id and
|
||||
// commitment height.
|
||||
type BackupID struct {
|
||||
// ChanID is the channel id of the revoked commitment.
|
||||
ChanID lnwire.ChannelID
|
||||
|
||||
// CommitHeight is the commitment height of the revoked commitment.
|
||||
CommitHeight uint64
|
||||
}
|
||||
|
||||
// CommittedUpdate holds a state update sent by a client along with its
|
||||
// SessionID.
|
||||
type CommittedUpdate struct {
|
||||
BackupID BackupID
|
||||
|
||||
// Hint is the 16-byte prefix of the revoked commitment transaction ID.
|
||||
Hint BreachHint
|
||||
|
||||
// EncryptedBlob is a ciphertext containing the sweep information for
|
||||
// exacting justice if the commitment transaction matching the breach
|
||||
// hint is braodcast.
|
||||
EncryptedBlob []byte
|
||||
}
|
@ -82,7 +82,7 @@ func (s *SessionInfo) AcceptUpdateSequence(seqNum, lastApplied uint16) error {
|
||||
return ErrSessionConsumed
|
||||
|
||||
// Client update does not match our expected next seqnum.
|
||||
case seqNum != s.LastApplied+1:
|
||||
case seqNum != s.LastApplied && seqNum != s.LastApplied+1:
|
||||
return ErrUpdateOutOfOrder
|
||||
}
|
||||
|
||||
|
65
watchtower/wtdb/tower.go
Normal file
65
watchtower/wtdb/tower.go
Normal file
@ -0,0 +1,65 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// Tower holds the necessary components required to connect to a remote tower.
|
||||
// Communication is handled by brontide, and requires both a public key and an
|
||||
// address.
|
||||
type Tower struct {
|
||||
// ID is a unique ID for this record assigned by the database.
|
||||
ID uint64
|
||||
|
||||
// IdentityKey is the public key of the remote node, used to
|
||||
// authenticate the brontide transport.
|
||||
IdentityKey *btcec.PublicKey
|
||||
|
||||
// Addresses is a list of possible addresses to reach the tower.
|
||||
Addresses []net.Addr
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// AddAddress adds the given address to the tower's in-memory list of addresses.
|
||||
// If the address's string is already present, the Tower will be left
|
||||
// unmodified. Otherwise, the adddress is prepended to the beginning of the
|
||||
// Tower's addresses, on the assumption that it is fresher than the others.
|
||||
func (t *Tower) AddAddress(addr net.Addr) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Ensure we don't add a duplicate address.
|
||||
addrStr := addr.String()
|
||||
for _, existingAddr := range t.Addresses {
|
||||
if existingAddr.String() == addrStr {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Add this address to the front of the list, on the assumption that it
|
||||
// is a fresher address and will be tried first.
|
||||
t.Addresses = append([]net.Addr{addr}, t.Addresses...)
|
||||
}
|
||||
|
||||
// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's
|
||||
// addresses. This can be used to have a client try multiple addresses for the
|
||||
// same Tower.
|
||||
func (t *Tower) LNAddrs() []*lnwire.NetAddress {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses))
|
||||
for _, addr := range t.Addresses {
|
||||
addrs = append(addrs, &lnwire.NetAddress{
|
||||
IdentityKey: t.IdentityKey,
|
||||
Address: addr,
|
||||
})
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
223
watchtower/wtmock/client_db.go
Normal file
223
watchtower/wtmock/client_db.go
Normal file
@ -0,0 +1,223 @@
|
||||
package wtmock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
type towerPK [33]byte
|
||||
|
||||
// ClientDB is a mock, in-memory database or testing the watchtower client
|
||||
// behavior.
|
||||
type ClientDB struct {
|
||||
nextTowerID uint64 // to be used atomically
|
||||
|
||||
mu sync.Mutex
|
||||
sweepPkScripts map[lnwire.ChannelID][]byte
|
||||
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
|
||||
towerIndex map[towerPK]uint64
|
||||
towers map[uint64]*wtdb.Tower
|
||||
}
|
||||
|
||||
// NewClientDB initializes a new mock ClientDB.
|
||||
func NewClientDB() *ClientDB {
|
||||
return &ClientDB{
|
||||
sweepPkScripts: make(map[lnwire.ChannelID][]byte),
|
||||
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
|
||||
towerIndex: make(map[towerPK]uint64),
|
||||
towers: make(map[uint64]*wtdb.Tower),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTower initializes a database entry with the given lightning address. If
|
||||
// the tower exists, the address is append to the list of all addresses used to
|
||||
// that tower previously.
|
||||
func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var towerPubKey towerPK
|
||||
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
|
||||
|
||||
var tower *wtdb.Tower
|
||||
towerID, ok := m.towerIndex[towerPubKey]
|
||||
if ok {
|
||||
tower = m.towers[towerID]
|
||||
tower.AddAddress(lnAddr.Address)
|
||||
} else {
|
||||
towerID = atomic.AddUint64(&m.nextTowerID, 1)
|
||||
tower = &wtdb.Tower{
|
||||
ID: towerID,
|
||||
IdentityKey: lnAddr.IdentityKey,
|
||||
Addresses: []net.Addr{lnAddr.Address},
|
||||
}
|
||||
}
|
||||
|
||||
m.towerIndex[towerPubKey] = towerID
|
||||
m.towers[towerID] = tower
|
||||
|
||||
return tower, nil
|
||||
}
|
||||
|
||||
// MarkBackupIneligible records that particular commit height is ineligible for
|
||||
// backup. This allows the client to track which updates it should not attempt
|
||||
// to retry after startup.
|
||||
func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListClientSessions returns the set of all client sessions known to the db.
|
||||
func (m *ClientDB) ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
for _, session := range m.activeSessions {
|
||||
sessions[session.ID] = session
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// CreateClientSession records a newly negotiated client session in the set of
|
||||
// active sessions. The session can be identified by its SessionID.
|
||||
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.activeSessions[session.ID] = &wtdb.ClientSession{
|
||||
TowerID: session.TowerID,
|
||||
Tower: session.Tower,
|
||||
SessionKeyDesc: session.SessionKeyDesc,
|
||||
SessionPrivKey: session.SessionPrivKey,
|
||||
ID: session.ID,
|
||||
Policy: session.Policy,
|
||||
SeqNum: session.SeqNum,
|
||||
TowerLastApplied: session.TowerLastApplied,
|
||||
RewardPkScript: session.RewardPkScript,
|
||||
CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate),
|
||||
AckedUpdates: make(map[uint16]wtdb.BackupID),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||
// seqNum). This allows the client to retransmit this update on startup.
|
||||
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
update *wtdb.CommittedUpdate) (uint16, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Fail if session doesn't exist.
|
||||
session, ok := m.activeSessions[*id]
|
||||
if !ok {
|
||||
return 0, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Check if an update has already been committed for this state.
|
||||
dbUpdate, ok := session.CommittedUpdates[seqNum]
|
||||
if ok {
|
||||
// If the breach hint matches, we'll just return the last
|
||||
// applied value so the client can retransmit.
|
||||
if dbUpdate.Hint == update.Hint {
|
||||
return session.TowerLastApplied, nil
|
||||
}
|
||||
|
||||
// Otherwise, fail since the breach hint doesn't match.
|
||||
return 0, wtdb.ErrUpdateAlreadyCommitted
|
||||
}
|
||||
|
||||
// Sequence number must increment.
|
||||
if seqNum != session.SeqNum+1 {
|
||||
return 0, wtdb.ErrCommitUnorderedUpdate
|
||||
}
|
||||
|
||||
// Save the update and increment the sequence number.
|
||||
session.CommittedUpdates[seqNum] = update
|
||||
session.SeqNum++
|
||||
|
||||
return session.TowerLastApplied, nil
|
||||
}
|
||||
|
||||
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
|
||||
// removes the update from the set of committed updates, and validates the
|
||||
// lastApplied value returned from the tower.
|
||||
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Fail if session doesn't exist.
|
||||
session, ok := m.activeSessions[*id]
|
||||
if !ok {
|
||||
return wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Retrieve the committed update, failing if none is found. We should
|
||||
// only receive acks for state updates that we send.
|
||||
update, ok := session.CommittedUpdates[seqNum]
|
||||
if !ok {
|
||||
return wtdb.ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
// Ensure the returned last applied value does not exceed the highest
|
||||
// allocated sequence number.
|
||||
if lastApplied > session.SeqNum {
|
||||
return wtdb.ErrUnallocatedLastApplied
|
||||
}
|
||||
|
||||
// Ensure the last applied value isn't lower than a previous one sent by
|
||||
// the tower.
|
||||
if lastApplied < session.TowerLastApplied {
|
||||
return wtdb.ErrLastAppliedReversion
|
||||
}
|
||||
|
||||
// Finally, remove the committed update from disk and mark the update as
|
||||
// acked. The tower last applied value is also recorded to send along
|
||||
// with the next update.
|
||||
delete(session.CommittedUpdates, seqNum)
|
||||
session.AckedUpdates[seqNum] = update.BackupID
|
||||
session.TowerLastApplied = lastApplied
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchChanPkScripts returns the set of sweep pkscripts known for all channels.
|
||||
// This allows the client to cache them in memory on startup.
|
||||
func (m *ClientDB) FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
sweepPkScripts := make(map[lnwire.ChannelID][]byte)
|
||||
for chanID, pkScript := range m.sweepPkScripts {
|
||||
sweepPkScripts[chanID] = cloneBytes(pkScript)
|
||||
}
|
||||
|
||||
return sweepPkScripts, nil
|
||||
}
|
||||
|
||||
// AddChanPkScript sets a pkscript or sweeping funds from the channel or chanID.
|
||||
func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.sweepPkScripts[chanID]; ok {
|
||||
return fmt.Errorf("pkscript for %x already exists", pkScript)
|
||||
}
|
||||
|
||||
m.sweepPkScripts[chanID] = cloneBytes(pkScript)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneBytes(b []byte) []byte {
|
||||
bb := make([]byte, len(b))
|
||||
copy(bb, b)
|
||||
return bb
|
||||
}
|
@ -13,6 +13,8 @@ import (
|
||||
type MockPeer struct {
|
||||
remotePub *btcec.PublicKey
|
||||
remoteAddr net.Addr
|
||||
localPub *btcec.PublicKey
|
||||
localAddr net.Addr
|
||||
|
||||
IncomingMsgs chan []byte
|
||||
OutgoingMsgs chan []byte
|
||||
@ -20,30 +22,74 @@ type MockPeer struct {
|
||||
writeDeadline <-chan time.Time
|
||||
readDeadline <-chan time.Time
|
||||
|
||||
RemoteQuit chan struct{}
|
||||
Quit chan struct{}
|
||||
}
|
||||
|
||||
// NewMockPeer returns a fresh MockPeer.
|
||||
func NewMockPeer(pk *btcec.PublicKey, addr net.Addr, bufferSize int) *MockPeer {
|
||||
func NewMockPeer(lpk, rpk *btcec.PublicKey, addr net.Addr,
|
||||
bufferSize int) *MockPeer {
|
||||
|
||||
return &MockPeer{
|
||||
remotePub: pk,
|
||||
remotePub: rpk,
|
||||
remoteAddr: addr,
|
||||
localAddr: &net.TCPAddr{
|
||||
IP: net.IP{0x32, 0x31, 0x30, 0x29},
|
||||
Port: 36723,
|
||||
},
|
||||
localPub: lpk,
|
||||
IncomingMsgs: make(chan []byte, bufferSize),
|
||||
OutgoingMsgs: make(chan []byte, bufferSize),
|
||||
Quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockConn establishes a bidirectional connection between two MockPeers.
|
||||
func NewMockConn(localPk, remotePk *btcec.PublicKey,
|
||||
localAddr, remoteAddr net.Addr,
|
||||
bufferSize int) (*MockPeer, *MockPeer) {
|
||||
|
||||
localPeer := &MockPeer{
|
||||
remotePub: remotePk,
|
||||
remoteAddr: remoteAddr,
|
||||
localPub: localPk,
|
||||
localAddr: localAddr,
|
||||
IncomingMsgs: make(chan []byte, bufferSize),
|
||||
OutgoingMsgs: make(chan []byte, bufferSize),
|
||||
Quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
remotePeer := &MockPeer{
|
||||
remotePub: localPk,
|
||||
remoteAddr: localAddr,
|
||||
localPub: remotePk,
|
||||
localAddr: remoteAddr,
|
||||
IncomingMsgs: localPeer.OutgoingMsgs,
|
||||
OutgoingMsgs: localPeer.IncomingMsgs,
|
||||
Quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
localPeer.RemoteQuit = remotePeer.Quit
|
||||
remotePeer.RemoteQuit = localPeer.Quit
|
||||
|
||||
return localPeer, remotePeer
|
||||
}
|
||||
|
||||
// Write sends the raw bytes as the next full message read to the remote peer.
|
||||
// The write will fail if either party closes the connection or the write
|
||||
// deadline expires. The passed bytes slice is copied before sending, thus the
|
||||
// bytes may be reused once the method returns.
|
||||
func (p *MockPeer) Write(b []byte) (n int, err error) {
|
||||
bb := make([]byte, len(b))
|
||||
copy(bb, b)
|
||||
|
||||
select {
|
||||
case p.OutgoingMsgs <- b:
|
||||
case p.OutgoingMsgs <- bb:
|
||||
return len(b), nil
|
||||
case <-p.writeDeadline:
|
||||
return 0, fmt.Errorf("write timeout expired")
|
||||
case <-p.RemoteQuit:
|
||||
return 0, fmt.Errorf("remote closed connected")
|
||||
case <-p.Quit:
|
||||
return 0, fmt.Errorf("connection closed")
|
||||
}
|
||||
@ -69,6 +115,8 @@ func (p *MockPeer) ReadNextMessage() ([]byte, error) {
|
||||
return b, nil
|
||||
case <-p.readDeadline:
|
||||
return nil, fmt.Errorf("read timeout expired")
|
||||
case <-p.RemoteQuit:
|
||||
return nil, fmt.Errorf("remote closed connected")
|
||||
case <-p.Quit:
|
||||
return nil, fmt.Errorf("connection closed")
|
||||
}
|
||||
@ -112,6 +160,25 @@ func (p *MockPeer) RemoteAddr() net.Addr {
|
||||
return p.remoteAddr
|
||||
}
|
||||
|
||||
// LocalAddr returns the local net address of the peer.
|
||||
func (p *MockPeer) LocalAddr() net.Addr {
|
||||
return p.localAddr
|
||||
}
|
||||
|
||||
// Read is not implemented.
|
||||
func (p *MockPeer) Read(dst []byte) (int, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// SetDeadline is not implemented.
|
||||
func (p *MockPeer) SetDeadline(t time.Time) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// Compile-time constraint ensuring the MockPeer implements the wserver.Peer
|
||||
// interface.
|
||||
var _ wtserver.Peer = (*MockPeer)(nil)
|
||||
|
||||
// Compile-time constraint ensuring the MockPeer implements the net.Conn
|
||||
// interface.
|
||||
var _ net.Conn = (*MockPeer)(nil)
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
@ -55,14 +54,18 @@ type Config struct {
|
||||
|
||||
// ChainHash identifies the network that the server is watching.
|
||||
ChainHash chainhash.Hash
|
||||
|
||||
// NoAckUpdates causes the server to not acknowledge state updates, this
|
||||
// should only be used for testing.
|
||||
NoAckUpdates bool
|
||||
}
|
||||
|
||||
// Server houses the state required to handle watchtower peers. It's primary job
|
||||
// is to accept incoming connections, and dispatch processing of the client
|
||||
// message streams.
|
||||
type Server struct {
|
||||
started int32 // atomic
|
||||
shutdown int32 // atomic
|
||||
started sync.Once
|
||||
stopped sync.Once
|
||||
|
||||
cfg *Config
|
||||
|
||||
@ -71,6 +74,8 @@ type Server struct {
|
||||
clientMtx sync.RWMutex
|
||||
clients map[wtdb.SessionID]Peer
|
||||
|
||||
newPeers chan Peer
|
||||
|
||||
localInit *wtwire.Init
|
||||
|
||||
wg sync.WaitGroup
|
||||
@ -89,6 +94,7 @@ func New(cfg *Config) (*Server, error) {
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
clients: make(map[wtdb.SessionID]Peer),
|
||||
newPeers: make(chan Peer),
|
||||
localInit: localInit,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
@ -109,27 +115,22 @@ func New(cfg *Config) (*Server, error) {
|
||||
|
||||
// Start begins listening on the server's listeners.
|
||||
func (s *Server) Start() error {
|
||||
// Already running?
|
||||
if !atomic.CompareAndSwapInt32(&s.started, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.started.Do(func() {
|
||||
log.Infof("Starting watchtower server")
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.peerHandler()
|
||||
|
||||
s.connMgr.Start()
|
||||
|
||||
log.Infof("Watchtower server started successfully")
|
||||
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop shutdowns down the server's listeners and any active requests.
|
||||
func (s *Server) Stop() error {
|
||||
// Bail if we're already shutting down.
|
||||
if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.stopped.Do(func() {
|
||||
log.Infof("Stopping watchtower server")
|
||||
|
||||
s.connMgr.Stop()
|
||||
@ -138,7 +139,7 @@ func (s *Server) Stop() error {
|
||||
s.wg.Wait()
|
||||
|
||||
log.Infof("Watchtower server stopped successfully")
|
||||
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -163,8 +164,29 @@ func (s *Server) inboundPeerConnected(c net.Conn) {
|
||||
// by the client. This method serves also as a public endpoint for locally
|
||||
// registering new clients with the server.
|
||||
func (s *Server) InboundPeerConnected(peer Peer) {
|
||||
select {
|
||||
case s.newPeers <- peer:
|
||||
case <-s.quit:
|
||||
}
|
||||
}
|
||||
|
||||
// peerHandler processes newly accepted peers and spawns a client handler for
|
||||
// each. The peerHandler is used to ensure that waitgrouped client handlers are
|
||||
// spawned from a waitgrouped goroutine.
|
||||
func (s *Server) peerHandler() {
|
||||
defer s.wg.Done()
|
||||
defer s.removeAllPeers()
|
||||
|
||||
for {
|
||||
select {
|
||||
case peer := <-s.newPeers:
|
||||
s.wg.Add(1)
|
||||
go s.handleClient(peer)
|
||||
|
||||
case <-s.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleClient processes a series watchtower messages sent by a client. The
|
||||
@ -445,6 +467,13 @@ func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
|
||||
failCode = wtwire.CodeTemporaryFailure
|
||||
}
|
||||
|
||||
if s.cfg.NoAckUpdates {
|
||||
return &connFailure{
|
||||
ID: *id,
|
||||
Code: uint16(failCode),
|
||||
}
|
||||
}
|
||||
|
||||
return s.replyStateUpdate(
|
||||
peer, id, failCode, lastApplied,
|
||||
)
|
||||
@ -614,6 +643,21 @@ func (s *Server) removePeer(id *wtdb.SessionID, addr net.Addr) {
|
||||
}
|
||||
}
|
||||
|
||||
// removeAllPeers iterates through the server's current set of peers and closes
|
||||
// all open connections.
|
||||
func (s *Server) removeAllPeers() {
|
||||
s.clientMtx.Lock()
|
||||
defer s.clientMtx.Unlock()
|
||||
|
||||
for id, peer := range s.clients {
|
||||
log.Infof("Releasing incoming peer %s@%s", id,
|
||||
peer.RemoteAddr())
|
||||
|
||||
delete(s.clients, id)
|
||||
peer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// logMessage writes information about a message exchanged with a remote peer,
|
||||
// using directional prepositions to signal whether the message was sent or
|
||||
// received.
|
||||
|
@ -87,10 +87,12 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) {
|
||||
s := initServer(t, nil, timeoutDuration)
|
||||
defer s.Stop()
|
||||
|
||||
localPub := randPubKey(t)
|
||||
|
||||
// Create two peers using the same session id.
|
||||
peerPub := randPubKey(t)
|
||||
peer1 := wtmock.NewMockPeer(peerPub, nil, 0)
|
||||
peer2 := wtmock.NewMockPeer(peerPub, nil, 0)
|
||||
peer1 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||
peer2 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||
|
||||
// Serialize a Init message to be sent by both peers.
|
||||
init := wtwire.NewInitMessage(
|
||||
@ -219,9 +221,11 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
|
||||
s := initServer(t, nil, timeoutDuration)
|
||||
defer s.Stop()
|
||||
|
||||
localPub := randPubKey(t)
|
||||
|
||||
// Create a new client and connect to server.
|
||||
peerPub := randPubKey(t)
|
||||
peer := wtmock.NewMockPeer(peerPub, nil, 0)
|
||||
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||
|
||||
// Send the CreateSession message, and wait for a reply.
|
||||
@ -249,7 +253,7 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
|
||||
|
||||
// Simulate a peer with the same session id connection to the server
|
||||
// again.
|
||||
peer = wtmock.NewMockPeer(peerPub, nil, 0)
|
||||
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||
|
||||
// Send the _same_ CreateSession message as the first attempt.
|
||||
@ -418,8 +422,8 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
{Code: wtwire.CodeOK, LastApplied: 4},
|
||||
},
|
||||
},
|
||||
// Valid update sequence with disconnection, ensure resumes resume.
|
||||
// Client doesn't echo last applied until last message.
|
||||
// Valid update sequence with disconnection, resume next update. Client
|
||||
// doesn't echo last applied until last message.
|
||||
{
|
||||
name: "resume after disconnect lagging lastapplied",
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
@ -448,6 +452,38 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||
{Code: wtwire.CodeOK, LastApplied: 4},
|
||||
},
|
||||
},
|
||||
// Valid update sequence with disconnection, resume last update. Client
|
||||
// doesn't echo last applied until last message.
|
||||
{
|
||||
name: "resume after disconnect lagging lastapplied",
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 4,
|
||||
RewardBase: 0,
|
||||
RewardRate: 0,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
updates: []*wtwire.StateUpdate{
|
||||
{SeqNum: 1, LastApplied: 0},
|
||||
{SeqNum: 2, LastApplied: 0},
|
||||
nil, // Wait for read timeout to drop conn, then reconnect.
|
||||
{SeqNum: 2, LastApplied: 0},
|
||||
{SeqNum: 3, LastApplied: 0},
|
||||
{SeqNum: 4, LastApplied: 3},
|
||||
},
|
||||
replies: []*wtwire.StateUpdateReply{
|
||||
{Code: wtwire.CodeOK, LastApplied: 1},
|
||||
{Code: wtwire.CodeOK, LastApplied: 2},
|
||||
nil,
|
||||
{Code: wtwire.CodeOK, LastApplied: 2},
|
||||
{Code: wtwire.CodeOK, LastApplied: 3},
|
||||
{Code: wtwire.CodeOK, LastApplied: 4},
|
||||
},
|
||||
},
|
||||
// Send update with sequence number that exceeds MaxUpdates.
|
||||
{
|
||||
name: "seqnum exceed maxupdates",
|
||||
@ -527,9 +563,11 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
|
||||
s := initServer(t, nil, timeoutDuration)
|
||||
defer s.Stop()
|
||||
|
||||
localPub := randPubKey(t)
|
||||
|
||||
// Create a new client and connect to the server.
|
||||
peerPub := randPubKey(t)
|
||||
peer := wtmock.NewMockPeer(peerPub, nil, 0)
|
||||
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||
|
||||
// Register a session for this client to use in the subsequent tests.
|
||||
@ -549,7 +587,7 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
|
||||
|
||||
// Now that the original connection has been closed, connect a new
|
||||
// client with the same session id.
|
||||
peer = wtmock.NewMockPeer(peerPub, nil, 0)
|
||||
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||
|
||||
// Send the intended StateUpdate messages in series.
|
||||
@ -560,7 +598,7 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
|
||||
if update == nil {
|
||||
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||
|
||||
peer = wtmock.NewMockPeer(peerPub, nil, 0)
|
||||
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||
|
||||
continue
|
||||
|
@ -14,9 +14,9 @@ const (
|
||||
// reply was never received and/or processed by the client.
|
||||
CreateSessionCodeAlreadyExists CreateSessionCode = 60
|
||||
|
||||
// CreateSessionCodeRejectRejectMaxUpdates the tower rejected the maximum
|
||||
// CreateSessionCodeRejectMaxUpdates the tower rejected the maximum
|
||||
// number of state updates proposed by the client.
|
||||
CreateSessionCodeRejectRejectMaxUpdates CreateSessionCode = 61
|
||||
CreateSessionCodeRejectMaxUpdates CreateSessionCode = 61
|
||||
|
||||
// CreateSessionCodeRejectRewardRate the tower rejected the reward rate
|
||||
// proposed by the client.
|
||||
|
@ -1,5 +1,7 @@
|
||||
package wtwire
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ErrorCode represents a generic error code used when replying to watchtower
|
||||
// clients. Specific reply messages may extend the ErrorCode primitive and add
|
||||
// custom codes, so long as they don't collide with the generic error codes..
|
||||
@ -18,3 +20,33 @@ const (
|
||||
// permanently failed, and further communication should be avoided.
|
||||
CodePermanentFailure ErrorCode = 50
|
||||
)
|
||||
|
||||
// String returns a human-readable description of an ErrorCode.
|
||||
func (c ErrorCode) String() string {
|
||||
switch c {
|
||||
case CodeOK:
|
||||
return "CodeOK"
|
||||
case CodeTemporaryFailure:
|
||||
return "CodeTemporaryFailure"
|
||||
case CodePermanentFailure:
|
||||
return "CodePermanentFailure"
|
||||
case CreateSessionCodeAlreadyExists:
|
||||
return "CreateSessionCodeAlreadyExists"
|
||||
case CreateSessionCodeRejectMaxUpdates:
|
||||
return "CreateSessionCodeRejectMaxUpdates"
|
||||
case CreateSessionCodeRejectRewardRate:
|
||||
return "CreateSessionCodeRejectRewardRate"
|
||||
case CreateSessionCodeRejectSweepFeeRate:
|
||||
return "CreateSessionCodeRejectSweepFeeRate"
|
||||
case CreateSessionCodeRejectBlobType:
|
||||
return "CreateSessionCodeRejectBlobType"
|
||||
case StateUpdateCodeClientBehind:
|
||||
return "StateUpdateCodeClientBehind"
|
||||
case StateUpdateCodeMaxUpdatesExceeded:
|
||||
return "StateUpdateCodeMaxUpdatesExceeded"
|
||||
case StateUpdateCodeSeqNumOutOfOrder:
|
||||
return "StateUpdateCodeSeqNumOutOfOrder"
|
||||
default:
|
||||
return fmt.Sprintf("UnknownErrorCode: %d", c)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user