watchtower/wtdb: add ClientDB
This commit adds the full bbolt-backed client database as well as a set of unit tests to assert that it exactly implements the same behavior as the mock ClientDB.
This commit is contained in:
parent
b35a5b8892
commit
3be651b0b3
@ -1,18 +1,11 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrChannelAlreadyRegistered signals a duplicate attempt to
|
||||
// register a channel with the client database.
|
||||
ErrChannelAlreadyRegistered = errors.New("channel already registered")
|
||||
)
|
||||
|
||||
// ChannelSummaries is a map for a given channel id to it's ClientChanSummary.
|
||||
type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary
|
||||
|
||||
|
908
watchtower/wtdb/client_db.go
Normal file
908
watchtower/wtdb/client_db.go
Normal file
@ -0,0 +1,908 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
const (
|
||||
// clientDBName is the filename of client database.
|
||||
clientDBName = "wtclient.db"
|
||||
)
|
||||
|
||||
var (
|
||||
// cSessionKeyIndexBkt is a top-level bucket storing:
|
||||
// tower-id -> reserved-session-key-index (uint32).
|
||||
cSessionKeyIndexBkt = []byte("client-session-key-index-bucket")
|
||||
|
||||
// cChanSummaryBkt is a top-level bucket storing:
|
||||
// channel-id -> encoded ClientChanSummary.
|
||||
cChanSummaryBkt = []byte("client-channel-summary-bucket")
|
||||
|
||||
// cSessionBkt is a top-level bucket storing:
|
||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
||||
// => cSessionAcks => seqnum -> encoded BackupID
|
||||
cSessionBkt = []byte("client-session-bucket")
|
||||
|
||||
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
|
||||
// the ClientSession.
|
||||
cSessionBody = []byte("client-session-body")
|
||||
|
||||
// cSessionBody is a sub-bucket of cSessionBkt storing:
|
||||
// seqnum -> encoded CommittedUpdate.
|
||||
cSessionCommits = []byte("client-session-commits")
|
||||
|
||||
// cSessionAcks is a sub-bucket of cSessionBkt storing:
|
||||
// seqnum -> encoded BackupID.
|
||||
cSessionAcks = []byte("client-session-acks")
|
||||
|
||||
// cTowerBkt is a top-level bucket storing:
|
||||
// tower-id -> encoded Tower.
|
||||
cTowerBkt = []byte("client-tower-bucket")
|
||||
|
||||
// cTowerIndexBkt is a top-level bucket storing:
|
||||
// tower-pubkey -> tower-id.
|
||||
cTowerIndexBkt = []byte("client-tower-index-bucket")
|
||||
|
||||
// ErrTowerNotFound signals that the target tower was not found in the
|
||||
// database.
|
||||
ErrTowerNotFound = errors.New("tower not found")
|
||||
|
||||
// ErrCorruptClientSession signals that the client session's on-disk
|
||||
// structure deviates from what is expected.
|
||||
ErrCorruptClientSession = errors.New("client session corrupted")
|
||||
|
||||
// ErrClientSessionAlreadyExists signals an attempt to reinsert a client
|
||||
// session that has already been created.
|
||||
ErrClientSessionAlreadyExists = errors.New(
|
||||
"client session already exists",
|
||||
)
|
||||
|
||||
// ErrChannelAlreadyRegistered signals a duplicate attempt to register a
|
||||
// channel with the client database.
|
||||
ErrChannelAlreadyRegistered = errors.New("channel already registered")
|
||||
|
||||
// ErrChannelNotRegistered signals a channel has not yet been registered
|
||||
// in the client database.
|
||||
ErrChannelNotRegistered = errors.New("channel not registered")
|
||||
|
||||
// ErrClientSessionNotFound signals that the requested client session
|
||||
// was not found in the database.
|
||||
ErrClientSessionNotFound = errors.New("client session not found")
|
||||
|
||||
// ErrUpdateAlreadyCommitted signals that the chosen sequence number has
|
||||
// already been committed to an update with a different breach hint.
|
||||
ErrUpdateAlreadyCommitted = errors.New("update already committed")
|
||||
|
||||
// ErrCommitUnorderedUpdate signals the client tried to commit a
|
||||
// sequence number other than the next unallocated sequence number.
|
||||
ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic")
|
||||
|
||||
// ErrCommittedUpdateNotFound signals that the tower tried to ACK a
|
||||
// sequence number that has not yet been allocated by the client.
|
||||
ErrCommittedUpdateNotFound = errors.New("committed update not found")
|
||||
|
||||
// ErrUnallocatedLastApplied signals that the tower tried to provide a
|
||||
// LastApplied value greater than any allocated sequence number.
|
||||
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
|
||||
"greater than allocated seqnum")
|
||||
|
||||
// ErrNoReservedKeyIndex signals that a client session could not be
|
||||
// created because no session key index was reserved.
|
||||
ErrNoReservedKeyIndex = errors.New("key index not reserved")
|
||||
|
||||
// ErrIncorrectKeyIndex signals that the client session could not be
|
||||
// created because session key index differs from the reserved key
|
||||
// index.
|
||||
ErrIncorrectKeyIndex = errors.New("incorrect key index")
|
||||
)
|
||||
|
||||
// ClientDB is single database providing a persistent storage engine for the
|
||||
// wtclient.
|
||||
type ClientDB struct {
|
||||
db *bbolt.DB
|
||||
dbPath string
|
||||
}
|
||||
|
||||
// OpenClientDB opens the client database given the path to the database's
|
||||
// directory. If no such database exists, this method will initialize a fresh
|
||||
// one using the latest version number and bucket structure. If a database
|
||||
// exists but has a lower version number than the current version, any necessary
|
||||
// migrations will be applied before returning. Any attempt to open a database
|
||||
// with a version number higher that the latest version will fail to prevent
|
||||
// accidental reversion.
|
||||
func OpenClientDB(dbPath string) (*ClientDB, error) {
|
||||
bdb, firstInit, err := createDBIfNotExist(dbPath, clientDBName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientDB := &ClientDB{
|
||||
db: bdb,
|
||||
dbPath: dbPath,
|
||||
}
|
||||
|
||||
err = initOrSyncVersions(clientDB, firstInit, clientDBVersions)
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now that the database version fully consistent with our latest known
|
||||
// version, ensure that all top-level buckets known to this version are
|
||||
// initialized. This allows us to assume their presence throughout all
|
||||
// operations. If an known top-level bucket is expected to exist but is
|
||||
// missing, this will trigger a ErrUninitializedDB error.
|
||||
err = clientDB.db.Update(initClientDBBuckets)
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return clientDB, nil
|
||||
}
|
||||
|
||||
// initClientDBBuckets creates all top-level buckets required to handle database
|
||||
// operations required by the latest version.
|
||||
func initClientDBBuckets(tx *bbolt.Tx) error {
|
||||
buckets := [][]byte{
|
||||
cSessionKeyIndexBkt,
|
||||
cChanSummaryBkt,
|
||||
cSessionBkt,
|
||||
cTowerBkt,
|
||||
cTowerIndexBkt,
|
||||
}
|
||||
|
||||
for _, bucket := range buckets {
|
||||
_, err := tx.CreateBucketIfNotExists(bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// bdb returns the backing bbolt.DB instance.
|
||||
//
|
||||
// NOTE: Part of the versionedDB interface.
|
||||
func (c *ClientDB) bdb() *bbolt.DB {
|
||||
return c.db
|
||||
}
|
||||
|
||||
// Version returns the database's current version number.
|
||||
//
|
||||
// NOTE: Part of the versionedDB interface.
|
||||
func (c *ClientDB) Version() (uint32, error) {
|
||||
var version uint32
|
||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||
var err error
|
||||
version, err = getDBVersion(tx)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database.
|
||||
func (c *ClientDB) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
// CreateTower initializes a database entry with the given lightning address. If
|
||||
// the tower exists, the address is append to the list of all addresses used to
|
||||
// that tower previously.
|
||||
func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
||||
var towerPubKey [33]byte
|
||||
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
|
||||
|
||||
var tower *Tower
|
||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
||||
towerIndex := tx.Bucket(cTowerIndexBkt)
|
||||
if towerIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
towers := tx.Bucket(cTowerBkt)
|
||||
if towers == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Check if the tower index already knows of this pubkey.
|
||||
towerIDBytes := towerIndex.Get(towerPubKey[:])
|
||||
if len(towerIDBytes) == 8 {
|
||||
// The tower already exists, deserialize the existing
|
||||
// record.
|
||||
var err error
|
||||
tower, err = getTower(towers, towerIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the new address to the existing tower. If the
|
||||
// address is a duplicate, this will result in no
|
||||
// change.
|
||||
tower.AddAddress(lnAddr.Address)
|
||||
} else {
|
||||
// No such tower exists, create a new tower id for our
|
||||
// new tower. The error is unhandled since NextSequence
|
||||
// never fails in an Update.
|
||||
towerID, _ := towerIndex.NextSequence()
|
||||
|
||||
tower = &Tower{
|
||||
ID: TowerID(towerID),
|
||||
IdentityKey: lnAddr.IdentityKey,
|
||||
Addresses: []net.Addr{lnAddr.Address},
|
||||
}
|
||||
|
||||
towerIDBytes = tower.ID.Bytes()
|
||||
|
||||
// Since this tower is new, record the mapping from
|
||||
// tower pubkey to tower id in the tower index.
|
||||
err := towerIndex.Put(towerPubKey[:], towerIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Store the new or updated tower under its tower id.
|
||||
return putTower(towers, tower)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tower, nil
|
||||
}
|
||||
|
||||
// LoadTower retrieves a tower by its tower ID.
|
||||
func (c *ClientDB) LoadTower(towerID TowerID) (*Tower, error) {
|
||||
var tower *Tower
|
||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||
towers := tx.Bucket(cTowerBkt)
|
||||
if towers == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
var err error
|
||||
tower, err = getTower(towers, towerID.Bytes())
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tower, nil
|
||||
}
|
||||
|
||||
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||
// particular tower id. The index is reserved for that tower until
|
||||
// CreateClientSession is invoked for that tower and index, at which point a new
|
||||
// index for that tower can be reserved. Multiple calls to this method before
|
||||
// CreateClientSession is invoked should return the same index.
|
||||
func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
|
||||
var index uint32
|
||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
||||
keyIndex := tx.Bucket(cSessionKeyIndexBkt)
|
||||
if keyIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Check the session key index to see if a key has already been
|
||||
// reserved for this tower. If so, we'll deserialize and return
|
||||
// the index directly.
|
||||
towerIDBytes := towerID.Bytes()
|
||||
indexBytes := keyIndex.Get(towerIDBytes)
|
||||
if len(indexBytes) == 4 {
|
||||
index = byteOrder.Uint32(indexBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, generate a new session key index since the node
|
||||
// doesn't already have reserved index. The error is ignored
|
||||
// since NextSequence can't fail inside Update.
|
||||
index64, _ := keyIndex.NextSequence()
|
||||
|
||||
// As a sanity check, assert that the index is still in the
|
||||
// valid range of unhardened pubkeys. In the future, we should
|
||||
// move to only using hardened keys, and this will prevent any
|
||||
// overlap from occurring until then. This also prevents us from
|
||||
// overflowing uint32s.
|
||||
if index64 > math.MaxInt32 {
|
||||
return fmt.Errorf("exhausted session key indexes")
|
||||
}
|
||||
|
||||
index = uint32(index64)
|
||||
|
||||
var indexBuf [4]byte
|
||||
byteOrder.PutUint32(indexBuf[:], index)
|
||||
|
||||
// Record the reserved session key index under this tower's id.
|
||||
return keyIndex.Put(towerIDBytes, indexBuf[:])
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// CreateClientSession records a newly negotiated client session in the set of
|
||||
// active sessions. The session can be identified by its SessionID.
|
||||
func (c *ClientDB) CreateClientSession(session *ClientSession) error {
|
||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
||||
keyIndexes := tx.Bucket(cSessionKeyIndexBkt)
|
||||
if keyIndexes == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
sessions := tx.Bucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Check that client session with this session id doesn't
|
||||
// already exist.
|
||||
existingSessionBytes := sessions.Bucket(session.ID[:])
|
||||
if existingSessionBytes != nil {
|
||||
return ErrClientSessionAlreadyExists
|
||||
}
|
||||
|
||||
// Check that this tower has a reserved key index.
|
||||
towerIDBytes := session.TowerID.Bytes()
|
||||
keyIndexBytes := keyIndexes.Get(towerIDBytes)
|
||||
if len(keyIndexBytes) != 4 {
|
||||
return ErrNoReservedKeyIndex
|
||||
}
|
||||
|
||||
// Assert that the key index of the inserted session matches the
|
||||
// reserved session key index.
|
||||
index := byteOrder.Uint32(keyIndexBytes)
|
||||
if index != session.KeyIndex {
|
||||
return ErrIncorrectKeyIndex
|
||||
}
|
||||
|
||||
// Remove the key index reservation.
|
||||
err := keyIndexes.Delete(towerIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, write the client session's body in the sessions
|
||||
// bucket.
|
||||
return putClientSessionBody(sessions, session)
|
||||
})
|
||||
}
|
||||
|
||||
// ListClientSessions returns the set of all client sessions known to the db.
|
||||
func (c *ClientDB) ListClientSessions() (map[SessionID]*ClientSession, error) {
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
return sessions.ForEach(func(k, _ []byte) error {
|
||||
// We'll load the full client session since the client
|
||||
// will need the CommittedUpdates and AckedUpdates on
|
||||
// startup to resume committed updates and compute the
|
||||
// highest known commit height for each channel.
|
||||
session, err := getClientSession(sessions, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientSessions[session.ID] = session
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return clientSessions, nil
|
||||
}
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to their
|
||||
// channel summaries.
|
||||
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
|
||||
summaries := make(map[lnwire.ChannelID]ClientChanSummary)
|
||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
||||
chanSummaries := tx.Bucket(cChanSummaryBkt)
|
||||
if chanSummaries == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
return chanSummaries.ForEach(func(k, v []byte) error {
|
||||
var chanID lnwire.ChannelID
|
||||
copy(chanID[:], k)
|
||||
|
||||
var summary ClientChanSummary
|
||||
err := summary.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
summaries[chanID] = summary
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
// RegisterChannel registers a channel for use within the client database. For
|
||||
// now, all that is stored in the channel summary is the sweep pkscript that
|
||||
// we'd like any tower sweeps to pay into. In the future, this will be extended
|
||||
// to contain more info to allow the client efficiently request historical
|
||||
// states to be backed up under the client's active policy.
|
||||
func (c *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
|
||||
sweepPkScript []byte) error {
|
||||
|
||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
||||
chanSummaries := tx.Bucket(cChanSummaryBkt)
|
||||
if chanSummaries == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
_, err := getChanSummary(chanSummaries, chanID)
|
||||
switch {
|
||||
|
||||
// Summary already exists.
|
||||
case err == nil:
|
||||
return ErrChannelAlreadyRegistered
|
||||
|
||||
// Channel is not registered, proceed with registration.
|
||||
case err == ErrChannelNotRegistered:
|
||||
|
||||
// Unexpected error.
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
summary := ClientChanSummary{
|
||||
SweepPkScript: sweepPkScript,
|
||||
}
|
||||
|
||||
return putChanSummary(chanSummaries, chanID, &summary)
|
||||
})
|
||||
}
|
||||
|
||||
// MarkBackupIneligible records that the state identified by the (channel id,
|
||||
// commit height) tuple was ineligible for being backed up under the current
|
||||
// policy. This state can be retried later under a different policy.
|
||||
func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID,
|
||||
commitHeight uint64) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||
// seqNum). This allows the client to retransmit this update on startup.
|
||||
func (c *ClientDB) CommitUpdate(id *SessionID,
|
||||
update *CommittedUpdate) (uint16, error) {
|
||||
|
||||
var lastApplied uint16
|
||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// We'll only load the ClientSession body for performance, since
|
||||
// we primarily need to inspect its SeqNum and TowerLastApplied
|
||||
// fields. The CommittedUpdates will be modified on disk
|
||||
// directly.
|
||||
session, err := getClientSessionBody(sessions, id[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Can't fail if the above didn't fail.
|
||||
sessionBkt := sessions.Bucket(id[:])
|
||||
|
||||
// Ensure the session commits sub-bucket is initialized.
|
||||
sessionCommits, err := sessionBkt.CreateBucketIfNotExists(
|
||||
cSessionCommits,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var seqNumBuf [2]byte
|
||||
byteOrder.PutUint16(seqNumBuf[:], update.SeqNum)
|
||||
|
||||
// Check to see if a committed update already exists for this
|
||||
// sequence number.
|
||||
committedUpdateBytes := sessionCommits.Get(seqNumBuf[:])
|
||||
if committedUpdateBytes != nil {
|
||||
var dbUpdate CommittedUpdate
|
||||
err := dbUpdate.Decode(
|
||||
bytes.NewReader(committedUpdateBytes),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If an existing committed update has a different hint,
|
||||
// we'll reject this newer update.
|
||||
if dbUpdate.Hint != update.Hint {
|
||||
return ErrUpdateAlreadyCommitted
|
||||
}
|
||||
|
||||
// Otherwise, capture the last applied value and
|
||||
// succeed.
|
||||
lastApplied = session.TowerLastApplied
|
||||
return nil
|
||||
}
|
||||
|
||||
// There's no committed update for this sequence number, ensure
|
||||
// that we are committing the next unallocated one.
|
||||
if update.SeqNum != session.SeqNum+1 {
|
||||
return ErrCommitUnorderedUpdate
|
||||
}
|
||||
|
||||
// Increment the session's sequence number and store the updated
|
||||
// client session.
|
||||
//
|
||||
// TODO(conner): split out seqnum and last applied own bucket to
|
||||
// eliminate serialization of full struct during CommitUpdate?
|
||||
// Can also read/write directly to byes [:2] without migration.
|
||||
session.SeqNum++
|
||||
err = putClientSessionBody(sessions, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encode and store the committed update in the sessionCommits
|
||||
// sub-bucket under the requested sequence number.
|
||||
var b bytes.Buffer
|
||||
err = update.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = sessionCommits.Put(seqNumBuf[:], b.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, capture the session's last applied value so it can
|
||||
// be sent in the next state update to the tower.
|
||||
lastApplied = session.TowerLastApplied
|
||||
|
||||
return nil
|
||||
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return lastApplied, nil
|
||||
}
|
||||
|
||||
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
|
||||
// removes the update from the set of committed updates, and validates the
|
||||
// lastApplied value returned from the tower.
|
||||
func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
||||
lastApplied uint16) error {
|
||||
|
||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// We'll only load the ClientSession body for performance, since
|
||||
// we primarily need to inspect its SeqNum and TowerLastApplied
|
||||
// fields. The CommittedUpdates and AckedUpdates will be
|
||||
// modified on disk directly.
|
||||
session, err := getClientSessionBody(sessions, id[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the tower has acked a sequence number beyond our highest
|
||||
// sequence number, fail.
|
||||
if lastApplied > session.SeqNum {
|
||||
return ErrUnallocatedLastApplied
|
||||
}
|
||||
|
||||
// If the tower acked with a lower sequence number than it gave
|
||||
// us prior, fail.
|
||||
if lastApplied < session.TowerLastApplied {
|
||||
return ErrLastAppliedReversion
|
||||
}
|
||||
|
||||
// TODO(conner): split out seqnum and last applied own bucket to
|
||||
// eliminate serialization of full struct during AckUpdate? Can
|
||||
// also read/write directly to byes [2:4] without migration.
|
||||
session.TowerLastApplied = lastApplied
|
||||
|
||||
// Write the client session with the updated last applied value.
|
||||
err = putClientSessionBody(sessions, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Can't fail because of getClientSession succeeded.
|
||||
sessionBkt := sessions.Bucket(id[:])
|
||||
|
||||
// If the commits sub-bucket doesn't exist, there can't possibly
|
||||
// be a corresponding committed update to remove.
|
||||
sessionCommits := sessionBkt.Bucket(cSessionCommits)
|
||||
if sessionCommits == nil {
|
||||
return ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
var seqNumBuf [2]byte
|
||||
byteOrder.PutUint16(seqNumBuf[:], seqNum)
|
||||
|
||||
// Assert that a committed update exists for this sequence
|
||||
// number.
|
||||
committedUpdateBytes := sessionCommits.Get(seqNumBuf[:])
|
||||
if committedUpdateBytes == nil {
|
||||
return ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
var committedUpdate CommittedUpdate
|
||||
err = committedUpdate.Decode(
|
||||
bytes.NewReader(committedUpdateBytes),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the corresponding committed update.
|
||||
err = sessionCommits.Delete(seqNumBuf[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that the session acks sub-bucket is initialized so we
|
||||
// can insert an entry.
|
||||
sessionAcks, err := sessionBkt.CreateBucketIfNotExists(
|
||||
cSessionAcks,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The session acks only need to track the backup id of the
|
||||
// update, so we can discard the blob and hint.
|
||||
var b bytes.Buffer
|
||||
err = committedUpdate.BackupID.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, insert the ack into the sessionAcks sub-bucket.
|
||||
return sessionAcks.Put(seqNumBuf[:], b.Bytes())
|
||||
})
|
||||
}
|
||||
|
||||
// getClientSessionBody loads the body of a ClientSession from the sessions
|
||||
// bucket corresponding to the serialized session id. This does not deserialize
|
||||
// the CommittedUpdates or AckUpdates associated with the session. If the caller
|
||||
// requires this info, use getClientSession.
|
||||
func getClientSessionBody(sessions *bbolt.Bucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
|
||||
sessionBkt := sessions.Bucket(idBytes)
|
||||
if sessionBkt == nil {
|
||||
return nil, ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Should never have a sessionBkt without also having its body.
|
||||
sessionBody := sessionBkt.Get(cSessionBody)
|
||||
if sessionBody == nil {
|
||||
return nil, ErrCorruptClientSession
|
||||
}
|
||||
|
||||
var session ClientSession
|
||||
copy(session.ID[:], idBytes)
|
||||
|
||||
err := session.Decode(bytes.NewReader(sessionBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// getClientSession loads the full ClientSession associated with the serialized
|
||||
// session id. This method populates the CommittedUpdates and AckUpdates in
|
||||
// addition to the ClientSession's body.
|
||||
func getClientSession(sessions *bbolt.Bucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
|
||||
session, err := getClientSessionBody(sessions, idBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the committed updates for this session.
|
||||
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the acked updates for this session.
|
||||
ackedUpdates, err := getClientSessionAcks(sessions, idBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.CommittedUpdates = commitedUpdates
|
||||
session.AckedUpdates = ackedUpdates
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// getClientSessionCommits retrieves all committed updates for the session
|
||||
// identified by the serialized session id.
|
||||
func getClientSessionCommits(sessions *bbolt.Bucket,
|
||||
idBytes []byte) ([]CommittedUpdate, error) {
|
||||
|
||||
// Can't fail because client session body has already been read.
|
||||
sessionBkt := sessions.Bucket(idBytes)
|
||||
|
||||
// Initialize commitedUpdates so that we can return an initialized map
|
||||
// if no committed updates exist.
|
||||
committedUpdates := make([]CommittedUpdate, 0)
|
||||
|
||||
sessionCommits := sessionBkt.Bucket(cSessionCommits)
|
||||
if sessionCommits == nil {
|
||||
return committedUpdates, nil
|
||||
}
|
||||
|
||||
err := sessionCommits.ForEach(func(k, v []byte) error {
|
||||
var committedUpdate CommittedUpdate
|
||||
err := committedUpdate.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
committedUpdate.SeqNum = byteOrder.Uint16(k)
|
||||
|
||||
committedUpdates = append(committedUpdates, committedUpdate)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return committedUpdates, nil
|
||||
}
|
||||
|
||||
// getClientSessionAcks retrieves all acked updates for the session identified
|
||||
// by the serialized session id.
|
||||
func getClientSessionAcks(sessions *bbolt.Bucket,
|
||||
idBytes []byte) (map[uint16]BackupID, error) {
|
||||
|
||||
// Can't fail because client session body has already been read.
|
||||
sessionBkt := sessions.Bucket(idBytes)
|
||||
|
||||
// Initialize ackedUpdates so that we can return an initialized map if
|
||||
// no acked updates exist.
|
||||
ackedUpdates := make(map[uint16]BackupID)
|
||||
|
||||
sessionAcks := sessionBkt.Bucket(cSessionAcks)
|
||||
if sessionAcks == nil {
|
||||
return ackedUpdates, nil
|
||||
}
|
||||
|
||||
err := sessionAcks.ForEach(func(k, v []byte) error {
|
||||
seqNum := byteOrder.Uint16(k)
|
||||
|
||||
var backupID BackupID
|
||||
err := backupID.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ackedUpdates[seqNum] = backupID
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ackedUpdates, nil
|
||||
}
|
||||
|
||||
// putClientSessionBody stores the body of the ClientSession (everything but the
|
||||
// CommittedUpdates and AckedUpdates).
|
||||
func putClientSessionBody(sessions *bbolt.Bucket,
|
||||
session *ClientSession) error {
|
||||
|
||||
sessionBkt, err := sessions.CreateBucketIfNotExists(session.ID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
err = session.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sessionBkt.Put(cSessionBody, b.Bytes())
|
||||
}
|
||||
|
||||
// getChanSummary loads a ClientChanSummary for the passed chanID.
|
||||
func getChanSummary(chanSummaries *bbolt.Bucket,
|
||||
chanID lnwire.ChannelID) (*ClientChanSummary, error) {
|
||||
|
||||
chanSummaryBytes := chanSummaries.Get(chanID[:])
|
||||
if chanSummaryBytes == nil {
|
||||
return nil, ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
var summary ClientChanSummary
|
||||
err := summary.Decode(bytes.NewReader(chanSummaryBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &summary, nil
|
||||
}
|
||||
|
||||
// putChanSummary stores a ClientChanSummary for the passed chanID.
|
||||
func putChanSummary(chanSummaries *bbolt.Bucket, chanID lnwire.ChannelID,
|
||||
summary *ClientChanSummary) error {
|
||||
|
||||
var b bytes.Buffer
|
||||
err := summary.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return chanSummaries.Put(chanID[:], b.Bytes())
|
||||
}
|
||||
|
||||
// getTower loads a Tower identified by its serialized tower id.
|
||||
func getTower(towers *bbolt.Bucket, id []byte) (*Tower, error) {
|
||||
towerBytes := towers.Get(id)
|
||||
if towerBytes == nil {
|
||||
return nil, ErrTowerNotFound
|
||||
}
|
||||
|
||||
var tower Tower
|
||||
err := tower.Decode(bytes.NewReader(towerBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tower.ID = TowerIDFromBytes(id)
|
||||
|
||||
return &tower, nil
|
||||
}
|
||||
|
||||
// putTower stores a Tower identified by its serialized tower id.
|
||||
func putTower(towers *bbolt.Bucket, tower *Tower) error {
|
||||
var b bytes.Buffer
|
||||
err := tower.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return towers.Put(tower.ID.Bytes(), b.Bytes())
|
||||
}
|
688
watchtower/wtdb/client_db_test.go
Normal file
688
watchtower/wtdb/client_db_test.go
Normal file
@ -0,0 +1,688 @@
|
||||
package wtdb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
|
||||
// clientDBInit is a closure used to initialize a wtclient.DB instance its
|
||||
// cleanup function.
|
||||
type clientDBInit func(t *testing.T) (wtclient.DB, func())
|
||||
|
||||
type clientDBHarness struct {
|
||||
t *testing.T
|
||||
db wtclient.DB
|
||||
}
|
||||
|
||||
func newClientDBHarness(t *testing.T, init clientDBInit) (*clientDBHarness, func()) {
|
||||
db, cleanup := init(t)
|
||||
|
||||
h := &clientDBHarness{
|
||||
t: t,
|
||||
db: db,
|
||||
}
|
||||
|
||||
return h, cleanup
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.CreateClientSession(session)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected create client session error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) listSessions() map[wtdb.SessionID]*wtdb.ClientSession {
|
||||
h.t.Helper()
|
||||
|
||||
sessions, err := h.db.ListClientSessions()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to list client sessions: %v", err)
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, expErr error) uint32 {
|
||||
h.t.Helper()
|
||||
|
||||
index, err := h.db.NextSessionKeyIndex(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected next session key index error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
h.t.Fatalf("next key index should never be 0")
|
||||
}
|
||||
|
||||
return index
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
|
||||
expErr error) *wtdb.Tower {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.CreateTower(lnAddr)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err)
|
||||
}
|
||||
|
||||
if tower.ID == 0 {
|
||||
h.t.Fatalf("tower id should never be 0")
|
||||
}
|
||||
|
||||
return tower
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) loadTower(id wtdb.TowerID, expErr error) *wtdb.Tower {
|
||||
h.t.Helper()
|
||||
|
||||
tower, err := h.db.LoadTower(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
|
||||
}
|
||||
|
||||
return tower
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary {
|
||||
h.t.Helper()
|
||||
|
||||
summaries, err := h.db.FetchChanSummaries()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch chan summaries: %v", err)
|
||||
}
|
||||
|
||||
return summaries
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID,
|
||||
sweepPkScript []byte, expErr error) {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.RegisterChannel(chanID, sweepPkScript)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected register channel error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
|
||||
update *wtdb.CommittedUpdate, expErr error) uint16 {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
lastApplied, err := h.db.CommitUpdate(id, update)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected commit update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
|
||||
return lastApplied
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||
lastApplied uint16, expErr error) {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.AckUpdate(id, seqNum, lastApplied)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected commit update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// testCreateClientSession asserts various conditions regarding the creation of
|
||||
// a new ClientSession. The test asserts:
|
||||
// - client sessions can only be created if a session key index is reserved.
|
||||
// - client sessions cannot be created with an incorrect session key index .
|
||||
// - inserting duplicate sessions fails.
|
||||
func testCreateClientSession(h *clientDBHarness) {
|
||||
// Create a test client session to insert.
|
||||
session := &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: wtdb.TowerID(3),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 100,
|
||||
},
|
||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
ID: wtdb.SessionID([33]byte{0x01}),
|
||||
}
|
||||
|
||||
// First, assert that this session is not already present in the
|
||||
// database.
|
||||
if _, ok := h.listSessions()[session.ID]; ok {
|
||||
h.t.Fatalf("session for id %x should not exist yet", session.ID)
|
||||
}
|
||||
|
||||
// Attempting to insert the client session without reserving a session
|
||||
// key index should fail.
|
||||
h.insertSession(session, wtdb.ErrNoReservedKeyIndex)
|
||||
|
||||
// Now, reserve a session key for this tower.
|
||||
keyIndex := h.nextKeyIndex(session.TowerID, nil)
|
||||
|
||||
// The client session hasn't been updated with the reserved key index
|
||||
// (since it's still zero). Inserting should fail due to the mismatch.
|
||||
h.insertSession(session, wtdb.ErrIncorrectKeyIndex)
|
||||
|
||||
// Reserve another key for the same index. Since no session has been
|
||||
// successfully created, it should return the same index to maintain
|
||||
// idempotency across restarts.
|
||||
keyIndex2 := h.nextKeyIndex(session.TowerID, nil)
|
||||
if keyIndex != keyIndex2 {
|
||||
h.t.Fatalf("next key index should be idempotent: want: %v, "+
|
||||
"got %v", keyIndex, keyIndex2)
|
||||
}
|
||||
|
||||
// Now, set the client session's key index so that it is proper and
|
||||
// insert it. This should succeed.
|
||||
session.KeyIndex = keyIndex
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Verify that the session now exists in the database.
|
||||
if _, ok := h.listSessions()[session.ID]; !ok {
|
||||
h.t.Fatalf("session for id %x should exist now", session.ID)
|
||||
}
|
||||
|
||||
// Attempt to insert the session again, which should fail due to the
|
||||
// session already existing.
|
||||
h.insertSession(session, wtdb.ErrClientSessionAlreadyExists)
|
||||
|
||||
// Finally, assert that reserving another key index succeeds with a
|
||||
// different key index, now that the first one has been finalized.
|
||||
keyIndex3 := h.nextKeyIndex(session.TowerID, nil)
|
||||
if keyIndex == keyIndex3 {
|
||||
h.t.Fatalf("key index still reserved after creating session")
|
||||
}
|
||||
}
|
||||
|
||||
// testCreateTower asserts the behavior of creating new Tower objects within the
|
||||
// database, and that the latest address is always prepended to the list of
|
||||
// known addresses for the tower.
|
||||
func testCreateTower(h *clientDBHarness) {
|
||||
// Test that loading a tower with an arbitrary tower id fails.
|
||||
h.loadTower(20, wtdb.ErrTowerNotFound)
|
||||
|
||||
pk, err := randPubKey()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to generate pubkey: %v", err)
|
||||
}
|
||||
|
||||
addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
lnAddr := &lnwire.NetAddress{
|
||||
IdentityKey: pk,
|
||||
Address: addr1,
|
||||
}
|
||||
|
||||
// Insert a random tower into the database.
|
||||
tower := h.createTower(lnAddr, nil)
|
||||
|
||||
// Load the tower from the database and assert that it matches the tower
|
||||
// we created.
|
||||
tower2 := h.loadTower(tower.ID, nil)
|
||||
if !reflect.DeepEqual(tower, tower2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
tower, tower2)
|
||||
}
|
||||
|
||||
// Insert the address again into the database. Since the address is the
|
||||
// same, this should result in an unmodified tower record.
|
||||
towerDupAddr := h.createTower(lnAddr, nil)
|
||||
if len(towerDupAddr.Addresses) != 1 {
|
||||
h.t.Fatalf("duplicate address should be deduped")
|
||||
}
|
||||
if !reflect.DeepEqual(tower, towerDupAddr) {
|
||||
h.t.Fatalf("mismatch towers, want: %v, got: %v",
|
||||
tower, towerDupAddr)
|
||||
}
|
||||
|
||||
// Generate a new address for this tower.
|
||||
addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
|
||||
|
||||
lnAddr2 := &lnwire.NetAddress{
|
||||
IdentityKey: pk,
|
||||
Address: addr2,
|
||||
}
|
||||
|
||||
// Insert the updated address, which should produce a tower with a new
|
||||
// address.
|
||||
towerNewAddr := h.createTower(lnAddr2, nil)
|
||||
|
||||
// Load the tower from the database, and assert that it matches the
|
||||
// tower returned from creation.
|
||||
towerNewAddr2 := h.loadTower(tower.ID, nil)
|
||||
if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
|
||||
h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
|
||||
towerNewAddr, towerNewAddr2)
|
||||
}
|
||||
|
||||
// Assert that there are now two addresses on the tower object.
|
||||
if len(towerNewAddr.Addresses) != 2 {
|
||||
h.t.Fatalf("new address should be added")
|
||||
}
|
||||
|
||||
// Finally, assert that the new address was prepended since it is deemed
|
||||
// fresher.
|
||||
if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) {
|
||||
h.t.Fatalf("new address should be prepended")
|
||||
}
|
||||
}
|
||||
|
||||
// testChanSummaries tests the process of a registering a channel and its
|
||||
// associated sweep pkscript.
|
||||
func testChanSummaries(h *clientDBHarness) {
|
||||
// First, assert that this channel is not already registered.
|
||||
var chanID lnwire.ChannelID
|
||||
if _, ok := h.fetchChanSummaries()[chanID]; ok {
|
||||
h.t.Fatalf("pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
}
|
||||
|
||||
// Generate a random sweep pkscript and register it for this channel.
|
||||
expPkScript := make([]byte, 22)
|
||||
if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil {
|
||||
h.t.Fatalf("unable to generate pkscript: %v", err)
|
||||
}
|
||||
h.registerChan(chanID, expPkScript, nil)
|
||||
|
||||
// Assert that the channel exists and that its sweep pkscript matches
|
||||
// the one we registered.
|
||||
summary, ok := h.fetchChanSummaries()[chanID]
|
||||
if !ok {
|
||||
h.t.Fatalf("pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
} else if bytes.Compare(expPkScript, summary.SweepPkScript) != 0 {
|
||||
h.t.Fatalf("pkscript mismatch, want: %x, got: %x",
|
||||
expPkScript, summary.SweepPkScript)
|
||||
}
|
||||
|
||||
// Finally, assert that re-registering the same channel produces a
|
||||
// failure.
|
||||
h.registerChan(chanID, expPkScript, wtdb.ErrChannelAlreadyRegistered)
|
||||
}
|
||||
|
||||
// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can
|
||||
func testCommitUpdate(h *clientDBHarness) {
|
||||
session := &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: wtdb.TowerID(3),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 100,
|
||||
},
|
||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
ID: wtdb.SessionID([33]byte{0x02}),
|
||||
}
|
||||
|
||||
// Generate a random update and try to commit before inserting the
|
||||
// session, which should fail.
|
||||
update1 := randCommittedUpdate(h.t, 1)
|
||||
h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound)
|
||||
|
||||
// Reserve a session key index and insert the session.
|
||||
session.KeyIndex = h.nextKeyIndex(session.TowerID, nil)
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Now, try to commit the update that failed initially which should
|
||||
// succeed. The lastApplied value should be 0 since we have not received
|
||||
// an ack from the tower.
|
||||
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied != 0 {
|
||||
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
|
||||
// Assert that the committed update appears in the client session's
|
||||
// CommittedUpdates map when loaded from disk and that there are no
|
||||
// AckedUpdates.
|
||||
dbSession := h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
|
||||
// Try to commit the same update, which should succeed due to
|
||||
// idempotency (which is preserved when the breach hint is identical to
|
||||
// the on-disk update's hint). The lastApplied value should remain
|
||||
// unchanged.
|
||||
lastApplied2 := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied2 != lastApplied {
|
||||
h.t.Fatalf("last applied should not have changed, got %v",
|
||||
lastApplied2)
|
||||
}
|
||||
|
||||
// Assert that the loaded ClientSession is the same as before.
|
||||
dbSession = h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
|
||||
// Generate another random update and try to commit it at the identical
|
||||
// sequence number. Since the breach hint has changed, this should fail.
|
||||
update2 := randCommittedUpdate(h.t, 1)
|
||||
h.commitUpdate(&session.ID, update2, wtdb.ErrUpdateAlreadyCommitted)
|
||||
|
||||
// Next, insert the new update at the next unallocated sequence number
|
||||
// which should succeed.
|
||||
update2.SeqNum = 2
|
||||
lastApplied3 := h.commitUpdate(&session.ID, update2, nil)
|
||||
if lastApplied3 != lastApplied {
|
||||
h.t.Fatalf("last applied should not have changed, got %v",
|
||||
lastApplied3)
|
||||
}
|
||||
|
||||
// Check that both updates now appear as committed on the ClientSession
|
||||
// loaded from disk.
|
||||
dbSession = h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
*update2,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
|
||||
// Finally, create one more random update and try to commit it at index
|
||||
// 4, which should be rejected since 3 is the next slot the database
|
||||
// expects.
|
||||
update4 := randCommittedUpdate(h.t, 4)
|
||||
h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate)
|
||||
|
||||
// Assert that the ClientSession loaded from disk remains unchanged.
|
||||
dbSession = h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
*update2,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
}
|
||||
|
||||
// testAckUpdate asserts the behavior of AckUpdate.
|
||||
func testAckUpdate(h *clientDBHarness) {
|
||||
// Create a new session that the updates in this will be tied to.
|
||||
session := &wtdb.ClientSession{
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
TowerID: wtdb.TowerID(3),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 100,
|
||||
},
|
||||
RewardPkScript: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
ID: wtdb.SessionID([33]byte{0x03}),
|
||||
}
|
||||
|
||||
// Try to ack an update before inserting the client session, which
|
||||
// should fail.
|
||||
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrClientSessionNotFound)
|
||||
|
||||
// Reserve a session key and insert the client session.
|
||||
session.KeyIndex = h.nextKeyIndex(session.TowerID, nil)
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Now, try to ack update 1. This should fail since update 1 was never
|
||||
// committed.
|
||||
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrCommittedUpdateNotFound)
|
||||
|
||||
// Commit to a random update at seqnum 1.
|
||||
update1 := randCommittedUpdate(h.t, 1)
|
||||
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
||||
if lastApplied != 0 {
|
||||
h.t.Fatalf("last applied mismatch, want: 0, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
|
||||
// Acking seqnum 1 should succeed.
|
||||
h.ackUpdate(&session.ID, 1, 1, nil)
|
||||
|
||||
// Acking seqnum 1 again should fail.
|
||||
h.ackUpdate(&session.ID, 1, 1, wtdb.ErrCommittedUpdateNotFound)
|
||||
|
||||
// Acking a valid seqnum with a reverted last applied value should fail.
|
||||
h.ackUpdate(&session.ID, 1, 0, wtdb.ErrLastAppliedReversion)
|
||||
|
||||
// Acking with a last applied greater than any allocated seqnum should
|
||||
// fail.
|
||||
h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
|
||||
|
||||
// Assert that the ClientSession loaded from disk has one update in it's
|
||||
// AckedUpdates map, and that the committed update has been removed.
|
||||
dbSession := h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, nil)
|
||||
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
|
||||
1: update1.BackupID,
|
||||
})
|
||||
|
||||
// Commit to another random update, and assert that the last applied
|
||||
// value is 1, since this was what was provided in the last successful
|
||||
// ack.
|
||||
update2 := randCommittedUpdate(h.t, 2)
|
||||
lastApplied = h.commitUpdate(&session.ID, update2, nil)
|
||||
if lastApplied != 1 {
|
||||
h.t.Fatalf("last applied mismatch, want: 1, got: %v",
|
||||
lastApplied)
|
||||
}
|
||||
|
||||
// Ack seqnum 2.
|
||||
h.ackUpdate(&session.ID, 2, 2, nil)
|
||||
|
||||
// Assert that both updates exist as AckedUpdates when loaded from disk.
|
||||
dbSession = h.listSessions()[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, nil)
|
||||
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
|
||||
1: update1.BackupID,
|
||||
2: update2.BackupID,
|
||||
})
|
||||
|
||||
// Acking again with a lower last applied should fail.
|
||||
h.ackUpdate(&session.ID, 2, 1, wtdb.ErrLastAppliedReversion)
|
||||
|
||||
// Acking an unallocated seqnum should fail.
|
||||
h.ackUpdate(&session.ID, 4, 2, wtdb.ErrCommittedUpdateNotFound)
|
||||
|
||||
// Acking with a last applied greater than any allocated seqnum should
|
||||
// fail.
|
||||
h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
|
||||
}
|
||||
|
||||
// checkCommittedUpdates asserts that the CommittedUpdates on session match the
|
||||
// expUpdates provided.
|
||||
func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates []wtdb.CommittedUpdate) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
// We promote nil expUpdates to an initialized slice since the database
|
||||
// should never return a nil slice. This promotion is done purely out of
|
||||
// convenience for the testing framework.
|
||||
if expUpdates == nil {
|
||||
expUpdates = make([]wtdb.CommittedUpdate, 0)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) {
|
||||
t.Fatalf("committed updates mismatch, want: %v, got: %v",
|
||||
expUpdates, session.CommittedUpdates)
|
||||
}
|
||||
}
|
||||
|
||||
// checkAckedUpdates asserts that the AckedUpdates on a sessio match the
|
||||
// expUpdates provided.
|
||||
func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates map[uint16]wtdb.BackupID) {
|
||||
|
||||
// We promote nil expUpdates to an initialized map since the database
|
||||
// should never return a nil map. This promotion is done purely out of
|
||||
// convenience for the testing framework.
|
||||
if expUpdates == nil {
|
||||
expUpdates = make(map[uint16]wtdb.BackupID)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(session.AckedUpdates, expUpdates) {
|
||||
t.Fatalf("acked updates mismatch, want: %v, got: %v",
|
||||
expUpdates, session.AckedUpdates)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
|
||||
// and the mock implementation. This ensures that all databases function
|
||||
// identically, especially in the negative paths.
|
||||
func TestClientDB(t *testing.T) {
|
||||
dbs := []struct {
|
||||
name string
|
||||
init clientDBInit
|
||||
}{
|
||||
{
|
||||
name: "fresh clientdb",
|
||||
init: func(t *testing.T) (wtclient.DB, func()) {
|
||||
path, err := ioutil.TempDir("", "clientdb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make temp dir: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
db, err := wtdb.OpenClientDB(path)
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
os.RemoveAll(path)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reopened clientdb",
|
||||
init: func(t *testing.T) (wtclient.DB, func()) {
|
||||
path, err := ioutil.TempDir("", "clientdb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make temp dir: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
db, err := wtdb.OpenClientDB(path)
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
db.Close()
|
||||
|
||||
db, err = wtdb.OpenClientDB(path)
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
t.Fatalf("unable to reopen db: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
os.RemoveAll(path)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mock",
|
||||
init: func(t *testing.T) (wtclient.DB, func()) {
|
||||
return wtmock.NewClientDB(), func() {}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*clientDBHarness)
|
||||
}{
|
||||
{
|
||||
name: "create client session",
|
||||
run: testCreateClientSession,
|
||||
},
|
||||
{
|
||||
name: "create tower",
|
||||
run: testCreateTower,
|
||||
},
|
||||
{
|
||||
name: "chan summaries",
|
||||
run: testChanSummaries,
|
||||
},
|
||||
{
|
||||
name: "commit update",
|
||||
run: testCommitUpdate,
|
||||
},
|
||||
{
|
||||
name: "ack update",
|
||||
run: testAckUpdate,
|
||||
},
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
db := database
|
||||
t.Run(db.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
h, cleanup := newClientDBHarness(
|
||||
t, db.init,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
test.run(h)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// randCommittedUpdate generates a random committed update.
|
||||
func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
|
||||
var chanID lnwire.ChannelID
|
||||
if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil {
|
||||
t.Fatalf("unable to generate chan id: %v", err)
|
||||
}
|
||||
|
||||
var hint wtdb.BreachHint
|
||||
if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil {
|
||||
t.Fatalf("unable to generate breach hint: %v", err)
|
||||
}
|
||||
|
||||
encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
|
||||
if _, err := io.ReadFull(crand.Reader, encBlob); err != nil {
|
||||
t.Fatalf("unable to generate encrypted blob: %v", err)
|
||||
}
|
||||
|
||||
return &wtdb.CommittedUpdate{
|
||||
SeqNum: seqNum,
|
||||
CommittedUpdateBody: wtdb.CommittedUpdateBody{
|
||||
BackupID: wtdb.BackupID{
|
||||
ChanID: chanID,
|
||||
CommitHeight: 666,
|
||||
},
|
||||
Hint: hint,
|
||||
EncryptedBlob: encBlob,
|
||||
},
|
||||
}
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
@ -9,44 +8,6 @@ import (
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrClientSessionNotFound signals that the requested client session
|
||||
// was not found in the database.
|
||||
ErrClientSessionNotFound = errors.New("client session not found")
|
||||
|
||||
// ErrUpdateAlreadyCommitted signals that the chosen sequence number has
|
||||
// already been committed to an update with a different breach hint.
|
||||
ErrUpdateAlreadyCommitted = errors.New("update already committed")
|
||||
|
||||
// ErrCommitUnorderedUpdate signals the client tried to commit a
|
||||
// sequence number other than the next unallocated sequence number.
|
||||
ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic")
|
||||
|
||||
// ErrCommittedUpdateNotFound signals that the tower tried to ACK a
|
||||
// sequence number that has not yet been allocated by the client.
|
||||
ErrCommittedUpdateNotFound = errors.New("committed update not found")
|
||||
|
||||
// ErrUnallocatedLastApplied signals that the tower tried to provide a
|
||||
// LastApplied value greater than any allocated sequence number.
|
||||
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
|
||||
"greater than allocated seqnum")
|
||||
|
||||
// ErrNoReservedKeyIndex signals that a client session could not be
|
||||
// created because no session key index was reserved.
|
||||
ErrNoReservedKeyIndex = errors.New("key index not reserved")
|
||||
|
||||
// ErrIncorrectKeyIndex signals that the client session could not be
|
||||
// created because session key index differs from the reserved key
|
||||
// index.
|
||||
ErrIncorrectKeyIndex = errors.New("incorrect key index")
|
||||
|
||||
// ErrClientSessionAlreadyExists signals an attempt to reinsert
|
||||
// a client session that has already been created.
|
||||
ErrClientSessionAlreadyExists = errors.New(
|
||||
"client session already exists",
|
||||
)
|
||||
)
|
||||
|
||||
// ClientSession encapsulates a SessionInfo returned from a successful
|
||||
// session negotiation, and also records the tower and ephemeral secret used for
|
||||
// communicating with the tower.
|
||||
|
@ -1,7 +1,6 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
@ -9,12 +8,6 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTowerNotFound signals that the target tower was not found in the
|
||||
// database.
|
||||
ErrTowerNotFound = errors.New("tower not found")
|
||||
)
|
||||
|
||||
// TowerID is a unique 64-bit identifier allocated to each unique watchtower.
|
||||
// This allows the client to conserve on-disk space by not needing to always
|
||||
// reference towers by their pubkey.
|
||||
|
@ -21,6 +21,11 @@ type version struct {
|
||||
// migrations must be applied.
|
||||
var towerDBVersions = []version{}
|
||||
|
||||
// clientDBVersions stores all versions and migrations of the client database.
|
||||
// This list will be used when opening the database to determine if any
|
||||
// migrations must be applied.
|
||||
var clientDBVersions = []version{}
|
||||
|
||||
// getLatestDBVersion returns the last known database version.
|
||||
func getLatestDBVersion(versions []version) uint32 {
|
||||
return uint32(len(versions))
|
||||
|
Loading…
Reference in New Issue
Block a user