Skip to content

Commit 950f301

Browse files
committed
prevent negative weights in Choice
These should break things (by design), so let's just guard against it at the API interface via type system. Things are kept internally as ints because most golang stdlib functions expect that, so we can avoid casting everywhere.
1 parent def26bd commit 950f301

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

weightedrand.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
// Choice is a generic wrapper that can be used to add weights for any object
1919
type Choice struct {
2020
Item interface{}
21-
Weight int
21+
Weight uint
2222
}
2323

2424
// A Chooser caches many possible Choices in a structure designed to improve
@@ -38,7 +38,7 @@ func NewChooser(cs ...Choice) Chooser {
3838
totals := make([]int, n, n)
3939
runningTotal := 0
4040
for i, c := range cs {
41-
runningTotal += c.Weight
41+
runningTotal += int(c.Weight)
4242
totals[i] = runningTotal
4343
}
4444
return Chooser{data: cs, totals: totals, max: runningTotal}

weightedrand_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func mockChoices(n int) []Choice {
1616
for i := 0; i < n; i++ {
1717
s := "⚽️"
1818
w := rand.Intn(10)
19-
c := Choice{Item: s, Weight: w}
19+
c := Choice{Item: s, Weight: uint(w)}
2020
choices = append(choices, c)
2121
}
2222
return choices
@@ -35,7 +35,7 @@ func TestWeightedChoice(t *testing.T) {
3535
presorted data. */
3636
list := rand.Perm(10)
3737
for _, v := range list {
38-
c := Choice{Weight: v, Item: v}
38+
c := Choice{Weight: uint(v), Item: v}
3939
choices = append(choices, c)
4040
}
4141
t.Log("FYI mocked choices of", choices)
@@ -59,8 +59,8 @@ func TestWeightedChoice(t *testing.T) {
5959
for i, c := range choices[0 : len(choices)-1] {
6060
next := choices[i+1]
6161
cw, nw := c.Weight, next.Weight
62-
if !(chosenCount[cw] < chosenCount[nw]) {
63-
t.Error("Value not lesser", cw, nw, chosenCount[cw], chosenCount[nw])
62+
if !(chosenCount[int(cw)] < chosenCount[int(nw)]) {
63+
t.Error("Value not lesser", cw, nw, chosenCount[int(cw)], chosenCount[int(nw)])
6464
}
6565
}
6666

0 commit comments

Comments
 (0)