diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 51d3c61b..d7d05635 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -9,7 +9,6 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/input" - "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/wtdb" @@ -77,7 +76,7 @@ type Config struct { // SecretKeyRing is used to derive the session keys used to communicate // with the tower. The client only stores the KeyLocators internally so // that we never store private keys on disk. - SecretKeyRing keychain.SecretKeyRing + SecretKeyRing SecretKeyRing // Dial connects to an addr using the specified net and returns the // connection object. @@ -201,15 +200,16 @@ func New(config *Config) (*TowerClient, error) { forceQuit: make(chan struct{}), } c.negotiator = newSessionNegotiator(&NegotiatorConfig{ - DB: cfg.DB, - Policy: cfg.Policy, - ChainHash: cfg.ChainHash, - SendMessage: c.sendMessage, - ReadMessage: c.readMessage, - Dial: c.dial, - Candidates: newTowerListIterator(tower), - MinBackoff: cfg.MinBackoff, - MaxBackoff: cfg.MaxBackoff, + DB: cfg.DB, + SecretKeyRing: cfg.SecretKeyRing, + Policy: cfg.Policy, + ChainHash: cfg.ChainHash, + SendMessage: c.sendMessage, + ReadMessage: c.readMessage, + Dial: c.dial, + Candidates: newTowerListIterator(tower), + MinBackoff: cfg.MinBackoff, + MaxBackoff: cfg.MaxBackoff, }) // Next, load all active sessions from the db into the client. We will @@ -222,14 +222,25 @@ func New(config *Config) (*TowerClient, error) { } // Reload any towers from disk using the tower IDs contained in each - // candidate session. + // candidate session. We will also rederive any session keys needed to + // be able to communicate with the towers and authenticate session + // requests. This prevents us from having to store the private keys on + // disk. for _, s := range c.candidateSessions { tower, err := c.cfg.DB.LoadTower(s.TowerID) if err != nil { return nil, err } + sessionPriv, err := DeriveSessionKey( + c.cfg.SecretKeyRing, s.KeyIndex, + ) + if err != nil { + return nil, err + } + s.Tower = tower + s.SessionPrivKey = sessionPriv } // Finally, load the sweep pkscripts that have been generated for all diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index dba4275e..05753d37 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -430,10 +430,11 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { Dial: func(string, string) (net.Conn, error) { return nil, nil }, - DB: clientDB, - AuthDial: mockNet.AuthDial, - PrivateTower: towerAddr, - Policy: cfg.policy, + DB: clientDB, + AuthDial: mockNet.AuthDial, + SecretKeyRing: wtmock.NewSecretKeyRing(), + PrivateTower: towerAddr, + Policy: cfg.policy, NewAddress: func() ([]byte, error) { return addrScript, nil }, diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index b62819cb..19c7347c 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -42,6 +42,10 @@ type NegotiatorConfig struct { // negotiated sessions. DB DB + // SecretKeyRing allows the client to derive new session private keys + // when attempting to negotiate session with a tower. + SecretKeyRing SecretKeyRing + // Candidates is an abstract set of tower candidates that the negotiator // will traverse serially when attempting to negotiate a new session. Candidates TowerCandidateIterator @@ -255,12 +259,23 @@ retryWithBackoff: goto retryWithBackoff } + towerPub := tower.IdentityKey.SerializeCompressed() log.Debugf("Attempting session negotiation with tower=%x", - tower.IdentityKey.SerializeCompressed()) + towerPub) + + // Before proceeding, we will reserve a session key index to use + // with this specific tower. If one is already reserved, the + // existing index will be returned. + keyIndex, err := n.cfg.DB.NextSessionKeyIndex(tower.ID) + if err != nil { + log.Debugf("Unable to reserve session key index "+ + "for tower=%x: %v", towerPub, err) + continue + } // We'll now attempt the CreateSession dance with the tower to // get a new session, trying all addresses if necessary. - err = n.createSession(tower) + err = n.createSession(tower, keyIndex) if err != nil { log.Debugf("Session negotiation with tower=%x "+ "failed, trying again -- reason: %v", @@ -277,22 +292,21 @@ retryWithBackoff: // its stored addresses. This method returns after the first successful // negotiation, or after all addresses have failed with ErrFailedNegotiation. If // the tower has no addresses, ErrNoTowerAddrs is returned. -func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error { +func (n *sessionNegotiator) createSession(tower *wtdb.Tower, + keyIndex uint32) error { + // If the tower has no addresses, there's nothing we can do. if len(tower.Addresses) == 0 { return ErrNoTowerAddrs } - // TODO(conner): create with hdkey at random index - sessionPrivKey, err := btcec.NewPrivateKey(btcec.S256()) + sessionPriv, err := DeriveSessionKey(n.cfg.SecretKeyRing, keyIndex) if err != nil { return err } - // TODO(conner): write towerAddr+privkey - for _, lnAddr := range tower.LNAddrs() { - err = n.tryAddress(sessionPrivKey, tower, lnAddr) + err = n.tryAddress(sessionPriv, keyIndex, tower, lnAddr) switch { case err == ErrPermanentTowerFailure: // TODO(conner): report to iterator? can then be reset @@ -318,7 +332,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error { // returns true if all steps succeed and the new session has been persisted, and // fails otherwise. func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey, - tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { + keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { // Connect to the tower address using our generated session key. conn, err := n.cfg.Dial(privKey, lnAddr) @@ -394,7 +408,8 @@ func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey, clientSession := &wtdb.ClientSession{ TowerID: tower.ID, Tower: tower, - SessionPrivKey: privKey, // remove after using HD keys + KeyIndex: keyIndex, + SessionPrivKey: privKey, ID: sessionID, Policy: n.cfg.Policy, SeqNum: 0, diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 6ce73dbc..a4408470 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -109,7 +109,6 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { m.activeSessions[session.ID] = &wtdb.ClientSession{ TowerID: session.TowerID, KeyIndex: session.KeyIndex, - SessionPrivKey: session.SessionPrivKey, ID: session.ID, Policy: session.Policy, SeqNum: session.SeqNum,