Merge pull request #3133 from cfromknecht/wt-polish

watchtower: integrate altruist watchtower and watchtower client
This commit is contained in:
Olaoluwa Osuntokun 2019-06-14 21:34:10 +02:00 committed by GitHub
commit a53323205c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 1562 additions and 486 deletions

@ -32,6 +32,7 @@ import (
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing"
"github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower"
) )
const ( const (
@ -39,6 +40,7 @@ const (
defaultDataDirname = "data" defaultDataDirname = "data"
defaultChainSubDirname = "chain" defaultChainSubDirname = "chain"
defaultGraphSubDirname = "graph" defaultGraphSubDirname = "graph"
defaultTowerSubDirname = "watchtower"
defaultTLSCertFilename = "tls.cert" defaultTLSCertFilename = "tls.cert"
defaultTLSKeyFilename = "tls.key" defaultTLSKeyFilename = "tls.key"
defaultAdminMacFilename = "admin.macaroon" defaultAdminMacFilename = "admin.macaroon"
@ -132,6 +134,8 @@ var (
defaultDataDir = filepath.Join(defaultLndDir, defaultDataDirname) defaultDataDir = filepath.Join(defaultLndDir, defaultDataDirname)
defaultLogDir = filepath.Join(defaultLndDir, defaultLogDirname) defaultLogDir = filepath.Join(defaultLndDir, defaultLogDirname)
defaultTowerDir = filepath.Join(defaultDataDir, defaultTowerSubDirname)
defaultTLSCertPath = filepath.Join(defaultLndDir, defaultTLSCertFilename) defaultTLSCertPath = filepath.Join(defaultLndDir, defaultTLSCertFilename)
defaultTLSKeyPath = filepath.Join(defaultLndDir, defaultTLSKeyFilename) defaultTLSKeyPath = filepath.Join(defaultLndDir, defaultTLSKeyFilename)
@ -315,6 +319,10 @@ type config struct {
Caches *lncfg.Caches `group:"caches" namespace:"caches"` Caches *lncfg.Caches `group:"caches" namespace:"caches"`
Prometheus lncfg.Prometheus `group:"prometheus" namespace:"prometheus"` Prometheus lncfg.Prometheus `group:"prometheus" namespace:"prometheus"`
WtClient *lncfg.WtClient `group:"wtclient" namespace:"wtclient"`
Watchtower *lncfg.Watchtower `group:"watchtower" namespace:"watchtower"`
} }
// loadConfig initializes and parses the config using a config file and command // loadConfig initializes and parses the config using a config file and command
@ -410,6 +418,9 @@ func loadConfig() (*config, error) {
ChannelCacheSize: channeldb.DefaultChannelCacheSize, ChannelCacheSize: channeldb.DefaultChannelCacheSize,
}, },
Prometheus: lncfg.DefaultPrometheus(), Prometheus: lncfg.DefaultPrometheus(),
Watchtower: &lncfg.Watchtower{
TowerDir: defaultTowerDir,
},
} }
// Pre-parse the command line options to pick up an alternative config // Pre-parse the command line options to pick up an alternative config
@ -470,6 +481,14 @@ func loadConfig() (*config, error) {
cfg.TLSCertPath = filepath.Join(lndDir, defaultTLSCertFilename) cfg.TLSCertPath = filepath.Join(lndDir, defaultTLSCertFilename)
cfg.TLSKeyPath = filepath.Join(lndDir, defaultTLSKeyFilename) cfg.TLSKeyPath = filepath.Join(lndDir, defaultTLSKeyFilename)
cfg.LogDir = filepath.Join(lndDir, defaultLogDirname) cfg.LogDir = filepath.Join(lndDir, defaultLogDirname)
// If the watchtower's directory is set to the default, i.e. the
// user has not requested a different location, we'll move the
// location to be relative to the specified lnd directory.
if cfg.Watchtower.TowerDir == defaultTowerDir {
cfg.Watchtower.TowerDir =
filepath.Join(cfg.DataDir, defaultTowerSubDirname)
}
} }
// Create the lnd directory if it doesn't already exist. // Create the lnd directory if it doesn't already exist.
@ -506,6 +525,7 @@ func loadConfig() (*config, error) {
cfg.BitcoindMode.Dir = cleanAndExpandPath(cfg.BitcoindMode.Dir) cfg.BitcoindMode.Dir = cleanAndExpandPath(cfg.BitcoindMode.Dir)
cfg.LitecoindMode.Dir = cleanAndExpandPath(cfg.LitecoindMode.Dir) cfg.LitecoindMode.Dir = cleanAndExpandPath(cfg.LitecoindMode.Dir)
cfg.Tor.PrivateKeyPath = cleanAndExpandPath(cfg.Tor.PrivateKeyPath) cfg.Tor.PrivateKeyPath = cleanAndExpandPath(cfg.Tor.PrivateKeyPath)
cfg.Watchtower.TowerDir = cleanAndExpandPath(cfg.Watchtower.TowerDir)
// Ensure that the user didn't attempt to specify negative values for // Ensure that the user didn't attempt to specify negative values for
// any of the autopilot params. // any of the autopilot params.
@ -1051,15 +1071,27 @@ func loadConfig() (*config, error) {
"minbackoff") "minbackoff")
} }
// Validate the subconfigs for workers and caches. // Validate the subconfigs for workers, caches, and the tower client.
err = lncfg.Validate( err = lncfg.Validate(
cfg.Workers, cfg.Workers,
cfg.Caches, cfg.Caches,
cfg.WtClient,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
// If the user provided private watchtower addresses, parse them to
// obtain the LN addresses.
if cfg.WtClient.IsActive() {
err := cfg.WtClient.ParsePrivateTowers(
watchtower.DefaultPeerPort, cfg.net.ResolveTCPAddr,
)
if err != nil {
return nil, err
}
}
// Finally, ensure that the user's color is correctly formatted, // Finally, ensure that the user's color is correctly formatted,
// otherwise the server will not be able to start after the unlocking // otherwise the server will not be able to start after the unlocking
// the wallet. // the wallet.

@ -6,6 +6,7 @@ import (
"github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -159,3 +160,20 @@ type ForwardingLog interface {
// visualizations, etc. // visualizations, etc.
AddForwardingEvents([]channeldb.ForwardingEvent) error AddForwardingEvents([]channeldb.ForwardingEvent) error
} }
// TowerClient is the primary interface used by the daemon to backup pre-signed
// justice transactions to watchtowers.
type TowerClient interface {
// RegisterChannel persistently initializes any channel-dependent
// parameters within the client. This should be called during link
// startup to ensure that the client is able to support the link during
// operation.
RegisterChannel(lnwire.ChannelID) error
// BackupState initiates a request to back up a particular revoked
// state. If the method returns nil, the backup is guaranteed to be
// successful unless the tower is unavailable and client is force quit,
// or the justice transaction would create dust outputs when trying to
// abide by the negotiated policy.
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution) error
}

@ -229,10 +229,14 @@ type ChannelLinkConfig struct {
// receiving node is persistent. // receiving node is persistent.
UnsafeReplay bool UnsafeReplay bool
// MinFeeUpdateTimeout and MaxFeeUpdateTimeout represent the timeout // MinFeeUpdateTimeout represents the minimum interval in which a link
// interval bounds in which a link will propose to update its commitment // will propose to update its commitment fee rate. A random timeout will
// fee rate. A random timeout will be selected between these values. // be selected between this and MaxFeeUpdateTimeout.
MinFeeUpdateTimeout time.Duration MinFeeUpdateTimeout time.Duration
// MaxFeeUpdateTimeout represents the maximum interval in which a link
// will propose to update its commitment fee rate. A random timeout will
// be selected between this and MinFeeUpdateTimeout.
MaxFeeUpdateTimeout time.Duration MaxFeeUpdateTimeout time.Duration
// OutgoingCltvRejectDelta defines the number of blocks before expiry of // OutgoingCltvRejectDelta defines the number of blocks before expiry of
@ -240,6 +244,11 @@ type ChannelLinkConfig struct {
// the outgoing broadcast delta, because in any case we don't want to // the outgoing broadcast delta, because in any case we don't want to
// risk offering an htlc that triggers channel closure. // risk offering an htlc that triggers channel closure.
OutgoingCltvRejectDelta uint32 OutgoingCltvRejectDelta uint32
// TowerClient is an optional engine that manages the signing,
// encrypting, and uploading of justice transactions to the daemon's
// configured set of watchtowers.
TowerClient TowerClient
} }
// channelLink is the service which drives a channel's commitment update // channelLink is the service which drives a channel's commitment update
@ -396,6 +405,15 @@ func (l *channelLink) Start() error {
log.Infof("ChannelLink(%v) is starting", l) log.Infof("ChannelLink(%v) is starting", l)
// If the config supplied watchtower client, ensure the channel is
// registered before trying to use it during operation.
if l.cfg.TowerClient != nil {
err := l.cfg.TowerClient.RegisterChannel(l.ChanID())
if err != nil {
return err
}
}
l.mailBox.ResetMessages() l.mailBox.ResetMessages()
l.overflowQueue.Start() l.overflowQueue.Start()
l.hodlQueue.Start() l.hodlQueue.Start()
@ -1786,6 +1804,28 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
return return
} }
// If we have a tower client, we'll proceed in backing up the
// state that was just revoked.
if l.cfg.TowerClient != nil {
state := l.channel.State()
breachInfo, err := lnwallet.NewBreachRetribution(
state, state.RemoteCommitment.CommitHeight-1, 0,
)
if err != nil {
l.fail(LinkFailureError{code: ErrInternalError},
"failed to load breach info: %v", err)
return
}
chanID := l.ChanID()
err = l.cfg.TowerClient.BackupState(&chanID, breachInfo)
if err != nil {
l.fail(LinkFailureError{code: ErrInternalError},
"unable to queue breach backup: %v", err)
return
}
}
l.processRemoteSettleFails(fwdPkg, settleFails) l.processRemoteSettleFails(fwdPkg, settleFails)
needUpdate := l.processRemoteAdds(fwdPkg, adds) needUpdate := l.processRemoteAdds(fwdPkg, adds)

@ -18,12 +18,14 @@ var (
loopBackAddrs = []string{"localhost", "127.0.0.1", "[::1]"} loopBackAddrs = []string{"localhost", "127.0.0.1", "[::1]"}
) )
type tcpResolver = func(network, addr string) (*net.TCPAddr, error) // TCPResolver is a function signature that resolves an address on a given
// network.
type TCPResolver = func(network, addr string) (*net.TCPAddr, error)
// NormalizeAddresses returns a new slice with all the passed addresses // NormalizeAddresses returns a new slice with all the passed addresses
// normalized with the given default port and all duplicates removed. // normalized with the given default port and all duplicates removed.
func NormalizeAddresses(addrs []string, defaultPort string, func NormalizeAddresses(addrs []string, defaultPort string,
tcpResolver tcpResolver) ([]net.Addr, error) { tcpResolver TCPResolver) ([]net.Addr, error) {
result := make([]net.Addr, 0, len(addrs)) result := make([]net.Addr, 0, len(addrs))
seen := map[string]struct{}{} seen := map[string]struct{}{}
@ -120,7 +122,7 @@ func IsUnix(addr net.Addr) bool {
// connections. We accept a custom function to resolve any TCP addresses so // connections. We accept a custom function to resolve any TCP addresses so
// that caller is able control exactly how resolution is performed. // that caller is able control exactly how resolution is performed.
func ParseAddressString(strAddress string, defaultPort string, func ParseAddressString(strAddress string, defaultPort string,
tcpResolver tcpResolver) (net.Addr, error) { tcpResolver TCPResolver) (net.Addr, error) {
var parsedNetwork, parsedAddr string var parsedNetwork, parsedAddr string
@ -188,9 +190,9 @@ func ParseAddressString(strAddress string, defaultPort string,
// 33-byte, compressed public key that lies on the secp256k1 curve. The <addr> // 33-byte, compressed public key that lies on the secp256k1 curve. The <addr>
// may be any address supported by ParseAddressString. If no port is specified, // may be any address supported by ParseAddressString. If no port is specified,
// the defaultPort will be used. Any tcp addresses that need resolving will be // the defaultPort will be used. Any tcp addresses that need resolving will be
// resolved using the custom tcpResolver. // resolved using the custom TCPResolver.
func ParseLNAddressString(strAddress string, defaultPort string, func ParseLNAddressString(strAddress string, defaultPort string,
tcpResolver tcpResolver) (*lnwire.NetAddress, error) { tcpResolver TCPResolver) (*lnwire.NetAddress, error) {
// Split the address string around the @ sign. // Split the address string around the @ sign.
parts := strings.Split(strAddress, "@") parts := strings.Split(strAddress, "@")

13
lncfg/watchtower.go Normal file

@ -0,0 +1,13 @@
package lncfg
import "github.com/lightningnetwork/lnd/watchtower"
// Watchtower holds the daemon specific configuration parameters for running a
// watchtower that shares resources with the daemon.
type Watchtower struct {
Active bool `long:"active" description:"If the watchtower should be active or not"`
TowerDir string `long:"towerdir" description:"Directory of the watchtower.db"`
watchtower.Conf
}

65
lncfg/wtclient.go Normal file

@ -0,0 +1,65 @@
package lncfg
import (
"fmt"
"strconv"
"github.com/lightningnetwork/lnd/lnwire"
)
// WtClient holds the configuration options for the daemon's watchtower client.
type WtClient struct {
// PrivateTowerURIs specifies the lightning URIs of the towers the
// watchtower client should send new backups to.
PrivateTowerURIs []string `long:"private-tower-uris" description:"Specifies the URIs of private watchtowers to use in backing up revoked states. URIs must be of the form <pubkey>@<addr>. Only 1 URI is supported at this time, if none are provided the tower will not be enabled."`
// PrivateTowers is the list of towers parsed from the URIs provided in
// PrivateTowerURIs.
PrivateTowers []*lnwire.NetAddress
// SweepFeeRate specifies the fee rate in sat/byte to be used when
// constructing justice transactions sent to the tower.
SweepFeeRate uint64 `long:"sweep-fee-rate" description:"Specifies the fee rate in sat/byte to be used when constructing justice transactions sent to the watchtower."`
}
// Validate asserts that at most 1 private watchtower is requested.
//
// NOTE: Part of the Validator interface.
func (c *WtClient) Validate() error {
if len(c.PrivateTowerURIs) > 1 {
return fmt.Errorf("at most 1 private watchtower is supported, "+
"found %d", len(c.PrivateTowerURIs))
}
return nil
}
// IsActive returns true if the watchtower client should be active.
func (c *WtClient) IsActive() bool {
return len(c.PrivateTowerURIs) > 0
}
// ParsePrivateTowers parses any private tower URIs held PrivateTowerURIs. The
// value of port should be the default port to use when a URI does not have one.
func (c *WtClient) ParsePrivateTowers(port int, resolver TCPResolver) error {
towers := make([]*lnwire.NetAddress, 0, len(c.PrivateTowerURIs))
for _, uri := range c.PrivateTowerURIs {
addr, err := ParseLNAddressString(
uri, strconv.Itoa(port), resolver,
)
if err != nil {
return fmt.Errorf("unable to parse private "+
"watchtower address: %v", err)
}
towers = append(towers, addr)
}
c.PrivateTowers = towers
return nil
}
// Compile-time constraint to ensure WtClient implements the Validator
// interface.
var _ Validator = (*WtClient)(nil)

69
lnd.go

@ -36,6 +36,7 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet"
proxy "github.com/grpc-ecosystem/grpc-gateway/runtime" proxy "github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/lightninglabs/neutrino" "github.com/lightninglabs/neutrino"
@ -51,6 +52,8 @@ import (
"github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/macaroons"
"github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/signal"
"github.com/lightningnetwork/lnd/walletunlocker" "github.com/lightningnetwork/lnd/walletunlocker"
"github.com/lightningnetwork/lnd/watchtower"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
) )
const ( const (
@ -313,11 +316,65 @@ func Main() error {
"is proxying over Tor as well", cfg.Tor.StreamIsolation) "is proxying over Tor as well", cfg.Tor.StreamIsolation)
} }
// If the watchtower client should be active, open the client database.
// This is done here so that Close always executes when lndMain returns.
var towerClientDB *wtdb.ClientDB
if cfg.WtClient.IsActive() {
var err error
towerClientDB, err = wtdb.OpenClientDB(graphDir)
if err != nil {
ltndLog.Errorf("Unable to open watchtower client db: %v", err)
}
defer towerClientDB.Close()
}
var tower *watchtower.Standalone
if cfg.Watchtower.Active {
// Segment the watchtower directory by chain and network.
towerDBDir := filepath.Join(
cfg.Watchtower.TowerDir,
registeredChains.PrimaryChain().String(),
normalizeNetwork(activeNetParams.Name),
)
towerDB, err := wtdb.OpenTowerDB(towerDBDir)
if err != nil {
ltndLog.Errorf("Unable to open watchtower db: %v", err)
return err
}
defer towerDB.Close()
wtConfig, err := cfg.Watchtower.Apply(&watchtower.Config{
BlockFetcher: activeChainControl.chainIO,
DB: towerDB,
EpochRegistrar: activeChainControl.chainNotifier,
Net: cfg.net,
NewAddress: func() (btcutil.Address, error) {
return activeChainControl.wallet.NewAddress(
lnwallet.WitnessPubKey, false,
)
},
NodePrivKey: idPrivKey,
PublishTx: activeChainControl.wallet.PublishTransaction,
ChainHash: *activeNetParams.GenesisHash,
}, lncfg.NormalizeAddresses)
if err != nil {
ltndLog.Errorf("Unable to configure watchtower: %v", err)
return err
}
tower, err = watchtower.New(wtConfig)
if err != nil {
ltndLog.Errorf("Unable to create watchtower: %v", err)
return err
}
}
// Set up the core server which will listen for incoming peer // Set up the core server which will listen for incoming peer
// connections. // connections.
server, err := newServer( server, err := newServer(
cfg.Listeners, chanDB, activeChainControl, idPrivKey, cfg.Listeners, chanDB, towerClientDB, activeChainControl,
walletInitParams.ChansToRestore, idPrivKey, walletInitParams.ChansToRestore,
) )
if err != nil { if err != nil {
srvrLog.Errorf("unable to create server: %v\n", err) srvrLog.Errorf("unable to create server: %v\n", err)
@ -418,6 +475,14 @@ func Main() error {
} }
} }
if cfg.Watchtower.Active {
if err := tower.Start(); err != nil {
ltndLog.Errorf("Unable to start watchtower: %v", err)
return err
}
defer tower.Stop()
}
// Wait for shutdown signal from either a graceful server stop or from // Wait for shutdown signal from either a graceful server stop or from
// the interrupt handler. // the interrupt handler.
<-signal.ShutdownChannel() <-signal.ShutdownChannel()

@ -7606,6 +7606,353 @@ func testRevokedCloseRetributionRemoteHodl(net *lntest.NetworkHarness,
assertNodeNumChannels(t, dave, 0) assertNodeNumChannels(t, dave, 0)
} }
// testRevokedCloseRetributionAltruistWatchtower establishes a channel between
// Carol and Dave, where Carol is using a third node Willy as her watchtower.
// After sending some payments, Dave reverts his state and force closes to
// trigger a breach. Carol is kept offline throughout the process and the test
// asserts that Willy responds by broadcasting the justice transaction on
// Carol's behalf sweeping her funds without a reward.
func testRevokedCloseRetributionAltruistWatchtower(net *lntest.NetworkHarness,
t *harnessTest) {
ctxb := context.Background()
const (
chanAmt = lnd.MaxBtcFundingAmount
paymentAmt = 10000
numInvoices = 6
)
// Since we'd like to test some multi-hop failure scenarios, we'll
// introduce another node into our test network: Carol.
carol, err := net.NewNode("Carol", []string{
"--debughtlc", "--hodl.exit-settle",
})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
defer shutdownAndAssert(net, t, carol)
// Willy the watchtower will protect Dave from Carol's breach. He will
// remain online in order to punish Carol on Dave's behalf, since the
// breach will happen while Dave is offline.
willy, err := net.NewNode("Willy", []string{"--watchtower.active"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
defer shutdownAndAssert(net, t, willy)
ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
willyInfo, err := willy.GetInfo(ctxt, &lnrpc.GetInfoRequest{})
if err != nil {
t.Fatalf("unable to getinfo from willy: %v", err)
}
willyAddr := willyInfo.Uris[0]
parts := strings.Split(willyAddr, ":")
willyTowerAddr := parts[0]
// Dave will be the breached party. We set --nolisten to ensure Carol
// won't be able to connect to him and trigger the channel data
// protection logic automatically.
dave, err := net.NewNode("Dave", []string{
"--nolisten",
"--wtclient.private-tower-uris=" + willyTowerAddr,
})
if err != nil {
t.Fatalf("unable to create new node: %v", err)
}
defer shutdownAndAssert(net, t, dave)
// We must let Dave have an open channel before she can send a node
// announcement, so we open a channel with Carol,
if err := net.ConnectNodes(ctxb, dave, carol); err != nil {
t.Fatalf("unable to connect dave to carol: %v", err)
}
// Before we make a channel, we'll load up Dave with some coins sent
// directly from the miner.
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, dave)
if err != nil {
t.Fatalf("unable to send coins to dave: %v", err)
}
// In order to test Dave's response to an uncooperative channel
// closure by Carol, we'll first open up a channel between them with a
// 0.5 BTC value.
ctxt, _ = context.WithTimeout(ctxb, channelOpenTimeout)
chanPoint := openChannelAndAssert(
ctxt, t, net, dave, carol,
lntest.OpenChannelParams{
Amt: 3 * (chanAmt / 4),
PushAmt: chanAmt / 4,
},
)
// With the channel open, we'll create a few invoices for Carol that
// Dave will pay to in order to advance the state of the channel.
carolPayReqs, _, _, err := createPayReqs(
carol, paymentAmt, numInvoices,
)
if err != nil {
t.Fatalf("unable to create pay reqs: %v", err)
}
// Wait for Dave to receive the channel edge from the funding manager.
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
err = dave.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil {
t.Fatalf("dave didn't see the dave->carol channel before "+
"timeout: %v", err)
}
// Next query for Carol's channel state, as we sent 0 payments, Carol
// should still see her balance as the push amount, which is 1/4 of the
// capacity.
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
carolChan, err := getChanInfo(ctxt, carol)
if err != nil {
t.Fatalf("unable to get carol's channel info: %v", err)
}
if carolChan.LocalBalance != int64(chanAmt/4) {
t.Fatalf("carol's balance is incorrect, got %v, expected %v",
carolChan.LocalBalance, chanAmt/4)
}
// Grab Carol's current commitment height (update number), we'll later
// revert her to this state after additional updates to force him to
// broadcast this soon to be revoked state.
carolStateNumPreCopy := carolChan.NumUpdates
// Create a temporary file to house Carol's database state at this
// particular point in history.
carolTempDbPath, err := ioutil.TempDir("", "carol-past-state")
if err != nil {
t.Fatalf("unable to create temp db folder: %v", err)
}
carolTempDbFile := filepath.Join(carolTempDbPath, "channel.db")
defer os.Remove(carolTempDbPath)
// With the temporary file created, copy Carol's current state into the
// temporary file we created above. Later after more updates, we'll
// restore this state.
if err := lntest.CopyFile(carolTempDbFile, carol.DBPath()); err != nil {
t.Fatalf("unable to copy database files: %v", err)
}
// Finally, send payments from Dave to Carol, consuming Carol's remaining
// payment hashes.
err = completePaymentRequests(ctxb, dave, carolPayReqs, false)
if err != nil {
t.Fatalf("unable to send payments: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
daveBalReq := &lnrpc.WalletBalanceRequest{}
daveBalResp, err := dave.WalletBalance(ctxt, daveBalReq)
if err != nil {
t.Fatalf("unable to get dave's balance: %v", err)
}
davePreSweepBalance := daveBalResp.ConfirmedBalance
// Shutdown Dave to simulate going offline for an extended period of
// time. Once he's not watching, Carol will try to breach the channel.
restart, err := net.SuspendNode(dave)
if err != nil {
t.Fatalf("unable to suspend Dave: %v", err)
}
// Now we shutdown Carol, copying over the his temporary database state
// which has the *prior* channel state over his current most up to date
// state. With this, we essentially force Carol to travel back in time
// within the channel's history.
if err = net.RestartNode(carol, func() error {
return os.Rename(carolTempDbFile, carol.DBPath())
}); err != nil {
t.Fatalf("unable to restart node: %v", err)
}
// Now query for Carol's channel state, it should show that he's at a
// state number in the past, not the *latest* state.
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
carolChan, err = getChanInfo(ctxt, carol)
if err != nil {
t.Fatalf("unable to get carol chan info: %v", err)
}
if carolChan.NumUpdates != carolStateNumPreCopy {
t.Fatalf("db copy failed: %v", carolChan.NumUpdates)
}
// TODO(conner): add hook for backup completion
time.Sleep(3 * time.Second)
// Now force Carol to execute a *force* channel closure by unilaterally
// broadcasting his current channel state. This is actually the
// commitment transaction of a prior *revoked* state, so he'll soon
// feel the wrath of Dave's retribution.
closeUpdates, closeTxId, err := net.CloseChannel(
ctxb, carol, chanPoint, true,
)
if err != nil {
t.Fatalf("unable to close channel: %v", err)
}
// Query the mempool for the breaching closing transaction, this should
// be broadcast by Carol when she force closes the channel above.
txid, err := waitForTxInMempool(net.Miner.Node, minerMempoolTimeout)
if err != nil {
t.Fatalf("unable to find Carol's force close tx in mempool: %v",
err)
}
if *txid != *closeTxId {
t.Fatalf("expected closeTx(%v) in mempool, instead found %v",
closeTxId, txid)
}
// Finally, generate a single block, wait for the final close status
// update, then ensure that the closing transaction was included in the
// block.
block := mineBlocks(t, net, 1, 1)[0]
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
breachTXID, err := net.WaitForChannelClose(ctxt, closeUpdates)
if err != nil {
t.Fatalf("error while waiting for channel close: %v", err)
}
assertTxInBlock(t, block, breachTXID)
// Query the mempool for Dave's justice transaction, this should be
// broadcast as Carol's contract breaching transaction gets confirmed
// above.
justiceTXID, err := waitForTxInMempool(net.Miner.Node, minerMempoolTimeout)
if err != nil {
t.Fatalf("unable to find Dave's justice tx in mempool: %v",
err)
}
time.Sleep(100 * time.Millisecond)
// Query for the mempool transaction found above. Then assert that all
// the inputs of this transaction are spending outputs generated by
// Carol's breach transaction above.
justiceTx, err := net.Miner.Node.GetRawTransaction(justiceTXID)
if err != nil {
t.Fatalf("unable to query for justice tx: %v", err)
}
for _, txIn := range justiceTx.MsgTx().TxIn {
if !bytes.Equal(txIn.PreviousOutPoint.Hash[:], breachTXID[:]) {
t.Fatalf("justice tx not spending commitment utxo "+
"instead is: %v", txIn.PreviousOutPoint)
}
}
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
willyBalReq := &lnrpc.WalletBalanceRequest{}
willyBalResp, err := willy.WalletBalance(ctxt, willyBalReq)
if err != nil {
t.Fatalf("unable to get willy's balance: %v", err)
}
if willyBalResp.ConfirmedBalance != 0 {
t.Fatalf("willy should have 0 balance before mining "+
"justice transaction, instead has %d",
willyBalResp.ConfirmedBalance)
}
// Now mine a block, this transaction should include Dave's justice
// transaction which was just accepted into the mempool.
block = mineBlocks(t, net, 1, 1)[0]
// The block should have exactly *two* transactions, one of which is
// the justice transaction.
if len(block.Transactions) != 2 {
t.Fatalf("transaction wasn't mined")
}
justiceSha := block.Transactions[1].TxHash()
if !bytes.Equal(justiceTx.Hash()[:], justiceSha[:]) {
t.Fatalf("justice tx wasn't mined")
}
// Ensure that Willy doesn't get any funds, as he is acting as an
// altruist watchtower.
var predErr error
err = lntest.WaitInvariant(func() bool {
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
willyBalReq := &lnrpc.WalletBalanceRequest{}
willyBalResp, err := willy.WalletBalance(ctxt, willyBalReq)
if err != nil {
t.Fatalf("unable to get willy's balance: %v", err)
}
if willyBalResp.ConfirmedBalance != 0 {
predErr = fmt.Errorf("Expected Willy to have no funds "+
"after justice transaction was mined, found %v",
willyBalResp)
return false
}
return true
}, time.Second*5)
if err != nil {
t.Fatalf("%v", predErr)
}
// Restart Dave, who will still think his channel with Carol is open.
// We should him to detect the breach, but realize that the funds have
// then been swept to his wallet by Willy.
err = restart()
if err != nil {
t.Fatalf("unable to restart dave: %v", err)
}
err = lntest.WaitPredicate(func() bool {
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
daveBalReq := &lnrpc.ChannelBalanceRequest{}
daveBalResp, err := dave.ChannelBalance(ctxt, daveBalReq)
if err != nil {
t.Fatalf("unable to get dave's balance: %v", err)
}
if daveBalResp.Balance != 0 {
predErr = fmt.Errorf("Dave should end up with zero "+
"channel balance, instead has %d",
daveBalResp.Balance)
return false
}
return true
}, time.Second*15)
if err != nil {
t.Fatalf("%v", predErr)
}
assertNumPendingChannels(t, dave, 0, 0)
err = lntest.WaitPredicate(func() bool {
ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
daveBalReq := &lnrpc.WalletBalanceRequest{}
daveBalResp, err := dave.WalletBalance(ctxt, daveBalReq)
if err != nil {
t.Fatalf("unable to get dave's balance: %v", err)
}
if daveBalResp.ConfirmedBalance <= davePreSweepBalance {
predErr = fmt.Errorf("Dave should have more than %d "+
"after sweep, instead has %d",
davePreSweepBalance,
daveBalResp.ConfirmedBalance)
return false
}
return true
}, time.Second*15)
if err != nil {
t.Fatalf("%v", predErr)
}
// Dave should have no open channels.
assertNodeNumChannels(t, dave, 0)
}
// assertNumPendingChannels checks that a PendingChannels response from the // assertNumPendingChannels checks that a PendingChannels response from the
// node reports the expected number of pending channels. // node reports the expected number of pending channels.
func assertNumPendingChannels(t *harnessTest, node *lntest.HarnessNode, func assertNumPendingChannels(t *harnessTest, node *lntest.HarnessNode,
@ -13840,6 +14187,10 @@ var testsCases = []*testCase{
name: "revoked uncooperative close retribution remote hodl", name: "revoked uncooperative close retribution remote hodl",
test: testRevokedCloseRetributionRemoteHodl, test: testRevokedCloseRetributionRemoteHodl,
}, },
{
name: "revoked uncooperative close retribution altruist watchtower",
test: testRevokedCloseRetributionAltruistWatchtower,
},
{ {
name: "data loss protection", name: "data loss protection",
test: testDataLossProtection, test: testDataLossProtection,

4
log.go

@ -34,6 +34,7 @@ import (
"github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/signal"
"github.com/lightningnetwork/lnd/sweep" "github.com/lightningnetwork/lnd/sweep"
"github.com/lightningnetwork/lnd/watchtower" "github.com/lightningnetwork/lnd/watchtower"
"github.com/lightningnetwork/lnd/watchtower/wtclient"
) )
// Loggers per subsystem. A single backend logger is created and all subsystem // Loggers per subsystem. A single backend logger is created and all subsystem
@ -87,6 +88,7 @@ var (
chnfLog = build.NewSubLogger("CHNF", backendLog.Logger) chnfLog = build.NewSubLogger("CHNF", backendLog.Logger)
chbuLog = build.NewSubLogger("CHBU", backendLog.Logger) chbuLog = build.NewSubLogger("CHBU", backendLog.Logger)
promLog = build.NewSubLogger("PROM", backendLog.Logger) promLog = build.NewSubLogger("PROM", backendLog.Logger)
wtclLog = build.NewSubLogger("WTCL", backendLog.Logger)
) )
// Initialize package-global logger variables. // Initialize package-global logger variables.
@ -115,6 +117,7 @@ func init() {
channelnotifier.UseLogger(chnfLog) channelnotifier.UseLogger(chnfLog)
chanbackup.UseLogger(chbuLog) chanbackup.UseLogger(chbuLog)
monitoring.UseLogger(promLog) monitoring.UseLogger(promLog)
wtclient.UseLogger(wtclLog)
addSubLogger(routerrpc.Subsystem, routerrpc.UseLogger) addSubLogger(routerrpc.Subsystem, routerrpc.UseLogger)
} }
@ -159,6 +162,7 @@ var subsystemLoggers = map[string]btclog.Logger{
"CHNF": chnfLog, "CHNF": chnfLog,
"CHBU": chbuLog, "CHBU": chbuLog,
"PROM": promLog, "PROM": promLog,
"WTCL": wtclLog,
} }
// initLogRotator initializes the logging rotator to write logs to logFile and // initLogRotator initializes the logging rotator to write logs to logFile and

@ -594,6 +594,7 @@ func (p *peer) addLink(chanPoint *wire.OutPoint,
MinFeeUpdateTimeout: htlcswitch.DefaultMinLinkFeeUpdateTimeout, MinFeeUpdateTimeout: htlcswitch.DefaultMinLinkFeeUpdateTimeout,
MaxFeeUpdateTimeout: htlcswitch.DefaultMaxLinkFeeUpdateTimeout, MaxFeeUpdateTimeout: htlcswitch.DefaultMaxLinkFeeUpdateTimeout,
OutgoingCltvRejectDelta: p.outgoingCltvRejectDelta, OutgoingCltvRejectDelta: p.outgoingCltvRejectDelta,
TowerClient: p.server.towerClient,
} }
link := htlcswitch.NewChannelLink(linkCfg, lnChan) link := htlcswitch.NewChannelLink(linkCfg, lnChan)

@ -2173,6 +2173,9 @@ func (r *rpcServer) ChannelBalance(ctx context.Context,
pendingOpenBalance += channel.LocalCommitment.LocalBalance.ToSatoshis() pendingOpenBalance += channel.LocalCommitment.LocalBalance.ToSatoshis()
} }
rpcsLog.Debugf("[channelbalance] balance=%v pending-open=%v",
balance, pendingOpenBalance)
return &lnrpc.ChannelBalanceResponse{ return &lnrpc.ChannelBalanceResponse{
Balance: int64(balance), Balance: int64(balance),
PendingOpenBalance: int64(pendingOpenBalance), PendingOpenBalance: int64(pendingOpenBalance),

@ -20,6 +20,7 @@ import (
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/connmgr"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
@ -50,6 +51,9 @@ import (
"github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/walletunlocker" "github.com/lightningnetwork/lnd/walletunlocker"
"github.com/lightningnetwork/lnd/watchtower/wtclient"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/zpay32" "github.com/lightningnetwork/lnd/zpay32"
) )
@ -204,6 +208,8 @@ type server struct {
sphinx *htlcswitch.OnionProcessor sphinx *htlcswitch.OnionProcessor
towerClient wtclient.Client
connMgr *connmgr.ConnManager connMgr *connmgr.ConnManager
sigPool *lnwallet.SigPool sigPool *lnwallet.SigPool
@ -282,7 +288,8 @@ func noiseDial(idPriv *btcec.PrivateKey) func(net.Addr) (net.Conn, error) {
// newServer creates a new instance of the server which is to listen using the // newServer creates a new instance of the server which is to listen using the
// passed listener address. // passed listener address.
func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
towerClientDB *wtdb.ClientDB, cc *chainControl,
privKey *btcec.PrivateKey, privKey *btcec.PrivateKey,
chansToRestore walletunlocker.ChannelsToRecover) (*server, error) { chansToRestore walletunlocker.ChannelsToRecover) (*server, error) {
@ -723,10 +730,8 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
} }
s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{ s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{
FeeEstimator: cc.feeEstimator, FeeEstimator: cc.feeEstimator,
GenSweepScript: func() ([]byte, error) { GenSweepScript: newSweepPkScriptGen(cc.wallet),
return newSweepPkScript(cc.wallet)
},
Signer: cc.wallet.Cfg.Signer, Signer: cc.wallet.Cfg.Signer,
PublishTransaction: cc.wallet.PublishTransaction, PublishTransaction: cc.wallet.PublishTransaction,
NewBatchTimer: func() <-chan time.Time { NewBatchTimer: func() <-chan time.Time {
@ -769,10 +774,8 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
ChainHash: *activeNetParams.GenesisHash, ChainHash: *activeNetParams.GenesisHash,
IncomingBroadcastDelta: DefaultIncomingBroadcastDelta, IncomingBroadcastDelta: DefaultIncomingBroadcastDelta,
OutgoingBroadcastDelta: DefaultOutgoingBroadcastDelta, OutgoingBroadcastDelta: DefaultOutgoingBroadcastDelta,
NewSweepAddr: func() ([]byte, error) { NewSweepAddr: newSweepPkScriptGen(cc.wallet),
return newSweepPkScript(cc.wallet) PublishTx: cc.wallet.PublishTransaction,
},
PublishTx: cc.wallet.PublishTransaction,
DeliverResolutionMsg: func(msgs ...contractcourt.ResolutionMsg) error { DeliverResolutionMsg: func(msgs ...contractcourt.ResolutionMsg) error {
for _, msg := range msgs { for _, msg := range msgs {
err := s.htlcSwitch.ProcessContractResolution(msg) err := s.htlcSwitch.ProcessContractResolution(msg)
@ -845,12 +848,10 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
}, chanDB) }, chanDB)
s.breachArbiter = newBreachArbiter(&BreachConfig{ s.breachArbiter = newBreachArbiter(&BreachConfig{
CloseLink: closeLink, CloseLink: closeLink,
DB: chanDB, DB: chanDB,
Estimator: s.cc.feeEstimator, Estimator: s.cc.feeEstimator,
GenSweepScript: func() ([]byte, error) { GenSweepScript: newSweepPkScriptGen(cc.wallet),
return newSweepPkScript(cc.wallet)
},
Notifier: cc.chainNotifier, Notifier: cc.chainNotifier,
PublishTransaction: cc.wallet.PublishTransaction, PublishTransaction: cc.wallet.PublishTransaction,
ContractBreaches: contractBreaches, ContractBreaches: contractBreaches,
@ -1056,6 +1057,41 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
return nil, err return nil, err
} }
if cfg.WtClient.IsActive() {
policy := wtpolicy.DefaultPolicy()
if cfg.WtClient.SweepFeeRate != 0 {
// We expose the sweep fee rate in sat/byte, but the
// tower protocol operations on sat/kw.
sweepRateSatPerByte := lnwallet.SatPerKVByte(
1000 * cfg.WtClient.SweepFeeRate,
)
policy.SweepFeeRate = sweepRateSatPerByte.FeePerKWeight()
}
if err := policy.Validate(); err != nil {
return nil, err
}
s.towerClient, err = wtclient.New(&wtclient.Config{
Signer: cc.wallet.Cfg.Signer,
NewAddress: newSweepPkScriptGen(cc.wallet),
SecretKeyRing: s.cc.keyRing,
Dial: cfg.net.Dial,
AuthDial: wtclient.AuthDial,
DB: towerClientDB,
Policy: wtpolicy.DefaultPolicy(),
PrivateTower: cfg.WtClient.PrivateTowers[0],
ChainHash: *activeNetParams.GenesisHash,
MinBackoff: 10 * time.Second,
MaxBackoff: 5 * time.Minute,
ForceQuitDelay: wtclient.DefaultForceQuitDelay,
})
if err != nil {
return nil, err
}
}
// Create the connection manager which will be responsible for // Create the connection manager which will be responsible for
// maintaining persistent outbound connections and also accepting new // maintaining persistent outbound connections and also accepting new
// incoming connections // incoming connections
@ -1128,6 +1164,12 @@ func (s *server) Start() error {
startErr = err startErr = err
return return
} }
if s.towerClient != nil {
if err := s.towerClient.Start(); err != nil {
startErr = err
return
}
}
if err := s.htlcSwitch.Start(); err != nil { if err := s.htlcSwitch.Start(); err != nil {
startErr = err startErr = err
return return
@ -1290,6 +1332,14 @@ func (s *server) Stop() error {
s.DisconnectPeer(peer.addr.IdentityKey) s.DisconnectPeer(peer.addr.IdentityKey)
} }
// Now that all connections have been torn down, stop the tower
// client which will reliably flush all queued states to the
// tower. If this is halted for any reason, the force quit timer
// will kick in and abort to allow this method to return.
if s.towerClient != nil {
s.towerClient.Stop()
}
// Wait for all lingering goroutines to quit. // Wait for all lingering goroutines to quit.
s.wg.Wait() s.wg.Wait()
@ -3180,3 +3230,20 @@ func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate) error {
return ErrServerShuttingDown return ErrServerShuttingDown
} }
} }
// newSweepPkScriptGen creates closure that generates a new public key script
// which should be used to sweep any funds into the on-chain wallet.
// Specifically, the script generated is a version 0, pay-to-witness-pubkey-hash
// (p2wkh) output.
func newSweepPkScriptGen(
wallet lnwallet.WalletController) func() ([]byte, error) {
return func() ([]byte, error) {
sweepAddr, err := wallet.NewAddress(lnwallet.WitnessPubKey, false)
if err != nil {
return nil, err
}
return txscript.PayToAddrScript(sweepAddr)
}
}

@ -8,7 +8,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
@ -1231,19 +1230,6 @@ func (u *utxoNursery) closeAndRemoveIfMature(chanPoint *wire.OutPoint) error {
return nil return nil
} }
// newSweepPkScript creates a new public key script which should be used to
// sweep any time-locked, or contested channel funds into the wallet.
// Specifically, the script generated is a version 0, pay-to-witness-pubkey-hash
// (p2wkh) output.
func newSweepPkScript(wallet lnwallet.WalletController) ([]byte, error) {
sweepAddr, err := wallet.NewAddress(lnwallet.WitnessPubKey, false)
if err != nil {
return nil, err
}
return txscript.PayToAddrScript(sweepAddr)
}
// babyOutput represents a two-stage CSV locked output, and is used to track // babyOutput represents a two-stage CSV locked output, and is used to track
// htlc outputs through incubation. The first stage requires broadcasting a // htlc outputs through incubation. The first stage requires broadcasting a
// presigned timeout txn that spends from the CLTV locked output on the // presigned timeout txn that spends from the CLTV locked output on the

@ -0,0 +1,70 @@
package blob
import (
"crypto/sha256"
"encoding/hex"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// BreachHintSize is the length of the identifier used to detect remote
// commitment broadcasts.
const BreachHintSize = 16
// BreachHint is the first 16-bytes of SHA256(txid), which is used to identify
// the breach transaction.
type BreachHint [BreachHintSize]byte
// NewBreachHintFromHash creates a breach hint from a transaction ID.
func NewBreachHintFromHash(hash *chainhash.Hash) BreachHint {
h := sha256.New()
h.Write(hash[:])
var hint BreachHint
copy(hint[:], h.Sum(nil))
return hint
}
// String returns a hex encoding of the breach hint.
func (h BreachHint) String() string {
return hex.EncodeToString(h[:])
}
// BreachKey is computed as SHA256(txid || txid), which produces the key for
// decrypting a client's encrypted blobs.
type BreachKey [KeySize]byte
// NewBreachKeyFromHash creates a breach key from a transaction ID.
func NewBreachKeyFromHash(hash *chainhash.Hash) BreachKey {
h := sha256.New()
h.Write(hash[:])
h.Write(hash[:])
var key BreachKey
copy(key[:], h.Sum(nil))
return key
}
// String returns a hex encoding of the breach key.
func (k BreachKey) String() string {
return hex.EncodeToString(k[:])
}
// NewBreachHintAndKeyFromHash derives a BreachHint and BreachKey from a given
// txid in a single pass. The hint and key are computed as:
// hint = SHA256(txid)
// key = SHA256(txid || txid)
func NewBreachHintAndKeyFromHash(hash *chainhash.Hash) (BreachHint, BreachKey) {
var (
hint BreachHint
key BreachKey
)
h := sha256.New()
h.Write(hash[:])
copy(hint[:], h.Sum(nil))
h.Write(hash[:])
copy(key[:], h.Sum(nil))
return hint, key
}

@ -75,11 +75,6 @@ var (
"ciphertext is too small for chacha20poly1305", "ciphertext is too small for chacha20poly1305",
) )
// ErrKeySize signals that the provided key is improperly sized.
ErrKeySize = fmt.Errorf(
"chacha20poly1305 key size must be %d bytes", KeySize,
)
// ErrNoCommitToRemoteOutput is returned when trying to retrieve the // ErrNoCommitToRemoteOutput is returned when trying to retrieve the
// commit to-remote output from the blob, though none exists. // commit to-remote output from the blob, though none exists.
ErrNoCommitToRemoteOutput = errors.New( ErrNoCommitToRemoteOutput = errors.New(
@ -223,12 +218,7 @@ func (b *JusticeKit) CommitToRemoteWitnessStack() ([][]byte, error) {
// //
// NOTE: It is the caller's responsibility to ensure that this method is only // NOTE: It is the caller's responsibility to ensure that this method is only
// called once for a given (nonce, key) pair. // called once for a given (nonce, key) pair.
func (b *JusticeKit) Encrypt(key []byte, blobType Type) ([]byte, error) { func (b *JusticeKit) Encrypt(key BreachKey, blobType Type) ([]byte, error) {
// Fail if the nonce is not 32-bytes.
if len(key) != KeySize {
return nil, ErrKeySize
}
// Encode the plaintext using the provided version, to obtain the // Encode the plaintext using the provided version, to obtain the
// plaintext bytes. // plaintext bytes.
var ptxtBuf bytes.Buffer var ptxtBuf bytes.Buffer
@ -238,7 +228,7 @@ func (b *JusticeKit) Encrypt(key []byte, blobType Type) ([]byte, error) {
} }
// Create a new chacha20poly1305 cipher, using a 32-byte key. // Create a new chacha20poly1305 cipher, using a 32-byte key.
cipher, err := chacha20poly1305.NewX(key) cipher, err := chacha20poly1305.NewX(key[:])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -264,21 +254,17 @@ func (b *JusticeKit) Encrypt(key []byte, blobType Type) ([]byte, error) {
// Decrypt unenciphers a blob of justice by decrypting the ciphertext using // Decrypt unenciphers a blob of justice by decrypting the ciphertext using
// chacha20poly1305 with the chosen (nonce, key) pair. The internal plaintext is // chacha20poly1305 with the chosen (nonce, key) pair. The internal plaintext is
// then deserialized using the given encoding version. // then deserialized using the given encoding version.
func Decrypt(key, ciphertext []byte, blobType Type) (*JusticeKit, error) { func Decrypt(key BreachKey, ciphertext []byte,
switch { blobType Type) (*JusticeKit, error) {
// Fail if the blob's overall length is less than required for the nonce // Fail if the blob's overall length is less than required for the nonce
// and expansion factor. // and expansion factor.
case len(ciphertext) < NonceSize+CiphertextExpansion: if len(ciphertext) < NonceSize+CiphertextExpansion {
return nil, ErrCiphertextTooSmall return nil, ErrCiphertextTooSmall
// Fail if the key is not 32-bytes.
case len(key) != KeySize:
return nil, ErrKeySize
} }
// Create a new chacha20poly1305 cipher, using a 32-byte key. // Create a new chacha20poly1305 cipher, using a 32-byte key.
cipher, err := chacha20poly1305.NewX(key) cipher, err := chacha20poly1305.NewX(key[:])
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -56,15 +56,11 @@ type descriptorTest struct {
decErr error decErr error
} }
var rewardAndCommitType = blob.TypeFromFlags(
blob.FlagReward, blob.FlagCommitOutputs,
)
var descriptorTests = []descriptorTest{ var descriptorTests = []descriptorTest{
{ {
name: "to-local only", name: "to-local only",
encVersion: blob.TypeDefault, encVersion: blob.TypeAltruistCommit,
decVersion: blob.TypeDefault, decVersion: blob.TypeAltruistCommit,
sweepAddr: makeAddr(22), sweepAddr: makeAddr(22),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -73,8 +69,8 @@ var descriptorTests = []descriptorTest{
}, },
{ {
name: "to-local and p2wkh", name: "to-local and p2wkh",
encVersion: rewardAndCommitType, encVersion: blob.TypeRewardCommit,
decVersion: rewardAndCommitType, decVersion: blob.TypeRewardCommit,
sweepAddr: makeAddr(22), sweepAddr: makeAddr(22),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -87,7 +83,7 @@ var descriptorTests = []descriptorTest{
{ {
name: "unknown encrypt version", name: "unknown encrypt version",
encVersion: 0, encVersion: 0,
decVersion: blob.TypeDefault, decVersion: blob.TypeAltruistCommit,
sweepAddr: makeAddr(34), sweepAddr: makeAddr(34),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -97,7 +93,7 @@ var descriptorTests = []descriptorTest{
}, },
{ {
name: "unknown decrypt version", name: "unknown decrypt version",
encVersion: blob.TypeDefault, encVersion: blob.TypeAltruistCommit,
decVersion: 0, decVersion: 0,
sweepAddr: makeAddr(34), sweepAddr: makeAddr(34),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
@ -108,8 +104,8 @@ var descriptorTests = []descriptorTest{
}, },
{ {
name: "sweep addr length zero", name: "sweep addr length zero",
encVersion: blob.TypeDefault, encVersion: blob.TypeAltruistCommit,
decVersion: blob.TypeDefault, decVersion: blob.TypeAltruistCommit,
sweepAddr: makeAddr(0), sweepAddr: makeAddr(0),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -118,8 +114,8 @@ var descriptorTests = []descriptorTest{
}, },
{ {
name: "sweep addr max size", name: "sweep addr max size",
encVersion: blob.TypeDefault, encVersion: blob.TypeAltruistCommit,
decVersion: blob.TypeDefault, decVersion: blob.TypeAltruistCommit,
sweepAddr: makeAddr(blob.MaxSweepAddrSize), sweepAddr: makeAddr(blob.MaxSweepAddrSize),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -128,8 +124,8 @@ var descriptorTests = []descriptorTest{
}, },
{ {
name: "sweep addr too long", name: "sweep addr too long",
encVersion: blob.TypeDefault, encVersion: blob.TypeAltruistCommit,
decVersion: blob.TypeDefault, decVersion: blob.TypeAltruistCommit,
sweepAddr: makeAddr(blob.MaxSweepAddrSize + 1), sweepAddr: makeAddr(blob.MaxSweepAddrSize + 1),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -165,8 +161,8 @@ func testBlobJusticeKitEncryptDecrypt(t *testing.T, test descriptorTest) {
// Generate a random encryption key for the blob. The key is // Generate a random encryption key for the blob. The key is
// sized at 32 byte, as in practice we will be using the remote // sized at 32 byte, as in practice we will be using the remote
// party's commitment txid as the key. // party's commitment txid as the key.
key := make([]byte, blob.KeySize) var key blob.BreachKey
_, err := io.ReadFull(rand.Reader, key) _, err := rand.Read(key[:])
if err != nil { if err != nil {
t.Fatalf("unable to generate blob encryption key: %v", err) t.Fatalf("unable to generate blob encryption key: %v", err)
} }

@ -45,9 +45,15 @@ func (f Flag) String() string {
// of the blob itself. // of the blob itself.
type Type uint16 type Type uint16
// TypeDefault sweeps only commitment outputs to a sweep address controlled by const (
// the user, and does not give the tower a reward. // TypeAltruistCommit sweeps only commitment outputs to a sweep address
const TypeDefault = Type(FlagCommitOutputs) // controlled by the user, and does not give the tower a reward.
TypeAltruistCommit = Type(FlagCommitOutputs)
// TypeRewardCommit sweeps only commitment outputs to a sweep address
// controlled by the user, and pays a negotiated reward to the tower.
TypeRewardCommit = Type(FlagCommitOutputs | FlagReward)
)
// Has returns true if the Type has the passed flag enabled. // Has returns true if the Type has the passed flag enabled.
func (t Type) Has(flag Flag) bool { func (t Type) Has(flag Flag) bool {
@ -114,8 +120,8 @@ func (t Type) String() string {
// supportedTypes is the set of all configurations known to be supported by the // supportedTypes is the set of all configurations known to be supported by the
// package. // package.
var supportedTypes = map[Type]struct{}{ var supportedTypes = map[Type]struct{}{
FlagCommitOutputs.Type(): {}, TypeAltruistCommit: {},
(FlagCommitOutputs | FlagReward).Type(): {}, TypeRewardCommit: {},
} }
// IsSupportedType returns true if the given type is supported by the package. // IsSupportedType returns true if the given type is supported by the package.

@ -17,12 +17,12 @@ type typeStringTest struct {
var typeStringTests = []typeStringTest{ var typeStringTests = []typeStringTest{
{ {
name: "commit no-reward", name: "commit no-reward",
typ: blob.TypeDefault, typ: blob.TypeAltruistCommit,
expStr: "[FlagCommitOutputs|No-FlagReward]", expStr: "[FlagCommitOutputs|No-FlagReward]",
}, },
{ {
name: "commit reward", name: "commit reward",
typ: (blob.FlagCommitOutputs | blob.FlagReward).Type(), typ: blob.TypeRewardCommit,
expStr: "[FlagCommitOutputs|FlagReward]", expStr: "[FlagCommitOutputs|FlagReward]",
}, },
{ {
@ -75,7 +75,7 @@ var typeFromFlagTests = []typeFromFlagTest{
{ {
name: "multiple flags", name: "multiple flags",
flags: []blob.Flag{blob.FlagReward, blob.FlagCommitOutputs}, flags: []blob.Flag{blob.FlagReward, blob.FlagCommitOutputs},
expType: blob.Type(blob.FlagReward | blob.FlagCommitOutputs), expType: blob.TypeRewardCommit,
}, },
{ {
name: "duplicate flag", name: "duplicate flag",
@ -119,8 +119,8 @@ func TestTypeFromFlags(t *testing.T) {
// blob.DefaultType returns true. // blob.DefaultType returns true.
func TestSupportedTypes(t *testing.T) { func TestSupportedTypes(t *testing.T) {
// Assert that the package's default type is supported. // Assert that the package's default type is supported.
if !blob.IsSupportedType(blob.TypeDefault) { if !blob.IsSupportedType(blob.TypeAltruistCommit) {
t.Fatalf("default type %s is not supported", blob.TypeDefault) t.Fatalf("default type %s is not supported", blob.TypeAltruistCommit)
} }
// Assert that all claimed supported types are actually supported. // Assert that all claimed supported types are actually supported.

@ -1,14 +1,68 @@
// +build !experimental
package watchtower package watchtower
// Conf specifies the watchtower options that be configured from the command import (
// line or configuration file. In non-experimental builds, we disallow such "time"
// configuration. )
type Conf struct{}
// Apply returns an error signaling that the Conf could not be applied in // Conf specifies the watchtower options that can be configured from the command
// non-experimental builds. // line or configuration file.
func (c *Conf) Apply(cfg *Config) (*Config, error) { type Conf struct {
return nil, ErrNonExperimentalConf // RawListeners configures the watchtower's listening ports/interfaces.
RawListeners []string `long:"listen" description:"Add interfaces/ports to listen for peer connections"`
// ReadTimeout specifies the duration the tower will wait when trying to
// read a message from a client before hanging up.
ReadTimeout time.Duration `long:"readtimeout" description:"Duration the watchtower server will wait for messages to be received before hanging up on clients"`
// WriteTimeout specifies the duration the tower will wait when trying
// to write a message from a client before hanging up.
WriteTimeout time.Duration `long:"writetimeout" description:"Duration the watchtower server will wait for messages to be written before hanging up on client connections"`
}
// Apply completes the passed Config struct by applying any parsed Conf options.
// If the corresponding values parsed by Conf are already set in the Config,
// those fields will be not be modified.
func (c *Conf) Apply(cfg *Config,
normalizer AddressNormalizer) (*Config, error) {
// Set the Config's listening addresses if they are empty.
if cfg.ListenAddrs == nil {
// Without a network, we will be unable to resolve the listening
// addresses.
if cfg.Net == nil {
return nil, ErrNoNetwork
}
// If no addresses are specified by the Config, we will resort
// to the default peer port.
if len(c.RawListeners) == 0 {
addr := DefaultPeerPortStr
c.RawListeners = append(c.RawListeners, addr)
}
// Normalize the raw listening addresses so that they can be
// used by the brontide listener.
var err error
cfg.ListenAddrs, err = normalizer(
c.RawListeners, DefaultPeerPortStr,
cfg.Net.ResolveTCPAddr,
)
if err != nil {
return nil, err
}
}
// If the Config has no read timeout, we will use the parsed Conf
// value.
if cfg.ReadTimeout == 0 && c.ReadTimeout != 0 {
cfg.ReadTimeout = c.ReadTimeout
}
// If the Config has no write timeout, we will use the parsed Conf
// value.
if cfg.WriteTimeout == 0 && c.WriteTimeout != 0 {
cfg.WriteTimeout = c.WriteTimeout
}
return cfg, nil
} }

@ -1,65 +0,0 @@
// +build experimental
package watchtower
import (
"time"
"github.com/lightningnetwork/lnd/lncfg"
)
// Conf specifies the watchtower options that can be configured from the command
// line or configuration file.
type Conf struct {
RawListeners []string `long:"listen" description:"Add interfaces/ports to listen for peer connections"`
ReadTimeout time.Duration `long:"readtimeout" description:"Duration the watchtower server will wait for messages to be received before hanging up on clients"`
WriteTimeout time.Duration `long:"writetimeout" description:"Duration the watchtower server will wait for messages to be written before hanging up on client connections"`
}
// Apply completes the passed Config struct by applying any parsed Conf options.
// If the corresponding values parsed by Conf are already set in the Config,
// those fields will be not be modified.
func (c *Conf) Apply(cfg *Config) (*Config, error) {
// Set the Config's listening addresses if they are empty.
if cfg.ListenAddrs == nil {
// Without a network, we will be unable to resolve the listening
// addresses.
if cfg.Net == nil {
return nil, ErrNoNetwork
}
// If no addresses are specified by the Config, we will resort
// to the default peer port.
if len(c.RawListeners) == 0 {
addr := DefaultPeerPortStr
c.RawListeners = append(c.RawListeners, addr)
}
// Normalize the raw listening addresses so that they can be
// used by the brontide listener.
var err error
cfg.ListenAddrs, err = lncfg.NormalizeAddresses(
c.RawListeners, DefaultPeerPortStr,
cfg.Net.ResolveTCPAddr,
)
if err != nil {
return nil, err
}
}
// If the Config has no read timeout, we will use the parsed Conf
// value.
if cfg.ReadTimeout == 0 && c.ReadTimeout != 0 {
cfg.ReadTimeout = c.ReadTimeout
}
// If the Config has no write timeout, we will use the parsed Conf
// value.
if cfg.WriteTimeout == 0 && c.WriteTimeout != 0 {
cfg.WriteTimeout = c.WriteTimeout
}
return cfg, nil
}

@ -7,11 +7,6 @@ var (
// rendering the tower unable to receive client requests. // rendering the tower unable to receive client requests.
ErrNoListeners = errors.New("no listening ports were specified") ErrNoListeners = errors.New("no listening ports were specified")
// ErrNonExperimentalConf signals that an attempt to apply a
// non-experimental Conf to a Config was detected.
ErrNonExperimentalConf = errors.New("cannot use watchtower in non-" +
"experimental builds")
// ErrNoNetwork signals that no tor.Net is provided in the Config, which // ErrNoNetwork signals that no tor.Net is provided in the Config, which
// prevents resolution of listening addresses. // prevents resolution of listening addresses.
ErrNoNetwork = errors.New("no network specified, must be tor or clearnet") ErrNoNetwork = errors.New("no network specified, must be tor or clearnet")

@ -1,6 +1,8 @@
package watchtower package watchtower
import ( import (
"net"
"github.com/lightningnetwork/lnd/watchtower/lookout" "github.com/lightningnetwork/lnd/watchtower/lookout"
"github.com/lightningnetwork/lnd/watchtower/wtserver" "github.com/lightningnetwork/lnd/watchtower/wtserver"
) )
@ -12,3 +14,8 @@ type DB interface {
lookout.DB lookout.DB
wtserver.DB wtserver.DB
} }
// AddressNormalizer is a function signature that allows the tower to resolve
// TCP addresses on clear or onion networks.
type AddressNormalizer func(addrs []string, defaultPort string,
resolver func(string, string) (*net.TCPAddr, error)) ([]net.Addr, error)

@ -4,6 +4,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtdb"
) )
@ -37,7 +38,7 @@ type DB interface {
// QueryMatches searches its database for any state updates matching the // QueryMatches searches its database for any state updates matching the
// provided breach hints. If any matches are found, they will be // provided breach hints. If any matches are found, they will be
// returned along with encrypted blobs so that justice can be exacted. // returned along with encrypted blobs so that justice can be exacted.
QueryMatches([]wtdb.BreachHint) ([]wtdb.Match, error) QueryMatches([]blob.BreachHint) ([]wtdb.Match, error)
// SetLookoutTip writes the best epoch for which the watchtower has // SetLookoutTip writes the best epoch for which the watchtower has
// queried for breach hints. // queried for breach hints.

@ -156,9 +156,11 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) {
// parameters that should be used in constructing the justice // parameters that should be used in constructing the justice
// transaction. // transaction.
policy := wtpolicy.Policy{ policy := wtpolicy.Policy{
BlobType: blobType, TxPolicy: wtpolicy.TxPolicy{
SweepFeeRate: 2000, BlobType: blobType,
RewardRate: 900000, SweepFeeRate: 2000,
RewardRate: 900000,
},
} }
sessionInfo := &wtdb.SessionInfo{ sessionInfo := &wtdb.SessionInfo{
Policy: policy, Policy: policy,

@ -7,7 +7,6 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
) )
// Config houses the Lookout's required resources to properly fulfill it's duty, // Config houses the Lookout's required resources to properly fulfill it's duty,
@ -159,11 +158,11 @@ func (l *Lookout) processEpoch(epoch *chainntnfs.BlockEpoch,
// Iterate over the transactions contained in the block, deriving a // Iterate over the transactions contained in the block, deriving a
// breach hint for each transaction and constructing an index mapping // breach hint for each transaction and constructing an index mapping
// the hint back to it's original transaction. // the hint back to it's original transaction.
hintToTx := make(map[wtdb.BreachHint]*wire.MsgTx, numTxnsInBlock) hintToTx := make(map[blob.BreachHint]*wire.MsgTx, numTxnsInBlock)
txHints := make([]wtdb.BreachHint, 0, numTxnsInBlock) txHints := make([]blob.BreachHint, 0, numTxnsInBlock)
for _, tx := range block.Transactions { for _, tx := range block.Transactions {
hash := tx.TxHash() hash := tx.TxHash()
hint := wtdb.NewBreachHintFromHash(&hash) hint := blob.NewBreachHintFromHash(&hash)
txHints = append(txHints, hint) txHints = append(txHints, hint)
hintToTx[hint] = tx hintToTx[hint] = tx
@ -203,13 +202,16 @@ func (l *Lookout) processEpoch(epoch *chainntnfs.BlockEpoch,
// The decryption key for the state update should be the full // The decryption key for the state update should be the full
// txid of the breaching commitment transaction. // txid of the breaching commitment transaction.
commitTxID := commitTx.TxHash() // The decryption key for the state update should be computed as
// key = SHA256(txid).
breachTxID := commitTx.TxHash()
breachKey := blob.NewBreachKeyFromHash(&breachTxID)
// Now, decrypt the blob of justice that we received in the // Now, decrypt the blob of justice that we received in the
// state update. This will contain all information required to // state update. This will contain all information required to
// sweep the breached commitment outputs. // sweep the breached commitment outputs.
justiceKit, err := blob.Decrypt( justiceKit, err := blob.Decrypt(
commitTxID[:], match.EncryptedBlob, breachKey, match.EncryptedBlob,
match.SessionInfo.Policy.BlobType, match.SessionInfo.Policy.BlobType,
) )
if err != nil { if err != nil {

@ -96,7 +96,10 @@ func TestLookoutBreachMatching(t *testing.T) {
sessionInfo1 := &wtdb.SessionInfo{ sessionInfo1 := &wtdb.SessionInfo{
ID: makeArray33(1), ID: makeArray33(1),
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
BlobType: rewardAndCommitType, TxPolicy: wtpolicy.TxPolicy{
BlobType: rewardAndCommitType,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 10, MaxUpdates: 10,
}, },
RewardAddress: makeAddrSlice(22), RewardAddress: makeAddrSlice(22),
@ -104,7 +107,10 @@ func TestLookoutBreachMatching(t *testing.T) {
sessionInfo2 := &wtdb.SessionInfo{ sessionInfo2 := &wtdb.SessionInfo{
ID: makeArray33(2), ID: makeArray33(2),
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
BlobType: rewardAndCommitType, TxPolicy: wtpolicy.TxPolicy{
BlobType: rewardAndCommitType,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 10, MaxUpdates: 10,
}, },
RewardAddress: makeAddrSlice(22), RewardAddress: makeAddrSlice(22),
@ -148,14 +154,17 @@ func TestLookoutBreachMatching(t *testing.T) {
CommitToLocalSig: makeArray64(2), CommitToLocalSig: makeArray64(2),
} }
// Encrypt the first justice kit under the txid of the first txn. key1 := blob.NewBreachKeyFromHash(&hash1)
encBlob1, err := blob1.Encrypt(hash1[:], blob.FlagCommitOutputs.Type()) key2 := blob.NewBreachKeyFromHash(&hash2)
// Encrypt the first justice kit under breach key one.
encBlob1, err := blob1.Encrypt(key1, blob.FlagCommitOutputs.Type())
if err != nil { if err != nil {
t.Fatalf("unable to encrypt sweep detail 1: %v", err) t.Fatalf("unable to encrypt sweep detail 1: %v", err)
} }
// Encrypt the second justice kit under the txid of the second txn. // Encrypt the second justice kit under breach key two.
encBlob2, err := blob2.Encrypt(hash2[:], blob.FlagCommitOutputs.Type()) encBlob2, err := blob2.Encrypt(key2, blob.FlagCommitOutputs.Type())
if err != nil { if err != nil {
t.Fatalf("unable to encrypt sweep detail 2: %v", err) t.Fatalf("unable to encrypt sweep detail 2: %v", err)
} }
@ -163,13 +172,13 @@ func TestLookoutBreachMatching(t *testing.T) {
// Add both state updates to the tower's database. // Add both state updates to the tower's database.
txBlob1 := &wtdb.SessionStateUpdate{ txBlob1 := &wtdb.SessionStateUpdate{
ID: makeArray33(1), ID: makeArray33(1),
Hint: wtdb.NewBreachHintFromHash(&hash1), Hint: blob.NewBreachHintFromHash(&hash1),
EncryptedBlob: encBlob1, EncryptedBlob: encBlob1,
SeqNum: 1, SeqNum: 1,
} }
txBlob2 := &wtdb.SessionStateUpdate{ txBlob2 := &wtdb.SessionStateUpdate{
ID: makeArray33(2), ID: makeArray33(2),
Hint: wtdb.NewBreachHintFromHash(&hash2), Hint: blob.NewBreachHintFromHash(&hash2),
EncryptedBlob: encBlob2, EncryptedBlob: encBlob2,
SeqNum: 1, SeqNum: 1,
} }

@ -78,13 +78,14 @@ func New(cfg *Config) (*Standalone, error) {
// Initialize the server with its required resources. // Initialize the server with its required resources.
server, err := wtserver.New(&wtserver.Config{ server, err := wtserver.New(&wtserver.Config{
ChainHash: cfg.ChainHash, ChainHash: cfg.ChainHash,
DB: cfg.DB, DB: cfg.DB,
NodePrivKey: cfg.NodePrivKey, NodePrivKey: cfg.NodePrivKey,
Listeners: listeners, Listeners: listeners,
ReadTimeout: cfg.ReadTimeout, ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout, WriteTimeout: cfg.WriteTimeout,
NewAddress: cfg.NewAddress, NewAddress: cfg.NewAddress,
DisableReward: true,
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -173,9 +173,9 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// required pieces from signatures, witness scripts, etc are then packaged into // required pieces from signatures, witness scripts, etc are then packaged into
// a JusticeKit and encrypted using the breach transaction's key. // a JusticeKit and encrypted using the breach transaction's key.
func (t *backupTask) craftSessionPayload( func (t *backupTask) craftSessionPayload(
signer input.Signer) (wtdb.BreachHint, []byte, error) { signer input.Signer) (blob.BreachHint, []byte, error) {
var hint wtdb.BreachHint var hint blob.BreachHint
// First, copy over the sweep pkscript, the pubkeys used to derive the // First, copy over the sweep pkscript, the pubkeys used to derive the
// to-local script, and the remote CSV delay. // to-local script, and the remote CSV delay.
@ -276,22 +276,19 @@ func (t *backupTask) craftSessionPayload(
} }
} }
// Compute the breach hint from the breach transaction id's prefix. breachTxID := t.breachInfo.BreachTransaction.TxHash()
breachKey := t.breachInfo.BreachTransaction.TxHash()
// Compute the breach key as SHA256(txid).
hint, key := blob.NewBreachHintAndKeyFromHash(&breachTxID)
// Then, we'll encrypt the computed justice kit using the full breach // Then, we'll encrypt the computed justice kit using the full breach
// transaction id, which will allow the tower to recover the contents // transaction id, which will allow the tower to recover the contents
// after the transaction is seen in the chain or mempool. // after the transaction is seen in the chain or mempool.
encBlob, err := justiceKit.Encrypt(breachKey[:], t.blobType) encBlob, err := justiceKit.Encrypt(key, t.blobType)
if err != nil { if err != nil {
return hint, nil, err return hint, nil, err
} }
// Finally, compute the breach hint, taken as the first half of the
// breach transactions txid. Once the tower sees the breach transaction
// on the network, it can use the full txid to decyrpt the blob.
hint = wtdb.NewBreachHintFromHash(&breachKey)
return hint, encBlob, nil return hint, encBlob, nil
} }

@ -207,9 +207,11 @@ func genTaskTest(
expRewardScript: rewardScript, expRewardScript: rewardScript,
session: &wtdb.ClientSessionBody{ session: &wtdb.ClientSessionBody{
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
BlobType: blobType, TxPolicy: wtpolicy.TxPolicy{
SweepFeeRate: sweepFeeRate, BlobType: blobType,
RewardRate: 10000, SweepFeeRate: sweepFeeRate,
RewardRate: 10000,
},
}, },
RewardPkScript: rewardScript, RewardPkScript: rewardScript,
}, },
@ -516,7 +518,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
// Verify that the breach hint matches the breach txid's prefix. // Verify that the breach hint matches the breach txid's prefix.
breachTxID := test.breachInfo.BreachTransaction.TxHash() breachTxID := test.breachInfo.BreachTransaction.TxHash()
expHint := wtdb.NewBreachHintFromHash(&breachTxID) expHint := blob.NewBreachHintFromHash(&breachTxID)
if hint != expHint { if hint != expHint {
t.Fatalf("breach hint mismatch, want: %x, got: %v", t.Fatalf("breach hint mismatch, want: %x, got: %v",
expHint, hint) expHint, hint)
@ -524,7 +526,8 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
// Decrypt the return blob to obtain the JusticeKit containing its // Decrypt the return blob to obtain the JusticeKit containing its
// contents. // contents.
jKit, err := blob.Decrypt(breachTxID[:], encBlob, policy.BlobType) key := blob.NewBreachKeyFromHash(&breachTxID)
jKit, err := blob.Decrypt(key, encBlob, policy.BlobType)
if err != nil { if err != nil {
t.Fatalf("unable to decrypt blob: %v", err) t.Fatalf("unable to decrypt blob: %v", err)
} }

@ -29,6 +29,11 @@ const (
// DefaultStatInterval specifies the default interval between logging // DefaultStatInterval specifies the default interval between logging
// metrics about the client's operation. // metrics about the client's operation.
DefaultStatInterval = 30 * time.Second DefaultStatInterval = 30 * time.Second
// DefaultForceQuitDelay specifies the default duration after which the
// client should abandon any pending updates or session negotiations
// before terminating.
DefaultForceQuitDelay = 10 * time.Second
) )
// Client is the primary interface used by the daemon to control a client's // Client is the primary interface used by the daemon to control a client's
@ -514,9 +519,10 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue {
delete(c.candidateSessions, id) delete(c.candidateSessions, id)
// Skip any sessions with policies that don't match the current // Skip any sessions with policies that don't match the current
// configuration. These can be used again if the client changes // TxPolicy, as they would result in different justice
// their configuration back. // transactions from what is requested. These can be used again
if sessionInfo.Policy != c.cfg.Policy { // if the client changes their configuration and restarting.
if sessionInfo.Policy.TxPolicy != c.cfg.Policy.TxPolicy {
continue continue
} }
@ -561,6 +567,7 @@ func (c *TowerClient) backupDispatcher() {
// Wait until we receive the newly negotiated session. // Wait until we receive the newly negotiated session.
// All backups sent in the meantime are queued in the // All backups sent in the meantime are queued in the
// revoke queue, as we cannot process them. // revoke queue, as we cannot process them.
awaitSession:
select { select {
case session := <-c.negotiator.NewSessions(): case session := <-c.negotiator.NewSessions():
log.Infof("Acquired new session with id=%s", log.Infof("Acquired new session with id=%s",
@ -571,6 +578,12 @@ func (c *TowerClient) backupDispatcher() {
case <-c.statTicker.C: case <-c.statTicker.C:
log.Infof("Client stats: %s", c.stats) log.Infof("Client stats: %s", c.stats)
// Instead of looping, we'll jump back into the
// select case and await the delivery of the
// session to prevent us from re-requesting
// additional sessions.
goto awaitSession
case <-c.forceQuit: case <-c.forceQuit:
return return
} }
@ -626,9 +639,7 @@ func (c *TowerClient) backupDispatcher() {
return return
} }
log.Debugf("Processing backup task chanid=%s "+ log.Debugf("Processing %v", task.id)
"commit-height=%d", task.id.ChanID,
task.id.CommitHeight)
c.stats.taskReceived() c.stats.taskReceived()
c.processTask(task) c.processTask(task)
@ -659,8 +670,8 @@ func (c *TowerClient) processTask(task *backupTask) {
// sessionQueue will be removed if accepting the task left the sessionQueue in // sessionQueue will be removed if accepting the task left the sessionQueue in
// an exhausted state. // an exhausted state.
func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) { func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) {
log.Infof("Backup chanid=%s commit-height=%d accepted successfully", log.Infof("Queued %v successfully for session %v",
task.id.ChanID, task.id.CommitHeight) task.id, c.sessionQueue.ID())
c.stats.taskAccepted() c.stats.taskAccepted()
@ -701,16 +712,14 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) {
case reserveAvailable: case reserveAvailable:
c.stats.taskIneligible() c.stats.taskIneligible()
log.Infof("Backup chanid=%s commit-height=%d is ineligible", log.Infof("Ignoring ineligible %v", task.id)
task.id.ChanID, task.id.CommitHeight)
err := c.cfg.DB.MarkBackupIneligible( err := c.cfg.DB.MarkBackupIneligible(
task.id.ChanID, task.id.CommitHeight, task.id.ChanID, task.id.CommitHeight,
) )
if err != nil { if err != nil {
log.Errorf("Unable to mark task chanid=%s "+ log.Errorf("Unable to mark %v ineligible: %v",
"commit-height=%d ineligible: %v", task.id, err)
task.id.ChanID, task.id.CommitHeight, err)
// It is safe to not handle this error, even if we could // It is safe to not handle this error, even if we could
// not persist the result. At worst, this task may be // not persist the result. At worst, this task may be
@ -729,10 +738,8 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) {
case reserveExhausted: case reserveExhausted:
c.stats.sessionExhausted() c.stats.sessionExhausted()
log.Debugf("Session %s exhausted, backup chanid=%s "+ log.Debugf("Session %v exhausted, %s queued for next session",
"commit-height=%d queued for next session", c.sessionQueue.ID(), task.id)
c.sessionQueue.ID(), task.id.ChanID,
task.id.CommitHeight)
// Cache the task that we pulled off, so that we can process it // Cache the task that we pulled off, so that we can process it
// once a new session queue is available. // once a new session queue is available.

@ -576,17 +576,17 @@ func (h *testHarness) registerChannel(id uint64) {
// advanceChannelN calls advanceState on the channel identified by id the number // advanceChannelN calls advanceState on the channel identified by id the number
// of provided times and returns the breach hints corresponding to the new // of provided times and returns the breach hints corresponding to the new
// states. // states.
func (h *testHarness) advanceChannelN(id uint64, n int) []wtdb.BreachHint { func (h *testHarness) advanceChannelN(id uint64, n int) []blob.BreachHint {
h.t.Helper() h.t.Helper()
channel := h.channel(id) channel := h.channel(id)
var hints []wtdb.BreachHint var hints []blob.BreachHint
for i := uint64(0); i < uint64(n); i++ { for i := uint64(0); i < uint64(n); i++ {
channel.advanceState(h.t) channel.advanceState(h.t)
commitTx, _ := h.channel(id).getState(i) commitTx, _ := h.channel(id).getState(i)
breachTxID := commitTx.TxHash() breachTxID := commitTx.TxHash()
hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID)) hints = append(hints, blob.NewBreachHintFromHash(&breachTxID))
} }
return hints return hints
@ -621,18 +621,18 @@ func (h *testHarness) backupState(id, i uint64, expErr error) {
// party for each state in from-to times and returns the breach hints for states // party for each state in from-to times and returns the breach hints for states
// [from, to). // [from, to).
func (h *testHarness) sendPayments(id, from, to uint64, func (h *testHarness) sendPayments(id, from, to uint64,
amt lnwire.MilliSatoshi) []wtdb.BreachHint { amt lnwire.MilliSatoshi) []blob.BreachHint {
h.t.Helper() h.t.Helper()
channel := h.channel(id) channel := h.channel(id)
var hints []wtdb.BreachHint var hints []blob.BreachHint
for i := from; i < to; i++ { for i := from; i < to; i++ {
h.channel(id).sendPayment(h.t, amt) h.channel(id).sendPayment(h.t, amt)
commitTx, _ := channel.getState(i) commitTx, _ := channel.getState(i)
breachTxID := commitTx.TxHash() breachTxID := commitTx.TxHash()
hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID)) hints = append(hints, blob.NewBreachHintFromHash(&breachTxID))
} }
return hints return hints
@ -642,18 +642,18 @@ func (h *testHarness) sendPayments(id, from, to uint64,
// remote party for each state in from-to times and returns the breach hints for // remote party for each state in from-to times and returns the breach hints for
// states [from, to). // states [from, to).
func (h *testHarness) recvPayments(id, from, to uint64, func (h *testHarness) recvPayments(id, from, to uint64,
amt lnwire.MilliSatoshi) []wtdb.BreachHint { amt lnwire.MilliSatoshi) []blob.BreachHint {
h.t.Helper() h.t.Helper()
channel := h.channel(id) channel := h.channel(id)
var hints []wtdb.BreachHint var hints []blob.BreachHint
for i := from; i < to; i++ { for i := from; i < to; i++ {
channel.receivePayment(h.t, amt) channel.receivePayment(h.t, amt)
commitTx, _ := channel.getState(i) commitTx, _ := channel.getState(i)
breachTxID := commitTx.TxHash() breachTxID := commitTx.TxHash()
hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID)) hints = append(hints, blob.NewBreachHintFromHash(&breachTxID))
} }
return hints return hints
@ -662,7 +662,7 @@ func (h *testHarness) recvPayments(id, from, to uint64,
// waitServerUpdates blocks until the breach hints provided all appear in the // waitServerUpdates blocks until the breach hints provided all appear in the
// watchtower's database or the timeout expires. This is used to test that the // watchtower's database or the timeout expires. This is used to test that the
// client in fact sends the updates to the server, even if it is offline. // client in fact sends the updates to the server, even if it is offline.
func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint, func (h *testHarness) waitServerUpdates(hints []blob.BreachHint,
timeout time.Duration) { timeout time.Duration) {
h.t.Helper() h.t.Helper()
@ -671,7 +671,7 @@ func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint,
// assert that no updates appear. // assert that no updates appear.
wantUpdates := len(hints) > 0 wantUpdates := len(hints) > 0
hintSet := make(map[wtdb.BreachHint]struct{}) hintSet := make(map[blob.BreachHint]struct{})
for _, hint := range hints { for _, hint := range hints {
hintSet[hint] = struct{}{} hintSet[hint] = struct{}{}
} }
@ -737,7 +737,7 @@ func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint,
// assertUpdatesForPolicy queries the server db for matches using the provided // assertUpdatesForPolicy queries the server db for matches using the provided
// breach hints, then asserts that each match has a session with the expected // breach hints, then asserts that each match has a session with the expected
// policy. // policy.
func (h *testHarness) assertUpdatesForPolicy(hints []wtdb.BreachHint, func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
expPolicy wtpolicy.Policy) { expPolicy wtpolicy.Policy) {
// Query for matches on the provided hints. // Query for matches on the provided hints.
@ -785,9 +785,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 20000, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 20000,
}, },
noRegisterChan0: true, noRegisterChan0: true,
}, },
@ -817,9 +819,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 20000, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 20000,
}, },
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {
@ -850,9 +854,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 5, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 5,
}, },
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {
@ -884,9 +890,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 20000, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1000000, // high sweep fee creates dust SweepFeeRate: 1000000, // high sweep fee creates dust
},
MaxUpdates: 20000,
}, },
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {
@ -913,9 +921,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 20000, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 20000,
}, },
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {
@ -993,9 +1003,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 5, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 5,
}, },
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {
@ -1049,9 +1061,11 @@ var clientTests = []clientTest{
localBalance: 10000001, // ensure (% amt != 0) localBalance: 10000001, // ensure (% amt != 0)
remoteBalance: 20000001, // ensure (% amt != 0) remoteBalance: 20000001, // ensure (% amt != 0)
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 1000, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 1000,
}, },
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {
@ -1091,9 +1105,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 5, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 5,
}, },
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {
@ -1113,7 +1129,7 @@ var clientTests = []clientTest{
// Generate the retributions for all 10 channels and // Generate the retributions for all 10 channels and
// collect the breach hints. // collect the breach hints.
var hints []wtdb.BreachHint var hints []blob.BreachHint
for id := uint64(0); id < 10; id++ { for id := uint64(0); id < 10; id++ {
chanHints := h.advanceChannelN(id, numUpdates) chanHints := h.advanceChannelN(id, numUpdates)
hints = append(hints, chanHints...) hints = append(hints, chanHints...)
@ -1139,9 +1155,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 5, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 5,
}, },
noAckCreateSession: true, noAckCreateSession: true,
}, },
@ -1195,9 +1213,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 5, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 5,
}, },
noAckCreateSession: true, noAckCreateSession: true,
}, },
@ -1230,7 +1250,7 @@ var clientTests = []clientTest{
// Restart the client with a new policy, which will // Restart the client with a new policy, which will
// immediately try to overwrite the prior session with // immediately try to overwrite the prior session with
// the old policy. // the old policy.
h.clientCfg.Policy.SweepFeeRate = 2 h.clientCfg.Policy.SweepFeeRate *= 2
h.startClient() h.startClient()
defer h.client.ForceQuit() defer h.client.ForceQuit()
@ -1246,6 +1266,67 @@ var clientTests = []clientTest{
h.assertUpdatesForPolicy(hints, h.clientCfg.Policy) h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
}, },
}, },
{
// Asserts that the client will not request a new session if
// already has an existing session with the same TxPolicy. This
// permits the client to continue using policies that differ in
// operational parameters, but don't manifest in different
// justice transactions.
name: "create session change policy same txpolicy",
cfg: harnessCfg{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 10,
},
},
fn: func(h *testHarness) {
const (
chanID = 0
numUpdates = 6
)
// Generate the retributions that will be backed up.
hints := h.advanceChannelN(chanID, numUpdates)
// Now, queue the first half of the retributions.
h.backupStates(chanID, 0, numUpdates/2, nil)
// Wait for the server to collect the first half.
h.waitServerUpdates(hints[:numUpdates/2], time.Second)
// Stop the client, which should have no more backups.
h.client.Stop()
// Record the policy that the first half was stored
// under. We'll expect the second half to also be stored
// under the original policy, since we are only adjusting
// the MaxUpdates. The client should detect that the
// two policies have equivalent TxPolicies and continue
// using the first.
expPolicy := h.clientCfg.Policy
// Restart the client with a new policy.
h.clientCfg.Policy.MaxUpdates = 20
h.startClient()
defer h.client.ForceQuit()
// Now, queue the second half of the retributions.
h.backupStates(chanID, numUpdates/2, numUpdates, nil)
// Wait for all of the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, 5*time.Second)
// Assert that the server has updates for the client's
// original policy.
h.assertUpdatesForPolicy(hints, expPolicy)
},
},
{ {
// Asserts that the client will deduplicate backups presented by // Asserts that the client will deduplicate backups presented by
// a channel both in memory and after a restart. The client // a channel both in memory and after a restart. The client
@ -1256,9 +1337,11 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
BlobType: blob.TypeDefault, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: 5, BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1, SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 5,
}, },
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {

@ -112,7 +112,7 @@ var _ SessionNegotiator = (*sessionNegotiator)(nil)
// newSessionNegotiator initializes a fresh sessionNegotiator instance. // newSessionNegotiator initializes a fresh sessionNegotiator instance.
func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator { func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator {
localInit := wtwire.NewInitMessage( localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
cfg.ChainHash, cfg.ChainHash,
) )

@ -109,7 +109,7 @@ type sessionQueue struct {
// newSessionQueue intiializes a fresh sessionQueue. // newSessionQueue intiializes a fresh sessionQueue.
func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
localInit := wtwire.NewInitMessage( localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
cfg.ChainHash, cfg.ChainHash,
) )
@ -156,7 +156,7 @@ func (q *sessionQueue) Start() {
// will clear all pending tasks in the queue before returning to the caller. // will clear all pending tasks in the queue before returning to the caller.
func (q *sessionQueue) Stop() { func (q *sessionQueue) Stop() {
q.stopped.Do(func() { q.stopped.Do(func() {
log.Debugf("Stopping session queue %s", q.ID()) log.Debugf("SessionQueue(%s) stopping ...", q.ID())
close(q.quit) close(q.quit)
q.signalUntilShutdown() q.signalUntilShutdown()
@ -168,7 +168,7 @@ func (q *sessionQueue) Stop() {
default: default:
} }
log.Debugf("Session queue %s successfully stopped", q.ID()) log.Debugf("SessionQueue(%s) stopped", q.ID())
}) })
} }
@ -176,12 +176,12 @@ func (q *sessionQueue) Stop() {
// he caller after all lingering goroutines have spun down. // he caller after all lingering goroutines have spun down.
func (q *sessionQueue) ForceQuit() { func (q *sessionQueue) ForceQuit() {
q.forced.Do(func() { q.forced.Do(func() {
log.Infof("Force quitting session queue %s", q.ID()) log.Infof("SessionQueue(%s) force quitting...", q.ID())
close(q.forceQuit) close(q.forceQuit)
q.signalUntilShutdown() q.signalUntilShutdown()
log.Infof("Session queue %s unclean shutdown complete", q.ID()) log.Infof("SessionQueue(%s) force quit", q.ID())
}) })
} }
@ -197,8 +197,15 @@ func (q *sessionQueue) ID() *wtdb.SessionID {
func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) { func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
q.queueCond.L.Lock() q.queueCond.L.Lock()
numPending := uint32(q.pendingQueue.Len())
maxUpdates := q.cfg.ClientSession.Policy.MaxUpdates
log.Debugf("SessionQueue(%x) deciding to accept %v seqnum=%d "+
"pending=%d max-updates=%d",
q.ID(), task.id, q.seqNum, numPending, maxUpdates)
// Examine the current reserve status of the session queue. // Examine the current reserve status of the session queue.
curStatus := q.reserveStatus() curStatus := q.reserveStatus()
switch curStatus { switch curStatus {
// The session queue is exhausted, and cannot accept the task because it // The session queue is exhausted, and cannot accept the task because it
@ -218,9 +225,8 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody) err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody)
if err != nil { if err != nil {
q.queueCond.L.Unlock() q.queueCond.L.Unlock()
log.Debugf("SessionQueue %s rejected backup chanid=%s "+ log.Debugf("SessionQueue(%s) rejected %v: %v ",
"commit-height=%d: %v", q.ID(), task.id.ChanID, q.ID(), task.id, err)
task.id.CommitHeight, err)
return curStatus, false return curStatus, false
} }
} }
@ -288,8 +294,8 @@ func (q *sessionQueue) drainBackups() {
// First, check that we are able to dial this session's tower. // First, check that we are able to dial this session's tower.
conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionPrivKey, q.towerAddr) conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionPrivKey, q.towerAddr)
if err != nil { if err != nil {
log.Errorf("Unable to dial watchtower at %v: %v", log.Errorf("SessionQueue(%s) unable to dial tower at %v: %v",
q.towerAddr, err) q.ID(), q.towerAddr, err)
q.increaseBackoff() q.increaseBackoff()
select { select {
@ -308,9 +314,10 @@ func (q *sessionQueue) drainBackups() {
// Generate the next state update to upload to the tower. This // Generate the next state update to upload to the tower. This
// method will first proceed in dequeueing committed updates // method will first proceed in dequeueing committed updates
// before attempting to dequeue any pending updates. // before attempting to dequeue any pending updates.
stateUpdate, isPending, err := q.nextStateUpdate() stateUpdate, isPending, backupID, err := q.nextStateUpdate()
if err != nil { if err != nil {
log.Errorf("Unable to get next state update: %v", err) log.Errorf("SessionQueue(%s) unable to get next state "+
"update: %v", err)
return return
} }
@ -319,7 +326,8 @@ func (q *sessionQueue) drainBackups() {
conn, stateUpdate, q.localInit, sendInit, isPending, conn, stateUpdate, q.localInit, sendInit, isPending,
) )
if err != nil { if err != nil {
log.Errorf("Unable to send state update: %v", err) log.Errorf("SessionQueue(%s) unable to send state "+
"update: %v", q.ID(), err)
q.increaseBackoff() q.increaseBackoff()
select { select {
@ -329,6 +337,9 @@ func (q *sessionQueue) drainBackups() {
return return
} }
log.Infof("SessionQueue(%s) uploaded %v seqnum=%d",
q.ID(), backupID, stateUpdate.SeqNum)
// If the last task was backed up successfully, we'll exit and // If the last task was backed up successfully, we'll exit and
// continue once more tasks are added to the queue. We'll also // continue once more tasks are added to the queue. We'll also
// clear any accumulated backoff as this batch was able to be // clear any accumulated backoff as this batch was able to be
@ -357,7 +368,9 @@ func (q *sessionQueue) drainBackups() {
// boolean value in the response is true if the state update is taken from the // boolean value in the response is true if the state update is taken from the
// pending queue, allowing the caller to remove the update from either the // pending queue, allowing the caller to remove the update from either the
// commit or pending queue if the update is successfully acked. // commit or pending queue if the update is successfully acked.
func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool,
wtdb.BackupID, error) {
var ( var (
seqNum uint16 seqNum uint16
update wtdb.CommittedUpdate update wtdb.CommittedUpdate
@ -382,8 +395,9 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
isLast = q.commitQueue.Len() == 1 && q.pendingQueue.Len() == 0 isLast = q.commitQueue.Len() == 1 && q.pendingQueue.Len() == 0
q.queueCond.L.Unlock() q.queueCond.L.Unlock()
log.Debugf("Reprocessing committed state update for "+ log.Debugf("SessionQueue(%s) reprocessing committed state "+
"session=%s seqnum=%d", q.ID(), seqNum) "update for %v seqnum=%d",
q.ID(), update.BackupID, seqNum)
// Otherwise, craft and commit the next update from the pending queue. // Otherwise, craft and commit the next update from the pending queue.
default: default:
@ -407,8 +421,9 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
hint, encBlob, err := task.craftSessionPayload(q.cfg.Signer) hint, encBlob, err := task.craftSessionPayload(q.cfg.Signer)
if err != nil { if err != nil {
// TODO(conner): mark will not send // TODO(conner): mark will not send
return nil, false, fmt.Errorf("unable to craft "+ err := fmt.Errorf("unable to craft session payload: %v",
"session payload: %v", err) err)
return nil, false, wtdb.BackupID{}, err
} }
// TODO(conner): special case other obscure errors // TODO(conner): special case other obscure errors
@ -421,8 +436,8 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
}, },
} }
log.Debugf("Committing state update for session=%s seqnum=%d", log.Debugf("SessionQueue(%s) committing state update "+
q.ID(), seqNum) "%v seqnum=%d", q.ID(), update.BackupID, seqNum)
} }
// Before sending the task to the tower, commit the state update // Before sending the task to the tower, commit the state update
@ -439,8 +454,9 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), &update) lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), &update)
if err != nil { if err != nil {
// TODO(conner): mark failed/reschedule // TODO(conner): mark failed/reschedule
return nil, false, fmt.Errorf("unable to commit state update "+ err := fmt.Errorf("unable to commit state update for "+
"for session=%s seqnum=%d: %v", q.ID(), seqNum, err) "%v seqnum=%d: %v", update.BackupID, seqNum, err)
return nil, false, wtdb.BackupID{}, err
} }
stateUpdate := &wtwire.StateUpdate{ stateUpdate := &wtwire.StateUpdate{
@ -455,7 +471,7 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) {
stateUpdate.IsComplete = 1 stateUpdate.IsComplete = 1
} }
return stateUpdate, isPending, nil return stateUpdate, isPending, update.BackupID, nil
} }
// sendStateUpdate sends a wtwire.StateUpdate to the watchtower and processes // sendStateUpdate sends a wtwire.StateUpdate to the watchtower and processes
@ -486,8 +502,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
remoteInit, ok := remoteMsg.(*wtwire.Init) remoteInit, ok := remoteMsg.(*wtwire.Init)
if !ok { if !ok {
return fmt.Errorf("watchtower responded with %T to "+ return fmt.Errorf("watchtower %s responded with %T "+
"Init", remoteMsg) "to Init", q.towerAddr, remoteMsg)
} }
// Validate Init. // Validate Init.
@ -513,8 +529,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply) stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply)
if !ok { if !ok {
return fmt.Errorf("watchtower responded with %T to StateUpdate", return fmt.Errorf("watchtower %s responded with %T to "+
remoteMsg) "StateUpdate", q.towerAddr, remoteMsg)
} }
// Process the reply from the tower. // Process the reply from the tower.
@ -527,10 +543,10 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
// TODO(conner): handle other error cases properly, ban towers, etc. // TODO(conner): handle other error cases properly, ban towers, etc.
default: default:
err := fmt.Errorf("received error code %v in "+ err := fmt.Errorf("received error code %v in "+
"StateUpdateReply from tower=%x session=%v", "StateUpdateReply for seqnum=%d",
stateUpdateReply.Code, stateUpdateReply.Code, stateUpdate.SeqNum)
conn.RemotePub().SerializeCompressed(), q.ID()) log.Warnf("SessionQueue(%s) unable to upload state update to "+
log.Warnf("Unable to upload state update: %v", err) "tower=%s: %v", q.ID(), q.towerAddr, err)
return err return err
} }
@ -539,28 +555,27 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
switch { switch {
case err == wtdb.ErrUnallocatedLastApplied: case err == wtdb.ErrUnallocatedLastApplied:
// TODO(conner): borked watchtower // TODO(conner): borked watchtower
err = fmt.Errorf("unable to ack update=%d session=%s: %v", err = fmt.Errorf("unable to ack seqnum=%d: %v",
stateUpdate.SeqNum, q.ID(), err) stateUpdate.SeqNum, err)
log.Errorf("Failed to ack update: %v", err) log.Errorf("SessionQueue(%s) failed to ack update: %v", err)
return err return err
case err == wtdb.ErrLastAppliedReversion: case err == wtdb.ErrLastAppliedReversion:
// TODO(conner): borked watchtower // TODO(conner): borked watchtower
err = fmt.Errorf("unable to ack update=%d session=%s: %v", err = fmt.Errorf("unable to ack seqnum=%d: %v",
stateUpdate.SeqNum, q.ID(), err) stateUpdate.SeqNum, err)
log.Errorf("Failed to ack update: %v", err) log.Errorf("SessionQueue(%s) failed to ack update: %v",
q.ID(), err)
return err return err
case err != nil: case err != nil:
err = fmt.Errorf("unable to ack update=%d session=%s: %v", err = fmt.Errorf("unable to ack seqnum=%d: %v",
stateUpdate.SeqNum, q.ID(), err) stateUpdate.SeqNum, err)
log.Errorf("Failed to ack update: %v", err) log.Errorf("SessionQueue(%s) failed to ack update: %v",
q.ID(), err)
return err return err
} }
log.Infof("Removing update session=%s seqnum=%d is_pending=%v "+
"from memory", q.ID(), stateUpdate.SeqNum, isPending)
q.queueCond.L.Lock() q.queueCond.L.Lock()
if isPending { if isPending {
// If a pending update was successfully sent, increment the // If a pending update was successfully sent, increment the
@ -591,9 +606,6 @@ func (q *sessionQueue) reserveStatus() reserveStatus {
numPending := uint32(q.pendingQueue.Len()) numPending := uint32(q.pendingQueue.Len())
maxUpdates := uint32(q.cfg.ClientSession.Policy.MaxUpdates) maxUpdates := uint32(q.cfg.ClientSession.Policy.MaxUpdates)
log.Debugf("SessionQueue %s reserveStatus seqnum=%d pending=%d "+
"max-updates=%d", q.ID(), q.seqNum, numPending, maxUpdates)
if uint32(q.seqNum)+numPending < maxUpdates { if uint32(q.seqNum)+numPending < maxUpdates {
return reserveAvailable return reserveAvailable
} }

@ -1,27 +0,0 @@
package wtdb
import (
"encoding/hex"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// BreachHintSize is the length of the txid prefix used to identify remote
// commitment broadcasts.
const BreachHintSize = 16
// BreachHint is the first 16-bytes of the txid belonging to a revoked
// commitment transaction.
type BreachHint [BreachHintSize]byte
// NewBreachHintFromHash creates a breach hint from a transaction ID.
func NewBreachHintFromHash(hash *chainhash.Hash) BreachHint {
var hint BreachHint
copy(hint[:], hash[:BreachHintSize])
return hint
}
// String returns a hex encoding of the breach hint.
func (h BreachHint) String() string {
return hex.EncodeToString(h[:])
}

@ -664,7 +664,7 @@ func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
t.Fatalf("unable to generate chan id: %v", err) t.Fatalf("unable to generate chan id: %v", err)
} }
var hint wtdb.BreachHint var hint blob.BreachHint
if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil { if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil {
t.Fatalf("unable to generate breach hint: %v", err) t.Fatalf("unable to generate breach hint: %v", err)
} }

@ -1,10 +1,12 @@
package wtdb package wtdb
import ( import (
"fmt"
"io" "io"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtpolicy"
) )
@ -159,6 +161,11 @@ func (b *BackupID) Decode(r io.Reader) error {
) )
} }
// String returns a human-readable encoding of a BackupID.
func (b *BackupID) String() string {
return fmt.Sprintf("backup(%x, %d)", b.ChanID, b.CommitHeight)
}
// CommittedUpdate holds a state update sent by a client along with its // CommittedUpdate holds a state update sent by a client along with its
// allocated sequence number and the exact remote commitment the encrypted // allocated sequence number and the exact remote commitment the encrypted
// justice transaction can rectify. // justice transaction can rectify.
@ -178,7 +185,7 @@ type CommittedUpdateBody struct {
BackupID BackupID BackupID BackupID
// Hint is the 16-byte prefix of the revoked commitment transaction ID. // Hint is the 16-byte prefix of the revoked commitment transaction ID.
Hint BreachHint Hint blob.BreachHint
// EncryptedBlob is a ciphertext containing the sweep information for // EncryptedBlob is a ciphertext containing the sweep information for
// exacting justice if the commitment transaction matching the breach // exacting justice if the commitment transaction matching the breach

@ -36,7 +36,7 @@ func ReadElement(r io.Reader, element interface{}) error {
return err return err
} }
case *BreachHint: case *blob.BreachHint:
if _, err := io.ReadFull(r, e[:]); err != nil { if _, err := io.ReadFull(r, e[:]); err != nil {
return err return err
} }
@ -94,7 +94,7 @@ func WriteElement(w io.Writer, element interface{}) error {
return err return err
} }
case BreachHint: case blob.BreachHint:
if _, err := w.Write(e[:]); err != nil { if _, err := w.Write(e[:]); err != nil {
return err return err
} }

@ -4,6 +4,7 @@ import (
"errors" "errors"
"io" "io"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtpolicy"
) )
@ -134,7 +135,7 @@ type Match struct {
SeqNum uint16 SeqNum uint16
// Hint is the breach hint that triggered the match. // Hint is the breach hint that triggered the match.
Hint BreachHint Hint blob.BreachHint
// EncryptedBlob is the encrypted payload containing the justice kit // EncryptedBlob is the encrypted payload containing the justice kit
// uploaded by the client. // uploaded by the client.

@ -1,6 +1,10 @@
package wtdb package wtdb
import "io" import (
"io"
"github.com/lightningnetwork/lnd/watchtower/blob"
)
// SessionStateUpdate holds a state update sent by a client along with its // SessionStateUpdate holds a state update sent by a client along with its
// SessionID. // SessionID.
@ -16,7 +20,7 @@ type SessionStateUpdate struct {
LastApplied uint16 LastApplied uint16
// Hint is the 16-byte prefix of the revoked commitment transaction. // Hint is the 16-byte prefix of the revoked commitment transaction.
Hint BreachHint Hint blob.BreachHint
// EncryptedBlob is a ciphertext containing the sweep information for // EncryptedBlob is a ciphertext containing the sweep information for
// exacting justice if the commitment transaction matching the breach // exacting justice if the commitment transaction matching the breach

@ -7,6 +7,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/watchtower/blob"
) )
const ( const (
@ -45,6 +46,10 @@ var (
// ErrNoSessionHintIndex signals that an active session does not have an // ErrNoSessionHintIndex signals that an active session does not have an
// initialized index for tracking its own state updates. // initialized index for tracking its own state updates.
ErrNoSessionHintIndex = errors.New("session hint index missing") ErrNoSessionHintIndex = errors.New("session hint index missing")
// ErrInvalidBlobSize indicates that the encrypted blob provided by the
// client is not valid according to the blob type of the session.
ErrInvalidBlobSize = errors.New("invalid blob size")
) )
// TowerDB is single database providing a persistent storage engine for the // TowerDB is single database providing a persistent storage engine for the
@ -188,6 +193,12 @@ func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error {
return ErrSessionAlreadyExists return ErrSessionAlreadyExists
} }
// Perform a quick sanity check on the session policy before
// accepting.
if err := session.Policy.Validate(); err != nil {
return err
}
err = putSession(sessions, session) err = putSession(sessions, session)
if err != nil { if err != nil {
return err return err
@ -232,6 +243,13 @@ func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error)
return err return err
} }
// Assert that the blob is the correct size for the session's
// blob type.
expBlobSize := blob.Size(session.Policy.BlobType)
if len(update.EncryptedBlob) != expBlobSize {
return ErrInvalidBlobSize
}
// Validate the update against the current state of the session. // Validate the update against the current state of the session.
err = session.AcceptUpdateSequence( err = session.AcceptUpdateSequence(
update.SeqNum, update.LastApplied, update.SeqNum, update.LastApplied,
@ -369,7 +387,7 @@ func (t *TowerDB) DeleteSession(target SessionID) error {
// QueryMatches searches against all known state updates for any that match the // QueryMatches searches against all known state updates for any that match the
// passed breachHints. More than one Match will be returned for a given hint if // passed breachHints. More than one Match will be returned for a given hint if
// they exist in the database. // they exist in the database.
func (t *TowerDB) QueryMatches(breachHints []BreachHint) ([]Match, error) { func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) {
var matches []Match var matches []Match
err := t.db.View(func(tx *bbolt.Tx) error { err := t.db.View(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(sessionsBkt) sessions := tx.Bucket(sessionsBkt)
@ -534,20 +552,20 @@ func removeSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error {
// If the index for the session has not been initialized, this method returns // If the index for the session has not been initialized, this method returns
// ErrNoSessionHintIndex. // ErrNoSessionHintIndex.
func getHintsForSession(updateIndex *bbolt.Bucket, func getHintsForSession(updateIndex *bbolt.Bucket,
id *SessionID) ([]BreachHint, error) { id *SessionID) ([]blob.BreachHint, error) {
sessionHints := updateIndex.Bucket(id[:]) sessionHints := updateIndex.Bucket(id[:])
if sessionHints == nil { if sessionHints == nil {
return nil, ErrNoSessionHintIndex return nil, ErrNoSessionHintIndex
} }
var hints []BreachHint var hints []blob.BreachHint
err := sessionHints.ForEach(func(k, _ []byte) error { err := sessionHints.ForEach(func(k, _ []byte) error {
if len(k) != BreachHintSize { if len(k) != blob.BreachHintSize {
return nil return nil
} }
var hint BreachHint var hint blob.BreachHint
copy(hint[:], k) copy(hint[:], k)
hints = append(hints, hint) hints = append(hints, hint)
return nil return nil
@ -565,7 +583,7 @@ func getHintsForSession(updateIndex *bbolt.Bucket,
// for the session has not been initialized, this method returns // for the session has not been initialized, this method returns
// ErrNoSessionHintIndex. // ErrNoSessionHintIndex.
func putHintForSession(updateIndex *bbolt.Bucket, id *SessionID, func putHintForSession(updateIndex *bbolt.Bucket, id *SessionID,
hint BreachHint) error { hint blob.BreachHint) error {
sessionHints := updateIndex.Bucket(id[:]) sessionHints := updateIndex.Bucket(id[:])
if sessionHints == nil { if sessionHints == nil {

@ -1,6 +1,7 @@
package wtdb_test package wtdb_test
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"io/ioutil" "io/ioutil"
"os" "os"
@ -10,11 +11,16 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/watchtower" "github.com/lightningnetwork/lnd/watchtower"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtpolicy"
) )
var (
testBlob = make([]byte, blob.Size(blob.TypeAltruistCommit))
)
// dbInit is a closure used to initialize a watchtower.DB instance and its // dbInit is a closure used to initialize a watchtower.DB instance and its
// cleanup function. // cleanup function.
type dbInit func(*testing.T) (watchtower.DB, func()) type dbInit func(*testing.T) (watchtower.DB, func())
@ -97,10 +103,10 @@ func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) {
// queryMatches queries that database for the passed breach hint, returning all // queryMatches queries that database for the passed breach hint, returning all
// matches found. // matches found.
func (h *towerDBHarness) queryMatches(hint wtdb.BreachHint) []wtdb.Match { func (h *towerDBHarness) queryMatches(hint blob.BreachHint) []wtdb.Match {
h.t.Helper() h.t.Helper()
matches, err := h.db.QueryMatches([]wtdb.BreachHint{hint}) matches, err := h.db.QueryMatches([]blob.BreachHint{hint})
if err != nil { if err != nil {
h.t.Fatalf("unable to query matches: %v", err) h.t.Fatalf("unable to query matches: %v", err)
} }
@ -111,7 +117,7 @@ func (h *towerDBHarness) queryMatches(hint wtdb.BreachHint) []wtdb.Match {
// hasUpdate queries the database for the passed breach hint, asserting that // hasUpdate queries the database for the passed breach hint, asserting that
// only one match is present and that the hints indeed match. If successful, the // only one match is present and that the hints indeed match. If successful, the
// match is returned. // match is returned.
func (h *towerDBHarness) hasUpdate(hint wtdb.BreachHint) wtdb.Match { func (h *towerDBHarness) hasUpdate(hint blob.BreachHint) wtdb.Match {
h.t.Helper() h.t.Helper()
matches := h.queryMatches(hint) matches := h.queryMatches(hint)
@ -136,11 +142,21 @@ func testInsertSession(h *towerDBHarness) {
session := &wtdb.SessionInfo{ session := &wtdb.SessionInfo{
ID: id, ID: id,
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
},
MaxUpdates: 100, MaxUpdates: 100,
}, },
RewardAddress: []byte{0x01, 0x02, 0x03}, RewardAddress: []byte{0x01, 0x02, 0x03},
} }
// Try to insert the session, which should fail since the policy doesn't
// meet the current sanity checks.
h.insertSession(session, wtpolicy.ErrSweepFeeRateTooLow)
// Now assign a sane sweep fee rate to the policy, inserting should
// succeed.
session.Policy.SweepFeeRate = wtpolicy.DefaultSweepFeeRate
h.insertSession(session, nil) h.insertSession(session, nil)
session2 := h.getSession(&id, nil) session2 := h.getSession(&id, nil)
@ -154,8 +170,9 @@ func testInsertSession(h *towerDBHarness) {
// Insert a state update to fully commit the session parameters. // Insert a state update to fully commit the session parameters.
update := &wtdb.SessionStateUpdate{ update := &wtdb.SessionStateUpdate{
ID: id, ID: id,
SeqNum: 1, SeqNum: 1,
EncryptedBlob: testBlob,
} }
h.insertUpdate(update, nil) h.insertUpdate(update, nil)
@ -169,12 +186,16 @@ func testMultipleMatches(h *towerDBHarness) {
const numUpdates = 3 const numUpdates = 3
// Create a new session and send updates with all the same hint. // Create a new session and send updates with all the same hint.
var hint wtdb.BreachHint var hint blob.BreachHint
for i := 0; i < numUpdates; i++ { for i := 0; i < numUpdates; i++ {
id := *id(i) id := *id(i)
session := &wtdb.SessionInfo{ session := &wtdb.SessionInfo{
ID: id, ID: id,
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -182,9 +203,10 @@ func testMultipleMatches(h *towerDBHarness) {
h.insertSession(session, nil) h.insertSession(session, nil)
update := &wtdb.SessionStateUpdate{ update := &wtdb.SessionStateUpdate{
ID: id, ID: id,
SeqNum: 1, SeqNum: 1,
Hint: hint, // Use same hint to cause multiple matches Hint: hint, // Use same hint to cause multiple matches
EncryptedBlob: testBlob,
} }
h.insertUpdate(update, nil) h.insertUpdate(update, nil)
} }
@ -266,6 +288,10 @@ func testDeleteSession(h *towerDBHarness) {
session0 := &wtdb.SessionInfo{ session0 := &wtdb.SessionInfo{
ID: *id0, ID: *id0,
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -284,6 +310,10 @@ func testDeleteSession(h *towerDBHarness) {
session1 := &wtdb.SessionInfo{ session1 := &wtdb.SessionInfo{
ID: *id1, ID: *id1,
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -291,18 +321,18 @@ func testDeleteSession(h *towerDBHarness) {
h.insertSession(session1, nil) h.insertSession(session1, nil)
// Create and insert updates for both sessions that have the same hint. // Create and insert updates for both sessions that have the same hint.
var hint wtdb.BreachHint var hint blob.BreachHint
update0 := &wtdb.SessionStateUpdate{ update0 := &wtdb.SessionStateUpdate{
ID: *id0, ID: *id0,
Hint: hint, Hint: hint,
SeqNum: 1, SeqNum: 1,
EncryptedBlob: []byte{}, EncryptedBlob: testBlob,
} }
update1 := &wtdb.SessionStateUpdate{ update1 := &wtdb.SessionStateUpdate{
ID: *id1, ID: *id1,
Hint: hint, Hint: hint,
SeqNum: 1, SeqNum: 1,
EncryptedBlob: []byte{}, EncryptedBlob: testBlob,
} }
// Insert both updates should succeed. // Insert both updates should succeed.
@ -413,7 +443,7 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) {
var stateUpdateNoSession = stateUpdateTest{ var stateUpdateNoSession = stateUpdateTest{
session: nil, session: nil,
updates: []*wtdb.SessionStateUpdate{ updates: []*wtdb.SessionStateUpdate{
{ID: *id(0), SeqNum: 1, LastApplied: 0}, updateFromInt(id(0), 1, 0),
}, },
updateErrs: []error{ updateErrs: []error{
wtdb.ErrSessionNotFound, wtdb.ErrSessionNotFound,
@ -424,6 +454,10 @@ var stateUpdateExhaustSession = stateUpdateTest{
session: &wtdb.SessionInfo{ session: &wtdb.SessionInfo{
ID: *id(0), ID: *id(0),
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -443,6 +477,10 @@ var stateUpdateSeqNumEqualLastApplied = stateUpdateTest{
session: &wtdb.SessionInfo{ session: &wtdb.SessionInfo{
ID: *id(0), ID: *id(0),
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -462,6 +500,10 @@ var stateUpdateSeqNumLTLastApplied = stateUpdateTest{
session: &wtdb.SessionInfo{ session: &wtdb.SessionInfo{
ID: *id(0), ID: *id(0),
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -480,6 +522,10 @@ var stateUpdateSeqNumZeroInvalid = stateUpdateTest{
session: &wtdb.SessionInfo{ session: &wtdb.SessionInfo{
ID: *id(0), ID: *id(0),
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -496,6 +542,10 @@ var stateUpdateSkipSeqNum = stateUpdateTest{
session: &wtdb.SessionInfo{ session: &wtdb.SessionInfo{
ID: *id(0), ID: *id(0),
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -512,6 +562,10 @@ var stateUpdateRevertSeqNum = stateUpdateTest{
session: &wtdb.SessionInfo{ session: &wtdb.SessionInfo{
ID: *id(0), ID: *id(0),
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -530,6 +584,10 @@ var stateUpdateRevertLastApplied = stateUpdateTest{
session: &wtdb.SessionInfo{ session: &wtdb.SessionInfo{
ID: *id(0), ID: *id(0),
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3, MaxUpdates: 3,
}, },
RewardAddress: []byte{}, RewardAddress: []byte{},
@ -545,6 +603,31 @@ var stateUpdateRevertLastApplied = stateUpdateTest{
}, },
} }
var stateUpdateInvalidBlobSize = stateUpdateTest{
session: &wtdb.SessionInfo{
ID: *id(0),
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 3,
},
RewardAddress: []byte{},
},
updates: []*wtdb.SessionStateUpdate{
{
ID: *id(0),
SeqNum: 1,
LastApplied: 0,
EncryptedBlob: []byte{0x01, 0x02, 0x03}, // too $hort
},
},
updateErrs: []error{
wtdb.ErrInvalidBlobSize,
},
}
func TestTowerDB(t *testing.T) { func TestTowerDB(t *testing.T) {
dbs := []struct { dbs := []struct {
name string name string
@ -662,6 +745,10 @@ func TestTowerDB(t *testing.T) {
name: "state update revert last applied", name: "state update revert last applied",
run: runStateUpdateTest(stateUpdateRevertLastApplied), run: runStateUpdateTest(stateUpdateRevertLastApplied),
}, },
{
name: "invalid blob size",
run: runStateUpdateTest(stateUpdateInvalidBlobSize),
},
{ {
name: "multiple breach matches", name: "multiple breach matches",
run: testMultipleMatches, run: testMultipleMatches,
@ -705,16 +792,18 @@ func updateFromInt(id *wtdb.SessionID, i int,
lastApplied uint16) *wtdb.SessionStateUpdate { lastApplied uint16) *wtdb.SessionStateUpdate {
// Ensure the hint is unique. // Ensure the hint is unique.
var hint wtdb.BreachHint var hint blob.BreachHint
copy(hint[:4], id[:4]) copy(hint[:4], id[:4])
binary.BigEndian.PutUint16(hint[4:6], uint16(i)) binary.BigEndian.PutUint16(hint[4:6], uint16(i))
blobSize := blob.Size(blob.TypeAltruistCommit)
return &wtdb.SessionStateUpdate{ return &wtdb.SessionStateUpdate{
ID: *id, ID: *id,
Hint: hint, Hint: hint,
SeqNum: uint16(i), SeqNum: uint16(i),
LastApplied: lastApplied, LastApplied: lastApplied,
EncryptedBlob: []byte{byte(i)}, EncryptedBlob: bytes.Repeat([]byte{byte(i)}, blobSize),
} }
} }

@ -4,6 +4,7 @@ import (
"sync" "sync"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtdb"
) )
@ -12,14 +13,14 @@ type TowerDB struct {
mu sync.Mutex mu sync.Mutex
lastEpoch *chainntnfs.BlockEpoch lastEpoch *chainntnfs.BlockEpoch
sessions map[wtdb.SessionID]*wtdb.SessionInfo sessions map[wtdb.SessionID]*wtdb.SessionInfo
blobs map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate blobs map[blob.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate
} }
// NewTowerDB initializes a fresh mock TowerDB. // NewTowerDB initializes a fresh mock TowerDB.
func NewTowerDB() *TowerDB { func NewTowerDB() *TowerDB {
return &TowerDB{ return &TowerDB{
sessions: make(map[wtdb.SessionID]*wtdb.SessionInfo), sessions: make(map[wtdb.SessionID]*wtdb.SessionInfo),
blobs: make(map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate), blobs: make(map[blob.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate),
} }
} }
@ -36,6 +37,11 @@ func (db *TowerDB) InsertStateUpdate(update *wtdb.SessionStateUpdate) (uint16, e
return 0, wtdb.ErrSessionNotFound return 0, wtdb.ErrSessionNotFound
} }
// Assert that the blob is the correct size for the session's blob type.
if len(update.EncryptedBlob) != blob.Size(info.Policy.BlobType) {
return 0, wtdb.ErrInvalidBlobSize
}
err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied) err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied)
if err != nil { if err != nil {
return info.LastApplied, err return info.LastApplied, err
@ -75,6 +81,11 @@ func (db *TowerDB) InsertSessionInfo(info *wtdb.SessionInfo) error {
return wtdb.ErrSessionAlreadyExists return wtdb.ErrSessionAlreadyExists
} }
// Perform a quick sanity check on the session policy before accepting.
if err := info.Policy.Validate(); err != nil {
return err
}
db.sessions[info.ID] = info db.sessions[info.ID] = info
return nil return nil
@ -113,7 +124,7 @@ func (db *TowerDB) DeleteSession(target wtdb.SessionID) error {
// passed breachHints. More than one Match will be returned for a given hint if // passed breachHints. More than one Match will be returned for a given hint if
// they exist in the database. // they exist in the database.
func (db *TowerDB) QueryMatches( func (db *TowerDB) QueryMatches(
breachHints []wtdb.BreachHint) ([]wtdb.Match, error) { breachHints []blob.BreachHint) ([]wtdb.Match, error) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()

@ -27,7 +27,11 @@ const (
// DefaultSweepFeeRate specifies the fee rate used to construct justice // DefaultSweepFeeRate specifies the fee rate used to construct justice
// transactions. The value is expressed in satoshis per kilo-weight. // transactions. The value is expressed in satoshis per kilo-weight.
DefaultSweepFeeRate = 3000 DefaultSweepFeeRate = lnwallet.SatPerKWeight(12000)
// MinSweepFeeRate is the minimum sweep fee rate a client may use in its
// policy, the current value is 4 sat/kw.
MinSweepFeeRate = lnwallet.SatPerKWeight(4000)
) )
var ( var (
@ -43,34 +47,42 @@ var (
// ErrCreatesDust signals that the session's policy would create a dust // ErrCreatesDust signals that the session's policy would create a dust
// output for the victim. // output for the victim.
ErrCreatesDust = errors.New("justice transaction creates dust at fee rate") ErrCreatesDust = errors.New("justice transaction creates dust at fee rate")
// ErrAltruistReward signals that the policy is invalid because it
// contains a non-zero RewardBase or RewardRate on an altruist policy.
ErrAltruistReward = errors.New("altruist policy has reward params")
// ErrNoMaxUpdates signals that the policy specified zero MaxUpdates.
ErrNoMaxUpdates = errors.New("max updates must be positive")
// ErrSweepFeeRateTooLow signals that the policy's fee rate is too low
// to get into the mempool during low congestion.
ErrSweepFeeRateTooLow = errors.New("sweep fee rate too low")
) )
// DefaultPolicy returns a Policy containing the default parameters that can be // DefaultPolicy returns a Policy containing the default parameters that can be
// used by clients or servers. // used by clients or servers.
func DefaultPolicy() Policy { func DefaultPolicy() Policy {
return Policy{ return Policy{
BlobType: blob.TypeDefault, TxPolicy: TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: DefaultSweepFeeRate,
},
MaxUpdates: DefaultMaxUpdates, MaxUpdates: DefaultMaxUpdates,
RewardRate: DefaultRewardRate,
SweepFeeRate: lnwallet.SatPerKWeight(
DefaultSweepFeeRate,
),
} }
} }
// Policy defines the negotiated parameters for a session between a client and // TxPolicy defines the negotiate parameters that determine the form of the
// server. The parameters specify the format of encrypted blobs sent to the // justice transaction for a given breached state. Thus, for any given revoked
// tower, the reward schedule for the tower, and the number of encrypted blobs a // state, an identical key will result in an identical justice transaction
// client can send in one session. // (barring signatures). The parameters specify the format of encrypted blobs
type Policy struct { // sent to the tower, the reward schedule for the tower, and the number of
// encrypted blobs a client can send in one session.
type TxPolicy struct {
// BlobType specifies the blob format that must be used by all updates sent // BlobType specifies the blob format that must be used by all updates sent
// under the session key used to negotiate this session. // under the session key used to negotiate this session.
BlobType blob.Type BlobType blob.Type
// MaxUpdates is the maximum number of updates the watchtower will honor
// for this session.
MaxUpdates uint16
// RewardBase is the fixed amount allocated to the tower when the // RewardBase is the fixed amount allocated to the tower when the
// policy's blob type specifies a reward for the tower. This is taken // policy's blob type specifies a reward for the tower. This is taken
// before adding the proportional reward. // before adding the proportional reward.
@ -88,6 +100,18 @@ type Policy struct {
SweepFeeRate lnwallet.SatPerKWeight SweepFeeRate lnwallet.SatPerKWeight
} }
// Policy defines the negotiated parameters for a session between a client and
// server. In addition to the TxPolicy that governs the shape of the justice
// transaction, the Policy also includes features which only affect the
// operation of the session.
type Policy struct {
TxPolicy
// MaxUpdates is the maximum number of updates the watchtower will honor
// for this session.
MaxUpdates uint16
}
// String returns a human-readable description of the current policy. // String returns a human-readable description of the current policy.
func (p Policy) String() string { func (p Policy) String() string {
return fmt.Sprintf("(blob-type=%b max-updates=%d reward-rate=%d "+ return fmt.Sprintf("(blob-type=%b max-updates=%d reward-rate=%d "+
@ -95,6 +119,31 @@ func (p Policy) String() string {
p.SweepFeeRate) p.SweepFeeRate)
} }
// Validate ensures that the policy satisfies some minimal correctness
// constraints.
func (p Policy) Validate() error {
// RewardBase and RewardRate should not be set if the policy doesn't
// have a reward.
if !p.BlobType.Has(blob.FlagReward) &&
(p.RewardBase != 0 || p.RewardRate != 0) {
return ErrAltruistReward
}
// MaxUpdates must be positive.
if p.MaxUpdates == 0 {
return ErrNoMaxUpdates
}
// SweepFeeRate must be sane enough to get in the mempool during low
// congestion.
if p.SweepFeeRate < MinSweepFeeRate {
return ErrSweepFeeRateTooLow
}
return nil
}
// ComputeAltruistOutput computes the lone output value of a justice transaction // ComputeAltruistOutput computes the lone output value of a justice transaction
// that pays no reward to the tower. The value is computed using the weight of // that pays no reward to the tower. The value is computed using the weight of
// of the justice transaction and subtracting an amount that satisfies the // of the justice transaction and subtracting an amount that satisfies the

@ -0,0 +1,93 @@
package wtpolicy_test
import (
"testing"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
)
var validationTests = []struct {
name string
policy wtpolicy.Policy
expErr error
}{
{
name: "fail no maxupdates",
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
},
},
expErr: wtpolicy.ErrNoMaxUpdates,
},
{
name: "fail altruist with reward base",
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
RewardBase: 1,
},
},
expErr: wtpolicy.ErrAltruistReward,
},
{
name: "fail altruist with reward rate",
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
RewardRate: 1,
},
},
expErr: wtpolicy.ErrAltruistReward,
},
{
name: "fail sweep fee rate too low",
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
},
MaxUpdates: 1,
},
expErr: wtpolicy.ErrSweepFeeRateTooLow,
},
{
name: "minimal valid altruist policy",
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.MinSweepFeeRate,
},
MaxUpdates: 1,
},
},
{
name: "valid altruist policy with default sweep rate",
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 1,
},
},
{
name: "valid default policy",
policy: wtpolicy.DefaultPolicy(),
},
}
// TestPolicyValidate asserts that the sanity checks for policies behave as
// expected.
func TestPolicyValidate(t *testing.T) {
for i := range validationTests {
test := validationTests[i]
t.Run(test.name, func(t *testing.T) {
err := test.policy.Validate()
if err != test.expErr {
t.Fatalf("validation error mismatch, "+
"want: %v, got: %v", test.expErr, err)
}
})
}
}

@ -54,6 +54,17 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
) )
} }
// If the request asks for a reward session and the tower has them
// disabled, we will reject the request.
if s.cfg.DisableReward && req.BlobType.Has(blob.FlagReward) {
log.Debugf("Rejecting CreateSession from %s, reward "+
"sessions disabled", id)
return s.replyCreateSession(
peer, id, wtwire.CreateSessionCodeRejectBlobType, 0,
nil,
)
}
// Now that we've established that this session does not exist in the // Now that we've established that this session does not exist in the
// database, retrieve the sweep address that will be given to the // database, retrieve the sweep address that will be given to the
// client. This address is to be included by the client when signing // client. This address is to be included by the client when signing
@ -89,11 +100,13 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
info := wtdb.SessionInfo{ info := wtdb.SessionInfo{
ID: *id, ID: *id,
Policy: wtpolicy.Policy{ Policy: wtpolicy.Policy{
BlobType: req.BlobType, TxPolicy: wtpolicy.TxPolicy{
MaxUpdates: req.MaxUpdates, BlobType: req.BlobType,
RewardBase: req.RewardBase, RewardBase: req.RewardBase,
RewardRate: req.RewardRate, RewardRate: req.RewardRate,
SweepFeeRate: req.SweepFeeRate, SweepFeeRate: req.SweepFeeRate,
},
MaxUpdates: req.MaxUpdates,
}, },
RewardAddress: rewardScript, RewardAddress: rewardScript,
} }

@ -63,6 +63,10 @@ type Config struct {
// NoAckUpdates causes the server to not acknowledge state updates, this // NoAckUpdates causes the server to not acknowledge state updates, this
// should only be used for testing. // should only be used for testing.
NoAckUpdates bool NoAckUpdates bool
// DisableReward causes the server to reject any session creation
// attempts that request rewards.
DisableReward bool
} }
// Server houses the state required to handle watchtower peers. It's primary job // Server houses the state required to handle watchtower peers. It's primary job
@ -92,7 +96,7 @@ type Server struct {
// sessions and send state updates. // sessions and send state updates.
func New(cfg *Config) (*Server, error) { func New(cfg *Config) (*Server, error) {
localInit := wtwire.NewInitMessage( localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
cfg.ChainHash, cfg.ChainHash,
) )

@ -28,7 +28,7 @@ var (
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
rewardType = (blob.FlagCommitOutputs | blob.FlagReward).Type() testBlob = make([]byte, blob.Size(blob.TypeAltruistCommit))
) )
// randPubKey generates a new secp keypair, and returns the public key. // randPubKey generates a new secp keypair, and returns the public key.
@ -168,11 +168,11 @@ var createSessionTests = []createSessionTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 1000, MaxUpdates: 1000,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
expReply: &wtwire.CreateSessionReply{ expReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK, Code: wtwire.CodeOK,
@ -190,11 +190,11 @@ var createSessionTests = []createSessionTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 1000, MaxUpdates: 1000,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
expReply: &wtwire.CreateSessionReply{ expReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK, Code: wtwire.CodeOK,
@ -214,11 +214,11 @@ var createSessionTests = []createSessionTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: rewardType, BlobType: blob.TypeRewardCommit,
MaxUpdates: 1000, MaxUpdates: 1000,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
expReply: &wtwire.CreateSessionReply{ expReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK, Code: wtwire.CodeOK,
@ -240,7 +240,7 @@ var createSessionTests = []createSessionTestCase{
MaxUpdates: 1000, MaxUpdates: 1000,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
expReply: &wtwire.CreateSessionReply{ expReply: &wtwire.CreateSessionReply{
Code: wtwire.CreateSessionCodeRejectBlobType, Code: wtwire.CreateSessionCodeRejectBlobType,
@ -302,8 +302,9 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, s, peer, test.initMsg, timeoutDuration) connect(t, s, peer, test.initMsg, timeoutDuration)
update := &wtwire.StateUpdate{ update := &wtwire.StateUpdate{
SeqNum: 1, SeqNum: 1,
IsComplete: 1, IsComplete: 1,
EncryptedBlob: testBlob,
} }
sendMsg(t, update, peer, timeoutDuration) sendMsg(t, update, peer, timeoutDuration)
@ -325,8 +326,8 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
// Ensure that the server's reply matches our expected response for a // Ensure that the server's reply matches our expected response for a
// duplicate send. // duplicate send.
if !reflect.DeepEqual(reply, test.expDupReply) { if !reflect.DeepEqual(reply, test.expDupReply) {
t.Fatalf("[test %d] expected reply %v, got %d", t.Fatalf("[test %d] expected reply %v, got %v",
i, test.expReply, reply) i, test.expDupReply, reply)
} }
// Finally, check that the server tore down the connection. // Finally, check that the server tore down the connection.
@ -350,17 +351,17 @@ var stateUpdateTests = []stateUpdateTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 3, MaxUpdates: 3,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
updates: []*wtwire.StateUpdate{ updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0}, {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 2, LastApplied: 1}, {SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob},
{SeqNum: 3, LastApplied: 2}, {SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob},
{SeqNum: 3, LastApplied: 3}, {SeqNum: 3, LastApplied: 3, EncryptedBlob: testBlob},
}, },
replies: []*wtwire.StateUpdateReply{ replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1}, {Code: wtwire.CodeOK, LastApplied: 1},
@ -380,14 +381,14 @@ var stateUpdateTests = []stateUpdateTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4, MaxUpdates: 4,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
updates: []*wtwire.StateUpdate{ updates: []*wtwire.StateUpdate{
{SeqNum: 2, LastApplied: 0}, {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
}, },
replies: []*wtwire.StateUpdateReply{ replies: []*wtwire.StateUpdateReply{
{ {
@ -404,16 +405,16 @@ var stateUpdateTests = []stateUpdateTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4, MaxUpdates: 4,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
updates: []*wtwire.StateUpdate{ updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0}, {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 2, LastApplied: 0}, {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 1, LastApplied: 0}, {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
}, },
replies: []*wtwire.StateUpdateReply{ replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1}, {Code: wtwire.CodeOK, LastApplied: 1},
@ -432,17 +433,17 @@ var stateUpdateTests = []stateUpdateTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4, MaxUpdates: 4,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
updates: []*wtwire.StateUpdate{ updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0}, {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 2, LastApplied: 1}, {SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob},
{SeqNum: 3, LastApplied: 2}, {SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob},
{SeqNum: 4, LastApplied: 1}, {SeqNum: 4, LastApplied: 1, EncryptedBlob: testBlob},
}, },
replies: []*wtwire.StateUpdateReply{ replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1}, {Code: wtwire.CodeOK, LastApplied: 1},
@ -460,18 +461,18 @@ var stateUpdateTests = []stateUpdateTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4, MaxUpdates: 4,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
updates: []*wtwire.StateUpdate{ updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0}, {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 2, LastApplied: 1}, {SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob},
nil, // Wait for read timeout to drop conn, then reconnect. nil, // Wait for read timeout to drop conn, then reconnect.
{SeqNum: 3, LastApplied: 2}, {SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob},
{SeqNum: 4, LastApplied: 3}, {SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob},
}, },
replies: []*wtwire.StateUpdateReply{ replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1}, {Code: wtwire.CodeOK, LastApplied: 1},
@ -490,18 +491,18 @@ var stateUpdateTests = []stateUpdateTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4, MaxUpdates: 4,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
updates: []*wtwire.StateUpdate{ updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0}, {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 2, LastApplied: 0}, {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
nil, // Wait for read timeout to drop conn, then reconnect. nil, // Wait for read timeout to drop conn, then reconnect.
{SeqNum: 3, LastApplied: 0}, {SeqNum: 3, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 4, LastApplied: 3}, {SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob},
}, },
replies: []*wtwire.StateUpdateReply{ replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1}, {Code: wtwire.CodeOK, LastApplied: 1},
@ -520,19 +521,19 @@ var stateUpdateTests = []stateUpdateTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4, MaxUpdates: 4,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
updates: []*wtwire.StateUpdate{ updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0}, {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 2, LastApplied: 0}, {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
nil, // Wait for read timeout to drop conn, then reconnect. nil, // Wait for read timeout to drop conn, then reconnect.
{SeqNum: 2, LastApplied: 0}, {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 3, LastApplied: 0}, {SeqNum: 3, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 4, LastApplied: 3}, {SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob},
}, },
replies: []*wtwire.StateUpdateReply{ replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1}, {Code: wtwire.CodeOK, LastApplied: 1},
@ -551,17 +552,17 @@ var stateUpdateTests = []stateUpdateTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 3, MaxUpdates: 3,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
updates: []*wtwire.StateUpdate{ updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0}, {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
{SeqNum: 2, LastApplied: 1}, {SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob},
{SeqNum: 3, LastApplied: 2}, {SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob},
{SeqNum: 4, LastApplied: 3}, {SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob},
}, },
replies: []*wtwire.StateUpdateReply{ replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1}, {Code: wtwire.CodeOK, LastApplied: 1},
@ -581,14 +582,14 @@ var stateUpdateTests = []stateUpdateTestCase{
testnetChainHash, testnetChainHash,
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 3, MaxUpdates: 3,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
}, },
updates: []*wtwire.StateUpdate{ updates: []*wtwire.StateUpdate{
{SeqNum: 0, LastApplied: 0}, {SeqNum: 0, LastApplied: 0, EncryptedBlob: testBlob},
}, },
replies: []*wtwire.StateUpdateReply{ replies: []*wtwire.StateUpdateReply{
{ {
@ -718,11 +719,11 @@ func TestServerDeleteSession(t *testing.T) {
) )
createSession := &wtwire.CreateSession{ createSession := &wtwire.CreateSession{
BlobType: blob.TypeDefault, BlobType: blob.TypeAltruistCommit,
MaxUpdates: 1000, MaxUpdates: 1000,
RewardBase: 0, RewardBase: 0,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 10000,
} }
const timeoutDuration = 100 * time.Millisecond const timeoutDuration = 100 * time.Millisecond

@ -5,18 +5,18 @@ import "github.com/lightningnetwork/lnd/lnwire"
// FeatureNames holds a mapping from each feature bit understood by this // FeatureNames holds a mapping from each feature bit understood by this
// implementation to its common name. // implementation to its common name.
var FeatureNames = map[lnwire.FeatureBit]string{ var FeatureNames = map[lnwire.FeatureBit]string{
WtSessionsRequired: "wt-sessions", AltruistSessionsRequired: "altruist-sessions",
WtSessionsOptional: "wt-sessions", AltruistSessionsOptional: "altruist-sessions",
} }
const ( const (
// WtSessionsRequired specifies that the advertising node requires the // AltruistSessionsRequired specifies that the advertising node requires
// remote party to understand the protocol for creating and updating // the remote party to understand the protocol for creating and updating
// watchtower sessions. // watchtower sessions.
WtSessionsRequired lnwire.FeatureBit = 8 AltruistSessionsRequired lnwire.FeatureBit = 0
// WtSessionsOptional specifies that the advertising node can support // AltruistSessionsOptional specifies that the advertising node can
// a remote party who understand the protocol for creating and updating // support a remote party who understand the protocol for creating and
// watchtower sessions. // updating watchtower sessions.
WtSessionsOptional lnwire.FeatureBit = 9 AltruistSessionsOptional lnwire.FeatureBit = 1
) )

@ -26,37 +26,37 @@ type checkRemoteInitTest struct {
var checkRemoteInitTests = []checkRemoteInitTest{ var checkRemoteInitTests = []checkRemoteInitTest{
{ {
name: "same chain, local-optional remote-required", name: "same chain, local-optional remote-required",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
lHash: testnetChainHash, lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
rHash: testnetChainHash, rHash: testnetChainHash,
}, },
{ {
name: "same chain, local-required remote-optional", name: "same chain, local-required remote-optional",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
lHash: testnetChainHash, lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
rHash: testnetChainHash, rHash: testnetChainHash,
}, },
{ {
name: "different chain, local-optional remote-required", name: "different chain, local-optional remote-required",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
lHash: testnetChainHash, lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
rHash: mainnetChainHash, rHash: mainnetChainHash,
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash), expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
}, },
{ {
name: "different chain, local-required remote-optional", name: "different chain, local-required remote-optional",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
lHash: testnetChainHash, lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
rHash: mainnetChainHash, rHash: mainnetChainHash,
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash), expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
}, },
{ {
name: "same chain, remote-unknown-required", name: "same chain, remote-unknown-required",
lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
lHash: testnetChainHash, lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired),
rHash: testnetChainHash, rHash: testnetChainHash,

@ -24,29 +24,29 @@ type MessageType uint16
// Watchtower protocol. // Watchtower protocol.
const ( const (
// MsgInit identifies an encoded Init message. // MsgInit identifies an encoded Init message.
MsgInit MessageType = 300 MsgInit MessageType = 600
// MsgError identifies an encoded Error message. // MsgError identifies an encoded Error message.
MsgError = 301 MsgError MessageType = 601
// MsgCreateSession identifies an encoded CreateSession message. // MsgCreateSession identifies an encoded CreateSession message.
MsgCreateSession MessageType = 302 MsgCreateSession MessageType = 602
// MsgCreateSessionReply identifies an encoded CreateSessionReply message. // MsgCreateSessionReply identifies an encoded CreateSessionReply message.
MsgCreateSessionReply MessageType = 303 MsgCreateSessionReply MessageType = 603
// MsgStateUpdate identifies an encoded StateUpdate message. // MsgStateUpdate identifies an encoded StateUpdate message.
MsgStateUpdate MessageType = 304 MsgStateUpdate MessageType = 604
// MsgStateUpdateReply identifies an encoded StateUpdateReply message. // MsgStateUpdateReply identifies an encoded StateUpdateReply message.
MsgStateUpdateReply MessageType = 305 MsgStateUpdateReply MessageType = 605
// MsgDeleteSession identifies an encoded DeleteSession message. // MsgDeleteSession identifies an encoded DeleteSession message.
MsgDeleteSession MessageType = 306 MsgDeleteSession MessageType = 606
// MsgDeleteSessionReply identifies an encoded DeleteSessionReply // MsgDeleteSessionReply identifies an encoded DeleteSessionReply
// message. // message.
MsgDeleteSessionReply MessageType = 307 MsgDeleteSessionReply MessageType = 607
) )
// String returns a human readable description of the message type. // String returns a human readable description of the message type.