Skip to content

Commit 492df7a

Browse files
committed
a working version
1 parent 307e037 commit 492df7a

File tree

3 files changed

+42
-148
lines changed

3 files changed

+42
-148
lines changed

src/forge/observability/metrics.py

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import itertools
99
import logging
1010
import os
11-
import random
1211
from abc import ABC, abstractmethod
1312
from dataclasses import dataclass
1413
from datetime import datetime
@@ -133,7 +132,6 @@ def record_episode_sample(key: str, episode):
133132
Args:
134133
key (str): logging prefix (e.g. "rollout/sample").
135134
episode (Episode): episode object with filled attributes.
136-
reward_breakdown (dict[str, float]): per-function rewards, e.g. {"MathReward": 0.8, "FormatReward": 1.0}.
137135
"""
138136
sample = {
139137
"episode_id": episode.episode_id,
@@ -246,65 +244,7 @@ def reduce_metrics_states(
246244
#################
247245

248246

249-
class SampleFilter(ABC):
250-
"""Abstract base class for sample filtering."""
251-
252-
@abstractmethod
253-
def filter_append(self, sample: Dict) -> bool:
254-
"""
255-
Decide whether a sample should be kept at append time.
256-
Return True if the sample should be stored, False otherwise.
257-
"""
258-
pass
259-
260-
def filter_flush(self, samples: List[Dict]) -> List[Dict]:
261-
"""
262-
Optionally filter or transform the collected samples at flush time.
263-
Default: return the samples unchanged.
264-
"""
265-
return samples
266-
267-
def reset(self) -> None:
268-
"""Clears for next accumulation cycle."""
269-
pass
270-
271-
272-
class RandomRatioFilter:
273-
"""Randomly keep a fraction of samples."""
274-
275-
def __init__(self, ratio=0.05):
276-
self.ratio = ratio
277-
278-
def filter_append(self, sample):
279-
return random.random() < self.ratio
280-
281-
282-
class RewardThresholdFilter:
283-
"""
284-
Keep samples only if their reward is < lt or > gt (depending on which bound is set).
285-
If a bound is None, that side of the filter is disabled.
286-
"""
287-
288-
def __init__(self, lt=None, gt=None):
289-
self.lt = lt
290-
self.gt = gt
291-
292-
def filter_append(self, sample):
293-
r = sample.get("reward", 0.0)
294-
295-
# If lt is set: drop samples with reward >= lt
296-
if self.lt is not None and r >= self.lt:
297-
return False
298-
299-
# If gt is set: drop samples with reward <= gt
300-
if self.gt is not None and r <= self.gt:
301-
return False
302-
303-
# Otherwise, keep this sample
304-
return True
305-
306-
307-
class TopBottomKFilter(SampleFilter):
247+
class TopBottomKFilter:
308248
"""Keep the top-k and bottom-k samples by a given key (e.g., reward)."""
309249

310250
def __init__(self, top_k=1, bottom_k=1, key="reward"):
@@ -539,18 +479,19 @@ def reset(self) -> None:
539479
class SampleAccumulator(MetricAccumulator):
540480
"""Accumulator for sample-level metrics (e.g., prompt/response/reward dicts).
541481
542-
Optionally uses a SampleFilter to decide what to keep at append/flush time.
482+
Optionally uses a sample filter to decide what to keep at append/flush time.
543483
"""
544484

545485
def __init__(
546-
self, reduction: Reduce, filter: SampleFilter | None = TopBottomKFilter()
486+
self, reduction: Reduce, filter: TopBottomKFilter | None = TopBottomKFilter()
547487
):
548488
super().__init__(reduction)
549489
self.samples: List[Dict[str, Any]] = []
550490
self.filter = filter
551491

552492
def append(self, value: dict) -> None:
553-
assert isinstance(value, dict)
493+
if not isinstance(value, dict):
494+
raise ValueError(f"Expected dict, got {type(value)}")
554495

555496
# If filter is provided, only keep the sample if filter_append returns True
556497
if self.filter:

tests/unit_tests/data/test_metrics_aggregator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,43 @@ def test_handler_replacement_warning(self, caplog):
246246
assert len(caplog.records) == 1
247247
assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message
248248

249+
def test_sample_accumulator_with_topbottom_filter(self):
250+
"""Ensure SampleAccumulator integrates with TopBottomKFilter correctly."""
251+
from forge.observability.metrics import (
252+
Reduce,
253+
SampleAccumulator,
254+
TopBottomKFilter,
255+
)
256+
257+
f = TopBottomKFilter(top_k=2, bottom_k=1, key="reward")
258+
acc = SampleAccumulator(Reduce.SAMPLE, filter=f)
259+
260+
rewards = [0.1, 0.9, 0.5, 0.7, 0.3]
261+
for r in rewards:
262+
acc.append({"reward": r, "prompt": "Q", "response": "A"})
263+
264+
result = acc.get_value()
265+
result_rewards = sorted(s["reward"] for s in result)
266+
267+
# Expect bottom-1 (0.1) and top-2 (0.7, 0.9)
268+
assert result_rewards == [0.1, 0.7, 0.9]
269+
270+
def test_sample_accumulator_no_filter_returns_all(self):
271+
"""Ensure SampleAccumulator without a filter returns all samples."""
272+
from forge.observability.metrics import Reduce, SampleAccumulator
273+
274+
acc = SampleAccumulator(Reduce.SAMPLE, filter=None)
275+
276+
samples = [
277+
{"reward": r, "prompt": "Q", "response": "A"} for r in [0.2, 0.4, 0.6]
278+
]
279+
for s in samples:
280+
acc.append(s)
281+
282+
result = acc.get_value()
283+
assert len(result) == len(samples)
284+
assert [s["reward"] for s in result] == [0.2, 0.4, 0.6]
285+
249286

250287
class TestDistributedMetricsAggregator(FSDPTest):
251288
"""Distributed tests for MetricsAggregator using FSDPTest infrastructure."""

tests/unit_tests/data/test_metrics_sampler.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

0 commit comments

Comments
 (0)