diff --git a/routing/graph.go b/routing/graph.go index f3dfa121..14eca178 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -1,4 +1,104 @@ package routing -// TODO(roasbeef): abstract out graph to interface -// * add in-memory version of graph for tests +import ( + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// routingGraph is an abstract interface that provides information about nodes +// and edges to pathfinding. +type routingGraph interface { + // forEachNodeChannel calls the callback for every channel of the given node. + forEachNodeChannel(nodePub route.Vertex, + cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, + *channeldb.ChannelEdgePolicy) error) error + + // sourceNode returns the source node of the graph. + sourceNode() route.Vertex + + // fetchNodeFeatures returns the features of the given node. + fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) +} + +// dbRoutingTx is a routingGraph implementation that retrieves from the +// database. +type dbRoutingTx struct { + graph *channeldb.ChannelGraph + tx *bbolt.Tx + source route.Vertex +} + +// newDbRoutingTx instantiates a new db-connected routing graph. It implictly +// instantiates a new read transaction. +func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { + sourceNode, err := graph.SourceNode() + if err != nil { + return nil, err + } + + tx, err := graph.Database().Begin(false) + if err != nil { + return nil, err + } + + return &dbRoutingTx{ + graph: graph, + tx: tx, + source: sourceNode.PubKeyBytes, + }, nil +} + +// close closes the underlying db transaction. +func (g *dbRoutingTx) close() error { + return g.tx.Rollback() +} + +// forEachNodeChannel calls the callback for every channel of the given node. +// +// NOTE: Part of the routingGraph interface. +func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, + cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, + *channeldb.ChannelEdgePolicy) error) error { + + txCb := func(_ *bbolt.Tx, info *channeldb.ChannelEdgeInfo, + p1, p2 *channeldb.ChannelEdgePolicy) error { + + return cb(info, p1, p2) + } + + return g.graph.ForEachNodeChannel(g.tx, nodePub[:], txCb) +} + +// sourceNode returns the source node of the graph. +// +// NOTE: Part of the routingGraph interface. +func (g *dbRoutingTx) sourceNode() route.Vertex { + return g.source +} + +// fetchNodeFeatures returns the features of the given node. If the node is +// unknown, assume no additional features are supported. +// +// NOTE: Part of the routingGraph interface. +func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( + *lnwire.FeatureVector, error) { + + targetNode, err := g.graph.FetchLightningNode(g.tx, nodePub) + switch err { + + // If the node exists and has features, return them directly. + case nil: + return targetNode.Features, nil + + // If we couldn't find a node announcement, populate a blank feature + // vector. + case channeldb.ErrGraphNodeNotFound: + return lnwire.EmptyFeatureVector(), nil + + // Otherwise bubble the error up. + default: + return nil, err + } +} diff --git a/routing/pathfind.go b/routing/pathfind.go index 77686337..f2ea782d 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -7,7 +7,6 @@ import ( "math" "time" - "github.com/coreos/bbolt" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/feature" @@ -346,10 +345,11 @@ type PathFindingConfig struct { // getMaxOutgoingAmt returns the maximum available balance in any of the // channels of the given node. func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64, - g *graphParams, tx *bbolt.Tx) (lnwire.MilliSatoshi, error) { + bandwidthHints map[uint64]lnwire.MilliSatoshi, + g routingGraph) (lnwire.MilliSatoshi, error) { var max lnwire.MilliSatoshi - cb := func(_ *bbolt.Tx, edgeInfo *channeldb.ChannelEdgeInfo, outEdge, + cb := func(edgeInfo *channeldb.ChannelEdgeInfo, outEdge, _ *channeldb.ChannelEdgePolicy) error { if outEdge == nil { @@ -363,7 +363,7 @@ func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64, return nil } - bandwidth, ok := g.bandwidthHints[chanID] + bandwidth, ok := bandwidthHints[chanID] // If the bandwidth is not available for whatever reason, don't // fail the pathfinding early. @@ -380,28 +380,54 @@ func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64, } // Iterate over all channels of the to node. - err := g.graph.ForEachNodeChannel(tx, node[:], cb) + err := g.forEachNodeChannel(node, cb) if err != nil { return 0, err } return max, err } -// findPath attempts to find a path from the source node within the -// ChannelGraph to the target node that's capable of supporting a payment of -// `amt` value. The current approach implemented is modified version of -// Dijkstra's algorithm to find a single shortest path between the source node -// and the destination. The distance metric used for edges is related to the -// time-lock+fee costs along a particular edge. If a path is found, this -// function returns a slice of ChannelHop structs which encoded the chosen path -// from the target to the source. The search is performed backwards from -// destination node back to source. This is to properly accumulate fees -// that need to be paid along the path and accurately check the amount -// to forward at every node against the available bandwidth. +// findPath attempts to find a path from the source node within the ChannelGraph +// to the target node that's capable of supporting a payment of `amt` value. The +// current approach implemented is modified version of Dijkstra's algorithm to +// find a single shortest path between the source node and the destination. The +// distance metric used for edges is related to the time-lock+fee costs along a +// particular edge. If a path is found, this function returns a slice of +// ChannelHop structs which encoded the chosen path from the target to the +// source. The search is performed backwards from destination node back to +// source. This is to properly accumulate fees that need to be paid along the +// path and accurately check the amount to forward at every node against the +// available bandwidth. func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( []*channeldb.ChannelEdgePolicy, error) { + routingTx, err := newDbRoutingTx(g.graph) + if err != nil { + return nil, err + } + defer func() { + err := routingTx.close() + if err != nil { + log.Errorf("Error closing db tx: %v", err) + } + }() + + return findPathInternal( + g.additionalEdges, g.bandwidthHints, routingTx, r, cfg, source, + target, amt, finalHtlcExpiry, + ) +} + +// findPathInternal is the internal implementation of findPath. +func findPathInternal( + additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy, + bandwidthHints map[uint64]lnwire.MilliSatoshi, + graph routingGraph, + r *RestrictParams, cfg *PathFindingConfig, + source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( + []*channeldb.ChannelEdgePolicy, error) { + // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to // aid in the analysis performance problems in this area. @@ -414,45 +440,20 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, "time=%v", nodesVisited, edgesExpanded, timeElapsed) }() - // Get source node outside of the pathfinding tx, to prevent a deadlock. - selfNode, err := g.graph.SourceNode() - if err != nil { - return nil, err - } - self := selfNode.PubKeyBytes - - // Get a db transaction to execute the graph queries in. - tx, err := g.graph.Database().Begin(false) - if err != nil { - return nil, err - } - defer tx.Rollback() - // If no destination features are provided, we will load what features // we have for the target node from our graph. features := r.DestFeatures if features == nil { - targetNode, err := g.graph.FetchLightningNode(tx, target) - switch { - - // If the node exists and has features, use them directly. - case err == nil: - features = targetNode.Features - - // If an error other than the node not existing is hit, abort. - case err != channeldb.ErrGraphNodeNotFound: + var err error + features, err = graph.fetchNodeFeatures(target) + if err != nil { return nil, err - - // Otherwise, we couldn't find a node announcement, populate a - // blank feature vector. - default: - features = lnwire.EmptyFeatureVector() } } // Ensure that the destination's features don't include unknown // required features. - err = feature.ValidateRequired(features) + err := feature.ValidateRequired(features) if err != nil { return nil, err } @@ -485,8 +486,12 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // If we are routing from ourselves, check that we have enough local // balance available. + self := graph.sourceNode() + if source == self { - max, err := getMaxOutgoingAmt(self, r.OutgoingChannelID, g, tx) + max, err := getMaxOutgoingAmt( + self, r.OutgoingChannelID, bandwidthHints, graph, + ) if err != nil { return nil, err } @@ -504,7 +509,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, distance := make(map[route.Vertex]*nodeWithDist, estimatedNodeCount) additionalEdgesWithSrc := make(map[route.Vertex][]*edgePolicyWithSource) - for vertex, outgoingEdgePolicies := range g.additionalEdges { + for vertex, outgoingEdgePolicies := range additionalEdges { // Build reverse lookup to find incoming edges. Needed because // search is taken place from target to source. for _, outgoingEdgePolicy := range outgoingEdgePolicies { @@ -746,45 +751,35 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Check cache for features of the fromNode. fromFeatures, ok := featureCache[node] - if !ok { - targetNode, err := g.graph.FetchLightningNode(tx, node) - switch { - - // If the node exists and has valid features, use them. - case err == nil: - nodeFeatures := targetNode.Features - - // Don't route through nodes that contain - // unknown required features. - err = feature.ValidateRequired(nodeFeatures) - if err != nil { - break - } - - // Don't route through nodes that don't properly - // set all transitive feature dependencies. - err = feature.ValidateDeps(nodeFeatures) - if err != nil { - break - } - - fromFeatures = nodeFeatures - - // If an error other than the node not existing is hit, - // abort. - case err != channeldb.ErrGraphNodeNotFound: - return nil, err - - // Otherwise, we couldn't find a node announcement, - // populate a blank feature vector. - default: - fromFeatures = lnwire.EmptyFeatureVector() - } - - // Update cache. - featureCache[node] = fromFeatures + if ok { + return fromFeatures, nil } + // Fetch node features fresh from the graph. + fromFeatures, err := graph.fetchNodeFeatures(node) + if err != nil { + return nil, err + } + + // Don't route through nodes that contain unknown required + // features and mark as nil in the cache. + err = feature.ValidateRequired(fromFeatures) + if err != nil { + featureCache[node] = nil + return nil, nil + } + + // Don't route through nodes that don't properly set all + // transitive feature dependencies and mark as nil in the cache. + err = feature.ValidateDeps(fromFeatures) + if err != nil { + featureCache[node] = nil + return nil, nil + } + + // Update cache. + featureCache[node] = fromFeatures + return fromFeatures, nil } @@ -797,7 +792,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Create unified policies for all incoming connections. u := newUnifiedPolicies(self, pivot, r.OutgoingChannelID) - err := u.addGraphPolicies(g.graph, tx) + err := u.addGraphPolicies(graph) if err != nil { return nil, err } @@ -828,7 +823,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } policy := unifiedPolicy.getPolicy( - amtToSend, g.bandwidthHints, + amtToSend, bandwidthHints, ) if policy == nil { diff --git a/routing/router.go b/routing/router.go index a7c24ba4..e42ce7a0 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2311,6 +2311,13 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, return nil, err } + // Fetch the current block height outside the routing transaction, to + // prevent the rpc call blocking the database. + _, height, err := r.cfg.Chain.GetBestBlock() + if err != nil { + return nil, err + } + // Allocate a list that will contain the unified policies for this // route. edges := make([]*unifiedPolicy, len(hops)) @@ -2328,6 +2335,18 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, runningAmt = *amt } + // Open a transaction to execute the graph queries in. + routingTx, err := newDbRoutingTx(r.cfg.Graph) + if err != nil { + return nil, err + } + defer func() { + err := routingTx.close() + if err != nil { + log.Errorf("Error closing db tx: %v", err) + } + }() + // Traverse hops backwards to accumulate fees in the running amounts. source := r.selfNode.PubKeyBytes for i := len(hops) - 1; i >= 0; i-- { @@ -2346,7 +2365,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // known in the graph. u := newUnifiedPolicies(source, toNode, outgoingChan) - err := u.addGraphPolicies(r.cfg.Graph, nil) + err := u.addGraphPolicies(routingTx) if err != nil { return nil, err } @@ -2414,11 +2433,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, } // Build and return the final route. - _, height, err := r.cfg.Chain.GetBestBlock() - if err != nil { - return nil, err - } - return newRoute( source, pathEdges, uint32(height), finalHopParams{ diff --git a/routing/unified_policies.go b/routing/unified_policies.go index 81e646c2..3759175a 100644 --- a/routing/unified_policies.go +++ b/routing/unified_policies.go @@ -2,7 +2,6 @@ package routing import ( "github.com/btcsuite/btcutil" - "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -69,10 +68,8 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. -func (u *unifiedPolicies) addGraphPolicies(g *channeldb.ChannelGraph, - tx *bbolt.Tx) error { - - cb := func(_ *bbolt.Tx, edgeInfo *channeldb.ChannelEdgeInfo, _, +func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { + cb := func(edgeInfo *channeldb.ChannelEdgeInfo, _, inEdge *channeldb.ChannelEdgePolicy) error { // If there is no edge policy for this candidate node, skip. @@ -95,7 +92,7 @@ func (u *unifiedPolicies) addGraphPolicies(g *channeldb.ChannelGraph, } // Iterate over all channels of the to node. - return g.ForEachNodeChannel(tx, u.toNode[:], cb) + return g.forEachNodeChannel(u.toNode, cb) } // unifiedPolicyEdge is the individual channel data that is kept inside an