Merge pull request #2820 from cfromknecht/session-key-derivation
wtclient: session private key derivation
This commit is contained in:
commit
f1df2eadb7
@ -90,6 +90,12 @@ const (
|
||||
// a payment, or self stored on disk in a single file containing all
|
||||
// the static channel backups.
|
||||
KeyFamilyStaticBackup KeyFamily = 7
|
||||
|
||||
// KeyFamilyTowerSession is the family of keys that will be used to
|
||||
// derive session keys when negotiating sessions with watchtowers. The
|
||||
// session keys are limited to the lifetime of the session and are used
|
||||
// to increase privacy in the watchtower protocol.
|
||||
KeyFamilyTowerSession KeyFamily = 8
|
||||
)
|
||||
|
||||
// KeyLocator is a two-tuple that can be used to derive *any* key that has ever
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
@ -77,7 +76,7 @@ type Config struct {
|
||||
// SecretKeyRing is used to derive the session keys used to communicate
|
||||
// with the tower. The client only stores the KeyLocators internally so
|
||||
// that we never store private keys on disk.
|
||||
SecretKeyRing keychain.SecretKeyRing
|
||||
SecretKeyRing SecretKeyRing
|
||||
|
||||
// Dial connects to an addr using the specified net and returns the
|
||||
// connection object.
|
||||
@ -201,15 +200,16 @@ func New(config *Config) (*TowerClient, error) {
|
||||
forceQuit: make(chan struct{}),
|
||||
}
|
||||
c.negotiator = newSessionNegotiator(&NegotiatorConfig{
|
||||
DB: cfg.DB,
|
||||
Policy: cfg.Policy,
|
||||
ChainHash: cfg.ChainHash,
|
||||
SendMessage: c.sendMessage,
|
||||
ReadMessage: c.readMessage,
|
||||
Dial: c.dial,
|
||||
Candidates: newTowerListIterator(tower),
|
||||
MinBackoff: cfg.MinBackoff,
|
||||
MaxBackoff: cfg.MaxBackoff,
|
||||
DB: cfg.DB,
|
||||
SecretKeyRing: cfg.SecretKeyRing,
|
||||
Policy: cfg.Policy,
|
||||
ChainHash: cfg.ChainHash,
|
||||
SendMessage: c.sendMessage,
|
||||
ReadMessage: c.readMessage,
|
||||
Dial: c.dial,
|
||||
Candidates: newTowerListIterator(tower),
|
||||
MinBackoff: cfg.MinBackoff,
|
||||
MaxBackoff: cfg.MaxBackoff,
|
||||
})
|
||||
|
||||
// Next, load all active sessions from the db into the client. We will
|
||||
@ -221,6 +221,28 @@ func New(config *Config) (*TowerClient, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reload any towers from disk using the tower IDs contained in each
|
||||
// candidate session. We will also rederive any session keys needed to
|
||||
// be able to communicate with the towers and authenticate session
|
||||
// requests. This prevents us from having to store the private keys on
|
||||
// disk.
|
||||
for _, s := range c.candidateSessions {
|
||||
tower, err := c.cfg.DB.LoadTower(s.TowerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionPriv, err := DeriveSessionKey(
|
||||
c.cfg.SecretKeyRing, s.KeyIndex,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.Tower = tower
|
||||
s.SessionPrivKey = sessionPriv
|
||||
}
|
||||
|
||||
// Finally, load the sweep pkscripts that have been generated for all
|
||||
// previously registered channels.
|
||||
c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts()
|
||||
@ -334,9 +356,6 @@ func (c *TowerClient) ForceQuit() {
|
||||
c.forced.Do(func() {
|
||||
log.Infof("Force quitting watchtower client")
|
||||
|
||||
// Cancel log message from stop.
|
||||
close(c.forceQuit)
|
||||
|
||||
// 1. Shutdown the backup queue, which will prevent any further
|
||||
// updates from being accepted. In practice, the links should be
|
||||
// shutdown before the client has been stopped, so all updates
|
||||
@ -347,6 +366,7 @@ func (c *TowerClient) ForceQuit() {
|
||||
// dispatcher to exit. The backup queue will signal it's
|
||||
// completion to the dispatcher, which releases the wait group
|
||||
// after all tasks have been assigned to session queues.
|
||||
close(c.forceQuit)
|
||||
c.wg.Wait()
|
||||
|
||||
// 3. Since all valid tasks have been assigned to session
|
||||
@ -490,6 +510,9 @@ func (c *TowerClient) backupDispatcher() {
|
||||
|
||||
case <-c.statTicker.C:
|
||||
log.Infof("Client stats: %s", c.stats)
|
||||
|
||||
case <-c.forceQuit:
|
||||
return
|
||||
}
|
||||
|
||||
// No active session queue but have additional sessions.
|
||||
|
@ -379,10 +379,11 @@ type testHarness struct {
|
||||
}
|
||||
|
||||
type harnessCfg struct {
|
||||
localBalance lnwire.MilliSatoshi
|
||||
remoteBalance lnwire.MilliSatoshi
|
||||
policy wtpolicy.Policy
|
||||
noRegisterChan0 bool
|
||||
localBalance lnwire.MilliSatoshi
|
||||
remoteBalance lnwire.MilliSatoshi
|
||||
policy wtpolicy.Policy
|
||||
noRegisterChan0 bool
|
||||
noAckCreateSession bool
|
||||
}
|
||||
|
||||
func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||
@ -414,6 +415,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||
NewAddress: func() (btcutil.Address, error) {
|
||||
return addr, nil
|
||||
},
|
||||
NoAckCreateSession: cfg.noAckCreateSession,
|
||||
}
|
||||
|
||||
server, err := wtserver.New(serverCfg)
|
||||
@ -430,10 +432,11 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||
Dial: func(string, string) (net.Conn, error) {
|
||||
return nil, nil
|
||||
},
|
||||
DB: clientDB,
|
||||
AuthDial: mockNet.AuthDial,
|
||||
PrivateTower: towerAddr,
|
||||
Policy: cfg.policy,
|
||||
DB: clientDB,
|
||||
AuthDial: mockNet.AuthDial,
|
||||
SecretKeyRing: wtmock.NewSecretKeyRing(),
|
||||
PrivateTower: towerAddr,
|
||||
Policy: cfg.policy,
|
||||
NewAddress: func() ([]byte, error) {
|
||||
return addrScript, nil
|
||||
},
|
||||
@ -729,6 +732,36 @@ func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint,
|
||||
}
|
||||
}
|
||||
|
||||
// assertUpdatesForPolicy queries the server db for matches using the provided
|
||||
// breach hints, then asserts that each match has a session with the expected
|
||||
// policy.
|
||||
func (h *testHarness) assertUpdatesForPolicy(hints []wtdb.BreachHint,
|
||||
expPolicy wtpolicy.Policy) {
|
||||
|
||||
// Query for matches on the provided hints.
|
||||
matches, err := h.serverDB.QueryMatches(hints)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to query for matches: %v", err)
|
||||
}
|
||||
|
||||
// Assert that the number of matches is exactly the number of provided
|
||||
// hints.
|
||||
if len(matches) != len(hints) {
|
||||
h.t.Fatalf("expected: %d matches, got: %d", len(hints),
|
||||
len(matches))
|
||||
}
|
||||
|
||||
// Assert that all of the matches correspond to a session with the
|
||||
// expected policy.
|
||||
for _, match := range matches {
|
||||
matchPolicy := match.SessionInfo.Policy
|
||||
if expPolicy != matchPolicy {
|
||||
h.t.Fatalf("expected session to have policy: %v, "+
|
||||
"got: %v", expPolicy, matchPolicy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
localBalance = lnwire.MilliSatoshi(100000000)
|
||||
remoteBalance = lnwire.MilliSatoshi(200000000)
|
||||
@ -1098,6 +1131,119 @@ var clientTests = []clientTest{
|
||||
h.waitServerUpdates(hints, 10*time.Second)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create session no ack",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 5,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
noAckCreateSession: true,
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
chanID = 0
|
||||
numUpdates = 3
|
||||
)
|
||||
|
||||
// Generate the retributions that will be backed up.
|
||||
hints := h.advanceChannelN(chanID, numUpdates)
|
||||
|
||||
// Now, queue the retributions for backup.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Since the client is unable to create a session, the
|
||||
// server should have no updates.
|
||||
h.waitServerUpdates(nil, time.Second)
|
||||
|
||||
// Force quit the client since it has queued backups.
|
||||
h.client.ForceQuit()
|
||||
|
||||
// Restart the server and allow it to ack session
|
||||
// creation.
|
||||
h.server.Stop()
|
||||
h.serverCfg.NoAckCreateSession = false
|
||||
h.startServer()
|
||||
defer h.server.Stop()
|
||||
|
||||
// Restart the client with the same policy, which will
|
||||
// immediately try to overwrite the old session with an
|
||||
// identical one.
|
||||
h.startClient()
|
||||
defer h.client.ForceQuit()
|
||||
|
||||
// Now, queue the retributions for backup.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Wait for all of the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, 5*time.Second)
|
||||
|
||||
// Assert that the server has updates for the clients
|
||||
// most recent policy.
|
||||
h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create session no ack change policy",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 5,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
noAckCreateSession: true,
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
chanID = 0
|
||||
numUpdates = 3
|
||||
)
|
||||
|
||||
// Generate the retributions that will be backed up.
|
||||
hints := h.advanceChannelN(chanID, numUpdates)
|
||||
|
||||
// Now, queue the retributions for backup.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Since the client is unable to create a session, the
|
||||
// server should have no updates.
|
||||
h.waitServerUpdates(nil, time.Second)
|
||||
|
||||
// Force quit the client since it has queued backups.
|
||||
h.client.ForceQuit()
|
||||
|
||||
// Restart the server and allow it to ack session
|
||||
// creation.
|
||||
h.server.Stop()
|
||||
h.serverCfg.NoAckCreateSession = false
|
||||
h.startServer()
|
||||
defer h.server.Stop()
|
||||
|
||||
// Restart the client with a new policy, which will
|
||||
// immediately try to overwrite the prior session with
|
||||
// the old policy.
|
||||
h.clientCfg.Policy.SweepFeeRate = 2
|
||||
h.startClient()
|
||||
defer h.client.ForceQuit()
|
||||
|
||||
// Now, queue the retributions for backup.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Wait for all of the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, 5*time.Second)
|
||||
|
||||
// Assert that the server has updates for the clients
|
||||
// most recent policy.
|
||||
h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestClient executes the client test suite, asserting the ability to backup
|
||||
|
24
watchtower/wtclient/derivation.go
Normal file
24
watchtower/wtclient/derivation.go
Normal file
@ -0,0 +1,24 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
)
|
||||
|
||||
// DeriveSessionKey accepts an session key index for an existing session and
|
||||
// derives the HD private key to be used to authenticate the brontide transport
|
||||
// and authenticate requests sent to the tower. The key will use the
|
||||
// keychain.KeyFamilyTowerSession and the provided index, giving a BIP43
|
||||
// derivation path of:
|
||||
//
|
||||
// * m/1017'/coinType'/8/0/index
|
||||
func DeriveSessionKey(keyRing SecretKeyRing,
|
||||
index uint32) (*btcec.PrivateKey, error) {
|
||||
|
||||
return keyRing.DerivePrivKey(keychain.KeyDescriptor{
|
||||
KeyLocator: keychain.KeyLocator{
|
||||
Family: keychain.KeyFamilyTowerSession,
|
||||
Index: index,
|
||||
},
|
||||
})
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/brontide"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||
@ -19,6 +20,17 @@ type DB interface {
|
||||
// sessions.
|
||||
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
|
||||
|
||||
// LoadTower retrieves a tower by its tower ID.
|
||||
LoadTower(uint64) (*wtdb.Tower, error)
|
||||
|
||||
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||
// particular tower id. The index is reserved for that tower until
|
||||
// CreateClientSession is invoked for that tower and index, at which
|
||||
// point a new index for that tower can be reserved. Multiple calls to
|
||||
// this method before CreateClientSession is invoked should return the
|
||||
// same index.
|
||||
NextSessionKeyIndex(uint64) (uint32, error)
|
||||
|
||||
// CreateClientSession saves a newly negotiated client session to the
|
||||
// client's database. This enables the session to be used across
|
||||
// restarts.
|
||||
@ -74,3 +86,11 @@ func AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
|
||||
|
||||
return brontide.Dial(localPriv, netAddr, dialer)
|
||||
}
|
||||
|
||||
// SecretKeyRing abstracts the ability to derive HD private keys given a
|
||||
// description of the derivation path.
|
||||
type SecretKeyRing interface {
|
||||
// DerivePrivKey derives the private key from the root seed using a
|
||||
// key descriptor specifying the key's derivation path.
|
||||
DerivePrivKey(loc keychain.KeyDescriptor) (*btcec.PrivateKey, error)
|
||||
}
|
||||
|
@ -42,6 +42,10 @@ type NegotiatorConfig struct {
|
||||
// negotiated sessions.
|
||||
DB DB
|
||||
|
||||
// SecretKeyRing allows the client to derive new session private keys
|
||||
// when attempting to negotiate session with a tower.
|
||||
SecretKeyRing SecretKeyRing
|
||||
|
||||
// Candidates is an abstract set of tower candidates that the negotiator
|
||||
// will traverse serially when attempting to negotiate a new session.
|
||||
Candidates TowerCandidateIterator
|
||||
@ -224,7 +228,7 @@ func (n *sessionNegotiator) negotiate() {
|
||||
|
||||
// On the first pass, initialize the backoff to our configured min
|
||||
// backoff.
|
||||
backoff := n.cfg.MinBackoff
|
||||
var backoff time.Duration
|
||||
|
||||
retryWithBackoff:
|
||||
// If we are retrying, wait out the delay before continuing.
|
||||
@ -240,13 +244,24 @@ retryWithBackoff:
|
||||
// iterator to ensure the results are fresh.
|
||||
n.cfg.Candidates.Reset()
|
||||
for {
|
||||
select {
|
||||
case <-n.quit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Pull the next candidate from our list of addresses.
|
||||
tower, err := n.cfg.Candidates.Next()
|
||||
if err != nil {
|
||||
// We've run out of addresses, double and clamp backoff.
|
||||
backoff *= 2
|
||||
if backoff > n.cfg.MaxBackoff {
|
||||
backoff = n.cfg.MaxBackoff
|
||||
if backoff == 0 {
|
||||
backoff = n.cfg.MinBackoff
|
||||
} else {
|
||||
// We've run out of addresses, double and clamp
|
||||
// backoff.
|
||||
backoff *= 2
|
||||
if backoff > n.cfg.MaxBackoff {
|
||||
backoff = n.cfg.MaxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Unable to get new tower candidate, "+
|
||||
@ -255,12 +270,23 @@ retryWithBackoff:
|
||||
goto retryWithBackoff
|
||||
}
|
||||
|
||||
towerPub := tower.IdentityKey.SerializeCompressed()
|
||||
log.Debugf("Attempting session negotiation with tower=%x",
|
||||
tower.IdentityKey.SerializeCompressed())
|
||||
towerPub)
|
||||
|
||||
// Before proceeding, we will reserve a session key index to use
|
||||
// with this specific tower. If one is already reserved, the
|
||||
// existing index will be returned.
|
||||
keyIndex, err := n.cfg.DB.NextSessionKeyIndex(tower.ID)
|
||||
if err != nil {
|
||||
log.Debugf("Unable to reserve session key index "+
|
||||
"for tower=%x: %v", towerPub, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// We'll now attempt the CreateSession dance with the tower to
|
||||
// get a new session, trying all addresses if necessary.
|
||||
err = n.createSession(tower)
|
||||
err = n.createSession(tower, keyIndex)
|
||||
if err != nil {
|
||||
log.Debugf("Session negotiation with tower=%x "+
|
||||
"failed, trying again -- reason: %v",
|
||||
@ -277,22 +303,21 @@ retryWithBackoff:
|
||||
// its stored addresses. This method returns after the first successful
|
||||
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If
|
||||
// the tower has no addresses, ErrNoTowerAddrs is returned.
|
||||
func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error {
|
||||
func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
|
||||
keyIndex uint32) error {
|
||||
|
||||
// If the tower has no addresses, there's nothing we can do.
|
||||
if len(tower.Addresses) == 0 {
|
||||
return ErrNoTowerAddrs
|
||||
}
|
||||
|
||||
// TODO(conner): create with hdkey at random index
|
||||
sessionPrivKey, err := btcec.NewPrivateKey(btcec.S256())
|
||||
sessionPriv, err := DeriveSessionKey(n.cfg.SecretKeyRing, keyIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(conner): write towerAddr+privkey
|
||||
|
||||
for _, lnAddr := range tower.LNAddrs() {
|
||||
err = n.tryAddress(sessionPrivKey, tower, lnAddr)
|
||||
err = n.tryAddress(sessionPriv, keyIndex, tower, lnAddr)
|
||||
switch {
|
||||
case err == ErrPermanentTowerFailure:
|
||||
// TODO(conner): report to iterator? can then be reset
|
||||
@ -318,7 +343,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error {
|
||||
// returns true if all steps succeed and the new session has been persisted, and
|
||||
// fails otherwise.
|
||||
func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
|
||||
tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
|
||||
keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
|
||||
|
||||
// Connect to the tower address using our generated session key.
|
||||
conn, err := n.cfg.Dial(privKey, lnAddr)
|
||||
@ -394,7 +419,8 @@ func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
|
||||
clientSession := &wtdb.ClientSession{
|
||||
TowerID: tower.ID,
|
||||
Tower: tower,
|
||||
SessionPrivKey: privKey, // remove after using HD keys
|
||||
KeyIndex: keyIndex,
|
||||
SessionPrivKey: privKey,
|
||||
ID: sessionID,
|
||||
Policy: n.cfg.Policy,
|
||||
SeqNum: 0,
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"errors"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
@ -30,6 +29,15 @@ var (
|
||||
// LastApplied value greater than any allocated sequence number.
|
||||
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
|
||||
"greater than allocated seqnum")
|
||||
|
||||
// ErrNoReservedKeyIndex signals that a client session could not be
|
||||
// created because no session key index was reserved.
|
||||
ErrNoReservedKeyIndex = errors.New("key index not reserved")
|
||||
|
||||
// ErrIncorrectKeyIndex signals that the client session could not be
|
||||
// created because session key index differs from the reserved key
|
||||
// index.
|
||||
ErrIncorrectKeyIndex = errors.New("incorrect key index")
|
||||
)
|
||||
|
||||
// ClientSession encapsulates a SessionInfo returned from a successful
|
||||
@ -57,14 +65,17 @@ type ClientSession struct {
|
||||
// tower with TowerID.
|
||||
Tower *Tower
|
||||
|
||||
// SessionKeyDesc is the key descriptor used to derive the client's
|
||||
// KeyIndex is the index of key locator used to derive the client's
|
||||
// session key so that it can authenticate with the tower to update its
|
||||
// session.
|
||||
SessionKeyDesc keychain.KeyLocator
|
||||
// session. In order to rederive the private key, the key locator should
|
||||
// use the keychain.KeyFamilyTowerSession key family.
|
||||
KeyIndex uint32
|
||||
|
||||
// SessionPrivKey is the ephemeral secret key used to connect to the
|
||||
// watchtower.
|
||||
// TODO(conner): remove after HD keys
|
||||
//
|
||||
// NOTE: This value is not serialized. It is derived using the KeyIndex
|
||||
// on startup to avoid storing private keys on disk.
|
||||
SessionPrivKey *btcec.PrivateKey
|
||||
|
||||
// Policy holds the negotiated session parameters.
|
||||
|
@ -61,7 +61,8 @@ func (db *MockDB) InsertSessionInfo(info *SessionInfo) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
if _, ok := db.sessions[info.ID]; ok {
|
||||
dbInfo, ok := db.sessions[info.ID]
|
||||
if ok && dbInfo.LastApplied > 0 {
|
||||
return ErrSessionAlreadyExists
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
@ -8,6 +9,12 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTowerNotFound signals that the target tower was not found in the
|
||||
// database.
|
||||
ErrTowerNotFound = errors.New("tower not found")
|
||||
)
|
||||
|
||||
// Tower holds the necessary components required to connect to a remote tower.
|
||||
// Communication is handled by brontide, and requires both a public key and an
|
||||
// address.
|
||||
|
@ -22,6 +22,9 @@ type ClientDB struct {
|
||||
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
|
||||
towerIndex map[towerPK]uint64
|
||||
towers map[uint64]*wtdb.Tower
|
||||
|
||||
nextIndex uint32
|
||||
indexes map[uint64]uint32
|
||||
}
|
||||
|
||||
// NewClientDB initializes a new mock ClientDB.
|
||||
@ -31,6 +34,7 @@ func NewClientDB() *ClientDB {
|
||||
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
|
||||
towerIndex: make(map[towerPK]uint64),
|
||||
towers: make(map[uint64]*wtdb.Tower),
|
||||
indexes: make(map[uint64]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,6 +68,18 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||
return tower, nil
|
||||
}
|
||||
|
||||
// LoadTower retrieves a tower by its tower ID.
|
||||
func (m *ClientDB) LoadTower(towerID uint64) (*wtdb.Tower, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if tower, ok := m.towers[towerID]; ok {
|
||||
return tower, nil
|
||||
}
|
||||
|
||||
return nil, wtdb.ErrTowerNotFound
|
||||
}
|
||||
|
||||
// MarkBackupIneligible records that particular commit height is ineligible for
|
||||
// backup. This allows the client to track which updates it should not attempt
|
||||
// to retry after startup.
|
||||
@ -90,16 +106,29 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Ensure that a session key index has been reserved for this tower.
|
||||
keyIndex, ok := m.indexes[session.TowerID]
|
||||
if !ok {
|
||||
return wtdb.ErrNoReservedKeyIndex
|
||||
}
|
||||
|
||||
// Ensure that the session's index matches the reserved index.
|
||||
if keyIndex != session.KeyIndex {
|
||||
return wtdb.ErrIncorrectKeyIndex
|
||||
}
|
||||
|
||||
// Remove the key index reservation for this tower. Once committed, this
|
||||
// permits us to create another session with this tower.
|
||||
delete(m.indexes, session.TowerID)
|
||||
|
||||
m.activeSessions[session.ID] = &wtdb.ClientSession{
|
||||
TowerID: session.TowerID,
|
||||
Tower: session.Tower,
|
||||
SessionKeyDesc: session.SessionKeyDesc,
|
||||
SessionPrivKey: session.SessionPrivKey,
|
||||
KeyIndex: session.KeyIndex,
|
||||
ID: session.ID,
|
||||
Policy: session.Policy,
|
||||
SeqNum: session.SeqNum,
|
||||
TowerLastApplied: session.TowerLastApplied,
|
||||
RewardPkScript: session.RewardPkScript,
|
||||
RewardPkScript: cloneBytes(session.RewardPkScript),
|
||||
CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate),
|
||||
AckedUpdates: make(map[uint16]wtdb.BackupID),
|
||||
}
|
||||
@ -107,6 +136,27 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||
// particular tower id. The index is reserved for that tower until
|
||||
// CreateClientSession is invoked for that tower and index, at which point a new
|
||||
// index for that tower can be reserved. Multiple calls to this method before
|
||||
// CreateClientSession is invoked should return the same index.
|
||||
func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if index, ok := m.indexes[towerID]; ok {
|
||||
return index, nil
|
||||
}
|
||||
|
||||
index := m.nextIndex
|
||||
m.indexes[towerID] = index
|
||||
|
||||
m.nextIndex++
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||
// seqNum). This allows the client to retransmit this update on startup.
|
||||
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
@ -217,7 +267,12 @@ func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) err
|
||||
}
|
||||
|
||||
func cloneBytes(b []byte) []byte {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bb := make([]byte, len(b))
|
||||
copy(bb, b)
|
||||
|
||||
return bb
|
||||
}
|
||||
|
44
watchtower/wtmock/keyring.go
Normal file
44
watchtower/wtmock/keyring.go
Normal file
@ -0,0 +1,44 @@
|
||||
package wtmock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
)
|
||||
|
||||
// SecretKeyRing is a mock, in-memory implementation for deriving private keys.
|
||||
type SecretKeyRing struct {
|
||||
mu sync.Mutex
|
||||
keys map[keychain.KeyLocator]*btcec.PrivateKey
|
||||
}
|
||||
|
||||
// NewSecretKeyRing creates a new mock SecretKeyRing.
|
||||
func NewSecretKeyRing() *SecretKeyRing {
|
||||
return &SecretKeyRing{
|
||||
keys: make(map[keychain.KeyLocator]*btcec.PrivateKey),
|
||||
}
|
||||
}
|
||||
|
||||
// DerivePrivKey derives the private key for a given key descriptor. If
|
||||
// this method is called twice with the same argument, it will return the same
|
||||
// private key.
|
||||
func (m *SecretKeyRing) DerivePrivKey(
|
||||
desc keychain.KeyDescriptor) (*btcec.PrivateKey, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if key, ok := m.keys[desc.KeyLocator]; ok {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
privKey, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.keys[desc.KeyLocator] = privKey
|
||||
|
||||
return privKey, nil
|
||||
}
|
@ -21,45 +21,26 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
||||
existingInfo, err := s.cfg.DB.GetSessionInfo(id)
|
||||
switch {
|
||||
|
||||
// We already have a session, though it is currently unused. We'll allow
|
||||
// the client to recommit the session if it wanted to change the policy.
|
||||
case err == nil && existingInfo.LastApplied == 0:
|
||||
|
||||
// We already have a session corresponding to this session id, return an
|
||||
// error signaling that it already exists in our database. We return the
|
||||
// reward address to the client in case they were not able to process
|
||||
// our reply earlier.
|
||||
case err == nil:
|
||||
case err == nil && existingInfo.LastApplied > 0:
|
||||
log.Debugf("Already have session for %s", id)
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CreateSessionCodeAlreadyExists,
|
||||
existingInfo.RewardAddress,
|
||||
existingInfo.LastApplied, existingInfo.RewardAddress,
|
||||
)
|
||||
|
||||
// Some other database error occurred, return a temporary failure.
|
||||
case err != wtdb.ErrSessionNotFound:
|
||||
log.Errorf("unable to load session info for %s", id)
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||
)
|
||||
}
|
||||
|
||||
// Now that we've established that this session does not exist in the
|
||||
// database, retrieve the sweep address that will be given to the
|
||||
// client. This address is to be included by the client when signing
|
||||
// sweep transactions destined for this tower, if its negotiated output
|
||||
// is not dust.
|
||||
rewardAddress, err := s.cfg.NewAddress()
|
||||
if err != nil {
|
||||
log.Errorf("unable to generate reward addr for %s", id)
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||
)
|
||||
}
|
||||
|
||||
// Construct the pkscript the client should pay to when signing justice
|
||||
// transactions for this session.
|
||||
rewardScript, err := txscript.PayToAddrScript(rewardAddress)
|
||||
if err != nil {
|
||||
log.Errorf("unable to generate reward script for %s", id)
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||
)
|
||||
}
|
||||
|
||||
@ -68,10 +49,39 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
||||
log.Debugf("Rejecting CreateSession from %s, unsupported blob "+
|
||||
"type %s", id, req.BlobType)
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CreateSessionCodeRejectBlobType, nil,
|
||||
peer, id, wtwire.CreateSessionCodeRejectBlobType, 0,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// Now that we've established that this session does not exist in the
|
||||
// database, retrieve the sweep address that will be given to the
|
||||
// client. This address is to be included by the client when signing
|
||||
// sweep transactions destined for this tower, if its negotiated output
|
||||
// is not dust.
|
||||
var rewardScript []byte
|
||||
if req.BlobType.Has(blob.FlagReward) {
|
||||
rewardAddress, err := s.cfg.NewAddress()
|
||||
if err != nil {
|
||||
log.Errorf("Unable to generate reward addr for %s: %v",
|
||||
id, err)
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||
)
|
||||
}
|
||||
|
||||
// Construct the pkscript the client should pay to when signing
|
||||
// justice transactions for this session.
|
||||
rewardScript, err = txscript.PayToAddrScript(rewardAddress)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to generate reward script for "+
|
||||
"%s: %v", id, err)
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(conner): create invoice for upfront payment
|
||||
|
||||
// Assemble the session info using the agreed upon parameters, reward
|
||||
@ -94,14 +104,14 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
||||
if err != nil {
|
||||
log.Errorf("unable to create session for %s", id)
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||
)
|
||||
}
|
||||
|
||||
log.Infof("Accepted session for %s", id)
|
||||
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CodeOK, rewardScript,
|
||||
peer, id, wtwire.CodeOK, 0, rewardScript,
|
||||
)
|
||||
}
|
||||
|
||||
@ -110,11 +120,19 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
||||
// Otherwise, this method returns a connection error to ensure we don't continue
|
||||
// communication with the client.
|
||||
func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
|
||||
code wtwire.ErrorCode, data []byte) error {
|
||||
code wtwire.ErrorCode, lastApplied uint16, data []byte) error {
|
||||
|
||||
if s.cfg.NoAckCreateSession {
|
||||
return &connFailure{
|
||||
ID: *id,
|
||||
Code: code,
|
||||
}
|
||||
}
|
||||
|
||||
msg := &wtwire.CreateSessionReply{
|
||||
Code: code,
|
||||
Data: data,
|
||||
Code: code,
|
||||
LastApplied: lastApplied,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
err := s.sendMessage(peer, msg)
|
||||
@ -131,6 +149,6 @@ func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
|
||||
// disconnect the client.
|
||||
return &connFailure{
|
||||
ID: *id,
|
||||
Code: uint16(code),
|
||||
Code: code,
|
||||
}
|
||||
}
|
||||
|
@ -52,6 +52,6 @@ func (s *Server) replyDeleteSession(peer Peer, id *wtdb.SessionID,
|
||||
// disconnect the client.
|
||||
return &connFailure{
|
||||
ID: *id,
|
||||
Code: uint16(code),
|
||||
Code: code,
|
||||
}
|
||||
}
|
||||
|
@ -56,6 +56,10 @@ type Config struct {
|
||||
// ChainHash identifies the network that the server is watching.
|
||||
ChainHash chainhash.Hash
|
||||
|
||||
// NoAckCreateSession causes the server to not reply to create session
|
||||
// requests, this should only be used for testing.
|
||||
NoAckCreateSession bool
|
||||
|
||||
// NoAckUpdates causes the server to not acknowledge state updates, this
|
||||
// should only be used for testing.
|
||||
NoAckUpdates bool
|
||||
@ -283,12 +287,12 @@ func (s *Server) handleClient(peer Peer) {
|
||||
// error code.
|
||||
type connFailure struct {
|
||||
ID wtdb.SessionID
|
||||
Code uint16
|
||||
Code wtwire.ErrorCode
|
||||
}
|
||||
|
||||
// Error displays the SessionID and Code that caused the connection failure.
|
||||
func (f *connFailure) Error() string {
|
||||
return fmt.Sprintf("connection with %s failed with code=%v",
|
||||
return fmt.Sprintf("connection with %s failed with code=%s",
|
||||
f.ID, f.Code,
|
||||
)
|
||||
}
|
||||
|
@ -29,6 +29,8 @@ var (
|
||||
addrScript, _ = txscript.PayToAddrScript(addr)
|
||||
|
||||
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
|
||||
|
||||
rewardType = (blob.FlagCommitOutputs | blob.FlagReward).Type()
|
||||
)
|
||||
|
||||
// randPubKey generates a new secp keypair, and returns the public key.
|
||||
@ -152,16 +154,17 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) {
|
||||
}
|
||||
|
||||
type createSessionTestCase struct {
|
||||
name string
|
||||
initMsg *wtwire.Init
|
||||
createMsg *wtwire.CreateSession
|
||||
expReply *wtwire.CreateSessionReply
|
||||
expDupReply *wtwire.CreateSessionReply
|
||||
name string
|
||||
initMsg *wtwire.Init
|
||||
createMsg *wtwire.CreateSession
|
||||
expReply *wtwire.CreateSessionReply
|
||||
expDupReply *wtwire.CreateSessionReply
|
||||
sendStateUpdate bool
|
||||
}
|
||||
|
||||
var createSessionTests = []createSessionTestCase{
|
||||
{
|
||||
name: "reject duplicate session create",
|
||||
name: "duplicate session create",
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
@ -173,12 +176,58 @@ var createSessionTests = []createSessionTestCase{
|
||||
RewardRate: 0,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
expReply: &wtwire.CreateSessionReply{
|
||||
Code: wtwire.CodeOK,
|
||||
Data: []byte{},
|
||||
},
|
||||
expDupReply: &wtwire.CreateSessionReply{
|
||||
Code: wtwire.CodeOK,
|
||||
Data: []byte{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "duplicate session create after use",
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 1000,
|
||||
RewardBase: 0,
|
||||
RewardRate: 0,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
expReply: &wtwire.CreateSessionReply{
|
||||
Code: wtwire.CodeOK,
|
||||
Data: []byte{},
|
||||
},
|
||||
expDupReply: &wtwire.CreateSessionReply{
|
||||
Code: wtwire.CreateSessionCodeAlreadyExists,
|
||||
LastApplied: 1,
|
||||
Data: []byte{},
|
||||
},
|
||||
sendStateUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "duplicate session create reward",
|
||||
initMsg: wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(),
|
||||
testnetChainHash,
|
||||
),
|
||||
createMsg: &wtwire.CreateSession{
|
||||
BlobType: rewardType,
|
||||
MaxUpdates: 1000,
|
||||
RewardBase: 0,
|
||||
RewardRate: 0,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
expReply: &wtwire.CreateSessionReply{
|
||||
Code: wtwire.CodeOK,
|
||||
Data: addrScript,
|
||||
},
|
||||
expDupReply: &wtwire.CreateSessionReply{
|
||||
Code: wtwire.CreateSessionCodeAlreadyExists,
|
||||
Code: wtwire.CodeOK,
|
||||
Data: addrScript,
|
||||
},
|
||||
},
|
||||
@ -251,6 +300,18 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
|
||||
return
|
||||
}
|
||||
|
||||
if test.sendStateUpdate {
|
||||
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||
connect(t, s, peer, test.initMsg, timeoutDuration)
|
||||
update := &wtwire.StateUpdate{
|
||||
SeqNum: 1,
|
||||
IsComplete: 1,
|
||||
}
|
||||
sendMsg(t, update, peer, timeoutDuration)
|
||||
|
||||
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||
}
|
||||
|
||||
// Simulate a peer with the same session id connection to the server
|
||||
// again.
|
||||
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||
@ -705,7 +766,7 @@ func TestServerDeleteSession(t *testing.T) {
|
||||
send: createSession,
|
||||
recv: &wtwire.CreateSessionReply{
|
||||
Code: wtwire.CodeOK,
|
||||
Data: addrScript,
|
||||
Data: []byte{},
|
||||
},
|
||||
assert: func(t *testing.T) {
|
||||
// Both peers should have sessions.
|
||||
|
@ -117,7 +117,7 @@ func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
|
||||
if s.cfg.NoAckUpdates {
|
||||
return &connFailure{
|
||||
ID: *id,
|
||||
Code: uint16(failCode),
|
||||
Code: failCode,
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,6 +152,6 @@ func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID,
|
||||
// disconnect the client.
|
||||
return &connFailure{
|
||||
ID: *id,
|
||||
Code: uint16(code),
|
||||
Code: code,
|
||||
}
|
||||
}
|
||||
|
@ -43,6 +43,12 @@ type CreateSessionReply struct {
|
||||
// Code will be non-zero if the watchtower rejected the session init.
|
||||
Code CreateSessionCode
|
||||
|
||||
// LastApplied is the tower's last accepted sequence number for the
|
||||
// session. This is useful when the session already exists but the
|
||||
// client doesn't realize it's already used the session, such as after a
|
||||
// restoration.
|
||||
LastApplied uint16
|
||||
|
||||
// Data is a byte slice returned the caller of the message, and is to be
|
||||
// interpreted according to the error Code. When the response is
|
||||
// CreateSessionCodeOK, data encodes the reward address to be included in
|
||||
@ -63,6 +69,7 @@ var _ Message = (*CreateSessionReply)(nil)
|
||||
func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error {
|
||||
return ReadElements(r,
|
||||
&m.Code,
|
||||
&m.LastApplied,
|
||||
&m.Data,
|
||||
)
|
||||
}
|
||||
@ -74,6 +81,7 @@ func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error {
|
||||
func (m *CreateSessionReply) Encode(w io.Writer, pver uint32) error {
|
||||
return WriteElements(w,
|
||||
m.Code,
|
||||
m.LastApplied,
|
||||
m.Data,
|
||||
)
|
||||
}
|
||||
|
@ -12,8 +12,6 @@ const (
|
||||
// client side, or that the tower had already deleted the session in a
|
||||
// prior request that the client may not have received.
|
||||
DeleteSessionCodeNotFound DeleteSessionCode = 80
|
||||
|
||||
// TODO(conner): add String method after wtclient is merged
|
||||
)
|
||||
|
||||
// DeleteSessionReply is a message sent in response to a client's DeleteSession
|
||||
|
@ -46,6 +46,8 @@ func (c ErrorCode) String() string {
|
||||
return "StateUpdateCodeMaxUpdatesExceeded"
|
||||
case StateUpdateCodeSeqNumOutOfOrder:
|
||||
return "StateUpdateCodeSeqNumOutOfOrder"
|
||||
case DeleteSessionCodeNotFound:
|
||||
return "DeleteSessionCodeNotFound"
|
||||
default:
|
||||
return fmt.Sprintf("UnknownErrorCode: %d", c)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user