diff --git a/routing/pathfind.go b/routing/pathfind.go index c52e1cb7..dee06b0c 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -381,11 +381,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // We can't always assume that the end destination is publicly // advertised to the network so we'll manually include the target node. - // The target node charges no fee. Distance is set to 0, because this - // is the starting point of the graph traversal. We are searching - // backwards to get the fees first time right and correctly match - // channel bandwidth. - distance[target] = &nodeWithDist{ + // The target node charges no fee. Distance is set to 0, because this is + // the starting point of the graph traversal. We are searching backwards + // to get the fees first time right and correctly match channel + // bandwidth. + // + // Don't record the initial partial path in the distance map and reserve + // that key for the source key in the case we route to ourselves. + partialPath := &nodeWithDist{ dist: 0, weight: 0, node: target, @@ -530,9 +533,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // TODO(roasbeef): also add path caching // * similar to route caching, but doesn't factor in the amount - // The partial path that we start out with is a path that consists of - // just the target node. - partialPath := distance[target] + routeToSelf := source == target for { nodesVisited++ @@ -555,6 +556,15 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Expand all connections using the optimal policy for each // connection. for fromNode, unifiedPolicy := range u.policies { + // The target node is not recorded in the distance map. + // Therefore we need to have this check to prevent + // creating a cycle. Only when we intend to route to + // self, we allow this cycle to form. In that case we'll + // also break out of the search loop below. + if !routeToSelf && fromNode == target { + continue + } + // Apply last hop restriction if set. if r.LastHop != nil && pivot == target && fromNode != *r.LastHop { @@ -610,6 +620,9 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Advance current node. currentNode = currentNodeWithDist.nextHop.Node.PubKeyBytes + // Check stop condition at the end of this loop. This prevents + // breaking out too soon for self-payments that have target set + // to source. if currentNode == target { break } diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index e7a93194..6cde1600 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2224,6 +2224,53 @@ func TestNoCycle(t *testing.T) { } } +// TestRouteToSelf tests that it is possible to find a route to the self node. +func TestRouteToSelf(t *testing.T) { + t.Parallel() + + testChannels := []*testChannel{ + symmetricTestChannel("source", "a", 100000, &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 500, + }, 1), + symmetricTestChannel("source", "b", 100000, &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 1000, + }, 2), + symmetricTestChannel("a", "b", 100000, &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 1000, + }, 3), + } + + ctx := newPathFindingTestContext(t, testChannels, "source") + defer ctx.cleanup() + + paymentAmt := lnwire.NewMSatFromSatoshis(100) + target := ctx.source + + // Find the best path to self. We expect this to be source->a->source, + // because a charges the lowest forwarding fee. + path, err := ctx.findPath(target, paymentAmt) + if err != nil { + t.Fatalf("unable to find path: %v", err) + } + ctx.assertPath(path, []uint64{1, 1}) + + outgoingChanID := uint64(1) + lastHop := ctx.keyFromAlias("b") + ctx.restrictParams.OutgoingChannelID = &outgoingChanID + ctx.restrictParams.LastHop = &lastHop + + // Find the best path to self given that we want to go out via channel 1 + // and return through node b. + path, err = ctx.findPath(target, paymentAmt) + if err != nil { + t.Fatalf("unable to find path: %v", err) + } + ctx.assertPath(path, []uint64{1, 3, 2}) +} + type pathFindingTestContext struct { t *testing.T graphParams graphParams @@ -2291,3 +2338,17 @@ func (c *pathFindingTestContext) findPath(target route.Vertex, c.source, target, amt, ) } + +func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, expected []uint64) { + if len(path) != len(expected) { + c.t.Fatalf("expected path of length %v, but got %v", + len(expected), len(path)) + } + + for i, edge := range path { + if edge.ChannelID != expected[i] { + c.t.Fatalf("expected hop %v to be channel %v, "+ + "but got %v", i, expected[i], edge.ChannelID) + } + } +}