diff --git a/autopilot/betweenness_centrality.go b/autopilot/betweenness_centrality.go index bfc15e36..34a10f7a 100644 --- a/autopilot/betweenness_centrality.go +++ b/autopilot/betweenness_centrality.go @@ -1,5 +1,10 @@ package autopilot +import ( + "fmt" + "sync" +) + // stack is a simple int stack to help with readability of Brandes' // betweenness centrality implementation below. type stack struct { @@ -50,6 +55,10 @@ func (q *queue) empty() bool { // shortest paths starting or ending at that node. This is a useful metric // to measure control of individual nodes over the whole network. type BetweennessCentrality struct { + // workers number of goroutines are used to parallelize + // centrality calculation. + workers int + // centrality stores original (not normalized) centrality values for // each node in the graph. centrality map[NodeID]float64 @@ -62,8 +71,15 @@ type BetweennessCentrality struct { } // NewBetweennessCentralityMetric creates a new BetweennessCentrality instance. -func NewBetweennessCentralityMetric() *BetweennessCentrality { - return &BetweennessCentrality{} +// Users can specify the number of workers to use for calculating centrality. +func NewBetweennessCentralityMetric(workers int) (*BetweennessCentrality, error) { + // There should be at least one worker. + if workers < 1 { + return nil, fmt.Errorf("workers must be positive") + } + return &BetweennessCentrality{ + workers: workers, + }, nil } // Name returns the name of the metric. @@ -158,10 +174,47 @@ func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error { return err } - // TODO: parallelize updates to centrality. - centrality := make([]float64, len(cache.Nodes)) + var wg sync.WaitGroup + work := make(chan int) + partials := make(chan []float64, bc.workers) + + // Each worker will compute a partial result. This + // partial result is a sum of centrality updates on + // roughly N / workers nodes. + worker := func() { + defer wg.Done() + partial := make([]float64, len(cache.Nodes)) + + // Consume the next node, update centrality + // parital to avoid unnecessary synchronizaton. + for node := range work { + betweennessCentrality(cache, node, partial) + } + partials <- partial + } + + // Now start the N workers. + wg.Add(bc.workers) + for i := 0; i < bc.workers; i++ { + go worker() + } + + // Distribute work amongst workers Should be + // fair when graph is sufficiently large. for node := range cache.Nodes { - betweennessCentrality(cache, node, centrality) + work <- node + } + + close(work) + wg.Wait() + close(partials) + + // Collect and sum partials for final result. + centrality := make([]float64, len(cache.Nodes)) + for partial := range partials { + for i := 0; i < len(partial); i++ { + centrality[i] += partial[i] + } } // Get min/max to be able to normalize @@ -169,11 +222,11 @@ func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error { bc.min = 0 bc.max = 0 if len(centrality) > 0 { - for i := 1; i < len(centrality); i++ { - if centrality[i] < bc.min { - bc.min = centrality[i] - } else if centrality[i] > bc.max { - bc.max = centrality[i] + for _, v := range centrality { + if v < bc.min { + bc.min = v + } else if v > bc.max { + bc.max = v } } } diff --git a/autopilot/betweenness_centrality_test.go b/autopilot/betweenness_centrality_test.go index 09fd7fc3..4b71f77f 100644 --- a/autopilot/betweenness_centrality_test.go +++ b/autopilot/betweenness_centrality_test.go @@ -1,15 +1,38 @@ package autopilot import ( + "fmt" "testing" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcutil" ) +func TestBetweennessCentralityMetricConstruction(t *testing.T) { + failing := []int{-1, 0} + ok := []int{1, 10} + + for _, workers := range failing { + m, err := NewBetweennessCentralityMetric(workers) + if m != nil || err == nil { + t.Fatalf("construction must fail with <= 0 workers") + } + } + + for _, workers := range ok { + m, err := NewBetweennessCentralityMetric(workers) + if m == nil || err != nil { + t.Fatalf("construction must succeed with >= 1 workers") + } + } +} + // Tests that empty graph results in empty centrality result. func TestBetweennessCentralityEmptyGraph(t *testing.T) { - centralityMetric := NewBetweennessCentralityMetric() + centralityMetric, err := NewBetweennessCentralityMetric(1) + if err != nil { + t.Fatalf("construction must succeed with positive number of workers") + } for _, chanGraph := range chanGraphs { graph, cleanup, err := chanGraph.genFunc() @@ -91,8 +114,9 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { }, } - tests := []struct { - name string + workers := []int{1, 3, 9, 100} + + results := []struct { normalize bool centrality []float64 }{ @@ -110,47 +134,54 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { }, } - for _, chanGraph := range chanGraphs { - graph, cleanup, err := chanGraph.genFunc() - if err != nil { - t.Fatalf("unable to create graph: %v", err) - } - if cleanup != nil { - defer cleanup() - } - - success := t.Run(chanGraph.name, func(t1 *testing.T) { - centralityMetric := NewBetweennessCentralityMetric() - graphNodes := buildTestGraph(t1, graph, graphDesc) - - if err := centralityMetric.Refresh(graph); err != nil { - t1.Fatalf("error while calculating betweeness centrality") + for _, numWorkers := range workers { + for _, chanGraph := range chanGraphs { + numWorkers := numWorkers + graph, cleanup, err := chanGraph.genFunc() + if err != nil { + t.Fatalf("unable to create graph: %v", err) + } + if cleanup != nil { + defer cleanup() } - for _, test := range tests { - test := test - centrality := centralityMetric.GetMetric(test.normalize) - if len(centrality) != graphDesc.nodes { - t.Fatalf("expected %v values, got: %v", - graphDesc.nodes, len(centrality)) + testName := fmt.Sprintf("%v %d workers", chanGraph.name, numWorkers) + success := t.Run(testName, func(t1 *testing.T) { + centralityMetric, err := NewBetweennessCentralityMetric(numWorkers) + if err != nil { + t.Fatalf("construction must succeed with positive number of workers") } - for node, nodeCentrality := range test.centrality { - nodeID := NewNodeID(graphNodes[node]) - calculatedCentrality, ok := centrality[nodeID] - if !ok { - t1.Fatalf("no result for node: %x (%v)", nodeID, node) + graphNodes := buildTestGraph(t1, graph, graphDesc) + if err := centralityMetric.Refresh(graph); err != nil { + t1.Fatalf("error while calculating betweeness centrality") + } + for _, expected := range results { + expected := expected + centrality := centralityMetric.GetMetric(expected.normalize) + + if len(centrality) != graphDesc.nodes { + t.Fatalf("expected %v values, got: %v", + graphDesc.nodes, len(centrality)) } - if nodeCentrality != calculatedCentrality { - t1.Errorf("centrality for node: %v should be %v, got: %v", - node, test.centrality[node], calculatedCentrality) + for node, nodeCentrality := range expected.centrality { + nodeID := NewNodeID(graphNodes[node]) + calculatedCentrality, ok := centrality[nodeID] + if !ok { + t1.Fatalf("no result for node: %x (%v)", nodeID, node) + } + + if nodeCentrality != calculatedCentrality { + t1.Errorf("centrality for node: %v should be %v, got: %v", + node, nodeCentrality, calculatedCentrality) + } } } + }) + if !success { + break } - }) - if !success { - break } } } diff --git a/rpcserver.go b/rpcserver.go index 995cc7b6..5cb8ade2 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -10,6 +10,7 @@ import ( "io" "math" "net/http" + "runtime" "sort" "strings" "sync" @@ -4693,7 +4694,12 @@ func (r *rpcServer) GetNodeMetrics(ctx context.Context, // Calculate betweenness centrality if requested. Note that depending on the // graph size, this may take up to a few minutes. channelGraph := autopilot.ChannelGraphFromDatabase(graph) - centralityMetric := autopilot.NewBetweennessCentralityMetric() + centralityMetric, err := autopilot.NewBetweennessCentralityMetric( + runtime.NumCPU(), + ) + if err != nil { + return nil, err + } if err := centralityMetric.Refresh(channelGraph); err != nil { return nil, err }