diff --git a/autopilot/betweenness_centrality.go b/autopilot/betweenness_centrality.go new file mode 100644 index 00000000..bfc15e36 --- /dev/null +++ b/autopilot/betweenness_centrality.go @@ -0,0 +1,212 @@ +package autopilot + +// stack is a simple int stack to help with readability of Brandes' +// betweenness centrality implementation below. +type stack struct { + stack []int +} + +func (s *stack) push(v int) { + s.stack = append(s.stack, v) +} + +func (s *stack) top() int { + return s.stack[len(s.stack)-1] +} + +func (s *stack) pop() { + s.stack = s.stack[:len(s.stack)-1] +} + +func (s *stack) empty() bool { + return len(s.stack) == 0 +} + +// queue is a simple int queue to help with readability of Brandes' +// betweenness centrality implementation below. +type queue struct { + queue []int +} + +func (q *queue) push(v int) { + q.queue = append(q.queue, v) +} + +func (q *queue) front() int { + return q.queue[0] +} + +func (q *queue) pop() { + q.queue = q.queue[1:] +} + +func (q *queue) empty() bool { + return len(q.queue) == 0 +} + +// BetweennessCentrality is a NodeMetric that calculates node betweenness +// centrality using Brandes' algorithm. Betweenness centrality for each node +// is the number of shortest paths passing trough that node, not counting +// shortest paths starting or ending at that node. This is a useful metric +// to measure control of individual nodes over the whole network. +type BetweennessCentrality struct { + // centrality stores original (not normalized) centrality values for + // each node in the graph. + centrality map[NodeID]float64 + + // min is the minimum centrality in the graph. + min float64 + + // max is the maximum centrality in the graph. + max float64 +} + +// NewBetweennessCentralityMetric creates a new BetweennessCentrality instance. +func NewBetweennessCentralityMetric() *BetweennessCentrality { + return &BetweennessCentrality{} +} + +// Name returns the name of the metric. +func (bc *BetweennessCentrality) Name() string { + return "betweeness_centrality" +} + +// betweennessCentrality is the core of Brandes' algorithm. +// We first calculate the shortest paths from the start node s to all other +// nodes with BFS, then update the betweenness centrality values by using +// Brandes' dependency trick. +// For detailed explanation please read: +// https://www.cl.cam.ac.uk/teaching/1617/MLRD/handbook/brandes.html +func betweennessCentrality(g *SimpleGraph, s int, centrality []float64) { + // pred[w] is the list of nodes that immediately precede w on a + // shortest path from s to t for each node t. + pred := make([][]int, len(g.Nodes)) + + // sigma[t] is the number of shortest paths between nodes s and t for + // each node t. + sigma := make([]int, len(g.Nodes)) + sigma[s] = 1 + + // dist[t] holds the distance between s and t for each node t. We initialize + // this to -1 (meaning infinity) for each t != s. + dist := make([]int, len(g.Nodes)) + for i := range dist { + dist[i] = -1 + } + + dist[s] = 0 + + var ( + st stack + q queue + ) + q.push(s) + + // BFS to calculate the shortest paths (sigma and pred) + // from s to t for each node t. + for !q.empty() { + v := q.front() + q.pop() + st.push(v) + + for _, w := range g.Adj[v] { + // If distance from s to w is infinity (-1) + // then set it and enqueue w. + if dist[w] < 0 { + dist[w] = dist[v] + 1 + q.push(w) + } + + // If w is on a shortest path the update + // sigma and add v to w's predecessor list. + if dist[w] == dist[v]+1 { + sigma[w] += sigma[v] + pred[w] = append(pred[w], v) + } + } + } + + // delta[v] is the ratio of the shortest paths between s and t that go + // through v and the total number of shortest paths between s and t. + // If we have delta then the betweenness centrality is simply the sum + // of delta[w] for each w != s. + delta := make([]float64, len(g.Nodes)) + + for !st.empty() { + w := st.top() + st.pop() + + // pred[w] is the list of nodes that immediately precede w on a + // shortest path from s. + for _, v := range pred[w] { + // Update delta using Brandes' equation. + delta[v] += (float64(sigma[v]) / float64(sigma[w])) * (1.0 + delta[w]) + } + + if w != s { + // As noted above centrality is simply the sum + // of delta[w] for each w != s. + centrality[w] += delta[w] + } + } +} + +// Refresh recaculates and stores centrality values. +func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error { + cache, err := NewSimpleGraph(graph) + if err != nil { + return err + } + + // TODO: parallelize updates to centrality. + centrality := make([]float64, len(cache.Nodes)) + for node := range cache.Nodes { + betweennessCentrality(cache, node, centrality) + } + + // Get min/max to be able to normalize + // centrality values between 0 and 1. + bc.min = 0 + bc.max = 0 + if len(centrality) > 0 { + for i := 1; i < len(centrality); i++ { + if centrality[i] < bc.min { + bc.min = centrality[i] + } else if centrality[i] > bc.max { + bc.max = centrality[i] + } + } + } + + // Divide by two as this is an undirected graph. + bc.min /= 2.0 + bc.max /= 2.0 + + bc.centrality = make(map[NodeID]float64) + for u, value := range centrality { + // Divide by two as this is an undirected graph. + bc.centrality[cache.Nodes[u]] = value / 2.0 + } + + return nil +} + +// GetMetric returns the current centrality values for each node indexed +// by node id. +func (bc *BetweennessCentrality) GetMetric(normalize bool) map[NodeID]float64 { + // Normalization factor. + var z float64 + if (bc.max - bc.min) > 0 { + z = 1.0 / (bc.max - bc.min) + } + + centrality := make(map[NodeID]float64) + for k, v := range bc.centrality { + if normalize { + v = (v - bc.min) * z + } + centrality[k] = v + } + + return centrality +} diff --git a/autopilot/betweenness_centrality_test.go b/autopilot/betweenness_centrality_test.go new file mode 100644 index 00000000..09fd7fc3 --- /dev/null +++ b/autopilot/betweenness_centrality_test.go @@ -0,0 +1,156 @@ +package autopilot + +import ( + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcutil" +) + +// Tests that empty graph results in empty centrality result. +func TestBetweennessCentralityEmptyGraph(t *testing.T) { + centralityMetric := NewBetweennessCentralityMetric() + + for _, chanGraph := range chanGraphs { + graph, cleanup, err := chanGraph.genFunc() + success := t.Run(chanGraph.name, func(t1 *testing.T) { + if err != nil { + t1.Fatalf("unable to create graph: %v", err) + } + if cleanup != nil { + defer cleanup() + } + + if err := centralityMetric.Refresh(graph); err != nil { + t.Fatalf("unexpected failure during metric refresh: %v", err) + } + + centrality := centralityMetric.GetMetric(false) + if len(centrality) > 0 { + t.Fatalf("expected empty metric, got: %v", len(centrality)) + } + + centrality = centralityMetric.GetMetric(true) + if len(centrality) > 0 { + t.Fatalf("expected empty metric, got: %v", len(centrality)) + } + + }) + if !success { + break + } + } +} + +// testGraphDesc is a helper type to describe a test graph. +type testGraphDesc struct { + nodes int + edges map[int][]int +} + +// buildTestGraph builds a test graph from a passed graph desriptor. +func buildTestGraph(t *testing.T, + graph testGraph, desc testGraphDesc) map[int]*btcec.PublicKey { + + nodes := make(map[int]*btcec.PublicKey) + + for i := 0; i < desc.nodes; i++ { + key, err := graph.addRandNode() + if err != nil { + t.Fatalf("cannot create random node") + } + + nodes[i] = key + } + + const chanCapacity = btcutil.SatoshiPerBitcoin + for u, neighbors := range desc.edges { + for _, v := range neighbors { + _, _, err := graph.addRandChannel(nodes[u], nodes[v], chanCapacity) + if err != nil { + t.Fatalf("unexpected error while adding random channel: %v", err) + } + } + } + + return nodes +} + +// Test betweenness centrality calculating using an example graph. +func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { + graphDesc := testGraphDesc{ + nodes: 9, + edges: map[int][]int{ + 0: {1, 2, 3}, + 1: {2}, + 2: {3}, + 3: {4, 5}, + 4: {5, 6, 7}, + 5: {6, 7}, + 6: {7, 8}, + }, + } + + tests := []struct { + name string + normalize bool + centrality []float64 + }{ + { + normalize: true, + centrality: []float64{ + 0.2, 0.0, 0.2, 1.0, 0.4, 0.4, 7.0 / 15.0, 0.0, 0.0, + }, + }, + { + normalize: false, + centrality: []float64{ + 3.0, 0.0, 3.0, 15.0, 6.0, 6.0, 7.0, 0.0, 0.0, + }, + }, + } + + for _, chanGraph := range chanGraphs { + graph, cleanup, err := chanGraph.genFunc() + if err != nil { + t.Fatalf("unable to create graph: %v", err) + } + if cleanup != nil { + defer cleanup() + } + + success := t.Run(chanGraph.name, func(t1 *testing.T) { + centralityMetric := NewBetweennessCentralityMetric() + graphNodes := buildTestGraph(t1, graph, graphDesc) + + if err := centralityMetric.Refresh(graph); err != nil { + t1.Fatalf("error while calculating betweeness centrality") + } + for _, test := range tests { + test := test + centrality := centralityMetric.GetMetric(test.normalize) + + if len(centrality) != graphDesc.nodes { + t.Fatalf("expected %v values, got: %v", + graphDesc.nodes, len(centrality)) + } + + for node, nodeCentrality := range test.centrality { + nodeID := NewNodeID(graphNodes[node]) + calculatedCentrality, ok := centrality[nodeID] + if !ok { + t1.Fatalf("no result for node: %x (%v)", nodeID, node) + } + + if nodeCentrality != calculatedCentrality { + t1.Errorf("centrality for node: %v should be %v, got: %v", + node, test.centrality[node], calculatedCentrality) + } + } + } + }) + if !success { + break + } + } +} diff --git a/autopilot/simple_graph.go b/autopilot/simple_graph.go new file mode 100644 index 00000000..208a784e --- /dev/null +++ b/autopilot/simple_graph.go @@ -0,0 +1,66 @@ +package autopilot + +// SimpleGraph stores a simplifed adj graph of a channel graph to speed +// up graph processing by eliminating all unnecessary hashing and map access. +type SimpleGraph struct { + // Nodes is a map from node index to NodeID. + Nodes []NodeID + + // Adj stores nodes and neighbors in an adjacency list. + Adj [][]int +} + +// NewSimpleGraph creates a simplified graph from the current channel graph. +// Returns an error if the channel graph iteration fails due to underlying +// failure. +func NewSimpleGraph(g ChannelGraph) (*SimpleGraph, error) { + nodes := make(map[NodeID]int) + adj := make(map[int][]int) + nextIndex := 0 + + // getNodeIndex returns the integer index of the passed node. + // The returned index is then used to create a simplifed adjacency list + // where each node is identified by its index instead of its pubkey, and + // also to create a mapping from node index to node pubkey. + getNodeIndex := func(node Node) int { + key := NodeID(node.PubKey()) + nodeIndex, ok := nodes[key] + + if !ok { + nodes[key] = nextIndex + nodeIndex = nextIndex + nextIndex++ + } + + return nodeIndex + } + + // Iterate over each node and each channel and update the adj and the node + // index. + err := g.ForEachNode(func(node Node) error { + u := getNodeIndex(node) + + return node.ForEachChannel(func(edge ChannelEdge) error { + v := getNodeIndex(edge.Peer) + + adj[u] = append(adj[u], v) + return nil + }) + }) + if err != nil { + return nil, err + } + + graph := &SimpleGraph{ + Nodes: make([]NodeID, len(nodes)), + Adj: make([][]int, len(nodes)), + } + + // Fill the adj and the node index to node pubkey mapping. + for nodeID, nodeIndex := range nodes { + graph.Adj[nodeIndex] = adj[nodeIndex] + graph.Nodes[nodeIndex] = nodeID + } + + return graph, nil +}