|
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 | import heapq |
| 8 | +import itertools |
8 | 9 | import logging |
9 | 10 | import os |
10 | 11 | import random |
@@ -157,37 +158,55 @@ def record_episode_sample(key: str, episode): |
157 | 158 | record_metric(key, sample, Reduce.SAMPLE) |
158 | 159 |
|
159 | 160 |
|
160 | | -def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metric]: |
161 | | - """Reduce metric accumulators states to a list of metrics. |
| 161 | +def reduce_metrics_states( |
| 162 | + states: List[Dict[str, Dict[str, Any]]] |
| 163 | +) -> tuple[List[Metric], Dict[str, list[dict]]]: |
| 164 | + """ |
| 165 | + Reduce metric accumulator states across ranks into two groups: |
| 166 | + - scalar metrics (mean/sum/etc.) |
| 167 | + - sample metrics (list[dict]) |
162 | 168 |
|
163 | | - Can be used when reducing metrics across ranks or services, as merging |
164 | | - states is more precise than merging locally reduced metrics. |
| 169 | + This function merges metric accumulator states from multiple ranks or processes |
| 170 | + into final reduced values. It automatically distinguishes between scalar reductions |
| 171 | + (e.g., MEAN, SUM) and structured SAMPLE-type reductions (e.g., per-example dicts). |
165 | 172 |
|
166 | 173 | Args: |
167 | 174 | states (List[Dict[str, Dict[str, Any]]]): List of state of one or more metrics, |
168 | 175 | normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. |
169 | 176 |
|
170 | 177 | Returns: |
171 | | - List[Metric]: List of reduced metrics |
| 178 | + metrics: List[Metric], List of reduced metrics |
| 179 | + samples: Dict[str, list[dict]], {metric_key: merged_list_of_samples} |
172 | 180 |
|
173 | 181 | Example: |
174 | | - states = [ |
175 | | - {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}}, |
176 | | - {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, |
177 | | - ] |
178 | | - reduce_metrics_states(states) |
179 | | - >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] |
| 182 | + >>> states = [ |
| 183 | + ... { |
| 184 | + ... "loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}, |
| 185 | + ... "rollout/sample": {"reduction_type": "sample", "samples": [{"id": 1}]}, |
| 186 | + ... }, |
| 187 | + ... { |
| 188 | + ... "loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}, |
| 189 | + ... "rollout/sample": {"reduction_type": "sample", "samples": [{"id": 2}]}, |
| 190 | + ... }, |
| 191 | + ... ] |
| 192 | + >>> metrics, samples = reduce_metrics_states(states) |
| 193 | + >>> metrics |
| 194 | + Metric(key="loss", value=2.0, reduction=Reduce.MEAN) |
| 195 | + >>> samples |
| 196 | + {'rollout/sample': [{'id': 1}, {'id': 2}]} |
180 | 197 |
|
181 | 198 | Raises: |
182 | 199 | ValueError: on mismatched reduction types for the same metric key. |
183 | 200 | """ |
184 | 201 | if not states: |
185 | | - return [] |
| 202 | + return [], {} |
186 | 203 |
|
187 | 204 | # Collect unique keys across all |
188 | 205 | all_keys = set(k for state in states for k in state) |
189 | 206 |
|
190 | | - reduced_metrics = [] |
| 207 | + samples: Dict[str, list[dict]] = {} |
| 208 | + reduced_metrics: List[Metric] = [] |
| 209 | + |
191 | 210 | for key in all_keys: |
192 | 211 | metric_states = [state.get(key) for state in states if key in state] |
193 | 212 | if not metric_states: |
@@ -215,7 +234,11 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metri |
215 | 234 | ) |
216 | 235 | reduced_metrics.append(metric) |
217 | 236 |
|
218 | | - return reduced_metrics |
| 237 | + # Create sample list if this is a SAMPLE reduction |
| 238 | + if first_reduction_type == Reduce.SAMPLE.value: |
| 239 | + samples[key] = reduced_value |
| 240 | + |
| 241 | + return reduced_metrics, samples |
219 | 242 |
|
220 | 243 |
|
221 | 244 | ################# |
@@ -290,36 +313,39 @@ def __init__(self, top_k=1, bottom_k=1, key="reward"): |
290 | 313 | self.key = key |
291 | 314 | self._top_heap = [] # min-heap for top-k |
292 | 315 | self._bottom_heap = [] # max-heap for bottom-k (store -value) |
| 316 | + self._counter = itertools.count() # tie-breaker id generator |
293 | 317 |
|
294 | 318 | def filter_append(self, sample: Dict) -> bool: |
295 | 319 | val = sample.get(self.key, 0.0) |
| 320 | + idx = next(self._counter) # unique tiebreaker |
296 | 321 |
|
297 | 322 | # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none). |
298 | 323 | # maintain top-k |
299 | 324 | if self.top_k > 0: |
300 | 325 | if len(self._top_heap) < self.top_k: |
301 | | - heapq.heappush(self._top_heap, (val, sample)) |
| 326 | + heapq.heappush(self._top_heap, (val, idx, sample)) |
302 | 327 | else: |
303 | | - heapq.heappushpop(self._top_heap, (val, sample)) |
| 328 | + heapq.heappushpop(self._top_heap, (val, idx, sample)) |
304 | 329 |
|
305 | 330 | # maintain bottom-k |
306 | 331 | if self.bottom_k > 0: |
307 | 332 | if len(self._bottom_heap) < self.bottom_k: |
308 | | - heapq.heappush(self._bottom_heap, (-val, sample)) |
| 333 | + heapq.heappush(self._bottom_heap, (-val, idx, sample)) |
309 | 334 | else: |
310 | | - heapq.heappushpop(self._bottom_heap, (-val, sample)) |
| 335 | + heapq.heappushpop(self._bottom_heap, (-val, idx, sample)) |
311 | 336 |
|
312 | 337 | # always return False here because we don't store in samples list |
313 | 338 | return False |
314 | 339 |
|
315 | 340 | def filter_flush(self, samples: List[Dict]) -> List[Dict]: |
316 | | - tops = [s for _, s in self._top_heap] |
317 | | - bottoms = [s for _, s in self._bottom_heap] |
| 341 | + tops = [s for _, _, s in self._top_heap] |
| 342 | + bottoms = [s for _, _, s in self._bottom_heap] |
318 | 343 | return bottoms + tops |
319 | 344 |
|
320 | 345 | def reset(self): |
321 | 346 | self._top_heap = [] |
322 | 347 | self._bottom_heap = [] |
| 348 | + self._counter = itertools.count() |
323 | 349 |
|
324 | 350 |
|
325 | 351 | ################ |
@@ -766,11 +792,14 @@ async def flush( |
766 | 792 |
|
767 | 793 | # Log to PER_RANK_REDUCE backends only (NO_REDUCE already logged in push) |
768 | 794 | if self.per_rank_reduce_backends: |
769 | | - metrics_for_backends = reduce_metrics_states([states]) |
| 795 | + reduced_metrics, reduced_samples = reduce_metrics_states([states]) |
770 | 796 |
|
771 | 797 | # Log to PER_RANK_REDUCE backends |
772 | 798 | for backend in self.per_rank_reduce_backends: |
773 | | - await backend.log_batch(metrics_for_backends, step) |
| 799 | + if reduced_metrics: |
| 800 | + await backend.log_batch(reduced_metrics, step) |
| 801 | + if reduced_samples: |
| 802 | + await backend.log_samples(reduced_samples, step) |
774 | 803 |
|
775 | 804 | return states if return_state else {} |
776 | 805 |
|
@@ -840,6 +869,9 @@ def log_stream(self, metric, step): |
840 | 869 | """ |
841 | 870 | pass |
842 | 871 |
|
| 872 | + async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: |
| 873 | + pass |
| 874 | + |
843 | 875 | async def finish(self) -> None: |
844 | 876 | pass |
845 | 877 |
|
@@ -881,13 +913,13 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: |
881 | 913 | """Pretty-print sample-level logs to console.""" |
882 | 914 | if not samples: |
883 | 915 | return |
884 | | - import pprint |
| 916 | + import json |
885 | 917 |
|
886 | 918 | logger.info(f"=== [{self.prefix}] - SAMPLE LOGS STEP {step} ===") |
887 | 919 | for key, rows in samples.items(): |
888 | 920 | logger.info(f"[{key}] ({len(rows)} samples)") |
889 | 921 | for sample in rows: |
890 | | - pretty = pprint.pformat(sample, indent=4, width=120, compact=True) |
| 922 | + pretty = json.dumps(sample, indent=2, ensure_ascii=False) |
891 | 923 | logger.info(pretty) |
892 | 924 | logger.info("==============================================\n") |
893 | 925 |
|
@@ -1027,18 +1059,25 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: |
1027 | 1059 |
|
1028 | 1060 | if not self.run or not samples: |
1029 | 1061 | return |
1030 | | - |
1031 | 1062 | for key, rows in samples.items(): |
1032 | 1063 | if not rows: |
1033 | 1064 | continue |
1034 | | - |
1035 | 1065 | # Create a WandB Table dynamically based on keys of first sample |
1036 | 1066 | columns = list(rows[0].keys()) |
1037 | 1067 | table = wandb.Table(columns=columns) |
1038 | 1068 | for sample in rows: |
1039 | | - table.add_data(*[sample.get(c) for c in columns]) |
1040 | | - |
1041 | | - self.run.log({f"{key}_table": table, "global_step": step}) |
| 1069 | + # table.add_data(*[sample.get(c) for c in columns]) |
| 1070 | + values = [sample.get(c) for c in columns] |
| 1071 | + logger.info(f"Adding row to {key}_table: {values}") |
| 1072 | + table.add_data(*values) |
| 1073 | + self.run.log( |
| 1074 | + { |
| 1075 | + f"{key}_step_{step}_table": table, |
| 1076 | + "_sample_rows_logged": len(rows), |
| 1077 | + "global_step": step, |
| 1078 | + }, |
| 1079 | + commit=True, |
| 1080 | + ) |
1042 | 1081 | logger.info( |
1043 | 1082 | f"WandbBackend: Logged {len(rows)} samples for {key} at step {step}" |
1044 | 1083 | ) |
|
0 commit comments