routing: restructure test context creation

This commit is contained in:
Joost Jager 2018-08-16 21:35:38 +02:00
parent 2d255e3bc3
commit 5daf75b264
No known key found for this signature in database
GPG Key ID: AE6B0D042C8E38D9
3 changed files with 168 additions and 149 deletions

@ -339,7 +339,7 @@ func (m *mockChainView) Stop() error {
func TestEdgeUpdateNotification(t *testing.T) { func TestEdgeUpdateNotification(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cleanUp, err := createTestCtx(0) ctx, cleanUp, err := createTestCtxSingleNode(0)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -528,7 +528,7 @@ func TestNodeUpdateNotification(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight) ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -706,7 +706,7 @@ func TestNotificationCancellation(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight) ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -798,7 +798,7 @@ func TestChannelCloseNotification(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight) ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)

@ -121,17 +121,12 @@ func makeTestGraph() (*channeldb.ChannelGraph, func(), error) {
return cdb.ChannelGraph(), cleanUp, nil return cdb.ChannelGraph(), cleanUp, nil
} }
// aliasMap is a map from a node's alias to its public key. This type is
// provided in order to allow easily look up from the human memorable alias
// to an exact node's public key.
type aliasMap map[string]*btcec.PublicKey
// parseTestGraph returns a fully populated ChannelGraph given a path to a JSON // parseTestGraph returns a fully populated ChannelGraph given a path to a JSON
// file which encodes a test graph. // file which encodes a test graph.
func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, error) { func parseTestGraph(path string) (*testGraphInstance, error) {
graphJSON, err := ioutil.ReadFile(path) graphJSON, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
// First unmarshal the JSON graph into an instance of the testGraph // First unmarshal the JSON graph into an instance of the testGraph
@ -139,7 +134,7 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err
// will be properly parsed into the struct above. // will be properly parsed into the struct above.
var g testGraph var g testGraph
if err := json.Unmarshal(graphJSON, &g); err != nil { if err := json.Unmarshal(graphJSON, &g); err != nil {
return nil, nil, nil, err return nil, err
} }
// We'll use this fake address for the IP address of all the nodes in // We'll use this fake address for the IP address of all the nodes in
@ -148,14 +143,14 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err
var testAddrs []net.Addr var testAddrs []net.Addr
testAddr, err := net.ResolveTCPAddr("tcp", "192.0.0.1:8888") testAddr, err := net.ResolveTCPAddr("tcp", "192.0.0.1:8888")
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
testAddrs = append(testAddrs, testAddr) testAddrs = append(testAddrs, testAddr)
// Next, create a temporary graph database for usage within the test. // Next, create a temporary graph database for usage within the test.
graph, cleanUp, err := makeTestGraph() graph, cleanUp, err := makeTestGraph()
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
aliasMap := make(map[string]*btcec.PublicKey) aliasMap := make(map[string]*btcec.PublicKey)
@ -165,7 +160,7 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err
for _, node := range g.Nodes { for _, node := range g.Nodes {
pubBytes, err := hex.DecodeString(node.PubKey) pubBytes, err := hex.DecodeString(node.PubKey)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
dbNode := &channeldb.LightningNode{ dbNode := &channeldb.LightningNode{
@ -181,13 +176,13 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err
// We require all aliases within the graph to be unique for our // We require all aliases within the graph to be unique for our
// tests. // tests.
if _, ok := aliasMap[node.Alias]; ok { if _, ok := aliasMap[node.Alias]; ok {
return nil, nil, nil, errors.New("aliases for nodes " + return nil, errors.New("aliases for nodes " +
"must be unique!") "must be unique!")
} }
pub, err := btcec.ParsePubKey(pubBytes, btcec.S256()) pub, err := btcec.ParsePubKey(pubBytes, btcec.S256())
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
// If the alias is unique, then add the node to the // If the alias is unique, then add the node to the
@ -203,7 +198,7 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err
// iteration, then the JSON has an error as only ONE // iteration, then the JSON has an error as only ONE
// node can be the source in the graph. // node can be the source in the graph.
if source != nil { if source != nil {
return nil, nil, nil, errors.New("JSON is invalid " + return nil, errors.New("JSON is invalid " +
"multiple nodes are tagged as the source") "multiple nodes are tagged as the source")
} }
@ -213,14 +208,14 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err
// With the node fully parsed, add it as a vertex within the // With the node fully parsed, add it as a vertex within the
// graph. // graph.
if err := graph.AddLightningNode(dbNode); err != nil { if err := graph.AddLightningNode(dbNode); err != nil {
return nil, nil, nil, err return nil, err
} }
} }
if source != nil { if source != nil {
// Set the selected source node // Set the selected source node
if err := graph.SetSourceNode(source); err != nil { if err := graph.SetSourceNode(source); err != nil {
return nil, nil, nil, err return nil, err
} }
} }
@ -229,18 +224,18 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err
for _, edge := range g.Edges { for _, edge := range g.Edges {
node1Bytes, err := hex.DecodeString(edge.Node1) node1Bytes, err := hex.DecodeString(edge.Node1)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
node2Bytes, err := hex.DecodeString(edge.Node2) node2Bytes, err := hex.DecodeString(edge.Node2)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
fundingTXID := strings.Split(edge.ChannelPoint, ":")[0] fundingTXID := strings.Split(edge.ChannelPoint, ":")[0]
txidBytes, err := chainhash.NewHashFromStr(fundingTXID) txidBytes, err := chainhash.NewHashFromStr(fundingTXID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
fundingPoint := wire.OutPoint{ fundingPoint := wire.OutPoint{
Hash: *txidBytes, Hash: *txidBytes,
@ -263,7 +258,7 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err
err = graph.AddChannelEdge(&edgeInfo) err = graph.AddChannelEdge(&edgeInfo)
if err != nil && err != channeldb.ErrEdgeAlreadyExist { if err != nil && err != channeldb.ErrEdgeAlreadyExist {
return nil, nil, nil, err return nil, err
} }
edgePolicy := &channeldb.ChannelEdgePolicy{ edgePolicy := &channeldb.ChannelEdgePolicy{
@ -277,11 +272,15 @@ func parseTestGraph(path string) (*channeldb.ChannelGraph, func(), aliasMap, err
FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate),
} }
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, nil, nil, err return nil, err
} }
} }
return graph, cleanUp, aliasMap, nil return &testGraphInstance{
graph: graph,
cleanUp: cleanUp,
aliasMap: aliasMap,
}, nil
} }
type testChannelPolicy struct { type testChannelPolicy struct {
@ -329,26 +328,36 @@ type testChannel struct {
Capacity btcutil.Amount Capacity btcutil.Amount
} }
// createTestGraph returns a fully populated ChannelGraph based on a set of type testGraphInstance struct {
graph *channeldb.ChannelGraph
cleanUp func()
// aliasMap is a map from a node's alias to its public key. This type is
// provided in order to allow easily look up from the human memorable alias
// to an exact node's public key.
aliasMap map[string]*btcec.PublicKey
}
// createTestGraphFromChannels returns a fully populated ChannelGraph based on a set of
// test channels. Additional required information like keys are derived in // test channels. Additional required information like keys are derived in
// a deterministical way and added to the channel graph. A list of nodes is // a deterministical way and added to the channel graph. A list of nodes is
// not required and derived from the channel data. The goal is to keep // not required and derived from the channel data. The goal is to keep
// instantiating a test channel graph as light weight as possible. // instantiating a test channel graph as light weight as possible.
func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func(), aliasMap, error) { func createTestGraphFromChannels(testChannels []*testChannel) (*testGraphInstance, error) {
// We'll use this fake address for the IP address of all the nodes in // We'll use this fake address for the IP address of all the nodes in
// our tests. This value isn't needed for path finding so it doesn't // our tests. This value isn't needed for path finding so it doesn't
// need to be unique. // need to be unique.
var testAddrs []net.Addr var testAddrs []net.Addr
testAddr, err := net.ResolveTCPAddr("tcp", "192.0.0.1:8888") testAddr, err := net.ResolveTCPAddr("tcp", "192.0.0.1:8888")
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
testAddrs = append(testAddrs, testAddr) testAddrs = append(testAddrs, testAddr)
// Next, create a temporary graph database for usage within the test. // Next, create a temporary graph database for usage within the test.
graph, cleanUp, err := makeTestGraph() graph, cleanUp, err := makeTestGraph()
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
aliasMap := make(map[string]*btcec.PublicKey) aliasMap := make(map[string]*btcec.PublicKey)
@ -391,12 +400,12 @@ func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func
var source *channeldb.LightningNode var source *channeldb.LightningNode
if source, err = addNodeWithAlias("roasbeef"); err != nil { if source, err = addNodeWithAlias("roasbeef"); err != nil {
return nil, nil, nil, err return nil, err
} }
// Set the source node // Set the source node
if err := graph.SetSourceNode(source); err != nil { if err := graph.SetSourceNode(source); err != nil {
return nil, nil, nil, err return nil, err
} }
channelID := uint64(0) channelID := uint64(0)
@ -437,7 +446,7 @@ func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func
err = graph.AddChannelEdge(&edgeInfo) err = graph.AddChannelEdge(&edgeInfo)
if err != nil && err != channeldb.ErrEdgeAlreadyExist { if err != nil && err != channeldb.ErrEdgeAlreadyExist {
return nil, nil, nil, err return nil, err
} }
edgePolicy := &channeldb.ChannelEdgePolicy{ edgePolicy := &channeldb.ChannelEdgePolicy{
@ -451,7 +460,7 @@ func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func
FeeProportionalMillionths: testChannel.Node1.FeeRate, FeeProportionalMillionths: testChannel.Node1.FeeRate,
} }
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, nil, nil, err return nil, err
} }
edgePolicy = &channeldb.ChannelEdgePolicy{ edgePolicy = &channeldb.ChannelEdgePolicy{
@ -466,13 +475,16 @@ func createTestGraph(testChannels []*testChannel) (*channeldb.ChannelGraph, func
} }
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, nil, nil, err return nil, err
} }
channelID++ channelID++
} }
return graph, cleanUp, aliasMap, nil return &testGraphInstance{
graph: graph,
cleanUp: cleanUp,
aliasMap: aliasMap}, nil
} }
// TestFindLowestFeePath tests that out of two routes with identical total // TestFindLowestFeePath tests that out of two routes with identical total
@ -513,13 +525,13 @@ func TestFindLowestFeePath(t *testing.T) {
}), }),
} }
graph, cleanUp, aliases, err := createTestGraph(testChannels) testGraphInstance, err := createTestGraphFromChannels(testChannels)
defer cleanUp() defer testGraphInstance.cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create graph: %v", err) t.Fatalf("unable to create graph: %v", err)
} }
sourceNode, err := graph.SourceNode() sourceNode, err := testGraphInstance.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
} }
@ -534,10 +546,10 @@ func TestFindLowestFeePath(t *testing.T) {
) )
paymentAmt := lnwire.NewMSatFromSatoshis(100) paymentAmt := lnwire.NewMSatFromSatoshis(100)
target := aliases["target"] target := testGraphInstance.aliasMap["target"]
path, err := findPath( path, err := findPath(
nil, graph, nil, sourceNode, target, ignoredVertexes, nil, testGraphInstance.graph, nil, sourceNode, target,
ignoredEdges, paymentAmt, noFeeLimit, nil, ignoredVertexes, ignoredEdges, paymentAmt, noFeeLimit, nil,
) )
if err != nil { if err != nil {
t.Fatalf("unable to find path: %v", err) t.Fatalf("unable to find path: %v", err)
@ -551,7 +563,7 @@ func TestFindLowestFeePath(t *testing.T) {
// Assert that the lowest fee route is returned. // Assert that the lowest fee route is returned.
if !bytes.Equal(route.Hops[1].Channel.Node.PubKeyBytes[:], if !bytes.Equal(route.Hops[1].Channel.Node.PubKeyBytes[:],
aliases["b"].SerializeCompressed()) { testGraphInstance.aliasMap["b"].SerializeCompressed()) {
t.Fatalf("expected route to pass through b, "+ t.Fatalf("expected route to pass through b, "+
"but got a route through %v", "but got a route through %v",
route.Hops[1].Channel.Node.Alias) route.Hops[1].Channel.Node.Alias)
@ -621,8 +633,8 @@ var basicGraphPathFindingTests = []basicGraphPathFindingTestCase{
func TestBasicGraphPathFinding(t *testing.T) { func TestBasicGraphPathFinding(t *testing.T) {
t.Parallel() t.Parallel()
graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) testGraphInstance, err := parseTestGraph(basicGraphFilePath)
defer cleanUp() defer testGraphInstance.cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create graph: %v", err) t.Fatalf("unable to create graph: %v", err)
} }
@ -634,18 +646,19 @@ func TestBasicGraphPathFinding(t *testing.T) {
for _, testCase := range basicGraphPathFindingTests { for _, testCase := range basicGraphPathFindingTests {
t.Run(testCase.target, func(subT *testing.T) { t.Run(testCase.target, func(subT *testing.T) {
testBasicGraphPathFindingCase(subT, graph, aliases, &testCase) testBasicGraphPathFindingCase(subT, testGraphInstance, &testCase)
}) })
} }
} }
func testBasicGraphPathFindingCase(t *testing.T, graph *channeldb.ChannelGraph, func testBasicGraphPathFindingCase(t *testing.T, graphInstance *testGraphInstance,
aliases aliasMap, test *basicGraphPathFindingTestCase) { test *basicGraphPathFindingTestCase) {
aliases := graphInstance.aliasMap
expectedHops := test.expectedHops expectedHops := test.expectedHops
expectedHopCount := len(expectedHops) expectedHopCount := len(expectedHops)
sourceNode, err := graph.SourceNode() sourceNode, err := graphInstance.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
} }
@ -660,10 +673,10 @@ func testBasicGraphPathFindingCase(t *testing.T, graph *channeldb.ChannelGraph,
) )
paymentAmt := lnwire.NewMSatFromSatoshis(test.paymentAmt) paymentAmt := lnwire.NewMSatFromSatoshis(test.paymentAmt)
target := aliases[test.target] target := graphInstance.aliasMap[test.target]
path, err := findPath( path, err := findPath(
nil, graph, nil, sourceNode, target, ignoredVertexes, nil, graphInstance.graph, nil, sourceNode, target,
ignoredEdges, paymentAmt, test.feeLimit, nil, ignoredVertexes, ignoredEdges, paymentAmt, test.feeLimit, nil,
) )
if test.expectFailureNoPath { if test.expectFailureNoPath {
if err == nil { if err == nil {
@ -799,13 +812,13 @@ func testBasicGraphPathFindingCase(t *testing.T, graph *channeldb.ChannelGraph,
func TestPathFindingWithAdditionalEdges(t *testing.T) { func TestPathFindingWithAdditionalEdges(t *testing.T) {
t.Parallel() t.Parallel()
graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) graph, err := parseTestGraph(basicGraphFilePath)
defer cleanUp() defer graph.cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create graph: %v", err) t.Fatalf("unable to create graph: %v", err)
} }
sourceNode, err := graph.SourceNode() sourceNode, err := graph.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
} }
@ -842,12 +855,12 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) {
} }
additionalEdges := map[Vertex][]*channeldb.ChannelEdgePolicy{ additionalEdges := map[Vertex][]*channeldb.ChannelEdgePolicy{
NewVertex(aliases["songoku"]): {songokuToDoge}, NewVertex(graph.aliasMap["songoku"]): {songokuToDoge},
} }
// We should now be able to find a path from roasbeef to doge. // We should now be able to find a path from roasbeef to doge.
path, err := findPath( path, err := findPath(
nil, graph, additionalEdges, sourceNode, dogePubKey, nil, nil, nil, graph.graph, additionalEdges, sourceNode, dogePubKey, nil, nil,
paymentAmt, noFeeLimit, nil, paymentAmt, noFeeLimit, nil,
) )
if err != nil { if err != nil {
@ -862,13 +875,13 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) {
func TestKShortestPathFinding(t *testing.T) { func TestKShortestPathFinding(t *testing.T) {
t.Parallel() t.Parallel()
graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) graph, err := parseTestGraph(basicGraphFilePath)
defer cleanUp() defer graph.cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create graph: %v", err) t.Fatalf("unable to create graph: %v", err)
} }
sourceNode, err := graph.SourceNode() sourceNode, err := graph.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
} }
@ -882,9 +895,9 @@ func TestKShortestPathFinding(t *testing.T) {
// them in order of their total "distance". // them in order of their total "distance".
paymentAmt := lnwire.NewMSatFromSatoshis(100) paymentAmt := lnwire.NewMSatFromSatoshis(100)
target := aliases["luoji"] target := graph.aliasMap["luoji"]
paths, err := findPaths( paths, err := findPaths(
nil, graph, sourceNode, target, paymentAmt, noFeeLimit, 100, nil, graph.graph, sourceNode, target, paymentAmt, noFeeLimit, 100,
nil, nil,
) )
if err != nil { if err != nil {
@ -1201,13 +1214,13 @@ func TestNewRoutePathTooLong(t *testing.T) {
// Ensure that potential paths which are over the maximum hop-limit are // Ensure that potential paths which are over the maximum hop-limit are
// rejected. // rejected.
graph, cleanUp, aliases, err := parseTestGraph(excessiveHopsGraphFilePath) graph, err := parseTestGraph(excessiveHopsGraphFilePath)
defer cleanUp() defer graph.cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create graph: %v", err) t.Fatalf("unable to create graph: %v", err)
} }
sourceNode, err := graph.SourceNode() sourceNode, err := graph.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
} }
@ -1219,9 +1232,9 @@ func TestNewRoutePathTooLong(t *testing.T) {
// We start by confirming that routing a payment 20 hops away is possible. // We start by confirming that routing a payment 20 hops away is possible.
// Alice should be able to find a valid route to ursula. // Alice should be able to find a valid route to ursula.
target := aliases["ursula"] target := graph.aliasMap["ursula"]
_, err = findPath( _, err = findPath(
nil, graph, nil, sourceNode, target, ignoredVertexes, nil, graph.graph, nil, sourceNode, target, ignoredVertexes,
ignoredEdges, paymentAmt, noFeeLimit, nil, ignoredEdges, paymentAmt, noFeeLimit, nil,
) )
if err != nil { if err != nil {
@ -1230,9 +1243,9 @@ func TestNewRoutePathTooLong(t *testing.T) {
// Vincent is 21 hops away from Alice, and thus no valid route should be // Vincent is 21 hops away from Alice, and thus no valid route should be
// presented to Alice. // presented to Alice.
target = aliases["vincent"] target = graph.aliasMap["vincent"]
path, err := findPath( path, err := findPath(
nil, graph, nil, sourceNode, target, ignoredVertexes, nil, graph.graph, nil, sourceNode, target, ignoredVertexes,
ignoredEdges, paymentAmt, noFeeLimit, nil, ignoredEdges, paymentAmt, noFeeLimit, nil,
) )
if err == nil { if err == nil {
@ -1246,13 +1259,13 @@ func TestNewRoutePathTooLong(t *testing.T) {
func TestPathNotAvailable(t *testing.T) { func TestPathNotAvailable(t *testing.T) {
t.Parallel() t.Parallel()
graph, cleanUp, _, err := parseTestGraph(basicGraphFilePath) graph, err := parseTestGraph(basicGraphFilePath)
defer cleanUp() defer graph.cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create graph: %v", err) t.Fatalf("unable to create graph: %v", err)
} }
sourceNode, err := graph.SourceNode() sourceNode, err := graph.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
} }
@ -1274,7 +1287,7 @@ func TestPathNotAvailable(t *testing.T) {
} }
_, err = findPath( _, err = findPath(
nil, graph, nil, sourceNode, unknownNode, ignoredVertexes, nil, graph.graph, nil, sourceNode, unknownNode, ignoredVertexes,
ignoredEdges, 100, noFeeLimit, nil, ignoredEdges, 100, noFeeLimit, nil,
) )
if !IsError(err, ErrNoPathFound) { if !IsError(err, ErrNoPathFound) {
@ -1285,13 +1298,13 @@ func TestPathNotAvailable(t *testing.T) {
func TestPathInsufficientCapacity(t *testing.T) { func TestPathInsufficientCapacity(t *testing.T) {
t.Parallel() t.Parallel()
graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) graph, err := parseTestGraph(basicGraphFilePath)
defer cleanUp() defer graph.cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create graph: %v", err) t.Fatalf("unable to create graph: %v", err)
} }
sourceNode, err := graph.SourceNode() sourceNode, err := graph.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
} }
@ -1306,11 +1319,11 @@ func TestPathInsufficientCapacity(t *testing.T) {
// satoshis. The largest channel in the basic graph is of size 100k // satoshis. The largest channel in the basic graph is of size 100k
// satoshis, so we shouldn't be able to find a path to sophon even // satoshis, so we shouldn't be able to find a path to sophon even
// though we have a 2-hop link. // though we have a 2-hop link.
target := aliases["sophon"] target := graph.aliasMap["sophon"]
payAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) payAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
_, err = findPath( _, err = findPath(
nil, graph, nil, sourceNode, target, ignoredVertexes, nil, graph.graph, nil, sourceNode, target, ignoredVertexes,
ignoredEdges, payAmt, noFeeLimit, nil, ignoredEdges, payAmt, noFeeLimit, nil,
) )
if !IsError(err, ErrNoPathFound) { if !IsError(err, ErrNoPathFound) {
@ -1323,13 +1336,13 @@ func TestPathInsufficientCapacity(t *testing.T) {
func TestRouteFailMinHTLC(t *testing.T) { func TestRouteFailMinHTLC(t *testing.T) {
t.Parallel() t.Parallel()
graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) graph, err := parseTestGraph(basicGraphFilePath)
defer cleanUp() defer graph.cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create graph: %v", err) t.Fatalf("unable to create graph: %v", err)
} }
sourceNode, err := graph.SourceNode() sourceNode, err := graph.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
} }
@ -1339,10 +1352,10 @@ func TestRouteFailMinHTLC(t *testing.T) {
// We'll not attempt to route an HTLC of 10 SAT from roasbeef to Son // We'll not attempt to route an HTLC of 10 SAT from roasbeef to Son
// Goku. However, the min HTLC of Son Goku is 1k SAT, as a result, this // Goku. However, the min HTLC of Son Goku is 1k SAT, as a result, this
// attempt should fail. // attempt should fail.
target := aliases["songoku"] target := graph.aliasMap["songoku"]
payAmt := lnwire.MilliSatoshi(10) payAmt := lnwire.MilliSatoshi(10)
_, err = findPath( _, err = findPath(
nil, graph, nil, sourceNode, target, ignoredVertexes, nil, graph.graph, nil, sourceNode, target, ignoredVertexes,
ignoredEdges, payAmt, noFeeLimit, nil, ignoredEdges, payAmt, noFeeLimit, nil,
) )
if !IsError(err, ErrNoPathFound) { if !IsError(err, ErrNoPathFound) {
@ -1356,13 +1369,13 @@ func TestRouteFailMinHTLC(t *testing.T) {
func TestRouteFailDisabledEdge(t *testing.T) { func TestRouteFailDisabledEdge(t *testing.T) {
t.Parallel() t.Parallel()
graph, cleanUp, aliases, err := parseTestGraph(basicGraphFilePath) graph, err := parseTestGraph(basicGraphFilePath)
defer cleanUp() defer graph.cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create graph: %v", err) t.Fatalf("unable to create graph: %v", err)
} }
sourceNode, err := graph.SourceNode() sourceNode, err := graph.graph.SourceNode()
if err != nil { if err != nil {
t.Fatalf("unable to fetch source node: %v", err) t.Fatalf("unable to fetch source node: %v", err)
} }
@ -1371,10 +1384,10 @@ func TestRouteFailDisabledEdge(t *testing.T) {
// First, we'll try to route from roasbeef -> sophon. This should // First, we'll try to route from roasbeef -> sophon. This should
// succeed without issue, and return a single path via phamnuwen // succeed without issue, and return a single path via phamnuwen
target := aliases["sophon"] target := graph.aliasMap["sophon"]
payAmt := lnwire.NewMSatFromSatoshis(105000) payAmt := lnwire.NewMSatFromSatoshis(105000)
_, err = findPath( _, err = findPath(
nil, graph, nil, sourceNode, target, ignoredVertexes, nil, graph.graph, nil, sourceNode, target, ignoredVertexes,
ignoredEdges, payAmt, noFeeLimit, nil, ignoredEdges, payAmt, noFeeLimit, nil,
) )
if err != nil { if err != nil {
@ -1383,19 +1396,19 @@ func TestRouteFailDisabledEdge(t *testing.T) {
// First, we'll modify the edge from roasbeef -> phamnuwen, to read that // First, we'll modify the edge from roasbeef -> phamnuwen, to read that
// it's disabled. // it's disabled.
_, _, phamnuwenEdge, err := graph.FetchChannelEdgesByID(999991) _, _, phamnuwenEdge, err := graph.graph.FetchChannelEdgesByID(999991)
if err != nil { if err != nil {
t.Fatalf("unable to fetch goku's edge: %v", err) t.Fatalf("unable to fetch goku's edge: %v", err)
} }
phamnuwenEdge.Flags = lnwire.ChanUpdateDisabled | lnwire.ChanUpdateDirection phamnuwenEdge.Flags = lnwire.ChanUpdateDisabled | lnwire.ChanUpdateDirection
if err := graph.UpdateEdgePolicy(phamnuwenEdge); err != nil { if err := graph.graph.UpdateEdgePolicy(phamnuwenEdge); err != nil {
t.Fatalf("unable to update edge: %v", err) t.Fatalf("unable to update edge: %v", err)
} }
// Now, if we attempt to route through that edge, we should get a // Now, if we attempt to route through that edge, we should get a
// failure as it is no longer eligible. // failure as it is no longer eligible.
_, err = findPath( _, err = findPath(
nil, graph, nil, sourceNode, target, ignoredVertexes, nil, graph.graph, nil, sourceNode, target, ignoredVertexes,
ignoredEdges, payAmt, noFeeLimit, nil, ignoredEdges, payAmt, noFeeLimit, nil,
) )
if !IsError(err, ErrNoPathFound) { if !IsError(err, ErrNoPathFound) {
@ -1420,7 +1433,7 @@ func TestPathFindSpecExample(t *testing.T) {
// we'll pass that in to ensure that the router uses 100 as the current // we'll pass that in to ensure that the router uses 100 as the current
// height. // height.
const startingHeight = 100 const startingHeight = 100
ctx, cleanUp, err := createTestCtx(startingHeight, specExampleFilePath) ctx, cleanUp, err := createTestCtxFromFile(startingHeight, specExampleFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)

@ -74,51 +74,17 @@ func copyPubKey(pub *btcec.PublicKey) *btcec.PublicKey {
} }
} }
func createTestCtx(startingHeight uint32, testGraph ...string) (*testCtx, func(), error) { func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGraphInstance) (
var ( *testCtx, func(), error) {
graph *channeldb.ChannelGraph
sourceNode *channeldb.LightningNode
cleanup func()
err error
)
aliasMap := make(map[string]*btcec.PublicKey) // We'll initialize an instance of the channel router with mock
// If the testGraph isn't set, then we'll create an empty graph to
// start out with. Our usage of a variadic parameter allows caller to
// omit the testGraph argument all together if they wish to start with
// a blank graph.
if testGraph == nil {
// First we'll set up a test graph for usage within the test.
graph, cleanup, err = makeTestGraph()
if err != nil {
return nil, nil, fmt.Errorf("unable to create test graph: %v", err)
}
sourceNode, err = createTestNode()
if err != nil {
return nil, nil, fmt.Errorf("unable to create source node: %v", err)
}
if err = graph.SetSourceNode(sourceNode); err != nil {
return nil, nil, fmt.Errorf("unable to set source node: %v", err)
}
} else {
// Otherwise, we'll attempt to locate and parse out the file
// that encodes the graph that our tests should be run against.
graph, cleanup, aliasMap, err = parseTestGraph(testGraph[0])
if err != nil {
return nil, nil, fmt.Errorf("unable to create test graph: %v", err)
}
}
// Next we'll initialize an instance of the channel router with mock
// versions of the chain and channel notifier. As we don't need to test // versions of the chain and channel notifier. As we don't need to test
// any p2p functionality, the peer send and switch send messages won't // any p2p functionality, the peer send and switch send messages won't
// be populated. // be populated.
chain := newMockChain(startingHeight) chain := newMockChain(startingHeight)
chainView := newMockChainView(chain) chainView := newMockChainView(chain)
router, err := New(Config{ router, err := New(Config{
Graph: graph, Graph: graphInstance.graph,
Chain: chain, Chain: chain,
ChainView: chainView, ChainView: chainView,
SendToSwitch: func(_ lnwire.ShortChannelID, SendToSwitch: func(_ lnwire.ShortChannelID,
@ -141,20 +107,60 @@ func createTestCtx(startingHeight uint32, testGraph ...string) (*testCtx, func()
ctx := &testCtx{ ctx := &testCtx{
router: router, router: router,
graph: graph, graph: graphInstance.graph,
aliases: aliasMap, aliases: graphInstance.aliasMap,
chain: chain, chain: chain,
chainView: chainView, chainView: chainView,
} }
cleanUp := func() { cleanUp := func() {
ctx.router.Stop() ctx.router.Stop()
cleanup() graphInstance.cleanUp()
} }
return ctx, cleanUp, nil return ctx, cleanUp, nil
} }
func createTestCtxSingleNode(startingHeight uint32) (*testCtx, func(), error) {
var (
graph *channeldb.ChannelGraph
sourceNode *channeldb.LightningNode
cleanup func()
err error
)
graph, cleanup, err = makeTestGraph()
if err != nil {
return nil, nil, fmt.Errorf("unable to create test graph: %v", err)
}
sourceNode, err = createTestNode()
if err != nil {
return nil, nil, fmt.Errorf("unable to create source node: %v", err)
}
if err = graph.SetSourceNode(sourceNode); err != nil {
return nil, nil, fmt.Errorf("unable to set source node: %v", err)
}
graphInstance := &testGraphInstance{
graph: graph,
cleanUp: cleanup,
}
return createTestCtxFromGraphInstance(startingHeight, graphInstance)
}
func createTestCtxFromFile(startingHeight uint32, testGraph string) (*testCtx, func(), error) {
// We'll attempt to locate and parse out the file
// that encodes the graph that our tests should be run against.
graphInstance, err := parseTestGraph(testGraph)
if err != nil {
return nil, nil, fmt.Errorf("unable to create test graph: %v", err)
}
return createTestCtxFromGraphInstance(startingHeight, graphInstance)
}
// TestFindRoutesFeeSorting asserts that routes found by the FindRoutes method // TestFindRoutesFeeSorting asserts that routes found by the FindRoutes method
// within the channel router are properly returned in a sorted order, with the // within the channel router are properly returned in a sorted order, with the
// lowest fee route coming first. // lowest fee route coming first.
@ -162,7 +168,7 @@ func TestFindRoutesFeeSorting(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -215,7 +221,7 @@ func TestFindRoutesWithFeeLimit(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx( ctx, cleanUp, err := createTestCtxFromFile(
startingBlockHeight, basicGraphFilePath, startingBlockHeight, basicGraphFilePath,
) )
defer cleanUp() defer cleanUp()
@ -269,7 +275,7 @@ func TestSendPaymentRouteFailureFallback(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -348,7 +354,7 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -449,7 +455,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -581,7 +587,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -753,7 +759,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) {
func TestAddProof(t *testing.T) { func TestAddProof(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cleanup, err := createTestCtx(0) ctx, cleanup, err := createTestCtxSingleNode(0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -816,7 +822,7 @@ func TestIgnoreNodeAnnouncement(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight, ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight,
basicGraphFilePath) basicGraphFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
@ -849,7 +855,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight, ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight,
basicGraphFilePath) basicGraphFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
@ -1119,7 +1125,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight) ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -1322,7 +1328,7 @@ func TestDisconnectedBlocks(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight) ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -1512,7 +1518,7 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight) ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -1665,7 +1671,7 @@ func TestFindPathFeeWeighting(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight, basicGraphFilePath) ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight, basicGraphFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -1716,7 +1722,7 @@ func TestIsStaleNode(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight) ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -1798,7 +1804,7 @@ func TestIsKnownEdge(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight) ctx, cleanUp, err := createTestCtxSingleNode(startingBlockHeight)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to create router: %v", err) t.Fatalf("unable to create router: %v", err)
@ -1850,7 +1856,7 @@ func TestIsStaleEdgePolicy(t *testing.T) {
t.Parallel() t.Parallel()
const startingBlockHeight = 101 const startingBlockHeight = 101
ctx, cleanUp, err := createTestCtx(startingBlockHeight, ctx, cleanUp, err := createTestCtxFromFile(startingBlockHeight,
basicGraphFilePath) basicGraphFilePath)
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {