Skip to content

Commit 371e062

Browse files
committed
debug; blocked by wandb table upload bug
1 parent 6495802 commit 371e062

File tree

3 files changed

+75
-34
lines changed

3 files changed

+75
-34
lines changed

apps/grpo/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,6 @@ async def continuous_rollouts():
437437
await replay_buffer.add.call_one(episode)
438438
record_episode_sample("rollout/sample", episode)
439439

440-
record_metric("sample/", {}, Reduce.SAMPLE)
441440
# Log metrics
442441
rollout_count += 1
443442
record_metric(

src/forge/observability/metric_actors.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import logging
99
from typing import Any, Dict, Optional
1010

11-
from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
12-
1311
from forge.observability.metrics import (
1412
get_logger_backend_class,
1513
LoggerBackend,
@@ -20,6 +18,8 @@
2018

2119
from forge.observability.utils import detect_actor_name_from_call_stack
2220

21+
from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
22+
2323
logger = logging.getLogger(__name__)
2424

2525
_global_logger = None
@@ -365,11 +365,14 @@ def extract_values_from_valuemesh(results):
365365
return
366366

367367
# Reduce metrics from states
368-
reduced_metrics = reduce_metrics_states(all_local_states)
368+
reduced_metrics, reduced_samples = reduce_metrics_states(all_local_states)
369369

370370
# Log to global backends
371371
for backend_name, backend in self.global_logger_backends.items():
372-
await backend.log_batch(reduced_metrics, step)
372+
if reduced_metrics:
373+
await backend.log_batch(reduced_metrics, step)
374+
if reduced_samples:
375+
await backend.log_samples(reduced_samples, step)
373376

374377
@endpoint
375378
def has_fetcher(self, name: str | ProcMesh) -> bool:

src/forge/observability/metrics.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import heapq
8+
import itertools
89
import logging
910
import os
1011
import random
@@ -157,37 +158,55 @@ def record_episode_sample(key: str, episode):
157158
record_metric(key, sample, Reduce.SAMPLE)
158159

159160

160-
def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metric]:
161-
"""Reduce metric accumulators states to a list of metrics.
161+
def reduce_metrics_states(
162+
states: List[Dict[str, Dict[str, Any]]]
163+
) -> tuple[List[Metric], Dict[str, list[dict]]]:
164+
"""
165+
Reduce metric accumulator states across ranks into two groups:
166+
- scalar metrics (mean/sum/etc.)
167+
- sample metrics (list[dict])
162168
163-
Can be used when reducing metrics across ranks or services, as merging
164-
states is more precise than merging locally reduced metrics.
169+
This function merges metric accumulator states from multiple ranks or processes
170+
into final reduced values. It automatically distinguishes between scalar reductions
171+
(e.g., MEAN, SUM) and structured SAMPLE-type reductions (e.g., per-example dicts).
165172
166173
Args:
167174
states (List[Dict[str, Dict[str, Any]]]): List of state of one or more metrics,
168175
normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`.
169176
170177
Returns:
171-
List[Metric]: List of reduced metrics
178+
metrics: List[Metric], List of reduced metrics
179+
samples: Dict[str, list[dict]], {metric_key: merged_list_of_samples}
172180
173181
Example:
174-
states = [
175-
{"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
176-
{"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
177-
]
178-
reduce_metrics_states(states)
179-
>>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)]
182+
>>> states = [
183+
... {
184+
... "loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN},
185+
... "rollout/sample": {"reduction_type": "sample", "samples": [{"id": 1}]},
186+
... },
187+
... {
188+
... "loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN},
189+
... "rollout/sample": {"reduction_type": "sample", "samples": [{"id": 2}]},
190+
... },
191+
... ]
192+
>>> metrics, samples = reduce_metrics_states(states)
193+
>>> metrics
194+
Metric(key="loss", value=2.0, reduction=Reduce.MEAN)
195+
>>> samples
196+
{'rollout/sample': [{'id': 1}, {'id': 2}]}
180197
181198
Raises:
182199
ValueError: on mismatched reduction types for the same metric key.
183200
"""
184201
if not states:
185-
return []
202+
return [], {}
186203

187204
# Collect unique keys across all
188205
all_keys = set(k for state in states for k in state)
189206

190-
reduced_metrics = []
207+
samples: Dict[str, list[dict]] = {}
208+
reduced_metrics: List[Metric] = []
209+
191210
for key in all_keys:
192211
metric_states = [state.get(key) for state in states if key in state]
193212
if not metric_states:
@@ -215,7 +234,11 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metri
215234
)
216235
reduced_metrics.append(metric)
217236

218-
return reduced_metrics
237+
# Create sample list if this is a SAMPLE reduction
238+
if first_reduction_type == Reduce.SAMPLE.value:
239+
samples[key] = reduced_value
240+
241+
return reduced_metrics, samples
219242

220243

221244
#################
@@ -290,36 +313,39 @@ def __init__(self, top_k=1, bottom_k=1, key="reward"):
290313
self.key = key
291314
self._top_heap = [] # min-heap for top-k
292315
self._bottom_heap = [] # max-heap for bottom-k (store -value)
316+
self._counter = itertools.count() # tie-breaker id generator
293317

294318
def filter_append(self, sample: Dict) -> bool:
295319
val = sample.get(self.key, 0.0)
320+
idx = next(self._counter) # unique tiebreaker
296321

297322
# If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
298323
# maintain top-k
299324
if self.top_k > 0:
300325
if len(self._top_heap) < self.top_k:
301-
heapq.heappush(self._top_heap, (val, sample))
326+
heapq.heappush(self._top_heap, (val, idx, sample))
302327
else:
303-
heapq.heappushpop(self._top_heap, (val, sample))
328+
heapq.heappushpop(self._top_heap, (val, idx, sample))
304329

305330
# maintain bottom-k
306331
if self.bottom_k > 0:
307332
if len(self._bottom_heap) < self.bottom_k:
308-
heapq.heappush(self._bottom_heap, (-val, sample))
333+
heapq.heappush(self._bottom_heap, (-val, idx, sample))
309334
else:
310-
heapq.heappushpop(self._bottom_heap, (-val, sample))
335+
heapq.heappushpop(self._bottom_heap, (-val, idx, sample))
311336

312337
# always return False here because we don't store in samples list
313338
return False
314339

315340
def filter_flush(self, samples: List[Dict]) -> List[Dict]:
316-
tops = [s for _, s in self._top_heap]
317-
bottoms = [s for _, s in self._bottom_heap]
341+
tops = [s for _, _, s in self._top_heap]
342+
bottoms = [s for _, _, s in self._bottom_heap]
318343
return bottoms + tops
319344

320345
def reset(self):
321346
self._top_heap = []
322347
self._bottom_heap = []
348+
self._counter = itertools.count()
323349

324350

325351
################
@@ -766,11 +792,14 @@ async def flush(
766792

767793
# Log to PER_RANK_REDUCE backends only (NO_REDUCE already logged in push)
768794
if self.per_rank_reduce_backends:
769-
metrics_for_backends = reduce_metrics_states([states])
795+
reduced_metrics, reduced_samples = reduce_metrics_states([states])
770796

771797
# Log to PER_RANK_REDUCE backends
772798
for backend in self.per_rank_reduce_backends:
773-
await backend.log_batch(metrics_for_backends, step)
799+
if reduced_metrics:
800+
await backend.log_batch(reduced_metrics, step)
801+
if reduced_samples:
802+
await backend.log_samples(reduced_samples, step)
774803

775804
return states if return_state else {}
776805

@@ -840,6 +869,9 @@ def log_stream(self, metric, step):
840869
"""
841870
pass
842871

872+
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
873+
pass
874+
843875
async def finish(self) -> None:
844876
pass
845877

@@ -881,13 +913,13 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
881913
"""Pretty-print sample-level logs to console."""
882914
if not samples:
883915
return
884-
import pprint
916+
import json
885917

886918
logger.info(f"=== [{self.prefix}] - SAMPLE LOGS STEP {step} ===")
887919
for key, rows in samples.items():
888920
logger.info(f"[{key}] ({len(rows)} samples)")
889921
for sample in rows:
890-
pretty = pprint.pformat(sample, indent=4, width=120, compact=True)
922+
pretty = json.dumps(sample, indent=2, ensure_ascii=False)
891923
logger.info(pretty)
892924
logger.info("==============================================\n")
893925

@@ -1027,18 +1059,25 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
10271059

10281060
if not self.run or not samples:
10291061
return
1030-
10311062
for key, rows in samples.items():
10321063
if not rows:
10331064
continue
1034-
10351065
# Create a WandB Table dynamically based on keys of first sample
10361066
columns = list(rows[0].keys())
10371067
table = wandb.Table(columns=columns)
10381068
for sample in rows:
1039-
table.add_data(*[sample.get(c) for c in columns])
1040-
1041-
self.run.log({f"{key}_table": table, "global_step": step})
1069+
# table.add_data(*[sample.get(c) for c in columns])
1070+
values = [sample.get(c) for c in columns]
1071+
logger.info(f"Adding row to {key}_table: {values}")
1072+
table.add_data(*values)
1073+
self.run.log(
1074+
{
1075+
f"{key}_step_{step}_table": table,
1076+
"_sample_rows_logged": len(rows),
1077+
"global_step": step,
1078+
},
1079+
commit=True,
1080+
)
10421081
logger.info(
10431082
f"WandbBackend: Logged {len(rows)} samples for {key} at step {step}"
10441083
)

0 commit comments

Comments
 (0)