Merge pull request #2820 from cfromknecht/session-key-derivation

wtclient: session private key derivation
This commit is contained in:
Olaoluwa Osuntokun 2019-04-25 17:59:46 -07:00 committed by GitHub
commit f1df2eadb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 549 additions and 95 deletions

@ -90,6 +90,12 @@ const (
// a payment, or self stored on disk in a single file containing all
// the static channel backups.
KeyFamilyStaticBackup KeyFamily = 7
// KeyFamilyTowerSession is the family of keys that will be used to
// derive session keys when negotiating sessions with watchtowers. The
// session keys are limited to the lifetime of the session and are used
// to increase privacy in the watchtower protocol.
KeyFamilyTowerSession KeyFamily = 8
)
// KeyLocator is a two-tuple that can be used to derive *any* key that has ever

@ -9,7 +9,6 @@ import (
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
@ -77,7 +76,7 @@ type Config struct {
// SecretKeyRing is used to derive the session keys used to communicate
// with the tower. The client only stores the KeyLocators internally so
// that we never store private keys on disk.
SecretKeyRing keychain.SecretKeyRing
SecretKeyRing SecretKeyRing
// Dial connects to an addr using the specified net and returns the
// connection object.
@ -201,15 +200,16 @@ func New(config *Config) (*TowerClient, error) {
forceQuit: make(chan struct{}),
}
c.negotiator = newSessionNegotiator(&NegotiatorConfig{
DB: cfg.DB,
Policy: cfg.Policy,
ChainHash: cfg.ChainHash,
SendMessage: c.sendMessage,
ReadMessage: c.readMessage,
Dial: c.dial,
Candidates: newTowerListIterator(tower),
MinBackoff: cfg.MinBackoff,
MaxBackoff: cfg.MaxBackoff,
DB: cfg.DB,
SecretKeyRing: cfg.SecretKeyRing,
Policy: cfg.Policy,
ChainHash: cfg.ChainHash,
SendMessage: c.sendMessage,
ReadMessage: c.readMessage,
Dial: c.dial,
Candidates: newTowerListIterator(tower),
MinBackoff: cfg.MinBackoff,
MaxBackoff: cfg.MaxBackoff,
})
// Next, load all active sessions from the db into the client. We will
@ -221,6 +221,28 @@ func New(config *Config) (*TowerClient, error) {
return nil, err
}
// Reload any towers from disk using the tower IDs contained in each
// candidate session. We will also rederive any session keys needed to
// be able to communicate with the towers and authenticate session
// requests. This prevents us from having to store the private keys on
// disk.
for _, s := range c.candidateSessions {
tower, err := c.cfg.DB.LoadTower(s.TowerID)
if err != nil {
return nil, err
}
sessionPriv, err := DeriveSessionKey(
c.cfg.SecretKeyRing, s.KeyIndex,
)
if err != nil {
return nil, err
}
s.Tower = tower
s.SessionPrivKey = sessionPriv
}
// Finally, load the sweep pkscripts that have been generated for all
// previously registered channels.
c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts()
@ -334,9 +356,6 @@ func (c *TowerClient) ForceQuit() {
c.forced.Do(func() {
log.Infof("Force quitting watchtower client")
// Cancel log message from stop.
close(c.forceQuit)
// 1. Shutdown the backup queue, which will prevent any further
// updates from being accepted. In practice, the links should be
// shutdown before the client has been stopped, so all updates
@ -347,6 +366,7 @@ func (c *TowerClient) ForceQuit() {
// dispatcher to exit. The backup queue will signal it's
// completion to the dispatcher, which releases the wait group
// after all tasks have been assigned to session queues.
close(c.forceQuit)
c.wg.Wait()
// 3. Since all valid tasks have been assigned to session
@ -490,6 +510,9 @@ func (c *TowerClient) backupDispatcher() {
case <-c.statTicker.C:
log.Infof("Client stats: %s", c.stats)
case <-c.forceQuit:
return
}
// No active session queue but have additional sessions.

@ -379,10 +379,11 @@ type testHarness struct {
}
type harnessCfg struct {
localBalance lnwire.MilliSatoshi
remoteBalance lnwire.MilliSatoshi
policy wtpolicy.Policy
noRegisterChan0 bool
localBalance lnwire.MilliSatoshi
remoteBalance lnwire.MilliSatoshi
policy wtpolicy.Policy
noRegisterChan0 bool
noAckCreateSession bool
}
func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
@ -414,6 +415,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
NewAddress: func() (btcutil.Address, error) {
return addr, nil
},
NoAckCreateSession: cfg.noAckCreateSession,
}
server, err := wtserver.New(serverCfg)
@ -430,10 +432,11 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
Dial: func(string, string) (net.Conn, error) {
return nil, nil
},
DB: clientDB,
AuthDial: mockNet.AuthDial,
PrivateTower: towerAddr,
Policy: cfg.policy,
DB: clientDB,
AuthDial: mockNet.AuthDial,
SecretKeyRing: wtmock.NewSecretKeyRing(),
PrivateTower: towerAddr,
Policy: cfg.policy,
NewAddress: func() ([]byte, error) {
return addrScript, nil
},
@ -729,6 +732,36 @@ func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint,
}
}
// assertUpdatesForPolicy queries the server db for matches using the provided
// breach hints, then asserts that each match has a session with the expected
// policy.
func (h *testHarness) assertUpdatesForPolicy(hints []wtdb.BreachHint,
expPolicy wtpolicy.Policy) {
// Query for matches on the provided hints.
matches, err := h.serverDB.QueryMatches(hints)
if err != nil {
h.t.Fatalf("unable to query for matches: %v", err)
}
// Assert that the number of matches is exactly the number of provided
// hints.
if len(matches) != len(hints) {
h.t.Fatalf("expected: %d matches, got: %d", len(hints),
len(matches))
}
// Assert that all of the matches correspond to a session with the
// expected policy.
for _, match := range matches {
matchPolicy := match.SessionInfo.Policy
if expPolicy != matchPolicy {
h.t.Fatalf("expected session to have policy: %v, "+
"got: %v", expPolicy, matchPolicy)
}
}
}
const (
localBalance = lnwire.MilliSatoshi(100000000)
remoteBalance = lnwire.MilliSatoshi(200000000)
@ -1098,6 +1131,119 @@ var clientTests = []clientTest{
h.waitServerUpdates(hints, 10*time.Second)
},
},
{
name: "create session no ack",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 5,
SweepFeeRate: 1,
},
noAckCreateSession: true,
},
fn: func(h *testHarness) {
const (
chanID = 0
numUpdates = 3
)
// Generate the retributions that will be backed up.
hints := h.advanceChannelN(chanID, numUpdates)
// Now, queue the retributions for backup.
h.backupStates(chanID, 0, numUpdates, nil)
// Since the client is unable to create a session, the
// server should have no updates.
h.waitServerUpdates(nil, time.Second)
// Force quit the client since it has queued backups.
h.client.ForceQuit()
// Restart the server and allow it to ack session
// creation.
h.server.Stop()
h.serverCfg.NoAckCreateSession = false
h.startServer()
defer h.server.Stop()
// Restart the client with the same policy, which will
// immediately try to overwrite the old session with an
// identical one.
h.startClient()
defer h.client.ForceQuit()
// Now, queue the retributions for backup.
h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, 5*time.Second)
// Assert that the server has updates for the clients
// most recent policy.
h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
},
},
{
name: "create session no ack change policy",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 5,
SweepFeeRate: 1,
},
noAckCreateSession: true,
},
fn: func(h *testHarness) {
const (
chanID = 0
numUpdates = 3
)
// Generate the retributions that will be backed up.
hints := h.advanceChannelN(chanID, numUpdates)
// Now, queue the retributions for backup.
h.backupStates(chanID, 0, numUpdates, nil)
// Since the client is unable to create a session, the
// server should have no updates.
h.waitServerUpdates(nil, time.Second)
// Force quit the client since it has queued backups.
h.client.ForceQuit()
// Restart the server and allow it to ack session
// creation.
h.server.Stop()
h.serverCfg.NoAckCreateSession = false
h.startServer()
defer h.server.Stop()
// Restart the client with a new policy, which will
// immediately try to overwrite the prior session with
// the old policy.
h.clientCfg.Policy.SweepFeeRate = 2
h.startClient()
defer h.client.ForceQuit()
// Now, queue the retributions for backup.
h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, 5*time.Second)
// Assert that the server has updates for the clients
// most recent policy.
h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
},
},
}
// TestClient executes the client test suite, asserting the ability to backup

@ -0,0 +1,24 @@
package wtclient
import (
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/keychain"
)
// DeriveSessionKey accepts an session key index for an existing session and
// derives the HD private key to be used to authenticate the brontide transport
// and authenticate requests sent to the tower. The key will use the
// keychain.KeyFamilyTowerSession and the provided index, giving a BIP43
// derivation path of:
//
// * m/1017'/coinType'/8/0/index
func DeriveSessionKey(keyRing SecretKeyRing,
index uint32) (*btcec.PrivateKey, error) {
return keyRing.DerivePrivKey(keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession,
Index: index,
},
})
}

@ -5,6 +5,7 @@ import (
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/brontide"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
@ -19,6 +20,17 @@ type DB interface {
// sessions.
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
// LoadTower retrieves a tower by its tower ID.
LoadTower(uint64) (*wtdb.Tower, error)
// NextSessionKeyIndex reserves a new session key derivation index for a
// particular tower id. The index is reserved for that tower until
// CreateClientSession is invoked for that tower and index, at which
// point a new index for that tower can be reserved. Multiple calls to
// this method before CreateClientSession is invoked should return the
// same index.
NextSessionKeyIndex(uint64) (uint32, error)
// CreateClientSession saves a newly negotiated client session to the
// client's database. This enables the session to be used across
// restarts.
@ -74,3 +86,11 @@ func AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
return brontide.Dial(localPriv, netAddr, dialer)
}
// SecretKeyRing abstracts the ability to derive HD private keys given a
// description of the derivation path.
type SecretKeyRing interface {
// DerivePrivKey derives the private key from the root seed using a
// key descriptor specifying the key's derivation path.
DerivePrivKey(loc keychain.KeyDescriptor) (*btcec.PrivateKey, error)
}

@ -42,6 +42,10 @@ type NegotiatorConfig struct {
// negotiated sessions.
DB DB
// SecretKeyRing allows the client to derive new session private keys
// when attempting to negotiate session with a tower.
SecretKeyRing SecretKeyRing
// Candidates is an abstract set of tower candidates that the negotiator
// will traverse serially when attempting to negotiate a new session.
Candidates TowerCandidateIterator
@ -224,7 +228,7 @@ func (n *sessionNegotiator) negotiate() {
// On the first pass, initialize the backoff to our configured min
// backoff.
backoff := n.cfg.MinBackoff
var backoff time.Duration
retryWithBackoff:
// If we are retrying, wait out the delay before continuing.
@ -240,13 +244,24 @@ retryWithBackoff:
// iterator to ensure the results are fresh.
n.cfg.Candidates.Reset()
for {
select {
case <-n.quit:
return
default:
}
// Pull the next candidate from our list of addresses.
tower, err := n.cfg.Candidates.Next()
if err != nil {
// We've run out of addresses, double and clamp backoff.
backoff *= 2
if backoff > n.cfg.MaxBackoff {
backoff = n.cfg.MaxBackoff
if backoff == 0 {
backoff = n.cfg.MinBackoff
} else {
// We've run out of addresses, double and clamp
// backoff.
backoff *= 2
if backoff > n.cfg.MaxBackoff {
backoff = n.cfg.MaxBackoff
}
}
log.Debugf("Unable to get new tower candidate, "+
@ -255,12 +270,23 @@ retryWithBackoff:
goto retryWithBackoff
}
towerPub := tower.IdentityKey.SerializeCompressed()
log.Debugf("Attempting session negotiation with tower=%x",
tower.IdentityKey.SerializeCompressed())
towerPub)
// Before proceeding, we will reserve a session key index to use
// with this specific tower. If one is already reserved, the
// existing index will be returned.
keyIndex, err := n.cfg.DB.NextSessionKeyIndex(tower.ID)
if err != nil {
log.Debugf("Unable to reserve session key index "+
"for tower=%x: %v", towerPub, err)
continue
}
// We'll now attempt the CreateSession dance with the tower to
// get a new session, trying all addresses if necessary.
err = n.createSession(tower)
err = n.createSession(tower, keyIndex)
if err != nil {
log.Debugf("Session negotiation with tower=%x "+
"failed, trying again -- reason: %v",
@ -277,22 +303,21 @@ retryWithBackoff:
// its stored addresses. This method returns after the first successful
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If
// the tower has no addresses, ErrNoTowerAddrs is returned.
func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error {
func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
keyIndex uint32) error {
// If the tower has no addresses, there's nothing we can do.
if len(tower.Addresses) == 0 {
return ErrNoTowerAddrs
}
// TODO(conner): create with hdkey at random index
sessionPrivKey, err := btcec.NewPrivateKey(btcec.S256())
sessionPriv, err := DeriveSessionKey(n.cfg.SecretKeyRing, keyIndex)
if err != nil {
return err
}
// TODO(conner): write towerAddr+privkey
for _, lnAddr := range tower.LNAddrs() {
err = n.tryAddress(sessionPrivKey, tower, lnAddr)
err = n.tryAddress(sessionPriv, keyIndex, tower, lnAddr)
switch {
case err == ErrPermanentTowerFailure:
// TODO(conner): report to iterator? can then be reset
@ -318,7 +343,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error {
// returns true if all steps succeed and the new session has been persisted, and
// fails otherwise.
func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
// Connect to the tower address using our generated session key.
conn, err := n.cfg.Dial(privKey, lnAddr)
@ -394,7 +419,8 @@ func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
clientSession := &wtdb.ClientSession{
TowerID: tower.ID,
Tower: tower,
SessionPrivKey: privKey, // remove after using HD keys
KeyIndex: keyIndex,
SessionPrivKey: privKey,
ID: sessionID,
Policy: n.cfg.Policy,
SeqNum: 0,

@ -4,7 +4,6 @@ import (
"errors"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
@ -30,6 +29,15 @@ var (
// LastApplied value greater than any allocated sequence number.
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
"greater than allocated seqnum")
// ErrNoReservedKeyIndex signals that a client session could not be
// created because no session key index was reserved.
ErrNoReservedKeyIndex = errors.New("key index not reserved")
// ErrIncorrectKeyIndex signals that the client session could not be
// created because session key index differs from the reserved key
// index.
ErrIncorrectKeyIndex = errors.New("incorrect key index")
)
// ClientSession encapsulates a SessionInfo returned from a successful
@ -57,14 +65,17 @@ type ClientSession struct {
// tower with TowerID.
Tower *Tower
// SessionKeyDesc is the key descriptor used to derive the client's
// KeyIndex is the index of key locator used to derive the client's
// session key so that it can authenticate with the tower to update its
// session.
SessionKeyDesc keychain.KeyLocator
// session. In order to rederive the private key, the key locator should
// use the keychain.KeyFamilyTowerSession key family.
KeyIndex uint32
// SessionPrivKey is the ephemeral secret key used to connect to the
// watchtower.
// TODO(conner): remove after HD keys
//
// NOTE: This value is not serialized. It is derived using the KeyIndex
// on startup to avoid storing private keys on disk.
SessionPrivKey *btcec.PrivateKey
// Policy holds the negotiated session parameters.

@ -61,7 +61,8 @@ func (db *MockDB) InsertSessionInfo(info *SessionInfo) error {
db.mu.Lock()
defer db.mu.Unlock()
if _, ok := db.sessions[info.ID]; ok {
dbInfo, ok := db.sessions[info.ID]
if ok && dbInfo.LastApplied > 0 {
return ErrSessionAlreadyExists
}

@ -1,6 +1,7 @@
package wtdb
import (
"errors"
"net"
"sync"
@ -8,6 +9,12 @@ import (
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrTowerNotFound signals that the target tower was not found in the
// database.
ErrTowerNotFound = errors.New("tower not found")
)
// Tower holds the necessary components required to connect to a remote tower.
// Communication is handled by brontide, and requires both a public key and an
// address.

@ -22,6 +22,9 @@ type ClientDB struct {
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
towerIndex map[towerPK]uint64
towers map[uint64]*wtdb.Tower
nextIndex uint32
indexes map[uint64]uint32
}
// NewClientDB initializes a new mock ClientDB.
@ -31,6 +34,7 @@ func NewClientDB() *ClientDB {
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
towerIndex: make(map[towerPK]uint64),
towers: make(map[uint64]*wtdb.Tower),
indexes: make(map[uint64]uint32),
}
}
@ -64,6 +68,18 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
return tower, nil
}
// LoadTower retrieves a tower by its tower ID.
func (m *ClientDB) LoadTower(towerID uint64) (*wtdb.Tower, error) {
m.mu.Lock()
defer m.mu.Unlock()
if tower, ok := m.towers[towerID]; ok {
return tower, nil
}
return nil, wtdb.ErrTowerNotFound
}
// MarkBackupIneligible records that particular commit height is ineligible for
// backup. This allows the client to track which updates it should not attempt
// to retry after startup.
@ -90,16 +106,29 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
m.mu.Lock()
defer m.mu.Unlock()
// Ensure that a session key index has been reserved for this tower.
keyIndex, ok := m.indexes[session.TowerID]
if !ok {
return wtdb.ErrNoReservedKeyIndex
}
// Ensure that the session's index matches the reserved index.
if keyIndex != session.KeyIndex {
return wtdb.ErrIncorrectKeyIndex
}
// Remove the key index reservation for this tower. Once committed, this
// permits us to create another session with this tower.
delete(m.indexes, session.TowerID)
m.activeSessions[session.ID] = &wtdb.ClientSession{
TowerID: session.TowerID,
Tower: session.Tower,
SessionKeyDesc: session.SessionKeyDesc,
SessionPrivKey: session.SessionPrivKey,
KeyIndex: session.KeyIndex,
ID: session.ID,
Policy: session.Policy,
SeqNum: session.SeqNum,
TowerLastApplied: session.TowerLastApplied,
RewardPkScript: session.RewardPkScript,
RewardPkScript: cloneBytes(session.RewardPkScript),
CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate),
AckedUpdates: make(map[uint16]wtdb.BackupID),
}
@ -107,6 +136,27 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
return nil
}
// NextSessionKeyIndex reserves a new session key derivation index for a
// particular tower id. The index is reserved for that tower until
// CreateClientSession is invoked for that tower and index, at which point a new
// index for that tower can be reserved. Multiple calls to this method before
// CreateClientSession is invoked should return the same index.
func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) {
m.mu.Lock()
defer m.mu.Unlock()
if index, ok := m.indexes[towerID]; ok {
return index, nil
}
index := m.nextIndex
m.indexes[towerID] = index
m.nextIndex++
return index, nil
}
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
// seqNum). This allows the client to retransmit this update on startup.
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
@ -217,7 +267,12 @@ func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) err
}
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
bb := make([]byte, len(b))
copy(bb, b)
return bb
}

@ -0,0 +1,44 @@
package wtmock
import (
"sync"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/keychain"
)
// SecretKeyRing is a mock, in-memory implementation for deriving private keys.
type SecretKeyRing struct {
mu sync.Mutex
keys map[keychain.KeyLocator]*btcec.PrivateKey
}
// NewSecretKeyRing creates a new mock SecretKeyRing.
func NewSecretKeyRing() *SecretKeyRing {
return &SecretKeyRing{
keys: make(map[keychain.KeyLocator]*btcec.PrivateKey),
}
}
// DerivePrivKey derives the private key for a given key descriptor. If
// this method is called twice with the same argument, it will return the same
// private key.
func (m *SecretKeyRing) DerivePrivKey(
desc keychain.KeyDescriptor) (*btcec.PrivateKey, error) {
m.mu.Lock()
defer m.mu.Unlock()
if key, ok := m.keys[desc.KeyLocator]; ok {
return key, nil
}
privKey, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
return nil, err
}
m.keys[desc.KeyLocator] = privKey
return privKey, nil
}

@ -21,45 +21,26 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
existingInfo, err := s.cfg.DB.GetSessionInfo(id)
switch {
// We already have a session, though it is currently unused. We'll allow
// the client to recommit the session if it wanted to change the policy.
case err == nil && existingInfo.LastApplied == 0:
// We already have a session corresponding to this session id, return an
// error signaling that it already exists in our database. We return the
// reward address to the client in case they were not able to process
// our reply earlier.
case err == nil:
case err == nil && existingInfo.LastApplied > 0:
log.Debugf("Already have session for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CreateSessionCodeAlreadyExists,
existingInfo.RewardAddress,
existingInfo.LastApplied, existingInfo.RewardAddress,
)
// Some other database error occurred, return a temporary failure.
case err != wtdb.ErrSessionNotFound:
log.Errorf("unable to load session info for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
// Now that we've established that this session does not exist in the
// database, retrieve the sweep address that will be given to the
// client. This address is to be included by the client when signing
// sweep transactions destined for this tower, if its negotiated output
// is not dust.
rewardAddress, err := s.cfg.NewAddress()
if err != nil {
log.Errorf("unable to generate reward addr for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
// Construct the pkscript the client should pay to when signing justice
// transactions for this session.
rewardScript, err := txscript.PayToAddrScript(rewardAddress)
if err != nil {
log.Errorf("unable to generate reward script for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
)
}
@ -68,10 +49,39 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
log.Debugf("Rejecting CreateSession from %s, unsupported blob "+
"type %s", id, req.BlobType)
return s.replyCreateSession(
peer, id, wtwire.CreateSessionCodeRejectBlobType, nil,
peer, id, wtwire.CreateSessionCodeRejectBlobType, 0,
nil,
)
}
// Now that we've established that this session does not exist in the
// database, retrieve the sweep address that will be given to the
// client. This address is to be included by the client when signing
// sweep transactions destined for this tower, if its negotiated output
// is not dust.
var rewardScript []byte
if req.BlobType.Has(blob.FlagReward) {
rewardAddress, err := s.cfg.NewAddress()
if err != nil {
log.Errorf("Unable to generate reward addr for %s: %v",
id, err)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
)
}
// Construct the pkscript the client should pay to when signing
// justice transactions for this session.
rewardScript, err = txscript.PayToAddrScript(rewardAddress)
if err != nil {
log.Errorf("Unable to generate reward script for "+
"%s: %v", id, err)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
)
}
}
// TODO(conner): create invoice for upfront payment
// Assemble the session info using the agreed upon parameters, reward
@ -94,14 +104,14 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
if err != nil {
log.Errorf("unable to create session for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
)
}
log.Infof("Accepted session for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeOK, rewardScript,
peer, id, wtwire.CodeOK, 0, rewardScript,
)
}
@ -110,11 +120,19 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
// Otherwise, this method returns a connection error to ensure we don't continue
// communication with the client.
func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
code wtwire.ErrorCode, data []byte) error {
code wtwire.ErrorCode, lastApplied uint16, data []byte) error {
if s.cfg.NoAckCreateSession {
return &connFailure{
ID: *id,
Code: code,
}
}
msg := &wtwire.CreateSessionReply{
Code: code,
Data: data,
Code: code,
LastApplied: lastApplied,
Data: data,
}
err := s.sendMessage(peer, msg)
@ -131,6 +149,6 @@ func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
// disconnect the client.
return &connFailure{
ID: *id,
Code: uint16(code),
Code: code,
}
}

@ -52,6 +52,6 @@ func (s *Server) replyDeleteSession(peer Peer, id *wtdb.SessionID,
// disconnect the client.
return &connFailure{
ID: *id,
Code: uint16(code),
Code: code,
}
}

@ -56,6 +56,10 @@ type Config struct {
// ChainHash identifies the network that the server is watching.
ChainHash chainhash.Hash
// NoAckCreateSession causes the server to not reply to create session
// requests, this should only be used for testing.
NoAckCreateSession bool
// NoAckUpdates causes the server to not acknowledge state updates, this
// should only be used for testing.
NoAckUpdates bool
@ -283,12 +287,12 @@ func (s *Server) handleClient(peer Peer) {
// error code.
type connFailure struct {
ID wtdb.SessionID
Code uint16
Code wtwire.ErrorCode
}
// Error displays the SessionID and Code that caused the connection failure.
func (f *connFailure) Error() string {
return fmt.Sprintf("connection with %s failed with code=%v",
return fmt.Sprintf("connection with %s failed with code=%s",
f.ID, f.Code,
)
}

@ -29,6 +29,8 @@ var (
addrScript, _ = txscript.PayToAddrScript(addr)
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
rewardType = (blob.FlagCommitOutputs | blob.FlagReward).Type()
)
// randPubKey generates a new secp keypair, and returns the public key.
@ -152,16 +154,17 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) {
}
type createSessionTestCase struct {
name string
initMsg *wtwire.Init
createMsg *wtwire.CreateSession
expReply *wtwire.CreateSessionReply
expDupReply *wtwire.CreateSessionReply
name string
initMsg *wtwire.Init
createMsg *wtwire.CreateSession
expReply *wtwire.CreateSessionReply
expDupReply *wtwire.CreateSessionReply
sendStateUpdate bool
}
var createSessionTests = []createSessionTestCase{
{
name: "reject duplicate session create",
name: "duplicate session create",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
@ -173,12 +176,58 @@ var createSessionTests = []createSessionTestCase{
RewardRate: 0,
SweepFeeRate: 1,
},
expReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: []byte{},
},
expDupReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: []byte{},
},
},
{
name: "duplicate session create after use",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 1000,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
expReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: []byte{},
},
expDupReply: &wtwire.CreateSessionReply{
Code: wtwire.CreateSessionCodeAlreadyExists,
LastApplied: 1,
Data: []byte{},
},
sendStateUpdate: true,
},
{
name: "duplicate session create reward",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: rewardType,
MaxUpdates: 1000,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
expReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: addrScript,
},
expDupReply: &wtwire.CreateSessionReply{
Code: wtwire.CreateSessionCodeAlreadyExists,
Code: wtwire.CodeOK,
Data: addrScript,
},
},
@ -251,6 +300,18 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
return
}
if test.sendStateUpdate {
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, s, peer, test.initMsg, timeoutDuration)
update := &wtwire.StateUpdate{
SeqNum: 1,
IsComplete: 1,
}
sendMsg(t, update, peer, timeoutDuration)
assertConnClosed(t, peer, 2*timeoutDuration)
}
// Simulate a peer with the same session id connection to the server
// again.
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
@ -705,7 +766,7 @@ func TestServerDeleteSession(t *testing.T) {
send: createSession,
recv: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: addrScript,
Data: []byte{},
},
assert: func(t *testing.T) {
// Both peers should have sessions.

@ -117,7 +117,7 @@ func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
if s.cfg.NoAckUpdates {
return &connFailure{
ID: *id,
Code: uint16(failCode),
Code: failCode,
}
}
@ -152,6 +152,6 @@ func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID,
// disconnect the client.
return &connFailure{
ID: *id,
Code: uint16(code),
Code: code,
}
}

@ -43,6 +43,12 @@ type CreateSessionReply struct {
// Code will be non-zero if the watchtower rejected the session init.
Code CreateSessionCode
// LastApplied is the tower's last accepted sequence number for the
// session. This is useful when the session already exists but the
// client doesn't realize it's already used the session, such as after a
// restoration.
LastApplied uint16
// Data is a byte slice returned the caller of the message, and is to be
// interpreted according to the error Code. When the response is
// CreateSessionCodeOK, data encodes the reward address to be included in
@ -63,6 +69,7 @@ var _ Message = (*CreateSessionReply)(nil)
func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&m.Code,
&m.LastApplied,
&m.Data,
)
}
@ -74,6 +81,7 @@ func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error {
func (m *CreateSessionReply) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
m.Code,
m.LastApplied,
m.Data,
)
}

@ -12,8 +12,6 @@ const (
// client side, or that the tower had already deleted the session in a
// prior request that the client may not have received.
DeleteSessionCodeNotFound DeleteSessionCode = 80
// TODO(conner): add String method after wtclient is merged
)
// DeleteSessionReply is a message sent in response to a client's DeleteSession

@ -46,6 +46,8 @@ func (c ErrorCode) String() string {
return "StateUpdateCodeMaxUpdatesExceeded"
case StateUpdateCodeSeqNumOutOfOrder:
return "StateUpdateCodeSeqNumOutOfOrder"
case DeleteSessionCodeNotFound:
return "DeleteSessionCodeNotFound"
default:
return fmt.Sprintf("UnknownErrorCode: %d", c)
}