|  | 
| 8 | 8 | import itertools | 
| 9 | 9 | import logging | 
| 10 | 10 | import os | 
| 11 |  | -import random | 
| 12 | 11 | from abc import ABC, abstractmethod | 
| 13 | 12 | from dataclasses import dataclass | 
| 14 | 13 | from datetime import datetime | 
| @@ -133,7 +132,6 @@ def record_episode_sample(key: str, episode): | 
| 133 | 132 |     Args: | 
| 134 | 133 |         key (str): logging prefix (e.g. "rollout/sample"). | 
| 135 | 134 |         episode (Episode): episode object with filled attributes. | 
| 136 |  | -        reward_breakdown (dict[str, float]): per-function rewards, e.g. {"MathReward": 0.8, "FormatReward": 1.0}. | 
| 137 | 135 |     """ | 
| 138 | 136 |     sample = { | 
| 139 | 137 |         "episode_id": episode.episode_id, | 
| @@ -246,65 +244,7 @@ def reduce_metrics_states( | 
| 246 | 244 | ################# | 
| 247 | 245 | 
 | 
| 248 | 246 | 
 | 
| 249 |  | -class SampleFilter(ABC): | 
| 250 |  | -    """Abstract base class for sample filtering.""" | 
| 251 |  | - | 
| 252 |  | -    @abstractmethod | 
| 253 |  | -    def filter_append(self, sample: Dict) -> bool: | 
| 254 |  | -        """ | 
| 255 |  | -        Decide whether a sample should be kept at append time. | 
| 256 |  | -        Return True if the sample should be stored, False otherwise. | 
| 257 |  | -        """ | 
| 258 |  | -        pass | 
| 259 |  | - | 
| 260 |  | -    def filter_flush(self, samples: List[Dict]) -> List[Dict]: | 
| 261 |  | -        """ | 
| 262 |  | -        Optionally filter or transform the collected samples at flush time. | 
| 263 |  | -        Default: return the samples unchanged. | 
| 264 |  | -        """ | 
| 265 |  | -        return samples | 
| 266 |  | - | 
| 267 |  | -    def reset(self) -> None: | 
| 268 |  | -        """Clears for next accumulation cycle.""" | 
| 269 |  | -        pass | 
| 270 |  | - | 
| 271 |  | - | 
| 272 |  | -class RandomRatioFilter: | 
| 273 |  | -    """Randomly keep a fraction of samples.""" | 
| 274 |  | - | 
| 275 |  | -    def __init__(self, ratio=0.05): | 
| 276 |  | -        self.ratio = ratio | 
| 277 |  | - | 
| 278 |  | -    def filter_append(self, sample): | 
| 279 |  | -        return random.random() < self.ratio | 
| 280 |  | - | 
| 281 |  | - | 
| 282 |  | -class RewardThresholdFilter: | 
| 283 |  | -    """ | 
| 284 |  | -    Keep samples only if their reward is < lt or > gt (depending on which bound is set). | 
| 285 |  | -    If a bound is None, that side of the filter is disabled. | 
| 286 |  | -    """ | 
| 287 |  | - | 
| 288 |  | -    def __init__(self, lt=None, gt=None): | 
| 289 |  | -        self.lt = lt | 
| 290 |  | -        self.gt = gt | 
| 291 |  | - | 
| 292 |  | -    def filter_append(self, sample): | 
| 293 |  | -        r = sample.get("reward", 0.0) | 
| 294 |  | - | 
| 295 |  | -        # If lt is set: drop samples with reward >= lt | 
| 296 |  | -        if self.lt is not None and r >= self.lt: | 
| 297 |  | -            return False | 
| 298 |  | - | 
| 299 |  | -        # If gt is set: drop samples with reward <= gt | 
| 300 |  | -        if self.gt is not None and r <= self.gt: | 
| 301 |  | -            return False | 
| 302 |  | - | 
| 303 |  | -        # Otherwise, keep this sample | 
| 304 |  | -        return True | 
| 305 |  | - | 
| 306 |  | - | 
| 307 |  | -class TopBottomKFilter(SampleFilter): | 
|  | 247 | +class TopBottomKFilter: | 
| 308 | 248 |     """Keep the top-k and bottom-k samples by a given key (e.g., reward).""" | 
| 309 | 249 | 
 | 
| 310 | 250 |     def __init__(self, top_k=1, bottom_k=1, key="reward"): | 
| @@ -539,18 +479,19 @@ def reset(self) -> None: | 
| 539 | 479 | class SampleAccumulator(MetricAccumulator): | 
| 540 | 480 |     """Accumulator for sample-level metrics (e.g., prompt/response/reward dicts). | 
| 541 | 481 | 
 | 
| 542 |  | -    Optionally uses a SampleFilter to decide what to keep at append/flush time. | 
|  | 482 | +    Optionally uses a sample filter to decide what to keep at append/flush time. | 
| 543 | 483 |     """ | 
| 544 | 484 | 
 | 
| 545 | 485 |     def __init__( | 
| 546 |  | -        self, reduction: Reduce, filter: SampleFilter | None = TopBottomKFilter() | 
|  | 486 | +        self, reduction: Reduce, filter: TopBottomKFilter | None = TopBottomKFilter() | 
| 547 | 487 |     ): | 
| 548 | 488 |         super().__init__(reduction) | 
| 549 | 489 |         self.samples: List[Dict[str, Any]] = [] | 
| 550 | 490 |         self.filter = filter | 
| 551 | 491 | 
 | 
| 552 | 492 |     def append(self, value: dict) -> None: | 
| 553 |  | -        assert isinstance(value, dict) | 
|  | 493 | +        if not isinstance(value, dict): | 
|  | 494 | +            raise ValueError(f"Expected dict, got {type(value)}") | 
| 554 | 495 | 
 | 
| 555 | 496 |         # If filter is provided, only keep the sample if filter_append returns True | 
| 556 | 497 |         if self.filter: | 
|  | 
0 commit comments