package wtclient import ( "fmt" "sync" "time" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtserver" "github.com/lightningnetwork/lnd/watchtower/wtwire" ) // SessionNegotiator is an interface for asynchronously requesting new sessions. type SessionNegotiator interface { // RequestSession signals to the session negotiator that the client // needs another session. Once the session is negotiated, it should be // returned via NewSessions. RequestSession() // NewSessions is a read-only channel where newly negotiated sessions // will be delivered. NewSessions() <-chan *wtdb.ClientSession // Start safely initializes the session negotiator. Start() error // Stop safely shuts down the session negotiator. Stop() error } // NegotiatorConfig provides access to the resources required by a // SessionNegotiator to faithfully carry out its duties. All nil-able field must // be initialized. type NegotiatorConfig struct { // DB provides access to a persistent storage medium used by the tower // to properly allocate session ephemeral keys and record successfully // negotiated sessions. DB DB // SecretKeyRing allows the client to derive new session private keys // when attempting to negotiate session with a tower. SecretKeyRing ECDHKeyRing // Candidates is an abstract set of tower candidates that the negotiator // will traverse serially when attempting to negotiate a new session. Candidates TowerCandidateIterator // Policy defines the session policy that will be proposed to towers // when attempting to negotiate a new session. This policy will be used // across all negotiation proposals for the lifetime of the negotiator. Policy wtpolicy.Policy // Dial initiates an outbound brontide connection to the given address // using a specified private key. The peer is returned in the event of a // successful connection. Dial func(keychain.SingleKeyECDH, *lnwire.NetAddress) (wtserver.Peer, error) // SendMessage writes a wtwire message to remote peer. SendMessage func(wtserver.Peer, wtwire.Message) error // ReadMessage reads a message from a remote peer and returns the // decoded wtwire message. ReadMessage func(wtserver.Peer) (wtwire.Message, error) // ChainHash the genesis hash identifying the chain for any negotiated // sessions. Any state updates sent to that session should also // originate from this chain. ChainHash chainhash.Hash // MinBackoff defines the initial backoff applied by the session // negotiator after all tower candidates have been exhausted and // reattempting negotiation with the same set of candidates. Subsequent // backoff durations will grow exponentially. MinBackoff time.Duration // MaxBackoff defines the maximum backoff applied by the session // negotiator after all tower candidates have been exhausted and // reattempting negotiation with the same set of candidates. If the // exponential backoff produces a timeout greater than this value, the // backoff duration will be clamped to MaxBackoff. MaxBackoff time.Duration } // sessionNegotiator is concrete SessionNegotiator that is able to request new // sessions from a set of candidate towers asynchronously and return successful // sessions to the primary client. type sessionNegotiator struct { started sync.Once stopped sync.Once localInit *wtwire.Init cfg *NegotiatorConfig dispatcher chan struct{} newSessions chan *wtdb.ClientSession successfulNegotiations chan *wtdb.ClientSession wg sync.WaitGroup quit chan struct{} } // Compile-time constraint to ensure a *sessionNegotiator implements the // SessionNegotiator interface. var _ SessionNegotiator = (*sessionNegotiator)(nil) // newSessionNegotiator initializes a fresh sessionNegotiator instance. func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator { localInit := wtwire.NewInitMessage( lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), cfg.ChainHash, ) return &sessionNegotiator{ cfg: cfg, localInit: localInit, dispatcher: make(chan struct{}, 1), newSessions: make(chan *wtdb.ClientSession), successfulNegotiations: make(chan *wtdb.ClientSession), quit: make(chan struct{}), } } // Start safely starts up the sessionNegotiator. func (n *sessionNegotiator) Start() error { n.started.Do(func() { log.Debugf("Starting session negotiator") n.wg.Add(1) go n.negotiationDispatcher() }) return nil } // Stop safely shutsdown the sessionNegotiator. func (n *sessionNegotiator) Stop() error { n.stopped.Do(func() { log.Debugf("Stopping session negotiator") close(n.quit) n.wg.Wait() }) return nil } // NewSessions returns a receive-only channel from which newly negotiated // sessions will be returned. func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession { return n.newSessions } // RequestSession sends a request to the sessionNegotiator to begin requesting a // new session. If one is already in the process of being negotiated, the // request will be ignored. func (n *sessionNegotiator) RequestSession() { select { case n.dispatcher <- struct{}{}: default: } } // negotiationDispatcher acts as the primary event loop for the // sessionNegotiator, coordinating requests for more sessions and dispatching // attempts to negotiate them from a list of candidates. func (n *sessionNegotiator) negotiationDispatcher() { defer n.wg.Done() var pendingNegotiations int for { select { case <-n.dispatcher: pendingNegotiations++ if pendingNegotiations > 1 { log.Debugf("Already negotiating session, " + "waiting for existing negotiation to " + "complete") continue } // TODO(conner): consider reusing good towers log.Debugf("Dispatching session negotiation") n.wg.Add(1) go n.negotiate() case session := <-n.successfulNegotiations: select { case n.newSessions <- session: pendingNegotiations-- case <-n.quit: return } if pendingNegotiations > 0 { log.Debugf("Dispatching pending session " + "negotiation") n.wg.Add(1) go n.negotiate() } case <-n.quit: return } } } // negotiate handles the process of iterating through potential tower candidates // and attempting to negotiate a new session until a successful negotiation // occurs. If the candidate iterator becomes exhausted because none were // successful, this method will back off exponentially up to the configured max // backoff. This method will continue trying until a negotiation is successful // before returning the negotiated session to the dispatcher via the succeed // channel. // // NOTE: This method MUST be run as a goroutine. func (n *sessionNegotiator) negotiate() { defer n.wg.Done() // On the first pass, initialize the backoff to our configured min // backoff. var backoff time.Duration retryWithBackoff: // If we are retrying, wait out the delay before continuing. if backoff > 0 { select { case <-time.After(backoff): case <-n.quit: return } } for { select { case <-n.quit: return default: } // Pull the next candidate from our list of addresses. tower, err := n.cfg.Candidates.Next() if err != nil { if backoff == 0 { backoff = n.cfg.MinBackoff } else { // We've run out of addresses, double and clamp // backoff. backoff *= 2 if backoff > n.cfg.MaxBackoff { backoff = n.cfg.MaxBackoff } } log.Debugf("Unable to get new tower candidate, "+ "retrying after %v -- reason: %v", backoff, err) // Only reset the iterator once we've exhausted all // candidates. Doing so allows us to load balance // sessions better amongst all of the tower candidates. if err == ErrTowerCandidatesExhausted { n.cfg.Candidates.Reset() } goto retryWithBackoff } towerPub := tower.IdentityKey.SerializeCompressed() log.Debugf("Attempting session negotiation with tower=%x", towerPub) // Before proceeding, we will reserve a session key index to use // with this specific tower. If one is already reserved, the // existing index will be returned. keyIndex, err := n.cfg.DB.NextSessionKeyIndex(tower.ID) if err != nil { log.Debugf("Unable to reserve session key index "+ "for tower=%x: %v", towerPub, err) continue } // We'll now attempt the CreateSession dance with the tower to // get a new session, trying all addresses if necessary. err = n.createSession(tower, keyIndex) if err != nil { log.Debugf("Session negotiation with tower=%x "+ "failed, trying again -- reason: %v", tower.IdentityKey.SerializeCompressed(), err) continue } // Success. return } } // createSession takes a tower an attempts to negotiate a session using any of // its stored addresses. This method returns after the first successful // negotiation, or after all addresses have failed with ErrFailedNegotiation. If // the tower has no addresses, ErrNoTowerAddrs is returned. func (n *sessionNegotiator) createSession(tower *wtdb.Tower, keyIndex uint32) error { // If the tower has no addresses, there's nothing we can do. if len(tower.Addresses) == 0 { return ErrNoTowerAddrs } sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey( keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, Index: keyIndex, }, ) if err != nil { return err } sessionKey := keychain.NewPubKeyECDH( sessionKeyDesc, n.cfg.SecretKeyRing, ) for _, lnAddr := range tower.LNAddrs() { err := n.tryAddress(sessionKey, keyIndex, tower, lnAddr) switch { case err == ErrPermanentTowerFailure: // TODO(conner): report to iterator? can then be reset // with restart fallthrough case err != nil: log.Debugf("Request for session negotiation with "+ "tower=%s failed, trying again -- reason: "+ "%v", lnAddr, err) continue default: return nil } } return ErrFailedNegotiation } // tryAddress executes a single create session dance using the given address. // The address should belong to the tower's set of addresses. This method only // returns true if all steps succeed and the new session has been persisted, and // fails otherwise. func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH, keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { // Connect to the tower address using our generated session key. conn, err := n.cfg.Dial(sessionKey, lnAddr) if err != nil { return err } // Send local Init message. err = n.cfg.SendMessage(conn, n.localInit) if err != nil { return fmt.Errorf("unable to send Init: %v", err) } // Receive remote Init message. remoteMsg, err := n.cfg.ReadMessage(conn) if err != nil { return fmt.Errorf("unable to read Init: %v", err) } // Check that returned message is wtwire.Init. remoteInit, ok := remoteMsg.(*wtwire.Init) if !ok { return fmt.Errorf("expected Init, got %T in reply", remoteMsg) } // Verify the watchtower's remote Init message against our own. err = n.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames) if err != nil { return err } policy := n.cfg.Policy createSession := &wtwire.CreateSession{ BlobType: policy.BlobType, MaxUpdates: policy.MaxUpdates, RewardBase: policy.RewardBase, RewardRate: policy.RewardRate, SweepFeeRate: policy.SweepFeeRate, } // Send CreateSession message. err = n.cfg.SendMessage(conn, createSession) if err != nil { return fmt.Errorf("unable to send CreateSession: %v", err) } // Receive CreateSessionReply message. remoteMsg, err = n.cfg.ReadMessage(conn) if err != nil { return fmt.Errorf("unable to read CreateSessionReply: %v", err) } // Check that returned message is wtwire.CreateSessionReply. createSessionReply, ok := remoteMsg.(*wtwire.CreateSessionReply) if !ok { return fmt.Errorf("expected CreateSessionReply, got %T in "+ "reply", remoteMsg) } switch createSessionReply.Code { case wtwire.CodeOK, wtwire.CreateSessionCodeAlreadyExists: // TODO(conner): add last-applied to create session reply to // handle case where we lose state, session already exists, and // we want to possibly resume using the session // TODO(conner): validate reward address rewardPkScript := createSessionReply.Data sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey()) clientSession := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ TowerID: tower.ID, KeyIndex: keyIndex, Policy: n.cfg.Policy, RewardPkScript: rewardPkScript, }, Tower: tower, SessionKeyECDH: sessionKey, ID: sessionID, } err = n.cfg.DB.CreateClientSession(clientSession) if err != nil { return fmt.Errorf("unable to persist ClientSession: %v", err) } log.Debugf("New session negotiated with %s, policy: %s", lnAddr, clientSession.Policy) // We have a newly negotiated session, return it to the // dispatcher so that it can update how many outstanding // negotiation requests we have. select { case n.successfulNegotiations <- clientSession: return nil case <-n.quit: return ErrNegotiatorExiting } // TODO(conner): handle error codes properly case wtwire.CreateSessionCodeRejectBlobType: return fmt.Errorf("tower rejected blob type: %v", policy.BlobType) case wtwire.CreateSessionCodeRejectMaxUpdates: return fmt.Errorf("tower rejected max updates: %v", policy.MaxUpdates) case wtwire.CreateSessionCodeRejectRewardRate: // The tower rejected the session because of the reward rate. If // we didn't request a reward session, we'll treat this as a // permanent tower failure. if !policy.BlobType.Has(blob.FlagReward) { return ErrPermanentTowerFailure } return fmt.Errorf("tower rejected reward rate: %v", policy.RewardRate) case wtwire.CreateSessionCodeRejectSweepFeeRate: return fmt.Errorf("tower rejected sweep fee rate: %v", policy.SweepFeeRate) default: return fmt.Errorf("received unhandled error code: %v", createSessionReply.Code) } }