Merge pull request #1535 from cfromknecht/wtwire-server
[watchtower/server] Server-Side Wire Protocol
This commit is contained in:
commit
5c6c966891
66
watchtower/server/interface.go
Normal file
66
watchtower/server/interface.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Interface represents a simple, listen-only service that accepts watchtower
|
||||||
|
// clients, and provides responses to their requests.
|
||||||
|
type Interface interface {
|
||||||
|
// InboundPeerConnected accepts a new watchtower client, and handles any
|
||||||
|
// requests sent by the peer.
|
||||||
|
InboundPeerConnected(Peer)
|
||||||
|
|
||||||
|
// Start sets up the watchtower server.
|
||||||
|
Start() error
|
||||||
|
|
||||||
|
// Stop cleans up the watchtower's current connections and resources.
|
||||||
|
Stop() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peer is the primary interface used to abstract watchtower clients.
|
||||||
|
type Peer interface {
|
||||||
|
io.WriteCloser
|
||||||
|
|
||||||
|
// ReadNextMessage pulls the next framed message from the client.
|
||||||
|
ReadNextMessage() ([]byte, error)
|
||||||
|
|
||||||
|
// SetWriteDeadline specifies the time by which the client must have
|
||||||
|
// read a message sent by the server. In practice, the connection is
|
||||||
|
// buffered, so the client must read enough from the connection to
|
||||||
|
// support the server adding another reply.
|
||||||
|
SetWriteDeadline(time.Time) error
|
||||||
|
|
||||||
|
// SetReadDeadline specifies the time by which the client must send
|
||||||
|
// another message.
|
||||||
|
SetReadDeadline(time.Time) error
|
||||||
|
|
||||||
|
// RemotePub returns the client's public key.
|
||||||
|
RemotePub() *btcec.PublicKey
|
||||||
|
|
||||||
|
// RemoteAddr returns the client's network address.
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// DB provides the server access to session creation and retrieval, as well as
|
||||||
|
// persisting state updates sent by clients.
|
||||||
|
type DB interface {
|
||||||
|
// InsertSessionInfo saves a newly agreed-upon session from a client.
|
||||||
|
// This method should fail if a session with the same session id already
|
||||||
|
// exists.
|
||||||
|
InsertSessionInfo(*wtdb.SessionInfo) error
|
||||||
|
|
||||||
|
// GetSessionInfo retrieves the SessionInfo associated with the session
|
||||||
|
// id, if it exists.
|
||||||
|
GetSessionInfo(*wtdb.SessionID) (*wtdb.SessionInfo, error)
|
||||||
|
|
||||||
|
// InsertStateUpdate persists a state update sent by a client, and
|
||||||
|
// validates the update against the current SessionInfo stored under the
|
||||||
|
// update's session id..
|
||||||
|
InsertStateUpdate(*wtdb.SessionStateUpdate) (uint16, error)
|
||||||
|
}
|
29
watchtower/server/log.go
Normal file
29
watchtower/server/log.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/btcsuite/btclog"
|
||||||
|
"github.com/lightningnetwork/lnd/build"
|
||||||
|
)
|
||||||
|
|
||||||
|
// log is a logger that is initialized with no output filters. This
|
||||||
|
// means the package will not perform any logging by default until the caller
|
||||||
|
// requests it.
|
||||||
|
var log btclog.Logger
|
||||||
|
|
||||||
|
// The default amount of logging is none.
|
||||||
|
func init() {
|
||||||
|
UseLogger(build.NewSubLogger("WTSV", nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableLog disables all library log output. Logging output is disabled
|
||||||
|
// by default until UseLogger is called.
|
||||||
|
func DisableLog() {
|
||||||
|
UseLogger(btclog.Disabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseLogger uses a specified Logger to output package logging info.
|
||||||
|
// This should be used in preference to SetLogWriter if the caller is also
|
||||||
|
// using btclog.
|
||||||
|
func UseLogger(logger btclog.Logger) {
|
||||||
|
log = logger
|
||||||
|
}
|
98
watchtower/server/mock.go
Normal file
98
watchtower/server/mock.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
// +build dev
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockPeer struct {
|
||||||
|
remotePub *btcec.PublicKey
|
||||||
|
remoteAddr net.Addr
|
||||||
|
|
||||||
|
IncomingMsgs chan []byte
|
||||||
|
OutgoingMsgs chan []byte
|
||||||
|
|
||||||
|
writeDeadline <-chan time.Time
|
||||||
|
readDeadline <-chan time.Time
|
||||||
|
|
||||||
|
Quit chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockPeer(pk *btcec.PublicKey, addr net.Addr, bufferSize int) *MockPeer {
|
||||||
|
return &MockPeer{
|
||||||
|
remotePub: pk,
|
||||||
|
remoteAddr: addr,
|
||||||
|
IncomingMsgs: make(chan []byte, bufferSize),
|
||||||
|
OutgoingMsgs: make(chan []byte, bufferSize),
|
||||||
|
Quit: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MockPeer) Write(b []byte) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case p.OutgoingMsgs <- b:
|
||||||
|
return len(b), nil
|
||||||
|
case <-p.writeDeadline:
|
||||||
|
return 0, fmt.Errorf("write timeout expired")
|
||||||
|
case <-p.Quit:
|
||||||
|
return 0, fmt.Errorf("connection closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MockPeer) Close() error {
|
||||||
|
select {
|
||||||
|
case <-p.Quit:
|
||||||
|
return fmt.Errorf("connection already closed")
|
||||||
|
default:
|
||||||
|
close(p.Quit)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MockPeer) ReadNextMessage() ([]byte, error) {
|
||||||
|
select {
|
||||||
|
case b := <-p.IncomingMsgs:
|
||||||
|
return b, nil
|
||||||
|
case <-p.readDeadline:
|
||||||
|
return nil, fmt.Errorf("read timeout expired")
|
||||||
|
case <-p.Quit:
|
||||||
|
return nil, fmt.Errorf("connection closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MockPeer) SetWriteDeadline(t time.Time) error {
|
||||||
|
if t.IsZero() {
|
||||||
|
p.writeDeadline = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Until(t)
|
||||||
|
p.writeDeadline = time.After(duration)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MockPeer) SetReadDeadline(t time.Time) error {
|
||||||
|
if t.IsZero() {
|
||||||
|
p.readDeadline = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Until(t)
|
||||||
|
p.readDeadline = time.After(duration)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MockPeer) RemotePub() *btcec.PublicKey {
|
||||||
|
return p.remotePub
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MockPeer) RemoteAddr() net.Addr {
|
||||||
|
return p.remoteAddr
|
||||||
|
}
|
611
watchtower/server/server.go
Normal file
611
watchtower/server/server.go
Normal file
@ -0,0 +1,611 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/btcsuite/btcd/connmgr"
|
||||||
|
"github.com/btcsuite/btcutil"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrPeerAlreadyConnected signals that a peer with the same session id
|
||||||
|
// is already active within the server.
|
||||||
|
ErrPeerAlreadyConnected = errors.New("peer already connected")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config abstracts the primary components and dependencies of the server.
|
||||||
|
type Config struct {
|
||||||
|
// DB provides persistent access to the server's sessions and for
|
||||||
|
// storing state updates.
|
||||||
|
DB DB
|
||||||
|
|
||||||
|
// NodePrivKey is private key to be used in accepting new brontide
|
||||||
|
// connections.
|
||||||
|
NodePrivKey *btcec.PrivateKey
|
||||||
|
|
||||||
|
// Listeners specifies which address to which clients may connect.
|
||||||
|
Listeners []net.Listener
|
||||||
|
|
||||||
|
// ReadTimeout specifies how long a client may go without sending a
|
||||||
|
// message.
|
||||||
|
ReadTimeout time.Duration
|
||||||
|
|
||||||
|
// WriteTimeout specifies how long a client may go without reading a
|
||||||
|
// message from the other end, if the connection has stopped buffering
|
||||||
|
// the server's replies.
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
|
||||||
|
// NewAddress is used to generate reward addresses, where a cut of
|
||||||
|
// successfully sent funds can be received.
|
||||||
|
NewAddress func() (btcutil.Address, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server houses the state required to handle watchtower peers. It's primary job
|
||||||
|
// is to accept incoming connections, and dispatch processing of the client
|
||||||
|
// message streams.
|
||||||
|
type Server struct {
|
||||||
|
started int32 // atomic
|
||||||
|
shutdown int32 // atomic
|
||||||
|
|
||||||
|
cfg *Config
|
||||||
|
|
||||||
|
connMgr *connmgr.ConnManager
|
||||||
|
|
||||||
|
clientMtx sync.RWMutex
|
||||||
|
clients map[wtdb.SessionID]Peer
|
||||||
|
|
||||||
|
globalFeatures *lnwire.RawFeatureVector
|
||||||
|
localFeatures *lnwire.RawFeatureVector
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
quit chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new server to handle watchtower clients. The server will accept
|
||||||
|
// clients connecting to the listener addresses, and allows them to open
|
||||||
|
// sessions and send state updates.
|
||||||
|
func New(cfg *Config) (*Server, error) {
|
||||||
|
localFeatures := lnwire.NewRawFeatureVector(
|
||||||
|
wtwire.WtSessionsOptional,
|
||||||
|
)
|
||||||
|
|
||||||
|
s := &Server{
|
||||||
|
cfg: cfg,
|
||||||
|
clients: make(map[wtdb.SessionID]Peer),
|
||||||
|
globalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
localFeatures: localFeatures,
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
connMgr, err := connmgr.New(&connmgr.Config{
|
||||||
|
Listeners: cfg.Listeners,
|
||||||
|
OnAccept: s.inboundPeerConnected,
|
||||||
|
Dial: noDial,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.connMgr = connMgr
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins listening on the server's listeners.
|
||||||
|
func (s *Server) Start() error {
|
||||||
|
// Already running?
|
||||||
|
if !atomic.CompareAndSwapInt32(&s.started, 0, 1) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.connMgr.Start()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop shutdowns down the server's listeners and any active requests.
|
||||||
|
func (s *Server) Stop() error {
|
||||||
|
// Bail if we're already shutting down.
|
||||||
|
if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.connMgr.Stop()
|
||||||
|
|
||||||
|
close(s.quit)
|
||||||
|
s.wg.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// inboundPeerConnected is the callback given to the connection manager, and is
|
||||||
|
// called each time a new connection is made to the watchtower. This method
|
||||||
|
// proxies the new peers by filtering out those that do not satisfy the
|
||||||
|
// server.Peer interface, and closes their connection. Successful connections
|
||||||
|
// will be passed on to the public InboundPeerConnected method.
|
||||||
|
func (s *Server) inboundPeerConnected(c net.Conn) {
|
||||||
|
peer, ok := c.(Peer)
|
||||||
|
if !ok {
|
||||||
|
log.Warnf("incoming connection %T does not satisfy "+
|
||||||
|
"server.Peer interface", c)
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.InboundPeerConnected(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InboundPeerConnected accepts a server.Peer, and handles the request submitted
|
||||||
|
// by the client. This method serves also as a public endpoint for locally
|
||||||
|
// registering new clients with the server.
|
||||||
|
func (s *Server) InboundPeerConnected(peer Peer) {
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.handleClient(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleClient processes a series watchtower messages sent by a client. The
|
||||||
|
// client may either send:
|
||||||
|
// * a single CreateSession message.
|
||||||
|
// * a series of StateUpdate messages.
|
||||||
|
//
|
||||||
|
// This method uses the server's peer map to ensure at most one peer using the
|
||||||
|
// same session id can enter the main event loop. The connection will be
|
||||||
|
// dropped by the watchtower if no messages are sent or received by the
|
||||||
|
// configured Read/WriteTimeouts.
|
||||||
|
//
|
||||||
|
// NOTE: This method MUST be run as a goroutine.
|
||||||
|
func (s *Server) handleClient(peer Peer) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
// Use the connection's remote pubkey as the client's session id.
|
||||||
|
id := wtdb.NewSessionIDFromPubKey(peer.RemotePub())
|
||||||
|
|
||||||
|
// Register this peer in the server's client map, and defer the
|
||||||
|
// connection's cleanup. If the peer already exists, we will close the
|
||||||
|
// connection and exit immediately.
|
||||||
|
err := s.addPeer(&id, peer)
|
||||||
|
if err != nil {
|
||||||
|
peer.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer s.removePeer(&id)
|
||||||
|
|
||||||
|
msg, err := s.readMessage(peer)
|
||||||
|
remoteInit, ok := msg.(*wtwire.Init)
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("Client %s@%s did not send Init msg as first "+
|
||||||
|
"message", id, peer.RemoteAddr())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
localInit := wtwire.NewInitMessage(
|
||||||
|
s.localFeatures, s.globalFeatures,
|
||||||
|
)
|
||||||
|
|
||||||
|
err = s.sendMessage(peer, localInit)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Unable to send Init msg to %s: %v", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = s.handleInit(localInit, remoteInit); err != nil {
|
||||||
|
log.Errorf("Cannot support client %s: %v", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateUpdateOnlyMode will become true if the client's first message is
|
||||||
|
// a StateUpdate. If instead, it is a CreateSession, this method will exit
|
||||||
|
// immediately after replying. We track this to ensure that the client
|
||||||
|
// can't send a CreateSession after having already sent a StateUpdate.
|
||||||
|
var stateUpdateOnlyMode bool
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.quit:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
nextMsg, err := s.readMessage(peer)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Unable to read watchtower msg from %x: %v",
|
||||||
|
id[:], err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the request according to the message's type.
|
||||||
|
switch msg := nextMsg.(type) {
|
||||||
|
|
||||||
|
// A CreateSession indicates a request to establish a new session
|
||||||
|
// with our watchtower.
|
||||||
|
case *wtwire.CreateSession:
|
||||||
|
// Ensure CreateSession can only be sent as the first
|
||||||
|
// message.
|
||||||
|
if stateUpdateOnlyMode {
|
||||||
|
log.Errorf("client %x sent CreateSession after "+
|
||||||
|
"StateUpdate", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Received CreateSession from %s, "+
|
||||||
|
"version=%d nupdates=%d rewardrate=%d "+
|
||||||
|
"sweepfeerate=%d", id, msg.BlobVersion,
|
||||||
|
msg.MaxUpdates, msg.RewardRate,
|
||||||
|
msg.SweepFeeRate)
|
||||||
|
|
||||||
|
// Attempt to open a new session for this client.
|
||||||
|
err := s.handleCreateSession(peer, &id, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to handle CreateSession "+
|
||||||
|
"from %s: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exit after replying to CreateSession.
|
||||||
|
return
|
||||||
|
|
||||||
|
// A StateUpdate indicates an existing client attempting to
|
||||||
|
// back-up a revoked commitment state.
|
||||||
|
case *wtwire.StateUpdate:
|
||||||
|
log.Infof("Received SessionUpdate from %s, seqnum=%d "+
|
||||||
|
"lastapplied=%d complete=%v hint=%x", id,
|
||||||
|
msg.SeqNum, msg.LastApplied, msg.IsComplete,
|
||||||
|
msg.Hint[:])
|
||||||
|
|
||||||
|
// Try to accept the state update from the client.
|
||||||
|
err := s.handleStateUpdate(peer, &id, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to handle StateUpdate "+
|
||||||
|
"from %s: %v", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client signals that this is last StateUpdate
|
||||||
|
// message, we can disconnect the client.
|
||||||
|
if msg.IsComplete == 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// The client has signaled that more StateUpdates are
|
||||||
|
// yet to come. Enter state-update-only mode to disallow
|
||||||
|
// future sends of CreateSession messages.
|
||||||
|
stateUpdateOnlyMode = true
|
||||||
|
|
||||||
|
default:
|
||||||
|
log.Errorf("received unsupported message type: %T "+
|
||||||
|
"from %s", nextMsg, id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error {
|
||||||
|
remoteLocalFeatures := lnwire.NewFeatureVector(
|
||||||
|
remoteInit.LocalFeatures, wtwire.LocalFeatures,
|
||||||
|
)
|
||||||
|
remoteGlobalFeatures := lnwire.NewFeatureVector(
|
||||||
|
remoteInit.GlobalFeatures, wtwire.GlobalFeatures,
|
||||||
|
)
|
||||||
|
|
||||||
|
unknownLocalFeatures := remoteLocalFeatures.UnknownRequiredFeatures()
|
||||||
|
if len(unknownLocalFeatures) > 0 {
|
||||||
|
err := fmt.Errorf("Peer set unknown local feature bits: %v",
|
||||||
|
unknownLocalFeatures)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
unknownGlobalFeatures := remoteGlobalFeatures.UnknownRequiredFeatures()
|
||||||
|
if len(unknownGlobalFeatures) > 0 {
|
||||||
|
err := fmt.Errorf("Peer set unknown global feature bits: %v",
|
||||||
|
unknownGlobalFeatures)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) readMessage(peer Peer) (wtwire.Message, error) {
|
||||||
|
// Set a read timeout to ensure we drop the client if not sent in a
|
||||||
|
// timely manner.
|
||||||
|
err := peer.SetReadDeadline(time.Now().Add(s.cfg.ReadTimeout))
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to set read deadline: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pull the next message off the wire, and parse it according to the
|
||||||
|
// watchtower wire specification.
|
||||||
|
rawMsg, err := peer.ReadNextMessage()
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to read message: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msgReader := bytes.NewReader(rawMsg)
|
||||||
|
nextMsg, err := wtwire.ReadMessage(msgReader, 0)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to parse message: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextMsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCreateSession processes a CreateSession message from the peer, and returns
|
||||||
|
// a CreateSessionReply in response. This method will only succeed if no existing
|
||||||
|
// session info is known about the session id. If an existing session is found,
|
||||||
|
// the reward address is returned in case the client lost our reply.
|
||||||
|
func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
||||||
|
init *wtwire.CreateSession) error {
|
||||||
|
|
||||||
|
// TODO(conner): validate accept against policy
|
||||||
|
|
||||||
|
// Query the db for session info belonging to the client's session id.
|
||||||
|
existingInfo, err := s.cfg.DB.GetSessionInfo(id)
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// We already have a session corresponding to this session id, return an
|
||||||
|
// error signaling that it already exists in our database. We return the
|
||||||
|
// reward address to the client in case they were not able to process
|
||||||
|
// our reply earlier.
|
||||||
|
case err == nil:
|
||||||
|
log.Debugf("Already have session for %s", id)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CreateSessionCodeAlreadyExists,
|
||||||
|
[]byte(existingInfo.RewardAddress),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Some other database error occurred, return a temporary failure.
|
||||||
|
case err != wtdb.ErrSessionNotFound:
|
||||||
|
log.Errorf("unable to load session info for %s", id)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we've established that this session does not exist in the
|
||||||
|
// database, retrieve the sweep address that will be given to the
|
||||||
|
// client. This address is to be included by the client when signing
|
||||||
|
// sweep transactions destined for this tower, if its negotiated output
|
||||||
|
// is not dust.
|
||||||
|
rewardAddress, err := s.cfg.NewAddress()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to generate reward addr for %s", id)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
rewardAddrBytes := rewardAddress.ScriptAddress()
|
||||||
|
|
||||||
|
// TODO(conner): create invoice for upfront payment
|
||||||
|
|
||||||
|
// Assemble the session info using the agreed upon parameters, reward
|
||||||
|
// address, and session id.
|
||||||
|
info := wtdb.SessionInfo{
|
||||||
|
ID: *id,
|
||||||
|
Version: init.BlobVersion,
|
||||||
|
MaxUpdates: init.MaxUpdates,
|
||||||
|
RewardRate: init.RewardRate,
|
||||||
|
SweepFeeRate: init.SweepFeeRate,
|
||||||
|
RewardAddress: rewardAddrBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the session info into the watchtower's database. If
|
||||||
|
// successful, the session will now be ready for use.
|
||||||
|
err = s.cfg.DB.InsertSessionInfo(&info)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to create session for %s", id)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Accepted session for %s", id)
|
||||||
|
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeOK, rewardAddrBytes,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStateUpdate processes a StateUpdate message request from a client. An
|
||||||
|
// attempt will be made to insert the update into the db, where it is validated
|
||||||
|
// against the client's session. The possible errors are then mapped back to
|
||||||
|
// StateUpdateCodes specified by the watchtower wire protocol, and sent back
|
||||||
|
// using a StateUpdateReply message.
|
||||||
|
func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
|
||||||
|
update *wtwire.StateUpdate) error {
|
||||||
|
|
||||||
|
var (
|
||||||
|
lastApplied uint16
|
||||||
|
failCode wtwire.ErrorCode
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
sessionUpdate := wtdb.SessionStateUpdate{
|
||||||
|
ID: *id,
|
||||||
|
Hint: update.Hint,
|
||||||
|
SeqNum: update.SeqNum,
|
||||||
|
LastApplied: update.LastApplied,
|
||||||
|
EncryptedBlob: update.EncryptedBlob,
|
||||||
|
}
|
||||||
|
|
||||||
|
lastApplied, err = s.cfg.DB.InsertStateUpdate(&sessionUpdate)
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
log.Infof("State update %d accepted for %s",
|
||||||
|
update.SeqNum, id)
|
||||||
|
|
||||||
|
failCode = wtwire.CodeOK
|
||||||
|
|
||||||
|
// Return a permanent failure if a client tries to send an update for
|
||||||
|
// which we have no session.
|
||||||
|
case err == wtdb.ErrSessionNotFound:
|
||||||
|
failCode = wtwire.CodePermanentFailure
|
||||||
|
|
||||||
|
case err == wtdb.ErrSeqNumAlreadyApplied:
|
||||||
|
failCode = wtwire.CodePermanentFailure
|
||||||
|
|
||||||
|
// TODO(conner): remove session state for protocol
|
||||||
|
// violation. Could also double as clean up method for
|
||||||
|
// session-related state.
|
||||||
|
|
||||||
|
case err == wtdb.ErrLastAppliedReversion:
|
||||||
|
failCode = wtwire.StateUpdateCodeClientBehind
|
||||||
|
|
||||||
|
case err == wtdb.ErrSessionConsumed:
|
||||||
|
failCode = wtwire.StateUpdateCodeMaxUpdatesExceeded
|
||||||
|
|
||||||
|
case err == wtdb.ErrUpdateOutOfOrder:
|
||||||
|
failCode = wtwire.StateUpdateCodeSeqNumOutOfOrder
|
||||||
|
|
||||||
|
default:
|
||||||
|
failCode = wtwire.CodeTemporaryFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.replyStateUpdate(
|
||||||
|
peer, id, failCode, lastApplied,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// connFailure is a default error used when a request failed with a non-zero
|
||||||
|
// error code.
|
||||||
|
type connFailure struct {
|
||||||
|
ID wtdb.SessionID
|
||||||
|
Code uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error displays the SessionID and Code that caused the connection failure.
|
||||||
|
func (f *connFailure) Error() string {
|
||||||
|
return fmt.Sprintf("connection with %s failed with code=%v",
|
||||||
|
f.ID, f.Code,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// replyCreateSession sends a response to a CreateSession from a client. If the
|
||||||
|
// status code in the reply is OK, the error from the write will be bubbled up.
|
||||||
|
// Otherwise, this method returns a connection error to ensure we don't continue
|
||||||
|
// communication with the client.
|
||||||
|
func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
|
||||||
|
code wtwire.ErrorCode, data []byte) error {
|
||||||
|
|
||||||
|
msg := &wtwire.CreateSessionReply{
|
||||||
|
Code: code,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.sendMessage(peer, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to send CreateSessionReply to %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the write error if the request succeeded.
|
||||||
|
if code == wtwire.CodeOK {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the request failed, return a connection failure to
|
||||||
|
// disconnect the client.
|
||||||
|
return &connFailure{
|
||||||
|
ID: *id,
|
||||||
|
Code: uint16(code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// replyStateUpdate sends a response to a StateUpdate from a client. If the
|
||||||
|
// status code in the reply is OK, the error from the write will be bubbled up.
|
||||||
|
// Otherwise, this method returns a connection error to ensure we don't continue
|
||||||
|
// communication with the client.
|
||||||
|
func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID,
|
||||||
|
code wtwire.StateUpdateCode, lastApplied uint16) error {
|
||||||
|
|
||||||
|
msg := &wtwire.StateUpdateReply{
|
||||||
|
Code: code,
|
||||||
|
LastApplied: lastApplied,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.sendMessage(peer, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to send StateUpdateReply to %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the write error if the request succeeded.
|
||||||
|
if code == wtwire.CodeOK {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the request failed, return a connection failure to
|
||||||
|
// disconnect the client.
|
||||||
|
return &connFailure{
|
||||||
|
ID: *id,
|
||||||
|
Code: uint16(code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMessage sends a watchtower wire message to the target peer.
|
||||||
|
func (s *Server) sendMessage(peer Peer, msg wtwire.Message) error {
|
||||||
|
// TODO(conner): use buffer pool?
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
_, err := wtwire.WriteMessage(&b, msg, 0)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to encode msg: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = peer.SetWriteDeadline(time.Now().Add(s.cfg.WriteTimeout))
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to set write deadline: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = peer.Write(b.Bytes())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPeer stores a client in the server's client map. An error is returned if a
|
||||||
|
// client with the same session id already exists.
|
||||||
|
func (s *Server) addPeer(id *wtdb.SessionID, peer Peer) error {
|
||||||
|
s.clientMtx.Lock()
|
||||||
|
defer s.clientMtx.Unlock()
|
||||||
|
|
||||||
|
if existingPeer, ok := s.clients[*id]; ok {
|
||||||
|
log.Infof("Already connected to peer %s@%s, disconnecting %s",
|
||||||
|
id, existingPeer.RemoteAddr(), peer.RemoteAddr())
|
||||||
|
return ErrPeerAlreadyConnected
|
||||||
|
}
|
||||||
|
s.clients[*id] = peer
|
||||||
|
|
||||||
|
log.Infof("Accepted incoming peer %s@%s",
|
||||||
|
id, peer.RemoteAddr())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removePeer deletes a client from the server's client map. If a peer is found,
|
||||||
|
// this method will close the peer's connection.
|
||||||
|
func (s *Server) removePeer(id *wtdb.SessionID) {
|
||||||
|
log.Infof("Releasing incoming peer %s", id)
|
||||||
|
|
||||||
|
s.clientMtx.Lock()
|
||||||
|
peer, ok := s.clients[*id]
|
||||||
|
delete(s.clients, *id)
|
||||||
|
s.clientMtx.Unlock()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
peer.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// noDial is a dummy dial method passed to the server's connmgr.
|
||||||
|
func noDial(_ net.Addr) (net.Conn, error) {
|
||||||
|
return nil, fmt.Errorf("watchtower cannot make outgoing conns")
|
||||||
|
}
|
634
watchtower/server/server_test.go
Normal file
634
watchtower/server/server_test.go
Normal file
@ -0,0 +1,634 @@
|
|||||||
|
// +build dev
|
||||||
|
|
||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/btcsuite/btcd/chaincfg"
|
||||||
|
"github.com/btcsuite/btcutil"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/server"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// addr is the server's reward address given to watchtower clients.
|
||||||
|
var addr, _ = btcutil.DecodeAddress(
|
||||||
|
"mrX9vMRYLfVy1BnZbc5gZjuyaqH3ZW2ZHz", &chaincfg.TestNet3Params,
|
||||||
|
)
|
||||||
|
|
||||||
|
// randPubKey generates a new secp keypair, and returns the public key.
|
||||||
|
func randPubKey(t *testing.T) *btcec.PublicKey {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
sk, err := btcec.NewPrivateKey(btcec.S256())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate pubkey: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sk.PubKey()
|
||||||
|
}
|
||||||
|
|
||||||
|
// initServer creates and starts a new server using the server.DB and timeout.
|
||||||
|
// If the provided database is nil, a mock db will be used.
|
||||||
|
func initServer(t *testing.T, db server.DB,
|
||||||
|
timeout time.Duration) server.Interface {
|
||||||
|
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if db == nil {
|
||||||
|
db = wtdb.NewMockDB()
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := server.New(&server.Config{
|
||||||
|
DB: db,
|
||||||
|
ReadTimeout: timeout,
|
||||||
|
WriteTimeout: timeout,
|
||||||
|
NewAddress: func() (btcutil.Address, error) {
|
||||||
|
return addr, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = s.Start(); err != nil {
|
||||||
|
t.Fatalf("unable to start server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServerOnlyAcceptOnePeer checks that the server will reject duplicate
|
||||||
|
// peers with the same session id by disconnecting them. This is accomplished by
|
||||||
|
// connecting two distinct peers with the same session id, and trying to send
|
||||||
|
// messages on both connections. Since one should be rejected, we verify that
|
||||||
|
// only one of the connections is able to send messages.
|
||||||
|
func TestServerOnlyAcceptOnePeer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const timeoutDuration = 500 * time.Millisecond
|
||||||
|
|
||||||
|
s := initServer(t, nil, timeoutDuration)
|
||||||
|
defer s.Stop()
|
||||||
|
|
||||||
|
// Create two peers using the same session id.
|
||||||
|
peerPub := randPubKey(t)
|
||||||
|
peer1 := server.NewMockPeer(peerPub, nil, 0)
|
||||||
|
peer2 := server.NewMockPeer(peerPub, nil, 0)
|
||||||
|
|
||||||
|
// Serialize a Init message to be sent by both peers.
|
||||||
|
init := wtwire.NewInitMessage(
|
||||||
|
lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
|
||||||
|
)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
_, err := wtwire.WriteMessage(&b, init, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to write message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := b.Bytes()
|
||||||
|
|
||||||
|
// Connect both peers to the server simultaneously.
|
||||||
|
s.InboundPeerConnected(peer1)
|
||||||
|
s.InboundPeerConnected(peer2)
|
||||||
|
|
||||||
|
// Use a timeout of twice the server's timeouts, to ensure the server
|
||||||
|
// has time to process the messages.
|
||||||
|
timeout := time.After(2 * timeoutDuration)
|
||||||
|
|
||||||
|
// Try to send a message on either peer, and record the opposite peer as
|
||||||
|
// the one we assume to be rejected.
|
||||||
|
var (
|
||||||
|
rejectedPeer *server.MockPeer
|
||||||
|
acceptedPeer *server.MockPeer
|
||||||
|
)
|
||||||
|
select {
|
||||||
|
case peer1.IncomingMsgs <- msg:
|
||||||
|
acceptedPeer = peer1
|
||||||
|
rejectedPeer = peer2
|
||||||
|
case peer2.IncomingMsgs <- msg:
|
||||||
|
acceptedPeer = peer2
|
||||||
|
rejectedPeer = peer1
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatalf("unable to send message via either peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try again to send a message, this time only via the assumed-rejected
|
||||||
|
// peer. We expect our conservative timeout to expire, as the server
|
||||||
|
// isn't reading from this peer. Before the timeout, the accepted peer
|
||||||
|
// should also receive a reply to its Init message.
|
||||||
|
select {
|
||||||
|
case <-acceptedPeer.OutgoingMsgs:
|
||||||
|
select {
|
||||||
|
case rejectedPeer.IncomingMsgs <- msg:
|
||||||
|
t.Fatalf("rejected peer should not have received message")
|
||||||
|
case <-timeout:
|
||||||
|
// Accepted peer got reply, rejected peer go nothing.
|
||||||
|
}
|
||||||
|
case rejectedPeer.IncomingMsgs <- msg:
|
||||||
|
t.Fatalf("rejected peer should not have received message")
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatalf("accepted peer should have received init message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type createSessionTestCase struct {
|
||||||
|
name string
|
||||||
|
initMsg *wtwire.Init
|
||||||
|
createMsg *wtwire.CreateSession
|
||||||
|
expReply *wtwire.CreateSessionReply
|
||||||
|
expDupReply *wtwire.CreateSessionReply
|
||||||
|
}
|
||||||
|
|
||||||
|
var createSessionTests = []createSessionTestCase{
|
||||||
|
{
|
||||||
|
name: "reject duplicate session create",
|
||||||
|
initMsg: wtwire.NewInitMessage(
|
||||||
|
lnwire.NewRawFeatureVector(),
|
||||||
|
lnwire.NewRawFeatureVector(),
|
||||||
|
),
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobVersion: 0,
|
||||||
|
MaxUpdates: 1000,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
expReply: &wtwire.CreateSessionReply{
|
||||||
|
Code: wtwire.CodeOK,
|
||||||
|
Data: []byte(addr.ScriptAddress()),
|
||||||
|
},
|
||||||
|
expDupReply: &wtwire.CreateSessionReply{
|
||||||
|
Code: wtwire.CreateSessionCodeAlreadyExists,
|
||||||
|
Data: []byte(addr.ScriptAddress()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// TODO(conner): add policy rejection tests
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServerCreateSession checks the server's behavior in response to a
|
||||||
|
// table-driven set of CreateSession messages.
|
||||||
|
func TestServerCreateSession(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for i, test := range createSessionTests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
testServerCreateSession(t, i, test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
|
||||||
|
const timeoutDuration = 500 * time.Millisecond
|
||||||
|
|
||||||
|
s := initServer(t, nil, timeoutDuration)
|
||||||
|
defer s.Stop()
|
||||||
|
|
||||||
|
// Create a new client and connect to server.
|
||||||
|
peerPub := randPubKey(t)
|
||||||
|
peer := server.NewMockPeer(peerPub, nil, 0)
|
||||||
|
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
|
// Send the CreateSession message, and wait for a reply.
|
||||||
|
sendMsg(t, i, test.createMsg, peer, timeoutDuration)
|
||||||
|
|
||||||
|
reply := recvReply(
|
||||||
|
t, i, "CreateSessionReply", peer, timeoutDuration,
|
||||||
|
).(*wtwire.CreateSessionReply)
|
||||||
|
|
||||||
|
// Verify that the server's response matches our expectation.
|
||||||
|
if !reflect.DeepEqual(reply, test.expReply) {
|
||||||
|
t.Fatalf("[test %d] expected reply %v, got %d",
|
||||||
|
i, test.expReply, reply)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that the server closes the connection after processing the
|
||||||
|
// CreateSession.
|
||||||
|
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||||
|
|
||||||
|
// If this test did not request sending a duplicate CreateSession, we can
|
||||||
|
// continue to the next test.
|
||||||
|
if test.expDupReply == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate a peer with the same session id connection to the server
|
||||||
|
// again.
|
||||||
|
peer = server.NewMockPeer(peerPub, nil, 0)
|
||||||
|
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
|
// Send the _same_ CreateSession message as the first attempt.
|
||||||
|
sendMsg(t, i, test.createMsg, peer, timeoutDuration)
|
||||||
|
|
||||||
|
reply = recvReply(
|
||||||
|
t, i, "CreateSessionReply", peer, timeoutDuration,
|
||||||
|
).(*wtwire.CreateSessionReply)
|
||||||
|
|
||||||
|
// Ensure that the server's reply matches our expected response for a
|
||||||
|
// duplicate send.
|
||||||
|
if !reflect.DeepEqual(reply, test.expDupReply) {
|
||||||
|
t.Fatalf("[test %d] expected reply %v, got %d",
|
||||||
|
i, test.expReply, reply)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, check that the server tore down the connection.
|
||||||
|
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateUpdateTestCase struct {
|
||||||
|
name string
|
||||||
|
initMsg *wtwire.Init
|
||||||
|
createMsg *wtwire.CreateSession
|
||||||
|
updates []*wtwire.StateUpdate
|
||||||
|
replies []*wtwire.StateUpdateReply
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateUpdateTests = []stateUpdateTestCase{
|
||||||
|
// Valid update sequence, send seqnum == lastapplied as last update.
|
||||||
|
{
|
||||||
|
name: "perm fail after sending seqnum equal lastapplied",
|
||||||
|
initMsg: &wtwire.Init{&lnwire.Init{
|
||||||
|
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
}},
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobVersion: 0,
|
||||||
|
MaxUpdates: 3,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
updates: []*wtwire.StateUpdate{
|
||||||
|
{SeqNum: 1, LastApplied: 0},
|
||||||
|
{SeqNum: 2, LastApplied: 1},
|
||||||
|
{SeqNum: 3, LastApplied: 2},
|
||||||
|
{SeqNum: 3, LastApplied: 3},
|
||||||
|
},
|
||||||
|
replies: []*wtwire.StateUpdateReply{
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 1},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 2},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 3},
|
||||||
|
{
|
||||||
|
Code: wtwire.CodePermanentFailure,
|
||||||
|
LastApplied: 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Send update that skips next expected sequence number.
|
||||||
|
{
|
||||||
|
name: "skip sequence number",
|
||||||
|
initMsg: &wtwire.Init{&lnwire.Init{
|
||||||
|
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
}},
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobVersion: 0,
|
||||||
|
MaxUpdates: 4,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
updates: []*wtwire.StateUpdate{
|
||||||
|
{SeqNum: 2, LastApplied: 0},
|
||||||
|
},
|
||||||
|
replies: []*wtwire.StateUpdateReply{
|
||||||
|
{
|
||||||
|
Code: wtwire.StateUpdateCodeSeqNumOutOfOrder,
|
||||||
|
LastApplied: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Send update that reverts to older sequence number.
|
||||||
|
{
|
||||||
|
name: "revert to older seqnum",
|
||||||
|
initMsg: &wtwire.Init{&lnwire.Init{
|
||||||
|
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
}},
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobVersion: 0,
|
||||||
|
MaxUpdates: 4,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
updates: []*wtwire.StateUpdate{
|
||||||
|
{SeqNum: 1, LastApplied: 0},
|
||||||
|
{SeqNum: 2, LastApplied: 0},
|
||||||
|
{SeqNum: 1, LastApplied: 0},
|
||||||
|
},
|
||||||
|
replies: []*wtwire.StateUpdateReply{
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 1},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 2},
|
||||||
|
{
|
||||||
|
Code: wtwire.StateUpdateCodeSeqNumOutOfOrder,
|
||||||
|
LastApplied: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Send update echoing a last applied that is lower than previous value.
|
||||||
|
{
|
||||||
|
name: "revert to older lastapplied",
|
||||||
|
initMsg: &wtwire.Init{&lnwire.Init{
|
||||||
|
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
}},
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobVersion: 0,
|
||||||
|
MaxUpdates: 4,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
updates: []*wtwire.StateUpdate{
|
||||||
|
{SeqNum: 1, LastApplied: 0},
|
||||||
|
{SeqNum: 2, LastApplied: 1},
|
||||||
|
{SeqNum: 3, LastApplied: 2},
|
||||||
|
{SeqNum: 4, LastApplied: 1},
|
||||||
|
},
|
||||||
|
replies: []*wtwire.StateUpdateReply{
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 1},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 2},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 3},
|
||||||
|
{Code: wtwire.StateUpdateCodeClientBehind, LastApplied: 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Valid update sequence with disconnection, ensure resumes resume.
|
||||||
|
// Client echos last applied as they are received.
|
||||||
|
{
|
||||||
|
name: "resume after disconnect",
|
||||||
|
initMsg: &wtwire.Init{&lnwire.Init{
|
||||||
|
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
}},
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobVersion: 0,
|
||||||
|
MaxUpdates: 4,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
updates: []*wtwire.StateUpdate{
|
||||||
|
{SeqNum: 1, LastApplied: 0},
|
||||||
|
{SeqNum: 2, LastApplied: 1},
|
||||||
|
nil, // Wait for read timeout to drop conn, then reconnect.
|
||||||
|
{SeqNum: 3, LastApplied: 2},
|
||||||
|
{SeqNum: 4, LastApplied: 3},
|
||||||
|
},
|
||||||
|
replies: []*wtwire.StateUpdateReply{
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 1},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 2},
|
||||||
|
nil,
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 3},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 4},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Valid update sequence with disconnection, ensure resumes resume.
|
||||||
|
// Client doesn't echo last applied until last message.
|
||||||
|
{
|
||||||
|
name: "resume after disconnect lagging lastapplied",
|
||||||
|
initMsg: &wtwire.Init{&lnwire.Init{
|
||||||
|
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
}},
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobVersion: 0,
|
||||||
|
MaxUpdates: 4,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
updates: []*wtwire.StateUpdate{
|
||||||
|
{SeqNum: 1, LastApplied: 0},
|
||||||
|
{SeqNum: 2, LastApplied: 0},
|
||||||
|
nil, // Wait for read timeout to drop conn, then reconnect.
|
||||||
|
{SeqNum: 3, LastApplied: 0},
|
||||||
|
{SeqNum: 4, LastApplied: 3},
|
||||||
|
},
|
||||||
|
replies: []*wtwire.StateUpdateReply{
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 1},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 2},
|
||||||
|
nil,
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 3},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 4},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Send update with sequence number that exceeds MaxUpdates.
|
||||||
|
{
|
||||||
|
name: "seqnum exceed maxupdates",
|
||||||
|
initMsg: &wtwire.Init{&lnwire.Init{
|
||||||
|
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
}},
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobVersion: 0,
|
||||||
|
MaxUpdates: 3,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
updates: []*wtwire.StateUpdate{
|
||||||
|
{SeqNum: 1, LastApplied: 0},
|
||||||
|
{SeqNum: 2, LastApplied: 1},
|
||||||
|
{SeqNum: 3, LastApplied: 2},
|
||||||
|
{SeqNum: 4, LastApplied: 3},
|
||||||
|
},
|
||||||
|
replies: []*wtwire.StateUpdateReply{
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 1},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 2},
|
||||||
|
{Code: wtwire.CodeOK, LastApplied: 3},
|
||||||
|
{
|
||||||
|
Code: wtwire.StateUpdateCodeMaxUpdatesExceeded,
|
||||||
|
LastApplied: 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Ensure sequence number 0 causes permanent failure.
|
||||||
|
{
|
||||||
|
name: "perm fail after seqnum 0",
|
||||||
|
initMsg: &wtwire.Init{&lnwire.Init{
|
||||||
|
LocalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
GlobalFeatures: lnwire.NewRawFeatureVector(),
|
||||||
|
}},
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobVersion: 0,
|
||||||
|
MaxUpdates: 3,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
updates: []*wtwire.StateUpdate{
|
||||||
|
{SeqNum: 0, LastApplied: 0},
|
||||||
|
},
|
||||||
|
replies: []*wtwire.StateUpdateReply{
|
||||||
|
{
|
||||||
|
Code: wtwire.CodePermanentFailure,
|
||||||
|
LastApplied: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServerStateUpdates tests the behavior of the server in response to
|
||||||
|
// watchtower clients sending StateUpdate messages, after having already
|
||||||
|
// established an open session. The test asserts that the server responds
|
||||||
|
// with the appropriate failure codes in a number of failure conditions where
|
||||||
|
// the server and client desynchronize. It also checks the ability of the client
|
||||||
|
// to disconnect, connect, and continue updating from the last successful state
|
||||||
|
// update.
|
||||||
|
func TestServerStateUpdates(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for i, test := range stateUpdateTests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
testServerStateUpdates(t, i, test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
|
||||||
|
const timeoutDuration = 100 * time.Millisecond
|
||||||
|
|
||||||
|
s := initServer(t, nil, timeoutDuration)
|
||||||
|
defer s.Stop()
|
||||||
|
|
||||||
|
// Create a new client and connect to the server.
|
||||||
|
peerPub := randPubKey(t)
|
||||||
|
peer := server.NewMockPeer(peerPub, nil, 0)
|
||||||
|
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
|
// Register a session for this client to use in the subsequent tests.
|
||||||
|
sendMsg(t, i, test.createMsg, peer, timeoutDuration)
|
||||||
|
initReply := recvReply(
|
||||||
|
t, i, "CreateSessionReply", peer, timeoutDuration,
|
||||||
|
).(*wtwire.CreateSessionReply)
|
||||||
|
|
||||||
|
// Fail if the server rejected our proposed CreateSession message.
|
||||||
|
if initReply.Code != wtwire.CodeOK {
|
||||||
|
t.Fatalf("[test %d] server rejected session init", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the server closed the connection used to register the
|
||||||
|
// session.
|
||||||
|
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||||
|
|
||||||
|
// Now that the original connection has been closed, connect a new
|
||||||
|
// client with the same session id.
|
||||||
|
peer = server.NewMockPeer(peerPub, nil, 0)
|
||||||
|
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
|
// Send the intended StateUpdate messages in series.
|
||||||
|
for j, update := range test.updates {
|
||||||
|
// A nil update signals that we should wait for the prior
|
||||||
|
// connection to die, before re-register with the same session
|
||||||
|
// identifier.
|
||||||
|
if update == nil {
|
||||||
|
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||||
|
|
||||||
|
peer = server.NewMockPeer(peerPub, nil, 0)
|
||||||
|
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the state update and verify it against our expected
|
||||||
|
// response.
|
||||||
|
sendMsg(t, i, update, peer, timeoutDuration)
|
||||||
|
reply := recvReply(
|
||||||
|
t, i, "StateUpdateReply", peer, timeoutDuration,
|
||||||
|
).(*wtwire.StateUpdateReply)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(reply, test.replies[j]) {
|
||||||
|
t.Fatalf("[test %d, update %d] expected reply "+
|
||||||
|
"%v, got %d", i, j,
|
||||||
|
test.replies[j], reply)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the final connection is properly cleaned up by the server.
|
||||||
|
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func connect(t *testing.T, i int, s server.Interface, peer *server.MockPeer,
|
||||||
|
initMsg *wtwire.Init, timeout time.Duration) {
|
||||||
|
|
||||||
|
s.InboundPeerConnected(peer)
|
||||||
|
sendMsg(t, i, initMsg, peer, timeout)
|
||||||
|
recvReply(t, i, "Init", peer, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMsg sends a wtwire.Message message via a server.MockPeer.
|
||||||
|
func sendMsg(t *testing.T, i int, msg wtwire.Message,
|
||||||
|
peer *server.MockPeer, timeout time.Duration) {
|
||||||
|
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
_, err := wtwire.WriteMessage(&b, msg, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("[test %d] unable to encode %T message: %v",
|
||||||
|
i, msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case peer.IncomingMsgs <- b.Bytes():
|
||||||
|
case <-time.After(2 * timeout):
|
||||||
|
t.Fatalf("[test %d] unable to send %T message", i, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// recvReply receives a message from the server, and parses it according to
|
||||||
|
// expected reply type. The supported replies are CreateSessionReply and
|
||||||
|
// StateUpdateReply.
|
||||||
|
func recvReply(t *testing.T, i int, name string,
|
||||||
|
peer *server.MockPeer, timeout time.Duration) wtwire.Message {
|
||||||
|
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var (
|
||||||
|
msg wtwire.Message
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case b := <-peer.OutgoingMsgs:
|
||||||
|
msg, err = wtwire.ReadMessage(bytes.NewReader(b), 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("[test %d] unable to decode server "+
|
||||||
|
"reply: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-time.After(2 * timeout):
|
||||||
|
t.Fatalf("[test %d] server did not reply", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "Init":
|
||||||
|
if _, ok := msg.(*wtwire.Init); !ok {
|
||||||
|
t.Fatalf("[test %d] expected %s reply "+
|
||||||
|
"message, got %T", i, name, msg)
|
||||||
|
}
|
||||||
|
case "CreateSessionReply":
|
||||||
|
if _, ok := msg.(*wtwire.CreateSessionReply); !ok {
|
||||||
|
t.Fatalf("[test %d] expected %s reply "+
|
||||||
|
"message, got %T", i, name, msg)
|
||||||
|
}
|
||||||
|
case "StateUpdateReply":
|
||||||
|
if _, ok := msg.(*wtwire.StateUpdateReply); !ok {
|
||||||
|
t.Fatalf("[test %d] expected %s reply "+
|
||||||
|
"message, got %T", i, name, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertConnClosed checks that the peer's connection is closed before the
|
||||||
|
// timeout expires.
|
||||||
|
func assertConnClosed(t *testing.T, peer *server.MockPeer, duration time.Duration) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-peer.Quit:
|
||||||
|
case <-time.After(duration):
|
||||||
|
t.Fatalf("expected connection to be closed")
|
||||||
|
}
|
||||||
|
}
|
27
watchtower/wtdb/breach_hint.go
Normal file
27
watchtower/wtdb/breach_hint.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package wtdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BreachHintSize is the length of the txid prefix used to identify remote
|
||||||
|
// commitment broadcasts.
|
||||||
|
const BreachHintSize = 16
|
||||||
|
|
||||||
|
// BreachHint is the first 16-bytes of the txid belonging to a revoked
|
||||||
|
// commitment transaction.
|
||||||
|
type BreachHint [BreachHintSize]byte
|
||||||
|
|
||||||
|
// NewBreachHintFromHash creates a breach hint from a transaction ID.
|
||||||
|
func NewBreachHintFromHash(hash *chainhash.Hash) BreachHint {
|
||||||
|
var hint BreachHint
|
||||||
|
copy(hint[:], hash[:BreachHintSize])
|
||||||
|
return hint
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a hex encoding of the breach hint.
|
||||||
|
func (h BreachHint) String() string {
|
||||||
|
return hex.EncodeToString(h[:])
|
||||||
|
}
|
57
watchtower/wtdb/mock.go
Normal file
57
watchtower/wtdb/mock.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
// +build dev
|
||||||
|
|
||||||
|
package wtdb
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
type MockDB struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
sessions map[SessionID]*SessionInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockDB() *MockDB {
|
||||||
|
return &MockDB{
|
||||||
|
sessions: make(map[SessionID]*SessionInfo),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *MockDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
info, ok := db.sessions[update.ID]
|
||||||
|
if !ok {
|
||||||
|
return 0, ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied)
|
||||||
|
if err != nil {
|
||||||
|
return info.LastApplied, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return info.LastApplied, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *MockDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
if info, ok := db.sessions[*id]; ok {
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *MockDB) InsertSessionInfo(info *SessionInfo) error {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := db.sessions[info.ID]; ok {
|
||||||
|
return ErrSessionAlreadyExists
|
||||||
|
}
|
||||||
|
|
||||||
|
db.sessions[info.ID] = info
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
26
watchtower/wtdb/session_id.go
Normal file
26
watchtower/wtdb/session_id.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package wtdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionIDSize is 33-bytes; it is a serialized, compressed public key.
|
||||||
|
const SessionIDSize = 33
|
||||||
|
|
||||||
|
// SessionID is created from the remote public key of a client, and serves as a
|
||||||
|
// unique identifier and authentication for sending state updates.
|
||||||
|
type SessionID [SessionIDSize]byte
|
||||||
|
|
||||||
|
// NewSessionIDFromPubKey creates a new SessionID from a public key.
|
||||||
|
func NewSessionIDFromPubKey(pubKey *btcec.PublicKey) SessionID {
|
||||||
|
var sid SessionID
|
||||||
|
copy(sid[:], pubKey.SerializeCompressed())
|
||||||
|
return sid
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a hex encoding of the session id.
|
||||||
|
func (s SessionID) String() string {
|
||||||
|
return hex.EncodeToString(s[:])
|
||||||
|
}
|
105
watchtower/wtdb/session_info.go
Normal file
105
watchtower/wtdb/session_info.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package wtdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/lnwallet"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrSessionNotFound is returned when querying by session id for a
|
||||||
|
// session that does not exist.
|
||||||
|
ErrSessionNotFound = errors.New("session not found in db")
|
||||||
|
|
||||||
|
// ErrSessionAlreadyExists signals that a session creation failed
|
||||||
|
// because a session with the same session id already exists.
|
||||||
|
ErrSessionAlreadyExists = errors.New("session already exists")
|
||||||
|
|
||||||
|
// ErrUpdateOutOfOrder is returned when the sequence number is not equal
|
||||||
|
// to the server's LastApplied+1.
|
||||||
|
ErrUpdateOutOfOrder = errors.New("update sequence number is not " +
|
||||||
|
"sequential")
|
||||||
|
|
||||||
|
// ErrLastAppliedReversion is returned when the client echos a
|
||||||
|
// last-applied value that is less than it claimed in a prior update.
|
||||||
|
ErrLastAppliedReversion = errors.New("update last applied must be " +
|
||||||
|
"non-decreasing")
|
||||||
|
|
||||||
|
// ErrSeqNumAlreadyApplied is returned when the client sends a sequence
|
||||||
|
// number for which they already claim to have an ACK.
|
||||||
|
ErrSeqNumAlreadyApplied = errors.New("update sequence number has " +
|
||||||
|
"already been applied")
|
||||||
|
|
||||||
|
// ErrSessionConsumed is returned if the client tries to send a sequence
|
||||||
|
// number larger than the session's max number of updates.
|
||||||
|
ErrSessionConsumed = errors.New("all session updates have been " +
|
||||||
|
"consumed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionInfo holds the negotiated session parameters for single session id,
|
||||||
|
// and handles the acceptance and validation of state updates sent by the
|
||||||
|
// client.
|
||||||
|
type SessionInfo struct {
|
||||||
|
// ID is the remote public key of the watchtower client.
|
||||||
|
ID SessionID
|
||||||
|
|
||||||
|
// Version specifies the plaintext blob encoding of all state updates.
|
||||||
|
Version uint16
|
||||||
|
|
||||||
|
// MaxUpdates is the total number of updates the client can send for
|
||||||
|
// this session.
|
||||||
|
MaxUpdates uint16
|
||||||
|
|
||||||
|
// LastApplied the sequence number of the last successful state update.
|
||||||
|
LastApplied uint16
|
||||||
|
|
||||||
|
// ClientLastApplied the last last-applied the client has echoed back.
|
||||||
|
ClientLastApplied uint16
|
||||||
|
|
||||||
|
// RewardRate the fraction of the swept amount that goes to the tower,
|
||||||
|
// expressed in millionths of the swept balance.
|
||||||
|
RewardRate uint32
|
||||||
|
|
||||||
|
// SweepFeeRate is the agreed upon fee rate used to sign any sweep
|
||||||
|
// transactions.
|
||||||
|
SweepFeeRate lnwallet.SatPerKWeight
|
||||||
|
|
||||||
|
// RewardAddress the address that the tower's reward will be deposited
|
||||||
|
// to if a sweep transaction confirms.
|
||||||
|
RewardAddress []byte
|
||||||
|
|
||||||
|
// TODO(conner): store client metrics, DOS score, etc
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptUpdateSequence validates that a state update's sequence number and last
|
||||||
|
// applied are valid given our past history with the client. These checks ensure
|
||||||
|
// that clients are properly in sync and following the update protocol properly.
|
||||||
|
// If validation is successful, the receiver's LastApplied and ClientLastApplied
|
||||||
|
// are updated with the latest values presented by the client. Any errors
|
||||||
|
// returned from this method are converted into an appropriate
|
||||||
|
// wtwire.StateUpdateCode.
|
||||||
|
func (s *SessionInfo) AcceptUpdateSequence(seqNum, lastApplied uint16) error {
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// Client already claims to have an ACK for this seqnum.
|
||||||
|
case seqNum <= lastApplied:
|
||||||
|
return ErrSeqNumAlreadyApplied
|
||||||
|
|
||||||
|
// Client echos a last applied that is lower than previously sent.
|
||||||
|
case lastApplied < s.ClientLastApplied:
|
||||||
|
return ErrLastAppliedReversion
|
||||||
|
|
||||||
|
// Client update exceeds capacity of session.
|
||||||
|
case seqNum > s.MaxUpdates:
|
||||||
|
return ErrSessionConsumed
|
||||||
|
|
||||||
|
// Client update does not match our expected next seqnum.
|
||||||
|
case seqNum != s.LastApplied+1:
|
||||||
|
return ErrUpdateOutOfOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
s.LastApplied = seqNum
|
||||||
|
s.ClientLastApplied = lastApplied
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
23
watchtower/wtdb/session_state_update.go
Normal file
23
watchtower/wtdb/session_state_update.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package wtdb
|
||||||
|
|
||||||
|
// SessionStateUpdate holds a state update sent by a client along with its
|
||||||
|
// SessionID.
|
||||||
|
type SessionStateUpdate struct {
|
||||||
|
// ID the session id of the client who sent the state update.
|
||||||
|
ID SessionID
|
||||||
|
|
||||||
|
// SeqNum the sequence number of the update within the session.
|
||||||
|
SeqNum uint16
|
||||||
|
|
||||||
|
// LastApplied the highest index that client has acknowledged is
|
||||||
|
// committed
|
||||||
|
LastApplied uint16
|
||||||
|
|
||||||
|
// Hint is the 16-byte prefix of the revoked commitment transaction.
|
||||||
|
Hint BreachHint
|
||||||
|
|
||||||
|
// EncryptedBlob is a ciphertext containing the sweep information for
|
||||||
|
// exacting justice if the commitment transaction matching the breach
|
||||||
|
// hint is braodcast.
|
||||||
|
EncryptedBlob []byte
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user