Merge pull request #3106 from cfromknecht/wtclient-db
watchtower/wtdb: add bbolt-backed ClientDB
This commit is contained in:
commit
6e3b92b55f
@ -103,6 +103,11 @@ func WriteElement(w io.Writer, element interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
case lnwire.ChannelID:
|
||||
if _, err := w.Write(e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case uint64:
|
||||
if err := binary.Write(w, byteOrder, e); err != nil {
|
||||
return err
|
||||
@ -123,6 +128,11 @@ func WriteElement(w io.Writer, element interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
case uint8:
|
||||
if err := binary.Write(w, byteOrder, e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case bool:
|
||||
if err := binary.Write(w, byteOrder, e); err != nil {
|
||||
return err
|
||||
@ -259,6 +269,11 @@ func ReadElement(r io.Reader, element interface{}) error {
|
||||
}
|
||||
*e = lnwire.NewShortChanIDFromInt(a)
|
||||
|
||||
case *lnwire.ChannelID:
|
||||
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *uint64:
|
||||
if err := binary.Read(r, byteOrder, e); err != nil {
|
||||
return err
|
||||
@ -279,6 +294,11 @@ func ReadElement(r io.Reader, element interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
case *uint8:
|
||||
if err := binary.Read(r, byteOrder, e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *bool:
|
||||
if err := binary.Read(r, byteOrder, e); err != nil {
|
||||
return err
|
||||
|
@ -126,8 +126,7 @@ func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
|
||||
// SessionInfo's policy. If no error is returned, the task has been bound to the
|
||||
// session and can be queued to upload to the tower. Otherwise, the bind failed
|
||||
// and should be rescheduled with a different session.
|
||||
func (t *backupTask) bindSession(session *wtdb.ClientSession) error {
|
||||
|
||||
func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
|
||||
// First we'll begin by deriving a weight estimate for the justice
|
||||
// transaction. The final weight can be different depending on whether
|
||||
// the watchtower is taking a reward.
|
||||
|
@ -69,7 +69,7 @@ type backupTaskTest struct {
|
||||
expSweepAmt int64
|
||||
expRewardAmt int64
|
||||
expRewardScript []byte
|
||||
session *wtdb.ClientSession
|
||||
session *wtdb.ClientSessionBody
|
||||
bindErr error
|
||||
expSweepScript []byte
|
||||
signer input.Signer
|
||||
@ -205,7 +205,7 @@ func genTaskTest(
|
||||
expSweepAmt: expSweepAmt,
|
||||
expRewardAmt: expRewardAmt,
|
||||
expRewardScript: rewardScript,
|
||||
session: &wtdb.ClientSession{
|
||||
session: &wtdb.ClientSessionBody{
|
||||
Policy: wtpolicy.Policy{
|
||||
BlobType: blobType,
|
||||
SweepFeeRate: sweepFeeRate,
|
||||
|
@ -150,8 +150,9 @@ type TowerClient struct {
|
||||
sessionQueue *sessionQueue
|
||||
prevTask *backupTask
|
||||
|
||||
sweepPkScriptMu sync.RWMutex
|
||||
sweepPkScripts map[lnwire.ChannelID][]byte
|
||||
backupMu sync.Mutex
|
||||
summaries wtdb.ChannelSummaries
|
||||
chanCommitHeights map[lnwire.ChannelID]uint64
|
||||
|
||||
statTicker *time.Ticker
|
||||
stats clientStats
|
||||
@ -243,9 +244,13 @@ func New(config *Config) (*TowerClient, error) {
|
||||
s.SessionPrivKey = sessionPriv
|
||||
}
|
||||
|
||||
// Reconstruct the highest commit height processed for each channel
|
||||
// under the client's current policy.
|
||||
c.buildHighestCommitHeights()
|
||||
|
||||
// Finally, load the sweep pkscripts that have been generated for all
|
||||
// previously registered channels.
|
||||
c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts()
|
||||
c.summaries, err = c.cfg.DB.FetchChanSummaries()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -253,6 +258,44 @@ func New(config *Config) (*TowerClient, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// buildHighestCommitHeights inspects the full set of candidate client sessions
|
||||
// loaded from disk, and determines the highest known commit height for each
|
||||
// channel. This allows the client to reject backups that it has already
|
||||
// processed for it's active policy.
|
||||
func (c *TowerClient) buildHighestCommitHeights() {
|
||||
chanCommitHeights := make(map[lnwire.ChannelID]uint64)
|
||||
for _, s := range c.candidateSessions {
|
||||
// We only want to consider accepted updates that have been
|
||||
// accepted under an identical policy to the client's current
|
||||
// policy.
|
||||
if s.Policy != c.cfg.Policy {
|
||||
continue
|
||||
}
|
||||
|
||||
// Take the highest commit height found in the session's
|
||||
// committed updates.
|
||||
for _, committedUpdate := range s.CommittedUpdates {
|
||||
bid := committedUpdate.BackupID
|
||||
|
||||
height, ok := chanCommitHeights[bid.ChanID]
|
||||
if !ok || bid.CommitHeight > height {
|
||||
chanCommitHeights[bid.ChanID] = bid.CommitHeight
|
||||
}
|
||||
}
|
||||
|
||||
// Take the heights commit height found in the session's acked
|
||||
// updates.
|
||||
for _, bid := range s.AckedUpdates {
|
||||
height, ok := chanCommitHeights[bid.ChanID]
|
||||
if !ok || bid.CommitHeight > height {
|
||||
chanCommitHeights[bid.ChanID] = bid.CommitHeight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.chanCommitHeights = chanCommitHeights
|
||||
}
|
||||
|
||||
// Start initializes the watchtower client by loading or negotiating an active
|
||||
// session and then begins processing backup tasks from the request pipeline.
|
||||
func (c *TowerClient) Start() error {
|
||||
@ -388,12 +431,12 @@ func (c *TowerClient) ForceQuit() {
|
||||
// within the client. This should be called during link startup to ensure that
|
||||
// the client is able to support the link during operation.
|
||||
func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
|
||||
c.sweepPkScriptMu.Lock()
|
||||
defer c.sweepPkScriptMu.Unlock()
|
||||
c.backupMu.Lock()
|
||||
defer c.backupMu.Unlock()
|
||||
|
||||
// If a pkscript for this channel already exists, the channel has been
|
||||
// previously registered.
|
||||
if _, ok := c.sweepPkScripts[chanID]; ok {
|
||||
if _, ok := c.summaries[chanID]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -406,14 +449,16 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
|
||||
|
||||
// Persist the sweep pkscript so that restarts will not introduce
|
||||
// address inflation when the channel is reregistered after a restart.
|
||||
err = c.cfg.DB.AddChanPkScript(chanID, pkScript)
|
||||
err = c.cfg.DB.RegisterChannel(chanID, pkScript)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, cache the pkscript in our in-memory cache to avoid db
|
||||
// lookups for the remainder of the daemon's execution.
|
||||
c.sweepPkScripts[chanID] = pkScript
|
||||
c.summaries[chanID] = wtdb.ClientChanSummary{
|
||||
SweepPkScript: pkScript,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -429,14 +474,29 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||
breachInfo *lnwallet.BreachRetribution) error {
|
||||
|
||||
// Retrieve the cached sweep pkscript used for this channel.
|
||||
c.sweepPkScriptMu.RLock()
|
||||
sweepPkScript, ok := c.sweepPkScripts[*chanID]
|
||||
c.sweepPkScriptMu.RUnlock()
|
||||
c.backupMu.Lock()
|
||||
summary, ok := c.summaries[*chanID]
|
||||
if !ok {
|
||||
c.backupMu.Unlock()
|
||||
return ErrUnregisteredChannel
|
||||
}
|
||||
|
||||
task := newBackupTask(chanID, breachInfo, sweepPkScript)
|
||||
// Ignore backups that have already been presented to the client.
|
||||
height, ok := c.chanCommitHeights[*chanID]
|
||||
if ok && breachInfo.RevokedStateNum <= height {
|
||||
c.backupMu.Unlock()
|
||||
log.Debugf("Ignoring duplicate backup for chanid=%v at height=%d",
|
||||
chanID, breachInfo.RevokedStateNum)
|
||||
return nil
|
||||
}
|
||||
|
||||
// This backup has a higher commit height than any known backup for this
|
||||
// channel. We'll update our tip so that we won't accept it again if the
|
||||
// link flaps.
|
||||
c.chanCommitHeights[*chanID] = breachInfo.RevokedStateNum
|
||||
c.backupMu.Unlock()
|
||||
|
||||
task := newBackupTask(chanID, breachInfo, summary.SweepPkScript)
|
||||
|
||||
return c.pipeline.QueueBackupTask(task)
|
||||
}
|
||||
|
@ -605,6 +605,8 @@ func (h *testHarness) backupStates(id, from, to uint64, expErr error) {
|
||||
// backupStates instructs the channel identified by id to send a backup for
|
||||
// state i.
|
||||
func (h *testHarness) backupState(id, i uint64, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
_, retribution := h.channel(id).getState(i)
|
||||
|
||||
chanID := chanIDFromInt(id)
|
||||
@ -1244,6 +1246,55 @@ var clientTests = []clientTest{
|
||||
h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
|
||||
},
|
||||
},
|
||||
{
|
||||
// Asserts that the client will deduplicate backups presented by
|
||||
// a channel both in memory and after a restart. The client
|
||||
// should only accept backups with a commit height greater than
|
||||
// any processed already processed for a given policy.
|
||||
name: "dedup backups",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
BlobType: blob.TypeDefault,
|
||||
MaxUpdates: 5,
|
||||
SweepFeeRate: 1,
|
||||
},
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
numUpdates = 10
|
||||
chanID = 0
|
||||
)
|
||||
|
||||
// Generate the retributions that will be backed up.
|
||||
hints := h.advanceChannelN(chanID, numUpdates)
|
||||
|
||||
// Queue the first half of the retributions twice, the
|
||||
// second batch should be entirely deduped by the
|
||||
// client's in-memory tracking.
|
||||
h.backupStates(chanID, 0, numUpdates/2, nil)
|
||||
h.backupStates(chanID, 0, numUpdates/2, nil)
|
||||
|
||||
// Wait for the first half of the updates to be
|
||||
// populated in the server's database.
|
||||
h.waitServerUpdates(hints[:len(hints)/2], 5*time.Second)
|
||||
|
||||
// Restart the client, so we can ensure the deduping is
|
||||
// maintained across restarts.
|
||||
h.client.Stop()
|
||||
h.startClient()
|
||||
defer h.client.ForceQuit()
|
||||
|
||||
// Try to back up the full range of retributions. Only
|
||||
// the second half should actually be sent.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Wait for all of the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, 5*time.Second)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestClient executes the client test suite, asserting the ability to backup
|
||||
|
@ -21,7 +21,7 @@ type DB interface {
|
||||
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
|
||||
|
||||
// LoadTower retrieves a tower by its tower ID.
|
||||
LoadTower(uint64) (*wtdb.Tower, error)
|
||||
LoadTower(wtdb.TowerID) (*wtdb.Tower, error)
|
||||
|
||||
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||
// particular tower id. The index is reserved for that tower until
|
||||
@ -29,7 +29,7 @@ type DB interface {
|
||||
// point a new index for that tower can be reserved. Multiple calls to
|
||||
// this method before CreateClientSession is invoked should return the
|
||||
// same index.
|
||||
NextSessionKeyIndex(uint64) (uint32, error)
|
||||
NextSessionKeyIndex(wtdb.TowerID) (uint32, error)
|
||||
|
||||
// CreateClientSession saves a newly negotiated client session to the
|
||||
// client's database. This enables the session to be used across
|
||||
@ -41,14 +41,17 @@ type DB interface {
|
||||
// still be able to accept state updates.
|
||||
ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||
|
||||
// FetchChanPkScripts returns a map of all sweep pkscripts for
|
||||
// registered channels. This is used on startup to cache the sweep
|
||||
// pkscripts of registered channels in memory.
|
||||
FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error)
|
||||
// FetchChanSummaries loads a mapping from all registered channels to
|
||||
// their channel summaries.
|
||||
FetchChanSummaries() (wtdb.ChannelSummaries, error)
|
||||
|
||||
// AddChanPkScript inserts a newly generated sweep pkscript for the
|
||||
// given channel.
|
||||
AddChanPkScript(lnwire.ChannelID, []byte) error
|
||||
// RegisterChannel registers a channel for use within the client
|
||||
// database. For now, all that is stored in the channel summary is the
|
||||
// sweep pkscript that we'd like any tower sweeps to pay into. In the
|
||||
// future, this will be extended to contain more info to allow the
|
||||
// client efficiently request historical states to be backed up under
|
||||
// the client's active policy.
|
||||
RegisterChannel(lnwire.ChannelID, []byte) error
|
||||
|
||||
// MarkBackupIneligible records that the state identified by the
|
||||
// (channel id, commit height) tuple was ineligible for being backed up
|
||||
@ -61,7 +64,7 @@ type DB interface {
|
||||
// hasn't been ACK'd by the tower. The sequence number of the update
|
||||
// should be exactly one greater than the existing entry, and less that
|
||||
// or equal to the session's MaxUpdates.
|
||||
CommitUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
CommitUpdate(id *wtdb.SessionID,
|
||||
update *wtdb.CommittedUpdate) (uint16, error)
|
||||
|
||||
// AckUpdate records an acknowledgment from the watchtower that the
|
||||
|
@ -417,14 +417,15 @@ func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
|
||||
privKey.PubKey(),
|
||||
)
|
||||
clientSession := &wtdb.ClientSession{
|
||||
TowerID: tower.ID,
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: tower.ID,
|
||||
KeyIndex: keyIndex,
|
||||
Policy: n.cfg.Policy,
|
||||
RewardPkScript: rewardPkScript,
|
||||
},
|
||||
Tower: tower,
|
||||
KeyIndex: keyIndex,
|
||||
SessionPrivKey: privKey,
|
||||
ID: sessionID,
|
||||
Policy: n.cfg.Policy,
|
||||
SeqNum: 0,
|
||||
RewardPkScript: rewardPkScript,
|
||||
}
|
||||
|
||||
err = n.cfg.DB.CreateClientSession(clientSession)
|
||||
|
@ -3,7 +3,6 @@ package wtclient
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -133,7 +132,11 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
|
||||
}
|
||||
sq.queueCond = sync.NewCond(&sq.queueMtx)
|
||||
|
||||
sq.restoreCommittedUpdates()
|
||||
// The database should return them in sorted order, and session queue's
|
||||
// sequence number will be equal to that of the last committed update.
|
||||
for _, update := range sq.cfg.ClientSession.CommittedUpdates {
|
||||
sq.commitQueue.PushBack(update)
|
||||
}
|
||||
|
||||
return sq
|
||||
}
|
||||
@ -212,7 +215,7 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
|
||||
//
|
||||
// TODO(conner): queue backups and retry with different session params.
|
||||
case reserveAvailable:
|
||||
err := task.bindSession(q.cfg.ClientSession)
|
||||
err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody)
|
||||
if err != nil {
|
||||
q.queueCond.L.Unlock()
|
||||
log.Debugf("SessionQueue %s rejected backup chanid=%s "+
|
||||
@ -237,45 +240,6 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
|
||||
return newStatus, true
|
||||
}
|
||||
|
||||
// updateWithSeqNum stores a CommittedUpdate with its assigned sequence number.
|
||||
// This allows committed updates to be sorted after a restart, and added to the
|
||||
// commitQueue in the proper order for delivery.
|
||||
type updateWithSeqNum struct {
|
||||
seqNum uint16
|
||||
update *wtdb.CommittedUpdate
|
||||
}
|
||||
|
||||
// restoreCommittedUpdates processes any CommittedUpdates loaded on startup by
|
||||
// sorting them in ascending order of sequence numbers and adding them to the
|
||||
// commitQueue. These will be sent before any pending updates are processed.
|
||||
func (q *sessionQueue) restoreCommittedUpdates() {
|
||||
committedUpdates := q.cfg.ClientSession.CommittedUpdates
|
||||
|
||||
// Construct and unordered slice of all committed updates with their
|
||||
// assigned sequence numbers.
|
||||
sortedUpdates := make([]updateWithSeqNum, 0, len(committedUpdates))
|
||||
for seqNum, update := range committedUpdates {
|
||||
sortedUpdates = append(sortedUpdates, updateWithSeqNum{
|
||||
seqNum: seqNum,
|
||||
update: update,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort the resulting slice by increasing sequence number.
|
||||
sort.Slice(sortedUpdates, func(i, j int) bool {
|
||||
return sortedUpdates[i].seqNum < sortedUpdates[j].seqNum
|
||||
})
|
||||
|
||||
// Finally, add the sorted, committed updates to he commitQueue. These
|
||||
// updates will be prioritized before any new tasks are assigned to the
|
||||
// sessionQueue. The queue will begin uploading any tasks in the
|
||||
// commitQueue as soon as it is started, e.g. during client
|
||||
// initialization when detecting that this session has unacked updates.
|
||||
for _, update := range sortedUpdates {
|
||||
q.commitQueue.PushBack(update)
|
||||
}
|
||||
}
|
||||
|
||||
// sessionManager is the primary event loop for the sessionQueue, and is
|
||||
// responsible for encrypting and sending accepted tasks to the tower.
|
||||
func (q *sessionQueue) sessionManager() {
|
||||
@ -396,7 +360,7 @@ func (q *sessionQueue) drainBackups() {
|
||||
func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
|
||||
var (
|
||||
seqNum uint16
|
||||
update *wtdb.CommittedUpdate
|
||||
update wtdb.CommittedUpdate
|
||||
isLast bool
|
||||
isPending bool
|
||||
)
|
||||
@ -407,10 +371,9 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
|
||||
// If the commit queue is non-empty, parse the next committed update.
|
||||
case q.commitQueue.Len() > 0:
|
||||
next := q.commitQueue.Front()
|
||||
updateWithSeq := next.Value.(updateWithSeqNum)
|
||||
|
||||
seqNum = updateWithSeq.seqNum
|
||||
update = updateWithSeq.update
|
||||
update = next.Value.(wtdb.CommittedUpdate)
|
||||
seqNum = update.SeqNum
|
||||
|
||||
// If this is the last item in the commit queue and no items
|
||||
// exist in the pending queue, we will use the IsComplete flag
|
||||
@ -449,10 +412,13 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
|
||||
}
|
||||
// TODO(conner): special case other obscure errors
|
||||
|
||||
update = &wtdb.CommittedUpdate{
|
||||
BackupID: task.id,
|
||||
Hint: hint,
|
||||
EncryptedBlob: encBlob,
|
||||
update = wtdb.CommittedUpdate{
|
||||
SeqNum: seqNum,
|
||||
CommittedUpdateBody: wtdb.CommittedUpdateBody{
|
||||
BackupID: task.id,
|
||||
Hint: hint,
|
||||
EncryptedBlob: encBlob,
|
||||
},
|
||||
}
|
||||
|
||||
log.Debugf("Committing state update for session=%s seqnum=%d",
|
||||
@ -470,7 +436,7 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
|
||||
// we send the next time. This step ensures that if we reliably send the
|
||||
// same update for a given sequence number, to prevent us from thinking
|
||||
// we backed up a state when we instead backed up another.
|
||||
lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), seqNum, update)
|
||||
lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), &update)
|
||||
if err != nil {
|
||||
// TODO(conner): mark failed/reschedule
|
||||
return nil, false, fmt.Errorf("unable to commit state update "+
|
||||
@ -478,7 +444,7 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
|
||||
}
|
||||
|
||||
stateUpdate := &wtwire.StateUpdate{
|
||||
SeqNum: seqNum,
|
||||
SeqNum: update.SeqNum,
|
||||
LastApplied: lastApplied,
|
||||
Hint: update.Hint,
|
||||
EncryptedBlob: update.EncryptedBlob,
|
||||
|
32
watchtower/wtdb/client_chan_summary.go
Normal file
32
watchtower/wtdb/client_chan_summary.go
Normal file
@ -0,0 +1,32 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// ChannelSummaries is a map for a given channel id to it's ClientChanSummary.
|
||||
type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary
|
||||
|
||||
// ClientChanSummary tracks channel-specific information. A new
|
||||
// ClientChanSummary is inserted in the database the first time the client
|
||||
// encounters a particular channel.
|
||||
type ClientChanSummary struct {
|
||||
// SweepPkScript is the pkscript to which all justice transactions will
|
||||
// deposit recovered funds for this particular channel.
|
||||
SweepPkScript []byte
|
||||
|
||||
// TODO(conner): later extend with info about initial commit height,
|
||||
// ineligible states, etc.
|
||||
}
|
||||
|
||||
// Encode writes the ClientChanSummary to the passed io.Writer.
|
||||
func (s *ClientChanSummary) Encode(w io.Writer) error {
|
||||
return WriteElement(w, s.SweepPkScript)
|
||||
}
|
||||
|
||||
// Decode reads a ClientChanSummary form the passed io.Reader.
|
||||
func (s *ClientChanSummary) Decode(r io.Reader) error {
|
||||
return ReadElement(r, &s.SweepPkScript)
|
||||
}
|
908
watchtower/wtdb/client_db.go
Normal file
908
watchtower/wtdb/client_db.go
Normal file
@ -0,0 +1,908 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
const (
|
||||
// clientDBName is the filename of client database.
|
||||
clientDBName = "wtclient.db"
|
||||
)
|
||||
|
||||
var (
|
||||
// cSessionKeyIndexBkt is a top-level bucket storing:
|
||||
// tower-id -> reserved-session-key-index (uint32).
|
||||
cSessionKeyIndexBkt = []byte("client-session-key-index-bucket")
|
||||
|
||||
// cChanSummaryBkt is a top-level bucket storing:
|
||||
// channel-id -> encoded ClientChanSummary.
|
||||
cChanSummaryBkt = []byte("client-channel-summary-bucket")
|
||||
|
||||
// cSessionBkt is a top-level bucket storing:
|
||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
||||
// => cSessionAcks => seqnum -> encoded BackupID
|
||||
cSessionBkt = []byte("client-session-bucket")
|
||||
|
||||
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
|
||||
// the ClientSession.
|
||||
cSessionBody = []byte("client-session-body")
|
||||
|
||||
// cSessionBody is a sub-bucket of cSessionBkt storing:
|
||||
// seqnum -> encoded CommittedUpdate.
|
||||
cSessionCommits = []byte("client-session-commits")
|
||||
|
||||
// cSessionAcks is a sub-bucket of cSessionBkt storing:
|
||||
// seqnum -> encoded BackupID.
|
||||
cSessionAcks = []byte("client-session-acks")
|
||||
|
||||
// cTowerBkt is a top-level bucket storing:
|
||||
// tower-id -> encoded Tower.
|
||||
cTowerBkt = []byte("client-tower-bucket")
|
||||
|
||||
// cTowerIndexBkt is a top-level bucket storing:
|
||||
// tower-pubkey -> tower-id.
|
||||
cTowerIndexBkt = []byte("client-tower-index-bucket")
|
||||
|
||||
// ErrTowerNotFound signals that the target tower was not found in the
|
||||
// database.
|
||||
ErrTowerNotFound = errors.New("tower not found")
|
||||
|
||||
// ErrCorruptClientSession signals that the client session's on-disk
|
||||
// structure deviates from what is expected.
|
||||
ErrCorruptClientSession = errors.New("client session corrupted")
|
||||
|
||||
// ErrClientSessionAlreadyExists signals an attempt to reinsert a client
|
||||
// session that has already been created.
|
||||
ErrClientSessionAlreadyExists = errors.New(
|
||||
"client session already exists",
|
||||
)
|
||||
|
||||
// ErrChannelAlreadyRegistered signals a duplicate attempt to register a
|
||||
// channel with the client database.
|
||||
ErrChannelAlreadyRegistered = errors.New("channel already registered")
|
||||
|
||||
// ErrChannelNotRegistered signals a channel has not yet been registered
|
||||
// in the client database.
|
||||
ErrChannelNotRegistered = errors.New("channel not registered")
|
||||
|
||||
// ErrClientSessionNotFound signals that the requested client session
|
||||
// was not found in the database.
|
||||
ErrClientSessionNotFound = errors.New("client session not found")
|
||||
|
||||
// ErrUpdateAlreadyCommitted signals that the chosen sequence number has
|
||||
// already been committed to an update with a different breach hint.
|
||||
ErrUpdateAlreadyCommitted = errors.New("update already committed")
|
||||
|
||||
// ErrCommitUnorderedUpdate signals the client tried to commit a
|
||||
// sequence number other than the next unallocated sequence number.
|
||||
ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic")
|
||||
|
||||
// ErrCommittedUpdateNotFound signals that the tower tried to ACK a
|
||||
// sequence number that has not yet been allocated by the client.
|
||||
ErrCommittedUpdateNotFound = errors.New("committed update not found")
|
||||
|
||||
// ErrUnallocatedLastApplied signals that the tower tried to provide a
|
||||
// LastApplied value greater than any allocated sequence number.
|
||||
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
|
||||
"greater than allocated seqnum")
|
||||
|
||||
// ErrNoReservedKeyIndex signals that a client session could not be
|
||||
// created because no session key index was reserved.
|
||||
ErrNoReservedKeyIndex = errors.New("key index not reserved")
|
||||
|
||||
// ErrIncorrectKeyIndex signals that the client session could not be
|
||||
// created because session key index differs from the reserved key
|
||||
// index.
|
||||
ErrIncorrectKeyIndex = errors.New("incorrect key index")
|
||||
)
|
||||
|
||||
// ClientDB is single database providing a persistent storage engine for the
|
||||
// wtclient.
|
||||
type ClientDB struct {
|
||||
db *bbolt.DB
|
||||
dbPath string
|
||||
}
|
||||
|
||||
// OpenClientDB opens the client database given the path to the database's
|
||||
// directory. If no such database exists, this method will initialize a fresh
|
||||
// one using the latest version number and bucket structure. If a database
|
||||
// exists but has a lower version number than the current version, any necessary
|
||||
// migrations will be applied before returning. Any attempt to open a database
|
||||
// with a version number higher that the latest version will fail to prevent
|
||||
// accidental reversion.
|
||||
func OpenClientDB(dbPath string) (*ClientDB, error) {
|
||||
bdb, firstInit, err := createDBIfNotExist(dbPath, clientDBName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientDB := &ClientDB{
|
||||
db: bdb,
|
||||
dbPath: dbPath,
|
||||
}
|
||||
|
||||
err = initOrSyncVersions(clientDB, firstInit, clientDBVersions)
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now that the database version fully consistent with our latest known
|
||||
// version, ensure that all top-level buckets known to this version are
|
||||
// initialized. This allows us to assume their presence throughout all
|
||||
// operations. If an known top-level bucket is expected to exist but is
|
||||
// missing, this will trigger a ErrUninitializedDB error.
|
||||
err = clientDB.db.Update(initClientDBBuckets)
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return clientDB, nil
|
||||
}
|
||||
|
||||
// initClientDBBuckets creates all top-level buckets required to handle database
|
||||
// operations required by the latest version.
|
||||
func initClientDBBuckets(tx *bbolt.Tx) error {
|
||||
buckets := [][]byte{
|
||||
cSessionKeyIndexBkt,
|
||||
cChanSummaryBkt,
|
||||
cSessionBkt,
|
||||
cTowerBkt,
|
||||
cTowerIndexBkt,
|
||||
}
|
||||
|
||||
for _, bucket := range buckets {
|
||||
_, err := tx.CreateBucketIfNotExists(bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// bdb returns the backing bbolt.DB instance.
|
||||
//
|
||||
// NOTE: Part of the versionedDB interface.
|
||||
func (c *ClientDB) bdb() *bbolt.DB {
|
||||
return c.db
|
||||
}
|
||||
|
||||
// Version returns the database's current version number.
|
||||
//
|
||||
// NOTE: Part of the versionedDB interface.
|
||||
func (c *ClientDB) Version() (uint32, error) {
|
||||
var version uint32
|
||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||
var err error
|
||||
version, err = getDBVersion(tx)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database.
|
||||
func (c *ClientDB) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
// CreateTower initializes a database entry with the given lightning address. If
|
||||
// the tower exists, the address is append to the list of all addresses used to
|
||||
// that tower previously.
|
||||
func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
||||
var towerPubKey [33]byte
|
||||
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
|
||||
|
||||
var tower *Tower
|
||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
||||
towerIndex := tx.Bucket(cTowerIndexBkt)
|
||||
if towerIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towers := tx.Bucket(cTowerBkt)
|
||||
if towers == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Check if the tower index already knows of this pubkey.
|
||||
towerIDBytes := towerIndex.Get(towerPubKey[:])
|
||||
if len(towerIDBytes) == 8 {
|
||||
// The tower already exists, deserialize the existing
|
||||
// record.
|
||||
var err error
|
||||
tower, err = getTower(towers, towerIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the new address to the existing tower. If the
|
||||
// address is a duplicate, this will result in no
|
||||
// change.
|
||||
tower.AddAddress(lnAddr.Address)
|
||||
} else {
|
||||
// No such tower exists, create a new tower id for our
|
||||
// new tower. The error is unhandled since NextSequence
|
||||
// never fails in an Update.
|
||||
towerID, _ := towerIndex.NextSequence()
|
||||
|
||||
tower = &Tower{
|
||||
ID: TowerID(towerID),
|
||||
IdentityKey: lnAddr.IdentityKey,
|
||||
Addresses: []net.Addr{lnAddr.Address},
|
||||
}
|
||||
|
||||
towerIDBytes = tower.ID.Bytes()
|
||||
|
||||
// Since this tower is new, record the mapping from
|
||||
// tower pubkey to tower id in the tower index.
|
||||
err := towerIndex.Put(towerPubKey[:], towerIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Store the new or updated tower under its tower id.
|
||||
return putTower(towers, tower)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tower, nil
|
||||
}
|
||||
|
||||
// LoadTower retrieves a tower by its tower ID.
|
||||
func (c *ClientDB) LoadTower(towerID TowerID) (*Tower, error) {
|
||||
var tower *Tower
|
||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||
towers := tx.Bucket(cTowerBkt)
|
||||
if towers == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
var err error
|
||||
tower, err = getTower(towers, towerID.Bytes())
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tower, nil
|
||||
}
|
||||
|
||||
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||
// particular tower id. The index is reserved for that tower until
|
||||
// CreateClientSession is invoked for that tower and index, at which point a new
|
||||
// index for that tower can be reserved. Multiple calls to this method before
|
||||
// CreateClientSession is invoked should return the same index.
|
||||
func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
|
||||
var index uint32
|
||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
||||
keyIndex := tx.Bucket(cSessionKeyIndexBkt)
|
||||
if keyIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Check the session key index to see if a key has already been
|
||||
// reserved for this tower. If so, we'll deserialize and return
|
||||
// the index directly.
|
||||
towerIDBytes := towerID.Bytes()
|
||||
indexBytes := keyIndex.Get(towerIDBytes)
|
||||
if len(indexBytes) == 4 {
|
||||
index = byteOrder.Uint32(indexBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, generate a new session key index since the node
|
||||
// doesn't already have reserved index. The error is ignored
|
||||
// since NextSequence can't fail inside Update.
|
||||
index64, _ := keyIndex.NextSequence()
|
||||
|
||||
// As a sanity check, assert that the index is still in the
|
||||
// valid range of unhardened pubkeys. In the future, we should
|
||||
// move to only using hardened keys, and this will prevent any
|
||||
// overlap from occurring until then. This also prevents us from
|
||||
// overflowing uint32s.
|
||||
if index64 > math.MaxInt32 {
|
||||
return fmt.Errorf("exhausted session key indexes")
|
||||
}
|
||||
|
||||
index = uint32(index64)
|
||||
|
||||
var indexBuf [4]byte
|
||||
byteOrder.PutUint32(indexBuf[:], index)
|
||||
|
||||
// Record the reserved session key index under this tower's id.
|
||||
return keyIndex.Put(towerIDBytes, indexBuf[:])
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// CreateClientSession records a newly negotiated client session in the set of
|
||||
// active sessions. The session can be identified by its SessionID.
|
||||
func (c *ClientDB) CreateClientSession(session *ClientSession) error {
|
||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
||||
keyIndexes := tx.Bucket(cSessionKeyIndexBkt)
|
||||
if keyIndexes == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
sessions := tx.Bucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Check that client session with this session id doesn't
|
||||
// already exist.
|
||||
existingSessionBytes := sessions.Bucket(session.ID[:])
|
||||
if existingSessionBytes != nil {
|
||||
return ErrClientSessionAlreadyExists
|
||||
}
|
||||
|
||||
// Check that this tower has a reserved key index.
|
||||
towerIDBytes := session.TowerID.Bytes()
|
||||
keyIndexBytes := keyIndexes.Get(towerIDBytes)
|
||||
if len(keyIndexBytes) != 4 {
|
||||
return ErrNoReservedKeyIndex
|
||||
}
|
||||
|
||||
// Assert that the key index of the inserted session matches the
|
||||
// reserved session key index.
|
||||
index := byteOrder.Uint32(keyIndexBytes)
|
||||
if index != session.KeyIndex {
|
||||
return ErrIncorrectKeyIndex
|
||||
}
|
||||
|
||||
// Remove the key index reservation.
|
||||
err := keyIndexes.Delete(towerIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, write the client session's body in the sessions
|
||||
// bucket.
|
||||
return putClientSessionBody(sessions, session)
|
||||
})
|
||||
}
|
||||
|
||||
// ListClientSessions returns the set of all client sessions known to the db.
|
||||
func (c *ClientDB) ListClientSessions() (map[SessionID]*ClientSession, error) {
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
return sessions.ForEach(func(k, _ []byte) error {
|
||||
// We'll load the full client session since the client
|
||||
// will need the CommittedUpdates and AckedUpdates on
|
||||
// startup to resume committed updates and compute the
|
||||
// highest known commit height for each channel.
|
||||
session, err := getClientSession(sessions, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientSessions[session.ID] = session
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return clientSessions, nil
|
||||
}
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to their
|
||||
// channel summaries.
|
||||
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
|
||||
summaries := make(map[lnwire.ChannelID]ClientChanSummary)
|
||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||
chanSummaries := tx.Bucket(cChanSummaryBkt)
|
||||
if chanSummaries == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
return chanSummaries.ForEach(func(k, v []byte) error {
|
||||
var chanID lnwire.ChannelID
|
||||
copy(chanID[:], k)
|
||||
|
||||
var summary ClientChanSummary
|
||||
err := summary.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
summaries[chanID] = summary
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
// RegisterChannel registers a channel for use within the client database. For
|
||||
// now, all that is stored in the channel summary is the sweep pkscript that
|
||||
// we'd like any tower sweeps to pay into. In the future, this will be extended
|
||||
// to contain more info to allow the client efficiently request historical
|
||||
// states to be backed up under the client's active policy.
|
||||
func (c *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
|
||||
sweepPkScript []byte) error {
|
||||
|
||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
||||
chanSummaries := tx.Bucket(cChanSummaryBkt)
|
||||
if chanSummaries == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
_, err := getChanSummary(chanSummaries, chanID)
|
||||
switch {
|
||||
|
||||
// Summary already exists.
|
||||
case err == nil:
|
||||
return ErrChannelAlreadyRegistered
|
||||
|
||||
// Channel is not registered, proceed with registration.
|
||||
case err == ErrChannelNotRegistered:
|
||||
|
||||
// Unexpected error.
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
summary := ClientChanSummary{
|
||||
SweepPkScript: sweepPkScript,
|
||||
}
|
||||
|
||||
return putChanSummary(chanSummaries, chanID, &summary)
|
||||
})
|
||||
}
|
||||
|
||||
// MarkBackupIneligible records that the state identified by the (channel id,
|
||||
// commit height) tuple was ineligible for being backed up under the current
|
||||
// policy. This state can be retried later under a different policy.
|
||||
func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID,
|
||||
commitHeight uint64) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||
// seqNum). This allows the client to retransmit this update on startup.
|
||||
func (c *ClientDB) CommitUpdate(id *SessionID,
|
||||
update *CommittedUpdate) (uint16, error) {
|
||||
|
||||
var lastApplied uint16
|
||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// We'll only load the ClientSession body for performance, since
|
||||
// we primarily need to inspect its SeqNum and TowerLastApplied
|
||||
// fields. The CommittedUpdates will be modified on disk
|
||||
// directly.
|
||||
session, err := getClientSessionBody(sessions, id[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Can't fail if the above didn't fail.
|
||||
sessionBkt := sessions.Bucket(id[:])
|
||||
|
||||
// Ensure the session commits sub-bucket is initialized.
|
||||
sessionCommits, err := sessionBkt.CreateBucketIfNotExists(
|
||||
cSessionCommits,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var seqNumBuf [2]byte
|
||||
byteOrder.PutUint16(seqNumBuf[:], update.SeqNum)
|
||||
|
||||
// Check to see if a committed update already exists for this
|
||||
// sequence number.
|
||||
committedUpdateBytes := sessionCommits.Get(seqNumBuf[:])
|
||||
if committedUpdateBytes != nil {
|
||||
var dbUpdate CommittedUpdate
|
||||
err := dbUpdate.Decode(
|
||||
bytes.NewReader(committedUpdateBytes),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If an existing committed update has a different hint,
|
||||
// we'll reject this newer update.
|
||||
if dbUpdate.Hint != update.Hint {
|
||||
return ErrUpdateAlreadyCommitted
|
||||
}
|
||||
|
||||
// Otherwise, capture the last applied value and
|
||||
// succeed.
|
||||
lastApplied = session.TowerLastApplied
|
||||
return nil
|
||||
}
|
||||
|
||||
// There's no committed update for this sequence number, ensure
|
||||
// that we are committing the next unallocated one.
|
||||
if update.SeqNum != session.SeqNum+1 {
|
||||
return ErrCommitUnorderedUpdate
|
||||
}
|
||||
|
||||
// Increment the session's sequence number and store the updated
|
||||
// client session.
|
||||
//
|
||||
// TODO(conner): split out seqnum and last applied own bucket to
|
||||
// eliminate serialization of full struct during CommitUpdate?
|
||||
// Can also read/write directly to byes [:2] without migration.
|
||||
session.SeqNum++
|
||||
err = putClientSessionBody(sessions, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encode and store the committed update in the sessionCommits
|
||||
// sub-bucket under the requested sequence number.
|
||||
var b bytes.Buffer
|
||||
err = update.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = sessionCommits.Put(seqNumBuf[:], b.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, capture the session's last applied value so it can
|
||||
// be sent in the next state update to the tower.
|
||||
lastApplied = session.TowerLastApplied
|
||||
|
||||
return nil
|
||||
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return lastApplied, nil
|
||||
}
|
||||
|
||||
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
|
||||
// removes the update from the set of committed updates, and validates the
|
||||
// lastApplied value returned from the tower.
|
||||
func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
||||
lastApplied uint16) error {
|
||||
|
||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// We'll only load the ClientSession body for performance, since
|
||||
// we primarily need to inspect its SeqNum and TowerLastApplied
|
||||
// fields. The CommittedUpdates and AckedUpdates will be
|
||||
// modified on disk directly.
|
||||
session, err := getClientSessionBody(sessions, id[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the tower has acked a sequence number beyond our highest
|
||||
// sequence number, fail.
|
||||
if lastApplied > session.SeqNum {
|
||||
return ErrUnallocatedLastApplied
|
||||
}
|
||||
|
||||
// If the tower acked with a lower sequence number than it gave
|
||||
// us prior, fail.
|
||||
if lastApplied < session.TowerLastApplied {
|
||||
return ErrLastAppliedReversion
|
||||
}
|
||||
|
||||
// TODO(conner): split out seqnum and last applied own bucket to
|
||||
// eliminate serialization of full struct during AckUpdate? Can
|
||||
// also read/write directly to byes [2:4] without migration.
|
||||
session.TowerLastApplied = lastApplied
|
||||
|
||||
// Write the client session with the updated last applied value.
|
||||
err = putClientSessionBody(sessions, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Can't fail because of getClientSession succeeded.
|
||||
sessionBkt := sessions.Bucket(id[:])
|
||||
|
||||
// If the commits sub-bucket doesn't exist, there can't possibly
|
||||
// be a corresponding committed update to remove.
|
||||
sessionCommits := sessionBkt.Bucket(cSessionCommits)
|
||||
if sessionCommits == nil {
|
||||
return ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
var seqNumBuf [2]byte
|
||||
byteOrder.PutUint16(seqNumBuf[:], seqNum)
|
||||
|
||||
// Assert that a committed update exists for this sequence
|
||||
// number.
|
||||
committedUpdateBytes := sessionCommits.Get(seqNumBuf[:])
|
||||
if committedUpdateBytes == nil {
|
||||
return ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
var committedUpdate CommittedUpdate
|
||||
err = committedUpdate.Decode(
|
||||
bytes.NewReader(committedUpdateBytes),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the corresponding committed update.
|
||||
err = sessionCommits.Delete(seqNumBuf[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that the session acks sub-bucket is initialized so we
|
||||
// can insert an entry.
|
||||
sessionAcks, err := sessionBkt.CreateBucketIfNotExists(
|
||||
cSessionAcks,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The session acks only need to track the backup id of the
|
||||
// update, so we can discard the blob and hint.
|
||||
var b bytes.Buffer
|
||||
err = committedUpdate.BackupID.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, insert the ack into the sessionAcks sub-bucket.
|
||||
return sessionAcks.Put(seqNumBuf[:], b.Bytes())
|
||||
})
|
||||
}
|
||||
|
||||
// getClientSessionBody loads the body of a ClientSession from the sessions
|
||||
// bucket corresponding to the serialized session id. This does not deserialize
|
||||
// the CommittedUpdates or AckUpdates associated with the session. If the caller
|
||||
// requires this info, use getClientSession.
|
||||
func getClientSessionBody(sessions *bbolt.Bucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
|
||||
sessionBkt := sessions.Bucket(idBytes)
|
||||
if sessionBkt == nil {
|
||||
return nil, ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Should never have a sessionBkt without also having its body.
|
||||
sessionBody := sessionBkt.Get(cSessionBody)
|
||||
if sessionBody == nil {
|
||||
return nil, ErrCorruptClientSession
|
||||
}
|
||||
|
||||
var session ClientSession
|
||||
copy(session.ID[:], idBytes)
|
||||
|
||||
err := session.Decode(bytes.NewReader(sessionBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// getClientSession loads the full ClientSession associated with the serialized
|
||||
// session id. This method populates the CommittedUpdates and AckUpdates in
|
||||
// addition to the ClientSession's body.
|
||||
func getClientSession(sessions *bbolt.Bucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
|
||||
session, err := getClientSessionBody(sessions, idBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the committed updates for this session.
|
||||
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the acked updates for this session.
|
||||
ackedUpdates, err := getClientSessionAcks(sessions, idBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.CommittedUpdates = commitedUpdates
|
||||
session.AckedUpdates = ackedUpdates
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// getClientSessionCommits retrieves all committed updates for the session
|
||||
// identified by the serialized session id.
|
||||
func getClientSessionCommits(sessions *bbolt.Bucket,
|
||||
idBytes []byte) ([]CommittedUpdate, error) {
|
||||
|
||||
// Can't fail because client session body has already been read.
|
||||
sessionBkt := sessions.Bucket(idBytes)
|
||||
|
||||
// Initialize commitedUpdates so that we can return an initialized map
|
||||
// if no committed updates exist.
|
||||
committedUpdates := make([]CommittedUpdate, 0)
|
||||
|
||||
sessionCommits := sessionBkt.Bucket(cSessionCommits)
|
||||
if sessionCommits == nil {
|
||||
return committedUpdates, nil
|
||||
}
|
||||
|
||||
err := sessionCommits.ForEach(func(k, v []byte) error {
|
||||
var committedUpdate CommittedUpdate
|
||||
err := committedUpdate.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
committedUpdate.SeqNum = byteOrder.Uint16(k)
|
||||
|
||||
committedUpdates = append(committedUpdates, committedUpdate)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return committedUpdates, nil
|
||||
}
|
||||
|
||||
// getClientSessionAcks retrieves all acked updates for the session identified
|
||||
// by the serialized session id.
|
||||
func getClientSessionAcks(sessions *bbolt.Bucket,
|
||||
idBytes []byte) (map[uint16]BackupID, error) {
|
||||
|
||||
// Can't fail because client session body has already been read.
|
||||
sessionBkt := sessions.Bucket(idBytes)
|
||||
|
||||
// Initialize ackedUpdates so that we can return an initialized map if
|
||||
// no acked updates exist.
|
||||
ackedUpdates := make(map[uint16]BackupID)
|
||||
|
||||
sessionAcks := sessionBkt.Bucket(cSessionAcks)
|
||||
if sessionAcks == nil {
|
||||
return ackedUpdates, nil
|
||||
}
|
||||
|
||||
err := sessionAcks.ForEach(func(k, v []byte) error {
|
||||
seqNum := byteOrder.Uint16(k)
|
||||
|
||||
var backupID BackupID
|
||||
err := backupID.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ackedUpdates[seqNum] = backupID
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ackedUpdates, nil
|
||||
}
|
||||
|
||||
// putClientSessionBody stores the body of the ClientSession (everything but the
|
||||
// CommittedUpdates and AckedUpdates).
|
||||
func putClientSessionBody(sessions *bbolt.Bucket,
|
||||
session *ClientSession) error {
|
||||
|
||||
sessionBkt, err := sessions.CreateBucketIfNotExists(session.ID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
err = session.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sessionBkt.Put(cSessionBody, b.Bytes())
|
||||
}
|
||||
|
||||
// getChanSummary loads a ClientChanSummary for the passed chanID.
|
||||
func getChanSummary(chanSummaries *bbolt.Bucket,
|
||||
chanID lnwire.ChannelID) (*ClientChanSummary, error) {
|
||||
|
||||
chanSummaryBytes := chanSummaries.Get(chanID[:])
|
||||
if chanSummaryBytes == nil {
|
||||
return nil, ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
var summary ClientChanSummary
|
||||
err := summary.Decode(bytes.NewReader(chanSummaryBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &summary, nil
|
||||
}
|
||||
|
||||
// putChanSummary stores a ClientChanSummary for the passed chanID.
|
||||
func putChanSummary(chanSummaries *bbolt.Bucket, chanID lnwire.ChannelID,
|
||||
summary *ClientChanSummary) error {
|
||||
|
||||
var b bytes.Buffer
|
||||
err := summary.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return chanSummaries.Put(chanID[:], b.Bytes())
|
||||
}
|
||||
|
||||
// getTower loads a Tower identified by its serialized tower id.
|
||||
func getTower(towers *bbolt.Bucket, id []byte) (*Tower, error) {
|
||||
towerBytes := towers.Get(id)
|
||||
if towerBytes == nil {
|
||||
return nil, ErrTowerNotFound
|
||||
}
|
||||
|
||||
var tower Tower
|
||||
err := tower.Decode(bytes.NewReader(towerBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tower.ID = TowerIDFromBytes(id)
|
||||
|
||||
return &tower, nil
|
||||
}
|
||||
|
||||
// putTower stores a Tower identified by its serialized tower id.
|
||||
func putTower(towers *bbolt.Bucket, tower *Tower) error {
|
||||
var b bytes.Buffer
|
||||
err := tower.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return towers.Put(tower.ID.Bytes(), b.Bytes())
|
||||
}
|
688
watchtower/wtdb/client_db_test.go
Normal file
688
watchtower/wtdb/client_db_test.go
Normal file
@ -0,0 +1,688 @@
|
||||
package wtdb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
|
||||
// clientDBInit is a closure used to initialize a wtclient.DB instance its
|
||||
// cleanup function.
|
||||
type clientDBInit func(t *testing.T) (wtclient.DB, func())
|
||||
|
||||
type clientDBHarness struct {
|
||||
t *testing.T
|
||||
db wtclient.DB
|
||||
}
|
||||
|
||||
func newClientDBHarness(t *testing.T, init clientDBInit) (*clientDBHarness, func()) {
|
||||
db, cleanup := init(t)
|
||||
|
||||
h := &clientDBHarness{
|
||||
t: t,
|
||||
db: db,
|
||||
}
|
||||
|
||||
return h, cleanup
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.CreateClientSession(session)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected create client session error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) listSessions() map[wtdb.SessionID]*wtdb.ClientSession {
|
||||
h.t.Helper()
|
||||
|
||||
sessions, err := h.db.ListClientSessions()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to list client sessions: %v", err)
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, expErr error) uint32 {
|
||||
h.t.Helper()
|
||||
|
||||
index, err := h.db.NextSessionKeyIndex(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected next session key index error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
h.t.Fatalf("next key index should never be 0")
|
||||
}
|
||||
|
||||
return index
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
|
||||
expErr error) *wtdb.Tower {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.CreateTower(lnAddr)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err)
|
||||
}
|
||||
|
||||
if tower.ID == 0 {
|
||||
h.t.Fatalf("tower id should never be 0")
|
||||
}
|
||||
|
||||
return tower
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) loadTower(id wtdb.TowerID, expErr error) *wtdb.Tower {
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.LoadTower(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
|
||||
}
|
||||
|
||||
return tower
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary {
|
||||
h.t.Helper()
|
||||
|
||||
summaries, err := h.db.FetchChanSummaries()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch chan summaries: %v", err)
|
||||
}
|
||||
|
||||
return summaries
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID,
|
||||
sweepPkScript []byte, expErr error) {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.RegisterChannel(chanID, sweepPkScript)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected register channel error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
|
||||
update *wtdb.CommittedUpdate, expErr error) uint16 {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
lastApplied, err := h.db.CommitUpdate(id, update)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected commit update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
|
||||
return lastApplied
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
lastApplied uint16, expErr error) {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.AckUpdate(id, seqNum, lastApplied)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected commit update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// testCreateClientSession asserts various conditions regarding the creation of
|
||||
// a new ClientSession. The test asserts:
|
||||
// - client sessions can only be created if a session key index is reserved.
|
||||
// - client sessions cannot be created with an incorrect session key index .
|
||||
// - inserting duplicate sessions fails.
|
||||
func testCreateClientSession(h *clientDBHarness) {
|
||||
// Create a test client session to insert.
|
||||
session := &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: wtdb.TowerID(3),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 100,
|
||||
},
|
||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
ID: wtdb.SessionID([33]byte{0x01}),
|
||||
}
|
||||
|
||||
// First, assert that this session is not already present in the
|
||||
// database.
|
||||
if _, ok := h.listSessions()[session.ID]; ok {
|
||||
h.t.Fatalf("session for id %x should not exist yet", session.ID)
|
||||
}
|
||||
|
||||
// Attempting to insert the client session without reserving a session
|
||||
// key index should fail.
|
||||
h.insertSession(session, wtdb.ErrNoReservedKeyIndex)
|
||||
|
||||
// Now, reserve a session key for this tower.
|
||||
keyIndex := h.nextKeyIndex(session.TowerID, nil)
|
||||
|
||||
// The client session hasn't been updated with the reserved key index
|
||||
// (since it's still zero). Inserting should fail due to the mismatch.
|
||||
h.insertSession(session, wtdb.ErrIncorrectKeyIndex)
|
||||
|
||||
// Reserve another key for the same index. Since no session has been
|
||||
// successfully created, it should return the same index to maintain
|
||||
// idempotency across restarts.
|
||||
keyIndex2 := h.nextKeyIndex(session.TowerID, nil)
|
||||
if keyIndex != keyIndex2 {
|
||||
h.t.Fatalf("next key index should be idempotent: want: %v, "+
|
||||
"got %v", keyIndex, keyIndex2)
|
||||
}
|
||||
|
||||
// Now, set the client session's key index so that it is proper and
|
||||
// insert it. This should succeed.
|
||||
session.KeyIndex = keyIndex
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Verify that the session now exists in the database.
|
||||
if _, ok := h.listSessions()[session.ID]; !ok {
|
||||
h.t.Fatalf("session for id %x should exist now", session.ID)
|
||||
}
|
||||
|
||||
// Attempt to insert the session again, which should fail due to the
|
||||
// session already existing.
|
||||
h.insertSession(session, wtdb.ErrClientSessionAlreadyExists)
|
||||
|
||||
// Finally, assert that reserving another key index succeeds with a
|
||||
// different key index, now that the first one has been finalized.
|
||||
keyIndex3 := h.nextKeyIndex(session.TowerID, nil)
|
||||
if keyIndex == keyIndex3 {
|
||||
h.t.Fatalf("key index still reserved after creating session")
|
||||
}
|
||||
}
|
||||
|
||||
// testCreateTower asserts the behavior of creating new Tower objects within the
|
||||
// database, and that the latest address is always prepended to the list of
|
||||
// known addresses for the tower.
|
||||
func testCreateTower(h *clientDBHarness) {
|
||||
// Test that loading a tower with an arbitrary tower id fails.
|
||||
h.loadTower(20, wtdb.ErrTowerNotFound)
|
||||
|
||||
pk, err := randPubKey()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to generate pubkey: %v", err)
|
||||
}
|
||||
|
||||
addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
lnAddr := &lnwire.NetAddress{
|
||||
IdentityKey: pk,
|
||||
Address: addr1,
|
||||
}
|
||||
|
||||
// Insert a random tower into the database.
|
||||
tower := h.createTower(lnAddr, nil)
|
||||
|
||||
// Load the tower from the database and assert that it matches the tower
|
||||
// we created.
|
||||
tower2 := h.loadTower(tower.ID, nil)
|
||||
if !reflect.DeepEqual(tower, tower2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
tower, tower2)
|
||||
}
|
||||
|
||||
// Insert the address again into the database. Since the address is the
|
||||
// same, this should result in an unmodified tower record.
|
||||
towerDupAddr := h.createTower(lnAddr, nil)
|
||||
if len(towerDupAddr.Addresses) != 1 {
|
||||
h.t.Fatalf("duplicate address should be deduped")
|
||||
}
|
||||
if !reflect.DeepEqual(tower, towerDupAddr) {
|
||||
h.t.Fatalf("mismatch towers, want: %v, got: %v",
|
||||
tower, towerDupAddr)
|
||||
}
|
||||
|
||||
// Generate a new address for this tower.
|
||||
addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
|
||||
lnAddr2 := &lnwire.NetAddress{
|
||||
IdentityKey: pk,
|
||||
Address: addr2,
|
||||
}
|
||||
|
||||
// Insert the updated address, which should produce a tower with a new
|
||||
// address.
|
||||
towerNewAddr := h.createTower(lnAddr2, nil)
|
||||
|
||||
// Load the tower from the database, and assert that it matches the
|
||||
// tower returned from creation.
|
||||
towerNewAddr2 := h.loadTower(tower.ID, nil)
|
||||
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
towerNewAddr, towerNewAddr2)
|
||||
}
|
||||
|
||||
// Assert that there are now two addresses on the tower object.
|
||||
if len(towerNewAddr.Addresses) != 2 {
|
||||
h.t.Fatalf("new address should be added")
|
||||
}
|
||||
|
||||
// Finally, assert that the new address was prepended since it is deemed
|
||||
// fresher.
|
||||
if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) {
|
||||
h.t.Fatalf("new address should be prepended")
|
||||
}
|
||||
}
|
||||
|
||||
// testChanSummaries tests the process of a registering a channel and its
|
||||
// associated sweep pkscript.
|
||||
func testChanSummaries(h *clientDBHarness) {
|
||||
// First, assert that this channel is not already registered.
|
||||
var chanID lnwire.ChannelID
|
||||
if _, ok := h.fetchChanSummaries()[chanID]; ok {
|
||||
h.t.Fatalf("pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
}
|
||||
|
||||
// Generate a random sweep pkscript and register it for this channel.
|
||||
expPkScript := make([]byte, 22)
|
||||
if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil {
|
||||
h.t.Fatalf("unable to generate pkscript: %v", err)
|
||||
}
|
||||
h.registerChan(chanID, expPkScript, nil)
|
||||
|
||||
// Assert that the channel exists and that its sweep pkscript matches
|
||||
// the one we registered.
|
||||
summary, ok := h.fetchChanSummaries()[chanID]
|
||||
if !ok {
|
||||
h.t.Fatalf("pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
} else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 {
|
||||
h.t.Fatalf("pkscript mismatch, want: %x, got: %x",
|
||||
expPkScript, summary.SweepPkScript)
|
||||
}
|
||||
|
||||
// Finally, assert that re-registering the same channel produces a
|
||||
// failure.
|
||||
h.registerChan(chanID, expPkScript, wtdb.ErrChannelAlreadyRegistered)
|
||||
}
|
||||
|
||||
// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can
|
||||
func testCommitUpdate(h *clientDBHarness) {
|
||||
session := &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: wtdb.TowerID(3),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 100,
|
||||
},
|
||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
ID: wtdb.SessionID([33]byte{0x02}),
|
||||
}
|
||||
|
||||
// Generate a random update and try to commit before inserting the
|
||||
// session, which should fail.
|
||||
update1 := randCommittedUpdate(h.t, 1)
|
||||
h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound)
|
||||
|
||||
// Reserve a session key index and insert the session.
|
||||
session.KeyIndex = h.nextKeyIndex(session.TowerID, nil)
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Now, try to commit the update that failed initially which should
|
||||
// succeed. The lastApplied value should be 0 since we have not received
|
||||
// an ack from the tower.
|
||||
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied != 0 {
|
||||
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
|
||||
// Assert that the committed update appears in the client session's
|
||||
// CommittedUpdates map when loaded from disk and that there are no
|
||||
// AckedUpdates.
|
||||
dbSession := h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
|
||||
// Try to commit the same update, which should succeed due to
|
||||
// idempotency (which is preserved when the breach hint is identical to
|
||||
// the on-disk update's hint). The lastApplied value should remain
|
||||
// unchanged.
|
||||
lastApplied2 := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied2 != lastApplied {
|
||||
h.t.Fatalf("last applied should not have changed, got %v",
|
||||
lastApplied2)
|
||||
}
|
||||
|
||||
// Assert that the loaded ClientSession is the same as before.
|
||||
dbSession = h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
|
||||
// Generate another random update and try to commit it at the identical
|
||||
// sequence number. Since the breach hint has changed, this should fail.
|
||||
update2 := randCommittedUpdate(h.t, 1)
|
||||
h.commitUpdate(&session.ID, update2, wtdb.ErrUpdateAlreadyCommitted)
|
||||
|
||||
// Next, insert the new update at the next unallocated sequence number
|
||||
// which should succeed.
|
||||
update2.SeqNum = 2
|
||||
lastApplied3 := h.commitUpdate(&session.ID, update2, nil)
|
||||
if lastApplied3 != lastApplied {
|
||||
h.t.Fatalf("last applied should not have changed, got %v",
|
||||
lastApplied3)
|
||||
}
|
||||
|
||||
// Check that both updates now appear as committed on the ClientSession
|
||||
// loaded from disk.
|
||||
dbSession = h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
*update2,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
|
||||
// Finally, create one more random update and try to commit it at index
|
||||
// 4, which should be rejected since 3 is the next slot the database
|
||||
// expects.
|
||||
update4 := randCommittedUpdate(h.t, 4)
|
||||
h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate)
|
||||
|
||||
// Assert that the ClientSession loaded from disk remains unchanged.
|
||||
dbSession = h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
*update2,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
}
|
||||
|
||||
// testAckUpdate asserts the behavior of AckUpdate.
|
||||
func testAckUpdate(h *clientDBHarness) {
|
||||
// Create a new session that the updates in this will be tied to.
|
||||
session := &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: wtdb.TowerID(3),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 100,
|
||||
},
|
||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
ID: wtdb.SessionID([33]byte{0x03}),
|
||||
}
|
||||
|
||||
// Try to ack an update before inserting the client session, which
|
||||
// should fail.
|
||||
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrClientSessionNotFound)
|
||||
|
||||
// Reserve a session key and insert the client session.
|
||||
session.KeyIndex = h.nextKeyIndex(session.TowerID, nil)
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Now, try to ack update 1. This should fail since update 1 was never
|
||||
// committed.
|
||||
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrCommittedUpdateNotFound)
|
||||
|
||||
// Commit to a random update at seqnum 1.
|
||||
update1 := randCommittedUpdate(h.t, 1)
|
||||
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied != 0 {
|
||||
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
|
||||
// Acking seqnum 1 should succeed.
|
||||
h.ackUpdate(&session.ID, 1, 1, nil)
|
||||
|
||||
// Acking seqnum 1 again should fail.
|
||||
h.ackUpdate(&session.ID, 1, 1, wtdb.ErrCommittedUpdateNotFound)
|
||||
|
||||
// Acking a valid seqnum with a reverted last applied value should fail.
|
||||
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrLastAppliedReversion)
|
||||
|
||||
// Acking with a last applied greater than any allocated seqnum should
|
||||
// fail.
|
||||
h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
|
||||
|
||||
// Assert that the ClientSession loaded from disk has one update in it's
|
||||
// AckedUpdates map, and that the committed update has been removed.
|
||||
dbSession := h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, nil)
|
||||
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
|
||||
1: update1.BackupID,
|
||||
})
|
||||
|
||||
// Commit to another random update, and assert that the last applied
|
||||
// value is 1, since this was what was provided in the last successful
|
||||
// ack.
|
||||
update2 := randCommittedUpdate(h.t, 2)
|
||||
lastApplied = h.commitUpdate(&session.ID, update2, nil)
|
||||
if lastApplied != 1 {
|
||||
h.t.Fatalf("last applied mismatch, want: 1, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
|
||||
// Ack seqnum 2.
|
||||
h.ackUpdate(&session.ID, 2, 2, nil)
|
||||
|
||||
// Assert that both updates exist as AckedUpdates when loaded from disk.
|
||||
dbSession = h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, nil)
|
||||
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
|
||||
1: update1.BackupID,
|
||||
2: update2.BackupID,
|
||||
})
|
||||
|
||||
// Acking again with a lower last applied should fail.
|
||||
h.ackUpdate(&session.ID, 2, 1, wtdb.ErrLastAppliedReversion)
|
||||
|
||||
// Acking an unallocated seqnum should fail.
|
||||
h.ackUpdate(&session.ID, 4, 2, wtdb.ErrCommittedUpdateNotFound)
|
||||
|
||||
// Acking with a last applied greater than any allocated seqnum should
|
||||
// fail.
|
||||
h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
|
||||
}
|
||||
|
||||
// checkCommittedUpdates asserts that the CommittedUpdates on session match the
|
||||
// expUpdates provided.
|
||||
func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates []wtdb.CommittedUpdate) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
// We promote nil expUpdates to an initialized slice since the database
|
||||
// should never return a nil slice. This promotion is done purely out of
|
||||
// convenience for the testing framework.
|
||||
if expUpdates == nil {
|
||||
expUpdates = make([]wtdb.CommittedUpdate, 0)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) {
|
||||
t.Fatalf("committed updates mismatch, want: %v, got: %v",
|
||||
expUpdates, session.CommittedUpdates)
|
||||
}
|
||||
}
|
||||
|
||||
// checkAckedUpdates asserts that the AckedUpdates on a sessio match the
|
||||
// expUpdates provided.
|
||||
func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates map[uint16]wtdb.BackupID) {
|
||||
|
||||
// We promote nil expUpdates to an initialized map since the database
|
||||
// should never return a nil map. This promotion is done purely out of
|
||||
// convenience for the testing framework.
|
||||
if expUpdates == nil {
|
||||
expUpdates = make(map[uint16]wtdb.BackupID)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(session.AckedUpdates, expUpdates) {
|
||||
t.Fatalf("acked updates mismatch, want: %v, got: %v",
|
||||
expUpdates, session.AckedUpdates)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
|
||||
// and the mock implementation. This ensures that all databases function
|
||||
// identically, especially in the negative paths.
|
||||
func TestClientDB(t *testing.T) {
|
||||
dbs := []struct {
|
||||
name string
|
||||
init clientDBInit
|
||||
}{
|
||||
{
|
||||
name: "fresh clientdb",
|
||||
init: func(t *testing.T) (wtclient.DB, func()) {
|
||||
path, err := ioutil.TempDir("", "clientdb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make temp dir: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
db, err := wtdb.OpenClientDB(path)
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
os.RemoveAll(path)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reopened clientdb",
|
||||
init: func(t *testing.T) (wtclient.DB, func()) {
|
||||
path, err := ioutil.TempDir("", "clientdb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make temp dir: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
db, err := wtdb.OpenClientDB(path)
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
db.Close()
|
||||
|
||||
db, err = wtdb.OpenClientDB(path)
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
t.Fatalf("unable to reopen db: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
os.RemoveAll(path)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mock",
|
||||
init: func(t *testing.T) (wtclient.DB, func()) {
|
||||
return wtmock.NewClientDB(), func() {}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*clientDBHarness)
|
||||
}{
|
||||
{
|
||||
name: "create client session",
|
||||
run: testCreateClientSession,
|
||||
},
|
||||
{
|
||||
name: "create tower",
|
||||
run: testCreateTower,
|
||||
},
|
||||
{
|
||||
name: "chan summaries",
|
||||
run: testChanSummaries,
|
||||
},
|
||||
{
|
||||
name: "commit update",
|
||||
run: testCommitUpdate,
|
||||
},
|
||||
{
|
||||
name: "ack update",
|
||||
run: testAckUpdate,
|
||||
},
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
db := database
|
||||
t.Run(db.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
h, cleanup := newClientDBHarness(
|
||||
t, db.init,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
test.run(h)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// randCommittedUpdate generates a random committed update.
|
||||
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
|
||||
var chanID lnwire.ChannelID
|
||||
if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil {
|
||||
t.Fatalf("unable to generate chan id: %v", err)
|
||||
}
|
||||
|
||||
var hint wtdb.BreachHint
|
||||
if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil {
|
||||
t.Fatalf("unable to generate breach hint: %v", err)
|
||||
}
|
||||
|
||||
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
|
||||
if _, err := io.ReadFull(crand.Reader, encBlob); err != nil {
|
||||
t.Fatalf("unable to generate encrypted blob: %v", err)
|
||||
}
|
||||
|
||||
return &wtdb.CommittedUpdate{
|
||||
SeqNum: seqNum,
|
||||
CommittedUpdateBody: wtdb.CommittedUpdateBody{
|
||||
BackupID: wtdb.BackupID{
|
||||
ChanID: chanID,
|
||||
CommitHeight: 666,
|
||||
},
|
||||
Hint: hint,
|
||||
EncryptedBlob: encBlob,
|
||||
},
|
||||
}
|
||||
}
|
@ -1,43 +1,21 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrClientSessionNotFound signals that the requested client session
|
||||
// was not found in the database.
|
||||
ErrClientSessionNotFound = errors.New("client session not found")
|
||||
// CSessionStatus is a bit-field representing the possible statuses of
|
||||
// ClientSessions.
|
||||
type CSessionStatus uint8
|
||||
|
||||
// ErrUpdateAlreadyCommitted signals that the chosen sequence number has
|
||||
// already been committed to an update with a different breach hint.
|
||||
ErrUpdateAlreadyCommitted = errors.New("update already committed")
|
||||
|
||||
// ErrCommitUnorderedUpdate signals the client tried to commit a
|
||||
// sequence number other than the next unallocated sequence number.
|
||||
ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic")
|
||||
|
||||
// ErrCommittedUpdateNotFound signals that the tower tried to ACK a
|
||||
// sequence number that has not yet been allocated by the client.
|
||||
ErrCommittedUpdateNotFound = errors.New("committed update not found")
|
||||
|
||||
// ErrUnallocatedLastApplied signals that the tower tried to provide a
|
||||
// LastApplied value greater than any allocated sequence number.
|
||||
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
|
||||
"greater than allocated seqnum")
|
||||
|
||||
// ErrNoReservedKeyIndex signals that a client session could not be
|
||||
// created because no session key index was reserved.
|
||||
ErrNoReservedKeyIndex = errors.New("key index not reserved")
|
||||
|
||||
// ErrIncorrectKeyIndex signals that the client session could not be
|
||||
// created because session key index differs from the reserved key
|
||||
// index.
|
||||
ErrIncorrectKeyIndex = errors.New("incorrect key index")
|
||||
const (
|
||||
// CSessionActive indicates that the ClientSession is active and can be
|
||||
// used for backups.
|
||||
CSessionActive CSessionStatus = 0
|
||||
)
|
||||
|
||||
// ClientSession encapsulates a SessionInfo returned from a successful
|
||||
@ -46,8 +24,48 @@ var (
|
||||
type ClientSession struct {
|
||||
// ID is the client's public key used when authenticating with the
|
||||
// tower.
|
||||
//
|
||||
// NOTE: This value is not serialized with the body of the struct, it
|
||||
// should be set and recovered as the ClientSession's key.
|
||||
ID SessionID
|
||||
|
||||
ClientSessionBody
|
||||
|
||||
// CommittedUpdates is a sorted list of unacked updates. These updates
|
||||
// can be resent after a restart if the updates failed to send or
|
||||
// receive an acknowledgment.
|
||||
//
|
||||
// NOTE: This list is serialized in it's own bucket, separate from the
|
||||
// body of the ClientSession. The representation on disk is a key value
|
||||
// map from sequence number to CommittedUpdateBody to allow efficient
|
||||
// insertion and retrieval.
|
||||
CommittedUpdates []CommittedUpdate
|
||||
|
||||
// AckedUpdates is a map from sequence number to backup id to record
|
||||
// which revoked states were uploaded via this session.
|
||||
//
|
||||
// NOTE: This map is serialized in it's own bucket, separate from the
|
||||
// body of the ClientSession.
|
||||
AckedUpdates map[uint16]BackupID
|
||||
|
||||
// Tower holds the pubkey and address of the watchtower.
|
||||
//
|
||||
// NOTE: This value is not serialized. It is recovered by looking up the
|
||||
// tower with TowerID.
|
||||
Tower *Tower
|
||||
|
||||
// SessionPrivKey is the ephemeral secret key used to connect to the
|
||||
// watchtower.
|
||||
//
|
||||
// NOTE: This value is not serialized. It is derived using the KeyIndex
|
||||
// on startup to avoid storing private keys on disk.
|
||||
SessionPrivKey *btcec.PrivateKey
|
||||
}
|
||||
|
||||
// ClientSessionBody represents the primary components of a ClientSession that
|
||||
// are serialized together within the database. The CommittedUpdates and
|
||||
// AckedUpdates are serialized in buckets separate from the body.
|
||||
type ClientSessionBody struct {
|
||||
// SeqNum is the next unallocated sequence number that can be sent to
|
||||
// the tower.
|
||||
SeqNum uint16
|
||||
@ -57,13 +75,7 @@ type ClientSession struct {
|
||||
|
||||
// TowerID is the unique, db-assigned identifier that references the
|
||||
// Tower with which the session is negotiated.
|
||||
TowerID uint64
|
||||
|
||||
// Tower holds the pubkey and address of the watchtower.
|
||||
//
|
||||
// NOTE: This value is not serialized. It is recovered by looking up the
|
||||
// tower with TowerID.
|
||||
Tower *Tower
|
||||
TowerID TowerID
|
||||
|
||||
// KeyIndex is the index of key locator used to derive the client's
|
||||
// session key so that it can authenticate with the tower to update its
|
||||
@ -71,29 +83,54 @@ type ClientSession struct {
|
||||
// use the keychain.KeyFamilyTowerSession key family.
|
||||
KeyIndex uint32
|
||||
|
||||
// SessionPrivKey is the ephemeral secret key used to connect to the
|
||||
// watchtower.
|
||||
//
|
||||
// NOTE: This value is not serialized. It is derived using the KeyIndex
|
||||
// on startup to avoid storing private keys on disk.
|
||||
SessionPrivKey *btcec.PrivateKey
|
||||
|
||||
// Policy holds the negotiated session parameters.
|
||||
Policy wtpolicy.Policy
|
||||
|
||||
// Status indicates the current state of the ClientSession.
|
||||
Status CSessionStatus
|
||||
|
||||
// RewardPkScript is the pkscript that the tower's reward will be
|
||||
// deposited to if a sweep transaction confirms and the sessions
|
||||
// specifies a reward output.
|
||||
RewardPkScript []byte
|
||||
}
|
||||
|
||||
// CommittedUpdates is a map from allocated sequence numbers to unacked
|
||||
// updates. These updates can be resent after a restart if the update
|
||||
// failed to send or receive an acknowledgment.
|
||||
CommittedUpdates map[uint16]*CommittedUpdate
|
||||
// Encode writes a ClientSessionBody to the passed io.Writer.
|
||||
func (s *ClientSessionBody) Encode(w io.Writer) error {
|
||||
return WriteElements(w,
|
||||
s.SeqNum,
|
||||
s.TowerLastApplied,
|
||||
uint64(s.TowerID),
|
||||
s.KeyIndex,
|
||||
uint8(s.Status),
|
||||
s.Policy,
|
||||
s.RewardPkScript,
|
||||
)
|
||||
}
|
||||
|
||||
// AckedUpdates is a map from sequence number to backup id to record
|
||||
// which revoked states were uploaded via this session.
|
||||
AckedUpdates map[uint16]BackupID
|
||||
// Decode reads a ClientSessionBody from the passed io.Reader.
|
||||
func (s *ClientSessionBody) Decode(r io.Reader) error {
|
||||
var (
|
||||
towerID uint64
|
||||
status uint8
|
||||
)
|
||||
err := ReadElements(r,
|
||||
&s.SeqNum,
|
||||
&s.TowerLastApplied,
|
||||
&towerID,
|
||||
&s.KeyIndex,
|
||||
&status,
|
||||
&s.Policy,
|
||||
&s.RewardPkScript,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.TowerID = TowerID(towerID)
|
||||
s.Status = CSessionStatus(status)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackupID identifies a particular revoked, remote commitment by channel id and
|
||||
@ -106,9 +143,38 @@ type BackupID struct {
|
||||
CommitHeight uint64
|
||||
}
|
||||
|
||||
// Encode writes the BackupID from the passed io.Writer.
|
||||
func (b *BackupID) Encode(w io.Writer) error {
|
||||
return WriteElements(w,
|
||||
b.ChanID,
|
||||
b.CommitHeight,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode reads a BackupID from the passed io.Reader.
|
||||
func (b *BackupID) Decode(r io.Reader) error {
|
||||
return ReadElements(r,
|
||||
&b.ChanID,
|
||||
&b.CommitHeight,
|
||||
)
|
||||
}
|
||||
|
||||
// CommittedUpdate holds a state update sent by a client along with its
|
||||
// SessionID.
|
||||
// allocated sequence number and the exact remote commitment the encrypted
|
||||
// justice transaction can rectify.
|
||||
type CommittedUpdate struct {
|
||||
// SeqNum is the unique sequence number allocated by the session to this
|
||||
// update.
|
||||
SeqNum uint16
|
||||
|
||||
CommittedUpdateBody
|
||||
}
|
||||
|
||||
// CommittedUpdateBody represents the primary components of a CommittedUpdate.
|
||||
// On disk, this is stored under the sequence number, which acts as its key.
|
||||
type CommittedUpdateBody struct {
|
||||
// BackupID identifies the breached commitment that the encrypted blob
|
||||
// can spend from.
|
||||
BackupID BackupID
|
||||
|
||||
// Hint is the 16-byte prefix of the revoked commitment transaction ID.
|
||||
@ -119,3 +185,29 @@ type CommittedUpdate struct {
|
||||
// hint is broadcast.
|
||||
EncryptedBlob []byte
|
||||
}
|
||||
|
||||
// Encode writes the CommittedUpdateBody to the passed io.Writer.
|
||||
func (u *CommittedUpdateBody) Encode(w io.Writer) error {
|
||||
err := u.BackupID.Encode(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteElements(w,
|
||||
u.Hint,
|
||||
u.EncryptedBlob,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode reads a CommittedUpdateBody from the passed io.Reader.
|
||||
func (u *CommittedUpdateBody) Decode(r io.Reader) error {
|
||||
err := u.BackupID.Decode(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ReadElements(r,
|
||||
&u.Hint,
|
||||
&u.EncryptedBlob,
|
||||
)
|
||||
}
|
||||
|
@ -2,14 +2,122 @@ package wtdb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/tor"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
func randPubKey() (*btcec.PublicKey, error) {
|
||||
priv, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return priv.PubKey(), nil
|
||||
}
|
||||
|
||||
func randTCP4Addr(r *rand.Rand) (*net.TCPAddr, error) {
|
||||
var ip [4]byte
|
||||
if _, err := r.Read(ip[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var port [2]byte
|
||||
if _, err := r.Read(port[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addrIP := net.IP(ip[:])
|
||||
addrPort := int(binary.BigEndian.Uint16(port[:]))
|
||||
|
||||
return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil
|
||||
}
|
||||
|
||||
func randTCP6Addr(r *rand.Rand) (*net.TCPAddr, error) {
|
||||
var ip [16]byte
|
||||
if _, err := r.Read(ip[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var port [2]byte
|
||||
if _, err := r.Read(port[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addrIP := net.IP(ip[:])
|
||||
addrPort := int(binary.BigEndian.Uint16(port[:]))
|
||||
|
||||
return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil
|
||||
}
|
||||
|
||||
func randV2OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) {
|
||||
var serviceID [tor.V2DecodedLen]byte
|
||||
if _, err := r.Read(serviceID[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var port [2]byte
|
||||
if _, err := r.Read(port[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
onionService := tor.Base32Encoding.EncodeToString(serviceID[:])
|
||||
onionService += tor.OnionSuffix
|
||||
addrPort := int(binary.BigEndian.Uint16(port[:]))
|
||||
|
||||
return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil
|
||||
}
|
||||
|
||||
func randV3OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) {
|
||||
var serviceID [tor.V3DecodedLen]byte
|
||||
if _, err := r.Read(serviceID[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var port [2]byte
|
||||
if _, err := r.Read(port[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
onionService := tor.Base32Encoding.EncodeToString(serviceID[:])
|
||||
onionService += tor.OnionSuffix
|
||||
addrPort := int(binary.BigEndian.Uint16(port[:]))
|
||||
|
||||
return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil
|
||||
}
|
||||
|
||||
func randAddrs(r *rand.Rand) ([]net.Addr, error) {
|
||||
tcp4Addr, err := randTCP4Addr(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tcp6Addr, err := randTCP6Addr(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v2OnionAddr, err := randV2OnionAddr(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v3OnionAddr, err := randV3OnionAddr(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []net.Addr{tcp4Addr, tcp6Addr, v2OnionAddr, v3OnionAddr}, nil
|
||||
}
|
||||
|
||||
// dbObject is abstract object support encoding and decoding.
|
||||
type dbObject interface {
|
||||
Encode(io.Writer) error
|
||||
@ -19,7 +127,9 @@ type dbObject interface {
|
||||
// TestCodec serializes and deserializes wtdb objects in order to test that that
|
||||
// the codec understands all of the required field types. The test also asserts
|
||||
// that decoding an object into another results in an equivalent object.
|
||||
func TestCodec(t *testing.T) {
|
||||
func TestCodec(tt *testing.T) {
|
||||
|
||||
var t *testing.T
|
||||
mainScenario := func(obj dbObject) bool {
|
||||
// Ensure encoding the object succeeds.
|
||||
var b bytes.Buffer
|
||||
@ -35,6 +145,16 @@ func TestCodec(t *testing.T) {
|
||||
obj2 = &wtdb.SessionInfo{}
|
||||
case *wtdb.SessionStateUpdate:
|
||||
obj2 = &wtdb.SessionStateUpdate{}
|
||||
case *wtdb.ClientSessionBody:
|
||||
obj2 = &wtdb.ClientSessionBody{}
|
||||
case *wtdb.CommittedUpdateBody:
|
||||
obj2 = &wtdb.CommittedUpdateBody{}
|
||||
case *wtdb.BackupID:
|
||||
obj2 = &wtdb.BackupID{}
|
||||
case *wtdb.Tower:
|
||||
obj2 = &wtdb.Tower{}
|
||||
case *wtdb.ClientChanSummary:
|
||||
obj2 = &wtdb.ClientChanSummary{}
|
||||
default:
|
||||
t.Fatalf("unknown type: %T", obj)
|
||||
return false
|
||||
@ -57,6 +177,29 @@ func TestCodec(t *testing.T) {
|
||||
return true
|
||||
}
|
||||
|
||||
customTypeGen := map[string]func([]reflect.Value, *rand.Rand){
|
||||
"Tower": func(v []reflect.Value, r *rand.Rand) {
|
||||
pk, err := randPubKey()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate pubkey: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrs, err := randAddrs(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate addrs: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
obj := wtdb.Tower{
|
||||
IdentityKey: pk,
|
||||
Addresses: addrs,
|
||||
}
|
||||
|
||||
v[0] = reflect.ValueOf(obj)
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scenario interface{}
|
||||
@ -73,11 +216,51 @@ func TestCodec(t *testing.T) {
|
||||
return mainScenario(&obj)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ClientSessionBody",
|
||||
scenario: func(obj wtdb.ClientSessionBody) bool {
|
||||
return mainScenario(&obj)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CommittedUpdateBody",
|
||||
scenario: func(obj wtdb.CommittedUpdateBody) bool {
|
||||
return mainScenario(&obj)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BackupID",
|
||||
scenario: func(obj wtdb.BackupID) bool {
|
||||
return mainScenario(&obj)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Tower",
|
||||
scenario: func(obj wtdb.Tower) bool {
|
||||
return mainScenario(&obj)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ClientChanSummary",
|
||||
scenario: func(obj wtdb.ClientChanSummary) bool {
|
||||
return mainScenario(&obj)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if err := quick.Check(test.scenario, nil); err != nil {
|
||||
tt.Run(test.name, func(h *testing.T) {
|
||||
t = h
|
||||
|
||||
var config *quick.Config
|
||||
if valueGen, ok := customTypeGen[test.name]; ok {
|
||||
config = &quick.Config{
|
||||
Values: valueGen,
|
||||
}
|
||||
}
|
||||
|
||||
err := quick.Check(test.scenario, config)
|
||||
if err != nil {
|
||||
t.Fatalf("fuzz checks for msg=%s failed: %v",
|
||||
test.name, err)
|
||||
}
|
||||
|
92
watchtower/wtdb/db_common.go
Normal file
92
watchtower/wtdb/db_common.go
Normal file
@ -0,0 +1,92 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/coreos/bbolt"
|
||||
)
|
||||
|
||||
const (
|
||||
// dbFilePermission requests read+write access to the db file.
|
||||
dbFilePermission = 0600
|
||||
)
|
||||
|
||||
var (
|
||||
// metadataBkt stores all the meta information concerning the state of
|
||||
// the database.
|
||||
metadataBkt = []byte("metadata-bucket")
|
||||
|
||||
// dbVersionKey is a static key used to retrieve the database version
|
||||
// number from the metadataBkt.
|
||||
dbVersionKey = []byte("version")
|
||||
|
||||
// ErrUninitializedDB signals that top-level buckets for the database
|
||||
// have not been initialized.
|
||||
ErrUninitializedDB = errors.New("db not initialized")
|
||||
|
||||
// ErrNoDBVersion signals that the database contains no version info.
|
||||
ErrNoDBVersion = errors.New("db has no version")
|
||||
|
||||
// byteOrder is the default endianness used when serializing integers.
|
||||
byteOrder = binary.BigEndian
|
||||
)
|
||||
|
||||
// fileExists returns true if the file exists, and false otherwise.
|
||||
func fileExists(path string) bool {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// createDBIfNotExist opens the boltdb database at dbPath/name, creating one if
|
||||
// one doesn't exist. The boolean returned indicates if the database did not
|
||||
// exist before, or if it has been created but no version metadata exists within
|
||||
// it.
|
||||
func createDBIfNotExist(dbPath, name string) (*bbolt.DB, bool, error) {
|
||||
path := filepath.Join(dbPath, name)
|
||||
|
||||
// If the database file doesn't exist, this indicates we much initialize
|
||||
// a fresh database with the latest version.
|
||||
firstInit := !fileExists(path)
|
||||
if firstInit {
|
||||
// Ensure all parent directories are initialized.
|
||||
err := os.MkdirAll(dbPath, 0700)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
bdb, err := bbolt.Open(path, dbFilePermission, nil)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// If the file existed previously, we'll now check to see that the
|
||||
// metadata bucket is properly initialized. It could be the case that
|
||||
// the database was created, but we failed to actually populate any
|
||||
// metadata. If the metadata bucket does not actually exist, we'll
|
||||
// set firstInit to true so that we can treat is initialize the bucket.
|
||||
if !firstInit {
|
||||
var metadataExists bool
|
||||
err = bdb.View(func(tx *bbolt.Tx) error {
|
||||
metadataExists = tx.Bucket(metadataBkt) != nil
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if !metadataExists {
|
||||
firstInit = true
|
||||
}
|
||||
}
|
||||
|
||||
return bdb, firstInit, nil
|
||||
}
|
@ -1,26 +1,38 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTowerNotFound signals that the target tower was not found in the
|
||||
// database.
|
||||
ErrTowerNotFound = errors.New("tower not found")
|
||||
)
|
||||
// TowerID is a unique 64-bit identifier allocated to each unique watchtower.
|
||||
// This allows the client to conserve on-disk space by not needing to always
|
||||
// reference towers by their pubkey.
|
||||
type TowerID uint64
|
||||
|
||||
// TowerIDFromBytes constructs a TowerID from the provided byte slice. The
|
||||
// argument must have at least 8 bytes, and should contain the TowerID in
|
||||
// big-endian byte order.
|
||||
func TowerIDFromBytes(towerIDBytes []byte) TowerID {
|
||||
return TowerID(byteOrder.Uint64(towerIDBytes))
|
||||
}
|
||||
|
||||
// Bytes encodes a TowerID into an 8-byte slice in big-endian byte order.
|
||||
func (id TowerID) Bytes() []byte {
|
||||
var buf [8]byte
|
||||
byteOrder.PutUint64(buf[:], uint64(id))
|
||||
return buf[:]
|
||||
}
|
||||
|
||||
// Tower holds the necessary components required to connect to a remote tower.
|
||||
// Communication is handled by brontide, and requires both a public key and an
|
||||
// address.
|
||||
type Tower struct {
|
||||
// ID is a unique ID for this record assigned by the database.
|
||||
ID uint64
|
||||
ID TowerID
|
||||
|
||||
// IdentityKey is the public key of the remote node, used to
|
||||
// authenticate the brontide transport.
|
||||
@ -28,18 +40,15 @@ type Tower struct {
|
||||
|
||||
// Addresses is a list of possible addresses to reach the tower.
|
||||
Addresses []net.Addr
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// AddAddress adds the given address to the tower's in-memory list of addresses.
|
||||
// If the address's string is already present, the Tower will be left
|
||||
// unmodified. Otherwise, the adddress is prepended to the beginning of the
|
||||
// Tower's addresses, on the assumption that it is fresher than the others.
|
||||
//
|
||||
// NOTE: This method is NOT safe for concurrent use.
|
||||
func (t *Tower) AddAddress(addr net.Addr) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Ensure we don't add a duplicate address.
|
||||
addrStr := addr.String()
|
||||
for _, existingAddr := range t.Addresses {
|
||||
@ -56,10 +65,9 @@ func (t *Tower) AddAddress(addr net.Addr) {
|
||||
// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's
|
||||
// addresses. This can be used to have a client try multiple addresses for the
|
||||
// same Tower.
|
||||
//
|
||||
// NOTE: This method is NOT safe for concurrent use.
|
||||
func (t *Tower) LNAddrs() []*lnwire.NetAddress {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses))
|
||||
for _, addr := range t.Addresses {
|
||||
addrs = append(addrs, &lnwire.NetAddress{
|
||||
@ -70,3 +78,21 @@ func (t *Tower) LNAddrs() []*lnwire.NetAddress {
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
// Encode writes the Tower to the passed io.Writer. The TowerID is not
|
||||
// serialized, since it acts as the key.
|
||||
func (t *Tower) Encode(w io.Writer) error {
|
||||
return WriteElements(w,
|
||||
t.IdentityKey,
|
||||
t.Addresses,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode reads a Tower from the passed io.Reader. The TowerID is meant to be
|
||||
// decoded from the key.
|
||||
func (t *Tower) Decode(r io.Reader) error {
|
||||
return ReadElements(r,
|
||||
&t.IdentityKey,
|
||||
&t.Addresses,
|
||||
)
|
||||
}
|
||||
|
@ -2,23 +2,16 @@ package wtdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
)
|
||||
|
||||
const (
|
||||
// dbName is the filename of tower database.
|
||||
dbName = "watchtower.db"
|
||||
|
||||
// dbFilePermission requests read+write access to the db file.
|
||||
dbFilePermission = 0600
|
||||
// towerDBName is the filename of tower database.
|
||||
towerDBName = "watchtower.db"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -49,26 +42,9 @@ var (
|
||||
// epoch from the lookoutTipBkt.
|
||||
lookoutTipKey = []byte("lookout-tip")
|
||||
|
||||
// metadataBkt stores all the meta information concerning the state of
|
||||
// the database.
|
||||
metadataBkt = []byte("metadata-bucket")
|
||||
|
||||
// dbVersionKey is a static key used to retrieve the database version
|
||||
// number from the metadataBkt.
|
||||
dbVersionKey = []byte("version")
|
||||
|
||||
// ErrUninitializedDB signals that top-level buckets for the database
|
||||
// have not been initialized.
|
||||
ErrUninitializedDB = errors.New("tower db not initialized")
|
||||
|
||||
// ErrNoDBVersion signals that the database contains no version info.
|
||||
ErrNoDBVersion = errors.New("tower db has no version")
|
||||
|
||||
// ErrNoSessionHintIndex signals that an active session does not have an
|
||||
// initialized index for tracking its own state updates.
|
||||
ErrNoSessionHintIndex = errors.New("session hint index missing")
|
||||
|
||||
byteOrder = binary.BigEndian
|
||||
)
|
||||
|
||||
// TowerDB is single database providing a persistent storage engine for the
|
||||
@ -86,67 +62,20 @@ type TowerDB struct {
|
||||
// with a version number higher that the latest version will fail to prevent
|
||||
// accidental reversion.
|
||||
func OpenTowerDB(dbPath string) (*TowerDB, error) {
|
||||
path := filepath.Join(dbPath, dbName)
|
||||
|
||||
// If the database file doesn't exist, this indicates we much initialize
|
||||
// a fresh database with the latest version.
|
||||
firstInit := !fileExists(path)
|
||||
if firstInit {
|
||||
// Ensure all parent directories are initialized.
|
||||
err := os.MkdirAll(dbPath, 0700)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
bdb, err := bbolt.Open(path, dbFilePermission, nil)
|
||||
bdb, firstInit, err := createDBIfNotExist(dbPath, towerDBName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the file existed previously, we'll now check to see that the
|
||||
// metadata bucket is properly initialized. It could be the case that
|
||||
// the database was created, but we failed to actually populate any
|
||||
// metadata. If the metadata bucket does not actually exist, we'll
|
||||
// set firstInit to true so that we can treat is initialize the bucket.
|
||||
if !firstInit {
|
||||
var metadataExists bool
|
||||
err = bdb.View(func(tx *bbolt.Tx) error {
|
||||
metadataExists = tx.Bucket(metadataBkt) != nil
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !metadataExists {
|
||||
firstInit = true
|
||||
}
|
||||
}
|
||||
|
||||
towerDB := &TowerDB{
|
||||
db: bdb,
|
||||
dbPath: dbPath,
|
||||
}
|
||||
|
||||
if firstInit {
|
||||
// If the database has not yet been created, we'll initialize
|
||||
// the database version with the latest known version.
|
||||
err = towerDB.db.Update(func(tx *bbolt.Tx) error {
|
||||
return initDBVersion(tx, getLatestDBVersion(dbVersions))
|
||||
})
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Otherwise, ensure that any migrations are applied to ensure
|
||||
// the data is in the format expected by the latest version.
|
||||
err = towerDB.syncVersions(dbVersions)
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
err = initOrSyncVersions(towerDB, firstInit, towerDBVersions)
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now that the database version fully consistent with our latest known
|
||||
@ -163,17 +92,6 @@ func OpenTowerDB(dbPath string) (*TowerDB, error) {
|
||||
return towerDB, nil
|
||||
}
|
||||
|
||||
// fileExists returns true if the file exists, and false otherwise.
|
||||
func fileExists(path string) bool {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// initTowerDBBuckets creates all top-level buckets required to handle database
|
||||
// operations required by the latest version.
|
||||
func initTowerDBBuckets(tx *bbolt.Tx) error {
|
||||
@ -194,53 +112,16 @@ func initTowerDBBuckets(tx *bbolt.Tx) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncVersions ensures the database version is consistent with the highest
|
||||
// known database version, applying any migrations that have not been made. If
|
||||
// the highest known version number is lower than the database's version, this
|
||||
// method will fail to prevent accidental reversions.
|
||||
func (t *TowerDB) syncVersions(versions []version) error {
|
||||
curVersion, err := t.Version()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
latestVersion := getLatestDBVersion(versions)
|
||||
switch {
|
||||
|
||||
// Current version is higher than any known version, fail to prevent
|
||||
// reversion.
|
||||
case curVersion > latestVersion:
|
||||
return channeldb.ErrDBReversion
|
||||
|
||||
// Current version matches highest known version, nothing to do.
|
||||
case curVersion == latestVersion:
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, apply any migrations in order to bring the database
|
||||
// version up to the highest known version.
|
||||
updates := getMigrations(versions, curVersion)
|
||||
return t.db.Update(func(tx *bbolt.Tx) error {
|
||||
for _, update := range updates {
|
||||
if update.migration == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("Applying migration #%d", update.number)
|
||||
|
||||
err := update.migration(tx)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to apply migration #%d: %v",
|
||||
err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return putDBVersion(tx, latestVersion)
|
||||
})
|
||||
// bdb returns the backing bbolt.DB instance.
|
||||
//
|
||||
// NOTE: Part of the versionedDB interface.
|
||||
func (t *TowerDB) bdb() *bbolt.DB {
|
||||
return t.db
|
||||
}
|
||||
|
||||
// Version returns the database's current version number.
|
||||
//
|
||||
// NOTE: Part of the versionedDB interface.
|
||||
func (t *TowerDB) Version() (uint32, error) {
|
||||
var version uint32
|
||||
err := t.db.View(func(tx *bbolt.Tx) error {
|
||||
|
@ -1,6 +1,9 @@
|
||||
package wtdb
|
||||
|
||||
import "github.com/coreos/bbolt"
|
||||
import (
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
)
|
||||
|
||||
// migration is a function which takes a prior outdated version of the database
|
||||
// instances and mutates the key/bucket structure to arrive at a more
|
||||
@ -10,32 +13,30 @@ type migration func(tx *bbolt.Tx) error
|
||||
// version pairs a version number with the migration that would need to be
|
||||
// applied from the prior version to upgrade.
|
||||
type version struct {
|
||||
number uint32
|
||||
migration migration
|
||||
}
|
||||
|
||||
// dbVersions stores all versions and migrations of the database. This list will
|
||||
// be used when opening the database to determine if any migrations must be
|
||||
// applied.
|
||||
var dbVersions = []version{
|
||||
{
|
||||
// Initial version requires no migration.
|
||||
number: 0,
|
||||
migration: nil,
|
||||
},
|
||||
}
|
||||
// towerDBVersions stores all versions and migrations of the tower database.
|
||||
// This list will be used when opening the database to determine if any
|
||||
// migrations must be applied.
|
||||
var towerDBVersions = []version{}
|
||||
|
||||
// clientDBVersions stores all versions and migrations of the client database.
|
||||
// This list will be used when opening the database to determine if any
|
||||
// migrations must be applied.
|
||||
var clientDBVersions = []version{}
|
||||
|
||||
// getLatestDBVersion returns the last known database version.
|
||||
func getLatestDBVersion(versions []version) uint32 {
|
||||
return versions[len(versions)-1].number
|
||||
return uint32(len(versions))
|
||||
}
|
||||
|
||||
// getMigrations returns a slice of all updates with a greater number that
|
||||
// curVersion that need to be applied to sync up with the latest version.
|
||||
func getMigrations(versions []version, curVersion uint32) []version {
|
||||
var updates []version
|
||||
for _, v := range versions {
|
||||
if v.number > curVersion {
|
||||
for i, v := range versions {
|
||||
if uint32(i)+1 > curVersion {
|
||||
updates = append(updates, v)
|
||||
}
|
||||
}
|
||||
@ -82,3 +83,81 @@ func putDBVersion(tx *bbolt.Tx, version uint32) error {
|
||||
byteOrder.PutUint32(versionBytes, version)
|
||||
return metadata.Put(dbVersionKey, versionBytes)
|
||||
}
|
||||
|
||||
// versionedDB is a private interface implemented by both the tower and client
|
||||
// databases, permitting all versioning operations to be performed generically
|
||||
// on either.
|
||||
type versionedDB interface {
|
||||
// bdb returns the underlying bbolt database.
|
||||
bdb() *bbolt.DB
|
||||
|
||||
// Version returns the current version stored in the database.
|
||||
Version() (uint32, error)
|
||||
}
|
||||
|
||||
// initOrSyncVersions ensures that the database version is properly set before
|
||||
// opening the database up for regular use. When the database is being
|
||||
// initialized for the first time, the caller should set init to true, which
|
||||
// will simply write the latest version to the database. Otherwise, passing init
|
||||
// as false will cause the database to apply any needed migrations to ensure its
|
||||
// version matches the latest version in the provided versions list.
|
||||
func initOrSyncVersions(db versionedDB, init bool, versions []version) error {
|
||||
// If the database has not yet been created, we'll initialize the
|
||||
// database version with the latest known version.
|
||||
if init {
|
||||
return db.bdb().Update(func(tx *bbolt.Tx) error {
|
||||
return initDBVersion(tx, getLatestDBVersion(versions))
|
||||
})
|
||||
}
|
||||
|
||||
// Otherwise, ensure that any migrations are applied to ensure the data
|
||||
// is in the format expected by the latest version.
|
||||
return syncVersions(db, versions)
|
||||
}
|
||||
|
||||
// syncVersions ensures the database version is consistent with the highest
|
||||
// known database version, applying any migrations that have not been made. If
|
||||
// the highest known version number is lower than the database's version, this
|
||||
// method will fail to prevent accidental reversions.
|
||||
func syncVersions(db versionedDB, versions []version) error {
|
||||
curVersion, err := db.Version()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
latestVersion := getLatestDBVersion(versions)
|
||||
switch {
|
||||
|
||||
// Current version is higher than any known version, fail to prevent
|
||||
// reversion.
|
||||
case curVersion > latestVersion:
|
||||
return channeldb.ErrDBReversion
|
||||
|
||||
// Current version matches highest known version, nothing to do.
|
||||
case curVersion == latestVersion:
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, apply any migrations in order to bring the database
|
||||
// version up to the highest known version.
|
||||
updates := getMigrations(versions, curVersion)
|
||||
return db.bdb().Update(func(tx *bbolt.Tx) error {
|
||||
for i, update := range updates {
|
||||
if update.migration == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
version := curVersion + uint32(i) + 1
|
||||
log.Infof("Applying migration #%d", version)
|
||||
|
||||
err := update.migration(tx)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to apply migration #%d: %v",
|
||||
version, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return putDBVersion(tx, latestVersion)
|
||||
})
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package wtmock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -18,23 +17,23 @@ type ClientDB struct {
|
||||
nextTowerID uint64 // to be used atomically
|
||||
|
||||
mu sync.Mutex
|
||||
sweepPkScripts map[lnwire.ChannelID][]byte
|
||||
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
||||
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
|
||||
towerIndex map[towerPK]uint64
|
||||
towers map[uint64]*wtdb.Tower
|
||||
towerIndex map[towerPK]wtdb.TowerID
|
||||
towers map[wtdb.TowerID]*wtdb.Tower
|
||||
|
||||
nextIndex uint32
|
||||
indexes map[uint64]uint32
|
||||
indexes map[wtdb.TowerID]uint32
|
||||
}
|
||||
|
||||
// NewClientDB initializes a new mock ClientDB.
|
||||
func NewClientDB() *ClientDB {
|
||||
return &ClientDB{
|
||||
sweepPkScripts: make(map[lnwire.ChannelID][]byte),
|
||||
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
|
||||
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
|
||||
towerIndex: make(map[towerPK]uint64),
|
||||
towers: make(map[uint64]*wtdb.Tower),
|
||||
indexes: make(map[uint64]uint32),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[wtdb.TowerID]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,9 +53,9 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||
tower = m.towers[towerID]
|
||||
tower.AddAddress(lnAddr.Address)
|
||||
} else {
|
||||
towerID = atomic.AddUint64(&m.nextTowerID, 1)
|
||||
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
|
||||
tower = &wtdb.Tower{
|
||||
ID: towerID,
|
||||
ID: wtdb.TowerID(towerID),
|
||||
IdentityKey: lnAddr.IdentityKey,
|
||||
Addresses: []net.Addr{lnAddr.Address},
|
||||
}
|
||||
@ -65,16 +64,16 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||
m.towerIndex[towerPubKey] = towerID
|
||||
m.towers[towerID] = tower
|
||||
|
||||
return tower, nil
|
||||
return copyTower(tower), nil
|
||||
}
|
||||
|
||||
// LoadTower retrieves a tower by its tower ID.
|
||||
func (m *ClientDB) LoadTower(towerID uint64) (*wtdb.Tower, error) {
|
||||
func (m *ClientDB) LoadTower(towerID wtdb.TowerID) (*wtdb.Tower, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if tower, ok := m.towers[towerID]; ok {
|
||||
return tower, nil
|
||||
return copyTower(tower), nil
|
||||
}
|
||||
|
||||
return nil, wtdb.ErrTowerNotFound
|
||||
@ -106,6 +105,11 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Ensure that we aren't overwriting an existing session.
|
||||
if _, ok := m.activeSessions[session.ID]; ok {
|
||||
return wtdb.ErrClientSessionAlreadyExists
|
||||
}
|
||||
|
||||
// Ensure that a session key index has been reserved for this tower.
|
||||
keyIndex, ok := m.indexes[session.TowerID]
|
||||
if !ok {
|
||||
@ -122,14 +126,16 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
delete(m.indexes, session.TowerID)
|
||||
|
||||
m.activeSessions[session.ID] = &wtdb.ClientSession{
|
||||
TowerID: session.TowerID,
|
||||
KeyIndex: session.KeyIndex,
|
||||
ID: session.ID,
|
||||
Policy: session.Policy,
|
||||
SeqNum: session.SeqNum,
|
||||
TowerLastApplied: session.TowerLastApplied,
|
||||
RewardPkScript: cloneBytes(session.RewardPkScript),
|
||||
CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate),
|
||||
ID: session.ID,
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
SeqNum: session.SeqNum,
|
||||
TowerLastApplied: session.TowerLastApplied,
|
||||
TowerID: session.TowerID,
|
||||
KeyIndex: session.KeyIndex,
|
||||
Policy: session.Policy,
|
||||
RewardPkScript: cloneBytes(session.RewardPkScript),
|
||||
},
|
||||
CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
|
||||
AckedUpdates: make(map[uint16]wtdb.BackupID),
|
||||
}
|
||||
|
||||
@ -141,7 +147,7 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
// CreateClientSession is invoked for that tower and index, at which point a new
|
||||
// index for that tower can be reserved. Multiple calls to this method before
|
||||
// CreateClientSession is invoked should return the same index.
|
||||
func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) {
|
||||
func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID) (uint32, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@ -149,17 +155,16 @@ func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) {
|
||||
return index, nil
|
||||
}
|
||||
|
||||
m.nextIndex++
|
||||
index := m.nextIndex
|
||||
m.indexes[towerID] = index
|
||||
|
||||
m.nextIndex++
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||
// seqNum). This allows the client to retransmit this update on startup.
|
||||
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||
update *wtdb.CommittedUpdate) (uint16, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
@ -172,25 +177,26 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
}
|
||||
|
||||
// Check if an update has already been committed for this state.
|
||||
dbUpdate, ok := session.CommittedUpdates[seqNum]
|
||||
if ok {
|
||||
// If the breach hint matches, we'll just return the last
|
||||
// applied value so the client can retransmit.
|
||||
if dbUpdate.Hint == update.Hint {
|
||||
return session.TowerLastApplied, nil
|
||||
}
|
||||
for _, dbUpdate := range session.CommittedUpdates {
|
||||
if dbUpdate.SeqNum == update.SeqNum {
|
||||
// If the breach hint matches, we'll just return the
|
||||
// last applied value so the client can retransmit.
|
||||
if dbUpdate.Hint == update.Hint {
|
||||
return session.TowerLastApplied, nil
|
||||
}
|
||||
|
||||
// Otherwise, fail since the breach hint doesn't match.
|
||||
return 0, wtdb.ErrUpdateAlreadyCommitted
|
||||
// Otherwise, fail since the breach hint doesn't match.
|
||||
return 0, wtdb.ErrUpdateAlreadyCommitted
|
||||
}
|
||||
}
|
||||
|
||||
// Sequence number must increment.
|
||||
if seqNum != session.SeqNum+1 {
|
||||
if update.SeqNum != session.SeqNum+1 {
|
||||
return 0, wtdb.ErrCommitUnorderedUpdate
|
||||
}
|
||||
|
||||
// Save the update and increment the sequence number.
|
||||
session.CommittedUpdates[seqNum] = update
|
||||
session.CommittedUpdates = append(session.CommittedUpdates, *update)
|
||||
session.SeqNum++
|
||||
|
||||
return session.TowerLastApplied, nil
|
||||
@ -209,13 +215,6 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
|
||||
return wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Retrieve the committed update, failing if none is found. We should
|
||||
// only receive acks for state updates that we send.
|
||||
update, ok := session.CommittedUpdates[seqNum]
|
||||
if !ok {
|
||||
return wtdb.ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
// Ensure the returned last applied value does not exceed the highest
|
||||
// allocated sequence number.
|
||||
if lastApplied > session.SeqNum {
|
||||
@ -228,40 +227,64 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
|
||||
return wtdb.ErrLastAppliedReversion
|
||||
}
|
||||
|
||||
// Finally, remove the committed update from disk and mark the update as
|
||||
// acked. The tower last applied value is also recorded to send along
|
||||
// with the next update.
|
||||
delete(session.CommittedUpdates, seqNum)
|
||||
session.AckedUpdates[seqNum] = update.BackupID
|
||||
session.TowerLastApplied = lastApplied
|
||||
// Retrieve the committed update, failing if none is found. We should
|
||||
// only receive acks for state updates that we send.
|
||||
updates := session.CommittedUpdates
|
||||
for i, update := range updates {
|
||||
if update.SeqNum != seqNum {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil
|
||||
// Remove the committed update from disk and mark the update as
|
||||
// acked. The tower last applied value is also recorded to send
|
||||
// along with the next update.
|
||||
copy(updates[:i], updates[i+1:])
|
||||
updates[len(updates)-1] = wtdb.CommittedUpdate{}
|
||||
session.CommittedUpdates = updates[:len(updates)-1]
|
||||
|
||||
session.AckedUpdates[seqNum] = update.BackupID
|
||||
session.TowerLastApplied = lastApplied
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return wtdb.ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
// FetchChanPkScripts returns the set of sweep pkscripts known for all channels.
|
||||
// This allows the client to cache them in memory on startup.
|
||||
func (m *ClientDB) FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) {
|
||||
// FetchChanSummaries loads a mapping from all registered channels to their
|
||||
// channel summaries.
|
||||
func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
sweepPkScripts := make(map[lnwire.ChannelID][]byte)
|
||||
for chanID, pkScript := range m.sweepPkScripts {
|
||||
sweepPkScripts[chanID] = cloneBytes(pkScript)
|
||||
summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
|
||||
for chanID, summary := range m.summaries {
|
||||
summaries[chanID] = wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(summary.SweepPkScript),
|
||||
}
|
||||
}
|
||||
|
||||
return sweepPkScripts, nil
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
// AddChanPkScript sets a pkscript or sweeping funds from the channel or chanID.
|
||||
func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) error {
|
||||
// RegisterChannel registers a channel for use within the client database. For
|
||||
// now, all that is stored in the channel summary is the sweep pkscript that
|
||||
// we'd like any tower sweeps to pay into. In the future, this will be extended
|
||||
// to contain more info to allow the client efficiently request historical
|
||||
// states to be backed up under the client's active policy.
|
||||
func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
|
||||
sweepPkScript []byte) error {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.sweepPkScripts[chanID]; ok {
|
||||
return fmt.Errorf("pkscript for %x already exists", pkScript)
|
||||
if _, ok := m.summaries[chanID]; ok {
|
||||
return wtdb.ErrChannelAlreadyRegistered
|
||||
}
|
||||
|
||||
m.sweepPkScripts[chanID] = cloneBytes(pkScript)
|
||||
m.summaries[chanID] = wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(sweepPkScript),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -276,3 +299,14 @@ func cloneBytes(b []byte) []byte {
|
||||
|
||||
return bb
|
||||
}
|
||||
|
||||
func copyTower(tower *wtdb.Tower) *wtdb.Tower {
|
||||
t := &wtdb.Tower{
|
||||
ID: tower.ID,
|
||||
IdentityKey: tower.IdentityKey,
|
||||
Addresses: make([]net.Addr, len(tower.Addresses)),
|
||||
}
|
||||
copy(t.Addresses, tower.Addresses)
|
||||
|
||||
return t
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user