From 5daf75b2642afe9a516da8379080e8bc8757f604 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Thu, 16 Aug 2018 21:35:38 +0200 Subject: [PATCH] routing: restructure test context creation --- routing/notifications_test.go | 8 +- routing/pathfind_test.go | 189 ++++++++++++++++++---------------- routing/router_test.go | 120 +++++++++++---------- 3 files changed, 168 insertions(+), 149 deletions(-) diff --git a/routing/notifications_test.go b/routing/notifications_test.go index 0a5cc393..2fa12274 100644 --- a/routing/notifications_test.go +++ b/routing/notifications_test.go @@ -339,7 +339,7 @@ func (m *mockChainView) Stop() error { func TestEdgeUpdateNotification(t *testing.T) { t.Parallel() - ctx, cleanUp, err := createTestCtx(0) + ctx, cleanUp, err := createTestCtxSingleNode(0) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -528,7 +528,7 @@ func TestNodeUpdateNotification(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight) + ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -706,7 +706,7 @@ func TestNotificationCancellation(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight) + ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -798,7 +798,7 @@ func TestChannelCloseNotification(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight) + ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 8180aa07..f1ee682f 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -121,17 +121,12 @@ func makeTestGraph() (*channeldb.ChannelGraph, func(), error) { return cdb.ChannelGraph(), cleanUp, nil } -// aliasMap is a map from a node's alias to its public key. This type is -// provided in order to allow easily look up from the human memorable alias -// to an exact node's public key. -type aliasMap map[string]*btcec.PublicKey - // parseTestGraph returns a fully populated ChannelGraph given a path to a JSON // file which encodes a test graph. -func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, error) { +func parseTestGraph(path string) (*testGraphInstance, error) { graphJSON, err := ioutil.ReadFile(path) if err != nil { - return nil, nil, nil, err + return nil, err } // First unmarshal the JSON graph into an instance of the testGraph @@ -139,7 +134,7 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err // will be properly parsed into the struct above. var g testGraph if err := json.Unmarshal(graphJSON, &g); err != nil { - return nil, nil, nil, err + return nil, err } // We'll use this fake address for the IP address of all the nodes in @@ -148,14 +143,14 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err var testAddrs []net.Addr testAddr, err := net.ResolveTCPAddr("tcp", "192.0.0.1:8888") if err != nil { - return nil, nil, nil, err + return nil, err } testAddrs = append(testAddrs, testAddr) // Next, create a temporary graph database for usage within the test. graph, cleanUp, err := makeTestGraph() if err != nil { - return nil, nil, nil, err + return nil, err } aliasMap := make(map[string]*btcec.PublicKey) @@ -165,7 +160,7 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err for _, node := range g.Nodes { pubBytes, err := hex.DecodeString(node.PubKey) if err != nil { - return nil, nil, nil, err + return nil, err } dbNode := &channeldb.LightningNode{ @@ -181,13 +176,13 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err // We require all aliases within the graph to be unique for our // tests. if _, ok := aliasMap[node.Alias]; ok { - return nil, nil, nil, errors.New("aliases for nodes " + + return nil, errors.New("aliases for nodes " + "must be unique!") } pub, err := btcec.ParsePubKey(pubBytes, btcec.S256()) if err != nil { - return nil, nil, nil, err + return nil, err } // If the alias is unique, then add the node to the @@ -203,7 +198,7 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err // iteration, then the JSON has an error as only ONE // node can be the source in the graph. if source != nil { - return nil, nil, nil, errors.New("JSON is invalid " + + return nil, errors.New("JSON is invalid " + "multiple nodes are tagged as the source") } @@ -213,14 +208,14 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err // With the node fully parsed, add it as a vertex within the // graph. if err := graph.AddLightningNode(dbNode); err != nil { - return nil, nil, nil, err + return nil, err } } if source != nil { // Set the selected source node if err := graph.SetSourceNode(source); err != nil { - return nil, nil, nil, err + return nil, err } } @@ -229,18 +224,18 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err for _, edge := range g.Edges { node1Bytes, err := hex.DecodeString(edge.Node1) if err != nil { - return nil, nil, nil, err + return nil, err } node2Bytes, err := hex.DecodeString(edge.Node2) if err != nil { - return nil, nil, nil, err + return nil, err } fundingTXID := strings.Split(edge.ChannelPoint, ":")[0] txidBytes, err := chainhash.NewHashFromStr(fundingTXID) if err != nil { - return nil, nil, nil, err + return nil, err } fundingPoint := wire.OutPoint{ Hash: *txidBytes, @@ -263,7 +258,7 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err err = graph.AddChannelEdge(&edgeInfo) if err != nil && err != channeldb.ErrEdgeAlreadyExist { - return nil, nil, nil, err + return nil, err } edgePolicy := &channeldb.ChannelEdgePolicy{ @@ -277,11 +272,15 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { - return nil, nil, nil, err + return nil, err } } - return graph, cleanUp, aliasMap, nil + return &testGraphInstance{ + graph: graph, + cleanUp: cleanUp, + aliasMap: aliasMap, + }, nil } type testChannelPolicy struct { @@ -329,26 +328,36 @@ type testChannel struct { Capacity btcutil.Amount } -// createTestGraph returns a fully populated ChannelGraph based on a set of +type testGraphInstance struct { + graph *channeldb.ChannelGraph + cleanUp func() + + // aliasMap is a map from a node's alias to its public key. This type is + // provided in order to allow easily look up from the human memorable alias + // to an exact node's public key. + aliasMap map[string]*btcec.PublicKey +} + +// createTestGraphFromChannels returns a fully populated ChannelGraph based on a set of // test channels. Additional required information like keys are derived in // a deterministical way and added to the channel graph. A list of nodes is // not required and derived from the channel data. The goal is to keep // instantiating a test channel graph as light weight as possible. -func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func(), aliasMap, error) { +func createTestGraphFromChannels(testChannels []*testChannel) (*testGraphInstance, error) { // We'll use this fake address for the IP address of all the nodes in // our tests. This value isn't needed for path finding so it doesn't // need to be unique. var testAddrs []net.Addr testAddr, err := net.ResolveTCPAddr("tcp", "192.0.0.1:8888") if err != nil { - return nil, nil, nil, err + return nil, err } testAddrs = append(testAddrs, testAddr) // Next, create a temporary graph database for usage within the test. graph, cleanUp, err := makeTestGraph() if err != nil { - return nil, nil, nil, err + return nil, err } aliasMap := make(map[string]*btcec.PublicKey) @@ -391,12 +400,12 @@ func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func var source *channeldb.LightningNode if source, err = addNodeWithAlias("roasbeef"); err != nil { - return nil, nil, nil, err + return nil, err } // Set the source node if err := graph.SetSourceNode(source); err != nil { - return nil, nil, nil, err + return nil, err } channelID := uint64(0) @@ -437,7 +446,7 @@ func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func err = graph.AddChannelEdge(&edgeInfo) if err != nil && err != channeldb.ErrEdgeAlreadyExist { - return nil, nil, nil, err + return nil, err } edgePolicy := &channeldb.ChannelEdgePolicy{ @@ -451,7 +460,7 @@ func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func FeeProportionalMillionths: testChannel.Node1.FeeRate, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { - return nil, nil, nil, err + return nil, err } edgePolicy = &channeldb.ChannelEdgePolicy{ @@ -466,13 +475,16 @@ func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { - return nil, nil, nil, err + return nil, err } channelID++ } - return graph, cleanUp, aliasMap, nil + return &testGraphInstance{ + graph: graph, + cleanUp: cleanUp, + aliasMap: aliasMap}, nil } // TestFindLowestFeePath tests that out of two routes with identical total @@ -513,13 +525,13 @@ func TestFindLowestFeePath(t *testing.T) { }), } - graph, cleanUp, aliases, err := createTestGraph(testChannels) - defer cleanUp() + testGraphInstance, err := createTestGraphFromChannels(testChannels) + defer testGraphInstance.cleanUp() if err != nil { t.Fatalf("unable to create graph: %v", err) } - sourceNode, err := graph.SourceNode() + sourceNode, err := testGraphInstance.graph.SourceNode() if err != nil { t.Fatalf("unable to fetch source node: %v", err) } @@ -534,10 +546,10 @@ func TestFindLowestFeePath(t *testing.T) { ) paymentAmt := lnwire.NewMSatFromSatoshis(100) - target := aliases["target"] + target := testGraphInstance.aliasMap["target"] path, err := findPath( - nil, graph, nil, sourceNode, target, ignoredVertexes, - ignoredEdges, paymentAmt, noFeeLimit, nil, + nil, testGraphInstance.graph, nil, sourceNode, target, + ignoredVertexes, ignoredEdges, paymentAmt, noFeeLimit, nil, ) if err != nil { t.Fatalf("unable to find path: %v", err) @@ -551,7 +563,7 @@ func TestFindLowestFeePath(t *testing.T) { // Assert that the lowest fee route is returned. if !bytes.Equal(route.Hops[1].Channel.Node.PubKeyBytes[:], - aliases["b"].SerializeCompressed()) { + testGraphInstance.aliasMap["b"].SerializeCompressed()) { t.Fatalf("expected route to pass through b, "+ "but got a route through %v", route.Hops[1].Channel.Node.Alias) @@ -621,8 +633,8 @@ var basicGraphPathFindingTests = []basicGraphPathFindingTestCase{ func TestBasicGraphPathFinding(t *testing.T) { t.Parallel() - graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) - defer cleanUp() + testGraphInstance, err := parseTestGraph(basicGraphFilePath) + defer testGraphInstance.cleanUp() if err != nil { t.Fatalf("unable to create graph: %v", err) } @@ -634,18 +646,19 @@ func TestBasicGraphPathFinding(t *testing.T) { for _, testCase := range basicGraphPathFindingTests { t.Run(testCase.target, func(subT *testing.T) { - testBasicGraphPathFindingCase(subT, graph, aliases, &testCase) + testBasicGraphPathFindingCase(subT, testGraphInstance, &testCase) }) } } -func testBasicGraphPathFindingCase(t *testing.T, graph *channeldb.ChannelGraph, - aliases aliasMap, test *basicGraphPathFindingTestCase) { +func testBasicGraphPathFindingCase(t *testing.T, graphInstance *testGraphInstance, + test *basicGraphPathFindingTestCase) { + aliases := graphInstance.aliasMap expectedHops := test.expectedHops expectedHopCount := len(expectedHops) - sourceNode, err := graph.SourceNode() + sourceNode, err := graphInstance.graph.SourceNode() if err != nil { t.Fatalf("unable to fetch source node: %v", err) } @@ -660,10 +673,10 @@ func testBasicGraphPathFindingCase(t *testing.T, graph *channeldb.ChannelGraph, ) paymentAmt := lnwire.NewMSatFromSatoshis(test.paymentAmt) - target := aliases[test.target] + target := graphInstance.aliasMap[test.target] path, err := findPath( - nil, graph, nil, sourceNode, target, ignoredVertexes, - ignoredEdges, paymentAmt, test.feeLimit, nil, + nil, graphInstance.graph, nil, sourceNode, target, + ignoredVertexes, ignoredEdges, paymentAmt, test.feeLimit, nil, ) if test.expectFailureNoPath { if err == nil { @@ -799,13 +812,13 @@ func testBasicGraphPathFindingCase(t *testing.T, graph *channeldb.ChannelGraph, func TestPathFindingWithAdditionalEdges(t *testing.T) { t.Parallel() - graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) - defer cleanUp() + graph, err := parseTestGraph(basicGraphFilePath) + defer graph.cleanUp() if err != nil { t.Fatalf("unable to create graph: %v", err) } - sourceNode, err := graph.SourceNode() + sourceNode, err := graph.graph.SourceNode() if err != nil { t.Fatalf("unable to fetch source node: %v", err) } @@ -842,12 +855,12 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) { } additionalEdges := map[Vertex][]*channeldb.ChannelEdgePolicy{ - NewVertex(aliases["songoku"]): {songokuToDoge}, + NewVertex(graph.aliasMap["songoku"]): {songokuToDoge}, } // We should now be able to find a path from roasbeef to doge. path, err := findPath( - nil, graph, additionalEdges, sourceNode, dogePubKey, nil, nil, + nil, graph.graph, additionalEdges, sourceNode, dogePubKey, nil, nil, paymentAmt, noFeeLimit, nil, ) if err != nil { @@ -862,13 +875,13 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) { func TestKShortestPathFinding(t *testing.T) { t.Parallel() - graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) - defer cleanUp() + graph, err := parseTestGraph(basicGraphFilePath) + defer graph.cleanUp() if err != nil { t.Fatalf("unable to create graph: %v", err) } - sourceNode, err := graph.SourceNode() + sourceNode, err := graph.graph.SourceNode() if err != nil { t.Fatalf("unable to fetch source node: %v", err) } @@ -882,9 +895,9 @@ func TestKShortestPathFinding(t *testing.T) { // them in order of their total "distance". paymentAmt := lnwire.NewMSatFromSatoshis(100) - target := aliases["luoji"] + target := graph.aliasMap["luoji"] paths, err := findPaths( - nil, graph, sourceNode, target, paymentAmt, noFeeLimit, 100, + nil, graph.graph, sourceNode, target, paymentAmt, noFeeLimit, 100, nil, ) if err != nil { @@ -1201,13 +1214,13 @@ func TestNewRoutePathTooLong(t *testing.T) { // Ensure that potential paths which are over the maximum hop-limit are // rejected. - graph, cleanUp, aliases, err := parseTestGraph(excessiveHopsGraphFilePath) - defer cleanUp() + graph, err := parseTestGraph(excessiveHopsGraphFilePath) + defer graph.cleanUp() if err != nil { t.Fatalf("unable to create graph: %v", err) } - sourceNode, err := graph.SourceNode() + sourceNode, err := graph.graph.SourceNode() if err != nil { t.Fatalf("unable to fetch source node: %v", err) } @@ -1219,9 +1232,9 @@ func TestNewRoutePathTooLong(t *testing.T) { // We start by confirming that routing a payment 20 hops away is possible. // Alice should be able to find a valid route to ursula. - target := aliases["ursula"] + target := graph.aliasMap["ursula"] _, err = findPath( - nil, graph, nil, sourceNode, target, ignoredVertexes, + nil, graph.graph, nil, sourceNode, target, ignoredVertexes, ignoredEdges, paymentAmt, noFeeLimit, nil, ) if err != nil { @@ -1230,9 +1243,9 @@ func TestNewRoutePathTooLong(t *testing.T) { // Vincent is 21 hops away from Alice, and thus no valid route should be // presented to Alice. - target = aliases["vincent"] + target = graph.aliasMap["vincent"] path, err := findPath( - nil, graph, nil, sourceNode, target, ignoredVertexes, + nil, graph.graph, nil, sourceNode, target, ignoredVertexes, ignoredEdges, paymentAmt, noFeeLimit, nil, ) if err == nil { @@ -1246,13 +1259,13 @@ func TestNewRoutePathTooLong(t *testing.T) { func TestPathNotAvailable(t *testing.T) { t.Parallel() - graph, cleanUp, _, err := parseTestGraph(basicGraphFilePath) - defer cleanUp() + graph, err := parseTestGraph(basicGraphFilePath) + defer graph.cleanUp() if err != nil { t.Fatalf("unable to create graph: %v", err) } - sourceNode, err := graph.SourceNode() + sourceNode, err := graph.graph.SourceNode() if err != nil { t.Fatalf("unable to fetch source node: %v", err) } @@ -1274,7 +1287,7 @@ func TestPathNotAvailable(t *testing.T) { } _, err = findPath( - nil, graph, nil, sourceNode, unknownNode, ignoredVertexes, + nil, graph.graph, nil, sourceNode, unknownNode, ignoredVertexes, ignoredEdges, 100, noFeeLimit, nil, ) if !IsError(err, ErrNoPathFound) { @@ -1285,13 +1298,13 @@ func TestPathNotAvailable(t *testing.T) { func TestPathInsufficientCapacity(t *testing.T) { t.Parallel() - graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) - defer cleanUp() + graph, err := parseTestGraph(basicGraphFilePath) + defer graph.cleanUp() if err != nil { t.Fatalf("unable to create graph: %v", err) } - sourceNode, err := graph.SourceNode() + sourceNode, err := graph.graph.SourceNode() if err != nil { t.Fatalf("unable to fetch source node: %v", err) } @@ -1306,11 +1319,11 @@ func TestPathInsufficientCapacity(t *testing.T) { // satoshis. The largest channel in the basic graph is of size 100k // satoshis, so we shouldn't be able to find a path to sophon even // though we have a 2-hop link. - target := aliases["sophon"] + target := graph.aliasMap["sophon"] payAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) _, err = findPath( - nil, graph, nil, sourceNode, target, ignoredVertexes, + nil, graph.graph, nil, sourceNode, target, ignoredVertexes, ignoredEdges, payAmt, noFeeLimit, nil, ) if !IsError(err, ErrNoPathFound) { @@ -1323,13 +1336,13 @@ func TestPathInsufficientCapacity(t *testing.T) { func TestRouteFailMinHTLC(t *testing.T) { t.Parallel() - graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) - defer cleanUp() + graph, err := parseTestGraph(basicGraphFilePath) + defer graph.cleanUp() if err != nil { t.Fatalf("unable to create graph: %v", err) } - sourceNode, err := graph.SourceNode() + sourceNode, err := graph.graph.SourceNode() if err != nil { t.Fatalf("unable to fetch source node: %v", err) } @@ -1339,10 +1352,10 @@ func TestRouteFailMinHTLC(t *testing.T) { // We'll not attempt to route an HTLC of 10 SAT from roasbeef to Son // Goku. However, the min HTLC of Son Goku is 1k SAT, as a result, this // attempt should fail. - target := aliases["songoku"] + target := graph.aliasMap["songoku"] payAmt := lnwire.MilliSatoshi(10) _, err = findPath( - nil, graph, nil, sourceNode, target, ignoredVertexes, + nil, graph.graph, nil, sourceNode, target, ignoredVertexes, ignoredEdges, payAmt, noFeeLimit, nil, ) if !IsError(err, ErrNoPathFound) { @@ -1356,13 +1369,13 @@ func TestRouteFailMinHTLC(t *testing.T) { func TestRouteFailDisabledEdge(t *testing.T) { t.Parallel() - graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) - defer cleanUp() + graph, err := parseTestGraph(basicGraphFilePath) + defer graph.cleanUp() if err != nil { t.Fatalf("unable to create graph: %v", err) } - sourceNode, err := graph.SourceNode() + sourceNode, err := graph.graph.SourceNode() if err != nil { t.Fatalf("unable to fetch source node: %v", err) } @@ -1371,10 +1384,10 @@ func TestRouteFailDisabledEdge(t *testing.T) { // First, we'll try to route from roasbeef -> sophon. This should // succeed without issue, and return a single path via phamnuwen - target := aliases["sophon"] + target := graph.aliasMap["sophon"] payAmt := lnwire.NewMSatFromSatoshis(105000) _, err = findPath( - nil, graph, nil, sourceNode, target, ignoredVertexes, + nil, graph.graph, nil, sourceNode, target, ignoredVertexes, ignoredEdges, payAmt, noFeeLimit, nil, ) if err != nil { @@ -1383,19 +1396,19 @@ func TestRouteFailDisabledEdge(t *testing.T) { // First, we'll modify the edge from roasbeef -> phamnuwen, to read that // it's disabled. - _, _, phamnuwenEdge, err := graph.FetchChannelEdgesByID(999991) + _, _, phamnuwenEdge, err := graph.graph.FetchChannelEdgesByID(999991) if err != nil { t.Fatalf("unable to fetch goku's edge: %v", err) } phamnuwenEdge.Flags = lnwire.ChanUpdateDisabled | lnwire.ChanUpdateDirection - if err := graph.UpdateEdgePolicy(phamnuwenEdge); err != nil { + if err := graph.graph.UpdateEdgePolicy(phamnuwenEdge); err != nil { t.Fatalf("unable to update edge: %v", err) } // Now, if we attempt to route through that edge, we should get a // failure as it is no longer eligible. _, err = findPath( - nil, graph, nil, sourceNode, target, ignoredVertexes, + nil, graph.graph, nil, sourceNode, target, ignoredVertexes, ignoredEdges, payAmt, noFeeLimit, nil, ) if !IsError(err, ErrNoPathFound) { @@ -1420,7 +1433,7 @@ func TestPathFindSpecExample(t *testing.T) { // we'll pass that in to ensure that the router uses 100 as the current // height. const startingHeight = 100 - ctx, cleanUp, err := createTestCtx(startingHeight, specExampleFilePath) + ctx, cleanUp, err := createTestCtxFromFile(startingHeight, specExampleFilePath) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) diff --git a/routing/router_test.go b/routing/router_test.go index 4ad37b67..a1bcbdc4 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -74,51 +74,17 @@ func copyPubKey(pub *btcec.PublicKey) *btcec.PublicKey { } } -func createTestCtx(startingHeight uint32, testGraph ...string) (*testCtx, func(), error) { - var ( - graph *channeldb.ChannelGraph - sourceNode *channeldb.LightningNode - cleanup func() - err error - ) +func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGraphInstance) ( + *testCtx, func(), error) { - aliasMap := make(map[string]*btcec.PublicKey) - - // If the testGraph isn't set, then we'll create an empty graph to - // start out with. Our usage of a variadic parameter allows caller to - // omit the testGraph argument all together if they wish to start with - // a blank graph. - if testGraph == nil { - // First we'll set up a test graph for usage within the test. - graph, cleanup, err = makeTestGraph() - if err != nil { - return nil, nil, fmt.Errorf("unable to create test graph: %v", err) - } - - sourceNode, err = createTestNode() - if err != nil { - return nil, nil, fmt.Errorf("unable to create source node: %v", err) - } - if err = graph.SetSourceNode(sourceNode); err != nil { - return nil, nil, fmt.Errorf("unable to set source node: %v", err) - } - } else { - // Otherwise, we'll attempt to locate and parse out the file - // that encodes the graph that our tests should be run against. - graph, cleanup, aliasMap, err = parseTestGraph(testGraph[0]) - if err != nil { - return nil, nil, fmt.Errorf("unable to create test graph: %v", err) - } - } - - // Next we'll initialize an instance of the channel router with mock + // We'll initialize an instance of the channel router with mock // versions of the chain and channel notifier. As we don't need to test // any p2p functionality, the peer send and switch send messages won't // be populated. chain := newMockChain(startingHeight) chainView := newMockChainView(chain) router, err := New(Config{ - Graph: graph, + Graph: graphInstance.graph, Chain: chain, ChainView: chainView, SendToSwitch: func(_ lnwire.ShortChannelID, @@ -141,20 +107,60 @@ func createTestCtx(startingHeight uint32, testGraph ...string) (*testCtx, func() ctx := &testCtx{ router: router, - graph: graph, - aliases: aliasMap, + graph: graphInstance.graph, + aliases: graphInstance.aliasMap, chain: chain, chainView: chainView, } cleanUp := func() { ctx.router.Stop() - cleanup() + graphInstance.cleanUp() } return ctx, cleanUp, nil } +func createTestCtxSingleNode(startingHeight uint32) (*testCtx, func(), error) { + var ( + graph *channeldb.ChannelGraph + sourceNode *channeldb.LightningNode + cleanup func() + err error + ) + + graph, cleanup, err = makeTestGraph() + if err != nil { + return nil, nil, fmt.Errorf("unable to create test graph: %v", err) + } + + sourceNode, err = createTestNode() + if err != nil { + return nil, nil, fmt.Errorf("unable to create source node: %v", err) + } + if err = graph.SetSourceNode(sourceNode); err != nil { + return nil, nil, fmt.Errorf("unable to set source node: %v", err) + } + + graphInstance := &testGraphInstance{ + graph: graph, + cleanUp: cleanup, + } + + return createTestCtxFromGraphInstance(startingHeight, graphInstance) +} + +func createTestCtxFromFile(startingHeight uint32, testGraph string) (*testCtx, func(), error) { + // We'll attempt to locate and parse out the file + // that encodes the graph that our tests should be run against. + graphInstance, err := parseTestGraph(testGraph) + if err != nil { + return nil, nil, fmt.Errorf("unable to create test graph: %v", err) + } + + return createTestCtxFromGraphInstance(startingHeight, graphInstance) +} + // TestFindRoutesFeeSorting asserts that routes found by the FindRoutes method // within the channel router are properly returned in a sorted order, with the // lowest fee route coming first. @@ -162,7 +168,7 @@ func TestFindRoutesFeeSorting(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) + ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -215,7 +221,7 @@ func TestFindRoutesWithFeeLimit(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx( + ctx, cleanUp, err := createTestCtxFromFile( startingBlockHeight, basicGraphFilePath, ) defer cleanUp() @@ -269,7 +275,7 @@ func TestSendPaymentRouteFailureFallback(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) + ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -348,7 +354,7 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) + ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -449,7 +455,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) + ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -581,7 +587,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) + ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -753,7 +759,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { func TestAddProof(t *testing.T) { t.Parallel() - ctx, cleanup, err := createTestCtx(0) + ctx, cleanup, err := createTestCtxSingleNode(0) if err != nil { t.Fatal(err) } @@ -816,7 +822,7 @@ func TestIgnoreNodeAnnouncement(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight, + ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath) defer cleanUp() if err != nil { @@ -849,7 +855,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight, + ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath) defer cleanUp() if err != nil { @@ -1119,7 +1125,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight) + ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -1322,7 +1328,7 @@ func TestDisconnectedBlocks(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight) + ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -1512,7 +1518,7 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight) + ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -1665,7 +1671,7 @@ func TestFindPathFeeWeighting(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) + ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -1716,7 +1722,7 @@ func TestIsStaleNode(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight) + ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -1798,7 +1804,7 @@ func TestIsKnownEdge(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight) + ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight) defer cleanUp() if err != nil { t.Fatalf("unable to create router: %v", err) @@ -1850,7 +1856,7 @@ func TestIsStaleEdgePolicy(t *testing.T) { t.Parallel() const startingBlockHeight = 101 - ctx, cleanUp, err := createTestCtx(startingBlockHeight, + ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath) defer cleanUp() if err != nil {