diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 395e9576..2c766d00 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -7,6 +7,7 @@ import ( "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tor" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtserver" ) @@ -44,12 +45,12 @@ type DB interface { ListTowers() ([]*wtdb.Tower, error) // NextSessionKeyIndex reserves a new session key derivation index for a - // particular tower id. The index is reserved for that tower until - // CreateClientSession is invoked for that tower and index, at which - // point a new index for that tower can be reserved. Multiple calls to - // this method before CreateClientSession is invoked should return the - // same index. - NextSessionKeyIndex(wtdb.TowerID) (uint32, error) + // particular tower id and blob type. The index is reserved for that + // (tower, blob type) pair until CreateClientSession is invoked for that + // tower and index, at which point a new index for that tower can be + // reserved. Multiple calls to this method before CreateClientSession is + // invoked should return the same index. + NextSessionKeyIndex(wtdb.TowerID, blob.Type) (uint32, error) // CreateClientSession saves a newly negotiated client session to the // client's database. This enables the session to be used across diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index 8ab4521f..fe85edb0 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -298,7 +298,9 @@ retryWithBackoff: // Before proceeding, we will reserve a session key index to use // with this specific tower. If one is already reserved, the // existing index will be returned. - keyIndex, err := n.cfg.DB.NextSessionKeyIndex(tower.ID) + keyIndex, err := n.cfg.DB.NextSessionKeyIndex( + tower.ID, n.cfg.Policy.BlobType, + ) if err != nil { log.Debugf("Unable to reserve session key index "+ "for tower=%x: %v", towerPub, err) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 60a4599a..bf150b10 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/blob" ) const ( @@ -479,7 +480,9 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) { // CreateClientSession is invoked for that tower and index, at which point a new // index for that tower can be reserved. Multiple calls to this method before // CreateClientSession is invoked should return the same index. -func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) { +func (c *ClientDB) NextSessionKeyIndex(towerID TowerID, + blobType blob.Type) (uint32, error) { + var index uint32 err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { keyIndex := tx.ReadWriteBucket(cSessionKeyIndexBkt) @@ -490,10 +493,9 @@ func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) { // Check the session key index to see if a key has already been // reserved for this tower. If so, we'll deserialize and return // the index directly. - towerIDBytes := towerID.Bytes() - indexBytes := keyIndex.Get(towerIDBytes) - if len(indexBytes) == 4 { - index = byteOrder.Uint32(indexBytes) + var err error + index, err = getSessionKeyIndex(keyIndex, towerID, blobType) + if err == nil { return nil } @@ -511,13 +513,16 @@ func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) { return fmt.Errorf("exhausted session key indexes") } + // Create the key that will used to be store the reserved index. + keyBytes := createSessionKeyIndexKey(towerID, blobType) + index = uint32(index64) var indexBuf [4]byte byteOrder.PutUint32(indexBuf[:], index) // Record the reserved session key index under this tower's id. - return keyIndex.Put(towerIDBytes, indexBuf[:]) + return keyIndex.Put(keyBytes, indexBuf[:]) }, func() { index = 0 }) @@ -549,25 +554,34 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { return ErrClientSessionAlreadyExists } + towerID := session.TowerID + blobType := session.Policy.BlobType + // Check that this tower has a reserved key index. - towerIDBytes := session.TowerID.Bytes() - keyIndexBytes := keyIndexes.Get(towerIDBytes) - if len(keyIndexBytes) != 4 { - return ErrNoReservedKeyIndex + index, err := getSessionKeyIndex(keyIndexes, towerID, blobType) + if err != nil { + return err } // Assert that the key index of the inserted session matches the // reserved session key index. - index := byteOrder.Uint32(keyIndexBytes) if index != session.KeyIndex { return ErrIncorrectKeyIndex } - // Remove the key index reservation. - err := keyIndexes.Delete(towerIDBytes) + // Remove the key index reservation. For altruist commit + // sessions, we'll also purge under the old legacy key format. + key := createSessionKeyIndexKey(towerID, blobType) + err = keyIndexes.Delete(key) if err != nil { return err } + if blobType == blob.TypeAltruistCommit { + err = keyIndexes.Delete(towerID.Bytes()) + if err != nil { + return err + } + } // Finally, write the client session's body in the sessions // bucket. @@ -575,6 +589,50 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { }, func() {}) } +// createSessionKeyIndexKey returns the indentifier used in the +// session-key-index index, created as tower-id||blob-type. +// +// NOTE: The original serialization only used tower-id, which prevents +// concurrent client types from reserving sessions with the same tower. +func createSessionKeyIndexKey(towerID TowerID, blobType blob.Type) []byte { + towerIDBytes := towerID.Bytes() + + // Session key indexes are stored under as tower-id||blob-type. + var keyBytes [6]byte + copy(keyBytes[:4], towerIDBytes) + byteOrder.PutUint16(keyBytes[4:], uint16(blobType)) + + return keyBytes[:] +} + +// getSessionKeyIndex is a helper method +func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID, + blobType blob.Type) (uint32, error) { + + // Session key indexes are store under as tower-id||blob-type. The + // original serialization only used tower-id, which prevents concurrent + // client types from reserving sessions with the same tower. + keyBytes := createSessionKeyIndexKey(towerID, blobType) + + // Retrieve the index using the key bytes. If the key wasn't found, we + // will fall back to the legacy format that only uses the tower id, but + // _only_ if the blob type is for altruist commit sessions since that + // was the only operational session type prior to changing the key + // format. + keyIndexBytes := keyIndexes.Get(keyBytes) + if keyIndexBytes == nil && blobType == blob.TypeAltruistCommit { + keyIndexBytes = keyIndexes.Get(towerID.Bytes()) + } + + // All session key indexes should be serialized uint32's. If no key + // index was found, the length of keyIndexBytes will be 0. + if len(keyIndexBytes) != 4 { + return 0, ErrNoReservedKeyIndex + } + + return byteOrder.Uint32(keyIndexBytes), nil +} + // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 92d1c5a7..38f6bd98 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -60,10 +60,12 @@ func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtd return sessions } -func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, expErr error) uint32 { +func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, blobType blob.Type, + expErr error) uint32 { + h.t.Helper() - index, err := h.db.NextSessionKeyIndex(id) + index, err := h.db.NextSessionKeyIndex(id, blobType) if err != expErr { h.t.Fatalf("expected next session key index error: %v, got: %v", expErr, err) @@ -227,11 +229,16 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, // - client sessions cannot be created with an incorrect session key index . // - inserting duplicate sessions fails. func testCreateClientSession(h *clientDBHarness) { + const blobType = blob.TypeAltruistCommit + // Create a test client session to insert. session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ TowerID: wtdb.TowerID(3), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + }, MaxUpdates: 100, }, RewardPkScript: []byte{0x01, 0x02, 0x03}, @@ -250,7 +257,7 @@ func testCreateClientSession(h *clientDBHarness) { h.insertSession(session, wtdb.ErrNoReservedKeyIndex) // Now, reserve a session key for this tower. - keyIndex := h.nextKeyIndex(session.TowerID, nil) + keyIndex := h.nextKeyIndex(session.TowerID, blobType, nil) // The client session hasn't been updated with the reserved key index // (since it's still zero). Inserting should fail due to the mismatch. @@ -259,7 +266,7 @@ func testCreateClientSession(h *clientDBHarness) { // Reserve another key for the same index. Since no session has been // successfully created, it should return the same index to maintain // idempotency across restarts. - keyIndex2 := h.nextKeyIndex(session.TowerID, nil) + keyIndex2 := h.nextKeyIndex(session.TowerID, blobType, nil) if keyIndex != keyIndex2 { h.t.Fatalf("next key index should be idempotent: want: %v, "+ "got %v", keyIndex, keyIndex2) @@ -281,7 +288,7 @@ func testCreateClientSession(h *clientDBHarness) { // Finally, assert that reserving another key index succeeds with a // different key index, now that the first one has been finalized. - keyIndex3 := h.nextKeyIndex(session.TowerID, nil) + keyIndex3 := h.nextKeyIndex(session.TowerID, blobType, nil) if keyIndex == keyIndex3 { h.t.Fatalf("key index still reserved after creating session") } @@ -293,18 +300,22 @@ func testFilterClientSessions(h *clientDBHarness) { // We'll create three client sessions, the first two belonging to one // tower, and the last belonging to another one. const numSessions = 3 + const blobType = blob.TypeAltruistCommit towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID) for i := 0; i < numSessions; i++ { towerID := wtdb.TowerID(1) if i == numSessions-1 { towerID = wtdb.TowerID(2) } - keyIndex := h.nextKeyIndex(towerID, nil) + keyIndex := h.nextKeyIndex(towerID, blobType, nil) sessionID := wtdb.SessionID([33]byte{byte(i)}) h.insertSession(&wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ TowerID: towerID, Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + }, MaxUpdates: 100, }, RewardPkScript: []byte{0x01, 0x02, 0x03}, @@ -458,14 +469,18 @@ func testRemoveTower(h *clientDBHarness) { Address: addr1, }, nil) + const blobType = blob.TypeAltruistCommit session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ TowerID: tower.ID, Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + }, MaxUpdates: 100, }, RewardPkScript: []byte{0x01, 0x02, 0x03}, - KeyIndex: h.nextKeyIndex(tower.ID, nil), + KeyIndex: h.nextKeyIndex(tower.ID, blobType, nil), }, ID: wtdb.SessionID([33]byte{0x01}), } @@ -525,10 +540,14 @@ func testChanSummaries(h *clientDBHarness) { // testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can func testCommitUpdate(h *clientDBHarness) { + const blobType = blob.TypeAltruistCommit session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ TowerID: wtdb.TowerID(3), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + }, MaxUpdates: 100, }, RewardPkScript: []byte{0x01, 0x02, 0x03}, @@ -542,7 +561,7 @@ func testCommitUpdate(h *clientDBHarness) { h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound) // Reserve a session key index and insert the session. - session.KeyIndex = h.nextKeyIndex(session.TowerID, nil) + session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType, nil) h.insertSession(session, nil) // Now, try to commit the update that failed initially which should @@ -620,11 +639,16 @@ func testCommitUpdate(h *clientDBHarness) { // testAckUpdate asserts the behavior of AckUpdate. func testAckUpdate(h *clientDBHarness) { + const blobType = blob.TypeAltruistCommit + // Create a new session that the updates in this will be tied to. session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ TowerID: wtdb.TowerID(3), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + }, MaxUpdates: 100, }, RewardPkScript: []byte{0x01, 0x02, 0x03}, @@ -637,7 +661,7 @@ func testAckUpdate(h *clientDBHarness) { h.ackUpdate(&session.ID, 1, 0, wtdb.ErrClientSessionNotFound) // Reserve a session key and insert the client session. - session.KeyIndex = h.nextKeyIndex(session.TowerID, nil) + session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType, nil) h.insertSession(session, nil) // Now, try to ack update 1. This should fail since update 1 was never diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index c06a097a..35190d93 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -7,11 +7,17 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtdb" ) type towerPK [33]byte +type keyIndexKey struct { + towerID wtdb.TowerID + blobType blob.Type +} + // ClientDB is a mock, in-memory database or testing the watchtower client // behavior. type ClientDB struct { @@ -23,8 +29,9 @@ type ClientDB struct { towerIndex map[towerPK]wtdb.TowerID towers map[wtdb.TowerID]*wtdb.Tower - nextIndex uint32 - indexes map[wtdb.TowerID]uint32 + nextIndex uint32 + indexes map[keyIndexKey]uint32 + legacyIndexes map[wtdb.TowerID]uint32 } // NewClientDB initializes a new mock ClientDB. @@ -34,7 +41,8 @@ func NewClientDB() *ClientDB { activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), towerIndex: make(map[towerPK]wtdb.TowerID), towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[wtdb.TowerID]uint32), + indexes: make(map[keyIndexKey]uint32), + legacyIndexes: make(map[wtdb.TowerID]uint32), } } @@ -229,10 +237,15 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { return wtdb.ErrClientSessionAlreadyExists } + key := keyIndexKey{ + towerID: session.TowerID, + blobType: session.Policy.BlobType, + } + // Ensure that a session key index has been reserved for this tower. - keyIndex, ok := m.indexes[session.TowerID] - if !ok { - return wtdb.ErrNoReservedKeyIndex + keyIndex, err := m.getSessionKeyIndex(key) + if err != nil { + return err } // Ensure that the session's index matches the reserved index. @@ -242,7 +255,10 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { // Remove the key index reservation for this tower. Once committed, this // permits us to create another session with this tower. - delete(m.indexes, session.TowerID) + delete(m.indexes, key) + if key.blobType == blob.TypeAltruistCommit { + delete(m.legacyIndexes, key.towerID) + } m.activeSessions[session.ID] = wtdb.ClientSession{ ID: session.ID, @@ -266,21 +282,42 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { // CreateClientSession is invoked for that tower and index, at which point a new // index for that tower can be reserved. Multiple calls to this method before // CreateClientSession is invoked should return the same index. -func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) { +func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID, + blobType blob.Type) (uint32, error) { + m.mu.Lock() defer m.mu.Unlock() - if index, ok := m.indexes[towerID]; ok { + key := keyIndexKey{ + towerID: towerID, + blobType: blobType, + } + + if index, err := m.getSessionKeyIndex(key); err == nil { return index, nil } m.nextIndex++ index := m.nextIndex - m.indexes[towerID] = index + m.indexes[key] = index return index, nil } +func (m *ClientDB) getSessionKeyIndex(key keyIndexKey) (uint32, error) { + if index, ok := m.indexes[key]; ok { + return index, nil + } + + if key.blobType == blob.TypeAltruistCommit { + if index, ok := m.legacyIndexes[key.towerID]; ok { + return index, nil + } + } + + return 0, wtdb.ErrNoReservedKeyIndex +} + // CommitUpdate persists the CommittedUpdate provided in the slot for (session, // seqNum). This allows the client to retransmit this update on startup. func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,