diff --git a/config.go b/config.go index 3cce0d55..27a59988 100644 --- a/config.go +++ b/config.go @@ -32,6 +32,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/tor" + "github.com/lightningnetwork/lnd/watchtower" ) const ( @@ -39,6 +40,7 @@ const ( defaultDataDirname = "data" defaultChainSubDirname = "chain" defaultGraphSubDirname = "graph" + defaultTowerSubDirname = "watchtower" defaultTLSCertFilename = "tls.cert" defaultTLSKeyFilename = "tls.key" defaultAdminMacFilename = "admin.macaroon" @@ -132,6 +134,8 @@ var ( defaultDataDir = filepath.Join(defaultLndDir, defaultDataDirname) defaultLogDir = filepath.Join(defaultLndDir, defaultLogDirname) + defaultTowerDir = filepath.Join(defaultDataDir, defaultTowerSubDirname) + defaultTLSCertPath = filepath.Join(defaultLndDir, defaultTLSCertFilename) defaultTLSKeyPath = filepath.Join(defaultLndDir, defaultTLSKeyFilename) @@ -315,6 +319,10 @@ type config struct { Caches *lncfg.Caches `group:"caches" namespace:"caches"` Prometheus lncfg.Prometheus `group:"prometheus" namespace:"prometheus"` + + WtClient *lncfg.WtClient `group:"wtclient" namespace:"wtclient"` + + Watchtower *lncfg.Watchtower `group:"watchtower" namespace:"watchtower"` } // loadConfig initializes and parses the config using a config file and command @@ -410,6 +418,9 @@ func loadConfig() (*config, error) { ChannelCacheSize: channeldb.DefaultChannelCacheSize, }, Prometheus: lncfg.DefaultPrometheus(), + Watchtower: &lncfg.Watchtower{ + TowerDir: defaultTowerDir, + }, } // Pre-parse the command line options to pick up an alternative config @@ -470,6 +481,14 @@ func loadConfig() (*config, error) { cfg.TLSCertPath = filepath.Join(lndDir, defaultTLSCertFilename) cfg.TLSKeyPath = filepath.Join(lndDir, defaultTLSKeyFilename) cfg.LogDir = filepath.Join(lndDir, defaultLogDirname) + + // If the watchtower's directory is set to the default, i.e. the + // user has not requested a different location, we'll move the + // location to be relative to the specified lnd directory. + if cfg.Watchtower.TowerDir == defaultTowerDir { + cfg.Watchtower.TowerDir = + filepath.Join(cfg.DataDir, defaultTowerSubDirname) + } } // Create the lnd directory if it doesn't already exist. @@ -506,6 +525,7 @@ func loadConfig() (*config, error) { cfg.BitcoindMode.Dir = cleanAndExpandPath(cfg.BitcoindMode.Dir) cfg.LitecoindMode.Dir = cleanAndExpandPath(cfg.LitecoindMode.Dir) cfg.Tor.PrivateKeyPath = cleanAndExpandPath(cfg.Tor.PrivateKeyPath) + cfg.Watchtower.TowerDir = cleanAndExpandPath(cfg.Watchtower.TowerDir) // Ensure that the user didn't attempt to specify negative values for // any of the autopilot params. @@ -1051,15 +1071,27 @@ func loadConfig() (*config, error) { "minbackoff") } - // Validate the subconfigs for workers and caches. + // Validate the subconfigs for workers, caches, and the tower client. err = lncfg.Validate( cfg.Workers, cfg.Caches, + cfg.WtClient, ) if err != nil { return nil, err } + // If the user provided private watchtower addresses, parse them to + // obtain the LN addresses. + if cfg.WtClient.IsActive() { + err := cfg.WtClient.ParsePrivateTowers( + watchtower.DefaultPeerPort, cfg.net.ResolveTCPAddr, + ) + if err != nil { + return nil, err + } + } + // Finally, ensure that the user's color is correctly formatted, // otherwise the server will not be able to start after the unlocking // the wallet. diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 68350f62..4dff21a5 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -6,6 +6,7 @@ import ( "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" ) @@ -159,3 +160,20 @@ type ForwardingLog interface { // visualizations, etc. AddForwardingEvents([]channeldb.ForwardingEvent) error } + +// TowerClient is the primary interface used by the daemon to backup pre-signed +// justice transactions to watchtowers. +type TowerClient interface { + // RegisterChannel persistently initializes any channel-dependent + // parameters within the client. This should be called during link + // startup to ensure that the client is able to support the link during + // operation. + RegisterChannel(lnwire.ChannelID) error + + // BackupState initiates a request to back up a particular revoked + // state. If the method returns nil, the backup is guaranteed to be + // successful unless the tower is unavailable and client is force quit, + // or the justice transaction would create dust outputs when trying to + // abide by the negotiated policy. + BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution) error +} diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 6f3a8316..15c65ec4 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -229,10 +229,14 @@ type ChannelLinkConfig struct { // receiving node is persistent. UnsafeReplay bool - // MinFeeUpdateTimeout and MaxFeeUpdateTimeout represent the timeout - // interval bounds in which a link will propose to update its commitment - // fee rate. A random timeout will be selected between these values. + // MinFeeUpdateTimeout represents the minimum interval in which a link + // will propose to update its commitment fee rate. A random timeout will + // be selected between this and MaxFeeUpdateTimeout. MinFeeUpdateTimeout time.Duration + + // MaxFeeUpdateTimeout represents the maximum interval in which a link + // will propose to update its commitment fee rate. A random timeout will + // be selected between this and MinFeeUpdateTimeout. MaxFeeUpdateTimeout time.Duration // OutgoingCltvRejectDelta defines the number of blocks before expiry of @@ -240,6 +244,11 @@ type ChannelLinkConfig struct { // the outgoing broadcast delta, because in any case we don't want to // risk offering an htlc that triggers channel closure. OutgoingCltvRejectDelta uint32 + + // TowerClient is an optional engine that manages the signing, + // encrypting, and uploading of justice transactions to the daemon's + // configured set of watchtowers. + TowerClient TowerClient } // channelLink is the service which drives a channel's commitment update @@ -396,6 +405,15 @@ func (l *channelLink) Start() error { log.Infof("ChannelLink(%v) is starting", l) + // If the config supplied watchtower client, ensure the channel is + // registered before trying to use it during operation. + if l.cfg.TowerClient != nil { + err := l.cfg.TowerClient.RegisterChannel(l.ChanID()) + if err != nil { + return err + } + } + l.mailBox.ResetMessages() l.overflowQueue.Start() l.hodlQueue.Start() @@ -1786,6 +1804,28 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { return } + // If we have a tower client, we'll proceed in backing up the + // state that was just revoked. + if l.cfg.TowerClient != nil { + state := l.channel.State() + breachInfo, err := lnwallet.NewBreachRetribution( + state, state.RemoteCommitment.CommitHeight-1, 0, + ) + if err != nil { + l.fail(LinkFailureError{code: ErrInternalError}, + "failed to load breach info: %v", err) + return + } + + chanID := l.ChanID() + err = l.cfg.TowerClient.BackupState(&chanID, breachInfo) + if err != nil { + l.fail(LinkFailureError{code: ErrInternalError}, + "unable to queue breach backup: %v", err) + return + } + } + l.processRemoteSettleFails(fwdPkg, settleFails) needUpdate := l.processRemoteAdds(fwdPkg, adds) diff --git a/lncfg/address.go b/lncfg/address.go index 66bca821..5a5a07fd 100644 --- a/lncfg/address.go +++ b/lncfg/address.go @@ -18,12 +18,14 @@ var ( loopBackAddrs = []string{"localhost", "127.0.0.1", "[::1]"} ) -type tcpResolver = func(network, addr string) (*net.TCPAddr, error) +// TCPResolver is a function signature that resolves an address on a given +// network. +type TCPResolver = func(network, addr string) (*net.TCPAddr, error) // NormalizeAddresses returns a new slice with all the passed addresses // normalized with the given default port and all duplicates removed. func NormalizeAddresses(addrs []string, defaultPort string, - tcpResolver tcpResolver) ([]net.Addr, error) { + tcpResolver TCPResolver) ([]net.Addr, error) { result := make([]net.Addr, 0, len(addrs)) seen := map[string]struct{}{} @@ -120,7 +122,7 @@ func IsUnix(addr net.Addr) bool { // connections. We accept a custom function to resolve any TCP addresses so // that caller is able control exactly how resolution is performed. func ParseAddressString(strAddress string, defaultPort string, - tcpResolver tcpResolver) (net.Addr, error) { + tcpResolver TCPResolver) (net.Addr, error) { var parsedNetwork, parsedAddr string @@ -188,9 +190,9 @@ func ParseAddressString(strAddress string, defaultPort string, // 33-byte, compressed public key that lies on the secp256k1 curve. The // may be any address supported by ParseAddressString. If no port is specified, // the defaultPort will be used. Any tcp addresses that need resolving will be -// resolved using the custom tcpResolver. +// resolved using the custom TCPResolver. func ParseLNAddressString(strAddress string, defaultPort string, - tcpResolver tcpResolver) (*lnwire.NetAddress, error) { + tcpResolver TCPResolver) (*lnwire.NetAddress, error) { // Split the address string around the @ sign. parts := strings.Split(strAddress, "@") diff --git a/lncfg/watchtower.go b/lncfg/watchtower.go new file mode 100644 index 00000000..d03d3695 --- /dev/null +++ b/lncfg/watchtower.go @@ -0,0 +1,13 @@ +package lncfg + +import "github.com/lightningnetwork/lnd/watchtower" + +// Watchtower holds the daemon specific configuration parameters for running a +// watchtower that shares resources with the daemon. +type Watchtower struct { + Active bool `long:"active" description:"If the watchtower should be active or not"` + + TowerDir string `long:"towerdir" description:"Directory of the watchtower.db"` + + watchtower.Conf +} diff --git a/lncfg/wtclient.go b/lncfg/wtclient.go new file mode 100644 index 00000000..a4c424fd --- /dev/null +++ b/lncfg/wtclient.go @@ -0,0 +1,65 @@ +package lncfg + +import ( + "fmt" + "strconv" + + "github.com/lightningnetwork/lnd/lnwire" +) + +// WtClient holds the configuration options for the daemon's watchtower client. +type WtClient struct { + // PrivateTowerURIs specifies the lightning URIs of the towers the + // watchtower client should send new backups to. + PrivateTowerURIs []string `long:"private-tower-uris" description:"Specifies the URIs of private watchtowers to use in backing up revoked states. URIs must be of the form @. Only 1 URI is supported at this time, if none are provided the tower will not be enabled."` + + // PrivateTowers is the list of towers parsed from the URIs provided in + // PrivateTowerURIs. + PrivateTowers []*lnwire.NetAddress + + // SweepFeeRate specifies the fee rate in sat/byte to be used when + // constructing justice transactions sent to the tower. + SweepFeeRate uint64 `long:"sweep-fee-rate" description:"Specifies the fee rate in sat/byte to be used when constructing justice transactions sent to the watchtower."` +} + +// Validate asserts that at most 1 private watchtower is requested. +// +// NOTE: Part of the Validator interface. +func (c *WtClient) Validate() error { + if len(c.PrivateTowerURIs) > 1 { + return fmt.Errorf("at most 1 private watchtower is supported, "+ + "found %d", len(c.PrivateTowerURIs)) + } + + return nil +} + +// IsActive returns true if the watchtower client should be active. +func (c *WtClient) IsActive() bool { + return len(c.PrivateTowerURIs) > 0 +} + +// ParsePrivateTowers parses any private tower URIs held PrivateTowerURIs. The +// value of port should be the default port to use when a URI does not have one. +func (c *WtClient) ParsePrivateTowers(port int, resolver TCPResolver) error { + towers := make([]*lnwire.NetAddress, 0, len(c.PrivateTowerURIs)) + for _, uri := range c.PrivateTowerURIs { + addr, err := ParseLNAddressString( + uri, strconv.Itoa(port), resolver, + ) + if err != nil { + return fmt.Errorf("unable to parse private "+ + "watchtower address: %v", err) + } + + towers = append(towers, addr) + } + + c.PrivateTowers = towers + + return nil +} + +// Compile-time constraint to ensure WtClient implements the Validator +// interface. +var _ Validator = (*WtClient)(nil) diff --git a/lnd.go b/lnd.go index 830884d5..28b3d511 100644 --- a/lnd.go +++ b/lnd.go @@ -36,6 +36,7 @@ import ( "google.golang.org/grpc/credentials" "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcutil" "github.com/btcsuite/btcwallet/wallet" proxy "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/lightninglabs/neutrino" @@ -51,6 +52,8 @@ import ( "github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/walletunlocker" + "github.com/lightningnetwork/lnd/watchtower" + "github.com/lightningnetwork/lnd/watchtower/wtdb" ) const ( @@ -313,11 +316,65 @@ func Main() error { "is proxying over Tor as well", cfg.Tor.StreamIsolation) } + // If the watchtower client should be active, open the client database. + // This is done here so that Close always executes when lndMain returns. + var towerClientDB *wtdb.ClientDB + if cfg.WtClient.IsActive() { + var err error + towerClientDB, err = wtdb.OpenClientDB(graphDir) + if err != nil { + ltndLog.Errorf("Unable to open watchtower client db: %v", err) + } + defer towerClientDB.Close() + } + + var tower *watchtower.Standalone + if cfg.Watchtower.Active { + // Segment the watchtower directory by chain and network. + towerDBDir := filepath.Join( + cfg.Watchtower.TowerDir, + registeredChains.PrimaryChain().String(), + normalizeNetwork(activeNetParams.Name), + ) + + towerDB, err := wtdb.OpenTowerDB(towerDBDir) + if err != nil { + ltndLog.Errorf("Unable to open watchtower db: %v", err) + return err + } + defer towerDB.Close() + + wtConfig, err := cfg.Watchtower.Apply(&watchtower.Config{ + BlockFetcher: activeChainControl.chainIO, + DB: towerDB, + EpochRegistrar: activeChainControl.chainNotifier, + Net: cfg.net, + NewAddress: func() (btcutil.Address, error) { + return activeChainControl.wallet.NewAddress( + lnwallet.WitnessPubKey, false, + ) + }, + NodePrivKey: idPrivKey, + PublishTx: activeChainControl.wallet.PublishTransaction, + ChainHash: *activeNetParams.GenesisHash, + }, lncfg.NormalizeAddresses) + if err != nil { + ltndLog.Errorf("Unable to configure watchtower: %v", err) + return err + } + + tower, err = watchtower.New(wtConfig) + if err != nil { + ltndLog.Errorf("Unable to create watchtower: %v", err) + return err + } + } + // Set up the core server which will listen for incoming peer // connections. server, err := newServer( - cfg.Listeners, chanDB, activeChainControl, idPrivKey, - walletInitParams.ChansToRestore, + cfg.Listeners, chanDB, towerClientDB, activeChainControl, + idPrivKey, walletInitParams.ChansToRestore, ) if err != nil { srvrLog.Errorf("unable to create server: %v\n", err) @@ -418,6 +475,14 @@ func Main() error { } } + if cfg.Watchtower.Active { + if err := tower.Start(); err != nil { + ltndLog.Errorf("Unable to start watchtower: %v", err) + return err + } + defer tower.Stop() + } + // Wait for shutdown signal from either a graceful server stop or from // the interrupt handler. <-signal.ShutdownChannel() diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index a06bb8df..b34fe433 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -7606,6 +7606,353 @@ func testRevokedCloseRetributionRemoteHodl(net *lntest.NetworkHarness, assertNodeNumChannels(t, dave, 0) } +// testRevokedCloseRetributionAltruistWatchtower establishes a channel between +// Carol and Dave, where Carol is using a third node Willy as her watchtower. +// After sending some payments, Dave reverts his state and force closes to +// trigger a breach. Carol is kept offline throughout the process and the test +// asserts that Willy responds by broadcasting the justice transaction on +// Carol's behalf sweeping her funds without a reward. +func testRevokedCloseRetributionAltruistWatchtower(net *lntest.NetworkHarness, + t *harnessTest) { + + ctxb := context.Background() + const ( + chanAmt = lnd.MaxBtcFundingAmount + paymentAmt = 10000 + numInvoices = 6 + ) + + // Since we'd like to test some multi-hop failure scenarios, we'll + // introduce another node into our test network: Carol. + carol, err := net.NewNode("Carol", []string{ + "--debughtlc", "--hodl.exit-settle", + }) + if err != nil { + t.Fatalf("unable to create new nodes: %v", err) + } + defer shutdownAndAssert(net, t, carol) + + // Willy the watchtower will protect Dave from Carol's breach. He will + // remain online in order to punish Carol on Dave's behalf, since the + // breach will happen while Dave is offline. + willy, err := net.NewNode("Willy", []string{"--watchtower.active"}) + if err != nil { + t.Fatalf("unable to create new nodes: %v", err) + } + defer shutdownAndAssert(net, t, willy) + + ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) + willyInfo, err := willy.GetInfo(ctxt, &lnrpc.GetInfoRequest{}) + if err != nil { + t.Fatalf("unable to getinfo from willy: %v", err) + } + + willyAddr := willyInfo.Uris[0] + parts := strings.Split(willyAddr, ":") + willyTowerAddr := parts[0] + + // Dave will be the breached party. We set --nolisten to ensure Carol + // won't be able to connect to him and trigger the channel data + // protection logic automatically. + dave, err := net.NewNode("Dave", []string{ + "--nolisten", + "--wtclient.private-tower-uris=" + willyTowerAddr, + }) + if err != nil { + t.Fatalf("unable to create new node: %v", err) + } + defer shutdownAndAssert(net, t, dave) + + // We must let Dave have an open channel before she can send a node + // announcement, so we open a channel with Carol, + if err := net.ConnectNodes(ctxb, dave, carol); err != nil { + t.Fatalf("unable to connect dave to carol: %v", err) + } + + // Before we make a channel, we'll load up Dave with some coins sent + // directly from the miner. + err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, dave) + if err != nil { + t.Fatalf("unable to send coins to dave: %v", err) + } + + // In order to test Dave's response to an uncooperative channel + // closure by Carol, we'll first open up a channel between them with a + // 0.5 BTC value. + ctxt, _ = context.WithTimeout(ctxb, channelOpenTimeout) + chanPoint := openChannelAndAssert( + ctxt, t, net, dave, carol, + lntest.OpenChannelParams{ + Amt: 3 * (chanAmt / 4), + PushAmt: chanAmt / 4, + }, + ) + + // With the channel open, we'll create a few invoices for Carol that + // Dave will pay to in order to advance the state of the channel. + carolPayReqs, _, _, err := createPayReqs( + carol, paymentAmt, numInvoices, + ) + if err != nil { + t.Fatalf("unable to create pay reqs: %v", err) + } + + // Wait for Dave to receive the channel edge from the funding manager. + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + err = dave.WaitForNetworkChannelOpen(ctxt, chanPoint) + if err != nil { + t.Fatalf("dave didn't see the dave->carol channel before "+ + "timeout: %v", err) + } + + // Next query for Carol's channel state, as we sent 0 payments, Carol + // should still see her balance as the push amount, which is 1/4 of the + // capacity. + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + carolChan, err := getChanInfo(ctxt, carol) + if err != nil { + t.Fatalf("unable to get carol's channel info: %v", err) + } + if carolChan.LocalBalance != int64(chanAmt/4) { + t.Fatalf("carol's balance is incorrect, got %v, expected %v", + carolChan.LocalBalance, chanAmt/4) + } + + // Grab Carol's current commitment height (update number), we'll later + // revert her to this state after additional updates to force him to + // broadcast this soon to be revoked state. + carolStateNumPreCopy := carolChan.NumUpdates + + // Create a temporary file to house Carol's database state at this + // particular point in history. + carolTempDbPath, err := ioutil.TempDir("", "carol-past-state") + if err != nil { + t.Fatalf("unable to create temp db folder: %v", err) + } + carolTempDbFile := filepath.Join(carolTempDbPath, "channel.db") + defer os.Remove(carolTempDbPath) + + // With the temporary file created, copy Carol's current state into the + // temporary file we created above. Later after more updates, we'll + // restore this state. + if err := lntest.CopyFile(carolTempDbFile, carol.DBPath()); err != nil { + t.Fatalf("unable to copy database files: %v", err) + } + + // Finally, send payments from Dave to Carol, consuming Carol's remaining + // payment hashes. + err = completePaymentRequests(ctxb, dave, carolPayReqs, false) + if err != nil { + t.Fatalf("unable to send payments: %v", err) + } + + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + daveBalReq := &lnrpc.WalletBalanceRequest{} + daveBalResp, err := dave.WalletBalance(ctxt, daveBalReq) + if err != nil { + t.Fatalf("unable to get dave's balance: %v", err) + } + + davePreSweepBalance := daveBalResp.ConfirmedBalance + + // Shutdown Dave to simulate going offline for an extended period of + // time. Once he's not watching, Carol will try to breach the channel. + restart, err := net.SuspendNode(dave) + if err != nil { + t.Fatalf("unable to suspend Dave: %v", err) + } + + // Now we shutdown Carol, copying over the his temporary database state + // which has the *prior* channel state over his current most up to date + // state. With this, we essentially force Carol to travel back in time + // within the channel's history. + if err = net.RestartNode(carol, func() error { + return os.Rename(carolTempDbFile, carol.DBPath()) + }); err != nil { + t.Fatalf("unable to restart node: %v", err) + } + + // Now query for Carol's channel state, it should show that he's at a + // state number in the past, not the *latest* state. + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + carolChan, err = getChanInfo(ctxt, carol) + if err != nil { + t.Fatalf("unable to get carol chan info: %v", err) + } + if carolChan.NumUpdates != carolStateNumPreCopy { + t.Fatalf("db copy failed: %v", carolChan.NumUpdates) + } + + // TODO(conner): add hook for backup completion + time.Sleep(3 * time.Second) + + // Now force Carol to execute a *force* channel closure by unilaterally + // broadcasting his current channel state. This is actually the + // commitment transaction of a prior *revoked* state, so he'll soon + // feel the wrath of Dave's retribution. + closeUpdates, closeTxId, err := net.CloseChannel( + ctxb, carol, chanPoint, true, + ) + if err != nil { + t.Fatalf("unable to close channel: %v", err) + } + + // Query the mempool for the breaching closing transaction, this should + // be broadcast by Carol when she force closes the channel above. + txid, err := waitForTxInMempool(net.Miner.Node, minerMempoolTimeout) + if err != nil { + t.Fatalf("unable to find Carol's force close tx in mempool: %v", + err) + } + if *txid != *closeTxId { + t.Fatalf("expected closeTx(%v) in mempool, instead found %v", + closeTxId, txid) + } + + // Finally, generate a single block, wait for the final close status + // update, then ensure that the closing transaction was included in the + // block. + block := mineBlocks(t, net, 1, 1)[0] + + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + breachTXID, err := net.WaitForChannelClose(ctxt, closeUpdates) + if err != nil { + t.Fatalf("error while waiting for channel close: %v", err) + } + assertTxInBlock(t, block, breachTXID) + + // Query the mempool for Dave's justice transaction, this should be + // broadcast as Carol's contract breaching transaction gets confirmed + // above. + justiceTXID, err := waitForTxInMempool(net.Miner.Node, minerMempoolTimeout) + if err != nil { + t.Fatalf("unable to find Dave's justice tx in mempool: %v", + err) + } + time.Sleep(100 * time.Millisecond) + + // Query for the mempool transaction found above. Then assert that all + // the inputs of this transaction are spending outputs generated by + // Carol's breach transaction above. + justiceTx, err := net.Miner.Node.GetRawTransaction(justiceTXID) + if err != nil { + t.Fatalf("unable to query for justice tx: %v", err) + } + for _, txIn := range justiceTx.MsgTx().TxIn { + if !bytes.Equal(txIn.PreviousOutPoint.Hash[:], breachTXID[:]) { + t.Fatalf("justice tx not spending commitment utxo "+ + "instead is: %v", txIn.PreviousOutPoint) + } + } + + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + willyBalReq := &lnrpc.WalletBalanceRequest{} + willyBalResp, err := willy.WalletBalance(ctxt, willyBalReq) + if err != nil { + t.Fatalf("unable to get willy's balance: %v", err) + } + + if willyBalResp.ConfirmedBalance != 0 { + t.Fatalf("willy should have 0 balance before mining "+ + "justice transaction, instead has %d", + willyBalResp.ConfirmedBalance) + } + + // Now mine a block, this transaction should include Dave's justice + // transaction which was just accepted into the mempool. + block = mineBlocks(t, net, 1, 1)[0] + + // The block should have exactly *two* transactions, one of which is + // the justice transaction. + if len(block.Transactions) != 2 { + t.Fatalf("transaction wasn't mined") + } + justiceSha := block.Transactions[1].TxHash() + if !bytes.Equal(justiceTx.Hash()[:], justiceSha[:]) { + t.Fatalf("justice tx wasn't mined") + } + + // Ensure that Willy doesn't get any funds, as he is acting as an + // altruist watchtower. + var predErr error + err = lntest.WaitInvariant(func() bool { + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + willyBalReq := &lnrpc.WalletBalanceRequest{} + willyBalResp, err := willy.WalletBalance(ctxt, willyBalReq) + if err != nil { + t.Fatalf("unable to get willy's balance: %v", err) + } + + if willyBalResp.ConfirmedBalance != 0 { + predErr = fmt.Errorf("Expected Willy to have no funds "+ + "after justice transaction was mined, found %v", + willyBalResp) + return false + } + + return true + }, time.Second*5) + if err != nil { + t.Fatalf("%v", predErr) + } + + // Restart Dave, who will still think his channel with Carol is open. + // We should him to detect the breach, but realize that the funds have + // then been swept to his wallet by Willy. + err = restart() + if err != nil { + t.Fatalf("unable to restart dave: %v", err) + } + + err = lntest.WaitPredicate(func() bool { + ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + daveBalReq := &lnrpc.ChannelBalanceRequest{} + daveBalResp, err := dave.ChannelBalance(ctxt, daveBalReq) + if err != nil { + t.Fatalf("unable to get dave's balance: %v", err) + } + + if daveBalResp.Balance != 0 { + predErr = fmt.Errorf("Dave should end up with zero "+ + "channel balance, instead has %d", + daveBalResp.Balance) + return false + } + + return true + }, time.Second*15) + if err != nil { + t.Fatalf("%v", predErr) + } + + assertNumPendingChannels(t, dave, 0, 0) + + err = lntest.WaitPredicate(func() bool { + ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) + daveBalReq := &lnrpc.WalletBalanceRequest{} + daveBalResp, err := dave.WalletBalance(ctxt, daveBalReq) + if err != nil { + t.Fatalf("unable to get dave's balance: %v", err) + } + + if daveBalResp.ConfirmedBalance <= davePreSweepBalance { + predErr = fmt.Errorf("Dave should have more than %d "+ + "after sweep, instead has %d", + davePreSweepBalance, + daveBalResp.ConfirmedBalance) + return false + } + + return true + }, time.Second*15) + if err != nil { + t.Fatalf("%v", predErr) + } + + // Dave should have no open channels. + assertNodeNumChannels(t, dave, 0) +} + // assertNumPendingChannels checks that a PendingChannels response from the // node reports the expected number of pending channels. func assertNumPendingChannels(t *harnessTest, node *lntest.HarnessNode, @@ -13840,6 +14187,10 @@ var testsCases = []*testCase{ name: "revoked uncooperative close retribution remote hodl", test: testRevokedCloseRetributionRemoteHodl, }, + { + name: "revoked uncooperative close retribution altruist watchtower", + test: testRevokedCloseRetributionAltruistWatchtower, + }, { name: "data loss protection", test: testDataLossProtection, diff --git a/log.go b/log.go index c84fd7b2..b79fa047 100644 --- a/log.go +++ b/log.go @@ -34,6 +34,7 @@ import ( "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sweep" "github.com/lightningnetwork/lnd/watchtower" + "github.com/lightningnetwork/lnd/watchtower/wtclient" ) // Loggers per subsystem. A single backend logger is created and all subsystem @@ -87,6 +88,7 @@ var ( chnfLog = build.NewSubLogger("CHNF", backendLog.Logger) chbuLog = build.NewSubLogger("CHBU", backendLog.Logger) promLog = build.NewSubLogger("PROM", backendLog.Logger) + wtclLog = build.NewSubLogger("WTCL", backendLog.Logger) ) // Initialize package-global logger variables. @@ -115,6 +117,7 @@ func init() { channelnotifier.UseLogger(chnfLog) chanbackup.UseLogger(chbuLog) monitoring.UseLogger(promLog) + wtclient.UseLogger(wtclLog) addSubLogger(routerrpc.Subsystem, routerrpc.UseLogger) } @@ -159,6 +162,7 @@ var subsystemLoggers = map[string]btclog.Logger{ "CHNF": chnfLog, "CHBU": chbuLog, "PROM": promLog, + "WTCL": wtclLog, } // initLogRotator initializes the logging rotator to write logs to logFile and diff --git a/peer.go b/peer.go index 0b36f63a..7cbd824e 100644 --- a/peer.go +++ b/peer.go @@ -594,6 +594,7 @@ func (p *peer) addLink(chanPoint *wire.OutPoint, MinFeeUpdateTimeout: htlcswitch.DefaultMinLinkFeeUpdateTimeout, MaxFeeUpdateTimeout: htlcswitch.DefaultMaxLinkFeeUpdateTimeout, OutgoingCltvRejectDelta: p.outgoingCltvRejectDelta, + TowerClient: p.server.towerClient, } link := htlcswitch.NewChannelLink(linkCfg, lnChan) diff --git a/rpcserver.go b/rpcserver.go index 3dcd5868..71498f7e 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -2173,6 +2173,9 @@ func (r *rpcServer) ChannelBalance(ctx context.Context, pendingOpenBalance += channel.LocalCommitment.LocalBalance.ToSatoshis() } + rpcsLog.Debugf("[channelbalance] balance=%v pending-open=%v", + balance, pendingOpenBalance) + return &lnrpc.ChannelBalanceResponse{ Balance: int64(balance), PendingOpenBalance: int64(pendingOpenBalance), diff --git a/server.go b/server.go index dd91d5bd..b0c2350f 100644 --- a/server.go +++ b/server.go @@ -20,6 +20,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/connmgr" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/coreos/bbolt" @@ -50,6 +51,9 @@ import ( "github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/walletunlocker" + "github.com/lightningnetwork/lnd/watchtower/wtclient" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/zpay32" ) @@ -204,6 +208,8 @@ type server struct { sphinx *htlcswitch.OnionProcessor + towerClient wtclient.Client + connMgr *connmgr.ConnManager sigPool *lnwallet.SigPool @@ -282,7 +288,8 @@ func noiseDial(idPriv *btcec.PrivateKey) func(net.Addr) (net.Conn, error) { // newServer creates a new instance of the server which is to listen using the // passed listener address. -func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, +func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, + towerClientDB *wtdb.ClientDB, cc *chainControl, privKey *btcec.PrivateKey, chansToRestore walletunlocker.ChannelsToRecover) (*server, error) { @@ -723,10 +730,8 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, } s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{ - FeeEstimator: cc.feeEstimator, - GenSweepScript: func() ([]byte, error) { - return newSweepPkScript(cc.wallet) - }, + FeeEstimator: cc.feeEstimator, + GenSweepScript: newSweepPkScriptGen(cc.wallet), Signer: cc.wallet.Cfg.Signer, PublishTransaction: cc.wallet.PublishTransaction, NewBatchTimer: func() <-chan time.Time { @@ -769,10 +774,8 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, ChainHash: *activeNetParams.GenesisHash, IncomingBroadcastDelta: DefaultIncomingBroadcastDelta, OutgoingBroadcastDelta: DefaultOutgoingBroadcastDelta, - NewSweepAddr: func() ([]byte, error) { - return newSweepPkScript(cc.wallet) - }, - PublishTx: cc.wallet.PublishTransaction, + NewSweepAddr: newSweepPkScriptGen(cc.wallet), + PublishTx: cc.wallet.PublishTransaction, DeliverResolutionMsg: func(msgs ...contractcourt.ResolutionMsg) error { for _, msg := range msgs { err := s.htlcSwitch.ProcessContractResolution(msg) @@ -845,12 +848,10 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, }, chanDB) s.breachArbiter = newBreachArbiter(&BreachConfig{ - CloseLink: closeLink, - DB: chanDB, - Estimator: s.cc.feeEstimator, - GenSweepScript: func() ([]byte, error) { - return newSweepPkScript(cc.wallet) - }, + CloseLink: closeLink, + DB: chanDB, + Estimator: s.cc.feeEstimator, + GenSweepScript: newSweepPkScriptGen(cc.wallet), Notifier: cc.chainNotifier, PublishTransaction: cc.wallet.PublishTransaction, ContractBreaches: contractBreaches, @@ -1056,6 +1057,41 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, return nil, err } + if cfg.WtClient.IsActive() { + policy := wtpolicy.DefaultPolicy() + + if cfg.WtClient.SweepFeeRate != 0 { + // We expose the sweep fee rate in sat/byte, but the + // tower protocol operations on sat/kw. + sweepRateSatPerByte := lnwallet.SatPerKVByte( + 1000 * cfg.WtClient.SweepFeeRate, + ) + policy.SweepFeeRate = sweepRateSatPerByte.FeePerKWeight() + } + + if err := policy.Validate(); err != nil { + return nil, err + } + + s.towerClient, err = wtclient.New(&wtclient.Config{ + Signer: cc.wallet.Cfg.Signer, + NewAddress: newSweepPkScriptGen(cc.wallet), + SecretKeyRing: s.cc.keyRing, + Dial: cfg.net.Dial, + AuthDial: wtclient.AuthDial, + DB: towerClientDB, + Policy: wtpolicy.DefaultPolicy(), + PrivateTower: cfg.WtClient.PrivateTowers[0], + ChainHash: *activeNetParams.GenesisHash, + MinBackoff: 10 * time.Second, + MaxBackoff: 5 * time.Minute, + ForceQuitDelay: wtclient.DefaultForceQuitDelay, + }) + if err != nil { + return nil, err + } + } + // Create the connection manager which will be responsible for // maintaining persistent outbound connections and also accepting new // incoming connections @@ -1128,6 +1164,12 @@ func (s *server) Start() error { startErr = err return } + if s.towerClient != nil { + if err := s.towerClient.Start(); err != nil { + startErr = err + return + } + } if err := s.htlcSwitch.Start(); err != nil { startErr = err return @@ -1290,6 +1332,14 @@ func (s *server) Stop() error { s.DisconnectPeer(peer.addr.IdentityKey) } + // Now that all connections have been torn down, stop the tower + // client which will reliably flush all queued states to the + // tower. If this is halted for any reason, the force quit timer + // will kick in and abort to allow this method to return. + if s.towerClient != nil { + s.towerClient.Stop() + } + // Wait for all lingering goroutines to quit. s.wg.Wait() @@ -3180,3 +3230,20 @@ func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate) error { return ErrServerShuttingDown } } + +// newSweepPkScriptGen creates closure that generates a new public key script +// which should be used to sweep any funds into the on-chain wallet. +// Specifically, the script generated is a version 0, pay-to-witness-pubkey-hash +// (p2wkh) output. +func newSweepPkScriptGen( + wallet lnwallet.WalletController) func() ([]byte, error) { + + return func() ([]byte, error) { + sweepAddr, err := wallet.NewAddress(lnwallet.WitnessPubKey, false) + if err != nil { + return nil, err + } + + return txscript.PayToAddrScript(sweepAddr) + } +} diff --git a/utxonursery.go b/utxonursery.go index bfdc075a..0e4501b6 100644 --- a/utxonursery.go +++ b/utxonursery.go @@ -8,7 +8,6 @@ import ( "sync" "sync/atomic" - "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" @@ -1231,19 +1230,6 @@ func (u *utxoNursery) closeAndRemoveIfMature(chanPoint *wire.OutPoint) error { return nil } -// newSweepPkScript creates a new public key script which should be used to -// sweep any time-locked, or contested channel funds into the wallet. -// Specifically, the script generated is a version 0, pay-to-witness-pubkey-hash -// (p2wkh) output. -func newSweepPkScript(wallet lnwallet.WalletController) ([]byte, error) { - sweepAddr, err := wallet.NewAddress(lnwallet.WitnessPubKey, false) - if err != nil { - return nil, err - } - - return txscript.PayToAddrScript(sweepAddr) -} - // babyOutput represents a two-stage CSV locked output, and is used to track // htlc outputs through incubation. The first stage requires broadcasting a // presigned timeout txn that spends from the CLTV locked output on the diff --git a/watchtower/blob/derivation.go b/watchtower/blob/derivation.go new file mode 100644 index 00000000..d5427942 --- /dev/null +++ b/watchtower/blob/derivation.go @@ -0,0 +1,70 @@ +package blob + +import ( + "crypto/sha256" + "encoding/hex" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// BreachHintSize is the length of the identifier used to detect remote +// commitment broadcasts. +const BreachHintSize = 16 + +// BreachHint is the first 16-bytes of SHA256(txid), which is used to identify +// the breach transaction. +type BreachHint [BreachHintSize]byte + +// NewBreachHintFromHash creates a breach hint from a transaction ID. +func NewBreachHintFromHash(hash *chainhash.Hash) BreachHint { + h := sha256.New() + h.Write(hash[:]) + + var hint BreachHint + copy(hint[:], h.Sum(nil)) + return hint +} + +// String returns a hex encoding of the breach hint. +func (h BreachHint) String() string { + return hex.EncodeToString(h[:]) +} + +// BreachKey is computed as SHA256(txid || txid), which produces the key for +// decrypting a client's encrypted blobs. +type BreachKey [KeySize]byte + +// NewBreachKeyFromHash creates a breach key from a transaction ID. +func NewBreachKeyFromHash(hash *chainhash.Hash) BreachKey { + h := sha256.New() + h.Write(hash[:]) + h.Write(hash[:]) + + var key BreachKey + copy(key[:], h.Sum(nil)) + return key +} + +// String returns a hex encoding of the breach key. +func (k BreachKey) String() string { + return hex.EncodeToString(k[:]) +} + +// NewBreachHintAndKeyFromHash derives a BreachHint and BreachKey from a given +// txid in a single pass. The hint and key are computed as: +// hint = SHA256(txid) +// key = SHA256(txid || txid) +func NewBreachHintAndKeyFromHash(hash *chainhash.Hash) (BreachHint, BreachKey) { + var ( + hint BreachHint + key BreachKey + ) + + h := sha256.New() + h.Write(hash[:]) + copy(hint[:], h.Sum(nil)) + h.Write(hash[:]) + copy(key[:], h.Sum(nil)) + + return hint, key +} diff --git a/watchtower/blob/justice_kit.go b/watchtower/blob/justice_kit.go index 46abaa00..dd0213fb 100644 --- a/watchtower/blob/justice_kit.go +++ b/watchtower/blob/justice_kit.go @@ -75,11 +75,6 @@ var ( "ciphertext is too small for chacha20poly1305", ) - // ErrKeySize signals that the provided key is improperly sized. - ErrKeySize = fmt.Errorf( - "chacha20poly1305 key size must be %d bytes", KeySize, - ) - // ErrNoCommitToRemoteOutput is returned when trying to retrieve the // commit to-remote output from the blob, though none exists. ErrNoCommitToRemoteOutput = errors.New( @@ -223,12 +218,7 @@ func (b *JusticeKit) CommitToRemoteWitnessStack() ([][]byte, error) { // // NOTE: It is the caller's responsibility to ensure that this method is only // called once for a given (nonce, key) pair. -func (b *JusticeKit) Encrypt(key []byte, blobType Type) ([]byte, error) { - // Fail if the nonce is not 32-bytes. - if len(key) != KeySize { - return nil, ErrKeySize - } - +func (b *JusticeKit) Encrypt(key BreachKey, blobType Type) ([]byte, error) { // Encode the plaintext using the provided version, to obtain the // plaintext bytes. var ptxtBuf bytes.Buffer @@ -238,7 +228,7 @@ func (b *JusticeKit) Encrypt(key []byte, blobType Type) ([]byte, error) { } // Create a new chacha20poly1305 cipher, using a 32-byte key. - cipher, err := chacha20poly1305.NewX(key) + cipher, err := chacha20poly1305.NewX(key[:]) if err != nil { return nil, err } @@ -264,21 +254,17 @@ func (b *JusticeKit) Encrypt(key []byte, blobType Type) ([]byte, error) { // Decrypt unenciphers a blob of justice by decrypting the ciphertext using // chacha20poly1305 with the chosen (nonce, key) pair. The internal plaintext is // then deserialized using the given encoding version. -func Decrypt(key, ciphertext []byte, blobType Type) (*JusticeKit, error) { - switch { +func Decrypt(key BreachKey, ciphertext []byte, + blobType Type) (*JusticeKit, error) { // Fail if the blob's overall length is less than required for the nonce // and expansion factor. - case len(ciphertext) < NonceSize+CiphertextExpansion: + if len(ciphertext) < NonceSize+CiphertextExpansion { return nil, ErrCiphertextTooSmall - - // Fail if the key is not 32-bytes. - case len(key) != KeySize: - return nil, ErrKeySize } // Create a new chacha20poly1305 cipher, using a 32-byte key. - cipher, err := chacha20poly1305.NewX(key) + cipher, err := chacha20poly1305.NewX(key[:]) if err != nil { return nil, err } diff --git a/watchtower/blob/justice_kit_test.go b/watchtower/blob/justice_kit_test.go index 977289e0..34fd726a 100644 --- a/watchtower/blob/justice_kit_test.go +++ b/watchtower/blob/justice_kit_test.go @@ -56,15 +56,11 @@ type descriptorTest struct { decErr error } -var rewardAndCommitType = blob.TypeFromFlags( - blob.FlagReward, blob.FlagCommitOutputs, -) - var descriptorTests = []descriptorTest{ { name: "to-local only", - encVersion: blob.TypeDefault, - decVersion: blob.TypeDefault, + encVersion: blob.TypeAltruistCommit, + decVersion: blob.TypeAltruistCommit, sweepAddr: makeAddr(22), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -73,8 +69,8 @@ var descriptorTests = []descriptorTest{ }, { name: "to-local and p2wkh", - encVersion: rewardAndCommitType, - decVersion: rewardAndCommitType, + encVersion: blob.TypeRewardCommit, + decVersion: blob.TypeRewardCommit, sweepAddr: makeAddr(22), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -87,7 +83,7 @@ var descriptorTests = []descriptorTest{ { name: "unknown encrypt version", encVersion: 0, - decVersion: blob.TypeDefault, + decVersion: blob.TypeAltruistCommit, sweepAddr: makeAddr(34), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -97,7 +93,7 @@ var descriptorTests = []descriptorTest{ }, { name: "unknown decrypt version", - encVersion: blob.TypeDefault, + encVersion: blob.TypeAltruistCommit, decVersion: 0, sweepAddr: makeAddr(34), revPubKey: makePubKey(0), @@ -108,8 +104,8 @@ var descriptorTests = []descriptorTest{ }, { name: "sweep addr length zero", - encVersion: blob.TypeDefault, - decVersion: blob.TypeDefault, + encVersion: blob.TypeAltruistCommit, + decVersion: blob.TypeAltruistCommit, sweepAddr: makeAddr(0), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -118,8 +114,8 @@ var descriptorTests = []descriptorTest{ }, { name: "sweep addr max size", - encVersion: blob.TypeDefault, - decVersion: blob.TypeDefault, + encVersion: blob.TypeAltruistCommit, + decVersion: blob.TypeAltruistCommit, sweepAddr: makeAddr(blob.MaxSweepAddrSize), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -128,8 +124,8 @@ var descriptorTests = []descriptorTest{ }, { name: "sweep addr too long", - encVersion: blob.TypeDefault, - decVersion: blob.TypeDefault, + encVersion: blob.TypeAltruistCommit, + decVersion: blob.TypeAltruistCommit, sweepAddr: makeAddr(blob.MaxSweepAddrSize + 1), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -165,8 +161,8 @@ func testBlobJusticeKitEncryptDecrypt(t *testing.T, test descriptorTest) { // Generate a random encryption key for the blob. The key is // sized at 32 byte, as in practice we will be using the remote // party's commitment txid as the key. - key := make([]byte, blob.KeySize) - _, err := io.ReadFull(rand.Reader, key) + var key blob.BreachKey + _, err := rand.Read(key[:]) if err != nil { t.Fatalf("unable to generate blob encryption key: %v", err) } diff --git a/watchtower/blob/type.go b/watchtower/blob/type.go index 3a6e03de..923614db 100644 --- a/watchtower/blob/type.go +++ b/watchtower/blob/type.go @@ -45,9 +45,15 @@ func (f Flag) String() string { // of the blob itself. type Type uint16 -// TypeDefault sweeps only commitment outputs to a sweep address controlled by -// the user, and does not give the tower a reward. -const TypeDefault = Type(FlagCommitOutputs) +const ( + // TypeAltruistCommit sweeps only commitment outputs to a sweep address + // controlled by the user, and does not give the tower a reward. + TypeAltruistCommit = Type(FlagCommitOutputs) + + // TypeRewardCommit sweeps only commitment outputs to a sweep address + // controlled by the user, and pays a negotiated reward to the tower. + TypeRewardCommit = Type(FlagCommitOutputs | FlagReward) +) // Has returns true if the Type has the passed flag enabled. func (t Type) Has(flag Flag) bool { @@ -114,8 +120,8 @@ func (t Type) String() string { // supportedTypes is the set of all configurations known to be supported by the // package. var supportedTypes = map[Type]struct{}{ - FlagCommitOutputs.Type(): {}, - (FlagCommitOutputs | FlagReward).Type(): {}, + TypeAltruistCommit: {}, + TypeRewardCommit: {}, } // IsSupportedType returns true if the given type is supported by the package. diff --git a/watchtower/blob/type_test.go b/watchtower/blob/type_test.go index f5e0d85f..95b01c01 100644 --- a/watchtower/blob/type_test.go +++ b/watchtower/blob/type_test.go @@ -17,12 +17,12 @@ type typeStringTest struct { var typeStringTests = []typeStringTest{ { name: "commit no-reward", - typ: blob.TypeDefault, + typ: blob.TypeAltruistCommit, expStr: "[FlagCommitOutputs|No-FlagReward]", }, { name: "commit reward", - typ: (blob.FlagCommitOutputs | blob.FlagReward).Type(), + typ: blob.TypeRewardCommit, expStr: "[FlagCommitOutputs|FlagReward]", }, { @@ -75,7 +75,7 @@ var typeFromFlagTests = []typeFromFlagTest{ { name: "multiple flags", flags: []blob.Flag{blob.FlagReward, blob.FlagCommitOutputs}, - expType: blob.Type(blob.FlagReward | blob.FlagCommitOutputs), + expType: blob.TypeRewardCommit, }, { name: "duplicate flag", @@ -119,8 +119,8 @@ func TestTypeFromFlags(t *testing.T) { // blob.DefaultType returns true. func TestSupportedTypes(t *testing.T) { // Assert that the package's default type is supported. - if !blob.IsSupportedType(blob.TypeDefault) { - t.Fatalf("default type %s is not supported", blob.TypeDefault) + if !blob.IsSupportedType(blob.TypeAltruistCommit) { + t.Fatalf("default type %s is not supported", blob.TypeAltruistCommit) } // Assert that all claimed supported types are actually supported. diff --git a/watchtower/conf.go b/watchtower/conf.go index 9383d817..21210995 100644 --- a/watchtower/conf.go +++ b/watchtower/conf.go @@ -1,14 +1,68 @@ -// +build !experimental - package watchtower -// Conf specifies the watchtower options that be configured from the command -// line or configuration file. In non-experimental builds, we disallow such -// configuration. -type Conf struct{} +import ( + "time" +) -// Apply returns an error signaling that the Conf could not be applied in -// non-experimental builds. -func (c *Conf) Apply(cfg *Config) (*Config, error) { - return nil, ErrNonExperimentalConf +// Conf specifies the watchtower options that can be configured from the command +// line or configuration file. +type Conf struct { + // RawListeners configures the watchtower's listening ports/interfaces. + RawListeners []string `long:"listen" description:"Add interfaces/ports to listen for peer connections"` + + // ReadTimeout specifies the duration the tower will wait when trying to + // read a message from a client before hanging up. + ReadTimeout time.Duration `long:"readtimeout" description:"Duration the watchtower server will wait for messages to be received before hanging up on clients"` + + // WriteTimeout specifies the duration the tower will wait when trying + // to write a message from a client before hanging up. + WriteTimeout time.Duration `long:"writetimeout" description:"Duration the watchtower server will wait for messages to be written before hanging up on client connections"` +} + +// Apply completes the passed Config struct by applying any parsed Conf options. +// If the corresponding values parsed by Conf are already set in the Config, +// those fields will be not be modified. +func (c *Conf) Apply(cfg *Config, + normalizer AddressNormalizer) (*Config, error) { + + // Set the Config's listening addresses if they are empty. + if cfg.ListenAddrs == nil { + // Without a network, we will be unable to resolve the listening + // addresses. + if cfg.Net == nil { + return nil, ErrNoNetwork + } + + // If no addresses are specified by the Config, we will resort + // to the default peer port. + if len(c.RawListeners) == 0 { + addr := DefaultPeerPortStr + c.RawListeners = append(c.RawListeners, addr) + } + + // Normalize the raw listening addresses so that they can be + // used by the brontide listener. + var err error + cfg.ListenAddrs, err = normalizer( + c.RawListeners, DefaultPeerPortStr, + cfg.Net.ResolveTCPAddr, + ) + if err != nil { + return nil, err + } + } + + // If the Config has no read timeout, we will use the parsed Conf + // value. + if cfg.ReadTimeout == 0 && c.ReadTimeout != 0 { + cfg.ReadTimeout = c.ReadTimeout + } + + // If the Config has no write timeout, we will use the parsed Conf + // value. + if cfg.WriteTimeout == 0 && c.WriteTimeout != 0 { + cfg.WriteTimeout = c.WriteTimeout + } + + return cfg, nil } diff --git a/watchtower/conf_experimental.go b/watchtower/conf_experimental.go deleted file mode 100644 index 4b474998..00000000 --- a/watchtower/conf_experimental.go +++ /dev/null @@ -1,65 +0,0 @@ -// +build experimental - -package watchtower - -import ( - "time" - - "github.com/lightningnetwork/lnd/lncfg" -) - -// Conf specifies the watchtower options that can be configured from the command -// line or configuration file. -type Conf struct { - RawListeners []string `long:"listen" description:"Add interfaces/ports to listen for peer connections"` - - ReadTimeout time.Duration `long:"readtimeout" description:"Duration the watchtower server will wait for messages to be received before hanging up on clients"` - - WriteTimeout time.Duration `long:"writetimeout" description:"Duration the watchtower server will wait for messages to be written before hanging up on client connections"` -} - -// Apply completes the passed Config struct by applying any parsed Conf options. -// If the corresponding values parsed by Conf are already set in the Config, -// those fields will be not be modified. -func (c *Conf) Apply(cfg *Config) (*Config, error) { - // Set the Config's listening addresses if they are empty. - if cfg.ListenAddrs == nil { - // Without a network, we will be unable to resolve the listening - // addresses. - if cfg.Net == nil { - return nil, ErrNoNetwork - } - - // If no addresses are specified by the Config, we will resort - // to the default peer port. - if len(c.RawListeners) == 0 { - addr := DefaultPeerPortStr - c.RawListeners = append(c.RawListeners, addr) - } - - // Normalize the raw listening addresses so that they can be - // used by the brontide listener. - var err error - cfg.ListenAddrs, err = lncfg.NormalizeAddresses( - c.RawListeners, DefaultPeerPortStr, - cfg.Net.ResolveTCPAddr, - ) - if err != nil { - return nil, err - } - } - - // If the Config has no read timeout, we will use the parsed Conf - // value. - if cfg.ReadTimeout == 0 && c.ReadTimeout != 0 { - cfg.ReadTimeout = c.ReadTimeout - } - - // If the Config has no write timeout, we will use the parsed Conf - // value. - if cfg.WriteTimeout == 0 && c.WriteTimeout != 0 { - cfg.WriteTimeout = c.WriteTimeout - } - - return cfg, nil -} diff --git a/watchtower/errors.go b/watchtower/errors.go index e8682d4f..1ddeeb3f 100644 --- a/watchtower/errors.go +++ b/watchtower/errors.go @@ -7,11 +7,6 @@ var ( // rendering the tower unable to receive client requests. ErrNoListeners = errors.New("no listening ports were specified") - // ErrNonExperimentalConf signals that an attempt to apply a - // non-experimental Conf to a Config was detected. - ErrNonExperimentalConf = errors.New("cannot use watchtower in non-" + - "experimental builds") - // ErrNoNetwork signals that no tor.Net is provided in the Config, which // prevents resolution of listening addresses. ErrNoNetwork = errors.New("no network specified, must be tor or clearnet") diff --git a/watchtower/interface.go b/watchtower/interface.go index 59b1b848..cc7c08b4 100644 --- a/watchtower/interface.go +++ b/watchtower/interface.go @@ -1,6 +1,8 @@ package watchtower import ( + "net" + "github.com/lightningnetwork/lnd/watchtower/lookout" "github.com/lightningnetwork/lnd/watchtower/wtserver" ) @@ -12,3 +14,8 @@ type DB interface { lookout.DB wtserver.DB } + +// AddressNormalizer is a function signature that allows the tower to resolve +// TCP addresses on clear or onion networks. +type AddressNormalizer func(addrs []string, defaultPort string, + resolver func(string, string) (*net.TCPAddr, error)) ([]net.Addr, error) diff --git a/watchtower/lookout/interface.go b/watchtower/lookout/interface.go index 6a750c9f..a84f22b0 100644 --- a/watchtower/lookout/interface.go +++ b/watchtower/lookout/interface.go @@ -4,6 +4,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtdb" ) @@ -37,7 +38,7 @@ type DB interface { // QueryMatches searches its database for any state updates matching the // provided breach hints. If any matches are found, they will be // returned along with encrypted blobs so that justice can be exacted. - QueryMatches([]wtdb.BreachHint) ([]wtdb.Match, error) + QueryMatches([]blob.BreachHint) ([]wtdb.Match, error) // SetLookoutTip writes the best epoch for which the watchtower has // queried for breach hints. diff --git a/watchtower/lookout/justice_descriptor_test.go b/watchtower/lookout/justice_descriptor_test.go index c4c4d35e..1f880c23 100644 --- a/watchtower/lookout/justice_descriptor_test.go +++ b/watchtower/lookout/justice_descriptor_test.go @@ -156,9 +156,11 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) { // parameters that should be used in constructing the justice // transaction. policy := wtpolicy.Policy{ - BlobType: blobType, - SweepFeeRate: 2000, - RewardRate: 900000, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + SweepFeeRate: 2000, + RewardRate: 900000, + }, } sessionInfo := &wtdb.SessionInfo{ Policy: policy, diff --git a/watchtower/lookout/lookout.go b/watchtower/lookout/lookout.go index 556db2e8..fc7badde 100644 --- a/watchtower/lookout/lookout.go +++ b/watchtower/lookout/lookout.go @@ -7,7 +7,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/watchtower/blob" - "github.com/lightningnetwork/lnd/watchtower/wtdb" ) // Config houses the Lookout's required resources to properly fulfill it's duty, @@ -159,11 +158,11 @@ func (l *Lookout) processEpoch(epoch *chainntnfs.BlockEpoch, // Iterate over the transactions contained in the block, deriving a // breach hint for each transaction and constructing an index mapping // the hint back to it's original transaction. - hintToTx := make(map[wtdb.BreachHint]*wire.MsgTx, numTxnsInBlock) - txHints := make([]wtdb.BreachHint, 0, numTxnsInBlock) + hintToTx := make(map[blob.BreachHint]*wire.MsgTx, numTxnsInBlock) + txHints := make([]blob.BreachHint, 0, numTxnsInBlock) for _, tx := range block.Transactions { hash := tx.TxHash() - hint := wtdb.NewBreachHintFromHash(&hash) + hint := blob.NewBreachHintFromHash(&hash) txHints = append(txHints, hint) hintToTx[hint] = tx @@ -203,13 +202,16 @@ func (l *Lookout) processEpoch(epoch *chainntnfs.BlockEpoch, // The decryption key for the state update should be the full // txid of the breaching commitment transaction. - commitTxID := commitTx.TxHash() + // The decryption key for the state update should be computed as + // key = SHA256(txid). + breachTxID := commitTx.TxHash() + breachKey := blob.NewBreachKeyFromHash(&breachTxID) // Now, decrypt the blob of justice that we received in the // state update. This will contain all information required to // sweep the breached commitment outputs. justiceKit, err := blob.Decrypt( - commitTxID[:], match.EncryptedBlob, + breachKey, match.EncryptedBlob, match.SessionInfo.Policy.BlobType, ) if err != nil { diff --git a/watchtower/lookout/lookout_test.go b/watchtower/lookout/lookout_test.go index fb70a961..893b3d7c 100644 --- a/watchtower/lookout/lookout_test.go +++ b/watchtower/lookout/lookout_test.go @@ -96,7 +96,10 @@ func TestLookoutBreachMatching(t *testing.T) { sessionInfo1 := &wtdb.SessionInfo{ ID: makeArray33(1), Policy: wtpolicy.Policy{ - BlobType: rewardAndCommitType, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: rewardAndCommitType, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 10, }, RewardAddress: makeAddrSlice(22), @@ -104,7 +107,10 @@ func TestLookoutBreachMatching(t *testing.T) { sessionInfo2 := &wtdb.SessionInfo{ ID: makeArray33(2), Policy: wtpolicy.Policy{ - BlobType: rewardAndCommitType, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: rewardAndCommitType, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 10, }, RewardAddress: makeAddrSlice(22), @@ -148,14 +154,17 @@ func TestLookoutBreachMatching(t *testing.T) { CommitToLocalSig: makeArray64(2), } - // Encrypt the first justice kit under the txid of the first txn. - encBlob1, err := blob1.Encrypt(hash1[:], blob.FlagCommitOutputs.Type()) + key1 := blob.NewBreachKeyFromHash(&hash1) + key2 := blob.NewBreachKeyFromHash(&hash2) + + // Encrypt the first justice kit under breach key one. + encBlob1, err := blob1.Encrypt(key1, blob.FlagCommitOutputs.Type()) if err != nil { t.Fatalf("unable to encrypt sweep detail 1: %v", err) } - // Encrypt the second justice kit under the txid of the second txn. - encBlob2, err := blob2.Encrypt(hash2[:], blob.FlagCommitOutputs.Type()) + // Encrypt the second justice kit under breach key two. + encBlob2, err := blob2.Encrypt(key2, blob.FlagCommitOutputs.Type()) if err != nil { t.Fatalf("unable to encrypt sweep detail 2: %v", err) } @@ -163,13 +172,13 @@ func TestLookoutBreachMatching(t *testing.T) { // Add both state updates to the tower's database. txBlob1 := &wtdb.SessionStateUpdate{ ID: makeArray33(1), - Hint: wtdb.NewBreachHintFromHash(&hash1), + Hint: blob.NewBreachHintFromHash(&hash1), EncryptedBlob: encBlob1, SeqNum: 1, } txBlob2 := &wtdb.SessionStateUpdate{ ID: makeArray33(2), - Hint: wtdb.NewBreachHintFromHash(&hash2), + Hint: blob.NewBreachHintFromHash(&hash2), EncryptedBlob: encBlob2, SeqNum: 1, } diff --git a/watchtower/standalone.go b/watchtower/standalone.go index f55f44cb..83bf49dc 100644 --- a/watchtower/standalone.go +++ b/watchtower/standalone.go @@ -78,13 +78,14 @@ func New(cfg *Config) (*Standalone, error) { // Initialize the server with its required resources. server, err := wtserver.New(&wtserver.Config{ - ChainHash: cfg.ChainHash, - DB: cfg.DB, - NodePrivKey: cfg.NodePrivKey, - Listeners: listeners, - ReadTimeout: cfg.ReadTimeout, - WriteTimeout: cfg.WriteTimeout, - NewAddress: cfg.NewAddress, + ChainHash: cfg.ChainHash, + DB: cfg.DB, + NodePrivKey: cfg.NodePrivKey, + Listeners: listeners, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + NewAddress: cfg.NewAddress, + DisableReward: true, }) if err != nil { return nil, err diff --git a/watchtower/wtclient/backup_task.go b/watchtower/wtclient/backup_task.go index 72e14934..9f284b98 100644 --- a/watchtower/wtclient/backup_task.go +++ b/watchtower/wtclient/backup_task.go @@ -173,9 +173,9 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error { // required pieces from signatures, witness scripts, etc are then packaged into // a JusticeKit and encrypted using the breach transaction's key. func (t *backupTask) craftSessionPayload( - signer input.Signer) (wtdb.BreachHint, []byte, error) { + signer input.Signer) (blob.BreachHint, []byte, error) { - var hint wtdb.BreachHint + var hint blob.BreachHint // First, copy over the sweep pkscript, the pubkeys used to derive the // to-local script, and the remote CSV delay. @@ -276,22 +276,19 @@ func (t *backupTask) craftSessionPayload( } } - // Compute the breach hint from the breach transaction id's prefix. - breachKey := t.breachInfo.BreachTransaction.TxHash() + breachTxID := t.breachInfo.BreachTransaction.TxHash() + + // Compute the breach key as SHA256(txid). + hint, key := blob.NewBreachHintAndKeyFromHash(&breachTxID) // Then, we'll encrypt the computed justice kit using the full breach // transaction id, which will allow the tower to recover the contents // after the transaction is seen in the chain or mempool. - encBlob, err := justiceKit.Encrypt(breachKey[:], t.blobType) + encBlob, err := justiceKit.Encrypt(key, t.blobType) if err != nil { return hint, nil, err } - // Finally, compute the breach hint, taken as the first half of the - // breach transactions txid. Once the tower sees the breach transaction - // on the network, it can use the full txid to decyrpt the blob. - hint = wtdb.NewBreachHintFromHash(&breachKey) - return hint, encBlob, nil } diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 869c4042..5bc0dbb4 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -207,9 +207,11 @@ func genTaskTest( expRewardScript: rewardScript, session: &wtdb.ClientSessionBody{ Policy: wtpolicy.Policy{ - BlobType: blobType, - SweepFeeRate: sweepFeeRate, - RewardRate: 10000, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + SweepFeeRate: sweepFeeRate, + RewardRate: 10000, + }, }, RewardPkScript: rewardScript, }, @@ -516,7 +518,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Verify that the breach hint matches the breach txid's prefix. breachTxID := test.breachInfo.BreachTransaction.TxHash() - expHint := wtdb.NewBreachHintFromHash(&breachTxID) + expHint := blob.NewBreachHintFromHash(&breachTxID) if hint != expHint { t.Fatalf("breach hint mismatch, want: %x, got: %v", expHint, hint) @@ -524,7 +526,8 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Decrypt the return blob to obtain the JusticeKit containing its // contents. - jKit, err := blob.Decrypt(breachTxID[:], encBlob, policy.BlobType) + key := blob.NewBreachKeyFromHash(&breachTxID) + jKit, err := blob.Decrypt(key, encBlob, policy.BlobType) if err != nil { t.Fatalf("unable to decrypt blob: %v", err) } diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 8f0cbc9f..09bb5484 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -29,6 +29,11 @@ const ( // DefaultStatInterval specifies the default interval between logging // metrics about the client's operation. DefaultStatInterval = 30 * time.Second + + // DefaultForceQuitDelay specifies the default duration after which the + // client should abandon any pending updates or session negotiations + // before terminating. + DefaultForceQuitDelay = 10 * time.Second ) // Client is the primary interface used by the daemon to control a client's @@ -514,9 +519,10 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue { delete(c.candidateSessions, id) // Skip any sessions with policies that don't match the current - // configuration. These can be used again if the client changes - // their configuration back. - if sessionInfo.Policy != c.cfg.Policy { + // TxPolicy, as they would result in different justice + // transactions from what is requested. These can be used again + // if the client changes their configuration and restarting. + if sessionInfo.Policy.TxPolicy != c.cfg.Policy.TxPolicy { continue } @@ -561,6 +567,7 @@ func (c *TowerClient) backupDispatcher() { // Wait until we receive the newly negotiated session. // All backups sent in the meantime are queued in the // revoke queue, as we cannot process them. + awaitSession: select { case session := <-c.negotiator.NewSessions(): log.Infof("Acquired new session with id=%s", @@ -571,6 +578,12 @@ func (c *TowerClient) backupDispatcher() { case <-c.statTicker.C: log.Infof("Client stats: %s", c.stats) + // Instead of looping, we'll jump back into the + // select case and await the delivery of the + // session to prevent us from re-requesting + // additional sessions. + goto awaitSession + case <-c.forceQuit: return } @@ -626,9 +639,7 @@ func (c *TowerClient) backupDispatcher() { return } - log.Debugf("Processing backup task chanid=%s "+ - "commit-height=%d", task.id.ChanID, - task.id.CommitHeight) + log.Debugf("Processing %v", task.id) c.stats.taskReceived() c.processTask(task) @@ -659,8 +670,8 @@ func (c *TowerClient) processTask(task *backupTask) { // sessionQueue will be removed if accepting the task left the sessionQueue in // an exhausted state. func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) { - log.Infof("Backup chanid=%s commit-height=%d accepted successfully", - task.id.ChanID, task.id.CommitHeight) + log.Infof("Queued %v successfully for session %v", + task.id, c.sessionQueue.ID()) c.stats.taskAccepted() @@ -701,16 +712,14 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) { case reserveAvailable: c.stats.taskIneligible() - log.Infof("Backup chanid=%s commit-height=%d is ineligible", - task.id.ChanID, task.id.CommitHeight) + log.Infof("Ignoring ineligible %v", task.id) err := c.cfg.DB.MarkBackupIneligible( task.id.ChanID, task.id.CommitHeight, ) if err != nil { - log.Errorf("Unable to mark task chanid=%s "+ - "commit-height=%d ineligible: %v", - task.id.ChanID, task.id.CommitHeight, err) + log.Errorf("Unable to mark %v ineligible: %v", + task.id, err) // It is safe to not handle this error, even if we could // not persist the result. At worst, this task may be @@ -729,10 +738,8 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) { case reserveExhausted: c.stats.sessionExhausted() - log.Debugf("Session %s exhausted, backup chanid=%s "+ - "commit-height=%d queued for next session", - c.sessionQueue.ID(), task.id.ChanID, - task.id.CommitHeight) + log.Debugf("Session %v exhausted, %s queued for next session", + c.sessionQueue.ID(), task.id) // Cache the task that we pulled off, so that we can process it // once a new session queue is available. diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index b5a9bbbd..1bdc6dda 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -576,17 +576,17 @@ func (h *testHarness) registerChannel(id uint64) { // advanceChannelN calls advanceState on the channel identified by id the number // of provided times and returns the breach hints corresponding to the new // states. -func (h *testHarness) advanceChannelN(id uint64, n int) []wtdb.BreachHint { +func (h *testHarness) advanceChannelN(id uint64, n int) []blob.BreachHint { h.t.Helper() channel := h.channel(id) - var hints []wtdb.BreachHint + var hints []blob.BreachHint for i := uint64(0); i < uint64(n); i++ { channel.advanceState(h.t) commitTx, _ := h.channel(id).getState(i) breachTxID := commitTx.TxHash() - hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID)) + hints = append(hints, blob.NewBreachHintFromHash(&breachTxID)) } return hints @@ -621,18 +621,18 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { // party for each state in from-to times and returns the breach hints for states // [from, to). func (h *testHarness) sendPayments(id, from, to uint64, - amt lnwire.MilliSatoshi) []wtdb.BreachHint { + amt lnwire.MilliSatoshi) []blob.BreachHint { h.t.Helper() channel := h.channel(id) - var hints []wtdb.BreachHint + var hints []blob.BreachHint for i := from; i < to; i++ { h.channel(id).sendPayment(h.t, amt) commitTx, _ := channel.getState(i) breachTxID := commitTx.TxHash() - hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID)) + hints = append(hints, blob.NewBreachHintFromHash(&breachTxID)) } return hints @@ -642,18 +642,18 @@ func (h *testHarness) sendPayments(id, from, to uint64, // remote party for each state in from-to times and returns the breach hints for // states [from, to). func (h *testHarness) recvPayments(id, from, to uint64, - amt lnwire.MilliSatoshi) []wtdb.BreachHint { + amt lnwire.MilliSatoshi) []blob.BreachHint { h.t.Helper() channel := h.channel(id) - var hints []wtdb.BreachHint + var hints []blob.BreachHint for i := from; i < to; i++ { channel.receivePayment(h.t, amt) commitTx, _ := channel.getState(i) breachTxID := commitTx.TxHash() - hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID)) + hints = append(hints, blob.NewBreachHintFromHash(&breachTxID)) } return hints @@ -662,7 +662,7 @@ func (h *testHarness) recvPayments(id, from, to uint64, // waitServerUpdates blocks until the breach hints provided all appear in the // watchtower's database or the timeout expires. This is used to test that the // client in fact sends the updates to the server, even if it is offline. -func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint, +func (h *testHarness) waitServerUpdates(hints []blob.BreachHint, timeout time.Duration) { h.t.Helper() @@ -671,7 +671,7 @@ func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint, // assert that no updates appear. wantUpdates := len(hints) > 0 - hintSet := make(map[wtdb.BreachHint]struct{}) + hintSet := make(map[blob.BreachHint]struct{}) for _, hint := range hints { hintSet[hint] = struct{}{} } @@ -737,7 +737,7 @@ func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint, // assertUpdatesForPolicy queries the server db for matches using the provided // breach hints, then asserts that each match has a session with the expected // policy. -func (h *testHarness) assertUpdatesForPolicy(hints []wtdb.BreachHint, +func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint, expPolicy wtpolicy.Policy) { // Query for matches on the provided hints. @@ -785,9 +785,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 20000, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 20000, }, noRegisterChan0: true, }, @@ -817,9 +819,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 20000, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 20000, }, }, fn: func(h *testHarness) { @@ -850,9 +854,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 5, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, }, }, fn: func(h *testHarness) { @@ -884,9 +890,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 20000, - SweepFeeRate: 1000000, // high sweep fee creates dust + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: 1000000, // high sweep fee creates dust + }, + MaxUpdates: 20000, }, }, fn: func(h *testHarness) { @@ -913,9 +921,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 20000, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 20000, }, }, fn: func(h *testHarness) { @@ -993,9 +1003,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 5, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, }, }, fn: func(h *testHarness) { @@ -1049,9 +1061,11 @@ var clientTests = []clientTest{ localBalance: 10000001, // ensure (% amt != 0) remoteBalance: 20000001, // ensure (% amt != 0) policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 1000, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 1000, }, }, fn: func(h *testHarness) { @@ -1091,9 +1105,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 5, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, }, }, fn: func(h *testHarness) { @@ -1113,7 +1129,7 @@ var clientTests = []clientTest{ // Generate the retributions for all 10 channels and // collect the breach hints. - var hints []wtdb.BreachHint + var hints []blob.BreachHint for id := uint64(0); id < 10; id++ { chanHints := h.advanceChannelN(id, numUpdates) hints = append(hints, chanHints...) @@ -1139,9 +1155,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 5, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, }, noAckCreateSession: true, }, @@ -1195,9 +1213,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 5, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, }, noAckCreateSession: true, }, @@ -1230,7 +1250,7 @@ var clientTests = []clientTest{ // Restart the client with a new policy, which will // immediately try to overwrite the prior session with // the old policy. - h.clientCfg.Policy.SweepFeeRate = 2 + h.clientCfg.Policy.SweepFeeRate *= 2 h.startClient() defer h.client.ForceQuit() @@ -1246,6 +1266,67 @@ var clientTests = []clientTest{ h.assertUpdatesForPolicy(hints, h.clientCfg.Policy) }, }, + { + // Asserts that the client will not request a new session if + // already has an existing session with the same TxPolicy. This + // permits the client to continue using policies that differ in + // operational parameters, but don't manifest in different + // justice transactions. + name: "create session change policy same txpolicy", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 10, + }, + }, + fn: func(h *testHarness) { + const ( + chanID = 0 + numUpdates = 6 + ) + + // Generate the retributions that will be backed up. + hints := h.advanceChannelN(chanID, numUpdates) + + // Now, queue the first half of the retributions. + h.backupStates(chanID, 0, numUpdates/2, nil) + + // Wait for the server to collect the first half. + h.waitServerUpdates(hints[:numUpdates/2], time.Second) + + // Stop the client, which should have no more backups. + h.client.Stop() + + // Record the policy that the first half was stored + // under. We'll expect the second half to also be stored + // under the original policy, since we are only adjusting + // the MaxUpdates. The client should detect that the + // two policies have equivalent TxPolicies and continue + // using the first. + expPolicy := h.clientCfg.Policy + + // Restart the client with a new policy. + h.clientCfg.Policy.MaxUpdates = 20 + h.startClient() + defer h.client.ForceQuit() + + // Now, queue the second half of the retributions. + h.backupStates(chanID, numUpdates/2, numUpdates, nil) + + // Wait for all of the updates to be populated in the + // server's database. + h.waitServerUpdates(hints, 5*time.Second) + + // Assert that the server has updates for the client's + // original policy. + h.assertUpdatesForPolicy(hints, expPolicy) + }, + }, { // Asserts that the client will deduplicate backups presented by // a channel both in memory and after a restart. The client @@ -1256,9 +1337,11 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - BlobType: blob.TypeDefault, - MaxUpdates: 5, - SweepFeeRate: 1, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 5, }, }, fn: func(h *testHarness) { diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index e3d58c1f..9f92ee0a 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -112,7 +112,7 @@ var _ SessionNegotiator = (*sessionNegotiator)(nil) // newSessionNegotiator initializes a fresh sessionNegotiator instance. func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator { localInit := wtwire.NewInitMessage( - lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), cfg.ChainHash, ) diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index 39ab0a43..a279176a 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -109,7 +109,7 @@ type sessionQueue struct { // newSessionQueue intiializes a fresh sessionQueue. func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { localInit := wtwire.NewInitMessage( - lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), cfg.ChainHash, ) @@ -156,7 +156,7 @@ func (q *sessionQueue) Start() { // will clear all pending tasks in the queue before returning to the caller. func (q *sessionQueue) Stop() { q.stopped.Do(func() { - log.Debugf("Stopping session queue %s", q.ID()) + log.Debugf("SessionQueue(%s) stopping ...", q.ID()) close(q.quit) q.signalUntilShutdown() @@ -168,7 +168,7 @@ func (q *sessionQueue) Stop() { default: } - log.Debugf("Session queue %s successfully stopped", q.ID()) + log.Debugf("SessionQueue(%s) stopped", q.ID()) }) } @@ -176,12 +176,12 @@ func (q *sessionQueue) Stop() { // he caller after all lingering goroutines have spun down. func (q *sessionQueue) ForceQuit() { q.forced.Do(func() { - log.Infof("Force quitting session queue %s", q.ID()) + log.Infof("SessionQueue(%s) force quitting...", q.ID()) close(q.forceQuit) q.signalUntilShutdown() - log.Infof("Session queue %s unclean shutdown complete", q.ID()) + log.Infof("SessionQueue(%s) force quit", q.ID()) }) } @@ -197,8 +197,15 @@ func (q *sessionQueue) ID() *wtdb.SessionID { func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) { q.queueCond.L.Lock() + numPending := uint32(q.pendingQueue.Len()) + maxUpdates := q.cfg.ClientSession.Policy.MaxUpdates + log.Debugf("SessionQueue(%x) deciding to accept %v seqnum=%d "+ + "pending=%d max-updates=%d", + q.ID(), task.id, q.seqNum, numPending, maxUpdates) + // Examine the current reserve status of the session queue. curStatus := q.reserveStatus() + switch curStatus { // The session queue is exhausted, and cannot accept the task because it @@ -218,9 +225,8 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) { err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody) if err != nil { q.queueCond.L.Unlock() - log.Debugf("SessionQueue %s rejected backup chanid=%s "+ - "commit-height=%d: %v", q.ID(), task.id.ChanID, - task.id.CommitHeight, err) + log.Debugf("SessionQueue(%s) rejected %v: %v ", + q.ID(), task.id, err) return curStatus, false } } @@ -288,8 +294,8 @@ func (q *sessionQueue) drainBackups() { // First, check that we are able to dial this session's tower. conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionPrivKey, q.towerAddr) if err != nil { - log.Errorf("Unable to dial watchtower at %v: %v", - q.towerAddr, err) + log.Errorf("SessionQueue(%s) unable to dial tower at %v: %v", + q.ID(), q.towerAddr, err) q.increaseBackoff() select { @@ -308,9 +314,10 @@ func (q *sessionQueue) drainBackups() { // Generate the next state update to upload to the tower. This // method will first proceed in dequeueing committed updates // before attempting to dequeue any pending updates. - stateUpdate, isPending, err := q.nextStateUpdate() + stateUpdate, isPending, backupID, err := q.nextStateUpdate() if err != nil { - log.Errorf("Unable to get next state update: %v", err) + log.Errorf("SessionQueue(%s) unable to get next state "+ + "update: %v", err) return } @@ -319,7 +326,8 @@ func (q *sessionQueue) drainBackups() { conn, stateUpdate, q.localInit, sendInit, isPending, ) if err != nil { - log.Errorf("Unable to send state update: %v", err) + log.Errorf("SessionQueue(%s) unable to send state "+ + "update: %v", q.ID(), err) q.increaseBackoff() select { @@ -329,6 +337,9 @@ func (q *sessionQueue) drainBackups() { return } + log.Infof("SessionQueue(%s) uploaded %v seqnum=%d", + q.ID(), backupID, stateUpdate.SeqNum) + // If the last task was backed up successfully, we'll exit and // continue once more tasks are added to the queue. We'll also // clear any accumulated backoff as this batch was able to be @@ -357,7 +368,9 @@ func (q *sessionQueue) drainBackups() { // boolean value in the response is true if the state update is taken from the // pending queue, allowing the caller to remove the update from either the // commit or pending queue if the update is successfully acked. -func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { +func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, + wtdb.BackupID, error) { + var ( seqNum uint16 update wtdb.CommittedUpdate @@ -382,8 +395,9 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { isLast = q.commitQueue.Len() == 1 && q.pendingQueue.Len() == 0 q.queueCond.L.Unlock() - log.Debugf("Reprocessing committed state update for "+ - "session=%s seqnum=%d", q.ID(), seqNum) + log.Debugf("SessionQueue(%s) reprocessing committed state "+ + "update for %v seqnum=%d", + q.ID(), update.BackupID, seqNum) // Otherwise, craft and commit the next update from the pending queue. default: @@ -407,8 +421,9 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { hint, encBlob, err := task.craftSessionPayload(q.cfg.Signer) if err != nil { // TODO(conner): mark will not send - return nil, false, fmt.Errorf("unable to craft "+ - "session payload: %v", err) + err := fmt.Errorf("unable to craft session payload: %v", + err) + return nil, false, wtdb.BackupID{}, err } // TODO(conner): special case other obscure errors @@ -421,8 +436,8 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { }, } - log.Debugf("Committing state update for session=%s seqnum=%d", - q.ID(), seqNum) + log.Debugf("SessionQueue(%s) committing state update "+ + "%v seqnum=%d", q.ID(), update.BackupID, seqNum) } // Before sending the task to the tower, commit the state update @@ -439,8 +454,9 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), &update) if err != nil { // TODO(conner): mark failed/reschedule - return nil, false, fmt.Errorf("unable to commit state update "+ - "for session=%s seqnum=%d: %v", q.ID(), seqNum, err) + err := fmt.Errorf("unable to commit state update for "+ + "%v seqnum=%d: %v", update.BackupID, seqNum, err) + return nil, false, wtdb.BackupID{}, err } stateUpdate := &wtwire.StateUpdate{ @@ -455,7 +471,7 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { stateUpdate.IsComplete = 1 } - return stateUpdate, isPending, nil + return stateUpdate, isPending, update.BackupID, nil } // sendStateUpdate sends a wtwire.StateUpdate to the watchtower and processes @@ -486,8 +502,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, remoteInit, ok := remoteMsg.(*wtwire.Init) if !ok { - return fmt.Errorf("watchtower responded with %T to "+ - "Init", remoteMsg) + return fmt.Errorf("watchtower %s responded with %T "+ + "to Init", q.towerAddr, remoteMsg) } // Validate Init. @@ -513,8 +529,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply) if !ok { - return fmt.Errorf("watchtower responded with %T to StateUpdate", - remoteMsg) + return fmt.Errorf("watchtower %s responded with %T to "+ + "StateUpdate", q.towerAddr, remoteMsg) } // Process the reply from the tower. @@ -527,10 +543,10 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, // TODO(conner): handle other error cases properly, ban towers, etc. default: err := fmt.Errorf("received error code %v in "+ - "StateUpdateReply from tower=%x session=%v", - stateUpdateReply.Code, - conn.RemotePub().SerializeCompressed(), q.ID()) - log.Warnf("Unable to upload state update: %v", err) + "StateUpdateReply for seqnum=%d", + stateUpdateReply.Code, stateUpdate.SeqNum) + log.Warnf("SessionQueue(%s) unable to upload state update to "+ + "tower=%s: %v", q.ID(), q.towerAddr, err) return err } @@ -539,28 +555,27 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, switch { case err == wtdb.ErrUnallocatedLastApplied: // TODO(conner): borked watchtower - err = fmt.Errorf("unable to ack update=%d session=%s: %v", - stateUpdate.SeqNum, q.ID(), err) - log.Errorf("Failed to ack update: %v", err) + err = fmt.Errorf("unable to ack seqnum=%d: %v", + stateUpdate.SeqNum, err) + log.Errorf("SessionQueue(%s) failed to ack update: %v", err) return err case err == wtdb.ErrLastAppliedReversion: // TODO(conner): borked watchtower - err = fmt.Errorf("unable to ack update=%d session=%s: %v", - stateUpdate.SeqNum, q.ID(), err) - log.Errorf("Failed to ack update: %v", err) + err = fmt.Errorf("unable to ack seqnum=%d: %v", + stateUpdate.SeqNum, err) + log.Errorf("SessionQueue(%s) failed to ack update: %v", + q.ID(), err) return err case err != nil: - err = fmt.Errorf("unable to ack update=%d session=%s: %v", - stateUpdate.SeqNum, q.ID(), err) - log.Errorf("Failed to ack update: %v", err) + err = fmt.Errorf("unable to ack seqnum=%d: %v", + stateUpdate.SeqNum, err) + log.Errorf("SessionQueue(%s) failed to ack update: %v", + q.ID(), err) return err } - log.Infof("Removing update session=%s seqnum=%d is_pending=%v "+ - "from memory", q.ID(), stateUpdate.SeqNum, isPending) - q.queueCond.L.Lock() if isPending { // If a pending update was successfully sent, increment the @@ -591,9 +606,6 @@ func (q *sessionQueue) reserveStatus() reserveStatus { numPending := uint32(q.pendingQueue.Len()) maxUpdates := uint32(q.cfg.ClientSession.Policy.MaxUpdates) - log.Debugf("SessionQueue %s reserveStatus seqnum=%d pending=%d "+ - "max-updates=%d", q.ID(), q.seqNum, numPending, maxUpdates) - if uint32(q.seqNum)+numPending < maxUpdates { return reserveAvailable } diff --git a/watchtower/wtdb/breach_hint.go b/watchtower/wtdb/breach_hint.go deleted file mode 100644 index 8332745b..00000000 --- a/watchtower/wtdb/breach_hint.go +++ /dev/null @@ -1,27 +0,0 @@ -package wtdb - -import ( - "encoding/hex" - - "github.com/btcsuite/btcd/chaincfg/chainhash" -) - -// BreachHintSize is the length of the txid prefix used to identify remote -// commitment broadcasts. -const BreachHintSize = 16 - -// BreachHint is the first 16-bytes of the txid belonging to a revoked -// commitment transaction. -type BreachHint [BreachHintSize]byte - -// NewBreachHintFromHash creates a breach hint from a transaction ID. -func NewBreachHintFromHash(hash *chainhash.Hash) BreachHint { - var hint BreachHint - copy(hint[:], hash[:BreachHintSize]) - return hint -} - -// String returns a hex encoding of the breach hint. -func (h BreachHint) String() string { - return hex.EncodeToString(h[:]) -} diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 66fd8a4e..92ebc95b 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -664,7 +664,7 @@ func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { t.Fatalf("unable to generate chan id: %v", err) } - var hint wtdb.BreachHint + var hint blob.BreachHint if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil { t.Fatalf("unable to generate breach hint: %v", err) } diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index cb59ca57..e1fd564b 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -1,10 +1,12 @@ package wtdb import ( + "fmt" "io" "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) @@ -159,6 +161,11 @@ func (b *BackupID) Decode(r io.Reader) error { ) } +// String returns a human-readable encoding of a BackupID. +func (b *BackupID) String() string { + return fmt.Sprintf("backup(%x, %d)", b.ChanID, b.CommitHeight) +} + // CommittedUpdate holds a state update sent by a client along with its // allocated sequence number and the exact remote commitment the encrypted // justice transaction can rectify. @@ -178,7 +185,7 @@ type CommittedUpdateBody struct { BackupID BackupID // Hint is the 16-byte prefix of the revoked commitment transaction ID. - Hint BreachHint + Hint blob.BreachHint // EncryptedBlob is a ciphertext containing the sweep information for // exacting justice if the commitment transaction matching the breach diff --git a/watchtower/wtdb/codec.go b/watchtower/wtdb/codec.go index 2fd30196..8c88f814 100644 --- a/watchtower/wtdb/codec.go +++ b/watchtower/wtdb/codec.go @@ -36,7 +36,7 @@ func ReadElement(r io.Reader, element interface{}) error { return err } - case *BreachHint: + case *blob.BreachHint: if _, err := io.ReadFull(r, e[:]); err != nil { return err } @@ -94,7 +94,7 @@ func WriteElement(w io.Writer, element interface{}) error { return err } - case BreachHint: + case blob.BreachHint: if _, err := w.Write(e[:]); err != nil { return err } diff --git a/watchtower/wtdb/session_info.go b/watchtower/wtdb/session_info.go index f4acf764..2c8809a9 100644 --- a/watchtower/wtdb/session_info.go +++ b/watchtower/wtdb/session_info.go @@ -4,6 +4,7 @@ import ( "errors" "io" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) @@ -134,7 +135,7 @@ type Match struct { SeqNum uint16 // Hint is the breach hint that triggered the match. - Hint BreachHint + Hint blob.BreachHint // EncryptedBlob is the encrypted payload containing the justice kit // uploaded by the client. diff --git a/watchtower/wtdb/session_state_update.go b/watchtower/wtdb/session_state_update.go index 9b178220..c5e3131e 100644 --- a/watchtower/wtdb/session_state_update.go +++ b/watchtower/wtdb/session_state_update.go @@ -1,6 +1,10 @@ package wtdb -import "io" +import ( + "io" + + "github.com/lightningnetwork/lnd/watchtower/blob" +) // SessionStateUpdate holds a state update sent by a client along with its // SessionID. @@ -16,7 +20,7 @@ type SessionStateUpdate struct { LastApplied uint16 // Hint is the 16-byte prefix of the revoked commitment transaction. - Hint BreachHint + Hint blob.BreachHint // EncryptedBlob is a ciphertext containing the sweep information for // exacting justice if the commitment transaction matching the breach diff --git a/watchtower/wtdb/tower_db.go b/watchtower/wtdb/tower_db.go index 96edafca..92a9e55a 100644 --- a/watchtower/wtdb/tower_db.go +++ b/watchtower/wtdb/tower_db.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/watchtower/blob" ) const ( @@ -45,6 +46,10 @@ var ( // ErrNoSessionHintIndex signals that an active session does not have an // initialized index for tracking its own state updates. ErrNoSessionHintIndex = errors.New("session hint index missing") + + // ErrInvalidBlobSize indicates that the encrypted blob provided by the + // client is not valid according to the blob type of the session. + ErrInvalidBlobSize = errors.New("invalid blob size") ) // TowerDB is single database providing a persistent storage engine for the @@ -188,6 +193,12 @@ func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error { return ErrSessionAlreadyExists } + // Perform a quick sanity check on the session policy before + // accepting. + if err := session.Policy.Validate(); err != nil { + return err + } + err = putSession(sessions, session) if err != nil { return err @@ -232,6 +243,13 @@ func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) return err } + // Assert that the blob is the correct size for the session's + // blob type. + expBlobSize := blob.Size(session.Policy.BlobType) + if len(update.EncryptedBlob) != expBlobSize { + return ErrInvalidBlobSize + } + // Validate the update against the current state of the session. err = session.AcceptUpdateSequence( update.SeqNum, update.LastApplied, @@ -369,7 +387,7 @@ func (t *TowerDB) DeleteSession(target SessionID) error { // QueryMatches searches against all known state updates for any that match the // passed breachHints. More than one Match will be returned for a given hint if // they exist in the database. -func (t *TowerDB) QueryMatches(breachHints []BreachHint) ([]Match, error) { +func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) { var matches []Match err := t.db.View(func(tx *bbolt.Tx) error { sessions := tx.Bucket(sessionsBkt) @@ -534,20 +552,20 @@ func removeSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error { // If the index for the session has not been initialized, this method returns // ErrNoSessionHintIndex. func getHintsForSession(updateIndex *bbolt.Bucket, - id *SessionID) ([]BreachHint, error) { + id *SessionID) ([]blob.BreachHint, error) { sessionHints := updateIndex.Bucket(id[:]) if sessionHints == nil { return nil, ErrNoSessionHintIndex } - var hints []BreachHint + var hints []blob.BreachHint err := sessionHints.ForEach(func(k, _ []byte) error { - if len(k) != BreachHintSize { + if len(k) != blob.BreachHintSize { return nil } - var hint BreachHint + var hint blob.BreachHint copy(hint[:], k) hints = append(hints, hint) return nil @@ -565,7 +583,7 @@ func getHintsForSession(updateIndex *bbolt.Bucket, // for the session has not been initialized, this method returns // ErrNoSessionHintIndex. func putHintForSession(updateIndex *bbolt.Bucket, id *SessionID, - hint BreachHint) error { + hint blob.BreachHint) error { sessionHints := updateIndex.Bucket(id[:]) if sessionHints == nil { diff --git a/watchtower/wtdb/tower_db_test.go b/watchtower/wtdb/tower_db_test.go index c9920bcb..c408e7ef 100644 --- a/watchtower/wtdb/tower_db_test.go +++ b/watchtower/wtdb/tower_db_test.go @@ -1,6 +1,7 @@ package wtdb_test import ( + "bytes" "encoding/binary" "io/ioutil" "os" @@ -10,11 +11,16 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/watchtower" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) +var ( + testBlob = make([]byte, blob.Size(blob.TypeAltruistCommit)) +) + // dbInit is a closure used to initialize a watchtower.DB instance and its // cleanup function. type dbInit func(*testing.T) (watchtower.DB, func()) @@ -97,10 +103,10 @@ func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) { // queryMatches queries that database for the passed breach hint, returning all // matches found. -func (h *towerDBHarness) queryMatches(hint wtdb.BreachHint) []wtdb.Match { +func (h *towerDBHarness) queryMatches(hint blob.BreachHint) []wtdb.Match { h.t.Helper() - matches, err := h.db.QueryMatches([]wtdb.BreachHint{hint}) + matches, err := h.db.QueryMatches([]blob.BreachHint{hint}) if err != nil { h.t.Fatalf("unable to query matches: %v", err) } @@ -111,7 +117,7 @@ func (h *towerDBHarness) queryMatches(hint wtdb.BreachHint) []wtdb.Match { // hasUpdate queries the database for the passed breach hint, asserting that // only one match is present and that the hints indeed match. If successful, the // match is returned. -func (h *towerDBHarness) hasUpdate(hint wtdb.BreachHint) wtdb.Match { +func (h *towerDBHarness) hasUpdate(hint blob.BreachHint) wtdb.Match { h.t.Helper() matches := h.queryMatches(hint) @@ -136,11 +142,21 @@ func testInsertSession(h *towerDBHarness) { session := &wtdb.SessionInfo{ ID: id, Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + }, MaxUpdates: 100, }, RewardAddress: []byte{0x01, 0x02, 0x03}, } + // Try to insert the session, which should fail since the policy doesn't + // meet the current sanity checks. + h.insertSession(session, wtpolicy.ErrSweepFeeRateTooLow) + + // Now assign a sane sweep fee rate to the policy, inserting should + // succeed. + session.Policy.SweepFeeRate = wtpolicy.DefaultSweepFeeRate h.insertSession(session, nil) session2 := h.getSession(&id, nil) @@ -154,8 +170,9 @@ func testInsertSession(h *towerDBHarness) { // Insert a state update to fully commit the session parameters. update := &wtdb.SessionStateUpdate{ - ID: id, - SeqNum: 1, + ID: id, + SeqNum: 1, + EncryptedBlob: testBlob, } h.insertUpdate(update, nil) @@ -169,12 +186,16 @@ func testMultipleMatches(h *towerDBHarness) { const numUpdates = 3 // Create a new session and send updates with all the same hint. - var hint wtdb.BreachHint + var hint blob.BreachHint for i := 0; i < numUpdates; i++ { id := *id(i) session := &wtdb.SessionInfo{ ID: id, Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -182,9 +203,10 @@ func testMultipleMatches(h *towerDBHarness) { h.insertSession(session, nil) update := &wtdb.SessionStateUpdate{ - ID: id, - SeqNum: 1, - Hint: hint, // Use same hint to cause multiple matches + ID: id, + SeqNum: 1, + Hint: hint, // Use same hint to cause multiple matches + EncryptedBlob: testBlob, } h.insertUpdate(update, nil) } @@ -266,6 +288,10 @@ func testDeleteSession(h *towerDBHarness) { session0 := &wtdb.SessionInfo{ ID: *id0, Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -284,6 +310,10 @@ func testDeleteSession(h *towerDBHarness) { session1 := &wtdb.SessionInfo{ ID: *id1, Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -291,18 +321,18 @@ func testDeleteSession(h *towerDBHarness) { h.insertSession(session1, nil) // Create and insert updates for both sessions that have the same hint. - var hint wtdb.BreachHint + var hint blob.BreachHint update0 := &wtdb.SessionStateUpdate{ ID: *id0, Hint: hint, SeqNum: 1, - EncryptedBlob: []byte{}, + EncryptedBlob: testBlob, } update1 := &wtdb.SessionStateUpdate{ ID: *id1, Hint: hint, SeqNum: 1, - EncryptedBlob: []byte{}, + EncryptedBlob: testBlob, } // Insert both updates should succeed. @@ -413,7 +443,7 @@ func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) { var stateUpdateNoSession = stateUpdateTest{ session: nil, updates: []*wtdb.SessionStateUpdate{ - {ID: *id(0), SeqNum: 1, LastApplied: 0}, + updateFromInt(id(0), 1, 0), }, updateErrs: []error{ wtdb.ErrSessionNotFound, @@ -424,6 +454,10 @@ var stateUpdateExhaustSession = stateUpdateTest{ session: &wtdb.SessionInfo{ ID: *id(0), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -443,6 +477,10 @@ var stateUpdateSeqNumEqualLastApplied = stateUpdateTest{ session: &wtdb.SessionInfo{ ID: *id(0), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -462,6 +500,10 @@ var stateUpdateSeqNumLTLastApplied = stateUpdateTest{ session: &wtdb.SessionInfo{ ID: *id(0), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -480,6 +522,10 @@ var stateUpdateSeqNumZeroInvalid = stateUpdateTest{ session: &wtdb.SessionInfo{ ID: *id(0), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -496,6 +542,10 @@ var stateUpdateSkipSeqNum = stateUpdateTest{ session: &wtdb.SessionInfo{ ID: *id(0), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -512,6 +562,10 @@ var stateUpdateRevertSeqNum = stateUpdateTest{ session: &wtdb.SessionInfo{ ID: *id(0), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -530,6 +584,10 @@ var stateUpdateRevertLastApplied = stateUpdateTest{ session: &wtdb.SessionInfo{ ID: *id(0), Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, MaxUpdates: 3, }, RewardAddress: []byte{}, @@ -545,6 +603,31 @@ var stateUpdateRevertLastApplied = stateUpdateTest{ }, } +var stateUpdateInvalidBlobSize = stateUpdateTest{ + session: &wtdb.SessionInfo{ + ID: *id(0), + Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + }, + updates: []*wtdb.SessionStateUpdate{ + { + ID: *id(0), + SeqNum: 1, + LastApplied: 0, + EncryptedBlob: []byte{0x01, 0x02, 0x03}, // too $hort + }, + }, + updateErrs: []error{ + wtdb.ErrInvalidBlobSize, + }, +} + func TestTowerDB(t *testing.T) { dbs := []struct { name string @@ -662,6 +745,10 @@ func TestTowerDB(t *testing.T) { name: "state update revert last applied", run: runStateUpdateTest(stateUpdateRevertLastApplied), }, + { + name: "invalid blob size", + run: runStateUpdateTest(stateUpdateInvalidBlobSize), + }, { name: "multiple breach matches", run: testMultipleMatches, @@ -705,16 +792,18 @@ func updateFromInt(id *wtdb.SessionID, i int, lastApplied uint16) *wtdb.SessionStateUpdate { // Ensure the hint is unique. - var hint wtdb.BreachHint + var hint blob.BreachHint copy(hint[:4], id[:4]) binary.BigEndian.PutUint16(hint[4:6], uint16(i)) + blobSize := blob.Size(blob.TypeAltruistCommit) + return &wtdb.SessionStateUpdate{ ID: *id, Hint: hint, SeqNum: uint16(i), LastApplied: lastApplied, - EncryptedBlob: []byte{byte(i)}, + EncryptedBlob: bytes.Repeat([]byte{byte(i)}, blobSize), } } diff --git a/watchtower/wtmock/tower_db.go b/watchtower/wtmock/tower_db.go index 403d61e3..35c01ab1 100644 --- a/watchtower/wtmock/tower_db.go +++ b/watchtower/wtmock/tower_db.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtdb" ) @@ -12,14 +13,14 @@ type TowerDB struct { mu sync.Mutex lastEpoch *chainntnfs.BlockEpoch sessions map[wtdb.SessionID]*wtdb.SessionInfo - blobs map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate + blobs map[blob.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate } // NewTowerDB initializes a fresh mock TowerDB. func NewTowerDB() *TowerDB { return &TowerDB{ sessions: make(map[wtdb.SessionID]*wtdb.SessionInfo), - blobs: make(map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate), + blobs: make(map[blob.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate), } } @@ -36,6 +37,11 @@ func (db *TowerDB) InsertStateUpdate(update *wtdb.SessionStateUpdate) (uint16, e return 0, wtdb.ErrSessionNotFound } + // Assert that the blob is the correct size for the session's blob type. + if len(update.EncryptedBlob) != blob.Size(info.Policy.BlobType) { + return 0, wtdb.ErrInvalidBlobSize + } + err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied) if err != nil { return info.LastApplied, err @@ -75,6 +81,11 @@ func (db *TowerDB) InsertSessionInfo(info *wtdb.SessionInfo) error { return wtdb.ErrSessionAlreadyExists } + // Perform a quick sanity check on the session policy before accepting. + if err := info.Policy.Validate(); err != nil { + return err + } + db.sessions[info.ID] = info return nil @@ -113,7 +124,7 @@ func (db *TowerDB) DeleteSession(target wtdb.SessionID) error { // passed breachHints. More than one Match will be returned for a given hint if // they exist in the database. func (db *TowerDB) QueryMatches( - breachHints []wtdb.BreachHint) ([]wtdb.Match, error) { + breachHints []blob.BreachHint) ([]wtdb.Match, error) { db.mu.Lock() defer db.mu.Unlock() diff --git a/watchtower/wtpolicy/policy.go b/watchtower/wtpolicy/policy.go index b21f73d3..163c1f1e 100644 --- a/watchtower/wtpolicy/policy.go +++ b/watchtower/wtpolicy/policy.go @@ -27,7 +27,11 @@ const ( // DefaultSweepFeeRate specifies the fee rate used to construct justice // transactions. The value is expressed in satoshis per kilo-weight. - DefaultSweepFeeRate = 3000 + DefaultSweepFeeRate = lnwallet.SatPerKWeight(12000) + + // MinSweepFeeRate is the minimum sweep fee rate a client may use in its + // policy, the current value is 4 sat/kw. + MinSweepFeeRate = lnwallet.SatPerKWeight(4000) ) var ( @@ -43,34 +47,42 @@ var ( // ErrCreatesDust signals that the session's policy would create a dust // output for the victim. ErrCreatesDust = errors.New("justice transaction creates dust at fee rate") + + // ErrAltruistReward signals that the policy is invalid because it + // contains a non-zero RewardBase or RewardRate on an altruist policy. + ErrAltruistReward = errors.New("altruist policy has reward params") + + // ErrNoMaxUpdates signals that the policy specified zero MaxUpdates. + ErrNoMaxUpdates = errors.New("max updates must be positive") + + // ErrSweepFeeRateTooLow signals that the policy's fee rate is too low + // to get into the mempool during low congestion. + ErrSweepFeeRateTooLow = errors.New("sweep fee rate too low") ) // DefaultPolicy returns a Policy containing the default parameters that can be // used by clients or servers. func DefaultPolicy() Policy { return Policy{ - BlobType: blob.TypeDefault, + TxPolicy: TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: DefaultSweepFeeRate, + }, MaxUpdates: DefaultMaxUpdates, - RewardRate: DefaultRewardRate, - SweepFeeRate: lnwallet.SatPerKWeight( - DefaultSweepFeeRate, - ), } } -// Policy defines the negotiated parameters for a session between a client and -// server. The parameters specify the format of encrypted blobs sent to the -// tower, the reward schedule for the tower, and the number of encrypted blobs a -// client can send in one session. -type Policy struct { +// TxPolicy defines the negotiate parameters that determine the form of the +// justice transaction for a given breached state. Thus, for any given revoked +// state, an identical key will result in an identical justice transaction +// (barring signatures). The parameters specify the format of encrypted blobs +// sent to the tower, the reward schedule for the tower, and the number of +// encrypted blobs a client can send in one session. +type TxPolicy struct { // BlobType specifies the blob format that must be used by all updates sent // under the session key used to negotiate this session. BlobType blob.Type - // MaxUpdates is the maximum number of updates the watchtower will honor - // for this session. - MaxUpdates uint16 - // RewardBase is the fixed amount allocated to the tower when the // policy's blob type specifies a reward for the tower. This is taken // before adding the proportional reward. @@ -88,6 +100,18 @@ type Policy struct { SweepFeeRate lnwallet.SatPerKWeight } +// Policy defines the negotiated parameters for a session between a client and +// server. In addition to the TxPolicy that governs the shape of the justice +// transaction, the Policy also includes features which only affect the +// operation of the session. +type Policy struct { + TxPolicy + + // MaxUpdates is the maximum number of updates the watchtower will honor + // for this session. + MaxUpdates uint16 +} + // String returns a human-readable description of the current policy. func (p Policy) String() string { return fmt.Sprintf("(blob-type=%b max-updates=%d reward-rate=%d "+ @@ -95,6 +119,31 @@ func (p Policy) String() string { p.SweepFeeRate) } +// Validate ensures that the policy satisfies some minimal correctness +// constraints. +func (p Policy) Validate() error { + // RewardBase and RewardRate should not be set if the policy doesn't + // have a reward. + if !p.BlobType.Has(blob.FlagReward) && + (p.RewardBase != 0 || p.RewardRate != 0) { + + return ErrAltruistReward + } + + // MaxUpdates must be positive. + if p.MaxUpdates == 0 { + return ErrNoMaxUpdates + } + + // SweepFeeRate must be sane enough to get in the mempool during low + // congestion. + if p.SweepFeeRate < MinSweepFeeRate { + return ErrSweepFeeRateTooLow + } + + return nil +} + // ComputeAltruistOutput computes the lone output value of a justice transaction // that pays no reward to the tower. The value is computed using the weight of // of the justice transaction and subtracting an amount that satisfies the diff --git a/watchtower/wtpolicy/policy_test.go b/watchtower/wtpolicy/policy_test.go new file mode 100644 index 00000000..4182a0de --- /dev/null +++ b/watchtower/wtpolicy/policy_test.go @@ -0,0 +1,93 @@ +package wtpolicy_test + +import ( + "testing" + + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" +) + +var validationTests = []struct { + name string + policy wtpolicy.Policy + expErr error +}{ + { + name: "fail no maxupdates", + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + }, + }, + expErr: wtpolicy.ErrNoMaxUpdates, + }, + { + name: "fail altruist with reward base", + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + RewardBase: 1, + }, + }, + expErr: wtpolicy.ErrAltruistReward, + }, + { + name: "fail altruist with reward rate", + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + RewardRate: 1, + }, + }, + expErr: wtpolicy.ErrAltruistReward, + }, + { + name: "fail sweep fee rate too low", + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + }, + MaxUpdates: 1, + }, + expErr: wtpolicy.ErrSweepFeeRateTooLow, + }, + { + name: "minimal valid altruist policy", + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.MinSweepFeeRate, + }, + MaxUpdates: 1, + }, + }, + { + name: "valid altruist policy with default sweep rate", + policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, + }, + MaxUpdates: 1, + }, + }, + { + name: "valid default policy", + policy: wtpolicy.DefaultPolicy(), + }, +} + +// TestPolicyValidate asserts that the sanity checks for policies behave as +// expected. +func TestPolicyValidate(t *testing.T) { + for i := range validationTests { + test := validationTests[i] + t.Run(test.name, func(t *testing.T) { + err := test.policy.Validate() + if err != test.expErr { + t.Fatalf("validation error mismatch, "+ + "want: %v, got: %v", test.expErr, err) + } + }) + } +} diff --git a/watchtower/wtserver/create_session.go b/watchtower/wtserver/create_session.go index 5636b79d..3fd58dbd 100644 --- a/watchtower/wtserver/create_session.go +++ b/watchtower/wtserver/create_session.go @@ -54,6 +54,17 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, ) } + // If the request asks for a reward session and the tower has them + // disabled, we will reject the request. + if s.cfg.DisableReward && req.BlobType.Has(blob.FlagReward) { + log.Debugf("Rejecting CreateSession from %s, reward "+ + "sessions disabled", id) + return s.replyCreateSession( + peer, id, wtwire.CreateSessionCodeRejectBlobType, 0, + nil, + ) + } + // Now that we've established that this session does not exist in the // database, retrieve the sweep address that will be given to the // client. This address is to be included by the client when signing @@ -89,11 +100,13 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, info := wtdb.SessionInfo{ ID: *id, Policy: wtpolicy.Policy{ - BlobType: req.BlobType, - MaxUpdates: req.MaxUpdates, - RewardBase: req.RewardBase, - RewardRate: req.RewardRate, - SweepFeeRate: req.SweepFeeRate, + TxPolicy: wtpolicy.TxPolicy{ + BlobType: req.BlobType, + RewardBase: req.RewardBase, + RewardRate: req.RewardRate, + SweepFeeRate: req.SweepFeeRate, + }, + MaxUpdates: req.MaxUpdates, }, RewardAddress: rewardScript, } diff --git a/watchtower/wtserver/server.go b/watchtower/wtserver/server.go index d4ee8874..1c46201b 100644 --- a/watchtower/wtserver/server.go +++ b/watchtower/wtserver/server.go @@ -63,6 +63,10 @@ type Config struct { // NoAckUpdates causes the server to not acknowledge state updates, this // should only be used for testing. NoAckUpdates bool + + // DisableReward causes the server to reject any session creation + // attempts that request rewards. + DisableReward bool } // Server houses the state required to handle watchtower peers. It's primary job @@ -92,7 +96,7 @@ type Server struct { // sessions and send state updates. func New(cfg *Config) (*Server, error) { localInit := wtwire.NewInitMessage( - lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional), cfg.ChainHash, ) diff --git a/watchtower/wtserver/server_test.go b/watchtower/wtserver/server_test.go index c3ae2a33..0bb5806d 100644 --- a/watchtower/wtserver/server_test.go +++ b/watchtower/wtserver/server_test.go @@ -28,7 +28,7 @@ var ( testnetChainHash = *chaincfg.TestNet3Params.GenesisHash - rewardType = (blob.FlagCommitOutputs | blob.FlagReward).Type() + testBlob = make([]byte, blob.Size(blob.TypeAltruistCommit)) ) // randPubKey generates a new secp keypair, and returns the public key. @@ -168,11 +168,11 @@ var createSessionTests = []createSessionTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 1000, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, expReply: &wtwire.CreateSessionReply{ Code: wtwire.CodeOK, @@ -190,11 +190,11 @@ var createSessionTests = []createSessionTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 1000, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, expReply: &wtwire.CreateSessionReply{ Code: wtwire.CodeOK, @@ -214,11 +214,11 @@ var createSessionTests = []createSessionTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: rewardType, + BlobType: blob.TypeRewardCommit, MaxUpdates: 1000, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, expReply: &wtwire.CreateSessionReply{ Code: wtwire.CodeOK, @@ -240,7 +240,7 @@ var createSessionTests = []createSessionTestCase{ MaxUpdates: 1000, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, expReply: &wtwire.CreateSessionReply{ Code: wtwire.CreateSessionCodeRejectBlobType, @@ -302,8 +302,9 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) { peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) connect(t, s, peer, test.initMsg, timeoutDuration) update := &wtwire.StateUpdate{ - SeqNum: 1, - IsComplete: 1, + SeqNum: 1, + IsComplete: 1, + EncryptedBlob: testBlob, } sendMsg(t, update, peer, timeoutDuration) @@ -325,8 +326,8 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) { // Ensure that the server's reply matches our expected response for a // duplicate send. if !reflect.DeepEqual(reply, test.expDupReply) { - t.Fatalf("[test %d] expected reply %v, got %d", - i, test.expReply, reply) + t.Fatalf("[test %d] expected reply %v, got %v", + i, test.expDupReply, reply) } // Finally, check that the server tore down the connection. @@ -350,17 +351,17 @@ var stateUpdateTests = []stateUpdateTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 3, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, updates: []*wtwire.StateUpdate{ - {SeqNum: 1, LastApplied: 0}, - {SeqNum: 2, LastApplied: 1}, - {SeqNum: 3, LastApplied: 2}, - {SeqNum: 3, LastApplied: 3}, + {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob}, + {SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob}, + {SeqNum: 3, LastApplied: 3, EncryptedBlob: testBlob}, }, replies: []*wtwire.StateUpdateReply{ {Code: wtwire.CodeOK, LastApplied: 1}, @@ -380,14 +381,14 @@ var stateUpdateTests = []stateUpdateTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 4, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, updates: []*wtwire.StateUpdate{ - {SeqNum: 2, LastApplied: 0}, + {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob}, }, replies: []*wtwire.StateUpdateReply{ { @@ -404,16 +405,16 @@ var stateUpdateTests = []stateUpdateTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 4, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, updates: []*wtwire.StateUpdate{ - {SeqNum: 1, LastApplied: 0}, - {SeqNum: 2, LastApplied: 0}, - {SeqNum: 1, LastApplied: 0}, + {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob}, }, replies: []*wtwire.StateUpdateReply{ {Code: wtwire.CodeOK, LastApplied: 1}, @@ -432,17 +433,17 @@ var stateUpdateTests = []stateUpdateTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 4, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, updates: []*wtwire.StateUpdate{ - {SeqNum: 1, LastApplied: 0}, - {SeqNum: 2, LastApplied: 1}, - {SeqNum: 3, LastApplied: 2}, - {SeqNum: 4, LastApplied: 1}, + {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob}, + {SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob}, + {SeqNum: 4, LastApplied: 1, EncryptedBlob: testBlob}, }, replies: []*wtwire.StateUpdateReply{ {Code: wtwire.CodeOK, LastApplied: 1}, @@ -460,18 +461,18 @@ var stateUpdateTests = []stateUpdateTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 4, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, updates: []*wtwire.StateUpdate{ - {SeqNum: 1, LastApplied: 0}, - {SeqNum: 2, LastApplied: 1}, + {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob}, nil, // Wait for read timeout to drop conn, then reconnect. - {SeqNum: 3, LastApplied: 2}, - {SeqNum: 4, LastApplied: 3}, + {SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob}, + {SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob}, }, replies: []*wtwire.StateUpdateReply{ {Code: wtwire.CodeOK, LastApplied: 1}, @@ -490,18 +491,18 @@ var stateUpdateTests = []stateUpdateTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 4, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, updates: []*wtwire.StateUpdate{ - {SeqNum: 1, LastApplied: 0}, - {SeqNum: 2, LastApplied: 0}, + {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob}, nil, // Wait for read timeout to drop conn, then reconnect. - {SeqNum: 3, LastApplied: 0}, - {SeqNum: 4, LastApplied: 3}, + {SeqNum: 3, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob}, }, replies: []*wtwire.StateUpdateReply{ {Code: wtwire.CodeOK, LastApplied: 1}, @@ -520,19 +521,19 @@ var stateUpdateTests = []stateUpdateTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 4, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, updates: []*wtwire.StateUpdate{ - {SeqNum: 1, LastApplied: 0}, - {SeqNum: 2, LastApplied: 0}, + {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob}, nil, // Wait for read timeout to drop conn, then reconnect. - {SeqNum: 2, LastApplied: 0}, - {SeqNum: 3, LastApplied: 0}, - {SeqNum: 4, LastApplied: 3}, + {SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 3, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob}, }, replies: []*wtwire.StateUpdateReply{ {Code: wtwire.CodeOK, LastApplied: 1}, @@ -551,17 +552,17 @@ var stateUpdateTests = []stateUpdateTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 3, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, updates: []*wtwire.StateUpdate{ - {SeqNum: 1, LastApplied: 0}, - {SeqNum: 2, LastApplied: 1}, - {SeqNum: 3, LastApplied: 2}, - {SeqNum: 4, LastApplied: 3}, + {SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob}, + {SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob}, + {SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob}, + {SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob}, }, replies: []*wtwire.StateUpdateReply{ {Code: wtwire.CodeOK, LastApplied: 1}, @@ -581,14 +582,14 @@ var stateUpdateTests = []stateUpdateTestCase{ testnetChainHash, ), createMsg: &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 3, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, }, updates: []*wtwire.StateUpdate{ - {SeqNum: 0, LastApplied: 0}, + {SeqNum: 0, LastApplied: 0, EncryptedBlob: testBlob}, }, replies: []*wtwire.StateUpdateReply{ { @@ -718,11 +719,11 @@ func TestServerDeleteSession(t *testing.T) { ) createSession := &wtwire.CreateSession{ - BlobType: blob.TypeDefault, + BlobType: blob.TypeAltruistCommit, MaxUpdates: 1000, RewardBase: 0, RewardRate: 0, - SweepFeeRate: 1, + SweepFeeRate: 10000, } const timeoutDuration = 100 * time.Millisecond diff --git a/watchtower/wtwire/features.go b/watchtower/wtwire/features.go index e407c96e..7ba298e0 100644 --- a/watchtower/wtwire/features.go +++ b/watchtower/wtwire/features.go @@ -5,18 +5,18 @@ import "github.com/lightningnetwork/lnd/lnwire" // FeatureNames holds a mapping from each feature bit understood by this // implementation to its common name. var FeatureNames = map[lnwire.FeatureBit]string{ - WtSessionsRequired: "wt-sessions", - WtSessionsOptional: "wt-sessions", + AltruistSessionsRequired: "altruist-sessions", + AltruistSessionsOptional: "altruist-sessions", } const ( - // WtSessionsRequired specifies that the advertising node requires the - // remote party to understand the protocol for creating and updating + // AltruistSessionsRequired specifies that the advertising node requires + // the remote party to understand the protocol for creating and updating // watchtower sessions. - WtSessionsRequired lnwire.FeatureBit = 8 + AltruistSessionsRequired lnwire.FeatureBit = 0 - // WtSessionsOptional specifies that the advertising node can support - // a remote party who understand the protocol for creating and updating - // watchtower sessions. - WtSessionsOptional lnwire.FeatureBit = 9 + // AltruistSessionsOptional specifies that the advertising node can + // support a remote party who understand the protocol for creating and + // updating watchtower sessions. + AltruistSessionsOptional lnwire.FeatureBit = 1 ) diff --git a/watchtower/wtwire/init_test.go b/watchtower/wtwire/init_test.go index 337c1de2..1aee5530 100644 --- a/watchtower/wtwire/init_test.go +++ b/watchtower/wtwire/init_test.go @@ -26,37 +26,37 @@ type checkRemoteInitTest struct { var checkRemoteInitTests = []checkRemoteInitTest{ { name: "same chain, local-optional remote-required", - lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional), lHash: testnetChainHash, - rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), rHash: testnetChainHash, }, { name: "same chain, local-required remote-optional", - lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), lHash: testnetChainHash, - rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional), rHash: testnetChainHash, }, { name: "different chain, local-optional remote-required", - lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional), lHash: testnetChainHash, - rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), rHash: mainnetChainHash, expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash), }, { name: "different chain, local-required remote-optional", - lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional), lHash: testnetChainHash, - rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), rHash: mainnetChainHash, expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash), }, { name: "same chain, remote-unknown-required", - lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional), lHash: testnetChainHash, rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), rHash: testnetChainHash, diff --git a/watchtower/wtwire/message.go b/watchtower/wtwire/message.go index 192a397c..2b3d2288 100644 --- a/watchtower/wtwire/message.go +++ b/watchtower/wtwire/message.go @@ -24,29 +24,29 @@ type MessageType uint16 // Watchtower protocol. const ( // MsgInit identifies an encoded Init message. - MsgInit MessageType = 300 + MsgInit MessageType = 600 // MsgError identifies an encoded Error message. - MsgError = 301 + MsgError MessageType = 601 // MsgCreateSession identifies an encoded CreateSession message. - MsgCreateSession MessageType = 302 + MsgCreateSession MessageType = 602 // MsgCreateSessionReply identifies an encoded CreateSessionReply message. - MsgCreateSessionReply MessageType = 303 + MsgCreateSessionReply MessageType = 603 // MsgStateUpdate identifies an encoded StateUpdate message. - MsgStateUpdate MessageType = 304 + MsgStateUpdate MessageType = 604 // MsgStateUpdateReply identifies an encoded StateUpdateReply message. - MsgStateUpdateReply MessageType = 305 + MsgStateUpdateReply MessageType = 605 // MsgDeleteSession identifies an encoded DeleteSession message. - MsgDeleteSession MessageType = 306 + MsgDeleteSession MessageType = 606 // MsgDeleteSessionReply identifies an encoded DeleteSessionReply // message. - MsgDeleteSessionReply MessageType = 307 + MsgDeleteSessionReply MessageType = 607 ) // String returns a human readable description of the message type.