Skip to content

Commit 32008c0

Browse files
committed
create init and add difficulty metadata for curriculum
1 parent 4205bff commit 32008c0

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
Probability reasoning tasks.
3+
"""
4+
5+
from .coin_flip import CoinFlipDataset, CoinFlipConfig, CoinFlipCurriculum
6+
7+
__all__ = [
8+
"CoinFlipDataset",
9+
"CoinFlipConfig",
10+
"CoinFlipCurriculum"
11+
]

reasoning_gym/probability/coin_flip.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def validate(self) -> None:
2525
assert self.size > 0, "size must be positive"
2626
assert self.min_trials > 0, "min_trials must be positive"
2727
assert self.max_trials >= self.min_trials, "max_trials must be >= min_trials"
28+
assert self.allow_exact or self.allow_at_least, "At least one of allow_exact or allow_at_least must be True"
2829

2930

3031
class CoinFlipDataset(ProceduralDataset):
@@ -55,9 +56,6 @@ def __getitem__(self, idx: int) -> dict:
5556
available_types.append("exact")
5657
if self.config.allow_at_least:
5758
available_types.append("at_least")
58-
59-
if not available_types:
60-
available_types = ["exact"]
6159

6260
problem_type = rng.choice(available_types)
6361

@@ -79,13 +77,16 @@ def __getitem__(self, idx: int) -> dict:
7977
"metadata": {
8078
"source_dataset": DATASET_NAME,
8179
"source_index": idx,
82-
"num_tosses": n,
80+
"num_trials": n,
8381
"k_heads": k,
8482
"problem_type": problem_type,
8583
"rational": {
8684
"numerator": self._rational_numerator(n, k, problem_type),
8785
"denominator": 2 ** n,
88-
}
86+
},
87+
"difficulty": {
88+
"num_trials": (self.config.min_trials, self.config.max_trials),
89+
},
8990
}
9091
}
9192

tests/test_coin_flip.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from fractions import Fraction
3-
from reasoning_gym.probability.coin_flip import CoinFlipDataset, CoinFlipConfig, CoinFlipCurriculum
3+
from reasoning_gym.probability import CoinFlipDataset, CoinFlipConfig, CoinFlipCurriculum
44

55
def test_coin_flip_config_validation():
66
"""Test that invalid configs raise errors"""
@@ -16,6 +16,9 @@ def test_coin_flip_config_validation():
1616
config = CoinFlipConfig(min_trials=5, max_trials=3)
1717
config.validate()
1818

19+
with pytest.raises(AssertionError):
20+
config = CoinFlipConfig(allow_exact=False, allow_at_least=False)
21+
config.validate()
1922

2023
def test_coin_flip_deterministic():
2124
"""Dataset generates same items with same seed"""
@@ -41,13 +44,13 @@ def test_coin_flip_items():
4144

4245

4346
metadata = item["metadata"]
44-
assert "num_tosses" in metadata
47+
assert "num_trials" in metadata
4548
assert "k_heads" in metadata
4649
assert "problem_type" in metadata
4750
assert metadata["problem_type"] in ["exact", "at_least"]
4851

4952
rational = metadata["rational"]
50-
assert rational["denominator"] == 2 ** metadata["num_tosses"]
53+
assert rational["denominator"] == 2 ** metadata["num_trials"]
5154
assert rational["numerator"] > 0
5255

5356

0 commit comments

Comments
 (0)