224 lines
6.4 KiB
Go
224 lines
6.4 KiB
Go
package wtmock
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/lightningnetwork/lnd/lnwire"
|
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
|
)
|
|
|
|
type towerPK [33]byte
|
|
|
|
// ClientDB is a mock, in-memory database or testing the watchtower client
|
|
// behavior.
|
|
type ClientDB struct {
|
|
nextTowerID uint64 // to be used atomically
|
|
|
|
mu sync.Mutex
|
|
sweepPkScripts map[lnwire.ChannelID][]byte
|
|
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
|
|
towerIndex map[towerPK]uint64
|
|
towers map[uint64]*wtdb.Tower
|
|
}
|
|
|
|
// NewClientDB initializes a new mock ClientDB.
|
|
func NewClientDB() *ClientDB {
|
|
return &ClientDB{
|
|
sweepPkScripts: make(map[lnwire.ChannelID][]byte),
|
|
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
|
|
towerIndex: make(map[towerPK]uint64),
|
|
towers: make(map[uint64]*wtdb.Tower),
|
|
}
|
|
}
|
|
|
|
// CreateTower initializes a database entry with the given lightning address. If
|
|
// the tower exists, the address is append to the list of all addresses used to
|
|
// that tower previously.
|
|
func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var towerPubKey towerPK
|
|
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
|
|
|
|
var tower *wtdb.Tower
|
|
towerID, ok := m.towerIndex[towerPubKey]
|
|
if ok {
|
|
tower = m.towers[towerID]
|
|
tower.AddAddress(lnAddr.Address)
|
|
} else {
|
|
towerID = atomic.AddUint64(&m.nextTowerID, 1)
|
|
tower = &wtdb.Tower{
|
|
ID: towerID,
|
|
IdentityKey: lnAddr.IdentityKey,
|
|
Addresses: []net.Addr{lnAddr.Address},
|
|
}
|
|
}
|
|
|
|
m.towerIndex[towerPubKey] = towerID
|
|
m.towers[towerID] = tower
|
|
|
|
return tower, nil
|
|
}
|
|
|
|
// MarkBackupIneligible records that particular commit height is ineligible for
|
|
// backup. This allows the client to track which updates it should not attempt
|
|
// to retry after startup.
|
|
func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error {
|
|
return nil
|
|
}
|
|
|
|
// ListClientSessions returns the set of all client sessions known to the db.
|
|
func (m *ClientDB) ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
|
for _, session := range m.activeSessions {
|
|
sessions[session.ID] = session
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
// CreateClientSession records a newly negotiated client session in the set of
|
|
// active sessions. The session can be identified by its SessionID.
|
|
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
m.activeSessions[session.ID] = &wtdb.ClientSession{
|
|
TowerID: session.TowerID,
|
|
Tower: session.Tower,
|
|
SessionKeyDesc: session.SessionKeyDesc,
|
|
SessionPrivKey: session.SessionPrivKey,
|
|
ID: session.ID,
|
|
Policy: session.Policy,
|
|
SeqNum: session.SeqNum,
|
|
TowerLastApplied: session.TowerLastApplied,
|
|
RewardPkScript: session.RewardPkScript,
|
|
CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate),
|
|
AckedUpdates: make(map[uint16]wtdb.BackupID),
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
|
// seqNum). This allows the client to retransmit this update on startup.
|
|
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
|
|
update *wtdb.CommittedUpdate) (uint16, error) {
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Fail if session doesn't exist.
|
|
session, ok := m.activeSessions[*id]
|
|
if !ok {
|
|
return 0, wtdb.ErrClientSessionNotFound
|
|
}
|
|
|
|
// Check if an update has already been committed for this state.
|
|
dbUpdate, ok := session.CommittedUpdates[seqNum]
|
|
if ok {
|
|
// If the breach hint matches, we'll just return the last
|
|
// applied value so the client can retransmit.
|
|
if dbUpdate.Hint == update.Hint {
|
|
return session.TowerLastApplied, nil
|
|
}
|
|
|
|
// Otherwise, fail since the breach hint doesn't match.
|
|
return 0, wtdb.ErrUpdateAlreadyCommitted
|
|
}
|
|
|
|
// Sequence number must increment.
|
|
if seqNum != session.SeqNum+1 {
|
|
return 0, wtdb.ErrCommitUnorderedUpdate
|
|
}
|
|
|
|
// Save the update and increment the sequence number.
|
|
session.CommittedUpdates[seqNum] = update
|
|
session.SeqNum++
|
|
|
|
return session.TowerLastApplied, nil
|
|
}
|
|
|
|
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
|
|
// removes the update from the set of committed updates, and validates the
|
|
// lastApplied value returned from the tower.
|
|
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Fail if session doesn't exist.
|
|
session, ok := m.activeSessions[*id]
|
|
if !ok {
|
|
return wtdb.ErrClientSessionNotFound
|
|
}
|
|
|
|
// Retrieve the committed update, failing if none is found. We should
|
|
// only receive acks for state updates that we send.
|
|
update, ok := session.CommittedUpdates[seqNum]
|
|
if !ok {
|
|
return wtdb.ErrCommittedUpdateNotFound
|
|
}
|
|
|
|
// Ensure the returned last applied value does not exceed the highest
|
|
// allocated sequence number.
|
|
if lastApplied > session.SeqNum {
|
|
return wtdb.ErrUnallocatedLastApplied
|
|
}
|
|
|
|
// Ensure the last applied value isn't lower than a previous one sent by
|
|
// the tower.
|
|
if lastApplied < session.TowerLastApplied {
|
|
return wtdb.ErrLastAppliedReversion
|
|
}
|
|
|
|
// Finally, remove the committed update from disk and mark the update as
|
|
// acked. The tower last applied value is also recorded to send along
|
|
// with the next update.
|
|
delete(session.CommittedUpdates, seqNum)
|
|
session.AckedUpdates[seqNum] = update.BackupID
|
|
session.TowerLastApplied = lastApplied
|
|
|
|
return nil
|
|
}
|
|
|
|
// FetchChanPkScripts returns the set of sweep pkscripts known for all channels.
|
|
// This allows the client to cache them in memory on startup.
|
|
func (m *ClientDB) FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
sweepPkScripts := make(map[lnwire.ChannelID][]byte)
|
|
for chanID, pkScript := range m.sweepPkScripts {
|
|
sweepPkScripts[chanID] = cloneBytes(pkScript)
|
|
}
|
|
|
|
return sweepPkScripts, nil
|
|
}
|
|
|
|
// AddChanPkScript sets a pkscript or sweeping funds from the channel or chanID.
|
|
func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if _, ok := m.sweepPkScripts[chanID]; ok {
|
|
return fmt.Errorf("pkscript for %x already exists", pkScript)
|
|
}
|
|
|
|
m.sweepPkScripts[chanID] = cloneBytes(pkScript)
|
|
|
|
return nil
|
|
}
|
|
|
|
func cloneBytes(b []byte) []byte {
|
|
bb := make([]byte, len(b))
|
|
copy(bb, b)
|
|
return bb
|
|
}
|