diff --git a/channeldb/db.go b/channeldb/db.go index fa844a2d..4d6958a3 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -8,10 +8,10 @@ import ( "path/filepath" "sync" - "github.com/coreos/bbolt" - "github.com/go-errors/errors" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/go-errors/errors" ) const ( @@ -233,56 +233,70 @@ func fileExists(path string) bool { return true } -// FetchOpenChannels returns all stored currently active/open channels -// associated with the target nodeID. In the case that no active channels are -// known to have been created with this node, then a zero-length slice is -// returned. +// FetchOpenChannels starts a new database transaction and returns all stored +// currently active/open channels associated with the target nodeID. In the case +// that no active channels are known to have been created with this node, then a +// zero-length slice is returned. func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { var channels []*OpenChannel err := d.View(func(tx *bolt.Tx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { + var err error + channels, err = d.fetchOpenChannels(tx, nodeID) + return err + }) + + return channels, err +} + +// fetchOpenChannels uses and existing database transaction and returns all +// stored currently active/open channels associated with the target nodeID. In +// the case that no active channels are known to have been created with this +// node, then a zero-length slice is returned. +func (d *DB) fetchOpenChannels(tx *bolt.Tx, + nodeID *btcec.PublicKey) ([]*OpenChannel, error) { + + // Get the bucket dedicated to storing the metadata for open channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return nil, nil + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + pub := nodeID.SerializeCompressed() + nodeChanBucket := openChanBucket.Bucket(pub) + if nodeChanBucket == nil { + return nil, nil + } + + // Next, we'll need to go down an additional layer in order to retrieve + // the channels for each chain the node knows of. + var channels []*OpenChannel + err := nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so ignore it. + if v != nil { return nil } - // Within this top level bucket, fetch the bucket dedicated to - // storing open channel data specific to the remote node. - pub := nodeID.SerializeCompressed() - nodeChanBucket := openChanBucket.Bucket(pub) - if nodeChanBucket == nil { - return nil + // If we've found a valid chainhash bucket, then we'll retrieve + // that so we can extract all the channels. + chainBucket := nodeChanBucket.Bucket(chainHash) + if chainBucket == nil { + return fmt.Errorf("unable to read bucket for chain=%x", + chainHash[:]) } - // Next, we'll need to go down an additional layer in order to - // retrieve the channels for each chain the node knows of. - return nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so ignore it. - if v != nil { - return nil - } + // Finally, we both of the necessary buckets retrieved, fetch + // all the active channels related to this node. + nodeChannels, err := d.fetchNodeChannels(chainBucket) + if err != nil { + return fmt.Errorf("unable to read channel for "+ + "chain_hash=%x, node_key=%x: %v", + chainHash[:], pub, err) + } - // If we've found a valid chainhash bucket, then we'll - // retrieve that so we can extract all the channels. - chainBucket := nodeChanBucket.Bucket(chainHash) - if chainBucket == nil { - return fmt.Errorf("unable to read bucket for "+ - "chain=%x", chainHash[:]) - } - - // Finally, we both of the necessary buckets retrieved, - // fetch all the active channels related to this node. - nodeChannels, err := d.fetchNodeChannels(chainBucket) - if err != nil { - return fmt.Errorf("unable to read channel for "+ - "chain_hash=%x, node_key=%x: %v", - chainHash[:], pub, err) - } - - channels = nodeChannels - return nil - }) + channels = nodeChannels + return nil }) return channels, err @@ -583,7 +597,56 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { return err } - return closedChanBucket.Put(chanID, newSummary.Bytes()) + err = closedChanBucket.Put(chanID, newSummary.Bytes()) + if err != nil { + return err + } + + // Now that the channel is closed, we'll check if we have any + // other open channels with this peer. If we don't we'll + // garbage collect it to ensure we don't establish persistent + // connections to peers without open channels. + return d.pruneLinkNode(tx, chanSummary.RemotePub) + }) +} + +// pruneLinkNode determines whether we should garbage collect a link node from +// the database due to no longer having any open channels with it. If there are +// any left, then this acts as a no-op. +func (db *DB) pruneLinkNode(tx *bolt.Tx, remotePub *btcec.PublicKey) error { + openChannels, err := db.fetchOpenChannels(tx, remotePub) + if err != nil { + return fmt.Errorf("unable to fetch open channels for peer %x: "+ + "%v", remotePub.SerializeCompressed(), err) + } + + if len(openChannels) > 0 { + return nil + } + + log.Infof("Pruning link node %x with zero open channels from database", + remotePub.SerializeCompressed()) + + return db.deleteLinkNode(tx, remotePub) +} + +// PruneLinkNodes attempts to prune all link nodes found within the databse with +// whom we no longer have any open channels with. +func (db *DB) PruneLinkNodes() error { + return db.Update(func(tx *bolt.Tx) error { + linkNodes, err := db.fetchAllLinkNodes(tx) + if err != nil { + return err + } + + for _, linkNode := range linkNodes { + err := db.pruneLinkNode(tx, linkNode.IdentityPub) + if err != nil { + return err + } + } + + return nil }) } diff --git a/channeldb/graph.go b/channeldb/graph.go index 886ceee4..920193aa 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -259,27 +259,12 @@ func (c *ChannelGraph) ForEachNode(tx *bolt.Tx, cb func(*bolt.Tx, *LightningNode func (c *ChannelGraph) SourceNode() (*LightningNode, error) { var source *LightningNode err := c.db.View(func(tx *bolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - selfPub := nodes.Get(sourceKey) - if selfPub == nil { - return ErrSourceNodeNotSet - } - - // With the pubKey of the source node retrieved, we're able to - // fetch the full node information. - node, err := fetchLightningNode(nodes, selfPub) + node, err := c.sourceNode(tx) if err != nil { return err } + source = node - source = &node - source.db = c.db return nil }) if err != nil { @@ -289,6 +274,34 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) { return source, nil } +// sourceNode uses an existing database transaction and returns the source node +// of the graph. The source node is treated as the center node within a +// star-graph. This method may be used to kick off a path finding algorithm in +// order to explore the reachability of another node based off the source node. +func (c *ChannelGraph) sourceNode(tx *bolt.Tx) (*LightningNode, error) { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return nil, ErrGraphNotFound + } + + selfPub := nodes.Get(sourceKey) + if selfPub == nil { + return nil, ErrSourceNodeNotSet + } + + // With the pubKey of the source node retrieved, we're able to + // fetch the full node information. + node, err := fetchLightningNode(nodes, selfPub) + if err != nil { + return nil, err + } + node.db = c.db + + return &node, nil +} + // SetSourceNode sets the source node within the graph database. The source // node is to be used as the center of a star-graph within path finding // algorithms. @@ -384,30 +397,36 @@ func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) { return alias, nil } -// DeleteLightningNode removes a vertex/node from the database according to the -// node's public key. +// DeleteLightningNode starts a new database transaction to remove a vertex/node +// from the database according to the node's public key. func (c *ChannelGraph) DeleteLightningNode(nodePub *btcec.PublicKey) error { - pub := nodePub.SerializeCompressed() - // TODO(roasbeef): ensure dangling edges are removed... return c.db.Update(func(tx *bolt.Tx) error { - nodes, err := tx.CreateBucketIfNotExists(nodeBucket) - if err != nil { - return err - } - - aliases, err := tx.CreateBucketIfNotExists(aliasIndexBucket) - if err != nil { - return err - } - - if err := aliases.Delete(pub); err != nil { - return err - } - return nodes.Delete(pub) + return c.deleteLightningNode(tx, nodePub.SerializeCompressed()) }) } +// deleteLightningNode uses an existing database transaction to remove a +// vertex/node from the database according to the node's public key. +func (c *ChannelGraph) deleteLightningNode(tx *bolt.Tx, + compressedPubKey []byte) error { + + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + aliases := nodes.Bucket(aliasIndexBucket) + if aliases == nil { + return ErrGraphNodesNotFound + } + + if err := aliases.Delete(compressedPubKey); err != nil { + return err + } + return nodes.Delete(compressedPubKey) +} + // AddChannelEdge adds a new (undirected, blank) edge to the graph database. An // undirected edge from the two target nodes are created. The information // stored denotes the static attributes of the channel, such as the channelID, @@ -569,6 +588,13 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, var chansClosed []*ChannelEdgeInfo + // nodesWithChansClosed is the set of nodes, each identified by their + // compressed public key, who had a channel closed within the latest + // block. We'll use this later on to determine whether we should prune + // them from the channel graph due to no longer having any other open + // channels. + nodesWithChansClosed := make(map[[33]byte]struct{}) + err := c.db.Update(func(tx *bolt.Tx) error { // First grab the edges bucket which houses the information // we'd like to delete @@ -617,7 +643,6 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, if err != nil { return err } - chansClosed = append(chansClosed, &edgeInfo) // Attempt to delete the channel, an ErrEdgeNotFound // will be returned if that outpoint isn't known to be @@ -629,6 +654,12 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, if err != nil && err != ErrEdgeNotFound { return err } + + // Include this channel in our list of closed channels + // and collect the node public keys at each end. + chansClosed = append(chansClosed, &edgeInfo) + nodesWithChansClosed[edgeInfo.NodeKey1Bytes] = struct{}{} + nodesWithChansClosed[edgeInfo.NodeKey2Bytes] = struct{}{} } metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) @@ -650,7 +681,15 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, var newTip [pruneTipBytes]byte copy(newTip[:], blockHash[:]) - return pruneBucket.Put(blockHeightBytes[:], newTip[:]) + err = pruneBucket.Put(blockHeightBytes[:], newTip[:]) + if err != nil { + return err + } + + // Now that the graph has been pruned, we'll also attempt to + // prune any nodes that have had a channel closed within the + // latest block. + return c.pruneGraphNodes(tx, nodes, nodesWithChansClosed) }) if err != nil { return nil, err @@ -659,6 +698,58 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, return chansClosed, nil } +// pruneGraphNodes attempts to remove any nodes from the graph who have had a +// channel closed within the current block. If the node still has existing +// channels in the graph, this will act as a no-op. +func (c *ChannelGraph) pruneGraphNodes(tx *bolt.Tx, nodes *bolt.Bucket, + nodePubKeys map[[33]byte]struct{}) error { + + log.Trace("Pruning nodes from graph with no open channels") + + // We'll retrieve the graph's source node to ensure we don't remove it + // even if it no longer has any open channels. + sourceNode, err := c.sourceNode(tx) + if err != nil { + return err + } + + // We'll now iterate over every node which had a channel closed and + // check whether they have any other open channels left within the + // graph. If they don't, they'll be pruned from the channel graph. + for nodePubKey := range nodePubKeys { + if bytes.Equal(nodePubKey[:], sourceNode.PubKeyBytes[:]) { + continue + } + + node, err := fetchLightningNode(nodes, nodePubKey[:]) + if err != nil { + continue + } + node.db = c.db + + numChansLeft := 0 + err = node.ForEachChannel(tx, func(*bolt.Tx, *ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) error { + + numChansLeft++ + return nil + }) + if err != nil { + continue + } + + if numChansLeft == 0 { + err := c.deleteLightningNode(tx, nodePubKey[:]) + if err != nil { + log.Tracef("Unable to prune node %x from the "+ + "graph: %v", nodePubKey, err) + } + } + } + + return nil +} + // DisconnectBlockAtHeight is used to indicate that the block specified // by the passed height has been disconnected from the main chain. This // will "rewind" the graph back to the height below, deleting channels diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 4f7756fb..2533a66a 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -421,6 +421,13 @@ func TestDisconnectBlockAtHeight(t *testing.T) { } graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. @@ -964,7 +971,7 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { return nil }); err != nil { _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v:unable to scan channels: %v", line, err) + t.Fatalf("line %v: unable to scan channels: %v", line, err) } if numChans != n { _, _, line, _ := runtime.Caller(1) @@ -973,6 +980,23 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { } } +func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { + numNodes := 0 + err := graph.ForEachNode(nil, func(_ *bolt.Tx, _ *LightningNode) error { + numNodes++ + return nil + }) + if err != nil { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: unable to scan nodes: %v", line, err) + } + + if numNodes != n { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: expected %v nodes, got %v", line, n, numNodes) + } +} + func assertChanViewEqual(t *testing.T, a []wire.OutPoint, b []*wire.OutPoint) { if len(a) != len(b) { _, _, line, _ := runtime.Caller(1) @@ -1003,6 +1027,13 @@ func TestGraphPruning(t *testing.T) { } graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } // As initial set up for the test, we'll create a graph with 5 vertexes // and enough edges to create a fully connected graph. The graph will @@ -1137,9 +1168,11 @@ func TestGraphPruning(t *testing.T) { t.Fatalf("channels were pruned but shouldn't have been") } - // Once again, the prune tip should have been updated. + // Once again, the prune tip should have been updated. We should still + // see both channels and their participants, along with the source node. assertPruneTip(t, graph, &blockHash, blockHeight) assertNumChans(t, graph, 2) + assertNumNodes(t, graph, 4) // Finally, create a block that prunes the remainder of the channels // from the graph. @@ -1159,10 +1192,11 @@ func TestGraphPruning(t *testing.T) { "expected %v, got %v", 2, len(prunedChans)) } - // The prune tip should be updated, and no channels should be found - // within the current graph. + // The prune tip should be updated, no channels should be found, and + // only the source node should remain within the current graph. assertPruneTip(t, graph, &blockHash, blockHeight) assertNumChans(t, graph, 0) + assertNumNodes(t, graph, 1) // Finally, the channel view at this point in the graph should now be // completely empty. Those channels should also be missing from the @@ -1888,6 +1922,13 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { } graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. diff --git a/channeldb/nodes.go b/channeldb/nodes.go index 3b1e7fee..59c67d41 100644 --- a/channeldb/nodes.go +++ b/channeldb/nodes.go @@ -6,9 +6,9 @@ import ( "net" "time" - "github.com/coreos/bbolt" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" ) var ( @@ -127,6 +127,24 @@ func putLinkNode(nodeMetaBucket *bolt.Bucket, l *LinkNode) error { return nodeMetaBucket.Put(nodePub, b.Bytes()) } +// DeleteLinkNode removes the link node with the given identity from the +// database. +func (d *DB) DeleteLinkNode(identity *btcec.PublicKey) error { + return d.Update(func(tx *bolt.Tx) error { + return d.deleteLinkNode(tx, identity) + }) +} + +func (d *DB) deleteLinkNode(tx *bolt.Tx, identity *btcec.PublicKey) error { + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return ErrLinkNodesNotFound + } + + pubKey := identity.SerializeCompressed() + return nodeMetaBucket.Delete(pubKey) +} + // FetchLinkNode attempts to lookup the data for a LinkNode based on a target // identity public key. If a particular LinkNode for the passed identity public // key cannot be found, then ErrNodeNotFound if returned. @@ -165,32 +183,48 @@ func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { return node, nil } -// FetchAllLinkNodes attempts to fetch all active LinkNodes from the database. -// If there haven't been any channels explicitly linked to LinkNodes written to -// the database, then this function will return an empty slice. +// FetchAllLinkNodes starts a new database transaction to fetch all nodes with +// whom we have active channels with. func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) { var linkNodes []*LinkNode - err := db.View(func(tx *bolt.Tx) error { - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return ErrLinkNodesNotFound + nodes, err := db.fetchAllLinkNodes(tx) + if err != nil { + return err } - return nodeMetaBucket.ForEach(func(k, v []byte) error { - if v == nil { - return nil - } + linkNodes = nodes + return nil + }) + if err != nil { + return nil, err + } - nodeReader := bytes.NewReader(v) - linkNode, err := deserializeLinkNode(nodeReader) - if err != nil { - return err - } + return linkNodes, nil +} - linkNodes = append(linkNodes, linkNode) +// fetchAllLinkNodes uses an existing database transaction to fetch all nodes +// with whom we have active channels with. +func (db *DB) fetchAllLinkNodes(tx *bolt.Tx) ([]*LinkNode, error) { + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return nil, ErrLinkNodesNotFound + } + + var linkNodes []*LinkNode + err := nodeMetaBucket.ForEach(func(k, v []byte) error { + if v == nil { return nil - }) + } + + nodeReader := bytes.NewReader(v) + linkNode, err := deserializeLinkNode(nodeReader) + if err != nil { + return err + } + + linkNodes = append(linkNodes, linkNode) + return nil }) if err != nil { return nil, err diff --git a/channeldb/nodes_test.go b/channeldb/nodes_test.go index 7f968efe..755177aa 100644 --- a/channeldb/nodes_test.go +++ b/channeldb/nodes_test.go @@ -106,3 +106,35 @@ func TestLinkNodeEncodeDecode(t *testing.T) { addr2.String(), node1DB.Addresses[1].String()) } } + +func TestDeleteLinkNode(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + _, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1337, + } + linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr) + if err := linkNode.Sync(); err != nil { + t.Fatalf("unable to write link node to db: %v", err) + } + + if _, err := cdb.FetchLinkNode(pubKey); err != nil { + t.Fatalf("unable to find link node: %v", err) + } + + if err := cdb.DeleteLinkNode(pubKey); err != nil { + t.Fatalf("unable to delete link node from db: %v", err) + } + + if _, err := cdb.FetchLinkNode(pubKey); err == nil { + t.Fatal("should not have found link node in db, but did") + } +} diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 0d8b6b5d..fb5f1c97 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -342,23 +342,6 @@ func (c *ChainArbitrator) Start() error { pCache: c.cfg.PreimageDB, signer: c.cfg.Signer, isOurAddr: c.cfg.IsOurAddress, - notifyChanClosed: func() error { - c.Lock() - delete(c.activeChannels, chanPoint) - - chainWatcher, ok := c.activeWatchers[chanPoint] - if ok { - // Since the chainWatcher is - // calling notifyChanClosed, we - // must stop it in a goroutine - // to not deadlock. - go chainWatcher.Stop() - } - delete(c.activeWatchers, chanPoint) - c.Unlock() - - return nil - }, contractBreach: func(retInfo *lnwallet.BreachRetribution) error { return c.cfg.ContractBreach(chanPoint, retInfo) }, @@ -697,22 +680,6 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error pCache: c.cfg.PreimageDB, signer: c.cfg.Signer, isOurAddr: c.cfg.IsOurAddress, - notifyChanClosed: func() error { - c.Lock() - delete(c.activeChannels, chanPoint) - - chainWatcher, ok := c.activeWatchers[chanPoint] - if ok { - // Since the chainWatcher is calling - // notifyChanClosed, we must stop it in - // a goroutine to not deadlock. - go chainWatcher.Stop() - } - delete(c.activeWatchers, chanPoint) - c.Unlock() - - return nil - }, contractBreach: func(retInfo *lnwallet.BreachRetribution) error { return c.cfg.ContractBreach(chanPoint, retInfo) }, diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 8333b52f..1ab959cb 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -78,13 +78,6 @@ type chainWatcherConfig struct { // machine. signer lnwallet.Signer - // notifyChanClosed is a method that will be called by the watcher when - // it has detected a close on-chain and performed all necessary - // actions, like marking the channel closed in the database and - // notified all its subcribers. It lets the chain arbitrator know that - // the chain watcher chan be stopped. - notifyChanClosed func() error - // contractBreach is a method that will be called by the watcher if it // detects that a contract breach transaction has been confirmed. Only // when this method returns with a non-nil error it will be safe to mark @@ -492,16 +485,7 @@ func (c *chainWatcher) dispatchCooperativeClose(commitSpend *chainntnfs.SpendDet } c.Unlock() - // Now notify the ChainArbitrator that the watcher's job is done, such - // that it can shut it down and clean up. - if err := c.cfg.notifyChanClosed(); err != nil { - log.Errorf("unable to notify channel closed for "+ - "ChannelPoint(%v): %v", - c.cfg.chanState.FundingOutpoint, err) - } - return nil - } // dispatchLocalForceClose processes a unilateral close by us being confirmed. diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index b80995b9..3d96d51b 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -1364,10 +1364,16 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { ) // We've cooperatively closed the channel, so we're no longer - // needed. + // needed. We'll mark the channel as resolved and exit. case <-c.cfg.ChainEvents.CooperativeClosure: log.Infof("ChannelArbitrator(%v) closing due to co-op "+ "closure", c.cfg.ChanPoint) + + if err := c.cfg.MarkChannelResolved(); err != nil { + log.Errorf("Unable to mark contract "+ + "resolved: %v", err) + } + return // We have broadcasted our commitment, and it is now confirmed diff --git a/lnd_test.go b/lnd_test.go index 8f2a9beb..f3a34503 100644 --- a/lnd_test.go +++ b/lnd_test.go @@ -22,18 +22,18 @@ import ( "crypto/sha256" prand "math/rand" - "github.com/btcsuite/btclog" - "github.com/davecgh/go-spew/spew" - "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/lnrpc" - "github.com/lightningnetwork/lnd/lntest" - "github.com/lightningnetwork/lnd/lnwire" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/integration/rpctest" "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btclog" "github.com/btcsuite/btcutil" + "github.com/davecgh/go-spew/spew" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lntest" + "github.com/lightningnetwork/lnd/lnwire" "golang.org/x/net/context" "google.golang.org/grpc" ) @@ -2809,7 +2809,7 @@ func updateChannelPolicy(t *harnessTest, node *lntest.HarnessNode, // Wait for listener node to receive the channel update from node. ctxt, _ = context.WithTimeout(ctxb, timeout) - listenerUpdates, aQuit := subscribeGraphNotifications(t, ctxt, + listenerUpdates, aQuit := subscribeGraphNotifications(t, ctxt, listenerNode) defer close(aQuit) @@ -2980,8 +2980,8 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { time.Sleep(time.Millisecond * 50) - // Set the fee policies of the Alice -> Bob and the Dave -> Alice - // channel edges to relatively large non default values. This makes it + // Set the fee policies of the Alice -> Bob and the Dave -> Alice + // channel edges to relatively large non default values. This makes it // possible to pick up more subtle fee calculation errors. updateChannelPolicy(t, net.Alice, chanPointAlice, 1000, 100000, 144, carol) @@ -3017,10 +3017,10 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { assertAmountPaid(t, ctxb, "Alice(local) => Bob(remote)", net.Alice, aliceFundPoint, expectedAmountPaidAtoB, int64(0)) - // To forward a payment of 1000 sat, Alice is charging a fee of + // To forward a payment of 1000 sat, Alice is charging a fee of // 1 sat + 10% = 101 sat. const expectedFeeAlice = 5 * 101 - + // Dave needs to pay what Alice pays plus Alice's fee. expectedAmountPaidDtoA := expectedAmountPaidAtoB + expectedFeeAlice @@ -3029,7 +3029,7 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { assertAmountPaid(t, ctxb, "Dave(local) => Alice(remote)", dave, daveFundPoint, expectedAmountPaidDtoA, int64(0)) - // To forward a payment of 1101 sat, Dave is charging a fee of + // To forward a payment of 1101 sat, Dave is charging a fee of // 5 sat + 15% = 170.15 sat. This is rounded down in rpcserver to 170. const expectedFeeDave = 5 * 170 @@ -4900,6 +4900,191 @@ func testFailingChannel(net *lntest.NetworkHarness, t *harnessTest) { } } +// testGarbageCollectLinkNodes tests that we properly garbase collect link nodes +// from the database and the set of persistent connections within the server. +func testGarbageCollectLinkNodes(net *lntest.NetworkHarness, t *harnessTest) { + const ( + timeout = time.Second * 10 + chanAmt = 1000000 + ) + + // Open a channel between Alice and Bob which will later be + // cooperatively closed. + ctxb := context.Background() + ctxt, _ := context.WithTimeout(ctxb, timeout) + coopChanPoint := openChannelAndAssert( + ctxt, t, net, net.Alice, net.Bob, chanAmt, 0, false, + ) + + // Create Carol's node and connect Alice to her. + carol, err := net.NewNode("Carol", nil) + if err != nil { + t.Fatalf("unable to create carol's node: %v", err) + } + defer shutdownAndAssert(net, t, carol) + ctxt, _ = context.WithTimeout(ctxb, timeout) + if err := net.ConnectNodes(ctxt, net.Alice, carol); err != nil { + t.Fatalf("unable to connect alice and carol: %v", err) + } + + // Open a channel between Alice and Carol which will later be force + // closed. + ctxt, _ = context.WithTimeout(ctxb, timeout) + forceCloseChanPoint := openChannelAndAssert( + ctxt, t, net, net.Alice, carol, chanAmt, 0, false, + ) + + // Now, create Dave's a node and also open a channel between Alice and + // him. This link will serve as the only persistent link throughout + // restarts in this test. + dave, err := net.NewNode("Dave", nil) + if err != nil { + t.Fatalf("unable to create dave's node: %v", err) + } + defer shutdownAndAssert(net, t, dave) + if err := net.ConnectNodes(ctxt, net.Alice, dave); err != nil { + t.Fatalf("unable to connect alice to dave: %v", err) + } + ctxt, _ = context.WithTimeout(ctxb, timeout) + persistentChanPoint := openChannelAndAssert( + ctxt, t, net, net.Alice, dave, chanAmt, 0, false, + ) + + // isConnected is a helper closure that checks if a peer is connected to + // Alice. + isConnected := func(pubKey string) bool { + req := &lnrpc.ListPeersRequest{} + resp, err := net.Alice.ListPeers(ctxb, req) + if err != nil { + t.Fatalf("unable to retrieve alice's peers: %v", err) + } + + for _, peer := range resp.Peers { + if peer.PubKey == pubKey { + return true + } + } + + return false + } + + // Restart both Bob and Carol to ensure Alice is able to reconnect to + // them. + if err := net.RestartNode(net.Bob, nil); err != nil { + t.Fatalf("unable to restart bob's node: %v", err) + } + if err := net.RestartNode(carol, nil); err != nil { + t.Fatalf("unable to restart carol's node: %v", err) + } + + err = lntest.WaitPredicate(func() bool { + return isConnected(net.Bob.PubKeyStr) + }, 15*time.Second) + if err != nil { + t.Fatalf("alice did not reconnect to bob") + } + err = lntest.WaitPredicate(func() bool { + return isConnected(carol.PubKeyStr) + }, 15*time.Second) + if err != nil { + t.Fatalf("alice did not reconnect to carol") + } + + // We'll also restart Alice to ensure she can reconnect to her peers + // with open channels. + if err := net.RestartNode(net.Alice, nil); err != nil { + t.Fatalf("unable to restart alice's node: %v", err) + } + + err = lntest.WaitPredicate(func() bool { + return isConnected(net.Bob.PubKeyStr) + }, 15*time.Second) + if err != nil { + t.Fatalf("alice did not reconnect to bob") + } + err = lntest.WaitPredicate(func() bool { + return isConnected(carol.PubKeyStr) + }, 15*time.Second) + if err != nil { + t.Fatalf("alice did not reconnect to carol") + } + + // testReconnection is a helper closure that restarts the nodes at both + // ends of a channel to ensure they do not reconnect after restarting. + // When restarting Alice, we'll first need to ensure she has + // reestablished her connection with Dave, as they still have an open + // channel together. + testReconnection := func(node *lntest.HarnessNode) { + if err := net.RestartNode(node, nil); err != nil { + t.Fatalf("unable to restart %v's node: %v", node.Name(), + err) + } + err = lntest.WaitPredicate(func() bool { + return !isConnected(node.PubKeyStr) + }, 20*time.Second) + if err != nil { + t.Fatalf("alice reconnected to %v", node.Name()) + } + + if err := net.RestartNode(net.Alice, nil); err != nil { + t.Fatalf("unable to restart alice's node: %v", err) + } + err = lntest.WaitPredicate(func() bool { + if !isConnected(dave.PubKeyStr) { + return false + } + return !isConnected(node.PubKeyStr) + }, 20*time.Second) + if err != nil { + t.Fatalf("alice reconnected to %v", node.Name()) + } + } + + // Now, we'll close the channel between Alice and Bob and ensure there + // is no reconnection logic between the both once the channel is fully + // closed. + ctxt, _ = context.WithTimeout(ctxb, timeout) + closeChannelAndAssert(ctxt, t, net, net.Alice, coopChanPoint, false) + + testReconnection(net.Bob) + + // We'll do the same with Alice and Carol, but this time we'll force + // close the channel instead. + ctxt, _ = context.WithTimeout(ctxb, timeout) + closeChannelAndAssert(ctxt, t, net, net.Alice, forceCloseChanPoint, true) + + // We'll need to mine some blocks in order to mark the channel fully + // closed. + _, err = net.Miner.Node.Generate(defaultBitcoinTimeLockDelta) + if err != nil { + t.Fatalf("unable to generate blocks: %v", err) + } + + testReconnection(carol) + + // Finally, we'll ensure that Bob and Carol no longer show in Alice's + // channel graph. + describeGraphReq := &lnrpc.ChannelGraphRequest{} + channelGraph, err := net.Alice.DescribeGraph(ctxb, describeGraphReq) + if err != nil { + t.Fatalf("unable to query for alice's channel graph: %v", err) + } + for _, node := range channelGraph.Nodes { + if node.PubKey == net.Bob.PubKeyStr { + t.Fatalf("did not expect to find bob in the channel " + + "graph, but did") + } + if node.PubKey == carol.PubKeyStr { + t.Fatalf("did not expect to find carol in the channel " + + "graph, but did") + } + } + + // Now that the test is done, we can also close the persistent link. + ctxt, _ = context.WithTimeout(ctxb, timeout) + closeChannelAndAssert(ctxt, t, net, net.Alice, persistentChanPoint, false) +} + // testRevokedCloseRetribution tests that Alice is able carry out // retribution in the event that she fails immediately after detecting Bob's // breach txn in the mempool. @@ -10324,6 +10509,10 @@ var testsCases = []*testCase{ name: "failing link", test: testFailingChannel, }, + { + name: "garbage collect link nodes", + test: testGarbageCollectLinkNodes, + }, { name: "revoked uncooperative close retribution zero value remote output", test: testRevokedCloseRetributionZeroValueRemoteOutput, @@ -10422,6 +10611,14 @@ func TestLightningNetworkDaemon(t *testing.T) { for _, testCase := range testsCases { logLine := fmt.Sprintf("STARTING ============ %v ============\n", testCase.name) + + err := lndHarness.EnsureConnected( + context.Background(), lndHarness.Alice, lndHarness.Bob, + ) + if err != nil { + t.Fatalf("unable to connect alice to bob: %v", err) + } + if err := lndHarness.Alice.AddToLog(logLine); err != nil { t.Fatalf("unable to add to log: %v", err) } diff --git a/peer.go b/peer.go index e1113c24..3b92cc8c 100644 --- a/peer.go +++ b/peer.go @@ -1697,9 +1697,6 @@ func (p *peer) fetchActiveChanCloser(chanID lnwire.ChannelID) (*channelCloser, e // handleLocalCloseReq kicks-off the workflow to execute a cooperative or // forced unilateral closure of the channel initiated by a local subsystem. -// -// TODO(roasbeef): if no more active channels with peer call Remove on connMgr -// with peerID func (p *peer) handleLocalCloseReq(req *htlcswitch.ChanClose) { chanID := lnwire.NewChanIDFromOutPoint(req.ChanPoint) @@ -1852,6 +1849,18 @@ func (p *peer) finalizeChanClosure(chanCloser *channelCloser) { }, } } + + // Remove the persistent connection to this peer if we + // no longer have open channels with them. + p.activeChanMtx.Lock() + numActiveChans := len(p.activeChannels) + p.activeChanMtx.Unlock() + + if numActiveChans == 0 { + p.server.prunePersistentPeerConnection( + p.pubKeyBytes, + ) + } }) } @@ -1905,6 +1914,9 @@ func (p *peer) WipeChannel(chanPoint *wire.OutPoint) error { if channel, ok := p.activeChannels[chanID]; ok { channel.Stop() delete(p.activeChannels, chanID) + if len(p.activeChannels) == 0 { + p.server.prunePersistentPeerConnection(p.pubKeyBytes) + } } p.activeChanMtx.Unlock() diff --git a/server.go b/server.go index cd979aa5..805d3788 100644 --- a/server.go +++ b/server.go @@ -759,7 +759,12 @@ func (s *server) Start() error { // With all the relevant sub-systems started, we'll now attempt to // establish persistent connections to our direct channel collaborators - // within the network. + // within the network. Before doing so however, we'll prune our set of + // link nodes found within the database to ensure we don't reconnect to + // any nodes we no longer have open channels with. + if err := s.chanDB.PruneLinkNodes(); err != nil { + return err + } if err := s.establishPersistentConnections(); err != nil { return err } @@ -1476,6 +1481,22 @@ func (s *server) establishPersistentConnections() error { return nil } +// prunePersistentPeerConnection removes all internal state related to +// persistent connections to a peer within the server. This is used to avoid +// persistent connection retries to peers we do not have any open channels with. +func (s *server) prunePersistentPeerConnection(compressedPubKey [33]byte) { + srvrLog.Infof("Pruning peer %x from persistent connections, number of "+ + "open channels is now zero", compressedPubKey) + + pubKeyStr := string(compressedPubKey[:]) + + s.mu.Lock() + delete(s.persistentPeers, pubKeyStr) + delete(s.persistentPeersBackoff, pubKeyStr) + s.cancelConnReqs(pubKeyStr, nil) + s.mu.Unlock() +} + // BroadcastMessage sends a request to the server to broadcast a set of // messages to all peers other than the one specified by the `skips` parameter. //