Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
77488cf
commit
Oct 8, 2025
feb4771
commit
Oct 8, 2025
41ceaa4
update backend role typehints and enum
Oct 8, 2025
8a24e71
update where we check FORGE_DISABLE_METRICS
Oct 8, 2025
3f3bc51
remove protected import
Oct 8, 2025
d82c354
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
4fe2611
protect import
Oct 8, 2025
8759bc8
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
fbb4a9e
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
d81a4ed
record_metric uses dataclass Metric
Oct 8, 2025
1e2255d
commit
Oct 8, 2025
a94c612
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
5b477e8
commit
Oct 9, 2025
f2b3eed
commit
Oct 9, 2025
471b88a
revert
Oct 9, 2025
1a02784
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
fa4895f
remove unnecessary code
Oct 9, 2025
7bb1fe7
better logging
Oct 9, 2025
43d5d27
docs/names
Oct 9, 2025
c97eb98
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
75355a2
commit
Oct 9, 2025
70e9c67
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 9, 2025
12f77c9
Merge branch 'timestamp_logging_diff3' into timestamp_logging_diff4
Oct 9, 2025
1186aec
update cfg back to true
Oct 9, 2025
a02ea75
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 13, 2025
aa00898
Merge branch 'timestamp_logging_diff3' into timestamp_logging_diff4
Oct 13, 2025
b75aa31
tests pass
Oct 13, 2025
192d32e
Merge branch 'main' of https://github.com/meta-pytorch/forge into sft…
Oct 13, 2025
57877da
main works
Oct 13, 2025
2bd3b35
docs and naming
Oct 14, 2025
5d68cf6
Merge branch 'main' of https://github.com/meta-pytorch/forge into sft…
Oct 21, 2025
6ec9733
nits
Oct 22, 2025
710703e
Merge branch 'main' of https://github.com/meta-pytorch/forge into sft…
Oct 22, 2025
7838dc4
fix tests
Oct 22, 2025
766319c
config
Oct 22, 2025
ac8ad72
delete old tests and update cfg
Oct 22, 2025
afe7e7a
update cfg
Oct 22, 2025
661b7b6
loss.item
Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ parallelism:
checkpoint:
enable: true
folder: ./checkpoint # The folder to save checkpoints to.
initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists.
initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
Expand All @@ -56,6 +56,12 @@ activation_checkpoint:
mode: selective
selective_ac_option: op

metric_logging:
wandb:
project: sft-training
group: sft_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce

# profiling:
# enable_profiling: false

Expand Down
42 changes: 38 additions & 4 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
from forge.util.config import parse

from monarch.actor import current_rank, current_size, endpoint
Expand Down Expand Up @@ -77,7 +78,6 @@ def __init__(self, config: DictConfig):

self.current_step = 0
self.num_training_steps = job_config.training.steps
self.metric_logger = None # TODO: fix this
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())
Expand Down Expand Up @@ -109,9 +109,22 @@ def _init_dist(self):
os.environ.update(env)
logger.info("env: {}".format(env))

async def setup_metric_logger(self):
"""Initialization happens in the main process. Here we just retrieve it"""
mlogger = await get_or_create_metric_logger()
return mlogger

def record_batch_metrics(self, data_metrics: list):
"""Since the dataloader creates new processes, we dont call `record_metric` in the dataset.
Instead, pop the metrics from the batch and record them here."""
for metric in data_metrics:
record_metric(metric.key, metric.value, metric.reduction)

@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()
self.mlogger = await self.setup_metric_logger()

# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
Expand Down Expand Up @@ -234,7 +247,9 @@ def train_step(self, batch) -> None:
# ) as grad_acc:
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels)
loss = loss.item()

record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN)
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)
Expand All @@ -251,14 +266,25 @@ async def train(self) -> None:

while self.current_step < self.num_training_steps:
batch = next(dataloader)

# Pop and record metrics from batch before moving to device
self.record_batch_metrics(batch.pop("metrics", []))
record_metric("ForgeSFTRecipe/train/step", self.current_step, Reduce.MEAN)

# Move tensors to the appropriate device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to("cuda") # TODO: hardcoded for now

self.train_step(batch)
# self.profiler.step()
self.current_step += 1

# Flush metrics
if self._rank == 0:
logger.debug(f"Flushing metrics at step {self.current_step}")
await self.mlogger.flush.call_one(global_step=self.current_step)

self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
Expand All @@ -270,16 +296,23 @@ async def train(self) -> None:
async def cleanup(self) -> None:
if self.checkpointer:
self.checkpointer.close()
if self.metric_logger:
self.metric_logger.close()
if getattr(self, "mlogger", None):
await self.mlogger.shutdown.call_one()

def __repr__(self) -> str:
return "Trainer"


async def run(cfg: DictConfig) -> None:
logging.info("Spawing recipe...")

logging.info("Spawning recipe...")
process_cfg = cfg.pop("processes")

# Initialize metric logger in main process
metric_logging_cfg = cfg.get("metric_logging", {})
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)

recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg)

logging.info("Created recipe, running setup.")
Expand All @@ -290,6 +323,7 @@ async def run(cfg: DictConfig) -> None:

logging.info("Done training. Clean up")
await recipe.cleanup.call()

await recipe.mesh.stop()
logging.info("All done!")

Expand Down
8 changes: 7 additions & 1 deletion apps/sft/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ parallelism:
checkpoint:
enable: true
folder: ./checkpoint # The folder to save checkpoints to.
initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists.
initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
Expand All @@ -55,6 +55,12 @@ activation_checkpoint:
mode: selective
selective_ac_option: op

metric_logging:
wandb:
project: sft-training
group: sft_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce

# profiling:
# enable_profiling: false

Expand Down
8 changes: 7 additions & 1 deletion src/forge/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .collate import collate_packed
from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
from .utils import CROSS_ENTROPY_IGNORE_IDX

__all__ = ["collate_packed", "CROSS_ENTROPY_IGNORE_IDX"]
__all__ = [
"collate_packed",
"CROSS_ENTROPY_IGNORE_IDX",
"MetricTransform",
"DefaultDatasetMetricTransform",
]
39 changes: 0 additions & 39 deletions src/forge/data/dataset_metrics/__init__.py

This file was deleted.

Loading
Loading