routing: allow route to self

This commit is contained in:
Joost Jager 2019-11-18 10:19:20 +01:00
parent 81b7798c03
commit f8e9efbf99
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
2 changed files with 82 additions and 8 deletions

@ -381,11 +381,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// We can't always assume that the end destination is publicly
// advertised to the network so we'll manually include the target node.
// The target node charges no fee. Distance is set to 0, because this
// is the starting point of the graph traversal. We are searching
// backwards to get the fees first time right and correctly match
// channel bandwidth.
distance[target] = &nodeWithDist{
// The target node charges no fee. Distance is set to 0, because this is
// the starting point of the graph traversal. We are searching backwards
// to get the fees first time right and correctly match channel
// bandwidth.
//
// Don't record the initial partial path in the distance map and reserve
// that key for the source key in the case we route to ourselves.
partialPath := &nodeWithDist{
dist: 0,
weight: 0,
node: target,
@ -530,9 +533,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// TODO(roasbeef): also add path caching
// * similar to route caching, but doesn't factor in the amount
// The partial path that we start out with is a path that consists of
// just the target node.
partialPath := distance[target]
routeToSelf := source == target
for {
nodesVisited++
@ -555,6 +556,15 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Expand all connections using the optimal policy for each
// connection.
for fromNode, unifiedPolicy := range u.policies {
// The target node is not recorded in the distance map.
// Therefore we need to have this check to prevent
// creating a cycle. Only when we intend to route to
// self, we allow this cycle to form. In that case we'll
// also break out of the search loop below.
if !routeToSelf && fromNode == target {
continue
}
// Apply last hop restriction if set.
if r.LastHop != nil &&
pivot == target && fromNode != *r.LastHop {
@ -610,6 +620,9 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Advance current node.
currentNode = currentNodeWithDist.nextHop.Node.PubKeyBytes
// Check stop condition at the end of this loop. This prevents
// breaking out too soon for self-payments that have target set
// to source.
if currentNode == target {
break
}

@ -2224,6 +2224,53 @@ func TestNoCycle(t *testing.T) {
}
}
// TestRouteToSelf tests that it is possible to find a route to the self node.
func TestRouteToSelf(t *testing.T) {
t.Parallel()
testChannels := []*testChannel{
symmetricTestChannel("source", "a", 100000, &testChannelPolicy{
Expiry: 144,
FeeBaseMsat: 500,
}, 1),
symmetricTestChannel("source", "b", 100000, &testChannelPolicy{
Expiry: 144,
FeeBaseMsat: 1000,
}, 2),
symmetricTestChannel("a", "b", 100000, &testChannelPolicy{
Expiry: 144,
FeeBaseMsat: 1000,
}, 3),
}
ctx := newPathFindingTestContext(t, testChannels, "source")
defer ctx.cleanup()
paymentAmt := lnwire.NewMSatFromSatoshis(100)
target := ctx.source
// Find the best path to self. We expect this to be source->a->source,
// because a charges the lowest forwarding fee.
path, err := ctx.findPath(target, paymentAmt)
if err != nil {
t.Fatalf("unable to find path: %v", err)
}
ctx.assertPath(path, []uint64{1, 1})
outgoingChanID := uint64(1)
lastHop := ctx.keyFromAlias("b")
ctx.restrictParams.OutgoingChannelID = &outgoingChanID
ctx.restrictParams.LastHop = &lastHop
// Find the best path to self given that we want to go out via channel 1
// and return through node b.
path, err = ctx.findPath(target, paymentAmt)
if err != nil {
t.Fatalf("unable to find path: %v", err)
}
ctx.assertPath(path, []uint64{1, 3, 2})
}
type pathFindingTestContext struct {
t *testing.T
graphParams graphParams
@ -2291,3 +2338,17 @@ func (c *pathFindingTestContext) findPath(target route.Vertex,
c.source, target, amt,
)
}
func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, expected []uint64) {
if len(path) != len(expected) {
c.t.Fatalf("expected path of length %v, but got %v",
len(expected), len(path))
}
for i, edge := range path {
if edge.ChannelID != expected[i] {
c.t.Fatalf("expected hop %v to be channel %v, "+
"but got %v", i, expected[i], edge.ChannelID)
}
}
}