autopilot: parallelize betweenness centrality
This commit parallelizes betweenness centrality calculation, by distributing the algo to N workers and creating partial results.
This commit is contained in:
parent
7e50997bb4
commit
608354032c
@ -1,5 +1,10 @@
|
|||||||
package autopilot
|
package autopilot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
// stack is a simple int stack to help with readability of Brandes'
|
// stack is a simple int stack to help with readability of Brandes'
|
||||||
// betweenness centrality implementation below.
|
// betweenness centrality implementation below.
|
||||||
type stack struct {
|
type stack struct {
|
||||||
@ -50,6 +55,10 @@ func (q *queue) empty() bool {
|
|||||||
// shortest paths starting or ending at that node. This is a useful metric
|
// shortest paths starting or ending at that node. This is a useful metric
|
||||||
// to measure control of individual nodes over the whole network.
|
// to measure control of individual nodes over the whole network.
|
||||||
type BetweennessCentrality struct {
|
type BetweennessCentrality struct {
|
||||||
|
// workers number of goroutines are used to parallelize
|
||||||
|
// centrality calculation.
|
||||||
|
workers int
|
||||||
|
|
||||||
// centrality stores original (not normalized) centrality values for
|
// centrality stores original (not normalized) centrality values for
|
||||||
// each node in the graph.
|
// each node in the graph.
|
||||||
centrality map[NodeID]float64
|
centrality map[NodeID]float64
|
||||||
@ -62,8 +71,15 @@ type BetweennessCentrality struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewBetweennessCentralityMetric creates a new BetweennessCentrality instance.
|
// NewBetweennessCentralityMetric creates a new BetweennessCentrality instance.
|
||||||
func NewBetweennessCentralityMetric() *BetweennessCentrality {
|
// Users can specify the number of workers to use for calculating centrality.
|
||||||
return &BetweennessCentrality{}
|
func NewBetweennessCentralityMetric(workers int) (*BetweennessCentrality, error) {
|
||||||
|
// There should be at least one worker.
|
||||||
|
if workers < 1 {
|
||||||
|
return nil, fmt.Errorf("workers must be positive")
|
||||||
|
}
|
||||||
|
return &BetweennessCentrality{
|
||||||
|
workers: workers,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name returns the name of the metric.
|
// Name returns the name of the metric.
|
||||||
@ -158,10 +174,47 @@ func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: parallelize updates to centrality.
|
var wg sync.WaitGroup
|
||||||
centrality := make([]float64, len(cache.Nodes))
|
work := make(chan int)
|
||||||
|
partials := make(chan []float64, bc.workers)
|
||||||
|
|
||||||
|
// Each worker will compute a partial result. This
|
||||||
|
// partial result is a sum of centrality updates on
|
||||||
|
// roughly N / workers nodes.
|
||||||
|
worker := func() {
|
||||||
|
defer wg.Done()
|
||||||
|
partial := make([]float64, len(cache.Nodes))
|
||||||
|
|
||||||
|
// Consume the next node, update centrality
|
||||||
|
// parital to avoid unnecessary synchronizaton.
|
||||||
|
for node := range work {
|
||||||
|
betweennessCentrality(cache, node, partial)
|
||||||
|
}
|
||||||
|
partials <- partial
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now start the N workers.
|
||||||
|
wg.Add(bc.workers)
|
||||||
|
for i := 0; i < bc.workers; i++ {
|
||||||
|
go worker()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Distribute work amongst workers Should be
|
||||||
|
// fair when graph is sufficiently large.
|
||||||
for node := range cache.Nodes {
|
for node := range cache.Nodes {
|
||||||
betweennessCentrality(cache, node, centrality)
|
work <- node
|
||||||
|
}
|
||||||
|
|
||||||
|
close(work)
|
||||||
|
wg.Wait()
|
||||||
|
close(partials)
|
||||||
|
|
||||||
|
// Collect and sum partials for final result.
|
||||||
|
centrality := make([]float64, len(cache.Nodes))
|
||||||
|
for partial := range partials {
|
||||||
|
for i := 0; i < len(partial); i++ {
|
||||||
|
centrality[i] += partial[i]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get min/max to be able to normalize
|
// Get min/max to be able to normalize
|
||||||
@ -169,11 +222,11 @@ func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error {
|
|||||||
bc.min = 0
|
bc.min = 0
|
||||||
bc.max = 0
|
bc.max = 0
|
||||||
if len(centrality) > 0 {
|
if len(centrality) > 0 {
|
||||||
for i := 1; i < len(centrality); i++ {
|
for _, v := range centrality {
|
||||||
if centrality[i] < bc.min {
|
if v < bc.min {
|
||||||
bc.min = centrality[i]
|
bc.min = v
|
||||||
} else if centrality[i] > bc.max {
|
} else if v > bc.max {
|
||||||
bc.max = centrality[i]
|
bc.max = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,15 +1,38 @@
|
|||||||
package autopilot
|
package autopilot
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/btcsuite/btcutil"
|
"github.com/btcsuite/btcutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestBetweennessCentralityMetricConstruction(t *testing.T) {
|
||||||
|
failing := []int{-1, 0}
|
||||||
|
ok := []int{1, 10}
|
||||||
|
|
||||||
|
for _, workers := range failing {
|
||||||
|
m, err := NewBetweennessCentralityMetric(workers)
|
||||||
|
if m != nil || err == nil {
|
||||||
|
t.Fatalf("construction must fail with <= 0 workers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, workers := range ok {
|
||||||
|
m, err := NewBetweennessCentralityMetric(workers)
|
||||||
|
if m == nil || err != nil {
|
||||||
|
t.Fatalf("construction must succeed with >= 1 workers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Tests that empty graph results in empty centrality result.
|
// Tests that empty graph results in empty centrality result.
|
||||||
func TestBetweennessCentralityEmptyGraph(t *testing.T) {
|
func TestBetweennessCentralityEmptyGraph(t *testing.T) {
|
||||||
centralityMetric := NewBetweennessCentralityMetric()
|
centralityMetric, err := NewBetweennessCentralityMetric(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("construction must succeed with positive number of workers")
|
||||||
|
}
|
||||||
|
|
||||||
for _, chanGraph := range chanGraphs {
|
for _, chanGraph := range chanGraphs {
|
||||||
graph, cleanup, err := chanGraph.genFunc()
|
graph, cleanup, err := chanGraph.genFunc()
|
||||||
@ -91,8 +114,9 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
workers := []int{1, 3, 9, 100}
|
||||||
name string
|
|
||||||
|
results := []struct {
|
||||||
normalize bool
|
normalize bool
|
||||||
centrality []float64
|
centrality []float64
|
||||||
}{
|
}{
|
||||||
@ -110,7 +134,9 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, numWorkers := range workers {
|
||||||
for _, chanGraph := range chanGraphs {
|
for _, chanGraph := range chanGraphs {
|
||||||
|
numWorkers := numWorkers
|
||||||
graph, cleanup, err := chanGraph.genFunc()
|
graph, cleanup, err := chanGraph.genFunc()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create graph: %v", err)
|
t.Fatalf("unable to create graph: %v", err)
|
||||||
@ -119,23 +145,27 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
}
|
}
|
||||||
|
|
||||||
success := t.Run(chanGraph.name, func(t1 *testing.T) {
|
testName := fmt.Sprintf("%v %d workers", chanGraph.name, numWorkers)
|
||||||
centralityMetric := NewBetweennessCentralityMetric()
|
success := t.Run(testName, func(t1 *testing.T) {
|
||||||
graphNodes := buildTestGraph(t1, graph, graphDesc)
|
centralityMetric, err := NewBetweennessCentralityMetric(numWorkers)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("construction must succeed with positive number of workers")
|
||||||
|
}
|
||||||
|
|
||||||
|
graphNodes := buildTestGraph(t1, graph, graphDesc)
|
||||||
if err := centralityMetric.Refresh(graph); err != nil {
|
if err := centralityMetric.Refresh(graph); err != nil {
|
||||||
t1.Fatalf("error while calculating betweeness centrality")
|
t1.Fatalf("error while calculating betweeness centrality")
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, expected := range results {
|
||||||
test := test
|
expected := expected
|
||||||
centrality := centralityMetric.GetMetric(test.normalize)
|
centrality := centralityMetric.GetMetric(expected.normalize)
|
||||||
|
|
||||||
if len(centrality) != graphDesc.nodes {
|
if len(centrality) != graphDesc.nodes {
|
||||||
t.Fatalf("expected %v values, got: %v",
|
t.Fatalf("expected %v values, got: %v",
|
||||||
graphDesc.nodes, len(centrality))
|
graphDesc.nodes, len(centrality))
|
||||||
}
|
}
|
||||||
|
|
||||||
for node, nodeCentrality := range test.centrality {
|
for node, nodeCentrality := range expected.centrality {
|
||||||
nodeID := NewNodeID(graphNodes[node])
|
nodeID := NewNodeID(graphNodes[node])
|
||||||
calculatedCentrality, ok := centrality[nodeID]
|
calculatedCentrality, ok := centrality[nodeID]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -144,7 +174,7 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
|
|||||||
|
|
||||||
if nodeCentrality != calculatedCentrality {
|
if nodeCentrality != calculatedCentrality {
|
||||||
t1.Errorf("centrality for node: %v should be %v, got: %v",
|
t1.Errorf("centrality for node: %v should be %v, got: %v",
|
||||||
node, test.centrality[node], calculatedCentrality)
|
node, nodeCentrality, calculatedCentrality)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,3 +184,4 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -4693,7 +4694,12 @@ func (r *rpcServer) GetNodeMetrics(ctx context.Context,
|
|||||||
// Calculate betweenness centrality if requested. Note that depending on the
|
// Calculate betweenness centrality if requested. Note that depending on the
|
||||||
// graph size, this may take up to a few minutes.
|
// graph size, this may take up to a few minutes.
|
||||||
channelGraph := autopilot.ChannelGraphFromDatabase(graph)
|
channelGraph := autopilot.ChannelGraphFromDatabase(graph)
|
||||||
centralityMetric := autopilot.NewBetweennessCentralityMetric()
|
centralityMetric, err := autopilot.NewBetweennessCentralityMetric(
|
||||||
|
runtime.NumCPU(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := centralityMetric.Refresh(channelGraph); err != nil {
|
if err := centralityMetric.Refresh(channelGraph); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user