package autopilot

import (
	"fmt"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestBetweennessCentralityMetricConstruction(t *testing.T) {
	failing := []int{-1, 0}
	ok := []int{1, 10}

	for _, workers := range failing {
		m, err := NewBetweennessCentralityMetric(workers)
		require.Error(
			t, err, "construction must fail with <= 0 workers",
		)
		require.Nil(t, m)
	}

	for _, workers := range ok {
		m, err := NewBetweennessCentralityMetric(workers)
		require.NoError(
			t, err, "construction must succeed with >= 1 workers",
		)
		require.NotNil(t, m)
	}
}

// Tests that empty graph results in empty centrality result.
func TestBetweennessCentralityEmptyGraph(t *testing.T) {
	centralityMetric, err := NewBetweennessCentralityMetric(1)
	require.NoError(
		t, err,
		"construction must succeed with positive number of workers",
	)

	for _, chanGraph := range chanGraphs {
		graph, cleanup, err := chanGraph.genFunc()
		success := t.Run(chanGraph.name, func(t1 *testing.T) {
			require.NoError(t, err, "unable to create graph")

			if cleanup != nil {
				defer cleanup()
			}

			err := centralityMetric.Refresh(graph)
			require.NoError(t, err)

			centrality := centralityMetric.GetMetric(false)
			require.Equal(t, 0, len(centrality))

			centrality = centralityMetric.GetMetric(true)
			require.Equal(t, 0, len(centrality))
		})
		if !success {
			break
		}
	}
}

// Test betweenness centrality calculating using an example graph.
func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) {
	workers := []int{1, 3, 9, 100}

	tests := []struct {
		normalize  bool
		centrality []float64
	}{
		{
			normalize:  true,
			centrality: normalizedTestGraphCentrality,
		},
		{
			normalize:  false,
			centrality: testGraphCentrality,
		},
	}

	for _, numWorkers := range workers {
		for _, chanGraph := range chanGraphs {
			numWorkers := numWorkers
			graph, cleanup, err := chanGraph.genFunc()
			require.NoError(t, err, "unable to create graph")

			if cleanup != nil {
				defer cleanup()
			}

			testName := fmt.Sprintf(
				"%v %d workers", chanGraph.name, numWorkers,
			)

			success := t.Run(testName, func(t1 *testing.T) {
				metric, err := NewBetweennessCentralityMetric(
					numWorkers,
				)
				require.NoError(
					t, err,
					"construction must succeed with "+
						"positive number of workers",
				)

				graphNodes := buildTestGraph(
					t1, graph, centralityTestGraph,
				)

				err = metric.Refresh(graph)
				require.NoError(t, err)

				for _, expected := range tests {
					expected := expected
					centrality := metric.GetMetric(
						expected.normalize,
					)

					require.Equal(t,
						centralityTestGraph.nodes,
						len(centrality),
					)

					for i, c := range expected.centrality {
						nodeID := NewNodeID(
							graphNodes[i],
						)
						result, ok := centrality[nodeID]
						require.True(t, ok)
						require.Equal(t, c, result)
					}
				}
			})
			if !success {
				break
			}
		}
	}
}