routing: add graph interface

This commit is contained in:
Joost Jager 2020-01-27 12:33:53 +01:00
parent a8ed1b342a
commit 06bdeb56e2
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
4 changed files with 206 additions and 100 deletions

@ -1,4 +1,104 @@
package routing package routing
// TODO(roasbeef): abstract out graph to interface import (
// * add in-memory version of graph for tests "github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// routingGraph is an abstract interface that provides information about nodes
// and edges to pathfinding.
type routingGraph interface {
// forEachNodeChannel calls the callback for every channel of the given node.
forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy) error) error
// sourceNode returns the source node of the graph.
sourceNode() route.Vertex
// fetchNodeFeatures returns the features of the given node.
fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
}
// dbRoutingTx is a routingGraph implementation that retrieves from the
// database.
type dbRoutingTx struct {
graph *channeldb.ChannelGraph
tx *bbolt.Tx
source route.Vertex
}
// newDbRoutingTx instantiates a new db-connected routing graph. It implictly
// instantiates a new read transaction.
func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) {
sourceNode, err := graph.SourceNode()
if err != nil {
return nil, err
}
tx, err := graph.Database().Begin(false)
if err != nil {
return nil, err
}
return &dbRoutingTx{
graph: graph,
tx: tx,
source: sourceNode.PubKeyBytes,
}, nil
}
// close closes the underlying db transaction.
func (g *dbRoutingTx) close() error {
return g.tx.Rollback()
}
// forEachNodeChannel calls the callback for every channel of the given node.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy) error) error {
txCb := func(_ *bbolt.Tx, info *channeldb.ChannelEdgeInfo,
p1, p2 *channeldb.ChannelEdgePolicy) error {
return cb(info, p1, p2)
}
return g.graph.ForEachNodeChannel(g.tx, nodePub[:], txCb)
}
// sourceNode returns the source node of the graph.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) sourceNode() route.Vertex {
return g.source
}
// fetchNodeFeatures returns the features of the given node. If the node is
// unknown, assume no additional features are supported.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
targetNode, err := g.graph.FetchLightningNode(g.tx, nodePub)
switch err {
// If the node exists and has features, return them directly.
case nil:
return targetNode.Features, nil
// If we couldn't find a node announcement, populate a blank feature
// vector.
case channeldb.ErrGraphNodeNotFound:
return lnwire.EmptyFeatureVector(), nil
// Otherwise bubble the error up.
default:
return nil, err
}
}

@ -7,7 +7,6 @@ import (
"math" "math"
"time" "time"
"github.com/coreos/bbolt"
sphinx "github.com/lightningnetwork/lightning-onion" sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/feature"
@ -346,10 +345,11 @@ type PathFindingConfig struct {
// getMaxOutgoingAmt returns the maximum available balance in any of the // getMaxOutgoingAmt returns the maximum available balance in any of the
// channels of the given node. // channels of the given node.
func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64, func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64,
g *graphParams, tx *bbolt.Tx) (lnwire.MilliSatoshi, error) { bandwidthHints map[uint64]lnwire.MilliSatoshi,
g routingGraph) (lnwire.MilliSatoshi, error) {
var max lnwire.MilliSatoshi var max lnwire.MilliSatoshi
cb := func(_ *bbolt.Tx, edgeInfo *channeldb.ChannelEdgeInfo, outEdge, cb := func(edgeInfo *channeldb.ChannelEdgeInfo, outEdge,
_ *channeldb.ChannelEdgePolicy) error { _ *channeldb.ChannelEdgePolicy) error {
if outEdge == nil { if outEdge == nil {
@ -363,7 +363,7 @@ func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64,
return nil return nil
} }
bandwidth, ok := g.bandwidthHints[chanID] bandwidth, ok := bandwidthHints[chanID]
// If the bandwidth is not available for whatever reason, don't // If the bandwidth is not available for whatever reason, don't
// fail the pathfinding early. // fail the pathfinding early.
@ -380,28 +380,54 @@ func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64,
} }
// Iterate over all channels of the to node. // Iterate over all channels of the to node.
err := g.graph.ForEachNodeChannel(tx, node[:], cb) err := g.forEachNodeChannel(node, cb)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return max, err return max, err
} }
// findPath attempts to find a path from the source node within the // findPath attempts to find a path from the source node within the ChannelGraph
// ChannelGraph to the target node that's capable of supporting a payment of // to the target node that's capable of supporting a payment of `amt` value. The
// `amt` value. The current approach implemented is modified version of // current approach implemented is modified version of Dijkstra's algorithm to
// Dijkstra's algorithm to find a single shortest path between the source node // find a single shortest path between the source node and the destination. The
// and the destination. The distance metric used for edges is related to the // distance metric used for edges is related to the time-lock+fee costs along a
// time-lock+fee costs along a particular edge. If a path is found, this // particular edge. If a path is found, this function returns a slice of
// function returns a slice of ChannelHop structs which encoded the chosen path // ChannelHop structs which encoded the chosen path from the target to the
// from the target to the source. The search is performed backwards from // source. The search is performed backwards from destination node back to
// destination node back to source. This is to properly accumulate fees // source. This is to properly accumulate fees that need to be paid along the
// that need to be paid along the path and accurately check the amount // path and accurately check the amount to forward at every node against the
// to forward at every node against the available bandwidth. // available bandwidth.
func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) (
[]*channeldb.ChannelEdgePolicy, error) { []*channeldb.ChannelEdgePolicy, error) {
routingTx, err := newDbRoutingTx(g.graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
return findPathInternal(
g.additionalEdges, g.bandwidthHints, routingTx, r, cfg, source,
target, amt, finalHtlcExpiry,
)
}
// findPathInternal is the internal implementation of findPath.
func findPathInternal(
additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy,
bandwidthHints map[uint64]lnwire.MilliSatoshi,
graph routingGraph,
r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) (
[]*channeldb.ChannelEdgePolicy, error) {
// Pathfinding can be a significant portion of the total payment // Pathfinding can be a significant portion of the total payment
// latency, especially on low-powered devices. Log several metrics to // latency, especially on low-powered devices. Log several metrics to
// aid in the analysis performance problems in this area. // aid in the analysis performance problems in this area.
@ -414,45 +440,20 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
"time=%v", nodesVisited, edgesExpanded, timeElapsed) "time=%v", nodesVisited, edgesExpanded, timeElapsed)
}() }()
// Get source node outside of the pathfinding tx, to prevent a deadlock.
selfNode, err := g.graph.SourceNode()
if err != nil {
return nil, err
}
self := selfNode.PubKeyBytes
// Get a db transaction to execute the graph queries in.
tx, err := g.graph.Database().Begin(false)
if err != nil {
return nil, err
}
defer tx.Rollback()
// If no destination features are provided, we will load what features // If no destination features are provided, we will load what features
// we have for the target node from our graph. // we have for the target node from our graph.
features := r.DestFeatures features := r.DestFeatures
if features == nil { if features == nil {
targetNode, err := g.graph.FetchLightningNode(tx, target) var err error
switch { features, err = graph.fetchNodeFeatures(target)
if err != nil {
// If the node exists and has features, use them directly.
case err == nil:
features = targetNode.Features
// If an error other than the node not existing is hit, abort.
case err != channeldb.ErrGraphNodeNotFound:
return nil, err return nil, err
// Otherwise, we couldn't find a node announcement, populate a
// blank feature vector.
default:
features = lnwire.EmptyFeatureVector()
} }
} }
// Ensure that the destination's features don't include unknown // Ensure that the destination's features don't include unknown
// required features. // required features.
err = feature.ValidateRequired(features) err := feature.ValidateRequired(features)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -485,8 +486,12 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// If we are routing from ourselves, check that we have enough local // If we are routing from ourselves, check that we have enough local
// balance available. // balance available.
self := graph.sourceNode()
if source == self { if source == self {
max, err := getMaxOutgoingAmt(self, r.OutgoingChannelID, g, tx) max, err := getMaxOutgoingAmt(
self, r.OutgoingChannelID, bandwidthHints, graph,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -504,7 +509,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
distance := make(map[route.Vertex]*nodeWithDist, estimatedNodeCount) distance := make(map[route.Vertex]*nodeWithDist, estimatedNodeCount)
additionalEdgesWithSrc := make(map[route.Vertex][]*edgePolicyWithSource) additionalEdgesWithSrc := make(map[route.Vertex][]*edgePolicyWithSource)
for vertex, outgoingEdgePolicies := range g.additionalEdges { for vertex, outgoingEdgePolicies := range additionalEdges {
// Build reverse lookup to find incoming edges. Needed because // Build reverse lookup to find incoming edges. Needed because
// search is taken place from target to source. // search is taken place from target to source.
for _, outgoingEdgePolicy := range outgoingEdgePolicies { for _, outgoingEdgePolicy := range outgoingEdgePolicies {
@ -746,45 +751,35 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Check cache for features of the fromNode. // Check cache for features of the fromNode.
fromFeatures, ok := featureCache[node] fromFeatures, ok := featureCache[node]
if !ok { if ok {
targetNode, err := g.graph.FetchLightningNode(tx, node) return fromFeatures, nil
switch {
// If the node exists and has valid features, use them.
case err == nil:
nodeFeatures := targetNode.Features
// Don't route through nodes that contain
// unknown required features.
err = feature.ValidateRequired(nodeFeatures)
if err != nil {
break
}
// Don't route through nodes that don't properly
// set all transitive feature dependencies.
err = feature.ValidateDeps(nodeFeatures)
if err != nil {
break
}
fromFeatures = nodeFeatures
// If an error other than the node not existing is hit,
// abort.
case err != channeldb.ErrGraphNodeNotFound:
return nil, err
// Otherwise, we couldn't find a node announcement,
// populate a blank feature vector.
default:
fromFeatures = lnwire.EmptyFeatureVector()
}
// Update cache.
featureCache[node] = fromFeatures
} }
// Fetch node features fresh from the graph.
fromFeatures, err := graph.fetchNodeFeatures(node)
if err != nil {
return nil, err
}
// Don't route through nodes that contain unknown required
// features and mark as nil in the cache.
err = feature.ValidateRequired(fromFeatures)
if err != nil {
featureCache[node] = nil
return nil, nil
}
// Don't route through nodes that don't properly set all
// transitive feature dependencies and mark as nil in the cache.
err = feature.ValidateDeps(fromFeatures)
if err != nil {
featureCache[node] = nil
return nil, nil
}
// Update cache.
featureCache[node] = fromFeatures
return fromFeatures, nil return fromFeatures, nil
} }
@ -797,7 +792,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Create unified policies for all incoming connections. // Create unified policies for all incoming connections.
u := newUnifiedPolicies(self, pivot, r.OutgoingChannelID) u := newUnifiedPolicies(self, pivot, r.OutgoingChannelID)
err := u.addGraphPolicies(g.graph, tx) err := u.addGraphPolicies(graph)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -828,7 +823,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
} }
policy := unifiedPolicy.getPolicy( policy := unifiedPolicy.getPolicy(
amtToSend, g.bandwidthHints, amtToSend, bandwidthHints,
) )
if policy == nil { if policy == nil {

@ -2311,6 +2311,13 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
return nil, err return nil, err
} }
// Fetch the current block height outside the routing transaction, to
// prevent the rpc call blocking the database.
_, height, err := r.cfg.Chain.GetBestBlock()
if err != nil {
return nil, err
}
// Allocate a list that will contain the unified policies for this // Allocate a list that will contain the unified policies for this
// route. // route.
edges := make([]*unifiedPolicy, len(hops)) edges := make([]*unifiedPolicy, len(hops))
@ -2328,6 +2335,18 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
runningAmt = *amt runningAmt = *amt
} }
// Open a transaction to execute the graph queries in.
routingTx, err := newDbRoutingTx(r.cfg.Graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
// Traverse hops backwards to accumulate fees in the running amounts. // Traverse hops backwards to accumulate fees in the running amounts.
source := r.selfNode.PubKeyBytes source := r.selfNode.PubKeyBytes
for i := len(hops) - 1; i >= 0; i-- { for i := len(hops) - 1; i >= 0; i-- {
@ -2346,7 +2365,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// known in the graph. // known in the graph.
u := newUnifiedPolicies(source, toNode, outgoingChan) u := newUnifiedPolicies(source, toNode, outgoingChan)
err := u.addGraphPolicies(r.cfg.Graph, nil) err := u.addGraphPolicies(routingTx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -2414,11 +2433,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
} }
// Build and return the final route. // Build and return the final route.
_, height, err := r.cfg.Chain.GetBestBlock()
if err != nil {
return nil, err
}
return newRoute( return newRoute(
source, pathEdges, uint32(height), source, pathEdges, uint32(height),
finalHopParams{ finalHopParams{

@ -2,7 +2,6 @@ package routing
import ( import (
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
@ -69,10 +68,8 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex,
// addGraphPolicies adds all policies that are known for the toNode in the // addGraphPolicies adds all policies that are known for the toNode in the
// graph. // graph.
func (u *unifiedPolicies) addGraphPolicies(g *channeldb.ChannelGraph, func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error {
tx *bbolt.Tx) error { cb := func(edgeInfo *channeldb.ChannelEdgeInfo, _,
cb := func(_ *bbolt.Tx, edgeInfo *channeldb.ChannelEdgeInfo, _,
inEdge *channeldb.ChannelEdgePolicy) error { inEdge *channeldb.ChannelEdgePolicy) error {
// If there is no edge policy for this candidate node, skip. // If there is no edge policy for this candidate node, skip.
@ -95,7 +92,7 @@ func (u *unifiedPolicies) addGraphPolicies(g *channeldb.ChannelGraph,
} }
// Iterate over all channels of the to node. // Iterate over all channels of the to node.
return g.ForEachNodeChannel(tx, u.toNode[:], cb) return g.forEachNodeChannel(u.toNode, cb)
} }
// unifiedPolicyEdge is the individual channel data that is kept inside an // unifiedPolicyEdge is the individual channel data that is kept inside an