Merge pull request #2783 from cfromknecht/wtserver-db
watchtower/wtdb: add bbolt-backed tower database
This commit is contained in:
commit
0393793733
@ -51,6 +51,12 @@ type UnknownElementType struct {
|
||||
element interface{}
|
||||
}
|
||||
|
||||
// NewUnknownElementType creates a new UnknownElementType error from the passed
|
||||
// method name and element.
|
||||
func NewUnknownElementType(method string, el interface{}) UnknownElementType {
|
||||
return UnknownElementType{method: method, element: el}
|
||||
}
|
||||
|
||||
// Error returns the name of the method that encountered the error, as well as
|
||||
// the type that was unsupported.
|
||||
func (e UnknownElementType) Error() string {
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/lookout"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
|
||||
@ -66,7 +67,7 @@ func makeAddrSlice(size int) []byte {
|
||||
}
|
||||
|
||||
func TestLookoutBreachMatching(t *testing.T) {
|
||||
db := wtdb.NewMockDB()
|
||||
db := wtmock.NewTowerDB()
|
||||
|
||||
// Initialize an mock backend to feed the lookout blocks.
|
||||
backend := lookout.NewMockBackend()
|
||||
|
@ -369,7 +369,7 @@ type testHarness struct {
|
||||
clientDB *wtmock.ClientDB
|
||||
clientCfg *wtclient.Config
|
||||
client wtclient.Client
|
||||
serverDB *wtdb.MockDB
|
||||
serverDB *wtmock.TowerDB
|
||||
serverCfg *wtserver.Config
|
||||
server *wtserver.Server
|
||||
net *mockNet
|
||||
@ -406,7 +406,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||
}
|
||||
|
||||
const timeout = 200 * time.Millisecond
|
||||
serverDB := wtdb.NewMockDB()
|
||||
serverDB := wtmock.NewTowerDB()
|
||||
|
||||
serverCfg := &wtserver.Config{
|
||||
DB: serverDB,
|
||||
|
143
watchtower/wtdb/codec.go
Normal file
143
watchtower/wtdb/codec.go
Normal file
@ -0,0 +1,143 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
|
||||
// UnknownElementType is an alias for channeldb.UnknownElementType.
|
||||
type UnknownElementType = channeldb.UnknownElementType
|
||||
|
||||
// ReadElement deserializes a single element from the provided io.Reader.
|
||||
func ReadElement(r io.Reader, element interface{}) error {
|
||||
err := channeldb.ReadElement(r, element)
|
||||
switch {
|
||||
|
||||
// Known to channeldb codec.
|
||||
case err == nil:
|
||||
return nil
|
||||
|
||||
// Fail if error is not UnknownElementType.
|
||||
case err != nil:
|
||||
if _, ok := err.(UnknownElementType); !ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Process any wtdb-specific extensions to the codec.
|
||||
switch e := element.(type) {
|
||||
|
||||
case *SessionID:
|
||||
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *BreachHint:
|
||||
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *wtpolicy.Policy:
|
||||
var (
|
||||
blobType uint16
|
||||
sweepFeeRate uint64
|
||||
)
|
||||
err := channeldb.ReadElements(r,
|
||||
&blobType,
|
||||
&e.MaxUpdates,
|
||||
&e.RewardBase,
|
||||
&e.RewardRate,
|
||||
&sweepFeeRate,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.BlobType = blob.Type(blobType)
|
||||
e.SweepFeeRate = lnwallet.SatPerKWeight(sweepFeeRate)
|
||||
|
||||
// Type is still unknown to wtdb extensions, fail.
|
||||
default:
|
||||
return channeldb.NewUnknownElementType(
|
||||
"ReadElement", element,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteElement serializes a single element into the provided io.Writer.
|
||||
func WriteElement(w io.Writer, element interface{}) error {
|
||||
err := channeldb.WriteElement(w, element)
|
||||
switch {
|
||||
|
||||
// Known to channeldb codec.
|
||||
case err == nil:
|
||||
return nil
|
||||
|
||||
// Fail if error is not UnknownElementType.
|
||||
case err != nil:
|
||||
if _, ok := err.(UnknownElementType); !ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Process any wtdb-specific extensions to the codec.
|
||||
switch e := element.(type) {
|
||||
|
||||
case SessionID:
|
||||
if _, err := w.Write(e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case BreachHint:
|
||||
if _, err := w.Write(e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case wtpolicy.Policy:
|
||||
return channeldb.WriteElements(w,
|
||||
uint16(e.BlobType),
|
||||
e.MaxUpdates,
|
||||
e.RewardBase,
|
||||
e.RewardRate,
|
||||
uint64(e.SweepFeeRate),
|
||||
)
|
||||
|
||||
// Type is still unknown to wtdb extensions, fail.
|
||||
default:
|
||||
return channeldb.NewUnknownElementType(
|
||||
"WriteElement", element,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteElements serializes a variadic list of elements into the given
|
||||
// io.Writer.
|
||||
func WriteElements(w io.Writer, elements ...interface{}) error {
|
||||
for _, element := range elements {
|
||||
if err := WriteElement(w, element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadElements deserializes the provided io.Reader into a variadic list of
|
||||
// target elements.
|
||||
func ReadElements(r io.Reader, elements ...interface{}) error {
|
||||
for _, element := range elements {
|
||||
if err := ReadElement(r, element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
86
watchtower/wtdb/codec_test.go
Normal file
86
watchtower/wtdb/codec_test.go
Normal file
@ -0,0 +1,86 @@
|
||||
package wtdb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
// dbObject is abstract object support encoding and decoding.
|
||||
type dbObject interface {
|
||||
Encode(io.Writer) error
|
||||
Decode(io.Reader) error
|
||||
}
|
||||
|
||||
// TestCodec serializes and deserializes wtdb objects in order to test that that
|
||||
// the codec understands all of the required field types. The test also asserts
|
||||
// that decoding an object into another results in an equivalent object.
|
||||
func TestCodec(t *testing.T) {
|
||||
mainScenario := func(obj dbObject) bool {
|
||||
// Ensure encoding the object succeeds.
|
||||
var b bytes.Buffer
|
||||
err := obj.Encode(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
var obj2 dbObject
|
||||
switch obj.(type) {
|
||||
case *wtdb.SessionInfo:
|
||||
obj2 = &wtdb.SessionInfo{}
|
||||
case *wtdb.SessionStateUpdate:
|
||||
obj2 = &wtdb.SessionStateUpdate{}
|
||||
default:
|
||||
t.Fatalf("unknown type: %T", obj)
|
||||
return false
|
||||
}
|
||||
|
||||
// Ensure decoding the object succeeds.
|
||||
err = obj2.Decode(bytes.NewReader(b.Bytes()))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Assert the original and decoded object match.
|
||||
if !reflect.DeepEqual(obj, obj2) {
|
||||
t.Fatalf("encode/decode mismatch, want: %v, "+
|
||||
"got: %v", obj, obj2)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scenario interface{}
|
||||
}{
|
||||
{
|
||||
name: "SessionInfo",
|
||||
scenario: func(obj wtdb.SessionInfo) bool {
|
||||
return mainScenario(&obj)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SessionStateUpdate",
|
||||
scenario: func(obj wtdb.SessionStateUpdate) bool {
|
||||
return mainScenario(&obj)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if err := quick.Check(test.scenario, nil); err != nil {
|
||||
t.Fatalf("fuzz checks for msg=%s failed: %v",
|
||||
test.name, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
45
watchtower/wtdb/log.go
Normal file
45
watchtower/wtdb/log.go
Normal file
@ -0,0 +1,45 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized with no output filters. This
|
||||
// means the package will not perform any logging by default until the caller
|
||||
// requests it.
|
||||
var log btclog.Logger
|
||||
|
||||
// The default amount of logging is none.
|
||||
func init() {
|
||||
UseLogger(build.NewSubLogger("WTDB", nil))
|
||||
}
|
||||
|
||||
// DisableLog disables all library log output. Logging output is disabled
|
||||
// by default until UseLogger is called.
|
||||
func DisableLog() {
|
||||
UseLogger(btclog.Disabled)
|
||||
}
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
// This should be used in preference to SetLogWriter if the caller is also
|
||||
// using btclog.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
||||
|
||||
// logClosure is used to provide a closure over expensive logging operations so
|
||||
// don't have to be performed when the logging level doesn't warrant it.
|
||||
type logClosure func() string
|
||||
|
||||
// String invokes the underlying function and returns the result.
|
||||
func (c logClosure) String() string {
|
||||
return c()
|
||||
}
|
||||
|
||||
// newLogClosure returns a new closure over a function that returns a string
|
||||
// which itself provides a Stringer interface so that it can be used with the
|
||||
// logging system.
|
||||
func newLogClosure(c func() string) logClosure {
|
||||
return logClosure(c)
|
||||
}
|
@ -1,142 +0,0 @@
|
||||
// +build dev
|
||||
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
)
|
||||
|
||||
type MockDB struct {
|
||||
mu sync.Mutex
|
||||
lastEpoch *chainntnfs.BlockEpoch
|
||||
sessions map[SessionID]*SessionInfo
|
||||
blobs map[BreachHint]map[SessionID]*SessionStateUpdate
|
||||
}
|
||||
|
||||
func NewMockDB() *MockDB {
|
||||
return &MockDB{
|
||||
sessions: make(map[SessionID]*SessionInfo),
|
||||
blobs: make(map[BreachHint]map[SessionID]*SessionStateUpdate),
|
||||
}
|
||||
}
|
||||
|
||||
func (db *MockDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
info, ok := db.sessions[update.ID]
|
||||
if !ok {
|
||||
return 0, ErrSessionNotFound
|
||||
}
|
||||
|
||||
err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied)
|
||||
if err != nil {
|
||||
return info.LastApplied, err
|
||||
}
|
||||
|
||||
sessionsToUpdates, ok := db.blobs[update.Hint]
|
||||
if !ok {
|
||||
sessionsToUpdates = make(map[SessionID]*SessionStateUpdate)
|
||||
db.blobs[update.Hint] = sessionsToUpdates
|
||||
}
|
||||
sessionsToUpdates[update.ID] = update
|
||||
|
||||
return info.LastApplied, nil
|
||||
}
|
||||
|
||||
func (db *MockDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
if info, ok := db.sessions[*id]; ok {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
|
||||
func (db *MockDB) InsertSessionInfo(info *SessionInfo) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
dbInfo, ok := db.sessions[info.ID]
|
||||
if ok && dbInfo.LastApplied > 0 {
|
||||
return ErrSessionAlreadyExists
|
||||
}
|
||||
|
||||
db.sessions[info.ID] = info
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *MockDB) DeleteSession(target SessionID) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
// Fail if the session doesn't exit.
|
||||
if _, ok := db.sessions[target]; !ok {
|
||||
return ErrSessionNotFound
|
||||
}
|
||||
|
||||
// Remove the target session.
|
||||
delete(db.sessions, target)
|
||||
|
||||
// Remove the state updates for any blobs stored under the target
|
||||
// session identifier.
|
||||
for hint, sessionUpdates := range db.blobs {
|
||||
delete(sessionUpdates, target)
|
||||
|
||||
//If this was the last state update, we can also remove the hint
|
||||
//that would map to an empty set.
|
||||
if len(sessionUpdates) == 0 {
|
||||
delete(db.blobs, hint)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
return db.lastEpoch, nil
|
||||
}
|
||||
|
||||
func (db *MockDB) QueryMatches(breachHints []BreachHint) ([]Match, error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
var matches []Match
|
||||
for _, hint := range breachHints {
|
||||
sessionsToUpdates, ok := db.blobs[hint]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for id, update := range sessionsToUpdates {
|
||||
info, ok := db.sessions[id]
|
||||
if !ok {
|
||||
panic("session not found")
|
||||
}
|
||||
|
||||
match := Match{
|
||||
ID: id,
|
||||
SeqNum: update.SeqNum,
|
||||
Hint: hint,
|
||||
EncryptedBlob: update.EncryptedBlob,
|
||||
SessionInfo: info,
|
||||
}
|
||||
matches = append(matches, match)
|
||||
}
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
func (db *MockDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
|
||||
db.lastEpoch = epoch
|
||||
return nil
|
||||
}
|
@ -2,6 +2,7 @@ package wtdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
@ -59,6 +60,28 @@ type SessionInfo struct {
|
||||
// TODO(conner): store client metrics, DOS score, etc
|
||||
}
|
||||
|
||||
// Encode serializes the session info to the given io.Writer.
|
||||
func (s *SessionInfo) Encode(w io.Writer) error {
|
||||
return WriteElements(w,
|
||||
s.ID,
|
||||
s.Policy,
|
||||
s.LastApplied,
|
||||
s.ClientLastApplied,
|
||||
s.RewardAddress,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode deserializes the session infor from the given io.Reader.
|
||||
func (s *SessionInfo) Decode(r io.Reader) error {
|
||||
return ReadElements(r,
|
||||
&s.ID,
|
||||
&s.Policy,
|
||||
&s.LastApplied,
|
||||
&s.ClientLastApplied,
|
||||
&s.RewardAddress,
|
||||
)
|
||||
}
|
||||
|
||||
// AcceptUpdateSequence validates that a state update's sequence number and last
|
||||
// applied are valid given our past history with the client. These checks ensure
|
||||
// that clients are properly in sync and following the update protocol properly.
|
||||
|
@ -1,5 +1,7 @@
|
||||
package wtdb
|
||||
|
||||
import "io"
|
||||
|
||||
// SessionStateUpdate holds a state update sent by a client along with its
|
||||
// SessionID.
|
||||
type SessionStateUpdate struct {
|
||||
@ -21,3 +23,25 @@ type SessionStateUpdate struct {
|
||||
// hint is braodcast.
|
||||
EncryptedBlob []byte
|
||||
}
|
||||
|
||||
// Encode serializes the state update into the provided io.Writer.
|
||||
func (u *SessionStateUpdate) Encode(w io.Writer) error {
|
||||
return WriteElements(w,
|
||||
u.ID,
|
||||
u.SeqNum,
|
||||
u.LastApplied,
|
||||
u.Hint,
|
||||
u.EncryptedBlob,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode deserializes the target state update from the provided io.Reader.
|
||||
func (u *SessionStateUpdate) Decode(r io.Reader) error {
|
||||
return ReadElements(r,
|
||||
&u.ID,
|
||||
&u.SeqNum,
|
||||
&u.LastApplied,
|
||||
&u.Hint,
|
||||
&u.EncryptedBlob,
|
||||
)
|
||||
}
|
||||
|
733
watchtower/wtdb/tower_db.go
Normal file
733
watchtower/wtdb/tower_db.go
Normal file
@ -0,0 +1,733 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
)
|
||||
|
||||
const (
|
||||
// dbName is the filename of tower database.
|
||||
dbName = "watchtower.db"
|
||||
|
||||
// dbFilePermission requests read+write access to the db file.
|
||||
dbFilePermission = 0600
|
||||
)
|
||||
|
||||
var (
|
||||
// sessionsBkt is a bucket containing all negotiated client sessions.
|
||||
// session id -> session
|
||||
sessionsBkt = []byte("sessions-bucket")
|
||||
|
||||
// updatesBkt is a bucket containing all state updates sent by clients.
|
||||
// The updates are further bucketed by session id to prevent clients
|
||||
// from overwrite each other.
|
||||
// hint => session id -> update
|
||||
updatesBkt = []byte("updates-bucket")
|
||||
|
||||
// updateIndexBkt is a bucket that indexes all state updates by their
|
||||
// overarching session id. This allows for efficient lookup of updates
|
||||
// by their session id, which is currently used to aide deletion
|
||||
// performance.
|
||||
// session id => hint1 -> []byte{}
|
||||
// => hint2 -> []byte{}
|
||||
updateIndexBkt = []byte("update-index-bucket")
|
||||
|
||||
// lookoutTipBkt is a bucket containing the last block epoch processed
|
||||
// by the lookout subsystem. It has one key, lookoutTipKey.
|
||||
// lookoutTipKey -> block epoch
|
||||
lookoutTipBkt = []byte("lookout-tip-bucket")
|
||||
|
||||
// lookoutTipKey is a static key used to retrieve lookout tip's block
|
||||
// epoch from the lookoutTipBkt.
|
||||
lookoutTipKey = []byte("lookout-tip")
|
||||
|
||||
// metadataBkt stores all the meta information concerning the state of
|
||||
// the database.
|
||||
metadataBkt = []byte("metadata-bucket")
|
||||
|
||||
// dbVersionKey is a static key used to retrieve the database version
|
||||
// number from the metadataBkt.
|
||||
dbVersionKey = []byte("version")
|
||||
|
||||
// ErrUninitializedDB signals that top-level buckets for the database
|
||||
// have not been initialized.
|
||||
ErrUninitializedDB = errors.New("tower db not initialized")
|
||||
|
||||
// ErrNoDBVersion signals that the database contains no version info.
|
||||
ErrNoDBVersion = errors.New("tower db has no version")
|
||||
|
||||
// ErrNoSessionHintIndex signals that an active session does not have an
|
||||
// initialized index for tracking its own state updates.
|
||||
ErrNoSessionHintIndex = errors.New("session hint index missing")
|
||||
|
||||
byteOrder = binary.BigEndian
|
||||
)
|
||||
|
||||
// TowerDB is single database providing a persistent storage engine for the
|
||||
// wtserver and lookout subsystems.
|
||||
type TowerDB struct {
|
||||
db *bbolt.DB
|
||||
dbPath string
|
||||
}
|
||||
|
||||
// OpenTowerDB opens the tower database given the path to the database's
|
||||
// directory. If no such database exists, this method will initialize a fresh
|
||||
// one using the latest version number and bucket structure. If a database
|
||||
// exists but has a lower version number than the current version, any necessary
|
||||
// migrations will be applied before returning. Any attempt to open a database
|
||||
// with a version number higher that the latest version will fail to prevent
|
||||
// accidental reversion.
|
||||
func OpenTowerDB(dbPath string) (*TowerDB, error) {
|
||||
path := filepath.Join(dbPath, dbName)
|
||||
|
||||
// If the database file doesn't exist, this indicates we much initialize
|
||||
// a fresh database with the latest version.
|
||||
firstInit := !fileExists(path)
|
||||
if firstInit {
|
||||
// Ensure all parent directories are initialized.
|
||||
err := os.MkdirAll(dbPath, 0700)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
bdb, err := bbolt.Open(path, dbFilePermission, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the file existed previously, we'll now check to see that the
|
||||
// metadata bucket is properly initialized. It could be the case that
|
||||
// the database was created, but we failed to actually populate any
|
||||
// metadata. If the metadata bucket does not actually exist, we'll
|
||||
// set firstInit to true so that we can treat is initialize the bucket.
|
||||
if !firstInit {
|
||||
var metadataExists bool
|
||||
err = bdb.View(func(tx *bbolt.Tx) error {
|
||||
metadataExists = tx.Bucket(metadataBkt) != nil
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !metadataExists {
|
||||
firstInit = true
|
||||
}
|
||||
}
|
||||
|
||||
towerDB := &TowerDB{
|
||||
db: bdb,
|
||||
dbPath: dbPath,
|
||||
}
|
||||
|
||||
if firstInit {
|
||||
// If the database has not yet been created, we'll initialize
|
||||
// the database version with the latest known version.
|
||||
err = towerDB.db.Update(func(tx *bbolt.Tx) error {
|
||||
return initDBVersion(tx, getLatestDBVersion(dbVersions))
|
||||
})
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Otherwise, ensure that any migrations are applied to ensure
|
||||
// the data is in the format expected by the latest version.
|
||||
err = towerDB.syncVersions(dbVersions)
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Now that the database version fully consistent with our latest known
|
||||
// version, ensure that all top-level buckets known to this version are
|
||||
// initialized. This allows us to assume their presence throughout all
|
||||
// operations. If an known top-level bucket is expected to exist but is
|
||||
// missing, this will trigger a ErrUninitializedDB error.
|
||||
err = towerDB.db.Update(initTowerDBBuckets)
|
||||
if err != nil {
|
||||
bdb.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return towerDB, nil
|
||||
}
|
||||
|
||||
// fileExists returns true if the file exists, and false otherwise.
|
||||
func fileExists(path string) bool {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// initTowerDBBuckets creates all top-level buckets required to handle database
|
||||
// operations required by the latest version.
|
||||
func initTowerDBBuckets(tx *bbolt.Tx) error {
|
||||
buckets := [][]byte{
|
||||
sessionsBkt,
|
||||
updateIndexBkt,
|
||||
updatesBkt,
|
||||
lookoutTipBkt,
|
||||
}
|
||||
|
||||
for _, bucket := range buckets {
|
||||
_, err := tx.CreateBucketIfNotExists(bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncVersions ensures the database version is consistent with the highest
|
||||
// known database version, applying any migrations that have not been made. If
|
||||
// the highest known version number is lower than the database's version, this
|
||||
// method will fail to prevent accidental reversions.
|
||||
func (t *TowerDB) syncVersions(versions []version) error {
|
||||
curVersion, err := t.Version()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
latestVersion := getLatestDBVersion(versions)
|
||||
switch {
|
||||
|
||||
// Current version is higher than any known version, fail to prevent
|
||||
// reversion.
|
||||
case curVersion > latestVersion:
|
||||
return channeldb.ErrDBReversion
|
||||
|
||||
// Current version matches highest known version, nothing to do.
|
||||
case curVersion == latestVersion:
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, apply any migrations in order to bring the database
|
||||
// version up to the highest known version.
|
||||
updates := getMigrations(versions, curVersion)
|
||||
return t.db.Update(func(tx *bbolt.Tx) error {
|
||||
for _, update := range updates {
|
||||
if update.migration == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("Applying migration #%d", update.number)
|
||||
|
||||
err := update.migration(tx)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to apply migration #%d: %v",
|
||||
err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return putDBVersion(tx, latestVersion)
|
||||
})
|
||||
}
|
||||
|
||||
// Version returns the database's current version number.
|
||||
func (t *TowerDB) Version() (uint32, error) {
|
||||
var version uint32
|
||||
err := t.db.View(func(tx *bbolt.Tx) error {
|
||||
var err error
|
||||
version, err = getDBVersion(tx)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database.
|
||||
func (t *TowerDB) Close() error {
|
||||
return t.db.Close()
|
||||
}
|
||||
|
||||
// GetSessionInfo retrieves the session for the passed session id. An error is
|
||||
// returned if the session could not be found.
|
||||
func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
|
||||
var session *SessionInfo
|
||||
err := t.db.View(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(sessionsBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
var err error
|
||||
session, err = getSession(sessions, id[:])
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// InsertSessionInfo records a negotiated session in the tower database. An
|
||||
// error is returned if the session already exists.
|
||||
func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error {
|
||||
return t.db.Update(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(sessionsBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
updateIndex := tx.Bucket(updateIndexBkt)
|
||||
if updateIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
dbSession, err := getSession(sessions, session.ID[:])
|
||||
switch {
|
||||
case err == ErrSessionNotFound:
|
||||
// proceed.
|
||||
|
||||
case err != nil:
|
||||
return err
|
||||
|
||||
case dbSession.LastApplied > 0:
|
||||
return ErrSessionAlreadyExists
|
||||
}
|
||||
|
||||
err = putSession(sessions, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the session-hint index which will be used to track
|
||||
// all updates added for this session. Upon deletion, we will
|
||||
// consult the index to determine exactly which updates should
|
||||
// be deleted without needing to iterate over the entire
|
||||
// database.
|
||||
return touchSessionHintBkt(updateIndex, &session.ID)
|
||||
})
|
||||
}
|
||||
|
||||
// InsertStateUpdate stores an update sent by the client after validating that
|
||||
// the update is well-formed in the context of other updates sent for the same
|
||||
// session. This include verifying that the sequence number is incremented
|
||||
// properly and the last applied values echoed by the client are sane.
|
||||
func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) {
|
||||
var lastApplied uint16
|
||||
err := t.db.Update(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(sessionsBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
updates := tx.Bucket(updatesBkt)
|
||||
if updates == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
updateIndex := tx.Bucket(updateIndexBkt)
|
||||
if updateIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Fetch the session corresponding to the update's session id.
|
||||
// This will be used to validate that the update's sequence
|
||||
// number and last applied values are sane.
|
||||
session, err := getSession(sessions, update.ID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate the update against the current state of the session.
|
||||
err = session.AcceptUpdateSequence(
|
||||
update.SeqNum, update.LastApplied,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validation succeeded, therefore the update is committed and
|
||||
// the session's last applied value is equal to the update's
|
||||
// sequence number.
|
||||
lastApplied = session.LastApplied
|
||||
|
||||
// Store the updated session to persist the updated last applied
|
||||
// values.
|
||||
err = putSession(sessions, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create or load the hint bucket for this state update's hint
|
||||
// and write the given update.
|
||||
hints, err := updates.CreateBucketIfNotExists(update.Hint[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
err = update.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = hints.Put(update.ID[:], b.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, create an entry in the update index to track this
|
||||
// hint under its session id. This will allow us to delete the
|
||||
// entries efficiently if the session is ever removed.
|
||||
return putHintForSession(updateIndex, &update.ID, update.Hint)
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return lastApplied, nil
|
||||
}
|
||||
|
||||
// DeleteSession removes all data associated with a particular session id from
|
||||
// the tower's database.
|
||||
func (t *TowerDB) DeleteSession(target SessionID) error {
|
||||
return t.db.Update(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(sessionsBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
updates := tx.Bucket(updatesBkt)
|
||||
if updates == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
updateIndex := tx.Bucket(updateIndexBkt)
|
||||
if updateIndex == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Fail if the session doesn't exit.
|
||||
_, err := getSession(sessions, target[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the target session.
|
||||
err = sessions.Delete(target[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Next, check the update index for any hints that were added
|
||||
// under this session.
|
||||
hints, err := getHintsForSession(updateIndex, &target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, hint := range hints {
|
||||
// Remove the state updates for any blobs stored under
|
||||
// the target session identifier.
|
||||
updatesForHint := updates.Bucket(hint[:])
|
||||
if updatesForHint == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
update := updatesForHint.Get(target[:])
|
||||
if update == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err := updatesForHint.Delete(target[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If this was the last state update, we can also remove
|
||||
// the hint that would map to an empty set.
|
||||
err = isBucketEmpty(updatesForHint)
|
||||
switch {
|
||||
|
||||
// Other updates exist for this hint, keep the bucket.
|
||||
case err == errBucketNotEmpty:
|
||||
continue
|
||||
|
||||
// Unexpected error.
|
||||
case err != nil:
|
||||
return err
|
||||
|
||||
// No more updates for this hint, prune hint bucket.
|
||||
default:
|
||||
err = updates.DeleteBucket(hint[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, remove this session from the update index, which
|
||||
// also removes any of the indexed hints beneath it.
|
||||
return removeSessionHintBkt(updateIndex, &target)
|
||||
})
|
||||
}
|
||||
|
||||
// QueryMatches searches against all known state updates for any that match the
|
||||
// passed breachHints. More than one Match will be returned for a given hint if
|
||||
// they exist in the database.
|
||||
func (t *TowerDB) QueryMatches(breachHints []BreachHint) ([]Match, error) {
|
||||
var matches []Match
|
||||
err := t.db.View(func(tx *bbolt.Tx) error {
|
||||
sessions := tx.Bucket(sessionsBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
updates := tx.Bucket(updatesBkt)
|
||||
if updates == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Iterate through the target breach hints, appending any
|
||||
// matching updates to the set of matches.
|
||||
for _, hint := range breachHints {
|
||||
// If a bucket does not exist for this hint, no matches
|
||||
// are known.
|
||||
updatesForHint := updates.Bucket(hint[:])
|
||||
if updatesForHint == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise, iterate through all (session id, update)
|
||||
// pairs, creating a Match for each.
|
||||
err := updatesForHint.ForEach(func(k, v []byte) error {
|
||||
// Load the session via the session id for this
|
||||
// update. The session info contains further
|
||||
// instructions for how to process the state
|
||||
// update.
|
||||
session, err := getSession(sessions, k)
|
||||
switch {
|
||||
case err == ErrSessionNotFound:
|
||||
log.Warnf("Missing session=%x for "+
|
||||
"matched state update hint=%x",
|
||||
k, hint)
|
||||
return nil
|
||||
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode the state update containing the
|
||||
// encrypted blob.
|
||||
update := &SessionStateUpdate{}
|
||||
err = update.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var id SessionID
|
||||
copy(id[:], k)
|
||||
|
||||
// Construct the final match using the found
|
||||
// update and its session info.
|
||||
match := Match{
|
||||
ID: id,
|
||||
SeqNum: update.SeqNum,
|
||||
Hint: hint,
|
||||
EncryptedBlob: update.EncryptedBlob,
|
||||
SessionInfo: session,
|
||||
}
|
||||
|
||||
matches = append(matches, match)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in
|
||||
// the tower database.
|
||||
func (t *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
|
||||
return t.db.Update(func(tx *bbolt.Tx) error {
|
||||
lookoutTip := tx.Bucket(lookoutTipBkt)
|
||||
if lookoutTip == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
return putLookoutEpoch(lookoutTip, epoch)
|
||||
})
|
||||
}
|
||||
|
||||
// GetLookoutTip retrieves the current lookout tip block epoch from the tower
|
||||
// database.
|
||||
func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
|
||||
var epoch *chainntnfs.BlockEpoch
|
||||
err := t.db.View(func(tx *bbolt.Tx) error {
|
||||
lookoutTip := tx.Bucket(lookoutTipBkt)
|
||||
if lookoutTip == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
epoch = getLookoutEpoch(lookoutTip)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return epoch, nil
|
||||
}
|
||||
|
||||
// getSession retrieves the session info from the sessions bucket identified by
|
||||
// its session id. An error is returned if the session is not found or a
|
||||
// deserialization error occurs.
|
||||
func getSession(sessions *bbolt.Bucket, id []byte) (*SessionInfo, error) {
|
||||
sessionBytes := sessions.Get(id)
|
||||
if sessionBytes == nil {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
|
||||
var session SessionInfo
|
||||
err := session.Decode(bytes.NewReader(sessionBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// putSession stores the session info in the sessions bucket identified by its
|
||||
// session id. An error is returned if a serialization error occurs.
|
||||
func putSession(sessions *bbolt.Bucket, session *SessionInfo) error {
|
||||
var b bytes.Buffer
|
||||
err := session.Encode(&b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sessions.Put(session.ID[:], b.Bytes())
|
||||
}
|
||||
|
||||
// touchSessionHintBkt initializes the session-hint bucket for a particular
|
||||
// session id. This ensures that future calls to getHintsForSession or
|
||||
// putHintForSession can rely on the bucket already being created, and fail if
|
||||
// index has not been initialized as this points to improper usage.
|
||||
func touchSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error {
|
||||
_, err := updateIndex.CreateBucketIfNotExists(id[:])
|
||||
return err
|
||||
}
|
||||
|
||||
// removeSessionHintBkt prunes the session-hint bucket for the given session id
|
||||
// and all of the hints contained inside. This should be used to clean up the
|
||||
// index upon session deletion.
|
||||
func removeSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error {
|
||||
return updateIndex.DeleteBucket(id[:])
|
||||
}
|
||||
|
||||
// getHintsForSession returns all known hints belonging to the given session id.
|
||||
// If the index for the session has not been initialized, this method returns
|
||||
// ErrNoSessionHintIndex.
|
||||
func getHintsForSession(updateIndex *bbolt.Bucket,
|
||||
id *SessionID) ([]BreachHint, error) {
|
||||
|
||||
sessionHints := updateIndex.Bucket(id[:])
|
||||
if sessionHints == nil {
|
||||
return nil, ErrNoSessionHintIndex
|
||||
}
|
||||
|
||||
var hints []BreachHint
|
||||
err := sessionHints.ForEach(func(k, _ []byte) error {
|
||||
if len(k) != BreachHintSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
var hint BreachHint
|
||||
copy(hint[:], k)
|
||||
hints = append(hints, hint)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hints, nil
|
||||
}
|
||||
|
||||
// putHintForSession inserts a record into the update index for a given
|
||||
// (session, hint) pair. The hints are coalesced under a bucket for the target
|
||||
// session id, and used to perform efficient removal of updates. If the index
|
||||
// for the session has not been initialized, this method returns
|
||||
// ErrNoSessionHintIndex.
|
||||
func putHintForSession(updateIndex *bbolt.Bucket, id *SessionID,
|
||||
hint BreachHint) error {
|
||||
|
||||
sessionHints := updateIndex.Bucket(id[:])
|
||||
if sessionHints == nil {
|
||||
return ErrNoSessionHintIndex
|
||||
}
|
||||
|
||||
return sessionHints.Put(hint[:], []byte{})
|
||||
}
|
||||
|
||||
// putLookoutEpoch stores the given lookout tip block epoch in provided bucket.
|
||||
func putLookoutEpoch(bkt *bbolt.Bucket, epoch *chainntnfs.BlockEpoch) error {
|
||||
epochBytes := make([]byte, 36)
|
||||
copy(epochBytes, epoch.Hash[:])
|
||||
byteOrder.PutUint32(epochBytes[32:], uint32(epoch.Height))
|
||||
|
||||
return bkt.Put(lookoutTipKey, epochBytes)
|
||||
}
|
||||
|
||||
// getLookoutEpoch retrieves the lookout tip block epoch from the given bucket.
|
||||
// A nil epoch is returned if no update exists.
|
||||
func getLookoutEpoch(bkt *bbolt.Bucket) *chainntnfs.BlockEpoch {
|
||||
epochBytes := bkt.Get(lookoutTipKey)
|
||||
if len(epochBytes) != 36 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var hash chainhash.Hash
|
||||
copy(hash[:], epochBytes[:32])
|
||||
height := byteOrder.Uint32(epochBytes[32:])
|
||||
|
||||
return &chainntnfs.BlockEpoch{
|
||||
Hash: &hash,
|
||||
Height: int32(height),
|
||||
}
|
||||
}
|
||||
|
||||
// errBucketNotEmpty is a helper error returned when testing whether a bucket is
|
||||
// empty or not.
|
||||
var errBucketNotEmpty = errors.New("bucket not empty")
|
||||
|
||||
// isBucketEmpty returns errBucketNotEmpty if the bucket is not empty.
|
||||
func isBucketEmpty(bkt *bbolt.Bucket) error {
|
||||
return bkt.ForEach(func(_, _ []byte) error {
|
||||
return errBucketNotEmpty
|
||||
})
|
||||
}
|
730
watchtower/wtdb/tower_db_test.go
Normal file
730
watchtower/wtdb/tower_db_test.go
Normal file
@ -0,0 +1,730 @@
|
||||
package wtdb_test
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/watchtower"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
)
|
||||
|
||||
// dbInit is a closure used to initialize a watchtower.DB instance and its
|
||||
// cleanup function.
|
||||
type dbInit func(*testing.T) (watchtower.DB, func())
|
||||
|
||||
// towerDBHarness holds the resources required to execute the tower db tests.
|
||||
type towerDBHarness struct {
|
||||
t *testing.T
|
||||
db watchtower.DB
|
||||
}
|
||||
|
||||
// newTowerDBHarness initializes a fresh test harness for testing watchtower.DB
|
||||
// implementations.
|
||||
func newTowerDBHarness(t *testing.T, init dbInit) (*towerDBHarness, func()) {
|
||||
db, cleanup := init(t)
|
||||
|
||||
h := &towerDBHarness{
|
||||
t: t,
|
||||
db: db,
|
||||
}
|
||||
|
||||
return h, cleanup
|
||||
}
|
||||
|
||||
// insertSession attempts to isnert the passed session and asserts that the
|
||||
// error returned matches expErr.
|
||||
func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.InsertSessionInfo(s)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected insert session error: %v, got : %v",
|
||||
expErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// getSession retrieves the session identified by id, asserting that the call
|
||||
// returns expErr. If successful, the found session is returned.
|
||||
func (h *towerDBHarness) getSession(id *wtdb.SessionID,
|
||||
expErr error) *wtdb.SessionInfo {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
session, err := h.db.GetSessionInfo(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected get session error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// insertUpdate attempts to insert the passed state update and asserts that the
|
||||
// error returned matches expErr. If successful, the session's last applied
|
||||
// value is returned.
|
||||
func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate,
|
||||
expErr error) uint16 {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
lastApplied, err := h.db.InsertStateUpdate(s)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected insert update error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
|
||||
return lastApplied
|
||||
}
|
||||
|
||||
// deleteSession attempts to delete the session identified by id and asserts
|
||||
// that the error returned from DeleteSession matches the expected error.
|
||||
func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
|
||||
h.t.Helper()
|
||||
|
||||
err := h.db.DeleteSession(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected deletion error: %v, got: %v",
|
||||
expErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// queryMatches queries that database for the passed breach hint, returning all
|
||||
// matches found.
|
||||
func (h *towerDBHarness) queryMatches(hint wtdb.BreachHint) []wtdb.Match {
|
||||
h.t.Helper()
|
||||
|
||||
matches, err := h.db.QueryMatches([]wtdb.BreachHint{hint})
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to query matches: %v", err)
|
||||
}
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
// hasUpdate queries the database for the passed breach hint, asserting that
|
||||
// only one match is present and that the hints indeed match. If successful, the
|
||||
// match is returned.
|
||||
func (h *towerDBHarness) hasUpdate(hint wtdb.BreachHint) wtdb.Match {
|
||||
h.t.Helper()
|
||||
|
||||
matches := h.queryMatches(hint)
|
||||
if len(matches) != 1 {
|
||||
h.t.Fatalf("expected 1 match, found: %d", len(matches))
|
||||
}
|
||||
|
||||
match := matches[0]
|
||||
if match.Hint != hint {
|
||||
h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint)
|
||||
}
|
||||
|
||||
return match
|
||||
}
|
||||
|
||||
// testInsertSession asserts that a session can only be inserted if a session
|
||||
// with the same session id does not already exist.
|
||||
func testInsertSession(h *towerDBHarness) {
|
||||
var id wtdb.SessionID
|
||||
h.getSession(&id, wtdb.ErrSessionNotFound)
|
||||
|
||||
session := &wtdb.SessionInfo{
|
||||
ID: id,
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 100,
|
||||
},
|
||||
RewardAddress: []byte{0x01, 0x02, 0x03},
|
||||
}
|
||||
|
||||
h.insertSession(session, nil)
|
||||
|
||||
session2 := h.getSession(&id, nil)
|
||||
|
||||
if !reflect.DeepEqual(session, session2) {
|
||||
h.t.Fatalf("expected session: %v, got %v",
|
||||
session, session2)
|
||||
}
|
||||
|
||||
h.insertSession(session, nil)
|
||||
|
||||
// Insert a state update to fully commit the session parameters.
|
||||
update := &wtdb.SessionStateUpdate{
|
||||
ID: id,
|
||||
SeqNum: 1,
|
||||
}
|
||||
h.insertUpdate(update, nil)
|
||||
|
||||
// Trying to insert a new session under the same ID should fail.
|
||||
h.insertSession(session, wtdb.ErrSessionAlreadyExists)
|
||||
}
|
||||
|
||||
// testMultipleMatches asserts that if multiple sessions insert state updates
|
||||
// with the same breach hint that all will be returned from QueryMatches.
|
||||
func testMultipleMatches(h *towerDBHarness) {
|
||||
const numUpdates = 3
|
||||
|
||||
// Create a new session and send updates with all the same hint.
|
||||
var hint wtdb.BreachHint
|
||||
for i := 0; i < numUpdates; i++ {
|
||||
id := *id(i)
|
||||
session := &wtdb.SessionInfo{
|
||||
ID: id,
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
}
|
||||
h.insertSession(session, nil)
|
||||
|
||||
update := &wtdb.SessionStateUpdate{
|
||||
ID: id,
|
||||
SeqNum: 1,
|
||||
Hint: hint, // Use same hint to cause multiple matches
|
||||
}
|
||||
h.insertUpdate(update, nil)
|
||||
}
|
||||
|
||||
// Query the db for matches on the chosen hint.
|
||||
matches := h.queryMatches(hint)
|
||||
if len(matches) != numUpdates {
|
||||
h.t.Fatalf("num updates mismatch, want: %d, got: %d",
|
||||
numUpdates, len(matches))
|
||||
}
|
||||
|
||||
// Assert that the hints are what we asked for, and compute the set of
|
||||
// sessions returned.
|
||||
sessions := make(map[wtdb.SessionID]struct{})
|
||||
for _, match := range matches {
|
||||
if match.Hint != hint {
|
||||
h.t.Fatalf("hint mismatch, want: %v, got: %v",
|
||||
hint, match.Hint)
|
||||
}
|
||||
sessions[match.ID] = struct{}{}
|
||||
}
|
||||
|
||||
// Assert that the sessions returned match the session ids of the
|
||||
// sessions we initially created.
|
||||
for i := 0; i < numUpdates; i++ {
|
||||
if _, ok := sessions[*id(i)]; !ok {
|
||||
h.t.Fatalf("match for session %v not found", *id(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testLookoutTip asserts that the database properly stores and returns the
|
||||
// lookout tip block epochs. It also asserts that the epoch returned is nil when
|
||||
// no tip has ever been set.
|
||||
func testLookoutTip(h *towerDBHarness) {
|
||||
// Retrieve lookout tip on fresh db.
|
||||
epoch, err := h.db.GetLookoutTip()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch lookout tip: %v", err)
|
||||
}
|
||||
|
||||
// Assert that the epoch is nil.
|
||||
if epoch != nil {
|
||||
h.t.Fatalf("lookout tip should not be set, found: %v", epoch)
|
||||
}
|
||||
|
||||
// Create a closure that inserts an epoch, retrieves it, and asserts
|
||||
// that the returned epoch matches what was inserted.
|
||||
setAndCheck := func(i int) {
|
||||
expEpoch := epochFromInt(1)
|
||||
err = h.db.SetLookoutTip(expEpoch)
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to set lookout tip: %v", err)
|
||||
}
|
||||
|
||||
epoch, err = h.db.GetLookoutTip()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to fetch lookout tip: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(epoch, expEpoch) {
|
||||
h.t.Fatalf("lookout tip mismatch, want: %v, got: %v",
|
||||
expEpoch, epoch)
|
||||
}
|
||||
}
|
||||
|
||||
// Set and assert the lookout tip.
|
||||
for i := 0; i < 5; i++ {
|
||||
setAndCheck(i)
|
||||
}
|
||||
}
|
||||
|
||||
// testDeleteSession asserts the behavior of a tower database when deleting
|
||||
// session data. The test asserts that the only proper the target session is
|
||||
// remmoved, and that only updates for a particular session are pruned.
|
||||
func testDeleteSession(h *towerDBHarness) {
|
||||
// First, create a session so that the database is not empty.
|
||||
id0 := id(0)
|
||||
session0 := &wtdb.SessionInfo{
|
||||
ID: *id0,
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
}
|
||||
h.insertSession(session0, nil)
|
||||
|
||||
// Now, attempt to delete a session which does not exist, that is also
|
||||
// different from the first one created.
|
||||
id1 := id(1)
|
||||
h.deleteSession(*id1, wtdb.ErrSessionNotFound)
|
||||
|
||||
// The first session should still be present.
|
||||
h.getSession(id0, nil)
|
||||
|
||||
// Now insert a second session under a different id.
|
||||
session1 := &wtdb.SessionInfo{
|
||||
ID: *id1,
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
}
|
||||
h.insertSession(session1, nil)
|
||||
|
||||
// Create and insert updates for both sessions that have the same hint.
|
||||
var hint wtdb.BreachHint
|
||||
update0 := &wtdb.SessionStateUpdate{
|
||||
ID: *id0,
|
||||
Hint: hint,
|
||||
SeqNum: 1,
|
||||
EncryptedBlob: []byte{},
|
||||
}
|
||||
update1 := &wtdb.SessionStateUpdate{
|
||||
ID: *id1,
|
||||
Hint: hint,
|
||||
SeqNum: 1,
|
||||
EncryptedBlob: []byte{},
|
||||
}
|
||||
|
||||
// Insert both updates should succeed.
|
||||
h.insertUpdate(update0, nil)
|
||||
h.insertUpdate(update1, nil)
|
||||
|
||||
// Remove the new session, which should succeed.
|
||||
h.deleteSession(*id1, nil)
|
||||
|
||||
// The first session should still be present.
|
||||
h.getSession(id0, nil)
|
||||
|
||||
// The second session should be removed.
|
||||
h.getSession(id1, wtdb.ErrSessionNotFound)
|
||||
|
||||
// Assert that only one update is still present.
|
||||
matches := h.queryMatches(hint)
|
||||
if len(matches) != 1 {
|
||||
h.t.Fatalf("expected one update, found: %d", len(matches))
|
||||
}
|
||||
|
||||
// Assert that the update belongs to the first session.
|
||||
if matches[0].ID != *id0 {
|
||||
h.t.Fatalf("expected match for %v, instead is for: %v",
|
||||
*id0, matches[0].ID)
|
||||
}
|
||||
|
||||
// Finally, remove the first session added.
|
||||
h.deleteSession(*id0, nil)
|
||||
|
||||
// The session should no longer be present.
|
||||
h.getSession(id0, wtdb.ErrSessionNotFound)
|
||||
|
||||
// No matches should exist for this hint.
|
||||
matches = h.queryMatches(hint)
|
||||
if len(matches) != 0 {
|
||||
h.t.Fatalf("expected zero updates, found: %d", len(matches))
|
||||
}
|
||||
}
|
||||
|
||||
type stateUpdateTest struct {
|
||||
session *wtdb.SessionInfo
|
||||
sessionErr error
|
||||
updates []*wtdb.SessionStateUpdate
|
||||
updateErrs []error
|
||||
}
|
||||
|
||||
func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
|
||||
return func(h *towerDBHarness) {
|
||||
// We may need to modify the initial session as we process
|
||||
// updates to discern the expected state of the session. We'll
|
||||
// create a copy of the test session if necessary to prevent
|
||||
// mutations from impacting other tests.
|
||||
var expSession *wtdb.SessionInfo
|
||||
|
||||
// Create the session if the tests requests one.
|
||||
if test.session != nil {
|
||||
// Copy the initial session and insert it into the
|
||||
// database.
|
||||
ogSession := *test.session
|
||||
expErr := test.sessionErr
|
||||
h.insertSession(&ogSession, expErr)
|
||||
|
||||
if expErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the initial state of the accepted session.
|
||||
expSession = &wtdb.SessionInfo{}
|
||||
*expSession = *test.session
|
||||
}
|
||||
|
||||
if len(test.updates) != len(test.updateErrs) {
|
||||
h.t.Fatalf("malformed test case, num updates " +
|
||||
"should match num errors")
|
||||
}
|
||||
|
||||
// Send any updates provided in the test.
|
||||
for i, update := range test.updates {
|
||||
expErr := test.updateErrs[i]
|
||||
h.insertUpdate(update, expErr)
|
||||
|
||||
if expErr != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't perform the following checks and modfications
|
||||
// if we don't have an expected session to compare
|
||||
// against.
|
||||
if expSession == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update the session's last applied and client last
|
||||
// applied.
|
||||
expSession.LastApplied = update.SeqNum
|
||||
expSession.ClientLastApplied = update.LastApplied
|
||||
|
||||
match := h.hasUpdate(update.Hint)
|
||||
if !reflect.DeepEqual(match.SessionInfo, expSession) {
|
||||
h.t.Fatalf("expected session: %v, got: %v",
|
||||
expSession, match.SessionInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var stateUpdateNoSession = stateUpdateTest{
|
||||
session: nil,
|
||||
updates: []*wtdb.SessionStateUpdate{
|
||||
{ID: *id(0), SeqNum: 1, LastApplied: 0},
|
||||
},
|
||||
updateErrs: []error{
|
||||
wtdb.ErrSessionNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
var stateUpdateExhaustSession = stateUpdateTest{
|
||||
session: &wtdb.SessionInfo{
|
||||
ID: *id(0),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
},
|
||||
updates: []*wtdb.SessionStateUpdate{
|
||||
updateFromInt(id(0), 1, 0),
|
||||
updateFromInt(id(0), 2, 0),
|
||||
updateFromInt(id(0), 3, 0),
|
||||
updateFromInt(id(0), 4, 0),
|
||||
},
|
||||
updateErrs: []error{
|
||||
nil, nil, nil, wtdb.ErrSessionConsumed,
|
||||
},
|
||||
}
|
||||
|
||||
var stateUpdateSeqNumEqualLastApplied = stateUpdateTest{
|
||||
session: &wtdb.SessionInfo{
|
||||
ID: *id(0),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
},
|
||||
updates: []*wtdb.SessionStateUpdate{
|
||||
updateFromInt(id(0), 1, 0),
|
||||
updateFromInt(id(0), 2, 1),
|
||||
updateFromInt(id(0), 3, 2),
|
||||
updateFromInt(id(0), 3, 3),
|
||||
},
|
||||
updateErrs: []error{
|
||||
nil, nil, nil, wtdb.ErrSeqNumAlreadyApplied,
|
||||
},
|
||||
}
|
||||
|
||||
var stateUpdateSeqNumLTLastApplied = stateUpdateTest{
|
||||
session: &wtdb.SessionInfo{
|
||||
ID: *id(0),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
},
|
||||
updates: []*wtdb.SessionStateUpdate{
|
||||
updateFromInt(id(0), 1, 0),
|
||||
updateFromInt(id(0), 2, 1),
|
||||
updateFromInt(id(0), 1, 2),
|
||||
},
|
||||
updateErrs: []error{
|
||||
nil, nil, wtdb.ErrSeqNumAlreadyApplied,
|
||||
},
|
||||
}
|
||||
|
||||
var stateUpdateSeqNumZeroInvalid = stateUpdateTest{
|
||||
session: &wtdb.SessionInfo{
|
||||
ID: *id(0),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
},
|
||||
updates: []*wtdb.SessionStateUpdate{
|
||||
updateFromInt(id(0), 0, 0),
|
||||
},
|
||||
updateErrs: []error{
|
||||
wtdb.ErrSeqNumAlreadyApplied,
|
||||
},
|
||||
}
|
||||
|
||||
var stateUpdateSkipSeqNum = stateUpdateTest{
|
||||
session: &wtdb.SessionInfo{
|
||||
ID: *id(0),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
},
|
||||
updates: []*wtdb.SessionStateUpdate{
|
||||
updateFromInt(id(0), 2, 0),
|
||||
},
|
||||
updateErrs: []error{
|
||||
wtdb.ErrUpdateOutOfOrder,
|
||||
},
|
||||
}
|
||||
|
||||
var stateUpdateRevertSeqNum = stateUpdateTest{
|
||||
session: &wtdb.SessionInfo{
|
||||
ID: *id(0),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
},
|
||||
updates: []*wtdb.SessionStateUpdate{
|
||||
updateFromInt(id(0), 1, 0),
|
||||
updateFromInt(id(0), 2, 0),
|
||||
updateFromInt(id(0), 1, 0),
|
||||
},
|
||||
updateErrs: []error{
|
||||
nil, nil, wtdb.ErrUpdateOutOfOrder,
|
||||
},
|
||||
}
|
||||
|
||||
var stateUpdateRevertLastApplied = stateUpdateTest{
|
||||
session: &wtdb.SessionInfo{
|
||||
ID: *id(0),
|
||||
Policy: wtpolicy.Policy{
|
||||
MaxUpdates: 3,
|
||||
},
|
||||
RewardAddress: []byte{},
|
||||
},
|
||||
updates: []*wtdb.SessionStateUpdate{
|
||||
updateFromInt(id(0), 1, 0),
|
||||
updateFromInt(id(0), 2, 1),
|
||||
updateFromInt(id(0), 3, 2),
|
||||
updateFromInt(id(0), 4, 1),
|
||||
},
|
||||
updateErrs: []error{
|
||||
nil, nil, nil, wtdb.ErrLastAppliedReversion,
|
||||
},
|
||||
}
|
||||
|
||||
func TestTowerDB(t *testing.T) {
|
||||
dbs := []struct {
|
||||
name string
|
||||
init dbInit
|
||||
}{
|
||||
{
|
||||
name: "fresh boltdb",
|
||||
init: func(t *testing.T) (watchtower.DB, func()) {
|
||||
path, err := ioutil.TempDir("", "towerdb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make temp dir: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
db, err := wtdb.OpenTowerDB(path)
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
os.RemoveAll(path)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reopened boltdb",
|
||||
init: func(t *testing.T) (watchtower.DB, func()) {
|
||||
path, err := ioutil.TempDir("", "towerdb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make temp dir: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
db, err := wtdb.OpenTowerDB(path)
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
db.Close()
|
||||
|
||||
// Open the db again, ensuring we test a
|
||||
// different path during open and that all
|
||||
// buckets remain initialized.
|
||||
db, err = wtdb.OpenTowerDB(path)
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
t.Fatalf("unable to open db: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
os.RemoveAll(path)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mock",
|
||||
init: func(t *testing.T) (watchtower.DB, func()) {
|
||||
return wtmock.NewTowerDB(), func() {}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*towerDBHarness)
|
||||
}{
|
||||
{
|
||||
name: "create session",
|
||||
run: testInsertSession,
|
||||
},
|
||||
{
|
||||
name: "delete session",
|
||||
run: testDeleteSession,
|
||||
},
|
||||
{
|
||||
name: "state update no session",
|
||||
run: runStateUpdateTest(stateUpdateNoSession),
|
||||
},
|
||||
{
|
||||
name: "state update exhaust session",
|
||||
run: runStateUpdateTest(stateUpdateExhaustSession),
|
||||
},
|
||||
{
|
||||
name: "state update seqnum equal last applied",
|
||||
run: runStateUpdateTest(
|
||||
stateUpdateSeqNumEqualLastApplied,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "state update seqnum less than last applied",
|
||||
run: runStateUpdateTest(
|
||||
stateUpdateSeqNumLTLastApplied,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "state update seqnum zero invalid",
|
||||
run: runStateUpdateTest(stateUpdateSeqNumZeroInvalid),
|
||||
},
|
||||
{
|
||||
name: "state update skip seqnum",
|
||||
run: runStateUpdateTest(stateUpdateSkipSeqNum),
|
||||
},
|
||||
{
|
||||
name: "state update revert seqnum",
|
||||
run: runStateUpdateTest(stateUpdateRevertSeqNum),
|
||||
},
|
||||
{
|
||||
name: "state update revert last applied",
|
||||
run: runStateUpdateTest(stateUpdateRevertLastApplied),
|
||||
},
|
||||
{
|
||||
name: "multiple breach matches",
|
||||
run: testMultipleMatches,
|
||||
},
|
||||
{
|
||||
name: "lookout tip",
|
||||
run: testLookoutTip,
|
||||
},
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
db := database
|
||||
t.Run(db.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
h, cleanup := newTowerDBHarness(
|
||||
t, db.init,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
test.run(h)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// id creates a session id from an integer.
|
||||
func id(i int) *wtdb.SessionID {
|
||||
var id wtdb.SessionID
|
||||
binary.BigEndian.PutUint32(id[:4], uint32(i))
|
||||
return &id
|
||||
}
|
||||
|
||||
// updateFromInt creates a unique update for a given (session, seqnum) pair. The
|
||||
// lastApplied argument can be used to construct updates simulating different
|
||||
// levels of synchronicity between client and db.
|
||||
func updateFromInt(id *wtdb.SessionID, i int,
|
||||
lastApplied uint16) *wtdb.SessionStateUpdate {
|
||||
|
||||
// Ensure the hint is unique.
|
||||
var hint wtdb.BreachHint
|
||||
copy(hint[:4], id[:4])
|
||||
binary.BigEndian.PutUint16(hint[4:6], uint16(i))
|
||||
|
||||
return &wtdb.SessionStateUpdate{
|
||||
ID: *id,
|
||||
Hint: hint,
|
||||
SeqNum: uint16(i),
|
||||
LastApplied: lastApplied,
|
||||
EncryptedBlob: []byte{byte(i)},
|
||||
}
|
||||
}
|
||||
|
||||
// epochFromInt creates a block epoch from an integer.
|
||||
func epochFromInt(i int) *chainntnfs.BlockEpoch {
|
||||
var hash chainhash.Hash
|
||||
binary.BigEndian.PutUint32(hash[:4], uint32(i))
|
||||
|
||||
return &chainntnfs.BlockEpoch{
|
||||
Hash: &hash,
|
||||
Height: int32(i),
|
||||
}
|
||||
}
|
84
watchtower/wtdb/version.go
Normal file
84
watchtower/wtdb/version.go
Normal file
@ -0,0 +1,84 @@
|
||||
package wtdb
|
||||
|
||||
import "github.com/coreos/bbolt"
|
||||
|
||||
// migration is a function which takes a prior outdated version of the database
|
||||
// instances and mutates the key/bucket structure to arrive at a more
|
||||
// up-to-date version of the database.
|
||||
type migration func(tx *bbolt.Tx) error
|
||||
|
||||
// version pairs a version number with the migration that would need to be
|
||||
// applied from the prior version to upgrade.
|
||||
type version struct {
|
||||
number uint32
|
||||
migration migration
|
||||
}
|
||||
|
||||
// dbVersions stores all versions and migrations of the database. This list will
|
||||
// be used when opening the database to determine if any migrations must be
|
||||
// applied.
|
||||
var dbVersions = []version{
|
||||
{
|
||||
// Initial version requires no migration.
|
||||
number: 0,
|
||||
migration: nil,
|
||||
},
|
||||
}
|
||||
|
||||
// getLatestDBVersion returns the last known database version.
|
||||
func getLatestDBVersion(versions []version) uint32 {
|
||||
return versions[len(versions)-1].number
|
||||
}
|
||||
|
||||
// getMigrations returns a slice of all updates with a greater number that
|
||||
// curVersion that need to be applied to sync up with the latest version.
|
||||
func getMigrations(versions []version, curVersion uint32) []version {
|
||||
var updates []version
|
||||
for _, v := range versions {
|
||||
if v.number > curVersion {
|
||||
updates = append(updates, v)
|
||||
}
|
||||
}
|
||||
|
||||
return updates
|
||||
}
|
||||
|
||||
// getDBVersion retrieves the current database version from the metadata bucket
|
||||
// using the dbVersionKey.
|
||||
func getDBVersion(tx *bbolt.Tx) (uint32, error) {
|
||||
metadata := tx.Bucket(metadataBkt)
|
||||
if metadata == nil {
|
||||
return 0, ErrUninitializedDB
|
||||
}
|
||||
|
||||
versionBytes := metadata.Get(dbVersionKey)
|
||||
if len(versionBytes) != 4 {
|
||||
return 0, ErrNoDBVersion
|
||||
}
|
||||
|
||||
return byteOrder.Uint32(versionBytes), nil
|
||||
}
|
||||
|
||||
// initDBVersion initializes the top-level metadata bucket and writes the passed
|
||||
// version number as the current version.
|
||||
func initDBVersion(tx *bbolt.Tx, version uint32) error {
|
||||
_, err := tx.CreateBucketIfNotExists(metadataBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return putDBVersion(tx, version)
|
||||
}
|
||||
|
||||
// putDBVersion stores the passed database version in the metadata bucket under
|
||||
// the dbVersionKey.
|
||||
func putDBVersion(tx *bbolt.Tx, version uint32) error {
|
||||
metadata := tx.Bucket(metadataBkt)
|
||||
if metadata == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
versionBytes := make([]byte, 4)
|
||||
byteOrder.PutUint32(versionBytes, version)
|
||||
return metadata.Put(dbVersionKey, versionBytes)
|
||||
}
|
162
watchtower/wtmock/tower_db.go
Normal file
162
watchtower/wtmock/tower_db.go
Normal file
@ -0,0 +1,162 @@
|
||||
package wtmock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
// TowerDB is a mock, in-memory implementation of a watchtower.DB.
|
||||
type TowerDB struct {
|
||||
mu sync.Mutex
|
||||
lastEpoch *chainntnfs.BlockEpoch
|
||||
sessions map[wtdb.SessionID]*wtdb.SessionInfo
|
||||
blobs map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate
|
||||
}
|
||||
|
||||
// NewTowerDB initializes a fresh mock TowerDB.
|
||||
func NewTowerDB() *TowerDB {
|
||||
return &TowerDB{
|
||||
sessions: make(map[wtdb.SessionID]*wtdb.SessionInfo),
|
||||
blobs: make(map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate),
|
||||
}
|
||||
}
|
||||
|
||||
// InsertStateUpdate stores an update sent by the client after validating that
|
||||
// the update is well-formed in the context of other updates sent for the same
|
||||
// session. This include verifying that the sequence number is incremented
|
||||
// properly and the last applied values echoed by the client are sane.
|
||||
func (db *TowerDB) InsertStateUpdate(update *wtdb.SessionStateUpdate) (uint16, error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
info, ok := db.sessions[update.ID]
|
||||
if !ok {
|
||||
return 0, wtdb.ErrSessionNotFound
|
||||
}
|
||||
|
||||
err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied)
|
||||
if err != nil {
|
||||
return info.LastApplied, err
|
||||
}
|
||||
|
||||
sessionsToUpdates, ok := db.blobs[update.Hint]
|
||||
if !ok {
|
||||
sessionsToUpdates = make(map[wtdb.SessionID]*wtdb.SessionStateUpdate)
|
||||
db.blobs[update.Hint] = sessionsToUpdates
|
||||
}
|
||||
sessionsToUpdates[update.ID] = update
|
||||
|
||||
return info.LastApplied, nil
|
||||
}
|
||||
|
||||
// GetSessionInfo retrieves the session for the passed session id. An error is
|
||||
// returned if the session could not be found.
|
||||
func (db *TowerDB) GetSessionInfo(id *wtdb.SessionID) (*wtdb.SessionInfo, error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
if info, ok := db.sessions[*id]; ok {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
return nil, wtdb.ErrSessionNotFound
|
||||
}
|
||||
|
||||
// InsertSessionInfo records a negotiated session in the tower database. An
|
||||
// error is returned if the session already exists.
|
||||
func (db *TowerDB) InsertSessionInfo(info *wtdb.SessionInfo) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
dbInfo, ok := db.sessions[info.ID]
|
||||
if ok && dbInfo.LastApplied > 0 {
|
||||
return wtdb.ErrSessionAlreadyExists
|
||||
}
|
||||
|
||||
db.sessions[info.ID] = info
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSession removes all data associated with a particular session id from
|
||||
// the tower's database.
|
||||
func (db *TowerDB) DeleteSession(target wtdb.SessionID) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
// Fail if the session doesn't exit.
|
||||
if _, ok := db.sessions[target]; !ok {
|
||||
return wtdb.ErrSessionNotFound
|
||||
}
|
||||
|
||||
// Remove the target session.
|
||||
delete(db.sessions, target)
|
||||
|
||||
// Remove the state updates for any blobs stored under the target
|
||||
// session identifier.
|
||||
for hint, sessionUpdates := range db.blobs {
|
||||
delete(sessionUpdates, target)
|
||||
|
||||
// If this was the last state update, we can also remove the
|
||||
// hint that would map to an empty set.
|
||||
if len(sessionUpdates) == 0 {
|
||||
delete(db.blobs, hint)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryMatches searches against all known state updates for any that match the
|
||||
// passed breachHints. More than one Match will be returned for a given hint if
|
||||
// they exist in the database.
|
||||
func (db *TowerDB) QueryMatches(
|
||||
breachHints []wtdb.BreachHint) ([]wtdb.Match, error) {
|
||||
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
var matches []wtdb.Match
|
||||
for _, hint := range breachHints {
|
||||
sessionsToUpdates, ok := db.blobs[hint]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for id, update := range sessionsToUpdates {
|
||||
info, ok := db.sessions[id]
|
||||
if !ok {
|
||||
panic("session not found")
|
||||
}
|
||||
|
||||
match := wtdb.Match{
|
||||
ID: id,
|
||||
SeqNum: update.SeqNum,
|
||||
Hint: hint,
|
||||
EncryptedBlob: update.EncryptedBlob,
|
||||
SessionInfo: info,
|
||||
}
|
||||
matches = append(matches, match)
|
||||
}
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in
|
||||
// the tower database.
|
||||
func (db *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
|
||||
db.lastEpoch = epoch
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLookoutTip retrieves the current lookout tip block epoch from the tower
|
||||
// database.
|
||||
func (db *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
return db.lastEpoch, nil
|
||||
}
|
@ -102,7 +102,7 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
||||
// successful, the session will now be ready for use.
|
||||
err = s.cfg.DB.InsertSessionInfo(&info)
|
||||
if err != nil {
|
||||
log.Errorf("unable to create session for %s", id)
|
||||
log.Errorf("Unable to create session for %s: %v", id, err)
|
||||
return s.replyCreateSession(
|
||||
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||
)
|
||||
|
@ -1,5 +1,3 @@
|
||||
// +build dev
|
||||
|
||||
package wtserver_test
|
||||
|
||||
import (
|
||||
@ -53,7 +51,7 @@ func initServer(t *testing.T, db wtserver.DB,
|
||||
t.Helper()
|
||||
|
||||
if db == nil {
|
||||
db = wtdb.NewMockDB()
|
||||
db = wtmock.NewTowerDB()
|
||||
}
|
||||
|
||||
s, err := wtserver.New(&wtserver.Config{
|
||||
@ -687,7 +685,7 @@ func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) {
|
||||
// checking that the proper error is returned when the session doesn't exist and
|
||||
// that a successful deletion does not disrupt other sessions.
|
||||
func TestServerDeleteSession(t *testing.T) {
|
||||
db := wtdb.NewMockDB()
|
||||
db := wtmock.NewTowerDB()
|
||||
|
||||
localPub := randPubKey(t)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user