Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
defaults: ../../grpo_math_1B.yaml
grpo:
num_prompts_per_step: 128
num_prompts_per_step: 64
num_generations_per_prompt: 16
policy:
model_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5
model_name: /lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf
tokenizer:
name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5
max_total_sequence_length: 1024
train_global_batch_size: 128
max_total_sequence_length: 24576
#max_total_sequence_length: 1024
train_global_batch_size: 64
train_micro_batch_size: 1
logprob_batch_size: 2
dtensor_cfg:
activation_checkpointing: true
tensor_parallel_size: 8
context_parallel_size: 4
tensor_parallel_size: 2
custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan
dynamic_batching:
enabled: true
Expand All @@ -32,16 +37,19 @@ policy:
- 13
generation:
vllm_cfg:
async_engine: false
tensor_parallel_size: 4
#pipeline_parallel_size: 2
make_sequence_length_divisible_by: ${max:${mul:${policy.dtensor_cfg.context_parallel_size}, 2}, ${policy.max_total_sequence_length}}
logger:
wandb_enabled: true
monitor_gpus: false
wandb:
project: grpo-nemotron-super-49b
name: grpo-${data.dataset_name}-nemotron-super-49b-tp${policy.dtensor_cfg.tensor_parallel_size}
name: grpo-${data.dataset_name}-nemotron-super-49b-tp${policy.dtensor_cfg.tensor_parallel_size}-cp${policy.dtensor_cfg.context_parallel_size}
mlflow:
experiment_name: sft-dev
run_name: grpo-nemotron-super-49b
cluster:
gpus_per_node: 8
num_nodes: 4
num_nodes: 8
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,73 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast

from torch.distributed.tensor.parallel import (
ColwiseParallel,
ParallelStyle,
PrepareModuleInput,
PrepareModuleOutput,
RowwiseParallel,
SequenceParallel,
)
from torch.distributed.tensor.placement_types import Replicate, Shard

custom_parallel_plan: dict[str, ParallelStyle] = {
"model.layers.*.self_attn": PrepareModuleInput(
input_kwarg_layouts={"attention_mask": Replicate()},
desired_input_kwarg_layouts={"attention_mask": Replicate()},
),
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(), output_layouts=Replicate(), use_local_output=True
),
"model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.self_attn.o_proj": RowwiseParallel(
output_layouts=Replicate(), use_local_output=True
),
"model.layers.*.self_attn.rotary_emb": PrepareModuleOutput(
output_layouts=(Replicate(), Replicate()),
desired_output_layouts=(Replicate(), Replicate()),
use_local_output=False,
),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(
output_layouts=Replicate(), use_local_output=True
),
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
}

def get_custom_parallel_plan():
# Reuse llama default parallel plan
base_model_tp_plan: dict[str, ParallelStyle] = {
"model.embed_tokens": RowwiseParallel(input_layouts=Replicate()),
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(),
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
}

base_model_sp_plan = {
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(), output_layouts=Shard(1)
),
"model.norm": SequenceParallel(),
"model.layers.*.input_layernorm": SequenceParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
"model.layers.*.post_attention_layernorm": SequenceParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
"lm_head": ColwiseParallel(
input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False
),
}

if False:
# Enable sequence parallelism only if TP size > 1
base_model_tp_plan.update(cast(dict[str, ParallelStyle], base_model_sp_plan))

return base_model_tp_plan


custom_parallel_plan: dict[str, ParallelStyle] = get_custom_parallel_plan()
# {

# "model.embed_tokens": RowwiseParallel(
# input_layouts=Replicate(), output_layouts=Replicate(), use_local_output=True
# ),
# "model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False),
# "model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False),
# "model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=False),
# "model.layers.*.self_attn.o_proj": RowwiseParallel(
# output_layouts=Replicate(), use_local_output=True
# ),
# "model.layers.*.self_attn.rotary_emb": PrepareModuleOutput(
# output_layouts=(Replicate(), Replicate()),
# desired_output_layouts=(Replicate(), Replicate()),
# use_local_output=False,
# ),
# "model.layers.*.mlp.up_proj": ColwiseParallel(),
# "model.layers.*.mlp.gate_proj": ColwiseParallel(),
# "model.layers.*.mlp.down_proj": RowwiseParallel(
# output_layouts=Replicate(), use_local_output=True
# ),
# "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
# }
134 changes: 134 additions & 0 deletions examples/configs/sft_nemotron_super_49b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# SFT Algorithm Configuration
sft:
max_num_epochs: 3
max_num_steps: 100
val_period: 10
val_batches: 8
val_global_batch_size: 128
val_micro_batch_size: 1
val_at_start: true
seed: 42

checkpointing:
enabled: true
checkpoint_dir: "results/sft_nemotron_super_49b"
metric_name: "val_loss"
higher_is_better: false
keep_top_k: 100
save_period: 500
checkpoint_must_save_by: null

policy:
# model_name: Qwen/Qwen2.5-7B-Instruct
# tokenizer:
# name: Qwen/Qwen2.5-7B-Instruct
model_name: "/lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf"
tokenizer:
name: ${policy.model_name}
max_total_sequence_length: 4096
precision: "bfloat16"
train_global_batch_size: 128
train_micro_batch_size: 8

dtensor_cfg:
_v2: true
activation_checkpointing: true
context_parallel_size: 2
cpu_offload: false
enabled: true
sequence_parallel: false
tensor_parallel_size: 4
custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan

megatron_cfg:
enabled: false

dynamic_batching:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 8192
sequence_length_round: 64

sequence_packing:
enabled: false
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64


# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${max:${mul:${policy.dtensor_cfg.context_parallel_size}, 2}, ${policy.max_total_sequence_length}}
max_grad_norm: null

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 2e-5
weight_decay: 0.01
betas: [0.9, 0.98]
eps: 1e-8
# when using Dtensor, we need to set foreach
# and fused to False
foreach: False
fused: False

# data:
# add_bos: true
# add_eos: true
# add_generation_prompt: false
# dataset_name: "tulu3_sft_mixture"
# cache_dir: "/lustre/fsw/portfolios/coreai/users/gvenkatakris/data-cache"
# max_input_seq_length: 1024
# max_samples: 10000
# shuffle: true
# test_size: 0.05

data:
max_input_seq_length: ${policy.max_total_sequence_length}
add_bos: true
add_eos: true
add_generation_prompt: false
shuffle: true
num_workers: 20

dataset_name: "squad"
# You can use custom response datasets for training and validation. For example:
# data:
# dataset_name: ResponseDataset
# train_data_path: <PathToTrainingDataset> # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace)
# val_data_path: <PathToValidationDataset>
# input_key: <QuestionKey>, default is "input"
# output_key: <AnswerKey>, default is "output"
# train_split: <TrainSplit>, default is None # used for HuggingFace datasets
# val_split: <ValSplit>, default is None # used for HuggingFace datasets
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details.

## unused with squad dataset
prompt_file: null
split: null
output_key: null
seed: null

logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running
tensorboard_enabled: false
mlflow_enabled: false
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
swanlab_enabled: false
wandb:
project: "sft-nemotron-joyang"
name: "sft-${data.dataset_name}-nemotron-super-49b-joyang"
tensorboard:
log_dir: "tb_logs-openmathinstruct-nemorl-1M_train"
mlflow:
experiment_name: "sft-dev"
run_name: "openmathinstruct-nemorl-1M_train"
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 8
num_nodes: 1
15 changes: 13 additions & 2 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Optional

from omegaconf import OmegaConf
from transformers import PreTrainedTokenizerBase
from transformers import AutoConfig, PreTrainedTokenizerBase

from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_rl.algorithms.utils import get_tokenizer
Expand All @@ -41,6 +41,7 @@
from nemo_rl.utils.logger import get_next_experiment_dir

OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
OmegaConf.register_new_resolver("max", lambda a, b: max(a, b))


def parse_args() -> tuple[argparse.Namespace, list[str]]:
Expand Down Expand Up @@ -158,7 +159,17 @@ def main() -> None:

init_ray()

# setup tokenizer
# setup tokenizer and preloading model to force HF to download the model and modules
# to avoid race condition inside generation/policy workers.
try:
_ = AutoConfig.from_pretrained(
config["policy"]["model_name"], trust_remote_code=True
)
print(f"Config preloaded successfully: {config['policy']['model_name']}")
except Exception as e:
print("WARNIN: error in preloading model, in general it's not a problem: ")
print(e)

tokenizer = get_tokenizer(config["policy"]["tokenizer"])
assert config["policy"]["generation"] is not None, (
"A generation config is required for GRPO"
Expand Down
1 change: 1 addition & 0 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from nemo_rl.utils.logger import get_next_experiment_dir

OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
OmegaConf.register_new_resolver("max", lambda a, b: max(a, b))


def parse_args():
Expand Down
6 changes: 5 additions & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,8 @@ def grpo_train(

print("▶ Computing logprobs...", flush=True)
with timer.time("policy_and_reference_logprobs"):
fprop_logprobs = policy.get_logprobs(train_data)["logprobs"]
logprobs_results = policy.get_logprobs(train_data)
fprop_logprobs = logprobs_results["logprobs"]
reference_logprobs = policy.get_reference_policy_logprobs(
train_data
)["reference_logprobs"]
Expand Down Expand Up @@ -915,12 +916,15 @@ def grpo_train(
log_data, f"train_data_step{total_steps}.jsonl"
)

print(f"train_results: {train_results['train_max_seq_len']}")
metrics = {
"loss": train_results["loss"].numpy(),
"train_max_seq_len": train_results["train_max_seq_len"],
"reward": rewards.numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
"train_max_seq_len": train_results["train_max_seq_len"],
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
Expand Down
Loading
Loading