1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+ from reasoning_gym .dataset import ProceduralDataset
4
+ import random
5
+ import math
6
+ from fractions import Fraction
7
+
8
+ from ..coaching import BaseCurriculum , RangeAttributeDefinition
9
+ from ..factory import register_dataset
10
+
11
+ DATASET_NAME = "coin_flip"
12
+
13
+ @dataclass
14
+ class CoinFlipConfig :
15
+ """Configuration for coin flip probability task generation."""
16
+
17
+ min_trials : int = 3
18
+ max_trials : int = 15
19
+ allow_exact : bool = True # whether to allow "exactly k heads" problems
20
+ allow_at_least : bool = True # whether to allow "at least k heads" problems
21
+ seed : Optional [int ] = None
22
+ size : int = 500
23
+
24
+ def validate (self ) -> None :
25
+ assert self .size > 0 , "size must be positive"
26
+ assert self .min_trials > 0 , "min_trials must be positive"
27
+ assert self .max_trials >= self .min_trials , "max_trials must be >= min_trials"
28
+
29
+
30
+ class CoinFlipDataset (ProceduralDataset ):
31
+ """Generates coin-flip probability problems (exact k heads / at-least k heads)."""
32
+ def __init__ (self , config : CoinFlipConfig ):
33
+ super ().__init__ (config = config , seed = config .seed , size = config .size )
34
+
35
+ def __getitem__ (self , idx : int ) -> dict :
36
+ """
37
+ Generate a single N coin flip probability problem.
38
+ Args:
39
+ idx: Index of the item to generate
40
+
41
+ Returns:
42
+ dict with keys:
43
+ - question: str, the formatted arithmetic expression
44
+ - answer: str, the ground truth result
45
+ - metadata: dict with generation parameters
46
+ """
47
+ # Create deterministic RNG from base seed and idx
48
+ rng = random .Random (self .seed + idx )
49
+
50
+ # Pick number of trials
51
+ n = rng .randint (self .config .min_trials , self .config .max_trials )
52
+
53
+ available_types = []
54
+ if self .config .allow_exact :
55
+ available_types .append ("exact" )
56
+ if self .config .allow_at_least :
57
+ available_types .append ("at_least" )
58
+
59
+ if not available_types :
60
+ available_types = ["exact" ]
61
+
62
+ problem_type = rng .choice (available_types )
63
+
64
+ if problem_type == "exact" :
65
+ k = rng .randint (0 , n )
66
+ question = f"What is the probability of getting exactly { k } heads in { n } fair coin flips?"
67
+ prob = self ._prob_exact_heads (n , k ) # compute actual answer as float
68
+
69
+ else :
70
+ k = rng .randint (0 , n )
71
+ question = f"What is the probability of getting at least { k } heads in { n } fair coin flips?"
72
+ prob = self ._prob_at_least_heads (n , k ) # compute actual answer as float
73
+
74
+ answer_str = format (prob , ".10g" )
75
+
76
+ return {
77
+ "question" : question ,
78
+ "answer" : answer_str ,
79
+ "metadata" : {
80
+ "source_dataset" : DATASET_NAME ,
81
+ "source_index" : idx ,
82
+ "num_tosses" : n ,
83
+ "k_heads" : k ,
84
+ "problem_type" : problem_type ,
85
+ "rational" : {
86
+ "numerator" : self ._rational_numerator (n , k , problem_type ),
87
+ "denominator" : 2 ** n ,
88
+ }
89
+ }
90
+ }
91
+
92
+ def _prob_exact_heads (self , n : int , k : int ) -> float :
93
+ """Return probability of exactly k heads in n fair coin tosses."""
94
+ comb = math .comb (n , k )
95
+ return comb * (0.5 ** n )
96
+
97
+ def _prob_at_least_heads (self , n : int , k : int ) -> float :
98
+ """Return probability of at least k heads in n fair coin tosses."""
99
+ total = sum (math .comb (n , i ) for i in range (k , n + 1 ))
100
+ return total * (0.5 ** n )
101
+
102
+ def _rational_numerator (self , n : int , k : int , problem_type : str ) -> int :
103
+ """Return the numerator of the probability as a rational number."""
104
+ if problem_type == "exact" :
105
+ return math .comb (n , k )
106
+ else :
107
+ return sum (math .comb (n , i ) for i in range (k , n + 1 ))
108
+
109
+ def score_answer (self , answer : Optional [str ], entry : dict , tol : float = 1e-4 ) -> float :
110
+ """
111
+ Compute reward for LLM answer against oracle probability.
112
+ Handles decimals, fractions, small numeric errors, and extra text.
113
+ """
114
+ reward = 0.0
115
+ oracle_answer = entry ["answer" ]
116
+
117
+ if answer is None or len (answer .strip ()) == 0 :
118
+ return reward
119
+
120
+ answer = answer .replace ("," , "" )
121
+ oracle_answer = oracle_answer .replace ("," , "" )
122
+
123
+ try :
124
+ answer_float = float (Fraction (answer ))
125
+ oracle_answer_float = float (Fraction (oracle_answer ))
126
+ except (ValueError , ZeroDivisionError ):
127
+ return reward
128
+
129
+ if abs (answer_float - oracle_answer_float ) <= tol :
130
+ return 1.0
131
+
132
+ answer_str = f"{ answer_float :.10g} "
133
+ oracle_answer_str = f"{ oracle_answer_float :.10g} "
134
+
135
+ # Partial Reward for matching prefix
136
+ match_len = 0
137
+ for a_char , o_char in zip (answer_str , oracle_answer_str ):
138
+ if a_char == o_char :
139
+ match_len += 1
140
+ else :
141
+ break
142
+
143
+ reward = match_len / min (len (oracle_answer_str ), len (answer_str ))
144
+
145
+ return reward
146
+
147
+
148
+ class CoinFlipCurriculum (BaseCurriculum ):
149
+ """Curriculum that allows scaling the number of tosses."""
150
+ def __init__ (self ):
151
+ super ().__init__ (CoinFlipCurriculum .__name__ , CoinFlipConfig )
152
+ self ._define_attributes (
153
+ RangeAttributeDefinition (
154
+ name = "num_trials" ,
155
+ levels = list (range (3 , 16 )), # starting from 3 upto 15 tosses
156
+ default_level = 0 ,
157
+ description = "Number of coin tosses (difficulty)" ,
158
+ lower_field_name = "min_trials" ,
159
+ upper_field_name = "max_trials" ,
160
+ ),
161
+ )
162
+
163
+ register_dataset (DATASET_NAME , CoinFlipDataset , CoinFlipConfig , CoinFlipCurriculum )
0 commit comments