Skip to content

Commit 4205bff

Browse files
committed
Add CoinFlip probability dataset and curriculum with test script
1 parent b399c65 commit 4205bff

File tree

2 files changed

+262
-0
lines changed

2 files changed

+262
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
from reasoning_gym.dataset import ProceduralDataset
4+
import random
5+
import math
6+
from fractions import Fraction
7+
8+
from ..coaching import BaseCurriculum, RangeAttributeDefinition
9+
from ..factory import register_dataset
10+
11+
DATASET_NAME = "coin_flip"
12+
13+
@dataclass
14+
class CoinFlipConfig:
15+
"""Configuration for coin flip probability task generation."""
16+
17+
min_trials: int = 3
18+
max_trials: int = 15
19+
allow_exact: bool = True # whether to allow "exactly k heads" problems
20+
allow_at_least: bool = True # whether to allow "at least k heads" problems
21+
seed: Optional[int] = None
22+
size: int = 500
23+
24+
def validate(self) -> None:
25+
assert self.size > 0, "size must be positive"
26+
assert self.min_trials > 0, "min_trials must be positive"
27+
assert self.max_trials >= self.min_trials, "max_trials must be >= min_trials"
28+
29+
30+
class CoinFlipDataset(ProceduralDataset):
31+
"""Generates coin-flip probability problems (exact k heads / at-least k heads)."""
32+
def __init__(self, config: CoinFlipConfig):
33+
super().__init__(config=config, seed=config.seed, size=config.size)
34+
35+
def __getitem__(self, idx: int) -> dict:
36+
"""
37+
Generate a single N coin flip probability problem.
38+
Args:
39+
idx: Index of the item to generate
40+
41+
Returns:
42+
dict with keys:
43+
- question: str, the formatted arithmetic expression
44+
- answer: str, the ground truth result
45+
- metadata: dict with generation parameters
46+
"""
47+
# Create deterministic RNG from base seed and idx
48+
rng = random.Random(self.seed + idx)
49+
50+
# Pick number of trials
51+
n = rng.randint(self.config.min_trials, self.config.max_trials)
52+
53+
available_types = []
54+
if self.config.allow_exact:
55+
available_types.append("exact")
56+
if self.config.allow_at_least:
57+
available_types.append("at_least")
58+
59+
if not available_types:
60+
available_types = ["exact"]
61+
62+
problem_type = rng.choice(available_types)
63+
64+
if problem_type == "exact":
65+
k = rng.randint(0, n)
66+
question = f"What is the probability of getting exactly {k} heads in {n} fair coin flips?"
67+
prob = self._prob_exact_heads(n, k) # compute actual answer as float
68+
69+
else:
70+
k = rng.randint(0, n)
71+
question = f"What is the probability of getting at least {k} heads in {n} fair coin flips?"
72+
prob = self._prob_at_least_heads(n, k) # compute actual answer as float
73+
74+
answer_str = format(prob, ".10g")
75+
76+
return {
77+
"question": question,
78+
"answer": answer_str,
79+
"metadata": {
80+
"source_dataset": DATASET_NAME,
81+
"source_index": idx,
82+
"num_tosses": n,
83+
"k_heads": k,
84+
"problem_type": problem_type,
85+
"rational": {
86+
"numerator": self._rational_numerator(n, k, problem_type),
87+
"denominator": 2 ** n,
88+
}
89+
}
90+
}
91+
92+
def _prob_exact_heads(self, n: int, k: int) -> float:
93+
"""Return probability of exactly k heads in n fair coin tosses."""
94+
comb = math.comb(n, k)
95+
return comb * (0.5 ** n)
96+
97+
def _prob_at_least_heads(self, n: int, k: int) -> float:
98+
"""Return probability of at least k heads in n fair coin tosses."""
99+
total = sum(math.comb(n, i) for i in range(k, n + 1))
100+
return total * (0.5 ** n)
101+
102+
def _rational_numerator(self, n: int, k: int, problem_type: str) -> int:
103+
"""Return the numerator of the probability as a rational number."""
104+
if problem_type == "exact":
105+
return math.comb(n, k)
106+
else:
107+
return sum(math.comb(n, i) for i in range(k, n + 1))
108+
109+
def score_answer(self, answer: Optional[str], entry: dict, tol: float = 1e-4) -> float:
110+
"""
111+
Compute reward for LLM answer against oracle probability.
112+
Handles decimals, fractions, small numeric errors, and extra text.
113+
"""
114+
reward = 0.0
115+
oracle_answer=entry["answer"]
116+
117+
if answer is None or len(answer.strip()) == 0:
118+
return reward
119+
120+
answer = answer.replace(",", "")
121+
oracle_answer = oracle_answer.replace(",", "")
122+
123+
try:
124+
answer_float = float(Fraction(answer))
125+
oracle_answer_float = float(Fraction(oracle_answer))
126+
except (ValueError, ZeroDivisionError):
127+
return reward
128+
129+
if abs(answer_float - oracle_answer_float) <= tol:
130+
return 1.0
131+
132+
answer_str = f"{answer_float:.10g}"
133+
oracle_answer_str = f"{oracle_answer_float:.10g}"
134+
135+
# Partial Reward for matching prefix
136+
match_len = 0
137+
for a_char, o_char in zip(answer_str, oracle_answer_str):
138+
if a_char == o_char:
139+
match_len += 1
140+
else:
141+
break
142+
143+
reward = match_len / min(len(oracle_answer_str), len(answer_str))
144+
145+
return reward
146+
147+
148+
class CoinFlipCurriculum(BaseCurriculum):
149+
"""Curriculum that allows scaling the number of tosses."""
150+
def __init__(self):
151+
super().__init__(CoinFlipCurriculum.__name__, CoinFlipConfig)
152+
self._define_attributes(
153+
RangeAttributeDefinition(
154+
name="num_trials",
155+
levels=list(range(3, 16)), # starting from 3 upto 15 tosses
156+
default_level=0,
157+
description="Number of coin tosses (difficulty)",
158+
lower_field_name="min_trials",
159+
upper_field_name="max_trials",
160+
),
161+
)
162+
163+
register_dataset(DATASET_NAME, CoinFlipDataset, CoinFlipConfig, CoinFlipCurriculum)

tests/test_coin_flip.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import pytest
2+
from fractions import Fraction
3+
from reasoning_gym.probability.coin_flip import CoinFlipDataset, CoinFlipConfig, CoinFlipCurriculum
4+
5+
def test_coin_flip_config_validation():
6+
"""Test that invalid configs raise errors"""
7+
with pytest.raises(AssertionError):
8+
config = CoinFlipConfig(size=0)
9+
config.validate()
10+
11+
with pytest.raises(AssertionError):
12+
config = CoinFlipConfig(min_trials=0)
13+
config.validate()
14+
15+
with pytest.raises(AssertionError):
16+
config = CoinFlipConfig(min_trials=5, max_trials=3)
17+
config.validate()
18+
19+
20+
def test_coin_flip_deterministic():
21+
"""Dataset generates same items with same seed"""
22+
config = CoinFlipConfig(size=10, seed=42)
23+
dataset1 = CoinFlipDataset(config)
24+
dataset2 = CoinFlipDataset(config)
25+
for i in range(len(dataset1)):
26+
assert dataset1[i] == dataset2[i]
27+
28+
29+
def test_coin_flip_items():
30+
"""Test basic properties of generated items"""
31+
config = CoinFlipConfig(min_trials=3, max_trials=6, size=7, seed=42)
32+
dataset = CoinFlipDataset(config)
33+
34+
for i in range(len(dataset)):
35+
item = dataset[i]
36+
assert isinstance(item, dict)
37+
assert "question" in item
38+
assert "answer" in item
39+
assert 0.0 <= float(item["answer"]) <= 1.0
40+
assert "metadata" in item
41+
42+
43+
metadata = item["metadata"]
44+
assert "num_tosses" in metadata
45+
assert "k_heads" in metadata
46+
assert "problem_type" in metadata
47+
assert metadata["problem_type"] in ["exact", "at_least"]
48+
49+
rational = metadata["rational"]
50+
assert rational["denominator"] == 2 ** metadata["num_tosses"]
51+
assert rational["numerator"] > 0
52+
53+
54+
def test_coin_flip_score_answer():
55+
"""Test full and partial reward behavior"""
56+
config = CoinFlipConfig(size=200, seed=42)
57+
dataset = CoinFlipDataset(config)
58+
59+
for i in range(len(dataset)):
60+
entry = dataset[i]
61+
answer = entry["answer"]
62+
63+
# Exact answer -> full reward
64+
reward = dataset.score_answer(answer, entry)
65+
assert reward == 1.0
66+
67+
# Slightly wrong answer -> partial reward
68+
if float(answer) + 0.01 <= 1.0:
69+
slightly_wrong = str(float(answer) + 0.01)
70+
else:
71+
slightly_wrong = str(float(answer) - 0.01)
72+
reward_partial = dataset.score_answer(slightly_wrong, entry)
73+
assert 0.0 <= reward_partial <= 1.0
74+
75+
def test_coin_flip_curriculum():
76+
"""Test curriculum generates valid configurations and increments attributes"""
77+
78+
curriculum = CoinFlipCurriculum()
79+
base_value = {"size": 100, "seed": 32}
80+
81+
cfg = curriculum.generate_configuration(base_value)
82+
83+
assert isinstance(cfg, CoinFlipConfig)
84+
assert cfg.size == 100
85+
assert cfg.seed == 32
86+
assert cfg.min_trials == 3
87+
assert cfg.max_trials == 3
88+
89+
# Increment attribute level for num_trials
90+
curriculum.increment_attr_level("num_trials")
91+
cfg_inc = curriculum.generate_configuration(base_value)
92+
assert cfg_inc.min_trials == 3
93+
assert cfg_inc.max_trials == 4
94+
95+
# Decrement attribute level
96+
curriculum.decrement_attr_level("num_trials")
97+
cfg_dec = curriculum.generate_configuration(base_value)
98+
assert cfg_dec.min_trials == 3
99+
assert cfg_dec.max_trials == 3

0 commit comments

Comments
 (0)