lnd.xprv/watchtower/wtclient/session_negotiator.go
2021-01-04 08:32:34 -08:00

519 lines
15 KiB
Go

package wtclient
import (
"fmt"
"sync"
"time"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
// SessionNegotiator is an interface for asynchronously requesting new sessions.
type SessionNegotiator interface {
// RequestSession signals to the session negotiator that the client
// needs another session. Once the session is negotiated, it should be
// returned via NewSessions.
RequestSession()
// NewSessions is a read-only channel where newly negotiated sessions
// will be delivered.
NewSessions() <-chan *wtdb.ClientSession
// Start safely initializes the session negotiator.
Start() error
// Stop safely shuts down the session negotiator.
Stop() error
}
// NegotiatorConfig provides access to the resources required by a
// SessionNegotiator to faithfully carry out its duties. All nil-able field must
// be initialized.
type NegotiatorConfig struct {
// DB provides access to a persistent storage medium used by the tower
// to properly allocate session ephemeral keys and record successfully
// negotiated sessions.
DB DB
// SecretKeyRing allows the client to derive new session private keys
// when attempting to negotiate session with a tower.
SecretKeyRing ECDHKeyRing
// Candidates is an abstract set of tower candidates that the negotiator
// will traverse serially when attempting to negotiate a new session.
Candidates TowerCandidateIterator
// Policy defines the session policy that will be proposed to towers
// when attempting to negotiate a new session. This policy will be used
// across all negotiation proposals for the lifetime of the negotiator.
Policy wtpolicy.Policy
// Dial initiates an outbound brontide connection to the given address
// using a specified private key. The peer is returned in the event of a
// successful connection.
Dial func(keychain.SingleKeyECDH, *lnwire.NetAddress) (wtserver.Peer,
error)
// SendMessage writes a wtwire message to remote peer.
SendMessage func(wtserver.Peer, wtwire.Message) error
// ReadMessage reads a message from a remote peer and returns the
// decoded wtwire message.
ReadMessage func(wtserver.Peer) (wtwire.Message, error)
// ChainHash the genesis hash identifying the chain for any negotiated
// sessions. Any state updates sent to that session should also
// originate from this chain.
ChainHash chainhash.Hash
// MinBackoff defines the initial backoff applied by the session
// negotiator after all tower candidates have been exhausted and
// reattempting negotiation with the same set of candidates. Subsequent
// backoff durations will grow exponentially.
MinBackoff time.Duration
// MaxBackoff defines the maximum backoff applied by the session
// negotiator after all tower candidates have been exhausted and
// reattempting negotiation with the same set of candidates. If the
// exponential backoff produces a timeout greater than this value, the
// backoff duration will be clamped to MaxBackoff.
MaxBackoff time.Duration
// Log specifies the desired log output, which should be prefixed by the
// client type, e.g. anchor or legacy.
Log btclog.Logger
}
// sessionNegotiator is concrete SessionNegotiator that is able to request new
// sessions from a set of candidate towers asynchronously and return successful
// sessions to the primary client.
type sessionNegotiator struct {
started sync.Once
stopped sync.Once
localInit *wtwire.Init
cfg *NegotiatorConfig
log btclog.Logger
dispatcher chan struct{}
newSessions chan *wtdb.ClientSession
successfulNegotiations chan *wtdb.ClientSession
wg sync.WaitGroup
quit chan struct{}
}
// Compile-time constraint to ensure a *sessionNegotiator implements the
// SessionNegotiator interface.
var _ SessionNegotiator = (*sessionNegotiator)(nil)
// newSessionNegotiator initializes a fresh sessionNegotiator instance.
func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator {
// Generate the set of features the negitator will present to the tower
// upon connection. For anchor channels, we'll conditionally signal that
// we require support for anchor channels depdening on the requested
// policy.
features := []lnwire.FeatureBit{
wtwire.AltruistSessionsRequired,
}
if cfg.Policy.IsAnchorChannel() {
features = append(features, wtwire.AnchorCommitRequired)
}
localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(features...),
cfg.ChainHash,
)
return &sessionNegotiator{
cfg: cfg,
log: cfg.Log,
localInit: localInit,
dispatcher: make(chan struct{}, 1),
newSessions: make(chan *wtdb.ClientSession),
successfulNegotiations: make(chan *wtdb.ClientSession),
quit: make(chan struct{}),
}
}
// Start safely starts up the sessionNegotiator.
func (n *sessionNegotiator) Start() error {
n.started.Do(func() {
n.log.Debugf("Starting session negotiator")
n.wg.Add(1)
go n.negotiationDispatcher()
})
return nil
}
// Stop safely shutsdown the sessionNegotiator.
func (n *sessionNegotiator) Stop() error {
n.stopped.Do(func() {
n.log.Debugf("Stopping session negotiator")
close(n.quit)
n.wg.Wait()
})
return nil
}
// NewSessions returns a receive-only channel from which newly negotiated
// sessions will be returned.
func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession {
return n.newSessions
}
// RequestSession sends a request to the sessionNegotiator to begin requesting a
// new session. If one is already in the process of being negotiated, the
// request will be ignored.
func (n *sessionNegotiator) RequestSession() {
select {
case n.dispatcher <- struct{}{}:
default:
}
}
// negotiationDispatcher acts as the primary event loop for the
// sessionNegotiator, coordinating requests for more sessions and dispatching
// attempts to negotiate them from a list of candidates.
func (n *sessionNegotiator) negotiationDispatcher() {
defer n.wg.Done()
var pendingNegotiations int
for {
select {
case <-n.dispatcher:
pendingNegotiations++
if pendingNegotiations > 1 {
n.log.Debugf("Already negotiating session, " +
"waiting for existing negotiation to " +
"complete")
continue
}
// TODO(conner): consider reusing good towers
n.log.Debugf("Dispatching session negotiation")
n.wg.Add(1)
go n.negotiate()
case session := <-n.successfulNegotiations:
select {
case n.newSessions <- session:
pendingNegotiations--
case <-n.quit:
return
}
if pendingNegotiations > 0 {
n.log.Debugf("Dispatching pending session " +
"negotiation")
n.wg.Add(1)
go n.negotiate()
}
case <-n.quit:
return
}
}
}
// negotiate handles the process of iterating through potential tower candidates
// and attempting to negotiate a new session until a successful negotiation
// occurs. If the candidate iterator becomes exhausted because none were
// successful, this method will back off exponentially up to the configured max
// backoff. This method will continue trying until a negotiation is successful
// before returning the negotiated session to the dispatcher via the succeed
// channel.
//
// NOTE: This method MUST be run as a goroutine.
func (n *sessionNegotiator) negotiate() {
defer n.wg.Done()
// On the first pass, initialize the backoff to our configured min
// backoff.
var backoff time.Duration
// Create a closure to update the backoff upon failure such that it
// stays within our min and max backoff parameters.
updateBackoff := func() {
if backoff == 0 {
backoff = n.cfg.MinBackoff
} else {
backoff *= 2
if backoff > n.cfg.MaxBackoff {
backoff = n.cfg.MaxBackoff
}
}
}
retryWithBackoff:
// If we are retrying, wait out the delay before continuing.
if backoff > 0 {
select {
case <-time.After(backoff):
case <-n.quit:
return
}
}
for {
select {
case <-n.quit:
return
default:
}
// Pull the next candidate from our list of addresses.
tower, err := n.cfg.Candidates.Next()
if err != nil {
// We've run out of addresses, update our backoff.
updateBackoff()
n.log.Debugf("Unable to get new tower candidate, "+
"retrying after %v -- reason: %v", backoff, err)
// Only reset the iterator once we've exhausted all
// candidates. Doing so allows us to load balance
// sessions better amongst all of the tower candidates.
if err == ErrTowerCandidatesExhausted {
n.cfg.Candidates.Reset()
}
goto retryWithBackoff
}
towerPub := tower.IdentityKey.SerializeCompressed()
n.log.Debugf("Attempting session negotiation with tower=%x",
towerPub)
// Before proceeding, we will reserve a session key index to use
// with this specific tower. If one is already reserved, the
// existing index will be returned.
keyIndex, err := n.cfg.DB.NextSessionKeyIndex(
tower.ID, n.cfg.Policy.BlobType,
)
if err != nil {
n.log.Debugf("Unable to reserve session key index "+
"for tower=%x: %v", towerPub, err)
continue
}
// We'll now attempt the CreateSession dance with the tower to
// get a new session, trying all addresses if necessary.
err = n.createSession(tower, keyIndex)
if err != nil {
// An unexpected error occurred, updpate our backoff.
updateBackoff()
n.log.Debugf("Session negotiation with tower=%x "+
"failed, trying again -- reason: %v",
tower.IdentityKey.SerializeCompressed(), err)
goto retryWithBackoff
}
// Success.
return
}
}
// createSession takes a tower an attempts to negotiate a session using any of
// its stored addresses. This method returns after the first successful
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If
// the tower has no addresses, ErrNoTowerAddrs is returned.
func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
keyIndex uint32) error {
// If the tower has no addresses, there's nothing we can do.
if len(tower.Addresses) == 0 {
return ErrNoTowerAddrs
}
sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey(
keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession,
Index: keyIndex,
},
)
if err != nil {
return err
}
sessionKey := keychain.NewPubKeyECDH(
sessionKeyDesc, n.cfg.SecretKeyRing,
)
for _, lnAddr := range tower.LNAddrs() {
err := n.tryAddress(sessionKey, keyIndex, tower, lnAddr)
switch {
case err == ErrPermanentTowerFailure:
// TODO(conner): report to iterator? can then be reset
// with restart
fallthrough
case err != nil:
n.log.Debugf("Request for session negotiation with "+
"tower=%s failed, trying again -- reason: "+
"%v", lnAddr, err)
continue
default:
return nil
}
}
return ErrFailedNegotiation
}
// tryAddress executes a single create session dance using the given address.
// The address should belong to the tower's set of addresses. This method only
// returns true if all steps succeed and the new session has been persisted, and
// fails otherwise.
func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH,
keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
// Connect to the tower address using our generated session key.
conn, err := n.cfg.Dial(sessionKey, lnAddr)
if err != nil {
return err
}
// Send local Init message.
err = n.cfg.SendMessage(conn, n.localInit)
if err != nil {
return fmt.Errorf("unable to send Init: %v", err)
}
// Receive remote Init message.
remoteMsg, err := n.cfg.ReadMessage(conn)
if err != nil {
return fmt.Errorf("unable to read Init: %v", err)
}
// Check that returned message is wtwire.Init.
remoteInit, ok := remoteMsg.(*wtwire.Init)
if !ok {
return fmt.Errorf("expected Init, got %T in reply", remoteMsg)
}
// Verify the watchtower's remote Init message against our own.
err = n.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
if err != nil {
return err
}
policy := n.cfg.Policy
createSession := &wtwire.CreateSession{
BlobType: policy.BlobType,
MaxUpdates: policy.MaxUpdates,
RewardBase: policy.RewardBase,
RewardRate: policy.RewardRate,
SweepFeeRate: policy.SweepFeeRate,
}
// Send CreateSession message.
err = n.cfg.SendMessage(conn, createSession)
if err != nil {
return fmt.Errorf("unable to send CreateSession: %v", err)
}
// Receive CreateSessionReply message.
remoteMsg, err = n.cfg.ReadMessage(conn)
if err != nil {
return fmt.Errorf("unable to read CreateSessionReply: %v", err)
}
// Check that returned message is wtwire.CreateSessionReply.
createSessionReply, ok := remoteMsg.(*wtwire.CreateSessionReply)
if !ok {
return fmt.Errorf("expected CreateSessionReply, got %T in "+
"reply", remoteMsg)
}
switch createSessionReply.Code {
case wtwire.CodeOK, wtwire.CreateSessionCodeAlreadyExists:
// TODO(conner): add last-applied to create session reply to
// handle case where we lose state, session already exists, and
// we want to possibly resume using the session
// TODO(conner): validate reward address
rewardPkScript := createSessionReply.Data
sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey())
clientSession := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: tower.ID,
KeyIndex: keyIndex,
Policy: n.cfg.Policy,
RewardPkScript: rewardPkScript,
},
Tower: tower,
SessionKeyECDH: sessionKey,
ID: sessionID,
}
err = n.cfg.DB.CreateClientSession(clientSession)
if err != nil {
return fmt.Errorf("unable to persist ClientSession: %v",
err)
}
n.log.Debugf("New session negotiated with %s, policy: %s",
lnAddr, clientSession.Policy)
// We have a newly negotiated session, return it to the
// dispatcher so that it can update how many outstanding
// negotiation requests we have.
select {
case n.successfulNegotiations <- clientSession:
return nil
case <-n.quit:
return ErrNegotiatorExiting
}
// TODO(conner): handle error codes properly
case wtwire.CreateSessionCodeRejectBlobType:
return fmt.Errorf("tower rejected blob type: %v",
policy.BlobType)
case wtwire.CreateSessionCodeRejectMaxUpdates:
return fmt.Errorf("tower rejected max updates: %v",
policy.MaxUpdates)
case wtwire.CreateSessionCodeRejectRewardRate:
// The tower rejected the session because of the reward rate. If
// we didn't request a reward session, we'll treat this as a
// permanent tower failure.
if !policy.BlobType.Has(blob.FlagReward) {
return ErrPermanentTowerFailure
}
return fmt.Errorf("tower rejected reward rate: %v",
policy.RewardRate)
case wtwire.CreateSessionCodeRejectSweepFeeRate:
return fmt.Errorf("tower rejected sweep fee rate: %v",
policy.SweepFeeRate)
default:
return fmt.Errorf("received unhandled error code: %v",
createSessionReply.Code)
}
}