diff --git a/autopilot/graph.go b/autopilot/graph.go index 5641bb21..766919ac 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -121,7 +121,7 @@ func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { // // NOTE: Part of the autopilot.ChannelGraph interface. func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { - return d.db.ForEachNode(nil, func(tx kvdb.ReadTx, n *channeldb.LightningNode) error { + return d.db.ForEachNode(func(tx kvdb.ReadTx, n *channeldb.LightningNode) error { // We'll skip over any node that doesn't have any advertised // addresses. As we won't be able to reach them to actually diff --git a/channeldb/graph.go b/channeldb/graph.go index 0bbd790e..3722bfbd 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -321,14 +321,9 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { // returns an error, then the transaction is aborted and the iteration stops // early. // -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal -// // TODO(roasbeef): add iterator interface to allow for memory efficient graph // traversal when graph gets mega -func (c *ChannelGraph) ForEachNode(tx kvdb.RwTx, cb func(kvdb.ReadTx, *LightningNode) error) error { // nolint:interfacer +func (c *ChannelGraph) ForEachNode(cb func(kvdb.ReadTx, *LightningNode) error) error { // nolint:interfacer traversal := func(tx kvdb.ReadTx) error { // First grab the nodes bucket which stores the mapping from // pubKey to node information. @@ -358,15 +353,7 @@ func (c *ChannelGraph) ForEachNode(tx kvdb.RwTx, cb func(kvdb.ReadTx, *Lightning }) } - // If no transaction was provided, then we'll create a new transaction - // to execute the transaction within. - if tx == nil { - return kvdb.View(c.db, traversal) - } - - // Otherwise, we re-use the existing transaction to execute the graph - // traversal. - return traversal(tx) + return kvdb.View(c.db, traversal) } // SourceNode returns the source node of the graph. The source node is treated diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index b1355baf..1fa98118 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -882,7 +882,7 @@ func TestGraphTraversal(t *testing.T) { // Iterate over each node as returned by the graph, if all nodes are // reached, then the map created above should be empty. - err = graph.ForEachNode(nil, func(_ kvdb.ReadTx, node *LightningNode) error { + err = graph.ForEachNode(func(_ kvdb.ReadTx, node *LightningNode) error { delete(nodeIndex, node.Alias) return nil }) @@ -1051,7 +1051,7 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { numNodes := 0 - err := graph.ForEachNode(nil, func(_ kvdb.ReadTx, _ *LightningNode) error { + err := graph.ForEachNode(func(_ kvdb.ReadTx, _ *LightningNode) error { numNodes++ return nil }) diff --git a/routing/router.go b/routing/router.go index d0ac67fd..e651b450 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2181,7 +2181,7 @@ func (r *ChannelRouter) FetchLightningNode(node route.Vertex) (*channeldb.Lightn // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) ForEachNode(cb func(*channeldb.LightningNode) error) error { - return r.cfg.Graph.ForEachNode(nil, func(_ kvdb.ReadTx, n *channeldb.LightningNode) error { + return r.cfg.Graph.ForEachNode(func(_ kvdb.ReadTx, n *channeldb.LightningNode) error { return cb(n) }) } diff --git a/rpcserver.go b/rpcserver.go index 189d9d38..df632299 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -4594,7 +4594,7 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // First iterate through all the known nodes (connected or unconnected // within the graph), collating their current state into the RPC // response. - err := graph.ForEachNode(nil, func(_ kvdb.ReadTx, node *channeldb.LightningNode) error { + err := graph.ForEachNode(func(_ kvdb.ReadTx, node *channeldb.LightningNode) error { nodeAddrs := make([]*lnrpc.NodeAddress, 0) for _, addr := range node.Addresses { nodeAddr := &lnrpc.NodeAddress{ @@ -4909,7 +4909,7 @@ func (r *rpcServer) GetNetworkInfo(ctx context.Context, // network, tallying up the total number of nodes, and also gathering // each node so we can measure the graph diameter and degree stats // below. - if err := graph.ForEachNode(nil, func(tx kvdb.ReadTx, node *channeldb.LightningNode) error { + if err := graph.ForEachNode(func(tx kvdb.ReadTx, node *channeldb.LightningNode) error { // Increment the total number of nodes with each iteration. numNodes++