diff --git a/recipes/dev/grpo_full_finetune_distributed.py b/recipes/dev/grpo_full_finetune_distributed.py index 145f5cd661..f75f3de073 100644 --- a/recipes/dev/grpo_full_finetune_distributed.py +++ b/recipes/dev/grpo_full_finetune_distributed.py @@ -21,7 +21,7 @@ from torchtune.config._utils import _get_component_from_path from torchtune.datasets import ConcatDataset from torchtune.dev.rl.generation import generate -from torchtune.dev.rl.rewards import batched_rewards +from torchtune.dev.grpo.rewards import batch_shaped_correctness_reward from torchtune.dev.rl.types import GRPOStats, GRPOTrajectory from torchtune.modules import local_kv_cache from torchtune.recipe_interfaces import FTRecipeInterface @@ -646,7 +646,7 @@ def generate_trajectory( # Do some reward modelingggggggg # responses :: [B x G, L] responses = responses.reshape(batch_size, grpo_size, -1) # [B, G, L] - rewards, successes = batched_rewards(self._tokenizer, responses, answers) + rewards, successes = batch_shaped_correctness_reward(self._tokenizer, responses, answers) rewards = rewards.to(self._device) # [B, G] successes = successes.to(self._device) # [B, G] diff --git a/recipes/dev/grpo_full_finetune_single_device.py b/recipes/dev/grpo_full_finetune_single_device.py new file mode 100644 index 0000000000..06391eabfc --- /dev/null +++ b/recipes/dev/grpo_full_finetune_single_device.py @@ -0,0 +1,920 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time +from functools import partial +from typing import Any, Optional, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig +from torch import nn +from torch.optim import Optimizer +from torchdata.stateful_dataloader import StatefulDataLoader +from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler +from torchtune import config, generation, modules, rlhf, training, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.datasets import ConcatDataset +from torchtune.dev.rl.generation import generate +from torchtune.dev.grpo.rewards import batch_shaped_correctness_reward +from torchtune.dev.rl.types import GRPOStats, GRPOTrajectory +from torchtune.modules import local_kv_cache +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY +from torchtune.training.lr_schedulers import get_lr +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class GRPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): + """ + Full finetuning recipe for RLHF with GRPO for dense transformer-based LLMs such as Llama2. + This recipe is optimized for single GPU training. Training on CPU is not supported. + + Features: + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Adjusting batch sizes when memory constrained. This recipe uses different batch sizes: + - ``batch_size`` controls the total number of samples which are sampled from the dataset for a single trajectory. + - ``forward_batch_size`` controls the mini-batch size for trajectory generation. Since gradients are disabled + during trajectory generation, memory consumption is lower and this can be higher than the main batch size. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + - Optimizer in Backward. Fusing the optimizer step into the backward pass helps reduce the memory + footprint associated with gradients. This can be especially helpful when you are memory + constrained. Note that users can only use ONE of gradient accumulation or optimizer in backward. + These features currently do not work together. + + - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes + library. We've tested the recipe with 8-bit AdamW and Paged AdamW. These optimizers are especially + helpful when you are memory constrained since they help reduce the memory footprint associated + with the optimizer states. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch, and at the end of + training. Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + RuntimeError: If ``dtype`` is set to fp16. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + self._output_dir = cfg.output_dir + + # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor + # enabled necessary features such as gradient scaling. + if self._dtype == torch.float16: + raise RuntimeError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # Logging attributes + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + # Training attributes + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._compile = cfg.get("compile", False) + + # Recipe state attributes + self.seed = training.set_seed(seed=cfg.seed) + self.total_epochs = cfg.epochs + self.global_step = 0 + self._steps_run = 0 + self._total_steps = 0 + self._epochs_run = 0 + self._rng = torch.Generator(self._device).manual_seed(self.seed) + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self._epochs_run = ckpt_dict[training.EPOCHS_KEY] + self._rng.set_state(ckpt_dict[training.RNG_KEY]) + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe. This includes training state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + self._metric_logger.log_config(cfg) + + # Setup model to train + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + model_sd=checkpoint_dict[training.MODEL_KEY], + ) + + # Setup reference model + ref_checkpoint_dict = self.load_checkpoint( + cfg_checkpointer=cfg.ref_checkpointer + ) + self._ref_model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + model_sd=ref_checkpoint_dict[training.MODEL_KEY], + eval_mode=True, + ) + + # Utilize the same tokenizer for both models + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + if self._compile: + training.compile_loss(self._loss_fn, verbose=True) + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + collate_name = cfg.get( + "collate_fn", "torchtune.dev.grpo.data.padded_collate_rl" + ) + self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + dataloader_state_dict=( + checkpoint_dict[training.DATALOADER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader. + # This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = len(self._dataloader) + self.global_step = self._epochs_run * self._steps_per_epoch + + # Setup lr scheduler + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.get("lr_scheduler", None), + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # RL params + self.grpo_samples = cfg.grpo_samples + self._temperature = cfg.temperature + self._top_k = cfg.top_k + self._max_generated_tokens = cfg.max_generated_tokens + self.batch_size = cfg.batch_size + self._forward_batch_size = cfg.forward_batch_size + + self._ppo_epochs = cfg.ppo_epochs + self._save_every_n_epochs = cfg.save_every_n_epochs + self._total_steps = cfg.num_steps + + if cfg.get("stop_token_ids", False): + stop_token_ids = cfg.stop_token_ids + if self._tokenizer.eos_id not in stop_token_ids: + warn( + f"tokenizer eos_id ({self._tokenizer.eos_id}) is not in stop_token_ids ({stop_token_ids})." + "This may lead to unexpected behaviour." + ) + else: + if not hasattr(self._tokenizer, "stop_tokens"): + warn( + "No stop tokens defined in tokenizer, and no stop_token_ids provided. This may lead to unexpected behaviour." + ) + stop_token_ids = [] + else: + stop_token_ids = self._tokenizer.stop_tokens + self._stop_token_ids = torch.tensor(stop_token_ids, device=self._device) + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: Optional[DictConfig], + num_training_steps: int, + last_epoch: int, + ) -> Optional[Optimizer]: + """ + Set up the learning rate scheduler based on the provided configuration. + + Args: + cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. + num_training_steps (int): The total number of training steps. + last_epoch (int): The index of the last epoch. + + Returns: + lr_scheduler (Optional[Optimizer]): The learning rate scheduler. + """ + if cfg_lr_scheduler is None: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + + optimizer = self._optimizer + + # Instantiate the learning rate scheduler + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + log.info("Learning rate scheduler is initialized.") + + return lr_scheduler + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + log.info(f"Profiler config after instantiation: {profiler_cfg}") + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + self.profiler_num_cycles = profiler_cfg["num_cycles"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + model_sd: dict[str, Any], + eval_mode: bool = False, + ) -> nn.Module: + """ + Model initialization for single device training. + """ + log.info("Instantiating model and loading checkpoint...") + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if eval_mode: + model.eval() + for p in model.parameters(): + p.requires_grad = False + + if self._compile: + training.compile_model(model, verbose=True) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # Load the model state dict + model.load_state_dict(model_sd, strict=True) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + disable_dropout(model) + + return model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + opt_state_dict: Optional[dict[str, Any]] = None, + ) -> Optional[Optimizer]: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + collate_fn: str, + dataloader_state_dict: Optional[dict[str, Any]] = None, + ) -> StatefulDataLoader: + """ + All data related setup happens here. Currently this recipe only supports the + StatefulDistributedSampler with Map-style Datasets which fit into memory. + """ + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + + # Instantiate collate_fn + collate_fn = _get_component_from_path(collate_fn) + + sampler = StatefulDistributedSampler( + ds, + num_replicas=1, # Single device + rank=0, # Single device + shuffle=shuffle, + seed=self.seed, + ) + dataloader = StatefulDataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ) + ), + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + ) + if dataloader_state_dict is not None: + dataloader.load_state_dict(dataloader_state_dict) + # B/c we currently only save at epoch boundaries, if we cut the previous epoch short + # we need to force the dataloader to finish the last iteration before it's actually used + list(dataloader) + return dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the model weights and recipe state in + different checkpoint files. To correctly resume training from an intermediate checkpoint, + the model weights and recipe state must be provided. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + + # Get the full model state dict + cpu_state_dict = training.gather_cpu_state_dict( + self._model, + True, # is_rank_zero + device=self._device, + ) + + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" + ) + + if intermediate_checkpoint: + start = time.perf_counter() + log.info("Getting optimizer state dict...") + opt_state_dict = training.get_full_optimizer_state_dict( + self._model, + self._optimizer, + True, # is_rank_zero + device=self._device, + ) + log.info( + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + + start = time.perf_counter() + checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self._epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.RNG_KEY: self._rng.get_state(), + training.DATALOADER_KEY: self._dataloader.state_dict(), + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + def generate_trajectory( + self, input_ids: torch.Tensor, answers: list[str] + ) -> GRPOTrajectory: + """ + Generates a trajectory given the current policy model, the reference policy model, the reward function, + and batch of inputs. This is done over the following steps: + + 1: Generate responses, and logits corresponding to the responses using the current policy, + generating (query, response) pairs. + 2. Estimate logprobs of the generated responses using the current policy. + 3. Compute rewards and successes for the generated responses. + 4. Estimate advantages using GRPO. + 5. Replace any tokens in the response after the first stop token (usually EOS token) with padding, + producing truncated responses. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + answers (list[str]): list of answers corresponding to the input_ids + + Returns: + Trajectory: An instance of :class:`~torchtune.rlhf.GRPOTrajectory` comprising + the current trajectory. + """ + batch_size, context_length = input_ids.shape + grpo_size = self.grpo_samples + + batch_input_ids = input_ids[:, None, :].expand(-1, grpo_size, -1) # [B, G, L] + batch_input_ids = batch_input_ids.reshape(batch_size * grpo_size, -1) + + # step 1: generate responses, and logits corresponding to the responses using the current policy + with local_kv_cache( + model=self._model, + batch_size=batch_size * grpo_size, + device=self._device, + dtype=self._dtype, + decoder_max_seq_len=context_length + self._max_generated_tokens, + ): + query_responses, _ = generate( # [B x G, L], [B x G, L, V] + model=self._model, + prompt=batch_input_ids, + max_generated_tokens=self._max_generated_tokens, + temperature=self._temperature, + top_k=self._top_k, + pad_id=self._tokenizer.pad_id, + rng=self._rng, + stop_tokens=self._tokenizer.stop_tokens, + return_logits=False, + ) + + responses = query_responses[:, context_length:].clone() + query_response_padding_masks = query_responses != self._tokenizer.pad_id + + # step 1.1 create attention masks and position IDs for any padding tokens in inputs, used for future forward passes + masks = generation.get_causal_mask_from_padding_mask( + query_response_padding_masks + ) + position_ids = generation.get_position_ids_from_padding_mask( + query_response_padding_masks + ) + del query_response_padding_masks + + # step 2. estimate logprobs of the responses using the current policy + logits = self._model(query_responses, input_pos=position_ids, mask=masks) + logits = logits[:, context_length - 1 :] + logprobs = rlhf.batched_logits_to_logprobs(logits, responses, self._temperature) + del logits + torch.cuda.empty_cache() + + # step 2.1 estimate logprobs of the responses using the reference policy + ref_logits = self._ref_model( + query_responses, input_pos=position_ids, mask=masks + ) + ref_logits = rlhf.truncate_sequence_for_logprobs(ref_logits, context_length) + ref_logprobs = rlhf.batched_logits_to_logprobs( + ref_logits, responses, self._temperature + ) + del ref_logits + torch.cuda.empty_cache() + + # step 4. replace any tokens in the responses after the first stop token (usually EOS token) with padding + # resulting in truncated responses + ( + response_padding_masks, + responses, + ) = rlhf.truncate_sequence_at_first_stop_token( # [B x G, L] + responses, self._stop_token_ids, self._tokenizer.pad_id + ) + + # Do some reward modeling + # responses :: [B x G, L] + responses = responses.reshape(batch_size, grpo_size, -1) # [B, G, L] + rewards, successes = batch_shaped_correctness_reward(self._tokenizer, responses, answers) + rewards = rewards.to(self._device) # [B, G] + successes = successes.to(self._device) # [B, G] + + advantages = (rewards - rewards.mean(1, keepdim=True)) / ( + rewards.std(1, keepdim=True) + 1e-4 + ) + advantages = advantages.reshape(batch_size * grpo_size) # flatten + del responses + torch.cuda.empty_cache() + + # step 6. mask out all the invalid values in the trajectory due to padding tokens + logprobs[response_padding_masks] = 1.0 + ref_logprobs[response_padding_masks] = 1.0 + + return GRPOTrajectory( + query_responses=query_responses, + logprobs=logprobs, + ref_logprobs=ref_logprobs, + rewards=rewards.reshape(batch_size * grpo_size), + successes=successes.reshape(batch_size * grpo_size), + advantages=advantages, + masks=masks, + position_ids=position_ids, + response_padding_masks=response_padding_masks, + seq_lens=training.get_unmasked_sequence_lengths(response_padding_masks), + ) + + def generate_trajectory_batched( + self, input_ids: torch.Tensor, answers: list[str] + ) -> GRPOTrajectory: + """ + Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes. + See ``generate_trajectory`` for more details. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + answers: (list[str]): list of answers corresponding to the input_ids + + Returns: + Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory`, comprising + the current trajectory. + """ + trajectories: list[GRPOTrajectory] = [] + with torch.no_grad(): + for batch_start in range(0, self.batch_size, self._forward_batch_size): + batch_input_ids = input_ids[ + batch_start : batch_start + self._forward_batch_size + ] + batch_answers = answers[ + batch_start : batch_start + self._forward_batch_size + ] + torch.cuda.empty_cache() + trajectories.append( + self.generate_trajectory(batch_input_ids, batch_answers) + ) + torch.cuda.empty_cache() + return GRPOTrajectory(*map(torch.cat, zip(*trajectories))) + + def grpo_step( + self, + trajectory: GRPOTrajectory, + context_length: int, + ) -> GRPOStats: + """ + Perform a single GRPO optimization step over a batch of trajectories and corresponding advantages and returns. + + Args: + trajectory (Trajectory): a batch of trajectories + context_length (int): input ids sequence length + + Returns: + GRPOStats: An instance of :class:`~torchtune.rlhf.GRPOStats`, a NamedTuple containing: + - loss (torch.Tensor): The total GRPO loss. + - policy_loss (torch.Tensor): The policy loss component. + - kl_loss (torch.Tensor): The KL divergence loss component. + - ratios (torch.Tensor): The ratio between the current and old policy probabilities. + - clipfrac (torch.Tensor): The fraction of ratios that were clipped. + - approx_policy_kls: Average estimated KL divergence between the policy before and after the optimisation step. + + """ + torch.cuda.empty_cache() + + # estimate logprobs from the policy at the current optimisation step + pi_logits = self._model( + trajectory.query_responses, + input_pos=trajectory.position_ids, + mask=trajectory.masks, + ) + + pi_logits = rlhf.truncate_sequence_for_logprobs(pi_logits, context_length) + pi_logprobs = rlhf.batched_logits_to_logprobs( + pi_logits, + trajectory.query_responses[:, context_length:], + self._temperature, + chunk_size=1, + ) + + pi_logprobs[trajectory.response_padding_masks] = 1.0 + + del pi_logits + torch.cuda.empty_cache() + + # calculate grpo loss + loss, policy_loss, kl_loss, ratios, clipfrac = self._loss_fn( + trajectory.logprobs, + pi_logprobs, + trajectory.ref_logprobs, + trajectory.advantages, + padding_masks=~trajectory.response_padding_masks, + ) + + torch.cuda.empty_cache() + loss.backward() + + with torch.no_grad(): + approx_policy_kls = ( + 0.5 * (pi_logprobs - trajectory.logprobs).pow(2) + ).mean() + + return GRPOStats( + loss, + policy_loss, + kl_loss, + ratios, + clipfrac, + approx_policy_kls, + ) + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + grad_norm = None + + training_completed = False + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self._epochs_run, self.total_epochs): + pbar = tqdm(total=self._steps_per_epoch) + self._dataloader.sampler.set_epoch(curr_epoch) + for idx, batch in enumerate(self._dataloader): + # Start tracking CUDA memory for active steps for just the first epoch + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history() + + tokens = batch["tokens"] # type: ignore + answers = batch["answers"] # type: ignore + tokens = tokens.to(self._device) # [B, P] + + _, context_length = tokens.shape + + trajectory = self.generate_trajectory_batched(tokens, answers) + + grpo_stats: list[GRPOStats] = [] + for _ in range(self._ppo_epochs): + step_stats = self.grpo_step(trajectory, context_length) + + grpo_stats.append(step_stats) + + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + self.global_step += 1 + + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + # Stop tracking CUDA memory now that active steps are complete + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history(enabled=None) + + self._steps_run += 1 + if self._steps_run % self._log_every_n_steps == 0: + extra_metrics = {} + extra_metrics["lr"] = get_lr(self._optimizer) + if grad_norm is not None: + extra_metrics["grad_norm"] = grad_norm + + self.log_metrics( + trajectory, + GRPOStats(*map(torch.stack, zip(*grpo_stats))), + **extra_metrics, + ) + + self.cleanup_after_step(trajectory, grpo_stats) + self._profiler.step() + + pbar.update(1) + + if self._steps_run == self._total_steps: + training_completed = True + break + + self._epochs_run += 1 + if self._epochs_run % self._save_every_n_epochs == 0: + self.save_checkpoint(curr_epoch) + if training_completed: + return + + self._profiler.stop() + + def log_metrics( + self, trajectory: GRPOTrajectory, grpo_stats: GRPOStats, **extras + ) -> None: + """ + Log metrics and statistics for the current step to the metric logger. + """ + log_dict = { + "rewards": trajectory.rewards.mean(), + "successes": trajectory.successes.mean(), + "num_stop_tokens": trajectory.response_padding_masks.any(-1).sum(), + "loss": grpo_stats.loss.mean(), + "policy_loss": grpo_stats.policy_loss.mean(), + "kl_loss": grpo_stats.kl_loss.mean(), + "clipfrac": grpo_stats.clipfrac.mean(), + "ratios": grpo_stats.ratios.mean(), + "approx_policy_kl": grpo_stats.approx_policy_kls.mean(), + "response_lengths": trajectory.seq_lens.float().mean(), + **extras, + } + + if self._device.type == "cuda" and self._log_peak_memory_stats: + log_dict.update(training.get_memory_stats(device=self._device)) + + self._metric_logger.log_dict(log_dict, step=self.global_step) + + def cleanup(self) -> None: + self._metric_logger.close() + + def cleanup_after_step( + self, + trajectory: GRPOTrajectory, + l_grpo_stats: list[GRPOStats], + ) -> None: + for v in trajectory: + del v + del trajectory + for g in l_grpo_stats: + for v in g: + del v + del g + del l_grpo_stats + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + + recipe = GRPOFullFinetuneRecipeSingleDevice(cfg=cfg) + config.log_config(recipe_name="GRPOFullFinetuneRecipeSingleDevice", cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index ee3ae60dd2..dcd2dcfe6c 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -31,6 +31,14 @@ class Recipe: ], supports_distributed=True, ), + Recipe( + name="dev/grpo_full_finetune_single_device", + file_path="dev/grpo_full_finetune_single_device.py", + configs=[ + Config(name="dev/3B_full_grpo", file_path="dev/3B_full_grpo.yaml"), + ], + supports_distributed=False, + ), Recipe( name="dev/async_grpo_full_finetune_distributed", file_path="dev/async_grpo_full_finetune_distributed.py", diff --git a/torchtune/dev/rl/rewards.py b/torchtune/dev/rl/rewards.py index 8d1ec1e79f..a7592ee5db 100644 --- a/torchtune/dev/rl/rewards.py +++ b/torchtune/dev/rl/rewards.py @@ -11,6 +11,8 @@ import torch +from torchtune.modules.transforms.tokenizers import ModelTokenizer + @dataclass class RewardOutput: