diff --git a/routing/pathfind.go b/routing/pathfind.go index c3e335ce..e2f1b447 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -263,6 +263,10 @@ type RestrictParams struct { // hop. If nil, any channel may be used. OutgoingChannelID *uint64 + // LastHop is the pubkey of the last node before the final destination + // is reached. If nil, any node may be used. + LastHop *route.Vertex + // CltvLimit is the maximum time lock of the route excluding the final // ctlv. After path finding is complete, the caller needs to increase // all cltv expiry heights with the required final cltv delta. @@ -562,6 +566,13 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Expand all connections using the optimal policy for each // connection. for fromNode, unifiedPolicy := range u.policies { + // Apply last hop restriction if set. + if r.LastHop != nil && + pivot == target && fromNode != *r.LastHop { + + continue + } + policy := unifiedPolicy.getPolicy( amtToSend, g.bandwidthHints, ) diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 2c7afde4..dd477548 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -1874,6 +1874,51 @@ func TestRestrictOutgoingChannel(t *testing.T) { } } +// TestRestrictLastHop asserts that a last hop restriction is obeyed by the path +// finding algorithm. +func TestRestrictLastHop(t *testing.T) { + t.Parallel() + + // Set up a test graph with three possible paths from roasbeef to + // target. The path via channel 1 and 2 is the lowest cost path. + testChannels := []*testChannel{ + symmetricTestChannel("source", "a", 100000, &testChannelPolicy{ + Expiry: 144, + }, 1), + symmetricTestChannel("a", "target", 100000, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + }, 2), + symmetricTestChannel("source", "b", 100000, &testChannelPolicy{ + Expiry: 144, + }, 3), + symmetricTestChannel("b", "target", 100000, &testChannelPolicy{ + Expiry: 144, + FeeRate: 800, + }, 4), + } + + ctx := newPathFindingTestContext(t, testChannels, "source") + defer ctx.cleanup() + + paymentAmt := lnwire.NewMSatFromSatoshis(100) + target := ctx.keyFromAlias("target") + lastHop := ctx.keyFromAlias("b") + + // Find the best path given the restriction to use b as the last hop. + // This should force pathfinding to not take the lowest cost option. + ctx.restrictParams.LastHop = &lastHop + path, err := ctx.findPath(target, paymentAmt) + if err != nil { + t.Fatalf("unable to find path: %v", err) + } + if path[0].ChannelID != 3 { + t.Fatalf("expected route to pass through channel 3, "+ + "but channel %v was selected instead", + path[0].ChannelID) + } +} + // TestCltvLimit asserts that a cltv limit is obeyed by the path finding // algorithm. func TestCltvLimit(t *testing.T) {