Merge pull request #1371 from wpaulino/prune-link-nodes

server: prune link nodes without any open channels
This commit is contained in:
Olaoluwa Osuntokun 2018-07-21 18:55:38 -07:00 committed by GitHub
commit bca926d6af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 618 additions and 170 deletions

@ -8,10 +8,10 @@ import (
"path/filepath" "path/filepath"
"sync" "sync"
"github.com/coreos/bbolt"
"github.com/go-errors/errors"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt"
"github.com/go-errors/errors"
) )
const ( const (
@ -233,46 +233,61 @@ func fileExists(path string) bool {
return true return true
} }
// FetchOpenChannels returns all stored currently active/open channels // FetchOpenChannels starts a new database transaction and returns all stored
// associated with the target nodeID. In the case that no active channels are // currently active/open channels associated with the target nodeID. In the case
// known to have been created with this node, then a zero-length slice is // that no active channels are known to have been created with this node, then a
// returned. // zero-length slice is returned.
func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
var channels []*OpenChannel var channels []*OpenChannel
err := d.View(func(tx *bolt.Tx) error { err := d.View(func(tx *bolt.Tx) error {
// Get the bucket dedicated to storing the metadata for open var err error
// channels. channels, err = d.fetchOpenChannels(tx, nodeID)
return err
})
return channels, err
}
// fetchOpenChannels uses and existing database transaction and returns all
// stored currently active/open channels associated with the target nodeID. In
// the case that no active channels are known to have been created with this
// node, then a zero-length slice is returned.
func (d *DB) fetchOpenChannels(tx *bolt.Tx,
nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
// Get the bucket dedicated to storing the metadata for open channels.
openChanBucket := tx.Bucket(openChannelBucket) openChanBucket := tx.Bucket(openChannelBucket)
if openChanBucket == nil { if openChanBucket == nil {
return nil return nil, nil
} }
// Within this top level bucket, fetch the bucket dedicated to // Within this top level bucket, fetch the bucket dedicated to storing
// storing open channel data specific to the remote node. // open channel data specific to the remote node.
pub := nodeID.SerializeCompressed() pub := nodeID.SerializeCompressed()
nodeChanBucket := openChanBucket.Bucket(pub) nodeChanBucket := openChanBucket.Bucket(pub)
if nodeChanBucket == nil { if nodeChanBucket == nil {
return nil return nil, nil
} }
// Next, we'll need to go down an additional layer in order to // Next, we'll need to go down an additional layer in order to retrieve
// retrieve the channels for each chain the node knows of. // the channels for each chain the node knows of.
return nodeChanBucket.ForEach(func(chainHash, v []byte) error { var channels []*OpenChannel
err := nodeChanBucket.ForEach(func(chainHash, v []byte) error {
// If there's a value, it's not a bucket so ignore it. // If there's a value, it's not a bucket so ignore it.
if v != nil { if v != nil {
return nil return nil
} }
// If we've found a valid chainhash bucket, then we'll // If we've found a valid chainhash bucket, then we'll retrieve
// retrieve that so we can extract all the channels. // that so we can extract all the channels.
chainBucket := nodeChanBucket.Bucket(chainHash) chainBucket := nodeChanBucket.Bucket(chainHash)
if chainBucket == nil { if chainBucket == nil {
return fmt.Errorf("unable to read bucket for "+ return fmt.Errorf("unable to read bucket for chain=%x",
"chain=%x", chainHash[:]) chainHash[:])
} }
// Finally, we both of the necessary buckets retrieved, // Finally, we both of the necessary buckets retrieved, fetch
// fetch all the active channels related to this node. // all the active channels related to this node.
nodeChannels, err := d.fetchNodeChannels(chainBucket) nodeChannels, err := d.fetchNodeChannels(chainBucket)
if err != nil { if err != nil {
return fmt.Errorf("unable to read channel for "+ return fmt.Errorf("unable to read channel for "+
@ -283,7 +298,6 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error)
channels = nodeChannels channels = nodeChannels
return nil return nil
}) })
})
return channels, err return channels, err
} }
@ -583,7 +597,56 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
return err return err
} }
return closedChanBucket.Put(chanID, newSummary.Bytes()) err = closedChanBucket.Put(chanID, newSummary.Bytes())
if err != nil {
return err
}
// Now that the channel is closed, we'll check if we have any
// other open channels with this peer. If we don't we'll
// garbage collect it to ensure we don't establish persistent
// connections to peers without open channels.
return d.pruneLinkNode(tx, chanSummary.RemotePub)
})
}
// pruneLinkNode determines whether we should garbage collect a link node from
// the database due to no longer having any open channels with it. If there are
// any left, then this acts as a no-op.
func (db *DB) pruneLinkNode(tx *bolt.Tx, remotePub *btcec.PublicKey) error {
openChannels, err := db.fetchOpenChannels(tx, remotePub)
if err != nil {
return fmt.Errorf("unable to fetch open channels for peer %x: "+
"%v", remotePub.SerializeCompressed(), err)
}
if len(openChannels) > 0 {
return nil
}
log.Infof("Pruning link node %x with zero open channels from database",
remotePub.SerializeCompressed())
return db.deleteLinkNode(tx, remotePub)
}
// PruneLinkNodes attempts to prune all link nodes found within the databse with
// whom we no longer have any open channels with.
func (db *DB) PruneLinkNodes() error {
return db.Update(func(tx *bolt.Tx) error {
linkNodes, err := db.fetchAllLinkNodes(tx)
if err != nil {
return err
}
for _, linkNode := range linkNodes {
err := db.pruneLinkNode(tx, linkNode.IdentityPub)
if err != nil {
return err
}
}
return nil
}) })
} }

@ -259,27 +259,12 @@ func (c *ChannelGraph) ForEachNode(tx *bolt.Tx, cb func(*bolt.Tx, *LightningNode
func (c *ChannelGraph) SourceNode() (*LightningNode, error) { func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
var source *LightningNode var source *LightningNode
err := c.db.View(func(tx *bolt.Tx) error { err := c.db.View(func(tx *bolt.Tx) error {
// First grab the nodes bucket which stores the mapping from node, err := c.sourceNode(tx)
// pubKey to node information.
nodes := tx.Bucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
selfPub := nodes.Get(sourceKey)
if selfPub == nil {
return ErrSourceNodeNotSet
}
// With the pubKey of the source node retrieved, we're able to
// fetch the full node information.
node, err := fetchLightningNode(nodes, selfPub)
if err != nil { if err != nil {
return err return err
} }
source = node
source = &node
source.db = c.db
return nil return nil
}) })
if err != nil { if err != nil {
@ -289,6 +274,34 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
return source, nil return source, nil
} }
// sourceNode uses an existing database transaction and returns the source node
// of the graph. The source node is treated as the center node within a
// star-graph. This method may be used to kick off a path finding algorithm in
// order to explore the reachability of another node based off the source node.
func (c *ChannelGraph) sourceNode(tx *bolt.Tx) (*LightningNode, error) {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.Bucket(nodeBucket)
if nodes == nil {
return nil, ErrGraphNotFound
}
selfPub := nodes.Get(sourceKey)
if selfPub == nil {
return nil, ErrSourceNodeNotSet
}
// With the pubKey of the source node retrieved, we're able to
// fetch the full node information.
node, err := fetchLightningNode(nodes, selfPub)
if err != nil {
return nil, err
}
node.db = c.db
return &node, nil
}
// SetSourceNode sets the source node within the graph database. The source // SetSourceNode sets the source node within the graph database. The source
// node is to be used as the center of a star-graph within path finding // node is to be used as the center of a star-graph within path finding
// algorithms. // algorithms.
@ -384,30 +397,36 @@ func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) {
return alias, nil return alias, nil
} }
// DeleteLightningNode removes a vertex/node from the database according to the // DeleteLightningNode starts a new database transaction to remove a vertex/node
// node's public key. // from the database according to the node's public key.
func (c *ChannelGraph) DeleteLightningNode(nodePub *btcec.PublicKey) error { func (c *ChannelGraph) DeleteLightningNode(nodePub *btcec.PublicKey) error {
pub := nodePub.SerializeCompressed()
// TODO(roasbeef): ensure dangling edges are removed... // TODO(roasbeef): ensure dangling edges are removed...
return c.db.Update(func(tx *bolt.Tx) error { return c.db.Update(func(tx *bolt.Tx) error {
nodes, err := tx.CreateBucketIfNotExists(nodeBucket) return c.deleteLightningNode(tx, nodePub.SerializeCompressed())
if err != nil {
return err
}
aliases, err := tx.CreateBucketIfNotExists(aliasIndexBucket)
if err != nil {
return err
}
if err := aliases.Delete(pub); err != nil {
return err
}
return nodes.Delete(pub)
}) })
} }
// deleteLightningNode uses an existing database transaction to remove a
// vertex/node from the database according to the node's public key.
func (c *ChannelGraph) deleteLightningNode(tx *bolt.Tx,
compressedPubKey []byte) error {
nodes := tx.Bucket(nodeBucket)
if nodes == nil {
return ErrGraphNodesNotFound
}
aliases := nodes.Bucket(aliasIndexBucket)
if aliases == nil {
return ErrGraphNodesNotFound
}
if err := aliases.Delete(compressedPubKey); err != nil {
return err
}
return nodes.Delete(compressedPubKey)
}
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An // AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
// undirected edge from the two target nodes are created. The information // undirected edge from the two target nodes are created. The information
// stored denotes the static attributes of the channel, such as the channelID, // stored denotes the static attributes of the channel, such as the channelID,
@ -569,6 +588,13 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
var chansClosed []*ChannelEdgeInfo var chansClosed []*ChannelEdgeInfo
// nodesWithChansClosed is the set of nodes, each identified by their
// compressed public key, who had a channel closed within the latest
// block. We'll use this later on to determine whether we should prune
// them from the channel graph due to no longer having any other open
// channels.
nodesWithChansClosed := make(map[[33]byte]struct{})
err := c.db.Update(func(tx *bolt.Tx) error { err := c.db.Update(func(tx *bolt.Tx) error {
// First grab the edges bucket which houses the information // First grab the edges bucket which houses the information
// we'd like to delete // we'd like to delete
@ -617,7 +643,6 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
if err != nil { if err != nil {
return err return err
} }
chansClosed = append(chansClosed, &edgeInfo)
// Attempt to delete the channel, an ErrEdgeNotFound // Attempt to delete the channel, an ErrEdgeNotFound
// will be returned if that outpoint isn't known to be // will be returned if that outpoint isn't known to be
@ -629,6 +654,12 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
if err != nil && err != ErrEdgeNotFound { if err != nil && err != ErrEdgeNotFound {
return err return err
} }
// Include this channel in our list of closed channels
// and collect the node public keys at each end.
chansClosed = append(chansClosed, &edgeInfo)
nodesWithChansClosed[edgeInfo.NodeKey1Bytes] = struct{}{}
nodesWithChansClosed[edgeInfo.NodeKey2Bytes] = struct{}{}
} }
metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket)
@ -650,7 +681,15 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
var newTip [pruneTipBytes]byte var newTip [pruneTipBytes]byte
copy(newTip[:], blockHash[:]) copy(newTip[:], blockHash[:])
return pruneBucket.Put(blockHeightBytes[:], newTip[:]) err = pruneBucket.Put(blockHeightBytes[:], newTip[:])
if err != nil {
return err
}
// Now that the graph has been pruned, we'll also attempt to
// prune any nodes that have had a channel closed within the
// latest block.
return c.pruneGraphNodes(tx, nodes, nodesWithChansClosed)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -659,6 +698,58 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
return chansClosed, nil return chansClosed, nil
} }
// pruneGraphNodes attempts to remove any nodes from the graph who have had a
// channel closed within the current block. If the node still has existing
// channels in the graph, this will act as a no-op.
func (c *ChannelGraph) pruneGraphNodes(tx *bolt.Tx, nodes *bolt.Bucket,
nodePubKeys map[[33]byte]struct{}) error {
log.Trace("Pruning nodes from graph with no open channels")
// We'll retrieve the graph's source node to ensure we don't remove it
// even if it no longer has any open channels.
sourceNode, err := c.sourceNode(tx)
if err != nil {
return err
}
// We'll now iterate over every node which had a channel closed and
// check whether they have any other open channels left within the
// graph. If they don't, they'll be pruned from the channel graph.
for nodePubKey := range nodePubKeys {
if bytes.Equal(nodePubKey[:], sourceNode.PubKeyBytes[:]) {
continue
}
node, err := fetchLightningNode(nodes, nodePubKey[:])
if err != nil {
continue
}
node.db = c.db
numChansLeft := 0
err = node.ForEachChannel(tx, func(*bolt.Tx, *ChannelEdgeInfo,
*ChannelEdgePolicy, *ChannelEdgePolicy) error {
numChansLeft++
return nil
})
if err != nil {
continue
}
if numChansLeft == 0 {
err := c.deleteLightningNode(tx, nodePubKey[:])
if err != nil {
log.Tracef("Unable to prune node %x from the "+
"graph: %v", nodePubKey, err)
}
}
}
return nil
}
// DisconnectBlockAtHeight is used to indicate that the block specified // DisconnectBlockAtHeight is used to indicate that the block specified
// by the passed height has been disconnected from the main chain. This // by the passed height has been disconnected from the main chain. This
// will "rewind" the graph back to the height below, deleting channels // will "rewind" the graph back to the height below, deleting channels

@ -421,6 +421,13 @@ func TestDisconnectBlockAtHeight(t *testing.T) {
} }
graph := db.ChannelGraph() graph := db.ChannelGraph()
sourceNode, err := createTestVertex(db)
if err != nil {
t.Fatalf("unable to create source node: %v", err)
}
if err := graph.SetSourceNode(sourceNode); err != nil {
t.Fatalf("unable to set source node: %v", err)
}
// We'd like to test the insertion/deletion of edges, so we create two // We'd like to test the insertion/deletion of edges, so we create two
// vertexes to connect. // vertexes to connect.
@ -964,7 +971,7 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) {
return nil return nil
}); err != nil { }); err != nil {
_, _, line, _ := runtime.Caller(1) _, _, line, _ := runtime.Caller(1)
t.Fatalf("line %v:unable to scan channels: %v", line, err) t.Fatalf("line %v: unable to scan channels: %v", line, err)
} }
if numChans != n { if numChans != n {
_, _, line, _ := runtime.Caller(1) _, _, line, _ := runtime.Caller(1)
@ -973,6 +980,23 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) {
} }
} }
func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) {
numNodes := 0
err := graph.ForEachNode(nil, func(_ *bolt.Tx, _ *LightningNode) error {
numNodes++
return nil
})
if err != nil {
_, _, line, _ := runtime.Caller(1)
t.Fatalf("line %v: unable to scan nodes: %v", line, err)
}
if numNodes != n {
_, _, line, _ := runtime.Caller(1)
t.Fatalf("line %v: expected %v nodes, got %v", line, n, numNodes)
}
}
func assertChanViewEqual(t *testing.T, a []wire.OutPoint, b []*wire.OutPoint) { func assertChanViewEqual(t *testing.T, a []wire.OutPoint, b []*wire.OutPoint) {
if len(a) != len(b) { if len(a) != len(b) {
_, _, line, _ := runtime.Caller(1) _, _, line, _ := runtime.Caller(1)
@ -1003,6 +1027,13 @@ func TestGraphPruning(t *testing.T) {
} }
graph := db.ChannelGraph() graph := db.ChannelGraph()
sourceNode, err := createTestVertex(db)
if err != nil {
t.Fatalf("unable to create source node: %v", err)
}
if err := graph.SetSourceNode(sourceNode); err != nil {
t.Fatalf("unable to set source node: %v", err)
}
// As initial set up for the test, we'll create a graph with 5 vertexes // As initial set up for the test, we'll create a graph with 5 vertexes
// and enough edges to create a fully connected graph. The graph will // and enough edges to create a fully connected graph. The graph will
@ -1137,9 +1168,11 @@ func TestGraphPruning(t *testing.T) {
t.Fatalf("channels were pruned but shouldn't have been") t.Fatalf("channels were pruned but shouldn't have been")
} }
// Once again, the prune tip should have been updated. // Once again, the prune tip should have been updated. We should still
// see both channels and their participants, along with the source node.
assertPruneTip(t, graph, &blockHash, blockHeight) assertPruneTip(t, graph, &blockHash, blockHeight)
assertNumChans(t, graph, 2) assertNumChans(t, graph, 2)
assertNumNodes(t, graph, 4)
// Finally, create a block that prunes the remainder of the channels // Finally, create a block that prunes the remainder of the channels
// from the graph. // from the graph.
@ -1159,10 +1192,11 @@ func TestGraphPruning(t *testing.T) {
"expected %v, got %v", 2, len(prunedChans)) "expected %v, got %v", 2, len(prunedChans))
} }
// The prune tip should be updated, and no channels should be found // The prune tip should be updated, no channels should be found, and
// within the current graph. // only the source node should remain within the current graph.
assertPruneTip(t, graph, &blockHash, blockHeight) assertPruneTip(t, graph, &blockHash, blockHeight)
assertNumChans(t, graph, 0) assertNumChans(t, graph, 0)
assertNumNodes(t, graph, 1)
// Finally, the channel view at this point in the graph should now be // Finally, the channel view at this point in the graph should now be
// completely empty. Those channels should also be missing from the // completely empty. Those channels should also be missing from the
@ -1888,6 +1922,13 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) {
} }
graph := db.ChannelGraph() graph := db.ChannelGraph()
sourceNode, err := createTestVertex(db)
if err != nil {
t.Fatalf("unable to create source node: %v", err)
}
if err := graph.SetSourceNode(sourceNode); err != nil {
t.Fatalf("unable to set source node: %v", err)
}
// We'll first populate our graph with two nodes. All channels created // We'll first populate our graph with two nodes. All channels created
// below will be made between these two nodes. // below will be made between these two nodes.

@ -6,9 +6,9 @@ import (
"net" "net"
"time" "time"
"github.com/coreos/bbolt"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt"
) )
var ( var (
@ -127,6 +127,24 @@ func putLinkNode(nodeMetaBucket *bolt.Bucket, l *LinkNode) error {
return nodeMetaBucket.Put(nodePub, b.Bytes()) return nodeMetaBucket.Put(nodePub, b.Bytes())
} }
// DeleteLinkNode removes the link node with the given identity from the
// database.
func (d *DB) DeleteLinkNode(identity *btcec.PublicKey) error {
return d.Update(func(tx *bolt.Tx) error {
return d.deleteLinkNode(tx, identity)
})
}
func (d *DB) deleteLinkNode(tx *bolt.Tx, identity *btcec.PublicKey) error {
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return ErrLinkNodesNotFound
}
pubKey := identity.SerializeCompressed()
return nodeMetaBucket.Delete(pubKey)
}
// FetchLinkNode attempts to lookup the data for a LinkNode based on a target // FetchLinkNode attempts to lookup the data for a LinkNode based on a target
// identity public key. If a particular LinkNode for the passed identity public // identity public key. If a particular LinkNode for the passed identity public
// key cannot be found, then ErrNodeNotFound if returned. // key cannot be found, then ErrNodeNotFound if returned.
@ -165,19 +183,36 @@ func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
return node, nil return node, nil
} }
// FetchAllLinkNodes attempts to fetch all active LinkNodes from the database. // FetchAllLinkNodes starts a new database transaction to fetch all nodes with
// If there haven't been any channels explicitly linked to LinkNodes written to // whom we have active channels with.
// the database, then this function will return an empty slice.
func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) { func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
var linkNodes []*LinkNode var linkNodes []*LinkNode
err := db.View(func(tx *bolt.Tx) error { err := db.View(func(tx *bolt.Tx) error {
nodeMetaBucket := tx.Bucket(nodeInfoBucket) nodes, err := db.fetchAllLinkNodes(tx)
if nodeMetaBucket == nil { if err != nil {
return ErrLinkNodesNotFound return err
} }
return nodeMetaBucket.ForEach(func(k, v []byte) error { linkNodes = nodes
return nil
})
if err != nil {
return nil, err
}
return linkNodes, nil
}
// fetchAllLinkNodes uses an existing database transaction to fetch all nodes
// with whom we have active channels with.
func (db *DB) fetchAllLinkNodes(tx *bolt.Tx) ([]*LinkNode, error) {
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return nil, ErrLinkNodesNotFound
}
var linkNodes []*LinkNode
err := nodeMetaBucket.ForEach(func(k, v []byte) error {
if v == nil { if v == nil {
return nil return nil
} }
@ -191,7 +226,6 @@ func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
linkNodes = append(linkNodes, linkNode) linkNodes = append(linkNodes, linkNode)
return nil return nil
}) })
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -106,3 +106,35 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
addr2.String(), node1DB.Addresses[1].String()) addr2.String(), node1DB.Addresses[1].String())
} }
} }
func TestDeleteLinkNode(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
_, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1337,
}
linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr)
if err := linkNode.Sync(); err != nil {
t.Fatalf("unable to write link node to db: %v", err)
}
if _, err := cdb.FetchLinkNode(pubKey); err != nil {
t.Fatalf("unable to find link node: %v", err)
}
if err := cdb.DeleteLinkNode(pubKey); err != nil {
t.Fatalf("unable to delete link node from db: %v", err)
}
if _, err := cdb.FetchLinkNode(pubKey); err == nil {
t.Fatal("should not have found link node in db, but did")
}
}

@ -342,23 +342,6 @@ func (c *ChainArbitrator) Start() error {
pCache: c.cfg.PreimageDB, pCache: c.cfg.PreimageDB,
signer: c.cfg.Signer, signer: c.cfg.Signer,
isOurAddr: c.cfg.IsOurAddress, isOurAddr: c.cfg.IsOurAddress,
notifyChanClosed: func() error {
c.Lock()
delete(c.activeChannels, chanPoint)
chainWatcher, ok := c.activeWatchers[chanPoint]
if ok {
// Since the chainWatcher is
// calling notifyChanClosed, we
// must stop it in a goroutine
// to not deadlock.
go chainWatcher.Stop()
}
delete(c.activeWatchers, chanPoint)
c.Unlock()
return nil
},
contractBreach: func(retInfo *lnwallet.BreachRetribution) error { contractBreach: func(retInfo *lnwallet.BreachRetribution) error {
return c.cfg.ContractBreach(chanPoint, retInfo) return c.cfg.ContractBreach(chanPoint, retInfo)
}, },
@ -697,22 +680,6 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error
pCache: c.cfg.PreimageDB, pCache: c.cfg.PreimageDB,
signer: c.cfg.Signer, signer: c.cfg.Signer,
isOurAddr: c.cfg.IsOurAddress, isOurAddr: c.cfg.IsOurAddress,
notifyChanClosed: func() error {
c.Lock()
delete(c.activeChannels, chanPoint)
chainWatcher, ok := c.activeWatchers[chanPoint]
if ok {
// Since the chainWatcher is calling
// notifyChanClosed, we must stop it in
// a goroutine to not deadlock.
go chainWatcher.Stop()
}
delete(c.activeWatchers, chanPoint)
c.Unlock()
return nil
},
contractBreach: func(retInfo *lnwallet.BreachRetribution) error { contractBreach: func(retInfo *lnwallet.BreachRetribution) error {
return c.cfg.ContractBreach(chanPoint, retInfo) return c.cfg.ContractBreach(chanPoint, retInfo)
}, },

@ -78,13 +78,6 @@ type chainWatcherConfig struct {
// machine. // machine.
signer lnwallet.Signer signer lnwallet.Signer
// notifyChanClosed is a method that will be called by the watcher when
// it has detected a close on-chain and performed all necessary
// actions, like marking the channel closed in the database and
// notified all its subcribers. It lets the chain arbitrator know that
// the chain watcher chan be stopped.
notifyChanClosed func() error
// contractBreach is a method that will be called by the watcher if it // contractBreach is a method that will be called by the watcher if it
// detects that a contract breach transaction has been confirmed. Only // detects that a contract breach transaction has been confirmed. Only
// when this method returns with a non-nil error it will be safe to mark // when this method returns with a non-nil error it will be safe to mark
@ -492,16 +485,7 @@ func (c *chainWatcher) dispatchCooperativeClose(commitSpend *chainntnfs.SpendDet
} }
c.Unlock() c.Unlock()
// Now notify the ChainArbitrator that the watcher's job is done, such
// that it can shut it down and clean up.
if err := c.cfg.notifyChanClosed(); err != nil {
log.Errorf("unable to notify channel closed for "+
"ChannelPoint(%v): %v",
c.cfg.chanState.FundingOutpoint, err)
}
return nil return nil
} }
// dispatchLocalForceClose processes a unilateral close by us being confirmed. // dispatchLocalForceClose processes a unilateral close by us being confirmed.

@ -1364,10 +1364,16 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
) )
// We've cooperatively closed the channel, so we're no longer // We've cooperatively closed the channel, so we're no longer
// needed. // needed. We'll mark the channel as resolved and exit.
case <-c.cfg.ChainEvents.CooperativeClosure: case <-c.cfg.ChainEvents.CooperativeClosure:
log.Infof("ChannelArbitrator(%v) closing due to co-op "+ log.Infof("ChannelArbitrator(%v) closing due to co-op "+
"closure", c.cfg.ChanPoint) "closure", c.cfg.ChanPoint)
if err := c.cfg.MarkChannelResolved(); err != nil {
log.Errorf("Unable to mark contract "+
"resolved: %v", err)
}
return return
// We have broadcasted our commitment, and it is now confirmed // We have broadcasted our commitment, and it is now confirmed

@ -22,18 +22,18 @@ import (
"crypto/sha256" "crypto/sha256"
prand "math/rand" prand "math/rand"
"github.com/btcsuite/btclog"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntest"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/integration/rpctest" "github.com/btcsuite/btcd/integration/rpctest"
"github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntest"
"github.com/lightningnetwork/lnd/lnwire"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -4900,6 +4900,191 @@ func testFailingChannel(net *lntest.NetworkHarness, t *harnessTest) {
} }
} }
// testGarbageCollectLinkNodes tests that we properly garbase collect link nodes
// from the database and the set of persistent connections within the server.
func testGarbageCollectLinkNodes(net *lntest.NetworkHarness, t *harnessTest) {
const (
timeout = time.Second * 10
chanAmt = 1000000
)
// Open a channel between Alice and Bob which will later be
// cooperatively closed.
ctxb := context.Background()
ctxt, _ := context.WithTimeout(ctxb, timeout)
coopChanPoint := openChannelAndAssert(
ctxt, t, net, net.Alice, net.Bob, chanAmt, 0, false,
)
// Create Carol's node and connect Alice to her.
carol, err := net.NewNode("Carol", nil)
if err != nil {
t.Fatalf("unable to create carol's node: %v", err)
}
defer shutdownAndAssert(net, t, carol)
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.ConnectNodes(ctxt, net.Alice, carol); err != nil {
t.Fatalf("unable to connect alice and carol: %v", err)
}
// Open a channel between Alice and Carol which will later be force
// closed.
ctxt, _ = context.WithTimeout(ctxb, timeout)
forceCloseChanPoint := openChannelAndAssert(
ctxt, t, net, net.Alice, carol, chanAmt, 0, false,
)
// Now, create Dave's a node and also open a channel between Alice and
// him. This link will serve as the only persistent link throughout
// restarts in this test.
dave, err := net.NewNode("Dave", nil)
if err != nil {
t.Fatalf("unable to create dave's node: %v", err)
}
defer shutdownAndAssert(net, t, dave)
if err := net.ConnectNodes(ctxt, net.Alice, dave); err != nil {
t.Fatalf("unable to connect alice to dave: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
persistentChanPoint := openChannelAndAssert(
ctxt, t, net, net.Alice, dave, chanAmt, 0, false,
)
// isConnected is a helper closure that checks if a peer is connected to
// Alice.
isConnected := func(pubKey string) bool {
req := &lnrpc.ListPeersRequest{}
resp, err := net.Alice.ListPeers(ctxb, req)
if err != nil {
t.Fatalf("unable to retrieve alice's peers: %v", err)
}
for _, peer := range resp.Peers {
if peer.PubKey == pubKey {
return true
}
}
return false
}
// Restart both Bob and Carol to ensure Alice is able to reconnect to
// them.
if err := net.RestartNode(net.Bob, nil); err != nil {
t.Fatalf("unable to restart bob's node: %v", err)
}
if err := net.RestartNode(carol, nil); err != nil {
t.Fatalf("unable to restart carol's node: %v", err)
}
err = lntest.WaitPredicate(func() bool {
return isConnected(net.Bob.PubKeyStr)
}, 15*time.Second)
if err != nil {
t.Fatalf("alice did not reconnect to bob")
}
err = lntest.WaitPredicate(func() bool {
return isConnected(carol.PubKeyStr)
}, 15*time.Second)
if err != nil {
t.Fatalf("alice did not reconnect to carol")
}
// We'll also restart Alice to ensure she can reconnect to her peers
// with open channels.
if err := net.RestartNode(net.Alice, nil); err != nil {
t.Fatalf("unable to restart alice's node: %v", err)
}
err = lntest.WaitPredicate(func() bool {
return isConnected(net.Bob.PubKeyStr)
}, 15*time.Second)
if err != nil {
t.Fatalf("alice did not reconnect to bob")
}
err = lntest.WaitPredicate(func() bool {
return isConnected(carol.PubKeyStr)
}, 15*time.Second)
if err != nil {
t.Fatalf("alice did not reconnect to carol")
}
// testReconnection is a helper closure that restarts the nodes at both
// ends of a channel to ensure they do not reconnect after restarting.
// When restarting Alice, we'll first need to ensure she has
// reestablished her connection with Dave, as they still have an open
// channel together.
testReconnection := func(node *lntest.HarnessNode) {
if err := net.RestartNode(node, nil); err != nil {
t.Fatalf("unable to restart %v's node: %v", node.Name(),
err)
}
err = lntest.WaitPredicate(func() bool {
return !isConnected(node.PubKeyStr)
}, 20*time.Second)
if err != nil {
t.Fatalf("alice reconnected to %v", node.Name())
}
if err := net.RestartNode(net.Alice, nil); err != nil {
t.Fatalf("unable to restart alice's node: %v", err)
}
err = lntest.WaitPredicate(func() bool {
if !isConnected(dave.PubKeyStr) {
return false
}
return !isConnected(node.PubKeyStr)
}, 20*time.Second)
if err != nil {
t.Fatalf("alice reconnected to %v", node.Name())
}
}
// Now, we'll close the channel between Alice and Bob and ensure there
// is no reconnection logic between the both once the channel is fully
// closed.
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, net.Alice, coopChanPoint, false)
testReconnection(net.Bob)
// We'll do the same with Alice and Carol, but this time we'll force
// close the channel instead.
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, net.Alice, forceCloseChanPoint, true)
// We'll need to mine some blocks in order to mark the channel fully
// closed.
_, err = net.Miner.Node.Generate(defaultBitcoinTimeLockDelta)
if err != nil {
t.Fatalf("unable to generate blocks: %v", err)
}
testReconnection(carol)
// Finally, we'll ensure that Bob and Carol no longer show in Alice's
// channel graph.
describeGraphReq := &lnrpc.ChannelGraphRequest{}
channelGraph, err := net.Alice.DescribeGraph(ctxb, describeGraphReq)
if err != nil {
t.Fatalf("unable to query for alice's channel graph: %v", err)
}
for _, node := range channelGraph.Nodes {
if node.PubKey == net.Bob.PubKeyStr {
t.Fatalf("did not expect to find bob in the channel " +
"graph, but did")
}
if node.PubKey == carol.PubKeyStr {
t.Fatalf("did not expect to find carol in the channel " +
"graph, but did")
}
}
// Now that the test is done, we can also close the persistent link.
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, net.Alice, persistentChanPoint, false)
}
// testRevokedCloseRetribution tests that Alice is able carry out // testRevokedCloseRetribution tests that Alice is able carry out
// retribution in the event that she fails immediately after detecting Bob's // retribution in the event that she fails immediately after detecting Bob's
// breach txn in the mempool. // breach txn in the mempool.
@ -10324,6 +10509,10 @@ var testsCases = []*testCase{
name: "failing link", name: "failing link",
test: testFailingChannel, test: testFailingChannel,
}, },
{
name: "garbage collect link nodes",
test: testGarbageCollectLinkNodes,
},
{ {
name: "revoked uncooperative close retribution zero value remote output", name: "revoked uncooperative close retribution zero value remote output",
test: testRevokedCloseRetributionZeroValueRemoteOutput, test: testRevokedCloseRetributionZeroValueRemoteOutput,
@ -10422,6 +10611,14 @@ func TestLightningNetworkDaemon(t *testing.T) {
for _, testCase := range testsCases { for _, testCase := range testsCases {
logLine := fmt.Sprintf("STARTING ============ %v ============\n", logLine := fmt.Sprintf("STARTING ============ %v ============\n",
testCase.name) testCase.name)
err := lndHarness.EnsureConnected(
context.Background(), lndHarness.Alice, lndHarness.Bob,
)
if err != nil {
t.Fatalf("unable to connect alice to bob: %v", err)
}
if err := lndHarness.Alice.AddToLog(logLine); err != nil { if err := lndHarness.Alice.AddToLog(logLine); err != nil {
t.Fatalf("unable to add to log: %v", err) t.Fatalf("unable to add to log: %v", err)
} }

18
peer.go

@ -1697,9 +1697,6 @@ func (p *peer) fetchActiveChanCloser(chanID lnwire.ChannelID) (*channelCloser, e
// handleLocalCloseReq kicks-off the workflow to execute a cooperative or // handleLocalCloseReq kicks-off the workflow to execute a cooperative or
// forced unilateral closure of the channel initiated by a local subsystem. // forced unilateral closure of the channel initiated by a local subsystem.
//
// TODO(roasbeef): if no more active channels with peer call Remove on connMgr
// with peerID
func (p *peer) handleLocalCloseReq(req *htlcswitch.ChanClose) { func (p *peer) handleLocalCloseReq(req *htlcswitch.ChanClose) {
chanID := lnwire.NewChanIDFromOutPoint(req.ChanPoint) chanID := lnwire.NewChanIDFromOutPoint(req.ChanPoint)
@ -1852,6 +1849,18 @@ func (p *peer) finalizeChanClosure(chanCloser *channelCloser) {
}, },
} }
} }
// Remove the persistent connection to this peer if we
// no longer have open channels with them.
p.activeChanMtx.Lock()
numActiveChans := len(p.activeChannels)
p.activeChanMtx.Unlock()
if numActiveChans == 0 {
p.server.prunePersistentPeerConnection(
p.pubKeyBytes,
)
}
}) })
} }
@ -1905,6 +1914,9 @@ func (p *peer) WipeChannel(chanPoint *wire.OutPoint) error {
if channel, ok := p.activeChannels[chanID]; ok { if channel, ok := p.activeChannels[chanID]; ok {
channel.Stop() channel.Stop()
delete(p.activeChannels, chanID) delete(p.activeChannels, chanID)
if len(p.activeChannels) == 0 {
p.server.prunePersistentPeerConnection(p.pubKeyBytes)
}
} }
p.activeChanMtx.Unlock() p.activeChanMtx.Unlock()

@ -759,7 +759,12 @@ func (s *server) Start() error {
// With all the relevant sub-systems started, we'll now attempt to // With all the relevant sub-systems started, we'll now attempt to
// establish persistent connections to our direct channel collaborators // establish persistent connections to our direct channel collaborators
// within the network. // within the network. Before doing so however, we'll prune our set of
// link nodes found within the database to ensure we don't reconnect to
// any nodes we no longer have open channels with.
if err := s.chanDB.PruneLinkNodes(); err != nil {
return err
}
if err := s.establishPersistentConnections(); err != nil { if err := s.establishPersistentConnections(); err != nil {
return err return err
} }
@ -1476,6 +1481,22 @@ func (s *server) establishPersistentConnections() error {
return nil return nil
} }
// prunePersistentPeerConnection removes all internal state related to
// persistent connections to a peer within the server. This is used to avoid
// persistent connection retries to peers we do not have any open channels with.
func (s *server) prunePersistentPeerConnection(compressedPubKey [33]byte) {
srvrLog.Infof("Pruning peer %x from persistent connections, number of "+
"open channels is now zero", compressedPubKey)
pubKeyStr := string(compressedPubKey[:])
s.mu.Lock()
delete(s.persistentPeers, pubKeyStr)
delete(s.persistentPeersBackoff, pubKeyStr)
s.cancelConnReqs(pubKeyStr, nil)
s.mu.Unlock()
}
// BroadcastMessage sends a request to the server to broadcast a set of // BroadcastMessage sends a request to the server to broadcast a set of
// messages to all peers other than the one specified by the `skips` parameter. // messages to all peers other than the one specified by the `skips` parameter.
// //