Merge pull request #3106 from cfromknecht/wtclient-db

watchtower/wtdb: add bbolt-backed ClientDB
This commit is contained in:
Olaoluwa Osuntokun 2019-05-24 18:53:00 -07:00 committed by GitHub
commit 6e3b92b55f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 2480 additions and 365 deletions

@ -103,6 +103,11 @@ func WriteElement(w io.Writer, element interface{}) error {
return err
}
case lnwire.ChannelID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case uint64:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
@ -123,6 +128,11 @@ func WriteElement(w io.Writer, element interface{}) error {
return err
}
case uint8:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case bool:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
@ -259,6 +269,11 @@ func ReadElement(r io.Reader, element interface{}) error {
}
*e = lnwire.NewShortChanIDFromInt(a)
case *lnwire.ChannelID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *uint64:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
@ -279,6 +294,11 @@ func ReadElement(r io.Reader, element interface{}) error {
return err
}
case *uint8:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *bool:
if err := binary.Read(r, byteOrder, e); err != nil {
return err

@ -126,8 +126,7 @@ func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
// SessionInfo's policy. If no error is returned, the task has been bound to the
// session and can be queued to upload to the tower. Otherwise, the bind failed
// and should be rescheduled with a different session.
func (t *backupTask) bindSession(session *wtdb.ClientSession) error {
func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// First we'll begin by deriving a weight estimate for the justice
// transaction. The final weight can be different depending on whether
// the watchtower is taking a reward.

@ -69,7 +69,7 @@ type backupTaskTest struct {
expSweepAmt int64
expRewardAmt int64
expRewardScript []byte
session *wtdb.ClientSession
session *wtdb.ClientSessionBody
bindErr error
expSweepScript []byte
signer input.Signer
@ -205,7 +205,7 @@ func genTaskTest(
expSweepAmt: expSweepAmt,
expRewardAmt: expRewardAmt,
expRewardScript: rewardScript,
session: &wtdb.ClientSession{
session: &wtdb.ClientSessionBody{
Policy: wtpolicy.Policy{
BlobType: blobType,
SweepFeeRate: sweepFeeRate,

@ -150,8 +150,9 @@ type TowerClient struct {
sessionQueue *sessionQueue
prevTask *backupTask
sweepPkScriptMu sync.RWMutex
sweepPkScripts map[lnwire.ChannelID][]byte
backupMu sync.Mutex
summaries wtdb.ChannelSummaries
chanCommitHeights map[lnwire.ChannelID]uint64
statTicker *time.Ticker
stats clientStats
@ -243,9 +244,13 @@ func New(config *Config) (*TowerClient, error) {
s.SessionPrivKey = sessionPriv
}
// Reconstruct the highest commit height processed for each channel
// under the client's current policy.
c.buildHighestCommitHeights()
// Finally, load the sweep pkscripts that have been generated for all
// previously registered channels.
c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts()
c.summaries, err = c.cfg.DB.FetchChanSummaries()
if err != nil {
return nil, err
}
@ -253,6 +258,44 @@ func New(config *Config) (*TowerClient, error) {
return c, nil
}
// buildHighestCommitHeights inspects the full set of candidate client sessions
// loaded from disk, and determines the highest known commit height for each
// channel. This allows the client to reject backups that it has already
// processed for it's active policy.
func (c *TowerClient) buildHighestCommitHeights() {
chanCommitHeights := make(map[lnwire.ChannelID]uint64)
for _, s := range c.candidateSessions {
// We only want to consider accepted updates that have been
// accepted under an identical policy to the client's current
// policy.
if s.Policy != c.cfg.Policy {
continue
}
// Take the highest commit height found in the session's
// committed updates.
for _, committedUpdate := range s.CommittedUpdates {
bid := committedUpdate.BackupID
height, ok := chanCommitHeights[bid.ChanID]
if !ok || bid.CommitHeight > height {
chanCommitHeights[bid.ChanID] = bid.CommitHeight
}
}
// Take the heights commit height found in the session's acked
// updates.
for _, bid := range s.AckedUpdates {
height, ok := chanCommitHeights[bid.ChanID]
if !ok || bid.CommitHeight > height {
chanCommitHeights[bid.ChanID] = bid.CommitHeight
}
}
}
c.chanCommitHeights = chanCommitHeights
}
// Start initializes the watchtower client by loading or negotiating an active
// session and then begins processing backup tasks from the request pipeline.
func (c *TowerClient) Start() error {
@ -388,12 +431,12 @@ func (c *TowerClient) ForceQuit() {
// within the client. This should be called during link startup to ensure that
// the client is able to support the link during operation.
func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
c.sweepPkScriptMu.Lock()
defer c.sweepPkScriptMu.Unlock()
c.backupMu.Lock()
defer c.backupMu.Unlock()
// If a pkscript for this channel already exists, the channel has been
// previously registered.
if _, ok := c.sweepPkScripts[chanID]; ok {
if _, ok := c.summaries[chanID]; ok {
return nil
}
@ -406,14 +449,16 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
// Persist the sweep pkscript so that restarts will not introduce
// address inflation when the channel is reregistered after a restart.
err = c.cfg.DB.AddChanPkScript(chanID, pkScript)
err = c.cfg.DB.RegisterChannel(chanID, pkScript)
if err != nil {
return err
}
// Finally, cache the pkscript in our in-memory cache to avoid db
// lookups for the remainder of the daemon's execution.
c.sweepPkScripts[chanID] = pkScript
c.summaries[chanID] = wtdb.ClientChanSummary{
SweepPkScript: pkScript,
}
return nil
}
@ -429,14 +474,29 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
breachInfo *lnwallet.BreachRetribution) error {
// Retrieve the cached sweep pkscript used for this channel.
c.sweepPkScriptMu.RLock()
sweepPkScript, ok := c.sweepPkScripts[*chanID]
c.sweepPkScriptMu.RUnlock()
c.backupMu.Lock()
summary, ok := c.summaries[*chanID]
if !ok {
c.backupMu.Unlock()
return ErrUnregisteredChannel
}
task := newBackupTask(chanID, breachInfo, sweepPkScript)
// Ignore backups that have already been presented to the client.
height, ok := c.chanCommitHeights[*chanID]
if ok && breachInfo.RevokedStateNum <= height {
c.backupMu.Unlock()
log.Debugf("Ignoring duplicate backup for chanid=%v at height=%d",
chanID, breachInfo.RevokedStateNum)
return nil
}
// This backup has a higher commit height than any known backup for this
// channel. We'll update our tip so that we won't accept it again if the
// link flaps.
c.chanCommitHeights[*chanID] = breachInfo.RevokedStateNum
c.backupMu.Unlock()
task := newBackupTask(chanID, breachInfo, summary.SweepPkScript)
return c.pipeline.QueueBackupTask(task)
}

@ -605,6 +605,8 @@ func (h *testHarness) backupStates(id, from, to uint64, expErr error) {
// backupStates instructs the channel identified by id to send a backup for
// state i.
func (h *testHarness) backupState(id, i uint64, expErr error) {
h.t.Helper()
_, retribution := h.channel(id).getState(i)
chanID := chanIDFromInt(id)
@ -1244,6 +1246,55 @@ var clientTests = []clientTest{
h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
},
},
{
// Asserts that the client will deduplicate backups presented by
// a channel both in memory and after a restart. The client
// should only accept backups with a commit height greater than
// any processed already processed for a given policy.
name: "dedup backups",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
BlobType: blob.TypeDefault,
MaxUpdates: 5,
SweepFeeRate: 1,
},
},
fn: func(h *testHarness) {
const (
numUpdates = 10
chanID = 0
)
// Generate the retributions that will be backed up.
hints := h.advanceChannelN(chanID, numUpdates)
// Queue the first half of the retributions twice, the
// second batch should be entirely deduped by the
// client's in-memory tracking.
h.backupStates(chanID, 0, numUpdates/2, nil)
h.backupStates(chanID, 0, numUpdates/2, nil)
// Wait for the first half of the updates to be
// populated in the server's database.
h.waitServerUpdates(hints[:len(hints)/2], 5*time.Second)
// Restart the client, so we can ensure the deduping is
// maintained across restarts.
h.client.Stop()
h.startClient()
defer h.client.ForceQuit()
// Try to back up the full range of retributions. Only
// the second half should actually be sent.
h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, 5*time.Second)
},
},
}
// TestClient executes the client test suite, asserting the ability to backup

@ -21,7 +21,7 @@ type DB interface {
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
// LoadTower retrieves a tower by its tower ID.
LoadTower(uint64) (*wtdb.Tower, error)
LoadTower(wtdb.TowerID) (*wtdb.Tower, error)
// NextSessionKeyIndex reserves a new session key derivation index for a
// particular tower id. The index is reserved for that tower until
@ -29,7 +29,7 @@ type DB interface {
// point a new index for that tower can be reserved. Multiple calls to
// this method before CreateClientSession is invoked should return the
// same index.
NextSessionKeyIndex(uint64) (uint32, error)
NextSessionKeyIndex(wtdb.TowerID) (uint32, error)
// CreateClientSession saves a newly negotiated client session to the
// client's database. This enables the session to be used across
@ -41,14 +41,17 @@ type DB interface {
// still be able to accept state updates.
ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error)
// FetchChanPkScripts returns a map of all sweep pkscripts for
// registered channels. This is used on startup to cache the sweep
// pkscripts of registered channels in memory.
FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error)
// FetchChanSummaries loads a mapping from all registered channels to
// their channel summaries.
FetchChanSummaries() (wtdb.ChannelSummaries, error)
// AddChanPkScript inserts a newly generated sweep pkscript for the
// given channel.
AddChanPkScript(lnwire.ChannelID, []byte) error
// RegisterChannel registers a channel for use within the client
// database. For now, all that is stored in the channel summary is the
// sweep pkscript that we'd like any tower sweeps to pay into. In the
// future, this will be extended to contain more info to allow the
// client efficiently request historical states to be backed up under
// the client's active policy.
RegisterChannel(lnwire.ChannelID, []byte) error
// MarkBackupIneligible records that the state identified by the
// (channel id, commit height) tuple was ineligible for being backed up
@ -61,7 +64,7 @@ type DB interface {
// hasn't been ACK'd by the tower. The sequence number of the update
// should be exactly one greater than the existing entry, and less that
// or equal to the session's MaxUpdates.
CommitUpdate(id *wtdb.SessionID, seqNum uint16,
CommitUpdate(id *wtdb.SessionID,
update *wtdb.CommittedUpdate) (uint16, error)
// AckUpdate records an acknowledgment from the watchtower that the

@ -417,14 +417,15 @@ func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
privKey.PubKey(),
)
clientSession := &wtdb.ClientSession{
TowerID: tower.ID,
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: tower.ID,
KeyIndex: keyIndex,
Policy: n.cfg.Policy,
RewardPkScript: rewardPkScript,
},
Tower: tower,
KeyIndex: keyIndex,
SessionPrivKey: privKey,
ID: sessionID,
Policy: n.cfg.Policy,
SeqNum: 0,
RewardPkScript: rewardPkScript,
}
err = n.cfg.DB.CreateClientSession(clientSession)

@ -3,7 +3,6 @@ package wtclient
import (
"container/list"
"fmt"
"sort"
"sync"
"time"
@ -133,7 +132,11 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
}
sq.queueCond = sync.NewCond(&sq.queueMtx)
sq.restoreCommittedUpdates()
// The database should return them in sorted order, and session queue's
// sequence number will be equal to that of the last committed update.
for _, update := range sq.cfg.ClientSession.CommittedUpdates {
sq.commitQueue.PushBack(update)
}
return sq
}
@ -212,7 +215,7 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
//
// TODO(conner): queue backups and retry with different session params.
case reserveAvailable:
err := task.bindSession(q.cfg.ClientSession)
err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody)
if err != nil {
q.queueCond.L.Unlock()
log.Debugf("SessionQueue %s rejected backup chanid=%s "+
@ -237,45 +240,6 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
return newStatus, true
}
// updateWithSeqNum stores a CommittedUpdate with its assigned sequence number.
// This allows committed updates to be sorted after a restart, and added to the
// commitQueue in the proper order for delivery.
type updateWithSeqNum struct {
seqNum uint16
update *wtdb.CommittedUpdate
}
// restoreCommittedUpdates processes any CommittedUpdates loaded on startup by
// sorting them in ascending order of sequence numbers and adding them to the
// commitQueue. These will be sent before any pending updates are processed.
func (q *sessionQueue) restoreCommittedUpdates() {
committedUpdates := q.cfg.ClientSession.CommittedUpdates
// Construct and unordered slice of all committed updates with their
// assigned sequence numbers.
sortedUpdates := make([]updateWithSeqNum, 0, len(committedUpdates))
for seqNum, update := range committedUpdates {
sortedUpdates = append(sortedUpdates, updateWithSeqNum{
seqNum: seqNum,
update: update,
})
}
// Sort the resulting slice by increasing sequence number.
sort.Slice(sortedUpdates, func(i, j int) bool {
return sortedUpdates[i].seqNum < sortedUpdates[j].seqNum
})
// Finally, add the sorted, committed updates to he commitQueue. These
// updates will be prioritized before any new tasks are assigned to the
// sessionQueue. The queue will begin uploading any tasks in the
// commitQueue as soon as it is started, e.g. during client
// initialization when detecting that this session has unacked updates.
for _, update := range sortedUpdates {
q.commitQueue.PushBack(update)
}
}
// sessionManager is the primary event loop for the sessionQueue, and is
// responsible for encrypting and sending accepted tasks to the tower.
func (q *sessionQueue) sessionManager() {
@ -396,7 +360,7 @@ func (q *sessionQueue) drainBackups() {
func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
var (
seqNum uint16
update *wtdb.CommittedUpdate
update wtdb.CommittedUpdate
isLast bool
isPending bool
)
@ -407,10 +371,9 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
// If the commit queue is non-empty, parse the next committed update.
case q.commitQueue.Len() > 0:
next := q.commitQueue.Front()
updateWithSeq := next.Value.(updateWithSeqNum)
seqNum = updateWithSeq.seqNum
update = updateWithSeq.update
update = next.Value.(wtdb.CommittedUpdate)
seqNum = update.SeqNum
// If this is the last item in the commit queue and no items
// exist in the pending queue, we will use the IsComplete flag
@ -449,10 +412,13 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
}
// TODO(conner): special case other obscure errors
update = &wtdb.CommittedUpdate{
BackupID: task.id,
Hint: hint,
EncryptedBlob: encBlob,
update = wtdb.CommittedUpdate{
SeqNum: seqNum,
CommittedUpdateBody: wtdb.CommittedUpdateBody{
BackupID: task.id,
Hint: hint,
EncryptedBlob: encBlob,
},
}
log.Debugf("Committing state update for session=%s seqnum=%d",
@ -470,7 +436,7 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
// we send the next time. This step ensures that if we reliably send the
// same update for a given sequence number, to prevent us from thinking
// we backed up a state when we instead backed up another.
lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), seqNum, update)
lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), &update)
if err != nil {
// TODO(conner): mark failed/reschedule
return nil, false, fmt.Errorf("unable to commit state update "+
@ -478,7 +444,7 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
}
stateUpdate := &wtwire.StateUpdate{
SeqNum: seqNum,
SeqNum: update.SeqNum,
LastApplied: lastApplied,
Hint: update.Hint,
EncryptedBlob: update.EncryptedBlob,

@ -0,0 +1,32 @@
package wtdb
import (
"io"
"github.com/lightningnetwork/lnd/lnwire"
)
// ChannelSummaries is a map for a given channel id to it's ClientChanSummary.
type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary
// ClientChanSummary tracks channel-specific information. A new
// ClientChanSummary is inserted in the database the first time the client
// encounters a particular channel.
type ClientChanSummary struct {
// SweepPkScript is the pkscript to which all justice transactions will
// deposit recovered funds for this particular channel.
SweepPkScript []byte
// TODO(conner): later extend with info about initial commit height,
// ineligible states, etc.
}
// Encode writes the ClientChanSummary to the passed io.Writer.
func (s *ClientChanSummary) Encode(w io.Writer) error {
return WriteElement(w, s.SweepPkScript)
}
// Decode reads a ClientChanSummary form the passed io.Reader.
func (s *ClientChanSummary) Decode(r io.Reader) error {
return ReadElement(r, &s.SweepPkScript)
}

@ -0,0 +1,908 @@
package wtdb
import (
"bytes"
"errors"
"fmt"
"math"
"net"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/lnwire"
)
const (
// clientDBName is the filename of client database.
clientDBName = "wtclient.db"
)
var (
// cSessionKeyIndexBkt is a top-level bucket storing:
// tower-id -> reserved-session-key-index (uint32).
cSessionKeyIndexBkt = []byte("client-session-key-index-bucket")
// cChanSummaryBkt is a top-level bucket storing:
// channel-id -> encoded ClientChanSummary.
cChanSummaryBkt = []byte("client-channel-summary-bucket")
// cSessionBkt is a top-level bucket storing:
// session-id => cSessionBody -> encoded ClientSessionBody
// => cSessionCommits => seqnum -> encoded CommittedUpdate
// => cSessionAcks => seqnum -> encoded BackupID
cSessionBkt = []byte("client-session-bucket")
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
// the ClientSession.
cSessionBody = []byte("client-session-body")
// cSessionBody is a sub-bucket of cSessionBkt storing:
// seqnum -> encoded CommittedUpdate.
cSessionCommits = []byte("client-session-commits")
// cSessionAcks is a sub-bucket of cSessionBkt storing:
// seqnum -> encoded BackupID.
cSessionAcks = []byte("client-session-acks")
// cTowerBkt is a top-level bucket storing:
// tower-id -> encoded Tower.
cTowerBkt = []byte("client-tower-bucket")
// cTowerIndexBkt is a top-level bucket storing:
// tower-pubkey -> tower-id.
cTowerIndexBkt = []byte("client-tower-index-bucket")
// ErrTowerNotFound signals that the target tower was not found in the
// database.
ErrTowerNotFound = errors.New("tower not found")
// ErrCorruptClientSession signals that the client session's on-disk
// structure deviates from what is expected.
ErrCorruptClientSession = errors.New("client session corrupted")
// ErrClientSessionAlreadyExists signals an attempt to reinsert a client
// session that has already been created.
ErrClientSessionAlreadyExists = errors.New(
"client session already exists",
)
// ErrChannelAlreadyRegistered signals a duplicate attempt to register a
// channel with the client database.
ErrChannelAlreadyRegistered = errors.New("channel already registered")
// ErrChannelNotRegistered signals a channel has not yet been registered
// in the client database.
ErrChannelNotRegistered = errors.New("channel not registered")
// ErrClientSessionNotFound signals that the requested client session
// was not found in the database.
ErrClientSessionNotFound = errors.New("client session not found")
// ErrUpdateAlreadyCommitted signals that the chosen sequence number has
// already been committed to an update with a different breach hint.
ErrUpdateAlreadyCommitted = errors.New("update already committed")
// ErrCommitUnorderedUpdate signals the client tried to commit a
// sequence number other than the next unallocated sequence number.
ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic")
// ErrCommittedUpdateNotFound signals that the tower tried to ACK a
// sequence number that has not yet been allocated by the client.
ErrCommittedUpdateNotFound = errors.New("committed update not found")
// ErrUnallocatedLastApplied signals that the tower tried to provide a
// LastApplied value greater than any allocated sequence number.
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
"greater than allocated seqnum")
// ErrNoReservedKeyIndex signals that a client session could not be
// created because no session key index was reserved.
ErrNoReservedKeyIndex = errors.New("key index not reserved")
// ErrIncorrectKeyIndex signals that the client session could not be
// created because session key index differs from the reserved key
// index.
ErrIncorrectKeyIndex = errors.New("incorrect key index")
)
// ClientDB is single database providing a persistent storage engine for the
// wtclient.
type ClientDB struct {
db *bbolt.DB
dbPath string
}
// OpenClientDB opens the client database given the path to the database's
// directory. If no such database exists, this method will initialize a fresh
// one using the latest version number and bucket structure. If a database
// exists but has a lower version number than the current version, any necessary
// migrations will be applied before returning. Any attempt to open a database
// with a version number higher that the latest version will fail to prevent
// accidental reversion.
func OpenClientDB(dbPath string) (*ClientDB, error) {
bdb, firstInit, err := createDBIfNotExist(dbPath, clientDBName)
if err != nil {
return nil, err
}
clientDB := &ClientDB{
db: bdb,
dbPath: dbPath,
}
err = initOrSyncVersions(clientDB, firstInit, clientDBVersions)
if err != nil {
bdb.Close()
return nil, err
}
// Now that the database version fully consistent with our latest known
// version, ensure that all top-level buckets known to this version are
// initialized. This allows us to assume their presence throughout all
// operations. If an known top-level bucket is expected to exist but is
// missing, this will trigger a ErrUninitializedDB error.
err = clientDB.db.Update(initClientDBBuckets)
if err != nil {
bdb.Close()
return nil, err
}
return clientDB, nil
}
// initClientDBBuckets creates all top-level buckets required to handle database
// operations required by the latest version.
func initClientDBBuckets(tx *bbolt.Tx) error {
buckets := [][]byte{
cSessionKeyIndexBkt,
cChanSummaryBkt,
cSessionBkt,
cTowerBkt,
cTowerIndexBkt,
}
for _, bucket := range buckets {
_, err := tx.CreateBucketIfNotExists(bucket)
if err != nil {
return err
}
}
return nil
}
// bdb returns the backing bbolt.DB instance.
//
// NOTE: Part of the versionedDB interface.
func (c *ClientDB) bdb() *bbolt.DB {
return c.db
}
// Version returns the database's current version number.
//
// NOTE: Part of the versionedDB interface.
func (c *ClientDB) Version() (uint32, error) {
var version uint32
err := c.db.View(func(tx *bbolt.Tx) error {
var err error
version, err = getDBVersion(tx)
return err
})
if err != nil {
return 0, err
}
return version, nil
}
// Close closes the underlying database.
func (c *ClientDB) Close() error {
return c.db.Close()
}
// CreateTower initializes a database entry with the given lightning address. If
// the tower exists, the address is append to the list of all addresses used to
// that tower previously.
func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
var towerPubKey [33]byte
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
var tower *Tower
err := c.db.Update(func(tx *bbolt.Tx) error {
towerIndex := tx.Bucket(cTowerIndexBkt)
if towerIndex == nil {
return ErrUninitializedDB
}
towers := tx.Bucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
// Check if the tower index already knows of this pubkey.
towerIDBytes := towerIndex.Get(towerPubKey[:])
if len(towerIDBytes) == 8 {
// The tower already exists, deserialize the existing
// record.
var err error
tower, err = getTower(towers, towerIDBytes)
if err != nil {
return err
}
// Add the new address to the existing tower. If the
// address is a duplicate, this will result in no
// change.
tower.AddAddress(lnAddr.Address)
} else {
// No such tower exists, create a new tower id for our
// new tower. The error is unhandled since NextSequence
// never fails in an Update.
towerID, _ := towerIndex.NextSequence()
tower = &Tower{
ID: TowerID(towerID),
IdentityKey: lnAddr.IdentityKey,
Addresses: []net.Addr{lnAddr.Address},
}
towerIDBytes = tower.ID.Bytes()
// Since this tower is new, record the mapping from
// tower pubkey to tower id in the tower index.
err := towerIndex.Put(towerPubKey[:], towerIDBytes)
if err != nil {
return err
}
}
// Store the new or updated tower under its tower id.
return putTower(towers, tower)
})
if err != nil {
return nil, err
}
return tower, nil
}
// LoadTower retrieves a tower by its tower ID.
func (c *ClientDB) LoadTower(towerID TowerID) (*Tower, error) {
var tower *Tower
err := c.db.View(func(tx *bbolt.Tx) error {
towers := tx.Bucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
var err error
tower, err = getTower(towers, towerID.Bytes())
return err
})
if err != nil {
return nil, err
}
return tower, nil
}
// NextSessionKeyIndex reserves a new session key derivation index for a
// particular tower id. The index is reserved for that tower until
// CreateClientSession is invoked for that tower and index, at which point a new
// index for that tower can be reserved. Multiple calls to this method before
// CreateClientSession is invoked should return the same index.
func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
var index uint32
err := c.db.Update(func(tx *bbolt.Tx) error {
keyIndex := tx.Bucket(cSessionKeyIndexBkt)
if keyIndex == nil {
return ErrUninitializedDB
}
// Check the session key index to see if a key has already been
// reserved for this tower. If so, we'll deserialize and return
// the index directly.
towerIDBytes := towerID.Bytes()
indexBytes := keyIndex.Get(towerIDBytes)
if len(indexBytes) == 4 {
index = byteOrder.Uint32(indexBytes)
return nil
}
// Otherwise, generate a new session key index since the node
// doesn't already have reserved index. The error is ignored
// since NextSequence can't fail inside Update.
index64, _ := keyIndex.NextSequence()
// As a sanity check, assert that the index is still in the
// valid range of unhardened pubkeys. In the future, we should
// move to only using hardened keys, and this will prevent any
// overlap from occurring until then. This also prevents us from
// overflowing uint32s.
if index64 > math.MaxInt32 {
return fmt.Errorf("exhausted session key indexes")
}
index = uint32(index64)
var indexBuf [4]byte
byteOrder.PutUint32(indexBuf[:], index)
// Record the reserved session key index under this tower's id.
return keyIndex.Put(towerIDBytes, indexBuf[:])
})
if err != nil {
return 0, err
}
return index, nil
}
// CreateClientSession records a newly negotiated client session in the set of
// active sessions. The session can be identified by its SessionID.
func (c *ClientDB) CreateClientSession(session *ClientSession) error {
return c.db.Update(func(tx *bbolt.Tx) error {
keyIndexes := tx.Bucket(cSessionKeyIndexBkt)
if keyIndexes == nil {
return ErrUninitializedDB
}
sessions := tx.Bucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
// Check that client session with this session id doesn't
// already exist.
existingSessionBytes := sessions.Bucket(session.ID[:])
if existingSessionBytes != nil {
return ErrClientSessionAlreadyExists
}
// Check that this tower has a reserved key index.
towerIDBytes := session.TowerID.Bytes()
keyIndexBytes := keyIndexes.Get(towerIDBytes)
if len(keyIndexBytes) != 4 {
return ErrNoReservedKeyIndex
}
// Assert that the key index of the inserted session matches the
// reserved session key index.
index := byteOrder.Uint32(keyIndexBytes)
if index != session.KeyIndex {
return ErrIncorrectKeyIndex
}
// Remove the key index reservation.
err := keyIndexes.Delete(towerIDBytes)
if err != nil {
return err
}
// Finally, write the client session's body in the sessions
// bucket.
return putClientSessionBody(sessions, session)
})
}
// ListClientSessions returns the set of all client sessions known to the db.
func (c *ClientDB) ListClientSessions() (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
err := c.db.View(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
return sessions.ForEach(func(k, _ []byte) error {
// We'll load the full client session since the client
// will need the CommittedUpdates and AckedUpdates on
// startup to resume committed updates and compute the
// highest known commit height for each channel.
session, err := getClientSession(sessions, k)
if err != nil {
return err
}
clientSessions[session.ID] = session
return nil
})
})
if err != nil {
return nil, err
}
return clientSessions, nil
}
// FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
summaries := make(map[lnwire.ChannelID]ClientChanSummary)
err := c.db.View(func(tx *bbolt.Tx) error {
chanSummaries := tx.Bucket(cChanSummaryBkt)
if chanSummaries == nil {
return ErrUninitializedDB
}
return chanSummaries.ForEach(func(k, v []byte) error {
var chanID lnwire.ChannelID
copy(chanID[:], k)
var summary ClientChanSummary
err := summary.Decode(bytes.NewReader(v))
if err != nil {
return err
}
summaries[chanID] = summary
return nil
})
})
if err != nil {
return nil, err
}
return summaries, nil
}
// RegisterChannel registers a channel for use within the client database. For
// now, all that is stored in the channel summary is the sweep pkscript that
// we'd like any tower sweeps to pay into. In the future, this will be extended
// to contain more info to allow the client efficiently request historical
// states to be backed up under the client's active policy.
func (c *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
sweepPkScript []byte) error {
return c.db.Update(func(tx *bbolt.Tx) error {
chanSummaries := tx.Bucket(cChanSummaryBkt)
if chanSummaries == nil {
return ErrUninitializedDB
}
_, err := getChanSummary(chanSummaries, chanID)
switch {
// Summary already exists.
case err == nil:
return ErrChannelAlreadyRegistered
// Channel is not registered, proceed with registration.
case err == ErrChannelNotRegistered:
// Unexpected error.
case err != nil:
return err
}
summary := ClientChanSummary{
SweepPkScript: sweepPkScript,
}
return putChanSummary(chanSummaries, chanID, &summary)
})
}
// MarkBackupIneligible records that the state identified by the (channel id,
// commit height) tuple was ineligible for being backed up under the current
// policy. This state can be retried later under a different policy.
func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID,
commitHeight uint64) error {
return nil
}
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
// seqNum). This allows the client to retransmit this update on startup.
func (c *ClientDB) CommitUpdate(id *SessionID,
update *CommittedUpdate) (uint16, error) {
var lastApplied uint16
err := c.db.Update(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
// We'll only load the ClientSession body for performance, since
// we primarily need to inspect its SeqNum and TowerLastApplied
// fields. The CommittedUpdates will be modified on disk
// directly.
session, err := getClientSessionBody(sessions, id[:])
if err != nil {
return err
}
// Can't fail if the above didn't fail.
sessionBkt := sessions.Bucket(id[:])
// Ensure the session commits sub-bucket is initialized.
sessionCommits, err := sessionBkt.CreateBucketIfNotExists(
cSessionCommits,
)
if err != nil {
return err
}
var seqNumBuf [2]byte
byteOrder.PutUint16(seqNumBuf[:], update.SeqNum)
// Check to see if a committed update already exists for this
// sequence number.
committedUpdateBytes := sessionCommits.Get(seqNumBuf[:])
if committedUpdateBytes != nil {
var dbUpdate CommittedUpdate
err := dbUpdate.Decode(
bytes.NewReader(committedUpdateBytes),
)
if err != nil {
return err
}
// If an existing committed update has a different hint,
// we'll reject this newer update.
if dbUpdate.Hint != update.Hint {
return ErrUpdateAlreadyCommitted
}
// Otherwise, capture the last applied value and
// succeed.
lastApplied = session.TowerLastApplied
return nil
}
// There's no committed update for this sequence number, ensure
// that we are committing the next unallocated one.
if update.SeqNum != session.SeqNum+1 {
return ErrCommitUnorderedUpdate
}
// Increment the session's sequence number and store the updated
// client session.
//
// TODO(conner): split out seqnum and last applied own bucket to
// eliminate serialization of full struct during CommitUpdate?
// Can also read/write directly to byes [:2] without migration.
session.SeqNum++
err = putClientSessionBody(sessions, session)
if err != nil {
return err
}
// Encode and store the committed update in the sessionCommits
// sub-bucket under the requested sequence number.
var b bytes.Buffer
err = update.Encode(&b)
if err != nil {
return err
}
err = sessionCommits.Put(seqNumBuf[:], b.Bytes())
if err != nil {
return err
}
// Finally, capture the session's last applied value so it can
// be sent in the next state update to the tower.
lastApplied = session.TowerLastApplied
return nil
})
if err != nil {
return 0, err
}
return lastApplied, nil
}
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
// removes the update from the set of committed updates, and validates the
// lastApplied value returned from the tower.
func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
lastApplied uint16) error {
return c.db.Update(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
// We'll only load the ClientSession body for performance, since
// we primarily need to inspect its SeqNum and TowerLastApplied
// fields. The CommittedUpdates and AckedUpdates will be
// modified on disk directly.
session, err := getClientSessionBody(sessions, id[:])
if err != nil {
return err
}
// If the tower has acked a sequence number beyond our highest
// sequence number, fail.
if lastApplied > session.SeqNum {
return ErrUnallocatedLastApplied
}
// If the tower acked with a lower sequence number than it gave
// us prior, fail.
if lastApplied < session.TowerLastApplied {
return ErrLastAppliedReversion
}
// TODO(conner): split out seqnum and last applied own bucket to
// eliminate serialization of full struct during AckUpdate? Can
// also read/write directly to byes [2:4] without migration.
session.TowerLastApplied = lastApplied
// Write the client session with the updated last applied value.
err = putClientSessionBody(sessions, session)
if err != nil {
return err
}
// Can't fail because of getClientSession succeeded.
sessionBkt := sessions.Bucket(id[:])
// If the commits sub-bucket doesn't exist, there can't possibly
// be a corresponding committed update to remove.
sessionCommits := sessionBkt.Bucket(cSessionCommits)
if sessionCommits == nil {
return ErrCommittedUpdateNotFound
}
var seqNumBuf [2]byte
byteOrder.PutUint16(seqNumBuf[:], seqNum)
// Assert that a committed update exists for this sequence
// number.
committedUpdateBytes := sessionCommits.Get(seqNumBuf[:])
if committedUpdateBytes == nil {
return ErrCommittedUpdateNotFound
}
var committedUpdate CommittedUpdate
err = committedUpdate.Decode(
bytes.NewReader(committedUpdateBytes),
)
if err != nil {
return err
}
// Remove the corresponding committed update.
err = sessionCommits.Delete(seqNumBuf[:])
if err != nil {
return err
}
// Ensure that the session acks sub-bucket is initialized so we
// can insert an entry.
sessionAcks, err := sessionBkt.CreateBucketIfNotExists(
cSessionAcks,
)
if err != nil {
return err
}
// The session acks only need to track the backup id of the
// update, so we can discard the blob and hint.
var b bytes.Buffer
err = committedUpdate.BackupID.Encode(&b)
if err != nil {
return err
}
// Finally, insert the ack into the sessionAcks sub-bucket.
return sessionAcks.Put(seqNumBuf[:], b.Bytes())
})
}
// getClientSessionBody loads the body of a ClientSession from the sessions
// bucket corresponding to the serialized session id. This does not deserialize
// the CommittedUpdates or AckUpdates associated with the session. If the caller
// requires this info, use getClientSession.
func getClientSessionBody(sessions *bbolt.Bucket,
idBytes []byte) (*ClientSession, error) {
sessionBkt := sessions.Bucket(idBytes)
if sessionBkt == nil {
return nil, ErrClientSessionNotFound
}
// Should never have a sessionBkt without also having its body.
sessionBody := sessionBkt.Get(cSessionBody)
if sessionBody == nil {
return nil, ErrCorruptClientSession
}
var session ClientSession
copy(session.ID[:], idBytes)
err := session.Decode(bytes.NewReader(sessionBody))
if err != nil {
return nil, err
}
return &session, nil
}
// getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates and AckUpdates in
// addition to the ClientSession's body.
func getClientSession(sessions *bbolt.Bucket,
idBytes []byte) (*ClientSession, error) {
session, err := getClientSessionBody(sessions, idBytes)
if err != nil {
return nil, err
}
// Fetch the committed updates for this session.
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
if err != nil {
return nil, err
}
// Fetch the acked updates for this session.
ackedUpdates, err := getClientSessionAcks(sessions, idBytes)
if err != nil {
return nil, err
}
session.CommittedUpdates = commitedUpdates
session.AckedUpdates = ackedUpdates
return session, nil
}
// getClientSessionCommits retrieves all committed updates for the session
// identified by the serialized session id.
func getClientSessionCommits(sessions *bbolt.Bucket,
idBytes []byte) ([]CommittedUpdate, error) {
// Can't fail because client session body has already been read.
sessionBkt := sessions.Bucket(idBytes)
// Initialize commitedUpdates so that we can return an initialized map
// if no committed updates exist.
committedUpdates := make([]CommittedUpdate, 0)
sessionCommits := sessionBkt.Bucket(cSessionCommits)
if sessionCommits == nil {
return committedUpdates, nil
}
err := sessionCommits.ForEach(func(k, v []byte) error {
var committedUpdate CommittedUpdate
err := committedUpdate.Decode(bytes.NewReader(v))
if err != nil {
return err
}
committedUpdate.SeqNum = byteOrder.Uint16(k)
committedUpdates = append(committedUpdates, committedUpdate)
return nil
})
if err != nil {
return nil, err
}
return committedUpdates, nil
}
// getClientSessionAcks retrieves all acked updates for the session identified
// by the serialized session id.
func getClientSessionAcks(sessions *bbolt.Bucket,
idBytes []byte) (map[uint16]BackupID, error) {
// Can't fail because client session body has already been read.
sessionBkt := sessions.Bucket(idBytes)
// Initialize ackedUpdates so that we can return an initialized map if
// no acked updates exist.
ackedUpdates := make(map[uint16]BackupID)
sessionAcks := sessionBkt.Bucket(cSessionAcks)
if sessionAcks == nil {
return ackedUpdates, nil
}
err := sessionAcks.ForEach(func(k, v []byte) error {
seqNum := byteOrder.Uint16(k)
var backupID BackupID
err := backupID.Decode(bytes.NewReader(v))
if err != nil {
return err
}
ackedUpdates[seqNum] = backupID
return nil
})
if err != nil {
return nil, err
}
return ackedUpdates, nil
}
// putClientSessionBody stores the body of the ClientSession (everything but the
// CommittedUpdates and AckedUpdates).
func putClientSessionBody(sessions *bbolt.Bucket,
session *ClientSession) error {
sessionBkt, err := sessions.CreateBucketIfNotExists(session.ID[:])
if err != nil {
return err
}
var b bytes.Buffer
err = session.Encode(&b)
if err != nil {
return err
}
return sessionBkt.Put(cSessionBody, b.Bytes())
}
// getChanSummary loads a ClientChanSummary for the passed chanID.
func getChanSummary(chanSummaries *bbolt.Bucket,
chanID lnwire.ChannelID) (*ClientChanSummary, error) {
chanSummaryBytes := chanSummaries.Get(chanID[:])
if chanSummaryBytes == nil {
return nil, ErrChannelNotRegistered
}
var summary ClientChanSummary
err := summary.Decode(bytes.NewReader(chanSummaryBytes))
if err != nil {
return nil, err
}
return &summary, nil
}
// putChanSummary stores a ClientChanSummary for the passed chanID.
func putChanSummary(chanSummaries *bbolt.Bucket, chanID lnwire.ChannelID,
summary *ClientChanSummary) error {
var b bytes.Buffer
err := summary.Encode(&b)
if err != nil {
return err
}
return chanSummaries.Put(chanID[:], b.Bytes())
}
// getTower loads a Tower identified by its serialized tower id.
func getTower(towers *bbolt.Bucket, id []byte) (*Tower, error) {
towerBytes := towers.Get(id)
if towerBytes == nil {
return nil, ErrTowerNotFound
}
var tower Tower
err := tower.Decode(bytes.NewReader(towerBytes))
if err != nil {
return nil, err
}
tower.ID = TowerIDFromBytes(id)
return &tower, nil
}
// putTower stores a Tower identified by its serialized tower id.
func putTower(towers *bbolt.Bucket, tower *Tower) error {
var b bytes.Buffer
err := tower.Encode(&b)
if err != nil {
return err
}
return towers.Put(tower.ID.Bytes(), b.Bytes())
}

@ -0,0 +1,688 @@
package wtdb_test
import (
"bytes"
crand "crypto/rand"
"io"
"io/ioutil"
"net"
"os"
"reflect"
"testing"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtclient"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
// clientDBInit is a closure used to initialize a wtclient.DB instance its
// cleanup function.
type clientDBInit func(t *testing.T) (wtclient.DB, func())
type clientDBHarness struct {
t *testing.T
db wtclient.DB
}
func newClientDBHarness(t *testing.T, init clientDBInit) (*clientDBHarness, func()) {
db, cleanup := init(t)
h := &clientDBHarness{
t: t,
db: db,
}
return h, cleanup
}
func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) {
h.t.Helper()
err := h.db.CreateClientSession(session)
if err != expErr {
h.t.Fatalf("expected create client session error: %v, got: %v",
expErr, err)
}
}
func (h *clientDBHarness) listSessions() map[wtdb.SessionID]*wtdb.ClientSession {
h.t.Helper()
sessions, err := h.db.ListClientSessions()
if err != nil {
h.t.Fatalf("unable to list client sessions: %v", err)
}
return sessions
}
func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, expErr error) uint32 {
h.t.Helper()
index, err := h.db.NextSessionKeyIndex(id)
if err != expErr {
h.t.Fatalf("expected next session key index error: %v, got: %v",
expErr, err)
}
if index == 0 {
h.t.Fatalf("next key index should never be 0")
}
return index
}
func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
expErr error) *wtdb.Tower {
h.t.Helper()
tower, err := h.db.CreateTower(lnAddr)
if err != expErr {
h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err)
}
if tower.ID == 0 {
h.t.Fatalf("tower id should never be 0")
}
return tower
}
func (h *clientDBHarness) loadTower(id wtdb.TowerID, expErr error) *wtdb.Tower {
h.t.Helper()
tower, err := h.db.LoadTower(id)
if err != expErr {
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
}
return tower
}
func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary {
h.t.Helper()
summaries, err := h.db.FetchChanSummaries()
if err != nil {
h.t.Fatalf("unable to fetch chan summaries: %v", err)
}
return summaries
}
func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID,
sweepPkScript []byte, expErr error) {
h.t.Helper()
err := h.db.RegisterChannel(chanID, sweepPkScript)
if err != expErr {
h.t.Fatalf("expected register channel error: %v, got: %v",
expErr, err)
}
}
func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
update *wtdb.CommittedUpdate, expErr error) uint16 {
h.t.Helper()
lastApplied, err := h.db.CommitUpdate(id, update)
if err != expErr {
h.t.Fatalf("expected commit update error: %v, got: %v",
expErr, err)
}
return lastApplied
}
func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
lastApplied uint16, expErr error) {
h.t.Helper()
err := h.db.AckUpdate(id, seqNum, lastApplied)
if err != expErr {
h.t.Fatalf("expected commit update error: %v, got: %v",
expErr, err)
}
}
// testCreateClientSession asserts various conditions regarding the creation of
// a new ClientSession. The test asserts:
// - client sessions can only be created if a session key index is reserved.
// - client sessions cannot be created with an incorrect session key index .
// - inserting duplicate sessions fails.
func testCreateClientSession(h *clientDBHarness) {
// Create a test client session to insert.
session := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: wtdb.TowerID(3),
Policy: wtpolicy.Policy{
MaxUpdates: 100,
},
RewardPkScript: []byte{0x01, 0x02, 0x03},
},
ID: wtdb.SessionID([33]byte{0x01}),
}
// First, assert that this session is not already present in the
// database.
if _, ok := h.listSessions()[session.ID]; ok {
h.t.Fatalf("session for id %x should not exist yet", session.ID)
}
// Attempting to insert the client session without reserving a session
// key index should fail.
h.insertSession(session, wtdb.ErrNoReservedKeyIndex)
// Now, reserve a session key for this tower.
keyIndex := h.nextKeyIndex(session.TowerID, nil)
// The client session hasn't been updated with the reserved key index
// (since it's still zero). Inserting should fail due to the mismatch.
h.insertSession(session, wtdb.ErrIncorrectKeyIndex)
// Reserve another key for the same index. Since no session has been
// successfully created, it should return the same index to maintain
// idempotency across restarts.
keyIndex2 := h.nextKeyIndex(session.TowerID, nil)
if keyIndex != keyIndex2 {
h.t.Fatalf("next key index should be idempotent: want: %v, "+
"got %v", keyIndex, keyIndex2)
}
// Now, set the client session's key index so that it is proper and
// insert it. This should succeed.
session.KeyIndex = keyIndex
h.insertSession(session, nil)
// Verify that the session now exists in the database.
if _, ok := h.listSessions()[session.ID]; !ok {
h.t.Fatalf("session for id %x should exist now", session.ID)
}
// Attempt to insert the session again, which should fail due to the
// session already existing.
h.insertSession(session, wtdb.ErrClientSessionAlreadyExists)
// Finally, assert that reserving another key index succeeds with a
// different key index, now that the first one has been finalized.
keyIndex3 := h.nextKeyIndex(session.TowerID, nil)
if keyIndex == keyIndex3 {
h.t.Fatalf("key index still reserved after creating session")
}
}
// testCreateTower asserts the behavior of creating new Tower objects within the
// database, and that the latest address is always prepended to the list of
// known addresses for the tower.
func testCreateTower(h *clientDBHarness) {
// Test that loading a tower with an arbitrary tower id fails.
h.loadTower(20, wtdb.ErrTowerNotFound)
pk, err := randPubKey()
if err != nil {
h.t.Fatalf("unable to generate pubkey: %v", err)
}
addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
lnAddr := &lnwire.NetAddress{
IdentityKey: pk,
Address: addr1,
}
// Insert a random tower into the database.
tower := h.createTower(lnAddr, nil)
// Load the tower from the database and assert that it matches the tower
// we created.
tower2 := h.loadTower(tower.ID, nil)
if !reflect.DeepEqual(tower, tower2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
tower, tower2)
}
// Insert the address again into the database. Since the address is the
// same, this should result in an unmodified tower record.
towerDupAddr := h.createTower(lnAddr, nil)
if len(towerDupAddr.Addresses) != 1 {
h.t.Fatalf("duplicate address should be deduped")
}
if !reflect.DeepEqual(tower, towerDupAddr) {
h.t.Fatalf("mismatch towers, want: %v, got: %v",
tower, towerDupAddr)
}
// Generate a new address for this tower.
addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
lnAddr2 := &lnwire.NetAddress{
IdentityKey: pk,
Address: addr2,
}
// Insert the updated address, which should produce a tower with a new
// address.
towerNewAddr := h.createTower(lnAddr2, nil)
// Load the tower from the database, and assert that it matches the
// tower returned from creation.
towerNewAddr2 := h.loadTower(tower.ID, nil)
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
towerNewAddr, towerNewAddr2)
}
// Assert that there are now two addresses on the tower object.
if len(towerNewAddr.Addresses) != 2 {
h.t.Fatalf("new address should be added")
}
// Finally, assert that the new address was prepended since it is deemed
// fresher.
if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) {
h.t.Fatalf("new address should be prepended")
}
}
// testChanSummaries tests the process of a registering a channel and its
// associated sweep pkscript.
func testChanSummaries(h *clientDBHarness) {
// First, assert that this channel is not already registered.
var chanID lnwire.ChannelID
if _, ok := h.fetchChanSummaries()[chanID]; ok {
h.t.Fatalf("pkscript for channel %x should not exist yet",
chanID)
}
// Generate a random sweep pkscript and register it for this channel.
expPkScript := make([]byte, 22)
if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil {
h.t.Fatalf("unable to generate pkscript: %v", err)
}
h.registerChan(chanID, expPkScript, nil)
// Assert that the channel exists and that its sweep pkscript matches
// the one we registered.
summary, ok := h.fetchChanSummaries()[chanID]
if !ok {
h.t.Fatalf("pkscript for channel %x should not exist yet",
chanID)
} else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 {
h.t.Fatalf("pkscript mismatch, want: %x, got: %x",
expPkScript, summary.SweepPkScript)
}
// Finally, assert that re-registering the same channel produces a
// failure.
h.registerChan(chanID, expPkScript, wtdb.ErrChannelAlreadyRegistered)
}
// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can
func testCommitUpdate(h *clientDBHarness) {
session := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: wtdb.TowerID(3),
Policy: wtpolicy.Policy{
MaxUpdates: 100,
},
RewardPkScript: []byte{0x01, 0x02, 0x03},
},
ID: wtdb.SessionID([33]byte{0x02}),
}
// Generate a random update and try to commit before inserting the
// session, which should fail.
update1 := randCommittedUpdate(h.t, 1)
h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound)
// Reserve a session key index and insert the session.
session.KeyIndex = h.nextKeyIndex(session.TowerID, nil)
h.insertSession(session, nil)
// Now, try to commit the update that failed initially which should
// succeed. The lastApplied value should be 0 since we have not received
// an ack from the tower.
lastApplied := h.commitUpdate(&session.ID, update1, nil)
if lastApplied != 0 {
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
lastApplied)
}
// Assert that the committed update appears in the client session's
// CommittedUpdates map when loaded from disk and that there are no
// AckedUpdates.
dbSession := h.listSessions()[session.ID]
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1,
})
checkAckedUpdates(h.t, dbSession, nil)
// Try to commit the same update, which should succeed due to
// idempotency (which is preserved when the breach hint is identical to
// the on-disk update's hint). The lastApplied value should remain
// unchanged.
lastApplied2 := h.commitUpdate(&session.ID, update1, nil)
if lastApplied2 != lastApplied {
h.t.Fatalf("last applied should not have changed, got %v",
lastApplied2)
}
// Assert that the loaded ClientSession is the same as before.
dbSession = h.listSessions()[session.ID]
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1,
})
checkAckedUpdates(h.t, dbSession, nil)
// Generate another random update and try to commit it at the identical
// sequence number. Since the breach hint has changed, this should fail.
update2 := randCommittedUpdate(h.t, 1)
h.commitUpdate(&session.ID, update2, wtdb.ErrUpdateAlreadyCommitted)
// Next, insert the new update at the next unallocated sequence number
// which should succeed.
update2.SeqNum = 2
lastApplied3 := h.commitUpdate(&session.ID, update2, nil)
if lastApplied3 != lastApplied {
h.t.Fatalf("last applied should not have changed, got %v",
lastApplied3)
}
// Check that both updates now appear as committed on the ClientSession
// loaded from disk.
dbSession = h.listSessions()[session.ID]
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1,
*update2,
})
checkAckedUpdates(h.t, dbSession, nil)
// Finally, create one more random update and try to commit it at index
// 4, which should be rejected since 3 is the next slot the database
// expects.
update4 := randCommittedUpdate(h.t, 4)
h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate)
// Assert that the ClientSession loaded from disk remains unchanged.
dbSession = h.listSessions()[session.ID]
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1,
*update2,
})
checkAckedUpdates(h.t, dbSession, nil)
}
// testAckUpdate asserts the behavior of AckUpdate.
func testAckUpdate(h *clientDBHarness) {
// Create a new session that the updates in this will be tied to.
session := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: wtdb.TowerID(3),
Policy: wtpolicy.Policy{
MaxUpdates: 100,
},
RewardPkScript: []byte{0x01, 0x02, 0x03},
},
ID: wtdb.SessionID([33]byte{0x03}),
}
// Try to ack an update before inserting the client session, which
// should fail.
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrClientSessionNotFound)
// Reserve a session key and insert the client session.
session.KeyIndex = h.nextKeyIndex(session.TowerID, nil)
h.insertSession(session, nil)
// Now, try to ack update 1. This should fail since update 1 was never
// committed.
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrCommittedUpdateNotFound)
// Commit to a random update at seqnum 1.
update1 := randCommittedUpdate(h.t, 1)
lastApplied := h.commitUpdate(&session.ID, update1, nil)
if lastApplied != 0 {
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
lastApplied)
}
// Acking seqnum 1 should succeed.
h.ackUpdate(&session.ID, 1, 1, nil)
// Acking seqnum 1 again should fail.
h.ackUpdate(&session.ID, 1, 1, wtdb.ErrCommittedUpdateNotFound)
// Acking a valid seqnum with a reverted last applied value should fail.
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrLastAppliedReversion)
// Acking with a last applied greater than any allocated seqnum should
// fail.
h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
// Assert that the ClientSession loaded from disk has one update in it's
// AckedUpdates map, and that the committed update has been removed.
dbSession := h.listSessions()[session.ID]
checkCommittedUpdates(h.t, dbSession, nil)
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
1: update1.BackupID,
})
// Commit to another random update, and assert that the last applied
// value is 1, since this was what was provided in the last successful
// ack.
update2 := randCommittedUpdate(h.t, 2)
lastApplied = h.commitUpdate(&session.ID, update2, nil)
if lastApplied != 1 {
h.t.Fatalf("last applied mismatch, want: 1, got: %v",
lastApplied)
}
// Ack seqnum 2.
h.ackUpdate(&session.ID, 2, 2, nil)
// Assert that both updates exist as AckedUpdates when loaded from disk.
dbSession = h.listSessions()[session.ID]
checkCommittedUpdates(h.t, dbSession, nil)
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
1: update1.BackupID,
2: update2.BackupID,
})
// Acking again with a lower last applied should fail.
h.ackUpdate(&session.ID, 2, 1, wtdb.ErrLastAppliedReversion)
// Acking an unallocated seqnum should fail.
h.ackUpdate(&session.ID, 4, 2, wtdb.ErrCommittedUpdateNotFound)
// Acking with a last applied greater than any allocated seqnum should
// fail.
h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
}
// checkCommittedUpdates asserts that the CommittedUpdates on session match the
// expUpdates provided.
func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
expUpdates []wtdb.CommittedUpdate) {
t.Helper()
// We promote nil expUpdates to an initialized slice since the database
// should never return a nil slice. This promotion is done purely out of
// convenience for the testing framework.
if expUpdates == nil {
expUpdates = make([]wtdb.CommittedUpdate, 0)
}
if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) {
t.Fatalf("committed updates mismatch, want: %v, got: %v",
expUpdates, session.CommittedUpdates)
}
}
// checkAckedUpdates asserts that the AckedUpdates on a sessio match the
// expUpdates provided.
func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
expUpdates map[uint16]wtdb.BackupID) {
// We promote nil expUpdates to an initialized map since the database
// should never return a nil map. This promotion is done purely out of
// convenience for the testing framework.
if expUpdates == nil {
expUpdates = make(map[uint16]wtdb.BackupID)
}
if !reflect.DeepEqual(session.AckedUpdates, expUpdates) {
t.Fatalf("acked updates mismatch, want: %v, got: %v",
expUpdates, session.AckedUpdates)
}
}
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
// and the mock implementation. This ensures that all databases function
// identically, especially in the negative paths.
func TestClientDB(t *testing.T) {
dbs := []struct {
name string
init clientDBInit
}{
{
name: "fresh clientdb",
init: func(t *testing.T) (wtclient.DB, func()) {
path, err := ioutil.TempDir("", "clientdb")
if err != nil {
t.Fatalf("unable to make temp dir: %v",
err)
}
db, err := wtdb.OpenClientDB(path)
if err != nil {
os.RemoveAll(path)
t.Fatalf("unable to open db: %v", err)
}
cleanup := func() {
db.Close()
os.RemoveAll(path)
}
return db, cleanup
},
},
{
name: "reopened clientdb",
init: func(t *testing.T) (wtclient.DB, func()) {
path, err := ioutil.TempDir("", "clientdb")
if err != nil {
t.Fatalf("unable to make temp dir: %v",
err)
}
db, err := wtdb.OpenClientDB(path)
if err != nil {
os.RemoveAll(path)
t.Fatalf("unable to open db: %v", err)
}
db.Close()
db, err = wtdb.OpenClientDB(path)
if err != nil {
os.RemoveAll(path)
t.Fatalf("unable to reopen db: %v", err)
}
cleanup := func() {
db.Close()
os.RemoveAll(path)
}
return db, cleanup
},
},
{
name: "mock",
init: func(t *testing.T) (wtclient.DB, func()) {
return wtmock.NewClientDB(), func() {}
},
},
}
tests := []struct {
name string
run func(*clientDBHarness)
}{
{
name: "create client session",
run: testCreateClientSession,
},
{
name: "create tower",
run: testCreateTower,
},
{
name: "chan summaries",
run: testChanSummaries,
},
{
name: "commit update",
run: testCommitUpdate,
},
{
name: "ack update",
run: testAckUpdate,
},
}
for _, database := range dbs {
db := database
t.Run(db.name, func(t *testing.T) {
t.Parallel()
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
h, cleanup := newClientDBHarness(
t, db.init,
)
defer cleanup()
test.run(h)
})
}
})
}
}
// randCommittedUpdate generates a random committed update.
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
var chanID lnwire.ChannelID
if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil {
t.Fatalf("unable to generate chan id: %v", err)
}
var hint wtdb.BreachHint
if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil {
t.Fatalf("unable to generate breach hint: %v", err)
}
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
if _, err := io.ReadFull(crand.Reader, encBlob); err != nil {
t.Fatalf("unable to generate encrypted blob: %v", err)
}
return &wtdb.CommittedUpdate{
SeqNum: seqNum,
CommittedUpdateBody: wtdb.CommittedUpdateBody{
BackupID: wtdb.BackupID{
ChanID: chanID,
CommitHeight: 666,
},
Hint: hint,
EncryptedBlob: encBlob,
},
}
}

@ -1,43 +1,21 @@
package wtdb
import (
"errors"
"io"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
var (
// ErrClientSessionNotFound signals that the requested client session
// was not found in the database.
ErrClientSessionNotFound = errors.New("client session not found")
// CSessionStatus is a bit-field representing the possible statuses of
// ClientSessions.
type CSessionStatus uint8
// ErrUpdateAlreadyCommitted signals that the chosen sequence number has
// already been committed to an update with a different breach hint.
ErrUpdateAlreadyCommitted = errors.New("update already committed")
// ErrCommitUnorderedUpdate signals the client tried to commit a
// sequence number other than the next unallocated sequence number.
ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic")
// ErrCommittedUpdateNotFound signals that the tower tried to ACK a
// sequence number that has not yet been allocated by the client.
ErrCommittedUpdateNotFound = errors.New("committed update not found")
// ErrUnallocatedLastApplied signals that the tower tried to provide a
// LastApplied value greater than any allocated sequence number.
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
"greater than allocated seqnum")
// ErrNoReservedKeyIndex signals that a client session could not be
// created because no session key index was reserved.
ErrNoReservedKeyIndex = errors.New("key index not reserved")
// ErrIncorrectKeyIndex signals that the client session could not be
// created because session key index differs from the reserved key
// index.
ErrIncorrectKeyIndex = errors.New("incorrect key index")
const (
// CSessionActive indicates that the ClientSession is active and can be
// used for backups.
CSessionActive CSessionStatus = 0
)
// ClientSession encapsulates a SessionInfo returned from a successful
@ -46,8 +24,48 @@ var (
type ClientSession struct {
// ID is the client's public key used when authenticating with the
// tower.
//
// NOTE: This value is not serialized with the body of the struct, it
// should be set and recovered as the ClientSession's key.
ID SessionID
ClientSessionBody
// CommittedUpdates is a sorted list of unacked updates. These updates
// can be resent after a restart if the updates failed to send or
// receive an acknowledgment.
//
// NOTE: This list is serialized in it's own bucket, separate from the
// body of the ClientSession. The representation on disk is a key value
// map from sequence number to CommittedUpdateBody to allow efficient
// insertion and retrieval.
CommittedUpdates []CommittedUpdate
// AckedUpdates is a map from sequence number to backup id to record
// which revoked states were uploaded via this session.
//
// NOTE: This map is serialized in it's own bucket, separate from the
// body of the ClientSession.
AckedUpdates map[uint16]BackupID
// Tower holds the pubkey and address of the watchtower.
//
// NOTE: This value is not serialized. It is recovered by looking up the
// tower with TowerID.
Tower *Tower
// SessionPrivKey is the ephemeral secret key used to connect to the
// watchtower.
//
// NOTE: This value is not serialized. It is derived using the KeyIndex
// on startup to avoid storing private keys on disk.
SessionPrivKey *btcec.PrivateKey
}
// ClientSessionBody represents the primary components of a ClientSession that
// are serialized together within the database. The CommittedUpdates and
// AckedUpdates are serialized in buckets separate from the body.
type ClientSessionBody struct {
// SeqNum is the next unallocated sequence number that can be sent to
// the tower.
SeqNum uint16
@ -57,13 +75,7 @@ type ClientSession struct {
// TowerID is the unique, db-assigned identifier that references the
// Tower with which the session is negotiated.
TowerID uint64
// Tower holds the pubkey and address of the watchtower.
//
// NOTE: This value is not serialized. It is recovered by looking up the
// tower with TowerID.
Tower *Tower
TowerID TowerID
// KeyIndex is the index of key locator used to derive the client's
// session key so that it can authenticate with the tower to update its
@ -71,29 +83,54 @@ type ClientSession struct {
// use the keychain.KeyFamilyTowerSession key family.
KeyIndex uint32
// SessionPrivKey is the ephemeral secret key used to connect to the
// watchtower.
//
// NOTE: This value is not serialized. It is derived using the KeyIndex
// on startup to avoid storing private keys on disk.
SessionPrivKey *btcec.PrivateKey
// Policy holds the negotiated session parameters.
Policy wtpolicy.Policy
// Status indicates the current state of the ClientSession.
Status CSessionStatus
// RewardPkScript is the pkscript that the tower's reward will be
// deposited to if a sweep transaction confirms and the sessions
// specifies a reward output.
RewardPkScript []byte
}
// CommittedUpdates is a map from allocated sequence numbers to unacked
// updates. These updates can be resent after a restart if the update
// failed to send or receive an acknowledgment.
CommittedUpdates map[uint16]*CommittedUpdate
// Encode writes a ClientSessionBody to the passed io.Writer.
func (s *ClientSessionBody) Encode(w io.Writer) error {
return WriteElements(w,
s.SeqNum,
s.TowerLastApplied,
uint64(s.TowerID),
s.KeyIndex,
uint8(s.Status),
s.Policy,
s.RewardPkScript,
)
}
// AckedUpdates is a map from sequence number to backup id to record
// which revoked states were uploaded via this session.
AckedUpdates map[uint16]BackupID
// Decode reads a ClientSessionBody from the passed io.Reader.
func (s *ClientSessionBody) Decode(r io.Reader) error {
var (
towerID uint64
status uint8
)
err := ReadElements(r,
&s.SeqNum,
&s.TowerLastApplied,
&towerID,
&s.KeyIndex,
&status,
&s.Policy,
&s.RewardPkScript,
)
if err != nil {
return err
}
s.TowerID = TowerID(towerID)
s.Status = CSessionStatus(status)
return nil
}
// BackupID identifies a particular revoked, remote commitment by channel id and
@ -106,9 +143,38 @@ type BackupID struct {
CommitHeight uint64
}
// Encode writes the BackupID from the passed io.Writer.
func (b *BackupID) Encode(w io.Writer) error {
return WriteElements(w,
b.ChanID,
b.CommitHeight,
)
}
// Decode reads a BackupID from the passed io.Reader.
func (b *BackupID) Decode(r io.Reader) error {
return ReadElements(r,
&b.ChanID,
&b.CommitHeight,
)
}
// CommittedUpdate holds a state update sent by a client along with its
// SessionID.
// allocated sequence number and the exact remote commitment the encrypted
// justice transaction can rectify.
type CommittedUpdate struct {
// SeqNum is the unique sequence number allocated by the session to this
// update.
SeqNum uint16
CommittedUpdateBody
}
// CommittedUpdateBody represents the primary components of a CommittedUpdate.
// On disk, this is stored under the sequence number, which acts as its key.
type CommittedUpdateBody struct {
// BackupID identifies the breached commitment that the encrypted blob
// can spend from.
BackupID BackupID
// Hint is the 16-byte prefix of the revoked commitment transaction ID.
@ -119,3 +185,29 @@ type CommittedUpdate struct {
// hint is broadcast.
EncryptedBlob []byte
}
// Encode writes the CommittedUpdateBody to the passed io.Writer.
func (u *CommittedUpdateBody) Encode(w io.Writer) error {
err := u.BackupID.Encode(w)
if err != nil {
return err
}
return WriteElements(w,
u.Hint,
u.EncryptedBlob,
)
}
// Decode reads a CommittedUpdateBody from the passed io.Reader.
func (u *CommittedUpdateBody) Decode(r io.Reader) error {
err := u.BackupID.Decode(r)
if err != nil {
return err
}
return ReadElements(r,
&u.Hint,
&u.EncryptedBlob,
)
}

@ -2,14 +2,122 @@ package wtdb_test
import (
"bytes"
"encoding/binary"
"io"
"math/rand"
"net"
"reflect"
"testing"
"testing/quick"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
)
func randPubKey() (*btcec.PublicKey, error) {
priv, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
return nil, err
}
return priv.PubKey(), nil
}
func randTCP4Addr(r *rand.Rand) (*net.TCPAddr, error) {
var ip [4]byte
if _, err := r.Read(ip[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
addrIP := net.IP(ip[:])
addrPort := int(binary.BigEndian.Uint16(port[:]))
return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil
}
func randTCP6Addr(r *rand.Rand) (*net.TCPAddr, error) {
var ip [16]byte
if _, err := r.Read(ip[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
addrIP := net.IP(ip[:])
addrPort := int(binary.BigEndian.Uint16(port[:]))
return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil
}
func randV2OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) {
var serviceID [tor.V2DecodedLen]byte
if _, err := r.Read(serviceID[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
onionService := tor.Base32Encoding.EncodeToString(serviceID[:])
onionService += tor.OnionSuffix
addrPort := int(binary.BigEndian.Uint16(port[:]))
return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil
}
func randV3OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) {
var serviceID [tor.V3DecodedLen]byte
if _, err := r.Read(serviceID[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
onionService := tor.Base32Encoding.EncodeToString(serviceID[:])
onionService += tor.OnionSuffix
addrPort := int(binary.BigEndian.Uint16(port[:]))
return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil
}
func randAddrs(r *rand.Rand) ([]net.Addr, error) {
tcp4Addr, err := randTCP4Addr(r)
if err != nil {
return nil, err
}
tcp6Addr, err := randTCP6Addr(r)
if err != nil {
return nil, err
}
v2OnionAddr, err := randV2OnionAddr(r)
if err != nil {
return nil, err
}
v3OnionAddr, err := randV3OnionAddr(r)
if err != nil {
return nil, err
}
return []net.Addr{tcp4Addr, tcp6Addr, v2OnionAddr, v3OnionAddr}, nil
}
// dbObject is abstract object support encoding and decoding.
type dbObject interface {
Encode(io.Writer) error
@ -19,7 +127,9 @@ type dbObject interface {
// TestCodec serializes and deserializes wtdb objects in order to test that that
// the codec understands all of the required field types. The test also asserts
// that decoding an object into another results in an equivalent object.
func TestCodec(t *testing.T) {
func TestCodec(tt *testing.T) {
var t *testing.T
mainScenario := func(obj dbObject) bool {
// Ensure encoding the object succeeds.
var b bytes.Buffer
@ -35,6 +145,16 @@ func TestCodec(t *testing.T) {
obj2 = &wtdb.SessionInfo{}
case *wtdb.SessionStateUpdate:
obj2 = &wtdb.SessionStateUpdate{}
case *wtdb.ClientSessionBody:
obj2 = &wtdb.ClientSessionBody{}
case *wtdb.CommittedUpdateBody:
obj2 = &wtdb.CommittedUpdateBody{}
case *wtdb.BackupID:
obj2 = &wtdb.BackupID{}
case *wtdb.Tower:
obj2 = &wtdb.Tower{}
case *wtdb.ClientChanSummary:
obj2 = &wtdb.ClientChanSummary{}
default:
t.Fatalf("unknown type: %T", obj)
return false
@ -57,6 +177,29 @@ func TestCodec(t *testing.T) {
return true
}
customTypeGen := map[string]func([]reflect.Value, *rand.Rand){
"Tower": func(v []reflect.Value, r *rand.Rand) {
pk, err := randPubKey()
if err != nil {
t.Fatalf("unable to generate pubkey: %v", err)
return
}
addrs, err := randAddrs(r)
if err != nil {
t.Fatalf("unable to generate addrs: %v", err)
return
}
obj := wtdb.Tower{
IdentityKey: pk,
Addresses: addrs,
}
v[0] = reflect.ValueOf(obj)
},
}
tests := []struct {
name string
scenario interface{}
@ -73,11 +216,51 @@ func TestCodec(t *testing.T) {
return mainScenario(&obj)
},
},
{
name: "ClientSessionBody",
scenario: func(obj wtdb.ClientSessionBody) bool {
return mainScenario(&obj)
},
},
{
name: "CommittedUpdateBody",
scenario: func(obj wtdb.CommittedUpdateBody) bool {
return mainScenario(&obj)
},
},
{
name: "BackupID",
scenario: func(obj wtdb.BackupID) bool {
return mainScenario(&obj)
},
},
{
name: "Tower",
scenario: func(obj wtdb.Tower) bool {
return mainScenario(&obj)
},
},
{
name: "ClientChanSummary",
scenario: func(obj wtdb.ClientChanSummary) bool {
return mainScenario(&obj)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err := quick.Check(test.scenario, nil); err != nil {
tt.Run(test.name, func(h *testing.T) {
t = h
var config *quick.Config
if valueGen, ok := customTypeGen[test.name]; ok {
config = &quick.Config{
Values: valueGen,
}
}
err := quick.Check(test.scenario, config)
if err != nil {
t.Fatalf("fuzz checks for msg=%s failed: %v",
test.name, err)
}

@ -0,0 +1,92 @@
package wtdb
import (
"encoding/binary"
"errors"
"os"
"path/filepath"
"github.com/coreos/bbolt"
)
const (
// dbFilePermission requests read+write access to the db file.
dbFilePermission = 0600
)
var (
// metadataBkt stores all the meta information concerning the state of
// the database.
metadataBkt = []byte("metadata-bucket")
// dbVersionKey is a static key used to retrieve the database version
// number from the metadataBkt.
dbVersionKey = []byte("version")
// ErrUninitializedDB signals that top-level buckets for the database
// have not been initialized.
ErrUninitializedDB = errors.New("db not initialized")
// ErrNoDBVersion signals that the database contains no version info.
ErrNoDBVersion = errors.New("db has no version")
// byteOrder is the default endianness used when serializing integers.
byteOrder = binary.BigEndian
)
// fileExists returns true if the file exists, and false otherwise.
func fileExists(path string) bool {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
// createDBIfNotExist opens the boltdb database at dbPath/name, creating one if
// one doesn't exist. The boolean returned indicates if the database did not
// exist before, or if it has been created but no version metadata exists within
// it.
func createDBIfNotExist(dbPath, name string) (*bbolt.DB, bool, error) {
path := filepath.Join(dbPath, name)
// If the database file doesn't exist, this indicates we much initialize
// a fresh database with the latest version.
firstInit := !fileExists(path)
if firstInit {
// Ensure all parent directories are initialized.
err := os.MkdirAll(dbPath, 0700)
if err != nil {
return nil, false, err
}
}
bdb, err := bbolt.Open(path, dbFilePermission, nil)
if err != nil {
return nil, false, err
}
// If the file existed previously, we'll now check to see that the
// metadata bucket is properly initialized. It could be the case that
// the database was created, but we failed to actually populate any
// metadata. If the metadata bucket does not actually exist, we'll
// set firstInit to true so that we can treat is initialize the bucket.
if !firstInit {
var metadataExists bool
err = bdb.View(func(tx *bbolt.Tx) error {
metadataExists = tx.Bucket(metadataBkt) != nil
return nil
})
if err != nil {
return nil, false, err
}
if !metadataExists {
firstInit = true
}
}
return bdb, firstInit, nil
}

@ -1,26 +1,38 @@
package wtdb
import (
"errors"
"io"
"net"
"sync"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrTowerNotFound signals that the target tower was not found in the
// database.
ErrTowerNotFound = errors.New("tower not found")
)
// TowerID is a unique 64-bit identifier allocated to each unique watchtower.
// This allows the client to conserve on-disk space by not needing to always
// reference towers by their pubkey.
type TowerID uint64
// TowerIDFromBytes constructs a TowerID from the provided byte slice. The
// argument must have at least 8 bytes, and should contain the TowerID in
// big-endian byte order.
func TowerIDFromBytes(towerIDBytes []byte) TowerID {
return TowerID(byteOrder.Uint64(towerIDBytes))
}
// Bytes encodes a TowerID into an 8-byte slice in big-endian byte order.
func (id TowerID) Bytes() []byte {
var buf [8]byte
byteOrder.PutUint64(buf[:], uint64(id))
return buf[:]
}
// Tower holds the necessary components required to connect to a remote tower.
// Communication is handled by brontide, and requires both a public key and an
// address.
type Tower struct {
// ID is a unique ID for this record assigned by the database.
ID uint64
ID TowerID
// IdentityKey is the public key of the remote node, used to
// authenticate the brontide transport.
@ -28,18 +40,15 @@ type Tower struct {
// Addresses is a list of possible addresses to reach the tower.
Addresses []net.Addr
mu sync.RWMutex
}
// AddAddress adds the given address to the tower's in-memory list of addresses.
// If the address's string is already present, the Tower will be left
// unmodified. Otherwise, the adddress is prepended to the beginning of the
// Tower's addresses, on the assumption that it is fresher than the others.
//
// NOTE: This method is NOT safe for concurrent use.
func (t *Tower) AddAddress(addr net.Addr) {
t.mu.Lock()
defer t.mu.Unlock()
// Ensure we don't add a duplicate address.
addrStr := addr.String()
for _, existingAddr := range t.Addresses {
@ -56,10 +65,9 @@ func (t *Tower) AddAddress(addr net.Addr) {
// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's
// addresses. This can be used to have a client try multiple addresses for the
// same Tower.
//
// NOTE: This method is NOT safe for concurrent use.
func (t *Tower) LNAddrs() []*lnwire.NetAddress {
t.mu.RLock()
defer t.mu.RUnlock()
addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses))
for _, addr := range t.Addresses {
addrs = append(addrs, &lnwire.NetAddress{
@ -70,3 +78,21 @@ func (t *Tower) LNAddrs() []*lnwire.NetAddress {
return addrs
}
// Encode writes the Tower to the passed io.Writer. The TowerID is not
// serialized, since it acts as the key.
func (t *Tower) Encode(w io.Writer) error {
return WriteElements(w,
t.IdentityKey,
t.Addresses,
)
}
// Decode reads a Tower from the passed io.Reader. The TowerID is meant to be
// decoded from the key.
func (t *Tower) Decode(r io.Reader) error {
return ReadElements(r,
&t.IdentityKey,
&t.Addresses,
)
}

@ -2,23 +2,16 @@ package wtdb
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
)
const (
// dbName is the filename of tower database.
dbName = "watchtower.db"
// dbFilePermission requests read+write access to the db file.
dbFilePermission = 0600
// towerDBName is the filename of tower database.
towerDBName = "watchtower.db"
)
var (
@ -49,26 +42,9 @@ var (
// epoch from the lookoutTipBkt.
lookoutTipKey = []byte("lookout-tip")
// metadataBkt stores all the meta information concerning the state of
// the database.
metadataBkt = []byte("metadata-bucket")
// dbVersionKey is a static key used to retrieve the database version
// number from the metadataBkt.
dbVersionKey = []byte("version")
// ErrUninitializedDB signals that top-level buckets for the database
// have not been initialized.
ErrUninitializedDB = errors.New("tower db not initialized")
// ErrNoDBVersion signals that the database contains no version info.
ErrNoDBVersion = errors.New("tower db has no version")
// ErrNoSessionHintIndex signals that an active session does not have an
// initialized index for tracking its own state updates.
ErrNoSessionHintIndex = errors.New("session hint index missing")
byteOrder = binary.BigEndian
)
// TowerDB is single database providing a persistent storage engine for the
@ -86,67 +62,20 @@ type TowerDB struct {
// with a version number higher that the latest version will fail to prevent
// accidental reversion.
func OpenTowerDB(dbPath string) (*TowerDB, error) {
path := filepath.Join(dbPath, dbName)
// If the database file doesn't exist, this indicates we much initialize
// a fresh database with the latest version.
firstInit := !fileExists(path)
if firstInit {
// Ensure all parent directories are initialized.
err := os.MkdirAll(dbPath, 0700)
if err != nil {
return nil, err
}
}
bdb, err := bbolt.Open(path, dbFilePermission, nil)
bdb, firstInit, err := createDBIfNotExist(dbPath, towerDBName)
if err != nil {
return nil, err
}
// If the file existed previously, we'll now check to see that the
// metadata bucket is properly initialized. It could be the case that
// the database was created, but we failed to actually populate any
// metadata. If the metadata bucket does not actually exist, we'll
// set firstInit to true so that we can treat is initialize the bucket.
if !firstInit {
var metadataExists bool
err = bdb.View(func(tx *bbolt.Tx) error {
metadataExists = tx.Bucket(metadataBkt) != nil
return nil
})
if err != nil {
return nil, err
}
if !metadataExists {
firstInit = true
}
}
towerDB := &TowerDB{
db: bdb,
dbPath: dbPath,
}
if firstInit {
// If the database has not yet been created, we'll initialize
// the database version with the latest known version.
err = towerDB.db.Update(func(tx *bbolt.Tx) error {
return initDBVersion(tx, getLatestDBVersion(dbVersions))
})
if err != nil {
bdb.Close()
return nil, err
}
} else {
// Otherwise, ensure that any migrations are applied to ensure
// the data is in the format expected by the latest version.
err = towerDB.syncVersions(dbVersions)
if err != nil {
bdb.Close()
return nil, err
}
err = initOrSyncVersions(towerDB, firstInit, towerDBVersions)
if err != nil {
bdb.Close()
return nil, err
}
// Now that the database version fully consistent with our latest known
@ -163,17 +92,6 @@ func OpenTowerDB(dbPath string) (*TowerDB, error) {
return towerDB, nil
}
// fileExists returns true if the file exists, and false otherwise.
func fileExists(path string) bool {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
// initTowerDBBuckets creates all top-level buckets required to handle database
// operations required by the latest version.
func initTowerDBBuckets(tx *bbolt.Tx) error {
@ -194,53 +112,16 @@ func initTowerDBBuckets(tx *bbolt.Tx) error {
return nil
}
// syncVersions ensures the database version is consistent with the highest
// known database version, applying any migrations that have not been made. If
// the highest known version number is lower than the database's version, this
// method will fail to prevent accidental reversions.
func (t *TowerDB) syncVersions(versions []version) error {
curVersion, err := t.Version()
if err != nil {
return err
}
latestVersion := getLatestDBVersion(versions)
switch {
// Current version is higher than any known version, fail to prevent
// reversion.
case curVersion > latestVersion:
return channeldb.ErrDBReversion
// Current version matches highest known version, nothing to do.
case curVersion == latestVersion:
return nil
}
// Otherwise, apply any migrations in order to bring the database
// version up to the highest known version.
updates := getMigrations(versions, curVersion)
return t.db.Update(func(tx *bbolt.Tx) error {
for _, update := range updates {
if update.migration == nil {
continue
}
log.Infof("Applying migration #%d", update.number)
err := update.migration(tx)
if err != nil {
log.Errorf("Unable to apply migration #%d: %v",
err)
return err
}
}
return putDBVersion(tx, latestVersion)
})
// bdb returns the backing bbolt.DB instance.
//
// NOTE: Part of the versionedDB interface.
func (t *TowerDB) bdb() *bbolt.DB {
return t.db
}
// Version returns the database's current version number.
//
// NOTE: Part of the versionedDB interface.
func (t *TowerDB) Version() (uint32, error) {
var version uint32
err := t.db.View(func(tx *bbolt.Tx) error {

@ -1,6 +1,9 @@
package wtdb
import "github.com/coreos/bbolt"
import (
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb"
)
// migration is a function which takes a prior outdated version of the database
// instances and mutates the key/bucket structure to arrive at a more
@ -10,32 +13,30 @@ type migration func(tx *bbolt.Tx) error
// version pairs a version number with the migration that would need to be
// applied from the prior version to upgrade.
type version struct {
number uint32
migration migration
}
// dbVersions stores all versions and migrations of the database. This list will
// be used when opening the database to determine if any migrations must be
// applied.
var dbVersions = []version{
{
// Initial version requires no migration.
number: 0,
migration: nil,
},
}
// towerDBVersions stores all versions and migrations of the tower database.
// This list will be used when opening the database to determine if any
// migrations must be applied.
var towerDBVersions = []version{}
// clientDBVersions stores all versions and migrations of the client database.
// This list will be used when opening the database to determine if any
// migrations must be applied.
var clientDBVersions = []version{}
// getLatestDBVersion returns the last known database version.
func getLatestDBVersion(versions []version) uint32 {
return versions[len(versions)-1].number
return uint32(len(versions))
}
// getMigrations returns a slice of all updates with a greater number that
// curVersion that need to be applied to sync up with the latest version.
func getMigrations(versions []version, curVersion uint32) []version {
var updates []version
for _, v := range versions {
if v.number > curVersion {
for i, v := range versions {
if uint32(i)+1 > curVersion {
updates = append(updates, v)
}
}
@ -82,3 +83,81 @@ func putDBVersion(tx *bbolt.Tx, version uint32) error {
byteOrder.PutUint32(versionBytes, version)
return metadata.Put(dbVersionKey, versionBytes)
}
// versionedDB is a private interface implemented by both the tower and client
// databases, permitting all versioning operations to be performed generically
// on either.
type versionedDB interface {
// bdb returns the underlying bbolt database.
bdb() *bbolt.DB
// Version returns the current version stored in the database.
Version() (uint32, error)
}
// initOrSyncVersions ensures that the database version is properly set before
// opening the database up for regular use. When the database is being
// initialized for the first time, the caller should set init to true, which
// will simply write the latest version to the database. Otherwise, passing init
// as false will cause the database to apply any needed migrations to ensure its
// version matches the latest version in the provided versions list.
func initOrSyncVersions(db versionedDB, init bool, versions []version) error {
// If the database has not yet been created, we'll initialize the
// database version with the latest known version.
if init {
return db.bdb().Update(func(tx *bbolt.Tx) error {
return initDBVersion(tx, getLatestDBVersion(versions))
})
}
// Otherwise, ensure that any migrations are applied to ensure the data
// is in the format expected by the latest version.
return syncVersions(db, versions)
}
// syncVersions ensures the database version is consistent with the highest
// known database version, applying any migrations that have not been made. If
// the highest known version number is lower than the database's version, this
// method will fail to prevent accidental reversions.
func syncVersions(db versionedDB, versions []version) error {
curVersion, err := db.Version()
if err != nil {
return err
}
latestVersion := getLatestDBVersion(versions)
switch {
// Current version is higher than any known version, fail to prevent
// reversion.
case curVersion > latestVersion:
return channeldb.ErrDBReversion
// Current version matches highest known version, nothing to do.
case curVersion == latestVersion:
return nil
}
// Otherwise, apply any migrations in order to bring the database
// version up to the highest known version.
updates := getMigrations(versions, curVersion)
return db.bdb().Update(func(tx *bbolt.Tx) error {
for i, update := range updates {
if update.migration == nil {
continue
}
version := curVersion + uint32(i) + 1
log.Infof("Applying migration #%d", version)
err := update.migration(tx)
if err != nil {
log.Errorf("Unable to apply migration #%d: %v",
version, err)
return err
}
}
return putDBVersion(tx, latestVersion)
})
}

@ -1,7 +1,6 @@
package wtmock
import (
"fmt"
"net"
"sync"
"sync/atomic"
@ -18,23 +17,23 @@ type ClientDB struct {
nextTowerID uint64 // to be used atomically
mu sync.Mutex
sweepPkScripts map[lnwire.ChannelID][]byte
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
towerIndex map[towerPK]uint64
towers map[uint64]*wtdb.Tower
towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower
nextIndex uint32
indexes map[uint64]uint32
indexes map[wtdb.TowerID]uint32
}
// NewClientDB initializes a new mock ClientDB.
func NewClientDB() *ClientDB {
return &ClientDB{
sweepPkScripts: make(map[lnwire.ChannelID][]byte),
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
towerIndex: make(map[towerPK]uint64),
towers: make(map[uint64]*wtdb.Tower),
indexes: make(map[uint64]uint32),
towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[wtdb.TowerID]uint32),
}
}
@ -54,9 +53,9 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
tower = m.towers[towerID]
tower.AddAddress(lnAddr.Address)
} else {
towerID = atomic.AddUint64(&m.nextTowerID, 1)
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
tower = &wtdb.Tower{
ID: towerID,
ID: wtdb.TowerID(towerID),
IdentityKey: lnAddr.IdentityKey,
Addresses: []net.Addr{lnAddr.Address},
}
@ -65,16 +64,16 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
m.towerIndex[towerPubKey] = towerID
m.towers[towerID] = tower
return tower, nil
return copyTower(tower), nil
}
// LoadTower retrieves a tower by its tower ID.
func (m *ClientDB) LoadTower(towerID uint64) (*wtdb.Tower, error) {
func (m *ClientDB) LoadTower(towerID wtdb.TowerID) (*wtdb.Tower, error) {
m.mu.Lock()
defer m.mu.Unlock()
if tower, ok := m.towers[towerID]; ok {
return tower, nil
return copyTower(tower), nil
}
return nil, wtdb.ErrTowerNotFound
@ -106,6 +105,11 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
m.mu.Lock()
defer m.mu.Unlock()
// Ensure that we aren't overwriting an existing session.
if _, ok := m.activeSessions[session.ID]; ok {
return wtdb.ErrClientSessionAlreadyExists
}
// Ensure that a session key index has been reserved for this tower.
keyIndex, ok := m.indexes[session.TowerID]
if !ok {
@ -122,14 +126,16 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
delete(m.indexes, session.TowerID)
m.activeSessions[session.ID] = &wtdb.ClientSession{
TowerID: session.TowerID,
KeyIndex: session.KeyIndex,
ID: session.ID,
Policy: session.Policy,
SeqNum: session.SeqNum,
TowerLastApplied: session.TowerLastApplied,
RewardPkScript: cloneBytes(session.RewardPkScript),
CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate),
ID: session.ID,
ClientSessionBody: wtdb.ClientSessionBody{
SeqNum: session.SeqNum,
TowerLastApplied: session.TowerLastApplied,
TowerID: session.TowerID,
KeyIndex: session.KeyIndex,
Policy: session.Policy,
RewardPkScript: cloneBytes(session.RewardPkScript),
},
CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
AckedUpdates: make(map[uint16]wtdb.BackupID),
}
@ -141,7 +147,7 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
// CreateClientSession is invoked for that tower and index, at which point a new
// index for that tower can be reserved. Multiple calls to this method before
// CreateClientSession is invoked should return the same index.
func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) {
func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) {
m.mu.Lock()
defer m.mu.Unlock()
@ -149,17 +155,16 @@ func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) {
return index, nil
}
m.nextIndex++
index := m.nextIndex
m.indexes[towerID] = index
m.nextIndex++
return index, nil
}
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
// seqNum). This allows the client to retransmit this update on startup.
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
update *wtdb.CommittedUpdate) (uint16, error) {
m.mu.Lock()
@ -172,25 +177,26 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
}
// Check if an update has already been committed for this state.
dbUpdate, ok := session.CommittedUpdates[seqNum]
if ok {
// If the breach hint matches, we'll just return the last
// applied value so the client can retransmit.
if dbUpdate.Hint == update.Hint {
return session.TowerLastApplied, nil
}
for _, dbUpdate := range session.CommittedUpdates {
if dbUpdate.SeqNum == update.SeqNum {
// If the breach hint matches, we'll just return the
// last applied value so the client can retransmit.
if dbUpdate.Hint == update.Hint {
return session.TowerLastApplied, nil
}
// Otherwise, fail since the breach hint doesn't match.
return 0, wtdb.ErrUpdateAlreadyCommitted
// Otherwise, fail since the breach hint doesn't match.
return 0, wtdb.ErrUpdateAlreadyCommitted
}
}
// Sequence number must increment.
if seqNum != session.SeqNum+1 {
if update.SeqNum != session.SeqNum+1 {
return 0, wtdb.ErrCommitUnorderedUpdate
}
// Save the update and increment the sequence number.
session.CommittedUpdates[seqNum] = update
session.CommittedUpdates = append(session.CommittedUpdates, *update)
session.SeqNum++
return session.TowerLastApplied, nil
@ -209,13 +215,6 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
return wtdb.ErrClientSessionNotFound
}
// Retrieve the committed update, failing if none is found. We should
// only receive acks for state updates that we send.
update, ok := session.CommittedUpdates[seqNum]
if !ok {
return wtdb.ErrCommittedUpdateNotFound
}
// Ensure the returned last applied value does not exceed the highest
// allocated sequence number.
if lastApplied > session.SeqNum {
@ -228,40 +227,64 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
return wtdb.ErrLastAppliedReversion
}
// Finally, remove the committed update from disk and mark the update as
// acked. The tower last applied value is also recorded to send along
// with the next update.
delete(session.CommittedUpdates, seqNum)
session.AckedUpdates[seqNum] = update.BackupID
session.TowerLastApplied = lastApplied
// Retrieve the committed update, failing if none is found. We should
// only receive acks for state updates that we send.
updates := session.CommittedUpdates
for i, update := range updates {
if update.SeqNum != seqNum {
continue
}
return nil
// Remove the committed update from disk and mark the update as
// acked. The tower last applied value is also recorded to send
// along with the next update.
copy(updates[:i], updates[i+1:])
updates[len(updates)-1] = wtdb.CommittedUpdate{}
session.CommittedUpdates = updates[:len(updates)-1]
session.AckedUpdates[seqNum] = update.BackupID
session.TowerLastApplied = lastApplied
return nil
}
return wtdb.ErrCommittedUpdateNotFound
}
// FetchChanPkScripts returns the set of sweep pkscripts known for all channels.
// This allows the client to cache them in memory on startup.
func (m *ClientDB) FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) {
// FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries.
func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
m.mu.Lock()
defer m.mu.Unlock()
sweepPkScripts := make(map[lnwire.ChannelID][]byte)
for chanID, pkScript := range m.sweepPkScripts {
sweepPkScripts[chanID] = cloneBytes(pkScript)
summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
for chanID, summary := range m.summaries {
summaries[chanID] = wtdb.ClientChanSummary{
SweepPkScript: cloneBytes(summary.SweepPkScript),
}
}
return sweepPkScripts, nil
return summaries, nil
}
// AddChanPkScript sets a pkscript or sweeping funds from the channel or chanID.
func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) error {
// RegisterChannel registers a channel for use within the client database. For
// now, all that is stored in the channel summary is the sweep pkscript that
// we'd like any tower sweeps to pay into. In the future, this will be extended
// to contain more info to allow the client efficiently request historical
// states to be backed up under the client's active policy.
func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
sweepPkScript []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.sweepPkScripts[chanID]; ok {
return fmt.Errorf("pkscript for %x already exists", pkScript)
if _, ok := m.summaries[chanID]; ok {
return wtdb.ErrChannelAlreadyRegistered
}
m.sweepPkScripts[chanID] = cloneBytes(pkScript)
m.summaries[chanID] = wtdb.ClientChanSummary{
SweepPkScript: cloneBytes(sweepPkScript),
}
return nil
}
@ -276,3 +299,14 @@ func cloneBytes(b []byte) []byte {
return bb
}
func copyTower(tower *wtdb.Tower) *wtdb.Tower {
t := &wtdb.Tower{
ID: tower.ID,
IdentityKey: tower.IdentityKey,
Addresses: make([]net.Addr, len(tower.Addresses)),
}
copy(t.Addresses, tower.Addresses)
return t
}