Skip to content

Commit 0c67e40

Browse files
felipemello1Felipe Mello
andauthored
Metric Logging updates 6/N - Enable SFT metrics / delete old file (#396)
Co-authored-by: Felipe Mello <[email protected]>
1 parent 53b9547 commit 0c67e40

22 files changed

+342
-1870
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ parallelism:
4646
checkpoint:
4747
enable: true
4848
folder: ./checkpoint # The folder to save checkpoints to.
49-
initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
49+
initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists.
5050
initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
5151
last_save_in_hf: true
5252
interval: 500
@@ -56,6 +56,12 @@ activation_checkpoint:
5656
mode: selective
5757
selective_ac_option: op
5858

59+
metric_logging:
60+
wandb:
61+
project: sft-training
62+
group: sft_exp_${oc.env:USER}
63+
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
64+
5965
# profiling:
6066
# enable_profiling: false
6167

apps/sft/main.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from forge.data.datasets.packed import PackedDataset, TextPacker
2828
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
2929
from forge.data.tokenizer import HuggingFaceModelTokenizer
30+
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
3031
from forge.util.config import parse
3132

3233
from monarch.actor import current_rank, current_size, endpoint
@@ -77,7 +78,6 @@ def __init__(self, config: DictConfig):
7778

7879
self.current_step = 0
7980
self.num_training_steps = job_config.training.steps
80-
self.metric_logger = None # TODO: fix this
8181
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
8282
self._rank = current_rank().rank
8383
self._size = math.prod(current_size().values())
@@ -109,9 +109,22 @@ def _init_dist(self):
109109
os.environ.update(env)
110110
logger.info("env: {}".format(env))
111111

112+
async def setup_metric_logger(self):
113+
"""Initialization happens in the main process. Here we just retrieve it"""
114+
mlogger = await get_or_create_metric_logger()
115+
return mlogger
116+
117+
def record_batch_metrics(self, data_metrics: list):
118+
"""Since the dataloader creates new processes, we dont call `record_metric` in the dataset.
119+
Instead, pop the metrics from the batch and record them here."""
120+
for metric in data_metrics:
121+
record_metric(metric.key, metric.value, metric.reduction)
122+
112123
@endpoint
113124
async def setup(self):
114125
self.train_dataloader = self.setup_data()
126+
self.mlogger = await self.setup_metric_logger()
127+
115128
# self.train_dataloader = self.setup_data(
116129
# self.train_config.train_dataset_config,
117130
# self.train_config.train_dataloader_config,
@@ -234,7 +247,9 @@ def train_step(self, batch) -> None:
234247
# ) as grad_acc:
235248
labels = batch.pop("labels")
236249
loss = self.forward_backward(batch, labels)
250+
loss = loss.item()
237251

252+
record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN)
238253
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
239254
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
240255
# self.pbar.update(1)
@@ -251,14 +266,25 @@ async def train(self) -> None:
251266

252267
while self.current_step < self.num_training_steps:
253268
batch = next(dataloader)
269+
270+
# Pop and record metrics from batch before moving to device
271+
self.record_batch_metrics(batch.pop("metrics", []))
272+
record_metric("ForgeSFTRecipe/train/step", self.current_step, Reduce.MEAN)
273+
254274
# Move tensors to the appropriate device
255275
for k, v in batch.items():
256276
if isinstance(v, torch.Tensor):
257277
batch[k] = v.to("cuda") # TODO: hardcoded for now
278+
258279
self.train_step(batch)
259280
# self.profiler.step()
260281
self.current_step += 1
261282

283+
# Flush metrics
284+
if self._rank == 0:
285+
logger.debug(f"Flushing metrics at step {self.current_step}")
286+
await self.mlogger.flush.call_one(global_step=self.current_step)
287+
262288
self.checkpointer.save(
263289
curr_step=self.current_step,
264290
last_step=self.current_step == self.num_training_steps,
@@ -270,16 +296,23 @@ async def train(self) -> None:
270296
async def cleanup(self) -> None:
271297
if self.checkpointer:
272298
self.checkpointer.close()
273-
if self.metric_logger:
274-
self.metric_logger.close()
299+
if getattr(self, "mlogger", None):
300+
await self.mlogger.shutdown.call_one()
275301

276302
def __repr__(self) -> str:
277303
return "Trainer"
278304

279305

280306
async def run(cfg: DictConfig) -> None:
281-
logging.info("Spawing recipe...")
307+
308+
logging.info("Spawning recipe...")
282309
process_cfg = cfg.pop("processes")
310+
311+
# Initialize metric logger in main process
312+
metric_logging_cfg = cfg.get("metric_logging", {})
313+
mlogger = await get_or_create_metric_logger(process_name="Controller")
314+
await mlogger.init_backends.call_one(metric_logging_cfg)
315+
283316
recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg)
284317

285318
logging.info("Created recipe, running setup.")
@@ -290,6 +323,7 @@ async def run(cfg: DictConfig) -> None:
290323

291324
logging.info("Done training. Clean up")
292325
await recipe.cleanup.call()
326+
293327
await recipe.mesh.stop()
294328
logging.info("All done!")
295329

apps/sft/qwen3_8b.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ parallelism:
4545
checkpoint:
4646
enable: true
4747
folder: ./checkpoint # The folder to save checkpoints to.
48-
initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
48+
initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists.
4949
initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
5050
last_save_in_hf: true
5151
interval: 500
@@ -55,6 +55,12 @@ activation_checkpoint:
5555
mode: selective
5656
selective_ac_option: op
5757

58+
metric_logging:
59+
wandb:
60+
project: sft-training
61+
group: sft_exp_${oc.env:USER}
62+
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
63+
5864
# profiling:
5965
# enable_profiling: false
6066

src/forge/data/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
from .collate import collate_packed
7+
from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
78
from .utils import CROSS_ENTROPY_IGNORE_IDX
89

9-
__all__ = ["collate_packed", "CROSS_ENTROPY_IGNORE_IDX"]
10+
__all__ = [
11+
"collate_packed",
12+
"CROSS_ENTROPY_IGNORE_IDX",
13+
"MetricTransform",
14+
"DefaultDatasetMetricTransform",
15+
]

src/forge/data/dataset_metrics/__init__.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

0 commit comments

Comments
 (0)