Merge pull request #4815 from cfromknecht/wtclient-key-segregation
wtdb+wtclient: segregate session-key-index reservations by blob-type
This commit is contained in:
commit
c58589db3b
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/keychain"
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/tor"
|
"github.com/lightningnetwork/lnd/tor"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||||
)
|
)
|
||||||
@ -44,12 +45,12 @@ type DB interface {
|
|||||||
ListTowers() ([]*wtdb.Tower, error)
|
ListTowers() ([]*wtdb.Tower, error)
|
||||||
|
|
||||||
// NextSessionKeyIndex reserves a new session key derivation index for a
|
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||||
// particular tower id. The index is reserved for that tower until
|
// particular tower id and blob type. The index is reserved for that
|
||||||
// CreateClientSession is invoked for that tower and index, at which
|
// (tower, blob type) pair until CreateClientSession is invoked for that
|
||||||
// point a new index for that tower can be reserved. Multiple calls to
|
// tower and index, at which point a new index for that tower can be
|
||||||
// this method before CreateClientSession is invoked should return the
|
// reserved. Multiple calls to this method before CreateClientSession is
|
||||||
// same index.
|
// invoked should return the same index.
|
||||||
NextSessionKeyIndex(wtdb.TowerID) (uint32, error)
|
NextSessionKeyIndex(wtdb.TowerID, blob.Type) (uint32, error)
|
||||||
|
|
||||||
// CreateClientSession saves a newly negotiated client session to the
|
// CreateClientSession saves a newly negotiated client session to the
|
||||||
// client's database. This enables the session to be used across
|
// client's database. This enables the session to be used across
|
||||||
|
@ -298,7 +298,9 @@ retryWithBackoff:
|
|||||||
// Before proceeding, we will reserve a session key index to use
|
// Before proceeding, we will reserve a session key index to use
|
||||||
// with this specific tower. If one is already reserved, the
|
// with this specific tower. If one is already reserved, the
|
||||||
// existing index will be returned.
|
// existing index will be returned.
|
||||||
keyIndex, err := n.cfg.DB.NextSessionKeyIndex(tower.ID)
|
keyIndex, err := n.cfg.DB.NextSessionKeyIndex(
|
||||||
|
tower.ID, n.cfg.Policy.BlobType,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Unable to reserve session key index "+
|
log.Debugf("Unable to reserve session key index "+
|
||||||
"for tower=%x: %v", towerPub, err)
|
"for tower=%x: %v", towerPub, err)
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/channeldb/kvdb"
|
"github.com/lightningnetwork/lnd/channeldb/kvdb"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -479,7 +480,9 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) {
|
|||||||
// CreateClientSession is invoked for that tower and index, at which point a new
|
// CreateClientSession is invoked for that tower and index, at which point a new
|
||||||
// index for that tower can be reserved. Multiple calls to this method before
|
// index for that tower can be reserved. Multiple calls to this method before
|
||||||
// CreateClientSession is invoked should return the same index.
|
// CreateClientSession is invoked should return the same index.
|
||||||
func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
|
func (c *ClientDB) NextSessionKeyIndex(towerID TowerID,
|
||||||
|
blobType blob.Type) (uint32, error) {
|
||||||
|
|
||||||
var index uint32
|
var index uint32
|
||||||
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
|
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
|
||||||
keyIndex := tx.ReadWriteBucket(cSessionKeyIndexBkt)
|
keyIndex := tx.ReadWriteBucket(cSessionKeyIndexBkt)
|
||||||
@ -490,10 +493,9 @@ func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
|
|||||||
// Check the session key index to see if a key has already been
|
// Check the session key index to see if a key has already been
|
||||||
// reserved for this tower. If so, we'll deserialize and return
|
// reserved for this tower. If so, we'll deserialize and return
|
||||||
// the index directly.
|
// the index directly.
|
||||||
towerIDBytes := towerID.Bytes()
|
var err error
|
||||||
indexBytes := keyIndex.Get(towerIDBytes)
|
index, err = getSessionKeyIndex(keyIndex, towerID, blobType)
|
||||||
if len(indexBytes) == 4 {
|
if err == nil {
|
||||||
index = byteOrder.Uint32(indexBytes)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -511,13 +513,16 @@ func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
|
|||||||
return fmt.Errorf("exhausted session key indexes")
|
return fmt.Errorf("exhausted session key indexes")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create the key that will used to be store the reserved index.
|
||||||
|
keyBytes := createSessionKeyIndexKey(towerID, blobType)
|
||||||
|
|
||||||
index = uint32(index64)
|
index = uint32(index64)
|
||||||
|
|
||||||
var indexBuf [4]byte
|
var indexBuf [4]byte
|
||||||
byteOrder.PutUint32(indexBuf[:], index)
|
byteOrder.PutUint32(indexBuf[:], index)
|
||||||
|
|
||||||
// Record the reserved session key index under this tower's id.
|
// Record the reserved session key index under this tower's id.
|
||||||
return keyIndex.Put(towerIDBytes, indexBuf[:])
|
return keyIndex.Put(keyBytes, indexBuf[:])
|
||||||
}, func() {
|
}, func() {
|
||||||
index = 0
|
index = 0
|
||||||
})
|
})
|
||||||
@ -549,25 +554,34 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
|
|||||||
return ErrClientSessionAlreadyExists
|
return ErrClientSessionAlreadyExists
|
||||||
}
|
}
|
||||||
|
|
||||||
|
towerID := session.TowerID
|
||||||
|
blobType := session.Policy.BlobType
|
||||||
|
|
||||||
// Check that this tower has a reserved key index.
|
// Check that this tower has a reserved key index.
|
||||||
towerIDBytes := session.TowerID.Bytes()
|
index, err := getSessionKeyIndex(keyIndexes, towerID, blobType)
|
||||||
keyIndexBytes := keyIndexes.Get(towerIDBytes)
|
if err != nil {
|
||||||
if len(keyIndexBytes) != 4 {
|
return err
|
||||||
return ErrNoReservedKeyIndex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert that the key index of the inserted session matches the
|
// Assert that the key index of the inserted session matches the
|
||||||
// reserved session key index.
|
// reserved session key index.
|
||||||
index := byteOrder.Uint32(keyIndexBytes)
|
|
||||||
if index != session.KeyIndex {
|
if index != session.KeyIndex {
|
||||||
return ErrIncorrectKeyIndex
|
return ErrIncorrectKeyIndex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the key index reservation.
|
// Remove the key index reservation. For altruist commit
|
||||||
err := keyIndexes.Delete(towerIDBytes)
|
// sessions, we'll also purge under the old legacy key format.
|
||||||
|
key := createSessionKeyIndexKey(towerID, blobType)
|
||||||
|
err = keyIndexes.Delete(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if blobType == blob.TypeAltruistCommit {
|
||||||
|
err = keyIndexes.Delete(towerID.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Finally, write the client session's body in the sessions
|
// Finally, write the client session's body in the sessions
|
||||||
// bucket.
|
// bucket.
|
||||||
@ -575,6 +589,50 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
|
|||||||
}, func() {})
|
}, func() {})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createSessionKeyIndexKey returns the indentifier used in the
|
||||||
|
// session-key-index index, created as tower-id||blob-type.
|
||||||
|
//
|
||||||
|
// NOTE: The original serialization only used tower-id, which prevents
|
||||||
|
// concurrent client types from reserving sessions with the same tower.
|
||||||
|
func createSessionKeyIndexKey(towerID TowerID, blobType blob.Type) []byte {
|
||||||
|
towerIDBytes := towerID.Bytes()
|
||||||
|
|
||||||
|
// Session key indexes are stored under as tower-id||blob-type.
|
||||||
|
var keyBytes [6]byte
|
||||||
|
copy(keyBytes[:4], towerIDBytes)
|
||||||
|
byteOrder.PutUint16(keyBytes[4:], uint16(blobType))
|
||||||
|
|
||||||
|
return keyBytes[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSessionKeyIndex is a helper method
|
||||||
|
func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
|
||||||
|
blobType blob.Type) (uint32, error) {
|
||||||
|
|
||||||
|
// Session key indexes are store under as tower-id||blob-type. The
|
||||||
|
// original serialization only used tower-id, which prevents concurrent
|
||||||
|
// client types from reserving sessions with the same tower.
|
||||||
|
keyBytes := createSessionKeyIndexKey(towerID, blobType)
|
||||||
|
|
||||||
|
// Retrieve the index using the key bytes. If the key wasn't found, we
|
||||||
|
// will fall back to the legacy format that only uses the tower id, but
|
||||||
|
// _only_ if the blob type is for altruist commit sessions since that
|
||||||
|
// was the only operational session type prior to changing the key
|
||||||
|
// format.
|
||||||
|
keyIndexBytes := keyIndexes.Get(keyBytes)
|
||||||
|
if keyIndexBytes == nil && blobType == blob.TypeAltruistCommit {
|
||||||
|
keyIndexBytes = keyIndexes.Get(towerID.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
// All session key indexes should be serialized uint32's. If no key
|
||||||
|
// index was found, the length of keyIndexBytes will be 0.
|
||||||
|
if len(keyIndexBytes) != 4 {
|
||||||
|
return 0, ErrNoReservedKeyIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
return byteOrder.Uint32(keyIndexBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||||
// optional tower ID can be used to filter out any client sessions in the
|
// optional tower ID can be used to filter out any client sessions in the
|
||||||
// response that do not correspond to this tower.
|
// response that do not correspond to this tower.
|
||||||
|
@ -60,13 +60,14 @@ func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtd
|
|||||||
return sessions
|
return sessions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, expErr error) uint32 {
|
func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID,
|
||||||
|
blobType blob.Type) uint32 {
|
||||||
|
|
||||||
h.t.Helper()
|
h.t.Helper()
|
||||||
|
|
||||||
index, err := h.db.NextSessionKeyIndex(id)
|
index, err := h.db.NextSessionKeyIndex(id, blobType)
|
||||||
if err != expErr {
|
if err != nil {
|
||||||
h.t.Fatalf("expected next session key index error: %v, got: %v",
|
h.t.Fatalf("unable to create next session key index: %v", err)
|
||||||
expErr, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if index == 0 {
|
if index == 0 {
|
||||||
@ -227,11 +228,16 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
|
|||||||
// - client sessions cannot be created with an incorrect session key index .
|
// - client sessions cannot be created with an incorrect session key index .
|
||||||
// - inserting duplicate sessions fails.
|
// - inserting duplicate sessions fails.
|
||||||
func testCreateClientSession(h *clientDBHarness) {
|
func testCreateClientSession(h *clientDBHarness) {
|
||||||
|
const blobType = blob.TypeAltruistAnchorCommit
|
||||||
|
|
||||||
// Create a test client session to insert.
|
// Create a test client session to insert.
|
||||||
session := &wtdb.ClientSession{
|
session := &wtdb.ClientSession{
|
||||||
ClientSessionBody: wtdb.ClientSessionBody{
|
ClientSessionBody: wtdb.ClientSessionBody{
|
||||||
TowerID: wtdb.TowerID(3),
|
TowerID: wtdb.TowerID(3),
|
||||||
Policy: wtpolicy.Policy{
|
Policy: wtpolicy.Policy{
|
||||||
|
TxPolicy: wtpolicy.TxPolicy{
|
||||||
|
BlobType: blobType,
|
||||||
|
},
|
||||||
MaxUpdates: 100,
|
MaxUpdates: 100,
|
||||||
},
|
},
|
||||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||||
@ -250,7 +256,7 @@ func testCreateClientSession(h *clientDBHarness) {
|
|||||||
h.insertSession(session, wtdb.ErrNoReservedKeyIndex)
|
h.insertSession(session, wtdb.ErrNoReservedKeyIndex)
|
||||||
|
|
||||||
// Now, reserve a session key for this tower.
|
// Now, reserve a session key for this tower.
|
||||||
keyIndex := h.nextKeyIndex(session.TowerID, nil)
|
keyIndex := h.nextKeyIndex(session.TowerID, blobType)
|
||||||
|
|
||||||
// The client session hasn't been updated with the reserved key index
|
// The client session hasn't been updated with the reserved key index
|
||||||
// (since it's still zero). Inserting should fail due to the mismatch.
|
// (since it's still zero). Inserting should fail due to the mismatch.
|
||||||
@ -259,7 +265,7 @@ func testCreateClientSession(h *clientDBHarness) {
|
|||||||
// Reserve another key for the same index. Since no session has been
|
// Reserve another key for the same index. Since no session has been
|
||||||
// successfully created, it should return the same index to maintain
|
// successfully created, it should return the same index to maintain
|
||||||
// idempotency across restarts.
|
// idempotency across restarts.
|
||||||
keyIndex2 := h.nextKeyIndex(session.TowerID, nil)
|
keyIndex2 := h.nextKeyIndex(session.TowerID, blobType)
|
||||||
if keyIndex != keyIndex2 {
|
if keyIndex != keyIndex2 {
|
||||||
h.t.Fatalf("next key index should be idempotent: want: %v, "+
|
h.t.Fatalf("next key index should be idempotent: want: %v, "+
|
||||||
"got %v", keyIndex, keyIndex2)
|
"got %v", keyIndex, keyIndex2)
|
||||||
@ -281,7 +287,7 @@ func testCreateClientSession(h *clientDBHarness) {
|
|||||||
|
|
||||||
// Finally, assert that reserving another key index succeeds with a
|
// Finally, assert that reserving another key index succeeds with a
|
||||||
// different key index, now that the first one has been finalized.
|
// different key index, now that the first one has been finalized.
|
||||||
keyIndex3 := h.nextKeyIndex(session.TowerID, nil)
|
keyIndex3 := h.nextKeyIndex(session.TowerID, blobType)
|
||||||
if keyIndex == keyIndex3 {
|
if keyIndex == keyIndex3 {
|
||||||
h.t.Fatalf("key index still reserved after creating session")
|
h.t.Fatalf("key index still reserved after creating session")
|
||||||
}
|
}
|
||||||
@ -293,18 +299,22 @@ func testFilterClientSessions(h *clientDBHarness) {
|
|||||||
// We'll create three client sessions, the first two belonging to one
|
// We'll create three client sessions, the first two belonging to one
|
||||||
// tower, and the last belonging to another one.
|
// tower, and the last belonging to another one.
|
||||||
const numSessions = 3
|
const numSessions = 3
|
||||||
|
const blobType = blob.TypeAltruistCommit
|
||||||
towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID)
|
towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID)
|
||||||
for i := 0; i < numSessions; i++ {
|
for i := 0; i < numSessions; i++ {
|
||||||
towerID := wtdb.TowerID(1)
|
towerID := wtdb.TowerID(1)
|
||||||
if i == numSessions-1 {
|
if i == numSessions-1 {
|
||||||
towerID = wtdb.TowerID(2)
|
towerID = wtdb.TowerID(2)
|
||||||
}
|
}
|
||||||
keyIndex := h.nextKeyIndex(towerID, nil)
|
keyIndex := h.nextKeyIndex(towerID, blobType)
|
||||||
sessionID := wtdb.SessionID([33]byte{byte(i)})
|
sessionID := wtdb.SessionID([33]byte{byte(i)})
|
||||||
h.insertSession(&wtdb.ClientSession{
|
h.insertSession(&wtdb.ClientSession{
|
||||||
ClientSessionBody: wtdb.ClientSessionBody{
|
ClientSessionBody: wtdb.ClientSessionBody{
|
||||||
TowerID: towerID,
|
TowerID: towerID,
|
||||||
Policy: wtpolicy.Policy{
|
Policy: wtpolicy.Policy{
|
||||||
|
TxPolicy: wtpolicy.TxPolicy{
|
||||||
|
BlobType: blobType,
|
||||||
|
},
|
||||||
MaxUpdates: 100,
|
MaxUpdates: 100,
|
||||||
},
|
},
|
||||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||||
@ -458,14 +468,18 @@ func testRemoveTower(h *clientDBHarness) {
|
|||||||
Address: addr1,
|
Address: addr1,
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
|
const blobType = blob.TypeAltruistCommit
|
||||||
session := &wtdb.ClientSession{
|
session := &wtdb.ClientSession{
|
||||||
ClientSessionBody: wtdb.ClientSessionBody{
|
ClientSessionBody: wtdb.ClientSessionBody{
|
||||||
TowerID: tower.ID,
|
TowerID: tower.ID,
|
||||||
Policy: wtpolicy.Policy{
|
Policy: wtpolicy.Policy{
|
||||||
|
TxPolicy: wtpolicy.TxPolicy{
|
||||||
|
BlobType: blobType,
|
||||||
|
},
|
||||||
MaxUpdates: 100,
|
MaxUpdates: 100,
|
||||||
},
|
},
|
||||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||||
KeyIndex: h.nextKeyIndex(tower.ID, nil),
|
KeyIndex: h.nextKeyIndex(tower.ID, blobType),
|
||||||
},
|
},
|
||||||
ID: wtdb.SessionID([33]byte{0x01}),
|
ID: wtdb.SessionID([33]byte{0x01}),
|
||||||
}
|
}
|
||||||
@ -525,10 +539,14 @@ func testChanSummaries(h *clientDBHarness) {
|
|||||||
|
|
||||||
// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can
|
// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can
|
||||||
func testCommitUpdate(h *clientDBHarness) {
|
func testCommitUpdate(h *clientDBHarness) {
|
||||||
|
const blobType = blob.TypeAltruistCommit
|
||||||
session := &wtdb.ClientSession{
|
session := &wtdb.ClientSession{
|
||||||
ClientSessionBody: wtdb.ClientSessionBody{
|
ClientSessionBody: wtdb.ClientSessionBody{
|
||||||
TowerID: wtdb.TowerID(3),
|
TowerID: wtdb.TowerID(3),
|
||||||
Policy: wtpolicy.Policy{
|
Policy: wtpolicy.Policy{
|
||||||
|
TxPolicy: wtpolicy.TxPolicy{
|
||||||
|
BlobType: blobType,
|
||||||
|
},
|
||||||
MaxUpdates: 100,
|
MaxUpdates: 100,
|
||||||
},
|
},
|
||||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||||
@ -542,7 +560,7 @@ func testCommitUpdate(h *clientDBHarness) {
|
|||||||
h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound)
|
h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound)
|
||||||
|
|
||||||
// Reserve a session key index and insert the session.
|
// Reserve a session key index and insert the session.
|
||||||
session.KeyIndex = h.nextKeyIndex(session.TowerID, nil)
|
session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType)
|
||||||
h.insertSession(session, nil)
|
h.insertSession(session, nil)
|
||||||
|
|
||||||
// Now, try to commit the update that failed initially which should
|
// Now, try to commit the update that failed initially which should
|
||||||
@ -620,11 +638,16 @@ func testCommitUpdate(h *clientDBHarness) {
|
|||||||
|
|
||||||
// testAckUpdate asserts the behavior of AckUpdate.
|
// testAckUpdate asserts the behavior of AckUpdate.
|
||||||
func testAckUpdate(h *clientDBHarness) {
|
func testAckUpdate(h *clientDBHarness) {
|
||||||
|
const blobType = blob.TypeAltruistCommit
|
||||||
|
|
||||||
// Create a new session that the updates in this will be tied to.
|
// Create a new session that the updates in this will be tied to.
|
||||||
session := &wtdb.ClientSession{
|
session := &wtdb.ClientSession{
|
||||||
ClientSessionBody: wtdb.ClientSessionBody{
|
ClientSessionBody: wtdb.ClientSessionBody{
|
||||||
TowerID: wtdb.TowerID(3),
|
TowerID: wtdb.TowerID(3),
|
||||||
Policy: wtpolicy.Policy{
|
Policy: wtpolicy.Policy{
|
||||||
|
TxPolicy: wtpolicy.TxPolicy{
|
||||||
|
BlobType: blobType,
|
||||||
|
},
|
||||||
MaxUpdates: 100,
|
MaxUpdates: 100,
|
||||||
},
|
},
|
||||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||||
@ -637,7 +660,7 @@ func testAckUpdate(h *clientDBHarness) {
|
|||||||
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrClientSessionNotFound)
|
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrClientSessionNotFound)
|
||||||
|
|
||||||
// Reserve a session key and insert the client session.
|
// Reserve a session key and insert the client session.
|
||||||
session.KeyIndex = h.nextKeyIndex(session.TowerID, nil)
|
session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType)
|
||||||
h.insertSession(session, nil)
|
h.insertSession(session, nil)
|
||||||
|
|
||||||
// Now, try to ack update 1. This should fail since update 1 was never
|
// Now, try to ack update 1. This should fail since update 1 was never
|
||||||
|
@ -7,11 +7,17 @@ import (
|
|||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
type towerPK [33]byte
|
type towerPK [33]byte
|
||||||
|
|
||||||
|
type keyIndexKey struct {
|
||||||
|
towerID wtdb.TowerID
|
||||||
|
blobType blob.Type
|
||||||
|
}
|
||||||
|
|
||||||
// ClientDB is a mock, in-memory database or testing the watchtower client
|
// ClientDB is a mock, in-memory database or testing the watchtower client
|
||||||
// behavior.
|
// behavior.
|
||||||
type ClientDB struct {
|
type ClientDB struct {
|
||||||
@ -24,7 +30,8 @@ type ClientDB struct {
|
|||||||
towers map[wtdb.TowerID]*wtdb.Tower
|
towers map[wtdb.TowerID]*wtdb.Tower
|
||||||
|
|
||||||
nextIndex uint32
|
nextIndex uint32
|
||||||
indexes map[wtdb.TowerID]uint32
|
indexes map[keyIndexKey]uint32
|
||||||
|
legacyIndexes map[wtdb.TowerID]uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientDB initializes a new mock ClientDB.
|
// NewClientDB initializes a new mock ClientDB.
|
||||||
@ -34,7 +41,8 @@ func NewClientDB() *ClientDB {
|
|||||||
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
|
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
|
||||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||||
indexes: make(map[wtdb.TowerID]uint32),
|
indexes: make(map[keyIndexKey]uint32),
|
||||||
|
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -229,10 +237,15 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
|||||||
return wtdb.ErrClientSessionAlreadyExists
|
return wtdb.ErrClientSessionAlreadyExists
|
||||||
}
|
}
|
||||||
|
|
||||||
|
key := keyIndexKey{
|
||||||
|
towerID: session.TowerID,
|
||||||
|
blobType: session.Policy.BlobType,
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure that a session key index has been reserved for this tower.
|
// Ensure that a session key index has been reserved for this tower.
|
||||||
keyIndex, ok := m.indexes[session.TowerID]
|
keyIndex, err := m.getSessionKeyIndex(key)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return wtdb.ErrNoReservedKeyIndex
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that the session's index matches the reserved index.
|
// Ensure that the session's index matches the reserved index.
|
||||||
@ -242,7 +255,10 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
|||||||
|
|
||||||
// Remove the key index reservation for this tower. Once committed, this
|
// Remove the key index reservation for this tower. Once committed, this
|
||||||
// permits us to create another session with this tower.
|
// permits us to create another session with this tower.
|
||||||
delete(m.indexes, session.TowerID)
|
delete(m.indexes, key)
|
||||||
|
if key.blobType == blob.TypeAltruistCommit {
|
||||||
|
delete(m.legacyIndexes, key.towerID)
|
||||||
|
}
|
||||||
|
|
||||||
m.activeSessions[session.ID] = wtdb.ClientSession{
|
m.activeSessions[session.ID] = wtdb.ClientSession{
|
||||||
ID: session.ID,
|
ID: session.ID,
|
||||||
@ -266,21 +282,42 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
|||||||
// CreateClientSession is invoked for that tower and index, at which point a new
|
// CreateClientSession is invoked for that tower and index, at which point a new
|
||||||
// index for that tower can be reserved. Multiple calls to this method before
|
// index for that tower can be reserved. Multiple calls to this method before
|
||||||
// CreateClientSession is invoked should return the same index.
|
// CreateClientSession is invoked should return the same index.
|
||||||
func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) {
|
func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID,
|
||||||
|
blobType blob.Type) (uint32, error) {
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
if index, ok := m.indexes[towerID]; ok {
|
key := keyIndexKey{
|
||||||
|
towerID: towerID,
|
||||||
|
blobType: blobType,
|
||||||
|
}
|
||||||
|
|
||||||
|
if index, err := m.getSessionKeyIndex(key); err == nil {
|
||||||
return index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.nextIndex++
|
m.nextIndex++
|
||||||
index := m.nextIndex
|
index := m.nextIndex
|
||||||
m.indexes[towerID] = index
|
m.indexes[key] = index
|
||||||
|
|
||||||
return index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *ClientDB) getSessionKeyIndex(key keyIndexKey) (uint32, error) {
|
||||||
|
if index, ok := m.indexes[key]; ok {
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.blobType == blob.TypeAltruistCommit {
|
||||||
|
if index, ok := m.legacyIndexes[key.towerID]; ok {
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, wtdb.ErrNoReservedKeyIndex
|
||||||
|
}
|
||||||
|
|
||||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||||
// seqNum). This allows the client to retransmit this update on startup.
|
// seqNum). This allows the client to retransmit this update on startup.
|
||||||
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||||
|
Loading…
Reference in New Issue
Block a user