Skip to content

Commit 14c1924

Browse files
authored
Create get_by_distribution.py
1 parent f1c18e3 commit 14c1924

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

get_by_distribution.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/usr/bin/env python
2+
3+
"""
4+
Provide convenient way to select list elements with different probability.
5+
"""
6+
7+
import random
8+
import bisect
9+
import collections
10+
import logging
11+
12+
13+
def set_log():
14+
logger = logging.getLogger()
15+
logger.setLevel(LOG_LEVEL)
16+
17+
fm = logging.Formatter('%(filename)s [LINE:%(lineno)d]# %(levelname)-8s [%(asctime)s] %(message)s')
18+
19+
console = logging.StreamHandler()
20+
console.setLevel(LOG_LEVEL)
21+
console.setFormatter(fm)
22+
23+
logger.addHandler(console)
24+
25+
26+
def cdf(weights):
27+
total = sum(weights)
28+
result = []
29+
cum_sum = 0
30+
31+
for w in weights:
32+
cum_sum += w
33+
result.append(cum_sum/total)
34+
35+
return result
36+
37+
38+
def get_by_distribution(collection, weights):
39+
assert len(collection) == len(weights)
40+
41+
cdf_values = cdf(weights)
42+
x = random.random()
43+
idx = bisect.bisect(cdf_values, x)
44+
logging.debug("cdf_values: %s x: %d idx: %d", cdf_values, x, idx)
45+
46+
return collection[idx]
47+
48+
49+
if __name__ == '__main__':
50+
population = 'ABC'
51+
distribution = [0.3, 0.4, 0.3]
52+
53+
LOG_LEVEL = 'INFO' # 'DEBUG'
54+
set_log()
55+
56+
counts = collections.defaultdict(int)
57+
for i in range(10000):
58+
counts[get_by_distribution(population, distribution)] += 1
59+
logging.info(counts)
60+
61+
# % test.py
62+
# defaultdict(<type 'int'>, {'A': 3066, 'C': 2964, 'B': 3970})

0 commit comments

Comments
 (0)