diff --git a/autopilot/choice.go b/autopilot/choice.go index c51062d2..e3052535 100644 --- a/autopilot/choice.go +++ b/autopilot/choice.go @@ -10,28 +10,23 @@ import ( // weights left to choose from. var ErrNoPositive = errors.New("no positive weights left") -// weightedChoice draws a random index from the map of channel candidates, with -// a probability propotional to their score. -func weightedChoice(s map[NodeID]*AttachmentDirective) (NodeID, error) { - // Calculate the sum of scores found in the map. +// weightedChoice draws a random index from the slice of weights, with a +// probability propotional to the weight at the given index. +func weightedChoice(w []float64) (int, error) { + // Calculate the sum of weights. var sum float64 - for _, v := range s { - sum += v.Score + for _, v := range w { + sum += v } if sum <= 0 { - return NodeID{}, ErrNoPositive + return 0, ErrNoPositive } - // Create a map of normalized scores such, that they sum to 1.0. - norm := make(map[NodeID]float64) - for k, v := range s { - norm[k] = v.Score / sum - } - - // Pick a random number in the range [0.0, 1.0), and iterate the map - // until the number goes below 0. This means that each index is picked - // with a probablity equal to their normalized score. + // Pick a random number in the range [0.0, 1.0) and multiply it with + // the sum of weights. Then we'll iterate the weights until the number + // goes below 0. This means that each index is picked with a probablity + // equal to their normalized score. // // Example: // Items with scores [1, 5, 2, 2] @@ -40,14 +35,15 @@ func weightedChoice(s map[NodeID]*AttachmentDirective) (NodeID, error) { // in [0, 1.0]: // [|-0.1-||-----0.5-----||--0.2--||--0.2--|] // The following loop is now equivalent to "hitting" the intervals. - r := rand.Float64() - for k, v := range norm { - r -= v + r := rand.Float64() * sum + for i := range w { + r -= w[i] if r <= 0 { - return k, nil + return i, nil } } - return NodeID{}, fmt.Errorf("unable to make choice") + + return 0, fmt.Errorf("unable to make choice") } // chooseN picks at random min[n, len(s)] nodes if from the @@ -55,25 +51,36 @@ func weightedChoice(s map[NodeID]*AttachmentDirective) (NodeID, error) { func chooseN(n uint32, s map[NodeID]*AttachmentDirective) ( map[NodeID]*AttachmentDirective, error) { - // Keep a map of nodes not yet choosen. - rem := make(map[NodeID]*AttachmentDirective) + // Keep track of the number of nodes not yet chosen, in addition to + // their scores and NodeIDs. + rem := len(s) + scores := make([]float64, len(s)) + nodeIDs := make([]NodeID, len(s)) + i := 0 for k, v := range s { - rem[k] = v + scores[i] = v.Score + nodeIDs[i] = k + i++ } // Pick a weighted choice from the remaining nodes as long as there are // nodes left, and we haven't already picked n. chosen := make(map[NodeID]*AttachmentDirective) - for len(chosen) < int(n) && len(rem) > 0 { - choice, err := weightedChoice(rem) + for len(chosen) < int(n) && rem > 0 { + choice, err := weightedChoice(scores) if err == ErrNoPositive { return chosen, nil } else if err != nil { return nil, err } - chosen[choice] = rem[choice] - delete(rem, choice) + nID := nodeIDs[choice] + + chosen[nID] = s[nID] + + // We set the score of the chosen node to 0, so it won't be + // picked the next iteration. + scores[choice] = 0 } return chosen, nil