diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 96c85d5f..d421c62f 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -11,6 +11,11 @@ import ( "github.com/lightningnetwork/lnd/routing/route" ) +const ( + sourceNodeID = 1 + targetNodeID = 2 +) + // integratedRoutingContext defines the context in which integrated routing // tests run. type integratedRoutingContext struct { @@ -31,8 +36,8 @@ type integratedRoutingContext struct { // context with a source and a target node. func newIntegratedRoutingContext(t *testing.T) *integratedRoutingContext { // Instantiate a mock graph. - source := newMockNode() - target := newMockNode() + source := newMockNode(sourceNodeID) + target := newMockNode(targetNodeID) graph := newMockGraph(t) graph.addNode(source) diff --git a/routing/integrated_routing_test.go b/routing/integrated_routing_test.go index ccd22117..0362f208 100644 --- a/routing/integrated_routing_test.go +++ b/routing/integrated_routing_test.go @@ -15,22 +15,25 @@ func TestProbabilityExtrapolation(t *testing.T) { // source -> intermediate1 (free routing) -> intermediate(1-10) (free routing) -> target g := ctx.graph - expensiveNode := newMockNode() + const expensiveNodeID = 3 + expensiveNode := newMockNode(expensiveNodeID) expensiveNode.baseFee = 10000 g.addNode(expensiveNode) - g.addChannel(ctx.source, expensiveNode, 100000) - g.addChannel(ctx.target, expensiveNode, 100000) + g.addChannel(100, sourceNodeID, expensiveNodeID, 100000) + g.addChannel(101, targetNodeID, expensiveNodeID, 100000) - intermediate1 := newMockNode() + const intermediate1NodeID = 4 + intermediate1 := newMockNode(intermediate1NodeID) g.addNode(intermediate1) - g.addChannel(ctx.source, intermediate1, 100000) + g.addChannel(102, sourceNodeID, intermediate1NodeID, 100000) for i := 0; i < 10; i++ { - imNode := newMockNode() + imNodeID := byte(10 + i) + imNode := newMockNode(imNodeID) g.addNode(imNode) - g.addChannel(imNode, ctx.target, 100000) - g.addChannel(imNode, intermediate1, 100000) + g.addChannel(uint64(200+i), imNodeID, targetNodeID, 100000) + g.addChannel(uint64(300+i), imNodeID, intermediate1NodeID, 100000) // The channels from intermediate1 all have insufficient balance. g.nodes[intermediate1.pubkey].channels[imNode.pubkey].balance = 0 diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index 075a416a..3834d9e5 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -11,14 +11,9 @@ import ( "github.com/lightningnetwork/lnd/routing/route" ) -// nextTestPubkey is global variable that is used to deterministically generate -// test keys. -var nextTestPubkey byte - // createPubkey return a new test pubkey. -func createPubkey() route.Vertex { - pubkey := route.Vertex{nextTestPubkey} - nextTestPubkey++ +func createPubkey(id byte) route.Vertex { + pubkey := route.Vertex{id} return pubkey } @@ -38,8 +33,8 @@ type mockNode struct { } // newMockNode instantiates a new mock node with a newly generated pubkey. -func newMockNode() *mockNode { - pubkey := createPubkey() +func newMockNode(id byte) *mockNode { + pubkey := createPubkey(id) return &mockNode{ channels: make(map[route.Vertex]*mockChannel), pubkey: pubkey, @@ -106,10 +101,9 @@ func (m *mockNode) fwd(from *mockNode, route *hop) (htlcResult, error) { // mockGraph contains a set of nodes that together for a mocked graph. type mockGraph struct { - t *testing.T - nodes map[route.Vertex]*mockNode - nextChanID uint64 - source *mockNode + t *testing.T + nodes map[route.Vertex]*mockNode + source *mockNode } // newMockGraph instantiates a new mock graph. @@ -122,6 +116,11 @@ func newMockGraph(t *testing.T) *mockGraph { // addNode adds the given mock node to the network. func (m *mockGraph) addNode(node *mockNode) { + m.t.Helper() + + if _, exists := m.nodes[node.pubkey]; exists { + m.t.Fatal("node already exists") + } m.nodes[node.pubkey] = node } @@ -131,16 +130,25 @@ func (m *mockGraph) addNode(node *mockNode) { // Ignore linter error because addChannel isn't yet called with different // capacities. // nolint:unparam -func (m *mockGraph) addChannel(node1, node2 *mockNode, capacity btcutil.Amount) { - id := m.nextChanID - m.nextChanID++ +func (m *mockGraph) addChannel(id uint64, node1id, node2id byte, + capacity btcutil.Amount) { - m.nodes[node1.pubkey].channels[node2.pubkey] = &mockChannel{ + node1pubkey := createPubkey(node1id) + node2pubkey := createPubkey(node2id) + + if _, exists := m.nodes[node1pubkey].channels[node2pubkey]; exists { + m.t.Fatal("channel already exists") + } + if _, exists := m.nodes[node2pubkey].channels[node1pubkey]; exists { + m.t.Fatal("channel already exists") + } + + m.nodes[node1pubkey].channels[node2pubkey] = &mockChannel{ capacity: capacity, id: id, balance: lnwire.NewMSatFromSatoshis(capacity / 2), } - m.nodes[node2.pubkey].channels[node1.pubkey] = &mockChannel{ + m.nodes[node2pubkey].channels[node1pubkey] = &mockChannel{ capacity: capacity, id: id, balance: lnwire.NewMSatFromSatoshis(capacity / 2),