routing: avoid modifying AssumeChannelValid in unit tests

This produces a race condition when reading AssumeChannelValid from a
different goroutine. Instead we isolate the test cases and initial
AssumeChannelValid properly.
This commit is contained in:
Conner Fromknecht 2021-02-17 18:55:56 -08:00
parent f7c5236bf6
commit 250bc8560e
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7

@ -72,6 +72,14 @@ func (c *testCtx) RestartRouter() error {
func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGraphInstance) ( func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGraphInstance) (
*testCtx, func(), error) { *testCtx, func(), error) {
return createTestCtxFromGraphInstanceAssumeValid(
startingHeight, graphInstance, false,
)
}
func createTestCtxFromGraphInstanceAssumeValid(startingHeight uint32,
graphInstance *testGraphInstance, assumeValid bool) (*testCtx, func(), error) {
// We'll initialize an instance of the channel router with mock // We'll initialize an instance of the channel router with mock
// versions of the chain and channel notifier. As we don't need to test // versions of the chain and channel notifier. As we don't need to test
// any p2p functionality, the peer send and switch send messages won't // any p2p functionality, the peer send and switch send messages won't
@ -128,6 +136,7 @@ func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGr
}, },
PathFindingConfig: pathFindingConfig, PathFindingConfig: pathFindingConfig,
Clock: clock.NewTestClock(time.Unix(1, 0)), Clock: clock.NewTestClock(time.Unix(1, 0)),
AssumeChannelValid: assumeValid,
}) })
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("unable to create router %v", err) return nil, nil, fmt.Errorf("unable to create router %v", err)
@ -2034,6 +2043,15 @@ func TestPruneChannelGraphStaleEdges(t *testing.T) {
func TestPruneChannelGraphDoubleDisabled(t *testing.T) { func TestPruneChannelGraphDoubleDisabled(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("no_assumechannelvalid", func(t *testing.T) {
testPruneChannelGraphDoubleDisabled(t, false)
})
t.Run("assumechannelvalid", func(t *testing.T) {
testPruneChannelGraphDoubleDisabled(t, true)
})
}
func testPruneChannelGraphDoubleDisabled(t *testing.T, assumeValid bool) {
// We'll create the following test graph so that only the last channel // We'll create the following test graph so that only the last channel
// is pruned. We'll use a fresh timestamp to ensure they're not pruned // is pruned. We'll use a fresh timestamp to ensure they're not pruned
// according to that heuristic. // according to that heuristic.
@ -2125,36 +2143,39 @@ func TestPruneChannelGraphDoubleDisabled(t *testing.T) {
defer testGraph.cleanUp() defer testGraph.cleanUp()
const startingHeight = 100 const startingHeight = 100
ctx, cleanUp, err := createTestCtxFromGraphInstance( ctx, cleanUp, err := createTestCtxFromGraphInstanceAssumeValid(
startingHeight, testGraph, startingHeight, testGraph, assumeValid,
) )
if err != nil { if err != nil {
t.Fatalf("unable to create test context: %v", err) t.Fatalf("unable to create test context: %v", err)
} }
defer cleanUp() defer cleanUp()
// All the channels should exist within the graph before pruning them. // All the channels should exist within the graph before pruning them
// when not using AssumeChannelValid, otherwise we should have pruned
// the last channel on startup.
if !assumeValid {
assertChannelsPruned(t, ctx.graph, testChannels) assertChannelsPruned(t, ctx.graph, testChannels)
} else {
// If we attempt to prune them without AssumeChannelValid being set,
// none should be pruned.
if err := ctx.router.pruneZombieChans(); err != nil {
t.Fatalf("unable to prune zombie channels: %v", err)
}
assertChannelsPruned(t, ctx.graph, testChannels)
// Now that AssumeChannelValid is set, we'll prune the graph again and
// the last channel should be the only one pruned.
ctx.router.cfg.AssumeChannelValid = true
if err := ctx.router.pruneZombieChans(); err != nil {
t.Fatalf("unable to prune zombie channels: %v", err)
}
prunedChannel := testChannels[len(testChannels)-1].ChannelID prunedChannel := testChannels[len(testChannels)-1].ChannelID
assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel) assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel)
} }
if err := ctx.router.pruneZombieChans(); err != nil {
t.Fatalf("unable to prune zombie channels: %v", err)
}
// If we attempted to prune them without AssumeChannelValid being set,
// none should be pruned. Otherwise the last channel should still be
// pruned.
if !assumeValid {
assertChannelsPruned(t, ctx.graph, testChannels)
} else {
prunedChannel := testChannels[len(testChannels)-1].ChannelID
assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel)
}
}
// TestFindPathFeeWeighting tests that the findPath method will properly prefer // TestFindPathFeeWeighting tests that the findPath method will properly prefer
// routes with lower fees over routes with lower time lock values. This is // routes with lower fees over routes with lower time lock values. This is
// meant to exercise the fact that the internal findPath method ranks edges // meant to exercise the fact that the internal findPath method ranks edges