Merge pull request #2783 from cfromknecht/wtserver-db

watchtower/wtdb: add bbolt-backed tower database
This commit is contained in:
Olaoluwa Osuntokun 2019-04-26 18:08:32 -07:00 committed by GitHub
commit 0393793733
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 2043 additions and 150 deletions

@ -51,6 +51,12 @@ type UnknownElementType struct {
element interface{}
}
// NewUnknownElementType creates a new UnknownElementType error from the passed
// method name and element.
func NewUnknownElementType(method string, el interface{}) UnknownElementType {
return UnknownElementType{method: method, element: el}
}
// Error returns the name of the method that encountered the error, as well as
// the type that was unsupported.
func (e UnknownElementType) Error() string {

@ -15,6 +15,7 @@ import (
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/lookout"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
@ -66,7 +67,7 @@ func makeAddrSlice(size int) []byte {
}
func TestLookoutBreachMatching(t *testing.T) {
db := wtdb.NewMockDB()
db := wtmock.NewTowerDB()
// Initialize an mock backend to feed the lookout blocks.
backend := lookout.NewMockBackend()

@ -369,7 +369,7 @@ type testHarness struct {
clientDB *wtmock.ClientDB
clientCfg *wtclient.Config
client wtclient.Client
serverDB *wtdb.MockDB
serverDB *wtmock.TowerDB
serverCfg *wtserver.Config
server *wtserver.Server
net *mockNet
@ -406,7 +406,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
}
const timeout = 200 * time.Millisecond
serverDB := wtdb.NewMockDB()
serverDB := wtmock.NewTowerDB()
serverCfg := &wtserver.Config{
DB: serverDB,

143
watchtower/wtdb/codec.go Normal file

@ -0,0 +1,143 @@
package wtdb
import (
"io"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
// UnknownElementType is an alias for channeldb.UnknownElementType.
type UnknownElementType = channeldb.UnknownElementType
// ReadElement deserializes a single element from the provided io.Reader.
func ReadElement(r io.Reader, element interface{}) error {
err := channeldb.ReadElement(r, element)
switch {
// Known to channeldb codec.
case err == nil:
return nil
// Fail if error is not UnknownElementType.
case err != nil:
if _, ok := err.(UnknownElementType); !ok {
return err
}
}
// Process any wtdb-specific extensions to the codec.
switch e := element.(type) {
case *SessionID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *BreachHint:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *wtpolicy.Policy:
var (
blobType uint16
sweepFeeRate uint64
)
err := channeldb.ReadElements(r,
&blobType,
&e.MaxUpdates,
&e.RewardBase,
&e.RewardRate,
&sweepFeeRate,
)
if err != nil {
return err
}
e.BlobType = blob.Type(blobType)
e.SweepFeeRate = lnwallet.SatPerKWeight(sweepFeeRate)
// Type is still unknown to wtdb extensions, fail.
default:
return channeldb.NewUnknownElementType(
"ReadElement", element,
)
}
return nil
}
// WriteElement serializes a single element into the provided io.Writer.
func WriteElement(w io.Writer, element interface{}) error {
err := channeldb.WriteElement(w, element)
switch {
// Known to channeldb codec.
case err == nil:
return nil
// Fail if error is not UnknownElementType.
case err != nil:
if _, ok := err.(UnknownElementType); !ok {
return err
}
}
// Process any wtdb-specific extensions to the codec.
switch e := element.(type) {
case SessionID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case BreachHint:
if _, err := w.Write(e[:]); err != nil {
return err
}
case wtpolicy.Policy:
return channeldb.WriteElements(w,
uint16(e.BlobType),
e.MaxUpdates,
e.RewardBase,
e.RewardRate,
uint64(e.SweepFeeRate),
)
// Type is still unknown to wtdb extensions, fail.
default:
return channeldb.NewUnknownElementType(
"WriteElement", element,
)
}
return nil
}
// WriteElements serializes a variadic list of elements into the given
// io.Writer.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
if err := WriteElement(w, element); err != nil {
return err
}
}
return nil
}
// ReadElements deserializes the provided io.Reader into a variadic list of
// target elements.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
if err := ReadElement(r, element); err != nil {
return err
}
}
return nil
}

@ -0,0 +1,86 @@
package wtdb_test
import (
"bytes"
"io"
"reflect"
"testing"
"testing/quick"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
)
// dbObject is abstract object support encoding and decoding.
type dbObject interface {
Encode(io.Writer) error
Decode(io.Reader) error
}
// TestCodec serializes and deserializes wtdb objects in order to test that that
// the codec understands all of the required field types. The test also asserts
// that decoding an object into another results in an equivalent object.
func TestCodec(t *testing.T) {
mainScenario := func(obj dbObject) bool {
// Ensure encoding the object succeeds.
var b bytes.Buffer
err := obj.Encode(&b)
if err != nil {
t.Fatalf("unable to encode: %v", err)
return false
}
var obj2 dbObject
switch obj.(type) {
case *wtdb.SessionInfo:
obj2 = &wtdb.SessionInfo{}
case *wtdb.SessionStateUpdate:
obj2 = &wtdb.SessionStateUpdate{}
default:
t.Fatalf("unknown type: %T", obj)
return false
}
// Ensure decoding the object succeeds.
err = obj2.Decode(bytes.NewReader(b.Bytes()))
if err != nil {
t.Fatalf("unable to decode: %v", err)
return false
}
// Assert the original and decoded object match.
if !reflect.DeepEqual(obj, obj2) {
t.Fatalf("encode/decode mismatch, want: %v, "+
"got: %v", obj, obj2)
return false
}
return true
}
tests := []struct {
name string
scenario interface{}
}{
{
name: "SessionInfo",
scenario: func(obj wtdb.SessionInfo) bool {
return mainScenario(&obj)
},
},
{
name: "SessionStateUpdate",
scenario: func(obj wtdb.SessionStateUpdate) bool {
return mainScenario(&obj)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err := quick.Check(test.scenario, nil); err != nil {
t.Fatalf("fuzz checks for msg=%s failed: %v",
test.name, err)
}
})
}
}

45
watchtower/wtdb/log.go Normal file

@ -0,0 +1,45 @@
package wtdb
import (
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build"
)
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("WTDB", nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
// logClosure is used to provide a closure over expensive logging operations so
// don't have to be performed when the logging level doesn't warrant it.
type logClosure func() string
// String invokes the underlying function and returns the result.
func (c logClosure) String() string {
return c()
}
// newLogClosure returns a new closure over a function that returns a string
// which itself provides a Stringer interface so that it can be used with the
// logging system.
func newLogClosure(c func() string) logClosure {
return logClosure(c)
}

@ -1,142 +0,0 @@
// +build dev
package wtdb
import (
"sync"
"github.com/lightningnetwork/lnd/chainntnfs"
)
type MockDB struct {
mu sync.Mutex
lastEpoch *chainntnfs.BlockEpoch
sessions map[SessionID]*SessionInfo
blobs map[BreachHint]map[SessionID]*SessionStateUpdate
}
func NewMockDB() *MockDB {
return &MockDB{
sessions: make(map[SessionID]*SessionInfo),
blobs: make(map[BreachHint]map[SessionID]*SessionStateUpdate),
}
}
func (db *MockDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) {
db.mu.Lock()
defer db.mu.Unlock()
info, ok := db.sessions[update.ID]
if !ok {
return 0, ErrSessionNotFound
}
err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied)
if err != nil {
return info.LastApplied, err
}
sessionsToUpdates, ok := db.blobs[update.Hint]
if !ok {
sessionsToUpdates = make(map[SessionID]*SessionStateUpdate)
db.blobs[update.Hint] = sessionsToUpdates
}
sessionsToUpdates[update.ID] = update
return info.LastApplied, nil
}
func (db *MockDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
db.mu.Lock()
defer db.mu.Unlock()
if info, ok := db.sessions[*id]; ok {
return info, nil
}
return nil, ErrSessionNotFound
}
func (db *MockDB) InsertSessionInfo(info *SessionInfo) error {
db.mu.Lock()
defer db.mu.Unlock()
dbInfo, ok := db.sessions[info.ID]
if ok && dbInfo.LastApplied > 0 {
return ErrSessionAlreadyExists
}
db.sessions[info.ID] = info
return nil
}
func (db *MockDB) DeleteSession(target SessionID) error {
db.mu.Lock()
defer db.mu.Unlock()
// Fail if the session doesn't exit.
if _, ok := db.sessions[target]; !ok {
return ErrSessionNotFound
}
// Remove the target session.
delete(db.sessions, target)
// Remove the state updates for any blobs stored under the target
// session identifier.
for hint, sessionUpdates := range db.blobs {
delete(sessionUpdates, target)
//If this was the last state update, we can also remove the hint
//that would map to an empty set.
if len(sessionUpdates) == 0 {
delete(db.blobs, hint)
}
}
return nil
}
func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
db.mu.Lock()
defer db.mu.Unlock()
return db.lastEpoch, nil
}
func (db *MockDB) QueryMatches(breachHints []BreachHint) ([]Match, error) {
db.mu.Lock()
defer db.mu.Unlock()
var matches []Match
for _, hint := range breachHints {
sessionsToUpdates, ok := db.blobs[hint]
if !ok {
continue
}
for id, update := range sessionsToUpdates {
info, ok := db.sessions[id]
if !ok {
panic("session not found")
}
match := Match{
ID: id,
SeqNum: update.SeqNum,
Hint: hint,
EncryptedBlob: update.EncryptedBlob,
SessionInfo: info,
}
matches = append(matches, match)
}
}
return matches, nil
}
func (db *MockDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
db.lastEpoch = epoch
return nil
}

@ -2,6 +2,7 @@ package wtdb
import (
"errors"
"io"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
@ -59,6 +60,28 @@ type SessionInfo struct {
// TODO(conner): store client metrics, DOS score, etc
}
// Encode serializes the session info to the given io.Writer.
func (s *SessionInfo) Encode(w io.Writer) error {
return WriteElements(w,
s.ID,
s.Policy,
s.LastApplied,
s.ClientLastApplied,
s.RewardAddress,
)
}
// Decode deserializes the session infor from the given io.Reader.
func (s *SessionInfo) Decode(r io.Reader) error {
return ReadElements(r,
&s.ID,
&s.Policy,
&s.LastApplied,
&s.ClientLastApplied,
&s.RewardAddress,
)
}
// AcceptUpdateSequence validates that a state update's sequence number and last
// applied are valid given our past history with the client. These checks ensure
// that clients are properly in sync and following the update protocol properly.

@ -1,5 +1,7 @@
package wtdb
import "io"
// SessionStateUpdate holds a state update sent by a client along with its
// SessionID.
type SessionStateUpdate struct {
@ -21,3 +23,25 @@ type SessionStateUpdate struct {
// hint is braodcast.
EncryptedBlob []byte
}
// Encode serializes the state update into the provided io.Writer.
func (u *SessionStateUpdate) Encode(w io.Writer) error {
return WriteElements(w,
u.ID,
u.SeqNum,
u.LastApplied,
u.Hint,
u.EncryptedBlob,
)
}
// Decode deserializes the target state update from the provided io.Reader.
func (u *SessionStateUpdate) Decode(r io.Reader) error {
return ReadElements(r,
&u.ID,
&u.SeqNum,
&u.LastApplied,
&u.Hint,
&u.EncryptedBlob,
)
}

733
watchtower/wtdb/tower_db.go Normal file

@ -0,0 +1,733 @@
package wtdb
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
)
const (
// dbName is the filename of tower database.
dbName = "watchtower.db"
// dbFilePermission requests read+write access to the db file.
dbFilePermission = 0600
)
var (
// sessionsBkt is a bucket containing all negotiated client sessions.
// session id -> session
sessionsBkt = []byte("sessions-bucket")
// updatesBkt is a bucket containing all state updates sent by clients.
// The updates are further bucketed by session id to prevent clients
// from overwrite each other.
// hint => session id -> update
updatesBkt = []byte("updates-bucket")
// updateIndexBkt is a bucket that indexes all state updates by their
// overarching session id. This allows for efficient lookup of updates
// by their session id, which is currently used to aide deletion
// performance.
// session id => hint1 -> []byte{}
// => hint2 -> []byte{}
updateIndexBkt = []byte("update-index-bucket")
// lookoutTipBkt is a bucket containing the last block epoch processed
// by the lookout subsystem. It has one key, lookoutTipKey.
// lookoutTipKey -> block epoch
lookoutTipBkt = []byte("lookout-tip-bucket")
// lookoutTipKey is a static key used to retrieve lookout tip's block
// epoch from the lookoutTipBkt.
lookoutTipKey = []byte("lookout-tip")
// metadataBkt stores all the meta information concerning the state of
// the database.
metadataBkt = []byte("metadata-bucket")
// dbVersionKey is a static key used to retrieve the database version
// number from the metadataBkt.
dbVersionKey = []byte("version")
// ErrUninitializedDB signals that top-level buckets for the database
// have not been initialized.
ErrUninitializedDB = errors.New("tower db not initialized")
// ErrNoDBVersion signals that the database contains no version info.
ErrNoDBVersion = errors.New("tower db has no version")
// ErrNoSessionHintIndex signals that an active session does not have an
// initialized index for tracking its own state updates.
ErrNoSessionHintIndex = errors.New("session hint index missing")
byteOrder = binary.BigEndian
)
// TowerDB is single database providing a persistent storage engine for the
// wtserver and lookout subsystems.
type TowerDB struct {
db *bbolt.DB
dbPath string
}
// OpenTowerDB opens the tower database given the path to the database's
// directory. If no such database exists, this method will initialize a fresh
// one using the latest version number and bucket structure. If a database
// exists but has a lower version number than the current version, any necessary
// migrations will be applied before returning. Any attempt to open a database
// with a version number higher that the latest version will fail to prevent
// accidental reversion.
func OpenTowerDB(dbPath string) (*TowerDB, error) {
path := filepath.Join(dbPath, dbName)
// If the database file doesn't exist, this indicates we much initialize
// a fresh database with the latest version.
firstInit := !fileExists(path)
if firstInit {
// Ensure all parent directories are initialized.
err := os.MkdirAll(dbPath, 0700)
if err != nil {
return nil, err
}
}
bdb, err := bbolt.Open(path, dbFilePermission, nil)
if err != nil {
return nil, err
}
// If the file existed previously, we'll now check to see that the
// metadata bucket is properly initialized. It could be the case that
// the database was created, but we failed to actually populate any
// metadata. If the metadata bucket does not actually exist, we'll
// set firstInit to true so that we can treat is initialize the bucket.
if !firstInit {
var metadataExists bool
err = bdb.View(func(tx *bbolt.Tx) error {
metadataExists = tx.Bucket(metadataBkt) != nil
return nil
})
if err != nil {
return nil, err
}
if !metadataExists {
firstInit = true
}
}
towerDB := &TowerDB{
db: bdb,
dbPath: dbPath,
}
if firstInit {
// If the database has not yet been created, we'll initialize
// the database version with the latest known version.
err = towerDB.db.Update(func(tx *bbolt.Tx) error {
return initDBVersion(tx, getLatestDBVersion(dbVersions))
})
if err != nil {
bdb.Close()
return nil, err
}
} else {
// Otherwise, ensure that any migrations are applied to ensure
// the data is in the format expected by the latest version.
err = towerDB.syncVersions(dbVersions)
if err != nil {
bdb.Close()
return nil, err
}
}
// Now that the database version fully consistent with our latest known
// version, ensure that all top-level buckets known to this version are
// initialized. This allows us to assume their presence throughout all
// operations. If an known top-level bucket is expected to exist but is
// missing, this will trigger a ErrUninitializedDB error.
err = towerDB.db.Update(initTowerDBBuckets)
if err != nil {
bdb.Close()
return nil, err
}
return towerDB, nil
}
// fileExists returns true if the file exists, and false otherwise.
func fileExists(path string) bool {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
// initTowerDBBuckets creates all top-level buckets required to handle database
// operations required by the latest version.
func initTowerDBBuckets(tx *bbolt.Tx) error {
buckets := [][]byte{
sessionsBkt,
updateIndexBkt,
updatesBkt,
lookoutTipBkt,
}
for _, bucket := range buckets {
_, err := tx.CreateBucketIfNotExists(bucket)
if err != nil {
return err
}
}
return nil
}
// syncVersions ensures the database version is consistent with the highest
// known database version, applying any migrations that have not been made. If
// the highest known version number is lower than the database's version, this
// method will fail to prevent accidental reversions.
func (t *TowerDB) syncVersions(versions []version) error {
curVersion, err := t.Version()
if err != nil {
return err
}
latestVersion := getLatestDBVersion(versions)
switch {
// Current version is higher than any known version, fail to prevent
// reversion.
case curVersion > latestVersion:
return channeldb.ErrDBReversion
// Current version matches highest known version, nothing to do.
case curVersion == latestVersion:
return nil
}
// Otherwise, apply any migrations in order to bring the database
// version up to the highest known version.
updates := getMigrations(versions, curVersion)
return t.db.Update(func(tx *bbolt.Tx) error {
for _, update := range updates {
if update.migration == nil {
continue
}
log.Infof("Applying migration #%d", update.number)
err := update.migration(tx)
if err != nil {
log.Errorf("Unable to apply migration #%d: %v",
err)
return err
}
}
return putDBVersion(tx, latestVersion)
})
}
// Version returns the database's current version number.
func (t *TowerDB) Version() (uint32, error) {
var version uint32
err := t.db.View(func(tx *bbolt.Tx) error {
var err error
version, err = getDBVersion(tx)
return err
})
if err != nil {
return 0, err
}
return version, nil
}
// Close closes the underlying database.
func (t *TowerDB) Close() error {
return t.db.Close()
}
// GetSessionInfo retrieves the session for the passed session id. An error is
// returned if the session could not be found.
func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
var session *SessionInfo
err := t.db.View(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(sessionsBkt)
if sessions == nil {
return ErrUninitializedDB
}
var err error
session, err = getSession(sessions, id[:])
return err
})
if err != nil {
return nil, err
}
return session, nil
}
// InsertSessionInfo records a negotiated session in the tower database. An
// error is returned if the session already exists.
func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error {
return t.db.Update(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(sessionsBkt)
if sessions == nil {
return ErrUninitializedDB
}
updateIndex := tx.Bucket(updateIndexBkt)
if updateIndex == nil {
return ErrUninitializedDB
}
dbSession, err := getSession(sessions, session.ID[:])
switch {
case err == ErrSessionNotFound:
// proceed.
case err != nil:
return err
case dbSession.LastApplied > 0:
return ErrSessionAlreadyExists
}
err = putSession(sessions, session)
if err != nil {
return err
}
// Initialize the session-hint index which will be used to track
// all updates added for this session. Upon deletion, we will
// consult the index to determine exactly which updates should
// be deleted without needing to iterate over the entire
// database.
return touchSessionHintBkt(updateIndex, &session.ID)
})
}
// InsertStateUpdate stores an update sent by the client after validating that
// the update is well-formed in the context of other updates sent for the same
// session. This include verifying that the sequence number is incremented
// properly and the last applied values echoed by the client are sane.
func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) {
var lastApplied uint16
err := t.db.Update(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(sessionsBkt)
if sessions == nil {
return ErrUninitializedDB
}
updates := tx.Bucket(updatesBkt)
if updates == nil {
return ErrUninitializedDB
}
updateIndex := tx.Bucket(updateIndexBkt)
if updateIndex == nil {
return ErrUninitializedDB
}
// Fetch the session corresponding to the update's session id.
// This will be used to validate that the update's sequence
// number and last applied values are sane.
session, err := getSession(sessions, update.ID[:])
if err != nil {
return err
}
// Validate the update against the current state of the session.
err = session.AcceptUpdateSequence(
update.SeqNum, update.LastApplied,
)
if err != nil {
return err
}
// Validation succeeded, therefore the update is committed and
// the session's last applied value is equal to the update's
// sequence number.
lastApplied = session.LastApplied
// Store the updated session to persist the updated last applied
// values.
err = putSession(sessions, session)
if err != nil {
return err
}
// Create or load the hint bucket for this state update's hint
// and write the given update.
hints, err := updates.CreateBucketIfNotExists(update.Hint[:])
if err != nil {
return err
}
var b bytes.Buffer
err = update.Encode(&b)
if err != nil {
return err
}
err = hints.Put(update.ID[:], b.Bytes())
if err != nil {
return err
}
// Finally, create an entry in the update index to track this
// hint under its session id. This will allow us to delete the
// entries efficiently if the session is ever removed.
return putHintForSession(updateIndex, &update.ID, update.Hint)
})
if err != nil {
return 0, err
}
return lastApplied, nil
}
// DeleteSession removes all data associated with a particular session id from
// the tower's database.
func (t *TowerDB) DeleteSession(target SessionID) error {
return t.db.Update(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(sessionsBkt)
if sessions == nil {
return ErrUninitializedDB
}
updates := tx.Bucket(updatesBkt)
if updates == nil {
return ErrUninitializedDB
}
updateIndex := tx.Bucket(updateIndexBkt)
if updateIndex == nil {
return ErrUninitializedDB
}
// Fail if the session doesn't exit.
_, err := getSession(sessions, target[:])
if err != nil {
return err
}
// Remove the target session.
err = sessions.Delete(target[:])
if err != nil {
return err
}
// Next, check the update index for any hints that were added
// under this session.
hints, err := getHintsForSession(updateIndex, &target)
if err != nil {
return err
}
for _, hint := range hints {
// Remove the state updates for any blobs stored under
// the target session identifier.
updatesForHint := updates.Bucket(hint[:])
if updatesForHint == nil {
continue
}
update := updatesForHint.Get(target[:])
if update == nil {
continue
}
err := updatesForHint.Delete(target[:])
if err != nil {
return err
}
// If this was the last state update, we can also remove
// the hint that would map to an empty set.
err = isBucketEmpty(updatesForHint)
switch {
// Other updates exist for this hint, keep the bucket.
case err == errBucketNotEmpty:
continue
// Unexpected error.
case err != nil:
return err
// No more updates for this hint, prune hint bucket.
default:
err = updates.DeleteBucket(hint[:])
if err != nil {
return err
}
}
}
// Finally, remove this session from the update index, which
// also removes any of the indexed hints beneath it.
return removeSessionHintBkt(updateIndex, &target)
})
}
// QueryMatches searches against all known state updates for any that match the
// passed breachHints. More than one Match will be returned for a given hint if
// they exist in the database.
func (t *TowerDB) QueryMatches(breachHints []BreachHint) ([]Match, error) {
var matches []Match
err := t.db.View(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(sessionsBkt)
if sessions == nil {
return ErrUninitializedDB
}
updates := tx.Bucket(updatesBkt)
if updates == nil {
return ErrUninitializedDB
}
// Iterate through the target breach hints, appending any
// matching updates to the set of matches.
for _, hint := range breachHints {
// If a bucket does not exist for this hint, no matches
// are known.
updatesForHint := updates.Bucket(hint[:])
if updatesForHint == nil {
continue
}
// Otherwise, iterate through all (session id, update)
// pairs, creating a Match for each.
err := updatesForHint.ForEach(func(k, v []byte) error {
// Load the session via the session id for this
// update. The session info contains further
// instructions for how to process the state
// update.
session, err := getSession(sessions, k)
switch {
case err == ErrSessionNotFound:
log.Warnf("Missing session=%x for "+
"matched state update hint=%x",
k, hint)
return nil
case err != nil:
return err
}
// Decode the state update containing the
// encrypted blob.
update := &SessionStateUpdate{}
err = update.Decode(bytes.NewReader(v))
if err != nil {
return err
}
var id SessionID
copy(id[:], k)
// Construct the final match using the found
// update and its session info.
match := Match{
ID: id,
SeqNum: update.SeqNum,
Hint: hint,
EncryptedBlob: update.EncryptedBlob,
SessionInfo: session,
}
matches = append(matches, match)
return nil
})
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return matches, nil
}
// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in
// the tower database.
func (t *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
return t.db.Update(func(tx *bbolt.Tx) error {
lookoutTip := tx.Bucket(lookoutTipBkt)
if lookoutTip == nil {
return ErrUninitializedDB
}
return putLookoutEpoch(lookoutTip, epoch)
})
}
// GetLookoutTip retrieves the current lookout tip block epoch from the tower
// database.
func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
var epoch *chainntnfs.BlockEpoch
err := t.db.View(func(tx *bbolt.Tx) error {
lookoutTip := tx.Bucket(lookoutTipBkt)
if lookoutTip == nil {
return ErrUninitializedDB
}
epoch = getLookoutEpoch(lookoutTip)
return nil
})
if err != nil {
return nil, err
}
return epoch, nil
}
// getSession retrieves the session info from the sessions bucket identified by
// its session id. An error is returned if the session is not found or a
// deserialization error occurs.
func getSession(sessions *bbolt.Bucket, id []byte) (*SessionInfo, error) {
sessionBytes := sessions.Get(id)
if sessionBytes == nil {
return nil, ErrSessionNotFound
}
var session SessionInfo
err := session.Decode(bytes.NewReader(sessionBytes))
if err != nil {
return nil, err
}
return &session, nil
}
// putSession stores the session info in the sessions bucket identified by its
// session id. An error is returned if a serialization error occurs.
func putSession(sessions *bbolt.Bucket, session *SessionInfo) error {
var b bytes.Buffer
err := session.Encode(&b)
if err != nil {
return err
}
return sessions.Put(session.ID[:], b.Bytes())
}
// touchSessionHintBkt initializes the session-hint bucket for a particular
// session id. This ensures that future calls to getHintsForSession or
// putHintForSession can rely on the bucket already being created, and fail if
// index has not been initialized as this points to improper usage.
func touchSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error {
_, err := updateIndex.CreateBucketIfNotExists(id[:])
return err
}
// removeSessionHintBkt prunes the session-hint bucket for the given session id
// and all of the hints contained inside. This should be used to clean up the
// index upon session deletion.
func removeSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error {
return updateIndex.DeleteBucket(id[:])
}
// getHintsForSession returns all known hints belonging to the given session id.
// If the index for the session has not been initialized, this method returns
// ErrNoSessionHintIndex.
func getHintsForSession(updateIndex *bbolt.Bucket,
id *SessionID) ([]BreachHint, error) {
sessionHints := updateIndex.Bucket(id[:])
if sessionHints == nil {
return nil, ErrNoSessionHintIndex
}
var hints []BreachHint
err := sessionHints.ForEach(func(k, _ []byte) error {
if len(k) != BreachHintSize {
return nil
}
var hint BreachHint
copy(hint[:], k)
hints = append(hints, hint)
return nil
})
if err != nil {
return nil, err
}
return hints, nil
}
// putHintForSession inserts a record into the update index for a given
// (session, hint) pair. The hints are coalesced under a bucket for the target
// session id, and used to perform efficient removal of updates. If the index
// for the session has not been initialized, this method returns
// ErrNoSessionHintIndex.
func putHintForSession(updateIndex *bbolt.Bucket, id *SessionID,
hint BreachHint) error {
sessionHints := updateIndex.Bucket(id[:])
if sessionHints == nil {
return ErrNoSessionHintIndex
}
return sessionHints.Put(hint[:], []byte{})
}
// putLookoutEpoch stores the given lookout tip block epoch in provided bucket.
func putLookoutEpoch(bkt *bbolt.Bucket, epoch *chainntnfs.BlockEpoch) error {
epochBytes := make([]byte, 36)
copy(epochBytes, epoch.Hash[:])
byteOrder.PutUint32(epochBytes[32:], uint32(epoch.Height))
return bkt.Put(lookoutTipKey, epochBytes)
}
// getLookoutEpoch retrieves the lookout tip block epoch from the given bucket.
// A nil epoch is returned if no update exists.
func getLookoutEpoch(bkt *bbolt.Bucket) *chainntnfs.BlockEpoch {
epochBytes := bkt.Get(lookoutTipKey)
if len(epochBytes) != 36 {
return nil
}
var hash chainhash.Hash
copy(hash[:], epochBytes[:32])
height := byteOrder.Uint32(epochBytes[32:])
return &chainntnfs.BlockEpoch{
Hash: &hash,
Height: int32(height),
}
}
// errBucketNotEmpty is a helper error returned when testing whether a bucket is
// empty or not.
var errBucketNotEmpty = errors.New("bucket not empty")
// isBucketEmpty returns errBucketNotEmpty if the bucket is not empty.
func isBucketEmpty(bkt *bbolt.Bucket) error {
return bkt.ForEach(func(_, _ []byte) error {
return errBucketNotEmpty
})
}

@ -0,0 +1,730 @@
package wtdb_test
import (
"encoding/binary"
"io/ioutil"
"os"
"reflect"
"testing"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/watchtower"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
// dbInit is a closure used to initialize a watchtower.DB instance and its
// cleanup function.
type dbInit func(*testing.T) (watchtower.DB, func())
// towerDBHarness holds the resources required to execute the tower db tests.
type towerDBHarness struct {
t *testing.T
db watchtower.DB
}
// newTowerDBHarness initializes a fresh test harness for testing watchtower.DB
// implementations.
func newTowerDBHarness(t *testing.T, init dbInit) (*towerDBHarness, func()) {
db, cleanup := init(t)
h := &towerDBHarness{
t: t,
db: db,
}
return h, cleanup
}
// insertSession attempts to isnert the passed session and asserts that the
// error returned matches expErr.
func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) {
h.t.Helper()
err := h.db.InsertSessionInfo(s)
if err != expErr {
h.t.Fatalf("expected insert session error: %v, got : %v",
expErr, err)
}
}
// getSession retrieves the session identified by id, asserting that the call
// returns expErr. If successful, the found session is returned.
func (h *towerDBHarness) getSession(id *wtdb.SessionID,
expErr error) *wtdb.SessionInfo {
h.t.Helper()
session, err := h.db.GetSessionInfo(id)
if err != expErr {
h.t.Fatalf("expected get session error: %v, got: %v",
expErr, err)
}
return session
}
// insertUpdate attempts to insert the passed state update and asserts that the
// error returned matches expErr. If successful, the session's last applied
// value is returned.
func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate,
expErr error) uint16 {
h.t.Helper()
lastApplied, err := h.db.InsertStateUpdate(s)
if err != expErr {
h.t.Fatalf("expected insert update error: %v, got: %v",
expErr, err)
}
return lastApplied
}
// deleteSession attempts to delete the session identified by id and asserts
// that the error returned from DeleteSession matches the expected error.
func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
h.t.Helper()
err := h.db.DeleteSession(id)
if err != expErr {
h.t.Fatalf("expected deletion error: %v, got: %v",
expErr, err)
}
}
// queryMatches queries that database for the passed breach hint, returning all
// matches found.
func (h *towerDBHarness) queryMatches(hint wtdb.BreachHint) []wtdb.Match {
h.t.Helper()
matches, err := h.db.QueryMatches([]wtdb.BreachHint{hint})
if err != nil {
h.t.Fatalf("unable to query matches: %v", err)
}
return matches
}
// hasUpdate queries the database for the passed breach hint, asserting that
// only one match is present and that the hints indeed match. If successful, the
// match is returned.
func (h *towerDBHarness) hasUpdate(hint wtdb.BreachHint) wtdb.Match {
h.t.Helper()
matches := h.queryMatches(hint)
if len(matches) != 1 {
h.t.Fatalf("expected 1 match, found: %d", len(matches))
}
match := matches[0]
if match.Hint != hint {
h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint)
}
return match
}
// testInsertSession asserts that a session can only be inserted if a session
// with the same session id does not already exist.
func testInsertSession(h *towerDBHarness) {
var id wtdb.SessionID
h.getSession(&id, wtdb.ErrSessionNotFound)
session := &wtdb.SessionInfo{
ID: id,
Policy: wtpolicy.Policy{
MaxUpdates: 100,
},
RewardAddress: []byte{0x01, 0x02, 0x03},
}
h.insertSession(session, nil)
session2 := h.getSession(&id, nil)
if !reflect.DeepEqual(session, session2) {
h.t.Fatalf("expected session: %v, got %v",
session, session2)
}
h.insertSession(session, nil)
// Insert a state update to fully commit the session parameters.
update := &wtdb.SessionStateUpdate{
ID: id,
SeqNum: 1,
}
h.insertUpdate(update, nil)
// Trying to insert a new session under the same ID should fail.
h.insertSession(session, wtdb.ErrSessionAlreadyExists)
}
// testMultipleMatches asserts that if multiple sessions insert state updates
// with the same breach hint that all will be returned from QueryMatches.
func testMultipleMatches(h *towerDBHarness) {
const numUpdates = 3
// Create a new session and send updates with all the same hint.
var hint wtdb.BreachHint
for i := 0; i < numUpdates; i++ {
id := *id(i)
session := &wtdb.SessionInfo{
ID: id,
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
}
h.insertSession(session, nil)
update := &wtdb.SessionStateUpdate{
ID: id,
SeqNum: 1,
Hint: hint, // Use same hint to cause multiple matches
}
h.insertUpdate(update, nil)
}
// Query the db for matches on the chosen hint.
matches := h.queryMatches(hint)
if len(matches) != numUpdates {
h.t.Fatalf("num updates mismatch, want: %d, got: %d",
numUpdates, len(matches))
}
// Assert that the hints are what we asked for, and compute the set of
// sessions returned.
sessions := make(map[wtdb.SessionID]struct{})
for _, match := range matches {
if match.Hint != hint {
h.t.Fatalf("hint mismatch, want: %v, got: %v",
hint, match.Hint)
}
sessions[match.ID] = struct{}{}
}
// Assert that the sessions returned match the session ids of the
// sessions we initially created.
for i := 0; i < numUpdates; i++ {
if _, ok := sessions[*id(i)]; !ok {
h.t.Fatalf("match for session %v not found", *id(i))
}
}
}
// testLookoutTip asserts that the database properly stores and returns the
// lookout tip block epochs. It also asserts that the epoch returned is nil when
// no tip has ever been set.
func testLookoutTip(h *towerDBHarness) {
// Retrieve lookout tip on fresh db.
epoch, err := h.db.GetLookoutTip()
if err != nil {
h.t.Fatalf("unable to fetch lookout tip: %v", err)
}
// Assert that the epoch is nil.
if epoch != nil {
h.t.Fatalf("lookout tip should not be set, found: %v", epoch)
}
// Create a closure that inserts an epoch, retrieves it, and asserts
// that the returned epoch matches what was inserted.
setAndCheck := func(i int) {
expEpoch := epochFromInt(1)
err = h.db.SetLookoutTip(expEpoch)
if err != nil {
h.t.Fatalf("unable to set lookout tip: %v", err)
}
epoch, err = h.db.GetLookoutTip()
if err != nil {
h.t.Fatalf("unable to fetch lookout tip: %v", err)
}
if !reflect.DeepEqual(epoch, expEpoch) {
h.t.Fatalf("lookout tip mismatch, want: %v, got: %v",
expEpoch, epoch)
}
}
// Set and assert the lookout tip.
for i := 0; i < 5; i++ {
setAndCheck(i)
}
}
// testDeleteSession asserts the behavior of a tower database when deleting
// session data. The test asserts that the only proper the target session is
// remmoved, and that only updates for a particular session are pruned.
func testDeleteSession(h *towerDBHarness) {
// First, create a session so that the database is not empty.
id0 := id(0)
session0 := &wtdb.SessionInfo{
ID: *id0,
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
}
h.insertSession(session0, nil)
// Now, attempt to delete a session which does not exist, that is also
// different from the first one created.
id1 := id(1)
h.deleteSession(*id1, wtdb.ErrSessionNotFound)
// The first session should still be present.
h.getSession(id0, nil)
// Now insert a second session under a different id.
session1 := &wtdb.SessionInfo{
ID: *id1,
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
}
h.insertSession(session1, nil)
// Create and insert updates for both sessions that have the same hint.
var hint wtdb.BreachHint
update0 := &wtdb.SessionStateUpdate{
ID: *id0,
Hint: hint,
SeqNum: 1,
EncryptedBlob: []byte{},
}
update1 := &wtdb.SessionStateUpdate{
ID: *id1,
Hint: hint,
SeqNum: 1,
EncryptedBlob: []byte{},
}
// Insert both updates should succeed.
h.insertUpdate(update0, nil)
h.insertUpdate(update1, nil)
// Remove the new session, which should succeed.
h.deleteSession(*id1, nil)
// The first session should still be present.
h.getSession(id0, nil)
// The second session should be removed.
h.getSession(id1, wtdb.ErrSessionNotFound)
// Assert that only one update is still present.
matches := h.queryMatches(hint)
if len(matches) != 1 {
h.t.Fatalf("expected one update, found: %d", len(matches))
}
// Assert that the update belongs to the first session.
if matches[0].ID != *id0 {
h.t.Fatalf("expected match for %v, instead is for: %v",
*id0, matches[0].ID)
}
// Finally, remove the first session added.
h.deleteSession(*id0, nil)
// The session should no longer be present.
h.getSession(id0, wtdb.ErrSessionNotFound)
// No matches should exist for this hint.
matches = h.queryMatches(hint)
if len(matches) != 0 {
h.t.Fatalf("expected zero updates, found: %d", len(matches))
}
}
type stateUpdateTest struct {
session *wtdb.SessionInfo
sessionErr error
updates []*wtdb.SessionStateUpdate
updateErrs []error
}
func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
return func(h *towerDBHarness) {
// We may need to modify the initial session as we process
// updates to discern the expected state of the session. We'll
// create a copy of the test session if necessary to prevent
// mutations from impacting other tests.
var expSession *wtdb.SessionInfo
// Create the session if the tests requests one.
if test.session != nil {
// Copy the initial session and insert it into the
// database.
ogSession := *test.session
expErr := test.sessionErr
h.insertSession(&ogSession, expErr)
if expErr != nil {
return
}
// Copy the initial state of the accepted session.
expSession = &wtdb.SessionInfo{}
*expSession = *test.session
}
if len(test.updates) != len(test.updateErrs) {
h.t.Fatalf("malformed test case, num updates " +
"should match num errors")
}
// Send any updates provided in the test.
for i, update := range test.updates {
expErr := test.updateErrs[i]
h.insertUpdate(update, expErr)
if expErr != nil {
continue
}
// Don't perform the following checks and modfications
// if we don't have an expected session to compare
// against.
if expSession == nil {
continue
}
// Update the session's last applied and client last
// applied.
expSession.LastApplied = update.SeqNum
expSession.ClientLastApplied = update.LastApplied
match := h.hasUpdate(update.Hint)
if !reflect.DeepEqual(match.SessionInfo, expSession) {
h.t.Fatalf("expected session: %v, got: %v",
expSession, match.SessionInfo)
}
}
}
}
var stateUpdateNoSession = stateUpdateTest{
session: nil,
updates: []*wtdb.SessionStateUpdate{
{ID: *id(0), SeqNum: 1, LastApplied: 0},
},
updateErrs: []error{
wtdb.ErrSessionNotFound,
},
}
var stateUpdateExhaustSession = stateUpdateTest{
session: &wtdb.SessionInfo{
ID: *id(0),
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
},
updates: []*wtdb.SessionStateUpdate{
updateFromInt(id(0), 1, 0),
updateFromInt(id(0), 2, 0),
updateFromInt(id(0), 3, 0),
updateFromInt(id(0), 4, 0),
},
updateErrs: []error{
nil, nil, nil, wtdb.ErrSessionConsumed,
},
}
var stateUpdateSeqNumEqualLastApplied = stateUpdateTest{
session: &wtdb.SessionInfo{
ID: *id(0),
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
},
updates: []*wtdb.SessionStateUpdate{
updateFromInt(id(0), 1, 0),
updateFromInt(id(0), 2, 1),
updateFromInt(id(0), 3, 2),
updateFromInt(id(0), 3, 3),
},
updateErrs: []error{
nil, nil, nil, wtdb.ErrSeqNumAlreadyApplied,
},
}
var stateUpdateSeqNumLTLastApplied = stateUpdateTest{
session: &wtdb.SessionInfo{
ID: *id(0),
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
},
updates: []*wtdb.SessionStateUpdate{
updateFromInt(id(0), 1, 0),
updateFromInt(id(0), 2, 1),
updateFromInt(id(0), 1, 2),
},
updateErrs: []error{
nil, nil, wtdb.ErrSeqNumAlreadyApplied,
},
}
var stateUpdateSeqNumZeroInvalid = stateUpdateTest{
session: &wtdb.SessionInfo{
ID: *id(0),
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
},
updates: []*wtdb.SessionStateUpdate{
updateFromInt(id(0), 0, 0),
},
updateErrs: []error{
wtdb.ErrSeqNumAlreadyApplied,
},
}
var stateUpdateSkipSeqNum = stateUpdateTest{
session: &wtdb.SessionInfo{
ID: *id(0),
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
},
updates: []*wtdb.SessionStateUpdate{
updateFromInt(id(0), 2, 0),
},
updateErrs: []error{
wtdb.ErrUpdateOutOfOrder,
},
}
var stateUpdateRevertSeqNum = stateUpdateTest{
session: &wtdb.SessionInfo{
ID: *id(0),
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
},
updates: []*wtdb.SessionStateUpdate{
updateFromInt(id(0), 1, 0),
updateFromInt(id(0), 2, 0),
updateFromInt(id(0), 1, 0),
},
updateErrs: []error{
nil, nil, wtdb.ErrUpdateOutOfOrder,
},
}
var stateUpdateRevertLastApplied = stateUpdateTest{
session: &wtdb.SessionInfo{
ID: *id(0),
Policy: wtpolicy.Policy{
MaxUpdates: 3,
},
RewardAddress: []byte{},
},
updates: []*wtdb.SessionStateUpdate{
updateFromInt(id(0), 1, 0),
updateFromInt(id(0), 2, 1),
updateFromInt(id(0), 3, 2),
updateFromInt(id(0), 4, 1),
},
updateErrs: []error{
nil, nil, nil, wtdb.ErrLastAppliedReversion,
},
}
func TestTowerDB(t *testing.T) {
dbs := []struct {
name string
init dbInit
}{
{
name: "fresh boltdb",
init: func(t *testing.T) (watchtower.DB, func()) {
path, err := ioutil.TempDir("", "towerdb")
if err != nil {
t.Fatalf("unable to make temp dir: %v",
err)
}
db, err := wtdb.OpenTowerDB(path)
if err != nil {
os.RemoveAll(path)
t.Fatalf("unable to open db: %v", err)
}
cleanup := func() {
db.Close()
os.RemoveAll(path)
}
return db, cleanup
},
},
{
name: "reopened boltdb",
init: func(t *testing.T) (watchtower.DB, func()) {
path, err := ioutil.TempDir("", "towerdb")
if err != nil {
t.Fatalf("unable to make temp dir: %v",
err)
}
db, err := wtdb.OpenTowerDB(path)
if err != nil {
os.RemoveAll(path)
t.Fatalf("unable to open db: %v", err)
}
db.Close()
// Open the db again, ensuring we test a
// different path during open and that all
// buckets remain initialized.
db, err = wtdb.OpenTowerDB(path)
if err != nil {
os.RemoveAll(path)
t.Fatalf("unable to open db: %v", err)
}
cleanup := func() {
db.Close()
os.RemoveAll(path)
}
return db, cleanup
},
},
{
name: "mock",
init: func(t *testing.T) (watchtower.DB, func()) {
return wtmock.NewTowerDB(), func() {}
},
},
}
tests := []struct {
name string
run func(*towerDBHarness)
}{
{
name: "create session",
run: testInsertSession,
},
{
name: "delete session",
run: testDeleteSession,
},
{
name: "state update no session",
run: runStateUpdateTest(stateUpdateNoSession),
},
{
name: "state update exhaust session",
run: runStateUpdateTest(stateUpdateExhaustSession),
},
{
name: "state update seqnum equal last applied",
run: runStateUpdateTest(
stateUpdateSeqNumEqualLastApplied,
),
},
{
name: "state update seqnum less than last applied",
run: runStateUpdateTest(
stateUpdateSeqNumLTLastApplied,
),
},
{
name: "state update seqnum zero invalid",
run: runStateUpdateTest(stateUpdateSeqNumZeroInvalid),
},
{
name: "state update skip seqnum",
run: runStateUpdateTest(stateUpdateSkipSeqNum),
},
{
name: "state update revert seqnum",
run: runStateUpdateTest(stateUpdateRevertSeqNum),
},
{
name: "state update revert last applied",
run: runStateUpdateTest(stateUpdateRevertLastApplied),
},
{
name: "multiple breach matches",
run: testMultipleMatches,
},
{
name: "lookout tip",
run: testLookoutTip,
},
}
for _, database := range dbs {
db := database
t.Run(db.name, func(t *testing.T) {
t.Parallel()
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
h, cleanup := newTowerDBHarness(
t, db.init,
)
defer cleanup()
test.run(h)
})
}
})
}
}
// id creates a session id from an integer.
func id(i int) *wtdb.SessionID {
var id wtdb.SessionID
binary.BigEndian.PutUint32(id[:4], uint32(i))
return &id
}
// updateFromInt creates a unique update for a given (session, seqnum) pair. The
// lastApplied argument can be used to construct updates simulating different
// levels of synchronicity between client and db.
func updateFromInt(id *wtdb.SessionID, i int,
lastApplied uint16) *wtdb.SessionStateUpdate {
// Ensure the hint is unique.
var hint wtdb.BreachHint
copy(hint[:4], id[:4])
binary.BigEndian.PutUint16(hint[4:6], uint16(i))
return &wtdb.SessionStateUpdate{
ID: *id,
Hint: hint,
SeqNum: uint16(i),
LastApplied: lastApplied,
EncryptedBlob: []byte{byte(i)},
}
}
// epochFromInt creates a block epoch from an integer.
func epochFromInt(i int) *chainntnfs.BlockEpoch {
var hash chainhash.Hash
binary.BigEndian.PutUint32(hash[:4], uint32(i))
return &chainntnfs.BlockEpoch{
Hash: &hash,
Height: int32(i),
}
}

@ -0,0 +1,84 @@
package wtdb
import "github.com/coreos/bbolt"
// migration is a function which takes a prior outdated version of the database
// instances and mutates the key/bucket structure to arrive at a more
// up-to-date version of the database.
type migration func(tx *bbolt.Tx) error
// version pairs a version number with the migration that would need to be
// applied from the prior version to upgrade.
type version struct {
number uint32
migration migration
}
// dbVersions stores all versions and migrations of the database. This list will
// be used when opening the database to determine if any migrations must be
// applied.
var dbVersions = []version{
{
// Initial version requires no migration.
number: 0,
migration: nil,
},
}
// getLatestDBVersion returns the last known database version.
func getLatestDBVersion(versions []version) uint32 {
return versions[len(versions)-1].number
}
// getMigrations returns a slice of all updates with a greater number that
// curVersion that need to be applied to sync up with the latest version.
func getMigrations(versions []version, curVersion uint32) []version {
var updates []version
for _, v := range versions {
if v.number > curVersion {
updates = append(updates, v)
}
}
return updates
}
// getDBVersion retrieves the current database version from the metadata bucket
// using the dbVersionKey.
func getDBVersion(tx *bbolt.Tx) (uint32, error) {
metadata := tx.Bucket(metadataBkt)
if metadata == nil {
return 0, ErrUninitializedDB
}
versionBytes := metadata.Get(dbVersionKey)
if len(versionBytes) != 4 {
return 0, ErrNoDBVersion
}
return byteOrder.Uint32(versionBytes), nil
}
// initDBVersion initializes the top-level metadata bucket and writes the passed
// version number as the current version.
func initDBVersion(tx *bbolt.Tx, version uint32) error {
_, err := tx.CreateBucketIfNotExists(metadataBkt)
if err != nil {
return err
}
return putDBVersion(tx, version)
}
// putDBVersion stores the passed database version in the metadata bucket under
// the dbVersionKey.
func putDBVersion(tx *bbolt.Tx, version uint32) error {
metadata := tx.Bucket(metadataBkt)
if metadata == nil {
return ErrUninitializedDB
}
versionBytes := make([]byte, 4)
byteOrder.PutUint32(versionBytes, version)
return metadata.Put(dbVersionKey, versionBytes)
}

@ -0,0 +1,162 @@
package wtmock
import (
"sync"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
)
// TowerDB is a mock, in-memory implementation of a watchtower.DB.
type TowerDB struct {
mu sync.Mutex
lastEpoch *chainntnfs.BlockEpoch
sessions map[wtdb.SessionID]*wtdb.SessionInfo
blobs map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate
}
// NewTowerDB initializes a fresh mock TowerDB.
func NewTowerDB() *TowerDB {
return &TowerDB{
sessions: make(map[wtdb.SessionID]*wtdb.SessionInfo),
blobs: make(map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate),
}
}
// InsertStateUpdate stores an update sent by the client after validating that
// the update is well-formed in the context of other updates sent for the same
// session. This include verifying that the sequence number is incremented
// properly and the last applied values echoed by the client are sane.
func (db *TowerDB) InsertStateUpdate(update *wtdb.SessionStateUpdate) (uint16, error) {
db.mu.Lock()
defer db.mu.Unlock()
info, ok := db.sessions[update.ID]
if !ok {
return 0, wtdb.ErrSessionNotFound
}
err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied)
if err != nil {
return info.LastApplied, err
}
sessionsToUpdates, ok := db.blobs[update.Hint]
if !ok {
sessionsToUpdates = make(map[wtdb.SessionID]*wtdb.SessionStateUpdate)
db.blobs[update.Hint] = sessionsToUpdates
}
sessionsToUpdates[update.ID] = update
return info.LastApplied, nil
}
// GetSessionInfo retrieves the session for the passed session id. An error is
// returned if the session could not be found.
func (db *TowerDB) GetSessionInfo(id *wtdb.SessionID) (*wtdb.SessionInfo, error) {
db.mu.Lock()
defer db.mu.Unlock()
if info, ok := db.sessions[*id]; ok {
return info, nil
}
return nil, wtdb.ErrSessionNotFound
}
// InsertSessionInfo records a negotiated session in the tower database. An
// error is returned if the session already exists.
func (db *TowerDB) InsertSessionInfo(info *wtdb.SessionInfo) error {
db.mu.Lock()
defer db.mu.Unlock()
dbInfo, ok := db.sessions[info.ID]
if ok && dbInfo.LastApplied > 0 {
return wtdb.ErrSessionAlreadyExists
}
db.sessions[info.ID] = info
return nil
}
// DeleteSession removes all data associated with a particular session id from
// the tower's database.
func (db *TowerDB) DeleteSession(target wtdb.SessionID) error {
db.mu.Lock()
defer db.mu.Unlock()
// Fail if the session doesn't exit.
if _, ok := db.sessions[target]; !ok {
return wtdb.ErrSessionNotFound
}
// Remove the target session.
delete(db.sessions, target)
// Remove the state updates for any blobs stored under the target
// session identifier.
for hint, sessionUpdates := range db.blobs {
delete(sessionUpdates, target)
// If this was the last state update, we can also remove the
// hint that would map to an empty set.
if len(sessionUpdates) == 0 {
delete(db.blobs, hint)
}
}
return nil
}
// QueryMatches searches against all known state updates for any that match the
// passed breachHints. More than one Match will be returned for a given hint if
// they exist in the database.
func (db *TowerDB) QueryMatches(
breachHints []wtdb.BreachHint) ([]wtdb.Match, error) {
db.mu.Lock()
defer db.mu.Unlock()
var matches []wtdb.Match
for _, hint := range breachHints {
sessionsToUpdates, ok := db.blobs[hint]
if !ok {
continue
}
for id, update := range sessionsToUpdates {
info, ok := db.sessions[id]
if !ok {
panic("session not found")
}
match := wtdb.Match{
ID: id,
SeqNum: update.SeqNum,
Hint: hint,
EncryptedBlob: update.EncryptedBlob,
SessionInfo: info,
}
matches = append(matches, match)
}
}
return matches, nil
}
// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in
// the tower database.
func (db *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
db.lastEpoch = epoch
return nil
}
// GetLookoutTip retrieves the current lookout tip block epoch from the tower
// database.
func (db *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
db.mu.Lock()
defer db.mu.Unlock()
return db.lastEpoch, nil
}

@ -102,7 +102,7 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
// successful, the session will now be ready for use.
err = s.cfg.DB.InsertSessionInfo(&info)
if err != nil {
log.Errorf("unable to create session for %s", id)
log.Errorf("Unable to create session for %s: %v", id, err)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
)

@ -1,5 +1,3 @@
// +build dev
package wtserver_test
import (
@ -53,7 +51,7 @@ func initServer(t *testing.T, db wtserver.DB,
t.Helper()
if db == nil {
db = wtdb.NewMockDB()
db = wtmock.NewTowerDB()
}
s, err := wtserver.New(&wtserver.Config{
@ -687,7 +685,7 @@ func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) {
// checking that the proper error is returned when the session doesn't exist and
// that a successful deletion does not disrupt other sessions.
func TestServerDeleteSession(t *testing.T) {
db := wtdb.NewMockDB()
db := wtmock.NewTowerDB()
localPub := randPubKey(t)