routing: use routingGraph interface in payment session

Preparation for more test coverage of payment session.

The function findPath now has the call signature of the former
findPathInternal function.
This commit is contained in:
Joost Jager 2020-03-17 11:32:07 +01:00
parent cb4cd49dc8
commit 47f9c1c3fd
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
9 changed files with 155 additions and 146 deletions

@ -111,8 +111,11 @@ func (c *integratedRoutingContext) testPayment(expectedNofAttempts int) {
} }
// Find a route. // Find a route.
path, err := findPathInternal( path, err := findPath(
nil, bandwidthHints, c.graph, &graphParams{
graph: c.graph,
bandwidthHints: bandwidthHints,
},
&restrictParams, &restrictParams,
&c.pathFindingCfg, &c.pathFindingCfg,
c.source.pubkey, c.target.pubkey, c.source.pubkey, c.target.pubkey,

@ -253,7 +253,7 @@ func edgeWeight(lockedAmt lnwire.MilliSatoshi, fee lnwire.MilliSatoshi,
// graphParams wraps the set of graph parameters passed to findPath. // graphParams wraps the set of graph parameters passed to findPath.
type graphParams struct { type graphParams struct {
// graph is the ChannelGraph to be used during path finding. // graph is the ChannelGraph to be used during path finding.
graph *channeldb.ChannelGraph graph routingGraph
// additionalEdges is an optional set of edges that should be // additionalEdges is an optional set of edges that should be
// considered during path finding, that is not already found in the // considered during path finding, that is not already found in the
@ -381,34 +381,8 @@ func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64,
// path and accurately check the amount to forward at every node against the // path and accurately check the amount to forward at every node against the
// available bandwidth. // available bandwidth.
func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( source, target route.Vertex, amt lnwire.MilliSatoshi,
[]*channeldb.ChannelEdgePolicy, error) { finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) {
routingTx, err := newDbRoutingTx(g.graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
return findPathInternal(
g.additionalEdges, g.bandwidthHints, routingTx, r, cfg, source,
target, amt, finalHtlcExpiry,
)
}
// findPathInternal is the internal implementation of findPath.
func findPathInternal(
additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy,
bandwidthHints map[uint64]lnwire.MilliSatoshi,
graph routingGraph,
r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) (
[]*channeldb.ChannelEdgePolicy, error) {
// Pathfinding can be a significant portion of the total payment // Pathfinding can be a significant portion of the total payment
// latency, especially on low-powered devices. Log several metrics to // latency, especially on low-powered devices. Log several metrics to
@ -427,7 +401,7 @@ func findPathInternal(
features := r.DestFeatures features := r.DestFeatures
if features == nil { if features == nil {
var err error var err error
features, err = graph.fetchNodeFeatures(target) features, err = g.graph.fetchNodeFeatures(target)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -468,11 +442,11 @@ func findPathInternal(
// If we are routing from ourselves, check that we have enough local // If we are routing from ourselves, check that we have enough local
// balance available. // balance available.
self := graph.sourceNode() self := g.graph.sourceNode()
if source == self { if source == self {
max, err := getMaxOutgoingAmt( max, err := getMaxOutgoingAmt(
self, r.OutgoingChannelID, bandwidthHints, graph, self, r.OutgoingChannelID, g.bandwidthHints, g.graph,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -491,7 +465,7 @@ func findPathInternal(
distance := make(map[route.Vertex]*nodeWithDist, estimatedNodeCount) distance := make(map[route.Vertex]*nodeWithDist, estimatedNodeCount)
additionalEdgesWithSrc := make(map[route.Vertex][]*edgePolicyWithSource) additionalEdgesWithSrc := make(map[route.Vertex][]*edgePolicyWithSource)
for vertex, outgoingEdgePolicies := range additionalEdges { for vertex, outgoingEdgePolicies := range g.additionalEdges {
// Build reverse lookup to find incoming edges. Needed because // Build reverse lookup to find incoming edges. Needed because
// search is taken place from target to source. // search is taken place from target to source.
for _, outgoingEdgePolicy := range outgoingEdgePolicies { for _, outgoingEdgePolicy := range outgoingEdgePolicies {
@ -739,7 +713,7 @@ func findPathInternal(
} }
// Fetch node features fresh from the graph. // Fetch node features fresh from the graph.
fromFeatures, err := graph.fetchNodeFeatures(node) fromFeatures, err := g.graph.fetchNodeFeatures(node)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -775,7 +749,7 @@ func findPathInternal(
// Create unified policies for all incoming connections. // Create unified policies for all incoming connections.
u := newUnifiedPolicies(self, pivot, r.OutgoingChannelID) u := newUnifiedPolicies(self, pivot, r.OutgoingChannelID)
err := u.addGraphPolicies(graph) err := u.addGraphPolicies(g.graph)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -806,7 +780,7 @@ func findPathInternal(
} }
policy := unifiedPolicy.getPolicy( policy := unifiedPolicy.getPolicy(
amtToSend, bandwidthHints, amtToSend, g.bandwidthHints,
) )
if policy == nil { if policy == nil {

@ -811,10 +811,8 @@ func testBasicGraphPathFindingCase(t *testing.T, graphInstance *testGraphInstanc
paymentAmt := lnwire.NewMSatFromSatoshis(test.paymentAmt) paymentAmt := lnwire.NewMSatFromSatoshis(test.paymentAmt)
target := graphInstance.aliasMap[test.target] target := graphInstance.aliasMap[test.target]
path, err := findPath( path, err := dbFindPath(
&graphParams{ graphInstance.graph, nil, nil,
graph: graphInstance.graph,
},
&RestrictParams{ &RestrictParams{
FeeLimit: test.feeLimit, FeeLimit: test.feeLimit,
ProbabilitySource: noProbabilitySource, ProbabilitySource: noProbabilitySource,
@ -1005,11 +1003,8 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) {
find := func(r *RestrictParams) ( find := func(r *RestrictParams) (
[]*channeldb.ChannelEdgePolicy, error) { []*channeldb.ChannelEdgePolicy, error) {
return findPath( return dbFindPath(
&graphParams{ graph.graph, additionalEdges, nil,
graph: graph.graph,
additionalEdges: additionalEdges,
},
r, testPathFindingConfig, r, testPathFindingConfig,
sourceNode.PubKeyBytes, doge.PubKeyBytes, paymentAmt, sourceNode.PubKeyBytes, doge.PubKeyBytes, paymentAmt,
0, 0,
@ -1433,10 +1428,8 @@ func TestPathNotAvailable(t *testing.T) {
var unknownNode route.Vertex var unknownNode route.Vertex
copy(unknownNode[:], unknownNodeBytes) copy(unknownNode[:], unknownNodeBytes)
_, err = findPath( _, err = dbFindPath(
&graphParams{ graph.graph, nil, nil,
graph: graph.graph,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, unknownNode, 100, 0, sourceNode.PubKeyBytes, unknownNode, 100, 0,
) )
@ -1482,7 +1475,7 @@ func TestDestTLVGraphFallback(t *testing.T) {
ctx := newPathFindingTestContext(t, testChannels, "roasbeef") ctx := newPathFindingTestContext(t, testChannels, "roasbeef")
defer ctx.cleanup() defer ctx.cleanup()
sourceNode, err := ctx.graphParams.graph.SourceNode() sourceNode, err := ctx.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
@ -1491,10 +1484,8 @@ func TestDestTLVGraphFallback(t *testing.T) {
find := func(r *RestrictParams, find := func(r *RestrictParams,
target route.Vertex) ([]*channeldb.ChannelEdgePolicy, error) { target route.Vertex) ([]*channeldb.ChannelEdgePolicy, error) {
return findPath( return dbFindPath(
&graphParams{ ctx.graph, nil, nil,
graph: ctx.graphParams.graph,
},
r, testPathFindingConfig, r, testPathFindingConfig,
sourceNode.PubKeyBytes, target, 100, 0, sourceNode.PubKeyBytes, target, 100, 0,
) )
@ -1765,10 +1756,8 @@ func TestPathInsufficientCapacity(t *testing.T) {
target := graph.aliasMap["sophon"] target := graph.aliasMap["sophon"]
payAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) payAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
_, err = findPath( _, err = dbFindPath(
&graphParams{ graph.graph, nil, nil,
graph: graph.graph,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, target, payAmt, 0, sourceNode.PubKeyBytes, target, payAmt, 0,
) )
@ -1798,10 +1787,8 @@ func TestRouteFailMinHTLC(t *testing.T) {
// attempt should fail. // attempt should fail.
target := graph.aliasMap["songoku"] target := graph.aliasMap["songoku"]
payAmt := lnwire.MilliSatoshi(10) payAmt := lnwire.MilliSatoshi(10)
_, err = findPath( _, err = dbFindPath(
&graphParams{ graph.graph, nil, nil,
graph: graph.graph,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, target, payAmt, 0, sourceNode.PubKeyBytes, target, payAmt, 0,
) )
@ -1897,10 +1884,8 @@ func TestRouteFailDisabledEdge(t *testing.T) {
// succeed without issue, and return a single path via phamnuwen // succeed without issue, and return a single path via phamnuwen
target := graph.aliasMap["sophon"] target := graph.aliasMap["sophon"]
payAmt := lnwire.NewMSatFromSatoshis(105000) payAmt := lnwire.NewMSatFromSatoshis(105000)
_, err = findPath( _, err = dbFindPath(
&graphParams{ graph.graph, nil, nil,
graph: graph.graph,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, target, payAmt, 0, sourceNode.PubKeyBytes, target, payAmt, 0,
) )
@ -1925,10 +1910,8 @@ func TestRouteFailDisabledEdge(t *testing.T) {
t.Fatalf("unable to update edge: %v", err) t.Fatalf("unable to update edge: %v", err)
} }
_, err = findPath( _, err = dbFindPath(
&graphParams{ graph.graph, nil, nil,
graph: graph.graph,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, target, payAmt, 0, sourceNode.PubKeyBytes, target, payAmt, 0,
) )
@ -1950,10 +1933,8 @@ func TestRouteFailDisabledEdge(t *testing.T) {
// If we attempt to route through that edge, we should get a failure as // If we attempt to route through that edge, we should get a failure as
// it is no longer eligible. // it is no longer eligible.
_, err = findPath( _, err = dbFindPath(
&graphParams{ graph.graph, nil, nil,
graph: graph.graph,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, target, payAmt, 0, sourceNode.PubKeyBytes, target, payAmt, 0,
) )
@ -1984,10 +1965,8 @@ func TestPathSourceEdgesBandwidth(t *testing.T) {
// cheapest path. // cheapest path.
target := graph.aliasMap["sophon"] target := graph.aliasMap["sophon"]
payAmt := lnwire.NewMSatFromSatoshis(50000) payAmt := lnwire.NewMSatFromSatoshis(50000)
path, err := findPath( path, err := dbFindPath(
&graphParams{ graph.graph, nil, nil,
graph: graph.graph,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, target, payAmt, 0, sourceNode.PubKeyBytes, target, payAmt, 0,
) )
@ -2007,11 +1986,8 @@ func TestPathSourceEdgesBandwidth(t *testing.T) {
// Since both these edges has a bandwidth of zero, no path should be // Since both these edges has a bandwidth of zero, no path should be
// found. // found.
_, err = findPath( _, err = dbFindPath(
&graphParams{ graph.graph, nil, bandwidths,
graph: graph.graph,
bandwidthHints: bandwidths,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, target, payAmt, 0, sourceNode.PubKeyBytes, target, payAmt, 0,
) )
@ -2025,11 +2001,8 @@ func TestPathSourceEdgesBandwidth(t *testing.T) {
// Now, if we attempt to route again, we should find the path via // Now, if we attempt to route again, we should find the path via
// phamnuven, as the other source edge won't be considered. // phamnuven, as the other source edge won't be considered.
path, err = findPath( path, err = dbFindPath(
&graphParams{ graph.graph, nil, bandwidths,
graph: graph.graph,
bandwidthHints: bandwidths,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, target, payAmt, 0, sourceNode.PubKeyBytes, target, payAmt, 0,
) )
@ -2056,11 +2029,8 @@ func TestPathSourceEdgesBandwidth(t *testing.T) {
// Since we ignore disable flags for local channels, a path should // Since we ignore disable flags for local channels, a path should
// still be found. // still be found.
path, err = findPath( path, err = dbFindPath(
&graphParams{ graph.graph, nil, bandwidths,
graph: graph.graph,
bandwidthHints: bandwidths,
},
noRestrictions, testPathFindingConfig, noRestrictions, testPathFindingConfig,
sourceNode.PubKeyBytes, target, payAmt, 0, sourceNode.PubKeyBytes, target, payAmt, 0,
) )
@ -2811,7 +2781,7 @@ func TestRouteToSelf(t *testing.T) {
type pathFindingTestContext struct { type pathFindingTestContext struct {
t *testing.T t *testing.T
graphParams graphParams graph *channeldb.ChannelGraph
restrictParams RestrictParams restrictParams RestrictParams
pathFindingConfig PathFindingConfig pathFindingConfig PathFindingConfig
testGraphInstance *testGraphInstance testGraphInstance *testGraphInstance
@ -2838,9 +2808,7 @@ func newPathFindingTestContext(t *testing.T, testChannels []*testChannel,
testGraphInstance: testGraphInstance, testGraphInstance: testGraphInstance,
source: route.Vertex(sourceNode.PubKeyBytes), source: route.Vertex(sourceNode.PubKeyBytes),
pathFindingConfig: *testPathFindingConfig, pathFindingConfig: *testPathFindingConfig,
graphParams: graphParams{
graph: testGraphInstance.graph, graph: testGraphInstance.graph,
},
restrictParams: *noRestrictions, restrictParams: *noRestrictions,
} }
@ -2868,8 +2836,8 @@ func (c *pathFindingTestContext) findPath(target route.Vertex,
amt lnwire.MilliSatoshi) ([]*channeldb.ChannelEdgePolicy, amt lnwire.MilliSatoshi) ([]*channeldb.ChannelEdgePolicy,
error) { error) {
return findPath( return dbFindPath(
&c.graphParams, &c.restrictParams, &c.pathFindingConfig, c.graph, nil, nil, &c.restrictParams, &c.pathFindingConfig,
c.source, target, amt, 0, c.source, target, amt, 0,
) )
} }
@ -2887,3 +2855,33 @@ func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy,
} }
} }
} }
// dbFindPath calls findPath after getting a db transaction from the database
// graph.
func dbFindPath(graph *channeldb.ChannelGraph,
additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy,
bandwidthHints map[uint64]lnwire.MilliSatoshi,
r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) {
routingTx, err := newDbRoutingTx(graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
return findPath(
&graphParams{
additionalEdges: additionalEdges,
bandwidthHints: bandwidthHints,
graph: routingTx,
},
r, cfg, source, target, amt, finalHtlcExpiry,
)
}

@ -98,13 +98,19 @@ type paymentSession struct {
getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error) getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error)
sessionSource *SessionSource
payment *LightningPayment payment *LightningPayment
empty bool empty bool
pathFinder pathFinder pathFinder pathFinder
getRoutingGraph func() (routingGraph, func(), error)
// pathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probabiity.
pathFindingConfig PathFindingConfig
missionControl MissionController
} }
// RequestRoute returns a route which is likely to be capable for successfully // RequestRoute returns a route which is likely to be capable for successfully
@ -138,10 +144,8 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// Taking into account this prune view, we'll attempt to locate a path // Taking into account this prune view, we'll attempt to locate a path
// to our destination, respecting the recommendations from // to our destination, respecting the recommendations from
// MissionControl. // MissionControl.
ss := p.sessionSource
restrictions := &RestrictParams{ restrictions := &RestrictParams{
ProbabilitySource: ss.MissionControl.GetProbability, ProbabilitySource: p.missionControl.GetProbability,
FeeLimit: feeLimit, FeeLimit: feeLimit,
OutgoingChannelID: p.payment.OutgoingChannelID, OutgoingChannelID: p.payment.OutgoingChannelID,
LastHop: p.payment.LastHop, LastHop: p.payment.LastHop,
@ -164,14 +168,22 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
finalHtlcExpiry := int32(height) + int32(finalCltvDelta) finalHtlcExpiry := int32(height) + int32(finalCltvDelta)
routingGraph, cleanup, err := p.getRoutingGraph()
if err != nil {
return nil, err
}
defer cleanup()
sourceVertex := routingGraph.sourceNode()
path, err := p.pathFinder( path, err := p.pathFinder(
&graphParams{ &graphParams{
graph: ss.Graph,
additionalEdges: p.additionalEdges, additionalEdges: p.additionalEdges,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
graph: routingGraph,
}, },
restrictions, &ss.PathFindingConfig, restrictions, &p.pathFindingConfig,
ss.SelfNode.PubKeyBytes, p.payment.Target, sourceVertex, p.payment.Target,
maxAmt, finalHtlcExpiry, maxAmt, finalHtlcExpiry,
) )
if err != nil { if err != nil {
@ -180,7 +192,6 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// With the next candidate path found, we'll attempt to turn this into // With the next candidate path found, we'll attempt to turn this into
// a route by applying the time-lock and fee requirements. // a route by applying the time-lock and fee requirements.
sourceVertex := route.Vertex(ss.SelfNode.PubKeyBytes)
route, err := newRoute( route, err := newRoute(
sourceVertex, path, height, sourceVertex, path, height,
finalHopParams{ finalHopParams{

@ -26,9 +26,6 @@ type SessionSource struct {
// the available bandwidth of the link should be returned. // the available bandwidth of the link should be returned.
QueryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi QueryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi
// SelfNode is our own node.
SelfNode *channeldb.LightningNode
// MissionControl is a shared memory of sorts that executions of payment // MissionControl is a shared memory of sorts that executions of payment
// path finding use in order to remember which vertexes/edges were // path finding use in order to remember which vertexes/edges were
// pruned from prior attempts. During payment execution, errors sent by // pruned from prior attempts. During payment execution, errors sent by
@ -43,6 +40,21 @@ type SessionSource struct {
PathFindingConfig PathFindingConfig PathFindingConfig PathFindingConfig
} }
// getRoutingGraph returns a routing graph and a clean-up function for
// pathfinding.
func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
routingTx, err := newDbRoutingTx(m.Graph)
if err != nil {
return nil, nil, err
}
return routingTx, func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}, nil
}
// NewPaymentSession creates a new payment session backed by the latest prune // NewPaymentSession creates a new payment session backed by the latest prune
// view from Mission Control. An optional set of routing hints can be provided // view from Mission Control. An optional set of routing hints can be provided
// in order to populate additional edges to explore when finding a path to the // in order to populate additional edges to explore when finding a path to the
@ -69,9 +81,11 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
return &paymentSession{ return &paymentSession{
additionalEdges: edges, additionalEdges: edges,
getBandwidthHints: getBandwidthHints, getBandwidthHints: getBandwidthHints,
sessionSource: m,
payment: p, payment: p,
pathFinder: findPath, pathFinder: findPath,
getRoutingGraph: m.getRoutingGraph,
pathFindingConfig: m.PathFindingConfig,
missionControl: m.MissionControl,
}, nil }, nil
} }
@ -80,7 +94,6 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
// missioncontrol for resumed payment we don't want to make more attempts for. // missioncontrol for resumed payment we don't want to make more attempts for.
func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession { func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession {
return &paymentSession{ return &paymentSession{
sessionSource: m,
empty: true, empty: true,
} }
} }

@ -13,10 +13,11 @@ func TestRequestRoute(t *testing.T) {
height = 10 height = 10
) )
findPath := func(g *graphParams, r *RestrictParams, findPath := func(
cfg *PathFindingConfig, source, target route.Vertex, g *graphParams,
amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( r *RestrictParams, cfg *PathFindingConfig,
[]*channeldb.ChannelEdgePolicy, error) { source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) {
// We expect find path to receive a cltv limit excluding the // We expect find path to receive a cltv limit excluding the
// final cltv delta (including the block padding). // final cltv delta (including the block padding).
@ -37,13 +38,6 @@ func TestRequestRoute(t *testing.T) {
return path, nil return path, nil
} }
sessionSource := &SessionSource{
SelfNode: &channeldb.LightningNode{},
MissionControl: &MissionControl{
cfg: &MissionControlConfig{},
},
}
cltvLimit := uint32(30) cltvLimit := uint32(30)
finalCltvDelta := uint16(8) finalCltvDelta := uint16(8)
@ -60,9 +54,14 @@ func TestRequestRoute(t *testing.T) {
return nil, nil return nil, nil
}, },
sessionSource: sessionSource,
payment: payment, payment: payment,
pathFinder: findPath, pathFinder: findPath,
missionControl: &MissionControl{
cfg: &MissionControlConfig{},
},
getRoutingGraph: func() (routingGraph, func(), error) {
return &sessionGraph{}, func() {}, nil
},
} }
route, err := session.RequestRoute( route, err := session.RequestRoute(
@ -79,3 +78,11 @@ func TestRequestRoute(t *testing.T) {
route.TotalTimeLock) route.TotalTimeLock)
} }
} }
type sessionGraph struct {
routingGraph
}
func (g *sessionGraph) sourceNode() route.Vertex {
return route.Vertex{}
}

@ -1426,13 +1426,25 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
// execute our path finding algorithm. // execute our path finding algorithm.
finalHtlcExpiry := currentHeight + int32(finalExpiry) finalHtlcExpiry := currentHeight + int32(finalExpiry)
routingTx, err := newDbRoutingTx(r.cfg.Graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
path, err := findPath( path, err := findPath(
&graphParams{ &graphParams{
graph: r.cfg.Graph,
bandwidthHints: bandwidthHints,
additionalEdges: routeHints, additionalEdges: routeHints,
bandwidthHints: bandwidthHints,
graph: routingTx,
}, },
restrictions, &r.cfg.PathFindingConfig, restrictions,
&r.cfg.PathFindingConfig,
source, target, amt, finalHtlcExpiry, source, target, amt, finalHtlcExpiry,
) )
if err != nil { if err != nil {

@ -79,11 +79,6 @@ func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGr
chain := newMockChain(startingHeight) chain := newMockChain(startingHeight)
chainView := newMockChainView(chain) chainView := newMockChainView(chain)
selfNode, err := graphInstance.graph.SourceNode()
if err != nil {
return nil, nil, err
}
pathFindingConfig := PathFindingConfig{ pathFindingConfig := PathFindingConfig{
MinProbability: 0.01, MinProbability: 0.01,
PaymentAttemptPenalty: 100, PaymentAttemptPenalty: 100,
@ -105,7 +100,6 @@ func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGr
sessionSource := &SessionSource{ sessionSource := &SessionSource{
Graph: graphInstance.graph, Graph: graphInstance.graph,
SelfNode: selfNode,
QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(e.Capacity) return lnwire.NewMSatFromSatoshis(e.Capacity)
}, },
@ -2188,10 +2182,8 @@ func TestFindPathFeeWeighting(t *testing.T) {
// We'll now attempt a path finding attempt using this set up. Due to // We'll now attempt a path finding attempt using this set up. Due to
// the edge weighting, we should select the direct path over the 2 hop // the edge weighting, we should select the direct path over the 2 hop
// path even though the direct path has a higher potential time lock. // path even though the direct path has a higher potential time lock.
path, err := findPath( path, err := dbFindPath(
&graphParams{ ctx.graph, nil, nil,
graph: ctx.graph,
},
noRestrictions, noRestrictions,
testPathFindingConfig, testPathFindingConfig,
sourceNode.PubKeyBytes, target, amt, 0, sourceNode.PubKeyBytes, target, amt, 0,

@ -731,7 +731,6 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
Graph: chanGraph, Graph: chanGraph,
MissionControl: s.missionControl, MissionControl: s.missionControl,
QueryBandwidth: queryBandwidth, QueryBandwidth: queryBandwidth,
SelfNode: selfNode,
PathFindingConfig: pathFindingConfig, PathFindingConfig: pathFindingConfig,
} }