Merge pull request #4384 from bhandras/atpl_bc_topk
add modified greedy topK centrality heuristic to autopilot
This commit is contained in:
commit
77549f1fb8
@ -657,17 +657,6 @@ func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32,
|
||||
log.Tracef("Creating attachment directive for chosen node %x",
|
||||
nID[:])
|
||||
|
||||
// Add addresses to the candidates.
|
||||
addrs := addresses[nID]
|
||||
|
||||
// If the node has no known addresses, we cannot connect to it,
|
||||
// so we'll skip it.
|
||||
if len(addrs) == 0 {
|
||||
log.Tracef("Skipping scored node %x with no addresses",
|
||||
nID[:])
|
||||
continue
|
||||
}
|
||||
|
||||
// Track the available funds we have left.
|
||||
if availableFunds < chanSize {
|
||||
chanSize = availableFunds
|
||||
@ -685,7 +674,7 @@ func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32,
|
||||
chanCandidates[nID] = &AttachmentDirective{
|
||||
NodeID: nID,
|
||||
ChanAmt: chanSize,
|
||||
Addrs: addrs,
|
||||
Addrs: addresses[nID],
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,8 +4,7 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBetweennessCentralityMetricConstruction(t *testing.T) {
|
||||
@ -14,50 +13,46 @@ func TestBetweennessCentralityMetricConstruction(t *testing.T) {
|
||||
|
||||
for _, workers := range failing {
|
||||
m, err := NewBetweennessCentralityMetric(workers)
|
||||
if m != nil || err == nil {
|
||||
t.Fatalf("construction must fail with <= 0 workers")
|
||||
}
|
||||
require.Error(
|
||||
t, err, "construction must fail with <= 0 workers",
|
||||
)
|
||||
require.Nil(t, m)
|
||||
}
|
||||
|
||||
for _, workers := range ok {
|
||||
m, err := NewBetweennessCentralityMetric(workers)
|
||||
if m == nil || err != nil {
|
||||
t.Fatalf("construction must succeed with >= 1 workers")
|
||||
}
|
||||
require.NoError(
|
||||
t, err, "construction must succeed with >= 1 workers",
|
||||
)
|
||||
require.NotNil(t, m)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that empty graph results in empty centrality result.
|
||||
func TestBetweennessCentralityEmptyGraph(t *testing.T) {
|
||||
centralityMetric, err := NewBetweennessCentralityMetric(1)
|
||||
if err != nil {
|
||||
t.Fatalf("construction must succeed with positive number of workers")
|
||||
}
|
||||
require.NoError(
|
||||
t, err,
|
||||
"construction must succeed with positive number of workers",
|
||||
)
|
||||
|
||||
for _, chanGraph := range chanGraphs {
|
||||
graph, cleanup, err := chanGraph.genFunc()
|
||||
success := t.Run(chanGraph.name, func(t1 *testing.T) {
|
||||
if err != nil {
|
||||
t1.Fatalf("unable to create graph: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "unable to create graph")
|
||||
|
||||
if cleanup != nil {
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
if err := centralityMetric.Refresh(graph); err != nil {
|
||||
t.Fatalf("unexpected failure during metric refresh: %v", err)
|
||||
}
|
||||
err := centralityMetric.Refresh(graph)
|
||||
require.NoError(t, err)
|
||||
|
||||
centrality := centralityMetric.GetMetric(false)
|
||||
if len(centrality) > 0 {
|
||||
t.Fatalf("expected empty metric, got: %v", len(centrality))
|
||||
}
|
||||
require.Equal(t, 0, len(centrality))
|
||||
|
||||
centrality = centralityMetric.GetMetric(true)
|
||||
if len(centrality) > 0 {
|
||||
t.Fatalf("expected empty metric, got: %v", len(centrality))
|
||||
}
|
||||
|
||||
require.Equal(t, 0, len(centrality))
|
||||
})
|
||||
if !success {
|
||||
break
|
||||
@ -65,72 +60,21 @@ func TestBetweennessCentralityEmptyGraph(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// testGraphDesc is a helper type to describe a test graph.
|
||||
type testGraphDesc struct {
|
||||
nodes int
|
||||
edges map[int][]int
|
||||
}
|
||||
|
||||
// buildTestGraph builds a test graph from a passed graph desriptor.
|
||||
func buildTestGraph(t *testing.T,
|
||||
graph testGraph, desc testGraphDesc) map[int]*btcec.PublicKey {
|
||||
|
||||
nodes := make(map[int]*btcec.PublicKey)
|
||||
|
||||
for i := 0; i < desc.nodes; i++ {
|
||||
key, err := graph.addRandNode()
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create random node")
|
||||
}
|
||||
|
||||
nodes[i] = key
|
||||
}
|
||||
|
||||
const chanCapacity = btcutil.SatoshiPerBitcoin
|
||||
for u, neighbors := range desc.edges {
|
||||
for _, v := range neighbors {
|
||||
_, _, err := graph.addRandChannel(nodes[u], nodes[v], chanCapacity)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding random channel: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
// Test betweenness centrality calculating using an example graph.
|
||||
func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
|
||||
graphDesc := testGraphDesc{
|
||||
nodes: 9,
|
||||
edges: map[int][]int{
|
||||
0: {1, 2, 3},
|
||||
1: {2},
|
||||
2: {3},
|
||||
3: {4, 5},
|
||||
4: {5, 6, 7},
|
||||
5: {6, 7},
|
||||
6: {7, 8},
|
||||
},
|
||||
}
|
||||
|
||||
workers := []int{1, 3, 9, 100}
|
||||
|
||||
results := []struct {
|
||||
tests := []struct {
|
||||
normalize bool
|
||||
centrality []float64
|
||||
}{
|
||||
{
|
||||
normalize: true,
|
||||
centrality: []float64{
|
||||
0.2, 0.0, 0.2, 1.0, 0.4, 0.4, 7.0 / 15.0, 0.0, 0.0,
|
||||
},
|
||||
normalize: true,
|
||||
centrality: normalizedTestGraphCentrality,
|
||||
},
|
||||
{
|
||||
normalize: false,
|
||||
centrality: []float64{
|
||||
3.0, 0.0, 3.0, 15.0, 6.0, 6.0, 7.0, 0.0, 0.0,
|
||||
},
|
||||
normalize: false,
|
||||
centrality: testGraphCentrality,
|
||||
},
|
||||
}
|
||||
|
||||
@ -138,49 +82,51 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
|
||||
for _, chanGraph := range chanGraphs {
|
||||
numWorkers := numWorkers
|
||||
graph, cleanup, err := chanGraph.genFunc()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create graph: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "unable to create graph")
|
||||
|
||||
if cleanup != nil {
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
testName := fmt.Sprintf("%v %d workers", chanGraph.name, numWorkers)
|
||||
testName := fmt.Sprintf(
|
||||
"%v %d workers", chanGraph.name, numWorkers,
|
||||
)
|
||||
|
||||
success := t.Run(testName, func(t1 *testing.T) {
|
||||
centralityMetric, err := NewBetweennessCentralityMetric(
|
||||
metric, err := NewBetweennessCentralityMetric(
|
||||
numWorkers,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("construction must succeed with " +
|
||||
"positive number of workers")
|
||||
}
|
||||
require.NoError(
|
||||
t, err,
|
||||
"construction must succeed with "+
|
||||
"positive number of workers",
|
||||
)
|
||||
|
||||
graphNodes := buildTestGraph(t1, graph, graphDesc)
|
||||
if err := centralityMetric.Refresh(graph); err != nil {
|
||||
t1.Fatalf("error while calculating betweeness centrality")
|
||||
}
|
||||
for _, expected := range results {
|
||||
graphNodes := buildTestGraph(
|
||||
t1, graph, centralityTestGraph,
|
||||
)
|
||||
|
||||
err = metric.Refresh(graph)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, expected := range tests {
|
||||
expected := expected
|
||||
centrality := centralityMetric.GetMetric(expected.normalize)
|
||||
centrality := metric.GetMetric(
|
||||
expected.normalize,
|
||||
)
|
||||
|
||||
if len(centrality) != graphDesc.nodes {
|
||||
t.Fatalf("expected %v values, got: %v",
|
||||
graphDesc.nodes, len(centrality))
|
||||
}
|
||||
require.Equal(t,
|
||||
centralityTestGraph.nodes,
|
||||
len(centrality),
|
||||
)
|
||||
|
||||
for node, nodeCentrality := range expected.centrality {
|
||||
nodeID := NewNodeID(graphNodes[node])
|
||||
calculatedCentrality, ok := centrality[nodeID]
|
||||
if !ok {
|
||||
t1.Fatalf("no result for node: %x (%v)",
|
||||
nodeID, node)
|
||||
}
|
||||
|
||||
if nodeCentrality != calculatedCentrality {
|
||||
t1.Errorf("centrality for node: %v "+
|
||||
"should be %v, got: %v",
|
||||
node, nodeCentrality, calculatedCentrality)
|
||||
}
|
||||
for i, c := range expected.centrality {
|
||||
nodeID := NewNodeID(
|
||||
graphNodes[i],
|
||||
)
|
||||
result, ok := centrality[nodeID]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, c, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
68
autopilot/centrality_testdata_test.go
Normal file
68
autopilot/centrality_testdata_test.go
Normal file
@ -0,0 +1,68 @@
|
||||
package autopilot
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testGraphDesc is a helper type to describe a test graph.
|
||||
type testGraphDesc struct {
|
||||
nodes int
|
||||
edges map[int][]int
|
||||
}
|
||||
|
||||
var centralityTestGraph = testGraphDesc{
|
||||
nodes: 9,
|
||||
edges: map[int][]int{
|
||||
0: {1, 2, 3},
|
||||
1: {2},
|
||||
2: {3},
|
||||
3: {4, 5},
|
||||
4: {5, 6, 7},
|
||||
5: {6, 7},
|
||||
6: {7, 8},
|
||||
},
|
||||
}
|
||||
|
||||
var testGraphCentrality = []float64{
|
||||
3.0, 0.0, 3.0, 15.0, 6.0, 6.0, 7.0, 0.0, 0.0,
|
||||
}
|
||||
|
||||
var normalizedTestGraphCentrality = []float64{
|
||||
0.2, 0.0, 0.2, 1.0, 0.4, 0.4, 7.0 / 15.0, 0.0, 0.0,
|
||||
}
|
||||
|
||||
// buildTestGraph builds a test graph from a passed graph desriptor.
|
||||
func buildTestGraph(t *testing.T,
|
||||
graph testGraph, desc testGraphDesc) map[int]*btcec.PublicKey {
|
||||
|
||||
nodes := make(map[int]*btcec.PublicKey)
|
||||
|
||||
for i := 0; i < desc.nodes; i++ {
|
||||
key, err := graph.addRandNode()
|
||||
require.NoError(t, err, "cannot create random node")
|
||||
|
||||
nodes[i] = key
|
||||
}
|
||||
|
||||
const chanCapacity = btcutil.SatoshiPerBitcoin
|
||||
for u, neighbors := range desc.edges {
|
||||
for _, v := range neighbors {
|
||||
_, _, err := graph.addRandChannel(
|
||||
nodes[u], nodes[v], chanCapacity,
|
||||
)
|
||||
require.NoError(t, err,
|
||||
"unexpected error adding random channel",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding"+
|
||||
"random channel: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
@ -185,6 +185,7 @@ var (
|
||||
availableHeuristics = []AttachmentHeuristic{
|
||||
NewPrefAttachment(),
|
||||
NewExternalScoreAttachment(),
|
||||
NewTopCentrality(),
|
||||
}
|
||||
|
||||
// AvailableHeuristics is a map that holds the name of available
|
||||
|
93
autopilot/top_centrality.go
Normal file
93
autopilot/top_centrality.go
Normal file
@ -0,0 +1,93 @@
|
||||
package autopilot
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"github.com/btcsuite/btcutil"
|
||||
)
|
||||
|
||||
// TopCentrality is a simple greedy technique to create connections to nodes
|
||||
// with the top betweenness centrality value. This algorithm is usually
|
||||
// referred to as TopK in the literature. The idea is that by opening channels
|
||||
// to nodes with top betweenness centrality we also increase our own betweenness
|
||||
// centrality (given we already have at least one channel, or create at least
|
||||
// two new channels).
|
||||
// A different and much better approach is instead of selecting nodes with top
|
||||
// centrality value, we extend the graph in a loop by inserting a new non
|
||||
// existing edge and recalculate the betweenness centrality of each node. This
|
||||
// technique is usually referred to as "greedy" algorithm and gives better
|
||||
// results than TopK but is considerably slower too.
|
||||
type TopCentrality struct {
|
||||
centralityMetric *BetweennessCentrality
|
||||
}
|
||||
|
||||
// A compile time assertion to ensure TopCentrality meets the
|
||||
// AttachmentHeuristic interface.
|
||||
var _ AttachmentHeuristic = (*TopCentrality)(nil)
|
||||
|
||||
// NewTopCentrality constructs and returns a new TopCentrality heuristic.
|
||||
func NewTopCentrality() *TopCentrality {
|
||||
metric, err := NewBetweennessCentralityMetric(
|
||||
runtime.NumCPU(),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &TopCentrality{
|
||||
centralityMetric: metric,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name of the heuristic.
|
||||
func (g *TopCentrality) Name() string {
|
||||
return "top_centrality"
|
||||
}
|
||||
|
||||
// NodeScores will return a [0,1] normalized map of scores for the given nodes
|
||||
// except for the ones we already have channels with. The scores will simply
|
||||
// be the betweenness centrality values of the nodes.
|
||||
// As our current implementation of betweenness centrality is non-incremental,
|
||||
// NodeScores will recalculate the centrality values on every call, which is
|
||||
// slow for large graphs.
|
||||
func (g *TopCentrality) NodeScores(graph ChannelGraph, chans []Channel,
|
||||
chanSize btcutil.Amount, nodes map[NodeID]struct{}) (
|
||||
map[NodeID]*NodeScore, error) {
|
||||
|
||||
// Calculate betweenness centrality for the whole graph.
|
||||
if err := g.centralityMetric.Refresh(graph); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalize := true
|
||||
centrality := g.centralityMetric.GetMetric(normalize)
|
||||
|
||||
// Create a map of the existing peers for faster filtering.
|
||||
existingPeers := make(map[NodeID]struct{})
|
||||
for _, c := range chans {
|
||||
existingPeers[c.Node] = struct{}{}
|
||||
}
|
||||
|
||||
result := make(map[NodeID]*NodeScore, len(nodes))
|
||||
for nodeID := range nodes {
|
||||
// Skip nodes we already have channel with.
|
||||
if _, ok := existingPeers[nodeID]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip passed nodes not in the graph. This could happen if
|
||||
// the graph changed before computing the centrality values as
|
||||
// the nodes we iterate are prefiltered by the autopilot agent.
|
||||
score, ok := centrality[nodeID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
result[nodeID] = &NodeScore{
|
||||
NodeID: nodeID,
|
||||
Score: score,
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
109
autopilot/top_centrality_test.go
Normal file
109
autopilot/top_centrality_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
package autopilot
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testTopCentrality is subtest helper to which given the passed graph and
|
||||
// channels creates the expected centrality score set and checks that the
|
||||
// calculated score set matches it.
|
||||
func testTopCentrality(t *testing.T, graph testGraph,
|
||||
graphNodes map[int]*btcec.PublicKey, channelsWith []int) {
|
||||
|
||||
topCentrality := NewTopCentrality()
|
||||
|
||||
var channels []Channel
|
||||
for _, ch := range channelsWith {
|
||||
channels = append(channels, Channel{
|
||||
Node: NewNodeID(graphNodes[ch]),
|
||||
})
|
||||
}
|
||||
|
||||
// Start iteration from -1 to also test the case where the node set
|
||||
// is empty.
|
||||
for i := -1; i < len(graphNodes); i++ {
|
||||
nodes := make(map[NodeID]struct{})
|
||||
expected := make(map[NodeID]*NodeScore)
|
||||
|
||||
for j := 0; j <= i; j++ {
|
||||
// Add node to the interest set.
|
||||
nodeID := NewNodeID(graphNodes[j])
|
||||
nodes[nodeID] = struct{}{}
|
||||
|
||||
// Add to the expected set unless it's a node we have
|
||||
// a channel with.
|
||||
haveChannel := false
|
||||
for _, ch := range channels {
|
||||
if nodeID == ch.Node {
|
||||
haveChannel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !haveChannel {
|
||||
score := normalizedTestGraphCentrality[j]
|
||||
expected[nodeID] = &NodeScore{
|
||||
NodeID: nodeID,
|
||||
Score: score,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const chanSize = btcutil.SatoshiPerBitcoin
|
||||
|
||||
// Attempt to get centrality scores and expect
|
||||
// that the result equals with the expected set.
|
||||
scores, err := topCentrality.NodeScores(
|
||||
graph, channels, chanSize, nodes,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, scores)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTopCentrality tests that we return the correct normalized centralitiy
|
||||
// values given a non empty graph, and given our node has an increasing amount
|
||||
// of channels from 0 to N-1 simulating the whole range from non-connected to
|
||||
// fully connected.
|
||||
func TestTopCentrality(t *testing.T) {
|
||||
// Generate channels: {}, {0}, {0, 1}, ... {0, 1, ..., N-1}
|
||||
channelsWith := [][]int{nil}
|
||||
|
||||
for i := 0; i < centralityTestGraph.nodes; i++ {
|
||||
channels := make([]int, i+1)
|
||||
for j := 0; j <= i; j++ {
|
||||
channels[j] = j
|
||||
}
|
||||
channelsWith = append(channelsWith, channels)
|
||||
}
|
||||
|
||||
for _, chanGraph := range chanGraphs {
|
||||
chanGraph := chanGraph
|
||||
|
||||
success := t.Run(chanGraph.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
graph, cleanup, err := chanGraph.genFunc()
|
||||
require.NoError(t, err, "unable to create graph")
|
||||
if cleanup != nil {
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
// Build the test graph.
|
||||
graphNodes := buildTestGraph(
|
||||
t, graph, centralityTestGraph,
|
||||
)
|
||||
|
||||
for _, chans := range channelsWith {
|
||||
testTopCentrality(t, graph, graphNodes, chans)
|
||||
}
|
||||
})
|
||||
|
||||
require.True(t, success)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user