Merge pull request #2783 from cfromknecht/wtserver-db
watchtower/wtdb: add bbolt-backed tower database
This commit is contained in:
commit
0393793733
@ -51,6 +51,12 @@ type UnknownElementType struct {
|
|||||||
element interface{}
|
element interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewUnknownElementType creates a new UnknownElementType error from the passed
|
||||||
|
// method name and element.
|
||||||
|
func NewUnknownElementType(method string, el interface{}) UnknownElementType {
|
||||||
|
return UnknownElementType{method: method, element: el}
|
||||||
|
}
|
||||||
|
|
||||||
// Error returns the name of the method that encountered the error, as well as
|
// Error returns the name of the method that encountered the error, as well as
|
||||||
// the type that was unsupported.
|
// the type that was unsupported.
|
||||||
func (e UnknownElementType) Error() string {
|
func (e UnknownElementType) Error() string {
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/lookout"
|
"github.com/lightningnetwork/lnd/watchtower/lookout"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,7 +67,7 @@ func makeAddrSlice(size int) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLookoutBreachMatching(t *testing.T) {
|
func TestLookoutBreachMatching(t *testing.T) {
|
||||||
db := wtdb.NewMockDB()
|
db := wtmock.NewTowerDB()
|
||||||
|
|
||||||
// Initialize an mock backend to feed the lookout blocks.
|
// Initialize an mock backend to feed the lookout blocks.
|
||||||
backend := lookout.NewMockBackend()
|
backend := lookout.NewMockBackend()
|
||||||
|
@ -369,7 +369,7 @@ type testHarness struct {
|
|||||||
clientDB *wtmock.ClientDB
|
clientDB *wtmock.ClientDB
|
||||||
clientCfg *wtclient.Config
|
clientCfg *wtclient.Config
|
||||||
client wtclient.Client
|
client wtclient.Client
|
||||||
serverDB *wtdb.MockDB
|
serverDB *wtmock.TowerDB
|
||||||
serverCfg *wtserver.Config
|
serverCfg *wtserver.Config
|
||||||
server *wtserver.Server
|
server *wtserver.Server
|
||||||
net *mockNet
|
net *mockNet
|
||||||
@ -406,7 +406,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const timeout = 200 * time.Millisecond
|
const timeout = 200 * time.Millisecond
|
||||||
serverDB := wtdb.NewMockDB()
|
serverDB := wtmock.NewTowerDB()
|
||||||
|
|
||||||
serverCfg := &wtserver.Config{
|
serverCfg := &wtserver.Config{
|
||||||
DB: serverDB,
|
DB: serverDB,
|
||||||
|
143
watchtower/wtdb/codec.go
Normal file
143
watchtower/wtdb/codec.go
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
package wtdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwallet"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UnknownElementType is an alias for channeldb.UnknownElementType.
|
||||||
|
type UnknownElementType = channeldb.UnknownElementType
|
||||||
|
|
||||||
|
// ReadElement deserializes a single element from the provided io.Reader.
|
||||||
|
func ReadElement(r io.Reader, element interface{}) error {
|
||||||
|
err := channeldb.ReadElement(r, element)
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// Known to channeldb codec.
|
||||||
|
case err == nil:
|
||||||
|
return nil
|
||||||
|
|
||||||
|
// Fail if error is not UnknownElementType.
|
||||||
|
case err != nil:
|
||||||
|
if _, ok := err.(UnknownElementType); !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process any wtdb-specific extensions to the codec.
|
||||||
|
switch e := element.(type) {
|
||||||
|
|
||||||
|
case *SessionID:
|
||||||
|
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case *BreachHint:
|
||||||
|
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case *wtpolicy.Policy:
|
||||||
|
var (
|
||||||
|
blobType uint16
|
||||||
|
sweepFeeRate uint64
|
||||||
|
)
|
||||||
|
err := channeldb.ReadElements(r,
|
||||||
|
&blobType,
|
||||||
|
&e.MaxUpdates,
|
||||||
|
&e.RewardBase,
|
||||||
|
&e.RewardRate,
|
||||||
|
&sweepFeeRate,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
e.BlobType = blob.Type(blobType)
|
||||||
|
e.SweepFeeRate = lnwallet.SatPerKWeight(sweepFeeRate)
|
||||||
|
|
||||||
|
// Type is still unknown to wtdb extensions, fail.
|
||||||
|
default:
|
||||||
|
return channeldb.NewUnknownElementType(
|
||||||
|
"ReadElement", element,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteElement serializes a single element into the provided io.Writer.
|
||||||
|
func WriteElement(w io.Writer, element interface{}) error {
|
||||||
|
err := channeldb.WriteElement(w, element)
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// Known to channeldb codec.
|
||||||
|
case err == nil:
|
||||||
|
return nil
|
||||||
|
|
||||||
|
// Fail if error is not UnknownElementType.
|
||||||
|
case err != nil:
|
||||||
|
if _, ok := err.(UnknownElementType); !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process any wtdb-specific extensions to the codec.
|
||||||
|
switch e := element.(type) {
|
||||||
|
|
||||||
|
case SessionID:
|
||||||
|
if _, err := w.Write(e[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case BreachHint:
|
||||||
|
if _, err := w.Write(e[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case wtpolicy.Policy:
|
||||||
|
return channeldb.WriteElements(w,
|
||||||
|
uint16(e.BlobType),
|
||||||
|
e.MaxUpdates,
|
||||||
|
e.RewardBase,
|
||||||
|
e.RewardRate,
|
||||||
|
uint64(e.SweepFeeRate),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Type is still unknown to wtdb extensions, fail.
|
||||||
|
default:
|
||||||
|
return channeldb.NewUnknownElementType(
|
||||||
|
"WriteElement", element,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteElements serializes a variadic list of elements into the given
|
||||||
|
// io.Writer.
|
||||||
|
func WriteElements(w io.Writer, elements ...interface{}) error {
|
||||||
|
for _, element := range elements {
|
||||||
|
if err := WriteElement(w, element); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadElements deserializes the provided io.Reader into a variadic list of
|
||||||
|
// target elements.
|
||||||
|
func ReadElements(r io.Reader, elements ...interface{}) error {
|
||||||
|
for _, element := range elements {
|
||||||
|
if err := ReadElement(r, element); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
86
watchtower/wtdb/codec_test.go
Normal file
86
watchtower/wtdb/codec_test.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package wtdb_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"testing/quick"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dbObject is abstract object support encoding and decoding.
|
||||||
|
type dbObject interface {
|
||||||
|
Encode(io.Writer) error
|
||||||
|
Decode(io.Reader) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCodec serializes and deserializes wtdb objects in order to test that that
|
||||||
|
// the codec understands all of the required field types. The test also asserts
|
||||||
|
// that decoding an object into another results in an equivalent object.
|
||||||
|
func TestCodec(t *testing.T) {
|
||||||
|
mainScenario := func(obj dbObject) bool {
|
||||||
|
// Ensure encoding the object succeeds.
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := obj.Encode(&b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to encode: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var obj2 dbObject
|
||||||
|
switch obj.(type) {
|
||||||
|
case *wtdb.SessionInfo:
|
||||||
|
obj2 = &wtdb.SessionInfo{}
|
||||||
|
case *wtdb.SessionStateUpdate:
|
||||||
|
obj2 = &wtdb.SessionStateUpdate{}
|
||||||
|
default:
|
||||||
|
t.Fatalf("unknown type: %T", obj)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure decoding the object succeeds.
|
||||||
|
err = obj2.Decode(bytes.NewReader(b.Bytes()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to decode: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert the original and decoded object match.
|
||||||
|
if !reflect.DeepEqual(obj, obj2) {
|
||||||
|
t.Fatalf("encode/decode mismatch, want: %v, "+
|
||||||
|
"got: %v", obj, obj2)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
scenario interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SessionInfo",
|
||||||
|
scenario: func(obj wtdb.SessionInfo) bool {
|
||||||
|
return mainScenario(&obj)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SessionStateUpdate",
|
||||||
|
scenario: func(obj wtdb.SessionStateUpdate) bool {
|
||||||
|
return mainScenario(&obj)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
if err := quick.Check(test.scenario, nil); err != nil {
|
||||||
|
t.Fatalf("fuzz checks for msg=%s failed: %v",
|
||||||
|
test.name, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
45
watchtower/wtdb/log.go
Normal file
45
watchtower/wtdb/log.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package wtdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/btcsuite/btclog"
|
||||||
|
"github.com/lightningnetwork/lnd/build"
|
||||||
|
)
|
||||||
|
|
||||||
|
// log is a logger that is initialized with no output filters. This
|
||||||
|
// means the package will not perform any logging by default until the caller
|
||||||
|
// requests it.
|
||||||
|
var log btclog.Logger
|
||||||
|
|
||||||
|
// The default amount of logging is none.
|
||||||
|
func init() {
|
||||||
|
UseLogger(build.NewSubLogger("WTDB", nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableLog disables all library log output. Logging output is disabled
|
||||||
|
// by default until UseLogger is called.
|
||||||
|
func DisableLog() {
|
||||||
|
UseLogger(btclog.Disabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseLogger uses a specified Logger to output package logging info.
|
||||||
|
// This should be used in preference to SetLogWriter if the caller is also
|
||||||
|
// using btclog.
|
||||||
|
func UseLogger(logger btclog.Logger) {
|
||||||
|
log = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// logClosure is used to provide a closure over expensive logging operations so
|
||||||
|
// don't have to be performed when the logging level doesn't warrant it.
|
||||||
|
type logClosure func() string
|
||||||
|
|
||||||
|
// String invokes the underlying function and returns the result.
|
||||||
|
func (c logClosure) String() string {
|
||||||
|
return c()
|
||||||
|
}
|
||||||
|
|
||||||
|
// newLogClosure returns a new closure over a function that returns a string
|
||||||
|
// which itself provides a Stringer interface so that it can be used with the
|
||||||
|
// logging system.
|
||||||
|
func newLogClosure(c func() string) logClosure {
|
||||||
|
return logClosure(c)
|
||||||
|
}
|
@ -1,142 +0,0 @@
|
|||||||
// +build dev
|
|
||||||
|
|
||||||
package wtdb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockDB struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
lastEpoch *chainntnfs.BlockEpoch
|
|
||||||
sessions map[SessionID]*SessionInfo
|
|
||||||
blobs map[BreachHint]map[SessionID]*SessionStateUpdate
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMockDB() *MockDB {
|
|
||||||
return &MockDB{
|
|
||||||
sessions: make(map[SessionID]*SessionInfo),
|
|
||||||
blobs: make(map[BreachHint]map[SessionID]*SessionStateUpdate),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *MockDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) {
|
|
||||||
db.mu.Lock()
|
|
||||||
defer db.mu.Unlock()
|
|
||||||
|
|
||||||
info, ok := db.sessions[update.ID]
|
|
||||||
if !ok {
|
|
||||||
return 0, ErrSessionNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied)
|
|
||||||
if err != nil {
|
|
||||||
return info.LastApplied, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionsToUpdates, ok := db.blobs[update.Hint]
|
|
||||||
if !ok {
|
|
||||||
sessionsToUpdates = make(map[SessionID]*SessionStateUpdate)
|
|
||||||
db.blobs[update.Hint] = sessionsToUpdates
|
|
||||||
}
|
|
||||||
sessionsToUpdates[update.ID] = update
|
|
||||||
|
|
||||||
return info.LastApplied, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *MockDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
|
|
||||||
db.mu.Lock()
|
|
||||||
defer db.mu.Unlock()
|
|
||||||
|
|
||||||
if info, ok := db.sessions[*id]; ok {
|
|
||||||
return info, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrSessionNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *MockDB) InsertSessionInfo(info *SessionInfo) error {
|
|
||||||
db.mu.Lock()
|
|
||||||
defer db.mu.Unlock()
|
|
||||||
|
|
||||||
dbInfo, ok := db.sessions[info.ID]
|
|
||||||
if ok && dbInfo.LastApplied > 0 {
|
|
||||||
return ErrSessionAlreadyExists
|
|
||||||
}
|
|
||||||
|
|
||||||
db.sessions[info.ID] = info
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *MockDB) DeleteSession(target SessionID) error {
|
|
||||||
db.mu.Lock()
|
|
||||||
defer db.mu.Unlock()
|
|
||||||
|
|
||||||
// Fail if the session doesn't exit.
|
|
||||||
if _, ok := db.sessions[target]; !ok {
|
|
||||||
return ErrSessionNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the target session.
|
|
||||||
delete(db.sessions, target)
|
|
||||||
|
|
||||||
// Remove the state updates for any blobs stored under the target
|
|
||||||
// session identifier.
|
|
||||||
for hint, sessionUpdates := range db.blobs {
|
|
||||||
delete(sessionUpdates, target)
|
|
||||||
|
|
||||||
//If this was the last state update, we can also remove the hint
|
|
||||||
//that would map to an empty set.
|
|
||||||
if len(sessionUpdates) == 0 {
|
|
||||||
delete(db.blobs, hint)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
|
|
||||||
db.mu.Lock()
|
|
||||||
defer db.mu.Unlock()
|
|
||||||
|
|
||||||
return db.lastEpoch, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *MockDB) QueryMatches(breachHints []BreachHint) ([]Match, error) {
|
|
||||||
db.mu.Lock()
|
|
||||||
defer db.mu.Unlock()
|
|
||||||
|
|
||||||
var matches []Match
|
|
||||||
for _, hint := range breachHints {
|
|
||||||
sessionsToUpdates, ok := db.blobs[hint]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, update := range sessionsToUpdates {
|
|
||||||
info, ok := db.sessions[id]
|
|
||||||
if !ok {
|
|
||||||
panic("session not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
match := Match{
|
|
||||||
ID: id,
|
|
||||||
SeqNum: update.SeqNum,
|
|
||||||
Hint: hint,
|
|
||||||
EncryptedBlob: update.EncryptedBlob,
|
|
||||||
SessionInfo: info,
|
|
||||||
}
|
|
||||||
matches = append(matches, match)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return matches, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *MockDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
|
|
||||||
db.lastEpoch = epoch
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -2,6 +2,7 @@ package wtdb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||||
)
|
)
|
||||||
@ -59,6 +60,28 @@ type SessionInfo struct {
|
|||||||
// TODO(conner): store client metrics, DOS score, etc
|
// TODO(conner): store client metrics, DOS score, etc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Encode serializes the session info to the given io.Writer.
|
||||||
|
func (s *SessionInfo) Encode(w io.Writer) error {
|
||||||
|
return WriteElements(w,
|
||||||
|
s.ID,
|
||||||
|
s.Policy,
|
||||||
|
s.LastApplied,
|
||||||
|
s.ClientLastApplied,
|
||||||
|
s.RewardAddress,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode deserializes the session infor from the given io.Reader.
|
||||||
|
func (s *SessionInfo) Decode(r io.Reader) error {
|
||||||
|
return ReadElements(r,
|
||||||
|
&s.ID,
|
||||||
|
&s.Policy,
|
||||||
|
&s.LastApplied,
|
||||||
|
&s.ClientLastApplied,
|
||||||
|
&s.RewardAddress,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// AcceptUpdateSequence validates that a state update's sequence number and last
|
// AcceptUpdateSequence validates that a state update's sequence number and last
|
||||||
// applied are valid given our past history with the client. These checks ensure
|
// applied are valid given our past history with the client. These checks ensure
|
||||||
// that clients are properly in sync and following the update protocol properly.
|
// that clients are properly in sync and following the update protocol properly.
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package wtdb
|
package wtdb
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
// SessionStateUpdate holds a state update sent by a client along with its
|
// SessionStateUpdate holds a state update sent by a client along with its
|
||||||
// SessionID.
|
// SessionID.
|
||||||
type SessionStateUpdate struct {
|
type SessionStateUpdate struct {
|
||||||
@ -21,3 +23,25 @@ type SessionStateUpdate struct {
|
|||||||
// hint is braodcast.
|
// hint is braodcast.
|
||||||
EncryptedBlob []byte
|
EncryptedBlob []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Encode serializes the state update into the provided io.Writer.
|
||||||
|
func (u *SessionStateUpdate) Encode(w io.Writer) error {
|
||||||
|
return WriteElements(w,
|
||||||
|
u.ID,
|
||||||
|
u.SeqNum,
|
||||||
|
u.LastApplied,
|
||||||
|
u.Hint,
|
||||||
|
u.EncryptedBlob,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode deserializes the target state update from the provided io.Reader.
|
||||||
|
func (u *SessionStateUpdate) Decode(r io.Reader) error {
|
||||||
|
return ReadElements(r,
|
||||||
|
&u.ID,
|
||||||
|
&u.SeqNum,
|
||||||
|
&u.LastApplied,
|
||||||
|
&u.Hint,
|
||||||
|
&u.EncryptedBlob,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
733
watchtower/wtdb/tower_db.go
Normal file
733
watchtower/wtdb/tower_db.go
Normal file
@ -0,0 +1,733 @@
|
|||||||
|
package wtdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
|
"github.com/coreos/bbolt"
|
||||||
|
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// dbName is the filename of tower database.
|
||||||
|
dbName = "watchtower.db"
|
||||||
|
|
||||||
|
// dbFilePermission requests read+write access to the db file.
|
||||||
|
dbFilePermission = 0600
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// sessionsBkt is a bucket containing all negotiated client sessions.
|
||||||
|
// session id -> session
|
||||||
|
sessionsBkt = []byte("sessions-bucket")
|
||||||
|
|
||||||
|
// updatesBkt is a bucket containing all state updates sent by clients.
|
||||||
|
// The updates are further bucketed by session id to prevent clients
|
||||||
|
// from overwrite each other.
|
||||||
|
// hint => session id -> update
|
||||||
|
updatesBkt = []byte("updates-bucket")
|
||||||
|
|
||||||
|
// updateIndexBkt is a bucket that indexes all state updates by their
|
||||||
|
// overarching session id. This allows for efficient lookup of updates
|
||||||
|
// by their session id, which is currently used to aide deletion
|
||||||
|
// performance.
|
||||||
|
// session id => hint1 -> []byte{}
|
||||||
|
// => hint2 -> []byte{}
|
||||||
|
updateIndexBkt = []byte("update-index-bucket")
|
||||||
|
|
||||||
|
// lookoutTipBkt is a bucket containing the last block epoch processed
|
||||||
|
// by the lookout subsystem. It has one key, lookoutTipKey.
|
||||||
|
// lookoutTipKey -> block epoch
|
||||||
|
lookoutTipBkt = []byte("lookout-tip-bucket")
|
||||||
|
|
||||||
|
// lookoutTipKey is a static key used to retrieve lookout tip's block
|
||||||
|
// epoch from the lookoutTipBkt.
|
||||||
|
lookoutTipKey = []byte("lookout-tip")
|
||||||
|
|
||||||
|
// metadataBkt stores all the meta information concerning the state of
|
||||||
|
// the database.
|
||||||
|
metadataBkt = []byte("metadata-bucket")
|
||||||
|
|
||||||
|
// dbVersionKey is a static key used to retrieve the database version
|
||||||
|
// number from the metadataBkt.
|
||||||
|
dbVersionKey = []byte("version")
|
||||||
|
|
||||||
|
// ErrUninitializedDB signals that top-level buckets for the database
|
||||||
|
// have not been initialized.
|
||||||
|
ErrUninitializedDB = errors.New("tower db not initialized")
|
||||||
|
|
||||||
|
// ErrNoDBVersion signals that the database contains no version info.
|
||||||
|
ErrNoDBVersion = errors.New("tower db has no version")
|
||||||
|
|
||||||
|
// ErrNoSessionHintIndex signals that an active session does not have an
|
||||||
|
// initialized index for tracking its own state updates.
|
||||||
|
ErrNoSessionHintIndex = errors.New("session hint index missing")
|
||||||
|
|
||||||
|
byteOrder = binary.BigEndian
|
||||||
|
)
|
||||||
|
|
||||||
|
// TowerDB is single database providing a persistent storage engine for the
|
||||||
|
// wtserver and lookout subsystems.
|
||||||
|
type TowerDB struct {
|
||||||
|
db *bbolt.DB
|
||||||
|
dbPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenTowerDB opens the tower database given the path to the database's
|
||||||
|
// directory. If no such database exists, this method will initialize a fresh
|
||||||
|
// one using the latest version number and bucket structure. If a database
|
||||||
|
// exists but has a lower version number than the current version, any necessary
|
||||||
|
// migrations will be applied before returning. Any attempt to open a database
|
||||||
|
// with a version number higher that the latest version will fail to prevent
|
||||||
|
// accidental reversion.
|
||||||
|
func OpenTowerDB(dbPath string) (*TowerDB, error) {
|
||||||
|
path := filepath.Join(dbPath, dbName)
|
||||||
|
|
||||||
|
// If the database file doesn't exist, this indicates we much initialize
|
||||||
|
// a fresh database with the latest version.
|
||||||
|
firstInit := !fileExists(path)
|
||||||
|
if firstInit {
|
||||||
|
// Ensure all parent directories are initialized.
|
||||||
|
err := os.MkdirAll(dbPath, 0700)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bdb, err := bbolt.Open(path, dbFilePermission, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the file existed previously, we'll now check to see that the
|
||||||
|
// metadata bucket is properly initialized. It could be the case that
|
||||||
|
// the database was created, but we failed to actually populate any
|
||||||
|
// metadata. If the metadata bucket does not actually exist, we'll
|
||||||
|
// set firstInit to true so that we can treat is initialize the bucket.
|
||||||
|
if !firstInit {
|
||||||
|
var metadataExists bool
|
||||||
|
err = bdb.View(func(tx *bbolt.Tx) error {
|
||||||
|
metadataExists = tx.Bucket(metadataBkt) != nil
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !metadataExists {
|
||||||
|
firstInit = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
towerDB := &TowerDB{
|
||||||
|
db: bdb,
|
||||||
|
dbPath: dbPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstInit {
|
||||||
|
// If the database has not yet been created, we'll initialize
|
||||||
|
// the database version with the latest known version.
|
||||||
|
err = towerDB.db.Update(func(tx *bbolt.Tx) error {
|
||||||
|
return initDBVersion(tx, getLatestDBVersion(dbVersions))
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
bdb.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Otherwise, ensure that any migrations are applied to ensure
|
||||||
|
// the data is in the format expected by the latest version.
|
||||||
|
err = towerDB.syncVersions(dbVersions)
|
||||||
|
if err != nil {
|
||||||
|
bdb.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that the database version fully consistent with our latest known
|
||||||
|
// version, ensure that all top-level buckets known to this version are
|
||||||
|
// initialized. This allows us to assume their presence throughout all
|
||||||
|
// operations. If an known top-level bucket is expected to exist but is
|
||||||
|
// missing, this will trigger a ErrUninitializedDB error.
|
||||||
|
err = towerDB.db.Update(initTowerDBBuckets)
|
||||||
|
if err != nil {
|
||||||
|
bdb.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return towerDB, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileExists returns true if the file exists, and false otherwise.
|
||||||
|
func fileExists(path string) bool {
|
||||||
|
if _, err := os.Stat(path); err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// initTowerDBBuckets creates all top-level buckets required to handle database
|
||||||
|
// operations required by the latest version.
|
||||||
|
func initTowerDBBuckets(tx *bbolt.Tx) error {
|
||||||
|
buckets := [][]byte{
|
||||||
|
sessionsBkt,
|
||||||
|
updateIndexBkt,
|
||||||
|
updatesBkt,
|
||||||
|
lookoutTipBkt,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bucket := range buckets {
|
||||||
|
_, err := tx.CreateBucketIfNotExists(bucket)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// syncVersions ensures the database version is consistent with the highest
|
||||||
|
// known database version, applying any migrations that have not been made. If
|
||||||
|
// the highest known version number is lower than the database's version, this
|
||||||
|
// method will fail to prevent accidental reversions.
|
||||||
|
func (t *TowerDB) syncVersions(versions []version) error {
|
||||||
|
curVersion, err := t.Version()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
latestVersion := getLatestDBVersion(versions)
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// Current version is higher than any known version, fail to prevent
|
||||||
|
// reversion.
|
||||||
|
case curVersion > latestVersion:
|
||||||
|
return channeldb.ErrDBReversion
|
||||||
|
|
||||||
|
// Current version matches highest known version, nothing to do.
|
||||||
|
case curVersion == latestVersion:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, apply any migrations in order to bring the database
|
||||||
|
// version up to the highest known version.
|
||||||
|
updates := getMigrations(versions, curVersion)
|
||||||
|
return t.db.Update(func(tx *bbolt.Tx) error {
|
||||||
|
for _, update := range updates {
|
||||||
|
if update.migration == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Applying migration #%d", update.number)
|
||||||
|
|
||||||
|
err := update.migration(tx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Unable to apply migration #%d: %v",
|
||||||
|
err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return putDBVersion(tx, latestVersion)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version returns the database's current version number.
|
||||||
|
func (t *TowerDB) Version() (uint32, error) {
|
||||||
|
var version uint32
|
||||||
|
err := t.db.View(func(tx *bbolt.Tx) error {
|
||||||
|
var err error
|
||||||
|
version, err = getDBVersion(tx)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return version, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying database.
|
||||||
|
func (t *TowerDB) Close() error {
|
||||||
|
return t.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionInfo retrieves the session for the passed session id. An error is
|
||||||
|
// returned if the session could not be found.
|
||||||
|
func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
|
||||||
|
var session *SessionInfo
|
||||||
|
err := t.db.View(func(tx *bbolt.Tx) error {
|
||||||
|
sessions := tx.Bucket(sessionsBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
session, err = getSession(sessions, id[:])
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertSessionInfo records a negotiated session in the tower database. An
|
||||||
|
// error is returned if the session already exists.
|
||||||
|
func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error {
|
||||||
|
return t.db.Update(func(tx *bbolt.Tx) error {
|
||||||
|
sessions := tx.Bucket(sessionsBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
updateIndex := tx.Bucket(updateIndexBkt)
|
||||||
|
if updateIndex == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
dbSession, err := getSession(sessions, session.ID[:])
|
||||||
|
switch {
|
||||||
|
case err == ErrSessionNotFound:
|
||||||
|
// proceed.
|
||||||
|
|
||||||
|
case err != nil:
|
||||||
|
return err
|
||||||
|
|
||||||
|
case dbSession.LastApplied > 0:
|
||||||
|
return ErrSessionAlreadyExists
|
||||||
|
}
|
||||||
|
|
||||||
|
err = putSession(sessions, session)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the session-hint index which will be used to track
|
||||||
|
// all updates added for this session. Upon deletion, we will
|
||||||
|
// consult the index to determine exactly which updates should
|
||||||
|
// be deleted without needing to iterate over the entire
|
||||||
|
// database.
|
||||||
|
return touchSessionHintBkt(updateIndex, &session.ID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertStateUpdate stores an update sent by the client after validating that
|
||||||
|
// the update is well-formed in the context of other updates sent for the same
|
||||||
|
// session. This include verifying that the sequence number is incremented
|
||||||
|
// properly and the last applied values echoed by the client are sane.
|
||||||
|
func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) {
|
||||||
|
var lastApplied uint16
|
||||||
|
err := t.db.Update(func(tx *bbolt.Tx) error {
|
||||||
|
sessions := tx.Bucket(sessionsBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
updates := tx.Bucket(updatesBkt)
|
||||||
|
if updates == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
updateIndex := tx.Bucket(updateIndexBkt)
|
||||||
|
if updateIndex == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch the session corresponding to the update's session id.
|
||||||
|
// This will be used to validate that the update's sequence
|
||||||
|
// number and last applied values are sane.
|
||||||
|
session, err := getSession(sessions, update.ID[:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the update against the current state of the session.
|
||||||
|
err = session.AcceptUpdateSequence(
|
||||||
|
update.SeqNum, update.LastApplied,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validation succeeded, therefore the update is committed and
|
||||||
|
// the session's last applied value is equal to the update's
|
||||||
|
// sequence number.
|
||||||
|
lastApplied = session.LastApplied
|
||||||
|
|
||||||
|
// Store the updated session to persist the updated last applied
|
||||||
|
// values.
|
||||||
|
err = putSession(sessions, session)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create or load the hint bucket for this state update's hint
|
||||||
|
// and write the given update.
|
||||||
|
hints, err := updates.CreateBucketIfNotExists(update.Hint[:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
err = update.Encode(&b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = hints.Put(update.ID[:], b.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, create an entry in the update index to track this
|
||||||
|
// hint under its session id. This will allow us to delete the
|
||||||
|
// entries efficiently if the session is ever removed.
|
||||||
|
return putHintForSession(updateIndex, &update.ID, update.Hint)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastApplied, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession removes all data associated with a particular session id from
|
||||||
|
// the tower's database.
|
||||||
|
func (t *TowerDB) DeleteSession(target SessionID) error {
|
||||||
|
return t.db.Update(func(tx *bbolt.Tx) error {
|
||||||
|
sessions := tx.Bucket(sessionsBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
updates := tx.Bucket(updatesBkt)
|
||||||
|
if updates == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
updateIndex := tx.Bucket(updateIndexBkt)
|
||||||
|
if updateIndex == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fail if the session doesn't exit.
|
||||||
|
_, err := getSession(sessions, target[:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the target session.
|
||||||
|
err = sessions.Delete(target[:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, check the update index for any hints that were added
|
||||||
|
// under this session.
|
||||||
|
hints, err := getHintsForSession(updateIndex, &target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, hint := range hints {
|
||||||
|
// Remove the state updates for any blobs stored under
|
||||||
|
// the target session identifier.
|
||||||
|
updatesForHint := updates.Bucket(hint[:])
|
||||||
|
if updatesForHint == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
update := updatesForHint.Get(target[:])
|
||||||
|
if update == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := updatesForHint.Delete(target[:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this was the last state update, we can also remove
|
||||||
|
// the hint that would map to an empty set.
|
||||||
|
err = isBucketEmpty(updatesForHint)
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// Other updates exist for this hint, keep the bucket.
|
||||||
|
case err == errBucketNotEmpty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
// Unexpected error.
|
||||||
|
case err != nil:
|
||||||
|
return err
|
||||||
|
|
||||||
|
// No more updates for this hint, prune hint bucket.
|
||||||
|
default:
|
||||||
|
err = updates.DeleteBucket(hint[:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, remove this session from the update index, which
|
||||||
|
// also removes any of the indexed hints beneath it.
|
||||||
|
return removeSessionHintBkt(updateIndex, &target)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMatches searches against all known state updates for any that match the
|
||||||
|
// passed breachHints. More than one Match will be returned for a given hint if
|
||||||
|
// they exist in the database.
|
||||||
|
func (t *TowerDB) QueryMatches(breachHints []BreachHint) ([]Match, error) {
|
||||||
|
var matches []Match
|
||||||
|
err := t.db.View(func(tx *bbolt.Tx) error {
|
||||||
|
sessions := tx.Bucket(sessionsBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
updates := tx.Bucket(updatesBkt)
|
||||||
|
if updates == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate through the target breach hints, appending any
|
||||||
|
// matching updates to the set of matches.
|
||||||
|
for _, hint := range breachHints {
|
||||||
|
// If a bucket does not exist for this hint, no matches
|
||||||
|
// are known.
|
||||||
|
updatesForHint := updates.Bucket(hint[:])
|
||||||
|
if updatesForHint == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, iterate through all (session id, update)
|
||||||
|
// pairs, creating a Match for each.
|
||||||
|
err := updatesForHint.ForEach(func(k, v []byte) error {
|
||||||
|
// Load the session via the session id for this
|
||||||
|
// update. The session info contains further
|
||||||
|
// instructions for how to process the state
|
||||||
|
// update.
|
||||||
|
session, err := getSession(sessions, k)
|
||||||
|
switch {
|
||||||
|
case err == ErrSessionNotFound:
|
||||||
|
log.Warnf("Missing session=%x for "+
|
||||||
|
"matched state update hint=%x",
|
||||||
|
k, hint)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case err != nil:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the state update containing the
|
||||||
|
// encrypted blob.
|
||||||
|
update := &SessionStateUpdate{}
|
||||||
|
err = update.Decode(bytes.NewReader(v))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var id SessionID
|
||||||
|
copy(id[:], k)
|
||||||
|
|
||||||
|
// Construct the final match using the found
|
||||||
|
// update and its session info.
|
||||||
|
match := Match{
|
||||||
|
ID: id,
|
||||||
|
SeqNum: update.SeqNum,
|
||||||
|
Hint: hint,
|
||||||
|
EncryptedBlob: update.EncryptedBlob,
|
||||||
|
SessionInfo: session,
|
||||||
|
}
|
||||||
|
|
||||||
|
matches = append(matches, match)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in
|
||||||
|
// the tower database.
|
||||||
|
func (t *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
|
||||||
|
return t.db.Update(func(tx *bbolt.Tx) error {
|
||||||
|
lookoutTip := tx.Bucket(lookoutTipBkt)
|
||||||
|
if lookoutTip == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
return putLookoutEpoch(lookoutTip, epoch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLookoutTip retrieves the current lookout tip block epoch from the tower
|
||||||
|
// database.
|
||||||
|
func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
|
||||||
|
var epoch *chainntnfs.BlockEpoch
|
||||||
|
err := t.db.View(func(tx *bbolt.Tx) error {
|
||||||
|
lookoutTip := tx.Bucket(lookoutTipBkt)
|
||||||
|
if lookoutTip == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
epoch = getLookoutEpoch(lookoutTip)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return epoch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSession retrieves the session info from the sessions bucket identified by
|
||||||
|
// its session id. An error is returned if the session is not found or a
|
||||||
|
// deserialization error occurs.
|
||||||
|
func getSession(sessions *bbolt.Bucket, id []byte) (*SessionInfo, error) {
|
||||||
|
sessionBytes := sessions.Get(id)
|
||||||
|
if sessionBytes == nil {
|
||||||
|
return nil, ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
var session SessionInfo
|
||||||
|
err := session.Decode(bytes.NewReader(sessionBytes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// putSession stores the session info in the sessions bucket identified by its
|
||||||
|
// session id. An error is returned if a serialization error occurs.
|
||||||
|
func putSession(sessions *bbolt.Bucket, session *SessionInfo) error {
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := session.Encode(&b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessions.Put(session.ID[:], b.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
// touchSessionHintBkt initializes the session-hint bucket for a particular
|
||||||
|
// session id. This ensures that future calls to getHintsForSession or
|
||||||
|
// putHintForSession can rely on the bucket already being created, and fail if
|
||||||
|
// index has not been initialized as this points to improper usage.
|
||||||
|
func touchSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error {
|
||||||
|
_, err := updateIndex.CreateBucketIfNotExists(id[:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeSessionHintBkt prunes the session-hint bucket for the given session id
|
||||||
|
// and all of the hints contained inside. This should be used to clean up the
|
||||||
|
// index upon session deletion.
|
||||||
|
func removeSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error {
|
||||||
|
return updateIndex.DeleteBucket(id[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHintsForSession returns all known hints belonging to the given session id.
|
||||||
|
// If the index for the session has not been initialized, this method returns
|
||||||
|
// ErrNoSessionHintIndex.
|
||||||
|
func getHintsForSession(updateIndex *bbolt.Bucket,
|
||||||
|
id *SessionID) ([]BreachHint, error) {
|
||||||
|
|
||||||
|
sessionHints := updateIndex.Bucket(id[:])
|
||||||
|
if sessionHints == nil {
|
||||||
|
return nil, ErrNoSessionHintIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
var hints []BreachHint
|
||||||
|
err := sessionHints.ForEach(func(k, _ []byte) error {
|
||||||
|
if len(k) != BreachHintSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var hint BreachHint
|
||||||
|
copy(hint[:], k)
|
||||||
|
hints = append(hints, hint)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return hints, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// putHintForSession inserts a record into the update index for a given
|
||||||
|
// (session, hint) pair. The hints are coalesced under a bucket for the target
|
||||||
|
// session id, and used to perform efficient removal of updates. If the index
|
||||||
|
// for the session has not been initialized, this method returns
|
||||||
|
// ErrNoSessionHintIndex.
|
||||||
|
func putHintForSession(updateIndex *bbolt.Bucket, id *SessionID,
|
||||||
|
hint BreachHint) error {
|
||||||
|
|
||||||
|
sessionHints := updateIndex.Bucket(id[:])
|
||||||
|
if sessionHints == nil {
|
||||||
|
return ErrNoSessionHintIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionHints.Put(hint[:], []byte{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// putLookoutEpoch stores the given lookout tip block epoch in provided bucket.
|
||||||
|
func putLookoutEpoch(bkt *bbolt.Bucket, epoch *chainntnfs.BlockEpoch) error {
|
||||||
|
epochBytes := make([]byte, 36)
|
||||||
|
copy(epochBytes, epoch.Hash[:])
|
||||||
|
byteOrder.PutUint32(epochBytes[32:], uint32(epoch.Height))
|
||||||
|
|
||||||
|
return bkt.Put(lookoutTipKey, epochBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLookoutEpoch retrieves the lookout tip block epoch from the given bucket.
|
||||||
|
// A nil epoch is returned if no update exists.
|
||||||
|
func getLookoutEpoch(bkt *bbolt.Bucket) *chainntnfs.BlockEpoch {
|
||||||
|
epochBytes := bkt.Get(lookoutTipKey)
|
||||||
|
if len(epochBytes) != 36 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var hash chainhash.Hash
|
||||||
|
copy(hash[:], epochBytes[:32])
|
||||||
|
height := byteOrder.Uint32(epochBytes[32:])
|
||||||
|
|
||||||
|
return &chainntnfs.BlockEpoch{
|
||||||
|
Hash: &hash,
|
||||||
|
Height: int32(height),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// errBucketNotEmpty is a helper error returned when testing whether a bucket is
|
||||||
|
// empty or not.
|
||||||
|
var errBucketNotEmpty = errors.New("bucket not empty")
|
||||||
|
|
||||||
|
// isBucketEmpty returns errBucketNotEmpty if the bucket is not empty.
|
||||||
|
func isBucketEmpty(bkt *bbolt.Bucket) error {
|
||||||
|
return bkt.ForEach(func(_, _ []byte) error {
|
||||||
|
return errBucketNotEmpty
|
||||||
|
})
|
||||||
|
}
|
730
watchtower/wtdb/tower_db_test.go
Normal file
730
watchtower/wtdb/tower_db_test.go
Normal file
@ -0,0 +1,730 @@
|
|||||||
|
package wtdb_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
|
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dbInit is a closure used to initialize a watchtower.DB instance and its
|
||||||
|
// cleanup function.
|
||||||
|
type dbInit func(*testing.T) (watchtower.DB, func())
|
||||||
|
|
||||||
|
// towerDBHarness holds the resources required to execute the tower db tests.
|
||||||
|
type towerDBHarness struct {
|
||||||
|
t *testing.T
|
||||||
|
db watchtower.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTowerDBHarness initializes a fresh test harness for testing watchtower.DB
|
||||||
|
// implementations.
|
||||||
|
func newTowerDBHarness(t *testing.T, init dbInit) (*towerDBHarness, func()) {
|
||||||
|
db, cleanup := init(t)
|
||||||
|
|
||||||
|
h := &towerDBHarness{
|
||||||
|
t: t,
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertSession attempts to isnert the passed session and asserts that the
|
||||||
|
// error returned matches expErr.
|
||||||
|
func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) {
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
err := h.db.InsertSessionInfo(s)
|
||||||
|
if err != expErr {
|
||||||
|
h.t.Fatalf("expected insert session error: %v, got : %v",
|
||||||
|
expErr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSession retrieves the session identified by id, asserting that the call
|
||||||
|
// returns expErr. If successful, the found session is returned.
|
||||||
|
func (h *towerDBHarness) getSession(id *wtdb.SessionID,
|
||||||
|
expErr error) *wtdb.SessionInfo {
|
||||||
|
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
session, err := h.db.GetSessionInfo(id)
|
||||||
|
if err != expErr {
|
||||||
|
h.t.Fatalf("expected get session error: %v, got: %v",
|
||||||
|
expErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertUpdate attempts to insert the passed state update and asserts that the
|
||||||
|
// error returned matches expErr. If successful, the session's last applied
|
||||||
|
// value is returned.
|
||||||
|
func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate,
|
||||||
|
expErr error) uint16 {
|
||||||
|
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
lastApplied, err := h.db.InsertStateUpdate(s)
|
||||||
|
if err != expErr {
|
||||||
|
h.t.Fatalf("expected insert update error: %v, got: %v",
|
||||||
|
expErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastApplied
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteSession attempts to delete the session identified by id and asserts
|
||||||
|
// that the error returned from DeleteSession matches the expected error.
|
||||||
|
func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
err := h.db.DeleteSession(id)
|
||||||
|
if err != expErr {
|
||||||
|
h.t.Fatalf("expected deletion error: %v, got: %v",
|
||||||
|
expErr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryMatches queries that database for the passed breach hint, returning all
|
||||||
|
// matches found.
|
||||||
|
func (h *towerDBHarness) queryMatches(hint wtdb.BreachHint) []wtdb.Match {
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
matches, err := h.db.QueryMatches([]wtdb.BreachHint{hint})
|
||||||
|
if err != nil {
|
||||||
|
h.t.Fatalf("unable to query matches: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasUpdate queries the database for the passed breach hint, asserting that
|
||||||
|
// only one match is present and that the hints indeed match. If successful, the
|
||||||
|
// match is returned.
|
||||||
|
func (h *towerDBHarness) hasUpdate(hint wtdb.BreachHint) wtdb.Match {
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
matches := h.queryMatches(hint)
|
||||||
|
if len(matches) != 1 {
|
||||||
|
h.t.Fatalf("expected 1 match, found: %d", len(matches))
|
||||||
|
}
|
||||||
|
|
||||||
|
match := matches[0]
|
||||||
|
if match.Hint != hint {
|
||||||
|
h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint)
|
||||||
|
}
|
||||||
|
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
|
||||||
|
// testInsertSession asserts that a session can only be inserted if a session
|
||||||
|
// with the same session id does not already exist.
|
||||||
|
func testInsertSession(h *towerDBHarness) {
|
||||||
|
var id wtdb.SessionID
|
||||||
|
h.getSession(&id, wtdb.ErrSessionNotFound)
|
||||||
|
|
||||||
|
session := &wtdb.SessionInfo{
|
||||||
|
ID: id,
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 100,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{0x01, 0x02, 0x03},
|
||||||
|
}
|
||||||
|
|
||||||
|
h.insertSession(session, nil)
|
||||||
|
|
||||||
|
session2 := h.getSession(&id, nil)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(session, session2) {
|
||||||
|
h.t.Fatalf("expected session: %v, got %v",
|
||||||
|
session, session2)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.insertSession(session, nil)
|
||||||
|
|
||||||
|
// Insert a state update to fully commit the session parameters.
|
||||||
|
update := &wtdb.SessionStateUpdate{
|
||||||
|
ID: id,
|
||||||
|
SeqNum: 1,
|
||||||
|
}
|
||||||
|
h.insertUpdate(update, nil)
|
||||||
|
|
||||||
|
// Trying to insert a new session under the same ID should fail.
|
||||||
|
h.insertSession(session, wtdb.ErrSessionAlreadyExists)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testMultipleMatches asserts that if multiple sessions insert state updates
|
||||||
|
// with the same breach hint that all will be returned from QueryMatches.
|
||||||
|
func testMultipleMatches(h *towerDBHarness) {
|
||||||
|
const numUpdates = 3
|
||||||
|
|
||||||
|
// Create a new session and send updates with all the same hint.
|
||||||
|
var hint wtdb.BreachHint
|
||||||
|
for i := 0; i < numUpdates; i++ {
|
||||||
|
id := *id(i)
|
||||||
|
session := &wtdb.SessionInfo{
|
||||||
|
ID: id,
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
}
|
||||||
|
h.insertSession(session, nil)
|
||||||
|
|
||||||
|
update := &wtdb.SessionStateUpdate{
|
||||||
|
ID: id,
|
||||||
|
SeqNum: 1,
|
||||||
|
Hint: hint, // Use same hint to cause multiple matches
|
||||||
|
}
|
||||||
|
h.insertUpdate(update, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query the db for matches on the chosen hint.
|
||||||
|
matches := h.queryMatches(hint)
|
||||||
|
if len(matches) != numUpdates {
|
||||||
|
h.t.Fatalf("num updates mismatch, want: %d, got: %d",
|
||||||
|
numUpdates, len(matches))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that the hints are what we asked for, and compute the set of
|
||||||
|
// sessions returned.
|
||||||
|
sessions := make(map[wtdb.SessionID]struct{})
|
||||||
|
for _, match := range matches {
|
||||||
|
if match.Hint != hint {
|
||||||
|
h.t.Fatalf("hint mismatch, want: %v, got: %v",
|
||||||
|
hint, match.Hint)
|
||||||
|
}
|
||||||
|
sessions[match.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that the sessions returned match the session ids of the
|
||||||
|
// sessions we initially created.
|
||||||
|
for i := 0; i < numUpdates; i++ {
|
||||||
|
if _, ok := sessions[*id(i)]; !ok {
|
||||||
|
h.t.Fatalf("match for session %v not found", *id(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testLookoutTip asserts that the database properly stores and returns the
|
||||||
|
// lookout tip block epochs. It also asserts that the epoch returned is nil when
|
||||||
|
// no tip has ever been set.
|
||||||
|
func testLookoutTip(h *towerDBHarness) {
|
||||||
|
// Retrieve lookout tip on fresh db.
|
||||||
|
epoch, err := h.db.GetLookoutTip()
|
||||||
|
if err != nil {
|
||||||
|
h.t.Fatalf("unable to fetch lookout tip: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that the epoch is nil.
|
||||||
|
if epoch != nil {
|
||||||
|
h.t.Fatalf("lookout tip should not be set, found: %v", epoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a closure that inserts an epoch, retrieves it, and asserts
|
||||||
|
// that the returned epoch matches what was inserted.
|
||||||
|
setAndCheck := func(i int) {
|
||||||
|
expEpoch := epochFromInt(1)
|
||||||
|
err = h.db.SetLookoutTip(expEpoch)
|
||||||
|
if err != nil {
|
||||||
|
h.t.Fatalf("unable to set lookout tip: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
epoch, err = h.db.GetLookoutTip()
|
||||||
|
if err != nil {
|
||||||
|
h.t.Fatalf("unable to fetch lookout tip: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(epoch, expEpoch) {
|
||||||
|
h.t.Fatalf("lookout tip mismatch, want: %v, got: %v",
|
||||||
|
expEpoch, epoch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set and assert the lookout tip.
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
setAndCheck(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testDeleteSession asserts the behavior of a tower database when deleting
|
||||||
|
// session data. The test asserts that the only proper the target session is
|
||||||
|
// remmoved, and that only updates for a particular session are pruned.
|
||||||
|
func testDeleteSession(h *towerDBHarness) {
|
||||||
|
// First, create a session so that the database is not empty.
|
||||||
|
id0 := id(0)
|
||||||
|
session0 := &wtdb.SessionInfo{
|
||||||
|
ID: *id0,
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
}
|
||||||
|
h.insertSession(session0, nil)
|
||||||
|
|
||||||
|
// Now, attempt to delete a session which does not exist, that is also
|
||||||
|
// different from the first one created.
|
||||||
|
id1 := id(1)
|
||||||
|
h.deleteSession(*id1, wtdb.ErrSessionNotFound)
|
||||||
|
|
||||||
|
// The first session should still be present.
|
||||||
|
h.getSession(id0, nil)
|
||||||
|
|
||||||
|
// Now insert a second session under a different id.
|
||||||
|
session1 := &wtdb.SessionInfo{
|
||||||
|
ID: *id1,
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
}
|
||||||
|
h.insertSession(session1, nil)
|
||||||
|
|
||||||
|
// Create and insert updates for both sessions that have the same hint.
|
||||||
|
var hint wtdb.BreachHint
|
||||||
|
update0 := &wtdb.SessionStateUpdate{
|
||||||
|
ID: *id0,
|
||||||
|
Hint: hint,
|
||||||
|
SeqNum: 1,
|
||||||
|
EncryptedBlob: []byte{},
|
||||||
|
}
|
||||||
|
update1 := &wtdb.SessionStateUpdate{
|
||||||
|
ID: *id1,
|
||||||
|
Hint: hint,
|
||||||
|
SeqNum: 1,
|
||||||
|
EncryptedBlob: []byte{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert both updates should succeed.
|
||||||
|
h.insertUpdate(update0, nil)
|
||||||
|
h.insertUpdate(update1, nil)
|
||||||
|
|
||||||
|
// Remove the new session, which should succeed.
|
||||||
|
h.deleteSession(*id1, nil)
|
||||||
|
|
||||||
|
// The first session should still be present.
|
||||||
|
h.getSession(id0, nil)
|
||||||
|
|
||||||
|
// The second session should be removed.
|
||||||
|
h.getSession(id1, wtdb.ErrSessionNotFound)
|
||||||
|
|
||||||
|
// Assert that only one update is still present.
|
||||||
|
matches := h.queryMatches(hint)
|
||||||
|
if len(matches) != 1 {
|
||||||
|
h.t.Fatalf("expected one update, found: %d", len(matches))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that the update belongs to the first session.
|
||||||
|
if matches[0].ID != *id0 {
|
||||||
|
h.t.Fatalf("expected match for %v, instead is for: %v",
|
||||||
|
*id0, matches[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, remove the first session added.
|
||||||
|
h.deleteSession(*id0, nil)
|
||||||
|
|
||||||
|
// The session should no longer be present.
|
||||||
|
h.getSession(id0, wtdb.ErrSessionNotFound)
|
||||||
|
|
||||||
|
// No matches should exist for this hint.
|
||||||
|
matches = h.queryMatches(hint)
|
||||||
|
if len(matches) != 0 {
|
||||||
|
h.t.Fatalf("expected zero updates, found: %d", len(matches))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateUpdateTest struct {
|
||||||
|
session *wtdb.SessionInfo
|
||||||
|
sessionErr error
|
||||||
|
updates []*wtdb.SessionStateUpdate
|
||||||
|
updateErrs []error
|
||||||
|
}
|
||||||
|
|
||||||
|
func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
|
||||||
|
return func(h *towerDBHarness) {
|
||||||
|
// We may need to modify the initial session as we process
|
||||||
|
// updates to discern the expected state of the session. We'll
|
||||||
|
// create a copy of the test session if necessary to prevent
|
||||||
|
// mutations from impacting other tests.
|
||||||
|
var expSession *wtdb.SessionInfo
|
||||||
|
|
||||||
|
// Create the session if the tests requests one.
|
||||||
|
if test.session != nil {
|
||||||
|
// Copy the initial session and insert it into the
|
||||||
|
// database.
|
||||||
|
ogSession := *test.session
|
||||||
|
expErr := test.sessionErr
|
||||||
|
h.insertSession(&ogSession, expErr)
|
||||||
|
|
||||||
|
if expErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the initial state of the accepted session.
|
||||||
|
expSession = &wtdb.SessionInfo{}
|
||||||
|
*expSession = *test.session
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(test.updates) != len(test.updateErrs) {
|
||||||
|
h.t.Fatalf("malformed test case, num updates " +
|
||||||
|
"should match num errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send any updates provided in the test.
|
||||||
|
for i, update := range test.updates {
|
||||||
|
expErr := test.updateErrs[i]
|
||||||
|
h.insertUpdate(update, expErr)
|
||||||
|
|
||||||
|
if expErr != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't perform the following checks and modfications
|
||||||
|
// if we don't have an expected session to compare
|
||||||
|
// against.
|
||||||
|
if expSession == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the session's last applied and client last
|
||||||
|
// applied.
|
||||||
|
expSession.LastApplied = update.SeqNum
|
||||||
|
expSession.ClientLastApplied = update.LastApplied
|
||||||
|
|
||||||
|
match := h.hasUpdate(update.Hint)
|
||||||
|
if !reflect.DeepEqual(match.SessionInfo, expSession) {
|
||||||
|
h.t.Fatalf("expected session: %v, got: %v",
|
||||||
|
expSession, match.SessionInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateUpdateNoSession = stateUpdateTest{
|
||||||
|
session: nil,
|
||||||
|
updates: []*wtdb.SessionStateUpdate{
|
||||||
|
{ID: *id(0), SeqNum: 1, LastApplied: 0},
|
||||||
|
},
|
||||||
|
updateErrs: []error{
|
||||||
|
wtdb.ErrSessionNotFound,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateUpdateExhaustSession = stateUpdateTest{
|
||||||
|
session: &wtdb.SessionInfo{
|
||||||
|
ID: *id(0),
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
},
|
||||||
|
updates: []*wtdb.SessionStateUpdate{
|
||||||
|
updateFromInt(id(0), 1, 0),
|
||||||
|
updateFromInt(id(0), 2, 0),
|
||||||
|
updateFromInt(id(0), 3, 0),
|
||||||
|
updateFromInt(id(0), 4, 0),
|
||||||
|
},
|
||||||
|
updateErrs: []error{
|
||||||
|
nil, nil, nil, wtdb.ErrSessionConsumed,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateUpdateSeqNumEqualLastApplied = stateUpdateTest{
|
||||||
|
session: &wtdb.SessionInfo{
|
||||||
|
ID: *id(0),
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
},
|
||||||
|
updates: []*wtdb.SessionStateUpdate{
|
||||||
|
updateFromInt(id(0), 1, 0),
|
||||||
|
updateFromInt(id(0), 2, 1),
|
||||||
|
updateFromInt(id(0), 3, 2),
|
||||||
|
updateFromInt(id(0), 3, 3),
|
||||||
|
},
|
||||||
|
updateErrs: []error{
|
||||||
|
nil, nil, nil, wtdb.ErrSeqNumAlreadyApplied,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateUpdateSeqNumLTLastApplied = stateUpdateTest{
|
||||||
|
session: &wtdb.SessionInfo{
|
||||||
|
ID: *id(0),
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
},
|
||||||
|
updates: []*wtdb.SessionStateUpdate{
|
||||||
|
updateFromInt(id(0), 1, 0),
|
||||||
|
updateFromInt(id(0), 2, 1),
|
||||||
|
updateFromInt(id(0), 1, 2),
|
||||||
|
},
|
||||||
|
updateErrs: []error{
|
||||||
|
nil, nil, wtdb.ErrSeqNumAlreadyApplied,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateUpdateSeqNumZeroInvalid = stateUpdateTest{
|
||||||
|
session: &wtdb.SessionInfo{
|
||||||
|
ID: *id(0),
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
},
|
||||||
|
updates: []*wtdb.SessionStateUpdate{
|
||||||
|
updateFromInt(id(0), 0, 0),
|
||||||
|
},
|
||||||
|
updateErrs: []error{
|
||||||
|
wtdb.ErrSeqNumAlreadyApplied,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateUpdateSkipSeqNum = stateUpdateTest{
|
||||||
|
session: &wtdb.SessionInfo{
|
||||||
|
ID: *id(0),
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
},
|
||||||
|
updates: []*wtdb.SessionStateUpdate{
|
||||||
|
updateFromInt(id(0), 2, 0),
|
||||||
|
},
|
||||||
|
updateErrs: []error{
|
||||||
|
wtdb.ErrUpdateOutOfOrder,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateUpdateRevertSeqNum = stateUpdateTest{
|
||||||
|
session: &wtdb.SessionInfo{
|
||||||
|
ID: *id(0),
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
},
|
||||||
|
updates: []*wtdb.SessionStateUpdate{
|
||||||
|
updateFromInt(id(0), 1, 0),
|
||||||
|
updateFromInt(id(0), 2, 0),
|
||||||
|
updateFromInt(id(0), 1, 0),
|
||||||
|
},
|
||||||
|
updateErrs: []error{
|
||||||
|
nil, nil, wtdb.ErrUpdateOutOfOrder,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateUpdateRevertLastApplied = stateUpdateTest{
|
||||||
|
session: &wtdb.SessionInfo{
|
||||||
|
ID: *id(0),
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
MaxUpdates: 3,
|
||||||
|
},
|
||||||
|
RewardAddress: []byte{},
|
||||||
|
},
|
||||||
|
updates: []*wtdb.SessionStateUpdate{
|
||||||
|
updateFromInt(id(0), 1, 0),
|
||||||
|
updateFromInt(id(0), 2, 1),
|
||||||
|
updateFromInt(id(0), 3, 2),
|
||||||
|
updateFromInt(id(0), 4, 1),
|
||||||
|
},
|
||||||
|
updateErrs: []error{
|
||||||
|
nil, nil, nil, wtdb.ErrLastAppliedReversion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTowerDB(t *testing.T) {
|
||||||
|
dbs := []struct {
|
||||||
|
name string
|
||||||
|
init dbInit
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "fresh boltdb",
|
||||||
|
init: func(t *testing.T) (watchtower.DB, func()) {
|
||||||
|
path, err := ioutil.TempDir("", "towerdb")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to make temp dir: %v",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := wtdb.OpenTowerDB(path)
|
||||||
|
if err != nil {
|
||||||
|
os.RemoveAll(path)
|
||||||
|
t.Fatalf("unable to open db: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
db.Close()
|
||||||
|
os.RemoveAll(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, cleanup
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reopened boltdb",
|
||||||
|
init: func(t *testing.T) (watchtower.DB, func()) {
|
||||||
|
path, err := ioutil.TempDir("", "towerdb")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to make temp dir: %v",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := wtdb.OpenTowerDB(path)
|
||||||
|
if err != nil {
|
||||||
|
os.RemoveAll(path)
|
||||||
|
t.Fatalf("unable to open db: %v", err)
|
||||||
|
}
|
||||||
|
db.Close()
|
||||||
|
|
||||||
|
// Open the db again, ensuring we test a
|
||||||
|
// different path during open and that all
|
||||||
|
// buckets remain initialized.
|
||||||
|
db, err = wtdb.OpenTowerDB(path)
|
||||||
|
if err != nil {
|
||||||
|
os.RemoveAll(path)
|
||||||
|
t.Fatalf("unable to open db: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
db.Close()
|
||||||
|
os.RemoveAll(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, cleanup
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mock",
|
||||||
|
init: func(t *testing.T) (watchtower.DB, func()) {
|
||||||
|
return wtmock.NewTowerDB(), func() {}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
run func(*towerDBHarness)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "create session",
|
||||||
|
run: testInsertSession,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete session",
|
||||||
|
run: testDeleteSession,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "state update no session",
|
||||||
|
run: runStateUpdateTest(stateUpdateNoSession),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "state update exhaust session",
|
||||||
|
run: runStateUpdateTest(stateUpdateExhaustSession),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "state update seqnum equal last applied",
|
||||||
|
run: runStateUpdateTest(
|
||||||
|
stateUpdateSeqNumEqualLastApplied,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "state update seqnum less than last applied",
|
||||||
|
run: runStateUpdateTest(
|
||||||
|
stateUpdateSeqNumLTLastApplied,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "state update seqnum zero invalid",
|
||||||
|
run: runStateUpdateTest(stateUpdateSeqNumZeroInvalid),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "state update skip seqnum",
|
||||||
|
run: runStateUpdateTest(stateUpdateSkipSeqNum),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "state update revert seqnum",
|
||||||
|
run: runStateUpdateTest(stateUpdateRevertSeqNum),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "state update revert last applied",
|
||||||
|
run: runStateUpdateTest(stateUpdateRevertLastApplied),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple breach matches",
|
||||||
|
run: testMultipleMatches,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "lookout tip",
|
||||||
|
run: testLookoutTip,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, database := range dbs {
|
||||||
|
db := database
|
||||||
|
t.Run(db.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
h, cleanup := newTowerDBHarness(
|
||||||
|
t, db.init,
|
||||||
|
)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
test.run(h)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// id creates a session id from an integer.
|
||||||
|
func id(i int) *wtdb.SessionID {
|
||||||
|
var id wtdb.SessionID
|
||||||
|
binary.BigEndian.PutUint32(id[:4], uint32(i))
|
||||||
|
return &id
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateFromInt creates a unique update for a given (session, seqnum) pair. The
|
||||||
|
// lastApplied argument can be used to construct updates simulating different
|
||||||
|
// levels of synchronicity between client and db.
|
||||||
|
func updateFromInt(id *wtdb.SessionID, i int,
|
||||||
|
lastApplied uint16) *wtdb.SessionStateUpdate {
|
||||||
|
|
||||||
|
// Ensure the hint is unique.
|
||||||
|
var hint wtdb.BreachHint
|
||||||
|
copy(hint[:4], id[:4])
|
||||||
|
binary.BigEndian.PutUint16(hint[4:6], uint16(i))
|
||||||
|
|
||||||
|
return &wtdb.SessionStateUpdate{
|
||||||
|
ID: *id,
|
||||||
|
Hint: hint,
|
||||||
|
SeqNum: uint16(i),
|
||||||
|
LastApplied: lastApplied,
|
||||||
|
EncryptedBlob: []byte{byte(i)},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// epochFromInt creates a block epoch from an integer.
|
||||||
|
func epochFromInt(i int) *chainntnfs.BlockEpoch {
|
||||||
|
var hash chainhash.Hash
|
||||||
|
binary.BigEndian.PutUint32(hash[:4], uint32(i))
|
||||||
|
|
||||||
|
return &chainntnfs.BlockEpoch{
|
||||||
|
Hash: &hash,
|
||||||
|
Height: int32(i),
|
||||||
|
}
|
||||||
|
}
|
84
watchtower/wtdb/version.go
Normal file
84
watchtower/wtdb/version.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package wtdb
|
||||||
|
|
||||||
|
import "github.com/coreos/bbolt"
|
||||||
|
|
||||||
|
// migration is a function which takes a prior outdated version of the database
|
||||||
|
// instances and mutates the key/bucket structure to arrive at a more
|
||||||
|
// up-to-date version of the database.
|
||||||
|
type migration func(tx *bbolt.Tx) error
|
||||||
|
|
||||||
|
// version pairs a version number with the migration that would need to be
|
||||||
|
// applied from the prior version to upgrade.
|
||||||
|
type version struct {
|
||||||
|
number uint32
|
||||||
|
migration migration
|
||||||
|
}
|
||||||
|
|
||||||
|
// dbVersions stores all versions and migrations of the database. This list will
|
||||||
|
// be used when opening the database to determine if any migrations must be
|
||||||
|
// applied.
|
||||||
|
var dbVersions = []version{
|
||||||
|
{
|
||||||
|
// Initial version requires no migration.
|
||||||
|
number: 0,
|
||||||
|
migration: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLatestDBVersion returns the last known database version.
|
||||||
|
func getLatestDBVersion(versions []version) uint32 {
|
||||||
|
return versions[len(versions)-1].number
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMigrations returns a slice of all updates with a greater number that
|
||||||
|
// curVersion that need to be applied to sync up with the latest version.
|
||||||
|
func getMigrations(versions []version, curVersion uint32) []version {
|
||||||
|
var updates []version
|
||||||
|
for _, v := range versions {
|
||||||
|
if v.number > curVersion {
|
||||||
|
updates = append(updates, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return updates
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDBVersion retrieves the current database version from the metadata bucket
|
||||||
|
// using the dbVersionKey.
|
||||||
|
func getDBVersion(tx *bbolt.Tx) (uint32, error) {
|
||||||
|
metadata := tx.Bucket(metadataBkt)
|
||||||
|
if metadata == nil {
|
||||||
|
return 0, ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
versionBytes := metadata.Get(dbVersionKey)
|
||||||
|
if len(versionBytes) != 4 {
|
||||||
|
return 0, ErrNoDBVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
return byteOrder.Uint32(versionBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initDBVersion initializes the top-level metadata bucket and writes the passed
|
||||||
|
// version number as the current version.
|
||||||
|
func initDBVersion(tx *bbolt.Tx, version uint32) error {
|
||||||
|
_, err := tx.CreateBucketIfNotExists(metadataBkt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return putDBVersion(tx, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// putDBVersion stores the passed database version in the metadata bucket under
|
||||||
|
// the dbVersionKey.
|
||||||
|
func putDBVersion(tx *bbolt.Tx, version uint32) error {
|
||||||
|
metadata := tx.Bucket(metadataBkt)
|
||||||
|
if metadata == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
versionBytes := make([]byte, 4)
|
||||||
|
byteOrder.PutUint32(versionBytes, version)
|
||||||
|
return metadata.Put(dbVersionKey, versionBytes)
|
||||||
|
}
|
162
watchtower/wtmock/tower_db.go
Normal file
162
watchtower/wtmock/tower_db.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package wtmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TowerDB is a mock, in-memory implementation of a watchtower.DB.
|
||||||
|
type TowerDB struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
lastEpoch *chainntnfs.BlockEpoch
|
||||||
|
sessions map[wtdb.SessionID]*wtdb.SessionInfo
|
||||||
|
blobs map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTowerDB initializes a fresh mock TowerDB.
|
||||||
|
func NewTowerDB() *TowerDB {
|
||||||
|
return &TowerDB{
|
||||||
|
sessions: make(map[wtdb.SessionID]*wtdb.SessionInfo),
|
||||||
|
blobs: make(map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertStateUpdate stores an update sent by the client after validating that
|
||||||
|
// the update is well-formed in the context of other updates sent for the same
|
||||||
|
// session. This include verifying that the sequence number is incremented
|
||||||
|
// properly and the last applied values echoed by the client are sane.
|
||||||
|
func (db *TowerDB) InsertStateUpdate(update *wtdb.SessionStateUpdate) (uint16, error) {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
info, ok := db.sessions[update.ID]
|
||||||
|
if !ok {
|
||||||
|
return 0, wtdb.ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied)
|
||||||
|
if err != nil {
|
||||||
|
return info.LastApplied, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionsToUpdates, ok := db.blobs[update.Hint]
|
||||||
|
if !ok {
|
||||||
|
sessionsToUpdates = make(map[wtdb.SessionID]*wtdb.SessionStateUpdate)
|
||||||
|
db.blobs[update.Hint] = sessionsToUpdates
|
||||||
|
}
|
||||||
|
sessionsToUpdates[update.ID] = update
|
||||||
|
|
||||||
|
return info.LastApplied, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionInfo retrieves the session for the passed session id. An error is
|
||||||
|
// returned if the session could not be found.
|
||||||
|
func (db *TowerDB) GetSessionInfo(id *wtdb.SessionID) (*wtdb.SessionInfo, error) {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
if info, ok := db.sessions[*id]; ok {
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, wtdb.ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertSessionInfo records a negotiated session in the tower database. An
|
||||||
|
// error is returned if the session already exists.
|
||||||
|
func (db *TowerDB) InsertSessionInfo(info *wtdb.SessionInfo) error {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
dbInfo, ok := db.sessions[info.ID]
|
||||||
|
if ok && dbInfo.LastApplied > 0 {
|
||||||
|
return wtdb.ErrSessionAlreadyExists
|
||||||
|
}
|
||||||
|
|
||||||
|
db.sessions[info.ID] = info
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession removes all data associated with a particular session id from
|
||||||
|
// the tower's database.
|
||||||
|
func (db *TowerDB) DeleteSession(target wtdb.SessionID) error {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
// Fail if the session doesn't exit.
|
||||||
|
if _, ok := db.sessions[target]; !ok {
|
||||||
|
return wtdb.ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the target session.
|
||||||
|
delete(db.sessions, target)
|
||||||
|
|
||||||
|
// Remove the state updates for any blobs stored under the target
|
||||||
|
// session identifier.
|
||||||
|
for hint, sessionUpdates := range db.blobs {
|
||||||
|
delete(sessionUpdates, target)
|
||||||
|
|
||||||
|
// If this was the last state update, we can also remove the
|
||||||
|
// hint that would map to an empty set.
|
||||||
|
if len(sessionUpdates) == 0 {
|
||||||
|
delete(db.blobs, hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMatches searches against all known state updates for any that match the
|
||||||
|
// passed breachHints. More than one Match will be returned for a given hint if
|
||||||
|
// they exist in the database.
|
||||||
|
func (db *TowerDB) QueryMatches(
|
||||||
|
breachHints []wtdb.BreachHint) ([]wtdb.Match, error) {
|
||||||
|
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
var matches []wtdb.Match
|
||||||
|
for _, hint := range breachHints {
|
||||||
|
sessionsToUpdates, ok := db.blobs[hint]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, update := range sessionsToUpdates {
|
||||||
|
info, ok := db.sessions[id]
|
||||||
|
if !ok {
|
||||||
|
panic("session not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
match := wtdb.Match{
|
||||||
|
ID: id,
|
||||||
|
SeqNum: update.SeqNum,
|
||||||
|
Hint: hint,
|
||||||
|
EncryptedBlob: update.EncryptedBlob,
|
||||||
|
SessionInfo: info,
|
||||||
|
}
|
||||||
|
matches = append(matches, match)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in
|
||||||
|
// the tower database.
|
||||||
|
func (db *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
|
||||||
|
db.lastEpoch = epoch
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLookoutTip retrieves the current lookout tip block epoch from the tower
|
||||||
|
// database.
|
||||||
|
func (db *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
return db.lastEpoch, nil
|
||||||
|
}
|
@ -102,7 +102,7 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
|||||||
// successful, the session will now be ready for use.
|
// successful, the session will now be ready for use.
|
||||||
err = s.cfg.DB.InsertSessionInfo(&info)
|
err = s.cfg.DB.InsertSessionInfo(&info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to create session for %s", id)
|
log.Errorf("Unable to create session for %s: %v", id, err)
|
||||||
return s.replyCreateSession(
|
return s.replyCreateSession(
|
||||||
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
// +build dev
|
|
||||||
|
|
||||||
package wtserver_test
|
package wtserver_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -53,7 +51,7 @@ func initServer(t *testing.T, db wtserver.DB,
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
if db == nil {
|
if db == nil {
|
||||||
db = wtdb.NewMockDB()
|
db = wtmock.NewTowerDB()
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := wtserver.New(&wtserver.Config{
|
s, err := wtserver.New(&wtserver.Config{
|
||||||
@ -687,7 +685,7 @@ func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) {
|
|||||||
// checking that the proper error is returned when the session doesn't exist and
|
// checking that the proper error is returned when the session doesn't exist and
|
||||||
// that a successful deletion does not disrupt other sessions.
|
// that a successful deletion does not disrupt other sessions.
|
||||||
func TestServerDeleteSession(t *testing.T) {
|
func TestServerDeleteSession(t *testing.T) {
|
||||||
db := wtdb.NewMockDB()
|
db := wtmock.NewTowerDB()
|
||||||
|
|
||||||
localPub := randPubKey(t)
|
localPub := randPubKey(t)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user