From 95ddab57faa611ecde2d3ab6402d1854ae42b0e0 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Fri, 20 Dec 2019 10:14:13 +0100 Subject: [PATCH] channeldb: add tx argument for FetchLightningNode To allow execution within an existing tx. --- autopilot/graph.go | 2 +- channeldb/graph.go | 21 +++++++++++++++++---- channeldb/graph_test.go | 15 ++++++++------- routing/pathfind.go | 2 +- routing/pathfind_test.go | 4 ++-- routing/router.go | 2 +- routing/router_test.go | 4 ++-- rpcserver.go | 2 +- server.go | 2 +- 9 files changed, 34 insertions(+), 20 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index c63d6650..5792443f 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -153,7 +153,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, return nil, err } - dbNode, err := d.db.FetchLightningNode(vertex) + dbNode, err := d.db.FetchLightningNode(nil, vertex) switch { case err == channeldb.ErrGraphNodeNotFound: fallthrough diff --git a/channeldb/graph.go b/channeldb/graph.go index 3ec06981..61ca7a5f 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2179,11 +2179,17 @@ func (l *LightningNode) isPublic(tx *bbolt.Tx, sourcePubKey []byte) (bool, error // FetchLightningNode attempts to look up a target node by its identity public // key. If the node isn't found in the database, then ErrGraphNodeNotFound is // returned. -func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (*LightningNode, - error) { +// +// If the caller wishes to re-use an existing boltdb transaction, then it +// should be passed as the first argument. Otherwise the first argument should +// be nil and a fresh transaction will be created to execute the graph +// traversal. +func (c *ChannelGraph) FetchLightningNode(tx *bbolt.Tx, nodePub route.Vertex) ( + *LightningNode, error) { var node *LightningNode - err := c.db.View(func(tx *bbolt.Tx) error { + + fetchNode := func(tx *bbolt.Tx) error { // First grab the nodes bucket which stores the mapping from // pubKey to node information. nodes := tx.Bucket(nodeBucket) @@ -2210,7 +2216,14 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (*LightningNode, node = &n return nil - }) + } + + var err error + if tx == nil { + err = c.db.View(fetchNode) + } else { + err = fetchNode(tx) + } if err != nil { return nil, err } diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 27e29871..d79f15dd 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -104,7 +104,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { // Next, fetch the node from the database to ensure everything was // serialized properly. - dbNode, err := graph.FetchLightningNode(testPub) + dbNode, err := graph.FetchLightningNode(nil, testPub) if err != nil { t.Fatalf("unable to locate node: %v", err) } @@ -128,7 +128,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. - _, err = graph.FetchLightningNode(testPub) + _, err = graph.FetchLightningNode(nil, testPub) if err != ErrGraphNodeNotFound { t.Fatalf("fetch after delete should fail!") } @@ -160,7 +160,7 @@ func TestPartialNode(t *testing.T) { // Next, fetch the node from the database to ensure everything was // serialized properly. - dbNode, err := graph.FetchLightningNode(testPub) + dbNode, err := graph.FetchLightningNode(nil, testPub) if err != nil { t.Fatalf("unable to locate node: %v", err) } @@ -192,7 +192,7 @@ func TestPartialNode(t *testing.T) { // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. - _, err = graph.FetchLightningNode(testPub) + _, err = graph.FetchLightningNode(nil, testPub) if err != ErrGraphNodeNotFound { t.Fatalf("fetch after delete should fail!") } @@ -2387,7 +2387,8 @@ func TestPruneGraphNodes(t *testing.T) { // Finally, we'll ensure that node3, the only fully unconnected node as // properly deleted from the graph and not another node in its place. - if _, err := graph.FetchLightningNode(node3.PubKeyBytes); err == nil { + _, err = graph.FetchLightningNode(nil, node3.PubKeyBytes) + if err == nil { t.Fatalf("node 3 should have been deleted!") } } @@ -2429,7 +2430,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { // Ensure that node1 was inserted as a full node, while node2 only has // a shell node present. - node1, err = graph.FetchLightningNode(node1.PubKeyBytes) + node1, err = graph.FetchLightningNode(nil, node1.PubKeyBytes) if err != nil { t.Fatalf("unable to fetch node1: %v", err) } @@ -2437,7 +2438,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { t.Fatalf("have shell announcement for node1, shouldn't") } - node2, err = graph.FetchLightningNode(node2.PubKeyBytes) + node2, err = graph.FetchLightningNode(nil, node2.PubKeyBytes) if err != nil { t.Fatalf("unable to fetch node2: %v", err) } diff --git a/routing/pathfind.go b/routing/pathfind.go index 9eca590a..1049e5da 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -446,7 +446,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // we have for the target node from our graph. features := r.DestFeatures if features == nil { - targetNode, err := g.graph.FetchLightningNode(target) + targetNode, err := g.graph.FetchLightningNode(tx, target) switch { // If the node exists and has features, use them directly. diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 3a26428c..1bec640a 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2070,7 +2070,7 @@ func TestPathFindSpecExample(t *testing.T) { // Carol, so we set "B" as the source node so path finding starts from // Bob. bob := ctx.aliases["B"] - bobNode, err := ctx.graph.FetchLightningNode(bob) + bobNode, err := ctx.graph.FetchLightningNode(nil, bob) if err != nil { t.Fatalf("unable to find bob: %v", err) } @@ -2119,7 +2119,7 @@ func TestPathFindSpecExample(t *testing.T) { // Next, we'll set A as the source node so we can assert that we create // the proper route for any queries starting with Alice. alice := ctx.aliases["A"] - aliceNode, err := ctx.graph.FetchLightningNode(alice) + aliceNode, err := ctx.graph.FetchLightningNode(nil, alice) if err != nil { t.Fatalf("unable to find alice: %v", err) } diff --git a/routing/router.go b/routing/router.go index 87a73937..085db61f 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2089,7 +2089,7 @@ func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) FetchLightningNode(node route.Vertex) (*channeldb.LightningNode, error) { - return r.cfg.Graph.FetchLightningNode(node) + return r.cfg.Graph.FetchLightningNode(nil, node) } // ForEachNode is used to iterate over every node in router topology. diff --git a/routing/router_test.go b/routing/router_test.go index 483abec4..26803492 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1319,7 +1319,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("unable to find any routes: %v", err) } - copy1, err := ctx.graph.FetchLightningNode(pub1) + copy1, err := ctx.graph.FetchLightningNode(nil, pub1) if err != nil { t.Fatalf("unable to fetch node: %v", err) } @@ -1328,7 +1328,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("fetched node not equal to original") } - copy2, err := ctx.graph.FetchLightningNode(pub2) + copy2, err := ctx.graph.FetchLightningNode(nil, pub2) if err != nil { t.Fatalf("unable to fetch node: %v", err) } diff --git a/rpcserver.go b/rpcserver.go index 703ce25a..1ae6f2ad 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -4239,7 +4239,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, // With the public key decoded, attempt to fetch the node corresponding // to this public key. If the node cannot be found, then an error will // be returned. - node, err := graph.FetchLightningNode(pubKey) + node, err := graph.FetchLightningNode(nil, pubKey) if err != nil { return nil, err } diff --git a/server.go b/server.go index 58f8f17b..dd97b2be 100644 --- a/server.go +++ b/server.go @@ -3374,7 +3374,7 @@ func (s *server) fetchNodeAdvertisedAddr(pub *btcec.PublicKey) (net.Addr, error) return nil, err } - node, err := s.chanDB.ChannelGraph().FetchLightningNode(vertex) + node, err := s.chanDB.ChannelGraph().FetchLightningNode(nil, vertex) if err != nil { return nil, err }