autopilot: parallelize betweenness centrality

This commit parallelizes betweenness centrality calculation, by
distributing the algo to N workers and creating partial results.
This commit is contained in:
Andras Banki-Horvath 2020-03-24 17:28:55 +01:00
parent 7e50997bb4
commit 608354032c
3 changed files with 136 additions and 46 deletions

View File

@ -1,5 +1,10 @@
package autopilot
import (
"fmt"
"sync"
)
// stack is a simple int stack to help with readability of Brandes'
// betweenness centrality implementation below.
type stack struct {
@ -50,6 +55,10 @@ func (q *queue) empty() bool {
// shortest paths starting or ending at that node. This is a useful metric
// to measure control of individual nodes over the whole network.
type BetweennessCentrality struct {
// workers number of goroutines are used to parallelize
// centrality calculation.
workers int
// centrality stores original (not normalized) centrality values for
// each node in the graph.
centrality map[NodeID]float64
@ -62,8 +71,15 @@ type BetweennessCentrality struct {
}
// NewBetweennessCentralityMetric creates a new BetweennessCentrality instance.
func NewBetweennessCentralityMetric() *BetweennessCentrality {
return &BetweennessCentrality{}
// Users can specify the number of workers to use for calculating centrality.
func NewBetweennessCentralityMetric(workers int) (*BetweennessCentrality, error) {
// There should be at least one worker.
if workers < 1 {
return nil, fmt.Errorf("workers must be positive")
}
return &BetweennessCentrality{
workers: workers,
}, nil
}
// Name returns the name of the metric.
@ -158,10 +174,47 @@ func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error {
return err
}
// TODO: parallelize updates to centrality.
centrality := make([]float64, len(cache.Nodes))
var wg sync.WaitGroup
work := make(chan int)
partials := make(chan []float64, bc.workers)
// Each worker will compute a partial result. This
// partial result is a sum of centrality updates on
// roughly N / workers nodes.
worker := func() {
defer wg.Done()
partial := make([]float64, len(cache.Nodes))
// Consume the next node, update centrality
// parital to avoid unnecessary synchronizaton.
for node := range work {
betweennessCentrality(cache, node, partial)
}
partials <- partial
}
// Now start the N workers.
wg.Add(bc.workers)
for i := 0; i < bc.workers; i++ {
go worker()
}
// Distribute work amongst workers Should be
// fair when graph is sufficiently large.
for node := range cache.Nodes {
betweennessCentrality(cache, node, centrality)
work <- node
}
close(work)
wg.Wait()
close(partials)
// Collect and sum partials for final result.
centrality := make([]float64, len(cache.Nodes))
for partial := range partials {
for i := 0; i < len(partial); i++ {
centrality[i] += partial[i]
}
}
// Get min/max to be able to normalize
@ -169,11 +222,11 @@ func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error {
bc.min = 0
bc.max = 0
if len(centrality) > 0 {
for i := 1; i < len(centrality); i++ {
if centrality[i] < bc.min {
bc.min = centrality[i]
} else if centrality[i] > bc.max {
bc.max = centrality[i]
for _, v := range centrality {
if v < bc.min {
bc.min = v
} else if v > bc.max {
bc.max = v
}
}
}

View File

@ -1,15 +1,38 @@
package autopilot
import (
"fmt"
"testing"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcutil"
)
func TestBetweennessCentralityMetricConstruction(t *testing.T) {
failing := []int{-1, 0}
ok := []int{1, 10}
for _, workers := range failing {
m, err := NewBetweennessCentralityMetric(workers)
if m != nil || err == nil {
t.Fatalf("construction must fail with <= 0 workers")
}
}
for _, workers := range ok {
m, err := NewBetweennessCentralityMetric(workers)
if m == nil || err != nil {
t.Fatalf("construction must succeed with >= 1 workers")
}
}
}
// Tests that empty graph results in empty centrality result.
func TestBetweennessCentralityEmptyGraph(t *testing.T) {
centralityMetric := NewBetweennessCentralityMetric()
centralityMetric, err := NewBetweennessCentralityMetric(1)
if err != nil {
t.Fatalf("construction must succeed with positive number of workers")
}
for _, chanGraph := range chanGraphs {
graph, cleanup, err := chanGraph.genFunc()
@ -91,8 +114,9 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
},
}
tests := []struct {
name string
workers := []int{1, 3, 9, 100}
results := []struct {
normalize bool
centrality []float64
}{
@ -110,47 +134,54 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
},
}
for _, chanGraph := range chanGraphs {
graph, cleanup, err := chanGraph.genFunc()
if err != nil {
t.Fatalf("unable to create graph: %v", err)
}
if cleanup != nil {
defer cleanup()
}
success := t.Run(chanGraph.name, func(t1 *testing.T) {
centralityMetric := NewBetweennessCentralityMetric()
graphNodes := buildTestGraph(t1, graph, graphDesc)
if err := centralityMetric.Refresh(graph); err != nil {
t1.Fatalf("error while calculating betweeness centrality")
for _, numWorkers := range workers {
for _, chanGraph := range chanGraphs {
numWorkers := numWorkers
graph, cleanup, err := chanGraph.genFunc()
if err != nil {
t.Fatalf("unable to create graph: %v", err)
}
if cleanup != nil {
defer cleanup()
}
for _, test := range tests {
test := test
centrality := centralityMetric.GetMetric(test.normalize)
if len(centrality) != graphDesc.nodes {
t.Fatalf("expected %v values, got: %v",
graphDesc.nodes, len(centrality))
testName := fmt.Sprintf("%v %d workers", chanGraph.name, numWorkers)
success := t.Run(testName, func(t1 *testing.T) {
centralityMetric, err := NewBetweennessCentralityMetric(numWorkers)
if err != nil {
t.Fatalf("construction must succeed with positive number of workers")
}
for node, nodeCentrality := range test.centrality {
nodeID := NewNodeID(graphNodes[node])
calculatedCentrality, ok := centrality[nodeID]
if !ok {
t1.Fatalf("no result for node: %x (%v)", nodeID, node)
graphNodes := buildTestGraph(t1, graph, graphDesc)
if err := centralityMetric.Refresh(graph); err != nil {
t1.Fatalf("error while calculating betweeness centrality")
}
for _, expected := range results {
expected := expected
centrality := centralityMetric.GetMetric(expected.normalize)
if len(centrality) != graphDesc.nodes {
t.Fatalf("expected %v values, got: %v",
graphDesc.nodes, len(centrality))
}
if nodeCentrality != calculatedCentrality {
t1.Errorf("centrality for node: %v should be %v, got: %v",
node, test.centrality[node], calculatedCentrality)
for node, nodeCentrality := range expected.centrality {
nodeID := NewNodeID(graphNodes[node])
calculatedCentrality, ok := centrality[nodeID]
if !ok {
t1.Fatalf("no result for node: %x (%v)", nodeID, node)
}
if nodeCentrality != calculatedCentrality {
t1.Errorf("centrality for node: %v should be %v, got: %v",
node, nodeCentrality, calculatedCentrality)
}
}
}
})
if !success {
break
}
})
if !success {
break
}
}
}

View File

@ -10,6 +10,7 @@ import (
"io"
"math"
"net/http"
"runtime"
"sort"
"strings"
"sync"
@ -4693,7 +4694,12 @@ func (r *rpcServer) GetNodeMetrics(ctx context.Context,
// Calculate betweenness centrality if requested. Note that depending on the
// graph size, this may take up to a few minutes.
channelGraph := autopilot.ChannelGraphFromDatabase(graph)
centralityMetric := autopilot.NewBetweennessCentralityMetric()
centralityMetric, err := autopilot.NewBetweennessCentralityMetric(
runtime.NumCPU(),
)
if err != nil {
return nil, err
}
if err := centralityMetric.Refresh(channelGraph); err != nil {
return nil, err
}