watchtower/wtclient/session_negotiator: add session negotiation
This commit is contained in:
parent
a8721bcedf
commit
95fa7659e0
451
watchtower/wtclient/session_negotiator.go
Normal file
451
watchtower/wtclient/session_negotiator.go
Normal file
@ -0,0 +1,451 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||
)
|
||||
|
||||
// SessionNegotiator is an interface for asynchronously requesting new sessions.
|
||||
type SessionNegotiator interface {
|
||||
// RequestSession signals to the session negotiator that the client
|
||||
// needs another session. Once the session is negotiated, it should be
|
||||
// returned via NewSessions.
|
||||
RequestSession()
|
||||
|
||||
// NewSessions is a read-only channel where newly negotiated sessions
|
||||
// will be delivered.
|
||||
NewSessions() <-chan *wtdb.ClientSession
|
||||
|
||||
// Start safely initializes the session negotiator.
|
||||
Start() error
|
||||
|
||||
// Stop safely shuts down the session negotiator.
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// NegotiatorConfig provides access to the resources required by a
|
||||
// SessionNegotiator to faithfully carry out its duties. All nil-able field must
|
||||
// be initialized.
|
||||
type NegotiatorConfig struct {
|
||||
// DB provides access to a persistent storage medium used by the tower
|
||||
// to properly allocate session ephemeral keys and record successfully
|
||||
// negotiated sessions.
|
||||
DB DB
|
||||
|
||||
// Candidates is an abstract set of tower candidates that the negotiator
|
||||
// will traverse serially when attempting to negotiate a new session.
|
||||
Candidates TowerCandidateIterator
|
||||
|
||||
// Policy defines the session policy that will be proposed to towers
|
||||
// when attempting to negotiate a new session. This policy will be used
|
||||
// across all negotiation proposals for the lifetime of the negotiator.
|
||||
Policy wtpolicy.Policy
|
||||
|
||||
// Dial initiates an outbound brontide connection to the given address
|
||||
// using a specified private key. The peer is returned in the event of a
|
||||
// successful connection.
|
||||
Dial func(*btcec.PrivateKey, *lnwire.NetAddress) (wtserver.Peer, error)
|
||||
|
||||
// SendMessage writes a wtwire message to remote peer.
|
||||
SendMessage func(wtserver.Peer, wtwire.Message) error
|
||||
|
||||
// ReadMessage reads a message from a remote peer and returns the
|
||||
// decoded wtwire message.
|
||||
ReadMessage func(wtserver.Peer) (wtwire.Message, error)
|
||||
|
||||
// ChainHash the genesis hash identifying the chain for any negotiated
|
||||
// sessions. Any state updates sent to that session should also
|
||||
// originate from this chain.
|
||||
ChainHash chainhash.Hash
|
||||
|
||||
// MinBackoff defines the initial backoff applied by the session
|
||||
// negotiator after all tower candidates have been exhausted and
|
||||
// reattempting negotiation with the same set of candidates. Subsequent
|
||||
// backoff durations will grow exponentially.
|
||||
MinBackoff time.Duration
|
||||
|
||||
// MaxBackoff defines the maximum backoff applied by the session
|
||||
// negotiator after all tower candidates have been exhausted and
|
||||
// reattempting negotation with the same set of candidates. If the
|
||||
// exponential backoff produces a timeout greater than this value, the
|
||||
// backoff duration will be clamped to MaxBackoff.
|
||||
MaxBackoff time.Duration
|
||||
}
|
||||
|
||||
// sessionNegotiator is concrete SessionNegotiator that is able to request new
|
||||
// sessions from a set of candidate towers asynchronously and return successful
|
||||
// sessions to the primary client.
|
||||
type sessionNegotiator struct {
|
||||
started sync.Once
|
||||
stopped sync.Once
|
||||
|
||||
localInit *wtwire.Init
|
||||
|
||||
cfg *NegotiatorConfig
|
||||
|
||||
dispatcher chan struct{}
|
||||
newSessions chan *wtdb.ClientSession
|
||||
successfulNegotiations chan *wtdb.ClientSession
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// Compile-time constraint to ensure a *sessionNegotiator implements the
|
||||
// SessionNegotiator interface.
|
||||
var _ SessionNegotiator = (*sessionNegotiator)(nil)
|
||||
|
||||
// newSessionNegotiator initializes a fresh sessionNegotiator instance.
|
||||
func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator {
|
||||
localInit := wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired),
|
||||
cfg.ChainHash,
|
||||
)
|
||||
|
||||
return &sessionNegotiator{
|
||||
cfg: cfg,
|
||||
localInit: localInit,
|
||||
dispatcher: make(chan struct{}, 1),
|
||||
newSessions: make(chan *wtdb.ClientSession),
|
||||
successfulNegotiations: make(chan *wtdb.ClientSession),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start safely starts up the sessionNegotiator.
|
||||
func (n *sessionNegotiator) Start() error {
|
||||
n.started.Do(func() {
|
||||
log.Debugf("Starting session negotiator")
|
||||
|
||||
n.wg.Add(1)
|
||||
go n.negotiationDispatcher()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop safely shutsdown the sessionNegotiator.
|
||||
func (n *sessionNegotiator) Stop() error {
|
||||
n.stopped.Do(func() {
|
||||
log.Debugf("Stopping session negotiator")
|
||||
|
||||
close(n.quit)
|
||||
n.wg.Wait()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewSessions returns a receive-only channel from which newly negotiated
|
||||
// sessions will be returned.
|
||||
func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession {
|
||||
return n.newSessions
|
||||
}
|
||||
|
||||
// RequestSession sends a request to the sessionNegotiator to begin requesting a
|
||||
// new session. If one is already in the process of being negotiated, the
|
||||
// request will be ignored.
|
||||
func (n *sessionNegotiator) RequestSession() {
|
||||
select {
|
||||
case n.dispatcher <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// negotiationDispatcher acts as the primary event loop for the
|
||||
// sessionNegotiator, coordinating requests for more sessions and dispatching
|
||||
// attempts to negotiate them from a list of candidates.
|
||||
func (n *sessionNegotiator) negotiationDispatcher() {
|
||||
defer n.wg.Done()
|
||||
|
||||
var pendingNegotiations int
|
||||
for {
|
||||
select {
|
||||
case <-n.dispatcher:
|
||||
pendingNegotiations++
|
||||
|
||||
if pendingNegotiations > 1 {
|
||||
log.Debugf("Already negotiating session, " +
|
||||
"waiting for existing negotiation to " +
|
||||
"complete")
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(conner): consider reusing good towers
|
||||
|
||||
log.Debugf("Dispatching session negotiation")
|
||||
|
||||
n.wg.Add(1)
|
||||
go n.negotiate()
|
||||
|
||||
case session := <-n.successfulNegotiations:
|
||||
select {
|
||||
case n.newSessions <- session:
|
||||
pendingNegotiations--
|
||||
case <-n.quit:
|
||||
return
|
||||
}
|
||||
|
||||
if pendingNegotiations > 0 {
|
||||
log.Debugf("Dispatching pending session " +
|
||||
"negotiation")
|
||||
|
||||
n.wg.Add(1)
|
||||
go n.negotiate()
|
||||
}
|
||||
|
||||
case <-n.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// negotiate handles the process of iterating through potential tower candidates
|
||||
// and attempting to negotiate a new session until a successful negotiation
|
||||
// occurs. If the candidate iterator becomes exhausted because none were
|
||||
// successful, this method will back off exponentially up to the configured max
|
||||
// backoff. This method will continue trying until a negotiation is succesful
|
||||
// before returning the negotiated session to the dispatcher via the succeed
|
||||
// channel.
|
||||
//
|
||||
// NOTE: This method MUST be run as a goroutine.
|
||||
func (n *sessionNegotiator) negotiate() {
|
||||
defer n.wg.Done()
|
||||
|
||||
// On the first pass, initialize the backoff to our configured min
|
||||
// backoff.
|
||||
backoff := n.cfg.MinBackoff
|
||||
|
||||
retryWithBackoff:
|
||||
// If we are retrying, wait out the delay before continuing.
|
||||
if backoff > 0 {
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-n.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Before attempting a bout of session negotiation, reset the candidate
|
||||
// iterator to ensure the results are fresh.
|
||||
n.cfg.Candidates.Reset()
|
||||
for {
|
||||
// Pull the next candidate from our list of addresses.
|
||||
tower, err := n.cfg.Candidates.Next()
|
||||
if err != nil {
|
||||
// We've run out of addresses, double and clamp backoff.
|
||||
backoff *= 2
|
||||
if backoff > n.cfg.MaxBackoff {
|
||||
backoff = n.cfg.MaxBackoff
|
||||
}
|
||||
|
||||
log.Debugf("Unable to get new tower candidate, "+
|
||||
"retrying after %v -- reason: %v", backoff, err)
|
||||
|
||||
goto retryWithBackoff
|
||||
}
|
||||
|
||||
log.Debugf("Attempting session negotiation with tower=%x",
|
||||
tower.IdentityKey.SerializeCompressed())
|
||||
|
||||
// We'll now attempt the CreateSession dance with the tower to
|
||||
// get a new session, trying all addresses if necessary.
|
||||
err = n.createSession(tower)
|
||||
if err != nil {
|
||||
log.Debugf("Session negotiation with tower=%x "+
|
||||
"failed, trying again -- reason: %v",
|
||||
tower.IdentityKey.SerializeCompressed(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Success.
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// createSession takes a tower an attempts to negotiate a session using any of
|
||||
// its stored addresses. This method returns after the first successful
|
||||
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If
|
||||
// the tower has no addresses, ErrNoTowerAddrs is returned.
|
||||
func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error {
|
||||
// If the tower has no addresses, there's nothing we can do.
|
||||
if len(tower.Addresses) == 0 {
|
||||
return ErrNoTowerAddrs
|
||||
}
|
||||
|
||||
// TODO(conner): create with hdkey at random index
|
||||
sessionPrivKey, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(conner): write towerAddr+privkey
|
||||
|
||||
for _, lnAddr := range tower.LNAddrs() {
|
||||
err = n.tryAddress(sessionPrivKey, tower, lnAddr)
|
||||
switch {
|
||||
case err == ErrPermanentTowerFailure:
|
||||
// TODO(conner): report to iterator? can then be reset
|
||||
// with restart
|
||||
fallthrough
|
||||
|
||||
case err != nil:
|
||||
log.Debugf("Request for session negotiation with "+
|
||||
"tower=%s failed, trying again -- reason: "+
|
||||
"%v", lnAddr, err)
|
||||
continue
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrFailedNegotiation
|
||||
}
|
||||
|
||||
// tryAddress executes a single create session dance using the given address.
|
||||
// The address should belong to the tower's set of addresses. This method only
|
||||
// returns true if all steps succeed and the new session has been persisted, and
|
||||
// fails otherwise.
|
||||
func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
|
||||
tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
|
||||
|
||||
// Connect to the tower address using our generated session key.
|
||||
conn, err := n.cfg.Dial(privKey, lnAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send local Init message.
|
||||
err = n.cfg.SendMessage(conn, n.localInit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to send Init: %v", err)
|
||||
}
|
||||
|
||||
// Receive remote Init message.
|
||||
remoteMsg, err := n.cfg.ReadMessage(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read Init: %v", err)
|
||||
}
|
||||
|
||||
// Check that returned message is wtwire.Init.
|
||||
remoteInit, ok := remoteMsg.(*wtwire.Init)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected Init, got %T in reply", remoteMsg)
|
||||
}
|
||||
|
||||
// Verify the watchtower's remote Init message against our own.
|
||||
err = n.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy := n.cfg.Policy
|
||||
createSession := &wtwire.CreateSession{
|
||||
BlobType: policy.BlobType,
|
||||
MaxUpdates: policy.MaxUpdates,
|
||||
RewardBase: policy.RewardBase,
|
||||
RewardRate: policy.RewardRate,
|
||||
SweepFeeRate: policy.SweepFeeRate,
|
||||
}
|
||||
|
||||
// Send CreateSession message.
|
||||
err = n.cfg.SendMessage(conn, createSession)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to send CreateSession: %v", err)
|
||||
}
|
||||
|
||||
// Receive CreateSessionReply message.
|
||||
remoteMsg, err = n.cfg.ReadMessage(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read CreateSessionReply: %v", err)
|
||||
}
|
||||
|
||||
// Check that returned message is wtwire.CreateSessionReply.
|
||||
createSessionReply, ok := remoteMsg.(*wtwire.CreateSessionReply)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected CreateSessionReply, got %T in "+
|
||||
"reply", remoteMsg)
|
||||
}
|
||||
|
||||
switch createSessionReply.Code {
|
||||
case wtwire.CodeOK, wtwire.CreateSessionCodeAlreadyExists:
|
||||
|
||||
// TODO(conner): add last-applied to create session reply to
|
||||
// handle case where we lose state, session already exists, and
|
||||
// we want to possibly resume using the session
|
||||
|
||||
// TODO(conner): validate reward address
|
||||
rewardPkScript := createSessionReply.Data
|
||||
|
||||
sessionID := wtdb.NewSessionIDFromPubKey(
|
||||
privKey.PubKey(),
|
||||
)
|
||||
clientSession := &wtdb.ClientSession{
|
||||
TowerID: tower.ID,
|
||||
Tower: tower,
|
||||
SessionPrivKey: privKey, // remove after using HD keys
|
||||
ID: sessionID,
|
||||
Policy: n.cfg.Policy,
|
||||
SeqNum: 0,
|
||||
RewardPkScript: rewardPkScript,
|
||||
}
|
||||
|
||||
err = n.cfg.DB.CreateClientSession(clientSession)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to persist ClientSession: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
log.Debugf("New session negotiated with %s, policy: %s",
|
||||
lnAddr, clientSession.Policy)
|
||||
|
||||
// We have a newly negotiated session, return it to the
|
||||
// dispatcher so that it can update how many outstanding
|
||||
// negotiation requests we have.
|
||||
select {
|
||||
case n.successfulNegotiations <- clientSession:
|
||||
return nil
|
||||
case <-n.quit:
|
||||
return ErrNegotiatorExiting
|
||||
}
|
||||
|
||||
// TODO(conner): handle error codes properly
|
||||
case wtwire.CreateSessionCodeRejectBlobType:
|
||||
return fmt.Errorf("tower rejected blob type: %v",
|
||||
policy.BlobType)
|
||||
|
||||
case wtwire.CreateSessionCodeRejectMaxUpdates:
|
||||
return fmt.Errorf("tower rejected max updates: %v",
|
||||
policy.MaxUpdates)
|
||||
|
||||
case wtwire.CreateSessionCodeRejectRewardRate:
|
||||
// The tower rejected the session because of the reward rate. If
|
||||
// we didn't request a reward session, we'll treat this as a
|
||||
// permanent tower failure.
|
||||
if !policy.BlobType.Has(blob.FlagReward) {
|
||||
return ErrPermanentTowerFailure
|
||||
}
|
||||
|
||||
return fmt.Errorf("tower rejected reward rate: %v",
|
||||
policy.RewardRate)
|
||||
|
||||
case wtwire.CreateSessionCodeRejectSweepFeeRate:
|
||||
return fmt.Errorf("tower rejected sweep fee rate: %v",
|
||||
policy.SweepFeeRate)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("received unhandled error code: %v",
|
||||
createSessionReply.Code)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user