diff --git a/routing/router_test.go b/routing/router_test.go
index a04aa7b3..19ad1ef3 100644
--- a/routing/router_test.go
+++ b/routing/router_test.go
@@ -72,6 +72,14 @@ func (c *testCtx) RestartRouter() error {
 func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGraphInstance) (
 	*testCtx, func(), error) {
 
+	return createTestCtxFromGraphInstanceAssumeValid(
+		startingHeight, graphInstance, false,
+	)
+}
+
+func createTestCtxFromGraphInstanceAssumeValid(startingHeight uint32,
+	graphInstance *testGraphInstance, assumeValid bool) (*testCtx, func(), error) {
+
 	// We'll initialize an instance of the channel router with mock
 	// versions of the chain and channel notifier. As we don't need to test
 	// any p2p functionality, the peer send and switch send messages won't
@@ -126,8 +134,9 @@ func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGr
 			next := atomic.AddUint64(&uniquePaymentID, 1)
 			return next, nil
 		},
-		PathFindingConfig: pathFindingConfig,
-		Clock:             clock.NewTestClock(time.Unix(1, 0)),
+		PathFindingConfig:  pathFindingConfig,
+		Clock:              clock.NewTestClock(time.Unix(1, 0)),
+		AssumeChannelValid: assumeValid,
 	})
 	if err != nil {
 		return nil, nil, fmt.Errorf("unable to create router %v", err)
@@ -2034,6 +2043,15 @@ func TestPruneChannelGraphStaleEdges(t *testing.T) {
 func TestPruneChannelGraphDoubleDisabled(t *testing.T) {
 	t.Parallel()
 
+	t.Run("no_assumechannelvalid", func(t *testing.T) {
+		testPruneChannelGraphDoubleDisabled(t, false)
+	})
+	t.Run("assumechannelvalid", func(t *testing.T) {
+		testPruneChannelGraphDoubleDisabled(t, true)
+	})
+}
+
+func testPruneChannelGraphDoubleDisabled(t *testing.T, assumeValid bool) {
 	// We'll create the following test graph so that only the last channel
 	// is pruned. We'll use a fresh timestamp to ensure they're not pruned
 	// according to that heuristic.
@@ -2125,34 +2143,37 @@ func TestPruneChannelGraphDoubleDisabled(t *testing.T) {
 	defer testGraph.cleanUp()
 
 	const startingHeight = 100
-	ctx, cleanUp, err := createTestCtxFromGraphInstance(
-		startingHeight, testGraph,
+	ctx, cleanUp, err := createTestCtxFromGraphInstanceAssumeValid(
+		startingHeight, testGraph, assumeValid,
 	)
 	if err != nil {
 		t.Fatalf("unable to create test context: %v", err)
 	}
 	defer cleanUp()
 
-	// All the channels should exist within the graph before pruning them.
-	assertChannelsPruned(t, ctx.graph, testChannels)
+	// All the channels should exist within the graph before pruning them
+	// when not using AssumeChannelValid, otherwise we should have pruned
+	// the last channel on startup.
+	if !assumeValid {
+		assertChannelsPruned(t, ctx.graph, testChannels)
+	} else {
+		prunedChannel := testChannels[len(testChannels)-1].ChannelID
+		assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel)
+	}
 
-	// If we attempt to prune them without AssumeChannelValid being set,
-	// none should be pruned.
 	if err := ctx.router.pruneZombieChans(); err != nil {
 		t.Fatalf("unable to prune zombie channels: %v", err)
 	}
 
-	assertChannelsPruned(t, ctx.graph, testChannels)
-
-	// Now that AssumeChannelValid is set, we'll prune the graph again and
-	// the last channel should be the only one pruned.
-	ctx.router.cfg.AssumeChannelValid = true
-	if err := ctx.router.pruneZombieChans(); err != nil {
-		t.Fatalf("unable to prune zombie channels: %v", err)
+	// If we attempted to prune them without AssumeChannelValid being set,
+	// none should be pruned. Otherwise the last channel should still be
+	// pruned.
+	if !assumeValid {
+		assertChannelsPruned(t, ctx.graph, testChannels)
+	} else {
+		prunedChannel := testChannels[len(testChannels)-1].ChannelID
+		assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel)
 	}
-
-	prunedChannel := testChannels[len(testChannels)-1].ChannelID
-	assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel)
 }
 
 // TestFindPathFeeWeighting tests that the findPath method will properly prefer