diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go new file mode 100644 index 00000000..54f9a697 --- /dev/null +++ b/watchtower/wtmock/client_db.go @@ -0,0 +1,223 @@ +package wtmock + +import ( + "fmt" + "net" + "sync" + "sync/atomic" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +type towerPK [33]byte + +// ClientDB is a mock, in-memory database or testing the watchtower client +// behavior. +type ClientDB struct { + nextTowerID uint64 // to be used atomically + + mu sync.Mutex + sweepPkScripts map[lnwire.ChannelID][]byte + activeSessions map[wtdb.SessionID]*wtdb.ClientSession + towerIndex map[towerPK]uint64 + towers map[uint64]*wtdb.Tower +} + +// NewClientDB initializes a new mock ClientDB. +func NewClientDB() *ClientDB { + return &ClientDB{ + sweepPkScripts: make(map[lnwire.ChannelID][]byte), + activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), + towerIndex: make(map[towerPK]uint64), + towers: make(map[uint64]*wtdb.Tower), + } +} + +// CreateTower initializes a database entry with the given lightning address. If +// the tower exists, the address is append to the list of all addresses used to +// that tower previously. +func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { + m.mu.Lock() + defer m.mu.Unlock() + + var towerPubKey towerPK + copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed()) + + var tower *wtdb.Tower + towerID, ok := m.towerIndex[towerPubKey] + if ok { + tower = m.towers[towerID] + tower.AddAddress(lnAddr.Address) + } else { + towerID = atomic.AddUint64(&m.nextTowerID, 1) + tower = &wtdb.Tower{ + ID: towerID, + IdentityKey: lnAddr.IdentityKey, + Addresses: []net.Addr{lnAddr.Address}, + } + } + + m.towerIndex[towerPubKey] = towerID + m.towers[towerID] = tower + + return tower, nil +} + +// MarkBackupIneligible records that particular commit height is ineligible for +// backup. This allows the client to track which updates it should not attempt +// to retry after startup. +func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error { + return nil +} + +// ListClientSessions returns the set of all client sessions known to the db. +func (m *ClientDB) ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) { + m.mu.Lock() + defer m.mu.Unlock() + + sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) + for _, session := range m.activeSessions { + sessions[session.ID] = session + } + + return sessions, nil +} + +// CreateClientSession records a newly negotiated client session in the set of +// active sessions. The session can be identified by its SessionID. +func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.activeSessions[session.ID] = &wtdb.ClientSession{ + TowerID: session.TowerID, + Tower: session.Tower, + SessionKeyDesc: session.SessionKeyDesc, + SessionPrivKey: session.SessionPrivKey, + ID: session.ID, + Policy: session.Policy, + SeqNum: session.SeqNum, + TowerLastApplied: session.TowerLastApplied, + RewardPkScript: session.RewardPkScript, + CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate), + AckedUpdates: make(map[uint16]wtdb.BackupID), + } + + return nil +} + +// CommitUpdate persists the CommittedUpdate provided in the slot for (session, +// seqNum). This allows the client to retransmit this update on startup. +func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16, + update *wtdb.CommittedUpdate) (uint16, error) { + + m.mu.Lock() + defer m.mu.Unlock() + + // Fail if session doesn't exist. + session, ok := m.activeSessions[*id] + if !ok { + return 0, wtdb.ErrClientSessionNotFound + } + + // Check if an update has already been committed for this state. + dbUpdate, ok := session.CommittedUpdates[seqNum] + if ok { + // If the breach hint matches, we'll just return the last + // applied value so the client can retransmit. + if dbUpdate.Hint == update.Hint { + return session.TowerLastApplied, nil + } + + // Otherwise, fail since the breach hint doesn't match. + return 0, wtdb.ErrUpdateAlreadyCommitted + } + + // Sequence number must increment. + if seqNum != session.SeqNum+1 { + return 0, wtdb.ErrCommitUnorderedUpdate + } + + // Save the update and increment the sequence number. + session.CommittedUpdates[seqNum] = update + session.SeqNum++ + + return session.TowerLastApplied, nil +} + +// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This +// removes the update from the set of committed updates, and validates the +// lastApplied value returned from the tower. +func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error { + m.mu.Lock() + defer m.mu.Unlock() + + // Fail if session doesn't exist. + session, ok := m.activeSessions[*id] + if !ok { + return wtdb.ErrClientSessionNotFound + } + + // Retrieve the committed update, failing if none is found. We should + // only receive acks for state updates that we send. + update, ok := session.CommittedUpdates[seqNum] + if !ok { + return wtdb.ErrCommittedUpdateNotFound + } + + // Ensure the returned last applied value does not exceed the highest + // allocated sequence number. + if lastApplied > session.SeqNum { + return wtdb.ErrUnallocatedLastApplied + } + + // Ensure the last applied value isn't lower than a previous one sent by + // the tower. + if lastApplied < session.TowerLastApplied { + return wtdb.ErrLastAppliedReversion + } + + // Finally, remove the committed update from disk and mark the update as + // acked. The tower last applied value is also recorded to send along + // with the next update. + delete(session.CommittedUpdates, seqNum) + session.AckedUpdates[seqNum] = update.BackupID + session.TowerLastApplied = lastApplied + + return nil +} + +// FetchChanPkScripts returns the set of sweep pkscripts known for all channels. +// This allows the client to cache them in memory on startup. +func (m *ClientDB) FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + sweepPkScripts := make(map[lnwire.ChannelID][]byte) + for chanID, pkScript := range m.sweepPkScripts { + sweepPkScripts[chanID] = cloneBytes(pkScript) + } + + return sweepPkScripts, nil +} + +// AddChanPkScript sets a pkscript or sweeping funds from the channel or chanID. +func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.sweepPkScripts[chanID]; ok { + return fmt.Errorf("pkscript for %x already exists", pkScript) + } + + m.sweepPkScripts[chanID] = cloneBytes(pkScript) + + return nil +} + +func cloneBytes(b []byte) []byte { + bb := make([]byte, len(b)) + copy(bb, b) + return bb +}