diff --git a/.config/clash/cache.db b/.config/clash/cache.db new file mode 100644 index 0000000000..e449c28987 Binary files /dev/null and b/.config/clash/cache.db differ diff --git a/.config/clash/config.yaml b/.config/clash/config.yaml new file mode 100644 index 0000000000..5d36fe7f61 --- /dev/null +++ b/.config/clash/config.yaml @@ -0,0 +1 @@ +mixed-port: 7890 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 144a8c2b89..1ad0a12edc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,18 +1,36 @@ -__pycache__ +# 模拟数据和临时文件 +enhanced_tensor_logs/ +draw/ +*.pt +*.png +*.jpg +*.jpeg + +# Python缓存 +__pycache__/ +*.py[cod] +*$py.class *.so -build -.coverage_* -*.egg-info -*~ -slurm* -logs -.vscode -local/ -.gitmodules -wandb/ -onelogger.log -onelogger.err + +# 环境变量文件 +.env .venv -runs/ -/test_cases/ -**/dist/ \ No newline at end of file +env/ +venv/ + +# IDE文件 +.vscode/ +.idea/ +*.swp +*.swo + +# 日志文件 +*.log + +# 临时文件 +*.tmp +*.temp + +# 系统文件 +.DS_Store +Thumbs.db \ No newline at end of file diff --git a/.pretrain_gpt.py.swo b/.pretrain_gpt.py.swo new file mode 100644 index 0000000000..4c656790e8 Binary files /dev/null and b/.pretrain_gpt.py.swo differ diff --git a/LAYER_ANALYSIS_README.md b/LAYER_ANALYSIS_README.md new file mode 100644 index 0000000000..731e1d2185 --- /dev/null +++ b/LAYER_ANALYSIS_README.md @@ -0,0 +1,204 @@ +# Layer Distribution Analysis Tool + +## 概述 + +专门分析某个层的tensor分布的工具,支持attention和linear层的q,k,v,output和input,weight,output分析。使用正则表达式匹配tensor文件,生成详细的分布图表和统计报告。 + +## 功能特性 + +### 1. 层分析功能 +- **Attention层分析**: 分析query, key, value, output, attention_weights的分布 +- **Linear层分析**: 分析input, weight, output, bias, hidden的分布 +- **多子图显示**: 一个大图包含6个子图,展示不同tensor类型的分布 +- **统计信息**: 每个子图显示均值、标准差等关键统计信息 + +### 2. 量化对比功能 +- **多量化类型对比**: 同时显示bf16, mxfp8, mxfp4, hifp8的分布对比 +- **特定tensor分析**: 可以针对特定tensor类型进行量化对比 +- **2x2子图布局**: 清晰展示4种量化类型的分布差异 + +### 3. 统计报告 +- **详细统计信息**: 包含均值、标准差、分位数等完整统计 +- **文件计数**: 显示找到的文件数量 +- **数据质量**: 显示有效数据点数量 + +## 使用方法 + +### 基本用法 + +#### Python脚本直接调用 +```bash +# 分析attention层 +python analyze_layer_distribution.py --layer 1 --sample 0 --layer_type attention + +# 分析linear层 +python analyze_layer_distribution.py --layer 2 --sample 1 --layer_type linear + +# 量化对比分析 +python analyze_layer_distribution.py --layer 1 --sample 0 --layer_type attention --tensor_type query --quantization_comparison +``` + +#### Shell脚本调用 +```bash +# 基本用法 +./run_layer_analysis.sh + +# 示例 +./run_layer_analysis.sh ./enhanced_tensor_logs ./layer_output 1 0 attention +./run_layer_analysis.sh ./enhanced_tensor_logs ./layer_output 2 1 linear + +# 带量化对比 +./run_layer_analysis.sh ./enhanced_tensor_logs ./layer_output 1 0 attention query true +``` + +### 参数说明 + +#### Python脚本参数 +- `--tensor_dir`: 张量文件目录 (默认: ./enhanced_tensor_logs) +- `--output_dir`: 输出目录 (默认: ./layer_analysis_output) +- `--layer`: 层号 (必需, 如: 1, 2, 3, ...) +- `--sample`: 样本号 (必需, 如: 0, 1, 2) +- `--layer_type`: 层类型 (必需, attention 或 linear) +- `--tensor_type`: 特定tensor类型 (可选, 用于量化对比) +- `--quantization_comparison`: 启用量化对比 (可选) + +#### Shell脚本参数 +1. `tensor_dir`: 张量文件目录 +2. `output_dir`: 输出目录 +3. `layer`: 层号 +4. `sample`: 样本号 +5. `layer_type`: 层类型 (attention/linear) +6. `tensor_type`: 特定tensor类型 (可选) +7. `quantization_comparison`: 是否启用量化对比 (true/false) + +## 输出文件 + +### 1. 层分析图表 +- **文件名格式**: `layer_{layer}_sample_{sample}_{layer_type}_analysis.png` +- **内容**: 6个子图显示不同tensor类型的分布 +- **统计信息**: 每个子图包含均值、标准差等统计信息 + +### 2. 量化对比图表 +- **文件名格式**: `quantization_comparison_layer_{layer}_sample_{sample}_{layer_type}_{tensor_type}.png` +- **内容**: 2x2子图显示4种量化类型的分布对比 +- **适用场景**: 需要比较不同量化类型对同一tensor的影响 + +### 3. 统计报告 +- **文件名格式**: `statistics_layer_{layer}_sample_{sample}_{layer_type}.txt` +- **内容**: 详细的数值统计信息 +- **包含信息**: 文件数量、数据点数量、均值、标准差、分位数等 + +## 支持的Tensor类型 + +### Attention层 +- `query`: Query张量 +- `key`: Key张量 +- `value`: Value张量 +- `output`: 输出张量 +- `attention_weights`: 注意力权重矩阵 + +### Linear层 +- `input`: 输入张量 +- `weight`: 权重张量 +- `output`: 输出张量 +- `bias`: 偏置张量 +- `hidden`: 隐藏层张量 + +## 文件命名格式支持 + +工具支持以下文件命名格式: +``` +YYYYMMDD_HHMMSS_XXXX_iterXXX_layer_type_LX_operation_phase_component_quant_type_rankXX_sampleXXX_groupXXX_tensor_name.pt +``` + +示例: +``` +20250914_075006_1399_iter000_attention_L1_forward_post_FA_bf16_rank07_sample000_group000_attention_weights.pt +``` + +## 使用示例 + +### 示例1: 分析第1层第0个样本的attention分布 +```bash +python analyze_layer_distribution.py --layer 1 --sample 0 --layer_type attention +``` + +### 示例2: 分析第2层第1个样本的linear分布 +```bash +python analyze_layer_distribution.py --layer 2 --sample 1 --layer_type linear +``` + +### 示例3: 对比第1层第0个样本query张量的量化效果 +```bash +python analyze_layer_distribution.py --layer 1 --sample 0 --layer_type attention --tensor_type query --quantization_comparison +``` + +### 示例4: 使用shell脚本分析 +```bash +# 分析attention层 +./run_layer_analysis.sh ./enhanced_tensor_logs ./output 1 0 attention + +# 分析linear层并启用量化对比 +./run_layer_analysis.sh ./enhanced_tensor_logs ./output 2 1 linear weight true +``` + +## 技术特性 + +### 数据处理 +- **自动数据清理**: 自动移除NaN和Inf值 +- **数据采样**: 大数据集自动采样以提高性能 +- **多文件合并**: 自动合并同一类型的多个tensor文件 + +### 可视化质量 +- **高分辨率**: 300 DPI输出 +- **专业配色**: 科学可视化标准配色 +- **清晰标注**: 详细的图表标签和统计信息 + +### 错误处理 +- **文件损坏处理**: 自动跳过损坏的tensor文件 +- **格式兼容性**: 支持多种tensor文件格式 +- **优雅降级**: 数据缺失时显示友好提示 + +## 依赖要求 + +### Python包 +- torch +- matplotlib +- numpy +- pandas +- seaborn + +### 安装依赖 +```bash +pip install torch matplotlib numpy pandas seaborn scipy +``` + +## 注意事项 + +1. **文件格式**: 确保tensor文件格式正确且可读 +2. **内存使用**: 处理大量tensor文件时注意内存使用 +3. **输出目录**: 确保对输出目录有写权限 +4. **层和样本**: 确保指定的层和样本存在对应的tensor文件 + +## 故障排除 + +### 常见问题 +1. **No data found**: 检查层号、样本号和层类型是否正确 +2. **No valid data**: 检查tensor文件是否损坏或格式不正确 +3. **Import error**: 安装缺失的Python包 +4. **Permission denied**: 检查输出目录的写权限 + +### 调试建议 +1. 检查tensor文件是否存在 +2. 验证文件名格式是否正确 +3. 确认Python环境配置 +4. 查看详细错误信息 + +## 版本历史 + +### v1.0.0 (当前版本) +- 基础层分析功能 +- 支持attention和linear层 +- 量化对比功能 +- 统计报告生成 +- Shell脚本封装 diff --git a/MXFP_SCALING_README.md b/MXFP_SCALING_README.md new file mode 100644 index 0000000000..08d12b25cc --- /dev/null +++ b/MXFP_SCALING_README.md @@ -0,0 +1,396 @@ +# MXFP Scaling Analysis Tools + +这个文档介绍了新增的MXFP量化缩放分析工具,用于分析和优化MXFP量化过程中的缩放策略。 + +## 新增功能 + +### 1. 下溢出分析功能 + +在`quant/mxfp.py`中新增了`_analyze_underflow_before_quantization`函数,该函数会在量化前分析张量的下溢出情况。 + +**功能特点:** +- 在量化前检测潜在的下溢出问题 +- 分析scaling对下溢出的影响 +- 提供详细的下溢出统计信息 +- 不会影响量化过程的正常执行 + +**触发条件:** +- 当检测到下溢出比例 > 0.1% 时自动输出分析报告 +- 提供高、中等下溢出率的警告 + +### 2. MXFP缩放测试工具 + +创建了`quant/mxfp_scaling_test.py`工具,用于测试不同缩放策略对量化精度的影响。 + +**主要功能:** +- 测试从最大值对齐到最小值对齐的不同缩放级别 +- 计算多种精度指标(MSE、余弦相似度、PSNR等) +- 生成详细的折线图可视化结果 +- 支持多种MXFP格式 + +## 使用方法 + +### 基本用法 + +```bash +# 测试单个张量文件的缩放效果 +python quant/mxfp_scaling_test.py input_tensor.pt + +# 测试多个张量文件 +python quant/mxfp_scaling_test.py tensor1.pt tensor2.pt tensor3.pt + +# 指定输出目录和参数 +python quant/mxfp_scaling_test.py input_tensor.pt --output-dir ./results/ --elem-format fp8_e4m3 --num-levels 31 +``` + +### 参数说明 + +- `input_tensor`: 输入的BF16张量文件路径(支持多个文件) +- `--output-dir`: 输出结果目录(默认:./draw/scaling_analysis/{tensor_name}/) +- `--elem-format`: 量化格式(fp8_e4m3, fp8_e5m2, fp4_e2m1等) +- `--scale-bits`: 缩放位数(默认:8) +- `--max-scale-exp`: 最大缩放指数(默认:自动计算,基于tensor最大值对齐) +- `--min-scale-exp`: 最小缩放指数(默认:自动计算,基于tensor最小值对齐) +- `--num-levels`: 测试的缩放级别数量(默认:21) +- `--no-plots`: 跳过生成图表 + +### 多tensor处理特性 + +当提供多个tensor文件时,工具会: + +1. **独立处理**: 每个tensor文件独立进行缩放测试和分析 +2. **独立输出**: 为每个tensor创建独立的输出目录和日志文件 +3. **进度显示**: 实时显示处理进度 `[1/3] Processing: tensor1.pt` +4. **状态反馈**: 显示每个tensor的处理结果(✅成功 / ❌失败) +5. **最终汇总**: 提供所有tensor的处理汇总统计 + +**示例输出:** +``` +Processing 3 tensor(s)... +================================================================================ + +[1/3] Processing: tensor1.pt +------------------------------------------------------------ +✅ Successfully processed: tensor1.pt + +[2/3] Processing: tensor2.pt +------------------------------------------------------------ +✅ Successfully processed: tensor2.pt + +[3/3] Processing: tensor3.pt +------------------------------------------------------------ +✅ Successfully processed: tensor3.pt + +================================================================================ +FINAL SUMMARY +================================================================================ +Total tensors: 3 +Successful: 3 +Failed: 0 +🎉 All tests completed successfully! +``` + +### 输出结果 + +工具会生成以下文件,默认保存在 `draw/scaling_analysis/{tensor_name}/` 目录下: + +1. **详细结果文件**: `mxfp_scaling_results_.txt` + - 包含所有缩放级别的详细指标数据 + +2. **综合图表**: `mxfp_scaling_test_.png` + - 6个子图展示不同指标随缩放指数的变化 + +3. **摘要图表**: `mxfp_scaling_summary_.png` + - 关键指标(MSE、余弦相似度、PSNR)的汇总图 + +4. **日志文件**: `mxfp_scaling_test_{tensor_name}_{format}.log` + - 完整的测试过程日志,包含所有输出信息 + - 同时显示在控制台和保存到文件中 + - 包含详细的缩放因子分析和推荐 + +**输出目录结构示例:** +``` +draw/scaling_analysis/ +├── tensor1/ +│ ├── mxfp_scaling_results_fp8_e4m3.txt +│ ├── mxfp_scaling_test_fp8_e4m3.png +│ ├── mxfp_scaling_summary_fp8_e4m3.png +│ └── mxfp_scaling_test_tensor1_fp8_e4m3.log +├── tensor2/ +│ ├── mxfp_scaling_results_fp8_e5m2.txt +│ ├── mxfp_scaling_test_fp8_e5m2.png +│ ├── mxfp_scaling_summary_fp8_e5m2.png +│ └── mxfp_scaling_test_tensor2_fp8_e5m2.log +└── ... +``` + +### 日志功能 + +工具会自动生成详细的日志文件,记录完整的测试过程: + +**日志特点:** +- **双重输出**: 同时显示在控制台和保存到日志文件 +- **时间戳**: 每条日志都包含精确的时间戳 +- **完整记录**: 记录从测试开始到结束的所有信息 +- **结构化格式**: 清晰的日志格式,便于分析和调试 + +**日志内容包含:** +- 测试开始和结束时间 +- 输入张量信息(形状、数据类型、数值范围) +- 测试参数(格式、缩放位数、指数范围等) +- 每个缩放级别的测试进度和结果 +- **详细的缩放因子分析和推荐** +- 最终的最佳结果汇总 +- 文件保存位置信息 + +### 智能分析功能 + +工具会自动分析测试结果并提供智能推荐: + +**分析内容:** +- **个体指标最优解**: 找出MSE、余弦相似度、PSNR等指标的最佳缩放因子 +- **综合评分推荐**: 基于加权综合评分推荐最佳缩放因子 +- **性能稳定性分析**: 分析不同缩放因子下的性能变化范围 +- **实用性建议**: 根据性能变化程度提供使用建议 + +**推荐算法:** +- 综合评分权重:MSE(30%) + 余弦相似度(30%) + PSNR(20%) + MAE(10%) + 相对误差(10%) +- 自动识别性能稳定性和关键选择点 +- 提供基于数据特征的个性化建议 + +### 智能溢出分析功能 + +工具会自动分析每个缩放级别的上溢出和下溢出情况并提供详细报告: + +**溢出分析内容:** +- **上溢出检测**: 检测值超出格式最大表示范围的情况 +- **下溢出检测**: 检测值小于格式最小表示范围的情况 +- **严重程度分类**: 高严重程度(>1%)、中等严重程度(0.1-1%)、无显著溢出(<0.1%) +- **详细统计**: 每个缩放级别的溢出数量、百分比和刷新到零的统计 +- **张量范围分析**: 显示量化前后的张量数值范围 +- **最优范围推荐**: 基于溢出分析推荐最佳的缩放范围 + +**溢出分析输出示例:** +``` +================================================================================ +OVERFLOW/UNDERFLOW ANALYSIS SUMMARY +================================================================================ +Format: fp8_e4m3 +Analyzed 7 scaling levels +Significant overflow/underflow detected in 7 levels +-------------------------------------------------------------------------------- +🔴 OVERFLOW ISSUES: +---------------------------------------- + Scale Exp: 5.00 (Factor: 32.000000) + Overflow: 10 (5.00%) + Max Normal: 4.48e+02 + Tensor Range: [-7.76e+02, 6.32e+02] + Severity: HIGH + +🟡 UNDERFLOW ISSUES: +---------------------------------------- + Scale Exp: 5.00 (Factor: 32.000000) + Underflow: 100 (50.00%) + Flush to Zero: 100 (50.00%) + Min Normal: 1.56e-02 + Tensor Range: [-7.76e+02, 6.32e+02] + Severity: HIGH + +OVERFLOW EXTREMES: +---------------------------------------- +Worst Overflow: Scale Exp -6.67 + 50.00% overflow + +UNDERFLOW EXTREMES: +---------------------------------------- +Worst Underflow: Scale Exp 5.00 + 50.00% underflow, 50.00% flushed to zero +Best Underflow: Scale Exp -30.00 + 0.00% underflow, 0.50% flushed to zero +-------------------------------------------------------------------------------- +OVERFLOW/UNDERFLOW RECOMMENDATIONS: +---------------------------------------- +⚠️ AVOID scaling factors with HIGH overflow/underflow severity + These factors cause significant precision loss +🔴 OVERFLOW WARNING: + Avoid scaling factors that cause overflow + These values will be saturated to max representable value +🟡 UNDERFLOW CONSIDERATIONS: + Moderate underflow may be acceptable depending on use case + Balance between underflow and overflow risks +⚠️ All scaling levels have some overflow/underflow - choose least problematic +💡 Least problematic scaling: 5.00 + Overflow: 5.00%, Underflow: 50.00% +================================================================================ +``` + +**日志文件命名规则:** +``` +mxfp_scaling_test_{tensor_name}_{format}.log +``` + +例如:`mxfp_scaling_test_my_tensor_fp8_e4m3.log` + +**分析输出示例:** +``` +================================================================================ +SCALING FACTOR ANALYSIS & RECOMMENDATIONS +================================================================================ +Format: fp8_e4m3 (e4m5) +Tested 7 scaling levels from -2.00 to 2.00 +-------------------------------------------------------------------------------- +INDIVIDUAL METRIC OPTIMA: +---------------------------------------- +🏆 Best MSE: Scale Exp = 0.00, Factor = 1.000000 + MSE: 1.765694e+00, Cosine: 0.999654, PSNR: 41.34 dB +🎯 Best Cosine Similarity: Scale Exp = 0.00, Factor = 1.000000 + MSE: 1.765694e+00, Cosine: 0.999654, PSNR: 41.34 dB +📊 Best PSNR: Scale Exp = 0.00, Factor = 1.000000 + MSE: 1.765694e+00, Cosine: 0.999654, PSNR: 41.34 dB +-------------------------------------------------------------------------------- +COMPOSITE RECOMMENDATION: +---------------------------------------- +⭐ RECOMMENDED Scaling Factor: 1.000000 + Scale Exponent: 0.00 + Composite Score: 0.9999 + Balanced Performance: + - MSE: 1.765694e+00 + - Cosine Similarity: 0.999654 + - PSNR: 41.34 dB + - MAE: 8.851570e-01 + - Relative Error: 2.20% +-------------------------------------------------------------------------------- +PERFORMANCE ANALYSIS: +---------------------------------------- +MSE Range: 1.765694e+00 to 1.117014e+01 (Δ: 9.404443e+00) +Cosine Range: 0.997942 to 0.999654 (Δ: 0.001712) +PSNR Range: 33.33 to 41.34 dB (Δ: 8.01 dB) +MSE Stability (std): 3.263742e+00 +Cosine Stability (std): 0.000594 +-------------------------------------------------------------------------------- +RECOMMENDATIONS: +---------------------------------------- +⚠️ MSE varies significantly with scaling - choose the recommended factor carefully +✅ Cosine similarity is very stable - scaling factor has minimal impact on direction preservation +✅ Small PSNR range - scaling factor has limited impact on quality +-------------------------------------------------------------------------------- +FINAL RECOMMENDATION: +---------------------------------------- +🎯 Use scaling factor: 1.000000 + This provides the best balance of accuracy and stability for fp8_e4m3 quantization + Scale exponent: 0.00 + 📍 This is a balanced middle ground between overflow and underflow +================================================================================ +``` + +### 计算的指标 + +- **MSE (Mean Squared Error)**: 均方误差 +- **RMSE (Root Mean Squared Error)**: 均方根误差 +- **Cosine Similarity**: 余弦相似度 +- **PSNR (Peak Signal-to-Noise Ratio)**: 峰值信噪比 +- **MAE (Mean Absolute Error)**: 平均绝对误差 +- **Max Absolute Error**: 最大绝对误差 +- **Relative Error**: 相对误差百分比 + +## 测试和验证 + +### 运行测试脚本 + +```bash +# 运行基本功能测试 +python test_mxfp_scaling.py +``` + +### 示例用法 + +```bash +# 测试FP8 E4M3格式的缩放效果 +python quant/mxfp_scaling_test.py your_tensor.pt --elem-format fp8_e4m3 --num-levels 21 + +# 测试FP8 E5M2格式,扩大测试范围 +python quant/mxfp_scaling_test.py your_tensor.pt --elem-format fp8_e5m2 --max-scale-exp 15 --min-scale-exp -15 --num-levels 31 + +# 测试FP4格式,更多缩放级别 +python quant/mxfp_scaling_test.py your_tensor.pt --elem-format fp4_e2m1 --num-levels 51 +``` + +## 技术细节 + +### 缩放策略 + +工具测试从最大值对齐(maximum alignment)到最小值对齐(minimum alignment)的不同缩放策略: + +1. **最大值对齐**: 缩放使张量的绝对值最大值刚好在格式的最大可表示值范围内,避免上溢出 +2. **最小值对齐**: 缩放使张量的绝对值最小值刚好在格式的最小可表示值范围内,避免下溢出 +3. **中间级别**: 在这两个极端之间均匀分布的缩放级别 + +**对齐计算逻辑:** +- **最大对齐指数**: `floor(log2(tensor.abs().max() / format.max_norm))` +- **最小对齐指数**: `ceil(log2(tensor.abs().min() / format.min_norm))` + +其中`tensor.abs().min()`只考虑非零值,确保下溢出分析的正确性。 + +### 下溢出分析 + +下溢出分析在量化前进行,检测: +- 非零值但小于最小可表示值的元素 +- 会被刷新为零的元素 +- 下溢出和刷新统计信息 + +### 可视化特性 + +- 动态调整图表范围以适应数据分布 +- 突出显示关键边界值 +- 提供详细的统计信息框 +- 支持对数刻度的误差指标 + +## 注意事项 + +1. **输入格式**: 工具期望输入为BF16格式的张量文件 +2. **内存使用**: 大张量可能需要较多内存,建议在GPU上运行 +3. **计算时间**: 更多缩放级别会增加计算时间 +4. **精度**: 指标计算使用float32精度以确保准确性 + +## 故障排除 + +### 常见问题 + +1. **导入错误**: 确保在Megatron-LM根目录运行脚本 +2. **内存不足**: 减小张量大小或缩放级别数量 +3. **格式不支持**: 检查elem-format参数是否支持 +4. **文件格式**: 确保输入文件是有效的PyTorch张量文件 + +### 输出目录管理 + +**默认行为:** +- 如果不指定`--output-dir`,工具会自动创建`draw/scaling_analysis/{tensor_name}/`目录 +- `{tensor_name}`是输入文件名(不包含扩展名) + +**自定义输出目录:** +```bash +# 使用默认目录(基于tensor名称) +python quant/mxfp_scaling_test.py my_tensor.pt + +# 指定自定义输出目录 +python quant/mxfp_scaling_test.py my_tensor.pt --output-dir ./my_results/ + +# 为不同格式指定不同目录 +python quant/mxfp_scaling_test.py my_tensor.pt --elem-format fp8_e4m3 --output-dir ./results_fp8_e4m3/ +python quant/mxfp_scaling_test.py my_tensor.pt --elem-format fp8_e5m2 --output-dir ./results_fp8_e5m2/ +``` + +### 调试选项 + +使用`--no-plots`选项跳过图表生成以加快测试: +```bash +python quant/mxfp_scaling_test.py input.pt --no-plots +``` + +## 扩展功能 + +工具设计为可扩展的,可以轻松添加: +- 新的量化格式支持 +- 额外的精度指标 +- 不同的可视化选项 +- 批量处理多个文件 diff --git a/docs/TIME_RESUME_ADAPTIVE_QUANTIZATION.md b/docs/TIME_RESUME_ADAPTIVE_QUANTIZATION.md new file mode 100644 index 0000000000..6ee33e6437 --- /dev/null +++ b/docs/TIME_RESUME_ADAPTIVE_QUANTIZATION.md @@ -0,0 +1,194 @@ +# Time-Resume Adaptive Quantization Training + +## Overview + +Time-resume adaptive quantization is an advanced training technique that dynamically switches between quantized (fp8/fp4) and high-precision (bf16) training based on loss thresholds. This approach combines the efficiency of quantized training with the stability of high-precision training. + +## Features + +- **Dynamic Precision Switching**: Automatically switches between quantized and bf16 training based on loss +- **Asynchronous Checkpoint Saving**: Non-blocking checkpoint saves to minimize training interruption +- **Window-based Training**: Organizes training into manageable windows with regular checkpoints +- **Loss-based Thresholds**: Configurable loss thresholds for precision switching +- **Recovery System**: Automatic recovery from the best available checkpoint + +## Command Line Arguments + +### Core Parameters + +- `--time-resume`: Enable time-resume adaptive quantization training +- `--quant-loss-threshold`: Loss threshold for switching from quantized to BF16 training (default: 0.1) +- `--quant-window-size`: Number of iterations per training window (default: 5) +- `--quant-checkpoint-interval`: Checkpoint save interval within windows (default: 1) +- `--quant-fallback-strategy`: Fallback precision when quantized training fails (choices: bf16, fp16, default: bf16) +- `--quant-recovery-buffer`: Number of checkpoints to keep for recovery (default: 2) + +### Additional Parameters + +- `--scaling-control`: Scaling control strategy for MX quantization (choices: max, max_minus_1, default: max) + +## Usage Examples + +### Basic Usage + +```bash +bash script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp4_time_resume.sh \ + "checkpoints/path" \ + "logs/path" \ + "tokenizer/path" \ + "data/path" \ + "bf16" \ + "max_minus_1" \ + "0.1" \ + "5" \ + "1" \ + "bf16" \ + "2" +``` + +### Custom Configuration + +```bash +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + --time-resume \ + --quant-loss-threshold 0.15 \ + --quant-window-size 10 \ + --quant-checkpoint-interval 2 \ + --quant-fallback-strategy bf16 \ + --quant-recovery-buffer 3 \ + --scaling-control max_minus_1 +``` + +## How It Works + +### 1. Initialization + +The adaptive quantization manager is initialized at the start of training with the specified parameters. + +### 2. Training Loop + +During training, the system: +- Monitors loss values +- Compares current loss against the threshold +- Switches precision when needed +- Saves checkpoints asynchronously +- Manages training windows + +### 3. Precision Switching Logic + +- **Switch to BF16**: When loss exceeds threshold for 3 consecutive iterations +- **Switch back to Quantized**: When loss is stable and below 80% of threshold +- **Checkpoint before switch**: Always saves checkpoint before precision changes + +### 4. Window Management + +- Training is organized into windows of specified size +- Checkpoints are saved at window boundaries +- Recovery system maintains multiple checkpoints for rollback + +## Configuration Guidelines + +### Loss Threshold + +- **Too High**: May not catch precision issues early enough +- **Too Low**: May cause unnecessary switching to BF16 +- **Recommended**: 0.1-0.2 for most models + +### Window Size + +- **Small Windows**: More frequent checkpoints, higher overhead +- **Large Windows**: Less overhead, but longer recovery time +- **Recommended**: 5-10 iterations for most cases + +### Checkpoint Interval + +- **Frequent Saves**: Better recovery options, higher I/O overhead +- **Infrequent Saves**: Lower overhead, but limited recovery options +- **Recommended**: 1-2 iterations for critical training + +## Monitoring and Debugging + +### Log Messages + +The system provides detailed logging: + +``` +[TimeResume] Adaptive quantization training enabled +[AdaptiveQuantization] Loss 0.1500 exceeds threshold 0.1000, switching to bf16 +[AdaptiveQuantization] Checkpoint saved: window_1_iter100_bf16 +[AdaptiveQuantization] Loss 0.0800 is stable, switching back to mxfp4 +``` + +### Checkpoint Naming + +Checkpoints are named with descriptive tags: +- `window_1_iter100_quantized`: Window 1, iteration 100, quantized training +- `switch_to_bf16_iter150`: Switch checkpoint at iteration 150 to bf16 +- `window_end_2_iter200`: End of window 2 at iteration 200 + +## Best Practices + +### 1. Initial Setup + +- Start with conservative thresholds (0.1-0.15) +- Use moderate window sizes (5-10 iterations) +- Enable frequent checkpointing during early training + +### 2. Monitoring + +- Watch for frequent precision switching +- Monitor checkpoint storage usage +- Track training stability metrics + +### 3. Optimization + +- Adjust thresholds based on model behavior +- Increase window size for stable models +- Reduce checkpoint frequency for storage-constrained environments + +## Troubleshooting + +### Common Issues + +1. **Frequent Switching**: Lower the loss threshold or increase window size +2. **Storage Issues**: Reduce checkpoint frequency or buffer size +3. **Training Instability**: Increase recovery buffer size + +### Recovery + +If training fails, the system can automatically recover from the most recent checkpoint: + +```python +# Automatic recovery is handled by the manager +adaptive_quantization_manager.load_recovery_checkpoint() +``` + +## Performance Considerations + +### Benefits + +- **Automatic Optimization**: Reduces manual tuning of quantization parameters +- **Fault Tolerance**: Better recovery from training issues +- **Efficiency**: Combines benefits of both quantized and high-precision training + +### Overhead + +- **Checkpoint I/O**: Asynchronous saves minimize impact +- **Memory Usage**: Multiple checkpoints require additional storage +- **Complexity**: Slightly more complex training loop + +## Integration with Existing Features + +Time-resume adaptive quantization works seamlessly with: +- Tensor saving and collection +- Scaling control strategies +- Existing checkpoint systems +- Multi-GPU training setups + +## Future Enhancements + +Potential improvements include: +- Machine learning-based threshold adjustment +- More sophisticated precision selection +- Integration with model compression techniques +- Advanced recovery strategies diff --git a/examples/bert/train_bert_340m_distributed.sh b/examples/bert/train_bert_340m_distributed.sh index f0d9c87c8b..9110688db7 100644 --- a/examples/bert/train_bert_340m_distributed.sh +++ b/examples/bert/train_bert_340m_distributed.sh @@ -76,4 +76,4 @@ torchrun ${DISTRIBUTED_ARGS[@]} pretrain_bert.py \ ${MODEL_PARALLEL_ARGS[@]} \ ${DATA_ARGS[@]} \ ${EVAL_AND_LOGGING_ARGS[@]} - \ No newline at end of file + diff --git a/examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh b/examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh new file mode 100755 index 0000000000..7bc907f0e0 --- /dev/null +++ b/examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh @@ -0,0 +1,204 @@ +#!/bin/bash + +# Environment variables for performance tuning +export CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1} +#export LOG_LEVEL=${LOG_LEVEL:-INFO} +#export NCCL_IB_TIMEOUT=${NCCL_IB_TIMEOUT:-19} +#export NVTE_FWD_LAYERNORM_SM_MARGIN=${NVTE_FWD_LAYERNORM_SM_MARGIN:-16} +#export NVTE_BWD_LAYERNORM_SM_MARGIN=${NVTE_BWD_LAYERNORM_SM_MARGIN:-16} +#export NCCL_P2P_NET_CHUNKSIZE=${NCCL_P2P_NET_CHUNKSIZE:-2097152} +#export NCCL_AVOID_RECORD_STREAMS=${NCCL_AVOID_RECORD_STREAMS:-1} + +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite_fp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_fp8"} +# TOKENIZER_ARG=${3:-"MOCK"} # Path to tokenizer model, or "MOCK" +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} # Path to tokenizer model, or "MOCK" +# DATA_ARG=${4:-"MOCK"} # Data prefix, or "MOCK" +# DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} # Data prefix, or "MOCK" +DATA_ARG=${4:-"dataset/wikitext_processed/wikitext_processed_text_document"} # Data prefix, or "MOCK" + + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# Distributed training setup +GPUS_PER_NODE=8 +NUM_NODES=1 +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-6000} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +# Path to the pretrain_gpt.py script, assuming this script is run from the root of the Megatron-LM repository +PRETRAIN_SCRIPT_PATH="pretrain_gpt.py" + +# Fixed model and training parameters for DeepSeek2-Lite +# DeepSeek2-Lite is a 1.3B parameter model with similar architecture to LLaMA +TP_SIZE=4 +CP_SIZE=1 +PP_SIZE=1 +MICRO_BATCH_SIZE=1 # default 1 +GLOBAL_BATCH_SIZE=128 # default 128 +NUM_LAYERS=16 +# DTYPE="bf16" +DTYPE=${5:-"fp8"} +SEQ_LENGTH=8192 +MAX_POSITION_EMBEDDINGS=8192 + +# Data cache path (useful for both mock and real data) +DATA_CACHE_PATH="${PWD}/benchmark_cache_deepseek2_lite_fp8" +mkdir -p "$DATA_CACHE_PATH" + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NUM_NODES + --node_rank $NODE_RANK + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT +) + +# Model architecture parameters for DeepSeek2-Lite (1.3B parameters) +# Based on typical configurations for models of this size +MODEL_ARGS=( + --use-mcore-models + --num-layers $NUM_LAYERS + --hidden-size 2048 + --ffn-hidden-size 8192 + --num-attention-heads 32 + --group-query-attention + --num-query-groups 8 + --kv-channels 128 + --seq-length $SEQ_LENGTH + --max-position-embeddings $MAX_POSITION_EMBEDDINGS + --position-embedding-type rope + --rotary-base 500000 + --rotary-percent 1.0 + --attention-dropout 0.0 + --hidden-dropout 0.0 + --swiglu + --init-method-std 0.0134 + --attention-backend fused + --apply-layernorm-1p + --untie-embeddings-and-output-weights + --disable-bias-linear + --transformer-impl local +) + +TRAINING_ARGS=( + --micro-batch-size $MICRO_BATCH_SIZE + --global-batch-size $GLOBAL_BATCH_SIZE + --train-samples 47340000 + --lr-decay-samples 47245280 + --lr-warmup-samples 94720 + --lr 0.00015 + --min-lr 0.00001 + --decoupled-lr 5.0e-4 # Specific to decoupled AdamW, ensure optimizer is compatible + --decoupled-min-lr 4.5e-5 # Specific to decoupled AdamW + --lr-decay-style cosine + --clip-grad 1.0 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.95 + --bf16 + --grad-reduce-in-bf16 + --cross-entropy-loss-fusion + --calculate-per-token-loss + --manual-gc + --empty-unused-memory-level 1 + --exit-duration-in-mins 235000000 # default 235 +) + +# Conditional arguments based on DTYPE (FP8) +DTYPE_ARGS=() +if [[ "$DTYPE" == "fp8" ]]; then + DTYPE_ARGS+=( + "--fp8-format hybrid" + "--fp8-amax-history-len 1024" + "--fp8-amax-compute-algo max" + "--fp8-param-gather" + ) +fi + +# Model parallelism arguments +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size $TP_SIZE + --context-parallel-size $CP_SIZE + --pipeline-model-parallel-size $PP_SIZE # Not explicitly set in llama script options, assume 1 if not multi-node PP + --sequence-parallel # Always enable sequence parallelism with TP_SIZE=4 +) + +# Distributed Data Parallel (DDP) arguments +# From original script's ddp_args +DDP_ARGS=( + --use-distributed-optimizer + --overlap-grad-reduce + --overlap-param-gather +) +TRAINING_ARGS+=("${DDP_ARGS[@]}") + + +# Data arguments (conditional for mock vs real data) +DATA_ARGS_LIST=() +if [[ "$TOKENIZER_ARG" == "MOCK" ]] || [[ "$DATA_ARG" == "MOCK" ]] || [[ -z "$TOKENIZER_ARG" ]]; then + DATA_ARGS_LIST+=( + "--mock-data" + "--tokenizer-type NullTokenizer" + "--vocab-size 128256" + "--data-cache-path ${DATA_CACHE_PATH}" + "--tiktoken-pattern v2" + "--split '99,1,0'" + "--no-create-attention-mask-in-dataloader" + "--no-mmap-bin-files" + "--num-workers 1" + ) +else + # Settings for real data + DATA_ARGS_LIST+=( + "--data-path $DATA_ARG" + "--tokenizer-type HuggingFaceTokenizer" + "--tokenizer-model $TOKENIZER_ARG" + "--data-cache-path ${DATA_CACHE_PATH}" + "--split '99,1,0'" + "--no-create-attention-mask-in-dataloader" + # "--no-mmap-bin-files" + "--num-workers 1" + # Note: --vocab-size might be inferred by HuggingFaceTokenizer or might need to be explicit. + "--vocab-size 128256" + ) +fi + +EVAL_AND_LOGGING_ARGS=( + --log-interval 1 + --eval-iters 32 + --eval-interval 100 + --save-interval 1000 + --log-throughput + # --profile + # --profile-step-start 4 + # --profile-step-end 6 + --ckpt-format torch_dist + --distributed-timeout-minutes 120 + --save "$CHECKPOINT_PATH" + --load "$CHECKPOINT_PATH" + --tensorboard-dir "$TENSORBOARD_LOGS_PATH" +) + +# Ensure pretrain_gpt.py is found +if [ ! -f "$PRETRAIN_SCRIPT_PATH" ]; then + echo "Error: pretrain_gpt.py not found at $PRETRAIN_SCRIPT_PATH" + echo "Please ensure you are running this script from the root of the Megatron-LM repository, and pretrain_gpt.py is present." + exit 1 +fi + +# Run the training command +torchrun ${DISTRIBUTED_ARGS[@]} \ + "$PRETRAIN_SCRIPT_PATH" \ + ${MODEL_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${DTYPE_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${DATA_ARGS_LIST[@]} \ + ${EVAL_AND_LOGGING_ARGS[@]} + +set +x diff --git a/examples/llama/train_llama32_1b_h100_fp8.sh b/examples/llama/train_llama32_1b_h100_fp8.sh new file mode 100644 index 0000000000..8d23681dba --- /dev/null +++ b/examples/llama/train_llama32_1b_h100_fp8.sh @@ -0,0 +1,236 @@ +#!/bin/bash + +# Environment variables for performance tuning +export CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1} +#export LOG_LEVEL=${LOG_LEVEL:-INFO} +#export NCCL_IB_TIMEOUT=${NCCL_IB_TIMEOUT:-19} +#export NVTE_FWD_LAYERNORM_SM_MARGIN=${NVTE_FWD_LAYERNORM_SM_MARGIN:-16} +#export NVTE_BWD_LAYERNORM_SM_MARGIN=${NVTE_BWD_LAYERNORM_SM_MARGIN:-16} +#export NCCL_P2P_NET_CHUNKSIZE=${NCCL_P2P_NET_CHUNKSIZE:-2097152} +#export NCCL_AVOID_RECORD_STREAMS=${NCCL_AVOID_RECORD_STREAMS:-1} + +CHECKPOINT_PATH=${1:-"checkpoints/llama3_8b_fp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama3_8b_fp8"} +# TOKENIZER_ARG=${3:-"MOCK"} # Path to tokenizer model, or "MOCK" +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} # Path to tokenizer model, or "MOCK" +# DATA_ARG=${4:-"MOCK"} # Data prefix, or "MOCK" +# DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} # Data prefix, or "MOCK" +DATA_ARG=${4:-"dataset/wikitext_processed/wikitext_processed_text_document"} # Data prefix, or "MOCK" +DTYPE=${5:-"fp8"} + +# Parse additional arguments +EXTRA_ARGS=() +shift 5 # Remove the first 5 positional arguments +while [[ $# -gt 0 ]]; do + case $1 in + --control-iter) + EXTRA_ARGS+=("--control-iter" "$2") + shift 2 + ;; + --save-tensors) + EXTRA_ARGS+=("--save-tensors") + shift + ;; + --tensor-save-dir) + EXTRA_ARGS+=("--tensor-save-dir" "$2") + shift 2 + ;; + # collect_micro_batches参数已移除 + *) + EXTRA_ARGS+=("$1") + shift + ;; + esac +done + + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# Distributed training setup +GPUS_PER_NODE=8 +NUM_NODES=1 +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-6000} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +# Path to the pretrain_gpt.py script, assuming this script is run from the root of the Megatron-LM repository +PRETRAIN_SCRIPT_PATH="pretrain_gpt.py" + +# Fixed model and training parameters +TP_SIZE=4 +CP_SIZE=1 +PP_SIZE=1 +MICRO_BATCH_SIZE=1 # default 1 +GLOBAL_BATCH_SIZE=128 # default 128 +NUM_LAYERS=16 +# DTYPE="bf16" +DTYPE=${5:-"fp8"} +SEQ_LENGTH=8192 +MAX_POSITION_EMBEDDINGS=8192 + +# Data cache path (useful for both mock and real data) +DATA_CACHE_PATH="${PWD}/benchmark_cache_llama3_8b_fp8" +mkdir -p "$DATA_CACHE_PATH" + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NUM_NODES + --node_rank $NODE_RANK + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT +) + +MODEL_ARGS=( + --use-mcore-models + --num-layers $NUM_LAYERS + --hidden-size 2048 + --ffn-hidden-size 8192 + --num-attention-heads 32 + --group-query-attention + --num-query-groups 8 + --kv-channels 128 + --seq-length $SEQ_LENGTH + --max-position-embeddings $MAX_POSITION_EMBEDDINGS + --position-embedding-type rope + --rotary-base 500000 + --rotary-percent 1.0 + --attention-dropout 0.0 + --hidden-dropout 0.0 + --swiglu + --init-method-std 0.0134 + --attention-backend fused + --apply-layernorm-1p + --untie-embeddings-and-output-weights + --disable-bias-linear + --transformer-impl local +) + +TRAINING_ARGS=( + --micro-batch-size $MICRO_BATCH_SIZE + --global-batch-size $GLOBAL_BATCH_SIZE + --train-iters 369844 # 47340000 / 128 (global_batch_size) = 369844 iterations + --lr-decay-iters 369103 # 47245280 / 128 = 369103 iterations + --lr-warmup-iters 740 # 94720 / 128 = 740 iterations + --use-checkpoint-opt_param-scheduler # Use optimizer parameters from checkpoint + --lr 0.00015 + --min-lr 0.00001 + --decoupled-lr 5.0e-4 # Specific to decoupled AdamW, ensure optimizer is compatible + --decoupled-min-lr 4.5e-5 # Specific to decoupled AdamW + --lr-decay-style cosine + --clip-grad 1.0 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.95 + --bf16 + --grad-reduce-in-bf16 + --cross-entropy-loss-fusion + --calculate-per-token-loss + --manual-gc + --empty-unused-memory-level 1 + --exit-duration-in-mins 235000000 # default 235 +) + +# Conditional arguments based on DTYPE (FP8) +DTYPE_ARGS=() +if [[ "$DTYPE" == "fp8" ]]; then + DTYPE_ARGS+=( + "--fp8-format hybrid" + "--fp8-amax-history-len 1024" + "--fp8-amax-compute-algo max" + "--fp8-param-gather" + ) +fi + +# Model parallelism arguments +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size $TP_SIZE + --context-parallel-size $CP_SIZE + --pipeline-model-parallel-size $PP_SIZE # Not explicitly set in llama script options, assume 1 if not multi-node PP + --sequence-parallel # Always enable sequence parallelism with TP_SIZE=2 +) + +# Distributed Data Parallel (DDP) arguments +# From original script's ddp_args +DDP_ARGS=( + --use-distributed-optimizer + --overlap-grad-reduce + --overlap-param-gather +) +TRAINING_ARGS+=("${DDP_ARGS[@]}") + + +# Data arguments (conditional for mock vs real data) +DATA_ARGS_LIST=() +if [[ "$TOKENIZER_ARG" == "MOCK" ]] || [[ "$DATA_ARG" == "MOCK" ]] || [[ -z "$TOKENIZER_ARG" ]]; then + DATA_ARGS_LIST+=( + "--mock-data" + "--tokenizer-type NullTokenizer" + "--vocab-size 128256" + "--data-cache-path ${DATA_CACHE_PATH}" + "--tiktoken-pattern v2" + "--split '99,1,0'" + "--no-create-attention-mask-in-dataloader" + "--no-mmap-bin-files" + "--num-workers 1" + ) +else + # Settings for real data + DATA_ARGS_LIST+=( + "--data-path $DATA_ARG" + "--tokenizer-type HuggingFaceTokenizer" + "--tokenizer-model $TOKENIZER_ARG" + "--data-cache-path ${DATA_CACHE_PATH}" + "--split '99,1,0'" + "--no-create-attention-mask-in-dataloader" + # "--no-mmap-bin-files" + "--num-workers 1" + # Note: --vocab-size might be inferred by HuggingFaceTokenizer or might need to be explicit. + "--vocab-size 128256" + ) +fi + +EVAL_AND_LOGGING_ARGS=( + --log-interval 1 + --eval-iters 32 + --eval-interval 100 + --save-interval 1000 + --log-throughput + # --profile + # --profile-step-start 4 + # --profile-step-end 6 + --ckpt-format torch_dist + --distributed-timeout-minutes 120 + --save "$CHECKPOINT_PATH" + --tensorboard-dir "$TENSORBOARD_LOGS_PATH" +) + +# Only load checkpoint if it exists +if [ -d "$CHECKPOINT_PATH" ] || [ -f "${CHECKPOINT_PATH}_iter_*.pt" ] 2>/dev/null; then + EVAL_AND_LOGGING_ARGS+=(--load "$CHECKPOINT_PATH") + echo "Loading existing checkpoint from: $CHECKPOINT_PATH" +else + echo "Starting fresh training (no checkpoint found at: $CHECKPOINT_PATH)" +fi + +# Ensure pretrain_gpt.py is found +if [ ! -f "$PRETRAIN_SCRIPT_PATH" ]; then + echo "Error: pretrain_gpt.py not found at $PRETRAIN_SCRIPT_PATH" + echo "Please ensure you are running this script from the root of the Megatron-LM repository, and pretrain_gpt.py is present." + exit 1 +fi + +# Run the training command +torchrun ${DISTRIBUTED_ARGS[@]} \ + "$PRETRAIN_SCRIPT_PATH" \ + ${MODEL_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${DTYPE_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${DATA_ARGS_LIST[@]} \ + ${EVAL_AND_LOGGING_ARGS[@]} \ + ${EXTRA_ARGS[@]} + +set +x diff --git a/examples/llama/train_llama3_8b_h100_fp8.sh b/examples/llama/train_llama3_8b_h100_fp8.sh index f791996308..da9042b298 100644 --- a/examples/llama/train_llama3_8b_h100_fp8.sh +++ b/examples/llama/train_llama3_8b_h100_fp8.sh @@ -11,8 +11,12 @@ export CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1} CHECKPOINT_PATH=${1:-"checkpoints/llama3_8b_fp8"} TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama3_8b_fp8"} -TOKENIZER_ARG=${3:-"MOCK"} # Path to tokenizer model, or "MOCK" -DATA_ARG=${4:-"MOCK"} # Data prefix, or "MOCK" +# TOKENIZER_ARG=${3:-"MOCK"} # Path to tokenizer model, or "MOCK" +TOKENIZER_ARG=${3:-"model/llama3"} # Path to tokenizer model, or "MOCK" +# DATA_ARG=${4:-"MOCK"} # Data prefix, or "MOCK" +# DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} # Data prefix, or "MOCK" +DATA_ARG=${4:-"dataset/wikitext_processed/wikitext_processed_text_document"} # Data prefix, or "MOCK" + # Create directories if they don't exist mkdir -p "$(dirname "$CHECKPOINT_PATH")" @@ -30,13 +34,14 @@ WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) PRETRAIN_SCRIPT_PATH="pretrain_gpt.py" # Fixed model and training parameters -TP_SIZE=1 +TP_SIZE=2 CP_SIZE=1 -PP_SIZE=1 -MICRO_BATCH_SIZE=1 -GLOBAL_BATCH_SIZE=128 +PP_SIZE=4 +MICRO_BATCH_SIZE=1 # default 1 +GLOBAL_BATCH_SIZE=128 # default 128 NUM_LAYERS=32 -DTYPE="fp8" +# DTYPE="bf16" +DTYPE=${5:-"fp8"} SEQ_LENGTH=8192 MAX_POSITION_EMBEDDINGS=8192 @@ -79,9 +84,9 @@ MODEL_ARGS=( TRAINING_ARGS=( --micro-batch-size $MICRO_BATCH_SIZE --global-batch-size $GLOBAL_BATCH_SIZE - --train-samples 1953125000 - --lr-decay-samples 1949218748 - --lr-warmup-samples 3906252 + --train-samples 47340000 + --lr-decay-samples 47245280 + --lr-warmup-samples 94720 --lr 0.00015 --min-lr 0.00001 --decoupled-lr 5.0e-4 # Specific to decoupled AdamW, ensure optimizer is compatible @@ -97,7 +102,7 @@ TRAINING_ARGS=( --calculate-per-token-loss --manual-gc --empty-unused-memory-level 1 - --exit-duration-in-mins 235 + --exit-duration-in-mins 235000000 # default 235 ) # Conditional arguments based on DTYPE (FP8) @@ -115,7 +120,7 @@ fi MODEL_PARALLEL_ARGS=( --tensor-model-parallel-size $TP_SIZE --context-parallel-size $CP_SIZE - # --pipeline-model-parallel-size $PP_SIZE # Not explicitly set in llama script options, assume 1 if not multi-node PP + --pipeline-model-parallel-size $PP_SIZE # Not explicitly set in llama script options, assume 1 if not multi-node PP --sequence-parallel # Always enable sequence parallelism with TP_SIZE=2 ) @@ -152,7 +157,7 @@ else "--data-cache-path ${DATA_CACHE_PATH}" "--split '99,1,0'" "--no-create-attention-mask-in-dataloader" - "--no-mmap-bin-files" + # "--no-mmap-bin-files" "--num-workers 1" # Note: --vocab-size might be inferred by HuggingFaceTokenizer or might need to be explicit. "--vocab-size 128256" @@ -192,4 +197,4 @@ torchrun ${DISTRIBUTED_ARGS[@]} \ ${DATA_ARGS_LIST[@]} \ ${EVAL_AND_LOGGING_ARGS[@]} -set +x \ No newline at end of file +set +x diff --git a/megatron/core/adaptive_quantization.py b/megatron/core/adaptive_quantization.py new file mode 100644 index 0000000000..eb2162eefd --- /dev/null +++ b/megatron/core/adaptive_quantization.py @@ -0,0 +1,314 @@ +""" +Adaptive Quantization Training Manager + +This module implements time-resume adaptive quantization training that dynamically +switches between quantized (fp8/fp4) and high-precision (bf16) training based on loss. +""" + +import os +import threading +import time +from collections import deque +from typing import Dict, List, Optional, Tuple, Any +import torch +import torch.distributed as dist + +from megatron.core import parallel_state +from megatron.training.checkpointing import save_checkpoint, load_checkpoint + + +class AdaptiveQuantizationManager: + """ + Manages adaptive quantization training with time-resume capability. + + Features: + - Dynamic switching between quantized and high-precision training + - Asynchronous checkpoint saving + - Loss-based threshold triggering + - Window-based training management + """ + + def __init__(self, args, model, optimizer, opt_param_scheduler, iteration, num_floating_point_operations_so_far): + self.args = args + self.model = model + self.optimizer = optimizer + self.opt_param_scheduler = opt_param_scheduler + self.iteration = iteration + self.num_floating_point_operations_so_far = num_floating_point_operations_so_far + + # Time-resume parameters + self.enabled = getattr(args, 'time_resume', False) + if not self.enabled: + return + + self.loss_threshold = getattr(args, 'quant_loss_threshold', 0.1) + self.window_size = getattr(args, 'quant_window_size', 5) + self.checkpoint_interval = getattr(args, 'quant_checkpoint_interval', 1) + self.fallback_strategy = getattr(args, 'quant_fallback_strategy', 'bf16') + self.recovery_buffer_size = getattr(args, 'quant_recovery_buffer', 2) + + # State management + self.current_precision = 'quantized' # 'quantized' or 'bf16' + self.window_iterations = 0 + self.window_start_iteration = iteration + self.last_checkpoint_iteration = iteration + self.checkpoint_queue = deque(maxlen=self.recovery_buffer_size) + self.loss_history = deque(maxlen=10) # Keep last 10 losses + + # Asynchronous checkpoint saving + self.checkpoint_thread = None + self.checkpoint_lock = threading.Lock() + self.pending_checkpoints = [] + + # Quantization type management + self.quant_types = ['mxfp4', 'mxfp8', 'hifp8'] + self.current_quant_type = 'mxfp4' + + if parallel_state.get_tensor_model_parallel_rank() == 0: + print(f"[AdaptiveQuantization] Initialized with window_size={self.window_size}, " + f"threshold={self.loss_threshold}, fallback={self.fallback_strategy}") + + def should_save_checkpoint(self, iteration: int) -> bool: + """Check if we should save a checkpoint at this iteration.""" + if not self.enabled: + return False + + return (iteration - self.last_checkpoint_iteration) >= self.checkpoint_interval + + def should_switch_precision(self, current_loss: float) -> Tuple[bool, str]: + """ + Determine if we should switch training precision based on loss. + + Returns: + (should_switch, new_precision) + """ + if not self.enabled: + return False, self.current_precision + + self.loss_history.append(current_loss) + + if len(self.loss_history) < 3: + return False, self.current_precision + + # Calculate recent loss trend + recent_losses = list(self.loss_history)[-3:] + avg_recent_loss = sum(recent_losses) / len(recent_losses) + + # Switch to BF16 if loss exceeds threshold + if self.current_precision == 'quantized' and avg_recent_loss > self.loss_threshold: + if parallel_state.get_tensor_model_parallel_rank() == 0: + print(f"[AdaptiveQuantization] Loss {avg_recent_loss:.4f} exceeds threshold {self.loss_threshold:.4f}, " + f"switching to {self.fallback_strategy}") + return True, self.fallback_strategy + + # Switch back to quantized if loss is stable and low + elif self.current_precision == self.fallback_strategy and avg_recent_loss < self.loss_threshold * 0.8: + if parallel_state.get_tensor_model_parallel_rank() == 0: + print(f"[AdaptiveQuantization] Loss {avg_recent_loss:.4f} is stable, " + f"switching back to {self.current_quant_type}") + return True, 'quantized' + + return False, self.current_precision + + def save_checkpoint_async(self, iteration: int, tag: str = None): + """Save checkpoint asynchronously to avoid blocking training.""" + if not self.enabled: + return + + if self.checkpoint_thread and self.checkpoint_thread.is_alive(): + # Wait for previous checkpoint to complete + self.checkpoint_thread.join() + + checkpoint_info = { + 'iteration': iteration, + 'tag': tag or f"window_{iteration // self.window_size}", + 'precision': self.current_precision, + 'quant_type': self.current_quant_type, + 'timestamp': time.time() + } + + self.checkpoint_thread = threading.Thread( + target=self._save_checkpoint_worker, + args=(checkpoint_info,) + ) + self.checkpoint_thread.start() + + self.last_checkpoint_iteration = iteration + self.checkpoint_queue.append(checkpoint_info) + + def _save_checkpoint_worker(self, checkpoint_info: Dict[str, Any]): + """Worker function for asynchronous checkpoint saving.""" + try: + with self.checkpoint_lock: + # Save checkpoint with timestamp + checkpoint_name = f"{checkpoint_info['tag']}_iter{checkpoint_info['iteration']}_{checkpoint_info['precision']}" + + # Temporarily modify args.save to include checkpoint name + original_save = getattr(self.args, 'save', None) + + # Create a unique checkpoint directory for this save + if original_save: + checkpoint_dir = f"{original_save}_{checkpoint_info['tag']}_iter{checkpoint_info['iteration']}_{checkpoint_info['precision']}" + self.args.save = checkpoint_dir + + save_checkpoint( + iteration=checkpoint_info['iteration'], + model=self.model, + optimizer=self.optimizer, + opt_param_scheduler=self.opt_param_scheduler, + num_floating_point_operations_so_far=self.num_floating_point_operations_so_far + ) + + # Restore original save path + self.args.save = original_save + + if parallel_state.get_tensor_model_parallel_rank() == 0: + print(f"[AdaptiveQuantization] Checkpoint saved: {checkpoint_name}") + + except Exception as e: + if parallel_state.get_tensor_model_parallel_rank() == 0: + print(f"[AdaptiveQuantization] Error saving checkpoint: {e}") + + def load_recovery_checkpoint(self, target_iteration: int = None) -> bool: + """ + Load the most recent checkpoint for recovery. + + Args: + target_iteration: Specific iteration to load, or None for most recent + + Returns: + True if checkpoint was loaded successfully + """ + if not self.enabled or not self.checkpoint_queue: + return False + + # Find the best checkpoint to load + if target_iteration is not None: + # Find checkpoint closest to target iteration + best_checkpoint = None + min_diff = float('inf') + for checkpoint in self.checkpoint_queue: + diff = abs(checkpoint['iteration'] - target_iteration) + if diff < min_diff: + min_diff = diff + best_checkpoint = checkpoint + else: + # Load most recent checkpoint + best_checkpoint = self.checkpoint_queue[-1] + + if best_checkpoint is None: + return False + + try: + checkpoint_name = f"{best_checkpoint['tag']}_iter{best_checkpoint['iteration']}_{best_checkpoint['precision']}" + + # Temporarily modify args.load to point to the checkpoint directory + original_load = getattr(self.args, 'load', None) + + # Set the load path to the specific checkpoint directory + if hasattr(self.args, 'save') and self.args.save: + # Construct the checkpoint directory path + checkpoint_dir = f"{self.args.save}_{checkpoint_name}" + self.args.load = checkpoint_dir + + # Load checkpoint + iteration, num_floating_point_operations_so_far = load_checkpoint( + model=self.model, + optimizer=self.optimizer, + opt_param_scheduler=self.opt_param_scheduler, + load_arg='load' + ) + + # Restore original load path + self.args.load = original_load + + # Update state + self.iteration = iteration + self.num_floating_point_operations_so_far = num_floating_point_operations_so_far + self.current_precision = best_checkpoint['precision'] + self.current_quant_type = best_checkpoint.get('quant_type', 'mxfp4') + + if parallel_state.get_tensor_model_parallel_rank() == 0: + print(f"[AdaptiveQuantization] Loaded checkpoint: {checkpoint_name}, " + f"iteration={iteration}, precision={self.current_precision}") + + return True + + except Exception as e: + if parallel_state.get_tensor_model_parallel_rank() == 0: + print(f"[AdaptiveQuantization] Error loading checkpoint: {e}") + return False + + def update_window_state(self, iteration: int): + """Update window state and handle window transitions.""" + if not self.enabled: + return + + self.window_iterations += 1 + + # Check if we've completed a window + if self.window_iterations >= self.window_size: + # Save window checkpoint + self.save_checkpoint_async(iteration, f"window_end_{iteration // self.window_size}") + + # Reset window state + self.window_iterations = 0 + self.window_start_iteration = iteration + + if parallel_state.get_tensor_model_parallel_rank() == 0: + print(f"[AdaptiveQuantization] Completed window, saved checkpoint at iteration {iteration}") + + def get_current_quantization_type(self) -> str: + """Get the current quantization type for training.""" + if not self.enabled: + return 'hifp8' # Default + + if self.current_precision == 'quantized': + return self.current_quant_type + else: + return 'bf16' + + def set_quantization_type(self, quant_type: str): + """Set the quantization type for quantized training.""" + if not self.enabled: + return + + if quant_type in self.quant_types + ['bf16']: + self.current_quant_type = quant_type + if parallel_state.get_tensor_model_parallel_rank() == 0: + print(f"[AdaptiveQuantization] Set quantization type to {quant_type}") + + def get_training_state(self) -> Dict[str, Any]: + """Get current training state for logging.""" + if not self.enabled: + return {} + + return { + 'precision': self.current_precision, + 'quant_type': self.current_quant_type, + 'window_iterations': self.window_iterations, + 'window_start': self.window_start_iteration, + 'recent_losses': list(self.loss_history)[-5:] if self.loss_history else [], + 'checkpoints_available': len(self.checkpoint_queue) + } + + def finalize(self): + """Clean up resources and save final checkpoint.""" + if not self.enabled: + return + + # Wait for any pending checkpoint saves + if self.checkpoint_thread and self.checkpoint_thread.is_alive(): + self.checkpoint_thread.join() + + if parallel_state.get_tensor_model_parallel_rank() == 0: + print("[AdaptiveQuantization] Finalized adaptive quantization training") + + +def get_adaptive_quantization_manager(args, model, optimizer, opt_param_scheduler, + iteration, num_floating_point_operations_so_far): + """Factory function to create adaptive quantization manager.""" + return AdaptiveQuantizationManager( + args, model, optimizer, opt_param_scheduler, + iteration, num_floating_point_operations_so_far + ) diff --git a/megatron/core/datasets/.blended_megatron_dataset_builder.py.swp b/megatron/core/datasets/.blended_megatron_dataset_builder.py.swp new file mode 100644 index 0000000000..3e5e27bcd6 Binary files /dev/null and b/megatron/core/datasets/.blended_megatron_dataset_builder.py.swp differ diff --git a/megatron/core/datasets/blended_megatron_dataset_builder.py b/megatron/core/datasets/blended_megatron_dataset_builder.py index ba47428654..3f1b9e89ef 100644 --- a/megatron/core/datasets/blended_megatron_dataset_builder.py +++ b/megatron/core/datasets/blended_megatron_dataset_builder.py @@ -128,7 +128,7 @@ def build(self) -> List[Optional[TopLevelDataset]]: split """ datasets = self._build_blended_dataset_splits() - + # import pdb;pdb.set_trace() for dataset in datasets: if dataset is not None and len(dataset) > 0: if isinstance(dataset, BlendedDataset): diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 7ea63df805..c4c3863ad5 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -90,6 +90,7 @@ def __init__( index_split: Split, config: GPTDatasetConfig, ) -> None: + # import pdb;pdb.set_trace() super().__init__( indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config ) @@ -326,12 +327,15 @@ def _build_document_sample_shuffle_indices( Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: The document index, the sample index, and the shuffle index """ + # print_rank_0("in line 330") path_to_cache = self.config.path_to_cache if path_to_cache is None and not self.config.mock: path_to_cache = os.path.join( self.dataset.path_prefix, "cache", f"{type(self).__name__}_indices" ) - + + # import pdb;pdb.set_trace() + # log_single_rank(logger,logging.INFO,"get line 338 in gpt_dataset.py") if path_to_cache: base = f"{self.unique_description_hash}-{type(self).__name__}-{self.index_split.name}" get_path_to = lambda affix: os.path.join(path_to_cache, f"{base}-{affix}") @@ -352,6 +356,8 @@ def _build_document_sample_shuffle_indices( ) else: cache_hit = False + # import pdb;pdb.set_trace() + # log_single_rank(logger,logging.INFO,"get line 358 in gpt_dataset.py") if not path_to_cache or ( not cache_hit @@ -410,11 +416,12 @@ def _build_document_sample_shuffle_indices( numpy_random_state = numpy.random.RandomState(self.config.random_seed) # Build the document index + # import pdb;pdb.set_trace() document_index = _build_document_index( self.indices, num_epochs, numpy_random_state, separate_final_epoch ) - # Build the sample index + from megatron.core.datasets import helpers if self.index_split == Split.valid: @@ -435,6 +442,10 @@ def _build_document_sample_shuffle_indices( sequence_lengths_for_cpp = self.dataset.sequence_lengths.copy() else: sequence_lengths_for_cpp = self.dataset.sequence_lengths + + # import pdb;pdb.set_trace() + # log_single_rank(logger,logging.INFO,"get line 450 in gpt_dataset.py") + sample_index = helpers.build_sample_idx( sequence_lengths_for_cpp, document_index, @@ -445,6 +456,9 @@ def _build_document_sample_shuffle_indices( self.config.add_extra_token_to_sequence, ) + # import pdb;pdb.set_trace() + # log_single_rank(logger,logging.INFO,"get line 463 in gpt_dataset.py") + # Build the shuffle index if separate_final_epoch: shuffle_index = _build_shuffle_index( @@ -454,7 +468,7 @@ def _build_document_sample_shuffle_indices( shuffle_index = _build_shuffle_index( sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state ) - + # import pdb;pdb.set_trace() if path_to_cache: os.makedirs(path_to_cache, exist_ok=True) # Write the description @@ -526,6 +540,7 @@ def _get_num_tokens_per_epoch(self) -> int: Returns: int: The number of tokens in a single epoch """ + # import pdb;pdb.set_trace() return int(numpy.sum(self.dataset.sequence_lengths[self.indices])) def _get_num_epochs(self, num_tokens_per_epoch: int) -> int: @@ -572,6 +587,7 @@ def _build_document_index( numpy.ndarray: The document index """ + # import pdb;pdb.set_trace() if not separate_final_epoch or num_epochs == 1: document_index = numpy.mgrid[0:num_epochs, 0 : len(documents)][1] document_index[:] = documents diff --git a/megatron/core/datasets/indexed_dataset.py b/megatron/core/datasets/indexed_dataset.py index 95e6016fa5..a818b3a193 100644 --- a/megatron/core/datasets/indexed_dataset.py +++ b/megatron/core/datasets/indexed_dataset.py @@ -622,6 +622,7 @@ def initialize( object_storage_config (Optional[ObjectStorageConfig]): See IndexedDataset docstring for details. """ + # import pdb;pdb.set_trace() idx_path = get_idx_path(path_prefix) bin_path = get_bin_path(path_prefix) if object_storage_config is None: diff --git a/megatron/core/extensions/.transformer_engine.py.swo b/megatron/core/extensions/.transformer_engine.py.swo new file mode 100644 index 0000000000..fff2a0137f Binary files /dev/null and b/megatron/core/extensions/.transformer_engine.py.swo differ diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 88a6c131c5..cdffdf1257 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -419,6 +419,7 @@ def forward(self, x): _is_first_microbatch = ( None if self.disable_parameter_transpose_cache else self.is_first_microbatch ) + # import pdb; pdb.set_trace() out = super().forward(x, is_first_microbatch=_is_first_microbatch) self.is_first_microbatch = False diff --git a/megatron/core/fp8_utils.py b/megatron/core/fp8_utils.py index d417ccdb4a..1aa3ef3bee 100644 --- a/megatron/core/fp8_utils.py +++ b/megatron/core/fp8_utils.py @@ -511,8 +511,8 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool fp8_group = parallel_state.get_amax_reduction_group( with_context_parallel=True, tp_only_amax_red=config.tp_only_amax_red ) - if not is_init: + # import pdb;pdb.set_trace() fp8_context = transformer_engine.pytorch.fp8_autocast( enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 6aec66e6dc..ccf98279a8 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -209,7 +209,7 @@ def __init__( else: self.embedding_activation_buffer = None self.grad_output_buffer = None - + # import pdb;pdb.set_trace() self.output_layer = tensor_parallel.ColumnParallelLinear( config.hidden_size, self.vocab_size, @@ -223,6 +223,7 @@ def __init__( embedding_activation_buffer=self.embedding_activation_buffer, grad_output_buffer=self.grad_output_buffer, tp_group=self.model_comm_pgs.tp, + mxfp_quant = False ) if self.pre_process or self.post_process: diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 0c4fbf9b92..b57537685b 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -389,7 +389,7 @@ def forward_step( set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor") set_input_tensor(input_tensor) - + if config.enable_autocast: context_manager = torch.autocast("cuda", dtype=config.autocast_dtype) else: @@ -610,10 +610,30 @@ def forward_backward_no_pipelining( current_microbatch=i, ) total_num_tokens += num_tokens + + # Check if should exit after completing a full forward pass + try: + from megatron.core.tensor_saver import get_tensor_saver + tensor_saver = get_tensor_saver() + if tensor_saver.enabled and not tensor_saver.tensor_collected_in_warmup: + # In no-pipelining mode, collect tensor in first microbatch + tensor_saver.mark_warmup_collection() + except Exception as e: + pass # Silently ignore tensor saver errors + if not forward_only: backward_step( input_tensor, output_tensor, output_tensor_grad, model_type, config ) + + # Check if should exit after backward completion + try: + from megatron.core.tensor_saver import get_tensor_saver + tensor_saver = get_tensor_saver() + if tensor_saver.should_exit_after_backward(): + break + except Exception as e: + pass # Silently ignore tensor saver errors # Run computation for last microbatch out of context handler (want to # synchronize gradients). output_tensor, num_tokens = forward_step( @@ -633,6 +653,16 @@ def forward_backward_no_pipelining( ) total_num_tokens += num_tokens + + # Check tensor collection for last microbatch + try: + from megatron.core.tensor_saver import get_tensor_saver + tensor_saver = get_tensor_saver() + if tensor_saver.enabled and not tensor_saver.tensor_collected_in_warmup: + # In no-pipelining mode, collect tensor in first microbatch + tensor_saver.mark_warmup_collection() + except Exception as e: + pass # Silently ignore tensor saver errors if not forward_only: backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) @@ -1225,6 +1255,17 @@ def forward_step_helper(virtual_microbatch_id, checkpoint_activations_microbatch ) forward_step_helper_postprocess(model_chunk_id, output_tensor, num_tokens) + + # 检查是否应该退出(在完成一个完整forward pass后) + try: + from megatron.core.tensor_saver import get_tensor_saver + tensor_saver = get_tensor_saver() + if tensor_saver.should_exit_after_forward(): + print(f"[Pipeline] 已完成tensor收集,退出interleaving训练循环") + # 返回None表示需要提前退出 + return None + except Exception as e: + print(f"[Pipeline] Warning: 无法检查tensor saver状态: {e}") return output_tensor @@ -1336,6 +1377,10 @@ def forward_backward_helper_wrapper( forward_output_tensor = forward_step_helper( f_virtual_microbatch_id, checkpoint_activations_microbatch ) + # 检查是否需要提前退出 + if forward_output_tensor is None: + print(f"[Pipeline] 检测到提前退出信号,停止interleaving训练循环") + return forward_output_tensor, backward_input_tensor_grad if post_forward is not None: forward_output_tensor = post_forward(forward_output_tensor) @@ -2145,6 +2190,18 @@ def enable_grad_sync(): ) p2p_communicator.send_forward(output_tensor, is_pp_last_stage(p2p_communicator.pp_group)) total_num_tokens += num_tokens + + # 检查是否应该退出(在完成一个完整forward pass后) + try: + from megatron.core.tensor_saver import get_tensor_saver + tensor_saver = get_tensor_saver() + if tensor_saver.enabled and not tensor_saver.tensor_collected_in_warmup: + # 在warmup阶段收集tensor + tensor_saver.mark_warmup_collection() + # 注意:这里不立即标记collection_completed,让tensor实际保存后再标记 + print(f"[Pipeline] 已在warmup阶段开始收集tensor,继续运行以完成收集") + except Exception as e: + print(f"[Pipeline] Warning: 无法检查tensor saver状态: {e}") if not forward_only: input_tensors.append(input_tensor) @@ -2189,6 +2246,17 @@ def enable_grad_sync(): is_last_stage=is_pp_last_stage(p2p_communicator.pp_group), ) total_num_tokens += num_tokens + # import pdb;pdb.set_trace() + + # 检查是否应该退出(tensor已在warmup阶段收集完成) + try: + from megatron.core.tensor_saver import get_tensor_saver + tensor_saver = get_tensor_saver() + if tensor_saver.should_exit_after_forward(): + print(f"[Pipeline] Tensor收集已完成,退出训练循环") + break + except Exception as e: + print(f"[Pipeline] Warning: 无法检查tensor saver状态: {e}") if forward_only: p2p_communicator.send_forward( diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 5646f50ca4..235c9031a9 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -476,8 +476,8 @@ def forward( total_input = all_gather_buffer else: total_input = input - output = torch.matmul(total_input, weight.t()) + if bias is not None: output = output + bias return output @@ -520,6 +520,7 @@ def backward(ctx, grad_output): total_input = all_gather_buffer else: total_input = input + # grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight) if ctx.sequence_parallel and wgrad_compute: @@ -727,12 +728,488 @@ def linear_with_grad_accumulation_and_async_allreduce( ) linear_with_grad_accumulation_and_async_allreduce.warned = True - return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) + return CustomLinearWithGradAccumulationAndAsyncCommunication.apply(*args) linear_with_grad_accumulation_and_async_allreduce.warned = False +class CustomLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + """See custom_linear_with_grad_accumulation_and_async_allreduce""" + + @staticmethod + @custom_fwd + def forward( + ctx, + input, + weight, + bias, + gradient_accumulation_fusion, + allreduce_dgrad, + sequence_parallel, + grad_output_buffer, + wgrad_deferral_limit, + tp_group, + ): + """Forward.""" + if gradient_accumulation_fusion and hasattr(weight, "main_grad"): + main_grad = weight.main_grad + else: + main_grad = None + ctx.save_for_backward(input, weight) + # We can't save main_grad in save_for_backward as this module would be + # reused across layers like MTP logits. So, to prevent in-place modification + # checks we save the tensor in ctx. + ctx.main_grad = main_grad + ctx.use_bias = bias is not None + ctx.gradient_accumulation_fusion = gradient_accumulation_fusion + ctx.allreduce_dgrad = allreduce_dgrad + ctx.sequence_parallel = sequence_parallel + ctx.wgrad_deferral_limit = wgrad_deferral_limit + ctx.grad_output_buffer = grad_output_buffer + ctx.tp_group = tp_group + + if sequence_parallel: + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * tp_group.size() + + all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + dist_all_gather_func(all_gather_buffer, input, group=tp_group) + total_input = all_gather_buffer + else: + total_input = input + + # 使用集成了tensor保存功能的量化算子 + import os + import inspect + + # 获取量化类型(支持time-resume自适应量化) + custom_quant_type = 'hifp8' + if hasattr(ctx, '_adaptive_quantization_manager'): + custom_quant_type = ctx._adaptive_quantization_manager.get_current_quantization_type() + + # 获取scaling_control参数 + scaling_control = getattr(ctx, 'scaling_control', 'max') + + # 尝试从调用栈获取layer_idx + layer_idx = getattr(ctx, 'layer_idx', None) + if layer_idx is None: + # 尝试从调用栈中获取layer信息 + try: + frame = inspect.currentframe() + while frame: + frame = frame.f_back + if frame and 'self' in frame.f_locals: + self_obj = frame.f_locals['self'] + if hasattr(self_obj, 'layer_number'): + layer_idx = self_obj.layer_number + break + elif hasattr(self_obj, 'layer_idx'): + layer_idx = self_obj.layer_idx + break + except: + pass + + # 获取rank信息 + rank = None + try: + import torch.distributed as dist + if dist.is_initialized(): + rank = dist.get_rank() + except: + pass + + if rank is None: + rank = int(os.environ.get("LOCAL_RANK", 0)) + + # 保存layer_idx到ctx中,供backward阶段使用 + ctx.layer_idx = layer_idx + + # 准备tensor保存参数 + tensor_save_params = { + "layer_type": "linear", + "layer_idx": layer_idx, + "operation": "forward", + "phase": "pre", + "component": "linear", + "rank": rank, + "metadata": { + "sequence_parallel": sequence_parallel, + "use_bias": ctx.use_bias, + "tp_group_size": tp_group.size() if tp_group else None, + } + } + + # 使用集成了tensor保存功能的算子 + from quant.mxfp import mxfp_matmul + from quant.hifp import hifp_matmul + from quant.bf16_operators import bf16_matmul + + if custom_quant_type == 'mxfp4': + output = mxfp_matmul( + total_input, weight.t(), + elem_format='fp4_e2m1', + block_size=32, + scaling_control=scaling_control, + **tensor_save_params + ) + elif custom_quant_type == 'mxfp8': + output = mxfp_matmul( + total_input, weight.t(), + elem_format='fp8_e4m3', + block_size=32, + scaling_control=scaling_control, + **tensor_save_params + ) + elif custom_quant_type == 'hifp8': + output = hifp_matmul( + total_input, weight.t(), + **tensor_save_params + ) + elif custom_quant_type == 'bf16': + output = bf16_matmul( + total_input, weight.t(), + **tensor_save_params + ) + else: + # 对于其他类型,使用BF16算子 + output = bf16_matmul( + total_input, weight.t(), + **tensor_save_params + ) + if bias is not None: + output = output + bias + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + """Backward.""" + input, weight = ctx.saved_tensors + main_grad = ctx.main_grad + use_bias = ctx.use_bias + grad_output_buffer = ctx.grad_output_buffer + wgrad_deferral_limit = ctx.wgrad_deferral_limit + handle = None + tp_group = ctx.tp_group + + if ctx.gradient_accumulation_fusion: + weight.main_grad = main_grad + + wgrad_compute = True + if grad_output_buffer is not None: + if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit: + grad_output_buffer.append(grad_output) + wgrad_compute = False + + if wgrad_compute: + if ctx.sequence_parallel: + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * tp_group.size() + + all_gather_buffer = get_global_memory_buffer().get_tensor( + dim_size, input.dtype, "mpu" + ) + handle = dist_all_gather_func( + all_gather_buffer, input, group=tp_group, async_op=True + ) + + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # gather is scheduled before the input gradient computation + total_input = all_gather_buffer + else: + total_input = input + + # 保存backward输入tensor (pre-linear) + from megatron.core.tensor_saver import save_linear_tensors + import os + import inspect + custom_quant_type = 'hifp8' + + # 尝试从调用栈获取layer_idx(与forward阶段保持一致) + layer_idx = getattr(ctx, 'layer_idx', None) + if layer_idx is None: + try: + # 从调用栈中查找layer_idx + frame = inspect.currentframe() + while frame: + frame = frame.f_back + if frame: + local_vars = frame.f_locals + if 'self' in local_vars: + self_obj = local_vars['self'] + if hasattr(self_obj, 'layer_number'): + layer_idx = self_obj.layer_number + break + elif hasattr(self_obj, 'layer_idx'): + layer_idx = self_obj.layer_idx + break + except: + pass + + save_linear_tensors( + input_tensor=grad_output, + weight=weight, + quant_type=custom_quant_type, + operation="backward", + layer_idx=layer_idx, + phase="pre", + component="linear", + metadata={ + "sequence_parallel": ctx.sequence_parallel, + "wgrad_compute": wgrad_compute, + "tp_group_size": tp_group.size() if tp_group else None, + } + ) + + from quant.mxfp import mxfp_matmul + from quant.hifp import hifp_matmul + custom_quant_type = 'hifp8' + if custom_quant_type == 'mxfp4': + grad_input = mxfp_matmul(grad_output,weight,'fp4_e2m1').to(torch.bfloat16) + elif custom_quant_type == 'mxfp8': + grad_input = mxfp_matmul(grad_output,weight,'fp8_e5m2').to(torch.bfloat16) + elif custom_quant_type == 'hifp8': + grad_input = hifp_matmul(grad_output,weight).to(torch.bfloat16) + else: + grad_input = grad_output.matmul(weight) + + if ctx.sequence_parallel and wgrad_compute: + # pylint: disable=possibly-used-before-assignment + handle.wait() + + if wgrad_compute: + grad_output, total_input = prepare_input_tensors_for_wgrad_compute( + grad_output, total_input + ) + + if ctx.allreduce_dgrad: + # Asynchronous all-reduce + handle = torch.distributed.all_reduce(grad_input, group=tp_group, async_op=True) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # all-reduce is scheduled before the weight gradient computation + + if ctx.sequence_parallel: + assert not ctx.allreduce_dgrad + dim_size = list(input.size()) + sub_grad_input = torch.empty( + dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False + ) + # reduce_scatter + handle = dist_reduce_scatter_func( + sub_grad_input, grad_input, group=tp_group, async_op=True + ) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # reduce scatter is scheduled before the weight gradient computation + + if ctx.gradient_accumulation_fusion: + if wgrad_compute: + if weight.main_grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( + total_input, grad_output, weight.main_grad + ) + elif weight.main_grad.dtype in (torch.float16, torch.bfloat16): + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( + total_input, grad_output, weight.main_grad + ) + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + + if hasattr(weight, "grad_added_to_main_grad"): + # When overlap_grad_reduce is True, need to ensure that backward hooks + # are all run on the main backprop thread to prevent deadlocks. Setup + # dummy grad_weight tensor to prevent backward hooks from being run + # in a background thread. + if getattr(weight, "zero_out_wgrad", False): + grad_weight = torch.zeros( + weight.main_grad.shape, + dtype=input.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + grad_weight = torch.empty( + weight.main_grad.shape, + dtype=input.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + weight.grad_added_to_main_grad = True + else: + grad_weight = None + else: + # 对于梯度计算,使用带tensor保存的算子 + import os + + # 获取rank信息 + rank = None + try: + import torch.distributed as dist + if dist.is_initialized(): + rank = dist.get_rank() + except: + pass + + if rank is None: + rank = int(os.environ.get("LOCAL_RANK", 0)) + + # 准备tensor保存参数 + tensor_save_params = { + "layer_type": "linear", + "layer_idx": getattr(ctx, 'layer_idx', None), + "operation": "backward", + "phase": "post", + "component": "linear", + "rank": rank, + "metadata": { + "sequence_parallel": ctx.sequence_parallel, + "wgrad_compute": wgrad_compute, + "tp_group_size": tp_group.size() if tp_group else None, + } + } + + # 使用BF16算子计算梯度并自动保存 + from quant.bf16_operators import bf16_matmul + grad_weight = bf16_matmul( + grad_output.t(), total_input, + **tensor_save_params + ) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.sequence_parallel: + handle.wait() + # Need to return None's as gradient has to flow for all the input arguments + # provided during forward + return (sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None) + + if ctx.allreduce_dgrad: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + + +def custom_linear_with_grad_accumulation_and_async_allreduce( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + gradient_accumulation_fusion: bool, + allreduce_dgrad: bool, + sequence_parallel: bool, + grad_output_buffer: Optional[List[torch.Tensor]] = None, + wgrad_deferral_limit: Optional[int] = 0, + async_grad_allreduce: Optional[bool] = None, + tp_group: Optional[torch.distributed.ProcessGroup] = None, +) -> torch.Tensor: + """Linear layer execution with asynchronous communication and + gradient accumulation fusion in backprop. + + This has the option to accumulate the result of backprop + calculation into an existing gradient buffer, preventing the need + to do an additional addition kernel after the gradient + calculation. + + Additionally, the tensor parallel all reduce of the input + gradients can be done asynchronously with the calculation of + the weight gradients. + + In the case of sequence parallelism, the reduce scatter of the + input gradients is done asynchronously with the calcluation of the + weight gradients. + + Use of this module requires that the environment variable + CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective + operations, noted in the code, that should be scheduled before + compute kernels to overlap the communication with the computation, + which is necessary for a speedup but not for correctness so that + ordering isn't imposed by the scheduler. Setting + CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled + in the order they are called. + + Args: + input (torch.Tensor required): input like torch.nn.functional.linear + + weight (torch.Tensor required): weight like torch.nn.functional.linear + + bias (torch.Tensor optional): bias like torch.nn.functional.linear + + gradient_accumulation_fusion (bool required): Perform the gradient + accumulation fusion, requires the custom CUDA extension + fused_weight_gradient_mlp_cuda module. To use + gradient_accumulation_fusion you must install APEX with + --cpp_ext and --cuda_ext. For example: "pip install + --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" + " Note that the extension requires CUDA>=11. Otherwise, you + must turn off gradient accumulation fusion." + + allreduce_dgrad (bool required): Do the allreduce of input gradients. + The allreduce is done asynchronously with the computation of weight + gradients. If sequence_parallel is True, this must be + False, as no all reduce is performed. + + sequence_parallel (bool required): Indicates that sequence + parallelism is used and thus in the forward pass the input is + all gathered, and the backward pass the input gradients are + reduce scattered. + + tp_group (torch.distributed.ProcessGroup required): The process group to use for tensor + parallel operations. + + grad_output_buffer (List[torch.Tensor] optional): Buffer used to save + output gradients when embedding table wgrad compute is deferred. + Defaults to None. + + wgrad_deferral_limit (int optional): Limit on the number of + micro-batches for which embedding weight gradient GEMM should be + deferred. Disable by setting this to 0. Defaults to 0. + + async_grad_allreduce (bool optional): Will be removed with 0.11.0. + Please use allreduce_dgrad instead. + """ + + if async_grad_allreduce is not None: + warnings.warn( + "async_grad_allreduce is deprecated, not in use anymore and will" + " be fully removed with 0.11.0. Please use allreduce_dgrad instead." + ) + + tp_group = get_tensor_model_parallel_group_if_none(tp_group) + + args = [ + input, + weight, + bias, + gradient_accumulation_fusion, + allreduce_dgrad, + sequence_parallel, + grad_output_buffer, + wgrad_deferral_limit, + tp_group, + ] + + if not custom_linear_with_grad_accumulation_and_async_allreduce.warned: + if os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS") != "1": + if sequence_parallel: + warnings.warn( + "When using sequence parallelism it is recommended to set the " + "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " + "maximum speedup" + ) + custom_linear_with_grad_accumulation_and_async_allreduce.warned = True + + if allreduce_dgrad: + warnings.warn( + "When using async grad allreduce it is recommended to set the " + "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " + "maximum speedup" + ) + custom_linear_with_grad_accumulation_and_async_allreduce.warned = True + + return CustomLinearWithGradAccumulationAndAsyncCommunication.apply(*args) + + +custom_linear_with_grad_accumulation_and_async_allreduce.warned = False + class ColumnParallelLinear(torch.nn.Module): """Linear layer with column parallelism. @@ -801,9 +1278,11 @@ def __init__( tp_comm_buffer_name: str = None, # Not used disable_grad_reduce: bool = False, tp_group: Optional[torch.distributed.ProcessGroup] = None, + mxfp_quant:bool = True, ): super(ColumnParallelLinear, self).__init__() + self.mxfp_quant = mxfp_quant # Keep input parameters self.input_size = input_size self.output_size = output_size @@ -923,6 +1402,11 @@ def __init__( "`allreduce_dgrad` and `sequence_parallel` cannot be enabled at the same time." ) + if self.mxfp_quant: + self._forward_impl = custom_linear_with_grad_accumulation_and_async_allreduce + else: + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + # Hook adding a default empty _extra_state for state dict self._register_load_state_dict_pre_hook( lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault( @@ -993,6 +1477,14 @@ def forward( self.embedding_activation_buffer.append(input_parallel) # Matrix multiply. + if not weight.requires_grad: + self._forward_impl = linear_with_frozen_weight + else: + if self.mxfp_quant: + self._forward_impl = custom_linear_with_grad_accumulation_and_async_allreduce + else: + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad if self.config._cpu_offloading_context is not None: @@ -1109,9 +1601,11 @@ def __init__( is_expert: bool = False, tp_comm_buffer_name: str = None, # Not used tp_group: Optional[torch.distributed.ProcessGroup] = None, + mxfp_quant:bool = True, ): super(RowParallelLinear, self).__init__() + self.mxfp_quant = mxfp_quant # Keep input parameters self.input_size = input_size self.output_size = output_size @@ -1202,6 +1696,11 @@ def __init__( else: self.register_parameter("bias", None) + if self.mxfp_quant: + self._forward_impl = custom_linear_with_grad_accumulation_and_async_allreduce + else: + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + # Hook adding a default empty _extra_state for state dict self._register_load_state_dict_pre_hook( lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault( @@ -1233,6 +1732,14 @@ def forward(self, input_): assert not self.sequence_parallel input_parallel = scatter_to_tensor_model_parallel_region(input_, group=self.tp_group) # Matrix multiply. + if not self.weight.requires_grad: + self._forward_impl = linear_with_frozen_weight + else: + if self.mxfp_quant: + self._forward_impl = custom_linear_with_grad_accumulation_and_async_allreduce + else: + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + allreduce_dgrad = False if self.config._cpu_offloading_context is not None: diff --git a/megatron/core/tensor_saver.py b/megatron/core/tensor_saver.py new file mode 100644 index 0000000000..39c96c4ef8 --- /dev/null +++ b/megatron/core/tensor_saver.py @@ -0,0 +1,768 @@ +#!/usr/bin/env python3 +""" +Tensor saving utility module +For saving forward/backward input tensors of attention and linear layers +""" + +import os +import torch +import time +from typing import Optional, Dict, Any +from pathlib import Path + + +# Global state management +class TensorCollectionState: + """Tensor collection state manager""" + def __init__(self): + self.current_rank = None + self.current_sample_idx = None + self.current_iteration = 0 + self.batch_idx = 0 + self.sequence_idx = 0 + + def set_rank(self, rank: int): + """设置当前rank""" + self.current_rank = rank + + def set_sample_idx(self, sample_idx: int): + """设置当前sample索引""" + self.current_sample_idx = sample_idx + + def set_iteration(self, iteration: int): + """设置当前iteration""" + self.current_iteration = iteration + + def set_batch_idx(self, batch_idx: int): + """设置当前batch索引""" + self.batch_idx = batch_idx + + def set_sequence_idx(self, sequence_idx: int): + """设置当前sequence索引""" + self.sequence_idx = sequence_idx + + def get_rank(self) -> Optional[int]: + """获取当前rank""" + if self.current_rank is not None: + return self.current_rank + + # 尝试从分布式环境获取 + try: + import torch.distributed as dist + if dist.is_initialized(): + rank = dist.get_rank() + self.current_rank = rank + return rank + except: + pass + + # 尝试从环境变量获取 + rank_env = os.environ.get("LOCAL_RANK") + if rank_env is not None: + try: + rank = int(rank_env) + self.current_rank = rank + return rank + except ValueError: + pass + + return None + + def get_sample_idx(self) -> Optional[int]: + """获取当前sample索引""" + if self.current_sample_idx is not None: + return self.current_sample_idx + + # 尝试从环境变量获取 + sample_env = os.environ.get("CURRENT_SAMPLE_IDX") + if sample_env is not None: + try: + sample_idx = int(sample_env) + self.current_sample_idx = sample_idx + return sample_idx + except ValueError: + pass + + return None + + def get_iteration(self) -> int: + """获取当前iteration""" + return self.current_iteration + + def get_batch_idx(self) -> int: + """获取当前batch索引""" + return self.batch_idx + + def get_sequence_idx(self) -> int: + """获取当前sequence索引""" + return self.sequence_idx + +# 全局状态实例 +_global_tensor_state = TensorCollectionState() + +# 全局tensor索引管理器 +class TensorIndexManager: + """Tensor索引管理器,确保同一层的不同tensor使用相同的索引""" + def __init__(self): + self.layer_tensor_counters = {} # {layer_key: counter} + self.current_layer_key = None + self.current_tensor_group = None + + def get_layer_key(self, layer_type: str, layer_idx: Optional[int], operation: str) -> str: + """生成层标识键""" + if layer_idx is not None: + return f"{layer_type}_L{layer_idx}_{operation}" + else: + return f"{layer_type}_unknown_{operation}" + + def get_tensor_group_key(self, layer_type: str, layer_idx: Optional[int], operation: str) -> str: + """生成tensor组标识键""" + return self.get_layer_key(layer_type, layer_idx, operation) + + def get_tensor_index(self, layer_type: str, layer_idx: Optional[int], operation: str) -> int: + """获取tensor索引""" + layer_key = self.get_layer_key(layer_type, layer_idx, operation) + + if layer_key not in self.layer_tensor_counters: + self.layer_tensor_counters[layer_key] = 0 + + # 对于同一层的不同tensor,使用相同的索引 + return self.layer_tensor_counters[layer_key] + + def increment_layer_counter(self, layer_type: str, layer_idx: Optional[int], operation: str): + """增加层计数器(当该层的所有tensor都保存完毕后调用)""" + layer_key = self.get_layer_key(layer_type, layer_idx, operation) + if layer_key in self.layer_tensor_counters: + self.layer_tensor_counters[layer_key] += 1 + + def reset_layer_counter(self, layer_type: str, layer_idx: Optional[int], operation: str): + """重置层计数器""" + layer_key = self.get_layer_key(layer_type, layer_idx, operation) + if layer_key in self.layer_tensor_counters: + self.layer_tensor_counters[layer_key] = 0 + +# 全局tensor索引管理器实例 +_global_tensor_index_manager = TensorIndexManager() + +def get_tensor_index_manager() -> TensorIndexManager: + """获取全局tensor索引管理器""" + return _global_tensor_index_manager + +def get_tensor_collection_state() -> TensorCollectionState: + """获取全局tensor收集状态""" + return _global_tensor_state + +def set_global_rank(rank: int): + """设置全局rank""" + _global_tensor_state.set_rank(rank) + +def set_global_sample_idx(sample_idx: int): + """设置全局sample索引""" + _global_tensor_state.set_sample_idx(sample_idx) + +def set_global_iteration(iteration: int): + """设置全局iteration""" + _global_tensor_state.set_iteration(iteration) + +def set_global_batch_idx(batch_idx: int): + """设置全局batch索引""" + _global_tensor_state.set_batch_idx(batch_idx) + +def set_global_sequence_idx(sequence_idx: int): + """设置全局sequence索引""" + _global_tensor_state.set_sequence_idx(sequence_idx) + +def get_rank_from_tensor_device(tensor: torch.Tensor) -> Optional[int]: + """尝试从tensor设备信息推断rank""" + try: + if tensor.is_cuda: + device_id = tensor.device.index + if device_id is not None: + # 在某些情况下,device_id可能对应rank + return device_id + except: + pass + return None + +def get_current_rank() -> Optional[int]: + """获取当前的rank信息""" + state = get_tensor_collection_state() + + # 首先尝试从全局状态获取 + rank = state.get_rank() + + # 如果全局状态中没有,尝试直接从分布式环境获取 + if rank is None: + try: + import torch.distributed as dist + if dist.is_initialized(): + rank = dist.get_rank() + # Update global state + state.set_rank(rank) + except Exception as e: + pass # Silently ignore distributed errors + + # If still None, try to get from environment variables + if rank is None: + rank_env = os.environ.get("LOCAL_RANK") or os.environ.get("RANK") + if rank_env is not None: + try: + rank = int(rank_env) + state.set_rank(rank) + except ValueError: + pass + + return rank + +def initialize_tensor_collection(rank: Optional[int] = None, + sample_idx: Optional[int] = None, + iteration: int = 0, + batch_idx: int = 0, + sequence_idx: int = 0): + """初始化tensor收集状态""" + state = get_tensor_collection_state() + + if rank is not None: + state.set_rank(rank) + else: + # Try to auto-detect rank + auto_rank = state.get_rank() + if auto_rank is None: + state.set_rank(0) # 默认值 + + if sample_idx is not None: + state.set_sample_idx(sample_idx) + else: + # Try to auto-detect sample_idx + auto_sample_idx = state.get_sample_idx() + if auto_sample_idx is None: + state.set_sample_idx(0) # 默认值 + + state.set_iteration(iteration) + state.set_batch_idx(batch_idx) + state.set_sequence_idx(sequence_idx) + + + +class TensorSaver: + """Tensor保存器,用于保存量化前后的tensor数据""" + + def __init__(self, save_dir: str = "./enhanced_tensor_logs", enabled: bool = True): + """ + 初始化Tensor保存器 + + Args: + save_dir: 保存目录 + enabled: 是否启用保存功能 + """ + self.save_dir = Path(save_dir) + self.enabled = enabled + self.tensor_counter = 0 + self.current_iteration = 0 + self.micro_batch_count = 0 + self.control_micro_batches = 1 # 固定为1,进行一次完整forward后跳出 + self.collection_completed = False # 标记是否已完成收集 + self.tensor_collected_in_warmup = False # 标记是否已在warmup阶段收集过tensor + + if self.enabled: + self.save_dir.mkdir(parents=True, exist_ok=True) + + def set_iteration(self, iteration: int): + """设置当前iteration""" + self.current_iteration = iteration + self.micro_batch_count = 0 # 重置micro_batch计数 + # 同时更新全局状态 + set_global_iteration(iteration) + + def mark_collection_completed(self): + """标记tensor收集已完成""" + self.collection_completed = True + + def should_exit_after_forward(self) -> bool: + """检查是否应该在forward后退出""" + # 在forward后不立即退出,等待backward完成后再退出 + return False + + def should_exit_after_backward(self) -> bool: + """检查是否应该在backward后退出""" + # 只有在启用tensor保存且已完成收集时才退出 + return self.enabled and self.collection_completed + + def should_collect_tensor(self) -> bool: + """检查是否应该收集tensor""" + # 只有在启用tensor保存且未完成收集时才收集(包括forward和backward) + return self.enabled and not self.collection_completed + + def mark_warmup_collection(self): + """标记已在warmup阶段收集过tensor""" + self.tensor_collected_in_warmup = True + + def should_collect_in_steady_state(self) -> bool: + """检查是否应该在steady state阶段收集tensor""" + return self.enabled and not self.tensor_collected_in_warmup + + def _get_tensor_info(self, tensor: torch.Tensor) -> Dict[str, Any]: + """获取tensor的基本信息""" + if tensor.numel() == 0: + return { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "device": str(tensor.device), + "requires_grad": tensor.requires_grad, + "is_leaf": tensor.is_leaf, + "min": 0.0, + "max": 0.0, + "mean": 0.0, + "std": 0.0, + "overflow_info": { + "upper_overflow_count": 0, + "lower_overflow_count": 0, + "upper_overflow_ratio": 0.0, + "lower_overflow_ratio": 0.0, + "total_overflow_ratio": 0.0 + } + } + + # 计算基本统计信息 + tensor_flat = tensor.float().flatten() + min_val = float(tensor_flat.min().item()) + max_val = float(tensor_flat.max().item()) + mean_val = float(tensor_flat.mean().item()) + std_val = float(tensor_flat.std().item()) + + # 计算溢出信息 + overflow_info = self._calculate_overflow_info(tensor_flat) + + return { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "device": str(tensor.device), + "requires_grad": tensor.requires_grad, + "is_leaf": tensor.is_leaf, + "min": min_val, + "max": max_val, + "mean": mean_val, + "std": std_val, + "overflow_info": overflow_info + } + + def _calculate_overflow_info(self, tensor_flat: torch.Tensor) -> Dict[str, Any]: + """计算tensor的溢出信息""" + total_elements = tensor_flat.numel() + + # 定义不同数据类型的溢出阈值 + dtype_thresholds = { + 'torch.float16': {'max': 65504.0, 'min': -65504.0}, + 'torch.bfloat16': {'max': 3.3895313892515355e+38, 'min': -3.3895313892515355e+38}, + 'torch.float32': {'max': 3.4028235e+38, 'min': -3.4028235e+38}, + 'torch.float64': {'max': 1.7976931348623157e+308, 'min': -1.7976931348623157e+308}, + } + + # 获取当前tensor的阈值 + tensor_dtype = str(tensor_flat.dtype) + if tensor_dtype in dtype_thresholds: + max_threshold = dtype_thresholds[tensor_dtype]['max'] + min_threshold = dtype_thresholds[tensor_dtype]['min'] + else: + # 默认使用float32阈值 + max_threshold = dtype_thresholds['torch.float32']['max'] + min_threshold = dtype_thresholds['torch.float32']['min'] + + # 计算上溢出和下溢出 + upper_overflow_mask = tensor_flat > max_threshold + lower_overflow_mask = tensor_flat < min_threshold + + upper_overflow_count = int(upper_overflow_mask.sum().item()) + lower_overflow_count = int(lower_overflow_mask.sum().item()) + + upper_overflow_ratio = upper_overflow_count / total_elements if total_elements > 0 else 0.0 + lower_overflow_ratio = lower_overflow_count / total_elements if total_elements > 0 else 0.0 + total_overflow_ratio = (upper_overflow_count + lower_overflow_count) / total_elements if total_elements > 0 else 0.0 + + return { + "upper_overflow_count": upper_overflow_count, + "lower_overflow_count": lower_overflow_count, + "upper_overflow_ratio": upper_overflow_ratio, + "lower_overflow_ratio": lower_overflow_ratio, + "total_overflow_ratio": total_overflow_ratio, + "max_threshold": max_threshold, + "min_threshold": min_threshold + } + + def _generate_filename(self, + layer_type: str, + operation: str, + quant_type: str, + tensor_name: str, + layer_idx: Optional[int] = None, + phase: str = "unknown", + component: str = "unknown", + rank: Optional[int] = None, + tensor_group_idx: Optional[int] = None) -> str: + """生成文件名""" + timestamp = time.strftime("%Y%m%d_%H%M%S") + self.tensor_counter += 1 + + # 构建文件名组件 + parts = [ + timestamp, + f"{self.tensor_counter:04d}", + f"iter{self.current_iteration:03d}", + layer_type + ] + + if layer_idx is not None: + parts.append(f"L{layer_idx}") + + parts.extend([operation, phase, component, quant_type]) + + if rank is not None: + parts.append(f"rank{rank:02d}") + + # 添加tensor组索引(同一层的不同tensor使用相同索引) + if tensor_group_idx is not None: + parts.append(f"group{tensor_group_idx:03d}") + + parts.append(tensor_name) + + filename = "_".join(parts) + ".pt" + return filename + + def save_tensor(self, + tensor: torch.Tensor, + layer_type: str, + operation: str, # "forward" or "backward" + quant_type: str, + tensor_name: str, + layer_idx: Optional[int] = None, + phase: str = "unknown", # "pre" or "post" for forward/backward phases + component: str = "unknown", # "linear" or "FA" for component type + rank: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None) -> Optional[str]: + """ + 保存tensor到文件 + + Args: + tensor: 要保存的tensor + layer_type: 层类型 ("attention" or "linear") + operation: 操作类型 ("forward" or "backward") + quant_type: 量化类型 ("hifp8", "mxfp8", "mxfp4", "bf16", etc.) + tensor_name: tensor名称 ("input", "output", "grad_input", etc.) + layer_idx: 层索引 + phase: 阶段 ("pre" or "post" for forward/backward phases) + component: 组件类型 ("linear" or "FA" for component type) + rank: GPU rank信息 + metadata: 额外的元数据 + + Returns: + 保存的文件路径,如果未启用则返回None + """ + if not self.should_collect_tensor(): + return None + + # 当启用tensor保存时,会在一次forward后自动退出,无需额外检查 + + # 自动获取rank信息(如果未提供) + if rank is None: + rank = get_current_rank() + + + # 如果仍然无法获取rank,尝试从tensor设备信息推断 + if rank is None: + rank = get_rank_from_tensor_device(tensor) + if rank is not None: + # Update global state + state = get_tensor_collection_state() + state.set_rank(rank) + + # 如果仍然无法获取,使用默认值并打印警告 + if rank is None: + rank = 0 # 默认rank为0 + + # 如果rank不是0或1,则不保存 + if rank not in [0, 1]: + return None + # 获取tensor组索引(同一层的不同tensor使用相同索引) + index_manager = get_tensor_index_manager() + tensor_group_idx = index_manager.get_tensor_index(layer_type, layer_idx, operation) + + try: + # 生成文件名 + filename = self._generate_filename(layer_type, operation, quant_type, tensor_name, + layer_idx, phase, component, rank, tensor_group_idx) + filepath = self.save_dir / filename + + # iteration数据计数已简化,无需手动增加 + + # 准备保存数据 - 添加更安全的tensor处理 + try: + # 先获取tensor信息(在移动之前) + tensor_info = self._get_tensor_info(tensor) + + # 安全地处理tensor + if tensor.is_cuda: + tensor_cpu = tensor.detach().cpu() + else: + tensor_cpu = tensor.detach().clone() + + # 确保tensor是连续的 + if not tensor_cpu.is_contiguous(): + tensor_cpu = tensor_cpu.contiguous() + + save_data = { + "tensor": tensor_cpu, + "tensor_info": tensor_info, + "metadata": { + "layer_type": layer_type, + "operation": operation, + "quant_type": quant_type, + "tensor_name": tensor_name, + "layer_idx": layer_idx, + "phase": phase, + "component": component, + "rank": rank, + "iteration": self.current_iteration, + "save_time": time.strftime("%Y-%m-%d %H:%M:%S"), + **(metadata or {}) + } + } + except Exception as tensor_error: + print(f"[TensorSaver] 处理tensor时出错: {tensor_error}") + # 如果tensor处理失败,尝试更简单的方式 + save_data = { + "tensor": tensor.detach().cpu().contiguous(), + "tensor_info": {"shape": list(tensor.shape), "dtype": str(tensor.dtype)}, + "metadata": { + "layer_type": layer_type, + "operation": operation, + "quant_type": quant_type, + "tensor_name": tensor_name, + "layer_idx": layer_idx, + "phase": phase, + "component": component, + "rank": rank, + "iteration": self.current_iteration, + "save_time": time.strftime("%Y-%m-%d %H:%M:%S"), + **(metadata or {}) + } + } + + # 保存到文件 - 添加更安全的保存过程 + try: + # 确保目录存在 + filepath.parent.mkdir(parents=True, exist_ok=True) + + # 使用更安全的保存方式 + torch.save(save_data, filepath, _use_new_zipfile_serialization=False) + + # 验证文件是否保存成功 + if filepath.exists() and filepath.stat().st_size > 0: + print(f"[TensorSaver] 已保存: {filename}") + return str(filepath) + else: + print(f"[TensorSaver] 保存失败: 文件为空或不存在") + return None + + except Exception as save_error: + print(f"[TensorSaver] 保存文件时出错: {save_error}") + # 尝试使用pickle保存 + try: + import pickle + with open(filepath.with_suffix('.pkl'), 'wb') as f: + pickle.dump(save_data, f) + print(f"[TensorSaver] 使用pickle保存成功: {filename}") + return str(filepath.with_suffix('.pkl')) + except Exception as pickle_error: + print(f"[TensorSaver] pickle保存也失败: {pickle_error}") + return None + + except Exception as e: + print(f"[TensorSaver] 保存tensor失败: {e}") + return None + + def save_attention_tensors(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + quant_type: str, + operation: str = "forward", + layer_idx: Optional[int] = None, + phase: str = "pre", + component: str = "FA", + rank: Optional[int] = None, + attention_weights: Optional[torch.Tensor] = None, + metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Optional[str]]: + """ + 保存attention层的输入tensor + + Args: + query: Query tensor + key: Key tensor + value: Value tensor + quant_type: 量化类型 + operation: 操作类型 + layer_idx: 层索引 + phase: 阶段 + component: 组件类型 + rank: GPU rank信息 + attention_weights: Attention权重矩阵(P分布) + metadata: 额外元数据 + + Returns: + 保存的文件路径字典 + """ + results = {} + + # 保存query tensor + if query is not None: + results["query"] = self.save_tensor( + query, "attention", operation, quant_type, "query", layer_idx, phase, component, rank, metadata + ) + + # 保存key tensor + if key is not None: + results["key"] = self.save_tensor( + key, "attention", operation, quant_type, "key", layer_idx, phase, component, rank, metadata + ) + + # 保存value tensor + if value is not None: + results["value"] = self.save_tensor( + value, "attention", operation, quant_type, "value", layer_idx, phase, component, rank, metadata + ) + + # 保存attention权重(P分布) + if attention_weights is not None: + results["attention_weights"] = self.save_tensor( + attention_weights, "attention", operation, quant_type, "attention_weights", layer_idx, phase, component, rank, metadata + ) + + return results + + def save_linear_tensors(self, + input_tensor: torch.Tensor, + weight: torch.Tensor, + quant_type: str, + operation: str = "forward", + layer_idx: Optional[int] = None, + phase: str = "pre", + component: str = "linear", + rank: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Optional[str]]: + """ + 保存linear层的输入tensor + + Args: + input_tensor: 输入tensor + weight: 权重tensor + quant_type: 量化类型 + operation: 操作类型 + layer_idx: 层索引 + phase: 阶段 + component: 组件类型 + rank: GPU rank信息 + metadata: 额外元数据 + + Returns: + 保存的文件路径字典 + """ + results = {} + + # 保存input tensor + if input_tensor is not None: + results["input"] = self.save_tensor( + input_tensor, "linear", operation, quant_type, "input", layer_idx, phase, component, rank, metadata + ) + + # 保存weight tensor + if weight is not None: + results["weight"] = self.save_tensor( + weight, "linear", operation, quant_type, "weight", layer_idx, phase, component, rank, metadata + ) + + return results + + +# 全局tensor保存器实例 +_global_tensor_saver = None + + +def get_tensor_saver() -> TensorSaver: + """获取全局tensor保存器实例""" + global _global_tensor_saver + if _global_tensor_saver is None: + # 从环境变量和命令行参数获取配置 + save_dir = os.environ.get("TENSOR_SAVE_DIR", "./enhanced_tensor_logs") + enabled = os.environ.get("TENSOR_SAVE_ENABLED", "false").lower() == "true" + + # 尝试从命令行参数获取配置(如果可用) + try: + from megatron.training.global_vars import get_args + args = get_args() + if hasattr(args, 'tensor_save_dir') and args.tensor_save_dir: + save_dir = args.tensor_save_dir + if hasattr(args, 'save_tensors'): + enabled = args.save_tensors or enabled + except Exception as e: + pass + + print(f"[TensorSaver] 初始化 - 保存目录: {save_dir}, 启用: {enabled}") + _global_tensor_saver = TensorSaver(save_dir=save_dir, enabled=enabled) + + # 从环境变量设置iteration + iteration = os.environ.get("TENSOR_SAVER_ITERATION") + if iteration is not None: + try: + _global_tensor_saver.set_iteration(int(iteration)) + except ValueError: + print(f"[TensorSaver] 无效的iteration值: {iteration}") + + # 初始化tensor收集状态 + initialize_tensor_collection() + + return _global_tensor_saver + + +def save_attention_tensors(query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + quant_type: str, + operation: str = "forward", + layer_idx: Optional[int] = None, + phase: str = "pre", + component: str = "FA", + rank: Optional[int] = None, + attention_weights: Optional[torch.Tensor] = None, + metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Optional[str]]: + """保存attention层tensor的便捷函数""" + saver = get_tensor_saver() + return saver.save_attention_tensors(query, key, value, quant_type, operation, layer_idx, phase, component, rank, attention_weights, metadata) + + +def save_linear_tensors(input_tensor: torch.Tensor, + weight: torch.Tensor, + quant_type: str, + operation: str = "forward", + layer_idx: Optional[int] = None, + phase: str = "pre", + component: str = "linear", + rank: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Optional[str]]: + """保存linear层tensor的便捷函数""" + saver = get_tensor_saver() + return saver.save_linear_tensors(input_tensor, weight, quant_type, operation, layer_idx, phase, component, rank, metadata) + + +def save_tensor(tensor: torch.Tensor, + layer_type: str, + operation: str, + quant_type: str, + tensor_name: str, + layer_idx: Optional[int] = None, + phase: str = "unknown", + component: str = "unknown", + rank: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None) -> Optional[str]: + """保存单个tensor的便捷函数""" + saver = get_tensor_saver() + return saver.save_tensor(tensor, layer_type, operation, quant_type, tensor_name, layer_idx, phase, component, rank, metadata) diff --git a/megatron/core/transformer/dot_product_attention.py b/megatron/core/transformer/dot_product_attention.py index 9e13cefe89..b55afe8fc2 100644 --- a/megatron/core/transformer/dot_product_attention.py +++ b/megatron/core/transformer/dot_product_attention.py @@ -47,6 +47,7 @@ def __init__( ): super().__init__(config=config) + # import pdb;pdb.set_trace() self.config: TransformerConfig = config assert ( @@ -117,12 +118,15 @@ def forward( packed_seq_params: Optional[PackedSeqParams] = None, ): """Forward.""" + # import pdb;pdb.set_trace() assert packed_seq_params is None, ( "Packed sequence is not supported by DotProductAttention." "Please use TEDotProductAttention instead." ) assert attention_bias is None, "Attention bias is not supported for DotProductAttention." + # attention输入tensor现在通过量化算子自动保存,无需重复保存 + # =================================== # Raw attention scores. [b, n/p, s, s] # =================================== @@ -158,13 +162,97 @@ def forward( ) # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query.transpose(0, 1), # [b * np, sq, hn] - key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=self.softmax_scale, - ) + # import pdb;pdb.set_trace() + from quant.mxfp import mxfp_baddbmm + from quant.hifp import hifp_baddbmm + from quant.bf16_operators import bf16_baddbmm + # 从环境变量获取量化类型,默认为hifp8 + import os + custom_quant_type = 'hifp8' + + # 支持time-resume自适应量化 + try: + from megatron.core.adaptive_quantization import get_adaptive_quantization_manager + # 尝试从全局状态获取量化管理器 + if hasattr(self, '_adaptive_quantization_manager'): + custom_quant_type = self._adaptive_quantization_manager.get_current_quantization_type() + except ImportError: + pass + + # 获取scaling_control参数 + scaling_control = 'max' # 默认值,后续可以从args获取 + + # 准备tensor保存参数 + tensor_save_params = { + "layer_type": "attention", + "layer_idx": getattr(self, 'layer_number', None), + "operation": "forward", + "phase": "pre", + "component": "attention", + "rank": None, # 稍后设置 + "metadata": { + "softmax_scale": self.softmax_scale, + "attention_scores": True, + } + } + + # 获取rank信息 + try: + import torch.distributed as dist + if dist.is_initialized(): + tensor_save_params["rank"] = dist.get_rank() + except: + pass + + if tensor_save_params["rank"] is None: + tensor_save_params["rank"] = int(os.environ.get("LOCAL_RANK", 0)) + + if custom_quant_type == 'mxfp4': + matmul_result = mxfp_baddbmm( + matmul_input_buffer, + query.transpose(0, 1), # [b * np, sq, hn] + key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=self.softmax_scale, + scaling_control=scaling_control, + **tensor_save_params + ) + elif custom_quant_type == 'mxfp8': + matmul_result = mxfp_baddbmm( + matmul_input_buffer, + query.transpose(0, 1), # [b * np, sq, hn] + key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=self.softmax_scale, + scaling_control=scaling_control, + **tensor_save_params + ) + elif custom_quant_type == 'hifp8': + matmul_result = hifp_baddbmm( + matmul_input_buffer, + query.transpose(0, 1), # [b * np, sq, hn] + key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=self.softmax_scale, + **tensor_save_params + ) + elif custom_quant_type == 'bf16': + matmul_result = bf16_baddbmm( + matmul_input_buffer, + query.transpose(0, 1), # [b * np, sq, hn] + key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=self.softmax_scale, + **tensor_save_params + ) + else: + matmul_result = torch.baddbmm( + matmul_input_buffer, + query.transpose(0, 1), # [b * np, sq, hn] + key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=self.softmax_scale, + ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) @@ -176,6 +264,8 @@ def forward( # attention scores and attention mask [b, np, sq, sk] attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask) + # attention权重现在通过量化算子自动保存,无需重复保存 + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -202,8 +292,37 @@ def forward( attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] - context = torch.bmm(attention_probs, value.transpose(0, 1)) - + from quant.mxfp import mxfp_matmul + from quant.hifp import hifp_matmul + from quant.bf16_operators import bf16_matmul + # 使用相同的量化类型 + # custom_quant_type 已在上面定义 + + # 准备tensor保存参数(用于context计算) + context_tensor_save_params = { + "layer_type": "attention", + "layer_idx": getattr(self, 'layer_number', None), + "operation": "forward", + "phase": "post", + "component": "FA", + "rank": tensor_save_params["rank"], # 重用之前的rank + "metadata": { + "attention_mask_shape": list(attention_mask.shape) if attention_mask is not None else None, + "attn_mask_type": str(attn_mask_type) if attn_mask_type is not None else None, + "output_shape": None, # 稍后设置 + } + } + + if custom_quant_type == 'hifp8': + context = hifp_matmul(attention_probs, value.transpose(0, 1), **context_tensor_save_params) + elif custom_quant_type == 'mxfp8': + context = mxfp_matmul(attention_probs, value.transpose(0, 1), 'fp8_e4m3', scaling_control=scaling_control, **context_tensor_save_params) + elif custom_quant_type == 'mxfp4': + context = mxfp_matmul(attention_probs, value.transpose(0, 1), 'fp4_e2m1', scaling_control=scaling_control, **context_tensor_save_params) + elif custom_quant_type == 'bf16': + context = bf16_matmul(attention_probs, value.transpose(0, 1), **context_tensor_save_params) + else: + context = torch.bmm(attention_probs, value.transpose(0, 1)) # change view [b, np, sq, hn] context = context.view(*output_size) @@ -214,4 +333,6 @@ def forward( new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,) context = context.view(*new_context_shape) + # context tensor现在通过量化算子自动保存,无需重复保存 + return context diff --git a/megatron/core/transformer/heterogeneous/.linear_replacements.py.swp b/megatron/core/transformer/heterogeneous/.linear_replacements.py.swp new file mode 100644 index 0000000000..833ac5e4f4 Binary files /dev/null and b/megatron/core/transformer/heterogeneous/.linear_replacements.py.swp differ diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index f6aad9fa48..bedf015f9c 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -341,6 +341,7 @@ def build_layer(layer_spec, layer_number): return module # offset is implicit in TransformerLayer + # import pdb;pdb.set_trace() self.layers = torch.nn.ModuleList( [ build_layer(layer_spec, i + 1) diff --git a/megatron/legacy/fused_kernels/tests/test_fused_kernels.py b/megatron/legacy/fused_kernels/tests/test_fused_kernels.py deleted file mode 100644 index f5b2b78a3f..0000000000 --- a/megatron/legacy/fused_kernels/tests/test_fused_kernels.py +++ /dev/null @@ -1,389 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import math - -import torch -from torch.nn import LayerNorm - -from megatron.legacy.model.enums import AttnMaskType -from megatron.legacy.model.fused_layer_norm import MixedFusedLayerNorm -from megatron.legacy.model.fused_softmax import FusedScaleMaskSoftmax -from megatron.legacy.model.utils import attention_mask_func -from megatron.legacy.fused_kernels import load - -def test_load_fused_kernels(): - try: - import fused_layer_norm_cuda - import scaled_masked_softmax_cuda - import scaled_upper_triang_masked_softmax_cuda - import torch - - print("[Success] load_fused_kernels") - except ImportError as e: - print("[Fail] load_fused_kernels") - raise e - -def test_fused_softmax(): - bert = BertModel.from_pretrained("bert-base-cased").cuda().half() - tokenizer = BertTokenizer.from_pretrained("bert-base-cased") - test_text = ( - "Hello. How are you? I am fine thank you and you? yes Good. " - "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 - ) - - tokens = tokenizer( - [test_text] * 4, - return_tensors="pt", - ) - - embedding_output = bert.embeddings( - input_ids=tokens["input_ids"].cuda(), - position_ids=None, - token_type_ids=tokens["token_type_ids"].cuda(), - inputs_embeds=None, - past_key_values_length=0, - ) - - # (bsz, 1, 1, seq_len) - mask = bert.get_extended_attention_mask( - attention_mask=tokens["attention_mask"].cuda(), - input_shape=tokens["input_ids"].shape, - device=bert.device, - ) - # (bsz, 1, seq_len, seq_len) - mask = mask.repeat(1, 1, mask.size()[-1], 1) - - attention = bert.encoder.layer[0].attention.self - key_layer = attention.transpose_for_scores(attention.key(embedding_output)) - query_layer = attention.transpose_for_scores(attention.query(embedding_output)) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores /= math.sqrt(key_layer.size()[-1]) - - fused_softmax = ( - FusedScaleMaskSoftmax( - input_in_fp16=True, - input_in_bf16=False, - mask_func=attention_mask_func, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.padding, - scaled_masked_softmax_fusion=True, - ) - .cuda() - .half() - ) - - fused_softmax_output = fused_softmax( - attention_scores, - (mask != 0), - ) - - torch_softmax = ( - FusedScaleMaskSoftmax( - input_in_fp16=True, - input_in_bf16=False, - mask_func=attention_mask_func, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.padding, - scaled_masked_softmax_fusion=False, - ) - .cuda() - .half() - ) - - torch_softmax_output = torch_softmax( - attention_scores, - (mask != 0), - ) - - test_result = (fused_softmax_output - torch_softmax_output).abs() - - while test_result.dim() != 1: - test_result = test_result.mean(dim=-1) - - diff = test_result.mean(dim=-1) - - if diff <= 1e-3: - print( - f"\n[Success] test_fused_softmax" - f"\n > mean_difference={diff}" - f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" - f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" - ) - else: - print( - f"\n[Fail] test_fused_softmax" - f"\n > mean_difference={diff}, " - f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " - f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" - ) - - -def test_fused_upper_triangle_mask_softmax(): - gpt = GPT2Model.from_pretrained("gpt2").cuda().half() - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - test_text = ( - "Hello. How are you? I am fine thank you and you? yes Good. " - "hi hi hi hi hi hi hi" # 24 - ) - - tokens = tokenizer( - [test_text] * 4, - return_tensors="pt", - ) - - attention_mask = tokens["attention_mask"].cuda() - attention_mask = attention_mask.view(attention_mask.size(0), -1) - attention_mask = attention_mask[:, None, None, :] - attention_mask = (1.0 - attention_mask) * -10000.0 - attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1) - attn = gpt.h[0] - - hidden_states = gpt.wte(tokens["input_ids"].cuda()) - q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1) - q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim) - k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim) - attn_weights = torch.matmul(q, k.transpose(-1, -2)) - - sq, sk = q.size(-2), k.size(-2) - causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool() - total_mask = ~(causal_mask & (attention_mask == 0)) - """ - tensor([[[[False, True, True, ..., True, True, True], - [False, False, True, ..., True, True, True], - [False, False, False, ..., True, True, True], - ..., - [False, False, False, ..., False, True, True], - [False, False, False, ..., False, False, True], - [False, False, False, ..., False, False, False]]] - """ - - fused_softmax = ( - FusedScaleMaskSoftmax( - input_in_fp16=True, - input_in_bf16=False, - mask_func=attention_mask_func, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=True, - ) - .cuda() - .half() - ) - - fused_softmax_output = fused_softmax( - attn_weights, - total_mask, - ) - - torch_softmax = ( - FusedScaleMaskSoftmax( - input_in_fp16=True, - input_in_bf16=False, - mask_func=attention_mask_func, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=False, - ) - .cuda() - .half() - ) - - torch_softmax_output = torch_softmax( - attn_weights, - total_mask, - ) - - test_result = (fused_softmax_output - torch_softmax_output).abs() - - while test_result.dim() != 1: - test_result = test_result.mean(dim=-1) - - diff = test_result.mean(dim=-1) - - if diff <= 1e-3: - print( - f"\n[Success] test_fused_upper_triangle_mask_softmax" - f"\n > mean_difference={diff}" - f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" - f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" - ) - else: - print( - f"\n[Fail] test_fused_upper_triangle_mask_softmax" - f"\n > mean_difference={diff}, " - f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " - f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" - ) - - -def test_layer_norm(): - bert = BertModel.from_pretrained("bert-base-cased").cuda().half() - tokenizer = BertTokenizer.from_pretrained("bert-base-cased") - test_text = ( - "Hello. How are you? I am fine thank you and you? yes Good. " - "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 - ) - - tokens = tokenizer( - [test_text] * 4, - return_tensors="pt", - ) - - # [bsz, seq_len, d_model] - embedding_output = ( - bert.embeddings( - input_ids=tokens["input_ids"].cuda(), - position_ids=None, - token_type_ids=tokens["token_type_ids"].cuda(), - inputs_embeds=None, - past_key_values_length=0, - ) - .cuda() - .half() - ) - - fused_layernorm_layer = ( - MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() - ) - - torch_layernorm_layer = ( - LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() - ) - - fused_output = fused_layernorm_layer(embedding_output) - torch_output = torch_layernorm_layer(embedding_output) - test_result = (fused_output - torch_output).abs() - - while test_result.dim() != 1: - test_result = test_result.mean(dim=-1) - - diff = test_result.mean(dim=-1) - - if diff <= 1e-3: - print( - f"\n[Success] test_layer_norm" - f"\n > mean_difference={diff}" - f"\n > fused_values={fused_output[-1][-1][:5].tolist()}" - f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" - ) - else: - print( - f"\n[Fail] test_layer_norm" - f"\n > mean_difference={diff}, " - f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, " - f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" - ) - - -def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - -def forward_torch_softmax(input, mask, scale): - input = input * scale - mask_output = attention_mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - return probs - - -def test_masked_softmax_forward(): - import scaled_masked_softmax_cuda - - batch = 2 - attn = 16 - scale_t = torch.tensor([1.0]) - for qlen in [128, 256, 1024, 2048, 4096]: - for klen in [128, 256, 1024, 2048]: - inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') - masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') - softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) - softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) - error = (softmax_results_torch - softmax_results).abs().max() - assert error < 1e-3 - -def test_masked_softmax_backward(): - import scaled_masked_softmax_cuda - - batch = 2 - attn = 16 - scale_t = torch.tensor([1.0]) - for qlen in [128, 256, 1024, 2048, 4096]: - for klen in [128, 256, 1024, 2048]: - inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') - backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0') - masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') - softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) - back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item()) - - inputs.requires_grad = True - softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) - softmax_results_torch.backward(backward) - error = (back_grad - inputs.grad).abs().max() - assert error < 1e-3 - - -def test_allmasked_softmax_forward(): - import scaled_masked_softmax_cuda - - batch = 2 - attn = 16 - scale_t = torch.tensor([1.0]) - for qlen in [128, 256, 1024, 2048, 4096]: - for klen in [128, 256, 1024, 2048]: - inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') - masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') - softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) - softmax_results_torch = torch.zeros_like(inputs) - error = (softmax_results_torch - softmax_results).abs().max() - assert error == 0.0 - - -def test_allmasked_softmax_backward(): - import scaled_masked_softmax_cuda - - batch = 2 - attn = 16 - scale_t = torch.tensor([1.0]) - for qlen in [128, 256, 1024, 2048, 4096]: - for klen in [128, 256, 1024, 2048]: - inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') - backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0') - masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') - softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) - back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item()) - inputs.requires_grad = True - softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) - softmax_results_torch.backward(backward) - error = (back_grad - inputs.grad).abs().max() - assert error < 1e-3 - - -if __name__ == "__main__": - try: - from transformers import BertTokenizer, GPT2Tokenizer - from transformers.models.bert.modeling_bert import BertModel - from transformers.models.gpt2.modeling_gpt2 import GPT2Model - import transformers - - transformers.logging.set_verbosity( - transformers.logging.FATAL, - ) - - except ImportError: - print("\n[Fail] Please install `transformers` package to test fused kernels\n") - exit(-1) - - load() - test_masked_softmax_forward() - test_masked_softmax_backward() - test_allmasked_softmax_forward() - test_allmasked_softmax_backward() - test_load_fused_kernels() - test_fused_softmax() - test_fused_upper_triangle_mask_softmax() - test_layer_norm() diff --git a/megatron/legacy/model/transformer.py b/megatron/legacy/model/transformer.py index 2a662a55b1..d74727342a 100644 --- a/megatron/legacy/model/transformer.py +++ b/megatron/legacy/model/transformer.py @@ -1741,6 +1741,7 @@ def forward(self, hidden_states, attention_mask, fp8_group=self.fp8_group ) if self.use_fp8 else nullcontext(): # Determine if the current iteration is first microbatch + # import pdb;pdb.set_trace() if self.num_microbatches_in_previous_step != get_num_microbatches(): self.microbatch_count = 0 # Reset count on new batch size rampup interval self.num_microbatches_in_previous_step = get_num_microbatches() diff --git a/megatron/legacy/mpu/tests/test_cross_entropy.py b/megatron/legacy/mpu/tests/test_cross_entropy.py deleted file mode 100644 index 00ae42228a..0000000000 --- a/megatron/legacy/mpu/tests/test_cross_entropy.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from commons import set_random_seed -from commons import IdentityLayer -from commons import print_separator -from commons import initialize_distributed -from mpu.cross_entropy import vocab_parallel_cross_entropy -import mpu -import torch.nn.functional as F -import torch -import random -import sys -sys.path.append("../..") - - -def torch_cross_entropy(batch_size, seq_length, vocab_size, - logits_scale, seed): - set_random_seed(seed) - identity = IdentityLayer((batch_size, seq_length, vocab_size), - scale=logits_scale).cuda() - logits = identity() - target = torch.cuda.LongTensor( - size=(batch_size, seq_length)).random_(0, vocab_size) - loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), - target.view(-1), - reduction='none').view_as(target).mean() - loss.backward() - return loss, identity.weight.grad - - -def mpu_cross_entropy(batch_size, seq_length, vocab_size, - logits_scale, seed): - set_random_seed(seed) - identity = IdentityLayer((batch_size, seq_length, vocab_size), - scale=logits_scale).cuda() - logits = identity() - logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits) - target = torch.cuda.LongTensor( - size=(batch_size, seq_length)).random_(0, vocab_size) - loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() - loss.backward() - return loss, identity.weight.grad - - -def test_cross_entropy(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing cross entropy with model parallel size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - batch_size = 13 - seq_length = 17 - vocab_size_per_partition = 11 - logits_scale = 1000.0 - vocab_size = vocab_size_per_partition * tensor_model_parallel_size - seed = 1234 - - loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, - vocab_size, logits_scale, - seed) - loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, - vocab_size, logits_scale, - seed) - - error = loss_torch.sub_(loss_mpu).abs().max() - print(' max error in loss on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - error = grad_torch.sub_(grad_mpu).abs().max() - print(' max error in grad on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset groups - mpu.destroy_tensor_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -if __name__ == '__main__': - - initialize_distributed() - world_size = torch.distributed.get_world_size() - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test cross entropy') - test_cross_entropy(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/legacy/mpu/tests/test_data.py b/megatron/legacy/mpu/tests/test_data.py deleted file mode 100644 index c30bf4bb8d..0000000000 --- a/megatron/legacy/mpu/tests/test_data.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from commons import print_separator -from commons import initialize_distributed -from mpu import data as data_utils -import mpu -import torch -import functools -import operator -import sys -sys.path.append("../..") - - -def test_broadcast_data(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing broadcast_data with model parallel size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - torch.manual_seed(1234 + mpu.get_data_parallel_rank()) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - key_size_t = {'key1': [7, 11], - 'key2': [8, 2, 1], - 'key3': [13], - 'key4': [5, 1, 2], - 'key5': [5, 12]} - keys = list(key_size_t.keys()) - - data = {} - data_t = {} - for key in key_size_t: - data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) - data_t[key] = data[key].clone() - data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) - data_t['keyX'] = data['keyX'].clone() - if mpu.get_tensor_model_parallel_rank() != 0: - data = None - - data_utils._check_data_types(keys, data_t, torch.int64) - key_size, key_numel, \ - total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) - for key in keys: - assert key_size[key] == key_size_t[key] - total_numel_t = 0 - for key in keys: - target_size = functools.reduce(operator.mul, key_size_t[key], 1) - assert key_numel[key] == target_size - total_numel_t += target_size - assert total_numel == total_numel_t - - data_b = data_utils.broadcast_data(keys, data, torch.int64) - for key in keys: - tensor = data_t[key].cuda() - assert data_b[key].sub(tensor).abs().max() == 0 - - # Reset groups - mpu.destroy_tensor_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -if __name__ == '__main__': - - initialize_distributed() - world_size = torch.distributed.get_world_size() - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test test broadcast data') - test_broadcast_data(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/legacy/mpu/tests/test_initialize.py b/megatron/legacy/mpu/tests/test_initialize.py deleted file mode 100644 index 48652080a5..0000000000 --- a/megatron/legacy/mpu/tests/test_initialize.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from commons import print_separator -from commons import initialize_distributed -import mpu -import torch -import sys -sys.path.append("../..") - - -def test_initialize_model_parallel(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing initialize_model_parallel with size {} ...'.format( - tensor_model_parallel_size)) - tensor_model_parallel_size_ = min(tensor_model_parallel_size, - torch.distributed.get_world_size()) - assert not mpu.model_parallel_is_initialized() - mpu.initialize_model_parallel(tensor_model_parallel_size_) - assert mpu.model_parallel_is_initialized() - - # Checks. - def check(group, world_size, rank): - assert world_size == group.size() - assert rank == group.rank() - - # Model parallel. - world_size = tensor_model_parallel_size_ - rank = torch.distributed.get_rank() % tensor_model_parallel_size_ - assert world_size == mpu.get_tensor_model_parallel_world_size() - assert rank == mpu.get_tensor_model_parallel_rank() - check(mpu.get_tensor_model_parallel_group(), world_size, rank) - - # Data parallel. - world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ - rank = torch.distributed.get_rank() // tensor_model_parallel_size - assert world_size == mpu.get_data_parallel_world_size() - assert rank == mpu.get_data_parallel_rank() - check(mpu.get_data_parallel_group(), world_size, rank) - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): - - if torch.distributed.get_rank() == 0: - print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format( - tensor_model_parallel_size_)) - tensor_model_parallel_size = min(tensor_model_parallel_size_, - torch.distributed.get_world_size()) - assert not mpu.model_parallel_is_initialized() - mpu.initialize_model_parallel(tensor_model_parallel_size) - assert mpu.model_parallel_is_initialized() - - # Checks - src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank() - assert mpu.get_tensor_model_parallel_src_rank() == src_rank - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -if __name__ == '__main__': - - initialize_distributed() - world_size = torch.distributed.get_world_size() - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test initialize model parallel') - test_initialize_model_parallel(tensor_model_parallel_size) - print_separator('test model parallel source rank') - test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/legacy/mpu/tests/test_layers.py b/megatron/legacy/mpu/tests/test_layers.py deleted file mode 100644 index 73ad4b9459..0000000000 --- a/megatron/legacy/mpu/tests/test_layers.py +++ /dev/null @@ -1,517 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from mpu import layers -from commons import set_random_seed -from commons import print_separator -from commons import initialize_distributed -import mpu -from torch.nn.parameter import Parameter -import torch.nn.init as init -import torch -import random -import sys -sys.path.append("../..") - - -def test_parallel_embedding(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing parallel embedding with model parallel size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - batch_size = 17 - seq_length = 23 - vocab_size = 48 - hidden_size = 16 - seed = 1236 - - set_random_seed(123) - input_data = torch.LongTensor( - size=(batch_size, seq_length)).random_(0, vocab_size).cuda() - loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() - - set_random_seed(seed) - embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() - - output = embedding_original(input_data) - loss_original = torch.mul(output, loss_weight).sum() - loss_original.backward() - - set_random_seed(seed) - embedding_parallel = layers.ParallelEmbedding( - vocab_size, hidden_size, init_method=init.normal_).cuda() - output = embedding_parallel(input_data) - loss_parallel = torch.mul(output, loss_weight).sum() - loss_parallel.backward() - - set_random_seed(seed) - embedding_vocab_parallel = layers.VocabParallelEmbedding( - vocab_size, hidden_size, init_method=init.normal_).cuda() - output = embedding_vocab_parallel(input_data) - loss_vocab_parallel = torch.mul(output, loss_weight).sum() - loss_vocab_parallel.backward() - - torch.distributed.barrier() - error = loss_parallel.sub(loss_original).abs() - print(' error in loss (parallel) on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-12, 'error: {}'.format(error) - - torch.distributed.barrier() - error = loss_vocab_parallel.sub(loss_original).abs() - print(' error in loss (vocab parallel) on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-12, 'error: {}'.format(error) - - weight_grad_orig = torch.split(embedding_original.weight.grad, - hidden_size // tensor_model_parallel_size, - 1)[mpu.get_tensor_model_parallel_rank()] - error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() - print(' error in grad (parallel) on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-12, 'error: {}'.format(error) - - weight_grad_orig = torch.split(embedding_original.weight.grad, - vocab_size // tensor_model_parallel_size, - 0)[mpu.get_tensor_model_parallel_rank()] - error = embedding_vocab_parallel.weight.grad.sub( - weight_grad_orig).abs().max() - print(' error in grad (vocab parallel) on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-12, 'error: {}'.format(error) - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -def test_initialize_affine_weight(tensor_model_parallel_size): - - mpu.initialize_model_parallel(tensor_model_parallel_size) - if torch.distributed.get_rank() == 0: - print('> testing initialize_affine_weight with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - input_size_coeff = 13 - input_size = input_size_coeff * tensor_model_parallel_size - output_size_coeff = 17 - output_size = output_size_coeff * tensor_model_parallel_size - - # --------------- - # Column parallel - # --------------- - weight = torch.empty(output_size_coeff, input_size) - set_random_seed(seed) - layers._initialize_affine_weight(weight, output_size, input_size, - - output_size_coeff, 0, - torch.nn.init.normal_) - # Target. - set_random_seed(seed) - master_weight = torch.empty(output_size, input_size) - torch.nn.init.normal_(master_weight) - rank = mpu.get_tensor_model_parallel_rank() - my_weight = torch.split(master_weight, output_size_coeff, - dim=0)[rank].contiguous().clone() - - # Compare. - error = weight.sub(my_weight).abs().max() - torch.distributed.barrier() - print(' column parallel max error (should be zero) on global rank ' - '{}: {}'.format(torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # ------------ - # Row parallel - # ------------ - weight = torch.empty(output_size, input_size_coeff) - set_random_seed(seed) - mpu.layers._initialize_affine_weight(weight, output_size, input_size, - input_size_coeff, 1, - torch.nn.init.normal_) - # Target. - set_random_seed(seed) - master_weight = torch.empty(output_size, input_size) - torch.nn.init.normal_(master_weight) - rank = mpu.get_tensor_model_parallel_rank() - my_weight = torch.split(master_weight, input_size_coeff, - dim=1)[rank].contiguous().clone() - - # Compare. - error = weight.sub(my_weight).abs().max() - torch.distributed.barrier() - print(' row parallel max error (should be zero) on global rank ' - '{}: {}'.format(torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -class IdentityLayer2D(torch.nn.Module): - def __init__(self, m, n): - super(IdentityLayer2D, self).__init__() - self.weight = Parameter(torch.Tensor(m, n)) - torch.nn.init.xavier_normal_(self.weight) - - def forward(self): - return self.weight - - -def test_column_parallel_linear(tensor_model_parallel_size): - - mpu.initialize_model_parallel(tensor_model_parallel_size) - if torch.distributed.get_rank() == 0: - print('> testing ColumnParallelLinear with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - set_random_seed(seed) - input_size_coeff = 13 - input_size = input_size_coeff * tensor_model_parallel_size - output_size_coeff = 17 - output_size = output_size_coeff * tensor_model_parallel_size - batch_size = 7 - - # Network - identity_layer = IdentityLayer2D(batch_size, input_size).cuda() - linear_layer = mpu.ColumnParallelLinear( - input_size, output_size, keep_master_weight_for_test=True).cuda() - loss_weight = torch.randn([batch_size, output_size]).cuda() - # Forward - input_ = identity_layer() - output = linear_layer(input_) - loss = torch.mul(output, loss_weight).sum() - # Backward - loss.backward() - - # Values. - dLdY = loss_weight - X = identity_layer.weight - A = linear_layer.master_weight.cuda() - dLdA = torch.matmul(dLdY.t(), X) - dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) - dLdX = torch.matmul(dLdY, A) - - rank = mpu.get_tensor_model_parallel_rank() - my_dLdA = torch.split(dLdA, output_size_coeff, - dim=0)[rank].contiguous().clone() - error = my_dLdA.sub(linear_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdA on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - my_dLdb = torch.split(dLdb, output_size_coeff, - dim=0)[rank].contiguous().clone() - error = my_dLdb.sub(linear_layer.bias.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdb on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - error = dLdX.sub(identity_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdX on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -def test_row_parallel_linear(tensor_model_parallel_size): - - mpu.initialize_model_parallel(tensor_model_parallel_size) - if torch.distributed.get_rank() == 0: - print('> testing RowParallelLinear with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - set_random_seed(seed) - input_size_coeff = 13 - input_size = input_size_coeff * tensor_model_parallel_size - output_size_coeff = 17 - output_size = output_size_coeff * tensor_model_parallel_size - batch_size = 7 - - # Network - identity_layer = IdentityLayer2D(batch_size, input_size).cuda() - linear_layer = mpu.RowParallelLinear( - input_size, output_size, keep_master_weight_for_test=True).cuda() - loss_weight = torch.randn([batch_size, output_size]).cuda() - # Forward - input_ = identity_layer() - output = linear_layer(input_) - loss = torch.mul(output, loss_weight).sum() - # Backward - loss.backward() - - # Values. - dLdY = loss_weight - X = identity_layer.weight - A = linear_layer.master_weight.cuda() - dLdA = torch.matmul(dLdY.t(), X) - dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) - dLdX = torch.matmul(dLdY, A) - - rank = mpu.get_tensor_model_parallel_rank() - my_dLdA = torch.split(dLdA, input_size_coeff, - dim=1)[rank].contiguous().clone() - error = my_dLdA.sub(linear_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdA on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - error = dLdb.sub(linear_layer.bias.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdb on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - error = dLdX.sub(identity_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdX on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -class IdentityLayer3D(torch.nn.Module): - def __init__(self, m, n, k): - super(IdentityLayer3D, self).__init__() - self.weight = Parameter(torch.Tensor(m, n, k)) - torch.nn.init.xavier_normal_(self.weight) - - def forward(self): - return self.weight - - -def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition, - hidden_size_per_att_head, dropout_prob, batch_size, - sequence_length): - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - set_random_seed(seed) - - num_att_heads = num_att_heads_per_partition * \ - torch.distributed.get_world_size() - hidden_size = hidden_size_per_att_head * num_att_heads - - # Network - identity_layer = IdentityLayer3D(batch_size, sequence_length, - hidden_size).cuda() - attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, - dropout_prob).cuda() - loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() - attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() - # Forward - input_ = identity_layer() - output = attention_layer(input_, attention_mask) - loss = torch.mul(output, loss_weight).sum() - # Backward - loss.backward() - - rank = mpu.get_tensor_model_parallel_rank() - mpu.destroy_model_parallel() - return rank, hidden_size, tensor_model_parallel_size, loss, \ - attention_layer, identity_layer - - -def test_parallel_self_attention(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing ParallelSelfAttention with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - - num_att_heads_per_partition = 3 - hidden_size_per_att_head = 7 - dropout_prob = 0.0 # has to be zero - batch_size = 5 - sequence_length = 13 - - rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \ - attention_layer_1, identity_layer_1 = parallel_self_attention( - 1, num_att_heads_per_partition, - hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) - - rank, hidden_size, tensor_model_parallel_size, loss, \ - attention_layer, identity_layer = parallel_self_attention( - tensor_model_parallel_size, num_att_heads_per_partition, - hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) - assert hideen_size_1 == hidden_size - - error = loss_1.sub(loss).abs().max() - torch.distributed.barrier() - print(' loss error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-6 - - my_lin_grad_list = torch.split( - attention_layer_1.query_key_value.weight.grad, - hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size] - my_lin_grad = torch.cat(my_lin_grad_list, dim=0) - error = my_lin_grad.sub( - attention_layer.query_key_value.weight.grad).abs().max() - torch.distributed.barrier() - print(' weight gradient error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-6 - - error = identity_layer_1.weight.grad.sub( - identity_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' input gradient error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-6 - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition, - hidden_size_per_att_head, batch_size, sequence_length): - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - set_random_seed(seed) - - num_att_heads = num_att_heads_per_partition * \ - torch.distributed.get_world_size() - hidden_size = hidden_size_per_att_head * num_att_heads - intermediate_size = 4 * hidden_size - - # Network - identity_layer = IdentityLayer3D(batch_size, sequence_length, - hidden_size).cuda() - transformer_layer = mpu.BertParallelTransformerLayer( - hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, - torch.nn.functional.relu, 1.0e-5).cuda() - - loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() - attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() - # Forward - input_ = identity_layer() - output = transformer_layer(input_, attention_mask) - loss = torch.mul(output, loss_weight).sum() - # Backward - loss.backward() - - rank = mpu.get_tensor_model_parallel_rank() - mpu.destroy_model_parallel() - return rank, hidden_size, tensor_model_parallel_size, loss, \ - transformer_layer, identity_layer - - -def test_parallel_transformer_layer(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing ParallelTransformerLayer with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - - num_att_heads_per_partition = 3 - hidden_size_per_att_head = 7 - batch_size = 5 - sequence_length = 13 - - rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \ - transformer_layer_1, identity_layer_1 = parallel_transformer( - 1, num_att_heads_per_partition, - hidden_size_per_att_head, batch_size, sequence_length) - - rank, hidden_size, tensor_model_parallel_size, loss, \ - transformer_layer, identity_layer = parallel_transformer( - tensor_model_parallel_size, num_att_heads_per_partition, - hidden_size_per_att_head, batch_size, sequence_length) - - error = loss_1.sub(loss).abs().max() - torch.distributed.barrier() - print(' loss error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-5, 'error: {}'.format(error) - - error = identity_layer_1.weight.grad.sub( - identity_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' input gradient error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-5, 'error: {}'.format(error) - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -if __name__ == '__main__': - - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - initialize_distributed() - world_size = torch.distributed.get_world_size() - - print_separator('test initialize affine weight') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_initialize_affine_weight(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test parallel embedding') - test_parallel_embedding(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - print_separator('test column-parallel linear') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_column_parallel_linear(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - print_separator('test row-parallel linear') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_row_parallel_linear(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - print_separator('test parallel self-attention') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_parallel_self_attention(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - print_separator('test parallel transformer') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_parallel_transformer_layer(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/legacy/mpu/tests/test_random.py b/megatron/legacy/mpu/tests/test_random.py deleted file mode 100644 index 26092772cf..0000000000 --- a/megatron/legacy/mpu/tests/test_random.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from commons import print_separator -from commons import initialize_distributed -import mpu -import torch -import sys -sys.path.append("../..") - - -def test_set_cuda_rng_state(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing set_rng_state with size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - size = 123 - seed = 1234 - torch.cuda.manual_seed(1234) - tensor = torch.tensor(size, dtype=torch.float, device='cuda') - - # Get the state - rng_state = torch.cuda.get_rng_state() - rng_state_copy = rng_state.clone() - - # Do some stuff. - for _ in range(5): - torch.randn(size, out=tensor) - result_1 = tensor.clone() - - assert rng_state.sub(rng_state_copy).max() == 0 - assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 - - # State should be different. - new_rng_state = torch.cuda.get_rng_state() - max_diff = new_rng_state.sub(rng_state).max() - print(' max diff in rng state (should be non-zero) on global rank {}: {}'. - format(torch.distributed.get_rank(), max_diff)) - assert max_diff > 0 - - # Reset the rng state and do the same stuff. - mpu.random._set_cuda_rng_state(rng_state) - for _ in range(5): - torch.randn(size, out=tensor) - mpu.random._set_cuda_rng_state(rng_state) - for _ in range(5): - torch.randn(size, out=tensor) - result_2 = tensor.clone() - - # Results should be the same - error = result_2.sub(result_1).abs().max() - print(' max error in generated tensors (should be zero) on ' - 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Input state should have remained intact. - error = rng_state.sub(rng_state_copy).max() - print(' max error in rng state (should be zero) on global rank {}: {}'. - format(torch.distributed.get_rank(), error)) - assert error == 0 - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -def test_cuda_rng_tracker(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing cuda rng tracker with size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed_1 = 1234 - seed_2 = 4321 - size = [12, 21] - tensor = torch.tensor(size, dtype=torch.float, device='cuda') - - # Set to seed_1 and generate two tensors. - torch.cuda.manual_seed(seed_1) - torch.randn(size, out=tensor) - target_11 = tensor.clone() - torch.randn(size, out=tensor) - target_12 = tensor.clone() - - # Set to seed_2 and generate two tensors. - torch.cuda.manual_seed(seed_2) - torch.randn(size, out=tensor) - target_21 = tensor.clone() - torch.randn(size, out=tensor) - target_22 = tensor.clone() - - # Now if we interleave seed_1 and seed_2, - # we should still get the same tensors - torch.cuda.manual_seed(seed_1) - mpu.get_cuda_rng_tracker().add('test', seed_2) - - torch.randn(size, out=tensor) - result_11 = tensor.clone() - - with mpu.get_cuda_rng_tracker().fork('test'): - torch.randn(size, out=tensor) - result_21 = tensor.clone() - - torch.randn(size, out=tensor) - result_12 = tensor.clone() - - with mpu.get_cuda_rng_tracker().fork('test'): - torch.randn(size, out=tensor) - result_22 = tensor.clone() - - diff = result_11.sub(result_21).abs().max() - diff = min(diff, result_12.sub(result_22).abs().max()) - print(' max diff in generated tensors (should be non-zero) on ' - 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) - assert diff > 1.0e-6 - error = max(result_11.sub(target_11).abs().max(), - result_12.sub(target_12).abs().max()) - error = max(error, result_21.sub(target_21).abs().max()) - error = max(error, result_22.sub(target_22).abs().max()) - print(' max error in generated tensors (should be zero) on ' - 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset the tracker - mpu.get_cuda_rng_tracker().reset() - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing model parallel cuda manual seed with size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - mpu.model_parallel_cuda_manual_seed(12345) - assert torch.cuda.initial_seed() == 12345 - with mpu.get_cuda_rng_tracker().fork(): - assert torch.cuda.initial_seed() == (12345 + 2718 + - mpu.get_tensor_model_parallel_rank()) - - # Reset the tracker - mpu.get_cuda_rng_tracker().reset() - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -if __name__ == '__main__': - - initialize_distributed() - world_size = torch.distributed.get_world_size() - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test set rng state') - test_set_cuda_rng_state(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test cuda rng tracker') - test_cuda_rng_tracker(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test model parallel cuda manual seed') - test_model_parallel_cuda_manual_seed(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/training/.arguments.py.swo b/megatron/training/.arguments.py.swo new file mode 100644 index 0000000000..a563475cf4 Binary files /dev/null and b/megatron/training/.arguments.py.swo differ diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index bce01066fc..bca44120c0 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1265,6 +1265,7 @@ def _add_transformer_engine_args(parser): group.add_argument('--no-fp8-wgrad', action='store_false', help='Execute wgrad in higher precision even for FP8 runs', dest='fp8_wgrad') + # import pdb;pdb.set_trace() group.add_argument('--transformer-impl', default='transformer_engine', choices=['local', 'transformer_engine'], help='Which Transformer implementation to use.') @@ -1739,6 +1740,28 @@ def _add_logging_args(parser): help='Path to save the wandb results locally.') group.add_argument('--logging-level', type=int, default=None, help='Set default logging level') + group.add_argument('--save-tensors', action='store_true', + help='Enable tensor saving for debugging and analysis.') + group.add_argument('--tensor-save-dir', type=str, default='./enhanced_tensor_logs', + help='Directory to save tensor logs (default: ./enhanced_tensor_logs)') + group.add_argument('--control-iter', type=int, default=None, + help='Number of iterations to collect tensors before stopping (default: None, no limit)') + group.add_argument('--scaling-control', type=str, default='max', choices=['max', 'max_minus_1'], + help='Scaling control strategy for MX quantization: max (default) or max_minus_1') + + # Time-resume adaptive quantization parameters + group.add_argument('--time-resume', action='store_true', + help='Enable time-resume adaptive quantization training') + group.add_argument('--quant-loss-threshold', type=float, default=0.1, + help='Loss threshold for switching from quantized to BF16 training') + group.add_argument('--quant-window-size', type=int, default=5, + help='Number of iterations per training window') + group.add_argument('--quant-checkpoint-interval', type=int, default=1, + help='Checkpoint save interval within windows (in iterations)') + group.add_argument('--quant-fallback-strategy', type=str, default='bf16', choices=['bf16', 'fp16'], + help='Fallback precision when quantized training fails') + group.add_argument('--quant-recovery-buffer', type=int, default=2, + help='Number of checkpoints to keep for recovery') return parser diff --git a/megatron/training/training.py b/megatron/training/training.py index e25647e3bf..8c46890cc9 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1237,6 +1237,8 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch if isinstance(optim_instance, DistributedOptimizer): optim_instance._copy_main_params_to_param_buffer() + # Tensor saving is now handled automatically in the pipeline functions + # Forward pass. losses_reduced = forward_backward_func( forward_step_func=forward_step_func, @@ -1249,6 +1251,15 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch forward_only=False, adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, ) + + # Mark tensor collection as completed if tensor saving is enabled + try: + from megatron.core.tensor_saver import get_tensor_saver + tensor_saver = get_tensor_saver() + if tensor_saver.enabled and tensor_saver.tensor_collected_in_warmup and not tensor_saver.collection_completed: + tensor_saver.mark_collection_completed() + except Exception as e: + pass # Silently ignore tensor saver errors should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: return {}, True, should_checkpoint, should_exit, exit_code, None, None @@ -1802,6 +1813,7 @@ def checkpoint_and_decide_exit( """Save checkpoint and decide whether to exit based on arguments (e.g., if --exit-duration-in-mins is set). Actual exit happens in main training loop based on the return value of this function.""" + # import pdb;pdb.set_trace() args = get_args() timers = get_timers() @@ -1852,6 +1864,7 @@ def checkpoint_and_decide_exit( non_persistent_ckpt=True, train_data_iterator=train_data_iterator, ) + # import pdb;pdb.set_trace() saved_checkpoint = True # Exit based on duration. @@ -1862,6 +1875,7 @@ def checkpoint_and_decide_exit( ) torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) done = done_cuda.item() + # return False if done: if args.save and not saved_checkpoint: save_checkpoint_and_time( @@ -1874,8 +1888,9 @@ def checkpoint_and_decide_exit( train_data_iterator=train_data_iterator, ) print_datetime(f'exiting program after {train_time} minutes') - + # import pdb;pdb.set_trace() return True + # return False # Exit based on iterations. if args.exit_interval and iteration % args.exit_interval == 0: @@ -1890,6 +1905,7 @@ def checkpoint_and_decide_exit( train_data_iterator=train_data_iterator, ) print_datetime(f'exiting program at iteration {iteration}') + # import pdb;pdb.set_trace() return True @@ -1956,6 +1972,8 @@ def train( write_args_to_tensorboard() # Turn on training mode which enables dropout. + # import pdb;pdb.set_trace() + print(f"model:{model}") for model_module in model: model_module.train() @@ -2121,6 +2139,20 @@ def get_e2e_base_metrics(): ) cuda_graph_helper.create_cudagraphs() + # Initialize adaptive quantization manager if time-resume is enabled + adaptive_quantization_manager = None + if getattr(args, 'time_resume', False): + try: + from megatron.core.adaptive_quantization import get_adaptive_quantization_manager + adaptive_quantization_manager = get_adaptive_quantization_manager( + args, model, optimizer, opt_param_scheduler, + iteration, num_floating_point_operations_so_far + ) + print_rank_0("[TimeResume] Adaptive quantization training enabled") + except ImportError as e: + print_rank_0(f"[TimeResume] Failed to import adaptive quantization: {e}") + adaptive_quantization_manager = None + # Run training iterations till done. buffered_rollouts = None ref_state_dict = None @@ -2172,6 +2204,16 @@ def get_e2e_base_metrics(): continue args.curr_iteration = iteration + + # Update tensor saver iteration if tensor saving is enabled + if getattr(args, 'save_tensors', False): + try: + from megatron.core.tensor_saver import get_tensor_saver + tensor_saver = get_tensor_saver() + tensor_saver.set_iteration(iteration) + except Exception as e: + print(f"[Training] Warning: Failed to update tensor saver iteration: {e}") + # For GRPO, we keep the data for a few epochs. DeepSeekMath paper calls this number $\mu$. # It is similar to a PPO epoch. @@ -2200,6 +2242,17 @@ def get_e2e_base_metrics(): forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func ) ft_integration.on_training_step_end() + + # Check if tensor collection is completed and set should_exit if needed + if getattr(args, 'save_tensors', False): + try: + from megatron.core.tensor_saver import get_tensor_saver + tensor_saver = get_tensor_saver() + if tensor_saver.should_exit_after_forward(): + should_exit = True + except Exception as e: + pass # Silently ignore tensor saver errors + if should_checkpoint: save_checkpoint_and_time( iteration, @@ -2211,6 +2264,8 @@ def get_e2e_base_metrics(): train_data_iterator=train_data_iterator, ) if should_exit: + print("exit from log in line 2312") + import pdb;pdb.set_trace() break # Enable forward pre-hooks after first set of forward and backward passes. @@ -2235,6 +2290,14 @@ def get_e2e_base_metrics(): cuda_graph_helper.cuda_graph_set_manual_hooks() iteration += 1 + + # Check if we've reached the control_iter limit and exit if needed + control_iter = getattr(args, 'control_iter', None) + if control_iter is not None and iteration >= control_iter: + print_rank_0(f"[Training] Reached control_iter limit ({control_iter}), exiting training...") + # Exit the training loop early + break + batch_size = ( mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() ) @@ -2282,6 +2345,35 @@ def get_e2e_base_metrics(): params_norm, num_zeros_in_grad, ) + + # Handle adaptive quantization if time-resume is enabled + if adaptive_quantization_manager is not None and not skipped_iter: + # Get current loss for adaptive quantization + current_loss = None + for key in loss_dict: + if 'loss' in key.lower(): + current_loss = loss_dict[key].item() + break + + if current_loss is not None: + # Check if we should switch precision + should_switch, new_precision = adaptive_quantization_manager.should_switch_precision(current_loss) + + if should_switch: + # Save checkpoint before switching + adaptive_quantization_manager.save_checkpoint_async(iteration, f"switch_to_{new_precision}") + + # Update precision + adaptive_quantization_manager.current_precision = new_precision + + print_rank_0(f"[TimeResume] Switched to {new_precision} training at iteration {iteration}") + + # Check if we should save checkpoint within window + if adaptive_quantization_manager.should_save_checkpoint(iteration): + adaptive_quantization_manager.save_checkpoint_async(iteration) + + # Update window state + adaptive_quantization_manager.update_window_state(iteration) # Evaluation. if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid: @@ -2343,6 +2435,8 @@ def get_e2e_base_metrics(): train_data_iterator, ) if should_exit: + # print("exit from log in line 2442") + # import pdb;pdb.set_trace() break one_logger_utils.track_e2e_metrics() @@ -2370,6 +2464,10 @@ def get_e2e_base_metrics(): print_rank_0(f"Total training energy (GPU): {total_energy / 1e6} MJ") energy_monitor.shutdown() + # Finalize adaptive quantization manager if enabled + if adaptive_quantization_manager is not None: + adaptive_quantization_manager.finalize() + # If any exit conditions (signal handler, duration, iterations) have been reached, exit. if should_exit: wandb_writer = get_wandb_writer() diff --git a/megatron/training/training.py.backup b/megatron/training/training.py.backup new file mode 100644 index 0000000000..59837fa525 --- /dev/null +++ b/megatron/training/training.py.backup @@ -0,0 +1,2856 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain utilities.""" + +import dataclasses +from datetime import datetime +import gc +import logging +import math +import os +import sys +from typing import List, Optional + +import torch.distributed + +from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer +from .log_handler import CustomHandler + +# Make default logging level INFO, but filter out all log messages not from MCore. +logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO) +from .theoretical_memory_usage import report_theoretical_memory +import time + +# The earliest we can measure the start time. +_TRAIN_START_TIME = time.time() +import torch + +try: + from megatron.training import rl_utils + has_rl_utils = True +except ImportError: + has_rl_utils = False +try: + from megatron.post_training.algos.distillation import ( + get_tensor_shapes_adjust_fn_for_distillation, + ) + + has_nvidia_modelopt = True +except ImportError: + has_nvidia_modelopt = False + +try: + from nvidia_resiliency_ext.inprocess import CallWrapper +except ImportError: + CallWrapper = type(None) + + +from megatron.core import mpu, tensor_parallel +from megatron.core.utils import ( + check_param_hashes_across_dp_replicas, + get_model_config, + StragglerDetector, +) +from megatron.core.fp8_utils import correct_amax_history_if_needed +from megatron.training.checkpointing import load_checkpoint +from megatron.training.checkpointing import save_checkpoint +from megatron.training.checkpointing import checkpoint_exists +from megatron.core.full_cuda_graph import FullCudaGraphWrapper +from megatron.core.transformer.cuda_graphs import TECudaGraphHelper +from megatron.core.transformer.module import Float16Module +from megatron.core.distributed import DistributedDataParallelConfig, TorchFullyShardedDataParallelConfig +from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as megatron_FSDP +from megatron.core.optimizer.optimizer import param_group_identifier_keys + +try: + from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP + + HAVE_FSDP2 = True +except ImportError: + HAVE_FSDP2 = False + +from megatron.core.distributed import finalize_model_grads +from megatron.core.enums import ModelType +from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig +from megatron.core.rerun_state_machine import ( + get_rerun_state_machine, + destroy_rerun_state_machine, + RerunDataIterator, + RerunMode, +) +from megatron.training.initialize import initialize_megatron +from megatron.training.initialize import write_args_to_tensorboard +from megatron.training.initialize import set_jit_fusion_options +from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank +from megatron.legacy.data.data_samplers import build_pretraining_data_loader +from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler +from megatron.core.transformer.moe import upcycling_utils +from megatron.core.transformer.moe.moe_utils import track_moe_metrics +from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper +from megatron.core.parallel_state import destroy_global_memory_buffer, destroy_model_parallel +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.num_microbatches_calculator import ( + destroy_num_microbatches_calculator, + get_current_global_batch_size, + get_current_running_global_batch_size, + get_num_microbatches, + update_num_microbatches +) + +from .async_utils import maybe_finalize_async_save +from .utils import ( + append_to_progress_log, + calc_params_l2_norm, + check_adlr_autoresume_termination, + logical_and_across_model_parallel_group, + reduce_max_stat_across_model_parallel_group, + is_last_rank, + print_rank_0, + print_rank_last, + report_memory, + unwrap_model, + update_use_dist_ckpt, + to_empty_if_meta_device, +) +from .global_vars import ( + destroy_global_vars, + get_args, + get_signal_handler, + get_timers, + get_tensorboard_writer, + get_wandb_writer, + get_one_logger, + get_tokenizer, + get_energy_monitor, +) +from . import one_logger_utils + +from . import ft_integration + +stimer = StragglerDetector() + +from megatron.core.msc_utils import MultiStorageClientFeature, open_file + + +def destroy_global_state(): + destroy_global_vars() + destroy_num_microbatches_calculator() + destroy_global_memory_buffer() + destroy_model_parallel() + destroy_rerun_state_machine() + + +def print_datetime(string): + """Note that this call will sync across all ranks.""" + torch.distributed.barrier() + time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + print_rank_0(f'[{string}] datetime: {time_str} ') + + +def num_floating_point_operations(args, batch_size): + def calculate_layer_counts(): + """Calculate the number of attention, Mamba, and MLP layers.""" + if args.hybrid_override_pattern: + counts = {'M': 0, '*': 0, '-': 0} + for layer_type in args.hybrid_override_pattern: + if layer_type in counts: + counts[layer_type] += 1 + return counts['*'], counts['M'], counts['-'] + else: + num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio) + num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio) + num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers + return num_attn_layers, num_mamba_layers, num_mlp_layers + + def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): + """Calculate FLOPs for an MLP layer.""" + scale_factor = 3.0 / 2.0 if swiglu else 1.0 + return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 + + def attn_layer_flops( + batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None + ): + """Calculate FLOPs for an attention layer.""" + p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 + g = gqa_groups if gqa else num_heads + return ( + 4 + * batch_size + * seq_len + * hidden_size + * p + * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) + ) + + def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, + head_dim=64, num_groups=1, num_heads=128): + """Calculate FLOPs for a Mamba layer.""" + # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels, + # but small percent of overall layer flops + d_in = 2 * hidden_size + if num_heads: + nheads = num_heads + else: + nheads = d_in // head_dim + return ( + ( + 2 + * batch_size + * seq_len + * hidden_size + * (2 * d_in + 2 * num_groups * state_dim + nheads) + ) # in_proj + + (7 * batch_size * seq_len * d_in * state_dim) # scan + + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj + ) + + def hybrid_flops(batch_size, seq_len, hidden_size, + num_attn_layers, num_mamba_layers, num_mlp_layers, + mamba_state_dim=128, mamba_head_dim=64, + mamba_num_groups=8, mamba_num_heads=128, + num_attn_heads=32,gqa=True, + gqa_groups=8, kv_channels=None, + mlp_expansion=4.0, swiglu=False, + vocab_size=256000): + """Calculate total FLOPs for the hybrid model.""" + flops_fwd = ( + num_attn_layers * attn_layer_flops(batch_size, seq_len, hidden_size, + num_attn_heads, gqa, gqa_groups, kv_channels) + + num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size, + mlp_expansion, swiglu) + + num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size, + mamba_state_dim, mamba_head_dim, + mamba_num_groups, mamba_num_heads) + + (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation + ) + return flops_fwd * 3 + + def transformer_flops(): + """Calculate FLOPs for a standard Transformer model.""" + # TODO(helenn/dnarayanan): Refactor this to reuse the helper methods. + # Attention projection size. + query_projection_size = args.kv_channels * args.num_attention_heads + query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size + # Group Query Attention. + if not args.group_query_attention: + args.num_query_groups = args.num_attention_heads + # MoE. + if args.num_experts is None: + # Every Transformer MLP is dense. + num_dense_layers = args.num_layers + num_moe_layers = 0 + num_experts_routed_to = 0 + last_layer_is_moe = 0 + else: + # Calculate number of dense and MoE Transformer MLPs. + if isinstance(args.moe_layer_freq, int): + moe_layer_pattern = [ + 1 if (i % args.moe_layer_freq == 0) else 0 for i in range(args.num_layers) + ] + elif isinstance(args.moe_layer_freq, list): + moe_layer_pattern = args.moe_layer_freq + else: + raise RuntimeError("Illegal --moe-layer-freq argument provided!") + assert len(moe_layer_pattern) == args.num_layers, ( + f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, " + f"expected {args.num_layers}, " + f"current moe layer pattern: {args.moe_layer_freq}" + ) + num_moe_layers = sum(moe_layer_pattern) # Number of 1s in `moe_layer_pattern`. + num_dense_layers = args.num_layers - num_moe_layers + num_experts_routed_to = args.moe_router_topk + last_layer_is_moe = moe_layer_pattern[-1] + + if args.mtp_num_layers is not None: + mtp_num_layers = args.mtp_num_layers + num_moe_layers += last_layer_is_moe * mtp_num_layers + num_dense_layers += (1 - last_layer_is_moe) * mtp_num_layers + num_layers = args.num_layers + mtp_num_layers + else: + mtp_num_layers = 0 + num_layers = args.num_layers + + moe_ffn_hidden_size = ( + args.moe_ffn_hidden_size + if args.moe_ffn_hidden_size is not None + else args.ffn_hidden_size + ) + shared_expert_ffn_hidden_size = ( + 0 + if args.moe_shared_expert_intermediate_size is None + else args.moe_shared_expert_intermediate_size + ) + # SwiGLU. + gated_linear_multiplier = 3 / 2 if args.swiglu else 1 + + # The 12x term below comes from the following factors; for more details, see + # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473. + # - 3x: Each GEMM in the model needs to be performed 3 times (forward pass, + # backward wgrad [weight gradient], backward dgrad [data gradient]). + # - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model + # architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM + # in MLP layer). + # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations. + expansion_factor = 3 * 2 * 2 + + if args.multi_latent_attention: + assert not args.group_query_attention + ''' + Basic arithmetic + let B is batch size, s is seq_len, h is embedding dim, + for one self_attnetion block (prenorm is not included) + qkv projection: 6Bsh^2 + attn: 2Bs^2h + attn over value: 2Bs^2h + oproj: 2Bsh^2 + + references + https://arxiv.org/abs/2305.10403 + https://arxiv.org/abs/2205.05198 + ''' + ## MLA + if args.q_lora_rank is None: + q_term = ( + args.hidden_size + * args.num_attention_heads + * (args.qk_head_dim + args.qk_pos_emb_head_dim) + ) + else: + q_term = args.q_lora_rank * ( + args.hidden_size + + args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim) + + 1 + ) + self_attn_term = ( + 3 + * 2 # fwd(1) + bwd(2) *FMA + * num_layers + * ( + ## q lora + rope + q norm + q_term + ## kv lora + rope + kv norm + + args.kv_lora_rank + * ( + args.hidden_size + + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim) + + 1 + ) + + args.hidden_size * args.qk_pos_emb_head_dim + ## o proj + + (args.num_attention_heads * args.v_head_dim) * args.hidden_size + ## core attn + + args.seq_length + * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) + / 2 + + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 + ) + ) + + else: + ## MHA or GQA + self_attn_term = ( + expansion_factor + * num_layers + * args.hidden_size + * args.hidden_size + * ( + ( + 1 + + (args.num_query_groups / args.num_attention_heads) + # # Only half of the attention matrix is non-zero and needs to be multiplied with V. + + (args.seq_length / args.hidden_size / 2) + ) + * query_projection_to_hidden_size_ratio + ) + ) + + total_floating_point_operations = ( + batch_size + * args.seq_length + * ( + # MLP + expansion_factor + * num_layers + * args.hidden_size + * ( + # dense layer (deepseek v2, v3 style) + (args.ffn_hidden_size * gated_linear_multiplier) + * (num_dense_layers / num_layers) + # routed experts + + (moe_ffn_hidden_size * num_experts_routed_to * gated_linear_multiplier) + * (num_moe_layers / num_layers) + # Shared Experts. + + (shared_expert_ffn_hidden_size * gated_linear_multiplier) + * (num_moe_layers / num_layers) + ) + # Self Attention + + self_attn_term + # MTP norms and proj + + 3 + * 2 + * mtp_num_layers + * ( + # MTP eh norm + final nrom + 3 * args.hidden_size + # MTH eh proj + + 2 * args.hidden_size * args.hidden_size + ) + # Logit. + + 3 * 2 * args.hidden_size * args.padded_vocab_size * (mtp_num_layers + 1) + ) + ) + return total_floating_point_operations + + # Main entrypoint for FLOPs calculation. + if args.is_hybrid_model: + # Calculate the number of each type of layer. + num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts() + + # Compute hybrid model FLOPs. + return hybrid_flops( + batch_size=batch_size, + seq_len=args.seq_length, + hidden_size=args.hidden_size, + num_attn_layers=num_attn_layers, + num_mamba_layers=num_mamba_layers, + num_mlp_layers=num_mlp_layers, + mamba_state_dim=args.mamba_state_dim, + mamba_head_dim=args.mamba_head_dim, + mamba_num_groups=args.mamba_num_groups, + mamba_num_heads=args.mamba_num_heads, + num_attn_heads=args.num_attention_heads, + gqa=args.group_query_attention, + gqa_groups=args.num_query_groups, + kv_channels=args.kv_channels, + mlp_expansion=args.ffn_hidden_size / args.hidden_size, + swiglu=args.swiglu, + vocab_size=args.padded_vocab_size, + ) + else: + # Compute standard Transformer model FLOPs. + return transformer_flops() + + +def get_start_time_from_progress_log(): + """ + Gets start time of earliest job with same world size. Also returns the number + of floating-point operations completed in last saved checkpoint. + """ + args = get_args() + assert args.save is not None + progress_log_filename = os.path.join(args.save, "progress.txt") + + # start_time is time when job with same world size started. + # start_num_floating_point_operations is the number of floating-point operations + # completed when this job started. + # latest_num_floating_point_operations is the number of floating-point operations + # completed in most recent saved checkpoint. + start_time = None + start_num_floating_point_operations = None + latest_num_floating_point_operations = 0 + + def _get_field(string, type): + return type(string.split(': ')[1]) + + with open_file(progress_log_filename, 'r') as f: + for line in f: + line = line.strip() + line_tokens = line.split('\t') + world_size_in_line = _get_field(line_tokens[2], int) + if line_tokens[3] == "Saved checkpoint": + latest_num_floating_point_operations = _get_field(line_tokens[7], float) + if world_size_in_line != args.world_size: + # Re-start search if we see a different world size. + start_time = None + start_num_floating_point_operations = None + continue + if line_tokens[3] == "Starting job": + if start_time is None: + start_time = line_tokens[0] + start_num_floating_point_operations = latest_num_floating_point_operations + assert ( + start_time is not None and start_num_floating_point_operations is not None + ), "Should have seen at least one 'Starting job' entry with same world_size" + return datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S'), start_num_floating_point_operations + + +def preprocess_common_state_dict(common_state_dict): + import copy + + # Convert args key of type namespace to dictionary + preprocessed_common_state_dict = copy.deepcopy(common_state_dict) + preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args']) + # Remove rank and local rank from state dict if it exists, since they are expected to be different + preprocessed_common_state_dict['args'].pop('local_rank', None) + preprocessed_common_state_dict['args'].pop('rank', None) + if ( + preprocessed_common_state_dict['args']['use_distributed_optimizer'] + and "optimizer" in preprocessed_common_state_dict + ): + def reorder_inner_param_groups(optimizer_state_dict): + # When distributed optimizer loading, source param groups will be reordered, + # so we reorder the param groups here to prevent warning. + + # Pop empty param_state. + if "param_state" in optimizer_state_dict and not optimizer_state_dict["param_state"]: + optimizer_state_dict.pop("param_state") + + # Reorder param groups. + if "optimizer" not in optimizer_state_dict: + return + inner_optimizer = optimizer_state_dict["optimizer"] + if "param_groups" not in inner_optimizer: + return + param_groups = inner_optimizer["param_groups"] + key_fn = lambda pg: [pg[key] for key in param_group_identifier_keys] + param_groups.sort(key=key_fn) + inner_optimizer["param_groups"] = param_groups + + optimizer_state_dict = preprocessed_common_state_dict['optimizer'] + if "optimizer" in optimizer_state_dict: + # Only 1 optimizer in chained optimizer. + reorder_inner_param_groups(optimizer_state_dict) + else: + # Multiple optimizers in chained optimizer. + for i in range(len(optimizer_state_dict)): + if i in optimizer_state_dict.keys(): + reorder_inner_param_groups(optimizer_state_dict[i]) + + return preprocessed_common_state_dict + + +def pretrain( + train_valid_test_dataset_provider, + model_provider, + model_type, + forward_step_func, + process_non_loss_data_func=None, + extra_args_provider=None, + args_defaults={}, + get_embedding_ranks=None, + get_position_embedding_ranks=None, + non_loss_data_func=None, + store=None, + inprocess_call_wrapper: Optional[CallWrapper] = None, +): + """Main training program. + + This function will run the followings in the order provided: + 1) initialize Megatron. + 2) setup model, optimizer and lr schedule using the model_provider. + 3) call train_val_test_data_provider to get train/val/test datasets. + 4) train the model using the forward_step_func. + + Args: + train_valid_test_dataset_provider: a function that takes the size of + train/valid/test dataset and returns `train, valid, test` datasets. + model_provider: a function that returns a vanilla version of the + model. By vanilla we mean a simple model on cpu with no fp16 or ddp. + model_type: an enum that specifies the type of model being trained. + forward_step_func: a function that takes a `data iterator` and `model`, + and returns a `loss` scalar with a dictionary with key:values being + the info we would like to monitor during training, for example + `lm-loss: value`. We also require that this function add + `batch generator` to the timers class. + process_non_loss_data_func: a function to post process outputs of the + network. It can be used for dumping output tensors (e.g images) to + tensorboard. It takes `collected data`(list of tensors), + `current iteration index` and `tensorboard writer` as arguments. + extra_args_provider: a function that takes a parser and adds arguments + to it. It is used for programs to add their own arguments. + args_defaults: a dictionary from argument-name to argument-value. It + to set already parse arguments. + get_embedding_ranks (TODO): + get_position_embedding_ranks (TODO): + non_loss_data_func (callable): A custom function to call during evaluation. + It can run e.g. benchmarks. + store: an optional instance of torch.distributed.Store, to be used by + torch.distributed.init_process_group + inprocess_call_wrapper: an optional instance of inprocess.CallWrapper, + it is automatically injected when in-process restart is in use + """ + + if inprocess_call_wrapper is not None: + iteration = inprocess_call_wrapper.iteration + store = torch.distributed.PrefixStore(str(iteration), store) + + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron( + extra_args_provider=extra_args_provider, + args_defaults=args_defaults, + get_embedding_ranks=get_embedding_ranks, + get_position_embedding_ranks=get_position_embedding_ranks, + store=store, + ) + + args = get_args() + timers = get_timers() + + if args.log_progress: + append_to_progress_log("Starting job") + + # Initialize fault tolerance + # NOTE: ft_integration functions other than `setup` are no-op if the FT is not initialized + if args.enable_ft_package: + ft_integration.setup(args) + ft_integration.maybe_setup_simulated_fault() + + # Set pytorch JIT layer fusion options and warmup JIT functions. + set_jit_fusion_options() + + # Adjust the startup time so it reflects the largest value. + # This will be closer to what scheduler will see (outside of + # image ... launches. + global _TRAIN_START_TIME + start_time_tensor = torch.tensor([_TRAIN_START_TIME], dtype=torch.double, device='cuda') + torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN) + _TRAIN_START_TIME = start_time_tensor.item() + + app_metrics = {} + app_metrics['app_start_time'] = round(_TRAIN_START_TIME * 1000.0) + app_metrics['app_model_init_start_time'] = round(_TRAIN_START_TIME * 1000.0) + + print_rank_0( + 'time to initialize megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME) + ) + print_datetime('after megatron is initialized') + app_metrics['app_model_init_finish_time'] = one_logger_utils.get_timestamp_in_ms() + + # Track E2E metrics on pretrain start + one_logger_utils.on_pretrain_start() + + # Context used for persisting some state between checkpoint saves. + if args.non_persistent_ckpt_type == 'local': + try: + from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import ( + LocalCheckpointManager, + ) + from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import ( + parse_group_sequence, + GroupWrapper, + ) + from nvidia_resiliency_ext.checkpointing.local.replication.strategies import ( + CliqueReplicationStrategy, + ) + except ModuleNotFoundError: + raise RuntimeError( + "The 'nvidia_resiliency_ext' module is required for local " + "checkpointing but was not found. Please ensure it is installed." + ) + + if args.replication: + repl_strategy = CliqueReplicationStrategy.from_replication_params( + args.replication_jump, args.replication_factor + ) + else: + repl_strategy = None + + checkpointing_context = { + 'local_checkpoint_manager': LocalCheckpointManager( + args.non_persistent_local_ckpt_dir, repl_strategy=repl_strategy + ) + } + else: + checkpointing_context = {} + + # Model, optimizer, and learning rate. + timers('model-and-optimizer-setup', log_level=0).start(barrier=True) + model, optimizer, opt_param_scheduler = setup_model_and_optimizer( + model_provider, model_type, checkpointing_context=checkpointing_context + ) + + timers('model-and-optimizer-setup').stop() + print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') + config = get_model_config(model[0]) + + # Data stuff. + app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms() + timers('train/valid/test-data-iterators-setup', log_level=0).start(barrier=True) + if args.virtual_pipeline_model_parallel_size is not None: + train_data_iterator = [] + valid_data_iterator = [] + test_data_iterator = [] + for i in range(len(model)): + iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider) + train_data_iterator.append(iterators[0]) + valid_data_iterator.append(iterators[1]) + test_data_iterator.append(iterators[2]) + else: + train_data_iterator, valid_data_iterator, test_data_iterator = ( + build_train_valid_test_data_iterators(train_valid_test_dataset_provider) + ) + timers('train/valid/test-data-iterators-setup').stop() + print_datetime('after dataloaders are built') + app_metrics['app_build_dataiters_finish_time'] = one_logger_utils.get_timestamp_in_ms() + + # Track if training is enabled. Can only be done once args.do_train is assigned after dataloader is built. + one_logger_utils.track_config_flags( + args.train_iters, + args.skip_train, + args.do_train, + args.do_valid, + args.do_test, + args.dataloader_type, + args.retro_project_dir, + args.retro_cyclic_train_iters, + ) + + # Print setup timing. + print_rank_0('done with setup ...') + timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'], barrier=True) + + one_logger = get_one_logger() + one_logger and one_logger.log_metrics(app_metrics) + + wandb_writer = get_wandb_writer() + if wandb_writer: + # Add job name to the wandb config to make it easier to run more singleton dependency jobs. + wandb_writer.config.update({'slurm_job_name': os.getenv("SLURM_JOB_NAME", "N/A")}) + + if not args.skip_train: + print_rank_0('training ...') + + if args.dataloader_type == 'cyclic' and args.retro_project_dir: + assert args.retro_cyclic_train_iters is not None + args.train_iters = args.retro_cyclic_train_iters + print_rank_0("retro cyclic train iters : %d" % args.train_iters) + + iteration = 0 + if args.do_train and args.train_iters > 0: + iteration, num_floating_point_operations_so_far = train( + forward_step_func, + model, + optimizer, + opt_param_scheduler, + train_data_iterator, + valid_data_iterator, + process_non_loss_data_func, + config, + checkpointing_context, + non_loss_data_func, + ) + + print_datetime('after training is done') + + if args.save and iteration != 0 and iteration % args.save_interval != 0: + save_checkpoint( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + train_data_iterator=train_data_iterator, + preprocess_common_state_dict_fn=preprocess_common_state_dict, + ) + + one_logger and one_logger.log_metrics( + {'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()} + ) + + else: + print_rank_0('skipping training (--skip-train is on) ...') + + iteration = args.iteration + + if args.do_valid: + prefix = f'iteration {iteration} on validation set' + if getattr(args, 'perform_rl_step', False): + rl_utils.evaluate_and_print_results_rl( + valid_data_iterator, model, optimizer, + iteration, write_to_tensorboard=not args.skip_train + ) + else: + evaluate_and_print_results( + prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, config, + verbose=True, write_to_tensorboard=not args.skip_train, + non_loss_data_func=non_loss_data_func + ) + + if args.do_test: + prefix = f'iteration {iteration} on test set' + evaluate_and_print_results( + prefix, + forward_step_func, + test_data_iterator, + model, + iteration, + process_non_loss_data_func, + config, + verbose=True, + write_to_tensorboard=not args.skip_train, + non_loss_data_func=non_loss_data_func, + ) + + wandb_writer = get_wandb_writer() + if wandb_writer: + wandb_writer.finish() + + ft_integration.on_checkpointing_start() + maybe_finalize_async_save(blocking=True, terminate=True) + ft_integration.on_checkpointing_end(is_async_finalization=True) + + one_logger and one_logger.log_metrics( + {'app_finish_time': one_logger_utils.get_timestamp_in_ms()} + ) + + ft_integration.shutdown() + one_logger_utils.finish() + + +def update_train_iters(args): + + # For iteration-based training, we don't need to do anything + if args.train_iters: + return + + # Constant batch size with sample-based training. + if args.rampup_batch_size is None: + args.train_iters = args.train_samples // args.global_batch_size + + else: + # Sample based training with rampup batch size. + iterations = 0 + consumed_samples = 0 + # Rampup phase. + while ( + consumed_samples <= int(args.rampup_batch_size[2]) + and consumed_samples <= args.train_samples + ): + update_num_microbatches(consumed_samples, consistency_check=False) + consumed_samples += get_current_global_batch_size() + iterations += 1 + # Reset + update_num_microbatches(0, consistency_check=False) + # Constant phase + # Note that we throw away any partial last batch. + if args.train_samples > consumed_samples: + iterations += (args.train_samples - consumed_samples) // args.global_batch_size + args.train_iters = iterations + + print_rank_0(f'setting training iterations to {args.train_iters}') + + +def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): + """Build the model.""" + args = get_args() + args.model_type = model_type + + # Build model. + def build_model(): + if ( + mpu.get_pipeline_model_parallel_world_size() > 1 + and args.virtual_pipeline_model_parallel_size is not None + ): + model = [] + for i in range(args.virtual_pipeline_model_parallel_size): + # Set pre_process and post_process only after virtual rank is set. + pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) + post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) + this_model = model_provider_func( + pre_process=pre_process, post_process=post_process, vp_stage=i) + this_model.model_type = model_type + this_model.vp_stage = i + model.append(this_model) + else: + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + model = model_provider_func(pre_process=pre_process, post_process=post_process) + model.model_type = model_type + return model + + if args.init_model_with_meta_device: + with torch.device('meta'): + model = build_model() + else: + model = build_model() + + if not isinstance(model, list): + model = [model] + + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for model_module in model: + for param in model_module.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # Print number of parameters. + num_parameters = sum( + [sum([p.nelement() for p in model_module.parameters()]) for model_module in model] + ) + if mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0: + print( + ' > number of parameters on (tensor, pipeline) ' + 'model parallel rank ({}, {}): {}'.format( + mpu.get_tensor_model_parallel_rank(), + mpu.get_pipeline_model_parallel_rank(), + num_parameters, + ), + flush=True, + ) + + # GPU allocation. + # For FSDP2, we don't allocate GPU memory here. We allocate GPU memory + # in the fully_shard function of FSDP2 instead. + if ( + not (args.use_torch_fsdp2 and args.use_cpu_initialization) + and not args.init_model_with_meta_device + ): + for model_module in model: + model_module.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if args.fp16 or args.bf16: + config = get_model_config(model[0]) + model = [Float16Module(config, model_module) for model_module in model] + + # Materialize tensors on meta device (GPU allocation) if not using FSDP2. + if args.init_model_with_meta_device and not args.use_torch_fsdp2: + #for model_module in model: + model = [to_empty_if_meta_device(model_module, device=torch.device("cuda")) for model_module in model] + + + + + # Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace + # copy of TE's Float8Tensor, which will write an unwanted value (amax calculated + # from the current fp8 param) to its amax_history. The below function will correct + # the amax_history back. + # After TE2.x: Below function is an empty function and does nothing. + correct_amax_history_if_needed(model) + + if wrap_with_ddp: + if args.use_torch_fsdp2: + assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0" + DP = torch_FSDP + elif args.use_megatron_fsdp: + DP = megatron_FSDP + else: + DP = DDP + + config = get_model_config(model[0]) + + if getattr(args, "use_torch_fsdp2", False): + reshard_after_forward = getattr(args, "torch_fsdp2_reshard_after_forward", True) + ddp_config = TorchFullyShardedDataParallelConfig(reshard_after_forward=reshard_after_forward) + else: + kwargs = {} + for f in dataclasses.fields(DistributedDataParallelConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 + kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad + kwargs['check_for_large_grads'] = args.check_for_large_grads + if args.ddp_num_buckets is not None: + assert args.ddp_bucket_size is None, \ + "Cannot specify both --ddp-num-buckets and --ddp-bucket-size" + assert args.ddp_num_buckets > 0, \ + "--ddp-num-buckets must be greater than 0" + kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets + else: + kwargs['bucket_size'] = args.ddp_bucket_size + kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw + kwargs['average_in_collective'] = args.ddp_average_in_collective + if args.use_megatron_fsdp and args.use_precision_aware_optimizer: + kwargs["preserve_fp32_weights"] = False + ddp_config = DistributedDataParallelConfig(**kwargs) + + # In the Megatron FSDP and DDP use path, we need to initialize the bucket size. + # If bucket_size is not provided as an input, use sane default. + # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL + # ring-reduce implementations are large enough to remain bandwidth-bound rather than + # latency-bound. + if ddp_config.bucket_size is None: + ddp_config.bucket_size = max( + 40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True) + ) + # Set bucket_size to infinity if overlap_grad_reduce is False. + if not ddp_config.overlap_grad_reduce: + ddp_config.bucket_size = None + + with torch.cuda.stream(torch.cuda.Stream()): + model = [ + DP( + config=config, + ddp_config=ddp_config, + module=model_chunk, + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0) + or args.overlap_param_gather_with_optimizer_step, + ) + for (model_chunk_idx, model_chunk) in enumerate(model) + ] + + # Broadcast params from data parallel src rank to other data parallel ranks. + if args.data_parallel_random_init: + for model_module in model: + model_module.broadcast_params() + + return model + + +def get_optimizer_param_scheduler(optimizer): + """Build the learning rate scheduler.""" + args = get_args() + + # Iteration-based training. + if args.train_iters: + if args.lr_decay_iters is None: + args.lr_decay_iters = args.train_iters + lr_decay_steps = args.lr_decay_iters * args.global_batch_size + wd_incr_steps = args.train_iters * args.global_batch_size + wsd_decay_steps = None + if args.lr_wsd_decay_iters is not None: + wsd_decay_steps = args.lr_wsd_decay_iters * args.global_batch_size + if args.lr_warmup_fraction is not None: + lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps + else: + lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size + # Sample-based training. + elif args.train_samples: + # We need to set training iters for later use. Technically + # we need to adjust the training samples too (due to last + # batch being incomplete) but we leave it as is for now. + update_train_iters(args) + if args.lr_decay_samples is None: + args.lr_decay_samples = args.train_samples + lr_decay_steps = args.lr_decay_samples + wd_incr_steps = args.train_samples + wsd_decay_steps = args.lr_wsd_decay_samples + if args.lr_warmup_fraction is not None: + lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps + else: + lr_warmup_steps = args.lr_warmup_samples + else: + raise Exception('either train-iters or train-samples should be provided.') + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + init_lr=args.lr_warmup_init, + max_lr=args.lr, + min_lr=args.min_lr, + lr_warmup_steps=lr_warmup_steps, + lr_decay_steps=lr_decay_steps, + lr_decay_style=args.lr_decay_style, + start_wd=args.start_weight_decay, + end_wd=args.end_weight_decay, + wd_incr_steps=wd_incr_steps, + wd_incr_style=args.weight_decay_incr_style, + use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler, + override_opt_param_scheduler=args.override_opt_param_scheduler, + wsd_decay_steps=wsd_decay_steps, + lr_wsd_decay_style=args.lr_wsd_decay_style, + ) + + return opt_param_scheduler + + +def setup_model_and_optimizer( + model_provider_func, + model_type, + no_wd_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0, + checkpointing_context=None, +): + """Setup model and optimizer.""" + args = get_args() + timers = get_timers() + one_logger = get_one_logger() + + model = get_model(model_provider_func, model_type) + unwrapped_model = unwrap_model(model) + + one_logger and one_logger.log_metrics({"app_build_optimzer_start_time": one_logger_utils.get_timestamp_in_ms()}) + kwargs = {} + for f in dataclasses.fields(OptimizerConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + config = OptimizerConfig(**kwargs) + config.timers = timers + optimizer = get_megatron_optimizer( + config, + model, + no_wd_decay_cond, + scale_lr_cond, + lr_mult, + use_gloo_process_groups=args.enable_gloo_process_groups, + # If the user is asking for a non-zero embedding init std, skip weight decay for embeddings + # to avoid embeddings from shrinking to zero as recommended in https://arxiv.org/abs/2312.16903 + default_skip_embedding_weight_decay=args.embedding_init_method_std is not None, + ) + opt_param_scheduler = get_optimizer_param_scheduler(optimizer) + one_logger and one_logger.log_metrics({"app_build_optimzer_finish_time": one_logger_utils.get_timestamp_in_ms()}) + + if args.moe_use_upcycling: + torch.distributed.barrier() + assert not checkpoint_exists(args.save), ( + "The upcycling destination directory already exists. " + "Please check if --moe-use-upcycling is mistakenly enabled. " + "Upcycling should only be set for the first run when converting the dense model. " + "All subsequent runs should remove this flag. " + ) + # before changing moe related global args, save them in local variables + num_experts = args.num_experts + expert_model_parallel_size = args.expert_model_parallel_size + moe_ffn_hidden_size = args.ffn_hidden_size + + # set dense model related args in to global args before getting dense model + args.num_experts = None + args.expert_model_parallel_size = 1 + args.ffn_hidden_size = moe_ffn_hidden_size * args.moe_upcycling_granularity + + # get dense model + dense_model_for_upcycling = get_model(model_provider_func, model_type) + + # recover moe upcycling related args in global args before executing upcycling + args.num_experts = num_experts + args.expert_model_parallel_size = expert_model_parallel_size + args.ffn_hidden_size = moe_ffn_hidden_size + + # execute upcycling + _, args.num_floating_point_operations_so_far = upcycling_utils.load_and_upcycle_model( + load_checkpoint, + unwrapped_model, + dense_model_for_upcycling, + load_kwargs={ + 'model': dense_model_for_upcycling, + 'optimizer': None, + 'opt_param_scheduler': None, + }, + ) + args.iteration = 1 + save_checkpoint( + args.iteration, model, None, None, args.num_floating_point_operations_so_far + ) + torch.distributed.barrier() + del dense_model_for_upcycling + if (args.fp16 or args.bf16) and optimizer is not None: + optimizer.reload_model_params() + print_rank_0(f'Upcycled checkpoint saved to {args.save}') + + if ( + args.load is not None or args.pretrained_checkpoint is not None + ) and not args.moe_use_upcycling: + one_logger and one_logger.log_metrics( + {'load_checkpoint_start_time': one_logger_utils.get_timestamp_in_ms()} + ) + timers('load-checkpoint', log_level=0).start(barrier=True) + + args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( + model, + optimizer, + opt_param_scheduler, + checkpointing_context=checkpointing_context, + skip_load_to_model_and_opt=HAVE_FSDP2 + and getattr(args, "use_torch_fsdp2", False) + and args.ckpt_format == "torch_dist", + ) + timers('load-checkpoint').stop(barrier=True) + timers.log(['load-checkpoint']) + one_logger and one_logger.log_metrics( + { + 'load_checkpoint_finish_time': one_logger_utils.get_timestamp_in_ms(), + 'load_checkpoint_time': timers('load-checkpoint').active_time(), + } + ) + else: + args.iteration = 0 + args.num_floating_point_operations_so_far = 0 + + # get model without FP16 and/or DDP wrappers + if ( + args.iteration == 0 + and len(unwrapped_model) == 1 + and hasattr(unwrapped_model[0], 'init_state_dict_from_bert') + ): + print_rank_0("Initializing ICT from pretrained BERT model") + unwrapped_model[0].init_state_dict_from_bert() + if args.fp16: + optimizer.reload_model_params() + + # Convert checkpoint format. + if args.ckpt_convert_format is not None: + load_ckpt_format = args.ckpt_format + args.ckpt_format = args.ckpt_convert_format + args.save = os.path.join(args.ckpt_convert_save, args.ckpt_convert_format) + update_use_dist_ckpt(args) + + save_checkpoint( + args.iteration, + model, + optimizer, + opt_param_scheduler, + args.num_floating_point_operations_so_far, + preprocess_common_state_dict_fn=preprocess_common_state_dict, + ) + + print_rank_0("> converted checkpoint: %s -> %s." % (load_ckpt_format, args.ckpt_format)) + torch.distributed.barrier() + exit() + + return model, optimizer, opt_param_scheduler + + +def dummy_train_step(data_iterator): + """Single dummy training step.""" + num_microbatches = get_num_microbatches() + rerun_state_machine = get_rerun_state_machine() + while rerun_state_machine.should_run_forward_backward(data_iterator): + for _ in range(num_microbatches): + # Re-use methods used in get_batch() from pretrain_{gpt, mamba}.py. + batch = get_batch_on_this_tp_rank(data_iterator) + batch = get_batch_on_this_cp_rank(batch) + + +def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func): + """Single training step.""" + args = get_args() + timers = get_timers() + + rerun_state_machine = get_rerun_state_machine() + while rerun_state_machine.should_run_forward_backward(data_iterator): + # Set grad to zero. + for model_chunk in model: + model_chunk.zero_grad_buffer() + optimizer.zero_grad() + + if has_nvidia_modelopt: + # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors + adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation( + model, args.seq_length, args.micro_batch_size, args.decoder_seq_length + ) + else: + adjust_tensor_shapes_fn = None + + # For the mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap, + # we need to call the _copy_main_params_to_param_buffer() after the grad buffer + # is zeroed by zero_grad_buffer() because param and grad buffer are shared. + if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather: + for optim_instance in optimizer.chained_optimizers: + if isinstance(optim_instance, DistributedOptimizer): + optim_instance._copy_main_params_to_param_buffer() + + # Update sample index for tensor saving if enabled + if getattr(args, 'save_tensors', False): + try: + from megatron.core.tensor_saver import get_tensor_collection_state + state = get_tensor_collection_state() + # Reset sample index for each training step + state.set_sample_idx(0) + except Exception as e: + print(f"[TrainStep] Warning: Failed to reset sample index: {e}") + + # Forward pass. + losses_reduced = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=False, + adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, + ) + should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() + if should_exit: + return {}, True, should_checkpoint, should_exit, exit_code, None, None + + # Empty unused memory. + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + # Vision gradients. + if args.vision_pretraining and args.vision_pretraining_type == "dino": + unwrapped_model = unwrap_model(model[0]) + unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) + + # Update parameters. + + timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) + update_successful, grad_norm, num_zeros_in_grad = optimizer.step() + timers('optimizer').stop() + + # when freezing sub-models we may have a mixture of successful and unsucessful ranks, + # so we must gather across mp ranks + update_successful = logical_and_across_model_parallel_group(update_successful) + # grad_norm and num_zeros_in_grad will be None on ranks without trainable params, + # so we must gather across mp ranks + grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) + if args.log_num_zeros_in_grad: + num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad) + + # Vision momentum. + if args.vision_pretraining and args.vision_pretraining_type == "dino": + unwrapped_model = unwrap_model(model[0]) + unwrapped_model.update_momentum(args.curr_iteration) + + # Update learning rate. + if update_successful: + increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size + opt_param_scheduler.step(increment=increment) + skipped_iter = 0 + else: + skipped_iter = 1 + + # Empty unused memory. + if args.empty_unused_memory_level >= 2: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + + for key in losses_reduced[0].keys(): + val = [x[key].view(-1) for x in losses_reduced] + if val[0].numel() == 2: + if args.sft: + # in mcore the normalization happens on micro batch instead of global + val = torch.vstack(val) + val = val[:, 0] / val[:, 1] + val = val.mean() + torch.distributed.all_reduce( + val, + group=mpu.get_data_parallel_group(with_context_parallel=True) + ) + val /= torch.distributed.get_world_size( + group=mpu.get_data_parallel_group(with_context_parallel=True) + ) + loss_reduced[key] = val + else: + # there is one dict per microbatch. in new reporting, we average + # over the total number of tokens across the global batch. + val = torch.vstack(val).sum(dim=0) + torch.distributed.all_reduce( + val, + group=mpu.get_data_parallel_group(with_context_parallel=True) + ) + loss_reduced[key] = val[0] / val[1] + elif val[0].numel() == 1: + # legacy behavior, we average over the number of microbatches + val = torch.cat(val).mean() + loss_reduced[key] = val + else: + raise ValueError(f"Invalid value shape: {val[0].shape} for key {key}") + return ( + loss_reduced, + skipped_iter, + should_checkpoint, + should_exit, + exit_code, + grad_norm, + num_zeros_in_grad, + ) + return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad + + +def training_log( + loss_dict, + total_loss_dict, + learning_rate, + decoupled_learning_rate, + iteration, + loss_scale, + report_memory_flag, + skipped_iter, + grad_norm, + params_norm, + num_zeros_in_grad, +): + """Log training information such as losses, timing, ....""" + args = get_args() + timers = get_timers() + writer = get_tensorboard_writer() + wandb_writer = get_wandb_writer() + one_logger = get_one_logger() + energy_monitor = get_energy_monitor() + + # Advanced, skipped, and Nan iterations. + advanced_iters_key = 'advanced iterations' + skipped_iters_key = 'skipped iterations' + nan_iters_key = 'nan iterations' + # Advanced iterations. + if not skipped_iter: + total_loss_dict[advanced_iters_key] = total_loss_dict.get(advanced_iters_key, 0) + 1 + else: + if advanced_iters_key not in total_loss_dict: + total_loss_dict[advanced_iters_key] = 0 + # Skipped iterations. + total_loss_dict[skipped_iters_key] = total_loss_dict.get(skipped_iters_key, 0) + skipped_iter + # Update losses and set nan iterations + got_nan = False + for key in loss_dict: + if not skipped_iter: + total_loss_dict[key] = ( + total_loss_dict.get(key, torch.tensor([0.0], dtype=torch.float, device='cuda')) + + loss_dict[key] + ) + else: + value = loss_dict[key].float().sum().item() + is_nan = value == float('inf') or value == -float('inf') or value != value + got_nan = got_nan or is_nan + total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int(got_nan) + + # Logging. + timers_to_log = [ + 'forward-backward', + 'forward-compute', + 'backward-compute', + 'batch-generator', + 'forward-recv', + 'forward-send', + 'backward-recv', + 'backward-send', + 'forward-send-forward-recv', + 'forward-send-backward-recv', + 'backward-send-forward-recv', + 'backward-send-backward-recv', + 'forward-backward-send-forward-backward-recv', + 'layernorm-grads-all-reduce', + 'embedding-grads-all-reduce', + 'all-grads-sync', + 'params-all-gather', + 'optimizer-copy-to-main-grad', + 'optimizer-unscale-and-check-inf', + 'optimizer-clip-main-grad', + 'optimizer-count-zeros', + 'optimizer-inner-step', + 'optimizer-copy-main-to-model-params', + 'optimizer', + ] + # Add timers from RL loop if needed. + if getattr(args, 'perform_rl_step', False): + timers_to_log.extend(['rollout-collection', 'rollout-collection-barrier', + 'compute-logprobs', 'compute-ref-logprobs', + 'prepare-advantages']) + + # Calculate batch size. + batch_size = args.micro_batch_size * args.data_parallel_size * get_num_microbatches() + + # Track app tag & app tag ID + one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length) + + total_iterations = total_loss_dict[advanced_iters_key] + total_loss_dict[skipped_iters_key] + + # learning rate will be None on ranks without trainable params, so we must gather across mp ranks + learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate) + # Tensorboard values. + if writer and (iteration % args.tensorboard_log_interval == 0): + if wandb_writer: + wandb_writer.log({'samples vs steps': args.consumed_train_samples}, iteration) + writer.add_scalar('learning-rate', learning_rate, iteration) + writer.add_scalar('learning-rate vs samples', learning_rate, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'learning-rate': learning_rate}, iteration) + if args.decoupled_lr is not None: + writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration) + if args.skipped_train_samples > 0: + writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration) + if wandb_writer: + wandb_writer.log({'skipped-train-samples': args.skipped_train_samples}, iteration) + writer.add_scalar('batch-size', batch_size, iteration) + writer.add_scalar('batch-size vs samples', batch_size, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'batch-size': batch_size}, iteration) + for key in loss_dict: + writer.add_scalar(key, loss_dict[key], iteration) + writer.add_scalar(key + ' vs samples', loss_dict[key], args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({key: loss_dict[key]}, iteration) + if args.log_loss_scale_to_tensorboard: + writer.add_scalar('loss-scale', loss_scale, iteration) + writer.add_scalar('loss-scale vs samples', loss_scale, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'loss-scale': loss_scale}, iteration) + if args.log_world_size_to_tensorboard: + writer.add_scalar('world-size', args.world_size, iteration) + writer.add_scalar('world-size vs samples', args.world_size, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'world-size': args.world_size}, iteration) + if grad_norm is not None: + writer.add_scalar('grad-norm', grad_norm, iteration) + writer.add_scalar('grad-norm vs samples', grad_norm, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'grad-norm': grad_norm}, iteration) + if num_zeros_in_grad is not None: + writer.add_scalar('num-zeros', num_zeros_in_grad, iteration) + writer.add_scalar( + 'num-zeros vs samples', num_zeros_in_grad, args.consumed_train_samples + ) + if wandb_writer: + wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration) + if params_norm is not None: + writer.add_scalar('params-norm', params_norm, iteration) + writer.add_scalar('params-norm vs samples', params_norm, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'params-norm': params_norm}, iteration) + if getattr(args, 'perform_rl_step', False): + grpo_collection_iteration = iteration // (args.grpo_iterations * ( ( args.grpo_samples_per_iteration )// args.global_batch_size )) + writer.add_scalar('grpo_collection_iteration', grpo_collection_iteration, iteration) + if wandb_writer: + wandb_writer.log({'grpo_collection_iteration': grpo_collection_iteration}, iteration) + if args.log_memory_to_tensorboard: + mem_stats = torch.cuda.memory_stats() + writer.add_scalar( + "mem-reserved-bytes", mem_stats["reserved_bytes.all.current"], iteration + ) + writer.add_scalar( + "mem-allocated-bytes", mem_stats["allocated_bytes.all.current"], iteration + ) + writer.add_scalar( + "mem-max-allocated-bytes", mem_stats["allocated_bytes.all.peak"], iteration + ) + writer.add_scalar("mem-allocated-count", mem_stats["allocation.all.current"], iteration) + if args.num_experts is not None: + moe_loss_scale = 1 / get_num_microbatches() + track_names = [] + if "aux_loss" in args.moe_router_load_balancing_type: + track_names.append("load_balancing_loss") + if "seq_aux_loss" in args.moe_router_load_balancing_type: + track_names.append("seq_load_balancing_loss") + if "global_aux_loss" in args.moe_router_load_balancing_type: + track_names.append("global_load_balancing_loss") + if args.moe_z_loss_coeff is not None: + track_names.append("z_loss") + track_moe_metrics( + loss_scale=moe_loss_scale, + iteration=iteration, + writer=writer, + wandb_writer=wandb_writer, + total_loss_dict=total_loss_dict, + per_layer_logging=args.moe_per_layer_logging, + force_initialize=True, + track_names=track_names, + num_layers=args.num_layers, + moe_layer_freq=args.moe_layer_freq, + mtp_num_layers=args.mtp_num_layers, + ) + if args.mtp_num_layers is not None: + mtp_loss_scale = 1 / get_num_microbatches() + MTPLossLoggingHelper.track_mtp_metrics( + mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict + ) + if iteration % args.log_interval == 0: + if args.record_memory_history and is_last_rank(): + snapshot = torch.cuda.memory._snapshot() + from pickle import dump + + with open(args.memory_snapshot_path, 'wb') as f: + dump(snapshot, f) + + elapsed_time = timers('interval-time').elapsed(barrier=True) + elapsed_time_per_iteration = elapsed_time / total_iterations + + throughput = num_floating_point_operations(args, batch_size) / ( + elapsed_time_per_iteration * 10**12 * args.world_size + ) + + one_logger_utils.track_e2e_metrics(args.log_throughput, throughput) + + if args.log_timers_to_tensorboard: + if writer: + writer.add_scalar('iteration-time', elapsed_time_per_iteration, iteration) + if wandb_writer: + wandb_writer.log({'iteration-time': elapsed_time_per_iteration}, iteration) + log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]" + log_string += ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters) + log_string += ' consumed samples: {:12d} |'.format(args.consumed_train_samples) + if args.skipped_train_samples > 0: + log_string += ' skipped samples: {:12d} |'.format(args.skipped_train_samples) + log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( + elapsed_time_per_iteration * 1000.0 + ) + if args.log_throughput: + log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |' + if args.log_timers_to_tensorboard: + if writer: + writer.add_scalar('throughput', throughput, iteration) + if wandb_writer: + wandb_writer.log({'throughput': throughput}, iteration) + if args.log_energy: + energy = (energy_monitor.lap() / total_iterations) / args.world_size + power = energy / elapsed_time_per_iteration + log_string += f' energy per GPU (J/iter/GPU): {energy:.1f} |' + log_string += f' power per GPU (W/GPU): {power:.1f} |' + if writer: + writer.add_scalar('iter-energy/gpu', energy, iteration) + writer.add_scalar('power/gpu', power, iteration) + if wandb_writer: + wandb_writer.log({'iter-energy/gpu': energy}, iteration) + wandb_writer.log({'power/gpu': power}, iteration) + # Decoupled_learning_rate should be not None only on first and last pipeline stage. + log_string += f' learning rate: {learning_rate:.6E} |' + if args.decoupled_lr is not None and ( + mpu.is_pipeline_first_stage(ignore_virtual=True) + or mpu.is_pipeline_last_stage(ignore_virtual=True) + ): + assert decoupled_learning_rate is not None + log_string += f' decoupled learning rate: {decoupled_learning_rate:.6E} |' + else: + assert decoupled_learning_rate is None + log_string += f' global batch size: {batch_size:5d} |' + for key in total_loss_dict: + if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]: + avg = total_loss_dict[key].item() / float( + max(1, total_loss_dict[advanced_iters_key]) + ) + if avg > 0.0: + log_string += ' {}: {:.6E} |'.format(key, avg) + total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda') + log_string += f' loss scale: {loss_scale:.1f} |' + if grad_norm is not None: + log_string += f' grad norm: {grad_norm:.3f} |' + if num_zeros_in_grad is not None: + log_string += f' num zeros: {num_zeros_in_grad} |' + if params_norm is not None: + log_string += f' params norm: {params_norm:.3f} |' + log_string += ' number of skipped iterations: {:3d} |'.format( + total_loss_dict[skipped_iters_key] + ) + log_string += ' number of nan iterations: {:3d} |'.format(total_loss_dict[nan_iters_key]) + total_loss_dict[advanced_iters_key] = 0 + total_loss_dict[skipped_iters_key] = 0 + total_loss_dict[nan_iters_key] = 0 + print_rank_last(log_string) + if report_memory_flag: + # Report memory after optimizer state has been initialized. + if torch.distributed.get_rank() == 0: + num_microbatches = get_num_microbatches() + report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True) + report_memory(f'(after {iteration} iterations)') + report_memory_flag = False + # Write timers to wandb, don't reset the counts + if args.log_timers_to_tensorboard: + timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False) + timers.write(timers_to_log, wandb_writer, iteration, normalizer=args.log_interval, reset=False) + # Log timers to stdout + timers.log(timers_to_log, normalizer=args.log_interval) + + return report_memory_flag + + +def compute_throughputs_and_append_to_progress_log(iteration, num_floating_point_operations_so_far): + args = get_args() + if args.save is None: + return + + # Compute job throughput. + # args.num_floating_point_operations_so_far keeps track of floating-point operations + # completed at the start of job. + global _TRAIN_START_TIME + job_throughput = ( + num_floating_point_operations_so_far - args.num_floating_point_operations_so_far + ) / ((time.time() - _TRAIN_START_TIME) * 10**12 * args.world_size) + + # Compute cumulative throughput since jobs of this world size were launched. + # `get_start_time_from_progress_log` returns start time and number of floating-point + # operations of first job of this world size. + start_time, start_num_floating_point_operations = get_start_time_from_progress_log() + elapsed_time = (datetime.now() - start_time).total_seconds() + cumulative_throughput = ( + num_floating_point_operations_so_far - start_num_floating_point_operations + ) / (elapsed_time * 10**12 * args.world_size) + + tokens_so_far = args.consumed_train_samples * args.seq_length + saved_ckpt_prefix = 'Saving async checkpoint' if args.async_save else 'Saved checkpoint' + append_to_progress_log( + f"{saved_ckpt_prefix}\tIteration: {iteration}\t" + f"Job throughput: {job_throughput:.1f} TFLOP/s/GPU\t" + f"Cumulative throughput: {cumulative_throughput:.1f} TFLOP/s/GPU\t" + f"Floating-point operations: {num_floating_point_operations_so_far:.2e}\t" + f"Tokens (in billions): {tokens_so_far / 10**9:.2f}" + ) + + +def enable_forward_pre_hook(model_chunks): + for model_chunk in model_chunks: + assert isinstance(model_chunk, DDP) + model_chunk.enable_forward_pre_hook() + + +def disable_forward_pre_hook(model_chunks, param_sync=True): + for model_chunk in model_chunks: + assert isinstance(model_chunk, DDP) + model_chunk.disable_forward_pre_hook(param_sync=param_sync) + + +def save_checkpoint_and_time( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + non_persistent_ckpt=False, + train_data_iterator=None, +): + args = get_args() + timers = get_timers() + energy_monitor = get_energy_monitor() + + # Stop timer to get accurate train interval time and exclude checkpointing duration + timers('interval-time').stop() + energy_monitor.pause() + + # Extra barrier is added to make sure all ranks report the max time. + timer_key = 'save-checkpoint-non-persistent' if non_persistent_ckpt else 'save-checkpoint' + timers(timer_key, log_level=0).start(barrier=True) + + # Log E2E metrics before save-checkpoint + one_logger_utils.track_e2e_metrics() + if should_disable_forward_pre_hook(args): + disable_forward_pre_hook(model) + save_checkpoint( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + non_persistent_ckpt=non_persistent_ckpt, + train_data_iterator=train_data_iterator, + preprocess_common_state_dict_fn=preprocess_common_state_dict, + ) + if args.fp8: + # Run garbage collection after checkpoint saving to free memory from + # dequantized bf16 tensors that were temporarily created during fp8 + # model checkpoint saving. + gc.collect() + if should_disable_forward_pre_hook(args): + enable_forward_pre_hook(model) + timers(timer_key).stop(barrier=True) + timers.log([timer_key]) + + # Log E2E metrics after save-checkpoint + one_logger_utils.track_e2e_metrics() + save_checkpoint_duration = timers(timer_key).elapsed() + one_logger_utils.on_save_checkpoint_end(save_checkpoint_duration, iteration, args.async_save) + + if args.log_progress and not non_persistent_ckpt: + compute_throughputs_and_append_to_progress_log( + iteration, num_floating_point_operations_so_far + ) + + # Recover timing + energy_monitor.resume() + timers('interval-time', log_level=0).start(barrier=True) + + +def post_training_step_callbacks( + model, + optimizer, + opt_param_scheduler, + iteration, + prof, + num_floating_point_operations_since_last_log_event, +): + """Run all post-training-step functions (e.g., FT heartbeats, GC).""" + args = get_args() + + # Bring CPU and GPU back in sync if on right iteration. + if args.train_sync_interval and iteration % args.train_sync_interval == 0: + torch.cuda.synchronize() + + # Straggler detector. + if iteration % args.log_interval == 0 and args.log_straggler: + stimer.report(num_floating_point_operations_since_last_log_event, args.log_interval) + num_floating_point_operations_since_last_log_event = 0.0 + + # Check weight hash across DP replicas. + if ( + args.check_weight_hash_across_dp_replicas_interval is not None + and iteration % args.check_weight_hash_across_dp_replicas_interval == 0 + ): + if should_disable_forward_pre_hook(args): + disable_forward_pre_hook(model) + assert check_param_hashes_across_dp_replicas( + model, cross_check=True + ), "Parameter hashes not matching across DP replicas" + torch.distributed.barrier() + print_rank_0(f">>> Weight hashes match after {iteration} iterations...") + if should_disable_forward_pre_hook(args): + enable_forward_pre_hook(model) + + # Autoresume. + if args.adlr_autoresume and (iteration % args.adlr_autoresume_interval == 0): + check_adlr_autoresume_termination(iteration, model, optimizer, opt_param_scheduler) + + # Profiling. + if ( + args.profile + and iteration == args.profile_step_end + and torch.distributed.get_rank() in args.profile_ranks + ): + if args.use_pytorch_profiler: + assert prof is not None + prof.stop() + else: + torch.cuda.cudart().cudaProfilerStop() + + # Manual garbage collection. + if args.manual_gc: + if args.manual_gc_interval != 0 and iteration % args.manual_gc_interval == 0: + gc.collect() + + +def checkpoint_and_decide_exit( + model, + optimizer, + opt_param_scheduler, + iteration, + num_floating_point_operations_so_far, + checkpointing_context, + train_data_iterator, +): + """Save checkpoint and decide whether to exit based on arguments (e.g., if + --exit-duration-in-mins is set). Actual exit happens in main training loop + based on the return value of this function.""" + # import pdb;pdb.set_trace() + args = get_args() + timers = get_timers() + + # Exit based on signal handler. + saved_checkpoint = False + if args.exit_signal_handler: + signal_handler = get_signal_handler() + if any(signal_handler.signals_received()): + if args.save: + save_checkpoint_and_time( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + train_data_iterator=train_data_iterator, + ) + print_datetime('exiting program after receiving SIGTERM.') + + return True + + # Regular save (persistent and non-persistent). + if args.save and args.save_interval and iteration % args.save_interval == 0: + save_checkpoint_and_time( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + train_data_iterator=train_data_iterator, + ) + saved_checkpoint = True + + elif ( + args.save + and args.non_persistent_save_interval + and iteration % args.non_persistent_save_interval == 0 + ): + save_checkpoint_and_time( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + non_persistent_ckpt=True, + train_data_iterator=train_data_iterator, + ) + # import pdb;pdb.set_trace() + saved_checkpoint = True + + # Exit based on duration. + if args.exit_duration_in_mins: + train_time = (time.time() - _TRAIN_START_TIME) / 60.0 + done_cuda = torch.tensor( + [train_time > args.exit_duration_in_mins], dtype=torch.int, device='cuda' + ) + torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) + done = done_cuda.item() + # return False + if done: + if args.save and not saved_checkpoint: + save_checkpoint_and_time( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + train_data_iterator=train_data_iterator, + ) + print_datetime(f'exiting program after {train_time} minutes') + # import pdb;pdb.set_trace() + return True + # return False + + # Exit based on iterations. + if args.exit_interval and iteration % args.exit_interval == 0: + if args.save and not saved_checkpoint: + save_checkpoint_and_time( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + train_data_iterator=train_data_iterator, + ) + print_datetime(f'exiting program at iteration {iteration}') + # import pdb;pdb.set_trace() + + return True + + return False + + +def train( + forward_step_func, + model, + optimizer, + opt_param_scheduler, + train_data_iterator, + valid_data_iterator, + process_non_loss_data_func, + config, + checkpointing_context, + non_loss_data_func, +): + """Training function: run train_step desired number of times, run validation, checkpoint.""" + args = get_args() + timers = get_timers() + + if getattr(args, 'perform_rl_step', False): + assert has_rl_utils, "RL cannot run without the lang_rl package" + + # Additional variable initialization for RL training + ref_state_dict = None + + # IMPORTANT FIX: For RL training, reinitialize the microbatch calculator with the correct configuration + if getattr(args, 'perform_rl_step', False): + print_rank_0("> Reinitializing microbatch calculator for GRPO training...") + from megatron.core.num_microbatches_calculator import ( + destroy_num_microbatches_calculator, + init_num_microbatches_calculator + ) + # First destroy the existing calculator + destroy_num_microbatches_calculator() + # Then initialize with the correct perform_rl_step=True context + init_num_microbatches_calculator( + args.rank, + args.rampup_batch_size, + args.global_batch_size, + args.micro_batch_size, + mpu.get_data_parallel_world_size(), + args.decrease_batch_size_if_needed + ) + print_rank_0(f"> GRPO training: num_microbatches set to {get_num_microbatches()}") + + energy_monitor = get_energy_monitor() + one_logger = get_one_logger() + + if args.run_workload_inspector_server: + try: + from workload_inspector.utils.webserver import run_server + import threading + + threading.Thread( + target=run_server, daemon=True, args=(torch.distributed.get_rank(),) + ).start() + except ModuleNotFoundError: + print_rank_0("workload inspector module not found.") + + # Write args to tensorboard + write_args_to_tensorboard() + + # Turn on training mode which enables dropout. + # import pdb;pdb.set_trace() + print(f"model:{model}") + for model_module in model: + model_module.train() + + # Tracking loss. + total_loss_dict = {} + + # Iterations. + iteration = args.iteration + # Make sure rerun_state_machine has the right iteration loaded from checkpoint. + rerun_state_machine = get_rerun_state_machine() + if rerun_state_machine.current_iteration != iteration: + print_rank_0(f"Overwriting rerun_state_machine.current_iteration from " + f"{rerun_state_machine.current_iteration} to {iteration}...") + rerun_state_machine.current_iteration = iteration + + # Track E2E metrics at the start of training. + one_logger_utils.on_train_start( + iteration=iteration, + consumed_train_samples=args.consumed_train_samples, + train_samples=args.train_samples, + seq_length=args.seq_length, + train_iters=args.train_iters, + save=args.save, + async_save=args.async_save, + log_throughput=args.log_throughput, + num_floating_point_operations_so_far=args.num_floating_point_operations_so_far, + ) + + num_floating_point_operations_so_far = args.num_floating_point_operations_so_far + + # Setup some training config params. + config.grad_scale_func = optimizer.scale_loss + config.timers = timers + if isinstance(model[0], (megatron_FSDP, DDP)) and args.overlap_grad_reduce: + assert config.no_sync_func is None, ( + 'When overlap_grad_reduce is True, config.no_sync_func must be None; ' + 'a custom no_sync_func is not supported when overlapping grad-reduce' + ) + config.no_sync_func = [model_chunk.no_sync for model_chunk in model] + if len(model) == 1: + config.no_sync_func = config.no_sync_func[0] + if args.align_grad_reduce: + config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] + if len(model) == 1: + config.grad_sync_func = config.grad_sync_func[0] + if args.overlap_param_gather and args.align_param_gather: + config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] + if len(model) == 1: + config.param_sync_func = config.param_sync_func[0] + config.finalize_model_grads_func = finalize_model_grads + + if args.log_energy: + energy_monitor.setup() + energy_monitor.resume() + + timers('interval-time', log_level=0).start(barrier=True) + print_datetime('before the start of training step') + report_memory_flag = True + pre_hook_enabled = False + should_exit = False + exit_code = 0 + + if args.manual_gc: + # Disable the default garbage collector and perform the collection manually. + # This is to align the timing of garbage collection across ranks. + assert ( + args.manual_gc_interval >= 0 + ), 'Manual garbage collection interval should be larger than or equal to 0' + gc.disable() + gc.collect() + + # Singleton initialization of straggler detector. + if args.log_straggler: + global stimer + world = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + mmcnt = args.straggler_minmax_count + stimer.configure( + world, + rank, + mmcnt=mmcnt, + enabled=not args.disable_straggler_on_startup, + port=args.straggler_ctrlr_port, + ) + num_floating_point_operations_since_last_log_event = 0.0 + + num_microbatches = get_num_microbatches() + eval_duration = 0.0 + eval_iterations = 0 + # Wrap forward_backward_func for Full iteration CUDA graph + forward_backward_func = get_forward_backward_func() + if args.enable_cuda_graph and args.cuda_graph_scope=="full_iteration": + forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + + def get_e2e_base_metrics(): + """Get base metrics values for one-logger to calculate E2E tracking metrics.""" + num_floating_point_operations_since_current_train_start = ( + num_floating_point_operations_so_far - args.num_floating_point_operations_so_far + ) + return { + 'iteration': iteration, + 'train_duration': timers('interval-time').active_time(), + 'eval_duration': eval_duration, + 'eval_iterations': eval_iterations, + 'total_flops_since_current_train_start': num_floating_point_operations_since_current_train_start, + 'num_floating_point_operations_so_far': num_floating_point_operations_so_far, + 'consumed_train_samples': args.consumed_train_samples, + 'world_size': args.world_size, + 'seq_length': args.seq_length, + } + + # Cache into one-logger for callback. + if one_logger: + with one_logger.get_context_manager(): + one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics) + + prof = None + if ( + args.profile + and torch.distributed.get_rank() in args.profile_ranks + and args.use_pytorch_profiler + ): + prof = torch.profiler.profile( + schedule=torch.profiler.schedule( + wait=max(args.profile_step_start - 1, 0), + warmup=1 if args.profile_step_start > 0 else 0, + active=args.profile_step_end - args.profile_step_start, + repeat=1, + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), + record_shapes=True, + with_stack=True, + ) + prof.start() + + start_iteration = iteration + # Disable forward pre-hook to start training to ensure that errors in checkpoint loading + # or random initialization don't propagate to all ranks in first all-gather (which is a + # no-op if things work correctly). + if should_disable_forward_pre_hook(args): + disable_forward_pre_hook(model, param_sync=False) + # Also remove param_sync_func temporarily so that sync calls made in + # `forward_backward_func` are no-ops. + param_sync_func = config.param_sync_func + config.param_sync_func = None + pre_hook_enabled = False + # Also, check weight hash across DP replicas to be very pedantic. + if args.check_weight_hash_across_dp_replicas_interval is not None: + assert check_param_hashes_across_dp_replicas( + model, cross_check=True + ), "Parameter hashes not matching across DP replicas" + torch.distributed.barrier() + print_rank_0(f">>> Weight hashes match after {iteration} iterations...") + + # Capture CUDA Graphs. + if args.external_cuda_graph: + cuda_graph_helper = TECudaGraphHelper( + model=model, + config=config, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + optimizers=[optimizer], + ) + cuda_graph_helper.create_cudagraphs() + + # Run training iterations till done. + buffered_rollouts = None + ref_state_dict = None + while iteration < args.train_iters: + if args.profile and torch.distributed.get_rank() in args.profile_ranks: + if args.use_pytorch_profiler: + prof.step() + elif iteration == args.profile_step_start: + torch.cuda.cudart().cudaProfilerStart() + torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + + ft_integration.on_checkpointing_start() + maybe_finalize_async_save(blocking=False) + ft_integration.on_checkpointing_end(is_async_finalization=True) + + # Update number of microbatches first without consistency check to decide if a + # checkpoint should be saved. If the number of microbatches is different + # from the previous iteration, save a checkpoint. Then run consistency check + # to make sure training configuration is still valid. + update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True) + if get_num_microbatches() != num_microbatches and iteration != 0: + assert get_num_microbatches() > num_microbatches, ( + f"Number of microbatches should be increasing due to batch size rampup; " + f"instead going from {num_microbatches} to {get_num_microbatches()}" + ) + if args.save is not None: + save_checkpoint_and_time( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + train_data_iterator=train_data_iterator, + ) + num_microbatches = get_num_microbatches() + update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True) + + # Completely skip iteration if needed. + if iteration in args.iterations_to_skip: + # Dummy train_step to fast forward train_data_iterator. + dummy_train_step(train_data_iterator) + iteration += 1 + batch_size = ( + mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() + ) + args.consumed_train_samples += batch_size + args.skipped_train_samples += batch_size + continue + + args.curr_iteration = iteration + + # Update tensor saver iteration if tensor saving is enabled + if getattr(args, 'save_tensors', False): + try: + from megatron.core.tensor_saver import get_tensor_saver, set_global_sample_idx + tensor_saver = get_tensor_saver() + tensor_saver.set_iteration(iteration) + # Reset sample idx for each iteration + set_global_sample_idx(0) + except Exception as e: + print(f"[Training] Warning: Failed to update tensor saver iteration: {e}") + + # For GRPO, we keep the data for a few epochs. DeepSeekMath paper calls this number $\mu$. + # It is similar to a PPO epoch. + + if getattr(args, 'perform_rl_step', False): + with torch.no_grad(): + if not ref_state_dict: + ref_state_dict = {k: (v.cpu() if v is not None else v) for k, v in model[0].state_dict().items()} + + # We collect new rollouts when we've gone over the collected data 'grpo_iterations' times. + if iteration % (args.grpo_iterations * ((args.grpo_samples_per_iteration) // args.global_batch_size)) == 0: + buffered_rollouts = rl_utils.get_rollout_data_iterator( + model, optimizer, iteration, ref_state_dict, + ) + train_data_iterator = buffered_rollouts + + ft_integration.on_training_step_start() + ( + loss_dict, + skipped_iter, + should_checkpoint, + should_exit, + exit_code, + grad_norm, + num_zeros_in_grad, + ) = train_step( + forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func + ) + ft_integration.on_training_step_end() + if should_checkpoint: + save_checkpoint_and_time( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context, + train_data_iterator=train_data_iterator, + ) + if should_exit: + print("exit from log in line 2312") + import pdb;pdb.set_trace() + break + + # Enable forward pre-hooks after first set of forward and backward passes. + # When running in fp16, skip all NaN iterations until steady-state loss scaling value + # is reached. + if iteration == start_iteration: + if skipped_iter: + # Only enable forward pre-hook after a training step has successfully run. Relevant + # for fp16 codepath where first XX iterations are skipped until steady-state loss + # scale value is reached. + start_iteration = iteration + 1 + else: + # Enable forward pre-hook after training step has successfully run. All subsequent + # forward passes will use the forward pre-hook / `param_sync_func` in + # `forward_backward_func`. + if should_disable_forward_pre_hook(args): + enable_forward_pre_hook(model) + config.param_sync_func = param_sync_func + pre_hook_enabled = True + # Set the manual hooks when CUDA Graphs are used. + if args.external_cuda_graph: + cuda_graph_helper.cuda_graph_set_manual_hooks() + + iteration += 1 + + # Check if we've reached the control_iter limit and exit if needed + control_iter = getattr(args, 'control_iter', None) + if control_iter is not None and iteration >= control_iter: + print_rank_0(f"[Training] Reached control_iter limit ({control_iter}), exiting training...") + # Exit the training loop early + break + + batch_size = ( + mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() + ) + args.consumed_train_samples += batch_size + num_skipped_samples_in_batch = ( + get_current_global_batch_size() - get_current_running_global_batch_size() + ) + if args.decrease_batch_size_if_needed: + assert num_skipped_samples_in_batch >= 0 + else: + assert num_skipped_samples_in_batch == 0 + args.skipped_train_samples += num_skipped_samples_in_batch + num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size) + num_floating_point_operations_so_far += num_floating_point_operations_in_batch + num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch + + # Logging. + if not optimizer.is_stub_optimizer: + loss_scale = optimizer.get_loss_scale().item() + else: + loss_scale = 1.0 + params_norm = None + + if args.log_params_norm: + params_norm = calc_params_l2_norm(model) + learning_rate = None + decoupled_learning_rate = None + for param_group in optimizer.param_groups: + if len(param_group['params']) == 0: + continue + if param_group['is_decoupled_lr']: + decoupled_learning_rate = param_group['lr'] + else: + learning_rate = param_group['lr'] + report_memory_flag = training_log( + loss_dict, + total_loss_dict, + learning_rate, + decoupled_learning_rate, + iteration, + loss_scale, + report_memory_flag, + skipped_iter, + grad_norm, + params_norm, + num_zeros_in_grad, + ) + + # Evaluation. + if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid: + if args.log_energy: + energy_monitor.pause() + timers('interval-time').stop() + if should_disable_forward_pre_hook(args): + disable_forward_pre_hook(model) + pre_hook_enabled = False + if args.manual_gc and args.manual_gc_eval: + # Collect all objects. + gc.collect() + prefix = f'iteration {iteration}' + timers('eval-time', log_level=0).start(barrier=True) + if getattr(args, 'perform_rl_step', False): + rl_utils.evaluate_and_print_results_rl(valid_data_iterator, model, optimizer, + iteration, write_to_tensorboard=True) + else: + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, + config, verbose=False, write_to_tensorboard=True, + non_loss_data_func=non_loss_data_func) + + eval_duration += timers('eval-time').elapsed() + eval_iterations += sum(args.eval_iters) if isinstance(args.eval_iters, list) else args.eval_iters + timers('eval-time').stop() + one_logger_utils.track_e2e_metrics() + + if args.manual_gc and args.manual_gc_eval: + # Collect only the objects created and used in evaluation. + gc.collect(generation=0) + if should_disable_forward_pre_hook(args): + enable_forward_pre_hook(model) + pre_hook_enabled = True + timers('interval-time', log_level=0).start(barrier=True) + if args.log_energy: + energy_monitor.resume() + + # Miscellaneous post-training-step functions (e.g., FT heartbeats, GC). + # Some of these only happen at specific iterations. + post_training_step_callbacks( + model, + optimizer, + opt_param_scheduler, + iteration, + prof, + num_floating_point_operations_since_last_log_event, + ) + + # Checkpoint and decide whether to exit. + should_exit = checkpoint_and_decide_exit( + model, + optimizer, + opt_param_scheduler, + iteration, + num_floating_point_operations_so_far, + checkpointing_context, + train_data_iterator, + ) + if should_exit: + # print("exit from log in line 2442") + # import pdb;pdb.set_trace() + break + + one_logger_utils.track_e2e_metrics() + + # Flush TensorBoard, WandB writers and one-logger. + writer = get_tensorboard_writer() + if writer: + writer.flush() + + # Close out pre-hooks if using distributed optimizer and overlapped param gather. + if pre_hook_enabled: + disable_forward_pre_hook(model) + + ft_integration.on_checkpointing_start() + # This will finalize all unfinalized async request and terminate + # a persistent async worker if persistent ckpt worker is enabled + maybe_finalize_async_save(blocking=True, terminate=True) + ft_integration.on_checkpointing_end(is_async_finalization=True) + if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None: + ft_integration.get_rank_monitor_client().shutdown_workload_monitoring() + + if args.log_energy: + energy_monitor.lap() + total_energy = energy_monitor.get_total() + print_rank_0(f"Total training energy (GPU): {total_energy / 1e6} MJ") + energy_monitor.shutdown() + + # If any exit conditions (signal handler, duration, iterations) have been reached, exit. + if should_exit: + wandb_writer = get_wandb_writer() + if wandb_writer: + wandb_writer.finish() + ft_integration.shutdown() + one_logger_utils.finish() + sys.exit(exit_code) + + return iteration, num_floating_point_operations_so_far + + +def evaluate( + forward_step_func, + data_iterator, + model, + process_non_loss_data_func, + config, + verbose=False, + non_loss_data_func=None, + eval_iters=None, +): + """Evaluation.""" + args = get_args() + timers = get_timers() + + timers('evaluate', log_level=0).start(barrier=True) + + if args.vision_pretraining and args.vision_pretraining_type == "dino": + from megatron.legacy.model.vision.knn_monitor import compute_feature_bank + + compute_feature_bank(model) + + # Turn on evaluation mode which disables dropout. + for model_module in model: + model_module.eval() + + # Disable result validation during evaluation + rerun_state_machine = get_rerun_state_machine() + rerun_mode = rerun_state_machine.get_mode() + rerun_state_machine.set_mode(RerunMode.DISABLED) + + total_loss_dict = {} + + # make validation batch size independent from training batch size + eval_batch_size = args.global_batch_size + eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size) + forward_backward_func = get_forward_backward_func() + if args.enable_cuda_graph and args.cuda_graph_scope=="full_iteration": + forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + + if eval_iters is None: + eval_iters = args.eval_iters + + with torch.no_grad(): + iteration = 0 + if verbose: + print_rank_0(f'Evaluating on {eval_iters * eval_batch_size} samples') + while iteration < eval_iters: + iteration += 1 + if verbose: + print_rank_0(f'Evaluating iter {iteration}/{eval_iters}') + + # Don't care about timing during evaluation + config.timers = None + ft_integration.on_eval_step_start() + loss_dicts = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=eval_num_microbatches, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=True, + ) + ft_integration.on_eval_step_end() + config.timers = get_timers() + + # Empty unused memory + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Reduce across processes. + for key in loss_dicts[0].keys(): + if key not in total_loss_dict: + total_loss_dict[key] = torch.tensor( + [0.0, 0.0], dtype=torch.float + ).cuda() + val = [x[key].view(-1) for x in loss_dicts] + + if val[0].numel() == 2: + if args.sft: + # normalize over micro batch instead of global + val = torch.vstack(val) + val = val[:, 0] / val[:, 1] + val = val.mean() + torch.distributed.all_reduce( + val, + group=mpu.get_data_parallel_group(with_context_parallel=True) + ) + val /= torch.distributed.get_world_size( + group=mpu.get_data_parallel_group(with_context_parallel=True) + ) + total_loss_dict[key][0] += val + total_loss_dict[key][1] += 1 + else : + val = torch.vstack(val).sum(dim=0) + torch.distributed.all_reduce( + val, + group=mpu.get_data_parallel_group(with_context_parallel=True) + ) + total_loss_dict[key] += val + elif val[0].numel() == 1: + val = torch.cat(val).sum() + total_loss_dict[key][0] += val + total_loss_dict[key][1] += len(loss_dicts) + else: + raise ValueError(f"Invalid value shape: {val[0].shape} for key {key}") + + args.consumed_valid_samples += eval_batch_size + + if args.exit_duration_in_mins: + train_time = (time.time() - _TRAIN_START_TIME) / 60.0 + done_cuda = torch.tensor( + [train_time > args.exit_duration_in_mins], dtype=torch.int, device='cuda' + ) + torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) + done = done_cuda.item() + if done: + rerun_state_machine.set_mode(rerun_mode) + print_rank_0('Exiting during evaluation, timelimit reached') + return None, None, True + + collected_non_loss_data = None + if non_loss_data_func is not None: + collected_non_loss_data = non_loss_data_func(model) + elif process_non_loss_data_func is not None and is_last_rank(): + collected_non_loss_data = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=True, + collect_non_loss_data=True, + ) + + # Move model back to the train mode. + for model_module in model: + model_module.train() + + for key in total_loss_dict: + numerator, denominator = total_loss_dict[key] + total_loss_dict[key] = numerator / denominator + + timers('evaluate').stop() + timers.log(['evaluate']) + + rerun_state_machine.set_mode(rerun_mode) + + rerun_state_machine.set_mode(rerun_mode) + + return total_loss_dict, collected_non_loss_data, False + + +def evaluate_and_print_results( + prefix, + forward_step_func, + data_iterator, + model, + iteration, + process_non_loss_data_func, + config, + verbose=False, + write_to_tensorboard=True, + non_loss_data_func=None, +): + """Helper function to evaluate and dump results on screen.""" + args = get_args() + if write_to_tensorboard: + writer = get_tensorboard_writer() + else: + writer = None + + wandb_writer = get_wandb_writer() + + data_iterators = data_iterator if args.multiple_validation_sets else [data_iterator] + + if not args.multiple_validation_sets: + eval_iters = [args.eval_iters] + else: + eval_iters = args.eval_iters + + if args.full_validation: + assert len(eval_iters) == len(data_iterators) + + # with full validation we need to distribute eval_iters to all ranks + if mpu.get_tensor_model_parallel_rank() == 0: + eval_iters = torch.tensor(args.eval_iters, dtype=torch.long, device='cuda') + else: + eval_iters = torch.tensor([0] * len(eval_iters), dtype=torch.long, device='cuda') + torch.distributed.broadcast(eval_iters, 0) + eval_iters = eval_iters.tolist() + args.eval_iters = eval_iters[0] if not args.multiple_validation_sets else eval_iters + elif not args.multiple_validation_sets: + eval_iters = [args.eval_iters] + else: + eval_iters = args.eval_iters + + for index, (iterator, iterations) in enumerate(zip(data_iterators, eval_iters)): + suffix = "" + if args.multiple_validation_sets: + suffix = f"-{index}" + total_loss_dict, collected_non_loss_data, timelimit = evaluate( + forward_step_func, + iterator, + model, + process_non_loss_data_func, + config, + verbose, + non_loss_data_func, + eval_iters=iterations, + ) + # Timelimit hit during evaluation + if timelimit: + return + string = f' validation{suffix} loss at {prefix} | ' + for key in total_loss_dict: + string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) + ppl = math.exp(min(20, total_loss_dict[key].item())) + string += '{} PPL: {:.6E} | '.format(key, ppl) + if writer: + writer.add_scalar('{} validation{}'.format(key, suffix), total_loss_dict[key].item(), iteration) + writer.add_scalar( + '{} validation{} vs samples'.format(key, suffix), + total_loss_dict[key].item(), + args.consumed_train_samples, + ) + if args.log_validation_ppl_to_tensorboard: + writer.add_scalar('{} validation{} ppl'.format(key, suffix), ppl, iteration) + writer.add_scalar( + '{} validation{} ppl vs samples'.format(key, suffix), ppl, args.consumed_train_samples + ) + if wandb_writer and is_last_rank(): + wandb_writer.log( + {'{} validation{}'.format(key, suffix): total_loss_dict[key].item()}, iteration + ) + + if process_non_loss_data_func is not None and writer and is_last_rank(): + process_non_loss_data_func(collected_non_loss_data, iteration, writer) + + length = len(string) + 1 + print_rank_last('-' * length) + print_rank_last(string) + print_rank_last('-' * length) + + +def cyclic_iter(iter): + while True: + for x in iter: + yield x + + +def get_train_valid_test_num_samples(): + """Train/valid/test num samples.""" + + args = get_args() + + # Number of train/valid/test samples. + if args.train_samples: + train_samples = args.train_samples + else: + train_samples = args.train_iters * args.global_batch_size + if args.full_validation: + eval_samples = None + else: + eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters + eval_samples = eval_iters * args.global_batch_size + test_iters = args.eval_iters + + return (train_samples, eval_samples, test_iters * args.global_batch_size) + + +def build_train_valid_test_datasets(build_train_valid_test_datasets_provider, train_valid_test_num_samples=None): + """Build pretraining datasets.""" + if train_valid_test_num_samples is None: + train_valid_test_num_samples = get_train_valid_test_num_samples() + print_rank_0(' > datasets target sizes (minimum size):') + print_rank_0(' train: {}'.format(train_valid_test_num_samples[0])) + print_rank_0(' validation: {}'.format(train_valid_test_num_samples[1])) + print_rank_0(' test: {}'.format(train_valid_test_num_samples[2])) + return build_train_valid_test_datasets_provider(train_valid_test_num_samples) + + +def build_train_valid_test_data_loaders(build_train_valid_test_datasets_provider): + """Build pretraining data loaders.""" + + args = get_args() + + (train_dataloader, valid_dataloaders, test_dataloader) = (None, None, None) + + print_rank_0('> building train, validation, and test datasets ...') + + # Backward compatibility, assume fixed batch size. + if args.iteration > 0 and args.consumed_train_samples == 0: + assert ( + args.train_samples is None + ), 'Only backward compatiblity support for iteration-based training' + args.consumed_train_samples = args.iteration * args.global_batch_size + if args.iteration > 0 and args.consumed_valid_samples == 0: + if args.train_samples is None: + args.consumed_valid_samples = ( + (args.iteration // args.eval_interval) * args.eval_iters * args.global_batch_size + ) + + # Rely on distributed-aware core datasets, temporary + is_distributed = getattr(build_train_valid_test_datasets_provider, "is_distributed", False) + + # Construct the data pipeline + if is_distributed or mpu.get_tensor_model_parallel_rank() == 0: + + # Build datasets. + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + build_train_valid_test_datasets_provider, (1, 1, 1) if getattr(args, 'perform_rl_step', False) else None + ) + valid_ds = [valid_ds] if not isinstance(valid_ds, list) else valid_ds + + # Build dataloders. + train_dataloader = build_pretraining_data_loader(train_ds, args.consumed_train_samples) + + valid_dataloaders = [] + for valid_d in valid_ds: + if args.skip_train or args.full_validation: + valid_dataloaders.append(build_pretraining_data_loader(valid_d, 0)) + else: + if args.multiple_validation_sets: + # TODO(bnorick): for multiple validation sets without full validation, args.consumed_valid_samples is not + # correct and needs to be calculated/set per validation set + raise NotImplementedError("--multiple-validation-sets currently requires --full-validation") + valid_dataloaders.append(build_pretraining_data_loader(valid_d, args.consumed_valid_samples)) + if not args.multiple_validation_sets: + assert len(valid_dataloaders) == 1 + test_dataloader = build_pretraining_data_loader(test_ds, 0) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and args.train_iters > 0 + do_valid = valid_dataloaders is not None and (args.full_validation or args.eval_iters > 0) + do_test = test_dataloader is not None and (args.full_validation or args.eval_iters > 0) + flags = torch.tensor( + [int(do_train), int(do_valid), int(do_test)], dtype=torch.long, device='cuda' + ) + else: + flags = torch.tensor([0, 0, 0], dtype=torch.long, device='cuda') + + torch.distributed.broadcast(flags, 0) + + args.do_train = getattr(args, "do_train", False) or flags[0].item() + args.do_valid = getattr(args, "do_valid", False) or flags[1].item() + args.do_test = getattr(args, "do_test", False) or flags[2].item() + if getattr(args, 'perform_rl_step', False): + args.to_test = False + + return train_dataloader, valid_dataloaders, test_dataloader + + +def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provider): + """Build pretraining data iterators.""" + + args = get_args() + + # Build loaders. + train_dataloader, valid_dataloaders, test_dataloader = build_train_valid_test_data_loaders( + build_train_valid_test_datasets_provider + ) + + # Build iterators. + dl_type = args.dataloader_type + assert dl_type in ['single', 'cyclic', 'external'] + + def _get_iterator(dataloader_type, dataloader): + """Return dataset iterator.""" + if dataloader_type == "single": + return RerunDataIterator(iter(dataloader)) + elif dataloader_type == "cyclic": + return RerunDataIterator(iter(cyclic_iter(dataloader))) + elif dataloader_type == "external": + # External dataloader is passed through. User is expected to define how to iterate. + if isinstance(dataloader, list): + return [RerunDataIterator(d) for d in dataloader] + else: + return RerunDataIterator(dataloader) + else: + raise RuntimeError("unexpected dataloader type") + + if train_dataloader is not None: + train_data_iterator = _get_iterator(dl_type, train_dataloader) + else: + train_data_iterator = None + + # when using full validation, we need to override eval iters with the correct + # number of iterations on tp rank 0 so that it can be distributed to the other + # ranks later + if args.full_validation: + if args.multiple_validation_sets: + if valid_dataloaders[0] is None: + args.eval_iters = [None]*len(valid_dataloaders) + else: + args.eval_iters = [len(dl) for dl in valid_dataloaders] + else: + args.eval_iters = len(valid_dataloaders[0]) + + if args.multiple_validation_sets: + if valid_dataloaders[0] is None: + valid_data_iterators = [None] * len(valid_dataloaders) + else: + valid_dl_type = "cyclic" if args.full_validation else dl_type + print( + f"[VALID DATA LOADER LENGTHS] " + ", ".join(f"{idx}: {len(dl)}" for idx, dl in enumerate(valid_dataloaders)) + ) + valid_data_iterators = [ + _get_iterator(valid_dl_type, dl) for dl in valid_dataloaders + ] + elif valid_dataloaders[0] is not None: + valid_data_iterators = _get_iterator(dl_type, valid_dataloaders[0]) + else: + valid_data_iterators = None + + if test_dataloader is not None: + test_data_iterator = _get_iterator(dl_type, test_dataloader) + else: + test_data_iterator = None + + return train_data_iterator, valid_data_iterators, test_data_iterator + + +def should_disable_forward_pre_hook(args): + """Block forward pre-hook for certain configurations.""" + return not args.use_megatron_fsdp and args.use_distributed_optimizer and args.overlap_param_gather diff --git a/pretrain_gpt.py b/pretrain_gpt.py index f883df6187..b920ef30fc 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -129,6 +129,17 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa """ args = get_args() timers = get_timers() + # import pdb;pdb.set_trace() + + # Update sample index for tensor saving if enabled + if getattr(args, 'save_tensors', False): + try: + from megatron.core.tensor_saver import get_tensor_collection_state + state = get_tensor_collection_state() + current_sample = state.get_sample_idx() or 0 + state.set_sample_idx(current_sample + 1) + except Exception as e: + print(f"[ForwardStep] Warning: Failed to update sample index: {e}") # Get the batch. timers('batch-generator', log_level=2).start() @@ -215,10 +226,15 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): dataset_type = GPTDataset print_rank_0("> building train, validation, and test datasets for GPT ...") - - train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + # import pdb;pdb.set_trace() + print_rank_0("> GPT datasets builder ...") + builder = BlendedMegatronDatasetBuilder( dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config - ).build() + ) + + print_rank_0("> GPT datasets build ...") + train_ds, valid_ds, test_ds = builder.build() + print_rank_0("> finished creating GPT datasets ...") diff --git a/quant/bf16_operators.py b/quant/bf16_operators.py new file mode 100644 index 0000000000..fe6a5991e3 --- /dev/null +++ b/quant/bf16_operators.py @@ -0,0 +1,515 @@ +#!/usr/bin/env python3 +""" +BF16 operators module +Provides BF16 matrix multiplication operators with tensor saving functionality +""" + +import torch +from torch.autograd import Function +from typing import Optional, Dict, Any + + +class BF16MatMul(Function): + """BF16 matrix multiplication operator with integrated tensor saving functionality""" + + @staticmethod + def forward(ctx, A: torch.Tensor, B: torch.Tensor, + layer_type: Optional[str] = None, layer_idx: Optional[int] = None, + operation: str = "forward", phase: str = "pre", component: str = "linear", + rank: Optional[int] = None, metadata: Optional[Dict[str, Any]] = None): + """ + BF16矩阵乘法前向传播 + + Args: + A: 输入tensor A (BF16) + B: 输入tensor B (BF16) + layer_type: 层类型 ("attention", "linear", etc.) + layer_idx: 层索引 + operation: 操作类型 ("forward", "backward") + phase: 阶段 ("pre", "post") + component: 组件类型 ("linear", "FA", etc.) + rank: GPU rank信息 + metadata: 额外的元数据 + """ + # 保存tensor和参数到ctx + ctx.save_for_backward(A, B) + ctx.layer_type = layer_type + ctx.layer_idx = layer_idx + ctx.operation = operation + ctx.phase = phase + ctx.component = component + ctx.rank = rank + # 使用私有属性名保存metadata,避免属性冲突 + ctx._metadata = metadata + + # 确保tensor是BF16格式 + if A.dtype != torch.bfloat16: + A = A.to(torch.bfloat16) + if B.dtype != torch.bfloat16: + B = B.to(torch.bfloat16) + + # 执行矩阵乘法 + output = torch.matmul(A, B) + + # 自动保存forward阶段的tensor + if layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 根据component类型确定tensor名称 + if component == "FA" or component == "attention": + # attention操作:A是attention_probs,B是value + tensor_name_A = "attention_probs" + tensor_name_B = "value" + else: + # linear操作:使用通用名称 + tensor_name_A = "input_A" + tensor_name_B = "input_B" + + # 保存输入tensor A + save_tensor( + tensor=A, + layer_type=layer_type, + operation=operation, + quant_type="bf16", + tensor_name=tensor_name_A, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存输入tensor B + save_tensor( + tensor=B, + layer_type=layer_type, + operation=operation, + quant_type="bf16", + tensor_name=tensor_name_B, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存输出tensor + save_tensor( + tensor=output, + layer_type=layer_type, + operation=operation, + quant_type="bf16", + tensor_name="output", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + BF16矩阵乘法反向传播 + """ + A, B = ctx.saved_tensors + grad_A = grad_B = None + + # 计算梯度 + if ctx.needs_input_grad[0]: + grad_A = torch.matmul(grad_output, B.transpose(-2, -1)) + if ctx.needs_input_grad[1]: + grad_B = torch.matmul(A.transpose(-2, -1), grad_output) + + # 自动保存backward阶段的tensor + if ctx.layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 保存梯度输出 + save_tensor( + tensor=grad_output, + layer_type=ctx.layer_type, + operation="backward", + quant_type="bf16", + tensor_name="grad_output", + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 根据component类型确定backward tensor名称 + if ctx.component == "FA" or ctx.component == "attention": + # attention操作:grad_A是grad_attention_probs,grad_B是grad_value + grad_tensor_name_A = "grad_attention_probs" + grad_tensor_name_B = "grad_value" + else: + # linear操作:使用通用名称 + grad_tensor_name_A = "grad_input_A" + grad_tensor_name_B = "grad_input_B" + + # 保存梯度A + if grad_A is not None: + save_tensor( + tensor=grad_A, + layer_type=ctx.layer_type, + operation="backward", + quant_type="bf16", + tensor_name=grad_tensor_name_A, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 保存梯度B + if grad_B is not None: + save_tensor( + tensor=grad_B, + layer_type=ctx.layer_type, + operation="backward", + quant_type="bf16", + tensor_name=grad_tensor_name_B, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return grad_A, grad_B, None, None, None, None, None, None, None # None对应所有额外参数(9个) + + +class BF16BAddBmm(Function): + """BF16 Batch Add Batch Matrix Multiplication operator with integrated tensor saving functionality""" + + @staticmethod + def forward(ctx, input: torch.Tensor, batch1: torch.Tensor, batch2: torch.Tensor, + beta: float = 1.0, alpha: float = 1.0, + layer_type: Optional[str] = None, layer_idx: Optional[int] = None, + operation: str = "forward", phase: str = "pre", component: str = "attention", + rank: Optional[int] = None, metadata: Optional[Dict[str, Any]] = None): + """ + BF16 Batch Add Batch Matrix Multiplication前向传播 + + Args: + input: 输入tensor + batch1: 第一个batch tensor + batch2: 第二个batch tensor + beta: beta参数 + alpha: alpha参数 + layer_type: 层类型 ("attention", "linear", etc.) + layer_idx: 层索引 + operation: 操作类型 ("forward", "backward") + phase: 阶段 ("pre", "post") + component: 组件类型 ("attention", "linear", etc.) + rank: GPU rank信息 + metadata: 额外的元数据 + """ + # 保存tensor和参数到ctx + ctx.save_for_backward(input, batch1, batch2) + ctx.beta = beta + ctx.alpha = alpha + ctx.layer_type = layer_type + ctx.layer_idx = layer_idx + ctx.operation = operation + ctx.phase = phase + ctx.component = component + ctx.rank = rank + ctx._metadata = metadata + + # 确保tensor是BF16格式 + if input.dtype != torch.bfloat16: + input = input.to(torch.bfloat16) + if batch1.dtype != torch.bfloat16: + batch1 = batch1.to(torch.bfloat16) + if batch2.dtype != torch.bfloat16: + batch2 = batch2.to(torch.bfloat16) + + # 执行batch matrix multiplication + mm_out = torch.bmm(batch1, batch2) + output = beta * input + alpha * mm_out + + # 自动保存forward阶段的tensor + if layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 根据component类型确定tensor名称 + if component == "FA" or component == "attention": + # attention操作:input是matmul_input_buffer,batch1是query,batch2是key + tensor_name_input = "matmul_input_buffer" + tensor_name_batch1 = "query" + tensor_name_batch2 = "key" + else: + # 其他操作:使用通用名称 + tensor_name_input = "input" + tensor_name_batch1 = "batch1" + tensor_name_batch2 = "batch2" + + # 保存输入tensor + save_tensor( + tensor=input, + layer_type=layer_type, + operation=operation, + quant_type="bf16", + tensor_name=tensor_name_input, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存batch1 tensor + save_tensor( + tensor=batch1, + layer_type=layer_type, + operation=operation, + quant_type="bf16", + tensor_name=tensor_name_batch1, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存batch2 tensor + save_tensor( + tensor=batch2, + layer_type=layer_type, + operation=operation, + quant_type="bf16", + tensor_name=tensor_name_batch2, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存矩阵乘法结果 + save_tensor( + tensor=mm_out, + layer_type=layer_type, + operation=operation, + quant_type="bf16", + tensor_name="mm_output", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存最终输出 + save_tensor( + tensor=output, + layer_type=layer_type, + operation=operation, + quant_type="bf16", + tensor_name="output", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return output + + @staticmethod + def backward(ctx, grad_output): + input, batch1, batch2 = ctx.saved_tensors + beta, alpha = ctx.beta, ctx.alpha + + grad_input = grad_batch1 = grad_batch2 = None + + # 计算梯度 + if ctx.needs_input_grad[0]: + grad_input = beta * grad_output + if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: + mm_grad = alpha * grad_output + grad_batch1 = torch.bmm(mm_grad, batch2.transpose(-2, -1)) + grad_batch2 = torch.bmm(batch1.transpose(-2, -1), mm_grad) + + # 自动保存backward阶段的tensor + if ctx.layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 保存梯度输出 + save_tensor( + tensor=grad_output, + layer_type=ctx.layer_type, + operation="backward", + quant_type="bf16", + tensor_name="grad_output", + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 根据component类型确定backward tensor名称 + if ctx.component == "FA" or ctx.component == "attention": + # attention操作:grad_input是grad_matmul_input_buffer,grad_batch1是grad_query,grad_batch2是grad_key + grad_tensor_name_input = "grad_matmul_input_buffer" + grad_tensor_name_batch1 = "grad_query" + grad_tensor_name_batch2 = "grad_key" + else: + # 其他操作:使用通用名称 + grad_tensor_name_input = "grad_input" + grad_tensor_name_batch1 = "grad_batch1" + grad_tensor_name_batch2 = "grad_batch2" + + # 保存梯度input + if grad_input is not None: + save_tensor( + tensor=grad_input, + layer_type=ctx.layer_type, + operation="backward", + quant_type="bf16", + tensor_name=grad_tensor_name_input, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 保存梯度batch1 + if grad_batch1 is not None: + save_tensor( + tensor=grad_batch1, + layer_type=ctx.layer_type, + operation="backward", + quant_type="bf16", + tensor_name=grad_tensor_name_batch1, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 保存梯度batch2 + if grad_batch2 is not None: + save_tensor( + tensor=grad_batch2, + layer_type=ctx.layer_type, + operation="backward", + quant_type="bf16", + tensor_name=grad_tensor_name_batch2, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return grad_input, grad_batch1, grad_batch2, None, None, None, None, None, None, None, None, None # None对应所有额外参数(12个) + + + + +# 便捷函数 +def bf16_matmul(A: torch.Tensor, B: torch.Tensor, **tensor_save_kwargs) -> torch.Tensor: + """ + BF16矩阵乘法便捷函数,支持tensor保存 + + Args: + A, B: 输入tensor + **tensor_save_kwargs: tensor保存相关参数 + - layer_type: 层类型 + - layer_idx: 层索引 + - operation: 操作类型 + - phase: 阶段 + - component: 组件类型 + - rank: GPU rank + - metadata: 元数据 + """ + # 如果有tensor保存参数,使用集成算子 + if tensor_save_kwargs and any(key in tensor_save_kwargs for key in + ['layer_type', 'layer_idx', 'operation', 'phase', 'component', 'rank', 'metadata']): + return BF16MatMul.apply( + A, B, + tensor_save_kwargs.get('layer_type'), + tensor_save_kwargs.get('layer_idx'), + tensor_save_kwargs.get('operation', 'forward'), + tensor_save_kwargs.get('phase', 'pre'), + tensor_save_kwargs.get('component', 'linear'), + tensor_save_kwargs.get('rank'), + tensor_save_kwargs.get('metadata') + ) + else: + # 否则使用原始调用方式 + return BF16MatMul.apply(A, B) + + +def bf16_baddbmm(input: torch.Tensor, batch1: torch.Tensor, batch2: torch.Tensor, + beta: float = 1.0, alpha: float = 1.0, **tensor_save_kwargs) -> torch.Tensor: + """ + BF16 Batch Add Batch Matrix Multiplication便捷函数,支持tensor保存 + + Args: + input: 输入tensor + batch1: 第一个batch tensor + batch2: 第二个batch tensor + beta: beta参数 + alpha: alpha参数 + **tensor_save_kwargs: tensor保存相关参数 + - layer_type: 层类型 + - layer_idx: 层索引 + - operation: 操作类型 + - phase: 阶段 + - component: 组件类型 + - rank: GPU rank + - metadata: 元数据 + """ + # 如果有tensor保存参数,使用集成算子 + if tensor_save_kwargs and any(key in tensor_save_kwargs for key in + ['layer_type', 'layer_idx', 'operation', 'phase', 'component', 'rank', 'metadata']): + return BF16BAddBmm.apply( + input, batch1, batch2, beta, alpha, + tensor_save_kwargs.get('layer_type'), + tensor_save_kwargs.get('layer_idx'), + tensor_save_kwargs.get('operation', 'forward'), + tensor_save_kwargs.get('phase', 'pre'), + tensor_save_kwargs.get('component', 'attention'), + tensor_save_kwargs.get('rank'), + tensor_save_kwargs.get('metadata') + ) + else: + # 否则使用原始调用方式 + return BF16BAddBmm.apply(input, batch1, batch2, beta, alpha) + + diff --git a/quant/curve/loss_curve_cmp_non_pretrain.png b/quant/curve/loss_curve_cmp_non_pretrain.png new file mode 100644 index 0000000000..a225433b68 Binary files /dev/null and b/quant/curve/loss_curve_cmp_non_pretrain.png differ diff --git a/quant/fusion_result.json b/quant/fusion_result.json new file mode 100644 index 0000000000..ec747fa47d --- /dev/null +++ b/quant/fusion_result.json @@ -0,0 +1 @@ +null \ No newline at end of file diff --git a/quant/hifp.py b/quant/hifp.py new file mode 100644 index 0000000000..baed71a81a --- /dev/null +++ b/quant/hifp.py @@ -0,0 +1,627 @@ +import numpy as np +from quant.qtype import QType +from torch import Tensor +import torch +from torch.autograd import Function + +def to_HiFX(x, G: int = 64, N: int = 4) -> np.ndarray: + x = np.array(x) + Mi, Ni = x.shape[0],x.shape[1] + Mcnt = np.ceil(Mi / G).astype(int) + res = np.zeros((Mi,Ni)) + Ng = N - 2 + for i in range(Mcnt): + for j in range(Ni): + ori = x[i*G : i*G+G, j] # 当前 64 长度向量 + S = np.ones(G) + S[ori < 0] = -1 + S = S.T + tmpG = np.abs(ori) + + # ---------- level-1 ---------- + EG = np.floor(np.log2(tmpG + 2**(-1000))) + E16 = np.zeros(16) + for k in range(16): + E16[k] = np.max(EG[k*4 : k*4+4]) + + E8 = np.zeros(8) + for k in range(8): + E8[k] = np.max(E16[k*2 : k*2+2]) + + Emax = np.max(E8) + E8_1 = Emax - 2 # [-127, 125] + + # ---------- level-2 ---------- + E1_8 = E8 - E8_1 - 1 # <= 1 + E1_8[E1_8 < 0] = 0 # [0, 1] + + # ---------- level-3 ---------- + E1_8x2 = np.zeros(16) + for k in range(8): + E1_8x2[k*2 : k*2+2] = E1_8[k] + + E1_16 = E16 - E1_8x2 - E8_1 # <= 1 + E1_16[E1_16 < 0] = 0 # [0, 1] + + # ---------- restore ---------- + E16G = E1_16 + E1_8x2 + E8_1 # fused 16 exp + EG = np.zeros(G) + for k in range(16): + EG[k*4 : k*4+4] = E16G[k] + + in_grp = np.floor(tmpG * 2**(-EG + Ng) + 0.5) * 2.0**(-Ng) + in_grp[in_grp >= 2] = 2 - 2**(-Ng) + grp = S * in_grp * 2.0**EG + res[i*G : i*G+G, j] = grp + + return res + + +class RoundHif8_dml(Function): + @staticmethod + def forward(ctx, + x: torch.Tensor, + max_exp: int, + min_exp: int, + Ec: int) -> torch.Tensor: + x_tmp = x.clone().detach() + E = torch.floor(torch.log2(torch.abs(x_tmp))).detach().float() + D = torch.floor(torch.log2(torch.abs(E - Ec))) + 1 + D = torch.where(E != Ec, D, 0) + + x = torch.where(D <= 2, + torch.round(x * torch.exp2(-E + 3)) * torch.exp2(-3 + E), + x) + x = torch.where((D > 2) & (D < 5), + torch.round(x * torch.exp2(-E + 5 - D)) * torch.exp2(D - 5 + E), + x) + x = torch.where(D >= 5, + torch.round(x * torch.exp2(-E)) * torch.exp2(E), + x) + + over_value = 1.25 * 2**(max_exp + Ec) + down_value = 1.5 * 2**(min_exp + Ec) + x = torch.where(x_tmp >= over_value, over_value, x) + x = torch.where(torch.abs(x_tmp) <= down_value, 0.0, x) + x = torch.where(torch.isinf(x_tmp)|torch.isnan(x_tmp), x_tmp, x) # 保持 NaN + return x + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None, None, None + +round_hif8_func_dml = RoundHif8_dml.apply +def any_to_hif8_dml(num: torch.Tensor, + Ec: int = 0, + dml = True) -> torch.Tensor: + dtype = num.dtype + num = num.float() + if dml: + max_exp = 15 + min_exp = -22 + num = round_hif8_func_dml(num, max_exp, min_exp, Ec) + num = num.to(dtype) + return num + + +fp_max_dict = { + "fp16": 65504.0, + "e4m3": 448.0, + "e5m2": 57344.0, + "hif8_7": 224.0, + "hif8_15": 32768.0 +} + +def compute_scaling_factor_fp8(amax: torch.Tensor, + scale: torch.Tensor, + fp_max: float) -> torch.Tensor: + sf = fp_max / amax + sf = torch.where(amax > 0.0, sf, scale) + sf = torch.where(torch.isfinite(amax), sf, scale) + sf = torch.where(torch.isinf(sf), + torch.full_like(sf, torch.finfo(amax.dtype).max), + sf) + return sf + + +@torch.no_grad() +def quant_hif8(x: Tensor, Q: QType=None, qdim: int=-1) -> Tensor: + max_value = (2**15)*0.95 + min_value = 2**(-22) + + x_unsignedl = torch.abs(x) + sign = torch.sign(x) + + x_unsigned = torch.clamp(x_unsignedl, min=min_value, max=max_value) + + if x.dtype == torch.float16: + e = torch.floor(torch.log2(x_unsigned + 2**(-14))) + else: + e = torch.floor(torch.log2(x_unsigned + 2**(-45))) + + abse = e.abs() + mant_bits = torch.zeros_like(abse) + mant_bits[abse <= 15] = 1 + mant_bits[abse <= 7] = 2 + mant_bits[abse <= 3] = 3 + + res = torch.floor(x_unsigned * 2.0**(-e + mant_bits) + 0.5) * 2.0**(e - mant_bits) * sign + return res + +# def hifp_matmul(A:torch.Tensor,B:torch.Tensor)->torch.Tensor: + # A = quant_hif8(A) + # # A = any_to_hif8_dml(A,Ec=15) + # B = quant_hif8(B) + # # B = any_to_hif8_dml(B,Ec=15) + # C = torch.matmul(A,B) + # return C +import torch +from torch.autograd import Function +from typing import Optional, Dict, Any + +class HIFPMatMul(Function): + @staticmethod + def forward(ctx, A: torch.Tensor, B: torch.Tensor, + elem_format: str = 'fp8_e5m2', block_size: int = 32, + layer_type: Optional[str] = None, layer_idx: Optional[int] = None, + operation: str = "forward", phase: str = "pre", component: str = "linear", + rank: Optional[int] = None, metadata: Optional[Dict[str, Any]] = None): + # 保存tensor和参数到ctx + ctx.save_for_backward(A, B) + ctx.elem_format = elem_format + ctx.block_size = block_size + ctx.layer_type = layer_type + ctx.layer_idx = layer_idx + ctx.operation = operation + ctx.phase = phase + ctx.component = component + ctx.rank = rank + ctx._metadata = metadata + + # 量化tensor + A_q = quant_hif8(A) + B_q = quant_hif8(B) + + # 执行矩阵乘法 + output = torch.matmul(A_q, B_q) + + # 自动保存forward阶段的tensor + if layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 根据component类型确定tensor名称 + if component == "FA" or component == "attention": + # attention操作:A是attention_probs,B是value + tensor_name_A = "attention_probs" + tensor_name_B = "value" + else: + # linear操作:使用通用名称 + tensor_name_A = "input_A" + tensor_name_B = "input_B" + + # 保存输入tensor A + save_tensor( + tensor=A, + layer_type=layer_type, + operation=operation, + quant_type="hifp8", + tensor_name=tensor_name_A, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存输入tensor B + save_tensor( + tensor=B, + layer_type=layer_type, + operation=operation, + quant_type="hifp8", + tensor_name=tensor_name_B, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存量化后的tensor A_q + save_tensor( + tensor=A_q, + layer_type=layer_type, + operation=operation, + quant_type="hifp8_quantized", + tensor_name="input_A_quantized", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存量化后的tensor B_q + save_tensor( + tensor=B_q, + layer_type=layer_type, + operation=operation, + quant_type="hifp8_quantized", + tensor_name="input_B_quantized", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存输出tensor + save_tensor( + tensor=output, + layer_type=layer_type, + operation=operation, + quant_type="hifp8", + tensor_name="output", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return output + + @staticmethod + def backward(ctx, grad_output): + A, B = ctx.saved_tensors + grad_A = grad_B = None + + # 计算梯度 + if ctx.needs_input_grad[0]: + grad_A = torch.matmul(grad_output, B.transpose(-2, -1)) + if ctx.needs_input_grad[1]: + grad_B = torch.matmul(A.transpose(-2, -1), grad_output) + + # 自动保存backward阶段的tensor + if ctx.layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 保存梯度输出 + save_tensor( + tensor=grad_output, + layer_type=ctx.layer_type, + operation="backward", + quant_type="hifp8", + tensor_name="grad_output", + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 根据component类型确定backward tensor名称 + if ctx.component == "FA" or ctx.component == "attention": + # attention操作:grad_A是grad_attention_probs,grad_B是grad_value + grad_tensor_name_A = "grad_attention_probs" + grad_tensor_name_B = "grad_value" + else: + # linear操作:使用通用名称 + grad_tensor_name_A = "grad_input_A" + grad_tensor_name_B = "grad_input_B" + + # 保存梯度A + if grad_A is not None: + save_tensor( + tensor=grad_A, + layer_type=ctx.layer_type, + operation="backward", + quant_type="hifp8", + tensor_name=grad_tensor_name_A, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 保存梯度B + if grad_B is not None: + save_tensor( + tensor=grad_B, + layer_type=ctx.layer_type, + operation="backward", + quant_type="hifp8", + tensor_name=grad_tensor_name_B, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return grad_A, grad_B, None, None, None, None, None, None, None, None, None # None对应所有额外参数 + +class HIFPBAddBmm(Function): + @staticmethod + def forward(ctx, input, batch1, batch2, beta=1.0, alpha=1.0, + layer_type=None, layer_idx=None, operation="forward", + phase="pre", component="attention", rank=None, metadata=None): + ctx.save_for_backward(input, batch1, batch2) + ctx.beta, ctx.alpha = beta, alpha + ctx.layer_type = layer_type + ctx.layer_idx = layer_idx + ctx.operation = operation + ctx.phase = phase + ctx.component = component + ctx.rank = rank + ctx._metadata = metadata + + # 使用集成了tensor保存的HIFPMatMul + mm_out = HIFPMatMul.apply(batch1, batch2, 'fp8_e5m2', 32, + layer_type, layer_idx, operation, phase, component, rank, metadata) + output = beta * input + alpha * mm_out + + # 自动保存forward阶段的tensor + if layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 根据component类型确定tensor名称 + if component == "FA" or component == "attention": + # attention操作:input是matmul_input_buffer,batch1是query,batch2是key + tensor_name_input = "matmul_input_buffer" + tensor_name_batch1 = "query" + tensor_name_batch2 = "key" + else: + # 其他操作:使用通用名称 + tensor_name_input = "input" + tensor_name_batch1 = "batch1" + tensor_name_batch2 = "batch2" + + # 保存输入tensor + save_tensor( + tensor=input, + layer_type=layer_type, + operation=operation, + quant_type="hifp", + tensor_name=tensor_name_input, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存batch1 tensor + save_tensor( + tensor=batch1, + layer_type=layer_type, + operation=operation, + quant_type="hifp", + tensor_name=tensor_name_batch1, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存batch2 tensor + save_tensor( + tensor=batch2, + layer_type=layer_type, + operation=operation, + quant_type="hifp", + tensor_name=tensor_name_batch2, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存最终输出 + save_tensor( + tensor=output, + layer_type=layer_type, + operation=operation, + quant_type="hifp", + tensor_name="output", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return output + + @staticmethod + def backward(ctx, grad_output): + input, batch1, batch2 = ctx.saved_tensors + beta, alpha = ctx.beta, ctx.alpha + + grad_input = grad_batch1 = grad_batch2 = None + if ctx.needs_input_grad[0]: + grad_input = beta * grad_output + if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: + mm_grad = alpha * grad_output + grad_batch1 = torch.matmul(mm_grad, batch2.transpose(-2, -1)) + grad_batch2 = torch.matmul(batch1.transpose(-2, -1), mm_grad) + + # 自动保存backward阶段的tensor + if ctx.layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 保存梯度输出 + save_tensor( + tensor=grad_output, + layer_type=ctx.layer_type, + operation="backward", + quant_type="hifp", + tensor_name="grad_output", + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 根据component类型确定backward tensor名称 + if ctx.component == "FA" or ctx.component == "attention": + # attention操作:grad_input是grad_matmul_input_buffer,grad_batch1是grad_query,grad_batch2是grad_key + grad_tensor_name_input = "grad_matmul_input_buffer" + grad_tensor_name_batch1 = "grad_query" + grad_tensor_name_batch2 = "grad_key" + else: + # 其他操作:使用通用名称 + grad_tensor_name_input = "grad_input" + grad_tensor_name_batch1 = "grad_batch1" + grad_tensor_name_batch2 = "grad_batch2" + + # 保存梯度input + if grad_input is not None: + save_tensor( + tensor=grad_input, + layer_type=ctx.layer_type, + operation="backward", + quant_type="hifp", + tensor_name=grad_tensor_name_input, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 保存梯度batch1 + if grad_batch1 is not None: + save_tensor( + tensor=grad_batch1, + layer_type=ctx.layer_type, + operation="backward", + quant_type="hifp", + tensor_name=grad_tensor_name_batch1, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 保存梯度batch2 + if grad_batch2 is not None: + save_tensor( + tensor=grad_batch2, + layer_type=ctx.layer_type, + operation="backward", + quant_type="hifp", + tensor_name=grad_tensor_name_batch2, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return grad_input, grad_batch1, grad_batch2, None, None, None, None, None, None, None, None # None对应所有额外参数 + +def hifp_matmul(A, B, **tensor_save_kwargs): + """ + HIFP矩阵乘法函数,支持tensor保存 + + Args: + A, B: 输入tensor + **tensor_save_kwargs: tensor保存相关参数 + - layer_type: 层类型 + - layer_idx: 层索引 + - operation: 操作类型 + - phase: 阶段 + - component: 组件类型 + - rank: GPU rank + - metadata: 元数据 + """ + # 如果有tensor保存参数,使用集成算子 + if tensor_save_kwargs and any(key in tensor_save_kwargs for key in + ['layer_type', 'layer_idx', 'operation', 'phase', 'component', 'rank', 'metadata']): + return HIFPMatMul.apply( + A, B, + tensor_save_kwargs.get('elem_format', 'fp8_e5m2'), + tensor_save_kwargs.get('block_size', 32), + tensor_save_kwargs.get('layer_type'), + tensor_save_kwargs.get('layer_idx'), + tensor_save_kwargs.get('operation', 'forward'), + tensor_save_kwargs.get('phase', 'pre'), + tensor_save_kwargs.get('component', 'linear'), + tensor_save_kwargs.get('rank'), + tensor_save_kwargs.get('metadata') + ) + else: + # 否则使用原始调用方式 + return HIFPMatMul.apply(A, B) + +def hifp_baddbmm(input, batch1, batch2, beta=1.0, alpha=1.0, **tensor_save_kwargs): + """ + HIFP Batch Add Batch Matrix Multiplication函数,支持tensor保存 + + Args: + input, batch1, batch2: 输入tensor + beta, alpha: 参数 + **tensor_save_kwargs: tensor保存相关参数 + """ + # 如果有tensor保存参数,使用集成算子 + if tensor_save_kwargs and any(key in tensor_save_kwargs for key in + ['layer_type', 'layer_idx', 'operation', 'phase', 'component', 'rank', 'metadata']): + return HIFPBAddBmm.apply( + input, batch1, batch2, beta, alpha, + tensor_save_kwargs.get('layer_type'), + tensor_save_kwargs.get('layer_idx'), + tensor_save_kwargs.get('operation', 'forward'), + tensor_save_kwargs.get('phase', 'pre'), + tensor_save_kwargs.get('component', 'attention'), + tensor_save_kwargs.get('rank'), + tensor_save_kwargs.get('metadata') + ) + else: + # 否则使用原始调用方式 + return HIFPBAddBmm.apply(input, batch1, batch2, beta, alpha) + +if __name__ == "__main__": + A = torch.load("grad_output.pt", map_location='cpu').cuda() + fp8 = quant_hif8(A) + # fp8 = any_to_hif8_dml(A, Ec=15) + + print("origin_A:", A) + print("hif8_A:", fp8) + + print(f"A_shape:{A.shape},grad_max:{torch.max(A)},grad_min:{torch.min(A)}") + B = torch.load("total_input.pt", map_location='cpu').cuda() + print(f"B_shape:{B.shape},input_max:{torch.max(B)},input_min:{torch.min(B)}") + + C_hifp8 = hifp_matmul(A.transpose(-2,-1),B) + + print(f"C_shape:{C_hifp8.shape},output_max:{torch.max(C_hifp8)},output_min:{torch.min(C_hifp8)}") diff --git a/quant/mxfp.py b/quant/mxfp.py new file mode 100644 index 0000000000..91ce30f39b --- /dev/null +++ b/quant/mxfp.py @@ -0,0 +1,1078 @@ +import torch +# import torch_npu +from enum import Enum, IntEnum +import numpy as np + + +FP32_EXPONENT_BIAS = 127 +FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1) + +# Enum for scalar data formats +class ElemFormat(Enum): + int8 = 1 + int4 = 2 + int2 = 3 + fp8_e5m2 = 4 + fp8_e4m3 = 5 + fp6_e3m2 = 6 + fp6_e2m3 = 7 + fp4 = 8 + fp4_e2m1 = 8 + float16 = 9 + fp16 = 9 + bfloat16 = 10 + bf16 = 10 + + @staticmethod + def from_str(s): + assert(s != None), "String elem_format == None" + s = s.lower() + if hasattr(ElemFormat, s): + return getattr(ElemFormat, s) + else: + raise Exception("Undefined elem format", s) + + +def _get_min_norm(ebits): + """ Valid for all float formats """ + emin = 2 - (2 ** (ebits - 1)) + return 0 if ebits == 0 else 2 ** emin + + +def _analyze_overflow_underflow_before_quantization(A, elem_format, mbits, ebits, max_norm, verbose=True): + """ + Analyze tensor for overflow and underflow conditions before quantization. + This function is called right before element-wise quantization to detect + potential overflow and underflow issues that might be caused by scaling. + + Args: + A (torch.Tensor): Input tensor after scaling but before quantization + elem_format (str): Element format identifier + mbits (int): Number of mantissa bits + ebits (int): Number of exponent bits + max_norm (float): Maximum normal value for the format + verbose (bool): Whether to print analysis results immediately + + Returns: + dict: Analysis results containing overflow and underflow statistics + """ + analysis_result = { + 'elem_format': elem_format, + 'total_elements': 0, + 'underflow_count': 0, + 'underflow_percent': 0.0, + 'flush_count': 0, + 'flush_percent': 0.0, + 'overflow_count': 0, + 'overflow_percent': 0.0, + 'min_denormal': 0.0, + 'min_norm': 0.0, + 'max_norm': max_norm, + 'tensor_range': [0.0, 0.0], + 'has_significant_underflow': False, + 'has_significant_overflow': False, + 'severity': 'none', # 'none', 'moderate', 'high' + 'error': None + } + + try: + # Calculate minimum representable values + min_norm = _get_min_norm(ebits) + min_denormal = min_norm / (2 ** (mbits - 2)) if mbits > 2 else min_norm + + # Convert to numpy for analysis (handle BFloat16) + if A.dtype == torch.bfloat16: + A_float = A.float() + else: + A_float = A + + if A_float.is_cuda: + A_np = A_float.cpu().numpy() + else: + A_np = A_float.numpy() + + # Handle empty tensors + if A_np.size == 0: + analysis_result['total_elements'] = 0 + return analysis_result + + # Count underflow conditions + total_elements = A_np.size + non_zero_mask = A_np != 0.0 + abs_A = np.abs(A_np) + + # Underflow: non-zero values closer to zero than smallest representable + underflow_mask = non_zero_mask & (abs_A < min_denormal) + underflow_count = np.sum(underflow_mask) + underflow_percent = (underflow_count / total_elements) * 100 + + # Also check for values that would be flushed to zero + flush_mask = non_zero_mask & (abs_A < min_norm) + flush_count = np.sum(flush_mask) + flush_percent = (flush_count / total_elements) * 100 + + # Check for overflow: values larger than maximum representable + overflow_mask = abs_A > max_norm + overflow_count = np.sum(overflow_mask) + overflow_percent = (overflow_count / total_elements) * 100 + + # Store analysis results + analysis_result.update({ + 'total_elements': total_elements, + 'underflow_count': int(underflow_count), + 'underflow_percent': float(underflow_percent), + 'flush_count': int(flush_count), + 'flush_percent': float(flush_percent), + 'overflow_count': int(overflow_count), + 'overflow_percent': float(overflow_percent), + 'min_denormal': float(min_denormal), + 'min_norm': float(min_norm), + 'max_norm': float(max_norm), + 'tensor_range': [float(np.min(A_np)), float(np.max(A_np))], + 'has_significant_underflow': underflow_percent > 0.1 or flush_percent > 0.1, + 'has_significant_overflow': overflow_percent > 0.1 + }) + + # Determine severity based on both overflow and underflow + max_issue_percent = max(underflow_percent, overflow_percent) + if max_issue_percent > 1.0: + analysis_result['severity'] = 'high' + elif max_issue_percent > 0.1: + analysis_result['severity'] = 'moderate' + else: + analysis_result['severity'] = 'none' + + # Print analysis if verbose and significant issues detected + if verbose and (analysis_result['has_significant_underflow'] or analysis_result['has_significant_overflow']): + print(f"\n⚠️ OVERFLOW/UNDERFLOW ANALYSIS ({elem_format}):") + print(f" Total elements: {total_elements:,}") + print(f" Min denormal: {min_denormal:.2e}") + print(f" Min normal: {min_norm:.2e}") + print(f" Max normal: {max_norm:.2e}") + print(f" Underflow count: {underflow_count:,} ({underflow_percent:.2f}%)") + print(f" Flush to zero count: {flush_count:,} ({flush_percent:.2f}%)") + print(f" Overflow count: {overflow_count:,} ({overflow_percent:.2f}%)") + print(f" Tensor range: [{np.min(A_np):.2e}, {np.max(A_np):.2e}]") + + if max_issue_percent > 1.0: + if underflow_percent > overflow_percent: + print(f" 🔴 HIGH UNDERFLOW RATE: {underflow_percent:.2f}%") + else: + print(f" 🔴 HIGH OVERFLOW RATE: {overflow_percent:.2f}%") + print(f" Consider adjusting scaling strategy!") + elif max_issue_percent > 0.1: + if underflow_percent > overflow_percent: + print(f" 🟡 MODERATE UNDERFLOW: {underflow_percent:.2f}%") + else: + print(f" 🟡 MODERATE OVERFLOW: {overflow_percent:.2f}%") + + except Exception as e: + # Don't let analysis errors break the quantization process + analysis_result['error'] = str(e) + if verbose: + print(f"Warning: Underflow analysis failed: {str(e)}") + + return analysis_result + + +def _get_max_norm(ebits, mbits): + """ Valid only for floats that define NaN """ + assert(ebits >= 5), "invalid for floats that don't define NaN" + emax = 0 if ebits==0 else 2**(ebits - 1) - 1 + return 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2) + + +_FORMAT_CACHE = {} +def _get_format_params(fmt): + """ Allowed formats: + - intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation + - floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf + - bfloatX/bfX: 9 <= X <= 32 + - fp4, no NaN/Inf + - fp6_e3m2/e2m3, no NaN/Inf + - fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior + + Returns: + ebits: exponent bits + mbits: mantissa bits: includes sign and implicit bits + emax: max normal exponent + max_norm: max normal number + min_norm: min normal number + """ + if type(fmt) is str: + fmt = ElemFormat.from_str(fmt) + + if fmt in _FORMAT_CACHE: + return _FORMAT_CACHE[fmt] + + if fmt == ElemFormat.int8: + ebits, mbits = 0, 8 + emax = 0 + elif fmt == ElemFormat.int4: + ebits, mbits = 0, 4 + emax = 0 + elif fmt == ElemFormat.int2: + ebits, mbits = 0, 2 + emax = 0 + elif fmt == ElemFormat.fp8_e5m2: + ebits, mbits = 5, 4 + emax = 2**(ebits - 1) - 1 + elif fmt == ElemFormat.fp8_e4m3: + ebits, mbits = 4, 5 + emax = 2**(ebits - 1) + elif fmt == ElemFormat.fp6_e3m2: + ebits, mbits = 3, 4 + emax = 2**(ebits - 1) + elif fmt == ElemFormat.fp6_e2m3: + ebits, mbits = 2, 5 + emax = 2**(ebits - 1) + elif fmt == ElemFormat.fp4: + ebits, mbits = 2, 3 + emax = 2**(ebits - 1) + elif fmt == ElemFormat.float16: + ebits, mbits = 5, 12 + emax = 2**(ebits - 1) - 1 + elif fmt == ElemFormat.bfloat16: + ebits, mbits = 8, 9 + emax = 2**(ebits - 1) - 1 + else: + raise Exception("Unknown element format %s" % fmt) + + if fmt != ElemFormat.fp8_e4m3: + max_norm = 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2) + else: + max_norm = 2**emax * 1.75 # FP8 has custom max_norm + + min_norm = _get_min_norm(ebits) + + _FORMAT_CACHE[fmt] = (ebits, mbits, emax, max_norm, min_norm) + + return ebits, mbits, emax, max_norm, min_norm + + +def _safe_lshift(x, bits, exp): + if exp is None: + return x * (2**bits) + else: + return x / (2 ** exp) * (2**bits) + + +def _safe_rshift(x, bits, exp): + if exp is None: + return x / (2**bits) + else: + return x / (2**bits) * (2 ** exp) + + +def _round_mantissa(A, bits, round, clamp=False): + """ + Rounds mantissa to nearest bits depending on the rounding method 'round' + Args: + A {PyTorch tensor} -- Input tensor + round {str} -- Rounding method + "floor" rounds to the floor + "nearest" rounds to ceil or floor, whichever is nearest + Returns: + A {PyTorch tensor} -- Tensor with mantissas rounded + """ + + if round == "dither": + rand_A = torch.rand_like(A, requires_grad=False) + A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A) + elif round == "floor": + A = torch.sign(A) * torch.floor(torch.abs(A)) + elif round == "nearest": + A = torch.sign(A) * torch.floor(torch.abs(A) + 0.5) + elif round == "even": + absA = torch.abs(A) + # find 0.5, 2.5, 4.5 ... + maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype) + A = torch.sign(A) * (torch.floor(absA + 0.5) - maskA) + else: + raise Exception("Unrecognized round method %s" % (round)) + + # Clip values that cannot be expressed by the specified number of bits + if clamp: + max_mantissa = 2 ** (bits - 1) - 1 + A = torch.clamp(A, -max_mantissa, max_mantissa) + return A + + +def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest', + saturate_normals=False, allow_denorm=True): + """ Core function used for element-wise quantization + Arguments: + A {PyTorch tensor} -- A tensor to be quantized + bits {int} -- Number of mantissa bits. Includes + sign bit and implicit one for floats + exp_bits {int} -- Number of exponent bits, 0 for ints + max_norm {float} -- Largest representable normal number + round {str} -- Rounding mode: (floor, nearest, even) + saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf) + that exceed max norm are clamped. + Must be True for correct MX conversion. + allow_denorm {bool} -- If False, flush denorm numbers in the + elem_format to zero. + Returns: + quantized tensor {PyTorch tensor} -- A tensor that has been quantized + """ + A_is_sparse = A.is_sparse + if A_is_sparse: + if A.layout != torch.sparse_coo: + raise NotImplementedError("Only COO layout sparse tensors are currently supported.") + + sparse_A = A.coalesce() + A = sparse_A.values().clone() + + # Flush values < min_norm to zero if denorms are not allowed + if not allow_denorm and exp_bits > 0: + min_norm = _get_min_norm(exp_bits) + out = (torch.abs(A) >= min_norm).type(A.dtype) * A + else: + out = A + + if exp_bits != 0: + private_exp = torch.floor(torch.log2( + torch.abs(A) + (A == 0).type(A.dtype))) + + # The minimum representable exponent for 8 exp bits is -126 + min_exp = -(2**(exp_bits-1)) + 2 + private_exp = private_exp.clip(min=min_exp) + else: + private_exp = None + + # Scale up so appropriate number of bits are in the integer portion of the number + out = _safe_lshift(out, bits - 2, private_exp) + + out = _round_mantissa(out, bits, round, clamp=False) + + # Undo scaling + out = _safe_rshift(out, bits - 2, private_exp) + + # Set values > max_norm to Inf if desired, else clamp them + if saturate_normals or exp_bits == 0: + out = torch.clamp(out, min=-max_norm, max=max_norm) + else: + out = torch.where((torch.abs(out) > max_norm), + torch.sign(out) * float("Inf"), out) + + # handle Inf/NaN + # out[A == float("Inf")] = float("Inf") + # out[A == -float("Inf")] = -float("Inf") + # out[A == float("NaN")] = float("NaN") + + if A_is_sparse: + output = torch.sparse_coo_tensor(sparse_A.indices(), output, + sparse_A.size(), dtype=sparse_A.dtype, device=sparse_A.device, + requires_grad=sparse_A.requires_grad) + + return out + + +def _shared_exponents(A, method="max", axes=None, ebits=0, scaling_control="max"): + """ + Get shared exponents for the passed matrix A. + Args: + A {PyTorch tensor} -- Input tensor + method {str} -- Exponent selection method. + "max" uses the max absolute value + "none" uses an exponent for each value (i.e., no sharing) + axes {list(int)} -- List of integers which specifies the axes across which + shared exponents are calculated. + Returns: + shared_exp {PyTorch tensor} -- Tensor of shared exponents + """ + + if method == "max": + if axes is None: + max_val = torch.max(torch.abs(A)) + if scaling_control == "max_minus_1": + # Use max - 1 strategy to avoid potential overflow + shared_exp = max_val - 1.0 + else: # default "max" + shared_exp = max_val + else: + shared_exp = A + for axis in axes: + shared_exp, _ = torch.max(torch.abs(shared_exp), dim=axis, keepdim=True) + if scaling_control == "max_minus_1": + # Use max - 1 strategy to avoid potential overflow + shared_exp = shared_exp - 1.0 + elif method == "none": + shared_exp = torch.abs(A) + else: + raise Exception("Unrecognized shared exponent selection method %s" % (method)) + + # log2(shared_exp) and truncate to integer + shared_exp = torch.floor( + torch.log2( + shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype) + ) + ) + + # Restrict to [-emax, emax] range + if ebits > 0: + emax = 2**(ebits-1) - 1 + #shared_exp = torch.clamp(shared_exp, -emax, emax) + # Overflow to Inf + shared_exp[shared_exp > emax] = float("NaN") + # Underflows are set to -127 which causes them to be + # flushed to 0 later + shared_exp[shared_exp < -emax] = -emax + + return shared_exp + + +def _reshape_to_blocks(A, axes, block_size): + if axes is None: + raise Exception( + "axes required in order to determine which " + "dimension toapply block size to" + ) + if block_size == 0: + raise Exception("block_size == 0 in _reshape_to_blocks") + + # Fix axes to be positive and sort them + axes = [(x + len(A.shape) if x < 0 else x) for x in axes] + assert all(x >= 0 for x in axes) + axes = sorted(axes) + + # Add extra dimension for tiles + for i in range(len(axes)): + axes[i] += i # Shift axes due to added dimensions + A = torch.unsqueeze(A, dim=axes[i] + 1) + + # Pad to block_size + orig_shape = A.size() + pad = [] + for i in range(len(orig_shape)): + pad += [0, 0] + + do_padding = False + for axis in axes: + pre_pad_size = orig_shape[axis] + if isinstance(pre_pad_size, torch.Tensor): + pre_pad_size = int(pre_pad_size.value) + # Don't pad if the axis is short enough to fit inside one tile + if pre_pad_size % block_size == 0: + pad[2 * axis] = 0 + else: + pad[2 * axis] = block_size - pre_pad_size % block_size + do_padding = True + + if do_padding: + pad = list(reversed(pad)) + A = torch.nn.functional.pad(A, pad, mode="constant") + + def _reshape(shape, reshape_block_size): + for axis in axes: + # Reshape to tiles if axis length > reshape_block_size + if shape[axis] >= reshape_block_size: + assert shape[axis] % reshape_block_size == 0 + shape[axis + 1] = reshape_block_size + shape[axis] = shape[axis] // reshape_block_size + # Otherwise preserve length and insert a 1 into the shape + else: + shape[axis + 1] = shape[axis] + shape[axis] = 1 + return shape + + # Reshape to tiles + padded_shape = A.size() + reshape = _reshape(list(padded_shape), block_size) + + A = A.view(reshape) + return A, axes, orig_shape, padded_shape + + +def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes): + # Undo tile reshaping + A = A.view(padded_shape) + # Undo padding + if not list(padded_shape) == list(orig_shape): + slices = [slice(0, x) for x in orig_shape] + A = A[slices] + for axis in reversed(axes): + # Remove extra dimension + A = torch.squeeze(A, dim=axis + 1) + return A + + +def _quantize_mx( + A, + scale_bits, + elem_format, # can be None for no quantization + shared_exp_method="max", + axes=None, + block_size=0, + round="nearest", + flush_fp32_subnorms=False, + scaling_control="max", +): + """Function used for MX* quantization + """ + # Shortcut for no quantization + if elem_format == None: + return A + + assert(scale_bits > 0) + + # Make sure axes is a list of non-negative numbers + if axes is None: + axes = [] + else: + axes = [axes] if type(axes) == int else axes + axes = [x + A.ndim if x < 0 else x for x in axes] + + ebits, mbits, emax, max_norm, _ = _get_format_params(elem_format) + + # Perform tiling to the hardware vector size + if block_size > 0: + A, axes, orig_shape, padded_shape = _reshape_to_blocks( + A, axes, block_size + ) + + #################### + # Quantize + #################### + shared_exp_axes = [x + 1 for x in axes] if block_size > 0 else axes + + # Get shared exponents + shared_exp = _shared_exponents( + A, method=shared_exp_method, axes=shared_exp_axes, ebits=0, scaling_control=scaling_control, + ) + + # Flush subnormal FP32 inputs to zero + if flush_fp32_subnorms: + A = A * (shared_exp > -FP32_EXPONENT_BIAS).type(A.dtype) + + # Offset the max exponent by the largest representable exponent + # in the element data format + shared_exp = shared_exp - emax + + scale_emax = 2**(scale_bits-1) - 1 + shared_exp[shared_exp > scale_emax] = float("NaN") + shared_exp[shared_exp < -scale_emax] = -scale_emax + + A = A / (2**shared_exp) + + # Add underflow analysis before quantization + # _analyze_underflow_before_quantization(A, elem_format, mbits, ebits, max_norm) + + A = _quantize_elemwise_core( + A, mbits, ebits, max_norm, round=round, + allow_denorm=True, saturate_normals=True) + + A = A * (2**shared_exp) + + # Undo tile reshaping + if block_size: + A = _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes) + + return A + + +import torch +from torch.autograd import Function +from typing import Optional, Dict, Any + +class MXFPMatMul(Function): + @staticmethod + def forward(ctx, A: torch.Tensor, B: torch.Tensor, + elem_format: str = 'fp8_e5m2', block_size: int = 32, + layer_type: Optional[str] = None, layer_idx: Optional[int] = None, + operation: str = "forward", phase: str = "pre", component: str = "linear", + rank: Optional[int] = None, metadata: Optional[Dict[str, Any]] = None, + scaling_control: str = "max"): + # 保存tensor和参数到ctx + ctx.save_for_backward(A, B) + ctx.elem_format = elem_format + ctx.block_size = block_size + ctx.layer_type = layer_type + ctx.layer_idx = layer_idx + ctx.operation = operation + ctx.phase = phase + ctx.component = component + ctx.rank = rank + ctx._metadata = metadata + ctx.scaling_control = scaling_control + + # 量化tensor + A_q = _quantize_mx( + A, scale_bits=8, elem_format=elem_format, + shared_exp_method="max", axes=-1, block_size=block_size, + round="nearest", flush_fp32_subnorms=False, scaling_control=scaling_control + ) + B_q = _quantize_mx( + B, scale_bits=8, elem_format=elem_format, + shared_exp_method="max", axes=-2, block_size=block_size, + round="nearest", flush_fp32_subnorms=False, scaling_control=scaling_control + ) + + # 执行矩阵乘法 + output = torch.matmul(A_q, B_q) + + # 自动保存forward阶段的tensor + if layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 根据component类型确定tensor名称 + if component == "FA" or component == "attention": + # attention操作:A是attention_probs,B是value + tensor_name_A = "attention_probs" + tensor_name_B = "value" + else: + # linear操作:使用通用名称 + tensor_name_A = "input_A" + tensor_name_B = "input_B" + + # 保存输入tensor A + save_tensor( + tensor=A, + layer_type=layer_type, + operation=operation, + quant_type=f"mxfp_{elem_format}", + tensor_name=tensor_name_A, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存输入tensor B + save_tensor( + tensor=B, + layer_type=layer_type, + operation=operation, + quant_type=f"mxfp_{elem_format}", + tensor_name=tensor_name_B, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存量化后的tensor A_q + save_tensor( + tensor=A_q, + layer_type=layer_type, + operation=operation, + quant_type=f"mxfp_{elem_format}_quantized", + tensor_name="input_A_quantized", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存量化后的tensor B_q + save_tensor( + tensor=B_q, + layer_type=layer_type, + operation=operation, + quant_type=f"mxfp_{elem_format}_quantized", + tensor_name="input_B_quantized", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存输出tensor + save_tensor( + tensor=output, + layer_type=layer_type, + operation=operation, + quant_type=f"mxfp_{elem_format}", + tensor_name="output", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return output + + @staticmethod + def backward(ctx, grad_output): + A, B = ctx.saved_tensors + grad_A = grad_B = None + + # 计算梯度 + if ctx.needs_input_grad[0]: + grad_A = torch.matmul(grad_output, B.transpose(-2, -1)) + if ctx.needs_input_grad[1]: + grad_B = torch.matmul(A.transpose(-2, -1), grad_output) + + # 自动保存backward阶段的tensor + if ctx.layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 保存梯度输出 + save_tensor( + tensor=grad_output, + layer_type=ctx.layer_type, + operation="backward", + quant_type=f"mxfp_{ctx.elem_format}", + tensor_name="grad_output", + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 根据component类型确定backward tensor名称 + if ctx.component == "FA" or ctx.component == "attention": + # attention操作:grad_A是grad_attention_probs,grad_B是grad_value + grad_tensor_name_A = "grad_attention_probs" + grad_tensor_name_B = "grad_value" + else: + # linear操作:使用通用名称 + grad_tensor_name_A = "grad_input_A" + grad_tensor_name_B = "grad_input_B" + + # 保存梯度A + if grad_A is not None: + save_tensor( + tensor=grad_A, + layer_type=ctx.layer_type, + operation="backward", + quant_type=f"mxfp_{ctx.elem_format}", + tensor_name=grad_tensor_name_A, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 保存梯度B + if grad_B is not None: + save_tensor( + tensor=grad_B, + layer_type=ctx.layer_type, + operation="backward", + quant_type=f"mxfp_{ctx.elem_format}", + tensor_name=grad_tensor_name_B, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return grad_A, grad_B, None, None, None, None, None, None, None, None, None, None # None对应所有额外参数(12个) + +class MXFPBAddBmm(Function): + @staticmethod + def forward(ctx, input, batch1, batch2, beta=1.0, alpha=1.0, + elem_format='fp8_e5m2', block_size=32, + layer_type=None, layer_idx=None, operation="forward", + phase="pre", component="attention", rank=None, metadata=None, + scaling_control="max"): + ctx.save_for_backward(input, batch1, batch2) + ctx.beta, ctx.alpha = beta, alpha + ctx.elem_format = elem_format + ctx.block_size = block_size + ctx.layer_type = layer_type + ctx.layer_idx = layer_idx + ctx.operation = operation + ctx.phase = phase + ctx.component = component + ctx.rank = rank + ctx._metadata = metadata + ctx.scaling_control = scaling_control + + # 使用集成了tensor保存的MXFPMatMul + mm_out = MXFPMatMul.apply(batch1, batch2, elem_format, block_size, + layer_type, layer_idx, operation, phase, component, rank, metadata, scaling_control) + output = beta * input + alpha * mm_out + + # 自动保存forward阶段的tensor + if layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 根据component类型确定tensor名称 + if component == "FA" or component == "attention": + # attention操作:input是matmul_input_buffer,batch1是query,batch2是key + tensor_name_input = "matmul_input_buffer" + tensor_name_batch1 = "query" + tensor_name_batch2 = "key" + else: + # 其他操作:使用通用名称 + tensor_name_input = "input" + tensor_name_batch1 = "batch1" + tensor_name_batch2 = "batch2" + + # 保存输入tensor + save_tensor( + tensor=input, + layer_type=layer_type, + operation=operation, + quant_type="mxfp", + tensor_name=tensor_name_input, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存batch1 tensor + save_tensor( + tensor=batch1, + layer_type=layer_type, + operation=operation, + quant_type="mxfp", + tensor_name=tensor_name_batch1, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存batch2 tensor + save_tensor( + tensor=batch2, + layer_type=layer_type, + operation=operation, + quant_type="mxfp", + tensor_name=tensor_name_batch2, + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + # 保存最终输出 + save_tensor( + tensor=output, + layer_type=layer_type, + operation=operation, + quant_type="mxfp", + tensor_name="output", + layer_idx=layer_idx, + phase=phase, + component=component, + rank=rank, + metadata=metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return output + + @staticmethod + def backward(ctx, grad_output): + input, batch1, batch2 = ctx.saved_tensors + beta, alpha = ctx.beta, ctx.alpha + + grad_input = grad_batch1 = grad_batch2 = None + if ctx.needs_input_grad[0]: + grad_input = beta * grad_output + if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: + mm_grad = alpha * grad_output + grad_batch1 = torch.matmul(mm_grad, batch2.transpose(-2, -1)) + grad_batch2 = torch.matmul(batch1.transpose(-2, -1), mm_grad) + + # 自动保存backward阶段的tensor + if ctx.layer_type is not None: + try: + from megatron.core.tensor_saver import save_tensor + + # 保存梯度输出 + save_tensor( + tensor=grad_output, + layer_type=ctx.layer_type, + operation="backward", + quant_type="mxfp", + tensor_name="grad_output", + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 根据component类型确定backward tensor名称 + if ctx.component == "FA" or ctx.component == "attention": + # attention操作:grad_input是grad_matmul_input_buffer,grad_batch1是grad_query,grad_batch2是grad_key + grad_tensor_name_input = "grad_matmul_input_buffer" + grad_tensor_name_batch1 = "grad_query" + grad_tensor_name_batch2 = "grad_key" + else: + # 其他操作:使用通用名称 + grad_tensor_name_input = "grad_input" + grad_tensor_name_batch1 = "grad_batch1" + grad_tensor_name_batch2 = "grad_batch2" + + # 保存梯度input + if grad_input is not None: + save_tensor( + tensor=grad_input, + layer_type=ctx.layer_type, + operation="backward", + quant_type="mxfp", + tensor_name=grad_tensor_name_input, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 保存梯度batch1 + if grad_batch1 is not None: + save_tensor( + tensor=grad_batch1, + layer_type=ctx.layer_type, + operation="backward", + quant_type="mxfp", + tensor_name=grad_tensor_name_batch1, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + # 保存梯度batch2 + if grad_batch2 is not None: + save_tensor( + tensor=grad_batch2, + layer_type=ctx.layer_type, + operation="backward", + quant_type="mxfp", + tensor_name=grad_tensor_name_batch2, + layer_idx=ctx.layer_idx, + phase="post", + component=ctx.component, + rank=ctx.rank, + metadata=ctx._metadata + ) + + except ImportError: + pass # 如果tensor_saver不可用,静默跳过 + except Exception as e: + pass # Silently ignore tensor saving errors + + return grad_input, grad_batch1, grad_batch2, None, None, None, None, None, None, None, None, None, None, None, None # None对应所有额外参数(15个) + +def mxfp_matmul(A, B, elem_format='fp8_e5m2', block_size=32, scaling_control='max', **tensor_save_kwargs): + """ + MXFP矩阵乘法函数,支持tensor保存 + + Args: + A, B: 输入tensor + elem_format: 元素格式 + block_size: 块大小 + **tensor_save_kwargs: tensor保存相关参数 + - layer_type: 层类型 + - layer_idx: 层索引 + - operation: 操作类型 + - phase: 阶段 + - component: 组件类型 + - rank: GPU rank + - metadata: 元数据 + """ + # 如果有tensor保存参数,使用集成算子 + if tensor_save_kwargs and any(key in tensor_save_kwargs for key in + ['layer_type', 'layer_idx', 'operation', 'phase', 'component', 'rank', 'metadata']): + return MXFPMatMul.apply( + A, B, elem_format, block_size, + tensor_save_kwargs.get('layer_type'), + tensor_save_kwargs.get('layer_idx'), + tensor_save_kwargs.get('operation', 'forward'), + tensor_save_kwargs.get('phase', 'pre'), + tensor_save_kwargs.get('component', 'linear'), + tensor_save_kwargs.get('rank'), + tensor_save_kwargs.get('metadata'), + scaling_control + ) + else: + # 否则使用原始调用方式 + return MXFPMatMul.apply(A, B, elem_format, block_size, None, None, "forward", "pre", "linear", None, None, scaling_control) + +def mxfp_baddbmm(input, batch1, batch2, beta=1.0, alpha=1.0, + elem_format='fp8_e5m2', block_size=32, scaling_control='max', **tensor_save_kwargs): + """ + MXFP Batch Add Batch Matrix Multiplication函数,支持tensor保存 + + Args: + input, batch1, batch2: 输入tensor + beta, alpha: 参数 + elem_format: 元素格式 + block_size: 块大小 + **tensor_save_kwargs: tensor保存相关参数 + """ + # 如果有tensor保存参数,使用集成算子 + if tensor_save_kwargs and any(key in tensor_save_kwargs for key in + ['layer_type', 'layer_idx', 'operation', 'phase', 'component', 'rank', 'metadata']): + return MXFPBAddBmm.apply( + input, batch1, batch2, beta, alpha, elem_format, block_size, + tensor_save_kwargs.get('layer_type'), + tensor_save_kwargs.get('layer_idx'), + tensor_save_kwargs.get('operation', 'forward'), + tensor_save_kwargs.get('phase', 'pre'), + tensor_save_kwargs.get('component', 'attention'), + tensor_save_kwargs.get('rank'), + tensor_save_kwargs.get('metadata'), + scaling_control + ) + else: + # 否则使用原始调用方式 + return MXFPBAddBmm.apply(input, batch1, batch2, beta, alpha, elem_format, block_size, None, None, "forward", "pre", "attention", None, None, scaling_control) + +if __name__ == '__main__': + A = torch.load("grad_output.pt", map_location='cpu').cuda() + print(f"A_shape:{A.shape},grad_max:{torch.max(A)},grad_min:{torch.min(A)}") + B = torch.load("total_input.pt", map_location='cpu').cuda() + print(f"B_shape:{B.shape},input_max:{torch.max(B)},input_min:{torch.min(B)}") + A = A.unsqueeze(0).repeat(3, 1, 1) + B = B.unsqueeze(0).repeat(3, 1, 1) + C = torch.matmul(A.transpose(-2, -1), B) + D = torch.baddbmm(C,A.transpose(-2,-1),B) + print(f"C_shape:{C.shape},output_max:{torch.max(C)},output_min:{torch.min(C)}") + C_e4m3 = mxfp_matmul(A.transpose(-2,-1),B,'fp8_e4m3') + D_e4m3 = mxfp_baddbmm(C,A.transpose(-2,-1),B,elem_format='fp8_e4m3') + print(f"C_shape:{C_e4m3.shape},output_max:{torch.max(C_e4m3)},output_min:{torch.min(C_e4m3)}") + print(f"D_shape:{D_e4m3.shape},output_max:{torch.max(D_e4m3)},output_min:{torch.min(D_e4m3)}") + print(torch.isnan(C).any()) + + mse_e4m3 = torch.mean((C - C_e4m3) ** 2) + max_err_e4m3 = torch.max(torch.abs(C - C_e4m3)) + print(f"MSE: {mse_e4m3:.20f}") + print(f"Max Error: {max_err_e4m3:.20f}") + print(f"相对误差: {mse_e4m3 / torch.mean(C ** 2):.20f}") + + b_mse_e4m3 = torch.mean((D - D_e4m3) ** 2) + b_max_err_e4m3 = torch.max(torch.abs(D - D_e4m3)) + print(f"B_MSE: {b_mse_e4m3:.20f}") + print(f"B Max Error: {b_max_err_e4m3:.20f}") + + + diff --git a/quant/mxfp_scaling_test.py b/quant/mxfp_scaling_test.py new file mode 100644 index 0000000000..19b6793f78 --- /dev/null +++ b/quant/mxfp_scaling_test.py @@ -0,0 +1,1075 @@ +#!/usr/bin/env python3 +""" +MXFP Scaling Test Tool +Tests different scaling strategies for MXFP quantization and evaluates their impact on accuracy. +""" + +import torch +import numpy as np +import matplotlib.pyplot as plt +import argparse +from pathlib import Path +import sys +import os +import logging +from datetime import datetime + +# Add the parent directory to path to import mxfp module +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from quant.mxfp import _quantize_mx, _get_format_params, ElemFormat + +def setup_logging(output_dir, tensor_name, elem_format): + """ + Setup logging to both console and file. + + Args: + output_dir (Path): Output directory for log file + tensor_name (str): Name of the input tensor + elem_format (str): Element format being tested + + Returns: + logging.Logger: Configured logger + """ + # Create logger + logger = logging.getLogger('mxfp_scaling_test') + logger.setLevel(logging.INFO) + + # Clear any existing handlers + logger.handlers.clear() + + # Create formatter + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # File handler + log_filename = f"mxfp_scaling_test_{tensor_name}_{elem_format}.log" + log_path = output_dir / log_filename + + file_handler = logging.FileHandler(log_path, mode='w', encoding='utf-8') + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + # Log initial information + logger.info("=" * 80) + logger.info("MXFP SCALING TEST LOG") + logger.info("=" * 80) + logger.info(f"Test started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + logger.info(f"Input tensor: {tensor_name}") + logger.info(f"Element format: {elem_format}") + logger.info(f"Output directory: {output_dir}") + logger.info("=" * 80) + + return logger + +def calculate_metrics(original_tensor, quantized_tensor): + """ + Calculate various metrics between original and quantized tensors. + + Args: + original_tensor (torch.Tensor): Original BF16 tensor + quantized_tensor (torch.Tensor): Quantized tensor + + Returns: + dict: Dictionary containing all calculated metrics + """ + # Convert to float32 for calculation + orig_f32 = original_tensor.float() + quant_f32 = quantized_tensor.float() + + # MSE (Mean Squared Error) + mse = torch.mean((orig_f32 - quant_f32) ** 2).item() + + # RMSE (Root Mean Squared Error) + rmse = torch.sqrt(torch.mean((orig_f32 - quant_f32) ** 2)).item() + + # Cosine Similarity + orig_flat = orig_f32.flatten() + quant_flat = quant_f32.flatten() + + # Avoid division by zero + orig_norm = torch.norm(orig_flat) + quant_norm = torch.norm(quant_flat) + + if orig_norm > 0 and quant_norm > 0: + cosine_sim = torch.dot(orig_flat, quant_flat) / (orig_norm * quant_norm) + cosine_sim = cosine_sim.item() + else: + cosine_sim = 1.0 if orig_norm == 0 and quant_norm == 0 else 0.0 + + # PSNR (Peak Signal-to-Noise Ratio) + if mse > 0: + # Use the maximum value in original tensor as peak signal + max_val = torch.max(torch.abs(orig_f32)).item() + psnr = 20 * np.log10(max_val / np.sqrt(mse)) if max_val > 0 else float('inf') + else: + psnr = float('inf') + + # MAE (Mean Absolute Error) + mae = torch.mean(torch.abs(orig_f32 - quant_f32)).item() + + # Maximum Absolute Error + max_abs_error = torch.max(torch.abs(orig_f32 - quant_f32)).item() + + # Relative Error (percentage) + orig_mean_abs = torch.mean(torch.abs(orig_f32)).item() + relative_error = (mae / orig_mean_abs * 100) if orig_mean_abs > 0 else 0.0 + + return { + 'mse': mse, + 'rmse': rmse, + 'cosine_similarity': cosine_sim, + 'psnr': psnr, + 'mae': mae, + 'max_abs_error': max_abs_error, + 'relative_error': relative_error + } + +def test_scaling_levels(input_tensor, elem_format='fp8_e4m3', scale_bits=8, + max_scale_exp=10, min_scale_exp=-10, logger=None): + """ + Test different scaling levels for MXFP quantization. + + Args: + input_tensor (torch.Tensor): Input BF16 tensor + elem_format (str): Element format for quantization + scale_bits (int): Number of scale bits + max_scale_exp (int): Maximum scale exponent (aligned with max value) + min_scale_exp (int): Minimum scale exponent (aligned with min value) + logger: Logger instance for output + + Returns: + dict: Results for each scaling level (all integers in range) + """ + # Get format parameters + ebits, mbits, emax, max_norm, min_norm = _get_format_params(elem_format) + + # Calculate tensor statistics for alignment + tensor_abs_max = torch.max(torch.abs(input_tensor)).item() + tensor_abs_min = torch.min(torch.abs(input_tensor[input_tensor != 0])).item() if torch.any(input_tensor != 0) else tensor_abs_max + + # Calculate emax for the format (following mxfp.py logic) + emax = 2**(ebits - 1) - 1 if ebits > 0 else 0 + + # Calculate scale exponents following mxfp.py _quantize_mx logic: + # In mxfp.py: + # 1. shared_exp = floor(log2(max_abs_value)) (from _shared_exponents with method="max") + # 2. shared_exp = shared_exp - emax (offset by emax) + # 3. A = A / (2^shared_exp) (apply scaling) + # + # So the actual scaling factor used by mxfp.py is: 2^(floor(log2(max)) - emax) + # + # For alignment calculations: + # - Max alignment: Use the same logic as mxfp.py (global max alignment) + # This gives: scale_exp = floor(log2(tensor_abs_max)) - emax + # - Min alignment: Find scale_exp such that tensor_abs_min / (2^scale_exp) >= min_norm + # So scale_exp <= log2(tensor_abs_min / min_norm) + + # Calculate the scale exponent that mxfp.py would use (for reference) + tensor_shared_exp = np.floor(np.log2(tensor_abs_max)) if tensor_abs_max > 0 else 0 + max_align_exp = tensor_shared_exp - emax # This is what mxfp.py actually uses + + # Calculate min alignment: find scale_exp such that scaled min >= min_norm + min_align_exp = np.floor(np.log2(tensor_abs_min / min_norm)) if tensor_abs_min > 0 and min_norm > 0 else max_align_exp + + # Use user-specified parameters directly, with calculated values as fallback for default parameters + if max_scale_exp == 10: # Default value, use calculated + max_scale_exp = max_align_exp + if min_scale_exp == -10: # Default value, use calculated + min_scale_exp = min_align_exp + + # Ensure max_scale_exp >= min_scale_exp + if max_scale_exp < min_scale_exp: + max_scale_exp, min_scale_exp = min_scale_exp, max_scale_exp + + # Generate integer scale exponents from max to min (inclusive) + max_exp_int = int(max_scale_exp) + min_exp_int = int(min_scale_exp) + + if max_exp_int == min_exp_int: + # Single point range - use the same integer value + scale_exponents = np.array([max_exp_int]) + else: + # Create integer range from max to min (inclusive) + scale_exponents = np.arange(max_exp_int, min_exp_int - 1, -1, dtype=int) + + results = { + 'scale_exponents': scale_exponents.tolist(), + 'metrics': {}, + 'elem_format': elem_format, + 'scale_bits': scale_bits, + 'format_params': { + 'ebits': ebits, + 'mbits': mbits, + 'emax': emax, + 'max_norm': max_norm, + 'min_norm': min_norm + } + } + + log_func = logger.info if logger else print + log_func(f"Tensor absolute value range: [{tensor_abs_min:.6e}, {tensor_abs_max:.6e}]") + log_func(f"Format range: max_norm={max_norm:.6e}, min_norm={min_norm:.6e}") + log_func(f"Calculated alignment (reference): max_align={max_align_exp:.2f}, min_align={min_align_exp:.2f}") + log_func(f"Testing integer scaling levels from {max_scale_exp:.2f} to {min_scale_exp:.2f}") + log_func(f"Element format: {elem_format} (e{ebits}m{mbits})") + log_func(f"Scale bits: {scale_bits}") + log_func("-" * 60) + + for i, scale_exp in enumerate(scale_exponents): + log_func(f"Testing scale exponent {scale_exp} ({i+1}/{len(scale_exponents)})...") + + # Create a custom quantize function with fixed scale exponent + quantized_tensor, overflow_underflow_analysis = quantize_with_fixed_scale( + input_tensor, elem_format, scale_bits, scale_exp, + ebits, mbits, max_norm + ) + + # Calculate metrics + metrics = calculate_metrics(input_tensor, quantized_tensor) + + # Store results + results['metrics'][f'scale_{i}'] = { + 'scale_exponent': float(scale_exp), + 'metrics': metrics, + 'overflow_underflow_analysis': overflow_underflow_analysis + } + + # Print current metrics + log_func(f" MSE: {metrics['mse']:.6e}, " + f"Cosine Sim: {metrics['cosine_similarity']:.6f}, " + f"PSNR: {metrics['psnr']:.2f} dB") + + return results + +def analyze_scaling_results(results, logger=None): + """ + Analyze scaling test results and recommend optimal scaling factors. + + Args: + results (dict): Results from test_scaling_levels + logger: Logger instance for output + + Returns: + dict: Analysis results with recommendations + """ + log_func = logger.info if logger else print + + scale_exponents = results['scale_exponents'] + elem_format = results['elem_format'] + format_params = results['format_params'] + + # Extract metrics for analysis + metrics_data = {} + for metric_name in ['mse', 'cosine_similarity', 'psnr', 'mae', 'relative_error']: + metrics_data[metric_name] = [] + for i in range(len(scale_exponents)): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + metrics_data[metric_name].append(results['metrics'][scale_key]['metrics'][metric_name]) + + # Find best indices for different metrics + # Use tolerance to handle numerical precision issues + tolerance = 1e-10 + + def find_best_indices(values, is_better_func): + """Find all indices with the best value, return the one with the largest scale exponent when tied""" + best_value = is_better_func(values) + if is_better_func == min: + best_indices = [i for i, v in enumerate(values) if abs(v - best_value) < tolerance] + else: # max + best_indices = [i for i, v in enumerate(values) if abs(v - best_value) < tolerance] + + if best_indices: + # When there are ties, choose the one with the largest scale exponent (closest to 0) + # Since scale_exponents are in descending order, the first index has the largest value + return best_indices[0] + else: + return 0 + + best_mse_idx = find_best_indices(metrics_data['mse'], min) + best_cosine_idx = find_best_indices(metrics_data['cosine_similarity'], max) + best_psnr_idx = find_best_indices(metrics_data['psnr'], max) + best_mae_idx = find_best_indices(metrics_data['mae'], min) + best_relative_error_idx = find_best_indices(metrics_data['relative_error'], min) + + # Calculate composite scores + # Normalize metrics to [0, 1] range for comparison + mse_normalized = 1 - (np.array(metrics_data['mse']) - np.min(metrics_data['mse'])) / (np.max(metrics_data['mse']) - np.min(metrics_data['mse']) + 1e-10) + cosine_normalized = np.array(metrics_data['cosine_similarity']) + psnr_normalized = (np.array(metrics_data['psnr']) - np.min(metrics_data['psnr'])) / (np.max(metrics_data['psnr']) - np.min(metrics_data['psnr']) + 1e-10) + mae_normalized = 1 - (np.array(metrics_data['mae']) - np.min(metrics_data['mae'])) / (np.max(metrics_data['mae']) - np.min(metrics_data['mae']) + 1e-10) + relative_error_normalized = 1 - (np.array(metrics_data['relative_error']) - np.min(metrics_data['relative_error'])) / (np.max(metrics_data['relative_error']) - np.min(metrics_data['relative_error']) + 1e-10) + + # Weighted composite score (can be adjusted based on priorities) + composite_scores = ( + 0.3 * mse_normalized + # Lower MSE is better + 0.3 * cosine_normalized + # Higher cosine similarity is better + 0.2 * psnr_normalized + # Higher PSNR is better + 0.1 * mae_normalized + # Lower MAE is better + 0.1 * relative_error_normalized # Lower relative error is better + ) + + # Find best composite index, handling ties by choosing larger scale exponent + best_composite_score = np.max(composite_scores) + best_composite_indices = [i for i, score in enumerate(composite_scores) if abs(score - best_composite_score) < tolerance] + # When there are ties, choose the one with the largest scale exponent (first index) + best_composite_idx = best_composite_indices[0] if best_composite_indices else 0 + + # Calculate scaling factor from scale exponent + def exp_to_factor(exp): + return 2 ** exp + + # Analysis results + analysis = { + 'best_mse': { + 'index': best_mse_idx, + 'scale_exp': scale_exponents[best_mse_idx], + 'scale_factor': exp_to_factor(scale_exponents[best_mse_idx]), + 'metrics': { + 'mse': metrics_data['mse'][best_mse_idx], + 'cosine_similarity': metrics_data['cosine_similarity'][best_mse_idx], + 'psnr': metrics_data['psnr'][best_mse_idx], + 'mae': metrics_data['mae'][best_mse_idx], + 'relative_error': metrics_data['relative_error'][best_mse_idx] + } + }, + 'best_cosine': { + 'index': best_cosine_idx, + 'scale_exp': scale_exponents[best_cosine_idx], + 'scale_factor': exp_to_factor(scale_exponents[best_cosine_idx]), + 'metrics': { + 'mse': metrics_data['mse'][best_cosine_idx], + 'cosine_similarity': metrics_data['cosine_similarity'][best_cosine_idx], + 'psnr': metrics_data['psnr'][best_cosine_idx], + 'mae': metrics_data['mae'][best_cosine_idx], + 'relative_error': metrics_data['relative_error'][best_cosine_idx] + } + }, + 'best_psnr': { + 'index': best_psnr_idx, + 'scale_exp': scale_exponents[best_psnr_idx], + 'scale_factor': exp_to_factor(scale_exponents[best_psnr_idx]), + 'metrics': { + 'mse': metrics_data['mse'][best_psnr_idx], + 'cosine_similarity': metrics_data['cosine_similarity'][best_psnr_idx], + 'psnr': metrics_data['psnr'][best_psnr_idx], + 'mae': metrics_data['mae'][best_psnr_idx], + 'relative_error': metrics_data['relative_error'][best_psnr_idx] + } + }, + 'best_composite': { + 'index': best_composite_idx, + 'scale_exp': scale_exponents[best_composite_idx], + 'scale_factor': exp_to_factor(scale_exponents[best_composite_idx]), + 'composite_score': composite_scores[best_composite_idx], + 'metrics': { + 'mse': metrics_data['mse'][best_composite_idx], + 'cosine_similarity': metrics_data['cosine_similarity'][best_composite_idx], + 'psnr': metrics_data['psnr'][best_composite_idx], + 'mae': metrics_data['mae'][best_composite_idx], + 'relative_error': metrics_data['relative_error'][best_composite_idx] + } + } + } + + # Log detailed analysis + log_func("\n" + "=" * 80) + log_func("SCALING FACTOR ANALYSIS & RECOMMENDATIONS") + log_func("=" * 80) + + log_func(f"Format: {elem_format} (e{format_params['ebits']}m{format_params['mbits']})") + log_func(f"Tested {len(scale_exponents)} scaling levels from {scale_exponents[0]:.2f} to {scale_exponents[-1]:.2f}") + log_func("-" * 80) + + # Check for ties in individual metrics + individual_indices = [best_mse_idx, best_cosine_idx, best_psnr_idx, best_mae_idx, best_relative_error_idx] + individual_names = ['MSE', 'Cosine Similarity', 'PSNR', 'MAE', 'Relative Error'] + + # Find if all individual metrics point to the same scale exponent + if len(set(individual_indices)) == 1: + log_func("🎯 ALL INDIVIDUAL METRICS AGREE:") + log_func("-" * 40) + log_func(f" All metrics recommend Scale Exp = {scale_exponents[individual_indices[0]]:.2f}") + log_func(f" Scale Factor = {analysis['best_mse']['scale_factor']:.6f}") + + # Check if there were ties and we chose the larger scale exponent + scale_exp = scale_exponents[individual_indices[0]] + all_same_values = [] + for i, (name, idx) in enumerate(zip(['MSE', 'Cosine', 'PSNR', 'MAE', 'Relative'], individual_indices)): + metric_values = [metrics_data[metric_name][idx] for metric_name in ['mse', 'cosine_similarity', 'psnr', 'mae', 'relative_error']] + all_same_values.extend([(name, metrics_data['mse'][idx]), (name, metrics_data['cosine_similarity'][idx])]) + + # Check for ties in the range + tied_indices = [] + for i in range(len(scale_exponents)): + if abs(scale_exponents[i] - scale_exp) < 0.1: # Check for nearby scale exponents + tied_indices.append(i) + + if len(tied_indices) > 1: + log_func(f" Note: Multiple scale exponents ({', '.join([f'{scale_exponents[i]:.2f}' for i in tied_indices])})") + log_func(f" produced identical performance. Selected largest: {scale_exp:.2f}") + else: + # Best results for individual metrics + log_func("INDIVIDUAL METRIC OPTIMA:") + log_func("-" * 40) + + log_func(f"🏆 Best MSE: Scale Exp = {analysis['best_mse']['scale_exp']:.2f}, Factor = {analysis['best_mse']['scale_factor']:.6f}") + log_func(f" MSE: {analysis['best_mse']['metrics']['mse']:.6e}, Cosine: {analysis['best_mse']['metrics']['cosine_similarity']:.6f}, PSNR: {analysis['best_mse']['metrics']['psnr']:.2f} dB") + + log_func(f"🎯 Best Cosine Similarity: Scale Exp = {analysis['best_cosine']['scale_exp']:.2f}, Factor = {analysis['best_cosine']['scale_factor']:.6f}") + log_func(f" MSE: {analysis['best_cosine']['metrics']['mse']:.6e}, Cosine: {analysis['best_cosine']['metrics']['cosine_similarity']:.6f}, PSNR: {analysis['best_cosine']['metrics']['psnr']:.2f} dB") + + log_func(f"📊 Best PSNR: Scale Exp = {analysis['best_psnr']['scale_exp']:.2f}, Factor = {analysis['best_psnr']['scale_factor']:.6f}") + log_func(f" MSE: {analysis['best_psnr']['metrics']['mse']:.6e}, Cosine: {analysis['best_psnr']['metrics']['cosine_similarity']:.6f}, PSNR: {analysis['best_psnr']['metrics']['psnr']:.2f} dB") + + # Composite recommendation + log_func("-" * 80) + log_func("COMPOSITE RECOMMENDATION:") + log_func("-" * 40) + + # Check if composite recommendation agrees with individual metrics + if len(set(individual_indices)) == 1 and individual_indices[0] == best_composite_idx: + log_func("🎯 UNANIMOUS RECOMMENDATION:") + log_func("-" * 40) + log_func(f" All individual metrics AND composite score agree!") + elif best_composite_idx in individual_indices: + log_func("📊 BALANCED RECOMMENDATION:") + log_func("-" * 40) + log_func(f" Composite score matches some individual metrics") + else: + log_func("⚖️ COMPOSITE RECOMMENDATION:") + log_func("-" * 40) + log_func(f" Composite score provides balanced recommendation") + + log_func(f"⭐ RECOMMENDED Scaling Factor: {analysis['best_composite']['scale_factor']:.6f}") + log_func(f" Scale Exponent: {analysis['best_composite']['scale_exp']:.2f}") + log_func(f" Composite Score: {analysis['best_composite']['composite_score']:.4f}") + log_func(f" Balanced Performance:") + log_func(f" - MSE: {analysis['best_composite']['metrics']['mse']:.6e}") + log_func(f" - Cosine Similarity: {analysis['best_composite']['metrics']['cosine_similarity']:.6f}") + log_func(f" - PSNR: {analysis['best_composite']['metrics']['psnr']:.2f} dB") + log_func(f" - MAE: {analysis['best_composite']['metrics']['mae']:.6e}") + log_func(f" - Relative Error: {analysis['best_composite']['metrics']['relative_error']:.2f}%") + + # Performance analysis + log_func("-" * 80) + log_func("PERFORMANCE ANALYSIS:") + log_func("-" * 40) + + # Calculate performance ranges + mse_range = np.max(metrics_data['mse']) - np.min(metrics_data['mse']) + cosine_range = np.max(metrics_data['cosine_similarity']) - np.min(metrics_data['cosine_similarity']) + psnr_range = np.max(metrics_data['psnr']) - np.min(metrics_data['psnr']) + + log_func(f"MSE Range: {np.min(metrics_data['mse']):.6e} to {np.max(metrics_data['mse']):.6e} (Δ: {mse_range:.6e})") + log_func(f"Cosine Range: {np.min(metrics_data['cosine_similarity']):.6f} to {np.max(metrics_data['cosine_similarity']):.6f} (Δ: {cosine_range:.6f})") + log_func(f"PSNR Range: {np.min(metrics_data['psnr']):.2f} to {np.max(metrics_data['psnr']):.2f} dB (Δ: {psnr_range:.2f} dB)") + + # Stability analysis + mse_std = np.std(metrics_data['mse']) + cosine_std = np.std(metrics_data['cosine_similarity']) + + log_func(f"MSE Stability (std): {mse_std:.6e}") + log_func(f"Cosine Stability (std): {cosine_std:.6f}") + + # Recommendations based on analysis + log_func("-" * 80) + log_func("RECOMMENDATIONS:") + log_func("-" * 40) + + if mse_range / np.min(metrics_data['mse']) < 0.1: + log_func("✅ MSE is relatively stable across scaling factors - any factor in the tested range should work well") + else: + log_func("⚠️ MSE varies significantly with scaling - choose the recommended factor carefully") + + if cosine_range < 0.01: + log_func("✅ Cosine similarity is very stable - scaling factor has minimal impact on direction preservation") + else: + log_func("⚠️ Cosine similarity varies with scaling - consider the impact on vector direction") + + if psnr_range > 20: + log_func("📈 Large PSNR range indicates significant quality differences - scaling factor choice is critical") + elif psnr_range > 10: + log_func("📊 Moderate PSNR range - scaling factor has noticeable impact on quality") + else: + log_func("✅ Small PSNR range - scaling factor has limited impact on quality") + + # Final recommendation + log_func("-" * 80) + log_func("FINAL RECOMMENDATION:") + log_func("-" * 40) + log_func(f"🎯 Use scaling factor: {analysis['best_composite']['scale_factor']:.6f}") + log_func(f" This provides the best balance of accuracy and stability for {elem_format} quantization") + log_func(f" Scale exponent: {analysis['best_composite']['scale_exp']:.2f}") + + if analysis['best_composite']['index'] == 0: + log_func(" 📍 This is at the maximum alignment end (minimal overflow risk)") + elif analysis['best_composite']['index'] == len(scale_exponents) - 1: + log_func(" 📍 This is at the minimum alignment end (minimal underflow risk)") + else: + log_func(" 📍 This is a balanced middle ground between overflow and underflow") + + log_func("=" * 80) + + return analysis + +def analyze_overflow_underflow_results(results, logger=None): + """ + Analyze and display overflow and underflow results from scaling tests. + + Args: + results (dict): Results from test_scaling_levels + logger: Logger instance for output + """ + log_func = logger.info if logger else print + + scale_exponents = results['scale_exponents'] + elem_format = results['elem_format'] + + # Collect all overflow/underflow analyses + overflow_underflow_results = [] + significant_issues = [] + + for i in range(len(scale_exponents)): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + analysis = results['metrics'][scale_key]['overflow_underflow_analysis'] + analysis['scale_exp'] = scale_exponents[i] + analysis['scale_factor'] = 2 ** scale_exponents[i] + overflow_underflow_results.append(analysis) + + if analysis['has_significant_underflow'] or analysis['has_significant_overflow']: + significant_issues.append(analysis) + + # Only display analysis if there are significant issues + if not significant_issues: + log_func("\n✅ No significant overflow or underflow issues detected across all scaling levels") + return + + # Display comprehensive overflow/underflow analysis + log_func("\n" + "=" * 80) + log_func("OVERFLOW/UNDERFLOW ANALYSIS SUMMARY") + log_func("=" * 80) + + log_func(f"Format: {elem_format}") + log_func(f"Analyzed {len(scale_exponents)} scaling levels") + log_func(f"Significant overflow/underflow detected in {len(significant_issues)} levels") + log_func("-" * 80) + + # Group by severity + high_severity = [u for u in significant_issues if u['severity'] == 'high'] + moderate_severity = [u for u in significant_issues if u['severity'] == 'moderate'] + + # Separate overflow and underflow issues + overflow_issues = [u for u in significant_issues if u['has_significant_overflow']] + underflow_issues = [u for u in significant_issues if u['has_significant_underflow']] + + # Display overflow issues + if overflow_issues: + log_func("🔴 OVERFLOW ISSUES:") + log_func("-" * 40) + for uf in overflow_issues: + log_func(f" Scale Exp: {uf['scale_exp']:.2f} (Factor: {uf['scale_factor']:.6f})") + log_func(f" Overflow: {uf['overflow_count']:,} ({uf['overflow_percent']:.2f}%)") + log_func(f" Max Normal: {uf['max_norm']:.2e}") + log_func(f" Tensor Range: [{uf['tensor_range'][0]:.2e}, {uf['tensor_range'][1]:.2e}]") + log_func(f" Severity: {uf['severity'].upper()}") + log_func("") + + # Display underflow issues + if underflow_issues: + log_func("🟡 UNDERFLOW ISSUES:") + log_func("-" * 40) + for uf in underflow_issues: + log_func(f" Scale Exp: {uf['scale_exp']:.2f} (Factor: {uf['scale_factor']:.6f})") + log_func(f" Underflow: {uf['underflow_count']:,} ({uf['underflow_percent']:.2f}%)") + log_func(f" Flush to Zero: {uf['flush_count']:,} ({uf['flush_percent']:.2f}%)") + log_func(f" Min Normal: {uf['min_norm']:.2e}") + log_func(f" Tensor Range: [{uf['tensor_range'][0]:.2e}, {uf['tensor_range'][1]:.2e}]") + log_func(f" Severity: {uf['severity'].upper()}") + log_func("") + + # Find best and worst cases + if overflow_issues: + worst_overflow = max(overflow_issues, key=lambda x: x['overflow_percent']) + log_func("OVERFLOW EXTREMES:") + log_func("-" * 40) + log_func(f"Worst Overflow: Scale Exp {worst_overflow['scale_exp']:.2f}") + log_func(f" {worst_overflow['overflow_percent']:.2f}% overflow") + + if underflow_issues: + worst_underflow = max(underflow_issues, key=lambda x: x['underflow_percent']) + best_underflow = min(underflow_issues, key=lambda x: x['underflow_percent']) + log_func("UNDERFLOW EXTREMES:") + log_func("-" * 40) + log_func(f"Worst Underflow: Scale Exp {worst_underflow['scale_exp']:.2f}") + log_func(f" {worst_underflow['underflow_percent']:.2f}% underflow, {worst_underflow['flush_percent']:.2f}% flushed to zero") + log_func(f"Best Underflow: Scale Exp {best_underflow['scale_exp']:.2f}") + log_func(f" {best_underflow['underflow_percent']:.2f}% underflow, {best_underflow['flush_percent']:.2f}% flushed to zero") + + # Recommendations + log_func("-" * 80) + log_func("OVERFLOW/UNDERFLOW RECOMMENDATIONS:") + log_func("-" * 40) + + if high_severity: + log_func("⚠️ AVOID scaling factors with HIGH overflow/underflow severity") + log_func(" These factors cause significant precision loss") + + if overflow_issues: + log_func("🔴 OVERFLOW WARNING:") + log_func(" Avoid scaling factors that cause overflow") + log_func(" These values will be saturated to max representable value") + + if underflow_issues: + log_func("🟡 UNDERFLOW CONSIDERATIONS:") + log_func(" Moderate underflow may be acceptable depending on use case") + log_func(" Balance between underflow and overflow risks") + + # Find optimal range + no_issue_levels = [u for u in overflow_underflow_results if not u['has_significant_underflow'] and not u['has_significant_overflow']] + if no_issue_levels: + optimal_range = [min(u['scale_exp'] for u in no_issue_levels), + max(u['scale_exp'] for u in no_issue_levels)] + log_func(f"✅ RECOMMENDED scaling range: {optimal_range[0]:.2f} to {optimal_range[1]:.2f}") + log_func(" This range minimizes both overflow and underflow issues") + else: + log_func("⚠️ All scaling levels have some overflow/underflow - choose least problematic") + # Find least problematic range + least_problematic = min(overflow_underflow_results, key=lambda x: max(x['overflow_percent'], x['underflow_percent'])) + log_func(f"💡 Least problematic scaling: {least_problematic['scale_exp']:.2f}") + log_func(f" Overflow: {least_problematic['overflow_percent']:.2f}%, Underflow: {least_problematic['underflow_percent']:.2f}%") + + log_func("=" * 80) + +def quantize_with_fixed_scale(input_tensor, elem_format, scale_bits, scale_exp, + ebits, mbits, max_norm, axes=None, block_size=0): + """ + Custom quantization function with fixed scale exponent. + This function simulates the exact behavior of mxfp.py _quantize_mx function. + + Args: + input_tensor (torch.Tensor): Input tensor + elem_format (str): Element format + scale_bits (int): Number of scale bits + scale_exp (float): Fixed scale exponent (log2 of scaling factor) + ebits (int): Exponent bits + mbits (int): Mantissa bits + max_norm (float): Maximum normal value + axes (list): Axes for shared exponent calculation + block_size (int): Block size for tiling + + Returns: + tuple: (quantized_tensor, overflow_underflow_analysis) + """ + A = input_tensor.clone() + + # Apply scaling directly (this simulates the A = A / (2**shared_exp) step in mxfp.py) + scale_factor = 2.0 ** scale_exp # Use float to handle negative exponents + A = A / scale_factor + + # Quantize element-wise + from quant.mxfp import _quantize_elemwise_core,_analyze_overflow_underflow_before_quantization + + # Analyze overflow/underflow without printing (collect results) + overflow_underflow_analysis = _analyze_overflow_underflow_before_quantization( + A, elem_format, mbits, ebits, max_norm, verbose=False + ) + + A = _quantize_elemwise_core( + A, mbits, ebits, max_norm, round='nearest', + allow_denorm=True, saturate_normals=True + ) + + # Undo scaling + A = A * scale_factor + + return A, overflow_underflow_analysis + +def plot_scaling_results(results, output_path): + """ + Create comprehensive plots showing scaling test results. + + Args: + results (dict): Results from test_scaling_levels + output_path (Path): Output directory for plots + """ + scale_exponents = results['scale_exponents'] + elem_format = results['elem_format'] + + # Extract metrics for plotting + metrics_data = {} + for metric_name in ['mse', 'rmse', 'cosine_similarity', 'psnr', 'mae', 'max_abs_error', 'relative_error']: + metrics_data[metric_name] = [] + for i in range(len(scale_exponents)): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + metrics_data[metric_name].append(results['metrics'][scale_key]['metrics'][metric_name]) + + # Create figure with subplots + fig, axes = plt.subplots(3, 2, figsize=(15, 18)) + fig.suptitle(f'MXFP Scaling Test Results - {elem_format.upper()}', fontsize=16, fontweight='bold') + + # Plot 1: MSE + axes[0, 0].semilogy(scale_exponents, metrics_data['mse'], 'b-o', linewidth=2, markersize=4) + axes[0, 0].set_xlabel('Scale Exponent') + axes[0, 0].set_ylabel('MSE (log scale)') + axes[0, 0].set_title('Mean Squared Error vs Scale Exponent') + axes[0, 0].grid(True, alpha=0.3) + + # Plot 2: Cosine Similarity + axes[0, 1].plot(scale_exponents, metrics_data['cosine_similarity'], 'g-o', linewidth=2, markersize=4) + axes[0, 1].set_xlabel('Scale Exponent') + axes[0, 1].set_ylabel('Cosine Similarity') + axes[0, 1].set_title('Cosine Similarity vs Scale Exponent') + axes[0, 1].grid(True, alpha=0.3) + axes[0, 1].set_ylim([0, 1]) + + # Plot 3: PSNR + # Handle infinite PSNR values + psnr_values = metrics_data['psnr'] + psnr_finite = [p if p != float('inf') else 1000 for p in psnr_values] # Cap at 1000 for plotting + + axes[1, 0].plot(scale_exponents, psnr_finite, 'r-o', linewidth=2, markersize=4) + axes[1, 0].set_xlabel('Scale Exponent') + axes[1, 0].set_ylabel('PSNR (dB)') + axes[1, 0].set_title('Peak Signal-to-Noise Ratio vs Scale Exponent') + axes[1, 0].grid(True, alpha=0.3) + + # Plot 4: MAE + axes[1, 1].semilogy(scale_exponents, metrics_data['mae'], 'm-o', linewidth=2, markersize=4) + axes[1, 1].set_xlabel('Scale Exponent') + axes[1, 1].set_ylabel('MAE (log scale)') + axes[1, 1].set_title('Mean Absolute Error vs Scale Exponent') + axes[1, 1].grid(True, alpha=0.3) + + # Plot 5: Maximum Absolute Error + axes[2, 0].semilogy(scale_exponents, metrics_data['max_abs_error'], 'c-o', linewidth=2, markersize=4) + axes[2, 0].set_xlabel('Scale Exponent') + axes[2, 0].set_ylabel('Max Absolute Error (log scale)') + axes[2, 0].set_title('Maximum Absolute Error vs Scale Exponent') + axes[2, 0].grid(True, alpha=0.3) + + # Plot 6: Relative Error + axes[2, 1].plot(scale_exponents, metrics_data['relative_error'], 'orange', marker='o', linewidth=2, markersize=4) + axes[2, 1].set_xlabel('Scale Exponent') + axes[2, 1].set_ylabel('Relative Error (%)') + axes[2, 1].set_title('Relative Error vs Scale Exponent') + axes[2, 1].grid(True, alpha=0.3) + + # Add format information + format_params = results['format_params'] + info_text = f"Format: {elem_format}\nE-bits: {format_params['ebits']}, M-bits: {format_params['mbits']}\n" + info_text += f"Max Normal: ±{format_params['max_norm']:.1e}\nMin Normal: {format_params['min_norm']:.1e}" + + fig.text(0.02, 0.02, info_text, fontsize=10, verticalalignment='bottom', + bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgray', alpha=0.8)) + + plt.tight_layout() + plt.subplots_adjust(top=0.93, bottom=0.15) + + # Save plot + plot_path = output_path / f'mxfp_scaling_test_{elem_format}.png' + plt.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='white') + plt.close() + + # This will be logged by the caller + pass + + # Create summary plot with key metrics + create_summary_plot(results, output_path) + +def create_summary_plot(results, output_path): + """Create a summary plot with the most important metrics.""" + scale_exponents = results['scale_exponents'] + elem_format = results['elem_format'] + + # Extract key metrics + mse_values = [] + cosine_sim_values = [] + psnr_values = [] + + for i in range(len(scale_exponents)): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + metrics = results['metrics'][scale_key]['metrics'] + mse_values.append(metrics['mse']) + cosine_sim_values.append(metrics['cosine_similarity']) + psnr_values.append(metrics['psnr']) + + # Handle infinite PSNR values + psnr_finite = [p if p != float('inf') else 1000 for p in psnr_values] + + # Create summary plot + fig, ax1 = plt.subplots(figsize=(12, 8)) + + # Plot MSE and PSNR on left y-axis + color1 = 'tab:blue' + ax1.set_xlabel('Scale Exponent', fontsize=12) + ax1.set_ylabel('MSE (log scale)', color=color1, fontsize=12) + line1 = ax1.semilogy(scale_exponents, mse_values, 'o-', color=color1, linewidth=2, markersize=6, label='MSE') + ax1.tick_params(axis='y', labelcolor=color1) + ax1.grid(True, alpha=0.3) + + # Create second y-axis for cosine similarity + ax2 = ax1.twinx() + color2 = 'tab:green' + ax2.set_ylabel('Cosine Similarity', color=color2, fontsize=12) + line2 = ax2.plot(scale_exponents, cosine_sim_values, 's-', color=color2, linewidth=2, markersize=6, label='Cosine Similarity') + ax2.tick_params(axis='y', labelcolor=color2) + ax2.set_ylim([0, 1]) + + # Add PSNR as dashed line on ax1 + ax1_2 = ax1.twinx() + ax1_2.spines['right'].set_position(('outward', 60)) + color3 = 'tab:red' + ax1_2.set_ylabel('PSNR (dB)', color=color3, fontsize=12) + line3 = ax1_2.plot(scale_exponents, psnr_finite, '^-', color=color3, linewidth=2, markersize=6, linestyle='--', label='PSNR') + ax1_2.tick_params(axis='y', labelcolor=color3) + + # Add title and legend + plt.title(f'MXFP Scaling Test Summary - {elem_format.upper()}\nKey Metrics vs Scale Exponent', + fontsize=14, fontweight='bold', pad=20) + + # Combine legends + lines = line1 + line2 + line3 + labels = [l.get_label() for l in lines] + ax1.legend(lines, labels, loc='upper right', fontsize=10) + + plt.tight_layout() + + # Save summary plot + summary_path = output_path / f'mxfp_scaling_summary_{elem_format}.png' + plt.savefig(summary_path, dpi=300, bbox_inches='tight', facecolor='white') + plt.close() + + # This will be logged by the caller + pass + +def save_results_to_file(results, output_path): + """Save detailed results to a text file.""" + results_path = output_path / f'mxfp_scaling_results_{results["elem_format"]}.txt' + + with open(results_path, 'w') as f: + f.write("MXFP Scaling Test Results\n") + f.write("=" * 50 + "\n\n") + + f.write(f"Element Format: {results['elem_format']}\n") + f.write(f"Scale Bits: {results['scale_bits']}\n") + f.write(f"Format Parameters: {results['format_params']}\n\n") + + f.write("Detailed Results:\n") + f.write("-" * 30 + "\n") + + for i, scale_exp in enumerate(results['scale_exponents']): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + metrics = results['metrics'][scale_key]['metrics'] + overflow_underflow_analysis = results['metrics'][scale_key]['overflow_underflow_analysis'] + + f.write(f"Scale Exponent {scale_exp:.2f} (Factor: {2**scale_exp:.6f}):\n") + f.write(" Performance Metrics:\n") + f.write(f" MSE: {metrics['mse']:.6e}\n") + f.write(f" RMSE: {metrics['rmse']:.6e}\n") + f.write(f" Cosine Similarity: {metrics['cosine_similarity']:.6f}\n") + f.write(f" PSNR: {metrics['psnr']:.2f} dB\n") + f.write(f" MAE: {metrics['mae']:.6e}\n") + f.write(f" Max Absolute Error: {metrics['max_abs_error']:.6e}\n") + f.write(f" Relative Error: {metrics['relative_error']:.2f}%\n") + + f.write(" Overflow/Underflow Analysis:\n") + f.write(f" Total Elements: {overflow_underflow_analysis['total_elements']:,}\n") + f.write(f" Underflow Count: {overflow_underflow_analysis['underflow_count']:,} ({overflow_underflow_analysis['underflow_percent']:.2f}%)\n") + f.write(f" Flush to Zero Count: {overflow_underflow_analysis['flush_count']:,} ({overflow_underflow_analysis['flush_percent']:.2f}%)\n") + f.write(f" Overflow Count: {overflow_underflow_analysis['overflow_count']:,} ({overflow_underflow_analysis['overflow_percent']:.2f}%)\n") + f.write(f" Min Denormal: {overflow_underflow_analysis['min_denormal']:.2e}\n") + f.write(f" Min Normal: {overflow_underflow_analysis['min_norm']:.2e}\n") + f.write(f" Max Normal: {overflow_underflow_analysis['max_norm']:.2e}\n") + f.write(f" Tensor Range: [{overflow_underflow_analysis['tensor_range'][0]:.2e}, {overflow_underflow_analysis['tensor_range'][1]:.2e}]\n") + f.write(f" Severity: {overflow_underflow_analysis['severity'].upper()}\n") + f.write(f" Has Significant Underflow: {'Yes' if overflow_underflow_analysis['has_significant_underflow'] else 'No'}\n") + f.write(f" Has Significant Overflow: {'Yes' if overflow_underflow_analysis['has_significant_overflow'] else 'No'}\n") + if overflow_underflow_analysis['error']: + f.write(f" Analysis Error: {overflow_underflow_analysis['error']}\n") + f.write("\n") + + # This will be logged by the caller + pass + +def process_single_tensor(input_path, args, logger=None): + """Process a single tensor file.""" + + # Validate input file + if not input_path.exists(): + print(f"Error: Input file does not exist: {input_path}") + return 1 + + if not input_path.is_file(): + print(f"Error: Input path is not a file: {input_path}") + return 1 + + # Setup output directory + if args.output_dir is None: + # Generate output directory based on tensor name + tensor_name = input_path.stem # Get filename without extension + output_dir = Path(f"./draw/scaling_analysis/{args.elem_format}/{tensor_name}") + else: + output_dir = Path(args.output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + + # Setup logging for this tensor + tensor_name = input_path.stem + tensor_logger = setup_logging(output_dir, tensor_name, args.elem_format) + + tensor_logger.info(f"Loading input tensor: {input_path.name}") + tensor_logger.info("=" * 60) + + # Load input tensor + try: + input_tensor = torch.load(str(input_path), map_location='cpu', weights_only=False) + + # Handle case where loaded object is not a tensor + if not isinstance(input_tensor, torch.Tensor): + if isinstance(input_tensor, dict) and 'tensor' in input_tensor: + input_tensor = input_tensor['tensor'] + elif isinstance(input_tensor, (list, tuple)) and len(input_tensor) > 0: + input_tensor = input_tensor[0] + else: + tensor_logger.error(f"Error: Loaded object is not a tensor: {input_path.name}") + return 1 + + # Convert to BF16 if needed + if input_tensor.dtype != torch.bfloat16: + tensor_logger.info(f"Converting tensor from {input_tensor.dtype} to bfloat16") + input_tensor = input_tensor.bfloat16() + + tensor_logger.info(f"Tensor shape: {input_tensor.shape}") + tensor_logger.info(f"Tensor dtype: {input_tensor.dtype}") + tensor_logger.info(f"Value range: [{torch.min(input_tensor):.6f}, {torch.max(input_tensor):.6f}]") + tensor_logger.info(f"Mean ± Std: {torch.mean(input_tensor):.6f} ± {torch.std(input_tensor):.6f}") + + except Exception as e: + tensor_logger.error(f"Error loading tensor {input_path.name}: {str(e)}") + return 1 + + # Run scaling test + results = test_scaling_levels( + input_tensor, + args.elem_format, + args.scale_bits, + max_scale_exp=args.max_scale_exp, + min_scale_exp=args.min_scale_exp, + logger=tensor_logger + ) + + # Save results to file + save_results_to_file(results, output_dir) + tensor_logger.info(f"Detailed results saved to: {output_dir}") + + # Generate plots unless disabled + if not args.no_plots: + plot_scaling_results(results, output_dir) + tensor_logger.info(f"Plots saved to: {output_dir}") + + # Perform detailed analysis + analysis_results = analyze_scaling_results(results, tensor_logger) + + # Analyze overflow/underflow results + analyze_overflow_underflow_results(results, tensor_logger) + + # Print summary + tensor_logger.info("\n" + "=" * 60) + tensor_logger.info("SCALING TEST SUMMARY") + tensor_logger.info("=" * 60) + + # Use analysis results for summary + best_composite = analysis_results['best_composite'] + best_mse = analysis_results['best_mse'] + best_cosine = analysis_results['best_cosine'] + + tensor_logger.info(f"Best Cosine Similarity: {best_cosine['metrics']['cosine_similarity']:.6f} at scale {best_cosine['scale_exp']:.2f}") + tensor_logger.info(f"Best MSE: {best_mse['metrics']['mse']:.6e} at scale {best_mse['scale_exp']:.2f}") + tensor_logger.info(f"Best PSNR: {best_mse['metrics']['psnr']:.2f} dB at scale {best_mse['scale_exp']:.2f}") + + tensor_logger.info(f"\n🎯 RECOMMENDED Scaling Factor: {best_composite['scale_factor']:.6f}") + tensor_logger.info(f" Scale Exponent: {best_composite['scale_exp']:.2f}") + tensor_logger.info(f" Composite Score: {best_composite['composite_score']:.4f}") + + tensor_logger.info(f"\nResults saved to: {output_dir}") + tensor_logger.info("Test completed successfully!") + + # Log completion time + tensor_logger.info("=" * 80) + tensor_logger.info(f"Test completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + tensor_logger.info("=" * 80) + + return 0 + +def main(): + """Main function for MXFP scaling test.""" + parser = argparse.ArgumentParser(description='Test different scaling strategies for MXFP quantization') + parser.add_argument('input_tensors', nargs='+', help='Path(s) to input BF16 tensor file(s) (.pt)') + parser.add_argument('--output-dir', default=None, + help='Output directory for results (default: ./draw/scaling_analysis/{args.elem_format}/{tensor_name}/)') + parser.add_argument('--elem-format', default='fp8_e4m3', + choices=['fp8_e4m3', 'fp8_e5m2', 'fp4_e2m1', 'fp6_e3m2', 'fp6_e2m3'], + help='Element format for quantization (default: fp8_e4m3)') + parser.add_argument('--scale-bits', type=int, default=8, + help='Number of scale bits (default: 8)') + parser.add_argument('--max-scale-exp', type=int, default=10, + help='Maximum scale exponent (default: auto-calculated from tensor max if using default value)') + parser.add_argument('--min-scale-exp', type=int, default=-10, + help='Minimum scale exponent (default: auto-calculated from tensor min if using default value)') + parser.add_argument('--no-plots', action='store_true', + help='Skip generating plots') + + args = parser.parse_args() + + # Process multiple tensors + total_tensors = len(args.input_tensors) + successful_tests = 0 + + print(f"Processing {total_tensors} tensor(s)...") + print("=" * 80) + + for i, tensor_path in enumerate(args.input_tensors, 1): + print(f"\n[{i}/{total_tensors}] Processing: {tensor_path}") + print("-" * 60) + + input_path = Path(tensor_path) + result = process_single_tensor(input_path, args) + + if result == 0: + successful_tests += 1 + print(f"✅ Successfully processed: {tensor_path}") + else: + print(f"❌ Failed to process: {tensor_path}") + + # Final summary + print("\n" + "=" * 80) + print("FINAL SUMMARY") + print("=" * 80) + print(f"Total tensors: {total_tensors}") + print(f"Successful: {successful_tests}") + print(f"Failed: {total_tensors - successful_tests}") + + if successful_tests == total_tensors: + print("🎉 All tests completed successfully!") + return 0 + else: + print("⚠️ Some tests failed. Check individual logs for details.") + return 1 + +if __name__ == '__main__': + exit(main()) diff --git a/quant/opt.py b/quant/opt.py new file mode 100644 index 0000000000..4d37fdcd59 --- /dev/null +++ b/quant/opt.py @@ -0,0 +1,303 @@ +import math +import torch +# import torch_npu +from typing import Tuple + +def optimized_mxfp8_e4m3_matmul(A: torch.Tensor, B: torch.Tensor, block_size: int = 32) -> torch.Tensor: + """优化的MXFP8矩阵乘法实现""" + device = A.device + M, K = A.shape + K2, N = B.shape + assert K == K2, "Inner dimensions must match" + + # E4M3格式常量 + MAX_VAL = 448.0 + + # 预分配输出张量 + C = torch.zeros((M, N), dtype=torch.float32, device=device) + + # 向量化的量化函数 + def vectorized_quantize_e4m3(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """向量化的E4M3量化""" + # x shape: (M, num_blocks, block_size) 或 (num_blocks, block_size, N) + # scale shape: (M, num_blocks, 1) 或 (num_blocks, 1, N) + + # 缩放输入 + scaled_x = x / scale + + # 处理符号 + sign = torch.sign(scaled_x) + abs_x = torch.abs(scaled_x) + + # 饱和处理 + abs_x = torch.clamp(abs_x, 0, MAX_VAL) + + # 近似量化 - 使用查找表或简化的量化 + # 这里使用简化版本:将值量化到2^n * (1 + k/8)的网格 + log_abs = torch.log2(torch.clamp(abs_x, min=1e-10)) + exp = torch.floor(log_abs).clamp(-6, 8) + + # 计算量化后的值 + base = torch.pow(2.0, exp) + normalized = abs_x / base + # 量化mantissa到8个level (3 bits) + quantized_mant = torch.round((normalized - 1.0) * 8) / 8 + quantized_mant = torch.clamp(quantized_mant, 0, 7/8) + + result = sign * base * (1.0 + quantized_mant) + return result + + # 批量处理A的行块 + def process_A_blocks(A_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """批量处理A的行块量化""" + M, K = A_tensor.shape + num_blocks = (K + block_size - 1) // block_size + + # 重塑为块结构 + padded_K = num_blocks * block_size + if padded_K > K: + A_padded = torch.zeros((M, padded_K), dtype=A_tensor.dtype, device=A_tensor.device) + A_padded[:, :K] = A_tensor + else: + A_padded = A_tensor + + A_blocks = A_padded.view(M, num_blocks, block_size) + + # 计算每个块的缩放因子 (M, num_blocks, 1) + block_max = torch.abs(A_blocks).max(dim=2, keepdim=True)[0] + scales = torch.pow(2.0, torch.ceil(torch.log2(block_max / MAX_VAL))) + scales = torch.where(block_max == 0, torch.ones_like(scales), scales) + + # 向量化量化 + A_quantized = vectorized_quantize_e4m3(A_blocks, scales) + + return A_quantized.view(M, padded_K)[:, :K], scales.squeeze(-1) + + # 批量处理B的列块 + def process_B_blocks(B_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """批量处理B的列块量化""" + K, N = B_tensor.shape + num_blocks = (K + block_size - 1) // block_size + + # 重塑为块结构 + padded_K = num_blocks * block_size + if padded_K > K: + B_padded = torch.zeros((padded_K, N), dtype=B_tensor.dtype, device=B_tensor.device) + B_padded[:K, :] = B_tensor + else: + B_padded = B_tensor + + B_blocks = B_padded.view(num_blocks, block_size, N) + + # 计算每个块的缩放因子 (num_blocks, 1, N) + block_max = torch.abs(B_blocks).max(dim=1, keepdim=True)[0] + scales = torch.pow(2.0, torch.ceil(torch.log2(block_max / MAX_VAL))) + scales = torch.where(block_max == 0, torch.ones_like(scales), scales) + + # 向量化量化 + B_quantized = vectorized_quantize_e4m3(B_blocks, scales) + + return B_quantized.view(padded_K, N)[:K, :], scales.squeeze(1) + + # 量化A和B + A_quantized, A_scales = process_A_blocks(A) + B_quantized, B_scales = process_B_blocks(B) + + # 分块矩阵乘法 + num_blocks = (K + block_size - 1) // block_size + + for block_idx in range(num_blocks): + start_k = block_idx * block_size + end_k = min(start_k + block_size, K) + + # 提取当前块 + A_block = A_quantized[:, start_k:end_k] + B_block = B_quantized[start_k:end_k, :] + + # 计算部分乘积 + partial = torch.matmul(A_block, B_block) + + # 应用缩放因子 + A_scale_block = A_scales[:, block_idx:block_idx+1] + B_scale_block = B_scales[block_idx:block_idx+1, :] + combined_scale = A_scale_block * B_scale_block + + # 累加到结果 + C += partial * combined_scale + + return C + + +def optimized_mxfp8_e5m2_matmul(A: torch.Tensor, B: torch.Tensor, block_size: int = 32) -> torch.Tensor: + """优化的MXFP8-E5M2矩阵乘法实现""" + device = A.device + M, K = A.shape + K2, N = B.shape + assert K == K2, "Inner dimensions must match" + + # E5M2格式常量 + # E5M2: 1 sign + 5 exponent + 2 mantissa + # 指数范围: -15 到 +16 (偏置15) + # 最大值: 2^16 * (1 + 3/4) = 2^16 * 1.75 = 114688 + MAX_VAL = 114688.0 + MIN_EXP = -15 + MAX_EXP = 16 + + # 预分配输出张量 + C = torch.zeros((M, N), dtype=torch.float32, device=device) + + # 向量化的量化函数 + def vectorized_quantize_e5m2(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """向量化的E5M2量化""" + # x shape: (M, num_blocks, block_size) 或 (num_blocks, block_size, N) + # scale shape: (M, num_blocks, 1) 或 (num_blocks, 1, N) + + # 缩放输入 + scaled_x = x / scale + + # 处理符号 + sign = torch.sign(scaled_x) + abs_x = torch.abs(scaled_x) + + # 饱和处理 + abs_x = torch.clamp(abs_x, 0, MAX_VAL) + + # 处理特殊情况:零值 + zero_mask = (abs_x == 0) + + # E5M2量化 + # 计算指数 + log_abs = torch.log2(torch.clamp(abs_x, min=1e-20)) + exp = torch.floor(log_abs).clamp(MIN_EXP, MAX_EXP) + + # 计算基数值 + base = torch.pow(2.0, exp) + + # 计算归一化的尾数 (范围 [1, 2)) + normalized = abs_x / base + + # 量化尾数到4个级别 (2 bits: 00, 01, 10, 11) + # 对应值: 1.00, 1.25, 1.50, 1.75 + quantized_mant_idx = torch.round((normalized - 1.0) * 4).clamp(0, 3) + quantized_mant = 1.0 + quantized_mant_idx / 4.0 + + # 计算量化后的结果 + result = sign * base * quantized_mant + + # 处理零值 + result = torch.where(zero_mask, torch.zeros_like(result), result) + + return result + + # 批量处理A的行块 + def process_A_blocks(A_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """批量处理A的行块量化""" + M, K = A_tensor.shape + num_blocks = (K + block_size - 1) // block_size + + # 重塑为块结构 + padded_K = num_blocks * block_size + if padded_K > K: + A_padded = torch.zeros((M, padded_K), dtype=A_tensor.dtype, device=A_tensor.device) + A_padded[:, :K] = A_tensor + else: + A_padded = A_tensor + + A_blocks = A_padded.view(M, num_blocks, block_size) + + # 计算每个块的缩放因子 (M, num_blocks, 1) + block_max = torch.abs(A_blocks).max(dim=2, keepdim=True)[0] + + # 对于E5M2,缩放因子计算需要考虑更大的指数范围 + scales = torch.pow(2.0, torch.ceil(torch.log2(block_max / MAX_VAL))) + scales = torch.where(block_max == 0, torch.ones_like(scales), scales) + + # 向量化量化 + A_quantized = vectorized_quantize_e5m2(A_blocks, scales) + + return A_quantized.view(M, padded_K)[:, :K], scales.squeeze(-1) + + # 批量处理B的列块 + def process_B_blocks(B_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """批量处理B的列块量化""" + K, N = B_tensor.shape + num_blocks = (K + block_size - 1) // block_size + + # 重塑为块结构 + padded_K = num_blocks * block_size + if padded_K > K: + B_padded = torch.zeros((padded_K, N), dtype=B_tensor.dtype, device=B_tensor.device) + B_padded[:K, :] = B_tensor + else: + B_padded = B_tensor + + B_blocks = B_padded.view(num_blocks, block_size, N) + + # 计算每个块的缩放因子 (num_blocks, 1, N) + block_max = torch.abs(B_blocks).max(dim=1, keepdim=True)[0] + + # 对于E5M2,缩放因子计算需要考虑更大的指数范围 + scales = torch.pow(2.0, torch.ceil(torch.log2(block_max / MAX_VAL))) + scales = torch.where(block_max == 0, torch.ones_like(scales), scales) + + # 向量化量化 + B_quantized = vectorized_quantize_e5m2(B_blocks, scales) + + return B_quantized.view(padded_K, N)[:K, :], scales.squeeze(1) + + # 量化A和B + A_quantized, A_scales = process_A_blocks(A) + B_quantized, B_scales = process_B_blocks(B) + + # 分块矩阵乘法 + num_blocks = (K + block_size - 1) // block_size + + for block_idx in range(num_blocks): + start_k = block_idx * block_size + end_k = min(start_k + block_size, K) + + # 提取当前块 + A_block = A_quantized[:, start_k:end_k] + B_block = B_quantized[start_k:end_k, :] + + # 计算部分乘积 + partial = torch.matmul(A_block, B_block) + + # 应用缩放因子 + A_scale_block = A_scales[:, block_idx:block_idx+1] + B_scale_block = B_scales[block_idx:block_idx+1, :] + combined_scale = A_scale_block * B_scale_block + + # 累加到结果 + C += partial * combined_scale + + return C + + +if __name__ == "__main__": + M, K, N = 128, 256, 64 + M, K, N = 1024, 256, 1024 + + A = torch.rand((M, N), dtype=torch.bfloat16).cuda() + B = torch.rand((N, K), dtype=torch.bfloat16).cuda() + C = torch.matmul(A, B) + + C_opt = optimized_mxfp8_e4m3_matmul(A, B) + mse_opt = torch.mean((C - C_opt) ** 2) + max_err_opt = torch.max(torch.abs(C - C_opt)) + + print(f"E4M3 OPT MSE: {mse_opt:.6f}") + print(f"E4M3 OPT Max Error: {max_err_opt:.6f}") + print(f"E4M3 OPT 相对误差: {mse_opt / torch.mean(C ** 2):.6f}") + + A = torch.rand((M, N), dtype=torch.bfloat16).cuda() * 1e-15 # 模拟梯度数值范围 + B = torch.rand((N, K), dtype=torch.bfloat16).cuda() + + C = torch.matmul(A, B) + + C_e5m2 = optimized_mxfp8_e5m2_matmul(A, B) + mse_e5m2 = torch.mean((C - C_e5m2) ** 2) + max_err_e5m2 = torch.max(torch.abs(C - C_e5m2)) + print(f"E5M2 OPT MSE: {mse_e5m2:.20f}") + print(f"E5M2 OPT Max Error: {max_err_e5m2:.20f}") + print(f"E5M2 OPT 相对误差: {mse_e5m2 / torch.mean(C ** 2):.20f}") diff --git a/quant/plot_loss_curve.py b/quant/plot_loss_curve.py new file mode 100644 index 0000000000..d02333acf4 --- /dev/null +++ b/quant/plot_loss_curve.py @@ -0,0 +1,79 @@ +import matplotlib.pyplot as plt +import re +import numpy as np +from tqdm import tqdm +import os +def parse_log_file(log_file_path): + """ + 解析日志文件,提取 iteration 和 lm loss + """ + iterations = [] + losses = [] + + with open(log_file_path, 'r') as file: + for line in tqdm(file): + # 使用正则表达式匹配所需的值 + + match = re.search(r'\[.*?\]\s+iteration\s+(\d+)/\s*\d+\s*\|.*?lm loss:\s*([\d.E+-]+)', line) + if match: + # import pdb;pdb.set_trace() + iteration = int(match.group(1)) + loss = float(match.group(2)) + if iteration > 300: + continue + iterations.append(iteration) + losses.append(loss) + # print(f"Extracted: iteration={iteration}, loss={loss}") + else: + pass + # print(f"Failed to match line: {line.strip()}") + + return iterations, losses + +def plot_loss_curve(iterations, losses, labels, output_file='loss_curve.png'): + plt.figure(figsize=(10, 6)) + for i in range(len(iterations)): + # import pdb ;pdb.set_trace() + plt.plot(iterations[i], losses[i], label=labels[i], linewidth=0.1) + plt.xlabel('Iteration') + plt.ylabel('LM Loss') + plt.title('LM Loss vs Iteration') + plt.legend() + plt.grid(True) + plt.savefig(output_file) + plt.show() + +def avg_rel_error(iterations,losses,max_len=20000): + min_len = min(len(iteration) for iteration in iterations) + idx = max(0,min_len-max_len) + bf16 = np.array(losses[0][idx:min_len],dtype=float) + res = [] + for i in range(1,len(iterations)): + quant = np.array(losses[i][idx:min_len],dtype=float) + rel = np.abs(bf16 - quant) / bf16 + res.append(float(rel.sum() / (min_len-idx))) + return res + + +if __name__ == "__main__": + LOG_PATH='/mtc_afs/charles/Megatron-LM/tensorboard_logs/llama3_8b_fp8' + log_files_llama3_8b_pretrain=['training_wikipedia_bf16_25-08-05_03-14-01.log','training_wikipedia_25-08-03_22-08-24.log','training_wikipedia_fp8_25-08-03_22-07-19.log','training_wikipedia_fp4_25-08-03_22-05-47.log'] + quant_labels=['bf16','te_fp8','fp8','fp4'] + output_image_path = '/mtc_afs/charles/Megatron-LM/quant/curve/loss_curve_cmp_non_pretrain.png' + # select mode + log_files = log_files_llama3_8b_pretrain + labels = quant_labels + iterations,losses = [],[] + for log_file in log_files: + log_file_path = os.path.join(LOG_PATH,log_file) + iteration, loss = parse_log_file(log_file_path) + iterations.append(iteration) + losses.append(loss) + + plot_loss_curve(iterations, losses, labels, output_image_path) + loss_res = avg_rel_error(iterations,losses) + for i in range(1,len(iterations)): + print(f"{labels[i]}:{loss_res[i-1]}") + loss_res = avg_rel_error(iterations,losses,500) + for i in range(1,len(iterations)): + print(f"{labels[i]} in last 500:{loss_res[i-1]}") diff --git a/quant/profiling.py b/quant/profiling.py new file mode 100644 index 0000000000..f7043f9210 --- /dev/null +++ b/quant/profiling.py @@ -0,0 +1,434 @@ +import torch +import torch_npu +import torch.nn.functional as F +from enum import Enum, IntEnum +import torch._dynamo as dynamo + + +FP32_EXPONENT_BIAS = 127 +FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1) + +# 替换 pow(2, x) 为更高效的位运算 +# 原代码中的 2**shared_exp 可以优化为: +def fast_power_of_2(exp): + return torch.exp2(exp) # 或者使用位移操作 + + +def _safe_lshift(x, bits, exp): + if exp is None: + return x * (2**bits) + else: + return x / (2 ** exp) * (2**bits) + + +def _safe_rshift(x, bits, exp): + if exp is None: + return x / (2**bits) + else: + return x / (2**bits) * (2 ** exp) + + +def _round_mantissa(A, bits, round, clamp=False): + """ + Rounds mantissa to nearest bits depending on the rounding method 'round' + Args: + A {PyTorch tensor} -- Input tensor + round {str} -- Rounding method + "floor" rounds to the floor + "nearest" rounds to ceil or floor, whichever is nearest + Returns: + A {PyTorch tensor} -- Tensor with mantissas rounded + """ + + if round == "dither": + rand_A = torch.rand_like(A, requires_grad=False) + A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A) + elif round == "floor": + A = torch.sign(A) * torch.floor(torch.abs(A)) + elif round == "nearest": + A = torch.sign(A) * torch.floor(torch.abs(A) + 0.5) + elif round == "even": + absA = torch.abs(A) + # find 0.5, 2.5, 4.5 ... + maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype) + A = torch.sign(A) * (torch.floor(absA + 0.5) - maskA) + else: + raise Exception("Unrecognized round method %s" % (round)) + + # Clip values that cannot be expressed by the specified number of bits + if clamp: + max_mantissa = 2 ** (bits - 1) - 1 + A = torch.clamp(A, -max_mantissa, max_mantissa) + return A + + +def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest', + saturate_normals=False, allow_denorm=True): + """ Core function used for element-wise quantization + Arguments: + A {PyTorch tensor} -- A tensor to be quantized + bits {int} -- Number of mantissa bits. Includes + sign bit and implicit one for floats + exp_bits {int} -- Number of exponent bits, 0 for ints + max_norm {float} -- Largest representable normal number + round {str} -- Rounding mode: (floor, nearest, even) + saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf) + that exceed max norm are clamped. + Must be True for correct MX conversion. + allow_denorm {bool} -- If False, flush denorm numbers in the + elem_format to zero. + Returns: + quantized tensor {PyTorch tensor} -- A tensor that has been quantized + """ + + out = A + + private_exp = torch.floor(torch.log2( + torch.abs(A) + (A == 0).type(A.dtype))) + + # The minimum representable exponent for 8 exp bits is -126 + min_exp = -(2**(exp_bits-1)) + 2 + private_exp = private_exp.clip(min=min_exp) + + # Scale up so appropriate number of bits are in the integer portion of the number + out = _safe_lshift(out, bits - 2, private_exp) + + out = _round_mantissa(out, bits, round, clamp=False) + + # Undo scaling + out = _safe_rshift(out, bits - 2, private_exp) + + # Set values > max_norm to Inf if desired, else clamp them + out = torch.clamp(out, min=-max_norm, max=max_norm) + + # handle Inf/NaN + out[A == float("Inf")] = float("Inf") + out[A == -float("Inf")] = -float("Inf") + out[A == float("NaN")] = float("NaN") + + return out + + +def _shared_exponents(A, method="max", axes=None, ebits=0): + """ + Get shared exponents for the passed matrix A. + Args: + A {PyTorch tensor} -- Input tensor + method {str} -- Exponent selection method. + "max" uses the max absolute value + "none" uses an exponent for each value (i.e., no sharing) + axes {list(int)} -- List of integers which specifies the axes across which + shared exponents are calculated. + Returns: + shared_exp {PyTorch tensor} -- Tensor of shared exponents + """ + + if method == "max": + if axes is None: + shared_exp = torch.max(torch.abs(A)) + else: + shared_exp = A + for axis in axes: + shared_exp, _ = torch.max(torch.abs(shared_exp), dim=axis, keepdim=True) + elif method == "none": + shared_exp = torch.abs(A) + else: + raise Exception("Unrecognized shared exponent selection method %s" % (method)) + + # log2(shared_exp) and truncate to integer + shared_exp = torch.floor( + torch.log2( + shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype) + ) + ) + + # Restrict to [-emax, emax] range + if ebits > 0: + emax = 2**(ebits-1) - 1 + #shared_exp = torch.clamp(shared_exp, -emax, emax) + # Overflow to Inf + shared_exp[shared_exp > emax] = float("NaN") + # Underflows are set to -127 which causes them to be + # flushed to 0 later + shared_exp[shared_exp < -emax] = -emax + + return shared_exp + + +def _reshape_to_blocks(A, axes, block_size): + if axes is None: + raise Exception( + "axes required in order to determine which " + "dimension toapply block size to" + ) + if block_size == 0: + raise Exception("block_size == 0 in _reshape_to_blocks") + + # Fix axes to be positive and sort them + axes = [(x + len(A.shape) if x < 0 else x) for x in axes] + assert all(x >= 0 for x in axes) + axes = sorted(axes) + + # Add extra dimension for tiles + for i in range(len(axes)): + axes[i] += i # Shift axes due to added dimensions + A = torch.unsqueeze(A, dim=axes[i] + 1) + + # Pad to block_size + orig_shape = A.size() + pad = [] + for i in range(len(orig_shape)): + pad += [0, 0] + + do_padding = False + for axis in axes: + pre_pad_size = orig_shape[axis] + if isinstance(pre_pad_size, torch.Tensor): + pre_pad_size = int(pre_pad_size.value) + # Don't pad if the axis is short enough to fit inside one tile + if pre_pad_size % block_size == 0: + pad[2 * axis] = 0 + else: + pad[2 * axis] = block_size - pre_pad_size % block_size + do_padding = True + + if do_padding: + pad = list(reversed(pad)) + A = torch.nn.functional.pad(A, pad, mode="constant") + + def _reshape(shape, reshape_block_size): + for axis in axes: + # Reshape to tiles if axis length > reshape_block_size + if shape[axis] >= reshape_block_size: + assert shape[axis] % reshape_block_size == 0 + shape[axis + 1] = reshape_block_size + shape[axis] = shape[axis] // reshape_block_size + # Otherwise preserve length and insert a 1 into the shape + else: + shape[axis + 1] = shape[axis] + shape[axis] = 1 + return shape + + # Reshape to tiles + padded_shape = A.size() + reshape = _reshape(list(padded_shape), block_size) + + A = A.view(reshape) + return A, axes, orig_shape, padded_shape + + +def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes): + # Undo tile reshaping + A = A.view(padded_shape) + # Undo padding + if not list(padded_shape) == list(orig_shape): + slices = [slice(0, x) for x in orig_shape] + A = A[slices] + for axis in reversed(axes): + # Remove extra dimension + A = torch.squeeze(A, dim=axis + 1) + return A + + +def _quantize_mx( + A, + scale_bits, + elem_format, # can be None for no quantization + shared_exp_method="max", + axes=None, + block_size=0, + round="nearest", + flush_fp32_subnorms=False, +): + """Function used for MX* quantization + """ + # Shortcut for no quantization + if elem_format == None: + return A + + assert(scale_bits > 0) + + # Make sure axes is a list of non-negative numbers + axes = [axes] if type(axes) == int else axes + axes = [x + A.ndim if x < 0 else x for x in axes] + + # ebits, mbits, emax, max_norm, _ = _get_format_params(elem_format) + ebits, mbits, emax, max_norm = 4, 5, 8, 448.0 + + # Perform tiling to the hardware vector size + if block_size > 0: + A, axes, orig_shape, padded_shape = _reshape_to_blocks( + A, axes, block_size + ) + + #################### + # Quantize + #################### + shared_exp_axes = [x + 1 for x in axes] if block_size > 0 else axes + + # Get shared exponents + shared_exp = _shared_exponents( + A, method=shared_exp_method, axes=shared_exp_axes, ebits=0, + ) + + # Flush subnormal FP32 inputs to zero + if flush_fp32_subnorms: + A = A * (shared_exp > -FP32_EXPONENT_BIAS).type(A.dtype) + + # Offset the max exponent by the largest representable exponent + # in the element data format + shared_exp = shared_exp - emax + + torch.npu.synchronize() + shape = shared_exp.shape + shared_exp = shared_exp.view(-1) + scale_emax = 127 + shared_exp[shared_exp > scale_emax] = float("NaN") + shared_exp[shared_exp < -scale_emax] = -scale_emax + torch.npu.synchronize() + shared_exp = shared_exp.view(shape) + + A = A / torch.exp2(shared_exp) # 替代 A / (2**shared_exp) + + A = _quantize_elemwise_core( + A, mbits, ebits, max_norm, round=round, + allow_denorm=True, saturate_normals=True) + + A = A * torch.exp2(shared_exp) # 替代 A * (2**shared_exp) + + # Undo tile reshaping + if block_size: + A = _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes) + + return A + + +# 编译关键函数 +# 预热编译,避免运行时编译延迟 +def warmup_compilation(): + print("Warming up compilation...") + with torch.no_grad(): + # 创建小规模测试数据来预热编译 + dummy_A = torch.randn(64, 64, device='npu') + dummy_B = torch.randn(64, 64, device='npu') + + # 预热 _quantize_mx 编译 + _ = _quantize_mx( + dummy_A, + 8, + 'fp8_e4m3', + shared_exp_method="max", + axes=1, + block_size=16, + round="nearest", + flush_fp32_subnorms=False, + ) + + # 预热 _quantize_elemwise_core 编译 + _ = _quantize_elemwise_core( + dummy_A, + 5, 4, 448.0, + round='nearest', + saturate_normals=False, + allow_denorm=True + ) + + torch.npu.synchronize() + print("Compilation warmup completed") + +# 执行预热 +warmup_compilation() + +# 使用编译后的函数(如果需要的话) +# 注意:如果编译开销太大,可以考虑直接使用原函数 +# _quantize_mx_compiled = torch.compile(_quantize_mx) +# _quantize_elemwise_core_compiled = torch.compile(_quantize_elemwise_core) + + +# Load data +A = torch.load("grad_output.pt", map_location='cpu').npu() +print(f"A_shape:{A.shape},grad_max:{torch.max(A)},grad_min:{torch.min(A)}") +B = torch.load("total_input.pt", map_location='cpu').npu() +print(f"B_shape:{B.shape},input_max:{torch.max(B)},input_min:{torch.min(B)}") + +C = torch.matmul(A.t(), B) +print(f"C_shape:{C.shape},output_max:{torch.max(C)},output_min:{torch.min(C)}") + +scale_bits = 8 +elem_format = 'fp8_e4m3' + +# 预热GPU +def warmup_gpu(): + with torch.no_grad(): + dummy = torch.randn(100, 100, device='npu') + _ = torch.matmul(dummy, dummy) + torch.npu.synchronize() + +warmup_gpu() + +# Use PyTorch profiler for performance analysis +import os +trace_dir = "./npu_trace" +os.makedirs(trace_dir, exist_ok=True) +with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + schedule=torch_npu.profiler.schedule(wait=2, warmup=1, active=3, repeat=1), + record_shapes=True, + with_stack=True, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(trace_dir) +) as prof: + # 使用torch.no_grad()减少内存分配 + with torch.no_grad(): + # 批量处理量化操作以减少函数调用开销 + A_T = _quantize_mx( + A.t(), + scale_bits, + elem_format, + shared_exp_method="max", + axes=1, + block_size=16, + round="nearest", + flush_fp32_subnorms=False, + ) + + B = _quantize_mx( + B, + scale_bits, + elem_format, + shared_exp_method="max", + axes=0, + block_size=16, + round="nearest", + flush_fp32_subnorms=False, + ) + + # 使用torch.matmul的优化版本 + C_e4m3 = torch.matmul(A_T, B) + + # 确保计算完成 + torch.npu.synchronize() +# Print profiling results +# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) +import subprocess + +msprof = "/usr/local/Ascend/ascend-toolkit/latest/tools/profiler/bin/msprof" +if os.path.exists(msprof): + cmd = [ + msprof, + "--application=./run_profiling.sh", + f"--output={trace_dir}", + ] + # subprocess.run(cmd, check=True) + # print(f"\n解析完成,请查看 {trace_dir}/summary.csv") +else: + print("\n未找到 msprof.py,请确认 CANN 版本 >= 6.0 或手动使用 MindStudio Insight 打开 json") + + +print(f"C_shape:{C_e4m3.shape},output_max:{torch.max(C_e4m3)},output_min:{torch.min(C_e4m3)}") +print(torch.isnan(C).any()) + +mse_e4m3 = torch.mean((C - C_e4m3) ** 2) +max_err_e4m3 = torch.max(torch.abs(C - C_e4m3)) +print(f"MSE: {mse_e4m3:.20f}") +print(f"Max Error: {max_err_e4m3:.20f}") +print(f"相对误差: {mse_e4m3 / torch.mean(C ** 2):.20f}") diff --git a/quant/qtype.py b/quant/qtype.py new file mode 100644 index 0000000000..23b1b725eb --- /dev/null +++ b/quant/qtype.py @@ -0,0 +1,139 @@ +import re +from copy import deepcopy + +class QType: + # declare datatype and default values + desc: str + exp_bits: int = -1 + man_bits: int = -1 + k_bits: int = -1 + k_outer_bits: int = 0 + blk_size: int = 1 + blk_outer_size: int = 1 + exp_max: int = -1 + exp_min: int = -1 + k_max: int = -1 + fp_val_max: float = -1 + q_dim: int = -1 + man_shift_bit: int = -1 + exp_offset: int = 0 + do_carry: bool = True + + def __init__(self, desc: str): + self.desc = desc + + if desc in ['fp16', 'fp32', 'bf16', 'int8sym', 'hif8']: + pass # 保留默认配置 + elif 'nf4' in desc.lower(): + res = re.match(r'^nf4B([0-9]+)b([0-9]+)$', desc) + if res is None: + raise ValueError("Quant type string must be [nf4B*b*]") + self.blk_outer_size = int(res.group(1)) + self.blk_size = int(res.group(2)) + elif desc.lower()[:3]=='tmx': + res = re.match(r"^tmx([0-9]+)$", desc) + self.blk_size = 16 + self.man_bits = int(res.group(1)) - 2 + elif desc.lower() == 'nvf4': + self.blk_size = 256 # padding to 256 to align kernel batch size + else: + if desc.lower() == 'mxfp4': + desc = 'e2m1k8b32c' + elif desc.lower() == 'mxfp6e3m2': + desc = 'e3m2k8b32c' + elif desc.lower() == 'mxfp8e4m3': + desc = 'e4m3k8b32c' + elif desc.lower() == 'mxfp8e5m2': + desc = 'e5m2k8b32c' + elif desc.lower()[:4] == 'hifx' and desc.lower()[-3:] == 'v12': + res = re.match(r"^hifx([2345])_v12$", desc.lower()) + if res is None: + raise ValueError("HiFx only supports hifx[2-5]_v12") + n_bit_tmp = int(res.group(1)) - 1 + desc = f"e0m{n_bit_tmp}k1k4B1b{n_bit_tmp}Coff38" + + res = re.match(r"^e([0-9]+)m([0-9]+)k([0-9]+)b([0-9]+)([Cc]?)$", desc) + res2 = re.match(r"^e([0-9]+)m([0-9]+)K([0-9]+)k([0-9]+)B([0-9]+)b([0-9]+)([Cc]?)(off[0-9]+)?$", desc) + + if res is not None: + self.exp_bits = int(res.group(1)) + self.man_bits = int(res.group(2)) + self.k_bits = int(res.group(3)) + self.blk_size = int(res.group(4)) + offset_number = None + if res.group(5) is None: + self.do_carry = False + else: + self.do_carry = str(res.group(5)).upper() == 'C' + elif res2 is not None: + self.exp_bits = int(res2.group(1)) + self.man_bits = int(res2.group(2)) + self.k_outer_bits = int(res2.group(3)) + self.k_bits = int(res2.group(4)) + self.blk_outer_size = int(res2.group(5)) + self.blk_size = int(res2.group(6)) + if res2.group(7) is None: + self.do_carry = False + else: + self.do_carry = str(res2.group(7)).upper() == 'C' + if res2.group(8) is not None: + offset_number = int(res2.group(8)[3:]) + else: + offset_number = None + else: + raise ValueError( + "Quant type string must be like 'ek' or 'ekkb', or special float types [fp16, fp32, bf16, int8sym]" + ) + + assert self.exp_bits !=1, "exp_bits==1 is not supported. E0M(x) is equivalent to E1M(x-1)" + assert self.man_bits >= 1, "man_bits should >= 1" + + if self.exp_bits == 0: + self.exp_max = self.man_bits - 1 + self.exp_min = 0 + else: + self.exp_max = 2 ** (self.exp_bits - 1) + if self.exp_bits == 5 and self.man_bits == 2: + self.exp_max-=1 + self.exp_min = -2 ** (self.exp_bits - 1) + 2 + + + self.k_max = 2 ** (self.k_bits + self.k_outer_bits - 1) - 1 + + if offset_number is None: + self.exp_offset = self.exp_max + else: + self.exp_offset = offset_number - self.k_max - 1 + self.exp_max + + # 计算 shift bits + self.man_shift_bit = self.man_bits + + # 计算 fp_val_max + if self.exp_bits == 4 and self.man_bits == 3: + self.fp_val_max = 448 # 特殊处理 e4m3 + else: + self.fp_val_max = 2 ** self.exp_max * float(2 ** (self.man_bits + 1) - 1) / (2 ** self.man_bits) + self.fp_val_max = min(self.fp_val_max, 1e38) # 近似 float max + + def dim_(self, dim: int): + """in-place function""" + self.q_dim = dim + return self + + def dim(self, dim: int): + """non-in-place function""" + out = deepcopy(self) + out.q_dim = dim + return out + + def copy(self): + return deepcopy(self) + + def __repr__(self) -> str: + return f'QType: {self.desc} Dim: {self.q_dim} ExpOffset: {self.exp_offset}' + +# 示例用法 +if __name__ == "__main__": + t = QType("e2m1k8b8") + t2 = deepcopy(t) + print(t) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..462542f1c1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +transformers +datasets diff --git a/run_tensor_collection.sh b/run_tensor_collection.sh new file mode 100755 index 0000000000..6cfd54f9c0 --- /dev/null +++ b/run_tensor_collection.sh @@ -0,0 +1,420 @@ +#!/bin/bash + +# ============================================================================= +# 统一Tensor收集脚本 +# 简化并整合了所有tensor收集功能,支持多种使用模式 +# ============================================================================= + +# 脚本元数据 +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="2.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') + +echo "==================================================================================" +echo "统一Tensor收集脚本" +echo "Script: $SCRIPT_NAME" +echo "Version: $SCRIPT_VERSION" +echo "Start Time: $START_TIME" +echo "==================================================================================" + +# 默认参数 +MODE="single" +QUANT_TYPE="mxfp8" +BASE_TENSOR_PATH="./enhanced_tensor_logs" +TOKENIZER_PATH="model/llama3.2-1b" +DATA_PATH="dataset/wikipedia_processed/wikipedia_processed_text_document" +DTYPE="bf16" +CONTROL_ITER=1 # 控制收集的micro_batch数量 +# collect_micro_batches已固定为1,进行一次完整forward后跳出 + +# 显示使用帮助 +show_help() { + echo "用法: $0 [OPTIONS] [MODE] [QUANT_TYPE]" + echo "" + echo "选项:" + echo " -h, --help 显示此帮助信息" + echo " --mode MODE 收集模式 (single|batch|quick) [默认: single]" + echo " --quant-type TYPE 量化类型 (bf16|mxfp8|mxfp4|hifp8) [默认: mxfp8]" + echo " --tensor-path PATH Tensor保存路径 [默认: ./enhanced_tensor_logs]" + echo " --tokenizer-path PATH 分词器路径 [默认: model/llama3.2-1b]" + echo " --data-path PATH 数据路径 [默认: dataset/wikipedia_processed/wikipedia_processed_text_document]" + echo " --dtype TYPE 数据类型 [默认: bf16]" + echo " --control-iter NUM 控制收集的micro_batch数量 [默认: 1]" + echo " (collect_micro_batches已固定为1,进行一次完整forward后跳出)" + echo "" + echo "位置参数:" + echo " MODE 收集模式 (single|batch|quick)" + echo " QUANT_TYPE 量化类型 (bf16|mxfp8|mxfp4|hifp8)" + echo "" + echo "使用示例:" + echo " # 基本用法" + echo " $0 single mxfp8" + echo "" + echo " # 使用命令行参数" + echo " $0 --mode single --quant-type mxfp8 --control-iter 3" + echo "" + echo " # 批量收集所有类型" + echo " $0 batch" + echo "" + echo " # 快速收集(收集少量数据用于测试)" + echo " $0 quick hifp8" + echo "" + echo " # 自定义路径和iteration数量" + echo " $0 --mode single --quant-type mxfp4 --tensor-path ./my_tensors --control-iter 3" + echo "" + echo " # 收集多个micro_batch的数据" + echo " $0 --mode single --quant-type mxfp8 --control-iter 5" + echo "" + echo " # 收集tensor(固定为1个micro_batch)" + echo " $0 --mode single --quant-type mxfp8 --control-iter 3" +} + +# 解析命令行参数 +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --mode) + MODE="$2" + shift 2 + ;; + --quant-type) + QUANT_TYPE="$2" + shift 2 + ;; + --tensor-path) + BASE_TENSOR_PATH="$2" + shift 2 + ;; + --tokenizer-path) + TOKENIZER_PATH="$2" + shift 2 + ;; + --data-path) + DATA_PATH="$2" + shift 2 + ;; + --dtype) + DTYPE="$2" + shift 2 + ;; + --control-iter|--control-micro-batches) + CONTROL_ITER="$2" + shift 2 + ;; + # collect_micro_batches参数已移除,固定为1 + single|batch|quick) + MODE="$1" + shift + ;; + bf16|mxfp8|mxfp4|hifp8) + QUANT_TYPE="$1" + shift + ;; + *) + echo "未知参数: $1" + echo "使用 --help 查看帮助信息" + exit 1 + ;; + esac +done + +# 支持的量化类型 +VALID_QUANT_TYPES=("bf16" "mxfp8" "mxfp4" "hifp8") + +# 验证模式 +if [[ ! "$MODE" =~ ^(single|batch|quick)$ ]]; then + echo "错误: 不支持的模式 '$MODE'" + echo "支持的模式: single, batch, quick" + show_help + exit 1 +fi + +# 验证量化类型 +if [[ ! " ${VALID_QUANT_TYPES[@]} " =~ " ${QUANT_TYPE} " ]]; then + echo "错误: 不支持的量化类型 '$QUANT_TYPE'" + echo "支持的量化类型: ${VALID_QUANT_TYPES[*]}" + exit 1 +fi + +# 验证control_iter参数 +if ! [[ "$CONTROL_ITER" =~ ^[0-9]+$ ]] || [ "$CONTROL_ITER" -lt 0 ]; then + echo "错误: control_iter 必须是大于等于0的整数" + echo "当前值: $CONTROL_ITER" + exit 1 +fi + +# collect_micro_batches已固定为1,无需验证 + +echo "配置信息:" +echo " - 模式: $MODE" +echo " - 量化类型: $QUANT_TYPE" +echo " - Tensor保存路径: $BASE_TENSOR_PATH" +echo " - 控制micro_batch数量: $CONTROL_ITER" +echo " - 收集micro_batch数量: 1 (固定)" +echo " - 分词器路径: $TOKENIZER_PATH" +echo " - 数据路径: $DATA_PATH" +echo " - 数据类型: $DTYPE" + +# 检查必要文件 +check_requirements() { + echo "" + echo "检查必要文件..." + + if [ ! -f "examples/llama/train_llama32_1b_h100_fp8.sh" ]; then + echo "错误: 训练脚本不存在: examples/llama/train_llama32_1b_h100_fp8.sh" + exit 1 + fi + + echo "✅ 必要文件检查完成" +} + +# 修改量化类型 +modify_quant_type() { + local quant_type=$1 + echo "" + echo "修改量化类型为: $quant_type" + + # 修改linear层量化类型 + if [ -f "megatron/core/tensor_parallel/layers.py" ]; then + sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'$quant_type'/" \ + megatron/core/tensor_parallel/layers.py + echo " ✅ 已修改 linear 层量化类型" + fi + + # 修改attention层量化类型 + if [ -f "megatron/core/transformer/dot_product_attention.py" ]; then + sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'$quant_type'/" \ + megatron/core/transformer/dot_product_attention.py + echo " ✅ 已修改 attention 层量化类型" + fi +} + +# 收集单个量化类型的tensor +collect_single_quant_type() { + local quant_type=$1 + local tensor_path="$BASE_TENSOR_PATH/${quant_type}" + local max_wait=300 # 最大等待时间(秒) + local control_iter=$CONTROL_ITER # 控制收集的micro_batch数量 + # collect_micro_batches已固定为1 + + echo "" + echo "==================================================================================" + echo "收集 $quant_type 量化类型的tensor" + echo "==================================================================================" + +# 设置环境变量 +export TENSOR_SAVE_ENABLED="true" +export TENSOR_SAVE_DIR="$tensor_path" +export HOST_TENSORBOARD_LOGS_PATH="tensorboard_logs/${quant_type}" +export CONTROL_ITER="$control_iter" +# collect_micro_batches已固定为1,无需设置环境变量 + + # 创建目录 + mkdir -p "$tensor_path" + mkdir -p "tensorboard_logs/${quant_type}" + mkdir -p "checkpoints/llama32_1b/${quant_type}" + + # 修改量化类型 + modify_quant_type "$quant_type" + + # 设置检查点路径 + local checkpoint_path="checkpoints/llama32_1b/${quant_type}" + + echo "开始训练并收集tensor..." + echo " - Tensor保存路径: $tensor_path" + echo " - 检查点路径: $checkpoint_path" + echo " - TensorBoard路径: $HOST_TENSORBOARD_LOGS_PATH" + echo " - 控制micro_batch数量: $control_iter" + echo "" + echo "环境变量设置:" + echo " - TENSOR_SAVE_ENABLED: $TENSOR_SAVE_ENABLED" + echo " - TENSOR_SAVE_DIR: $TENSOR_SAVE_DIR" + echo " - CONTROL_ITER: $CONTROL_ITER" + echo "" + echo "训练命令参数:" + echo " - --save-tensors: 启用" + echo " - --tensor-save-dir: $tensor_path" + echo " - --control-iter: $control_iter" + + # 运行训练脚本 + local log_file="training_${quant_type}_$(date +'%y-%m-%d_%H-%M-%S').log" + bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$checkpoint_path" \ + "$HOST_TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_PATH" \ + "$DATA_PATH" \ + "$DTYPE" \ + --control-iter "$control_iter" \ + --save-tensors \ + --tensor-save-dir "$tensor_path" \ + 2>&1 | tee "$log_file" + + # 最终统计 + local final_count=$(find "$tensor_path" -name "*.pt" 2>/dev/null | wc -l) + echo "" + echo "✅ $quant_type 量化类型tensor收集完成" + echo " - 最终收集到: $final_count 个tensor文件" + echo " - 实际保存路径: $tensor_path" + + # 检查父目录是否有文件 + local parent_path=$(dirname "$tensor_path") + local parent_count=$(find "$parent_path" -name "*.pt" 2>/dev/null | wc -l) + if [ $parent_count -gt 0 ]; then + echo " - 父目录文件数量: $parent_count (可能数据保存在父目录)" + echo " - 父目录路径: $parent_path" + fi + + # 显示统计信息 + if [ $final_count -gt 0 ]; then + echo " - 文件类型统计:" + local pre_count=$(find "$tensor_path" -name "*_pre_*" 2>/dev/null | wc -l) + local post_count=$(find "$tensor_path" -name "*_post_*" 2>/dev/null | wc -l) + local fa_count=$(find "$tensor_path" -name "*_FA_*" 2>/dev/null | wc -l) + local linear_count=$(find "$tensor_path" -name "*_linear_*" 2>/dev/null | wc -l) + local forward_count=$(find "$tensor_path" -name "*_forward_*" 2>/dev/null | wc -l) + local backward_count=$(find "$tensor_path" -name "*_backward_*" 2>/dev/null | wc -l) + + echo " * Pre阶段: $pre_count, Post阶段: $post_count" + echo " * FA组件: $fa_count, Linear组件: $linear_count" + echo " * Forward: $forward_count, Backward: $backward_count" + + echo " - 部分tensor文件:" + find "$tensor_path" -name "*.pt" | head -3 | while read file; do + echo " * $(basename "$file")" + done + fi + + return $final_count +} + +# 批量收集所有量化类型 +collect_batch() { + echo "" + echo "==================================================================================" + echo "批量收集所有量化类型的tensor" + echo "==================================================================================" + + local total_tensors=0 + local success_count=0 + + for quant_type in "${VALID_QUANT_TYPES[@]}"; do + collect_single_quant_type "$quant_type" + local result=$? + + if [ $result -gt 0 ]; then + success_count=$((success_count + 1)) + total_tensors=$((total_tensors + result)) + fi + + # 在每次运行之间稍作休息 + if [ "$quant_type" != "${VALID_QUANT_TYPES[-1]}" ]; then + echo "" + echo "等待5秒后继续下一个量化类型..." + sleep 5 + fi + done + + echo "" + echo "==================================================================================" + echo "批量收集完成" + echo "==================================================================================" + echo "成功收集: $success_count/${#VALID_QUANT_TYPES[@]} 个量化类型" + echo "总计tensor文件: $total_tensors 个" + echo "保存位置: $BASE_TENSOR_PATH" +} + +# 快速收集(用于测试) +collect_quick() { + echo "" + echo "==================================================================================" + echo "快速收集模式(用于测试)" + echo "==================================================================================" + + # 设置较短的等待时间 + local original_max_wait=$max_wait + max_wait=60 # 快速模式只等待60秒 + + collect_single_quant_type "$QUANT_TYPE" + local result=$? + + max_wait=$original_max_wait + + echo "" + echo "快速收集完成,收集到 $result 个tensor文件" +} + +# 显示结果总结 +show_summary() { + echo "" + echo "==================================================================================" + echo "Tensor收集结果总结" + echo "==================================================================================" + + local total_tensors=0 + local quant_types_found=() + + for quant_type in "${VALID_QUANT_TYPES[@]}"; do + local tensor_path="$BASE_TENSOR_PATH/${quant_type}" + if [ -d "$tensor_path" ]; then + local count=$(find "$tensor_path" -name "*.pt" 2>/dev/null | wc -l) + if [ $count -gt 0 ]; then + total_tensors=$((total_tensors + count)) + quant_types_found+=("$quant_type") + echo " - $quant_type: $count 个tensor文件" + fi + fi + done + + echo "" + echo "总计收集到: $total_tensors 个tensor文件" + echo "成功收集的量化类型: ${quant_types_found[*]}" + echo "保存位置: $BASE_TENSOR_PATH" + + if [ $total_tensors -gt 0 ]; then + echo "" + echo "下一步操作建议:" + echo "1. 查看收集到的tensor文件:" + echo " ls -la $BASE_TENSOR_PATH/*/" + echo "" + echo "2. 使用统一可视化脚本分析:" + echo " ./run_tensor_draw.sh $BASE_TENSOR_PATH" + echo "" + echo "3. 分析特定量化类型:" + for quant_type in "${quant_types_found[@]}"; do + echo " ls -la $BASE_TENSOR_PATH/$quant_type/" + done + fi +} + +# 主执行流程 +main() { + check_requirements + + case "$MODE" in + "single") + collect_single_quant_type "$QUANT_TYPE" + ;; + "batch") + collect_batch + ;; + "quick") + collect_quick + ;; + esac + + show_summary + + END_TIME=$(date '+%Y-%m-%d %H:%M:%S') + echo "" + echo "==================================================================================" + echo "Tensor收集完成" + echo "Start time: $START_TIME" + echo "End time: $END_TIME" + echo "==================================================================================" +} + +# 执行主函数 +main diff --git a/script/README.md b/script/README.md new file mode 100644 index 0000000000..6fb71109f5 --- /dev/null +++ b/script/README.md @@ -0,0 +1,206 @@ +# Script目录结构说明 + +这个目录包含了Megatron-LM项目的所有脚本文件,按功能分类组织。 + +## 📁 目录结构 + +``` +script/ +├── data_processing/ # 数据处理脚本 +│ ├── process_dolma_data.sh +│ ├── process_wikipedia_data.sh +│ ├── process_c4_data.sh +│ ├── process_custom_data.sh +│ ├── data_processing_utils.py +│ └── README.md +├── visualization/ # 可视化脚本 +│ ├── visualize_tensors.py +│ ├── quick_visualize.py +│ ├── one_click_visualize.sh +│ └── README.md +├── utils/ # 工具脚本 +│ ├── quant_type_modifier.py +│ ├── update_scripts_with_pattern_v2.py +│ └── README.md +├── templates/ # 模板文件 +│ ├── improved_script_template.sh +│ └── README.md +├── training/ # 训练脚本(按模型分类) +│ ├── llama32-1b/ +│ ├── llama31-8b/ +│ └── deepseek2_lite/ +└── README.md # 本文件 +``` + +## 🚀 快速开始 + +### 1. 数据处理 +```bash +# 处理Dolma数据集 +cd data_processing +./process_dolma_data.sh + +# 使用工具函数 +python data_processing_utils.py --action check +``` + +### 2. 模型训练 +```bash +# 使用模板创建训练脚本 +cp templates/improved_script_template.sh my_training.sh +chmod +x my_training.sh +./my_training.sh +``` + +### 3. 结果可视化 +```bash +# 一键可视化 +cd visualization +./one_click_visualize.sh +``` + +## 📋 功能概览 + +### 🔧 数据处理 (data_processing/) +- **数据集处理**: 支持Dolma、Wikipedia、C4等主流数据集 +- **格式转换**: 将原始数据转换为Megatron-LM训练格式 +- **批量处理**: 支持大规模数据集的并行处理 +- **工具函数**: 提供环境检查、时间估算等辅助功能 + +### 📊 可视化 (visualization/) +- **Tensor分析**: 可视化训练过程中的tensor数据 +- **量化研究**: 分析不同量化类型的影响 +- **统计图表**: 生成分布图、热力图、对比图等 +- **一键操作**: 自动生成所有分析图表 + +### 🛠️ 工具脚本 (utils/) +- **量化类型管理**: 批量修改脚本中的量化类型 +- **脚本模式更新**: 应用统一的脚本模式 +- **批量操作**: 支持大规模脚本文件的批量处理 + +### 📝 模板文件 (templates/) +- **训练脚本模板**: 功能完整的训练脚本模板 +- **标准化配置**: 统一的参数设置和错误处理 +- **易于定制**: 支持快速创建新的训练脚本 + +### 🏋️ 训练脚本 (training/) +- **模型分类**: 按模型类型组织训练脚本 +- **量化支持**: 支持多种量化类型 (hifp8, mxfp8, mxfp4, bf16等) +- **数据集适配**: 支持多种数据集的训练配置 + +## 🎯 使用场景 + +### 1. 量化研究 +```bash +# 1. 处理数据 +cd data_processing +./process_dolma_data.sh + +# 2. 运行训练(保存tensor) +cd ../training/llama32-1b +./pretrain_llama32-1b_dolma_hifp8.sh + +# 3. 可视化分析 +cd ../../visualization +./one_click_visualize.sh +``` + +### 2. 模型训练 +```bash +# 1. 使用模板创建脚本 +cp templates/improved_script_template.sh my_training.sh + +# 2. 修改参数 +vim my_training.sh + +# 3. 运行训练 +./my_training.sh +``` + +### 3. 批量操作 +```bash +# 1. 批量修改量化类型 +cd utils +python quant_type_modifier.py --directory ../training/ --old_quant_type bf16 --new_quant_type hifp8 + +# 2. 更新脚本模式 +python update_scripts_with_pattern_v2.py +``` + +## 📚 详细文档 + +每个子目录都包含详细的README文档: + +- **[数据处理文档](data_processing/README.md)** - 数据处理脚本的详细说明 +- **[可视化文档](visualization/README.md)** - 可视化工具的完整指南 +- **[工具文档](utils/README.md)** - 工具脚本的使用方法 +- **[模板文档](templates/README.md)** - 模板文件的定制指南 + +## 🔧 环境要求 + +### 基础环境 +- Python 3.8+ +- PyTorch 1.12+ +- CUDA 11.0+ + +### Python依赖 +```bash +pip install matplotlib seaborn pandas scipy +``` + +### 环境变量 +```bash +export CUSTOM_QUANT_TYPE="hifp8" +export TENSOR_SAVE_DIR="./enhanced_tensor_logs" +export TENSOR_SAVE_ENABLED="true" +``` + +## 🚨 注意事项 + +### 1. 文件权限 +```bash +# 设置脚本执行权限 +chmod +x *.sh +``` + +### 2. 路径配置 +- 确保所有路径配置正确 +- 检查数据集和模型文件是否存在 +- 验证输出目录的写入权限 + +### 3. 资源管理 +- 根据系统资源调整工作进程数 +- 监控磁盘空间使用情况 +- 注意内存使用峰值 + +### 4. 错误处理 +- 查看详细的错误日志 +- 使用dry_run模式预览操作 +- 定期备份重要文件 + +## 🤝 贡献指南 + +### 添加新脚本 +1. 选择合适的子目录 +2. 遵循现有的命名规范 +3. 添加详细的文档说明 +4. 测试脚本功能 + +### 修改现有脚本 +1. 创建备份文件 +2. 测试修改后的功能 +3. 更新相关文档 +4. 提交变更说明 + +## 📞 支持 + +如果遇到问题,请: +1. 查看相关子目录的README文档 +2. 检查错误日志和输出信息 +3. 验证环境配置和依赖 +4. 参考使用示例和最佳实践 + +--- + +**最后更新**: 2024年9月8日 +**版本**: 1.0.0 \ No newline at end of file diff --git a/script/STRUCTURE_SUMMARY.md b/script/STRUCTURE_SUMMARY.md new file mode 100644 index 0000000000..51605dafe0 --- /dev/null +++ b/script/STRUCTURE_SUMMARY.md @@ -0,0 +1,263 @@ +# Script目录结构整理总结 + +## 🎯 整理目标 + +将script目录重新组织,按功能分类,提高可维护性和易用性。 + +## ✅ 完成的工作 + +### 1. 目录结构重组 + +#### 📁 新的目录结构 +``` +script/ +├── data_processing/ # 数据处理脚本 +│ ├── process_dolma_data.sh +│ ├── process_wikipedia_data.sh +│ ├── process_c4_data.sh +│ ├── process_custom_data.sh +│ ├── data_processing_utils.py +│ └── README.md +├── visualization/ # 可视化脚本 +│ ├── visualize_tensors.py +│ ├── quick_visualize.py +│ ├── one_click_visualize.sh +│ └── README.md +├── utils/ # 工具脚本 +│ ├── quant_type_modifier.py +│ ├── update_scripts_with_pattern_v2.py +│ └── README.md +├── templates/ # 模板文件 +│ ├── improved_script_template.sh +│ └── README.md +├── training/ # 训练脚本(按模型分类) +│ ├── llama32-1b/ +│ ├── llama31-8b/ +│ └── deepseek2_lite/ +├── navigate.sh # 导航脚本 +├── README.md # 主说明文档 +└── STRUCTURE_SUMMARY.md # 本总结文档 +``` + +### 2. 文件分类和移动 + +#### 🔄 文件移动记录 +- **数据处理脚本** → `data_processing/` + - `process_dolma_data.sh` + - `process_wikipedia_data.sh` + - `process_c4_data.sh` + - `process_custom_data.sh` + - `data_processing_utils.py` + +- **可视化脚本** → `visualization/` + - `visualize_tensors.py` + - `quick_visualize.py` + - `one_click_visualize.sh` + +- **工具脚本** → `utils/` + - `quant_type_modifier.py` + - `update_scripts_with_pattern_v2.py` + +- **模板文件** → `templates/` + - `improved_script_template.sh` + +- **训练脚本** → 保持原有结构 + - `llama32-1b/` + - `llama31-8b/` + - `deepseek2_lite/` + +### 3. 文档完善 + +#### 📚 新增文档 +- **主README.md** - 整体说明和快速开始指南 +- **data_processing/README.md** - 数据处理脚本详细说明 +- **visualization/README.md** - 可视化工具完整指南 +- **utils/README.md** - 工具脚本使用方法 +- **templates/README.md** - 模板文件定制指南 +- **navigate.sh** - 交互式导航脚本 + +### 4. 权限设置 + +#### 🔐 执行权限 +```bash +# 设置所有.sh脚本的执行权限 +find script -name "*.sh" -exec chmod +x {} \; +``` + +## 🚀 新增功能 + +### 1. 数据处理脚本 + +#### 📊 支持的数据集 +- **Dolma数据集**: `process_dolma_data.sh` +- **Wikipedia数据集**: `process_wikipedia_data.sh` +- **C4数据集**: `process_c4_data.sh` +- **自定义数据集**: `process_custom_data.sh` + +#### 🛠️ 工具函数 +- **环境检查**: 验证必要目录和文件 +- **数据集列表**: 列出可用的数据集 +- **模型列表**: 列出可用的模型 +- **时间估算**: 估算处理时间 +- **参数优化**: 推荐最优参数 + +### 2. 可视化工具 + +#### 📈 可视化功能 +- **分布图**: tensor数值分布分析 +- **热力图**: tensor数据热力图 +- **对比图**: 不同量化类型对比 +- **统计图**: 统计信息汇总 +- **Attention分析**: 专门的attention分析 + +#### 🎯 使用方式 +- **一键可视化**: `one_click_visualize.sh` +- **快速可视化**: `quick_visualize.py` +- **完整可视化**: `visualize_tensors.py` + +### 3. 工具脚本 + +#### 🔧 批量操作 +- **量化类型修改**: 批量修改脚本中的量化类型 +- **脚本模式更新**: 应用统一的脚本模式 +- **正则表达式支持**: 灵活的文件匹配 + +### 4. 导航系统 + +#### 🧭 交互式导航 +- **功能模块导航**: 快速访问各个功能模块 +- **快速命令**: 常用命令的快速执行 +- **帮助系统**: 详细的使用说明 + +## 📋 使用指南 + +### 1. 快速开始 + +#### 🚀 使用导航脚本 +```bash +cd script +./navigate.sh +``` + +#### 📖 查看文档 +```bash +# 查看主文档 +cat README.md + +# 查看各模块文档 +cat data_processing/README.md +cat visualization/README.md +cat utils/README.md +cat templates/README.md +``` + +### 2. 数据处理流程 + +#### 📊 典型工作流 +```bash +# 1. 检查环境 +cd data_processing +python data_processing_utils.py --action check + +# 2. 处理数据 +./process_dolma_data.sh + +# 3. 验证结果 +ls -la ../dataset/dolma_processed* +``` + +### 3. 训练和可视化 + +#### 🏋️ 训练流程 +```bash +# 1. 使用模板创建脚本 +cp templates/improved_script_template.sh my_training.sh + +# 2. 运行训练 +./my_training.sh + +# 3. 可视化结果 +cd visualization +./one_click_visualize.sh +``` + +## 🎯 优势和改进 + +### 1. 结构优势 + +#### ✅ 清晰的分类 +- **按功能分类**: 每个目录都有明确的功能定位 +- **易于维护**: 相关文件集中管理 +- **快速定位**: 通过目录结构快速找到所需文件 + +#### ✅ 完善的文档 +- **详细说明**: 每个模块都有完整的README文档 +- **使用示例**: 提供具体的使用示例 +- **故障排除**: 包含常见问题的解决方案 + +### 2. 易用性改进 + +#### 🚀 便捷访问 +- **导航脚本**: 交互式导航系统 +- **快速命令**: 常用操作的快速执行 +- **一键操作**: 简化的操作流程 + +#### 🔧 工具支持 +- **批量操作**: 支持大规模文件的批量处理 +- **参数优化**: 自动推荐最优参数 +- **错误处理**: 完善的错误检查和提示 + +### 3. 扩展性 + +#### 📈 易于扩展 +- **模块化设计**: 新功能可以独立添加 +- **标准化接口**: 统一的参数和接口设计 +- **向后兼容**: 保持与现有脚本的兼容性 + +## 📊 统计信息 + +### 文件统计 +- **总文件数**: 约80个文件 +- **脚本文件**: 约60个.sh脚本 +- **Python脚本**: 约8个.py脚本 +- **文档文件**: 约12个.md文档 + +### 目录统计 +- **主要目录**: 5个功能目录 +- **训练脚本目录**: 3个模型目录 +- **文档目录**: 每个功能目录都有README + +## 🔮 未来规划 + +### 1. 功能扩展 +- **更多数据集支持**: 添加更多数据集的处理脚本 +- **高级可视化**: 增加更多可视化功能 +- **自动化工具**: 开发更多自动化工具 + +### 2. 用户体验 +- **Web界面**: 开发Web管理界面 +- **配置管理**: 统一的配置文件管理 +- **监控系统**: 训练过程监控工具 + +### 3. 性能优化 +- **并行处理**: 优化数据处理性能 +- **内存管理**: 改进内存使用效率 +- **缓存机制**: 添加结果缓存功能 + +## 🎉 总结 + +通过这次整理,script目录现在具有: + +1. **清晰的结构**: 按功能分类,易于导航 +2. **完善的文档**: 详细的使用说明和示例 +3. **便捷的工具**: 交互式导航和快速命令 +4. **强大的功能**: 数据处理、可视化、批量操作 +5. **良好的扩展性**: 易于添加新功能和模块 + +这个新的结构大大提高了script目录的可维护性和易用性,为用户提供了更好的使用体验。 + +--- + +**整理完成时间**: 2024年9月8日 +**整理人员**: AI Assistant +**版本**: 1.0.0 diff --git a/script/data_processing/README.md b/script/data_processing/README.md new file mode 100644 index 0000000000..db6ccb6e07 --- /dev/null +++ b/script/data_processing/README.md @@ -0,0 +1,142 @@ +# 数据处理脚本 + +这个目录包含了用于处理各种数据集的脚本和工具。 + +## 脚本文件 + +### 数据处理脚本 +- **`process_dolma_data.sh`** - 处理Dolma数据集 +- **`process_wikipedia_data.sh`** - 处理Wikipedia数据集 +- **`process_c4_data.sh`** - 处理C4数据集 +- **`process_custom_data.sh`** - 处理自定义数据集 + +### 工具脚本 +- **`data_processing_utils.py`** - 数据处理工具函数 + +## 使用方法 + +### 1. Dolma数据处理 +```bash +# 基本用法 +./process_dolma_data.sh + +# 自定义参数 +./process_dolma_data.sh \ + "./dataset/dolma/**/*.json.gz" \ + "./dataset/dolma_processed" \ + 32 \ + 8 \ + "./model/llama3/" \ + "HuggingFaceTokenizer" +``` + +### 2. Wikipedia数据处理 +```bash +# 基本用法 +./process_wikipedia_data.sh + +# 自定义参数 +./process_wikipedia_data.sh \ + "./dataset/wikipedia/**/*.json" \ + "./dataset/wikipedia_processed" \ + 16 \ + 4 \ + "./model/llama3/" +``` + +### 3. C4数据处理 +```bash +# 基本用法 +./process_c4_data.sh + +# 自定义参数 +./process_c4_data.sh \ + "./dataset/c4/**/*.json" \ + "./dataset/c4_processed" \ + 24 \ + 6 \ + "./model/llama3/" +``` + +### 4. 自定义数据处理 +```bash +# 基本用法 +./process_custom_data.sh + +# 自定义参数 +./process_custom_data.sh \ + "./dataset/custom/**/*.json" \ + "./dataset/custom_processed" \ + 16 \ + 4 \ + "./model/llama3/" \ + "HuggingFaceTokenizer" \ + "true" \ + 2048 \ + "false" \ + "text" +``` + +## 参数说明 + +### 必需参数 +1. **输入路径** - 数据文件的路径(支持通配符) +2. **输出前缀** - 处理后文件的输出前缀 +3. **工作进程数** - 并行处理的工作进程数 +4. **分区数** - 数据分区的数量 +5. **分词器模型** - 分词器模型路径 + +### 可选参数 +- **分词器类型** - 默认为"HuggingFaceTokenizer" +- **追加EOD** - 是否在序列末尾追加EOD token +- **序列长度** - 最大序列长度,默认为2048 +- **覆盖输出** - 是否覆盖已存在的输出文件 + +## 工具函数 + +### data_processing_utils.py + +提供以下功能: +- 环境检查 +- 数据集列表 +- 模型列表 +- 处理时间估算 +- 最优参数推荐 +- 数据验证 +- 脚本生成 + +#### 使用方法 +```bash +# 检查环境 +python data_processing_utils.py --action check + +# 列出可用数据集和模型 +python data_processing_utils.py --action list + +# 估算处理时间 +python data_processing_utils.py --action estimate --input "./dataset/dolma/**/*.json.gz" + +# 运行数据处理 +python data_processing_utils.py --action process \ + --input "./dataset/dolma/**/*.json.gz" \ + --output "./dataset/dolma_processed" \ + --tokenizer "./model/llama3/" \ + --workers 32 \ + --partitions 8 +``` + +## 注意事项 + +1. **数据格式**: 支持JSON、JSONL、TXT格式 +2. **内存使用**: 大数据集建议使用更多分区 +3. **磁盘空间**: 确保有足够的磁盘空间存储处理后的数据 +4. **分词器**: 确保分词器模型路径正确 +5. **权限**: 确保脚本有执行权限 + +## 输出文件 + +处理完成后会生成以下文件: +- `{output_prefix}.bin` - 二进制数据文件 +- `{output_prefix}.idx` - 索引文件 + +这些文件可以直接用于Megatron-LM训练。 diff --git a/script/data_processing/data_processing_utils.py b/script/data_processing/data_processing_utils.py new file mode 100644 index 0000000000..c4b26c6ffe --- /dev/null +++ b/script/data_processing/data_processing_utils.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +""" +数据处理工具函数 +提供常用的数据处理辅助功能 +""" + +import os +import json +import argparse +import subprocess +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import time +import shutil + +class DataProcessor: + """数据处理器类""" + + def __init__(self, base_dir: str = "."): + """ + 初始化数据处理器 + + Args: + base_dir: 基础目录路径 + """ + self.base_dir = Path(base_dir) + self.dataset_dir = self.base_dir / "dataset" + self.model_dir = self.base_dir / "model" + self.tools_dir = self.base_dir / "tools" + + def check_environment(self) -> bool: + """检查环境是否满足要求""" + print("=== 检查环境 ===") + + # 检查必要目录 + required_dirs = [self.dataset_dir, self.model_dir, self.tools_dir] + for dir_path in required_dirs: + if not dir_path.exists(): + print(f"❌ 目录不存在: {dir_path}") + return False + print(f"✅ 目录存在: {dir_path}") + + # 检查preprocess_data.py + preprocess_script = self.tools_dir / "preprocess_data.py" + if not preprocess_script.exists(): + print(f"❌ 预处理脚本不存在: {preprocess_script}") + return False + print(f"✅ 预处理脚本存在: {preprocess_script}") + + return True + + def list_available_datasets(self) -> List[str]: + """列出可用的数据集""" + print("=== 可用数据集 ===") + + if not self.dataset_dir.exists(): + print("数据集目录不存在") + return [] + + datasets = [] + for item in self.dataset_dir.iterdir(): + if item.is_dir(): + # 检查是否包含数据文件 + data_files = list(item.glob("**/*.json*")) + list(item.glob("**/*.txt*")) + if data_files: + datasets.append(item.name) + print(f"✅ {item.name} ({len(data_files)} 个文件)") + else: + print(f"⚠️ {item.name} (无数据文件)") + + return datasets + + def list_available_models(self) -> List[str]: + """列出可用的模型""" + print("=== 可用模型 ===") + + if not self.model_dir.exists(): + print("模型目录不存在") + return [] + + models = [] + for item in self.model_dir.iterdir(): + if item.is_dir(): + # 检查是否包含tokenizer文件 + tokenizer_files = list(item.glob("tokenizer*")) + list(item.glob("*.json")) + if tokenizer_files: + models.append(item.name) + print(f"✅ {item.name}") + else: + print(f"⚠️ {item.name} (无tokenizer文件)") + + return models + + def estimate_processing_time(self, input_path: str, workers: int = 16) -> str: + """估算处理时间""" + input_path = Path(input_path) + + if not input_path.exists(): + return "无法估算:输入路径不存在" + + # 计算文件大小 + total_size = 0 + file_count = 0 + + if input_path.is_file(): + total_size = input_path.stat().st_size + file_count = 1 + else: + for file_path in input_path.rglob("*"): + if file_path.is_file(): + total_size += file_path.stat().st_size + file_count += 1 + + # 估算处理时间(基于经验值) + # 假设每GB数据需要约10-30分钟,取决于硬件配置 + size_gb = total_size / (1024**3) + estimated_minutes = size_gb * 20 # 20分钟/GB + + return f"估算处理时间: {estimated_minutes:.1f}分钟 (基于{size_gb:.2f}GB数据, {file_count}个文件)" + + def get_optimal_workers(self, input_path: str) -> int: + """获取最优的工作进程数""" + import multiprocessing + + # 获取CPU核心数 + cpu_count = multiprocessing.cpu_count() + + # 检查输入文件数量 + input_path = Path(input_path) + file_count = 0 + + if input_path.is_file(): + file_count = 1 + else: + file_count = len(list(input_path.rglob("*"))) + + # 计算最优进程数 + optimal_workers = min(cpu_count, max(1, file_count // 4)) + + return optimal_workers + + def validate_input_data(self, input_path: str) -> Tuple[bool, str]: + """验证输入数据""" + input_path = Path(input_path) + + if not input_path.exists(): + return False, "输入路径不存在" + + # 检查文件格式 + if input_path.is_file(): + if input_path.suffix in ['.json', '.jsonl', '.txt']: + return True, "文件格式正确" + else: + return False, f"不支持的文件格式: {input_path.suffix}" + else: + # 检查目录中的文件 + data_files = list(input_path.rglob("*.json*")) + list(input_path.rglob("*.txt*")) + if not data_files: + return False, "目录中没有找到数据文件" + + return True, f"找到 {len(data_files)} 个数据文件" + + def create_processing_script(self, + dataset_name: str, + input_path: str, + output_prefix: str, + tokenizer_model: str, + **kwargs) -> str: + """创建处理脚本""" + + script_content = f"""#!/bin/bash +# 自动生成的数据处理脚本 - {dataset_name} +# 生成时间: {time.strftime('%Y-%m-%d %H:%M:%S')} + +# 设置参数 +INPUT_PATH="{input_path}" +OUTPUT_PREFIX="{output_prefix}" +TOKENIZER_MODEL="{tokenizer_model}" +WORKERS={kwargs.get('workers', 16)} +PARTITIONS={kwargs.get('partitions', 4)} +TOKENIZER_TYPE="{kwargs.get('tokenizer_type', 'HuggingFaceTokenizer')}" +APPEND_EOD="{kwargs.get('append_eod', 'true')}" +SEQUENCE_LENGTH={kwargs.get('sequence_length', 2048)} +OVERWRITE="{kwargs.get('overwrite', 'false')}" + +echo "=== {dataset_name}数据处理 ===" +echo "输入路径: $INPUT_PATH" +echo "输出前缀: $OUTPUT_PREFIX" +echo "分词器模型: $TOKENIZER_MODEL" +echo "工作进程数: $WORKERS" +echo "分区数: $PARTITIONS" + +# 构建命令 +CMD="python tools/preprocess_data.py" +CMD="$CMD --input '$INPUT_PATH'" +CMD="$CMD --workers $WORKERS" +CMD="$CMD --partitions $PARTITIONS" +CMD="$CMD --output-prefix $OUTPUT_PREFIX" +CMD="$CMD --tokenizer-type $TOKENIZER_TYPE" +CMD="$CMD --tokenizer-model $TOKENIZER_MODEL" + +if [ "$APPEND_EOD" = "true" ]; then + CMD="$CMD --append-eod" +fi + +if [ "$SEQUENCE_LENGTH" != "2048" ]; then + CMD="$CMD --seq-length $SEQUENCE_LENGTH" +fi + +if [ "$OVERWRITE" = "true" ]; then + CMD="$CMD --overwrite" +fi + +echo "执行命令: $CMD" +echo "开始处理时间: $(date)" + +# 执行命令 +eval $CMD + +if [ $? -eq 0 ]; then + echo "✅ 处理完成: $(date)" + echo "输出文件:" + ls -lh "$OUTPUT_PREFIX"* +else + echo "❌ 处理失败" + exit 1 +fi +""" + + # 保存脚本 + script_path = self.base_dir / "script" / f"process_{dataset_name}_auto.sh" + with open(script_path, 'w', encoding='utf-8') as f: + f.write(script_content) + + # 设置执行权限 + os.chmod(script_path, 0o755) + + return str(script_path) + + def run_processing(self, + input_path: str, + output_prefix: str, + tokenizer_model: str, + **kwargs) -> bool: + """运行数据处理""" + + # 验证输入 + is_valid, message = self.validate_input_data(input_path) + if not is_valid: + print(f"❌ 输入验证失败: {message}") + return False + + print(f"✅ 输入验证通过: {message}") + + # 获取最优参数 + optimal_workers = self.get_optimal_workers(input_path) + if 'workers' not in kwargs: + kwargs['workers'] = optimal_workers + print(f"💡 使用最优工作进程数: {optimal_workers}") + + # 估算处理时间 + time_estimate = self.estimate_processing_time(input_path, kwargs['workers']) + print(f"⏱️ {time_estimate}") + + # 构建命令 + cmd = [ + "python", "tools/preprocess_data.py", + "--input", input_path, + "--workers", str(kwargs.get('workers', 16)), + "--partitions", str(kwargs.get('partitions', 4)), + "--output-prefix", output_prefix, + "--tokenizer-type", kwargs.get('tokenizer_type', 'HuggingFaceTokenizer'), + "--tokenizer-model", tokenizer_model + ] + + if kwargs.get('append_eod', True): + cmd.append("--append-eod") + + if kwargs.get('sequence_length', 2048) != 2048: + cmd.extend(["--seq-length", str(kwargs['sequence_length'])]) + + if kwargs.get('overwrite', False): + cmd.append("--overwrite") + + print(f"🚀 执行命令: {' '.join(cmd)}") + + # 记录开始时间 + start_time = time.time() + print(f"⏰ 开始时间: {time.strftime('%Y-%m-%d %H:%M:%S')}") + + try: + # 执行命令 + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + + # 计算处理时间 + end_time = time.time() + duration = end_time - start_time + + print(f"✅ 处理完成!") + print(f"⏰ 结束时间: {time.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"⏱️ 总耗时: {duration:.1f}秒 ({duration/60:.1f}分钟)") + + # 显示输出文件 + output_path = Path(output_prefix) + if output_path.with_suffix('.bin').exists(): + bin_size = output_path.with_suffix('.bin').stat().st_size / (1024**2) + idx_size = output_path.with_suffix('.idx').stat().st_size / (1024**2) + print(f"📁 输出文件大小: .bin={bin_size:.1f}MB, .idx={idx_size:.1f}MB") + + return True + + except subprocess.CalledProcessError as e: + print(f"❌ 处理失败: {e}") + print(f"错误输出: {e.stderr}") + return False + + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description='数据处理工具') + parser.add_argument('--action', choices=['check', 'list', 'process', 'estimate'], + default='check', help='执行的操作') + parser.add_argument('--input', type=str, help='输入数据路径') + parser.add_argument('--output', type=str, help='输出前缀') + parser.add_argument('--tokenizer', type=str, help='分词器模型路径') + parser.add_argument('--workers', type=int, default=16, help='工作进程数') + parser.add_argument('--partitions', type=int, default=4, help='分区数') + parser.add_argument('--seq-length', type=int, default=2048, help='序列长度') + parser.add_argument('--no-eod', action='store_true', help='不追加EOD') + parser.add_argument('--overwrite', action='store_true', help='覆盖输出文件') + + args = parser.parse_args() + + # 创建处理器 + processor = DataProcessor() + + if args.action == 'check': + processor.check_environment() + + elif args.action == 'list': + processor.list_available_datasets() + processor.list_available_models() + + elif args.action == 'estimate': + if not args.input: + print("错误: 需要指定 --input 参数") + return + estimate = processor.estimate_processing_time(args.input, args.workers) + print(estimate) + + elif args.action == 'process': + if not all([args.input, args.output, args.tokenizer]): + print("错误: 需要指定 --input, --output, --tokenizer 参数") + return + + kwargs = { + 'workers': args.workers, + 'partitions': args.partitions, + 'sequence_length': args.seq_length, + 'append_eod': not args.no_eod, + 'overwrite': args.overwrite + } + + success = processor.run_processing( + args.input, args.output, args.tokenizer, **kwargs + ) + + if not success: + exit(1) + + +if __name__ == "__main__": + main() diff --git a/script/data_processing/process_c4_data.sh b/script/data_processing/process_c4_data.sh new file mode 100755 index 0000000000..6ea0204c73 --- /dev/null +++ b/script/data_processing/process_c4_data.sh @@ -0,0 +1,102 @@ +#!/bin/bash +""" +C4数据处理脚本 +用于处理C4 (Colossal Clean Crawled Corpus) 数据集 +""" + +# 设置默认参数 +INPUT_PATH=${1:-"./dataset/c4/**/*.json"} +OUTPUT_PREFIX=${2:-"./dataset/c4_processed"} +WORKERS=${3:-24} +PARTITIONS=${4:-6} +TOKENIZER_MODEL=${5:-"./model/llama3/"} +TOKENIZER_TYPE=${6:-"HuggingFaceTokenizer"} + +# 可选参数 +APPEND_EOD=${7:-"true"} +SEQUENCE_LENGTH=${8:-2048} +OVERWRITE=${9:-"false"} + +echo "=== C4数据处理脚本 ===" +echo "输入路径: $INPUT_PATH" +echo "输出前缀: $OUTPUT_PREFIX" +echo "工作进程数: $WORKERS" +echo "分区数: $PARTITIONS" +echo "分词器模型: $TOKENIZER_MODEL" +echo "分词器类型: $TOKENIZER_TYPE" +echo "追加EOD: $APPEND_EOD" +echo "序列长度: $SEQUENCE_LENGTH" +echo "覆盖输出: $OVERWRITE" + +# 检查输入路径 +if [ ! -d "$(dirname "$INPUT_PATH")" ]; then + echo "错误: 输入目录不存在: $(dirname "$INPUT_PATH")" + exit 1 +fi + +# 检查分词器模型 +if [ ! -d "$TOKENIZER_MODEL" ]; then + echo "错误: 分词器模型目录不存在: $TOKENIZER_MODEL" + exit 1 +fi + +# 创建输出目录 +OUTPUT_DIR=$(dirname "$OUTPUT_PREFIX") +mkdir -p "$OUTPUT_DIR" + +# 构建命令 +CMD="python tools/preprocess_data.py" +CMD="$CMD --input '$INPUT_PATH'" +CMD="$CMD --workers $WORKERS" +CMD="$CMD --partitions $PARTITIONS" +CMD="$CMD --output-prefix $OUTPUT_PREFIX" +CMD="$CMD --tokenizer-type $TOKENIZER_TYPE" +CMD="$CMD --tokenizer-model $TOKENIZER_MODEL" + +# 添加可选参数 +if [ "$APPEND_EOD" = "true" ]; then + CMD="$CMD --append-eod" +fi + +if [ "$SEQUENCE_LENGTH" != "2048" ]; then + CMD="$CMD --seq-length $SEQUENCE_LENGTH" +fi + +if [ "$OVERWRITE" = "true" ]; then + CMD="$CMD --overwrite" +fi + +echo "" +echo "执行命令:" +echo "$CMD" +echo "" + +# 记录开始时间 +START_TIME=$(date +%s) +echo "开始处理时间: $(date)" + +# 执行命令 +eval $CMD + +# 检查执行结果 +if [ $? -eq 0 ]; then + END_TIME=$(date +%s) + DURATION=$((END_TIME - START_TIME)) + echo "" + echo "✅ C4数据处理完成!" + echo "处理时间: ${DURATION}秒" + echo "完成时间: $(date)" + + # 显示输出文件信息 + echo "" + echo "输出文件:" + ls -lh "${OUTPUT_PREFIX}"* 2>/dev/null || echo "未找到输出文件" + +else + echo "" + echo "❌ C4数据处理失败!" + exit 1 +fi + +echo "" +echo "=== 处理完成 ===" diff --git a/script/data_processing/process_custom_data.sh b/script/data_processing/process_custom_data.sh new file mode 100755 index 0000000000..381083f9ce --- /dev/null +++ b/script/data_processing/process_custom_data.sh @@ -0,0 +1,133 @@ +#!/bin/bash +""" +自定义数据处理脚本 +用于处理任意格式的数据集,支持多种配置选项 +""" + +# 设置默认参数 +INPUT_PATH=${1:-"./dataset/custom/**/*.json"} +OUTPUT_PREFIX=${2:-"./dataset/custom_processed"} +WORKERS=${3:-16} +PARTITIONS=${4:-4} +TOKENIZER_MODEL=${5:-"./model/llama3/"} +TOKENIZER_TYPE=${6:-"HuggingFaceTokenizer"} + +# 可选参数 +APPEND_EOD=${7:-"true"} +SEQUENCE_LENGTH=${8:-2048} +OVERWRITE=${9:-"false"} +JSON_KEYS=${10:-"text"} +TOKENIZER_VOCAB_FILE=${11:-""} +TOKENIZER_MERGE_FILE=${12:-""} + +echo "=== 自定义数据处理脚本 ===" +echo "输入路径: $INPUT_PATH" +echo "输出前缀: $OUTPUT_PREFIX" +echo "工作进程数: $WORKERS" +echo "分区数: $PARTITIONS" +echo "分词器模型: $TOKENIZER_MODEL" +echo "分词器类型: $TOKENIZER_TYPE" +echo "追加EOD: $APPEND_EOD" +echo "序列长度: $SEQUENCE_LENGTH" +echo "覆盖输出: $OVERWRITE" +echo "JSON键: $JSON_KEYS" + +# 检查输入路径 +if [ ! -d "$(dirname "$INPUT_PATH")" ]; then + echo "错误: 输入目录不存在: $(dirname "$INPUT_PATH")" + exit 1 +fi + +# 检查分词器模型 +if [ ! -d "$TOKENIZER_MODEL" ] && [ -z "$TOKENIZER_VOCAB_FILE" ]; then + echo "错误: 分词器模型目录不存在: $TOKENIZER_MODEL" + echo "或者请提供vocab文件和merge文件" + exit 1 +fi + +# 创建输出目录 +OUTPUT_DIR=$(dirname "$OUTPUT_PREFIX") +mkdir -p "$OUTPUT_DIR" + +# 构建命令 +CMD="python tools/preprocess_data.py" +CMD="$CMD --input '$INPUT_PATH'" +CMD="$CMD --workers $WORKERS" +CMD="$CMD --partitions $PARTITIONS" +CMD="$CMD --output-prefix $OUTPUT_PREFIX" +CMD="$CMD --tokenizer-type $TOKENIZER_TYPE" + +# 添加分词器相关参数 +if [ -n "$TOKENIZER_VOCAB_FILE" ] && [ -n "$TOKENIZER_MERGE_FILE" ]; then + CMD="$CMD --tokenizer-vocab-file $TOKENIZER_VOCAB_FILE" + CMD="$CMD --tokenizer-merge-file $TOKENIZER_MERGE_FILE" +else + CMD="$CMD --tokenizer-model $TOKENIZER_MODEL" +fi + +# 添加JSON键参数 +if [ "$JSON_KEYS" != "text" ]; then + CMD="$CMD --json-keys $JSON_KEYS" +fi + +# 添加可选参数 +if [ "$APPEND_EOD" = "true" ]; then + CMD="$CMD --append-eod" +fi + +if [ "$SEQUENCE_LENGTH" != "2048" ]; then + CMD="$CMD --seq-length $SEQUENCE_LENGTH" +fi + +if [ "$OVERWRITE" = "true" ]; then + CMD="$CMD --overwrite" +fi + +echo "" +echo "执行命令:" +echo "$CMD" +echo "" + +# 记录开始时间 +START_TIME=$(date +%s) +echo "开始处理时间: $(date)" + +# 执行命令 +eval $CMD + +# 检查执行结果 +if [ $? -eq 0 ]; then + END_TIME=$(date +%s) + DURATION=$((END_TIME - START_TIME)) + echo "" + echo "✅ 自定义数据处理完成!" + echo "处理时间: ${DURATION}秒" + echo "完成时间: $(date)" + + # 显示输出文件信息 + echo "" + echo "输出文件:" + ls -lh "${OUTPUT_PREFIX}"* 2>/dev/null || echo "未找到输出文件" + + # 显示文件统计信息 + if [ -f "${OUTPUT_PREFIX}.bin" ]; then + echo "" + echo "文件统计:" + echo " .bin文件大小: $(du -h "${OUTPUT_PREFIX}.bin" | cut -f1)" + echo " .idx文件大小: $(du -h "${OUTPUT_PREFIX}.idx" | cut -f1)" + + # 尝试获取文档数量 + if [ -f "${OUTPUT_PREFIX}.idx" ]; then + DOC_COUNT=$(wc -l < "${OUTPUT_PREFIX}.idx") + echo " 文档数量: $DOC_COUNT" + fi + fi + +else + echo "" + echo "❌ 自定义数据处理失败!" + exit 1 +fi + +echo "" +echo "=== 处理完成 ===" diff --git a/script/data_processing/process_dolma_data.sh b/script/data_processing/process_dolma_data.sh new file mode 100755 index 0000000000..6e9737fce3 --- /dev/null +++ b/script/data_processing/process_dolma_data.sh @@ -0,0 +1,115 @@ +#!/bin/bash +""" +Dolma数据处理脚本 +用于处理dolma数据集,支持多种配置选项 +""" + +# 设置默认参数 +INPUT_PATH=${1:-"./dataset/dolma/**/*.json.gz"} +OUTPUT_PREFIX=${2:-"./dataset/dolma_processed"} +WORKERS=${3:-32} +PARTITIONS=${4:-8} +TOKENIZER_MODEL=${5:-"./model/llama3/"} +TOKENIZER_TYPE=${6:-"HuggingFaceTokenizer"} + +# 可选参数 +APPEND_EOD=${7:-"true"} +SEQUENCE_LENGTH=${8:-2048} +OVERWRITE=${9:-"false"} + +echo "=== Dolma数据处理脚本 ===" +echo "输入路径: $INPUT_PATH" +echo "输出前缀: $OUTPUT_PREFIX" +echo "工作进程数: $WORKERS" +echo "分区数: $PARTITIONS" +echo "分词器模型: $TOKENIZER_MODEL" +echo "分词器类型: $TOKENIZER_TYPE" +echo "追加EOD: $APPEND_EOD" +echo "序列长度: $SEQUENCE_LENGTH" +echo "覆盖输出: $OVERWRITE" + +# 检查输入路径是否存在 +if [ ! -d "$(dirname "$INPUT_PATH")" ]; then + echo "错误: 输入目录不存在: $(dirname "$INPUT_PATH")" + echo "请确保dolma数据集已下载到正确位置" + exit 1 +fi + +# 检查分词器模型是否存在 +if [ ! -d "$TOKENIZER_MODEL" ]; then + echo "错误: 分词器模型目录不存在: $TOKENIZER_MODEL" + echo "请确保分词器模型已下载到正确位置" + exit 1 +fi + +# 创建输出目录 +OUTPUT_DIR=$(dirname "$OUTPUT_PREFIX") +mkdir -p "$OUTPUT_DIR" + +# 构建命令 +CMD="python tools/preprocess_data.py" +CMD="$CMD --input '$INPUT_PATH'" +CMD="$CMD --workers $WORKERS" +CMD="$CMD --partitions $PARTITIONS" +CMD="$CMD --output-prefix $OUTPUT_PREFIX" +CMD="$CMD --tokenizer-type $TOKENIZER_TYPE" +CMD="$CMD --tokenizer-model $TOKENIZER_MODEL" + +# 添加可选参数 +if [ "$APPEND_EOD" = "true" ]; then + CMD="$CMD --append-eod" +fi + +if [ "$SEQUENCE_LENGTH" != "2048" ]; then + CMD="$CMD --seq-length $SEQUENCE_LENGTH" +fi + +if [ "$OVERWRITE" = "true" ]; then + CMD="$CMD --overwrite" +fi + +echo "" +echo "执行命令:" +echo "$CMD" +echo "" + +# 记录开始时间 +START_TIME=$(date +%s) +echo "开始处理时间: $(date)" + +# 执行命令 +eval $CMD + +# 检查执行结果 +if [ $? -eq 0 ]; then + END_TIME=$(date +%s) + DURATION=$((END_TIME - START_TIME)) + echo "" + echo "✅ 数据处理完成!" + echo "处理时间: ${DURATION}秒" + echo "完成时间: $(date)" + + # 显示输出文件信息 + echo "" + echo "输出文件:" + ls -lh "${OUTPUT_PREFIX}"* 2>/dev/null || echo "未找到输出文件" + + # 显示文件统计信息 + if [ -f "${OUTPUT_PREFIX}.bin" ]; then + echo "" + echo "文件统计:" + echo " .bin文件大小: $(du -h "${OUTPUT_PREFIX}.bin" | cut -f1)" + echo " .idx文件大小: $(du -h "${OUTPUT_PREFIX}.idx" | cut -f1)" + fi + +else + echo "" + echo "❌ 数据处理失败!" + echo "请检查错误信息并重试" + exit 1 +fi + +echo "" +echo "=== 处理完成 ===" +echo "输出文件前缀: $OUTPUT_PREFIX" +echo "可用于训练的数据文件已生成" diff --git a/script/data_processing/process_wikipedia_data.sh b/script/data_processing/process_wikipedia_data.sh new file mode 100755 index 0000000000..ae79db29ea --- /dev/null +++ b/script/data_processing/process_wikipedia_data.sh @@ -0,0 +1,102 @@ +#!/bin/bash +""" +Wikipedia数据处理脚本 +用于处理Wikipedia数据集 +""" + +# 设置默认参数 +INPUT_PATH=${1:-"./dataset/wikipedia/**/*.json"} +OUTPUT_PREFIX=${2:-"./dataset/wikipedia_processed"} +WORKERS=${3:-16} +PARTITIONS=${4:-4} +TOKENIZER_MODEL=${5:-"./model/llama3/"} +TOKENIZER_TYPE=${6:-"HuggingFaceTokenizer"} + +# 可选参数 +APPEND_EOD=${7:-"true"} +SEQUENCE_LENGTH=${8:-2048} +OVERWRITE=${8:-"false"} + +echo "=== Wikipedia数据处理脚本 ===" +echo "输入路径: $INPUT_PATH" +echo "输出前缀: $OUTPUT_PREFIX" +echo "工作进程数: $WORKERS" +echo "分区数: $PARTITIONS" +echo "分词器模型: $TOKENIZER_MODEL" +echo "分词器类型: $TOKENIZER_TYPE" +echo "追加EOD: $APPEND_EOD" +echo "序列长度: $SEQUENCE_LENGTH" +echo "覆盖输出: $OVERWRITE" + +# 检查输入路径 +if [ ! -d "$(dirname "$INPUT_PATH")" ]; then + echo "错误: 输入目录不存在: $(dirname "$INPUT_PATH")" + exit 1 +fi + +# 检查分词器模型 +if [ ! -d "$TOKENIZER_MODEL" ]; then + echo "错误: 分词器模型目录不存在: $TOKENIZER_MODEL" + exit 1 +fi + +# 创建输出目录 +OUTPUT_DIR=$(dirname "$OUTPUT_PREFIX") +mkdir -p "$OUTPUT_DIR" + +# 构建命令 +CMD="python tools/preprocess_data.py" +CMD="$CMD --input '$INPUT_PATH'" +CMD="$CMD --workers $WORKERS" +CMD="$CMD --partitions $PARTITIONS" +CMD="$CMD --output-prefix $OUTPUT_PREFIX" +CMD="$CMD --tokenizer-type $TOKENIZER_TYPE" +CMD="$CMD --tokenizer-model $TOKENIZER_MODEL" + +# 添加可选参数 +if [ "$APPEND_EOD" = "true" ]; then + CMD="$CMD --append-eod" +fi + +if [ "$SEQUENCE_LENGTH" != "2048" ]; then + CMD="$CMD --seq-length $SEQUENCE_LENGTH" +fi + +if [ "$OVERWRITE" = "true" ]; then + CMD="$CMD --overwrite" +fi + +echo "" +echo "执行命令:" +echo "$CMD" +echo "" + +# 记录开始时间 +START_TIME=$(date +%s) +echo "开始处理时间: $(date)" + +# 执行命令 +eval $CMD + +# 检查执行结果 +if [ $? -eq 0 ]; then + END_TIME=$(date +%s) + DURATION=$((END_TIME - START_TIME)) + echo "" + echo "✅ Wikipedia数据处理完成!" + echo "处理时间: ${DURATION}秒" + echo "完成时间: $(date)" + + # 显示输出文件信息 + echo "" + echo "输出文件:" + ls -lh "${OUTPUT_PREFIX}"* 2>/dev/null || echo "未找到输出文件" + +else + echo "" + echo "❌ Wikipedia数据处理失败!" + exit 1 +fi + +echo "" +echo "=== 处理完成 ===" diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_hifp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_hifp8.sh new file mode 100755 index 0000000000..6d64ed5899 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_FA_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_hifp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_FA_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_hifp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_hifp8.sh new file mode 100755 index 0000000000..01773c486f --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_FA_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_hifp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_FA_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_mxfp4.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_mxfp4.sh new file mode 100755 index 0000000000..89bbc67164 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_FA_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp4"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_FA_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_mxfp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_mxfp8.sh new file mode 100755 index 0000000000..234c282426 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_FA_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_FA_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_mxfp4.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_mxfp4.sh new file mode 100755 index 0000000000..3274b35f4a --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_FA_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp4"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_FA_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_mxfp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_mxfp8.sh new file mode 100755 index 0000000000..f400bcb070 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_FA_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_FA_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_FA_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_bf16.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_bf16.sh new file mode 100755 index 0000000000..b4b77ea861 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_bf16.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_bf16.sh +# Quantization Type: bf16 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_bf16"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_bf16"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] BF16 training - no quantization modification needed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_bf16_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_hifp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_hifp8.sh new file mode 100755 index 0000000000..cf1fc60dfa --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_hifp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_mxfp4.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_mxfp4.sh new file mode 100755 index 0000000000..fd3a618492 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp4"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_mxfp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_mxfp8.sh new file mode 100755 index 0000000000..156aba9cb2 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_dolma_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_dolma_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_dolma_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_hifp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_hifp8.sh new file mode 100755 index 0000000000..72c2f8ce95 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_FA_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_hifp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_FA_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_hifp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_hifp8.sh new file mode 100755 index 0000000000..a966c39e7d --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_FA_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_hifp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_FA_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp4.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp4.sh new file mode 100755 index 0000000000..dcb01d6cf1 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp4"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp8.sh new file mode 100755 index 0000000000..8e42f9a2c4 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_FA_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_mxfp4.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_mxfp4.sh new file mode 100755 index 0000000000..bc09760b7f --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_FA_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp4"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_FA_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_mxfp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_mxfp8.sh new file mode 100755 index 0000000000..75b985f9db --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_FA_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_FA_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_FA_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_bf16.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_bf16.sh new file mode 100755 index 0000000000..4259932d2b --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_bf16.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_bf16.sh +# Quantization Type: bf16 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_bf16"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_bf16"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] BF16 training - no quantization modification needed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_bf16_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_hifp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_hifp8.sh new file mode 100755 index 0000000000..13ebd6b3ef --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_hifp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_mxfp4.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_mxfp4.sh new file mode 100755 index 0000000000..0b1fd0716c --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp4"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_mxfp8.sh b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_mxfp8.sh new file mode 100755 index 0000000000..1d9675fab9 --- /dev/null +++ b/script/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for DEEPSEEK2_LITE - Updated with new pattern +# Script: pretrain_deepseek2_lite_wikipedia_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/deepseek2_lite/pretrain_deepseek2_lite_wikipedia_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/deepseek2_lite_mxfp8"} +TOKENIZER_ARG=${3:-"model/deepseek2_lite"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_deepseek2_lite_wikipedia_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_FA_hifp8.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_hifp8.sh new file mode 100755 index 0000000000..2eac801d23 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_FA_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_FA_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_FA_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_hifp8.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_hifp8.sh new file mode 100755 index 0000000000..a98b3fd189 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_FA_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_FA_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_mxfp4.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_mxfp4.sh new file mode 100755 index 0000000000..84bce99ca0 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_FA_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_FA_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_mxfp8.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_mxfp8.sh new file mode 100755 index 0000000000..48eaf06c96 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_FA_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_FA_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_FA_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_FA_mxfp4.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_mxfp4.sh new file mode 100755 index 0000000000..8a7222363a --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_FA_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_FA_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_FA_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_FA_mxfp8.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_mxfp8.sh new file mode 100755 index 0000000000..5817a782f8 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_FA_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_FA_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_FA_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_FA_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_bf16.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_bf16.sh new file mode 100755 index 0000000000..fc29d5c2c6 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_bf16.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_bf16.sh +# Quantization Type: bf16 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_bf16"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_bf16"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] BF16 training - no quantization modification needed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_bf16_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_linear_hifp8.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_linear_hifp8.sh new file mode 100755 index 0000000000..37cecaeb84 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_linear_mxfp4.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_linear_mxfp4.sh new file mode 100755 index 0000000000..73a6545d3f --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_dolma_linear_mxfp8.sh b/script/llama31-8b/pretrain_llama31-8b_dolma_linear_mxfp8.sh new file mode 100755 index 0000000000..a69882f6c9 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_dolma_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_dolma_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_dolma_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_dolma_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_QK_mxfp8.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_QK_mxfp8.sh new file mode 100755 index 0000000000..26ff813306 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_QK_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_FA_QK_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_FA_QK_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_FA_QK_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_hifp8.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_hifp8.sh new file mode 100755 index 0000000000..21b3f9c232 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_FA_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_FA_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_FA_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_hifp8.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_hifp8.sh new file mode 100755 index 0000000000..ee34487454 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_FA_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_FA_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_mxfp4.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_mxfp4.sh new file mode 100755 index 0000000000..0c1605cd3d --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_FA_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_FA_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_mxfp8.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_mxfp8.sh new file mode 100755 index 0000000000..ca3e36097d --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_FA_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_FA_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_FA_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_mxfp4.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_mxfp4.sh new file mode 100755 index 0000000000..903a85958a --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_FA_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_FA_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_FA_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_mxfp8.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_mxfp8.sh new file mode 100755 index 0000000000..aceb1bc283 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_FA_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_FA_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_FA_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_FA_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_bf16.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_bf16.sh new file mode 100755 index 0000000000..71f1c0a5df --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_bf16.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_bf16.sh +# Quantization Type: bf16 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_bf16"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_bf16"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] BF16 training - no quantization modification needed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_bf16_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_linear_hifp8.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_linear_hifp8.sh new file mode 100755 index 0000000000..3dc7df8a72 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_linear_mxfp4.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_linear_mxfp4.sh new file mode 100755 index 0000000000..37a77f39d0 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_linear_mxfp8.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_linear_mxfp8.sh new file mode 100755 index 0000000000..0bf3964609 --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama31-8b/pretrain_llama31-8b_wikipedia_mxfp8.sh b/script/llama31-8b/pretrain_llama31-8b_wikipedia_mxfp8.sh new file mode 100755 index 0000000000..63adf5900c --- /dev/null +++ b/script/llama31-8b/pretrain_llama31-8b_wikipedia_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLAMA31-8B - Updated with new pattern +# Script: pretrain_llama31-8b_wikipedia_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama31-8b/pretrain_llama31-8b_wikipedia_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama31-8b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama31-8b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama3_8b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama31-8b_wikipedia_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_FA_hifp8.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_hifp8.sh new file mode 100755 index 0000000000..ca04691e02 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_FA_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_FA_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_FA_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_FA_linear_hifp8.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_linear_hifp8.sh new file mode 100755 index 0000000000..20ac3536f9 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_FA_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_FA_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_FA_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_FA_linear_mxfp4.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_linear_mxfp4.sh new file mode 100755 index 0000000000..e9bbd06d38 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_FA_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_FA_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_FA_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_FA_linear_mxfp8.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_linear_mxfp8.sh new file mode 100755 index 0000000000..b9be887911 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_FA_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_FA_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_FA_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_FA_mxfp4.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_mxfp4.sh new file mode 100755 index 0000000000..281822af24 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_FA_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_FA_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_FA_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_FA_mxfp8.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_mxfp8.sh new file mode 100755 index 0000000000..c050a33447 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_FA_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_FA_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_FA_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_FA_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_bf16.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_bf16.sh new file mode 100755 index 0000000000..782b378853 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_bf16.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_bf16.sh +# Quantization Type: bf16 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_bf16"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_bf16"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] BF16 training - no quantization modification needed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_bf16_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_linear_hifp8.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_linear_hifp8.sh new file mode 100755 index 0000000000..84557cfaaa --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_linear_mxfp4.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_linear_mxfp4.sh new file mode 100755 index 0000000000..bb45333f82 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_dolma_linear_mxfp8.sh b/script/llama32-1b/pretrain_llama32-1b_dolma_linear_mxfp8.sh new file mode 100755 index 0000000000..6d7ced18ec --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_dolma_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_dolma_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_dolma_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_dolma_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_hifp8.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_hifp8.sh new file mode 100755 index 0000000000..aa8d2430d1 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_FA_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_FA_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_FA_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_hifp8.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_hifp8.sh new file mode 100755 index 0000000000..39d54814ad --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_FA_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_FA_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_FA_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp4.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp4.sh new file mode 100755 index 0000000000..c01ac4478a --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp4.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_FA_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +# SCALING_CONTROL=${6:-"max"} +SCALING_CONTROL=${6:-"max_minus_1"} +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp4_${SCALING_CONTROL}"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp4_${SCALING_CONTROL}"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + --scaling-control "$SCALING_CONTROL" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_FA_linear_mxfp4_${SCALING_CONTROL}_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp4_time_resume.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp4_time_resume.sh new file mode 100644 index 0000000000..579b1a01ec --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp4_time_resume.sh @@ -0,0 +1,103 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Time-Resume Adaptive Quantization +# Script: pretrain_llama32-1b_wikipedia_FA_linear_mxfp4_time_resume.sh +# Features: Adaptive quantization with time-resume capability +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +SCALING_CONTROL=${6:-"max_minus_1"} +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp4_time_resume_${SCALING_CONTROL}"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp4_time_resume_${SCALING_CONTROL}"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# Time-resume specific parameters +QUANT_LOSS_THRESHOLD=${7:-"0.1"} +QUANT_WINDOW_SIZE=${8:-"5"} +QUANT_CHECKPOINT_INTERVAL=${9:-"1"} +QUANT_FALLBACK_STRATEGY=${10:-"bf16"} +QUANT_RECOVERY_BUFFER=${11:-"2"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Time-resume adaptive quantization enabled" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Configuration:" +echo " - Scaling Control: $SCALING_CONTROL" +echo " - Loss Threshold: $QUANT_LOSS_THRESHOLD" +echo " - Window Size: $QUANT_WINDOW_SIZE" +echo " - Checkpoint Interval: $QUANT_CHECKPOINT_INTERVAL" +echo " - Fallback Strategy: $QUANT_FALLBACK_STRATEGY" +echo " - Recovery Buffer: $QUANT_RECOVERY_BUFFER" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# Set up logging paths +HOST_TENSORBOARD_LOGS_PATH="${TENSORBOARD_LOGS_PATH}_logs" +mkdir -p "$HOST_TENSORBOARD_LOGS_PATH" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with time-resume parameters +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + --scaling-control "$SCALING_CONTROL" \ + --time-resume \ + --quant-loss-threshold "$QUANT_LOSS_THRESHOLD" \ + --quant-window-size "$QUANT_WINDOW_SIZE" \ + --quant-checkpoint-interval "$QUANT_CHECKPOINT_INTERVAL" \ + --quant-fallback-strategy "$QUANT_FALLBACK_STRATEGY" \ + --quant-recovery-buffer "$QUANT_RECOVERY_BUFFER" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_FA_linear_mxfp4_time_resume_${SCALING_CONTROL}_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +END_TIME=$(date '+%Y-%m-%d %H:%M:%S') +END_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Training completed" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Start time: $START_TIME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] End time: $END_TIME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Exit code: $TRAINING_EXIT_CODE" + +if [ $TRAINING_EXIT_CODE -eq 0 ]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Logs saved to: $HOST_TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Check logs for details: $HOST_TENSORBOARD_LOGS_PATH" +fi + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Script finished: $SCRIPT_NAME" + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp8.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp8.sh new file mode 100755 index 0000000000..be97a00fcd --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_FA_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_FA_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_FA_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_mxfp4.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_mxfp4.sh new file mode 100755 index 0000000000..deb7d65b4c --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_FA_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_FA_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_FA_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_mxfp8.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_mxfp8.sh new file mode 100755 index 0000000000..1063808f73 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_FA_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_FA_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_FA_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_FA_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_bf16.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_bf16.sh new file mode 100755 index 0000000000..5d5cd3e2af --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_bf16.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_bf16.sh +# Quantization Type: bf16 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_bf16"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_bf16"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] BF16 training - no quantization modification needed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_bf16_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_linear_hifp8.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_linear_hifp8.sh new file mode 100755 index 0000000000..0aa6785e59 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_linear_hifp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_linear_hifp8.sh +# Quantization Type: hifp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_linear_hifp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_hifp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to hifp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'hifp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_linear_hifp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_linear_mxfp4.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_linear_mxfp4.sh new file mode 100755 index 0000000000..851ba3a8e3 --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_linear_mxfp4.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_linear_mxfp4.sh +# Quantization Type: mxfp4 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_linear_mxfp4"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp4"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp4..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp4'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_linear_mxfp4_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/llama32-1b/pretrain_llama32-1b_wikipedia_linear_mxfp8.sh b/script/llama32-1b/pretrain_llama32-1b_wikipedia_linear_mxfp8.sh new file mode 100755 index 0000000000..cffa22aa6b --- /dev/null +++ b/script/llama32-1b/pretrain_llama32-1b_wikipedia_linear_mxfp8.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# ============================================================================= +# Training Script for LLaMA 3.2 1B - Updated with new pattern +# Script: pretrain_llama32-1b_wikipedia_linear_mxfp8.sh +# Quantization Type: mxfp8 +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${1:-"checkpoints/llama32_1b/pretrain_llama32-1b_wikipedia_linear_mxfp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama32_1b_mxfp8"} +TOKENIZER_ARG=${3:-"model/llama3.2-1b"} +DATA_ARG=${4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"} +DTYPE=${5:-"bf16"} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to mxfp8..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'mxfp8'/" \ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed" + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash examples/llama/train_llama32_1b_h100_fp8.sh \ + "$CHECKPOINT_PATH" \ + "$TENSORBOARD_LOGS_PATH" \ + "$TOKENIZER_ARG" \ + "$DATA_ARG" \ + "$DTYPE" \ + 2>&1 | tee "${HOST_TENSORBOARD_LOGS_PATH}/training_pretrain_llama32-1b_wikipedia_linear_mxfp8_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE diff --git a/script/navigate.sh b/script/navigate.sh new file mode 100755 index 0000000000..0fd7ac3c19 --- /dev/null +++ b/script/navigate.sh @@ -0,0 +1,275 @@ +#!/bin/bash +""" +Script目录导航脚本 +提供快速访问各个功能模块的便捷方式 +""" + +# 颜色定义 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +PURPLE='\033[0;35m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color + +# 获取脚本目录 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# 显示标题 +show_title() { + echo -e "${CYAN}========================================${NC}" + echo -e "${CYAN} Megatron-LM Script 导航工具${NC}" + echo -e "${CYAN}========================================${NC}" + echo "" +} + +# 显示菜单 +show_menu() { + echo -e "${GREEN}请选择要访问的功能模块:${NC}" + echo "" + echo -e "${YELLOW}1.${NC} 数据处理 (data_processing)" + echo -e "${YELLOW}2.${NC} 可视化工具 (visualization)" + echo -e "${YELLOW}3.${NC} 工具脚本 (utils)" + echo -e "${YELLOW}4.${NC} 模板文件 (templates)" + echo -e "${YELLOW}5.${NC} 训练脚本 (training)" + echo -e "${YELLOW}6.${NC} 查看目录结构" + echo -e "${YELLOW}7.${NC} 快速命令" + echo -e "${YELLOW}8.${NC} 帮助信息" + echo -e "${YELLOW}0.${NC} 退出" + echo "" +} + +# 显示数据处理菜单 +show_data_processing_menu() { + echo -e "${BLUE}=== 数据处理模块 ===${NC}" + echo "" + echo -e "${GREEN}可用脚本:${NC}" + echo "1. process_dolma_data.sh - 处理Dolma数据集" + echo "2. process_wikipedia_data.sh - 处理Wikipedia数据集" + echo "3. process_c4_data.sh - 处理C4数据集" + echo "4. process_custom_data.sh - 处理自定义数据集" + echo "5. data_processing_utils.py - 数据处理工具" + echo "6. 返回主菜单" + echo "" + echo -e "${YELLOW}快速命令:${NC}" + echo "cd data_processing && ls -la" + echo "cd data_processing && ./process_dolma_data.sh" + echo "cd data_processing && python data_processing_utils.py --action check" +} + +# 显示可视化菜单 +show_visualization_menu() { + echo -e "${BLUE}=== 可视化模块 ===${NC}" + echo "" + echo -e "${GREEN}可用脚本:${NC}" + echo "1. visualize_tensors.py - 完整tensor可视化" + echo "2. quick_visualize.py - 快速可视化" + echo "3. one_click_visualize.sh - 一键可视化" + echo "4. 返回主菜单" + echo "" + echo -e "${YELLOW}快速命令:${NC}" + echo "cd visualization && ls -la" + echo "cd visualization && ./one_click_visualize.sh" + echo "cd visualization && python quick_visualize.py" +} + +# 显示工具菜单 +show_utils_menu() { + echo -e "${BLUE}=== 工具模块 ===${NC}" + echo "" + echo -e "${GREEN}可用脚本:${NC}" + echo "1. quant_type_modifier.py - 量化类型修改工具" + echo "2. update_scripts_with_pattern_v2.py - 脚本模式更新工具" + echo "3. 返回主菜单" + echo "" + echo -e "${YELLOW}快速命令:${NC}" + echo "cd utils && ls -la" + echo "cd utils && python quant_type_modifier.py --help" + echo "cd utils && python update_scripts_with_pattern_v2.py --help" +} + +# 显示模板菜单 +show_templates_menu() { + echo -e "${BLUE}=== 模板模块 ===${NC}" + echo "" + echo -e "${GREEN}可用文件:${NC}" + echo "1. improved_script_template.sh - 改进的训练脚本模板" + echo "2. 返回主菜单" + echo "" + echo -e "${YELLOW}快速命令:${NC}" + echo "cd templates && ls -la" + echo "cd templates && cp improved_script_template.sh ../my_training.sh" +} + +# 显示训练脚本菜单 +show_training_menu() { + echo -e "${BLUE}=== 训练脚本模块 ===${NC}" + echo "" + echo -e "${GREEN}可用模型:${NC}" + echo "1. llama32-1b - LLaMA 3.2 1B模型" + echo "2. llama31-8b - LLaMA 3.1 8B模型" + echo "3. deepseek2_lite - DeepSeek2 Lite模型" + echo "4. 返回主菜单" + echo "" + echo -e "${YELLOW}快速命令:${NC}" + echo "cd training && ls -la" + echo "cd training/llama32-1b && ls -la" + echo "cd training/llama32-1b && ./pretrain_llama32-1b_dolma_hifp8.sh" +} + +# 显示目录结构 +show_directory_structure() { + echo -e "${BLUE}=== 目录结构 ===${NC}" + echo "" + echo "script/" + echo "├── data_processing/ # 数据处理脚本" + echo "│ ├── process_dolma_data.sh" + echo "│ ├── process_wikipedia_data.sh" + echo "│ ├── process_c4_data.sh" + echo "│ ├── process_custom_data.sh" + echo "│ ├── data_processing_utils.py" + echo "│ └── README.md" + echo "├── visualization/ # 可视化脚本" + echo "│ ├── visualize_tensors.py" + echo "│ ├── quick_visualize.py" + echo "│ ├── one_click_visualize.sh" + echo "│ └── README.md" + echo "├── utils/ # 工具脚本" + echo "│ ├── quant_type_modifier.py" + echo "│ ├── update_scripts_with_pattern_v2.py" + echo "│ └── README.md" + echo "├── templates/ # 模板文件" + echo "│ ├── improved_script_template.sh" + echo "│ └── README.md" + echo "├── training/ # 训练脚本(按模型分类)" + echo "│ ├── llama32-1b/" + echo "│ ├── llama31-8b/" + echo "│ └── deepseek2_lite/" + echo "├── navigate.sh # 本导航脚本" + echo "└── README.md # 主说明文档" + echo "" +} + +# 显示快速命令 +show_quick_commands() { + echo -e "${BLUE}=== 快速命令 ===${NC}" + echo "" + echo -e "${GREEN}数据处理:${NC}" + echo " ./data_processing/process_dolma_data.sh" + echo " python data_processing/data_processing_utils.py --action check" + echo "" + echo -e "${GREEN}可视化:${NC}" + echo " ./visualization/one_click_visualize.sh" + echo " python visualization/quick_visualize.py" + echo "" + echo -e "${GREEN}工具:${NC}" + echo " python utils/quant_type_modifier.py --help" + echo " python utils/update_scripts_with_pattern_v2.py" + echo "" + echo -e "${GREEN}训练:${NC}" + echo " cp templates/improved_script_template.sh my_training.sh" + echo " ./training/llama32-1b/pretrain_llama32-1b_dolma_hifp8.sh" + echo "" +} + +# 显示帮助信息 +show_help() { + echo -e "${BLUE}=== 帮助信息 ===${NC}" + echo "" + echo -e "${GREEN}使用方法:${NC}" + echo "1. 运行此脚本: ./navigate.sh" + echo "2. 选择要访问的功能模块" + echo "3. 按照提示操作" + echo "" + echo -e "${GREEN}各模块功能:${NC}" + echo "• data_processing: 处理各种数据集,转换为训练格式" + echo "• visualization: 可视化tensor数据,分析量化效果" + echo "• utils: 批量修改脚本,管理量化类型" + echo "• templates: 训练脚本模板,快速创建新脚本" + echo "• training: 按模型分类的训练脚本" + echo "" + echo -e "${GREEN}环境要求:${NC}" + echo "• Python 3.8+" + echo "• PyTorch 1.12+" + echo "• CUDA 11.0+" + echo "• 必要的Python包: matplotlib, seaborn, pandas, scipy" + echo "" + echo -e "${GREEN}环境变量:${NC}" + echo "export CUSTOM_QUANT_TYPE=\"hifp8\"" + echo "export TENSOR_SAVE_DIR=\"./enhanced_tensor_logs\"" + echo "export TENSOR_SAVE_ENABLED=\"true\"" + echo "" +} + +# 处理用户选择 +handle_choice() { + local choice=$1 + + case $choice in + 1) + show_data_processing_menu + echo -e "${YELLOW}按回车键返回主菜单...${NC}" + read + ;; + 2) + show_visualization_menu + echo -e "${YELLOW}按回车键返回主菜单...${NC}" + read + ;; + 3) + show_utils_menu + echo -e "${YELLOW}按回车键返回主菜单...${NC}" + read + ;; + 4) + show_templates_menu + echo -e "${YELLOW}按回车键返回主菜单...${NC}" + read + ;; + 5) + show_training_menu + echo -e "${YELLOW}按回车键返回主菜单...${NC}" + read + ;; + 6) + show_directory_structure + echo -e "${YELLOW}按回车键返回主菜单...${NC}" + read + ;; + 7) + show_quick_commands + echo -e "${YELLOW}按回车键返回主菜单...${NC}" + read + ;; + 8) + show_help + echo -e "${YELLOW}按回车键返回主菜单...${NC}" + read + ;; + 0) + echo -e "${GREEN}感谢使用!再见!${NC}" + exit 0 + ;; + *) + echo -e "${RED}无效选择,请重新输入!${NC}" + ;; + esac +} + +# 主循环 +main() { + while true; do + clear + show_title + show_menu + echo -e -n "${GREEN}请输入选择 (0-8): ${NC}" + read choice + handle_choice $choice + done +} + +# 检查是否直接运行 +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + main +fi diff --git a/script/templates/README.md b/script/templates/README.md new file mode 100644 index 0000000000..87759247c6 --- /dev/null +++ b/script/templates/README.md @@ -0,0 +1,132 @@ +# 模板文件 + +这个目录包含了各种脚本模板文件。 + +## 模板文件 + +### 脚本模板 +- **`improved_script_template.sh`** - 改进的训练脚本模板 + +## 模板说明 + +### 1. 改进的训练脚本模板 (improved_script_template.sh) + +这是一个功能完整的训练脚本模板,包含以下特性: + +#### 核心功能 +- **环境变量管理**: 自动设置和导出必要的环境变量 +- **量化类型控制**: 动态修改量化类型设置 +- **日志记录**: 带时间戳的详细日志记录 +- **错误处理**: 完善的错误检查和处理机制 +- **参数验证**: 输入参数的验证和默认值设置 + +#### 主要特性 +1. **HOST_TENSORBOARD_LOGS_PATH**: 自动设置tensorboard日志路径 +2. **量化类型修改**: 使用sed命令动态修改量化类型 +3. **命令构建**: 智能构建训练命令 +4. **日志捕获**: 使用tee命令捕获训练日志 +5. **时间戳**: 自动添加时间戳到日志文件名 + +#### 使用方法 +```bash +# 直接使用模板 +cp improved_script_template.sh my_training_script.sh +chmod +x my_training_script.sh + +# 修改参数后运行 +./my_training_script.sh +``` + +#### 参数说明 +脚本接受以下参数: +1. **CHECKPOINT_PATH** - 检查点保存路径 +2. **TENSORBOARD_LOGS_PATH** - TensorBoard日志路径 +3. **MODEL_PATH** - 模型路径 +4. **DATA_PATH** - 数据路径 +5. **PRECISION** - 精度类型 (bf16, fp16, fp8等) + +#### 环境变量 +- `HOST_TENSORBOARD_LOGS_PATH`: TensorBoard日志路径 +- `CUSTOM_QUANT_TYPE`: 自定义量化类型 +- `TENSOR_SAVE_DIR`: Tensor保存目录 +- `TENSOR_SAVE_ENABLED`: 是否启用tensor保存 + +## 模板定制 + +### 1. 修改默认参数 +```bash +# 在模板中修改默认值 +DEFAULT_CHECKPOINT_PATH="./checkpoints/my_model" +DEFAULT_TENSORBOARD_LOGS_PATH="./tensorboard_logs/my_model" +DEFAULT_MODEL_PATH="./model/my_model" +DEFAULT_DATA_PATH="./dataset/my_data" +DEFAULT_PRECISION="bf16" +``` + +### 2. 添加新的量化类型 +```bash +# 在modify_quantization_types函数中添加 +if [ "$QUANT_TYPE" = "new_quant_type" ]; then + sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'new_quant_type'/" \ + megatron/core/tensor_parallel/layers.py + sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'new_quant_type'/" \ + megatron/core/transformer/dot_product_attention.py +fi +``` + +### 3. 修改训练命令 +```bash +# 在build_and_run_command函数中修改 +TRAINING_CMD="bash examples/my_model/train_my_model.sh" +``` + +## 最佳实践 + +### 1. 脚本命名 +- 使用描述性的文件名 +- 包含模型名称、数据集、量化类型等信息 +- 例如: `pretrain_llama32-1b_dolma_hifp8.sh` + +### 2. 参数验证 +- 检查必需参数是否存在 +- 验证路径是否有效 +- 设置合理的默认值 + +### 3. 错误处理 +- 使用`set -e`在错误时退出 +- 检查命令执行结果 +- 提供清晰的错误信息 + +### 4. 日志记录 +- 记录开始和结束时间 +- 保存命令执行日志 +- 使用时间戳命名日志文件 + +## 扩展模板 + +### 1. 创建新的模板 +```bash +# 基于现有模板创建新模板 +cp improved_script_template.sh new_template.sh + +# 修改新模板的内容 +# 添加新的功能或修改现有功能 +``` + +### 2. 模板版本管理 +- 使用版本号管理模板 +- 记录模板的修改历史 +- 保持向后兼容性 + +### 3. 模板测试 +- 使用小数据集测试模板 +- 验证所有功能正常工作 +- 检查日志输出是否正确 + +## 注意事项 + +1. **权限设置**: 确保模板文件有执行权限 +2. **路径检查**: 验证所有路径是否正确 +3. **环境依赖**: 确保所需环境已正确设置 +4. **资源限制**: 根据系统资源调整参数 +5. **备份策略**: 定期备份重要的模板文件 diff --git a/script/templates/improved_script_template.sh b/script/templates/improved_script_template.sh new file mode 100755 index 0000000000..20aafd0496 --- /dev/null +++ b/script/templates/improved_script_template.sh @@ -0,0 +1,435 @@ +#!/bin/bash + +# ============================================================================= +# Enhanced Training Script Template with Timestamped Logging +# ============================================================================= + +# Set script name and version +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Logging Functions +# ============================================================================= + +# Function to get current timestamp +get_timestamp() { + date '+%Y-%m-%d %H:%M:%S' +} + +# Function to log with timestamp +log_info() { + echo "[$(get_timestamp)] [INFO] $*" +} + +log_warn() { + echo "[$(get_timestamp)] [WARN] $*" >&2 +} + +log_error() { + echo "[$(get_timestamp)] [ERROR] $*" >&2 +} + +log_success() { + echo "[$(get_timestamp)] [SUCCESS] $*" +} + +# Function to log script execution +log_script_start() { + log_info "==========================================" + log_info "Script: $SCRIPT_NAME v$SCRIPT_VERSION" + log_info "Started at: $START_TIME" + log_info "PID: $$" + log_info "User: $(whoami)" + log_info "Host: $(hostname)" + log_info "Working Directory: $(pwd)" + log_info "==========================================" +} + +log_script_end() { + local end_time=$(get_timestamp) + local duration=$(($(date +%s) - $(date -d "$START_TIME" +%s))) + log_info "==========================================" + log_info "Script completed at: $end_time" + log_info "Total execution time: ${duration}s" + log_info "==========================================" +} + +# Function to log configuration +log_config() { + log_info "Configuration Summary:" + log_info " Model: $MODEL_NAME" + log_info " Dataset: $DATASET_NAME" + log_info " Quantization: $QUANT_TYPE" + log_info " Experiment: $EXPERIMENT_NAME" + log_info " Checkpoint Path: $CHECKPOINT_PATH" + log_info " TensorBoard Path: $TENSORBOARD_PATH" + log_info " Data Path: $DATA_PATH" + log_info " Tokenizer Path: $TOKENIZER_PATH" +} + +# Function to log training parameters +log_training_params() { + log_info "Training Parameters:" + log_info " Micro Batch Size: $MICRO_BATCH_SIZE" + log_info " Global Batch Size: $GLOBAL_BATCH_SIZE" + log_info " Sequence Length: $SEQ_LENGTH" + log_info " Learning Rate: $LR" + log_info " Min Learning Rate: $MIN_LR" + log_info " Train Samples: $TRAIN_SAMPLES" + log_info " Exit Duration: ${EXIT_DURATION_MINS} minutes" +} + +# Function to log model parameters +log_model_params() { + log_info "Model Parameters:" + log_info " Tensor Parallel Size: $TP_SIZE" + log_info " Context Parallel Size: $CP_SIZE" + log_info " Pipeline Parallel Size: $PP_SIZE" + log_info " Number of Layers: $NUM_LAYERS" + log_info " Hidden Size: $HIDDEN_SIZE" + log_info " FFN Hidden Size: $FFN_HIDDEN_SIZE" + log_info " Number of Attention Heads: $NUM_ATTENTION_HEADS" + log_info " Number of Query Groups: $NUM_QUERY_GROUPS" + log_info " KV Channels: $KV_CHANNELS" + log_info " Rotary Base: $ROTARY_BASE" + log_info " Vocabulary Size: $VOCAB_SIZE" +} + +# Function to validate paths +validate_paths() { + log_info "Validating paths..." + + # Check if data path exists + if [[ ! -d "$(dirname "$DATA_PATH")" ]]; then + log_error "Data directory does not exist: $(dirname "$DATA_PATH")" + return 1 + fi + + # Check if tokenizer path exists + if [[ ! -d "$TOKENIZER_PATH" ]]; then + log_error "Tokenizer directory does not exist: $TOKENIZER_PATH" + return 1 + fi + + # Create checkpoint directory if it doesn't exist + if [[ ! -d "$(dirname "$CHECKPOINT_PATH")" ]]; then + log_info "Creating checkpoint directory: $(dirname "$CHECKPOINT_PATH")" + mkdir -p "$(dirname "$CHECKPOINT_PATH")" + fi + + # Create tensorboard directory if it doesn't exist + if [[ ! -d "$(dirname "$TENSORBOARD_PATH")" ]]; then + log_info "Creating tensorboard directory: $(dirname "$TENSORBOARD_PATH")" + mkdir -p "$(dirname "$TENSORBOARD_PATH")" + fi + + log_success "Path validation completed" + return 0 +} + +# Function to check system resources +check_system_resources() { + log_info "Checking system resources..." + + # Check available memory + local available_memory=$(free -h | awk '/^Mem:/ {print $7}') + log_info "Available memory: $available_memory" + + # Check disk space + local disk_usage=$(df -h . | awk 'NR==2 {print $4}') + log_info "Available disk space: $disk_usage" + + # Check GPU availability + if command -v nvidia-smi &> /dev/null; then + log_info "GPU Information:" + nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits | while read line; do + log_info " $line" + done + else + log_warn "nvidia-smi not found, GPU information unavailable" + fi + + log_success "System resource check completed" +} + +# Function to modify quantization types using sed +modify_quantization_types() { + local quant_type="$1" + + log_info "Modifying quantization types to: $quant_type" + + # Modify linear layer quantization + if [[ -f "megatron/core/tensor_parallel/layers.py" ]]; then + sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'$quant_type'/" \ + megatron/core/tensor_parallel/layers.py + log_info "Modified linear layer quantization in layers.py" + else + log_warn "layers.py not found, skipping linear quantization modification" + fi + + # Modify attention quantization + if [[ -f "megatron/core/transformer/dot_product_attention.py" ]]; then + sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'$quant_type'/" \ + megatron/core/transformer/dot_product_attention.py + log_info "Modified attention quantization in dot_product_attention.py" + else + log_warn "dot_product_attention.py not found, skipping attention quantization modification" + fi + + log_success "Quantization type modifications completed" +} + +# Function to build and run command with enhanced logging +build_and_run_command() { + local training_config="$1" + local dry_run="$2" + + log_info "Building training command..." + log_info "Training config: $training_config" + log_info "Dry run: $dry_run" + + local cmd=( + "torchrun" "--nproc_per_node" "8" + "--nnodes" "1" "--node_rank" "0" + "--master_addr" "localhost" "--master_port" "6000" + "pretrain_gpt.py" + + # Model arguments + "--use-mcore-models" + "--num-layers" "$NUM_LAYERS" + "--hidden-size" "$HIDDEN_SIZE" + "--ffn-hidden-size" "$FFN_HIDDEN_SIZE" + "--num-attention-heads" "$NUM_ATTENTION_HEADS" + "--group-query-attention" + "--num-query-groups" "$NUM_QUERY_GROUPS" + "--kv-channels" "$KV_CHANNELS" + "--seq-length" "$SEQ_LENGTH" + "--max-position-embeddings" "$SEQ_LENGTH" + "--position-embedding-type" "rope" + "--rotary-base" "$ROTARY_BASE" + "--rotary-percent" "1.0" + "--attention-dropout" "0.0" + "--hidden-dropout" "0.0" + "--swiglu" + "--init-method-std" "0.0134" + "--attention-backend" "fused" + "--apply-layernorm-1p" + "--untie-embeddings-and-output-weights" + "--disable-bias-linear" + + # Training arguments + "--micro-batch-size" "$MICRO_BATCH_SIZE" + "--global-batch-size" "$GLOBAL_BATCH_SIZE" + "--train-samples" "$TRAIN_SAMPLES" + "--lr" "$LR" + "--min-lr" "$MIN_LR" + "--lr-decay-style" "cosine" + "--clip-grad" "1.0" + "--weight-decay" "0.1" + "--adam-beta1" "0.9" + "--adam-beta2" "0.95" + "--bf16" + "--grad-reduce-in-bf16" + "--cross-entropy-loss-fusion" + "--calculate-per-token-loss" + "--manual-gc" + "--empty-unused-memory-level" "1" + "--exit-duration-in-mins" "$EXIT_DURATION_MINS" + "--use-distributed-optimizer" + "--overlap-grad-reduce" + "--overlap-param-gather" + + # Model parallelism + "--tensor-model-parallel-size" "$TP_SIZE" + "--context-parallel-size" "$CP_SIZE" + "--pipeline-model-parallel-size" "$PP_SIZE" + "--sequence-parallel" + + # Data arguments + "--data-path" "$DATA_PATH" + "--tokenizer-type" "HuggingFaceTokenizer" + "--tokenizer-model" "$TOKENIZER_PATH" + "--vocab-size" "$VOCAB_SIZE" + "--split" "99,1,0" + "--no-create-attention-mask-in-dataloader" + "--num-workers" "1" + + # Logging and checkpointing + "--log-interval" "1" + "--eval-iters" "32" + "--eval-interval" "100" + "--save-interval" "1000" + "--log-throughput" + "--ckpt-format" "torch_dist" + "--distributed-timeout-minutes" "60" + "--save" "$CHECKPOINT_PATH" + "--load" "$CHECKPOINT_PATH" + "--tensorboard-dir" "$TENSORBOARD_PATH" + ) + + # Add quantization arguments + if [[ "$QUANT_TYPE" == "fp8" ]]; then + log_info "Adding FP8 quantization arguments..." + cmd+=( + "--fp8-format" "$FP8_FORMAT" + "--fp8-amax-history-len" "1024" + "--fp8-amax-compute-algo" "max" + ) + + if [[ "$LINEAR_QUANT" != "None" ]]; then + log_info "Adding linear quantization: $LINEAR_QUANT" + cmd+=("--linear-quantization" "$LINEAR_QUANT") + fi + if [[ "$ATTENTION_QUANT" != "None" ]]; then + log_info "Adding attention quantization: $ATTENTION_QUANT" + cmd+=("--attention-quantization" "$ATTENTION_QUANT") + fi + fi + + if [[ "$dry_run" == true ]]; then + log_info "Dry run mode - showing command:" + echo "${cmd[*]}" + log_success "Command generation completed (dry run)" + else + log_info "Starting training execution..." + log_info "Command: ${cmd[*]}" + + # Set up tensorboard logs path and timestamped logging + export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_PATH" + local log_file="${HOST_TENSORBOARD_LOGS_PATH}/training_${EXPERIMENT_NAME}_$(date +'%y-%m-%d_%H-%M-%S').log" + + log_info "Training logs will be saved to: $log_file" + + # Execute the command with timestamped logging + if "${cmd[@]}" 2>&1 | tee "$log_file"; then + log_success "Training completed successfully" + else + log_error "Training failed with exit code $?" + return 1 + fi + fi +} + +# Function to show usage with enhanced formatting +show_usage() { + cat << EOF +[$(get_timestamp)] [INFO] ========================================== +[$(get_timestamp)] [INFO] Training Script: $SCRIPT_NAME v$SCRIPT_VERSION +[$(get_timestamp)] [INFO] ========================================== + +Usage: $0 [OPTIONS] + +Training script for ${MODEL_NAME} on ${DATASET_NAME} with ${QUANT_TYPE} quantization. + +Options: + --dry-run Show command without executing + --training-config CONFIG Training configuration (standard|fast) + --help Show this help message + +Examples: + $0 # Run with default settings + $0 --dry-run # Show command without executing + $0 --training-config fast # Run with fast configuration + +Configuration: + Model: $MODEL_NAME + Dataset: $DATASET_NAME + Quantization: $QUANT_TYPE + Experiment: $EXPERIMENT_NAME + +[$(get_timestamp)] [INFO] ========================================== +EOF +} + +# Function to parse arguments with enhanced logging +parse_arguments() { + local training_config="standard" + local dry_run=false + + while [[ $# -gt 0 ]]; do + case $1 in + --dry-run) + dry_run=true + log_info "Dry run mode enabled" + shift + ;; + --training-config) + training_config="$2" + log_info "Training config set to: $training_config" + shift 2 + ;; + --help) + show_usage + exit 0 + ;; + *) + log_error "Unknown option: $1" + show_usage + exit 1 + ;; + esac + done + + log_info "Arguments parsed successfully" + log_info "Training config: $training_config" + log_info "Dry run: $dry_run" +} + +# Main function with comprehensive logging +main() { + # Set up error handling + set -e + trap 'log_error "Script failed at line $LINENO"' ERR + trap 'log_script_end' EXIT + + # Start logging + log_script_start + + # Parse arguments + parse_arguments "$@" + + # Log configuration + log_config + log_model_params + log_training_params + + # Validate paths + if ! validate_paths; then + log_error "Path validation failed" + exit 1 + fi + + # Check system resources + check_system_resources + + # Modify quantization types if needed + if [[ "$QUANT_TYPE" != "bf16" ]] && [[ "$QUANT_TYPE" != "fp16" ]]; then + # Extract quantization type from QUANT_TYPE (e.g., "linear_hifp8" -> "hifp8") + local quant_type_to_use="" + if [[ "$QUANT_TYPE" == *"hifp8"* ]]; then + quant_type_to_use="hifp8" + elif [[ "$QUANT_TYPE" == *"mxfp8"* ]]; then + quant_type_to_use="mxfp8" + elif [[ "$QUANT_TYPE" == *"mxfp4"* ]]; then + quant_type_to_use="mxfp4" + fi + + if [[ -n "$quant_type_to_use" ]]; then + modify_quantization_types "$quant_type_to_use" + fi + fi + + # Build and run command + if ! build_and_run_command "$training_config" "$dry_run"; then + log_error "Command execution failed" + exit 1 + fi + + log_success "Script completed successfully" +} + +# Run main function with all arguments +main "$@" diff --git a/script/utils/README.md b/script/utils/README.md new file mode 100644 index 0000000000..c760822e5f --- /dev/null +++ b/script/utils/README.md @@ -0,0 +1,171 @@ +# 工具脚本 + +这个目录包含了各种辅助工具和实用脚本。 + +## 脚本文件 + +### 工具脚本 +- **`quant_type_modifier.py`** - 量化类型修改工具 +- **`update_scripts_with_pattern_v2.py`** - 脚本模式更新工具 + +## 功能说明 + +### 1. 量化类型修改工具 (quant_type_modifier.py) + +用于批量修改脚本中的量化类型设置。 + +#### 功能特性 +- 批量修改多个脚本文件 +- 支持多种量化类型 (hifp8, mxfp8, mxfp4, bf16, fp16) +- 自动备份原文件 +- 支持正则表达式匹配 + +#### 使用方法 +```bash +# 修改单个文件 +python quant_type_modifier.py \ + --file script/llama32-1b/pretrain_llama32-1b_dolma_bf16.sh \ + --old_quant_type bf16 \ + --new_quant_type hifp8 + +# 批量修改目录下所有文件 +python quant_type_modifier.py \ + --directory script/llama32-1b/ \ + --old_quant_type bf16 \ + --new_quant_type hifp8 + +# 使用正则表达式匹配 +python quant_type_modifier.py \ + --directory script/ \ + --pattern ".*_bf16\.sh$" \ + --new_quant_type hifp8 +``` + +#### 参数说明 +- `--file`: 要修改的单个文件路径 +- `--directory`: 要修改的目录路径 +- `--pattern`: 文件匹配的正则表达式 +- `--old_quant_type`: 旧的量化类型 +- `--new_quant_type`: 新的量化类型 +- `--backup`: 是否创建备份文件 (默认: true) +- `--dry_run`: 预览模式,不实际修改文件 + +### 2. 脚本模式更新工具 (update_scripts_with_pattern_v2.py) + +用于批量更新训练脚本,应用统一的模式。 + +#### 功能特性 +- 自动识别脚本类型和量化类型 +- 应用统一的脚本模式 +- 支持多种模型类型 (llama32-1b, llama31-8b, deepseek2_lite) +- 自动生成新的脚本文件 + +#### 使用方法 +```bash +# 更新所有脚本 +python update_scripts_with_pattern_v2.py + +# 更新特定目录 +python update_scripts_with_pattern_v2.py \ + --target_dir script/llama32-1b/ + +# 预览模式 +python update_scripts_with_pattern_v2.py \ + --dry_run +``` + +#### 参数说明 +- `--target_dir`: 目标目录路径 +- `--template_file`: 模板文件路径 +- `--dry_run`: 预览模式,不实际修改文件 +- `--backup`: 是否创建备份文件 + +## 使用场景 + +### 1. 量化类型切换 +当需要批量切换量化类型时: +```bash +# 将所有bf16脚本改为hifp8 +python quant_type_modifier.py \ + --directory script/ \ + --old_quant_type bf16 \ + --new_quant_type hifp8 +``` + +### 2. 脚本模式统一 +当需要应用新的脚本模式时: +```bash +# 应用新的脚本模式 +python update_scripts_with_pattern_v2.py +``` + +### 3. 批量文件操作 +当需要对大量文件进行相同操作时: +```bash +# 批量修改特定模式的文件 +python quant_type_modifier.py \ + --directory script/ \ + --pattern ".*_mxfp4\.sh$" \ + --new_quant_type mxfp8 +``` + +## 注意事项 + +### 1. 备份文件 +- 工具会自动创建备份文件 +- 备份文件以`.bak`后缀命名 +- 建议在批量操作前手动备份重要文件 + +### 2. 文件权限 +- 确保脚本有读取和写入权限 +- 使用`chmod +x`设置执行权限 + +### 3. 正则表达式 +- 使用正确的正则表达式语法 +- 测试模式匹配是否正确 +- 使用`--dry_run`预览结果 + +### 4. 错误处理 +- 检查文件路径是否正确 +- 确保目标文件存在 +- 查看错误日志定位问题 + +## 故障排除 + +### 常见问题 +1. **文件不存在**: 检查文件路径是否正确 +2. **权限不足**: 使用`chmod`设置正确权限 +3. **正则表达式错误**: 检查正则表达式语法 +4. **编码问题**: 确保文件编码为UTF-8 + +### 调试技巧 +- 使用`--dry_run`预览操作结果 +- 先在小范围文件上测试 +- 查看详细的错误信息 +- 检查备份文件是否正确创建 + +## 扩展功能 + +### 自定义修改规则 +可以修改脚本以支持自定义的修改规则: + +```python +# 在quant_type_modifier.py中添加自定义规则 +def custom_modification(content, old_type, new_type): + # 自定义修改逻辑 + return modified_content +``` + +### 批量操作 +可以结合使用多个工具进行复杂的批量操作: + +```bash +# 1. 先更新脚本模式 +python update_scripts_with_pattern_v2.py + +# 2. 再修改量化类型 +python quant_type_modifier.py \ + --directory script/ \ + --old_quant_type bf16 \ + --new_quant_type hifp8 +``` diff --git a/script/utils/quant_type_modifier.py b/script/utils/quant_type_modifier.py new file mode 100644 index 0000000000..71b2ac3d14 --- /dev/null +++ b/script/utils/quant_type_modifier.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +""" +Script to modify custom_quant_type in Megatron-LM source code +""" + +import os +import re +import argparse +import shutil +from pathlib import Path + +def backup_file(file_path): + """Create a backup of the original file.""" + backup_path = f"{file_path}.backup" + if not os.path.exists(backup_path): + shutil.copy2(file_path, backup_path) + print(f"Created backup: {backup_path}") + else: + print(f"Backup already exists: {backup_path}") + +def modify_quant_type(file_path, line_number, new_quant_type): + """Modify custom_quant_type in a specific line of a file.""" + if not os.path.exists(file_path): + print(f"Error: File {file_path} does not exist") + return False + + # Create backup + backup_file(file_path) + + with open(file_path, 'r') as f: + lines = f.readlines() + + # Check if line number is valid + if line_number > len(lines): + print(f"Error: Line {line_number} does not exist in {file_path}") + return False + + # Get the original line + original_line = lines[line_number - 1].strip() + print(f"Original line {line_number}: {original_line}") + + # Modify the specific line + modified_line = re.sub( + r"custom_quant_type = '[^']*'", + f"custom_quant_type = '{new_quant_type}'", + lines[line_number - 1] + ) + + lines[line_number - 1] = modified_line + + with open(file_path, 'w') as f: + f.writelines(lines) + + print(f"Modified line {line_number}: {modified_line.strip()}") + return True + +def restore_backup(file_path): + """Restore from backup file.""" + backup_path = f"{file_path}.backup" + if os.path.exists(backup_path): + shutil.copy2(backup_path, file_path) + print(f"Restored from backup: {backup_path}") + else: + print(f"No backup found: {backup_path}") + +def main(): + parser = argparse.ArgumentParser(description="Modify custom_quant_type in Megatron-LM source code") + parser.add_argument("--linear-quant", choices=['hifp8', 'mxfp8', 'mxfp4', 'none'], + help="Linear layer quantization type") + parser.add_argument("--qk-quant", choices=['hifp8', 'mxfp8', 'mxfp4', 'none'], + help="QK attention quantization type") + parser.add_argument("--pv-quant", choices=['hifp8', 'mxfp8', 'mxfp4', 'none'], + help="PV attention quantization type") + parser.add_argument("--restore", action="store_true", help="Restore from backup files") + parser.add_argument("--check", action="store_true", help="Check current quantization types") + + args = parser.parse_args() + + # Get Megatron-LM root directory + script_dir = Path(__file__).parent + megatron_root = script_dir.parent + + # File paths + layers_file = megatron_root / "megatron/core/tensor_parallel/layers.py" + attention_file = megatron_root / "megatron/core/transformer/dot_product_attention.py" + + if args.restore: + print("Restoring from backup files...") + restore_backup(layers_file) + restore_backup(attention_file) + return 0 + + if args.check: + print("Checking current quantization types...") + print(f"\nLinear layer ({layers_file}):") + with open(layers_file, 'r') as f: + lines = f.readlines() + if len(lines) >= 783: + print(f" Line 783: {lines[782].strip()}") + + print(f"\nAttention layer ({attention_file}):") + with open(attention_file, 'r') as f: + lines = f.readlines() + if len(lines) >= 166: + print(f" Line 166 (QK): {lines[165].strip()}") + if len(lines) >= 238: + print(f" Line 238 (PV): {lines[237].strip()}") + return 0 + + # Modify quantization types + success = True + + if args.linear_quant: + print(f"Modifying linear layer quantization to {args.linear_quant}...") + success &= modify_quant_type(layers_file, 783, args.linear_quant) + + if args.qk_quant: + print(f"Modifying QK attention quantization to {args.qk_quant}...") + success &= modify_quant_type(attention_file, 166, args.qk_quant) + + if args.pv_quant: + print(f"Modifying PV attention quantization to {args.pv_quant}...") + success &= modify_quant_type(attention_file, 238, args.pv_quant) + + if not args.linear_quant and not args.qk_quant and not args.pv_quant: + print("No quantization type specified. Use --help for usage information.") + return 1 + + if success: + print("\n✅ All modifications completed successfully!") + print("\nTo verify changes, run:") + print(" python3 modify_quant_type.py --check") + print("\nTo restore original files, run:") + print(" python3 modify_quant_type.py --restore") + else: + print("\n❌ Some modifications failed!") + return 1 + + return 0 + +if __name__ == "__main__": + exit(main()) diff --git a/script/utils/update_scripts_with_pattern_v2.py b/script/utils/update_scripts_with_pattern_v2.py new file mode 100644 index 0000000000..fb950d38a0 --- /dev/null +++ b/script/utils/update_scripts_with_pattern_v2.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +""" +Script to update all training scripts with the new pattern: +1. Export HOST_TENSORBOARD_LOGS_PATH +2. Use sed commands to modify quantization types +3. Execute training with timestamped logging + +This version handles llama31-8b and deepseek2_lite directories. +""" + +import os +import re +import shutil +from pathlib import Path + +def backup_file(file_path): + """Create a backup of the original file.""" + backup_path = f"{file_path}.backup_pattern_update_v2" + if not os.path.exists(backup_path): + shutil.copy2(file_path, backup_path) + print(f"Created backup: {backup_path}") + else: + print(f"Backup already exists: {backup_path}") + +def extract_quant_type_from_filename(filename): + """Extract quantization type from filename.""" + filename_str = str(filename) + if "hifp8" in filename_str: + return "hifp8" + elif "mxfp8" in filename_str: + return "mxfp8" + elif "mxfp4" in filename_str: + return "mxfp4" + else: + return "bf16" # Default + +def extract_model_name_from_filename(filename): + """Extract model name from filename.""" + filename_str = str(filename) + if "llama31-8b" in filename_str: + return "llama31-8b" + elif "deepseek2_lite" in filename_str: + return "deepseek2_lite" + else: + return "unknown" + +def get_training_script_path(model_name): + """Get the appropriate training script path based on model name.""" + if model_name == "llama31-8b": + return "examples/llama/train_llama3_8b_h100_fp8.sh" # 注意是 llama3 不是 llama31 + elif model_name == "deepseek2_lite": + return "examples/deepseek2_lite/train_deepseek2_lite_h100_fp8.sh" # 使用新创建的 deepseek2_lite 脚本 + else: + return "examples/llama/train_llama32_1b_h100_fp8.sh" # Default + +def update_script_with_pattern(script_path): + """Update a script with the new pattern.""" + print(f"Updating script: {script_path}") + + # Create backup + backup_file(script_path) + + # Extract information from filename + quant_type = extract_quant_type_from_filename(script_path) + model_name = extract_model_name_from_filename(script_path) + training_script = get_training_script_path(model_name) + + # Create new script content following the pattern + script_name = os.path.basename(str(script_path)) + new_content = f'''#!/bin/bash + +# ============================================================================= +# Training Script for {model_name.upper()} - Updated with new pattern +# Script: {script_name} +# Quantization Type: {quant_type} +# ============================================================================= + +# Set script metadata +SCRIPT_NAME="$(basename "$0")" +SCRIPT_VERSION="1.0.0" +START_TIME=$(date '+%Y-%m-%d %H:%M:%S') +START_TIMESTAMP=$(date '+%Y%m%d_%H%M%S') + +# ============================================================================= +# Configuration Parameters +# ============================================================================= + +# Parse command line arguments +CHECKPOINT_PATH=${{1:-"checkpoints/{model_name}/{script_name.replace('.sh', '')}"}} +TENSORBOARD_LOGS_PATH=${{2:-"tensorboard_logs/{model_name}_{quant_type}"}} +TOKENIZER_ARG=${{3:-"model/{model_name}"}} +DATA_ARG=${{4:-"dataset/wikipedia_processed/wikipedia_processed_text_document"}} +DTYPE=${{5:-"bf16"}} + +# ============================================================================= +# Environment Setup +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training script: $SCRIPT_NAME" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint Path: $CHECKPOINT_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard Path: $TENSORBOARD_LOGS_PATH" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Tokenizer Path: $TOKENIZER_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Path: $DATA_ARG" +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Data Type: $DTYPE" + +# Export tensorboard logs path +export HOST_TENSORBOARD_LOGS_PATH="$TENSORBOARD_LOGS_PATH" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# ============================================================================= +# Quantization Type Modification +# =============================================================================''' + + # Add quantization modification if not bf16 + if quant_type != "bf16": + new_content += f''' + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Modifying quantization types to {quant_type}..." + +# Modify linear layer quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'{quant_type}'/" \\ + megatron/core/tensor_parallel/layers.py + +# Modify attention quantization +sed -i "s/^\([[:space:]]*custom_quant_type[[:space:]]*=[[:space:]]*\)'[^']*'/\1'{quant_type}'/" \\ + megatron/core/transformer/dot_product_attention.py + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Quantization type modifications completed"''' + else: + new_content += ''' + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] BF16 training - no quantization modification needed"''' + + # Add training execution + new_content += f''' + +# ============================================================================= +# Training Execution +# ============================================================================= + +echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Starting training execution..." + +# Execute the training script with timestamped logging +bash {training_script} \\ + "$CHECKPOINT_PATH" \\ + "$TENSORBOARD_LOGS_PATH" \\ + "$TOKENIZER_ARG" \\ + "$DATA_ARG" \\ + "$DTYPE" \\ + 2>&1 | tee "${{HOST_TENSORBOARD_LOGS_PATH}}/training_{script_name.replace('.sh', '')}_$(date +'%y-%m-%d_%H-%M-%S').log" + +TRAINING_EXIT_CODE=${{PIPESTATUS[0]}} + +# ============================================================================= +# Finalization +# ============================================================================= + +if [[ $TRAINING_EXIT_CODE -eq 0 ]]; then + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [SUCCESS] Training completed successfully" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] Checkpoint saved to: $CHECKPOINT_PATH" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [INFO] TensorBoard logs saved to: $TENSORBOARD_LOGS_PATH" +else + echo "[$(date '+%Y-%m-%d %H:%M:%S')] [ERROR] Training failed with exit code: $TRAINING_EXIT_CODE" +fi + +exit $TRAINING_EXIT_CODE +''' + + # Write the new content + with open(str(script_path), 'w') as f: + f.write(new_content) + + # Make the script executable + os.chmod(str(script_path), 0o755) + + print(f"Updated script: {script_path}") + +def update_directory_scripts(directory_path): + """Update all scripts in a directory.""" + if not directory_path.exists(): + print(f"Error: Directory {directory_path} does not exist") + return 0 + + # Find all .sh files in the directory + script_files = list(directory_path.glob("*.sh")) + + # Filter out backup files + script_files = [f for f in script_files if not f.name.endswith('.backup')] + + print(f"\nFound {len(script_files)} scripts to update in {directory_path.name}:") + for script_file in script_files: + print(f" - {script_file.name}") + + # Update each script + updated_count = 0 + for script_file in script_files: + try: + update_script_with_pattern(script_file) + updated_count += 1 + except Exception as e: + print(f"Error updating {script_file}: {e}") + + print(f"\n✅ Updated {updated_count} scripts in {directory_path.name}") + return updated_count + +def main(): + """Main function to update all scripts.""" + script_dir = Path(__file__).parent + + # Directories to update + directories_to_update = [ + script_dir / "llama31-8b", + script_dir / "deepseek2_lite" + ] + + total_updated = 0 + + for directory in directories_to_update: + if directory.exists(): + updated_count = update_directory_scripts(directory) + total_updated += updated_count + else: + print(f"Warning: Directory {directory} does not exist, skipping...") + + print(f"\n🎉 Total updated {total_updated} scripts across all directories!") + print("\nThe updated scripts now include:") + print(" 1. HOST_TENSORBOARD_LOGS_PATH export") + print(" 2. sed commands to modify quantization types") + print(" 3. Timestamped logging with tee command") + print(" 4. Model-specific training script paths") + + return 0 + +if __name__ == "__main__": + exit(main()) diff --git a/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py b/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py deleted file mode 100644 index e0286695e6..0000000000 --- a/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py +++ /dev/null @@ -1,77 +0,0 @@ -import json -import logging -import math - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_inference_pipeline(golden_values_path: str, test_values_path: str) -> None: - - with open(golden_values_path, 'r') as f1, open(test_values_path, 'r') as f2: - golden_values_content = f1.read() - tensorboard_content = f2.read() - - output_groundtruth = json.loads(golden_values_content) - - if isinstance(output_groundtruth, str): - # Handle JSONL output, assume only one line in this case. - output_groundtruth = json.loads(output_groundtruth) - - output_current = json.loads(tensorboard_content) - if isinstance(output_current, str): - # Handle JSONL output, assume only one line in this case. - output_current = json.loads(output_current) - - assert set(output_groundtruth.keys()).issuperset( - set(output_current.keys()) - ), f"Some IDs from groundtruth are missing in current: {output_groundtruth.keys()} vs {output_current.keys()}" - if set(output_groundtruth.keys()) != set(output_current.keys()): - logger.warning( - f"Some IDs from groundtruth are missing in output, only the subset of ids in groundtruth will be tested: {output_groundtruth.keys()} vs {output_current.keys()}" - ) - assert len(output_groundtruth) > 0, "No test performed for output" - for request_id, groundtruth_results in output_groundtruth.items(): - current_results = output_current[request_id] - - at_least_one_test_loop = False - if "generated_tokens" in groundtruth_results: - at_least_one_test_loop = True - tokens_groundtruth = groundtruth_results["generated_tokens"] - tokens_current = current_results["generated_tokens"] - # Check token equality - assert ( - tokens_groundtruth == tokens_current - ), f"Token mismatch:\nGround truth: {tokens_groundtruth}\nCurrent: {tokens_current}" - - if "logprobs" in groundtruth_results: - at_least_one_test_loop = True - logprobs_groundtruth = groundtruth_results["logprobs"] - logprobs_current = current_results["logprobs"] - # Check logprobs length and tolerance - assert len(logprobs_groundtruth) == len( - logprobs_current - ), f"Logprobs length mismatch: {len(logprobs_groundtruth)} vs {len(logprobs_current)}" - - for i, (lp1, lp2) in enumerate(zip(logprobs_groundtruth, logprobs_current)): - assert math.isclose( - lp1, lp2, abs_tol=0.001 - ), f"Logprobs differ at index {i}: {lp1:.5f} vs {lp2:.5f}" - - if "generated_text" in groundtruth_results: - at_least_one_test_loop = True - generated_text_groundtruth = groundtruth_results["generated_text"] - generated_text_current = current_results["generated_text"] - min_len = min(len(generated_text_groundtruth), len(generated_text_current)) - assert min_len > 0, ( - "Generated text mismatch:" - f"\nGround truth: {generated_text_groundtruth}\nCurrent: {generated_text_current}" - ) - assert generated_text_groundtruth[:min_len] == generated_text_current[:min_len], ( - "Generated text mismatch:" - f"\nGround truth (truncated to {min_len} chars): {generated_text_groundtruth[:min_len]}" - f"\nCurrent (truncated to {min_len} chars): {generated_text_current[:min_len]}" - ) - - if not at_least_one_test_loop: - raise AssertionError(f"No test performed for output {groundtruth_results}") diff --git a/tests/functional_tests/python_test_utils/test_pretraining_regular_pipeline.py b/tests/functional_tests/python_test_utils/test_pretraining_regular_pipeline.py deleted file mode 100644 index 2887a65df8..0000000000 --- a/tests/functional_tests/python_test_utils/test_pretraining_regular_pipeline.py +++ /dev/null @@ -1,80 +0,0 @@ -import logging -from typing import Dict, List, Optional - -import yaml - -from tests.functional_tests.python_test_utils import common - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -CHECK_THRESHOLDS = { - "iteration-time": [common.ApproximateTest(atol=2.0, rtol=0)], - "mem-allocated-bytes": [ - common.ApproximateTest(atol_func=common.approximate_threshold(rtol=0.05), rtol=0) - ], - "mem-max-allocated-bytes": [ - common.ApproximateTest(atol_func=common.approximate_threshold(rtol=0.05), rtol=0) - ], - "lm loss": [ - common.DeterministicTest(), - common.ApproximateTest(atol_func=common.approximate_threshold(rtol=0.05), rtol=0), - ], - "mtp_1 loss": [ - common.DeterministicTest(), - common.ApproximateTest(atol_func=common.approximate_threshold(rtol=0.05), rtol=0), - ], - "num-zeros": [ - common.DeterministicTest(), - common.ApproximateTest(atol_func=common.approximate_threshold(rtol=0.20), rtol=0), - ], - "generated_tokens": [ - common.DeterministicTest(), - common.ApproximateTest(atol_func=common.approximate_threshold(rtol=0.05), rtol=0), - ], - "logprobs": [ - common.DeterministicTest(), - common.ApproximateTest(atol_func=common.approximate_threshold(rtol=0.05), rtol=0), - ], -} - - -def test_regular_pipeline( - compare_approximate_results: bool, - golden_values: Dict[str, common.GoldenValueMetric], - actual_values: Dict[str, common.GoldenValueMetric], - model_config_path: str, - checks: Optional[Dict[str, List[common.Test]]] = None, -): - if checks is None: - with open(model_config_path) as f: - model_config = yaml.safe_load(f) - - checks_types = ( - model_config["METRICS"] - if "METRICS" in model_config - else ["iteration-time", "lm loss", "num-zeros"] - ) - checks = {metric: CHECK_THRESHOLDS[metric] for metric in checks_types} - - if ( - len( - missing_metrics := [ - golden_metric - for golden_metric in checks.keys() - if golden_metric not in golden_values.keys() - ] - ) - > 0 - ): - logger.error( - f"The following metrics are required but not provided in golden values: {', '.join(missing_metrics)}" - ) - assert False - - common.pipeline( - compare_approximate_results=compare_approximate_results, - golden_values=golden_values, - actual_values=actual_values, - checks=checks, - ) diff --git a/tests/functional_tests/python_test_utils/test_pretraining_resume_checkpoint_pipeline.py b/tests/functional_tests/python_test_utils/test_pretraining_resume_checkpoint_pipeline.py deleted file mode 100644 index 64cbe0b9b5..0000000000 --- a/tests/functional_tests/python_test_utils/test_pretraining_resume_checkpoint_pipeline.py +++ /dev/null @@ -1,86 +0,0 @@ -import logging -from typing import Dict - -import yaml - -from tests.functional_tests.python_test_utils import common, test_pretraining_regular_pipeline - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_resume_checkpoint_pipeline( - compare_approximate_results: bool, - actual_values_first_run: Dict[str, common.GoldenValueMetric], - actual_values_second_run: Dict[str, common.GoldenValueMetric], - train_iters: int, - model_config_path: str, -): - with open(model_config_path) as f: - model_config = yaml.safe_load(f) - - checks_types = ( - model_config["METRICS"] - if "METRICS" in model_config - else ["iteration-time", "lm loss", "num-zeros"] - ) - checks = { - metric: test_pretraining_regular_pipeline.CHECK_THRESHOLDS[metric] - for metric in checks_types - } - - if ( - len( - missing_metrics := [ - golden_metric - for golden_metric in checks.keys() - if golden_metric not in actual_values_first_run.keys() - ] - ) - > 0 - ): - logger.error( - f"The following metrics are required but not logged during training: {', '.join(missing_metrics)}" - ) - assert False - - # actual_values_second_run is NaN for the first 50 steps. We want to replace those - # with the first 50 steps of actual_values_first_run - - actual_values_first_run = { - metric_name: metric_values - for (metric_name, metric_values) in actual_values_first_run.items() - if metric_name in checks.keys() - } - - actual_values_second_run = { - metric_name: metric_values - for (metric_name, metric_values) in actual_values_second_run.items() - if metric_name in checks.keys() - } - - for metric_name in checks.keys(): - actual_values_first_run[metric_name].start_step = train_iters // 2 + 1 - actual_values_first_run[metric_name].values = { - k: v - for k, v in actual_values_first_run[metric_name].values.items() - if k > train_iters // 2 - } - - actual_values_second_run[metric_name].start_step = train_iters // 2 + 1 - actual_values_second_run[metric_name].values = { - k: v - for k, v in actual_values_second_run[metric_name].values.items() - if k > train_iters // 2 - } - - logger.info(actual_values_first_run) - logger.info(actual_values_second_run) - - test_pretraining_regular_pipeline.test_regular_pipeline( - compare_approximate_results=compare_approximate_results, - golden_values=actual_values_first_run, - actual_values=actual_values_second_run, - checks=checks, - model_config_path=model_config_path, - ) diff --git a/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py b/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py deleted file mode 100644 index 2dd0f20fe2..0000000000 --- a/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import gc - -import pytest -import torch - -from megatron.core.models.common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_decoder_block_spec, - get_gpt_mtp_block_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.pipeline_parallel.utils import set_streams -from megatron.core.transformer.module import float16_to_fp32 -from megatron.core.utils import is_te_min_version -from tests.unit_tests.a2a_overlap.utils import ( - compare_captures, - deterministic_mode, - get_test_config, - get_valid_fp8_flags, - get_valid_token_dispatcher_types, -) -from tests.unit_tests.test_utilities import Utils - - -def build_model(config): - seq_len = 32 - max_seq_len = 300 - # ids = random.sample([i for i in range(max_seq_len)], seq_len) - ids = [i for i in range(seq_len)] - - # build input tensors - data = { - "input_ids": torch.tensor(ids, dtype=torch.int64).repeat((1, 1)).cuda(), - "labels": torch.tensor(ids, dtype=torch.int64).repeat((1, 1)).cuda(), - "position_ids": torch.tensor([i for i in range(seq_len)], dtype=torch.int64) - .repeat((1, 1)) - .cuda(), - "attention_mask": torch.ones((1, 1, seq_len, seq_len), dtype=bool).cuda(), - } - - # build layer spec - transformer_layer_spec = get_gpt_decoder_block_spec(config=config, use_transformer_engine=True) - mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec.layer_specs[-1], True) - - # build model - gpt_model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - mtp_block_spec=mtp_block_spec, - vocab_size=100, - pre_process=True, - post_process=True, - max_sequence_length=max_seq_len, - ) - f_schedule_plan = gpt_model.build_schedule_plan(**data) - return gpt_model, f_schedule_plan, data - - -class TestA2AOverlap: - """ - Test class for all-to-all overlap optimization in transformer models. - - This class contains tests to verify that the all-to-all overlap optimization - produces the same results as the reference implementation. - """ - - def setup_method(self, method): - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - expert_model_parallel_size=4, - ) - set_streams() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") - @pytest.mark.parametrize("mtp_layers", [0, 1]) - @pytest.mark.parametrize("dispatcher_type", get_valid_token_dispatcher_types()) - @pytest.mark.parametrize("fp8_flag", get_valid_fp8_flags()) - @pytest.mark.parametrize("layers", [[2, 1], [1, 2], [1, 1]]) - def test_1f1b_schedule_model_chunk(self, mtp_layers, dispatcher_type, fp8_flag, layers): - """ - Verifies all-to-all overlap optimization in transformer layer produces - the same results as the reference implementation. - """ - microbatches = 1 - - gpt_models = [] - schedule_plans = [] - ref_captures = [] - datas = [] - - # create TransformerConfig - extra_kwargs = {"moe_token_dispatcher_type": dispatcher_type} - if dispatcher_type == "flex": - extra_kwargs["moe_enable_deepep"] = True - extra_kwargs["moe_router_dtype"] = "fp32" - if fp8_flag is not None: - extra_kwargs["fp8"] = fp8_flag[0] - extra_kwargs["fp8_recipe"] = fp8_flag[1] - if mtp_layers > 0: - extra_kwargs["mtp_num_layers"] = mtp_layers - extra_kwargs["mtp_loss_scaling_factor"] = 1.1 - with deterministic_mode(): - for layer_num in layers: - output_tensors = [] - # build config - config = get_test_config(num_layers=layer_num, extra_kwargs=extra_kwargs) - # build model - gpt_model, schedule_plan, data = build_model(config) - gpt_model.cuda() - gpt_models.append(gpt_model) - datas.append(data) - schedule_plans.append(schedule_plan) - - # run reference - for _ in range(microbatches): - loss = gpt_model.forward(**data) - loss = float16_to_fp32(loss) - loss.backward(torch.ones_like(loss)) - output_tensors.append(loss) - - capture = {"outputs": output_tensors} - for name, param in gpt_model.named_parameters(): - capture[name] = param.grad - ref_captures.append(capture) - gpt_model.zero_grad() - assert gpt_models[0].embedding is not None - assert gpt_models[1].embedding is not None - # run a2a overlap - capture_0 = {"outputs": []} - capture_1 = {"outputs": []} - a2a_captures = [capture_0, capture_1] - for i in range(microbatches): - # 1st forward - if i > 0: - assert ( - schedule_plans[0].pre_process is None - ), "pre_process should be released after backward" - schedule_plans[0] = gpt_models[0].build_schedule_plan(**datas[0]) - schedule_plans[1] = gpt_models[1].build_schedule_plan(**datas[1]) - f_input_0 = TransformerModelChunkSchedulePlan.run(schedule_plans[0], None) - capture_0["outputs"].append(f_input_0) - # overlap - f_input_1 = TransformerModelChunkSchedulePlan.run( - schedule_plans[1], schedule_plans[0], b_grad=torch.ones_like(f_input_0) - ) - capture_1["outputs"].append(f_input_1) - # last backward - TransformerModelChunkSchedulePlan.run( - None, schedule_plans[1], b_grad=torch.ones_like(f_input_1) - ) - for i in range(len(gpt_models)): - for name, param in gpt_models[i].named_parameters(): - a2a_captures[i][name] = param.grad - - # compare results - for i in range(len(ref_captures)): - comp_res = compare_captures(ref_captures[i], a2a_captures[i], True, True) - assert comp_res[0], f"[rank {torch.distributed.get_rank()}] {comp_res[1]}" - - # release resources is necessary, otherwise later testcases will oom - for i in range(len(schedule_plans)): - schedule_plans[i] = None - ref_captures[i] = None - a2a_captures[i] = None - for k in datas[i]: - datas[i][k] = None - datas[i] = None - gpt_models[i].zero_grad() - gpt_models[i] = None - gc.collect() - torch.cuda.empty_cache() diff --git a/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py b/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py deleted file mode 100644 index 9b369410aa..0000000000 --- a/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py +++ /dev/null @@ -1,478 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from contextlib import nullcontext - -import pytest -import torch - -from megatron.core.fp8_utils import get_fp8_context -from megatron.core.models.common.model_chunk_schedule_plan import TransformerLayerSchedulePlan -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_decoder_block_spec, - get_gpt_layer_with_transformer_engine_spec, - get_gpt_mtp_block_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.utils import is_te_min_version -from tests.unit_tests.a2a_overlap.utils import ( - DummyState, - build_data, - compare_captures, - deterministic_mode, - get_test_config, - get_valid_fp8_flags, - get_valid_token_dispatcher_types, - reset_model, -) -from tests.unit_tests.test_utilities import Utils - - -def run_transformer_layer_ref_with_capture(model, input_tensors, iterations): - """ - Runs the model in reference mode and captures outputs and gradients. - - Args: - model: The transformer model to run. - input_tensors: List of input tensors for each iteration. - iterations: Number of iterations to run the model. - - Returns: - dict: A dictionary containing model outputs and parameter gradients. - """ - - output_tensors = [] - for i in range(iterations): - output = model(input_tensors[i].clone())[0] - output_tensors.append(output) - output.backward(torch.ones_like(output)) - - capture = {"outputs": output_tensors} - for name, param in model.named_parameters(): - capture[name] = param.grad - - return capture - - -def run_transformer_layer_a2a_overlap_with_capture(model, input_tensors, microbatches): - """ - Runs the model with all-to-all overlap optimization and captures outputs and gradients. - - Args: - model: The transformer model to run. - input_tensors: List of input tensors for each microbatch. - microbatches: Number of microbatches to process. - - Returns: - dict: A dictionary containing model outputs and parameter gradients. - """ - for i in range(len(input_tensors)): - input_tensors[i] = input_tensors[i].clone() - - event = torch.cuda.Event() - comp_stream = torch.cuda.current_stream() - comm_stream = torch.cuda.Stream(device="cuda") - layers = [ - TransformerLayerSchedulePlan( - model, - event, - DummyState(), - comp_stream, - comm_stream, - extra_args={"is_moe": True, "enable_deepep": False}, - ) - for _ in range(microbatches) - ] - output_tensors = [] - - # forward for 1st microbatch - output, _ = TransformerLayerSchedulePlan.run( - layers[0], None, f_input=input_tensors[0], b_grad=None - ) - output_tensors.append(output) - torch.cuda.synchronize() - # overlapped forward and backward - for i in range(1, microbatches): - f_input, b_grad = TransformerLayerSchedulePlan.run( - layers[i], layers[i - 1], f_input=input_tensors[i], b_grad=torch.ones_like(output) - ) - output_tensors.append(f_input) - torch.cuda.synchronize() - # backward for last microbatch - TransformerLayerSchedulePlan.run(None, layers[-1], f_input=None, b_grad=torch.ones_like(output)) - torch.cuda.synchronize() - capture = {"outputs": output_tensors} - for name, param in model.named_parameters(): - capture[name] = param.grad - - return capture - - -def run_mtp_layer_ref_with_capture( - model, - hidden_states, - input_ids, - position_ids, - labels, - attention_mask, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - microbatches, -): - """ - Runs the model in reference mode and captures outputs and gradients. - - Args: - model: The transformer model to run. - input_tensors: List of input tensors for each iteration. - iterations: Number of iterations to run the model. - - Returns: - dict: A dictionary containing model outputs and parameter gradients. - """ - mtp_block = model.mtp - - output_tensors = [] - for i in range(microbatches): - output = mtp_block( - input_ids=input_ids, - position_ids=position_ids, - hidden_states=hidden_states[i].clone(), - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - embedding=model.embedding, - ) - output_tensors.append(output) - output.backward(torch.ones_like(output)) - - capture = {"outputs": output_tensors} - for name, param in model.named_parameters(): - capture[name] = param.grad - - return capture - - -def run_mtp_layer_a2a_overlap_with_capture( - model, - hidden_states, - input_ids, - position_ids, - labels, - attention_mask, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - microbatches, -): - """ - Runs the model with all-to-all overlap optimization and captures outputs and gradients. - - Args: - model: The transformer model to run. - input_tensors: List of input tensors for each microbatch. - microbatches: Number of microbatches to process. - - Returns: - dict: A dictionary containing model outputs and parameter gradients. - """ - for i in range(len(hidden_states)): - hidden_states[i] = hidden_states[i].clone() - - comp_stream = torch.cuda.current_stream() - comm_stream = torch.cuda.Stream(device="cuda") - layers = [] - for _ in range(microbatches): - state = DummyState() - state.mtp_labels = labels - state.input_ids = input_ids - state.position_ids = position_ids - state.attention_mask = attention_mask - state.rotary_pos_emb = rotary_pos_emb - state.rotary_pos_cos = rotary_pos_cos - state.rotary_pos_sin = rotary_pos_sin - state.model = model - event = torch.cuda.Event() - layers.append( - TransformerLayerSchedulePlan( - model.mtp.layers[0], - event, - state, - comp_stream, - comm_stream, - extra_args={ - "is_moe": True, - "enable_deepep": False, - "is_first_layer": True, - "is_last_layer": True, - }, - ) - ) - output_tensors = [] - # forward for 1st microbatch - f_input, _ = TransformerLayerSchedulePlan.run( - layers[0], None, f_input=hidden_states[0], b_grad=None - ) - output_tensors.append(f_input) - torch.cuda.synchronize() - # overlapped forward and backward - for i in range(1, microbatches): - f_input, b_grad = TransformerLayerSchedulePlan.run( - layers[i], layers[i - 1], f_input=hidden_states[i], b_grad=torch.ones_like(f_input) - ) - output_tensors.append(f_input) - torch.cuda.synchronize() - # backward for last microbatch - TransformerLayerSchedulePlan.run( - None, layers[-1], f_input=None, b_grad=torch.ones_like(f_input) - ) - torch.cuda.synchronize() - capture = {"outputs": output_tensors} - for name, param in model.named_parameters(): - capture[name] = param.grad - - return capture - - -class TestA2AOverlap: - """ - Test class for all-to-all overlap optimization in transformer models. - - This class contains tests to verify that the all-to-all overlap optimization - produces the same results as the reference implementation. - """ - - def setup_method(self, method): - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - expert_model_parallel_size=4, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") - def test_transformer_layer_overlap_dense(self): - """ - Verifies all-to-all overlap optimization in dense transformer layer produces - the same results as the reference implementation. - """ - extra_kwargs = {"moe_token_dispatcher_type": "alltoall"} - config = get_test_config(num_moe_experts=None, extra_kwargs=extra_kwargs) - microbatches = 4 - with deterministic_mode(): - transformer_layer_spec = get_gpt_decoder_block_spec( - config=config, use_transformer_engine=True - ) - gpt_model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=100, - pre_process=True, - post_process=True, - max_sequence_length=300, - ) - model = gpt_model.decoder.layers[0] - - params = reset_model(gpt_model) - input_tensors = [build_data() for _ in range(microbatches)] - - fp8_context = ( - get_fp8_context(config, model.layer_number - 1) if config.fp8 else nullcontext() - ) - with fp8_context: - capture_ref = run_transformer_layer_ref_with_capture( - model, input_tensors, microbatches - ) - reset_model(gpt_model, params) - capture_a2a_overlap = run_transformer_layer_a2a_overlap_with_capture( - model, input_tensors, microbatches - ) - comp_res = compare_captures(capture_ref, capture_a2a_overlap, True) - assert comp_res[0], f"[rank {torch.distributed.get_rank()}] {comp_res[1]}" - - @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") - def test_transformer_layer_overlap_shared_expert(self): - """ - Verifies all-to-all overlap optimization in transformer layer with shared expert produces - the same results as the reference implement - ation. - """ - extra_kwargs = { - "moe_token_dispatcher_type": "alltoall", - "moe_shared_expert_intermediate_size": 512, - } - config = get_test_config(extra_kwargs=extra_kwargs) - microbatches = 4 - with deterministic_mode(): - transformer_layer_spec = get_gpt_decoder_block_spec( - config=config, use_transformer_engine=True - ) - gpt_model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=100, - pre_process=True, - post_process=True, - max_sequence_length=300, - ) - model = gpt_model.decoder.layers[0] - - params = reset_model(gpt_model) - input_tensors = [build_data() for _ in range(microbatches)] - - fp8_context = ( - get_fp8_context(config, model.layer_number - 1) if config.fp8 else nullcontext() - ) - with fp8_context: - capture_ref = run_transformer_layer_ref_with_capture( - model, input_tensors, microbatches - ) - reset_model(gpt_model, params) - capture_a2a_overlap = run_transformer_layer_a2a_overlap_with_capture( - model, input_tensors, microbatches - ) - comp_res = compare_captures(capture_ref, capture_a2a_overlap, True) - assert comp_res[0], f"[rank {torch.distributed.get_rank()}] {comp_res[1]}" - - @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") - @pytest.mark.parametrize("dispatcher_type", get_valid_token_dispatcher_types()) - @pytest.mark.parametrize("fp8_flag", get_valid_fp8_flags()) - def test_transformer_layer_overlap(self, dispatcher_type, fp8_flag): - """ - Verifies all-to-all overlap optimization in transformer layer produces - the same results as the reference implementation. - """ - - extra_kwargs = {"moe_token_dispatcher_type": dispatcher_type} - if dispatcher_type == "flex": - extra_kwargs["moe_enable_deepep"] = True - extra_kwargs["moe_router_dtype"] = "fp32" - if fp8_flag is not None: - extra_kwargs["fp8"] = fp8_flag[0] - extra_kwargs["fp8_recipe"] = fp8_flag[1] - config = get_test_config(extra_kwargs=extra_kwargs) - microbatches = 4 - with deterministic_mode(): - transformer_layer_spec = get_gpt_decoder_block_spec( - config=config, use_transformer_engine=True - ) - gpt_model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=100, - pre_process=True, - post_process=True, - max_sequence_length=300, - ) - model = gpt_model.decoder.layers[0] - - params = reset_model(gpt_model) - input_tensors = [build_data() for _ in range(microbatches)] - - fp8_context = ( - get_fp8_context(config, model.layer_number - 1) if config.fp8 else nullcontext() - ) - with fp8_context: - capture_ref = run_transformer_layer_ref_with_capture( - model, input_tensors, microbatches - ) - reset_model(gpt_model, params) - capture_a2a_overlap = run_transformer_layer_a2a_overlap_with_capture( - model, input_tensors, microbatches - ) - comp_res = compare_captures(capture_ref, capture_a2a_overlap, True) - assert comp_res[0], f"[rank {torch.distributed.get_rank()}] {comp_res[1]}" - - @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") - @pytest.mark.parametrize("dispatcher_type", get_valid_token_dispatcher_types()) - @pytest.mark.parametrize("fp8_flag", get_valid_fp8_flags()) - def test_mtp_layer_overlap(self, dispatcher_type, fp8_flag): - """ - Verifies all-to-all overlap optimization in MTP layer produces - the same results as the reference implementation. - """ - - extra_kwargs = { - "moe_token_dispatcher_type": dispatcher_type, - "mtp_num_layers": 1, - "mtp_loss_scaling_factor": 1.1, - } - if dispatcher_type == "flex": - extra_kwargs["moe_enable_deepep"] = True - extra_kwargs["moe_router_dtype"] = "fp32" - if fp8_flag is not None: - extra_kwargs["fp8_recipe"] = fp8_flag[1] - extra_kwargs["fp8"] = fp8_flag[0] - config = get_test_config(extra_kwargs=extra_kwargs) - microbatches = 1 - seq_len = 32 - with deterministic_mode(): - # init models - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=16, - moe_grouped_gemm=True, - qk_layernorm=True, - multi_latent_attention=True, - ) - mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, True) - if mtp_block_spec is None: - # only last rank has mtp block - assert True - return - gpt_model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - mtp_block_spec=mtp_block_spec, - vocab_size=100, - pre_process=True, - post_process=True, - max_sequence_length=300, - ) - gpt_model.decoder.final_layernorm = None - gpt_model.cuda() - params = reset_model(gpt_model) - - # build input data - data = list(range(seq_len)) - hidden_states = [build_data(seq_len) for _ in range(microbatches)] - input_ids = torch.tensor(data, dtype=torch.int64).repeat((1, 1)).cuda() - labels = torch.tensor(data, dtype=torch.int64).repeat((1, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((1, 1)).cuda() - attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool).cuda() - # get rotary pos emb - _, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, _ = gpt_model._preprocess( - input_ids, position_ids - ) - # reset model - params = reset_model(gpt_model) - - # run reference implementation - capture_ref = run_mtp_layer_ref_with_capture( - model=gpt_model, - hidden_states=hidden_states, - input_ids=input_ids, - position_ids=position_ids, - labels=labels, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - microbatches=microbatches, - ) - reset_model(gpt_model, params) - capture_a2a_overlap = run_mtp_layer_a2a_overlap_with_capture( - model=gpt_model, - hidden_states=hidden_states, - input_ids=input_ids, - position_ids=position_ids, - labels=labels, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - microbatches=microbatches, - ) - comp_res = compare_captures(capture_ref, capture_a2a_overlap, True, True) - assert comp_res[0], f"[rank {torch.distributed.get_rank()}] {comp_res[1]}" diff --git a/tests/unit_tests/data/test_bin_reader.py b/tests/unit_tests/data/test_bin_reader.py deleted file mode 100644 index e479676ac4..0000000000 --- a/tests/unit_tests/data/test_bin_reader.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -import random -import sys -import tempfile -from dataclasses import dataclass -from types import ModuleType, SimpleNamespace -from typing import Any, Dict - -import nltk -import pytest - -try: - import boto3 - import botocore.exceptions as exceptions -except ModuleNotFoundError: - boto3 = ModuleType("boto3") - sys.modules[boto3.__name__] = boto3 - exceptions = ModuleType("botocore.exceptions") - sys.modules[exceptions.__name__] = exceptions - -try: - import multistorageclient as msc -except ModuleNotFoundError: - # Create mock msc module - msc = ModuleType("multistorageclient") - - # Create mock types submodule - types_module = ModuleType("multistorageclient.types") - - # Create Range class in types module - class Range: - def __init__(self, offset: int, size: int): - self.offset = offset - self.size = size - - # Add Range class to types module - types_module.Range = Range # type: ignore[attr-defined] - - # Add types submodule to msc - msc.types = types_module - - # Register the mock module in sys.modules - sys.modules[msc.__name__] = msc - sys.modules[types_module.__name__] = types_module - -from megatron.core.datasets.indexed_dataset import ( - IndexedDataset, - ObjectStorageConfig, - _FileBinReader, - _MMapBinReader, - _MultiStorageClientBinReader, - _S3BinReader, -) -from megatron.core.datasets.object_storage_utils import MSC_PREFIX, S3_PREFIX, S3Client -from tests.unit_tests.data.test_preprocess_data import ( - build_datasets, - dummy_jsonl, - gpt2_merge, - gpt2_vocab, -) - -## -# Overload client from boto3 -## - - -class _LocalClient(S3Client): - """Local test client""" - - def __init__(self, *args: Any) -> None: - pass - - def download_file(self, Bucket: str, Key: str, Filename: str) -> None: - os.makedirs(os.path.dirname(Filename), exist_ok=True) - os.system(f"cp {os.path.join('/', Bucket, Key)} {Filename}") - assert os.path.exists(Filename) - - def upload_file(self, Filename: str, Bucket: str, Key: str) -> None: - raise NotImplementedError - - def head_object(self, Bucket: str, Key: str) -> Dict[str, Any]: - assert os.path.exists(os.path.join("/", Bucket, Key)) - return {} - - def get_object(self, Bucket: str, Key: str, Range: str) -> Dict[str, Any]: - _, _range = Range.split("=") - _range_beg, _range_end = tuple(map(int, _range.split("-"))) - - filename = os.path.join("/", Bucket, Key) - - with open(filename, mode='rb', buffering=0) as bin_buffer_file: - bin_buffer_file.seek(_range_beg) - _bytes = bin_buffer_file.read(_range_end - _range_beg) - - response = {"Body": SimpleNamespace(read=lambda: _bytes)} - - return response - - def close(self) -> None: - pass - - -setattr(boto3, "client", _LocalClient) - - -## -# Overload ClientError from botocore.exceptions -## - - -class _LocalClientError(Exception): - """ "Local test client error""" - - pass - - -setattr(exceptions, "ClientError", _LocalClientError) - -## -# Mock multistorageclient module -## - - -def _msc_download_file(remote_path, local_path): - remote_path = remote_path.removeprefix(MSC_PREFIX + "default") - os.makedirs(os.path.dirname(local_path), exist_ok=True) - os.system(f"cp {remote_path} {local_path}") - - -def _msc_resolve_storage_client(path): - class StorageClient: - def read(self, path, byte_range): - with open(path, "rb") as f: - f.seek(byte_range.offset) - return f.read(byte_range.size) - - return StorageClient(), path.removeprefix(MSC_PREFIX + "default") - - -setattr(msc, "open", open) -setattr(msc, "download_file", _msc_download_file) -setattr(msc, "resolve_storage_client", _msc_resolve_storage_client) - - -@pytest.mark.flaky -@pytest.mark.flaky_in_dev -def test_bin_reader(): - with tempfile.TemporaryDirectory() as temp_dir: - # set the default nltk data path - os.environ["NLTK_DATA"] = os.path.join(temp_dir, "nltk_data") - nltk.data.path.append(os.environ["NLTK_DATA"]) - - path_to_raws = os.path.join(temp_dir, "sample_raws") - path_to_data = os.path.join(temp_dir, "sample_data") - path_to_object_storage_cache = os.path.join(temp_dir, "object_storage_cache") - os.mkdir(path_to_raws) - os.mkdir(path_to_data) - os.mkdir(path_to_object_storage_cache) - - # create the dummy resources - dummy_jsonl(path_to_raws) - - # build the datasets - build_datasets( - path_to_raws, - path_to_data, - extra_args=[ - "--tokenizer-type", - "GPT2BPETokenizer", - "--vocab-file", - gpt2_vocab(temp_dir), - "--merge-file", - gpt2_merge(temp_dir), - "--append-eod", - "--workers", - "10", - "--log-interval", - "1", - ], - ) - - prefixes = set( - [ - os.path.join(temp_dir, "sample_data", path.split(".")[0]) - for path in os.listdir(path_to_data) - if path.endswith(".bin") or path.endswith(".idx") - ] - ) - - for prefix in prefixes: - indexed_dataset_file = IndexedDataset(prefix, multimodal=False, mmap=False) - assert isinstance(indexed_dataset_file.bin_reader, _FileBinReader) - - indexed_dataset_mmap = IndexedDataset(prefix, multimodal=False, mmap=True) - assert isinstance(indexed_dataset_mmap.bin_reader, _MMapBinReader) - - indexed_dataset_msc = IndexedDataset( - MSC_PREFIX + "default" + prefix, # use the default profile to access the filesystem - multimodal=False, - mmap=False, - object_storage_config=ObjectStorageConfig( - path_to_idx_cache=path_to_object_storage_cache - ), - ) - assert isinstance(indexed_dataset_msc.bin_reader, _MultiStorageClientBinReader) - assert len(indexed_dataset_msc) == len(indexed_dataset_file) - assert len(indexed_dataset_msc) == len(indexed_dataset_mmap) - - indexed_dataset_s3 = IndexedDataset( - S3_PREFIX + prefix, - multimodal=False, - mmap=False, - object_storage_config=ObjectStorageConfig( - path_to_idx_cache=path_to_object_storage_cache - ), - ) - assert isinstance(indexed_dataset_s3.bin_reader, _S3BinReader) - - assert len(indexed_dataset_s3) == len(indexed_dataset_file) - assert len(indexed_dataset_s3) == len(indexed_dataset_mmap) - - indices = random.sample( - list(range(len(indexed_dataset_s3))), min(100, len(indexed_dataset_s3)) - ) - - for idx in indices: - assert (indexed_dataset_s3[idx] == indexed_dataset_file[idx]).all() - assert (indexed_dataset_s3[idx] == indexed_dataset_mmap[idx]).all() - - -if __name__ == "__main__": - test_bin_reader() diff --git a/tests/unit_tests/data/test_builder.py b/tests/unit_tests/data/test_builder.py deleted file mode 100644 index 939677268b..0000000000 --- a/tests/unit_tests/data/test_builder.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -## -# Compile megatron.core.datasets.helpers_cpp dependencies before BlendedDataset import -## - -import os -import tempfile -from collections import defaultdict -from typing import Dict, Optional - -import numpy -import pytest -import torch - -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig -from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset -from megatron.core.datasets.utils import Split, compile_helpers, get_blend_from_list -from tests.unit_tests.test_utilities import Utils - -_NUM_DATASETS = 10 - -_SEQUENCE_LENGTH = 10 - -_SIZES = {} -for split in Split: - _SIZES[split] = [] - for i in range(_NUM_DATASETS): - _SIZES[split].append({Split.train: 1000, Split.valid: 100, Split.test: 10}[split] * (i + 1)) - -_MARGIN = 0.005 - - -def do_setup(odir): - paths = defaultdict(list) - - for i in range(_NUM_DATASETS): - path_to_data = os.path.join(odir, str(i)) - os.mkdir(path_to_data) - - for split in _SIZES: - data = numpy.zeros((_SIZES[split][i], _SEQUENCE_LENGTH)) - path = os.path.join(path_to_data, f"{split.name}.npy") - numpy.save(path, data) - paths[split].append(path) - - return paths - - -def test_builder(): - if torch.distributed.is_available(): - Utils.initialize_distributed() - if torch.distributed.get_rank() == 0: - compile_helpers() - torch.distributed.barrier() - else: - compile_helpers() - - # Define the class here to avoid pytest warnings - - class TestDataset(MegatronDataset): - def __init__( - self, - dataset: LowLevelDataset, - dataset_path: Optional[str], - indices: numpy.ndarray, - num_samples: Optional[int], - index_split: Split, - config: BlendedMegatronDatasetConfig, - ) -> None: - super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) - - if self.num_samples is None: - self.num_samples = len(self.indices) - - self.sample_index = numpy.random.choice(self.indices, size=self.num_samples) - - @staticmethod - def numel_low_level_dataset(low_level_dataset: LowLevelDataset) -> int: - return len(low_level_dataset) - - @staticmethod - def build_low_level_dataset( - dataset_path: str, config: BlendedMegatronDatasetConfig - ) -> LowLevelDataset: - return numpy.load(dataset_path) - - def __len__(self) -> int: - return len(self.sample_index) - - def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: - return {"text": self.dataset[self.sample_index[idx]]} - - with tempfile.TemporaryDirectory() as temp_dir: - - paths = do_setup(temp_dir) - - blends = { - split: get_blend_from_list( - [ - weight_or_path - for pair in zip(list(range(1, len(paths[split]) + 1, 1)), paths[split]) - for weight_or_path in pair - ] - ) - for split in Split - } - - blends_unweighted = {split: (blends[split][0], None) for split in blends} - - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[blends[Split.train], None, None], - mid_level_dataset_surplus=0.005, - ) - try: - datasets = BlendedMegatronDatasetBuilder( - TestDataset, [None, None, None], lambda: True, config - ).build() - raise RuntimeError - except AssertionError: - pass - - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[get_blend_from_list([paths[Split.train][0]]), None, None], - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, [1000, None, None], lambda: True, config - ).build() - assert len(datasets[0]) == 1000 and isinstance(datasets[0], TestDataset) - assert datasets[1] is None - assert datasets[2] is None - - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[ - blends_unweighted[Split.train], - blends_unweighted[Split.valid], - blends_unweighted[Split.test], - ], - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, [1000, 1000, 1000], lambda: True, config - ).build() - assert len(datasets[0]) == 1000 - assert len(datasets[1]) == 1000 - assert len(datasets[2]) == sum(_SIZES[Split.test]) - - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[ - blends_unweighted[Split.train], - blends_unweighted[Split.valid], - blends_unweighted[Split.test], - ], - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, [None, None, None], lambda: True, config - ).build() - assert len(datasets[0]) == sum(_SIZES[Split.train]) - assert numpy.all( - numpy.array(datasets[0].weights) - == numpy.unique(datasets[0].dataset_index, return_counts=True)[1] - ) - assert len(datasets[1]) == sum(_SIZES[Split.valid]) - assert numpy.all( - numpy.array(datasets[1].weights) - == numpy.unique(datasets[1].dataset_index, return_counts=True)[1] - ) - assert len(datasets[2]) == sum(_SIZES[Split.test]) - assert numpy.all( - numpy.array(datasets[2].weights) - == numpy.unique(datasets[2].dataset_index, return_counts=True)[1] - ) - - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[blends_unweighted[Split.train], None, None], - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, [1000, None, None], lambda: True, config - ).build() - assert len(datasets[0]) == 1000 - for i in range(_NUM_DATASETS): - assert len(datasets[0].datasets[i]) == _SIZES[Split.train][i] - assert datasets[1] is None - assert datasets[2] is None - - # This build used to fail when building datasets without a sample buffer - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[blends[Split.train], None, None], - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, [1000, None, None], lambda: True, config - ).build() - - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend=blends_unweighted[Split.train], - split="100,0,0", - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, [None, None, None], lambda: True, config - ).build() - assert len(datasets[0]) == sum(_SIZES[Split.train]) - assert numpy.all( - numpy.array(datasets[0].weights) - == numpy.unique(datasets[0].dataset_index, return_counts=True)[1] - ) - assert datasets[1] is None - assert datasets[2] is None - - if torch.distributed.is_initialized(): - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend=blends_unweighted[Split.train], - split="100,0,0", - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, - [None, None, None], - lambda: torch.distributed.get_rank() % 2 == 0, - config, - ).build() - if torch.distributed.get_rank() % 2 == 0: - assert len(datasets[0]) == sum(_SIZES[Split.train]) - assert numpy.all( - numpy.array(datasets[0].weights) - == numpy.unique(datasets[0].dataset_index, return_counts=True)[1] - ) - else: - assert datasets[0] is None - assert datasets[1] is None - assert datasets[2] is None - - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend=blends_unweighted[Split.train], - split="50,50,0", - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, [1000, 0, None], lambda: True, config - ).build() - assert len(datasets[0]) == 1000 - assert sum(map(len, datasets[0].datasets)) == sum(_SIZES[Split.train]) / 2 - assert sum(map(len, datasets[1].datasets)) == sum(_SIZES[Split.train]) / 2 - assert datasets[1] is not None and len(datasets[1]) == 0 - assert datasets[2] is None - - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend=blends_unweighted[Split.train], - split="50,50,0", - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, - [int(sum(_SIZES[Split.train]) / 4), int(sum(_SIZES[Split.train])), None], - lambda: True, - config, - ).build() - assert len(datasets[0]) == sum(_SIZES[Split.train]) / 4 - assert len(datasets[1]) == sum(_SIZES[Split.train]) / 2 - assert datasets[2] is None - - # This build used to fail when building datasets without a sample buffer - config = BlendedMegatronDatasetConfig( - random_seed=1234, - sequence_length=_SEQUENCE_LENGTH, - blend=blends[Split.train], - split="990,9,1", - mid_level_dataset_surplus=0.005, - ) - datasets = BlendedMegatronDatasetBuilder( - TestDataset, [100000, 1000, 1], lambda: True, config - ).build() - - -if __name__ == "__main__": - test_builder() diff --git a/tests/unit_tests/data/test_gpt_dataset.py b/tests/unit_tests/data/test_gpt_dataset.py deleted file mode 100644 index fdfa864579..0000000000 --- a/tests/unit_tests/data/test_gpt_dataset.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -## -# Compile megatron.core.datasets.helpers_cpp dependencies before BlendedDataset import -## - -import random - -import numpy -import pytest -import torch - -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset -from megatron.core.datasets.utils import compile_helpers -from megatron.training.tokenizer.tokenizer import _NullTokenizer -from tests.unit_tests.test_utilities import Utils - -_MOCK_VOCAB_SIZE = 8192 - - -def sample_N(dataset, N, randomize): - if randomize: - indices = [random.randint(0, len(dataset) - 1) for _ in range(N)] - else: - indices = list(range(N)) - samples = [dataset[index]["tokens"].numpy() for index in indices] - return samples - - -def test_mock_gpt_dataset(): - if torch.distributed.is_available(): - Utils.initialize_distributed() - if torch.distributed.get_rank() == 0: - compile_helpers() - torch.distributed.barrier() - else: - compile_helpers() - - tokenizer = _NullTokenizer(vocab_size=_MOCK_VOCAB_SIZE) - - config = GPTDatasetConfig( - random_seed=1234, - sequence_length=1024, - split="990,9,1", - reset_position_ids=True, - reset_attention_mask=True, - eod_mask_loss=True, - tokenizer=tokenizer, - mid_level_dataset_surplus=0.005, - ) - - datasets = BlendedMegatronDatasetBuilder( - MockGPTDataset, [100, 100, 100], lambda: True, config - ).build() - - N = 10 - - # Check iso-index variance by split - subsets = [sample_N(dataset, N, randomize=False) for dataset in datasets] - assert not numpy.allclose(subsets[0], subsets[1]) - assert not numpy.allclose(subsets[0], subsets[2]) - assert not numpy.allclose(subsets[1], subsets[2]) - - # Check iso-split / iso-index identity - subset_1A = sample_N(datasets[0], N, randomize=False) - subset_1B = sample_N(datasets[0], N, randomize=False) - assert numpy.allclose(subset_1A, subset_1B) - - # Check iso-split variance by index - subset_1A = sample_N(datasets[0], N, randomize=True) - subset_1B = sample_N(datasets[0], N, randomize=True) - assert not numpy.allclose(subset_1A, subset_1B) - - config = GPTDatasetConfig( - random_seed=1234, - sequence_length=1024, - split="990,10,0", - reset_position_ids=True, - reset_attention_mask=True, - eod_mask_loss=True, - drop_last_partial_validation_sequence=False, - add_extra_token_to_sequence=False, - tokenizer=tokenizer, - mid_level_dataset_surplus=0.005, - ) - - datasets = BlendedMegatronDatasetBuilder( - MockGPTDataset, [0, None, 0], lambda: True, config - ).build() - - sample = datasets[1][datasets[1].shuffle_index.argmax()] - argmax = sample['labels'].shape[0] - torch.flip(sample['labels'], [0]).argmax() - 1 - - # Test add_extra_token_to_sequence - assert sample['tokens'][argmax] != tokenizer.eod - assert sample['labels'][argmax] == tokenizer.eod - - # Test eod_mask_loss, drop_last_partial_validation_sequence - assert argmax < sample['labels'].shape[0] - 1 - assert torch.all(sample['labels'][argmax + 1 :] == 0) - assert not torch.any( - sample['loss_mask'][ - torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0) - ] - ) - - sample = datasets[1][None] - - # Check handling of None index - assert not torch.any(sample['loss_mask']) - - -if __name__ == "__main__": - test_mock_gpt_dataset() diff --git a/tests/unit_tests/data/test_multimodal_dataset.py b/tests/unit_tests/data/test_multimodal_dataset.py deleted file mode 100644 index f6ff575684..0000000000 --- a/tests/unit_tests/data/test_multimodal_dataset.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -## -# Compile megatron.core.datasets.helpers_cpp dependencies before BlendedDataset import -## - -from types import SimpleNamespace - -import torch - -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.multimodal_dataset import MockMultimodalDataset, MultimodalDatasetConfig -from megatron.core.datasets.utils import compile_helpers -from megatron.training.tokenizer.tokenizer import _NullTokenizer -from tests.unit_tests.test_utilities import Utils - -_MOCK_VOCAB_SIZE = 8192 - - -def test_mock_multimodal_dataset(): - if torch.distributed.is_available(): - Utils.initialize_distributed() - if torch.distributed.get_rank() == 0: - compile_helpers() - torch.distributed.barrier() - else: - compile_helpers() - - config = MultimodalDatasetConfig( - random_seed=1234, - sequence_length=1024, - reset_position_ids=False, - reset_attention_mask=False, - eod_mask_loss=True, - image_h=336, - image_w=336, - split="990,9,1", - tokenizer=_NullTokenizer(vocab_size=_MOCK_VOCAB_SIZE), - mid_level_dataset_surplus=0.005, - ) - - datasets = BlendedMegatronDatasetBuilder( - MockMultimodalDataset, [100, 100, 100], lambda: True, config - ).build() - - for ds in datasets: - sample = ds[0] - assert "image" in sample - assert sample["image"].shape == torch.Size([3, 336, 336]) - assert "tokens" in sample - - -if __name__ == "__main__": - test_mock_multimodal_dataset() diff --git a/tests/unit_tests/data/test_preprocess_data.py b/tests/unit_tests/data/test_preprocess_data.py deleted file mode 100644 index 48f3a2e7bb..0000000000 --- a/tests/unit_tests/data/test_preprocess_data.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import json -import os -import sys -import tempfile - -import nltk -import pytest -import requests - -from megatron.core.datasets.indexed_dataset import IndexedDataset -from megatron.training.tokenizer.gpt2_tokenization import ( - PRETRAINED_MERGES_ARCHIVE_MAP, - PRETRAINED_VOCAB_ARCHIVE_MAP, -) -from tools.merge_datasets import main as merge_main -from tools.preprocess_data import Encoder -from tools.preprocess_data import get_args as build_args -from tools.preprocess_data import main as build_main - -__HUGGINGFACE_BERT_BASE_UNCASED_VOCAB = ( - "https://huggingface.co/bert-base-uncased/raw/main/vocab.txt" -) - -__LOCAL_BERT_VOCAB = "/home/gitlab-runner/data/bert_data/vocab.txt" - -__LOCAL_GPT2_MERGE = "/home/gitlab-runner/data/gpt3_data/gpt2-merges.txt" - -__LOCAL_GPT2_VOCAB = "/home/gitlab-runner/data/gpt3_data/gpt2-vocab.json" - - -def dummy_jsonl(odir): - # numbers - list_numbers = [json.dumps({"text": str(i + 1)}) + "\n" for i in range(100)] - with open(os.path.join(odir, "numbers.jsonl"), "w") as writer: - writer.writelines(list_numbers) - # numbers ascending - list_numbers_ascending = [ - json.dumps({"text": " ".join([str(j + 1) for j in range(i + 1)])}) + "\n" - for i in range(100) - ] - with open(os.path.join(odir, "numbers_ascending.jsonl"), "w") as writer: - writer.writelines(list_numbers_ascending) - # test - list_test = [] - with open(__file__) as reader: - for line in reader: - list_test.append(json.dumps({"text": line}) + "\n") - with open(os.path.join(odir, "test.jsonl"), "w") as writer: - writer.writelines(list_test) - - -def build_datasets(idir, odir, extra_args=[]): - for name in os.listdir(idir): - sys.argv = [ - sys.argv[0], - "--input", - os.path.join(idir, name), - "--output-prefix", - os.path.join(odir, os.path.splitext(name)[0]), - ] + extra_args - build_main() - - -def merge_datasets(idir): - sys.argv = [sys.argv[0], "--input", idir, "--output-prefix", os.path.join(idir, "merge")] - merge_main() - - -def do_test_preprocess_data(temp_dir, extra_args=[]): - # set the default nltk data path - os.environ["NLTK_DATA"] = os.path.join(temp_dir, "nltk_data") - nltk.data.path.append(os.environ["NLTK_DATA"]) - - path_to_raws = os.path.join(temp_dir, "sample_raws") - path_to_data = os.path.join(temp_dir, "sample_data") - os.mkdir(path_to_raws) - os.mkdir(path_to_data) - - # create the dummy resources - dummy_jsonl(path_to_raws) - - # build the datasets - build_datasets(path_to_raws, path_to_data, extra_args=extra_args) - - # merge the datasets - merge_datasets(path_to_data) - - sys.argv = [sys.argv[0], "--input", None, "--output-prefix", None] + extra_args - encoder = Encoder(build_args()) - encoder.initializer() - - def tokens_to_string(toks): - for option in ["decode", "detokenize"]: - try: - return getattr(encoder.tokenizer, option)(toks) - except: - continue - raise RuntimeError(f"{type(encoder.tokenizer)} tokenizer cannot decode or detokenize") - - merged_index = 0 - merged_dataset = IndexedDataset(os.path.join(path_to_data, "merge")) - - # sorted to ensure ordering matches merged dataset - basenames = sorted( - [ - name - for name in os.listdir(path_to_data) - if name.endswith(".idx") and not name.startswith("merge") - ] - ) - - # index into the merged document index - merged_doc_index_index = 0 - - for basename in basenames: - realpath_raw = f"{os.path.join(path_to_raws, '_'.join(basename.split('_')[:-2]))}.jsonl" - realpath_doc = os.path.join(path_to_data, basename.split(".")[-2]) - - dataset_index = 0 - dataset = IndexedDataset(realpath_doc) - - merged_doc_idx = merged_dataset.document_indices[ - merged_doc_index_index : merged_doc_index_index + len(dataset.document_indices) - ] - merged_doc_idx = merged_doc_idx - merged_doc_idx[0] - - assert ( - dataset.document_indices == merged_doc_idx - ).all(), f"ERROR: {basename.split('_')[:-2]}: merged dataset document indices mismatch" - - merged_doc_index_index += len(dataset.document_indices) - 1 - - with open(realpath_raw, "rt") as reader: - for json_line in reader: - toks = encoder.encode(json_line)[0]["text"] - - raw = tokens_to_string(toks) - - processed_toks = [] - while len(processed_toks) < len(toks): - processed_toks.extend(dataset[dataset_index]) - dataset_index += 1 - processed = tokens_to_string(processed_toks) - - assert ( - raw == processed - ), f"ERROR: {basename.split('_')[:-2]}: raw and processed documents do not match" - - merged_toks = [] - while len(merged_toks) < len(toks): - merged_toks.extend(merged_dataset[merged_index]) - merged_index += 1 - merged = tokens_to_string(merged_toks) - - assert ( - raw == merged - ), f"ERROR: {basename.split('_')[:-2]}: raw and merged documents do not match" - - print( - f"INFO: {''.join(basename.split('_')[:-2])}: raw, processed, and merged documents match!" - ) - - print("INFO: Success!") - - -def gpt2_vocab(odir): - if os.path.exists(__LOCAL_GPT2_VOCAB): - return __LOCAL_GPT2_VOCAB - path = os.path.join(odir, "vocab.json") - with open(path, "wb") as writer: - writer.write(requests.get(PRETRAINED_VOCAB_ARCHIVE_MAP['gpt2']).content) - return path - - -def gpt2_merge(odir): - if os.path.exists(__LOCAL_GPT2_MERGE): - return __LOCAL_GPT2_MERGE - path = os.path.join(odir, "merge.txt") - with open(path, "wb") as writer: - writer.write(requests.get(PRETRAINED_MERGES_ARCHIVE_MAP['gpt2']).content) - return path - - -def test_preprocess_data_gpt(): - with tempfile.TemporaryDirectory() as temp_dir: - - # gpt specific args - gpt_args = [ - "--tokenizer-type", - "GPT2BPETokenizer", - "--vocab-file", - "/opt/data/tokenizers/megatron/gpt2-vocab.json", - "--merge-file", - "/opt/data/tokenizers/megatron/gpt2-merges.txt", - "--append-eod", - "--workers", - "10", - "--log-interval", - "1", - ] - - do_test_preprocess_data(temp_dir, extra_args=gpt_args) - - -def bert_vocab(odir): - if os.path.exists(__LOCAL_BERT_VOCAB): - return __LOCAL_BERT_VOCAB - path = os.path.join(odir, "vocab.txt") - with open(path, "wb") as writer: - writer.write(requests.get(__HUGGINGFACE_BERT_BASE_UNCASED_VOCAB).content) - return path - - -@pytest.mark.flaky -@pytest.mark.flaky_in_dev -def test_preprocess_data_bert(): - with tempfile.TemporaryDirectory() as temp_dir: - - # bert specific args - bert_args = [ - "--tokenizer-type", - "BertWordPieceLowerCase", - "--vocab-file", - "/opt/data/tokenizers/megatron/gpt2-vocab.json", - "--split-sentences", - "--workers", - "10", - "--log-interval", - "1", - "--partitions", - "2", - "--keep-sequential-samples", - ] - - do_test_preprocess_data(temp_dir, extra_args=bert_args) - - -if __name__ == "__main__": - test_preprocess_data_gpt() - test_preprocess_data_bert() diff --git a/tests/unit_tests/data/test_preprocess_mmdata.py b/tests/unit_tests/data/test_preprocess_mmdata.py deleted file mode 100644 index d6ad4eddc7..0000000000 --- a/tests/unit_tests/data/test_preprocess_mmdata.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import os -import random -import sys -import tempfile - -import nltk -import numpy - -from megatron.core.datasets.indexed_dataset import IndexedDataset -from tests.unit_tests.data.test_preprocess_data import dummy_jsonl, gpt2_merge, gpt2_vocab -from tools.merge_datasets import main as merge_main -from tools.preprocess_mmdata import Encoder -from tools.preprocess_mmdata import get_args as build_args -from tools.preprocess_mmdata import main as build_main - - -def dummy_img(odir_txt, odir_img): - for name in os.listdir(odir_txt): - with open(os.path.join(odir_txt, name), "rt") as reader_txt: - length = sum(1 for _ in reader_txt) - os.makedirs(os.path.join(odir_img, os.path.splitext(name)[0]), exist_ok=False) - for i in range(length): - with open( - os.path.join(odir_img, os.path.splitext(name)[0], f"{str(i).zfill(4)}.img"), "wb" - ) as writer_img: - # 32 * 32 - 1 to induce preprocessing 0-index padding - writer_img.write(bytes([random.randint(0, 255) for _ in range(32 * 32 - 1)])) - - -def build_datasets(idir_txt, idir_img, odir, extra_args=[]): - for name in os.listdir(idir_txt): - sys.argv = [ - sys.argv[0], - "--input", - os.path.join(idir_txt, name), - "--input-image", - os.path.join(idir_img, os.path.splitext(name)[0]), - "--output-prefix", - os.path.join(odir, os.path.splitext(name)[0]), - ] + extra_args - build_main() - - -def merge_datasets(idir): - sys.argv = [ - sys.argv[0], - "--input", - idir, - "--output-prefix", - os.path.join(idir, "merge"), - "--multimodal", - ] - merge_main() - - -def do_test_preprocess_mmdata(temp_dir, extra_args=[]): - # set the default nltk data path - os.environ["NLTK_DATA"] = os.path.join(temp_dir, "nltk_data") - nltk.data.path.append(os.environ["NLTK_DATA"]) - - path_to_raws_txt = os.path.join(temp_dir, "sample_raws_txt") - path_to_raws_img = os.path.join(temp_dir, "sample_raws_img") - path_to_data = os.path.join(temp_dir, "sample_data") - os.mkdir(path_to_raws_txt) - os.mkdir(path_to_raws_img) - os.mkdir(path_to_data) - - # create the dummy text resources - dummy_jsonl(path_to_raws_txt) - - # create the dummy image resources - dummy_img(path_to_raws_txt, path_to_raws_img) - - # build the datasets - build_datasets(path_to_raws_txt, path_to_raws_img, path_to_data, extra_args=extra_args) - - # merge the datasets - merge_datasets(path_to_data) - - sys.argv = [ - sys.argv[0], - "--input", - None, - "--input-image", - None, - "--output-prefix", - None, - ] + extra_args - encoder = Encoder(build_args()) - encoder.initializer() - - def tokens_to_string(toks): - for option in ["decode", "detokenize"]: - try: - return getattr(encoder.tokenizer, option)(toks) - except AttributeError: - continue - raise RuntimeError(f"{type(encoder.tokenizer)} tokenizer cannot `decode` or `detokenize`.") - - merged_index = 0 - merged_dataset = IndexedDataset(os.path.join(path_to_data, "merge"), multimodal=True) - - # sorted to ensure ordering matches merged dataset - basenames = sorted( - [ - name - for name in os.listdir(path_to_data) - if name.endswith(".idx") and not name.startswith("merge") - ] - ) - - # index into the merged document index - merged_doc_index_index = 0 - - for basename in basenames: - realpath_raw_txt = os.path.join(path_to_raws_txt, f"{os.path.splitext(basename)[0]}.jsonl") - realpath_raw_img = os.path.join(path_to_raws_img, os.path.splitext(basename)[0]) - realpath_doc = os.path.join(path_to_data, os.path.splitext(basename)[0]) - - dataset_index = 0 - dataset = IndexedDataset(realpath_doc, multimodal=True) - - merged_doc_idx = merged_dataset.document_indices[ - merged_doc_index_index : merged_doc_index_index + len(dataset.document_indices) - ] - merged_doc_idx = merged_doc_idx - merged_doc_idx[0] - - assert ( - dataset.document_indices == merged_doc_idx - ).all(), f"ERROR: {basename.split('_')[:-2]}: merged dataset document indices mismatch" - - merged_doc_index_index += len(dataset.document_indices) - 1 - - with open(realpath_raw_txt, "rt") as reader: - for json_line, image_path in zip( - reader, - [ - os.path.join(realpath_raw_img, basename) - for basename in os.listdir(realpath_raw_img) - ], - ): - toks, image, length = encoder.encode((json_line, image_path)) - - raw_text = tokens_to_string(toks) - # reverse to account for preprocessing 0-index padding - raw_image = image[::-1] - - processed_toks = dataset[dataset_index][0] - assert dataset[dataset_index][1] == 0 - processed_text = tokens_to_string(processed_toks) - - processed_image = dataset[dataset_index + 1][0] - assert dataset[dataset_index + 1][1] == 1 - # reverse to account for preprocessing 0-index padding - processed_image = processed_image[::-1][0 : raw_image.size] - - assert ( - raw_text == processed_text - ), f"ERROR: {basename.split('_')[:-2]}: raw and processed documents (text) do not match" - - assert numpy.allclose( - raw_image, processed_image - ), f"ERROR: {basename.split('_')[:-2]}: raw and processed documents (image) do not match" - - dataset_index += 2 - - merged_toks = merged_dataset[merged_index][0] - assert merged_dataset[merged_index][1] == 0 - merged_text = tokens_to_string(merged_toks) - - merged_image = merged_dataset[merged_index + 1][0] - assert merged_dataset[merged_index + 1][1] == 1 - # reverse to account for preprocessing 0-index padding - merged_image = merged_image[::-1][0 : raw_image.size] - - assert ( - raw_text == merged_text - ), f"ERROR: {basename.split('_')[:-2]}: raw and merged documents (text) do not match" - - assert numpy.allclose( - raw_image, merged_image - ), f"ERROR: {basename.split('_')[:-2]}: raw and merged documents (image) do not match" - - merged_index += 2 - - print( - f"INFO: {''.join(basename.split('_')[:-2])}: raw, processed, and merged documents match!" - ) - - print("INFO: Success!") - - -def test_preprocess_mmdata(): - with tempfile.TemporaryDirectory() as temp_dir: - - # gpt specific args - gpt_args = [ - "--pad-length", - "1024", - "--tokenizer-type", - "GPT2BPETokenizer", - "--vocab-file", - gpt2_vocab(temp_dir), - "--merge-file", - gpt2_merge(temp_dir), - "--append-eod", - "--workers", - "10", - "--log-interval", - "1", - ] - - do_test_preprocess_mmdata(temp_dir, extra_args=gpt_args) - - -if __name__ == "__main__": - test_preprocess_mmdata() diff --git a/tests/unit_tests/dist_checkpointing/models/test_bert_model.py b/tests/unit_tests/dist_checkpointing/models/test_bert_model.py deleted file mode 100644 index 27f0144785..0000000000 --- a/tests/unit_tests/dist_checkpointing/models/test_bert_model.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import os - -import pytest -import torch - -from megatron.core import parallel_state as ps -from megatron.core.models.bert.bert_layer_specs import ( - bert_layer_local_spec, - bert_layer_with_transformer_engine_spec, -) -from megatron.core.models.bert.bert_model import BertModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import AttnBackend -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.dist_checkpointing.models.common import ( - common_test_parallel_reconfiguration_e2e, - common_test_simple_sharded_state_dict_save_load, - common_test_state_dict_comparison, - common_test_vocab_size_padding_change, -) -from tests.unit_tests.test_utilities import Utils - - -def initialize_bert_model( - seed, layer_spec_fn=bert_layer_with_transformer_engine_spec, vocab_size=128, **config_kwargs -): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - layer_spec = layer_spec_fn() if callable(layer_spec_fn) else layer_spec_fn - - default_config_kwargs = dict( - num_layers=8, - hidden_size=16, - num_attention_heads=8, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - attention_backend=AttnBackend.auto, - ) - default_config_kwargs.update(**config_kwargs) - transformer_config = TransformerConfig(**default_config_kwargs) - pre_process = ps.is_pipeline_first_stage() - post_process = ps.is_pipeline_last_stage() - model = BertModel( - config=transformer_config, - transformer_layer_spec=layer_spec, - vocab_size=vocab_size, - max_sequence_length=4, - pre_process=pre_process, - post_process=post_process, - num_tokentypes=0, - ) - - with torch.no_grad(): - for p in model.parameters(): - p.random_() - return model - - -class TestBertModel: - @pytest.mark.parametrize( - 'src_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec] - ) - @pytest.mark.parametrize( - 'dst_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec] - ) - @pytest.mark.internal - def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_layer_spec, dst_layer_spec): - common_test_simple_sharded_state_dict_save_load( - initialize_bert_model, tmp_path_dist_ckpt, src_layer_spec, dst_layer_spec - ) - - -class TestBERTModelReconfiguration: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - ('use_fpsl', 'src_tp_pp', 'dest_tp_pp', 'src_layer_spec', 'dst_layer_spec'), - [ - ( - False, - (2, 4), - (4, 2), - bert_layer_with_transformer_engine_spec, - bert_layer_with_transformer_engine_spec, - ), - ( - False, - (1, 8), - (8, 1), - bert_layer_with_transformer_engine_spec, - bert_layer_with_transformer_engine_spec, - ), - ( - True, - (2, 1), - (1, 8), - bert_layer_with_transformer_engine_spec, - bert_layer_with_transformer_engine_spec, - ), - ( - False, - (1, 1), - (2, 2), - bert_layer_with_transformer_engine_spec, - bert_layer_with_transformer_engine_spec, - ), - (True, (2, 1), (1, 8), bert_layer_local_spec, bert_layer_local_spec), - (True, (1, 1), (2, 4), bert_layer_with_transformer_engine_spec, bert_layer_local_spec), - (False, (1, 8), (2, 1), bert_layer_local_spec, bert_layer_with_transformer_engine_spec), - ], - ) - @pytest.mark.internal - def test_parallel_reconfiguration_e2e( - self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, src_layer_spec, dst_layer_spec, use_fpsl - ): - """Test model saving and loading with different TP/PP""" - Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1]) - - common_test_parallel_reconfiguration_e2e( - initialize_bert_model, - tmp_path_dist_ckpt, - src_tp_pp, - dest_tp_pp, - src_layer_spec, - dst_layer_spec, - use_fpsl, - ) - - @pytest.mark.internal - def test_state_dict_comparison(self, tmp_path_dist_ckpt): - common_test_state_dict_comparison(initialize_bert_model, tmp_path_dist_ckpt) - - @pytest.mark.parametrize( - "vocab_size_base,src_tp_pp,dest_tp_pp", - [ - (128, (2, 4), (4, 2)), - (17, (1, 8), (8, 1)), - (127, (1, 8), (8, 1)), - (31123, (1, 1), (1, 8)), - (17, (1, 1), (1, 8)), - ], - ) - @pytest.mark.internal - def test_vocab_size_padding_change( - self, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp - ): - """Test model loading with different vocab size (caused by TP padding).""" - Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1]) - common_test_vocab_size_padding_change( - initialize_bert_model, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp - ) diff --git a/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py b/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py deleted file mode 100644 index e18d3b4683..0000000000 --- a/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import functools -import pathlib -from collections.abc import Callable -from typing import Optional - -import pytest -import torch - -from megatron.core import parallel_state as ps -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec as gpt_local_spec -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_with_transformer_engine_spec as gpt_te_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from tests.unit_tests.dist_checkpointing.models.common import ( - common_test_parallel_reconfiguration_e2e, - common_test_simple_sharded_state_dict_save_load, - common_test_state_dict_comparison, - common_test_vocab_size_padding_change, -) -from tests.unit_tests.test_utilities import Utils - -# List of model spec functions -_spec_fn_list: list[Callable[[], ModuleSpec]] = [gpt_te_spec, gpt_local_spec] -_gpt_te_spec_op_fuser: Optional[Callable] = None -if is_te_min_version("1.13.0"): - _gpt_te_spec_op_fuser = functools.partial(gpt_te_spec, use_te_op_fuser=True) - _spec_fn_list.append(_gpt_te_spec_op_fuser) - - -def initialize_gpt_model(seed, layer_spec_fn=gpt_te_spec, vocab_size=128, **config_kwargs): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - default_config_kwargs = dict( - num_layers=8, - hidden_size=16, - num_attention_heads=8, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - ) - default_config_kwargs.update(**config_kwargs) - transformer_config = TransformerConfig(**default_config_kwargs) - pre_process = ps.is_pipeline_first_stage() - post_process = ps.is_pipeline_last_stage() - model = GPTModel( - config=transformer_config, - transformer_layer_spec=layer_spec_fn(), - vocab_size=vocab_size, - max_sequence_length=4, - pre_process=pre_process, - post_process=post_process, - ) - - with torch.no_grad(): - for p in model.parameters(): - p.random_() - return model - - -class TestGPTModel: - @pytest.mark.parametrize('src_layer_spec_fn', _spec_fn_list) - @pytest.mark.parametrize('dst_layer_spec_fn', _spec_fn_list) - def test_sharded_state_dict_save_load( - self, - tmp_path_dist_ckpt: pathlib.Path, - src_layer_spec_fn: Callable[[], ModuleSpec], - dst_layer_spec_fn: Callable[[], ModuleSpec], - ): - common_test_simple_sharded_state_dict_save_load( - initialize_gpt_model, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn - ) - - -class TestGPTModelReconfiguration: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - ( - 'use_fpsl', - 'load_order', - 'store_order', - 'src_tp_pp', - 'dest_tp_pp', - 'singleton_local_shards', - 'src_layer_spec_fn', - 'dst_layer_spec_fn', - ), - [ - (False, 'tp-dp-pp', 'tp-dp-pp', (2, 4), (4, 2), True, gpt_te_spec, gpt_te_spec), - (False, 'tp-pp-dp', 'tp-pp-dp', (1, 8), (8, 1), False, gpt_te_spec, gpt_te_spec), - (True, 'tp-dp-pp', 'tp-pp-dp', (2, 1), (1, 8), True, gpt_te_spec, gpt_te_spec), - (False, 'tp-dp-pp', 'tp-dp-pp', (1, 1), (2, 2), True, gpt_te_spec, gpt_te_spec), - (True, 'tp-pp-dp', 'tp-pp-dp', (2, 1), (1, 8), False, gpt_local_spec, gpt_local_spec), - (False, 'tp-dp-pp', 'tp-pp-dp', (1, 1), (2, 4), False, gpt_te_spec, gpt_local_spec), - (True, 'tp-dp-pp', 'tp-dp-pp', (2, 4), (4, 2), True, gpt_local_spec, gpt_te_spec), - (False, 'tp-pp-dp', 'tp-pp-dp', (2, 1), (1, 8), False, gpt_te_spec, gpt_local_spec), - (False, 'tp-dp-pp', 'tp-pp-dp', (2, 4), (2, 4), True, gpt_local_spec, gpt_local_spec), - ( - False, - 'tp-dp-pp', - 'tp-dp-pp', - (2, 4), - (4, 2), - False, - gpt_te_spec, - _gpt_te_spec_op_fuser, - ), - ( - False, - 'tp-dp-pp', - 'tp-dp-pp', - (2, 4), - (4, 2), - False, - _gpt_te_spec_op_fuser, - gpt_te_spec, - ), - ], - ) - def test_parallel_reconfiguration_e2e( - self, - tmp_path_dist_ckpt: pathlib.Path, - src_tp_pp: tuple[int, int], - dest_tp_pp: tuple[int, int], - src_layer_spec_fn: Optional[Callable[[], ModuleSpec]], - dst_layer_spec_fn: Optional[Callable[[], ModuleSpec]], - use_fpsl: bool, - load_order: str, - store_order: str, - singleton_local_shards: bool, - ): - """Test model saving and loading with different TP/PP""" - if src_layer_spec_fn is None or dst_layer_spec_fn is None: - pytest.skip("Spec function is not supported") - Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1]) - common_test_parallel_reconfiguration_e2e( - initialize_gpt_model, - tmp_path_dist_ckpt, - src_tp_pp, - dest_tp_pp, - src_layer_spec_fn, - dst_layer_spec_fn, - use_fpsl, - load_order, - store_order, - metadata={'singleton_local_shards': singleton_local_shards}, - ) - - def test_state_dict_comparison(self, tmp_path_dist_ckpt): - common_test_state_dict_comparison(initialize_gpt_model, tmp_path_dist_ckpt) - - @pytest.mark.parametrize( - "vocab_size_base,src_tp_pp,dest_tp_pp", - [ - (128, (2, 4), (4, 2)), - (17, (1, 8), (8, 1)), - (127, (1, 8), (8, 1)), - (31123, (1, 1), (1, 8)), - (17, (1, 1), (1, 8)), - ], - ) - def test_vocab_size_padding_change( - self, - tmp_path_dist_ckpt: pathlib.Path, - vocab_size_base: int, - src_tp_pp: tuple[int, int], - dest_tp_pp: tuple[int, int], - ) -> None: - """Test model loading with different vocab size (caused by TP padding).""" - Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1]) - common_test_vocab_size_padding_change( - initialize_gpt_model, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp - ) - - @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp', 'src_layer_spec_fn', 'dst_layer_spec_fn'), - [ - ((2, 4), (4, 2), gpt_te_spec, gpt_te_spec), - ((2, 4), (4, 2), gpt_te_spec, gpt_local_spec), - ((2, 4), (4, 2), gpt_local_spec, gpt_te_spec), - ((2, 4), (4, 2), gpt_te_spec, _gpt_te_spec_op_fuser), - ((2, 4), (4, 2), _gpt_te_spec_op_fuser, gpt_te_spec), - ], - ) - def test_mlp_with_glu( - self, - tmp_path_dist_ckpt: pathlib.Path, - src_tp_pp: tuple[int, int], - dest_tp_pp: tuple[int, int], - src_layer_spec_fn: Optional[Callable[[], ModuleSpec]], - dst_layer_spec_fn: Optional[Callable[[], ModuleSpec]], - ) -> None: - """Test model loading when MLP activation is gated linear unit.""" - if src_layer_spec_fn is None or dst_layer_spec_fn is None: - pytest.skip("Spec function is not supported") - Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1]) - common_test_parallel_reconfiguration_e2e( - functools.partial(initialize_gpt_model, gated_linear_unit=True), - tmp_path_dist_ckpt, - src_tp_pp, - dest_tp_pp, - src_layer_spec_fn, - dst_layer_spec_fn, - False, # use_fpsl - ) diff --git a/tests/unit_tests/dist_checkpointing/models/test_mamba.py b/tests/unit_tests/dist_checkpointing/models/test_mamba.py deleted file mode 100644 index dd90cc2a7b..0000000000 --- a/tests/unit_tests/dist_checkpointing/models/test_mamba.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing import load, load_plain_tensors, save -from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) -from megatron.core.dist_checkpointing.strategies.fully_parallel import ( - FullyParallelLoadStrategyWrapper, - FullyParallelSaveStrategyWrapper, -) -from megatron.core.extensions.transformer_engine import ( - TELayerNormColumnParallelLinear, - TERowParallelLinear, -) -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -def initialize_mamba(seed, glu=True, **config_kwargs): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - pp_size = parallel_state.get_pipeline_model_parallel_world_size() - num_moe_experts = 8 - default_config_kwargs = dict( - num_layers=pp_size, - hidden_size=256, # for Mamba: expand=2, headdim=64 -> nheads=8 (divisible by ngroups=8) - num_attention_heads=8, # must be divisible by tp_size (testing up to tp_size=8) - num_moe_experts=num_moe_experts, - use_cpu_initialization=True, - gated_linear_unit=glu, - add_bias_linear=False, - pipeline_dtype=torch.bfloat16, - ) - default_config_kwargs.update(**config_kwargs) - transformer_config = TransformerConfig(**default_config_kwargs) - submodules = MambaMixerSubmodules( - in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear - ) - model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups(required_pgs=['tp', 'cp']) - model = MambaMixer( - transformer_config, - submodules, - transformer_config.hidden_size, - rmsnorm=True, - model_comm_pgs=model_comm_pgs, - ) - return model - - -class TestMambaReconfiguration: - @pytest.mark.parametrize( - "use_fpsl,src_tp_pp_exp_cp,dest_tp_pp_exp_cp,use_glu", - [ - (False, (2, 4, 1, 1), (2, 4, 1, 1), False), - (True, (2, 4, 1, 1), (2, 4, 1, 1), False), - (False, (1, 1, 1, 1), (1, 1, 1, 1), False), - (True, (1, 1, 1, 1), (1, 1, 4, 1), False), - (False, (1, 1, 8, 1), (1, 1, 2, 1), False), - (False, (2, 2, 2, 1), (4, 2, 1, 1), False), - (True, (1, 1, 4, 1), (8, 1, 1, 1), False), - (False, (1, 8, 1, 1), (1, 8, 1, 1), False), - (False, (1, 1, 4, 1), (2, 1, 1, 1), False), - (False, (1, 1, 1, 1), (1, 1, 1, 1), True), - (False, (1, 1, 1, 1), (1, 1, 4, 1), True), - (True, (1, 1, 1, 1), (2, 1, 1, 1), True), - (False, (1, 1, 4, 1), (8, 1, 1, 1), True), - # CP-focused cases: - (False, (8, 1, 1, 1), (1, 1, 1, 8), False), - (False, (4, 1, 1, 2), (2, 1, 1, 4), False), - # TODO(duncan): investigate why changing pp_size (up or down) yields an unexpected shape - # mismatch error on dt_bias - ], - ) - @pytest.mark.parametrize('singleton_local_shards', [True, False]) - def test_parallel_reconfiguration_e2e( - self, - tmp_path_dist_ckpt, - src_tp_pp_exp_cp, - dest_tp_pp_exp_cp, - use_glu, - use_fpsl, - singleton_local_shards, - ): - """Test model saving and loading with different TP/PP/expert parallelism""" - src_tp, src_pp, src_exp, src_cp = src_tp_pp_exp_cp - metadata = {'singleton_local_shards': singleton_local_shards} - Utils.initialize_model_parallel( - src_tp, src_pp, expert_model_parallel_size=src_exp, context_parallel_size=src_cp - ) - dest_tp, dest_pp, dest_exp, dest_cp = dest_tp_pp_exp_cp - with ( - TempNamedDir( - tmp_path_dist_ckpt / 'test_sequential_mlp_reconfiguration_model_A' - ) as ckpt_dir_A, - TempNamedDir( - tmp_path_dist_ckpt / 'test_sequential_mlp_reconfiguration_model_B' - ) as ckpt_dir_B, - ): - # Save checkpoint A - layer_prefix = f'{parallel_state.get_pipeline_model_parallel_rank()}.' - model_A = initialize_mamba( - 1, - use_glu, - tensor_model_parallel_size=src_tp, - pipeline_model_parallel_size=src_pp, - expert_model_parallel_size=src_exp, - context_parallel_size=src_cp, - # Sequence parallelism is required when using both expert and tensor parallelism - sequence_parallel=(src_exp > 1 and src_pp > 1), - ) - sharded_state_dict = model_A.sharded_state_dict(prefix=layer_prefix, metadata=metadata) - - save_strategy = get_default_save_sharded_strategy() - if use_fpsl: - save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, - parallel_state.get_data_parallel_group(with_context_parallel=True), - True, - ) - save(sharded_state_dict, ckpt_dir_A, save_strategy) - Utils.destroy_model_parallel() - - # Load checkpoint A with different TP/PP/expert/CP and save as checkpoint B - # No FPS this time, only FPL - Utils.initialize_model_parallel( - dest_tp, dest_pp, expert_model_parallel_size=dest_exp, context_parallel_size=dest_cp - ) - model_B = initialize_mamba( - 2, - use_glu, - tensor_model_parallel_size=dest_tp, - pipeline_model_parallel_size=dest_pp, - expert_model_parallel_size=dest_exp, - context_parallel_size=dest_cp, - # Sequence parallelism is required when using both expert and tensor parallelism - sequence_parallel=(dest_exp > 1 and dest_pp > 1), - ) - if use_fpsl: - load_strategy = get_default_load_sharded_strategy(ckpt_dir_A) - load_strategy = FullyParallelLoadStrategyWrapper( - load_strategy, - parallel_state.get_data_parallel_group(with_context_parallel=True), - ) - else: - load_strategy = None - state_dict = load( - model_B.sharded_state_dict(prefix=layer_prefix, metadata=metadata), - ckpt_dir_A, - load_strategy, - ) - model_B.load_state_dict( - {k.removeprefix(layer_prefix): v for k, v in state_dict.items()} - ) - save(model_B.sharded_state_dict(prefix=layer_prefix, metadata=metadata), ckpt_dir_B) - Utils.destroy_model_parallel() - - # Test both checkpoints are equal - Utils.initialize_model_parallel(1, 1) - state_dict_A = load_plain_tensors(ckpt_dir_A) - state_dict_B = load_plain_tensors(ckpt_dir_B) - diffs = diff(state_dict_A, state_dict_B) - assert not any(map(bool, diffs)), diffs - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py b/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py deleted file mode 100644 index 18cfbf67ce..0000000000 --- a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import inspect -import logging - -import pytest -import torch -from torch.optim import Adam - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing import ShardedTensor, load, load_plain_tensors, save -from megatron.core.dist_checkpointing.dict_utils import diff, nested_values -from megatron.core.dist_checkpointing.optimizer import ( - get_param_id_to_sharded_param_map, - optim_state_to_sharding_state, -) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.mlp import MLP, apply_swiglu_sharded_factory -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -def initialize_mlp(glu=True): - model_parallel_cuda_manual_seed(123) - pp_size = parallel_state.get_pipeline_model_parallel_world_size() - transformer_config = TransformerConfig( - num_layers=pp_size, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - gated_linear_unit=glu, - ) - return MLP( - transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules.mlp.submodules - ) - - -class TestParallelMLPWithGLU: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - "src_tp_pp,dest_tp_pp", - [ - # changing PP is impossible because the number of layers must be the same - ((2, 2), (4, 2)), - ((1, 1), (8, 1)), - ((1, 8), (1, 8)), - ((1, 1), (2, 1)), - ], - ) - @pytest.mark.parametrize('singleton_local_shards', [True, False]) - def test_parallel_reconfiguration_e2e( - self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, singleton_local_shards - ): - """Test module saving and loading with different TP/PP""" - Utils.initialize_model_parallel(*src_tp_pp) - metadata = {'singleton_local_shards': singleton_local_shards} - - with ( - TempNamedDir(tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_A') as ckpt_dir_A, - TempNamedDir(tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_B') as ckpt_dir_B, - ): - # Save checkpoint A - layer_prefix = f'{parallel_state.get_pipeline_model_parallel_rank()}.' - mlp_A = initialize_mlp() - save(mlp_A.sharded_state_dict(prefix=layer_prefix, metadata=metadata), ckpt_dir_A) - Utils.destroy_model_parallel() - - # Load checkpoint A with different TP/PP and save as checkpoint B - Utils.initialize_model_parallel(*dest_tp_pp) - mlp_B = initialize_mlp() - state_dict = load( - mlp_B.sharded_state_dict(prefix=layer_prefix, metadata=metadata), ckpt_dir_A - ) - mlp_B.load_state_dict({k.removeprefix(layer_prefix): v for k, v in state_dict.items()}) - save(mlp_B.sharded_state_dict(prefix=layer_prefix, metadata=metadata), ckpt_dir_B) - Utils.destroy_model_parallel() - - # Test both checkpoints are equal - Utils.initialize_model_parallel(1, 1) - state_dict_A = load_plain_tensors(ckpt_dir_A) - state_dict_B = load_plain_tensors(ckpt_dir_B) - diffs = diff(state_dict_A, state_dict_B) - assert not any(map(bool, diffs)), diffs - - def test_oom_is_handled(self, caplog): - Utils.initialize_model_parallel(Utils.world_size, 1) - dtype = torch.bfloat16 - - # Compute free memory in bytes - device = torch.cuda.current_device() - allocated = torch.cuda.memory_allocated(device) - total = torch.cuda.get_device_properties(device).total_memory - free = total - allocated - - # We should create two tensor which take up between 50% and 100% of free memory, - # so that the torch.cat tries to allocate twice as many and OOMs. - expected_local_num_bytes = free * 0.6 - - local_num_elems = expected_local_num_bytes // torch._utils._element_size(dtype) - local_num_elems = int(local_num_elems // 1024 * 1024) - assert local_num_elems % 1024 == 0 - - local_w_plus_v_shape = (local_num_elems // 512, 512) - local_w_or_v_shape = (local_num_elems // 1024, 512) - - fc1_weight_sh_ten = ShardedTensor.from_rank_offsets( - 'a', - torch.ones(local_w_plus_v_shape, device='cuda', dtype=dtype), - (0, Utils.rank, Utils.world_size), - ) - fc1_factory = apply_swiglu_sharded_factory(fc1_weight_sh_ten, ()) - sharded_state_dict = fc1_factory.build() - assert len(sharded_state_dict) == 2 - assert sharded_state_dict[0].data.shape == local_w_or_v_shape - # NOTE: with singleton_local_shards=True this assert would fail - global shape is - # `(Utils.world_size * local_w_or_v_shape[0], local_w_or_v_shape[1])` - assert sharded_state_dict[0].global_shape[-2:] == ( - Utils.world_size * local_w_plus_v_shape[0], - local_w_or_v_shape[1], - ) - - # Checkpoint load replaces ShardedTensors with tensors. - # Load happens in-place, so we can just use the same tensors - loaded_state_dict = [sh_ten.data for sh_ten in sharded_state_dict] - - # The critical part that should OOM: - with caplog.at_level(logging.WARNING): - fc1_factory.merge_fn(loaded_state_dict) - assert "CUDA OutOfMemoryError encountered during tensors merging" in caplog.text diff --git a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py deleted file mode 100644 index c4a19bd3fc..0000000000 --- a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py +++ /dev/null @@ -1,442 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import os - -import pytest -import torch -from transformer_engine.pytorch.fp8 import check_fp8_support, fp8_autocast - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing import load, load_plain_tensors, save -from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) -from megatron.core.dist_checkpointing.strategies.fully_parallel import ( - FullyParallelLoadStrategyWrapper, - FullyParallelSaveStrategyWrapper, -) -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP -from megatron.core.transformer.moe.moe_utils import get_default_model_comm_pgs -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - -fp8_available, reason_for_no_fp8 = check_fp8_support() - - -def initialize_expert_layer(seed, glu=True, expert_type='sequential', fp8=False, **config_kwargs): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - model_comm_pgs = get_default_model_comm_pgs() - - pp_size = parallel_state.get_pipeline_model_parallel_world_size() - num_moe_experts = 8 - num_local_experts = num_moe_experts // parallel_state.get_expert_model_parallel_world_size() - default_config_kwargs = dict( - num_layers=pp_size, - hidden_size=16, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - use_cpu_initialization=True, - gated_linear_unit=glu, - fp8="hybrid" if fp8 else None, - add_bias_linear=False, - ) - default_config_kwargs.update(**config_kwargs) - transformer_config = TransformerConfig(**default_config_kwargs) - if expert_type == 'grouped': - model = GroupedMLP(num_local_experts, transformer_config, model_comm_pgs) - elif expert_type == 'te_grouped': - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=True - ) - model = TEGroupedMLP( - num_local_experts, - transformer_config, - transformer_layer_spec.submodules.mlp.submodules.experts.submodules, - model_comm_pgs, - ) - elif expert_type == 'sequential': - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False - ) - model = SequentialMLP( - num_local_experts, - transformer_config, - transformer_layer_spec.submodules.mlp.submodules.experts.submodules, - model_comm_pgs, - ) - elif expert_type == 'te_sequential': - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False - ) - model = SequentialMLP( - num_local_experts, - transformer_config, - transformer_layer_spec.submodules.mlp.submodules.experts.submodules, - model_comm_pgs, - ) - else: - raise ValueError( - 'expert_type can only be one of ["sequential", "te_sequential", "grouped",' - ' "te_grouped"]' - ) - return model - - -expert_type = ['sequential', 'grouped'] -src_dest_expert_type = [('sequential', 'grouped'), ('grouped', 'sequential')] -if is_te_min_version("1.7.0.dev0"): - expert_type.append('te_sequential') - src_dest_expert_type.append(('sequential', 'te_sequential')) - src_dest_expert_type.append(('te_sequential', 'sequential')) -if is_te_min_version("1.9.0.dev0"): - expert_type.append('te_grouped') - src_dest_expert_type.append(('te_sequential', 'te_grouped')) - src_dest_expert_type.append(('te_grouped', 'te_sequential')) - - -class TestExpertLayerReconfiguration: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - @pytest.mark.parametrize( - "use_fpsl,src_tp_pp_ep_etp,dest_tp_pp_ep_etp,use_glu", - [ - # changing PP is impossible because the number of layers must be the same - (False, (2, 4, 1, 2), (2, 4, 1, 2), False), - (True, (2, 4, 1, 2), (2, 4, 1, 2), False), - (False, (2, 4, 1, 2), (1, 4, 1, 2), False), - (True, (2, 1, 1, 2), (1, 1, 1, 2), False), - (False, (1, 1, 1, 1), (1, 1, 1, 1), False), - (True, (1, 1, 1, 1), (1, 1, 4, 1), False), - (False, (1, 1, 8, 1), (1, 1, 2, 1), False), - (False, (2, 2, 2, 2), (4, 2, 1, 4), False), - (True, (1, 1, 4, 1), (8, 1, 1, 1), False), - (False, (1, 8, 1, 1), (1, 8, 1, 1), False), - (False, (1, 1, 4, 1), (2, 1, 1, 2), False), - (False, (2, 1, 4, 1), (2, 1, 1, 4), False), - (False, (1, 1, 1, 1), (1, 1, 1, 1), True), - (False, (1, 1, 1, 1), (1, 1, 4, 1), True), - (True, (1, 1, 1, 1), (2, 1, 1, 1), True), - (False, (1, 1, 4, 1), (8, 1, 1, 8), True), - ], - ) - @pytest.mark.parametrize("expert_type", expert_type) - @pytest.mark.parametrize( - "load_order,store_order", - [ - ("tp-ep-dp-pp", "tp-ep-dp-pp"), - # ("tp-ep-dp-pp", "ep-tp-dp-pp"), - # ("ep-tp-dp-pp", "ep-tp-dp-pp"), - # ("ep-tp-dp-pp", "tp-ep-dp-pp"), - ], - ) - @pytest.mark.parametrize('singleton_local_shards', [True, False]) - def test_parallel_reconfiguration_e2e( - self, - tmp_path_dist_ckpt, - src_tp_pp_ep_etp, - dest_tp_pp_ep_etp, - use_glu, - use_fpsl, - expert_type, - load_order, - store_order, - singleton_local_shards, - ): - """Test model saving and loading with different TP/PP/EP/ETP(expert-tensor-parallel)""" - src_tp, src_pp, src_ep, src_etp = src_tp_pp_ep_etp - dest_tp, dest_pp, dest_ep, dest_etp = dest_tp_pp_ep_etp - metadata = {'singleton_local_shards': singleton_local_shards} - # Save checkpoint A - Utils.initialize_model_parallel( - src_tp, - src_pp, - expert_model_parallel_size=src_ep, - expert_tensor_parallel_size=src_etp, - order=store_order, - ) - with ( - TempNamedDir( - tmp_path_dist_ckpt / 'test_expert_layer_reconfiguration_model_A' - ) as ckpt_dir_A, - TempNamedDir( - tmp_path_dist_ckpt / 'test_expert_layer_reconfiguration_model_B' - ) as ckpt_dir_B, - ): - layer_prefix = f'{parallel_state.get_pipeline_model_parallel_rank()}.' - model_A = initialize_expert_layer(1, use_glu, expert_type) - sharded_state_dict = model_A.sharded_state_dict(prefix=layer_prefix, metadata=metadata) - - save_strategy = get_default_save_sharded_strategy() - if use_fpsl: - save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, - parallel_state.get_data_parallel_group(with_context_parallel=True), - True, - ) - save(sharded_state_dict, ckpt_dir_A, save_strategy) - Utils.destroy_model_parallel() - - # Load checkpoint A with different TP/PP/EP and save as checkpoint B - # No FPS this time, only FPL - Utils.initialize_model_parallel( - dest_tp, - dest_pp, - expert_model_parallel_size=dest_ep, - expert_tensor_parallel_size=dest_etp, - order=load_order, - ) - model_B = initialize_expert_layer(1, use_glu, expert_type) - if use_fpsl: - load_strategy = get_default_load_sharded_strategy(ckpt_dir_A) - load_strategy = FullyParallelLoadStrategyWrapper( - load_strategy, - parallel_state.get_data_parallel_group(with_context_parallel=True), - ) - else: - load_strategy = None - state_dict = load( - model_B.sharded_state_dict(prefix=layer_prefix, metadata=metadata), - ckpt_dir_A, - load_strategy, - ) - model_B.load_state_dict( - {k.removeprefix(layer_prefix): v for k, v in state_dict.items()} - ) - save(model_B.sharded_state_dict(prefix=layer_prefix, metadata=metadata), ckpt_dir_B) - Utils.destroy_model_parallel() - - # Test both checkpoints are equal - Utils.initialize_model_parallel(1, 1) - state_dict_A = load_plain_tensors(ckpt_dir_A) - state_dict_B = load_plain_tensors(ckpt_dir_B) - diffs = diff(state_dict_A, state_dict_B) - assert not any(map(bool, diffs)), diffs - - @pytest.mark.internal - @pytest.mark.parametrize( - "src_tp_pp_exp,dest_tp_pp_exp,use_glu,singleton_local_shards", - [ - # changing PP is impossible because the number of layers must be the same - ((2, 4, 1), (2, 4, 1), False, False), - ((1, 1, 1), (1, 1, 4), False, True), - ((2, 2, 2), (4, 2, 1), False, False), - ((1, 1, 4), (8, 1, 1), False, True), - ((2, 1, 4), (1, 1, 8), False, False), - ((2, 4, 1), (2, 4, 1), True, True), - ((1, 1, 1), (1, 1, 4), True, False), - ((2, 2, 2), (4, 2, 1), True, True), - ((1, 1, 4), (8, 1, 1), True, False), - ((2, 1, 4), (1, 1, 8), True, True), - ], - ) - @pytest.mark.parametrize("src_module,dest_module", src_dest_expert_type) - def test_sequential_grouped_mlp_interchangeable( - self, - tmp_path_dist_ckpt, - src_tp_pp_exp, - dest_tp_pp_exp, - use_glu, - src_module, - dest_module, - singleton_local_shards, - ): - """Test model saving and loading with different TP/PP/expert parallelism""" - src_tp, src_pp, src_exp = src_tp_pp_exp - dest_tp, dest_pp, dest_exp = dest_tp_pp_exp - metadata = {'singleton_local_shards': singleton_local_shards} - # Save checkpoint A - Utils.initialize_model_parallel(src_tp, src_pp, expert_model_parallel_size=src_exp) - with ( - TempNamedDir( - tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_A' - ) as ckpt_dir_A, - TempNamedDir( - tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_B' - ) as ckpt_dir_B, - ): - layer_prefix = f'{parallel_state.get_pipeline_model_parallel_rank()}.' - model_A = initialize_expert_layer(1, use_glu, expert_type=src_module) - sharded_state_dict = model_A.sharded_state_dict(prefix=layer_prefix, metadata=metadata) - - save_strategy = get_default_save_sharded_strategy() - save(sharded_state_dict, ckpt_dir_A, save_strategy) - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel(dest_tp, dest_pp, expert_model_parallel_size=dest_exp) - model_B = initialize_expert_layer(1, use_glu, expert_type=dest_module) - load_strategy = None - state_dict = load( - model_B.sharded_state_dict(prefix=layer_prefix, metadata=metadata), - ckpt_dir_A, - load_strategy, - ) - model_B.load_state_dict( - {k.removeprefix(layer_prefix): v for k, v in state_dict.items()} - ) - save(model_B.sharded_state_dict(prefix=layer_prefix, metadata=metadata), ckpt_dir_B) - Utils.destroy_model_parallel() - - # Test both checkpoints are equal - Utils.initialize_model_parallel(1, 1) - state_dict_A = load_plain_tensors(ckpt_dir_A) - state_dict_B = load_plain_tensors(ckpt_dir_B) - diffs = diff(state_dict_A, state_dict_B) - assert not any(map(bool, diffs)), diffs - Utils.destroy_model_parallel() - - @pytest.mark.internal - @pytest.mark.skipif( - not is_te_min_version("1.11.0"), - reason="FP8 support of TEGroupedMLP is only available in TE 1.11.0 and later.", - ) - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize( - "src_module,dst_module,src_tp_pp_exp,dest_tp_pp_exp", - [ - # Changing tp/pp/dp doesn't affect _extra_state - ('te_sequential', 'te_grouped', (1, 1, 1), (1, 1, 4)), - ('te_sequential', 'te_grouped', (1, 1, 4), (1, 1, 1)), - ('te_grouped', 'te_sequential', (1, 1, 1), (1, 1, 4)), - ('te_grouped', 'te_sequential', (1, 1, 4), (1, 1, 1)), - ], - ) - @pytest.mark.parametrize('singleton_local_shards', [True, False]) - def test_sequential_grouped_mlp_extra_state( - self, - tmp_path_dist_ckpt, - src_tp_pp_exp, - dest_tp_pp_exp, - src_module, - dst_module, - singleton_local_shards, - ): - """Test saving and loading _extra_state""" - src_tp, src_pp, src_exp = src_tp_pp_exp - dest_tp, dest_pp, dest_exp = dest_tp_pp_exp - metadata = {'singleton_local_shards': singleton_local_shards} - use_glu = True - Utils.initialize_model_parallel(src_tp, src_pp, expert_model_parallel_size=src_exp) - with ( - TempNamedDir(tmp_path_dist_ckpt / 'test_grouped_mlp_extra_state_model_A') as ckpt_dir_A, - TempNamedDir(tmp_path_dist_ckpt / 'test_grouped_mlp_extra_state_model_B') as ckpt_dir_B, - fp8_autocast(), - ): - tokens_per_expert = torch.tensor([16] * (8 // src_exp)) - input_tensor = torch.randn(tokens_per_expert.sum(), 16, device="cuda") - probs = torch.rand((tokens_per_expert.sum(),), dtype=torch.float32, device="cuda") - - # Save checkpoint A - layer_prefix = f'{parallel_state.get_pipeline_model_parallel_rank()}.' - model_A = initialize_expert_layer(1, use_glu, expert_type=src_module, fp8=True) - model_A = model_A.cuda() - # fp8 meta is initialized at the first step - model_A(input_tensor, tokens_per_expert, probs) - sharded_state_dict = model_A.sharded_state_dict(prefix=layer_prefix, metadata=metadata) - - save_strategy = get_default_save_sharded_strategy() - save(sharded_state_dict, ckpt_dir_A, save_strategy) - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel(dest_tp, dest_pp, expert_model_parallel_size=dest_exp) - load_strategy = None - - # model_A load checkpoint A - model_A = initialize_expert_layer(1, use_glu, expert_type=src_module, fp8=True) - model_A = model_A.cuda() - state_dict = load( - model_A.sharded_state_dict(prefix=layer_prefix, metadata=metadata), - ckpt_dir_A, - load_strategy, - ) - model_A.load_state_dict( - {k.removeprefix(layer_prefix): v for k, v in state_dict.items()} - ) - - # model_B load checkpoint A - model_B = initialize_expert_layer(1, use_glu, expert_type=dst_module, fp8=True) - model_B = model_B.cuda() - state_dict = load( - model_B.sharded_state_dict(prefix=layer_prefix, metadata=metadata), - ckpt_dir_A, - load_strategy, - ) - model_B.load_state_dict( - {k.removeprefix(layer_prefix): v for k, v in state_dict.items()} - ) - - # Should be bitwise equal - if src_module == "te_grouped": - model_A, model_B = model_B, model_A - # Compare amax_history - torch.testing.assert_close( - torch.cat( - [ - model_A.local_experts[i] - .linear_fc1.fp8_meta["scaling_fwd"] - .amax_history.view(-1, 1) - for i in range(8 // dest_exp) - ], - dim=1, - ).view(model_A.local_experts[0].linear_fc1.fp8_meta["recipe"].amax_history_len, -1), - model_B.linear_fc1.fp8_meta["scaling_fwd"].amax_history, - rtol=0, - atol=0, - ) - # Compare scale - torch.testing.assert_close( - torch.cat( - [ - model_A.local_experts[i] - .linear_fc1.fp8_meta["scaling_fwd"] - .scale.view(-1, 1) - for i in range(8 // dest_exp) - ], - dim=1, - ).view(-1), - model_B.linear_fc1.fp8_meta["scaling_fwd"].scale, - rtol=0, - atol=0, - ) - - Utils.destroy_model_parallel() - - @pytest.mark.internal - @pytest.mark.skipif( - not is_te_min_version("1.9.0"), - reason="TEGroupedMLP is only supported in TE 1.9.0 and later.", - ) - @pytest.mark.parametrize("ep_size", [1, 2]) - def test_te_grouped_linear_torch_native(self, tmp_path_dist_ckpt, ep_size): - """Test saving and loading torch native checkpoints""" - use_glu = True - Utils.initialize_model_parallel(1, 1, expert_model_parallel_size=ep_size) - with TempNamedDir(tmp_path_dist_ckpt / 'test_te_grouped_linear_torch_native') as ckpt_dir: - tokens_per_expert = torch.tensor([16] * (8 // ep_size)) - input_tensor = torch.randn(tokens_per_expert.sum(), 16, device="cuda") - probs = torch.rand((tokens_per_expert.sum(),), dtype=torch.float32, device="cuda") - - # Save checkpoint - model = initialize_expert_layer(1, use_glu, expert_type="te_grouped") - model = model.cuda() - model(input_tensor, tokens_per_expert, probs) - torch.save(model.state_dict(), ckpt_dir / f"model_ep{torch.distributed.get_rank()}.pt") - - # Load checkpoint - state_dict = torch.load(ckpt_dir / f"model_ep{torch.distributed.get_rank()}.pt") - model.load_state_dict(state_dict) - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/models/test_t5_model.py b/tests/unit_tests/dist_checkpointing/models/test_t5_model.py deleted file mode 100644 index 1e44ee527a..0000000000 --- a/tests/unit_tests/dist_checkpointing/models/test_t5_model.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core import parallel_state as ps -from megatron.core.dist_checkpointing import load, save -from megatron.core.dist_checkpointing.validation import StrictHandling -from megatron.core.models.retro.decoder_spec import ( - get_retro_decoder_layer_local_spec, - get_retro_decoder_layer_te_spec, -) -from megatron.core.models.retro.encoder_spec import ( - get_retro_encoder_layer_local_spec, - get_retro_encoder_layer_te_spec, -) -from megatron.core.models.T5 import T5Model -from megatron.core.models.T5.t5_spec import decoder_model_with_local_spec as t5_decoder_local_spec -from megatron.core.models.T5.t5_spec import ( - decoder_model_with_transformer_engine_default_spec as t5_decoder_te_spec, -) -from megatron.core.models.T5.t5_spec import encoder_model_with_local_spec as t5_encoder_local_spec -from megatron.core.models.T5.t5_spec import ( - encoder_model_with_transformer_engine_default_spec as t5_encoder_te_spec, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_block import TransformerBlockSubmodules -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.dist_checkpointing.models.common import ( - common_test_parallel_reconfiguration_e2e, -) -from tests.unit_tests.test_utilities import Utils - - -def initialize_t5_model(seed, encoder_decoder_spec_fn, num_layers=8, **config_kwargs): - encoder_spec_fn, decoder_spec_fn = encoder_decoder_spec_fn - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - add_encoder = None - add_decoder = None - - encoder_layers_per_pipeline = num_layers - decoder_layers_per_pipeline = num_layers - - pre_process = ps.is_pipeline_first_stage() - post_process = ps.is_pipeline_last_stage() - - default_config_kwargs = dict( - num_layers=num_layers, - hidden_size=16, - num_attention_heads=12, - kv_channels=64, - ffn_hidden_size=64, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - ) - default_config_kwargs.update(**config_kwargs) - transformer_config = TransformerConfig(**default_config_kwargs) - - en_block_spec = TransformerBlockSubmodules([encoder_spec_fn()] * encoder_layers_per_pipeline) - de_block_spec = TransformerBlockSubmodules([decoder_spec_fn()] * decoder_layers_per_pipeline) - model = T5Model( - encoder_config=transformer_config, - config=transformer_config, - transformer_encoder_layer_spec=en_block_spec, - transformer_decoder_layer_spec=de_block_spec, - vocab_size=29184, - max_sequence_length=4, - pre_process=pre_process, - post_process=post_process, - add_encoder=add_encoder, - add_decoder=add_decoder, - ) - - with torch.no_grad(): - for p in model.parameters(): - p.random_() - return model - - -class TestT5Model: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize('src_spec_type', ['te', 'local']) - @pytest.mark.parametrize('dst_spec_type', ['te', 'local']) - @pytest.mark.parametrize('model_type', ['t5']) - def test_sharded_state_dict_save_load( - self, tmp_path_dist_ckpt, src_spec_type, dst_spec_type, model_type - ): - enc_dec_spec_fn = { - 'te': { - 't5': (t5_encoder_te_spec, t5_decoder_te_spec), - 'retro': (get_retro_encoder_layer_te_spec, get_retro_decoder_layer_te_spec), - }, - 'local': { - 't5': (t5_encoder_local_spec, t5_decoder_local_spec), - 'retro': (get_retro_encoder_layer_local_spec, get_retro_decoder_layer_local_spec), - }, - } - src_encoder_decoder_spec_fn = enc_dec_spec_fn[src_spec_type][model_type] - dst_encoder_decoder_spec_fn = enc_dec_spec_fn[dst_spec_type][model_type] - - Utils.initialize_model_parallel(1, 1) - gpt_model = initialize_t5_model(1, src_encoder_decoder_spec_fn) - with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model') as ckpt_dir: - # Save - sharded_state_dict = gpt_model.sharded_state_dict() - save(sharded_state_dict, ckpt_dir) - - # Load - gpt_model = initialize_t5_model(2, dst_encoder_decoder_spec_fn) - sharded_state_dict = gpt_model.sharded_state_dict() - - state_dict, missing_keys, unexpected_keys = load( - sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL - ) - # Potential mismatch is because of extra states which is ok - assert all('_extra_state' in k for k in missing_keys) - assert all('_extra_state' in k for k in unexpected_keys) - gpt_model.load_state_dict(state_dict) - - Utils.destroy_model_parallel() - - -class TestT5ModelReconfiguration: - - # def teardown_method(self, method): - # Utils.destroy_model_parallel() - - @pytest.mark.parametrize('src_spec_type', ['local']) # ['te', 'local']) - @pytest.mark.parametrize('dst_spec_type', ['local']) # ['te', 'local']) - @pytest.mark.parametrize('model_type', ['t5']) - @pytest.mark.parametrize( - ('use_fpsl', 'src_tp_pp_encpp', 'dest_tp_pp_encpp'), [(False, (1, 1, None), (1, 1, None))] - ) - def test_parallel_reconfiguration_e2e( - self, - tmp_path_dist_ckpt, - src_tp_pp_encpp, - dest_tp_pp_encpp, - use_fpsl, - src_spec_type, - dst_spec_type, - model_type, - ): - """Test model saving and loading with different TP/PP""" - - *src_tp_pp, src_encpp = src_tp_pp_encpp - *dest_tp_pp, dst_encpp = dest_tp_pp_encpp - - enc_dec_spec_fn = { - 'te': { - 't5': (t5_encoder_te_spec, t5_decoder_te_spec), - 'retro': (get_retro_encoder_layer_te_spec, get_retro_decoder_layer_te_spec), - }, - 'local': { - 't5': (t5_encoder_local_spec, t5_decoder_local_spec), - 'retro': (get_retro_encoder_layer_local_spec, get_retro_decoder_layer_local_spec), - }, - } - - common_test_parallel_reconfiguration_e2e( - initialize_t5_model, - tmp_path_dist_ckpt, - src_tp_pp, - dest_tp_pp, - enc_dec_spec_fn[src_spec_type][model_type], - enc_dec_spec_fn[dst_spec_type][model_type], - use_fpsl, - ) diff --git a/tests/unit_tests/dist_checkpointing/test_async_save.py b/tests/unit_tests/dist_checkpointing/test_async_save.py deleted file mode 100644 index b24478984d..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_async_save.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from unittest import mock - -import pytest -import torch - -from megatron.core.dist_checkpointing import ShardedTensor, load, save -from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue -from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync -from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -def write_data_os_err_mock_fn( - transform_list, local_proc_idx, write_bucket, results_queue, count_queue, use_fsync, **kwargs -): - """Raises an error on worker #2 during storage save""" - try: - if local_proc_idx == 2: - raise OSError('worker #2 critical failure') - output = (local_proc_idx, []) - except Exception as e: - output = (local_proc_idx, e) - results_queue.put(output) - count_queue.get() - count_queue.task_done() - - -class TestAsyncSave: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize('persistent', [True, False]) - @pytest.mark.parametrize('abort', [True, False]) - def test_async_is_equivalent_to_sync(self, tmp_path_dist_ckpt, persistent, abort): - Utils.initialize_model_parallel(2, 4) - - sharded_state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), replica_id=Utils.rank - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.ones(3, 5, 7), replica_id=Utils.world_size - Utils.rank - 1 - ), - } - - with ( - TempNamedDir(tmp_path_dist_ckpt / 'test_equivalence_async') as async_ckpt_dir, - TempNamedDir(tmp_path_dist_ckpt / 'test_equivalence_sync') as sync_ckpt_dir, - ): - # async - async_calls = AsyncCallsQueue(persistent) - async_request = save(sharded_state_dict, async_ckpt_dir, async_sharded_save=True) - async_calls.schedule_async_request(async_request) - - # sync - save(sharded_state_dict, sync_ckpt_dir, async_sharded_save=False) - - # finalize async - async_calls.maybe_finalize_async_calls(blocking=True) - - # load and compare - loaded_async_state_dict = load(sharded_state_dict, async_ckpt_dir) - loaded_sync_state_dict = load(sharded_state_dict, sync_ckpt_dir) - diffs = diff(loaded_async_state_dict, loaded_sync_state_dict) - assert not any(map(bool, diffs)), diffs - async_calls.close(abort=abort) - - Utils.destroy_model_parallel() - - @pytest.mark.parametrize('async_save', [False, True]) - @pytest.mark.parametrize('worker_fn', [write_data_os_err_mock_fn]) - def test_errors_are_reported(self, tmp_path_dist_ckpt, async_save, worker_fn): - Utils.initialize_model_parallel(2, 4) - sharded_state_dict = { - f'key{i}': ShardedTensor.from_rank_offsets(f'key{i}_rank{Utils.rank}', torch.ones(2, 4)) - for i in range(4) # make sure there is enough non-empty saving workers - } - - with TempNamedDir(tmp_path_dist_ckpt / 'test_errors_are_reported') as ckpt_dir: - async_calls = AsyncCallsQueue() - save_strategy = TorchDistSaveShardedStrategy('torch_dist', 1, thread_count=8) - - try: - orig_fn = FileSystemWriterAsync.write_preloaded_data - FileSystemWriterAsync.write_preloaded_data = worker_fn - with pytest.raises(RuntimeError) as exc_info: - if async_save: - async_request = save( - sharded_state_dict, ckpt_dir, save_strategy, async_sharded_save=True - ) - async_calls.schedule_async_request(async_request) - async_calls.maybe_finalize_async_calls(blocking=True) - else: - save(sharded_state_dict, ckpt_dir, save_strategy) - assert 'Worker failure' in str(exc_info.value) - - finally: - FileSystemWriterAsync.write_preloaded_data = orig_fn - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/test_checkpointable.py b/tests/unit_tests/dist_checkpointing/test_checkpointable.py deleted file mode 100644 index 7ae5df7121..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_checkpointable.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import pytest -import torch -from packaging import version -from torch.distributed.checkpoint import FileSystemReader, TensorStorageMetadata - -from megatron.core.dist_checkpointing import ShardedTensor -from megatron.core.dist_checkpointing.strategies.checkpointable import ( - CheckpointableShardedTensor, - LocalShardsContainer, -) -from megatron.core.utils import is_torch_min_version -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -@pytest.mark.skipif( - not is_torch_min_version("2.6a0"), - reason="CheckpointableShardedTensor requires PyTorch 2.6 or later", -) -class TestCheckpointableProtocol: - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_sharded_tensor_checkpointing(self, tmp_path_dist_ckpt): - """Test sharded tensor checkpointing with pure DCP.""" - - def get_sd(val=3): - sh_ten = ShardedTensor.from_rank_offsets( - 'b_ten', torch.ones(3) * Utils.rank + val, (0, Utils.rank, Utils.world_size) - ) - return {'b_ten_sd': CheckpointableShardedTensor.from_sh_ten(sh_ten)} - - state_dict = get_sd(3) - with TempNamedDir(tmp_path_dist_ckpt / 'test_sharded_objects') as ckpt_dir: - torch.distributed.checkpoint.save(state_dict, checkpoint_id=ckpt_dir) - torch.distributed.barrier() - - loaded_state_dict = get_sd(4) - assert torch.all(loaded_state_dict['b_ten_sd']._sh_ten.data == Utils.rank + 4) - torch.distributed.checkpoint.load(loaded_state_dict, checkpoint_id=ckpt_dir) - assert torch.all(loaded_state_dict['b_ten_sd']._sh_ten.data == Utils.rank + 3) - - def test_multiple_local_shards(self, tmp_path_dist_ckpt): - def get_sd(val=3): - sh_ten_part_one = ShardedTensor.from_rank_offsets( - 'b_ten', torch.ones(3) * Utils.rank + val, (0, Utils.rank, Utils.world_size * 2) - ) - sh_ten_part_two = ShardedTensor.from_rank_offsets( - 'b_ten', - torch.ones(3) * Utils.rank + val, - (0, Utils.world_size + Utils.rank, Utils.world_size * 2), - ) - - return { - 'b_ten_sd': LocalShardsContainer( - [ - CheckpointableShardedTensor.from_sh_ten(sh_ten_part_one), - CheckpointableShardedTensor.from_sh_ten(sh_ten_part_two), - ] - ) - } - - state_dict = get_sd(3) - with TempNamedDir(tmp_path_dist_ckpt / 'test_sharded_objects') as ckpt_dir: - torch.distributed.checkpoint.save(state_dict, checkpoint_id=ckpt_dir) - torch.distributed.barrier() - - metadata = FileSystemReader(ckpt_dir).read_metadata() - assert isinstance(metadata.state_dict_metadata['b_ten_sd'], TensorStorageMetadata) - - loaded_state_dict = get_sd(4) - for shard in loaded_state_dict['b_ten_sd']._local_shards: - assert torch.all(shard._sh_ten.data == Utils.rank + 4) - torch.distributed.checkpoint.load(loaded_state_dict, checkpoint_id=ckpt_dir) - for shard in loaded_state_dict['b_ten_sd']._local_shards: - assert torch.all(shard._sh_ten.data == Utils.rank + 3) diff --git a/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py b/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py deleted file mode 100644 index 1485eebe10..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import io -from contextlib import nullcontext - -import numpy as np -import pytest -import torch -from torch.distributed.checkpoint import CheckpointException - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing import ShardedTensor, load, save -from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config -from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory -from megatron.core.dist_checkpointing.serialization import load_tensors_metadata -from megatron.core.dist_checkpointing.strategies.resharding import ( - apply_nd_flattened_tensors_reformulation, - restore_nd_flattened_tensors_formulation, -) -from megatron.core.dist_checkpointing.strategies.torch import get_reformulation_metadata -from megatron.core.dist_checkpointing.validation import ( - determine_global_metadata, - validate_sharding_integrity, -) -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -class TestFlattenedResharding: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp'), - [((2, 4), (2, 4)), ((2, 4), (2, 2)), ((2, 4), (4, 2)), ((8, 1), (1, 2))], - ) - def test_partition_change_save_load(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp): - Utils.initialize_model_parallel(*src_tp_pp) - with TempNamedDir( - tmp_path_dist_ckpt / 'test_flattened_partition_change_save_load' - ) as ckpt_dir: - - state_dict = self._build_state_dict() - - save(state_dict, ckpt_dir) - - # change TPxPP - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(*dest_tp_pp) - loaded_state_dict = load(self._build_state_dict(random=True), ckpt_dir) - expected_state_dict = {k: v.data for k, v in self._build_state_dict().items()} - - diffs = diff(expected_state_dict, loaded_state_dict) - assert not any(diffs), diffs - - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp', 'expected_ckpt_offsets_by_rank'), - [ - ( - (2, 4), - (2, 2), - { - 0: [(0, 0, 0), (0, 0, 10)], # TP 0, DP 0, PP 0 - 1: [(4, 0, 0), (4, 0, 10)], # TP 1, DP 0, PP 0 - 2: [(0, 0, 0), (0, 0, 10)], # TP 0, DP 1, PP 0 - 3: [(4, 0, 0), (4, 0, 10)], # TP 1, DP 1, PP 0 - 4: [(0, 0, 20), (0, 0, 30)], # TP 0, DP 0, PP 1 - 5: [(4, 0, 20), (4, 0, 30)], # TP 1, DP 0, PP 1 - 6: [(0, 0, 20), (0, 0, 30)], # TP 0, DP 1, PP 1 - 7: [(4, 0, 20), (4, 0, 30)], # TP 1, DP 1, PP 1 - }, - ), - ((8, 1), (1, 2), {rank: [(tp, 0, 0) for tp in range(8)] for rank in range(8)}), - ], - ) - def test_reformulate_nd_flattened_tensors( - self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, expected_ckpt_offsets_by_rank - ): - Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp') - with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir: - - state_dict = self._build_state_dict() - - ckpt_local_shape = state_dict['sd_key_flat'].local_shape - - save(state_dict, ckpt_dir) - - # change TPxPP - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(*dest_tp_pp, order='tp-dp-pp') - load_state_dict = self._build_state_dict(random=True) - - reformulation_metadata = get_reformulation_metadata(load_state_dict, ckpt_dir) - reformulated_state_dict, formulation_restore_data = ( - apply_nd_flattened_tensors_reformulation(load_state_dict, reformulation_metadata) - ) - assert isinstance(reformulated_state_dict['sd_key_unflat'], ShardedTensor) - assert isinstance(reformulated_state_dict['sd_key_flat'], dict) - - assert reformulated_state_dict['sd_key_flat'].keys() == set( - (offset, ckpt_local_shape) for offset in expected_ckpt_offsets_by_rank[Utils.rank] - ), ( - reformulated_state_dict['sd_key_flat'].keys(), - ckpt_local_shape, - expected_ckpt_offsets_by_rank[Utils.rank], - ) - - # We can even load the reformulated state dict with a high-level API - loaded_state_dict = load( - reformulated_state_dict, ckpt_dir, validate_access_integrity=False - ) - loaded_state_dict = restore_nd_flattened_tensors_formulation( - loaded_state_dict, formulation_restore_data - ) - expected_state_dict = {k: v.data for k, v in self._build_state_dict().items()} - diffs = diff(expected_state_dict, loaded_state_dict) - assert not any(diffs), diffs - - Utils.destroy_model_parallel() - - @pytest.mark.parametrize(('src_tp_pp',), [((2, 4),), ((8, 1),), ((1, 1),), ((1, 4),)]) - def test_load_tensor_metadata(self, tmp_path_dist_ckpt, src_tp_pp): - Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp') - with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir: - - state_dict = self._build_state_dict() - - save(state_dict, ckpt_dir) - - # change TPxPP - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(1, 1) - - sharded_metadata = load_tensors_metadata(ckpt_dir) - - for attr_name in ('local_shape', 'global_shape'): - flat_val = getattr(sharded_metadata['flat'], attr_name) - unflat_val = getattr(sharded_metadata['unflat'], attr_name) - assert flat_val == unflat_val, (attr_name, flat_val, unflat_val) - - for sh_ten in sharded_metadata.values(): - sh_ten.replica_id = Utils.rank - loaded_state_dict = load(sharded_metadata, ckpt_dir) - assert torch.all( - loaded_state_dict['unflat'] == torch.arange(8 * 5 * 40).reshape(8, 5, 40) - ) - assert torch.all(loaded_state_dict['flat'] == torch.arange(8 * 5 * 40)) - - Utils.destroy_model_parallel() - - def _build_state_dict(self, random=False): - tp_rank = parallel_state.get_tensor_model_parallel_rank() - tp_size = parallel_state.get_tensor_model_parallel_world_size() - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - pp_size = parallel_state.get_pipeline_model_parallel_world_size() - dp_rank = parallel_state.get_data_parallel_rank() - dp_size = parallel_state.get_data_parallel_world_size() - - init_fn = torch.rand if random else torch.arange - global_ten = init_fn(8 * 5 * 40).reshape(8, 5, 40) - local_ten = global_ten - local_ten = local_ten.chunk(tp_size, dim=0)[tp_rank] - local_ten = local_ten.chunk(pp_size, dim=2)[pp_rank] - assert local_ten.shape == (8 // tp_size, 5, 40 // pp_size) - - local_ten_size_by_dp = local_ten.numel() - assert local_ten_size_by_dp % dp_size == 0, (local_ten_size_by_dp, dp_size) - local_ten_size_by_dp = local_ten_size_by_dp // dp_size - # make a bit shifted DP slices so that they are not equal - start_jitter = dp_rank - end_jitter = dp_rank + 1 if dp_rank + 1 < dp_size else 0 - local_dp_slice = slice( - local_ten_size_by_dp * dp_rank + start_jitter, - local_ten_size_by_dp * (dp_rank + 1) + end_jitter, - ) - local_flat_ten = local_ten.flatten()[local_dp_slice] - if dp_rank == dp_size - 1: - assert local_flat_ten.numel() == local_ten_size_by_dp - dp_rank - else: - assert local_flat_ten.numel() == local_ten_size_by_dp + 1 - - state_dict = { - 'sd_key_unflat': ShardedTensor.from_rank_offsets( - 'unflat', - local_ten, - (0, tp_rank, tp_size), - (2, pp_rank, pp_size), - replica_id=dp_rank, - ), - 'sd_key_flat': ShardedTensor.from_rank_offsets_flat( - 'flat', - local_flat_ten, - local_ten.shape, - (0, tp_rank, tp_size), - (2, pp_rank, pp_size), - flattened_range=local_dp_slice, - ), - } - return state_dict - - def test_flattened_tensors_are_properly_validated(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel() - # Global tensor of shape (6, 6) is built from: - # ranks 0, 1, 2 tensors of length 1, 2, 3 - # and then ranks 3, ..., 7 tensors of length 6 - local_flat_ten = torch.ones(Utils.rank + 1 if Utils.rank <= 2 else 6) * Utils.rank - - global_flattened_len = 6 + (Utils.world_size - 3) * 6 - if Utils.world_size == 8: - assert global_flattened_len == 1 + 2 + 3 + 5 * 6 - local_ten_shape = (1, 6) - else: - local_ten_shape = (global_flattened_len,) - - if Utils.rank == 0: - local_dp_slice_start = 0 - elif Utils.rank == 1: - local_dp_slice_start = 1 - elif Utils.rank == 2: - local_dp_slice_start = 3 - else: - local_dp_slice_start = 0 - local_dp_slice = slice(local_dp_slice_start, local_dp_slice_start + len(local_flat_ten)) - - state_dict = { - 'sd_key_flat': ShardedTensor.from_rank_offsets_flat( - 'flat', - local_flat_ten, - local_ten_shape, - *((0, max(0, Utils.rank - 2), 6),) if Utils.world_size == 8 else (), - flattened_range=local_dp_slice, - replica_id=0 - ) - } - validate_sharding_integrity(determine_global_metadata(state_dict)[1]) - if Utils.rank == 1: - old_state_dict = state_dict - state_dict = {} - - with ( - pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext() - ) as exc_info: - validate_sharding_integrity(determine_global_metadata(state_dict)[1]) - if Utils.rank == 0: - assert 'Flattened ranges dont cover the whole shard ShardedTensor' in str( - exc_info.value - ) - - if Utils.rank == 1: - state_dict = old_state_dict - - if Utils.rank == 4: - state_dict = {} - - with ( - pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext() - ) as exc_info: - validate_sharding_integrity(determine_global_metadata(state_dict)[1]) - if Utils.rank == 0: - assert 'Invalid access pattern' in str(exc_info.value) - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/test_fp8.py b/tests/unit_tests/dist_checkpointing/test_fp8.py deleted file mode 100644 index 4fb89a8265..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_fp8.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor - -from megatron.core.dist_checkpointing import ShardedTensor, load, save -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) -from megatron.core.dist_checkpointing.strategies.fully_parallel import ( - FullyParallelLoadStrategyWrapper, - FullyParallelSaveStrategyWrapper, -) -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -def to_float8(tensor: torch.Tensor) -> Float8Tensor: - """Convert a tensor to FP8 format.""" - try: - return Float8Tensor.to_float8(tensor) - except Exception as e: - # Handle the case where the method fails (due to API changes in TransformerEngine) - # https://github.com/NVIDIA/TransformerEngine/commit/544dd14b4301beb47136f273deff3f532cdde181 - import transformer_engine_torch as tex - from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer - - fp8_dtype = tex.DType.kFloat8E4M3 - scale = 1.0 - - # Create a quantizer for FP8 conversion - quantizer = Float8Quantizer( - scale=torch.full([1], scale, dtype=torch.float32, device="cuda"), - amax=torch.empty([1], dtype=torch.float32, device="cuda"), - fp8_dtype=fp8_dtype, - ) - - # Return the quantized tensor - return quantizer(tensor.cuda()) - - -class TestFP8: - @pytest.mark.parametrize('dtype', ['bf16', 'fp16', 'fp8']) - @pytest.mark.parametrize('src_rank', [0, 6]) - def test_simple_broadcast(self, dtype, src_rank): - Utils.initialize_model_parallel() - - def get_ten(dtype: str = 'fp8'): - if dtype == 'fp8': - return to_float8(torch.full((3,), Utils.rank, dtype=torch.bfloat16, device='cuda')) - elif dtype == 'bf16': - return torch.full((3,), Utils.rank, dtype=torch.bfloat16, device='cuda') - elif dtype == 'fp16': - return torch.full((3,), Utils.rank, dtype=torch.float16, device='cuda') - else: - raise NotImplementedError(dtype) - - ten = get_ten(dtype) - - # because of a bug in TE, with the cast broadcast fails - if isinstance(ten, Float8Tensor): - ten = ten.dequantize() - torch.distributed.broadcast(ten, src=src_rank) - assert torch.all(ten == src_rank) - - @pytest.mark.parametrize( - ('use_fpsl', 'src_tp_pp', 'dest_tp_pp', 'load_exchange_algo'), - [ - (True, (2, 4), (2, 4), 'broadcast'), - (True, (2, 4), (2, 4), 'gather_rounds'), - (False, (2, 4), (2, 4), None), - ], - ) - def test_fp8_save_load( - self, tmp_path_dist_ckpt, use_fpsl, src_tp_pp, dest_tp_pp, load_exchange_algo - ): - Utils.initialize_model_parallel(*src_tp_pp) - - def get_fp8_tensor(fill_val=1): - return to_float8(torch.full((3,), fill_val, dtype=torch.bfloat16, device='cuda')) - - def get_state_dict(fill_val=1): - return { - 'a': ShardedTensor.from_rank_offsets( - 'a', get_fp8_tensor(fill_val), (0, Utils.rank, Utils.world_size), replica_id=0 - ), - 'b': ShardedTensor.from_rank_offsets( - 'b', get_fp8_tensor(fill_val), replica_id=Utils.rank - ), - 'c': ShardedTensor.from_rank_offsets( - 'c', get_fp8_tensor(fill_val), replica_id=Utils.rank - ), - } - - with TempNamedDir(tmp_path_dist_ckpt / 'test_fp8_save_load', sync=True) as ckpt_dir: - save_strategy = get_default_save_sharded_strategy() - if use_fpsl: - save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, None, True) - save(get_state_dict(4), ckpt_dir, save_strategy) - - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(*dest_tp_pp) - - if use_fpsl: - load_strategy = get_default_load_sharded_strategy(ckpt_dir) - load_strategy = FullyParallelLoadStrategyWrapper( - load_strategy, None, False, load_exchange_algo - ) - else: - load_strategy = None - - loaded_state_dict = load(get_state_dict(8), ckpt_dir, load_strategy) - assert torch.all(loaded_state_dict['a'] == 4) - assert torch.all(loaded_state_dict['b'] == 4) - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/test_fully_parallel.py b/tests/unit_tests/dist_checkpointing/test_fully_parallel.py deleted file mode 100644 index 494eaefb44..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_fully_parallel.py +++ /dev/null @@ -1,636 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import inspect -from collections import defaultdict -from pathlib import Path -from types import MethodType -from typing import Dict, List, Tuple -from unittest import mock - -import pytest -import torch -import torch.distributed - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing import ShardedTensor -from megatron.core.dist_checkpointing.dict_utils import ( - dict_list_map_outplace, - map_reduce, - nested_values, -) -from megatron.core.dist_checkpointing.exchange_utils import ( - _get_empty_tensor_for_exchange, - distribute_shards_to_ranks, -) -from megatron.core.dist_checkpointing.mapping import ( - ShardedObject, - ShardedStateDict, - ShardedTensorFactory, - is_main_replica, -) -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) -from megatron.core.dist_checkpointing.strategies.base import ( - LoadShardedStrategy, - SaveShardedStrategy, - StrategyAction, - get_default_strategy, -) -from megatron.core.dist_checkpointing.strategies.fully_parallel import ( - FullyParallelLoadStrategyWrapper, - FullyParallelSaveStrategyWrapper, - _sharded_tensor_shard_id, -) -from megatron.core.dist_checkpointing.strategies.torch import ( - MCoreLoadPlanner, - TorchDistSaveShardedStrategy, -) -from megatron.core.utils import get_pg_rank -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -class MockSaveStrategy(SaveShardedStrategy): - def __init__(self): - super().__init__('mock', 1) - self.save_keys = set() - - def save(self, sharded_state_dict, ckpt_dir): - for sh_ten in nested_values(sharded_state_dict): - if is_main_replica(sh_ten.replica_id): - self.save_keys.add(sh_ten.key) - - -class MockLoadStrategy(LoadShardedStrategy): - def __init__(self, device='cpu'): - super().__init__() - self.device = device - self.load_keys = set() - - def load(self, sharded_state_dict, ckpt_dir): - for sh_ten in nested_values(sharded_state_dict): - if is_main_replica(sh_ten.replica_id): - self.load_keys.add(sh_ten.key) - - def load_rand(x): - assert isinstance(x, ShardedTensor) or isinstance(x, ShardedObject) - if isinstance(x, ShardedTensor): - x.init_data(self.device) - x.data.fill_(Utils.rank) - return x.data - else: - x.data = [Utils.rank] - return x.data - - return dict_list_map_outplace(load_rand, sharded_state_dict) - - def load_tensors_metadata(self, checkpoint_dir: Path): - pass - - def check_backend_compatibility(self, loaded_version): - pass - - def check_version_compatibility(self, loaded_version): - pass - - -class TestFullyParallelSaveAndLoad: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @staticmethod - def get_sharded_state_dict(): - return { - 'sd_key_tp_repl1': ShardedTensor.from_rank_offsets( - 'key_TP_repl1', - torch.ones(10), - ( - 0, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_tensor_model_parallel_world_size(), - ), - replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True), - ), - 'sd_key_tp_repl2': ShardedTensor.from_rank_offsets( - 'key_TP_repl2', - torch.ones(10), - ( - 0, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_tensor_model_parallel_world_size(), - ), - replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True), - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.ones(20), (0, Utils.rank, Utils.world_size) - ), - 'sd_keyE_no_C': ShardedTensor.from_rank_offsets( - 'keyC', torch.ones(100), replica_id=Utils.rank - ), - 'sd_keyX_no_D': ShardedTensor.from_rank_offsets( - 'keyD', torch.ones(1000), replica_id=Utils.rank - ), - 'sd_keyC_no_E': ShardedTensor.from_rank_offsets( - 'keyE', torch.ones(100), replica_id=Utils.rank - ), - } - - @pytest.mark.parametrize("parallelization_along_dp", [False, True]) - def test_save_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 1) - state_dict = self.get_sharded_state_dict() - - # Ranks assignment: - # 1. Lowest coverage - # 2. Largest tensor - # 3. Shard id (key) - if not parallelization_along_dp: - expected_key_to_saving_ranks = { - 'keyB': list( - range(Utils.world_size) - ), # everyone must save (disjoint shards, coverage == 1) - 'key_TP_repl1': [0, 1], # lowest coverage (4), first TP domain - 'key_TP_repl2': [2, 3], # lowest coverage (4), second TP domain - 'keyD': [4], # largest tensor - 'keyC': [5], # second largest tensor - 'keyE': [6], # second largest tensor - } - else: - if parallel_state.get_tensor_model_parallel_rank() == 0: - expected_key_to_saving_ranks = { - # everyone must save (disjoint shards, coverage == 1): - 'keyB': list( - range( - parallel_state.get_data_parallel_world_size(with_context_parallel=True) - ) - ), - # this time, TP sharded tensors have the same coverage as fully replicated! - 'keyD': [0], # largest tensor - 'keyC': [1], # second largest tensor - 'keyE': [2], # second largest tensor - 'key_TP_repl1': [3], # smallest tensor - 'key_TP_repl2': [3], # smallest tensor, last rank is the least occupied - } - else: - expected_key_to_saving_ranks = { - # everyone must save (disjoint shards, coverage == 1): - 'keyB': list( - range( - parallel_state.get_data_parallel_world_size(with_context_parallel=True) - ) - ), - # tensors C, D, E are absent in this DP group - 'key_TP_repl1': [0], # smallest tensor - 'key_TP_repl2': [1], # smallest tensor, last rank is the least occupied - } - - parallelization_group = ( - parallel_state.get_data_parallel_group(with_context_parallel=True) - if parallelization_along_dp - else torch.distributed.group.WORLD - ) - dp_rank = get_pg_rank(parallelization_group) - expected_keys_saved_by_current_rank = { - k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v - } - - # Run save and tests - mock_strategy = MockSaveStrategy() - save_strategy = FullyParallelSaveStrategyWrapper( - mock_strategy, parallelization_group, do_cache_distribution=True - ) - with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A: - save_strategy.save(state_dict, ckpt_dir_A) - key_to_saving_rank = dict( - map_reduce( - save_strategy.cached_distribution.main_rank_for_shard.items(), - lambda shard_rank: shard_rank[0][0], - lambda shard_rank: shard_rank[1], - ) - ) - assert expected_key_to_saving_ranks == key_to_saving_rank - - for _, sh_ten in state_dict.items(): - if ( - _sharded_tensor_shard_id(sh_ten) - in save_strategy.cached_distribution.shards_in_this_group - ): - is_expected_to_be_saved_by_this_rank = dp_rank in expected_key_to_saving_ranks.get( - sh_ten.key, [] - ) - assert sh_ten.replica_id == int( - not is_expected_to_be_saved_by_this_rank - ), expected_key_to_saving_ranks - - assert mock_strategy.save_keys == expected_keys_saved_by_current_rank, ( - Utils.rank, - mock_strategy.save_keys, - expected_keys_saved_by_current_rank, - ) - - @pytest.mark.internal - @pytest.mark.parametrize("parallelize_within_dp", [False, True]) - def test_load_distribution(self, parallelize_within_dp, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 1) - - state_dict = self.get_sharded_state_dict() - - # Ranks assignment: - # 1. non-cross-DP read - # 2. Lowest coverage - # 3. Largest tensor - # 4. Shard id (key) - if not parallelize_within_dp: - expected_key_to_loading_ranks = { - 'keyB': list( - range(Utils.world_size) - ), # everyone must save (disjoint shards, coverage == 1) - 'key_TP_repl1': [0, 1], # lowest coverage (4), first TP domain - 'key_TP_repl2': [2, 3], # lowest coverage (4), second TP domain - 'keyD': [4], # largest tensor - 'keyC': [5], # second largest tensor - 'keyE': [6], # second largest tensor - } - else: - # We must check if we should expect old load behavior (<= v0.10) or aligned one (v0.11) - sig = inspect.signature(distribute_shards_to_ranks) - aligned_load = 'cross_parallelization_group_loads' in sig.parameters - if not aligned_load or parallel_state.get_tensor_model_parallel_rank() == 0: - # All main ranks are in the first DP group (TP rank 0), - # so the load distribution is the same as the saving one - expected_key_to_loading_ranks = { - # everyone must load (disjoint shards, coverage == 1): - 'keyB': list( - range( - parallel_state.get_data_parallel_world_size(with_context_parallel=True) - ) - ), - # this time, TP sharded tensors have the same coverage as fully replicated! - 'keyD': [0], # largest tensor - 'keyC': [1], # second largest tensor - 'keyE': [2], # second largest tensor - 'key_TP_repl1': [3], # smallest tensor - 'key_TP_repl2': [3], # smallest tensor, last rank is the least occupied - } - else: - # 'C', 'D', 'E' are cross-DP reads, so are assigned at the end. - # First 'key_TP_repl*' are assigned to rank 0 and 1 - expected_key_to_loading_ranks = { - # everyone must load (disjoint shards, coverage == 1): - 'keyB': list( - range( - parallel_state.get_data_parallel_world_size(with_context_parallel=True) - ) - ), - # the only intra-DP reads - 'key_TP_repl1': [0], - 'key_TP_repl2': [1], - # cross-DP reads are assigned at the end - 'keyD': [2], # largest tensor - 'keyC': [3], # second largest tensor - 'keyE': [0], # second largest tensor, round-robin - } - - parallelization_group = ( - parallel_state.get_data_parallel_group(with_context_parallel=True) - if parallelize_within_dp - else torch.distributed.group.WORLD - ) - dp_rank = get_pg_rank(parallelization_group) - expected_keys_loaded_by_current_rank = { - k for k, v in expected_key_to_loading_ranks.items() if dp_rank in v - } - - # Run save and tests - mock_strategy = MockLoadStrategy() - load_strategy = FullyParallelLoadStrategyWrapper( - mock_strategy, parallelization_group, do_cache_distribution=True - ) - with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A: - loaded_state_dict = load_strategy.load(state_dict, ckpt_dir_A) - key_to_loading_rank = dict( - map_reduce( - load_strategy.cached_distribution.main_rank_for_shard.items(), - lambda shard_rank: shard_rank[0][0], - lambda shard_rank: shard_rank[1], - ) - ) - assert expected_key_to_loading_ranks == key_to_loading_rank - - assert mock_strategy.load_keys == expected_keys_loaded_by_current_rank, ( - Utils.rank, - mock_strategy.load_keys, - expected_keys_loaded_by_current_rank, - ) - - assert loaded_state_dict.keys() == state_dict.keys() - - @pytest.mark.parametrize('state_dict_device', ['cpu', 'cuda']) - @pytest.mark.flaky - @pytest.mark.flaky_in_dev - def test_memory_usage(self, state_dict_device, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 1) - - megabytes = 1024 * 1024 - mock_strategy = MockLoadStrategy(state_dict_device) - - mem_alloc = [] - - real_get_empty_tensor_for_exchange = _get_empty_tensor_for_exchange - - def mock_get_empty_tensor_for_exchange(*args, **kwargs) -> torch.Tensor: - ret = real_get_empty_tensor_for_exchange(*args, **kwargs) - mem_alloc.append(torch.cuda.memory_allocated()) - return ret - - load_strategy = FullyParallelLoadStrategyWrapper(mock_strategy) - torch.distributed.barrier() - - # Each tensor is 4MB, 40MB in total. - # We expect extra memory usage peak at ~32MB, not 1GB - sharded_state_dict = { - f'ten_{i}': ShardedTensor.from_rank_offsets( - f'ten_{i}', - torch.rand(megabytes, dtype=torch.float, device=state_dict_device), - replica_id=Utils.rank, - ) - for i in range(10) - } - - mem_alloc_start = torch.cuda.memory_allocated() - - with ( - mock.patch( - 'megatron.core.dist_checkpointing.exchange_utils._get_empty_tensor_for_exchange', - new=mock_get_empty_tensor_for_exchange, - ), - TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A, - ): - _ = load_strategy.load(sharded_state_dict, ckpt_dir_A) - - # Each rank is expected to do 9 allocations for all shards loaded by some other rank. - # There are 10 shards and 8 ranks so ranks <= 1 load 2 shards (and allocate 10 - 2 = 8) - assert len(mem_alloc) == 8 if Utils.rank <= 1 else 9 - # Peak mem usage should be within 4MB (single tensor) - assert max(mem_alloc) - mem_alloc_start < 4.01 * megabytes, ( - max(mem_alloc), - mem_alloc_start, - ) - - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_only_necessary_exchanges_performed_during_load(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 1) - - # State dict with 2 expected exchanges - sharded_state_dict_baseline_two_exchanges = { - 'needed_by_all_A': ShardedTensor.from_rank_offsets( - 'needed_by_all_A', - torch.ones(4, dtype=torch.float, device='cuda'), - replica_id=Utils.rank, - ), - 'needed_by_all_B': ShardedTensor.from_rank_offsets( - 'needed_by_all_B', - torch.ones(4, dtype=torch.float, device='cuda'), - replica_id=Utils.rank, - ), - } - # State dict with 1 expected exchange - sharded_state_dict_baseline_one_exchange = { - 'needed_by_all': sharded_state_dict_baseline_two_exchanges['needed_by_all_A'] - } - # State dict with 1 expected exchanges even though there are 2 tensors to load (1 is unique for each rank) - sharded_state_dict_test_one_exchange = sharded_state_dict_baseline_one_exchange.copy() - sharded_state_dict_test_one_exchange['unique'] = ShardedTensor.from_rank_offsets( - 'unique', - torch.ones(4, dtype=torch.float, device='cuda'), - (0, Utils.rank, Utils.world_size), - ) - - expected_call_counts: List[Tuple[ShardedStateDict, int]] = [ - (sharded_state_dict_baseline_one_exchange, 1), - (sharded_state_dict_baseline_two_exchanges, 2), - (sharded_state_dict_test_one_exchange, 1), - ] - - mock_strategy = MockLoadStrategy() - with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir: - for sharded_state_dict, expected_count in expected_call_counts: - load_strategy = FullyParallelLoadStrategyWrapper( - mock_strategy, None, do_cache_distribution=True, exchange_algo='broadcast' - ) - with mock.patch( - 'megatron.core.dist_checkpointing.strategies.fully_parallel.torch.distributed.broadcast' - ) as broadcast_mock: - _ = load_strategy.load(sharded_state_dict, ckpt_dir) - assert broadcast_mock.call_count == expected_count - - Utils.destroy_model_parallel() - - def test_broadcast_sharded_objects(self, tmp_path_dist_ckpt): - - sharded_state_dict = { - f'Obj_{i}': ShardedObject(f'Obj_{i}', None, (1,), (0,), replica_id=abs(Utils.rank - i)) - for i in range(Utils.world_size) - } - - with TempNamedDir(tmp_path_dist_ckpt / 'test_broadcast_sharded_objects') as ckpt_dir: - load_strategy = MockLoadStrategy() - load_strategy = FullyParallelLoadStrategyWrapper(load_strategy, None) - - loaded_state_dict = load_strategy.load(sharded_state_dict, ckpt_dir) - - # each rank is supposed to only load obj_rank because of how replica_id is set - assert load_strategy.base_strategy.load_keys == set({f'Obj_{Utils.rank}'}) - - # since each rank only loaded their Obj they were broadcasted - assert set(sharded_state_dict.keys()) == set(loaded_state_dict.keys()) - - -class TestCrossRanksReads: - RanksPlacementT = Dict[str, List[Tuple[int, int]]] # maps from name to (TP, DP) - - def teardown_method(self): - Utils.destroy_model_parallel() - - def get_sharded_state_dict(self, ranks_placement: RanksPlacementT): - tp_rank = parallel_state.get_tensor_model_parallel_rank() - dp_rank = parallel_state.get_data_parallel_rank() - - sharded_state_dict = {} - for name, tps_dps in ranks_placement.items(): - if (tp_rank, dp_rank) in tps_dps: - is_main = (tp_rank, dp_rank) == tps_dps[0] - sharded_state_dict[name] = ShardedTensor.from_rank_offsets( - name, torch.ones(1), replica_id=int(not is_main) - ) - - return sharded_state_dict - - def test_full_dp_reads(self, tmp_path_dist_ckpt): - """DP is the whole world.""" - ranks_placement = {'a': [(0, 0)], 'b': [(0, 1)], 'c': [(0, i) for i in range(8)]} - cross_rank_reads, same_rank_reads = self.determine_cross_rank_reads( - 1, ranks_placement, tmp_path_dist_ckpt - ) - - # We expect no cross-DP reads - assert not cross_rank_reads - # `c` was assigned to rank 2 during saving because 0 and 1 already saved `a` and `b` - if Utils.rank == 0: - assert same_rank_reads == {'a': [0]} - elif Utils.rank == 1: - assert same_rank_reads == {'b': [1]} - elif Utils.rank == 2: - assert same_rank_reads == {'c': [2]} - else: - assert not same_rank_reads - - def test_singleton_dp_reads(self, tmp_path_dist_ckpt): - """DP group has 1 rank (TP=8).""" - ranks_placement = {'a': [(0, 0)], 'b': [(1, 0)], 'c': [(i, 0) for i in range(8)]} - cross_rank_reads, same_rank_reads = self.determine_cross_rank_reads( - 8, ranks_placement, tmp_path_dist_ckpt - ) - - # We expect (unfortunately) a lot of cross-DP reads for `c` tensor. - if Utils.rank != 0: - assert cross_rank_reads == {'c': [0]} - - # `c` was assigned to rank 0 during saving because rank 0 belonged to the DP group - # which held the main replica for `c` - if Utils.rank == 0: - assert same_rank_reads == {'a': [0], 'c': [0]} - elif Utils.rank == 1: - assert same_rank_reads == {'b': [1]} - else: - assert not same_rank_reads - - def test_out_of_order_load(self, tmp_path_dist_ckpt): - """DP group has 8 rank (TP=1).""" - ranks_placement = {'a': [(0, 2)]} - cross_rank_reads, same_rank_reads = self.determine_cross_rank_reads( - 1, ranks_placement, tmp_path_dist_ckpt - ) - assert not cross_rank_reads - if Utils.rank == 2: - assert same_rank_reads == {'a': [2]} - - def test_cross_dp_access_does_not_disturb_the_distribution(self, tmp_path_dist_ckpt): - """Each DP group has 4 ranks (TP=2).""" - - # See `distribute_shards_to_ranks` algorithm for assignment explanation - ranks_placement = { - 'a': [(0, 0)], # saved by rank 0 obviously - # main replica is in DP group with ranks [0, 2, 4, 6], - # will be saved on rank 4 because 'c' is assigned first: - 'b': [(tp, dp) for tp in range(2) for dp in range(4)], - # assigned before 'b' because of a smaller potential saving ranks count - 'c': [(0, dp) for dp in range(3)], - # Here main replica is on rank 1 so will be saved by rank 1: - 'd': [(1, 0), (0, 0), (1, 3)], - # Rank 1 saved 'd' so rank 5 saves - 'e': [(1, 0), (1, 2)], - # Can be saved by DP group [1, 3, 5, 7], - # round-robin back to rank 1 - 'f': [(1, 0), (0, 0), (1, 2)], - 'g': [(1, 3)], # saved by rank 7 - } - # This dict encodes the comments above (who saves a given tensor) - # Save order: - # DP group 0: 'a', 'c', 'b' - # DP group 1: 'g', 'd', 'e', 'f' - expected_saving_ranks = {'a': 0, 'b': 4, 'c': 2, 'd': 1, 'e': 5, 'f': 1, 'g': 7} - # Which tensors are cross-read (from a different rank) by each rank. - # After assigning the intra-DP loads on the ranks according to the saving distribution, - # the cross-DP reads are assigned. So first `a`, 'e' and `g` are assigned, then - # rank 0 must cross-read 'd' and 'f' - # and one of the ranks [1, 3, 5, 7] must cross-read 'b'. Rank 3 does that (first empty) - expected_cross_load_ranks = {'d': 0, 'f': 0, 'b': 3} - cross_rank_reads, same_rank_reads = self.determine_cross_rank_reads( - 2, ranks_placement, tmp_path_dist_ckpt - ) - - for key, saving_rank in expected_saving_ranks.items(): - # Check `Utils.rank == saving_rank` *iff* it's expected - if Utils.rank == saving_rank: - assert same_rank_reads[key] == [Utils.rank], saving_rank - if same_rank_reads.get(key, []) == [Utils.rank]: - assert Utils.rank == saving_rank, key - - torch.distributed.barrier() - - if Utils.rank == 0: - assert cross_rank_reads == { - 'd': [expected_saving_ranks['d']], - 'f': [expected_saving_ranks['f']], - } - elif Utils.rank == 3: - assert cross_rank_reads == {'b': [expected_saving_ranks['b']]} - else: - assert not cross_rank_reads - - def determine_cross_rank_reads( - self, - tp_size: int, - ranks_placement: RanksPlacementT, - tmp_path_dist_ckpt: Path, - parallel_within_dp: bool = True, - ): - Utils.initialize_model_parallel(tp_size, 1) - parallelization_group = ( - parallel_state.get_data_parallel_group() - if parallel_within_dp - else torch.distributed.group.WORLD - ) - state_dict = self.get_sharded_state_dict(ranks_placement) - with TempNamedDir(tmp_path_dist_ckpt / 'determine_cross_rank_reads') as ckpt_dir: - save_strategy = FullyParallelSaveStrategyWrapper( - get_default_save_sharded_strategy(), parallelization_group - ) - save_strategy.save(state_dict, ckpt_dir) - - load_strategy = FullyParallelLoadStrategyWrapper( - get_default_strategy(StrategyAction.LOAD_SHARDED, 'torch_dist', 1), - parallelization_group, - do_cache_distribution=True, - exchange_algo='broadcast', - ) - - # Create a mock that will do what it's supposed to do, - # but additionally collect info about cross-rank reads. - cross_rank_reads = None - same_rank_reads = None - - def mock_local_plan(self): - self._validate_global_shapes(self.metadata, self.shapes_validation_sharded_tensors) - local_plan = super(MCoreLoadPlanner, self).create_local_plan() - - nonlocal cross_rank_reads - nonlocal same_rank_reads - cross_rank_reads = defaultdict(list) - same_rank_reads = defaultdict(list) - - # Debug cross-reads - for read_item in local_plan.items: - item_md = self.metadata.storage_data[read_item.storage_index] - - read_rank = int(item_md.relative_path.split('_')[2]) - if read_rank == torch.distributed.get_rank(): - same_rank_reads[read_item.dest_index.fqn].append(read_rank) - else: - cross_rank_reads[read_item.dest_index.fqn].append(read_rank) - - return local_plan - - with mock.patch.object(MCoreLoadPlanner, 'create_local_plan', mock_local_plan): - _ = load_strategy.load(state_dict, ckpt_dir) - - Utils.destroy_model_parallel() - - return cross_rank_reads, same_rank_reads diff --git a/tests/unit_tests/dist_checkpointing/test_global_metadata_reuse.py b/tests/unit_tests/dist_checkpointing/test_global_metadata_reuse.py deleted file mode 100644 index 5d0bc9dbc2..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_global_metadata_reuse.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - - -from unittest import mock - -import pytest - -from megatron.training.arguments import parse_args -from megatron.training.checkpointing import load_checkpoint, save_checkpoint -from tests.unit_tests.dist_checkpointing import ( - TempNamedDir, - init_basic_mock_args, - init_checkpointing_mock_args, - setup_model_and_optimizer, -) -from tests.unit_tests.test_utilities import Utils - - -class TestGlobalMetadataReuse: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) - def test_global_metadata_reuse(self, tmp_path_dist_ckpt, tp, pp): - Utils.initialize_model_parallel(tp, pp) - num_floating_point_operations_so_far = 0 - model, optimizer = setup_model_and_optimizer(1, tp, pp) - opt_param_scheduler = None - - mock_args = parse_args(ignore_unknown_args=True) - with ( - TempNamedDir( - tmp_path_dist_ckpt / "test_global_metadata_reuse" - ) as non_persistent_ckpt_dir, - mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args), - mock.patch("megatron.training.checkpointing.update_num_microbatches"), - ): - init_basic_mock_args(mock_args, tp, pp) - init_checkpointing_mock_args(mock_args, non_persistent_ckpt_dir) - mock_args.non_persistent_ckpt_type = "global" - mock_args.ckpt_assume_constant_structure = True - save_ckpt_context = {} - - # Check we avoid reduce_scatter - with mock.patch( - 'torch.distributed.checkpoint.utils._DistWrapper.reduce_scatter' - ) as reduce_scatter_mock: - save_checkpoint( - 1, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - save_ckpt_context, - ) - - assert reduce_scatter_mock.call_count == 0 - - assert save_ckpt_context['save_strategy'].cached_global_metadata is None - - resume_ckpt_context = {} - _, _ = load_checkpoint( - model, optimizer, opt_param_scheduler, checkpointing_context=resume_ckpt_context - ) - - load_strategy_cached_metadata = resume_ckpt_context[ - 'load_strategy' - ].cached_global_metadata - assert load_strategy_cached_metadata is not None - assert getattr(load_strategy_cached_metadata, "all_local_plans", None) is not None - - # Check we avoid reduce_scatter - with mock.patch( - 'torch.distributed.checkpoint.utils._DistWrapper.reduce_scatter' - ) as reduce_scatter_mock: - save_checkpoint( - 2, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - resume_ckpt_context, - ) - assert reduce_scatter_mock.call_count == 0 - - assert ( - load_strategy_cached_metadata - is resume_ckpt_context['save_strategy'].cached_global_metadata - ) - - assert resume_ckpt_context['save_strategy'].validated_loaded_metadata_reuse - - @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) - def test_no_global_metadata_reuse_on_different_parallelism(self, tmp_path_dist_ckpt, tp, pp): - Utils.initialize_model_parallel(tp, pp) - num_floating_point_operations_so_far = 0 - model, optimizer = setup_model_and_optimizer(1, tp, pp) - opt_param_scheduler = None - - mock_args = parse_args(ignore_unknown_args=True) - with ( - TempNamedDir( - tmp_path_dist_ckpt / "test_global_metadata_reuse" - ) as non_persistent_ckpt_dir, - mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args), - mock.patch("megatron.training.checkpointing.update_num_microbatches"), - ): - init_basic_mock_args(mock_args, tp, pp) - init_checkpointing_mock_args(mock_args, non_persistent_ckpt_dir) - mock_args.non_persistent_ckpt_type = "global" - mock_args.ckpt_assume_constant_structure = True - mock_args.ckpt_fully_parallel_save = True - mock_args.use_distributed_optimizer = False - - save_ckpt_context = {} - - # Check we avoid reduce_scatter - with mock.patch( - 'torch.distributed.checkpoint.utils._DistWrapper.reduce_scatter' - ) as reduce_scatter_mock: - save_checkpoint( - 1, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - save_ckpt_context, - ) - - assert reduce_scatter_mock.call_count == 0 - - assert save_ckpt_context['save_strategy'].base_strategy.cached_global_metadata is None - - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(pp, tp) - model, optimizer = setup_model_and_optimizer(1, pp, tp) - init_basic_mock_args(mock_args, pp, tp) - mock_args.use_distributed_optimizer = False - mock_args.no_load_rng = True - - resume_ckpt_context = {} - _, _ = load_checkpoint( - model, optimizer, opt_param_scheduler, checkpointing_context=resume_ckpt_context - ) - - load_strategy_cached_metadata = resume_ckpt_context[ - 'load_strategy' - ].cached_global_metadata - - assert load_strategy_cached_metadata is not None - assert getattr(load_strategy_cached_metadata, "all_local_plans", None) is not None - - # Check we avoid reduce_scatter - with mock.patch( - 'torch.distributed.checkpoint.utils._DistWrapper.reduce_scatter' - ) as reduce_scatter_mock: - save_checkpoint( - 2, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - resume_ckpt_context, - ) - assert reduce_scatter_mock.call_count == 0 - - assert not resume_ckpt_context[ - 'save_strategy' - ].base_strategy.validated_loaded_metadata_reuse diff --git a/tests/unit_tests/dist_checkpointing/test_local.py b/tests/unit_tests/dist_checkpointing/test_local.py deleted file mode 100644 index 1b2752f846..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_local.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import filecmp -import logging -import shutil -import tempfile -import time -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Callable, Tuple, Union -from unittest import mock - -import pytest -import torch - -from megatron.training.arguments import parse_args - -nvidia_resiliency_ext = pytest.importorskip( - "nvidia_resiliency_ext", - reason="nvidia_resiliency_ext is required for local checkpointing tests", -) - -from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.base_manager import ( - CheckpointingException, -) -from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import ( - LocalCheckpointManager, -) - -from megatron.core.dist_checkpointing import ShardedTensor -from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.mapping import ShardedBase, ShardedTensorFactory -from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict -from megatron.core.dist_checkpointing.utils import extract_nonpersistent -from megatron.training.async_utils import maybe_finalize_async_save -from megatron.training.checkpointing import generate_state_dict, load_checkpoint, save_checkpoint -from tests.unit_tests.dist_checkpointing import ( - TempNamedDir, - init_basic_mock_args, - init_checkpointing_mock_args, - setup_model_and_optimizer, -) -from tests.unit_tests.test_utilities import Utils - -from .utils import find_matching_values - - -# TODO: Use mock local checkpointing? -class TestLocalCheckpointingReplication: - - def test_filename_to_id(self): - iteration_string = "0000123" - rank = "4" - with tempfile.TemporaryDirectory() as tmpdir: - ckpt_mgr = LocalCheckpointManager(tmpdir) - filename = ckpt_mgr._filename_from_template(iteration_string, rank) - assert (123, 4) == ckpt_mgr._filename_to_id(filename)[:2] - - @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) - def test_sharded_tensors(self, tp, pp): - Utils.initialize_model_parallel(tp, pp) - num_floating_point_operations_so_far = 0 - model, optimizer = setup_model_and_optimizer(1, tp, pp) - - -class TestLocalCheckpointing: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) - @pytest.mark.parametrize(('use_torch_fsdp2'), [True, False]) - def test_sharded_tensors(self, tp, pp, use_torch_fsdp2): - Utils.initialize_model_parallel(tp, pp) - num_floating_point_operations_so_far = 0 - model, optimizer = setup_model_and_optimizer(1, tp, pp) - opt_param_scheduler = None - rng_state = None - iteration = None - optim_sd_kwargs = dict(sharding_type='fully_sharded_model_space') - mock_args = parse_args(ignore_unknown_args=True) - mock_args.no_save_optim = False - mock_args.no_save_rng = True - mock_args.use_torch_fsdp2 = use_torch_fsdp2 - # Test save_local - state_dict = generate_state_dict( - mock_args, - model, - optimizer, - opt_param_scheduler, - rng_state, - iteration=iteration, - optim_sd_kwargs=optim_sd_kwargs, - ) - sharded_tensor_factories = find_matching_values( - state_dict, lambda x: isinstance(x, ShardedTensorFactory) - ) - sharded_tensors = find_matching_values(state_dict, lambda x: isinstance(x, ShardedTensor)) - for ten in sharded_tensors: - assert ten.data != None - saved_state_dict, _ = MCoreTensorAwareStateDict.from_state_dict(state_dict, algo='atomic') - saved_sharded_tensors = find_matching_values( - saved_state_dict, lambda x: isinstance(x, ShardedTensor) - ) - assert ( - len(saved_sharded_tensors) - == len(sharded_tensors) + 2 * len(sharded_tensor_factories) - == len(list(saved_state_dict.tensors)) - ) - tensors = saved_state_dict.pop_tensors() - for ten in saved_sharded_tensors: - assert ten.data is None - assert saved_state_dict.is_hollow - hollow_sharded_tensors = find_matching_values( - saved_state_dict, lambda x: isinstance(x, torch.Tensor) - ) - assert hollow_sharded_tensors == [] - saved_state_dict.insert_tensors(tensors) - common_sharded_tensors = find_matching_values( - saved_state_dict.common_state_dict, lambda x: isinstance(x, ShardedTensor) - ) - assert common_sharded_tensors == [] - # Test load_local - state_dict = generate_state_dict( - mock_args, - model, - optimizer, - opt_param_scheduler, - rng_state, - iteration=iteration, - optim_sd_kwargs=optim_sd_kwargs, - ) - nonpersistent_state_dict, _ = extract_nonpersistent(state_dict) - # For a given use case - assert not nonpersistent_state_dict - loaded_state_dict = saved_state_dict.to_state_dict(state_dict) - only_left, only_right, mismatch = diff(loaded_state_dict, state_dict) - assert not only_left - assert not only_right - for i in mismatch: - # ShardedObjects and ShardedTensors should be replaced - assert issubclass(i[-1], ShardedBase) - - @pytest.mark.internal - @pytest.mark.parametrize(('tp,pp'), [(2, 4), (1, 1)]) - @pytest.mark.parametrize(('use_ramdisk'), [True, False]) - @pytest.mark.parametrize(('async_save'), [True, False]) - @pytest.mark.parametrize(('algo'), ['atomic', 'fully_parallel']) - def test_basic_save_load_scenarios( - self, tmp_path_dist_ckpt, tp, pp, use_ramdisk, async_save, algo - ): - Utils.initialize_model_parallel(tp, pp) - num_floating_point_operations_so_far = 0 - model, optimizer = setup_model_and_optimizer(1, tp, pp) - opt_param_scheduler = None - - mock_args = ( - SimpleNamespace() - ) # FIXME: fails with additional arguments (e.g.,'weight_decay') - if use_ramdisk: - tmp_path_dist_ckpt = Path("/dev/shm") - with ( - TempNamedDir(tmp_path_dist_ckpt / "test_local", sync=True) as local_ckpt_dir, - mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args), - mock.patch('megatron.training.async_utils.get_args', new=lambda: mock_args), - mock.patch("megatron.training.checkpointing.update_num_microbatches"), - ): - local_ckpt_dir = local_ckpt_dir / "subdir" # Test handling of non-existent directories - init_basic_mock_args(mock_args, tp, pp) - init_checkpointing_mock_args(mock_args, None) - mock_args.non_persistent_ckpt_type = 'local' - mock_args.non_persistent_local_ckpt_algo = algo - mock_args.async_save = async_save - mock_args.ckpt_fully_parallel_save = True # ensure proper sharding_type is set - checkpointing_context = { - 'local_checkpoint_manager': LocalCheckpointManager(local_ckpt_dir) - } - - save_checkpoint( - 1, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - checkpointing_context=checkpointing_context, - non_persistent_ckpt=True, - ) - if async_save: - maybe_finalize_async_save(True) - iteration, _ = load_checkpoint( - model, optimizer, opt_param_scheduler, checkpointing_context=checkpointing_context - ) - assert iteration == 1 - ckpt_id = checkpointing_context['local_checkpoint_manager']._ckpt_id(iteration) - ckpt_path = checkpointing_context['local_checkpoint_manager']._local_ckpt_path_from_id( - ckpt_id - ) - backup_path = ckpt_path.with_name('backup_' + ckpt_path.name) - checkpointing_context['local_checkpoint_manager'].latest_iteration = -1 - iteration, _ = load_checkpoint( - model, optimizer, opt_param_scheduler, checkpointing_context=checkpointing_context - ) - assert iteration == 1 - shutil.move(ckpt_path, backup_path) - checkpointing_context['local_checkpoint_manager'].latest_iteration = -1 - torch.distributed.barrier() - iteration, _ = load_checkpoint( - model, optimizer, opt_param_scheduler, checkpointing_context=checkpointing_context - ) - assert iteration == 0 - save_checkpoint( - 1, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - checkpointing_context=checkpointing_context, - non_persistent_ckpt=True, - ) - if async_save: - maybe_finalize_async_save(True) - if Utils.rank > 0: # Skip assertion on rank 0 due to harmless nondeterminism - assert filecmp.cmp(ckpt_path, backup_path, shallow=False), [ckpt_path, backup_path] - save_checkpoint( - 2, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - checkpointing_context=checkpointing_context, - non_persistent_ckpt=True, - ) - if async_save: - maybe_finalize_async_save(True) - time.sleep(0.01) # Allow sufficient time for async cleanup to complete - assert not ckpt_path.exists() - ckpt_id = checkpointing_context['local_checkpoint_manager']._ckpt_id(2) - ckpt_path = checkpointing_context['local_checkpoint_manager']._local_ckpt_path_from_id( - ckpt_id - ) - assert ckpt_path.exists() - - Utils.destroy_model_parallel() - - @pytest.mark.internal - @pytest.mark.parametrize(('tp,pp'), [(1, 1), (2, 4)]) - @pytest.mark.parametrize(('use_ramdisk'), [True, False]) - @pytest.mark.parametrize(('async_save'), [True, False]) - @pytest.mark.parametrize(('algo'), ['atomic', 'fully_parallel']) - @pytest.mark.flaky_in_dev - def test_failed_save(self, caplog, tmp_path_dist_ckpt, tp, pp, use_ramdisk, async_save, algo): - Utils.initialize_model_parallel(tp, pp) - num_floating_point_operations_so_far = 0 - model, optimizer = setup_model_and_optimizer(1, tp, pp) - opt_param_scheduler = None - - mock_args = parse_args(ignore_unknown_args=True) - if use_ramdisk: - tmp_path_dist_ckpt = Path("/dev/shm") - - def test_save_wrapper(save_wrapper, subdir): - with ( - TempNamedDir(tmp_path_dist_ckpt / subdir, sync=True) as local_ckpt_dir, - mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args), - mock.patch('megatron.training.async_utils.get_args', new=lambda: mock_args), - mock.patch("megatron.training.checkpointing.update_num_microbatches"), - mock.patch.object(LocalCheckpointManager, '_save', new=save_wrapper), - caplog.at_level(logging.INFO), - ): - - local_ckpt_dir = ( - local_ckpt_dir / "subdir" - ) # Test handling of non-existent directories - init_basic_mock_args(mock_args, tp, pp) - init_checkpointing_mock_args(mock_args, None) - mock_args.non_persistent_ckpt_type = 'local' - mock_args.non_persistent_local_ckpt_algo = algo - mock_args.async_save = async_save - mock_args.ckpt_fully_parallel_save = True # ensure proper sharding_type is set - checkpointing_context = { - 'local_checkpoint_manager': LocalCheckpointManager(local_ckpt_dir) - } - - with pytest.raises(CheckpointingException): - save_checkpoint( - 1, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - checkpointing_context=checkpointing_context, - non_persistent_ckpt=True, - ) - if async_save: - maybe_finalize_async_save(True) - iteration, _ = load_checkpoint( - model, - optimizer, - opt_param_scheduler, - checkpointing_context=checkpointing_context, - ) - assert iteration == 0 - assert not any((local_ckpt_dir / str(Utils.rank)).iterdir()) - - if Utils.rank == 1: - assert f"iter_0000001_{Utils.rank}_local.pt" not in caplog.text - else: - assert f"iter_0000001_{Utils.rank}_local.pt" in caplog.text - - original_save = LocalCheckpointManager._save - - def silent_error(self, *args, **kwargs): - if self.rank == 1: - return - return original_save(self, *args, **kwargs) - - def exception(self, *args, **kwargs): - if self.rank == 1: - raise Exception("TEST") - return original_save(self, *args, **kwargs) - - test_save_wrapper(silent_error, "test_sync") - if async_save: - test_save_wrapper(exception, "test_async") - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/test_mapping.py b/tests/unit_tests/dist_checkpointing/test_mapping.py deleted file mode 100644 index 38582d7524..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_mapping.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.dist_checkpointing import ShardedTensor -from megatron.core.dist_checkpointing.core import CheckpointingException -from megatron.core.dist_checkpointing.mapping import ( - ShardedObject, - ShardedTensorFactory, - apply_factories, - apply_factory_merges, - is_main_replica, -) -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestShardedTensor: - - # def setup_method(self, method): - # Utils.initialize_model_parallel(1,1) - # transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True) - # self.gpt_embedding = GPTEmbedding(config=transformer_config, vocab_size=100, max_sequence_length=4, add_position_embedding=True) - # - # def teardown_method(self, method): - # Utils.destroy_model_parallel() - - def test_from_rank_offsets_constructor(self, dtype=torch.float, device='cuda'): - data = torch.ones((1, 3, 7, 9), dtype=dtype, device=device) - shape = data.shape - rank_offsets = [(0, 0, 10), (2, 3, 6)] - sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets) - - assert isinstance(sh_ten, ShardedTensor) - assert sh_ten.dtype is dtype - assert sh_ten.local_shape == shape - assert sh_ten.global_shape == (shape[0] * 10, shape[1], shape[2] * 6, shape[3]) - assert sh_ten.global_offset == (0, 0, shape[2] * 3, 0) - assert sh_ten.axis_fragmentations == (10, 1, 6, 1) - - def test_from_rank_offsets_flat_constructor(self, dtype=torch.float, device='cuda'): - data = torch.arange(28, dtype=dtype, device=device).reshape((1, 4, 7)) - shape = data.shape - rank_offsets = [(1, 0, 2), (2, 3, 5)] - flattened_range = slice(4, 9) - flat_data = data.flatten()[flattened_range] - sh_ten = ShardedTensor.from_rank_offsets_flat( - 'keyA', flat_data, data.shape, *rank_offsets, flattened_range=flattened_range - ) - - # The main attributes properties are unchanged - assert isinstance(sh_ten, ShardedTensor) - assert sh_ten.dtype is dtype - assert sh_ten.local_shape == shape - assert sh_ten.global_shape == (shape[0], shape[1] * 2, shape[2] * 5) - assert sh_ten.global_offset == (0, 0, shape[2] * 3) - assert sh_ten.axis_fragmentations == (1, 2, 5) - - assert torch.all(sh_ten.data == torch.arange(4, 9, device=device)) - - def test_metadata_integrity_violation(self): - data = torch.ones((1, 3, 7, 9), device='meta') - rank_offsets = [(0, 0, 10), (2, 3, 6)] - sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets) - sh_ten.validate_metadata_integrity() - with pytest.raises(CheckpointingException): - sh_ten.local_shape = (1, 2, 7, 9) - sh_ten.validate_metadata_integrity() - - sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets) - with pytest.raises(CheckpointingException): - sh_ten.global_offset = (0, 1, 0) - sh_ten.validate_metadata_integrity() - - with pytest.raises(CheckpointingException): - sh_ten = ShardedTensor.from_rank_offsets_flat( - 'keyA', data, data.shape, *rank_offsets, flattened_range=slice(4, 9) - ) - - sh_ten = ShardedTensor.from_rank_offsets_flat( - 'keyA', data.flatten()[4:9], data.shape, *rank_offsets, flattened_range=slice(4, 9) - ) - assert sh_ten.local_shape == (1, 3, 7, 9) - with pytest.raises(CheckpointingException): - sh_ten.local_shape = (5,) - sh_ten.validate_metadata_integrity() - - def test_narrowing(self): - data = torch.ones((1, 3, 7, 9)) - rank_offsets = [(0, 0, 10), (2, 3, 6)] - sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets) - (narr_sh_ten,) = sh_ten.narrow(1, 1, 2) - assert narr_sh_ten.local_shape == (1, 2, 7, 9) - assert narr_sh_ten.global_shape == (10, 2, 42, 9) - assert narr_sh_ten.global_offset == (0, 0, 21, 0) - - (narr_sh_ten,) = sh_ten.narrow(2, 3, 2) - assert narr_sh_ten.local_shape == (1, 3, 2, 9) - assert narr_sh_ten.global_shape == (10, 3, 12, 9) - assert narr_sh_ten.global_offset == (0, 0, 6, 0) - - def test_flat_narrow(self): - data = torch.arange(28).reshape((4, 7)) - rank_offsets = [(0, 1, 2), (1, 3, 5)] - flattened_range = slice(4, 9) - flat_data = data.flatten()[flattened_range] - sh_ten = ShardedTensor.from_rank_offsets_flat( - 'keyA', flat_data, data.shape, *rank_offsets, flattened_range=flattened_range - ) - - # The main attributes properties are unchanged - assert isinstance(sh_ten, ShardedTensor) - assert torch.all(sh_ten.data == torch.arange(4, 9)) - - (narrow_sh_ten,) = sh_ten.narrow( - 0, 0, 1 - ) # First seven elements of unflat, intersection has 3 elements - assert torch.all(narrow_sh_ten.data == torch.arange(4, 7)) - assert narrow_sh_ten.local_shape == (1, 7) - assert narrow_sh_ten.global_shape == (2, 35) - assert narrow_sh_ten.global_offset == (1, 21) - - (narrow_sh_ten,) = sh_ten.narrow( - 0, 0, 3 - ) # First 21 elements of unflat, intersection has all 5 elements - assert torch.all(narrow_sh_ten.data == torch.arange(4, 9)) - assert narrow_sh_ten.local_shape == (3, 7) - assert narrow_sh_ten.global_shape == (6, 35) - assert narrow_sh_ten.global_offset == (3, 21) - - narrow_sh_ten = sh_ten.narrow(0, 2, 1) # empty intersection - assert not narrow_sh_ten, narrow_sh_ten - - -class TestShardedTensorFactory: - def test_build_and_merge(self): - def build_fn(key, tensor, replica_id, flattened_range): - assert flattened_range is None - return { - 'level2_a': ShardedTensor.from_rank_offsets( - key + 'part1', tensor + 1, replica_id=replica_id - ), - 'level2_b': ShardedTensor.from_rank_offsets( - key + 'part2', tensor + 2, replica_id=replica_id - ), - } - - # state_dict will be modified in-place - def get_state_dict(): - return { - 'level1': ShardedTensorFactory( - 'a', torch.arange(3), build_fn, lambda x: x['level2_b'] - ) - } - - state_dict = get_state_dict() - apply_factories(state_dict) - assert torch.allclose(state_dict['level1']['level2_a'].data, torch.tensor([1, 2, 3])) - assert torch.allclose(state_dict['level1']['level2_b'].data, torch.tensor([2, 3, 4])) - - # Simulate loading - state_dict['level1']['level2_a'] = state_dict['level1']['level2_a'].data - state_dict['level1']['level2_b'] = state_dict['level1']['level2_b'].data - - loaded_state_dict = apply_factory_merges(state_dict, get_state_dict()) - assert torch.allclose(loaded_state_dict['level1'], torch.tensor([2, 3, 4])) - - -def test_is_main_replica(): - assert is_main_replica(0) - assert is_main_replica((0,)) - assert is_main_replica((0, 0)) - assert not is_main_replica(1) - assert not is_main_replica(2) - assert not is_main_replica((1,)) - assert not is_main_replica((1, 0)) - assert not is_main_replica((1, 1, 1)) diff --git a/tests/unit_tests/dist_checkpointing/test_msc.py b/tests/unit_tests/dist_checkpointing/test_msc.py deleted file mode 100644 index 5016ddf793..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_msc.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import torch - -try: - from torch.distributed import DeviceMesh - from torch.distributed._tensor import DTensor - - HAVE_DTENSOR = True -except ImportError: - HAVE_DTENSOR = False - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing import ShardedTensor, load, save -from megatron.core.dist_checkpointing.strategies.base import StrategyAction, get_default_strategy -from megatron.core.msc_utils import MultiStorageClientFeature -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -class TestSerializationWithMultiStorageClient: - - def setup_method(self, method): - MultiStorageClientFeature.enable() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_process_save_load(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 4) - - sharded_state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), replica_id=Utils.rank - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.ones(3, 5, 7), replica_id=Utils.rank - ), - } - - if HAVE_DTENSOR: - mesh = DeviceMesh.from_group( - parallel_state.get_data_parallel_group(with_context_parallel=True), "cuda" - ) - sharded_state_dict['sd_keyD'] = ShardedTensor.from_rank_offsets( - 'keyD', - DTensor.from_local(torch.ones(3, 5, 7), mesh)._local_tensor, - replica_id=Utils.rank, - ) - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir( - tmp_path_dist_ckpt / 'test_single_process_save_load', sync=True - ) as ckpt_dir: - save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'torch_dist', 1) - save(sharded_state_dict, ckpt_dir, save_strategy) - torch.distributed.barrier() - - load_ssd = { - 'load_sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), replica_id=Utils.rank - ) - } - loaded_state_dict = load(load_ssd, ckpt_dir) - - assert set(loaded_state_dict.keys()) == {'load_sd_keyA'} - assert isinstance(loaded_state_dict['load_sd_keyA'], torch.Tensor) - assert loaded_state_dict['load_sd_keyA'].shape == (2, 4) - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/test_nonpersistent.py b/tests/unit_tests/dist_checkpointing/test_nonpersistent.py deleted file mode 100644 index 1b95becae3..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_nonpersistent.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import filecmp -import os -from unittest import mock - -import pytest - -from megatron.training.arguments import parse_args -from megatron.training.checkpointing import ( - _NON_PERSISTENT_CKPT_SUBDIR, - load_checkpoint, - save_checkpoint, -) -from tests.unit_tests.dist_checkpointing import ( - TempNamedDir, - init_basic_mock_args, - init_checkpointing_mock_args, - setup_model_and_optimizer, -) -from tests.unit_tests.test_utilities import Utils - - -class TestNonPersistentSaveAndLoad: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) - def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp): - Utils.initialize_model_parallel(tp, pp) - num_floating_point_operations_so_far = 0 - model, optimizer = setup_model_and_optimizer(1, tp, pp) - opt_param_scheduler = None - - mock_args = parse_args(ignore_unknown_args=True) - with TempNamedDir( - tmp_path_dist_ckpt / "test_non_persistent" - ) as non_persistent_ckpt_dir, mock.patch( - 'megatron.training.checkpointing.get_args', new=lambda: mock_args - ), mock.patch( - "megatron.training.checkpointing.update_num_microbatches" - ): - init_basic_mock_args(mock_args, tp, pp) - init_checkpointing_mock_args(mock_args, non_persistent_ckpt_dir) - mock_args.non_persistent_ckpt_type = "global" - - save_checkpoint( - 2, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - {}, - non_persistent_ckpt=True, - ) - save_checkpoint( - 3, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {} - ) - save_checkpoint( - 4, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - {}, - non_persistent_ckpt=True, - ) - iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler) - assert iteration == 4 - save_checkpoint( - 6, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {} - ) - iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler) - assert iteration == 6 - save_checkpoint( - 8, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - {}, - non_persistent_ckpt=True, - ) - iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler) - assert iteration == 8 - assert "iter_0000003" in os.listdir(non_persistent_ckpt_dir) - assert "iter_0000006" in os.listdir(non_persistent_ckpt_dir) - assert "iter_0000002" not in os.listdir( - os.path.join(non_persistent_ckpt_dir, _NON_PERSISTENT_CKPT_SUBDIR) - ) - assert "iter_0000004" in os.listdir( - os.path.join(non_persistent_ckpt_dir, _NON_PERSISTENT_CKPT_SUBDIR) - ) - assert "iter_0000008" in os.listdir( - os.path.join(non_persistent_ckpt_dir, _NON_PERSISTENT_CKPT_SUBDIR) - ) - ckpt_dirs = [ - "iter_0000003", - "iter_0000006", - _NON_PERSISTENT_CKPT_SUBDIR + "/iter_0000004", - _NON_PERSISTENT_CKPT_SUBDIR + "/iter_0000008", - ] - for ckpt_a in ckpt_dirs: - for ckpt_b in ckpt_dirs: - for filename in os.listdir(os.path.join(non_persistent_ckpt_dir, ckpt_a)): - if filename != "common.pt" and filename != ".metadata": - assert filecmp.cmp( - os.path.join(non_persistent_ckpt_dir, ckpt_a, filename), - os.path.join(non_persistent_ckpt_dir, ckpt_b, filename), - shallow=False, - ), [filename, ckpt_a, ckpt_b] - Utils.destroy_model_parallel() - - -class TestLegacySaveAndLoad: - @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) - def test_basic_save_load_scenario(self, tmp_path_dist_ckpt, tp, pp): - Utils.initialize_model_parallel(tp, pp) - num_floating_point_operations_so_far = 0 - model, optimizer = setup_model_and_optimizer(1, tp, pp) - opt_param_scheduler = None - - mock_args = parse_args(ignore_unknown_args=True) - with TempNamedDir(tmp_path_dist_ckpt / "test_legacy") as legacy_ckpt_dir, mock.patch( - 'megatron.training.checkpointing.get_args', new=lambda: mock_args - ), mock.patch("megatron.training.checkpointing.update_num_microbatches"): - init_basic_mock_args(mock_args, tp, pp) - init_checkpointing_mock_args(mock_args, legacy_ckpt_dir) - - save_checkpoint( - 2, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {} - ) - iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler) - assert iteration == 2 - assert "iter_0000002" in os.listdir(legacy_ckpt_dir) - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/test_optimizer.py b/tests/unit_tests/dist_checkpointing/test_optimizer.py deleted file mode 100644 index bbeb393041..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_optimizer.py +++ /dev/null @@ -1,1132 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import re -from copy import deepcopy -from functools import partial -from unittest import mock -from unittest.mock import patch - -import pytest -import torch -from torch.optim import Adam - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing import ShardedTensor, load, load_plain_tensors, save -from megatron.core.dist_checkpointing.dict_utils import diff, nested_values -from megatron.core.dist_checkpointing.optimizer import ( - get_param_id_to_sharded_param_map, - optim_state_to_sharding_state, -) -from megatron.core.dist_checkpointing.utils import add_prefix_for_sharding, extract_sharded_tensors -from megatron.core.dist_checkpointing.validation import StrictHandling -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_with_transformer_engine_spec as gpt_te_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.optimizer import ChainedOptimizer -from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed -from megatron.core.transformer import MLATransformerConfig, TransformerConfig -from megatron.core.transformer.mlp import apply_swiglu_sharded_factory -from megatron.core.utils import is_torch_min_version -from megatron.training.arguments import parse_args -from megatron.training.checkpointing import load_checkpoint, save_checkpoint -from tests.unit_tests.dist_checkpointing import ( - TempNamedDir, - init_basic_mock_args, - init_checkpointing_mock_args, - initialize_gpt_model, - setup_model_and_optimizer, - setup_moe_model_and_optimizer, -) -from tests.unit_tests.test_utilities import Utils - - -class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(8, 16, 3) - self.proj = torch.nn.Linear(8, 5) - self.config = TransformerConfig( - hidden_size=8, num_attention_heads=1, num_layers=1, bf16=True - ) - - def sharded_state_dict(self): - sharded_state_dict = self.state_dict(keep_vars=True) - # conv - sharded_state_dict['conv.weight'] = ShardedTensor.from_rank_offsets( - 'conv.weight', - sharded_state_dict['conv.weight'], - ( - 1, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_tensor_model_parallel_world_size(), - ), - ) - # bias is non-sharded - sharded_state_dict['conv.bias'] = ShardedTensor.from_rank_offsets( - 'conv.bias', sharded_state_dict['conv.bias'] - ) - - # proj - sharded_state_dict['proj.weight'] = ShardedTensor.from_rank_offsets( - 'proj.weight', sharded_state_dict['proj.weight'], (0, Utils.rank, Utils.world_size) - ) - sharded_state_dict['proj.bias'] = ShardedTensor.from_rank_offsets( - 'proj.bias', sharded_state_dict['proj.bias'], (0, Utils.rank, Utils.world_size) - ) - return sharded_state_dict - - -class SwigluFactoryModel(torch.nn.Module): - def __init__(self, pp_separate_model: bool = False): - super().__init__() - self.linear = torch.nn.Linear( - 5, 64 // parallel_state.get_tensor_model_parallel_world_size(), bias=False - ) - self.config = TransformerConfig( - hidden_size=8, num_attention_heads=1, num_layers=1, bf16=True - ) - self.pp_separate_model = pp_separate_model - - def sharded_state_dict(self): - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - if self.pp_separate_model: - pp_replica_id = 0 - else: - pp_replica_id = pp_rank - sharded_state_dict = self.state_dict(keep_vars=True) - sharded_state_dict['linear.weight'] = ShardedTensor.from_rank_offsets( - 'linear.weight', - sharded_state_dict['linear.weight'], - ( - ( - 0, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_tensor_model_parallel_world_size(), - ) - ), - replica_id=( - ( - pp_replica_id, - 0, - parallel_state.get_data_parallel_rank(with_context_parallel=True), - ) - ), - ) - sharded_state_dict['linear.weight'] = apply_swiglu_sharded_factory( - sharded_state_dict['linear.weight'], () - ) - if self.pp_separate_model: - add_prefix_for_sharding(sharded_state_dict, f'pp_rank_{pp_rank}.') - return sharded_state_dict - - -class SwigluFactoryModel(torch.nn.Module): - def __init__(self, pp_separate_model: bool = False): - super().__init__() - self.linear = torch.nn.Linear(5, 64, bias=False) - self.config = TransformerConfig( - hidden_size=8, num_attention_heads=1, num_layers=1, bf16=True - ) - self.pp_separate_model = pp_separate_model - - def sharded_state_dict(self): - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - if self.pp_separate_model: - pp_replica_id = 0 - else: - pp_replica_id = pp_rank - sharded_state_dict = self.state_dict(keep_vars=True) - sharded_state_dict['linear.weight'] = ShardedTensor.from_rank_offsets( - 'linear.weight', - sharded_state_dict['linear.weight'], - replica_id=( - ( - pp_replica_id, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_data_parallel_rank(with_context_parallel=True), - ) - ), - ) - if self.pp_separate_model: - add_prefix_for_sharding(sharded_state_dict, f'pp_rank_{pp_rank}.') - return sharded_state_dict - - -class Model1dFlattenTensor(torch.nn.Module): - """This model is used to test whether a 1d flatten tensor can be correctly - transformed into torch dist-ckpt form - """ - - def __init__(self): - super().__init__() - self.config = TransformerConfig( - hidden_size=128, num_attention_heads=1, num_layers=1, bf16=True - ) - weight_size_per_rank = ( - self.config.hidden_size // parallel_state.get_tensor_model_parallel_world_size() - ) - self.weight_1d = torch.nn.Parameter(torch.randn(weight_size_per_rank)) - - def sharded_state_dict(self): - sharded_state_dict = self.state_dict(keep_vars=True) - sharded_state_dict['weight_1d'] = ShardedTensor.from_rank_offsets( - 'weight_1d', - sharded_state_dict['weight_1d'], - ( - ( - 0, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_tensor_model_parallel_world_size(), - ) - ), - replica_id=( - ( - parallel_state.get_pipeline_model_parallel_rank(), - 0, - parallel_state.get_data_parallel_rank(with_context_parallel=True), - ) - ), - ) - return sharded_state_dict - - -def get_param_state_dp_zero(optimizer): - if isinstance(optimizer, ChainedOptimizer): - assert len(optimizer.chained_optimizers) == 1 - optim_param_state_A = optimizer.chained_optimizers[0].get_parameter_state_dp_zero( - use_gloo_comm=False - ) - else: - optim_param_state_A = optimizer.get_parameter_state_dp_zero(use_gloo_comm=False) - return optim_param_state_A - - -class TestOptimizer: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_optimizer_params(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(1, 1) - model = Model() - # Force optimizer state initialization - for p in model.parameters(): - p.grad = torch.ones_like(p.data) - optim = Adam(model.parameters()) - optim.step() - - model_state_dict = model.sharded_state_dict() - param_map = get_param_id_to_sharded_param_map( - model_state_dict, optim.param_groups[0]['params'] - ) - optim_state_dict = optim.state_dict() - optim_state_to_sharding_state(optim_state_dict, param_map, exclude_keys=('step',)) - - optim_sharded_tensors = nested_values(extract_sharded_tensors(optim_state_dict)[0]) - optim_sharded_keys = {sh_ten.key for sh_ten in optim_sharded_tensors} - assert len(optim_sharded_keys) == 2 * len(model_state_dict) - assert optim_sharded_keys == set( - [ - f'optimizer.state.{state_key}.{layer_name}' - for state_key in ['exp_avg', 'exp_avg_sq'] - for layer_name in model_state_dict - ] - ) - - -def initialize_pp_agnostic_model(pre_process=True, post_process=True, seed=0, **config_kwargs): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - return SwigluFactoryModel(False) - - -def initialize_pp_agnostic_gpt_model(pre_process=True, post_process=True, seed=0, **config_kwargs): - return initialize_gpt_model(False, False, seed=seed, **config_kwargs) - - -def initialize_small_model(pre_process=True, post_process=True, seed=0, **config_kwargs): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - return SwigluFactoryModel() - - -def initialize_1d_flatten_tensor_model( - pre_process=True, post_process=True, seed=0, **config_kwargs -): - # This model is used to test whether a 1d flatten tensor can be correctly - # transformed into torch dist-ckpt form - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - return Model1dFlattenTensor() - - -def initialize_real_model( - seed, - pre_process, - post_process, - vp_stage=None, - is_moe=False, - is_mla=False, - virtual_pipeline_model_parallel_size=None, - **config_kwargs, -): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - default_config_kwargs = dict( - num_layers=6, - hidden_size=16, - num_attention_heads=8, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - bf16=True, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - ) - if is_moe: - default_config_kwargs["moe_ffn_hidden_size"] = 128 - default_config_kwargs["num_moe_experts"] = 4 - default_config_kwargs["add_bias_linear"] = False - # Pop unused fields - config_kwargs.pop("use_sp") - config_kwargs.pop("use_te") - config_kwargs.pop("use_grouped_mlp") - config_kwargs.pop("use_glu") - if is_mla: - default_config_kwargs["multi_latent_attention"] = True - default_config_kwargs["q_lora_rank"] = 96 - default_config_kwargs["kv_lora_rank"] = 512 - default_config_kwargs["qk_head_dim"] = 64 - default_config_kwargs["qk_pos_emb_head_dim"] = 32 - default_config_kwargs["v_head_dim"] = 64 - default_config_kwargs.update(**config_kwargs) - config_cls = MLATransformerConfig if is_mla else TransformerConfig - transformer_config = config_cls(**default_config_kwargs) - - if is_moe: - layer_spec = get_gpt_decoder_block_spec( - transformer_config, use_transformer_engine=True, vp_stage=vp_stage - ) - else: - layer_spec = gpt_te_spec(multi_latent_attention=is_mla) - this_model = GPTModel( - config=transformer_config, - transformer_layer_spec=layer_spec, - vocab_size=128, - max_sequence_length=4, - pre_process=pre_process, - post_process=post_process, - vp_stage=vp_stage, - ) - - return this_model - - -def load_checkpoint_no_arg_checks(*args, **kwargs): - with mock.patch('megatron.training.checkpointing.check_checkpoint_args'): - with mock.patch('megatron.training.checkpointing.update_num_microbatches'): - return load_checkpoint(*args, **kwargs) - - -class TestDistributedOptimizer: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize("fully_parallel", [False, True]) - @pytest.mark.parametrize( - ("tp_pp_ep", "is_moe", "is_mla", "test_step", "kwargs"), - [ - ((2, 2, 1), False, False, False, {}), # check TP - ((1, 2, 1), False, False, True, {}), # check "step" is synced - ((1, 2, 1), False, True, False, {}), # check param group order is right - ( - (1, 8, 1), - False, - False, - False, - { - "account_for_embedding_in_pipeline_split": True, - "account_for_loss_in_pipeline_split": True, - }, - ), # check embedding standalone - ( - (1, 2, 2), - True, - False, - True, - {"moe_layer_freq": [0, 0, 0, 1, 1, 1]}, - ), # check moe not on all ranks (case 1) - ( - (1, 2, 2), - True, - False, - True, - {"moe_layer_freq": [1, 1, 1, 0, 0, 0]}, - ), # check moe not on all ranks (case 2) - ], - ) - def test_optimizer_common_state_dict( - self, tmp_path_dist_ckpt, fully_parallel, tp_pp_ep, is_moe, is_mla, test_step, kwargs - ): - initialize_fn = partial(initialize_real_model, is_moe=is_moe, is_mla=is_mla, **kwargs) - - # Initialize parallel - tp, pp, ep = tp_pp_ep - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp, - pipeline_model_parallel_size=pp, - expert_model_parallel_size=ep, - ) - rank = torch.distributed.get_rank() - - with TempNamedDir(tmp_path_dist_ckpt / 'test_dp_sharding', sync=True) as ckpt_dir: - mock_args = parse_args(ignore_unknown_args=True) - mock_args.use_distributed_optimizer = True - with mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args): - # Initialize model and optimizer A - if is_moe: - model, optimizer_A = setup_moe_model_and_optimizer( - seed=2, tp=tp, pp=pp, ep=ep, initialize_fn=initialize_fn - ) - else: - model, optimizer_A = setup_model_and_optimizer( - seed=2, tp=tp, pp=pp, initialize_fn=initialize_fn - ) - if test_step: - # Simulate "step" not set in some of the param groups on rank 0. - # TE FusedAdam may have "step" not set in some of the param groups on some ranks. - for i, param_group in enumerate( - optimizer_A.chained_optimizers[0].optimizer.param_groups - ): - if rank > 0 or i == 0: - param_group['step'] = 1234 - - # Save checkpoint - init_checkpointing_mock_args(mock_args, ckpt_dir, fully_parallel=fully_parallel) - from megatron.training.training import preprocess_common_state_dict - - save_checkpoint( - 10, - model, - optimizer_A, - None, - 0, - preprocess_common_state_dict_fn=preprocess_common_state_dict, - ) - - # Get optimizer A param state - optim_param_state_A = optimizer_A.state_dict() - - # Initialize model and optimizer B - if is_moe: - model, optimizer_B = setup_moe_model_and_optimizer( - seed=3, tp=tp, pp=pp, ep=ep, initialize_fn=initialize_fn - ) - else: - model, optimizer_B = setup_model_and_optimizer( - seed=3, tp=tp, pp=pp, initialize_fn=initialize_fn - ) - # Load optimizer B from checkpoint - load_checkpoint_no_arg_checks(model, optimizer_B, None) - if test_step: - # Complete "step" for comparison - for i, param_group in enumerate( - optimizer_A.chained_optimizers[0].optimizer.param_groups - ): - if rank == 0 and i > 0: - param_group['step'] = 1234 - # Get optimizer B param state - optim_param_state_B = optimizer_B.state_dict() - - # Test both param state dicts are equal - diffs = diff(optim_param_state_A, optim_param_state_B) - assert not any(map(bool, diffs)), (rank, diffs) - - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp', 'use_glu'), - [((2, 2), (2, 4), False), ((1, 8), (4, 1), True), ((2, 4), (4, 2), False)], - ) - def test_finetune_doesnt_load_optimizer( - self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, use_glu - ): - """Test finetuning doesn't try to load the optimizer.""" - Utils.initialize_model_parallel(*src_tp_pp) - with TempNamedDir( - tmp_path_dist_ckpt / 'test_finetune_doesnt_load_optimizer', sync=True - ) as ckpt_dir: - mock_args = parse_args(ignore_unknown_args=True) - with mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args): - init_basic_mock_args(mock_args, tp=src_tp_pp[0], pp=src_tp_pp[1]) - init_checkpointing_mock_args(mock_args, ckpt_dir, False) - - model, optimizer = setup_model_and_optimizer( - seed=2, - tp=src_tp_pp[0], - pp=src_tp_pp[1], - initialize_fn=partial(initialize_gpt_model, use_glu=use_glu), - ) - - save_checkpoint(10, model, optimizer, None, 0) - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel(*dest_tp_pp) - mock_args.tensor_model_parallel_size = dest_tp_pp[0] - mock_args.pipeline_model_parallel_size = dest_tp_pp[1] - model, optimizer = setup_model_and_optimizer( - seed=3, - tp=dest_tp_pp[0], - pp=dest_tp_pp[1], - initialize_fn=partial(initialize_gpt_model, use_glu=use_glu), - ) - model_unloaded_state_dict = deepcopy(model[0].state_dict()) - optim_unloaded_state_dict = deepcopy(optimizer.state_dict()) - - # Load with different TPxPP should raise DistributeOptimizer error - with pytest.raises(RuntimeError) as exc_info: - load_checkpoint_no_arg_checks(model, optimizer, None) - # "(TP, PP) mismatch" check is for backwards compatibility tests - assert "(TP, PP) mismatch" in str( - exc_info.value - ) or "(TP, PP, encoder TP, encoder PP) mismatch" in str(exc_info.value) - - # Check that the state didn't change - assert not any(diff(model[0].state_dict(), model_unloaded_state_dict)) - assert not any(diff(optimizer.state_dict(), optim_unloaded_state_dict)) - - # Now test the same with a `finetune` flag - mock_args.finetune = True - load_checkpoint_no_arg_checks(model, optimizer, None) - - # Model weights should be different, but optimizer state is unchanged - diffs = diff(model[0].state_dict(), model_unloaded_state_dict) - # diffs[0] and diffs[1] is structural diff, diffs[2] is values diff - - # we expect only values diff - assert not diffs[0] and not diffs[1] and diffs[2] - assert not any(diff(optimizer.state_dict(), optim_unloaded_state_dict)) - - # ... or `no_load_optim` flag - model, optimizer = setup_model_and_optimizer( - seed=3, - tp=dest_tp_pp[0], - pp=dest_tp_pp[1], - initialize_fn=partial(initialize_gpt_model, use_glu=use_glu), - ) - mock_args.finetune = False - mock_args.no_load_optim = True - mock_args.no_load_rng = True - load_checkpoint_no_arg_checks(model, optimizer, None) - - # Model weights should be different, but optimizer state is unchanged - diffs = diff(model[0].state_dict(), model_unloaded_state_dict) - # diffs[0] and diffs[1] is structural diff, diffs[2] is values diff - - # we expect only values diff - assert not diffs[0] and not diffs[1] and diffs[2] - assert not any(diff(optimizer.state_dict(), optim_unloaded_state_dict)) - - @pytest.mark.skipif( - not is_torch_min_version("2.6a0"), reason="dp_reshardable requires PyTorch 2.5 or later" - ) - @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp', 'src_bucket_pad_divisor', 'dest_bucket_pad_divisor'), - [ - # PP must be decreasing - # Note: PP must be > 1 if TP <= 2 because of empty buckets otherwise - ((1, 2), (1, 2), 8 * 7, 8 * 5), - ((2, 4), (2, 4), 128, 128), - ((8, 1), (8, 1), 8, 4 * 11), - # DP resharding: - ((4, 2), (4, 1), 8 * 7, 8 * 5), - ((2, 4), (2, 2), 128, 128), - ((1, 4), (1, 2), 8, 4 * 11), - ((1, 8), (1, 4), 8 * 7, 8 * 5), - ((1, 8), (1, 2), 128, 128), - ], - ) - def test_bucket_space_optimizer_save_load( - self, - tmp_path_dist_ckpt, - src_tp_pp, - dest_tp_pp, - src_bucket_pad_divisor, - dest_bucket_pad_divisor, - ): - """Test DistOpt save/load with dp_reshardable format. - - Since unit test have a fixed world size and "bucket_space" format is - only DP-reshardable, we can't simply change DP. The trick is to use PP rank - agnostic model and decrease PP for load - some DP groups will be missing - but the common subset is enough to test correctness. - """ - Utils.initialize_model_parallel(*src_tp_pp) - src_num_dp_groups = src_tp_pp[1] * src_tp_pp[0] - dest_num_dp_groups = dest_tp_pp[1] * dest_tp_pp[0] - assert ( - dest_num_dp_groups <= src_num_dp_groups - ), 'This test cant be run with increasing number of DP groups' - - with ( - TempNamedDir( - tmp_path_dist_ckpt / 'test_bucket_state_optimizer_save_load_A', sync=True - ) as ckpt_dir_A, - TempNamedDir( - tmp_path_dist_ckpt / 'test_bucket_state_optimizer_save_load_B', sync=True - ) as ckpt_dir_B, - ): - # Init model and optimizer with "src" bucket padding - with patch('megatron.core.distributed.param_and_grad_buffer.math.lcm') as lcm_mock: - lcm_mock.return_value = src_bucket_pad_divisor - assert len(lcm_mock.mock_calls) == 0 - model_A, optimizer_A = setup_model_and_optimizer( - seed=2, - tp=src_tp_pp[0], - pp=src_tp_pp[1], - bf16=True, - dist_opt=True, - initialize_fn=initialize_pp_agnostic_model, - ) - assert len(lcm_mock.mock_calls) > 1 - - metadata = {'distrib_optim_sharding_type': 'dp_reshardable'} - - model_sharded_sd = model_A[0].sharded_state_dict() - optim_sd = optimizer_A.sharded_state_dict(model_sharded_sd, metadata=metadata) - per_bucket_numel_unpadded_A = optim_sd['param_state']['per_bucket_numel_unpadded'].data - save(optim_sd, ckpt_dir_A) - Utils.destroy_model_parallel() - - # Load checkpoint A with different PP (and therefore DP) and save as checkpoint B - Utils.initialize_model_parallel(*dest_tp_pp) - dest_dp_group_idx = torch.distributed.get_rank( - parallel_state.get_model_parallel_group() - ) - # Init model and optimizer with "dest" bucket padding - with patch('megatron.core.distributed.param_and_grad_buffer.math.lcm') as lcm_mock: - lcm_mock.return_value = dest_bucket_pad_divisor - assert len(lcm_mock.mock_calls) == 0 - model_B, optimizer_B = setup_model_and_optimizer( - seed=3, - tp=dest_tp_pp[0], - pp=dest_tp_pp[1], - bf16=True, - dist_opt=True, - initialize_fn=initialize_pp_agnostic_model, - ) - assert len(lcm_mock.mock_calls) > 1 - - model_sharded_sd = model_B[0].sharded_state_dict() - load_sharded_state_dict = optimizer_B.sharded_state_dict( - model_sharded_sd, metadata=metadata, is_loading=True - ) - state_dict, missing_keys, unexpected_keys = load( - load_sharded_state_dict, ckpt_dir_A, strict=StrictHandling.RETURN_ALL - ) - - # Check that because of decreasing PP, some DP groups were not read. - assert not unexpected_keys - missing_dp_groups = set() - for missing_key in missing_keys: - match = re.search(r'dp_group_idx_(\d+)', missing_key) - assert match is not None - missing_dp_groups.add(int(match.group(1))) - - assert missing_dp_groups == set(range(dest_num_dp_groups, src_num_dp_groups)) - - # Save optimizer B checkpoint to compare them - optimizer_B.load_state_dict(state_dict) - model_sharded_sd = model_B[0].sharded_state_dict() - optim_sd = optimizer_B.sharded_state_dict(model_sharded_sd, metadata=metadata) - per_bucket_numel_unpadded_B = optim_sd['param_state']['per_bucket_numel_unpadded'].data - save(optim_sd, ckpt_dir_B) - Utils.destroy_model_parallel() - - # Test both checkpoints are equal - # Ckpt A has more keys because of larger PP. - # Each rank tests correctness within its DP group, and only unpadded tensor part. - assert per_bucket_numel_unpadded_A == per_bucket_numel_unpadded_B - assert len(per_bucket_numel_unpadded_A) == 1 # Assuming a simple case with one buffer - per_bucket_numel_unpadded_A = per_bucket_numel_unpadded_A[0] - assert len(per_bucket_numel_unpadded_A) == 1 # Assuming a simple case with one dtype - per_bucket_numel_unpadded = next(iter(per_bucket_numel_unpadded_A.values())) - Utils.initialize_model_parallel(1, 1) - plain_state_dict_A = load_plain_tensors(ckpt_dir_A) - plain_state_dict_B = load_plain_tensors(ckpt_dir_B) - torch.distributed.barrier() - - # We test only the `plain_state_dict_B` keys because of decreasing PP - for key in list(plain_state_dict_B.keys()): - if 'per_bucket_numel' in key or 'param_state_sharding_type' in key: - del plain_state_dict_A[key] - del plain_state_dict_B[key] - continue - match = re.search(r'dp_group_idx_(\d+).+bucket_idx_(\d+)', key) - assert match is not None, key - dp_group_idx = int(match.group(1)) - bucket_idx = int(match.group(2)) - if dp_group_idx != dest_dp_group_idx: - del plain_state_dict_A[key] - del plain_state_dict_B[key] - else: - numel_unpadded = per_bucket_numel_unpadded[bucket_idx] - assert len(plain_state_dict_A[key]) == numel_unpadded - assert len(plain_state_dict_B[key]) == numel_unpadded - - only_left, only_right, mismatch = diff(plain_state_dict_A, plain_state_dict_B) - - missing_tensors = set( - key - for key in missing_keys - if 'per_bucket_numel' not in key and not key.endswith('.optimizer') - ) - assert set(key[0] for key in only_left) == missing_tensors - assert not only_right - assert not mismatch - - @pytest.mark.skipif( - not is_torch_min_version("2.6a0"), reason="dp_reshardable requires PyTorch 2.7 or later" - ) - @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp', 'sharding_type', 'mem_efficient'), - [ - # PP must be decreasing - # Note: PP must be > 1 if TP <= 2 because of empty buckets otherwise - ((2, 4), (2, 4), 'fully_reshardable', False), - ((4, 2), (4, 2), 'dp_reshardable', None), - ((8, 1), (8, 1), 'fully_sharded_model_space', None), - # DP resharding: - ((4, 2), (4, 1), 'dp_reshardable', None), - ((2, 4), (2, 2), 'fully_reshardable', False), - ((2, 4), (2, 2), 'fully_reshardable', True), - ((1, 8), (1, 2), 'fully_sharded_model_space', None), - ], - ) - @pytest.mark.parametrize("initalize_fn", [initialize_pp_agnostic_model]) - def test_nonreshardable_optimizer_save_load( - self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, initalize_fn, sharding_type, mem_efficient - ): - """Generalization of the test above for different formats. - - This time we don't load "plain" tensors from the checkpoint to compare. - Instead, we use `get_param_state_dp_zero` method to have common representation - irrespective of DP size. - - This test requires src and dest optimizers to be on the same rank. - The `test_model_parallel_dp_group_idx_preservation` test checks that - there is at least one testing rank for each DP group, given the 'tp-pp-dp' - parallel state order. - """ - Utils.initialize_model_parallel(*src_tp_pp, order='tp-pp-dp') - src_num_dp_groups = src_tp_pp[1] * src_tp_pp[0] - dest_num_dp_groups = dest_tp_pp[1] * dest_tp_pp[0] - src_dp_group_idx = torch.distributed.get_rank(parallel_state.get_model_parallel_group()) - assert ( - dest_num_dp_groups <= src_num_dp_groups - ), 'This test cant be run with increasing number of DP groups' - - with TempNamedDir( - tmp_path_dist_ckpt / 'test_nonreshardable_optimizer_save_load', sync=True - ) as ckpt_dir_A: - model_A, optimizer_A = setup_model_and_optimizer( - seed=2, - tp=src_tp_pp[0], - pp=src_tp_pp[1], - bf16=True, - dist_opt=True, - initialize_fn=initalize_fn, - ) - - metadata = { - 'distrib_optim_sharding_type': sharding_type, - 'distrib_optim_fully_reshardable_mem_efficient': mem_efficient, - } - - model_sharded_sd = model_A[0].sharded_state_dict() - optim_sd = optimizer_A.sharded_state_dict(model_sharded_sd, metadata=metadata) - save(optim_sd, ckpt_dir_A) - - dp_zero_optim_A = get_param_state_dp_zero(optimizer_A) - Utils.destroy_model_parallel() - - # Load checkpoint A with different TP/PP and save as checkpoint B - Utils.initialize_model_parallel(*dest_tp_pp, order='tp-pp-dp') - dest_dp_group_idx = torch.distributed.get_rank( - parallel_state.get_model_parallel_group() - ) - same_dp_group = src_dp_group_idx == dest_dp_group_idx - model_B, optimizer_B = setup_model_and_optimizer( - seed=3, - tp=dest_tp_pp[0], - pp=dest_tp_pp[1], - bf16=True, - dist_opt=True, - initialize_fn=initalize_fn, - ) - # Before checkpoint load the state is expected to differ - dp_zero_optim_B = get_param_state_dp_zero(optimizer_B) - assert not self.check_equal_dp_zero_state( - dp_zero_optim_A, dp_zero_optim_B, same_dp_group - ) - - model_sharded_sd = model_B[0].sharded_state_dict() - load_sharded_state_dict = optimizer_B.sharded_state_dict( - model_sharded_sd, metadata=metadata, is_loading=True - ) - - state_dict, missing_keys, unexpected_keys = load( - load_sharded_state_dict, ckpt_dir_A, strict=StrictHandling.RETURN_ALL - ) - assert not unexpected_keys - missing_dp_groups = set() - for missing_key in missing_keys: - match = re.search(r'dp_group_idx_(\d+)', missing_key) - assert match is not None - missing_dp_groups.add(int(match.group(1))) - - optimizer_B.load_state_dict(state_dict) - dp_zero_optim_B = get_param_state_dp_zero(optimizer_B) - - assert self.check_equal_dp_zero_state( - dp_zero_optim_A, dp_zero_optim_B, same_dp_group, raise_if_different=True - ) - - def check_equal_dp_zero_state( - self, dp_zero_state_A, dp_zero_state_B, same_dp_group, raise_if_different=False - ): - if same_dp_group and parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0: - diffs = diff(dp_zero_state_A, dp_zero_state_B) - is_equal = not any(map(bool, diffs)) - else: - diffs = None - is_equal = True - - all_equal = torch.tensor(int(is_equal), device='cuda') - torch.distributed.all_reduce(all_equal, op=torch.distributed.ReduceOp.MIN) - if bool(all_equal.item()): - return True - else: - if raise_if_different: - raise RuntimeError(f'[{Utils.rank}] {diffs}') - return False - - @pytest.mark.parametrize('tp_pp', [(2, 4), (4, 2), (1, 1), (2, 1), (1, 8)]) - def test_model_parallel_rank_order(self, tp_pp): - """Verifies that DP group idx is `PP rank * TP size + TP rank`.""" - Utils.initialize_model_parallel(*tp_pp, order='tp-pp-dp') - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - tp_rank = parallel_state.get_tensor_model_parallel_rank() - tp_size = parallel_state.get_tensor_model_parallel_world_size() - dp_group_idx = torch.distributed.get_rank(parallel_state.get_model_parallel_group()) - - assert pp_rank * tp_size + tp_rank == dp_group_idx - - @pytest.mark.parametrize( - ('src_pp', 'dest_pp'), - [ - # PP must be decreasing - (8, 1), - (4, 1), - (2, 1), - (1, 1), - (8, 2), - (4, 2), - (2, 2), - (8, 4), - (4, 4), - (8, 4), - ], - ) - @pytest.mark.parametrize('tp', [1, 2, 4, 8]) - def test_model_parallel_dp_group_idx_preservation(self, tp, src_pp, dest_pp): - """For each dst DP group, test there is at least one DP 0 rank both in the src and dest group. - - For this condition to hold, `parallel_state` must be initialized with 'tp-pp-dp' order. - """ - assert src_pp >= dest_pp, 'This test is only for decreasing PP' - if tp * src_pp > Utils.world_size: - pytest.skip(f'TP ({tp}) * PP ({src_pp}) > {Utils.world_size}') - - Utils.initialize_model_parallel(tp, src_pp, order='tp-pp-dp') - src_dp_group_idx = torch.distributed.get_rank(parallel_state.get_model_parallel_group()) - Utils.initialize_model_parallel(tp, dest_pp, order='tp-pp-dp') - dest_dp_group_idx = torch.distributed.get_rank(parallel_state.get_model_parallel_group()) - num_dest_dp_groups = tp * dest_pp - - is_dp_rank_zero = parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0 - if src_dp_group_idx == dest_dp_group_idx and is_dp_rank_zero: - same_dp_group_idx = src_dp_group_idx - else: - same_dp_group_idx = None - - same_groups = [None] * Utils.world_size - torch.distributed.all_gather_object(same_groups, same_dp_group_idx) - - same_groups = set(g for g in same_groups if g is not None) - # Check each dst group has at least 1 rank both in src and dest - assert same_groups == set(range(num_dest_dp_groups)) - - -class TestFP32Optimizer: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp'), [((2, 4), (2, 4)), ((2, 4), (4, 2)), ((8, 1), (1, 2))] - ) - def test_fp32_optimizer_resharding(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp): - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - - def preprocess_fn(optim_common_dict): - import copy - - preprocessed_optimzier_common_dict = copy.deepcopy(optim_common_dict) - list = preprocessed_optimzier_common_dict['optimizer']['param_groups'] - for dict_item in list: - del dict_item['wd_mult'] - return preprocessed_optimzier_common_dict - - Utils.initialize_model_parallel(*src_tp_pp) - with TempNamedDir( - tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=True - ) as ckpt_dir_A: - with TempNamedDir( - tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=True - ) as ckpt_dir_B: - - model_A, optimizer_A = setup_model_and_optimizer( - seed=2, - tp=src_tp_pp[0], - pp=src_tp_pp[1], - initialize_fn=initialize_small_model, - bf16=False, - ) - - save( - optimizer_A.sharded_state_dict(model_A[0].sharded_state_dict()), - ckpt_dir_A, - preprocess_common_before_consistancy_check=preprocess_fn, - ) - Utils.destroy_model_parallel() - - # Load checkpoint A with different TP/PP and save as checkpoint B - Utils.initialize_model_parallel(*dest_tp_pp) - model_B, optimizer_B = setup_model_and_optimizer( - seed=3, - tp=dest_tp_pp[0], - pp=dest_tp_pp[1], - initialize_fn=initialize_small_model, - bf16=False, - ) - load_sharded_state_dict = optimizer_B.sharded_state_dict( - model_B[0].sharded_state_dict(), is_loading=True - ) - state_dict = load(load_sharded_state_dict, ckpt_dir_A) - - optimizer_B.load_state_dict(state_dict) - save(optimizer_B.sharded_state_dict(model_B[0].sharded_state_dict()), ckpt_dir_B) - Utils.destroy_model_parallel() - - # Test both checkpoints are equal - Utils.initialize_model_parallel(1, 1) - plain_state_dict_A = load_plain_tensors(ckpt_dir_A) - plain_state_dict_B = load_plain_tensors(ckpt_dir_B) - diffs = diff(plain_state_dict_A, plain_state_dict_B) - assert not any(map(bool, diffs)), diffs - - -class TestOptimizerResharding: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - ('use_dist_opt', 'bf16', 'fully_parallel'), - ( - (False, True, False), # regular BF16 - (True, True, False), # DistOpt BF16 - (True, True, True), # DistOpt BF16 - (False, False, False), # FP32 - ), - ) - @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp'), - [((2, 4), (2, 4)), ((2, 4), (2, 2)), ((2, 4), (4, 2)), ((8, 1), (1, 2))], - ) - @pytest.mark.parametrize( - "initialize_fn", [initialize_gpt_model, initialize_1d_flatten_tensor_model] - ) - def test_optimizer_resharding( - self, - tmp_path_dist_ckpt, - src_tp_pp, - dest_tp_pp, - use_dist_opt, - bf16, - initialize_fn, - fully_parallel, - ): - Utils.initialize_model_parallel(*src_tp_pp) - with TempNamedDir( - tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=False - ) as ckpt_dir_A: - with TempNamedDir( - tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=False - ) as ckpt_dir_B: - model_A, optimizer_A = setup_model_and_optimizer( - seed=2, - tp=src_tp_pp[0], - pp=src_tp_pp[1], - bf16=bf16, - dist_opt=use_dist_opt, - initialize_fn=initialize_fn, - ) - - if fully_parallel: - metadata = {'distrib_optim_sharding_type': 'fully_sharded_model_space'} - else: - metadata = {'distrib_optim_sharding_type': 'fully_reshardable'} - - save( - optimizer_A.sharded_state_dict( - model_A[0].sharded_state_dict(), metadata=metadata - ), - ckpt_dir_A, - ) - Utils.destroy_model_parallel() - - # Load checkpoint A with different TP/PP and save as checkpoint B - Utils.initialize_model_parallel(*dest_tp_pp) - model_B, optimizer_B = setup_model_and_optimizer( - seed=3, - tp=dest_tp_pp[0], - pp=dest_tp_pp[1], - bf16=bf16, - dist_opt=use_dist_opt, - initialize_fn=initialize_fn, - ) - load_sharded_state_dict = optimizer_B.sharded_state_dict( - model_B[0].sharded_state_dict(), metadata=metadata, is_loading=True - ) - state_dict = load(load_sharded_state_dict, ckpt_dir_A) - - optimizer_B.load_state_dict(state_dict) - save( - optimizer_B.sharded_state_dict( - model_B[0].sharded_state_dict(), metadata=metadata - ), - ckpt_dir_B, - ) - Utils.destroy_model_parallel() - - # Test both checkpoints are equal - Utils.initialize_model_parallel(1, 1) - plain_state_dict_A = load_plain_tensors(ckpt_dir_A) - plain_state_dict_B = load_plain_tensors(ckpt_dir_B) - diffs = diff(plain_state_dict_A, plain_state_dict_B) - assert not any(map(bool, diffs)), diffs - - @pytest.mark.parametrize('fully_parallel', (False, True)) - @pytest.mark.parametrize(('use_te', 'use_grouped_mlp'), ((False, False), (False, True))) - @pytest.mark.parametrize('use_glu', [False, True]) - @pytest.mark.parametrize( - ('src_tp_pp_exp', 'dest_tp_pp_exp'), - [ - ((2, 2, 2), (2, 2, 2)), - ((4, 1, 2), (1, 2, 2)), - ((1, 1, 2), (1, 1, 4)), - ((2, 1, 2), (1, 1, 8)), - ], - ) - def test_chained_optimizer_resharding( - self, - tmp_path_dist_ckpt, - src_tp_pp_exp, - dest_tp_pp_exp, - use_te, - use_grouped_mlp, - use_glu, - fully_parallel, - ): - src_tp, src_pp, src_exp = src_tp_pp_exp - dest_tp, dest_pp, dest_exp = dest_tp_pp_exp - with TempNamedDir( - tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=False - ) as ckpt_dir_A: - with TempNamedDir( - tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=False - ) as ckpt_dir_B: - Utils.initialize_model_parallel(src_tp, src_pp, expert_model_parallel_size=src_exp) - model_A, optimizer_A = setup_moe_model_and_optimizer( - seed=2, - tp=src_tp, - pp=src_pp, - ep=src_exp, - bf16=True, - dist_opt=True, - use_te=use_te, - use_grouped_mlp=use_grouped_mlp, - use_glu=use_glu, - ) - - if fully_parallel: - metadata = {'distrib_optim_sharding_type': 'fully_sharded_model_space'} - else: - metadata = {'distrib_optim_sharding_type': 'fully_reshardable'} - - save( - optimizer_A.sharded_state_dict( - model_A[0].sharded_state_dict(), metadata=metadata - ), - ckpt_dir_A, - ) - Utils.destroy_model_parallel() - - # Load checkpoint A with different TP/PP and save as checkpoint B - Utils.initialize_model_parallel( - dest_tp, dest_pp, expert_model_parallel_size=dest_exp - ) - model_B, optimizer_B = setup_moe_model_and_optimizer( - seed=3, - tp=dest_tp, - pp=dest_pp, - ep=dest_exp, - bf16=True, - dist_opt=True, - use_te=use_te, - use_grouped_mlp=use_grouped_mlp, - use_glu=use_glu, - ) - load_sharded_state_dict = optimizer_B.sharded_state_dict( - model_B[0].sharded_state_dict(), metadata=metadata, is_loading=True - ) - state_dict = load(load_sharded_state_dict, ckpt_dir_A) - - optimizer_B.load_state_dict(state_dict) - save( - optimizer_B.sharded_state_dict( - model_B[0].sharded_state_dict(), metadata=metadata - ), - ckpt_dir_B, - ) - Utils.destroy_model_parallel() - - # Test both checkpoints are equal - Utils.initialize_model_parallel(1, 1) - plain_state_dict_A = load_plain_tensors(ckpt_dir_A) - plain_state_dict_B = load_plain_tensors(ckpt_dir_B) - diffs = diff(plain_state_dict_A, plain_state_dict_B) - assert not any(map(bool, diffs)), diffs - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/test_pipeline_parallel_layout.py b/tests/unit_tests/dist_checkpointing/test_pipeline_parallel_layout.py deleted file mode 100644 index 8768ae6ed4..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_pipeline_parallel_layout.py +++ /dev/null @@ -1,404 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import os -from types import SimpleNamespace - -import pytest -import torch - -from megatron.core import mpu -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_with_transformer_engine_spec as gpt_te_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.num_microbatches_calculator import ( - init_num_microbatches_calculator, - unset_num_microbatches_calculator, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training.checkpointing import load_checkpoint, save_checkpoint -from megatron.training.global_vars import set_args -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.dist_checkpointing.models.common import ( - common_test_parallel_reconfiguration_e2e, -) -from tests.unit_tests.test_utilities import Utils - - -def initialize_gpt_model( - seed, - layer_spec_fn=gpt_te_spec, - vocab_size=128, - virtual_pipeline_model_parallel_size=None, - is_moe=False, - **config_kwargs, -): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - default_config_kwargs = dict( - num_layers=8, - hidden_size=16, - num_attention_heads=8, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - bf16=True, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - ) - default_config_kwargs.update(**config_kwargs) - transformer_config = TransformerConfig(**default_config_kwargs) - if is_moe: - transformer_config.moe_layer_freq = [0, 1, 1, 1, 1, 0, 1, 0] - transformer_config.moe_ffn_hidden_size = 128 - transformer_config.num_moe_experts = 4 - model = [] - for i in range(virtual_pipeline_model_parallel_size or 1): - if is_moe: - layer_spec = layer_spec_fn(transformer_config, use_transformer_engine=True, vp_stage=i) - else: - layer_spec = layer_spec_fn() - pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) - post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) - this_model = GPTModel( - config=transformer_config, - transformer_layer_spec=layer_spec, - vocab_size=vocab_size, - max_sequence_length=4, - pre_process=pre_process, - post_process=post_process, - vp_stage=i, - ) - this_model.model_type = ModelType.encoder_or_decoder - model.append(this_model) - - with torch.no_grad(): - for m in model: - for p in m.parameters(): - p.random_() - if virtual_pipeline_model_parallel_size is None: - model = model[0] - return model - - -# Dense Model Only -@pytest.mark.internal -def test_save_and_load_checkpoint_pp(tmp_path_dist_ckpt): - src_layer_spec_fn = gpt_te_spec - dst_layer_spec_fn = gpt_te_spec - use_fpsl = False - load_order = 'tp-dp-pp' - store_order = 'tp-dp-pp' - src_tp_pp = (1, 4) - src_model_init_kwargs = { - "pipeline_model_parallel_layout": [ - ["embedding"] + ["decoder"] * 2, - ["decoder"] * 3, - [], - ["decoder"] * 3 + ["loss"], - ] - } - dest_tp_pp = (2, 1) - - common_test_parallel_reconfiguration_e2e( - initialize_gpt_model, - tmp_path_dist_ckpt, - src_tp_pp, - dest_tp_pp, - src_layer_spec_fn, - dst_layer_spec_fn, - use_fpsl, - load_order, - store_order, - src_model_init_kwargs=src_model_init_kwargs, - ) - - -@pytest.fixture -def create_args(): - """Setup dummy args.""" - args = SimpleNamespace() - args.finetune = False - args.non_persistent_global_ckpt_dir = None - args.non_persistent_ckpt_type = None - args.non_persistent_save_interval = None - args.exit_on_missing_checkpoint = True - args.async_save = False - args.data_parallel_random_init = False - args.log_progress = False - args.ckpt_fully_parallel_save = False - args.ckpt_fully_parallel_load = False - args.auto_detect_ckpt_format = False - args.retro_add_retriever = False - args.ckpt_convert_update_legacy_dist_opt_format = False - args.ckpt_step = None - args.use_dist_ckpt = True - args.consumed_train_samples = 0 - args.skipped_train_samples = 0 - args.consumed_valid_samples = 0 - args.vocab_file = None - args.add_position_embedding = False - args.ckpt_assume_constant_structure = True - args.dist_ckpt_strictness = "assume_ok_unexpected" - args.fp16 = False - args.bf16 = True - args.no_save_optim = True - args.no_save_rng = True - args.no_load_optim = True - args.no_load_rng = True - args.use_distributed_optimizer = True - args.use_megatron_fsdp = False - - yield args - - -# Dense and MoE Models -@pytest.mark.parametrize( - ('src_tp_pp_vpp', 'dst_tp_pp_vpp', 'src_pp_layout', 'dst_pp_layout', 'is_moe'), - [ - ( - (1, 4, 2), - (1, 2, 1), - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - [["embedding"] + ["decoder"] * 4, ["decoder"] * 4 + ["loss"]], - False, - ), - ( - (1, 4, 2), - (2, 1, None), - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - None, - False, - ), - ( - (4, 1, None), - (1, 2, 1), - None, - [["embedding"] + ["decoder"] * 4, ["decoder"] * 4 + ["loss"]], - False, - ), - ( - (1, 2, 1), - (1, 4, 2), - [["embedding"] + ["decoder"] * 4, ["decoder"] * 4 + ["loss"]], - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - False, - ), - ( - (1, 4, 2), - (1, 2, 1), - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - [["embedding"] + ["decoder"] * 4, ["decoder"] * 4 + ["loss"]], - True, - ), - ( - (1, 4, 2), - (2, 1, None), - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - None, - True, - ), - ( - (4, 1, None), - (1, 2, 1), - None, - [["embedding"] + ["decoder"] * 4, ["decoder"] * 4 + ["loss"]], - True, - ), - ( - (1, 2, 1), - (1, 4, 2), - [["embedding"] + ["decoder"] * 4, ["decoder"] * 4 + ["loss"]], - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - True, - ), - ], -) -def test_save_and_load_checkpoint_vpp( - create_args, - tmp_path_dist_ckpt, - src_tp_pp_vpp, - src_pp_layout, - dst_tp_pp_vpp, - dst_pp_layout, - is_moe, -): - args = create_args - # Model config - args.num_layers = 8 - args.hidden_size = 8 - args.num_attention_heads = 8 - # Ckpt format - args.ckpt_format = "torch_dist" - set_args(args) - - def set_tp_pp_vpp(tp, pp, vpp=None, pp_layout=None, destroy_first=True): - if destroy_first: - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(tp, pp, vpp) - args.tensor_model_parallel_size = tp - args.pipeline_model_parallel_size = pp - args.virtual_pipeline_model_parallel_size = vpp - args.pipeline_model_parallel_layout = pp_layout - - def set_ckpt_path(ckpt_path): - args.save = ckpt_path - args.load = ckpt_path - - set_tp_pp_vpp(*src_tp_pp_vpp, pp_layout=src_pp_layout, destroy_first=False) - init_num_microbatches_calculator(0, None, 1, 1, 1) - - iteration = 123 - layer_spec_fn = get_gpt_decoder_block_spec if is_moe else gpt_te_spec - model = initialize_gpt_model( - 1, - layer_spec_fn=layer_spec_fn, - num_layers=args.num_layers, - hidden_size=args.hidden_size, - num_attention_heads=args.num_attention_heads, - tensor_model_parallel_size=args.tensor_model_parallel_size, - pipeline_model_parallel_size=args.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_layout=args.pipeline_model_parallel_layout, - is_moe=is_moe, - ) - model = model if isinstance(model, list) else [model] - optimizer = None - opt_param_scheduler = None - num_floating_point_operations_so_far = 456 - - with ( - TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_A') as ckpt_dir_A, - TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_B') as ckpt_dir_B, - ): - set_ckpt_path(ckpt_dir_A) - save_checkpoint( - iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far - ) - - expected_ckpt_path = args.save / "iter_0000123" / ".metadata" - assert os.path.exists(expected_ckpt_path) - - set_tp_pp_vpp(*dst_tp_pp_vpp, pp_layout=dst_pp_layout) - new_model = initialize_gpt_model( - 2, - layer_spec_fn=layer_spec_fn, - num_layers=args.num_layers, - hidden_size=args.hidden_size, - num_attention_heads=args.num_attention_heads, - tensor_model_parallel_size=args.tensor_model_parallel_size, - pipeline_model_parallel_size=args.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_layout=args.pipeline_model_parallel_layout, - is_moe=is_moe, - ) - new_model = new_model if isinstance(new_model, list) else [new_model] - - load_checkpoint(new_model, optimizer, opt_param_scheduler, strict=False) - set_ckpt_path(ckpt_dir_B) - save_checkpoint( - iteration, - new_model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - ) - - set_tp_pp_vpp(1, 1) - set_ckpt_path(ckpt_dir_A) - model_A = initialize_gpt_model( - 123, - layer_spec_fn=layer_spec_fn, - num_layers=args.num_layers, - hidden_size=args.hidden_size, - num_attention_heads=args.num_attention_heads, - tensor_model_parallel_size=args.tensor_model_parallel_size, - pipeline_model_parallel_size=args.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_layout=args.pipeline_model_parallel_layout, - is_moe=is_moe, - ) - load_checkpoint([model_A], optimizer, opt_param_scheduler, strict=False) - - set_ckpt_path(ckpt_dir_B) - model_B = initialize_gpt_model( - 321, - layer_spec_fn=layer_spec_fn, - num_layers=args.num_layers, - hidden_size=args.hidden_size, - num_attention_heads=args.num_attention_heads, - tensor_model_parallel_size=args.tensor_model_parallel_size, - pipeline_model_parallel_size=args.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_layout=args.pipeline_model_parallel_layout, - is_moe=is_moe, - ) - load_checkpoint([model_B], optimizer, opt_param_scheduler, strict=False) - - for k in model_A.state_dict(): - if "_extra_state" in k: # Ignore extra states - continue - tensor_a = model_A.state_dict()[k] - tensor_b = model_B.state_dict()[k] - assert tensor_a is not None, k - assert tensor_b is not None, k - assert torch.equal(tensor_a, tensor_b), k - - Utils.destroy_model_parallel() - unset_num_microbatches_calculator() diff --git a/tests/unit_tests/dist_checkpointing/test_replication.py b/tests/unit_tests/dist_checkpointing/test_replication.py deleted file mode 100644 index d4e792bf71..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_replication.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import os -from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from shutil import rmtree -from typing import Any, Dict, List, Optional -from unittest import mock - -import pytest -import torch -import torch.distributed as dist - -from megatron.training.arguments import parse_args - -nvidia_resiliency_ext = pytest.importorskip( - "nvidia_resiliency_ext", - reason="nvidia_resiliency_ext is required for local checkpointing tests", -) - -from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import ( - LocalCheckpointManager, -) -from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import GroupWrapper -from nvidia_resiliency_ext.checkpointing.local.replication.strategies import ( - CliqueReplicationStrategy, -) - -from megatron.training.async_utils import maybe_finalize_async_save -from megatron.training.checkpointing import load_checkpoint, save_checkpoint -from tests.unit_tests.dist_checkpointing import ( - TempNamedDir, - init_basic_mock_args, - init_checkpointing_mock_args, - setup_model_and_optimizer, -) -from tests.unit_tests.test_utilities import Utils - - -def equal_(a, b): - def bool_generator(): - if isinstance(a, list): - yield isinstance(b, list) - yield len(a) == len(b) - yield all(equal_(aa, bb) for aa, bb in zip(a, b)) - elif isinstance(a, torch.Tensor): - yield isinstance(b, torch.Tensor) - yield torch.equal(a, b) - else: - yield a == b - - return all(bool_generator()) - - -@pytest.mark.parametrize(('tp,pp'), [(2, 4), (1, 1)]) -def test_all_gather_batch(tp, pp): - Utils.initialize_model_parallel(tp, pp) - torch.cuda.set_device(dist.get_rank()) - t0 = torch.arange(4, device="cuda").reshape((2, 2)) - t1 = torch.arange(6, device="cuda").reshape((3, 1, 2)) - t2 = torch.arange(12, device="cuda").reshape((2, 3, 2)) - test_ranks = [0, 3, 7] - test_group = GroupWrapper(dist.new_group(test_ranks)) - rank = dist.get_rank() - if rank not in test_ranks: - dist.barrier() - return - batch = [[t1, t2], [t0], []] - pred_batch = test_group.all_gather_batch(batch[test_group.my_group_rank]) - assert equal_(batch, pred_batch) - dist.barrier() - - -# TODO: Use mock local checkpointing? -@pytest.mark.parametrize(('tp,pp'), [(2, 4), (1, 1)]) -@pytest.mark.parametrize(('async_save'), [True, False]) -@pytest.mark.parametrize(('algo'), ['atomic', 'fully_parallel']) -@pytest.mark.parametrize( - ("repl_groups"), [[[0, 1], [2, 3], [4, 5], [6, 7]], [[2, 6, 7], [3, 1], [5], [0, 4]]] -) -class TestLocalCheckpointingReplication: - # tp: int - # pp: int - # async_save: bool - # algo: str - # repl_groups: List[List[int]] - # # To be filled by post_init - # checkpointing_context: Optional[Dict[str, LocalCheckpointManager]] - # repl_groups: Optional[List[dist.ProcessGroup]] - # local_ckpt_dir: Optional[Path] - - @contextmanager - def post_init(self, root_tmp_dir, tp, pp, async_save, algo, repl_groups): - Utils.initialize_model_parallel(tp, pp) - - mock_args = parse_args(ignore_unknown_args=True) - with ( - mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args), - mock.patch('megatron.training.async_utils.get_args', new=lambda: mock_args), - mock.patch("megatron.training.checkpointing.update_num_microbatches"), - ): - self.local_ckpt_dir = ( - root_tmp_dir / "subdir" - ) # Test handling of non-existent directories - init_basic_mock_args(mock_args, tp, pp) - init_checkpointing_mock_args(mock_args, None) - mock_args.non_persistent_ckpt_type = 'local' - mock_args.non_persistent_local_ckpt_algo = algo - mock_args.async_save = async_save - mock_args.ckpt_fully_parallel_save = True # ensure proper sharding_type is set - repl_groups_init = [dist.new_group(g) for g in repl_groups] - my_process_group = GroupWrapper.from_list_of_groups(repl_groups_init) - repl_strategy = CliqueReplicationStrategy(my_process_group, target_device="cpu") - self.checkpointing_context = { - 'local_checkpoint_manager': LocalCheckpointManager( - self.local_ckpt_dir, repl_strategy=repl_strategy - ) - } - self.local_ckpt_dir /= str(dist.get_rank()) - yield - Utils.destroy_model_parallel() - - def test_repl_save_and_load(self, tmp_dir_per_class, tp, pp, async_save, algo, repl_groups): - with self.post_init(tmp_dir_per_class, tp, pp, async_save, algo, repl_groups): - num_floating_point_operations_so_far = 0 - model, optimizer = setup_model_and_optimizer(1, tp, pp) - opt_param_scheduler = None - - save_checkpoint( - 1, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - checkpointing_context=self.checkpointing_context, - non_persistent_ckpt=True, - ) - if async_save: - maybe_finalize_async_save(True) - - my_group = [group for group in repl_groups if dist.get_rank() in group][0] - assert {f"iter_0000001_{rank}_local.pt" for rank in my_group} == { - f.name for f in self.local_ckpt_dir.rglob("*") - } - with self.post_init(tmp_dir_per_class, tp, pp, async_save, algo, repl_groups): - - ranks_to_break = [6, 3, 4] - if dist.get_rank() in ranks_to_break: - rmtree(self.local_ckpt_dir) - os.makedirs(self.local_ckpt_dir) - - model, optimizer = setup_model_and_optimizer(2, tp, pp) - opt_param_scheduler = None - - iteration, _ = load_checkpoint( - model, - optimizer, - opt_param_scheduler, - checkpointing_context=self.checkpointing_context, - ) - assert iteration == 1 - # Perform cleanup to ensure no side effects on subsequent tests - torch.distributed.barrier() - rmtree(self.local_ckpt_dir) diff --git a/tests/unit_tests/dist_checkpointing/test_safe_globals.py b/tests/unit_tests/dist_checkpointing/test_safe_globals.py deleted file mode 100755 index dc09b2a292..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_safe_globals.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import os -from argparse import Namespace -from collections import OrderedDict -from pickle import UnpicklingError - -import pytest -import torch - -from megatron.core.utils import is_torch_min_version - - -class UnsafeClass: - def __init__(self, value): - self.value = value - - def __repr__(self): - return f"UnsafeClass(value={self.value})" - - -class TestSafeGlobals: - def test_safe_globals(self, tmp_path_dist_ckpt): - # create dummy checkpoint - ckpt_path = tmp_path_dist_ckpt / "test_safe_globals.pt" - dummy_obj = Namespace(dummy_value=0) - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - torch.save(dummy_obj, ckpt_path) - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - torch.load(ckpt_path) - - @pytest.mark.skipif(not is_torch_min_version("2.6a0"), reason="PyTorch 2.6 is required") - def test_unsafe_globals(self, tmp_path_dist_ckpt): - # create dummy checkpoint - ckpt_path = tmp_path_dist_ckpt / "test_safe_globals.pt" - dummy_obj = UnsafeClass(123) - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - torch.save(dummy_obj, ckpt_path) - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - # expected error - with pytest.raises(UnpicklingError): - torch.load(ckpt_path) - - # add class to safe globals - torch.serialization.add_safe_globals([UnsafeClass]) - torch.load(ckpt_path) diff --git a/tests/unit_tests/dist_checkpointing/test_serialization.py b/tests/unit_tests/dist_checkpointing/test_serialization.py deleted file mode 100644 index 1a50be17d2..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_serialization.py +++ /dev/null @@ -1,1066 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import io -import logging -import os - -import numpy as np -import pytest -import torch -from torch.distributed.checkpoint import CheckpointException as PyTCheckpointingException -from torch.distributed.checkpoint import FileSystemReader - -try: - from torch.distributed import DeviceMesh - from torch.distributed._tensor import DTensor - - HAVE_DTENSOR = True -except ImportError: - HAVE_DTENSOR = False - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing import ( - ShardedTensor, - load, - load_content_metadata, - remove_sharded_tensors, - save, -) -from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config -from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory -from megatron.core.dist_checkpointing.serialization import ( - load_sharded_metadata, - load_tensors_metadata, -) -from megatron.core.dist_checkpointing.strategies.base import StrategyAction, get_default_strategy -from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy -from megatron.core.dist_checkpointing.validation import StrictHandling -from megatron.core.utils import is_torch_min_version -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -class TestSerialization: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_single_process_save_load(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(1, 1) - - sharded_state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), replica_id=Utils.rank - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.ones(3, 5, 7), replica_id=Utils.rank - ), - } - - if HAVE_DTENSOR: - mesh = DeviceMesh.from_group( - parallel_state.get_data_parallel_group(with_context_parallel=True), "cuda" - ) - sharded_state_dict['sd_keyD'] = ShardedTensor.from_rank_offsets( - 'keyD', - DTensor.from_local(torch.ones(3, 5, 7), mesh)._local_tensor, - replica_id=Utils.rank, - ) - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir( - tmp_path_dist_ckpt / 'test_single_process_save_load', sync=True - ) as ckpt_dir: - save(sharded_state_dict, ckpt_dir) - torch.distributed.barrier() - - saved_config = maybe_load_config(ckpt_dir) - if saved_config.sharded_backend == 'zarr': - assert (ckpt_dir / 'keyA').is_dir() - assert (ckpt_dir / 'keyB').is_dir() - assert not (ckpt_dir / 'keyC').exists() - assert not (ckpt_dir / 'sd_keyA').is_dir() - - if HAVE_DTENSOR: - assert (ckpt_dir / 'keyD').is_dir() - - load_ssd = { - 'load_sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), replica_id=Utils.rank - ) - } - loaded_state_dict = load(load_ssd, ckpt_dir) - - assert set(loaded_state_dict.keys()) == {'load_sd_keyA'} - assert isinstance(loaded_state_dict['load_sd_keyA'], torch.Tensor) - assert loaded_state_dict['load_sd_keyA'].shape == (2, 4) - - Utils.destroy_model_parallel() - - def test_multi_process_save(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 4) - - state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size) - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size) - ), - 'lr': 0.01, - 'rank': torch.distributed.get_rank(), - } - - def preprocess_fn(x): - del x['rank'] - return x - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir(tmp_path_dist_ckpt / 'test_multi_process_save', sync=True) as ckpt_dir: - save( - state_dict, - ckpt_dir, - validate_access_integrity=True, - preprocess_common_before_consistancy_check=preprocess_fn, - ) - - saved_config = maybe_load_config(ckpt_dir) - if saved_config.sharded_backend == 'zarr': - assert (ckpt_dir / 'keyA').is_dir() - assert (ckpt_dir / 'keyB').is_dir() - assert not (ckpt_dir / 'keyC').exists() - assert not (ckpt_dir / 'sd_keyA').is_dir() - - Utils.destroy_model_parallel() - - def test_multi_process_save_log_difference(self, tmp_path_dist_ckpt, caplog): - Utils.initialize_model_parallel(2, 4) - rank = Utils.rank - world_size = Utils.world_size - - state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), (0, rank, world_size) - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.ones(3, 5, 7), (2, rank, world_size) - ), - 'rank': rank, - } - - def preprocess_fn(x): - return x - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir( - tmp_path_dist_ckpt / 'test_multi_process_save_log_difference', sync=True - ) as ckpt_dir: - with caplog.at_level(logging.WARNING): - save( - state_dict, - ckpt_dir, - validate_access_integrity=True, - preprocess_common_before_consistancy_check=preprocess_fn, - ) - - if rank == 0: - # Rank 0 should not log the warning related to common state dict difference - assert not any( - f"Rank {rank} common state dict differs from rank 0 common state dict." - in record.message - for record in caplog.records - ) - else: - found_detailed_match = False - # Construct the expected full message string based on user request - expected_full_message = ( - f"Rank {rank} common state dict differs from rank 0 common state dict. " - f"Keys only on rank 0: [], " - f"Keys only on {rank}: [], " - f"Mismatched keys: [(('rank',), , )]" - ) - - for record in caplog.records: - if record.message == expected_full_message: - found_detailed_match = True - break - - assert ( - found_detailed_match - ), f"Did not find expected log message format for mismatch on rank {rank}. Expected: {expected_full_message}" - - Utils.destroy_model_parallel() - - def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None): - Utils.initialize_model_parallel(2, 4) - - # ten_a: global shape (2, 4): - ten_a_global = torch.tensor([[0, 1, 2, 3], [10, 11, 12, 13]]) - ten_a = ( - torch.zeros(1, 1) - + 10 * parallel_state.get_tensor_model_parallel_rank() - + parallel_state.get_pipeline_model_parallel_rank() - ) - assert ten_a.shape == (1, 1) - - # ten_b: global shape (4, 5, 80), where (x, y, z) is (100x + z) - ten_b = torch.zeros(4, 5, 10) + (torch.arange(10) + 10 * Utils.rank) - ten_b += torch.arange(4).unsqueeze(-1).unsqueeze(-1) * 100 - assert ten_b.shape == (4, 5, 10) - - state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', - ten_a, - ( - 0, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_tensor_model_parallel_world_size(), - ), - ( - 1, - parallel_state.get_pipeline_model_parallel_rank(), - parallel_state.get_pipeline_model_parallel_world_size(), - ), - replica_id=0, - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', ten_b, (2, Utils.rank, Utils.world_size) - ), - } - - ten_a_global_shape = ten_a_global.shape - ten_b_global_shape = (4, 5, 10 * 8) - - assert state_dict['sd_keyA'].local_shape == (1, 1) - assert state_dict['sd_keyA'].global_shape == ten_a_global_shape - assert state_dict['sd_keyB'].global_shape == ten_b_global_shape - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir( - tmp_path_dist_ckpt / 'test_partition_change_save_load', sync=True - ) as ckpt_dir: - save(state_dict, ckpt_dir, strategy) - - del ten_a, ten_b - - # without changing TPxPP, load tensors without any sharding - load_sd = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.empty(ten_a_global_shape), replica_id=Utils.rank - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.empty(ten_b_global_shape), replica_id=Utils.rank - ), - } - loaded_state_dict = load(load_sd, ckpt_dir) - - ten_a = loaded_state_dict['sd_keyA'] - ten_b = loaded_state_dict['sd_keyB'] - assert isinstance(ten_a, torch.Tensor) - assert ten_a.shape == ten_a_global_shape - assert torch.all(ten_a == ten_a_global) - - assert isinstance(ten_b, torch.Tensor) - assert ten_b.shape == ten_b_global_shape - assert np.all( - [ - val == 100 * x + z - for x, x_row in enumerate(ten_b) - for y, y_row in enumerate(x_row) - for z, val in enumerate(y_row) - ] - ) - - del ten_a, ten_b - - # change TPxPP - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(1, 2) - - load_sd = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', - torch.empty(2, 1), - ( - 1, - parallel_state.get_data_parallel_rank(), - parallel_state.get_data_parallel_world_size(), - ), - replica_id=parallel_state.get_pipeline_model_parallel_rank(), - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', - torch.empty(5, 80), - (0, Utils.rank // 2, 4), - prepend_axis_num=1, - replica_id=Utils.rank % 2, - ), - } - - loaded_state_dict = load(load_sd, ckpt_dir) - ten_a = loaded_state_dict['sd_keyA'] - ten_b = loaded_state_dict['sd_keyB'] - - assert isinstance(ten_a, torch.Tensor) - assert ten_a.shape == (2, 1) - assert torch.all( - ten_a[:, 0] == ten_a_global[:, parallel_state.get_data_parallel_rank()] - ) - - assert isinstance(ten_b, torch.Tensor) - assert ten_b.shape == (5, 10 * 8) - assert torch.all( - ten_b == torch.arange(80).unsqueeze(0).expand(5, 80) + Utils.rank // 2 * 100 - ) - - def test_load_tensors_metadata(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 4) - - state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.arange(10) + Utils.rank * 10, (0, Utils.rank, Utils.world_size) - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size) - ), - } - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir(tmp_path_dist_ckpt / 'test_load_tensors_metadata', sync=True) as ckpt_dir: - save(state_dict, ckpt_dir) - - del state_dict - sharded_state_dict = load_tensors_metadata(ckpt_dir) - # loaded dict keys are ShardedTensor keys! - assert 'keyA' in sharded_state_dict - assert 'sd_keyA' not in sharded_state_dict - - # Check metadata - assert sharded_state_dict['keyA'].global_shape == (10 * Utils.world_size,) - assert sharded_state_dict['keyB'].global_shape == (3, 5, 7 * Utils.world_size) - assert sharded_state_dict['keyA'].local_shape == sharded_state_dict['keyA'].global_shape - assert sharded_state_dict['keyB'].local_shape == sharded_state_dict['keyB'].global_shape - assert sharded_state_dict['keyA'].global_offset == (0,) - assert sharded_state_dict['keyB'].global_offset == (0, 0, 0) - assert sharded_state_dict['keyA'].axis_fragmentations == (1,) - assert sharded_state_dict['keyB'].axis_fragmentations == (1, 1, 1) - assert sharded_state_dict['keyA'].replica_id == 0 - assert sharded_state_dict['keyB'].replica_id == 0 - - # metadata dict can be loaded. We don't validate access because there are multiple replica_id=0 - state_dict = load(sharded_state_dict, ckpt_dir, validate_access_integrity=False) - assert torch.all(state_dict['keyA'] == torch.arange(10 * Utils.world_size)) - - Utils.destroy_model_parallel() - - def test_can_mix_sharded_tensors_and_factories(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(1, 1) - - def _build_fn(key, tensor, replica_id, flattened_range): - assert flattened_range is None - return [ - ShardedTensor.from_rank_offsets(key + 'part1', tensor, replica_id=replica_id), - ShardedTensor.from_rank_offsets(key + 'part2', tensor, replica_id=replica_id), - ShardedTensor.from_rank_offsets(key + 'part3', tensor, replica_id=replica_id), - ] - - # state dict can be modified by dist_checkpointing.save, so two copies - def get_sharded_state_dict(base=0): - return { - 'all': [ - ShardedTensor.from_rank_offsets( - 'A', torch.arange(2) + base, replica_id=Utils.rank - ), - ShardedTensor.from_rank_offsets( - 'B', torch.arange(3) + base, replica_id=Utils.rank - ), - ShardedTensor.from_rank_offsets( - 'C', torch.arange(4) + base, replica_id=Utils.rank - ), - ShardedTensorFactory( - 'D', torch.arange(5) + base, _build_fn, sum, replica_id=Utils.rank - ), - ] - } - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir( - tmp_path_dist_ckpt / 'test_can_mix_sharded_tensors_and_factories', sync=True - ) as ckpt_dir: - save(get_sharded_state_dict(0), ckpt_dir) - loaded_state_dict = load(get_sharded_state_dict(10), ckpt_dir) - - expected_sd = { - 'all': [ - torch.arange(2), - torch.arange(3), - torch.arange(4), - torch.arange(5) * 3, # sum of three parts, as specified in merge_fn - ] - } - diffs = diff(loaded_state_dict, expected_sd) - assert not any(map(bool, diffs)), diffs - - Utils.destroy_model_parallel() - - def test_load_error_msg(self, tmp_path_dist_ckpt): - ckpt_dir_name = 'test_load_error_msg' - Utils.initialize_model_parallel(1, 1) - sh_ten = ShardedTensor.from_rank_offsets('keyA', torch.rand(10), replica_id=Utils.rank) - state_dict = {'some_key': sh_ten} - - # Non-existent directory - non_ex_path = f'/tmp/non-existent-path/{ckpt_dir_name}' - with pytest.raises(CheckpointingException) as exc_info: - load(state_dict, non_ex_path) - assert f'directory {non_ex_path} does not exist' in str(exc_info.value) - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir(tmp_path_dist_ckpt / ckpt_dir_name, sync=True) as ckpt_dir: - # Empty directory - not a distributed checkpoint - with pytest.raises(CheckpointingException) as exc_info: - load(state_dict, ckpt_dir) - assert f'is not a distributed checkpoint' in str(exc_info.value) - - # Missing Zarr arrays - torch.distributed.barrier() - save(state_dict, ckpt_dir) - sh_ten.key = 'different_key' - with pytest.raises((CheckpointingException, PyTCheckpointingException)) as exc_info: - load(state_dict, ckpt_dir) - assert "different_key" in str(exc_info.value) - - def test_sharded_object_serialization(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(1, 1) - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir(tmp_path_dist_ckpt / 'test_sh_obj', sync=True) as ckpt_dir: - state = {'some': 'dict'} - state_serialized = io.BytesIO() - torch.save(state, state_serialized) - state_dict = { - 'some_key': ShardedObject( - 'sh_obj_A', state_serialized, (1,), (0,), replica_id=Utils.rank - ) - } - - save(state_dict, ckpt_dir) - del state, state_serialized, state_dict - other_state = {'other': 'dictionary'} - other_serialized = io.BytesIO() - torch.save(other_state, other_serialized) - state_dict = { - 'other_key': ShardedObject( - 'sh_obj_A', other_serialized, (1,), (0,), replica_id=Utils.rank - ) - } - load_state_dict = load(state_dict, ckpt_dir) - assert 'other_key' in load_state_dict - load_state_dict['other_key'].seek(0) - loaded_state = torch.load(load_state_dict['other_key']) - - assert loaded_state == {'some': 'dict'} - - Utils.destroy_model_parallel() - - def test_tensor_shape_mismatch(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 4) - - # Global tensor is just a range(32) repeated twice over the first dimension - local_tensor = torch.arange(4).unsqueeze(0).expand(2, 4) + Utils.rank * 4 - - state_dict = { - 'rigid': ShardedTensor.from_rank_offsets( - 'keyA', local_tensor, (1, Utils.rank, Utils.world_size) - ), - 'flexible': ShardedTensor.from_rank_offsets( - 'keyB', local_tensor, (1, Utils.rank, Utils.world_size), allow_shape_mismatch=True - ), - } - assert state_dict['rigid'].global_shape == (2, 32) - assert state_dict['flexible'].global_shape == (2, 32) - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir(tmp_path_dist_ckpt / 'test_tensor_shape_mismatch', sync=True) as ckpt_dir: - save(state_dict, ckpt_dir) - - pp_size = parallel_state.get_pipeline_model_parallel_world_size() - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - tp_rank = parallel_state.get_tensor_model_parallel_rank() - - # Smaller coverage than expected (28 < 32) - state_dict = { - 'rigid': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 7), (1, pp_rank, pp_size), replica_id=tp_rank - ) - } - with pytest.raises((CheckpointingException, PyTCheckpointingException)): - load(state_dict, ckpt_dir) - - state_dict = { - 'flexible': ShardedTensor.from_rank_offsets( - 'keyB', - torch.ones(2, 7), - (1, pp_rank, pp_size), - replica_id=tp_rank, - allow_shape_mismatch=True, - ) - } - loaded_state_dict = load(state_dict, ckpt_dir) - assert torch.all( - loaded_state_dict['flexible'] - == torch.arange(7).unsqueeze(0).expand(2, 7) + pp_rank * 7 - ) - - # Larger coverage than expected (36 > 32) - state_dict = { - 'rigid': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 9), (1, pp_rank, pp_size), replica_id=tp_rank - ) - } - with pytest.raises((CheckpointingException, PyTCheckpointingException)): - load(state_dict, ckpt_dir) - - state_dict = { - 'flexible': ShardedTensor.from_rank_offsets( - 'keyB', - torch.ones(2, 9), - (1, pp_rank, pp_size), - replica_id=tp_rank, - allow_shape_mismatch=True, - ) - } - loaded_state_dict = load(state_dict, ckpt_dir) - expected_tensor = torch.arange(9).unsqueeze(0).expand(2, 9) + pp_rank * 9 - - if pp_rank >= (32 // 9): - assert pp_rank == 3, pp_rank - expected_tensor[:, 5:] = 0 # padding with 0s - assert torch.all(loaded_state_dict['flexible'] == expected_tensor) - - Utils.destroy_model_parallel() - - @pytest.mark.skipif( - not is_torch_min_version("2.3.0"), - reason="remove_sharded_tensors relies on Torch APIs introduced in v2.3.0", - ) - @pytest.mark.flaky_in_dev - def test_remove_sharded_tensors(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 4) - - # Global tensor is just a range(32) repeated twice over the first dimension - global_tensor = torch.arange(4).unsqueeze(0).expand(2, 4) - state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size) - ), - 'sd_prefix_key_to_remove': ShardedTensor.from_rank_offsets( - 'prefix_key_to_remove', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size) - ), - } - - prefix_name = "prefix" ## we will drop all tensors whose keys begin with "prefix" - - # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir( - tmp_path_dist_ckpt / 'test_remove_sharded_tensor_prefix', sync=True - ) as ckpt_dir: - save_strategy = TorchDistSaveShardedStrategy( - "torch_dist", 1, separation_hint=prefix_name - ) - save(state_dict, ckpt_dir, save_strategy) - - files = os.listdir(ckpt_dir) - prefix_files = [f for f in files if f.startswith(prefix_name)] - assert len(prefix_files) == torch.distributed.get_world_size() - - fs_reader = FileSystemReader(ckpt_dir) - original_metadata = fs_reader.read_metadata() - assert set(original_metadata.state_dict_metadata.keys()) == { - 'keyA', - 'prefix_key_to_remove', - } - - if torch.distributed.get_rank() == 0: - remove_sharded_tensors(ckpt_dir, key_prefix=prefix_name) - torch.distributed.barrier() - - files = os.listdir(ckpt_dir) - prefix_files = [f for f in files if f.startswith(prefix_name)] - assert len(prefix_files) == 0 - - new_metadata = fs_reader.read_metadata() - assert set(new_metadata.state_dict_metadata.keys()) == {'keyA'} - - Utils.destroy_model_parallel() - - def test_empty_load(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 4) - - if Utils.rank == 0: - state_dict = {'common': 'common-value'} - elif Utils.rank == 1: - state_dict = {'a': 3} # this is not saved at all (common saved by rank 0 only) - elif Utils.rank == 2: - state_dict = {'b': 3} # this is not saved at all (common saved by rank 0 only) - else: - state_dict = { - 'a': ShardedTensor.from_rank_offsets( - 'x', torch.ones((2,)) * Utils.rank, replica_id=Utils.rank - 3 - ) - } - - with TempNamedDir(tmp_path_dist_ckpt / 'test_empty_load', sync=True) as ckpt_dir: - save(state_dict, ckpt_dir) - torch.distributed.barrier() - loaded_state_dict = load(state_dict, ckpt_dir) - assert loaded_state_dict['common'] == 'common-value' - - if Utils.rank <= 2: - assert loaded_state_dict.keys() == {'common'} - else: - assert loaded_state_dict.keys() == {'common', 'a'} - loaded_state_dict['a'].cpu().numpy().tolist() == [ - 3, - 3, - ] # rank 3 held the main replica so did the saving - - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - 'content_metadata', [{'a': 3}, {'nested': {'a': 3}, 'flat': (5, {6: None})}, {}] - ) - def test_content_metadata_load_from_checkpoint(self, tmp_path_dist_ckpt, content_metadata): - Utils.initialize_model_parallel(1, 1) - state_dict = {'common': (3, 5, 7)} - - with TempNamedDir( - tmp_path_dist_ckpt / 'test_content_metadata_load_from_checkpoint', sync=True - ) as ckpt_dir: - save(state_dict, ckpt_dir, content_metadata=content_metadata) - torch.distributed.barrier() - loaded_metadata = load_content_metadata(ckpt_dir) - - assert loaded_metadata == content_metadata - - @pytest.mark.parametrize( - 'content_metadata', [{'a': 3}, {'nested': {'a': 3}, 'flat': (5, {6: None})}, {}] - ) - def test_content_metadata_load_from_state_dict(self, tmp_path_dist_ckpt, content_metadata): - Utils.initialize_model_parallel(1, 1) - state_dict = {'common': (3, 5, 7)} - - with TempNamedDir( - tmp_path_dist_ckpt / 'test_content_metadata_load_from_state_dict', sync=True - ) as ckpt_dir: - save(state_dict, ckpt_dir, content_metadata=content_metadata) - torch.distributed.barrier() - loaded_state_dict = load(state_dict, ckpt_dir) - loaded_metadata = load_content_metadata(preloaded_state_dict=loaded_state_dict) - - assert loaded_metadata == content_metadata - - @pytest.mark.parametrize( - ('src_split', 'dest_split'), - [ - # Same src and dest - ([3] * 8, None), - (list(range(1, 9)), None), - ([1, 5, 7, 3, 6, 2, 5, 4], None), - ([2, 2, 2, 2, 2, 2, 2, 1], None), - ([2, 2, 2, 2, 2, 2, 2, 10], None), - # Different src and dest - ([3] * 8, [1] * 6 + [2, 16]), - ([1, 5, 7, 3, 6, 2, 5, 4], [14, 3, 6, 3, 1, 1, 2, 3]), - # Empty shards - ([5] * 6 + [0, 0], [5, 0, 5, 0, 5, 5, 3, 7]), - ([15] + [0] * 7, [0, 0, 0] + [3] * 5), - ], - ) - @pytest.mark.skipif( - not is_torch_min_version("2.6a0"), - reason="CheckpointableShardedTensor requires PyTorch 2.6 or later", - ) - def test_uneven_1d_sharding(self, tmp_path_dist_ckpt, src_split, dest_split): - Utils.initialize_model_parallel(2, 4) - - if dest_split is None: - dest_split = src_split - - assert len(src_split) == Utils.world_size - assert len(dest_split) == len(src_split) - assert sum(src_split) == sum(dest_split) - - def _create_1d_sharded_tensor_based_on_split(split, content_split=None, key='a'): - # Split [a, b, c] means a global tensor of shape (a + b + c,), divided - # into 3 rank, with a, b, c, elements on each rank - global_shape = (sum(split),) # Sum of all splits - local_shape = (split[Utils.rank],) # Split size of this rank - global_offset = (sum(split[: Utils.rank]),) # Sum of all sizes before this rank - - if content_split is None: - data = torch.zeros(local_shape) - else: - data = torch.zeros(global_shape) - assert len(content_split) == len(split) - # Content split determines the data stored in the global tensor. - # Content split [a, b, c] means `a` zeros, `b` ones and `c` twos. - content_split = torch.cumsum(torch.tensor(content_split), 0) - for ( - idx - ) in content_split: # this handles `data[content_split] += 1` with repeating values - if idx < len(data): - data[idx] += 1 - else: - assert idx == len(data) - data = data.cumsum(0) - data = data[global_offset[0] : global_offset[0] + local_shape[0]] - assert data.shape == local_shape - return ShardedTensor( - key, data, data.dtype, data.shape, global_shape, global_offset, None - ) - - state_dict = {'a': _create_1d_sharded_tensor_based_on_split(src_split, dest_split)} - - with TempNamedDir(tmp_path_dist_ckpt / 'test_uneven_sharding', sync=True) as ckpt_dir: - save(state_dict, ckpt_dir) - torch.distributed.barrier() - - state_dict = {'a': _create_1d_sharded_tensor_based_on_split(dest_split)} - loaded_state_dict = load(state_dict, ckpt_dir) - assert torch.all(loaded_state_dict['a'] == Utils.rank) - - @pytest.mark.parametrize( - ('src_split', 'dest_split'), - [ - # Same src and dest - ([[3]] * 8, None), - ([[]] * 7 + [[3, 3]], None), - ([[4], [7, 8], [1], [1], [1], [1], [1], [3, 3]], None), - ([[2]] * 5 + [[10]] * 3, [[10]] * 3 + [[2]] * 5), - ( - [[4], [7, 8], [1], [1], [1], [1], [1], [3, 3]], - [[2, 4], [], [5], [], [5, 9], [], [1, 1, 1, 1, 1], []], - ), - ([[3]] * 8, [[2, 4]] * 4 + [[]] * 4), - ], - ) - @pytest.mark.skipif( - not is_torch_min_version("2.6a0"), - reason="CheckpointableShardedTensor requires PyTorch 2.6 or later", - ) - def test_uneven_1d_sharding_multiple_shards(self, tmp_path_dist_ckpt, src_split, dest_split): - """The same as test_uneven_1d_sharding but with multiple shards per rank. - - src_split and dest_split have now 2 levels. - """ - Utils.initialize_model_parallel(2, 4) - - if dest_split is None: - dest_split = src_split - - def nested_sum(x): - return sum(map(sum, x)) - - assert len(src_split) == Utils.world_size - assert len(dest_split) == len(src_split) - assert nested_sum(src_split) == nested_sum(dest_split) - - def _create_1d_sharded_tensors_based_on_split(split, content_split=None, key='a'): - # Split [a, b, c] means a global tensor of shape (a + b + c,), divided - # into 3 rank, with a, b, c, elements on each rank - global_shape = (nested_sum(split),) # Sum of all splits - global_offset_base = nested_sum( - split[: Utils.rank] - ) # Sum of all sizes before this rank - - local_shards = [] - for local_split in split[Utils.rank]: - local_shape = (local_split,) - global_offset = (global_offset_base,) - global_offset_base += local_split - - if content_split is None: - data = torch.zeros(local_shape) - else: - data = torch.zeros(global_shape) - assert len(content_split) == len(split) - # Content split determines the data stored in the global tensor. - # Content split [a, b, c] means `a` zeros, `b` ones and `c` twos. - cumsum_content_split = torch.cumsum( - torch.tensor(list(map(sum, content_split))), 0 - ) - for ( - idx - ) in ( - cumsum_content_split - ): # this handles `data[cumsum_content_split] += 1` with repeating values - if idx < len(data): - data[idx] += 1 - else: - assert idx == len(data) - data = data.cumsum(0) - data = data[global_offset[0] : global_offset[0] + local_shape[0]] - assert data.shape == local_shape - local_shards.append( - ShardedTensor( - key, data, data.dtype, data.shape, global_shape, global_offset, None - ) - ) - return local_shards - - state_dict = dict( - enumerate(_create_1d_sharded_tensors_based_on_split(src_split, dest_split)) - ) - - with TempNamedDir(tmp_path_dist_ckpt / 'test_uneven_sharding', sync=True) as ckpt_dir: - save(state_dict, ckpt_dir) - torch.distributed.barrier() - - state_dict = dict(enumerate(_create_1d_sharded_tensors_based_on_split(dest_split))) - loaded_state_dict = load(state_dict, ckpt_dir) - for local_shard in loaded_state_dict.values(): - assert torch.all(local_shard == Utils.rank) - - -class TestNonStrictLoad: - def setup_method(self, method): - Utils.initialize_model_parallel(2, 4) # doesn't matter for this test - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def _get_base_state_dict(self): - return { - 'TenA': ShardedTensor.from_rank_offsets('TenA', torch.arange(2), replica_id=Utils.rank), - 'TenB': ShardedTensor.from_rank_offsets( - 'TenB', torch.arange(3), (0, Utils.rank, Utils.world_size), replica_id=0 - ), - 'TenC': ShardedTensor.from_rank_offsets( - 'TenC', torch.arange(3), replica_id=Utils.world_size - Utils.rank - 1 - ), - 'ObjA': ShardedObject('ObjA', list(range(10)), (1,), (0,), replica_id=Utils.rank), - 'ObjB': ShardedObject( - 'ObjB', {Utils.rank + 7}, (1, Utils.world_size), (0, Utils.rank), replica_id=0 - ), - } - - @pytest.mark.parametrize('save_format', ['torch_dist']) - @pytest.mark.parametrize('validate_integrity', [True, False]) - def test_unexpected_keys_handling_during_validation( - self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format - ): - sharded_state_dict = self._get_base_state_dict() - with TempNamedDir( - tmp_path_dist_ckpt / 'test_unexpected_keys_raises_error_during_validation' - ) as ckpt_dir: - save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) - save(sharded_state_dict, ckpt_dir, save_strategy) - - def load_with_flag(strict): - sharded_state_dict = self._get_base_state_dict() - sharded_state_dict['TenD'] = ShardedTensor.from_rank_offsets( - 'UnexpectedTenD', torch.arange(3), replica_id=Utils.rank - ) - sharded_state_dict['ObjD'] = ShardedObject( - 'UnexpectedObjD', None, (1,), (0,), replica_id=Utils.rank - ) - return load( - sharded_state_dict, - ckpt_dir, - validate_access_integrity=validate_integrity, - strict=strict, - ) - - def test_error(error_msg): - assert 'Unexpected keys' in error_msg - assert 'UnexpectedTenD' in error_msg - assert 'UnexpectedObjD' in error_msg - assert 'Missing keys' not in error_msg - - # ASSUME_OK_UNEXPECTED results in an exception raised by the underlying strategy - with pytest.raises( - PyTCheckpointingException if save_format == 'torch_dist' else CheckpointingException - ) as exc_info: - load_with_flag(StrictHandling.ASSUME_OK_UNEXPECTED) - # Informative exceptions with `RAISE_*` options: - with pytest.raises(CheckpointingException) as exc_info: - load_with_flag(StrictHandling.RAISE_UNEXPECTED) - test_error(str(exc_info.value)) - with pytest.raises(CheckpointingException) as exc_info: - load_with_flag(StrictHandling.RAISE_ALL) - test_error(str(exc_info.value)) - - # Logged mismatches: - with caplog.at_level(logging.WARNING): - loaded_state_dict = load_with_flag(StrictHandling.LOG_UNEXPECTED) - assert 'TenA' in loaded_state_dict - test_error(caplog.text) - with caplog.at_level(logging.WARNING): - loaded_state_dict = load_with_flag(StrictHandling.LOG_ALL) - assert 'TenA' in loaded_state_dict - test_error(caplog.text) - - # Returned mismatches - loaded_state_dict, missing_keys, unexpected_keys = load_with_flag( - StrictHandling.RETURN_UNEXPECTED - ) - assert 'TenA' in loaded_state_dict - assert unexpected_keys == {'UnexpectedTenD', 'UnexpectedObjD'} - assert missing_keys == set() - loaded_state_dict, missing_keys, unexpected_keys = load_with_flag( - StrictHandling.RETURN_ALL - ) - assert 'TenA' in loaded_state_dict - assert unexpected_keys == {'UnexpectedTenD', 'UnexpectedObjD'} - assert missing_keys == set() - - # Ignore mismatch - loaded_state_dict = load_with_flag(StrictHandling.IGNORE_ALL) - assert 'TenA' in loaded_state_dict - - @pytest.mark.parametrize('save_format', ['torch_dist']) - @pytest.mark.parametrize('validate_integrity', [True, False]) - def test_missing_keys_raises_error_during_validation( - self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format - ): - sharded_state_dict = self._get_base_state_dict() - with TempNamedDir( - tmp_path_dist_ckpt / 'test_missing_keys_raises_error_during_validation' - ) as ckpt_dir: - save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) - save(sharded_state_dict, ckpt_dir, save_strategy) - - def load_with_flag(strict): - sharded_state_dict = self._get_base_state_dict() - del sharded_state_dict['TenA'] - del sharded_state_dict['ObjB'] - return load( - sharded_state_dict, - ckpt_dir, - validate_access_integrity=validate_integrity, - strict=strict, - ) - - def test_error(error_msg): - assert 'Unexpected keys' not in error_msg - assert 'TenA' in error_msg - assert 'ObjB' in error_msg - assert 'Missing keys' in error_msg - - # no mismatch for `*_UNEXPECTED` flag - loaded_state_dict = load_with_flag(StrictHandling.ASSUME_OK_UNEXPECTED) - assert 'TenB' in loaded_state_dict - - loaded_state_dict = load_with_flag(StrictHandling.RAISE_UNEXPECTED) - assert 'TenB' in loaded_state_dict - - with caplog.at_level(logging.WARNING): - loaded_state_dict = load_with_flag(StrictHandling.LOG_UNEXPECTED) - assert caplog.text == '' - assert 'TenB' in loaded_state_dict - - loaded_state_dict, missing_keys, unexpected_keys = load_with_flag( - StrictHandling.RETURN_UNEXPECTED - ) - assert 'TenB' in loaded_state_dict - assert missing_keys == set() - assert unexpected_keys == set() - - loaded_state_dict = load_with_flag(StrictHandling.IGNORE_ALL) - assert 'TenB' in loaded_state_dict - - # Informative exceptions with `RAISE_ALL` option: - with pytest.raises(CheckpointingException) as exc_info: - load_with_flag(StrictHandling.RAISE_ALL) - test_error(str(exc_info.value)) - - # Logged mismatches: - with caplog.at_level(logging.WARNING): - loaded_state_dict = load_with_flag(StrictHandling.LOG_ALL) - assert 'TenB' in loaded_state_dict - test_error(caplog.text) - - # Returned mismatches - loaded_state_dict, missing_keys, unexpected_keys = load_with_flag( - StrictHandling.RETURN_ALL - ) - assert 'TenB' in loaded_state_dict - assert unexpected_keys == set() - assert missing_keys == {'TenA', 'ObjB'} - - @pytest.mark.parametrize('save_format', ['torch_dist']) - @pytest.mark.parametrize('validate_integrity', [True, False]) - def test_exact_load_handling(self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format): - sharded_state_dict = self._get_base_state_dict() - with TempNamedDir(tmp_path_dist_ckpt / 'test_exact_load_handling') as ckpt_dir: - save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) - save(sharded_state_dict, ckpt_dir, save_strategy) - - def load_with_flag(strict): - sharded_state_dict = self._get_base_state_dict() - return load( - sharded_state_dict, - ckpt_dir, - validate_access_integrity=validate_integrity, - strict=strict, - ) - - for strict in ( - StrictHandling.ASSUME_OK_UNEXPECTED, - StrictHandling.LOG_UNEXPECTED, - StrictHandling.LOG_ALL, - StrictHandling.RAISE_UNEXPECTED, - StrictHandling.RAISE_ALL, - StrictHandling.IGNORE_ALL, - ): - with caplog.at_level(logging.WARNING): - loaded_state_dict = load_with_flag(strict) - assert caplog.text == '' - assert 'TenB' in loaded_state_dict - assert 'ObjB' in loaded_state_dict - - for strict in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL): - with caplog.at_level(logging.WARNING): - loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(strict) - assert caplog.text == '' - assert 'TenB' in loaded_state_dict - assert 'ObjB' in loaded_state_dict - assert missing_keys == set() - assert unexpected_keys == set() - - @pytest.mark.parametrize('save_format', ['torch_dist']) - def test_sharded_metadata(self, tmp_path_dist_ckpt, save_format): - - sharded_state_dict = self._get_base_state_dict() - with TempNamedDir(tmp_path_dist_ckpt / 'test_exact_load_handling') as ckpt_dir: - save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) - save(sharded_state_dict, ckpt_dir, save_strategy) - torch.distributed.barrier() - sharded_metadata = load_sharded_metadata(ckpt_dir) - assert set(sh_base.key for sh_base in sharded_metadata.values()) == { - 'TenA', - 'TenB', - 'TenC', - 'ObjA', - 'ObjB', - } - assert set(sharded_metadata.keys()) == { - 'TenA', - 'TenB', - 'TenC', - 'ObjA/shard_0_1', - *(f'ObjB/shard_0.{i}_1.8' for i in range(8)), - } - - loaded_state_dict = load(sharded_metadata, ckpt_dir, validate_access_integrity=False) - - assert loaded_state_dict['ObjA/shard_0_1'] == list(range(10)) - for shard_idx in range(8): - assert loaded_state_dict[f'ObjB/shard_0.{shard_idx}_1.8'] == {shard_idx + 7} - assert torch.all(loaded_state_dict['TenA'] == torch.arange(2)) - assert torch.all(loaded_state_dict['TenB'] == torch.arange(3).repeat(8)) - assert torch.all(loaded_state_dict['TenC'] == torch.arange(3)) diff --git a/tests/unit_tests/dist_checkpointing/test_strict.py b/tests/unit_tests/dist_checkpointing/test_strict.py deleted file mode 100644 index 1f0d3d7fdc..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_strict.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import logging -import typing - -import pytest -import torch - -try: - from torch.distributed import DeviceMesh - from torch.distributed._tensor import DTensor - - HAVE_DTENSOR = True -except ImportError: - HAVE_DTENSOR = False - -pytest.importorskip( - "nvidia_resiliency_ext", reason="MCoreTensorAwareStateDict requires nvidia-resiliency-ext" -) - -from megatron.core.dist_checkpointing import ShardedTensor -from megatron.core.dist_checkpointing.core import CheckpointingException -from megatron.core.dist_checkpointing.dict_utils import merge -from megatron.core.dist_checkpointing.mapping import ShardedObject -from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict -from megatron.core.dist_checkpointing.validation import StrictHandling -from tests.unit_tests.test_utilities import Utils - - -class TestStrictLocal: - def setup_method(self, method): - Utils.initialize_model_parallel(8, 1) # doesn't matter for this test - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def _get_base_state_dict(self): - return { - 'TenA': ShardedTensor.from_rank_offsets('TenA', torch.arange(2), replica_id=Utils.rank), - 'TenB': ShardedTensor.from_rank_offsets( - 'TenB', torch.arange(3), (0, Utils.rank, Utils.world_size), replica_id=0 - ), - 'TenC': ShardedTensor.from_rank_offsets( - 'TenC', torch.arange(3), replica_id=Utils.world_size - Utils.rank - 1 - ), - 'ObjA': ShardedObject('ObjA', list(range(10)), (1,), (0,), replica_id=Utils.rank), - 'ObjB': ShardedObject( - 'ObjB', {Utils.rank + 7}, (1, Utils.world_size), (0, Utils.rank), replica_id=0 - ), - 'Nested': { - 'TenE': ShardedTensor.from_rank_offsets( - 'Nested.TenE', torch.arange(3), replica_id=Utils.world_size - Utils.rank - 1 - ), - 'ObjE': ShardedObject( - 'Nested.ObjE', list(range(10)), (1,), (0,), replica_id=Utils.rank - ), - 'TenF': ShardedTensor.from_rank_offsets( - 'Nested.TenF', torch.arange(3), replica_id=Utils.rank - ), - 'ObjF': ShardedObject( - 'Nested.ObjF', list(range(10)), (1,), (0,), replica_id=Utils.rank - ), - }, - 'NestedEmpty': {}, - } - - def _get_extra_state_dict(self): - return { - 'UnexpectedTenD': ShardedTensor.from_rank_offsets( - 'UnexpectedTenD', torch.arange(3), replica_id=Utils.rank - ), - 'UnexpectedObjD': ShardedObject( - 'UnexpectedObjD', None, (1,), (0,), replica_id=Utils.rank - ), - 'UnexpectedNested': { - 'UnexpectedTenF': ShardedTensor.from_rank_offsets( - 'UnexpectedNested.UnexpectedTenF', torch.arange(3), replica_id=Utils.rank - ), - 'UnexpectedObjF': ShardedObject( - 'UnexpectedNested.UnexpectedObjF', None, (1,), (0,), replica_id=Utils.rank - ), - }, - 'Nested': { - 'UnexpectedTenG': ShardedTensor.from_rank_offsets( - 'Nested.UnexpectedTenG', torch.arange(3), replica_id=Utils.rank - ), - 'UnexpectedObjG': ShardedObject( - 'Nested.UnexpectedObjG', None, (1,), (0,), replica_id=Utils.rank - ), - }, - 'NestedEmpty': { - 'UnexpectedTenH': ShardedTensor.from_rank_offsets( - 'NestedEmpty.UnexpectedTenH', torch.arange(3), replica_id=Utils.rank - ), - 'UnexpectedObjH': ShardedObject( - 'NestedEmpty.UnexpectedObjH', None, (1,), (0,), replica_id=Utils.rank - ), - }, - } - - def _tasd_to_state_dict(self, *, algo, strict, validate_access_integrity, missing, unexpected): - sharded_state_dict = self._get_base_state_dict() - if missing: - del sharded_state_dict['TenA'] - del sharded_state_dict['ObjB'] - del sharded_state_dict['Nested']['TenE'] - del sharded_state_dict['Nested']['ObjF'] - del sharded_state_dict['NestedEmpty'] - if unexpected: - # Note: merge is in-place - sharded_state_dict = merge(sharded_state_dict, self._get_extra_state_dict()) - tasd, _ = MCoreTensorAwareStateDict.from_state_dict(self._get_base_state_dict(), algo) - tasd = typing.cast(MCoreTensorAwareStateDict, tasd) - return tasd.to_state_dict( - sharded_state_dict=sharded_state_dict, - validate_access_integrity=validate_access_integrity, - strict=strict, - algo=algo, - return_mismatch_keys=True, - ) - - @property - def _missing_keys(self): - return {'TenA', 'ObjB', 'Nested.TenE', 'Nested.ObjF'} - - @property - def _unexpected_keys(self): - return { - 'UnexpectedTenD', - 'UnexpectedObjD', - 'UnexpectedNested.UnexpectedTenF', - 'UnexpectedNested.UnexpectedObjF', - 'NestedEmpty.UnexpectedTenH', - 'NestedEmpty.UnexpectedObjH', - } - - def _check_log_message(self, text, should_contain_missing, should_contain_unexpected): - if not should_contain_missing and not should_contain_unexpected: - assert text == "" - return - if should_contain_missing: - assert 'Missing keys' in text - for key in self._missing_keys: - assert key in text - else: - assert 'Missing keys' not in text - for key in self._missing_keys: - assert key not in text - if should_contain_unexpected: - assert 'Unexpected keys' in text - for key in self._unexpected_keys: - assert key in text - else: - assert 'Unexpected keys' not in text - for key in self._unexpected_keys: - assert key not in text - - def _check_log_message_for_strict_handling(self, text, strict, missing, unexpected): - # Answers the question: - # "I got the log message [text] using strictness [strict]. I [removed/didn't remove] missing and [added/didn't add] unexpected keys." - # Is the log correct? - should_contain_unexpected = ( - strict in {StrictHandling.LOG_UNEXPECTED, StrictHandling.LOG_ALL} - ) and unexpected - should_contain_missing = (strict in {StrictHandling.LOG_ALL}) and missing - return self._check_log_message(text, should_contain_missing, should_contain_unexpected) - - def _check_return_values( - self, missing_keys, unexpected_keys, should_contain_missing, should_contain_unexpected - ): - if should_contain_missing: - assert set(missing_keys) == self._missing_keys - else: - assert set(missing_keys) == set() - if should_contain_unexpected: - assert set(unexpected_keys) == self._unexpected_keys - else: - assert set(unexpected_keys) == set() - - def _check_return_values_for_strict_handling( - self, strict, missing_keys, unexpected_keys, missing, unexpected - ): - should_contain_missing = ( - strict in {StrictHandling.RETURN_ALL, StrictHandling.LOG_ALL} - ) and missing - should_contain_unexpected = ( - strict - in { - StrictHandling.RETURN_ALL, - StrictHandling.RETURN_UNEXPECTED, - StrictHandling.LOG_UNEXPECTED, - StrictHandling.LOG_ALL, - } - ) and unexpected - self._check_return_values( - missing_keys, unexpected_keys, should_contain_missing, should_contain_unexpected - ) - - @pytest.mark.parametrize('algo', ['fully_parallel', 'atomic']) - @pytest.mark.parametrize('validate_integrity', [True, False]) - @pytest.mark.parametrize('strict', list(StrictHandling)) - def test_everything_ok(self, caplog, algo, validate_integrity, strict): - with caplog.at_level(logging.WARNING): - state_dict, missing_keys, unexpected_keys = self._tasd_to_state_dict( - algo=algo, - strict=strict, - validate_access_integrity=validate_integrity, - missing=False, - unexpected=False, - ) - assert state_dict.keys() == self._get_base_state_dict().keys() - assert set(missing_keys) == set() - assert set(unexpected_keys) == set() - assert caplog.text == '' - - @pytest.mark.parametrize('algo', ['atomic']) - @pytest.mark.parametrize('validate_integrity', [True, False]) - @pytest.mark.parametrize(['missing', 'unexpected'], [(True, False)]) - @pytest.mark.parametrize( - 'strict', - [ - StrictHandling.ASSUME_OK_UNEXPECTED, - StrictHandling.LOG_UNEXPECTED, - StrictHandling.LOG_ALL, - StrictHandling.RETURN_UNEXPECTED, - StrictHandling.RETURN_ALL, - StrictHandling.IGNORE_ALL, - ], - ) - def test_passthrough(self, caplog, algo, validate_integrity, missing, unexpected, strict): - # Scenario: strictness check is supposed to pass the errors through, the underlying algorithm is able to handle it. - with caplog.at_level(logging.WARNING): - _, missing_keys, unexpected_keys = self._tasd_to_state_dict( - algo=algo, - strict=strict, - validate_access_integrity=validate_integrity, - missing=missing, - unexpected=unexpected, - ) - self._check_log_message_for_strict_handling(caplog.text, strict, missing, unexpected) - self._check_return_values_for_strict_handling( - strict, missing_keys, unexpected_keys, missing, unexpected - ) - - # NOTE: Fully parallel results in a hard-to-catch error: - # The exchange algorithm is unaware of the missing tensors and will still expect the shards to be received - - # which will cause the process to hang indefinitely. - @pytest.mark.parametrize('algo', ['atomic']) - @pytest.mark.parametrize('validate_integrity', [True, False]) - @pytest.mark.parametrize(['missing', 'unexpected'], [(False, True), (True, True)]) - @pytest.mark.parametrize( - 'strict', - [ - StrictHandling.ASSUME_OK_UNEXPECTED, - StrictHandling.LOG_UNEXPECTED, - StrictHandling.LOG_ALL, - StrictHandling.RETURN_UNEXPECTED, - StrictHandling.RETURN_ALL, - StrictHandling.IGNORE_ALL, - ], - ) - def test_passthrough_errors( - self, caplog, algo, validate_integrity, missing, unexpected, strict - ): - # Scenario: strictness check is supposed to pass the errors through, - # but they result in an error in the underlying algorithm as it's unable to handle it. - # That's why "Fully parallel" is excluded, as instead of raising an error, it will hang indefinitely, which is hard to catch. - with caplog.at_level(logging.WARNING): - with pytest.raises(AssertionError) as exc_info: - self._tasd_to_state_dict( - algo=algo, - strict=strict, - validate_access_integrity=validate_integrity, - missing=missing, - unexpected=unexpected, - ) - # TODO: check exc_info - self._check_log_message_for_strict_handling(caplog.text, strict, missing, unexpected) - - @pytest.mark.parametrize('algo', ['fully_parallel', 'atomic']) - @pytest.mark.parametrize('validate_integrity', [True, False]) - @pytest.mark.parametrize('missing', [True, False]) - def test_raise_unexpected(self, validate_integrity, algo, missing): - with pytest.raises(CheckpointingException) as exc_info: - self._tasd_to_state_dict( - algo=algo, - strict=StrictHandling.RAISE_UNEXPECTED, - validate_access_integrity=validate_integrity, - missing=missing, - unexpected=True, - ) - self._check_log_message( - str(exc_info.value), should_contain_missing=False, should_contain_unexpected=True - ) - - @pytest.mark.parametrize('algo', ['fully_parallel', 'atomic']) - @pytest.mark.parametrize('validate_integrity', [True, False]) - @pytest.mark.parametrize( - ['missing', 'unexpected'], [(True, False), (False, True), (True, True)] - ) - def test_raise_all(self, validate_integrity, algo, missing, unexpected): - with pytest.raises(CheckpointingException) as exc_info: - self._tasd_to_state_dict( - algo=algo, - strict=StrictHandling.RAISE_ALL, - validate_access_integrity=validate_integrity, - missing=missing, - unexpected=unexpected, - ) - self._check_log_message( - str(exc_info.value), - should_contain_missing=missing, - should_contain_unexpected=unexpected, - ) diff --git a/tests/unit_tests/dist_checkpointing/test_torch_dist.py b/tests/unit_tests/dist_checkpointing/test_torch_dist.py deleted file mode 100644 index 4f4df05897..0000000000 --- a/tests/unit_tests/dist_checkpointing/test_torch_dist.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -"""Tests for PyTorch DCP based checkpoint format. """ - -import pickle -from copy import deepcopy -from dataclasses import fields - -import torch - -from megatron.core.dist_checkpointing import ShardedTensor, load, save -from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy -from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -class TestCachedMetadata: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_cached_metadata(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2, 4) - - sharded_state_dict_non_cached = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), replica_id=Utils.rank - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.ones(3, 5, 7), replica_id=Utils.world_size - Utils.rank - 1 - ), - } - - sharded_state_dict_cached = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', torch.ones(2, 4), replica_id=Utils.rank - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', torch.ones(3, 5, 7), replica_id=Utils.world_size - Utils.rank - 1 - ), - } - - loaded_non_cached, loaded_cached = None, None - md_non_cached, md_cached = None, None - with TempNamedDir(tmp_path_dist_ckpt / 'ckpt_dir') as ckpt_dir: - save(sharded_state_dict_non_cached, ckpt_dir, async_sharded_save=False) - loaded_non_cached = load(sharded_state_dict_non_cached, ckpt_dir) - md_path = ckpt_dir / '.metadata' - with md_path.open('rb') as f: - md_non_cached = pickle.load(f) - - save_strategy = deepcopy(get_default_save_sharded_strategy()) - save_strategy.use_cached_ckpt_structure = True - # Run over 3 iterations with cached metadata enabled - # The 3rd iteration will run with cached metadata - # `ckpt_dir` at the 3rd iteration 2 will be maintained for comparison - ckpt_dir = None - for i in range(3): - ckpt_dir = TempNamedDir(tmp_path_dist_ckpt / f'ckpt_dir_${i}_cached') - save( - sharded_state_dict_cached, - ckpt_dir.__enter__(), - save_strategy, - async_sharded_save=False, - ) - if i < 2: - ckpt_dir.cleanup() - loaded_cached = load(sharded_state_dict_cached, ckpt_dir.__enter__()) - md_path = ckpt_dir.__enter__() / '.metadata' - - with md_path.open('rb') as f: - md_cached = pickle.load(f) - - # Check loaded state dict - diffs = diff(loaded_non_cached, loaded_cached) - - assert not any( - len(x) for x in diffs - ), 'Cached metadata doesn\'t produce the same state_dict in loading' - # Check metadata recorded in .metadata, torch.distributed.metadata.Metadata - for field in fields(md_non_cached): - if field.name not in ['storage_data', 'storage_meta']: - diffs = diff(getattr(md_non_cached, field.name), getattr(md_cached, field.name)) - assert not any( - len(x) for x in diffs - ), f'{field.name} is different in metadata from non-cached, cached metadata impls' - ckpt_dir.cleanup() - Utils.destroy_model_parallel() - - -class TestCPUTensors: - def setup_method(self, method): - Utils.initialize_model_parallel() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_cpu_tensors_dont_take_too_much_space(self, tmp_path_dist_ckpt): - large_cuda_tensor = torch.ones(1_000_000, dtype=torch.float, device='cuda') - large_cpu_tensor = torch.ones(1_000_000, dtype=torch.float) - # Create small tensors which are a view of a large tensor - sharded_state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets( - 'keyA', large_cuda_tensor[:10], replica_id=Utils.rank - ), - 'sd_keyB': ShardedTensor.from_rank_offsets( - 'keyB', large_cpu_tensor[:10], replica_id=Utils.rank - ), - } - - with TempNamedDir( - tmp_path_dist_ckpt / 'test_cpu_tensors_dont_take_too_much_space' - ) as ckpt_dir: - save(sharded_state_dict, ckpt_dir) - - distcp_files = [(ckpt_dir / '__0_0.distcp'), (ckpt_dir / '__0_1.distcp')] - for file in distcp_files: - assert file.exists() - file_size = file.stat().st_size - assert file_size < 10_000, file.name diff --git a/tests/unit_tests/distributed/test_distributed_data_parallel.py b/tests/unit_tests/distributed/test_distributed_data_parallel.py deleted file mode 100644 index 0aa764aa4b..0000000000 --- a/tests/unit_tests/distributed/test_distributed_data_parallel.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch -from packaging import version -from torch import testing - -from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.process_groups_config import GradCommProcessGroups, ModelCommProcessGroups -from megatron.core.transformer import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -# Test model for testing DDP -class TestModel(torch.nn.Module): - def __init__(self, input_dim, output_dim): - super().__init__() - self.linear1 = torch.nn.Linear(input_dim, input_dim * 4) - self.activation = torch.nn.ReLU() - self.linear2 = torch.nn.Linear(input_dim * 4, output_dim) - - def forward(self, x): - x = self.linear1(x) - x = self.activation(x) - x = self.linear2(x) - return x - - -class TestDistributedDataParallel: - @classmethod - def setup_class(cls): - Utils.initialize_model_parallel() - - @classmethod - def teardown_class(cls): - Utils.destroy_model_parallel() - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - @pytest.mark.parametrize("dp_size", [2, 8]) # Test with 2 or 8 GPUs - def test_ddp_with_dp_process_groups(self, dp_size): - """Test that DDP works correctly with dp pgs from parallel state and user defined pgs.""" - - # Skip test if we don't have enough GPUs - world_size = torch.distributed.get_world_size() - if world_size != dp_size: - pytest.skip(f"This test requires {dp_size} GPUs, but only {world_size} are available") - - # Simple model config - input_dim = 13 - output_dim = 17 - - # Setup DDP config - ddp_config = DistributedDataParallelConfig(overlap_grad_reduce=True, bucket_size=10000) - - # Create two identical models - model1 = TestModel(input_dim=input_dim, output_dim=output_dim).cuda() - model2 = TestModel(input_dim=input_dim, output_dim=output_dim).cuda() - - # Ensure identical weights - for p1, p2 in zip(model1.parameters(), model2.parameters()): - p2.data.copy_(p1.data) - - # Wrap first model with default process groups - transformer_config = TransformerConfig( - num_attention_heads=1, num_layers=1, context_parallel_size=1 - ) - - ddp_model1 = DistributedDataParallel( - transformer_config, ddp_config=ddp_config, module=model1 - ) - - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - # Create HyperCommGrid with dimension ep, pp, dp (reversed from device mesh order) - grid = HyperCommGrid([1, 1, dp_size], ["ep", "pp", "dp"]) - - # Create process groups config with ONLY dp group - grad_comm_pgs = GradCommProcessGroups() - model_comm_pgs = ModelCommProcessGroups() - - grad_comm_pgs.dp = grid.create_pg("dp") - model_comm_pgs.pp = grid.create_pg("pp") - model_comm_pgs.ep = grid.create_pg("ep") - - # Wrap second model with minimal process groups (only dp) - ddp_model2 = DistributedDataParallel( - transformer_config, - ddp_config=ddp_config, - module=model2, - grad_comm_pgs=grad_comm_pgs, - model_comm_pgs=model_comm_pgs, - ) - - # Create identical inputs with integer values - batch_size = 2 - input_data = torch.randint(0, 10, (batch_size, input_dim), device='cuda', dtype=torch.long) - input_data = input_data.float() # Convert to float for model compatibility - - # Forward pass - out1 = ddp_model1(input_data) - out2 = ddp_model2(input_data) - - testing.assert_close(out1, out2, rtol=0, atol=0) - - # Loss and backward - loss1 = out1.sum() - loss2 = out2.sum() - - loss1.backward() - loss2.backward() - - # Check gradients are identical using torch.testing - for p1, p2 in zip(ddp_model1.parameters(), ddp_model2.parameters()): - if hasattr(p1, 'main_grad') and hasattr(p2, 'main_grad'): - testing.assert_close(p1.main_grad, p2.main_grad, rtol=0, atol=0) diff --git a/tests/unit_tests/distributed/test_finalize_model_grads.py b/tests/unit_tests/distributed/test_finalize_model_grads.py deleted file mode 100644 index e1e2e76069..0000000000 --- a/tests/unit_tests/distributed/test_finalize_model_grads.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import inspect -import os - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.distributed import DistributedDataParallelConfig -from megatron.core.distributed.finalize_model_grads import ( - _allreduce_non_tensor_model_parallel_grads, - _allreduce_word_embedding_grads, -) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestAllReduceLNGrads: - - def init_model(self, share_embeddings_and_output_weights: bool = False): - self.transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - tensor_model_parallel_size=self.tp_size, - pipeline_model_parallel_size=self.pp_size, - qk_layernorm=True, - pipeline_dtype=torch.float32, - ) - - self.model = GPTModel( - config=self.transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(qk_layernorm=True), - vocab_size=100, - max_sequence_length=4, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - ) - - def setup_method(self, method): - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - Utils.destroy_model_parallel() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize("freeze_model,tp_size", [(True, 2), (False, 2)]) - def test_allreduce_layernorm_grads(self, freeze_model, tp_size): - self.tp_size = tp_size - self.pp_size = 1 - Utils.initialize_model_parallel(tensor_model_parallel_size=self.tp_size) - model_parallel_cuda_manual_seed(123) - - self.init_model() - self.model.cuda() - self.model.ddp_config = DistributedDataParallelConfig() - - for param in self.model.parameters(): - if freeze_model: - param.requires_grad = False - else: - param.grad = torch.ones_like(param) - - _allreduce_non_tensor_model_parallel_grads( - [self.model], self.transformer_config, parallel_state.get_tensor_model_parallel_group() - ) - - @pytest.mark.parametrize( - ("freeze_model", "pp_size", "share_embeddings"), - [(True, 2, True), (False, 2, True), (True, 2, False), (False, 2, False)], - ) - def test_allreduce_word_embedding_grads(self, freeze_model, pp_size, share_embeddings): - self.tp_size = 1 - self.pp_size = pp_size - Utils.initialize_model_parallel(pipeline_model_parallel_size=self.pp_size) - model_parallel_cuda_manual_seed(123) - - self.init_model(share_embeddings) - self.model.cuda() - self.model.ddp_config = DistributedDataParallelConfig() - - for param in self.model.parameters(): - if freeze_model: - param.requires_grad = False - else: - param.grad = torch.ones_like(param) - pp_group = parallel_state.get_pipeline_model_parallel_group() - embd_group = parallel_state.get_embedding_group() - - _allreduce_word_embedding_grads([self.model], self.transformer_config, embd_group, pp_group) diff --git a/tests/unit_tests/distributed/test_grad_reduce_for_replicated_embedder.py b/tests/unit_tests/distributed/test_grad_reduce_for_replicated_embedder.py deleted file mode 100644 index c5acf0b76c..0000000000 --- a/tests/unit_tests/distributed/test_grad_reduce_for_replicated_embedder.py +++ /dev/null @@ -1,49 +0,0 @@ -import pytest -import torch - -from megatron.core import ModelParallelConfig, parallel_state -from megatron.core.distributed.finalize_model_grads import _allreduce_conditional_embedding_grads -from tests.unit_tests.test_utilities import Utils - -rank = Utils.rank - - -def test_allreduce_conditional_embedding_grads(): - - Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4) - - # For virtual pipeline parallelism. - model = [torch.nn.Linear(10, 10, bias=True).cuda() for _ in range(2)] - # Here we only reduce weights, not bias to compare the results. - for chunk in model: - setattr(chunk.weight, "pipeline_parallel", True) - - config = ModelParallelConfig( - pipeline_model_parallel_size=4, sequence_parallel=False, pipeline_dtype=torch.float - ) - config.has_cond_embedder = True - - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() - - # Init different grads for each model chunk and rank. - for i, chunk in enumerate(model): - for param in chunk.parameters(): - param.main_grad = torch.ones_like(param) * (pp_rank * 10.0 + i) - - _allreduce_conditional_embedding_grads( - model, config, parallel_state.get_pipeline_model_parallel_group() - ) - - expect_value = 0 - for i in range(len(model)): - for j in range(pp_world_size): - expect_value += j * 10.0 + i - expect_weight_grad = torch.ones([10, 10]).cuda() * expect_value - - for i, chunk in enumerate(model): - expect_bias_grad = torch.ones([10]).cuda() * (pp_rank * 10.0 + i) - assert torch.equal(chunk.weight.main_grad, expect_weight_grad) - assert torch.equal(chunk.bias.main_grad, expect_bias_grad) - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/distributed/test_grad_sync_with_expert_parallel.py b/tests/unit_tests/distributed/test_grad_sync_with_expert_parallel.py deleted file mode 100644 index 71e45f9d92..0000000000 --- a/tests/unit_tests/distributed/test_grad_sync_with_expert_parallel.py +++ /dev/null @@ -1,252 +0,0 @@ -import contextlib -from typing import Optional - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.distributed.param_and_grad_buffer import partition_buckets -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.moe.moe_layer import MoELayer -from tests.unit_tests.test_utilities import TestModel, Utils - - -class TestMoEModel(torch.nn.Module): - def __init__( - self, - hidden_size: int, - num_layers: int, - num_moe_experts: int, - moe_grouped_gemm: bool, - ep_size: int, - etp_size: int, - ): - transformer_config = TransformerConfig( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=1, - num_moe_experts=num_moe_experts, - moe_router_load_balancing_type="aux_loss", - moe_router_topk=2, - moe_aux_loss_coeff=0.01, - moe_grouped_gemm=moe_grouped_gemm, - moe_token_dispatcher_type='alltoall', - expert_model_parallel_size=ep_size, - expert_tensor_parallel_size=etp_size, - bf16=True, - params_dtype=torch.bfloat16, - add_bias_linear=False, - ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=moe_grouped_gemm - ) - super().__init__() - self.layers = torch.nn.ModuleList( - [ - MoELayer( - transformer_config, transformer_layer_spec.submodules.mlp.submodules - ).cuda() - for _ in range(num_layers) - ] - ) - - -def get_moe_model_and_buffers( - num_layers: int, - hidden_size: int, - num_moe_experts: int, - moe_grouped_gemm: bool, - ep_size: int, - bucket_size: Optional[int], - etp_size: int, - use_distributed_optimizer: bool, - overlap_grad_reduce: bool, - average_in_collective: bool, - num_distributed_optimizer_instances: int, -): - ddp_config = DistributedDataParallelConfig( - grad_reduce_in_fp32=True, - use_distributed_optimizer=use_distributed_optimizer, - overlap_grad_reduce=overlap_grad_reduce, - bucket_size=bucket_size, - average_in_collective=average_in_collective, - num_distributed_optimizer_instances=num_distributed_optimizer_instances, - ) - model = TestMoEModel( - hidden_size=hidden_size, - num_layers=num_layers, - num_moe_experts=num_moe_experts, - moe_grouped_gemm=moe_grouped_gemm, - ep_size=ep_size, - etp_size=etp_size, - ) - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config=ddp_config, module=model - ) - assert len(model.buffers) == 1 - param_and_grad_buffer = model.buffers[0] - ep_param_and_grad_buffer = ( - model.expert_parallel_buffers[0] if len(model.expert_parallel_buffers) else None - ) - non_ep_bucket_groups = model.bucket_groups - ep_bucket_groups = model.expert_parallel_bucket_groups - - return ( - model, - param_and_grad_buffer, - ep_param_and_grad_buffer, - non_ep_bucket_groups, - ep_bucket_groups, - ) - - -@pytest.mark.parametrize("use_distributed_optimizer", [False, True]) -@pytest.mark.parametrize("overlap_grad_reduce", [False, True]) -@pytest.mark.parametrize("average_in_collective", [False, True]) -@pytest.mark.parametrize("ep_size", [1, 2]) -@pytest.mark.parametrize("etp_size", [1, 2]) -@pytest.mark.parametrize("num_distributed_optimizer_instances", [1, 2]) -@pytest.mark.flaky -@pytest.mark.flaky_in_dev -def test_grad_sync( - use_distributed_optimizer: bool, - overlap_grad_reduce: bool, - average_in_collective: bool, - ep_size: int, - etp_size: int, - num_distributed_optimizer_instances: int, -): - Utils.initialize_model_parallel( - expert_model_parallel_size=ep_size, - expert_tensor_parallel_size=etp_size, - num_distributed_optimizer_instances=num_distributed_optimizer_instances, - ) - - if num_distributed_optimizer_instances > 1 and not use_distributed_optimizer: - pytest.skip( - "Multiple distributed optimizer instances requires distributed optimizer to be enabled" - ) - - ( - model, - non_ep_param_and_grad_buffer, - ep_param_and_grad_buffer, - non_ep_bucket_groups, - ep_bucket_groups, - ) = get_moe_model_and_buffers( - num_layers=2, - hidden_size=512, - num_moe_experts=4, - moe_grouped_gemm=True, - ep_size=ep_size, - etp_size=etp_size, - bucket_size=None, - use_distributed_optimizer=use_distributed_optimizer, - overlap_grad_reduce=overlap_grad_reduce, - average_in_collective=average_in_collective, - num_distributed_optimizer_instances=num_distributed_optimizer_instances, - ) - - param_to_bucket_group = {} - for bucket_group in non_ep_bucket_groups: - for param in bucket_group.params: - assert param not in param_to_bucket_group - param_to_bucket_group[param] = bucket_group - for bucket_group in ep_bucket_groups: - for param in bucket_group.params: - assert param not in param_to_bucket_group - param_to_bucket_group[param] = bucket_group - - non_ep_param_and_grad_buffer.grad_data.data.fill_(1.0) - non_ep_expected_grad_data_value_after_collective = 1 - if ( - use_distributed_optimizer - and (not average_in_collective) - and parallel_state.get_data_parallel_rank( - with_context_parallel=True, partial_data_parallel=True - ) - != 0 - ): - # With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to 1/data_parallel_word_size - # When average_in_collective=False, the grad data is always first scaled by 1/data_parallel_word_size and then summed by AR/RS - # when use_distributed_optimizer=True, only for rank=0 param_and_grad_buffer.grad_data[0] is updated, for other ranks - # another shard of grad_data is updated while param_and_grad_buffer.grad_data[0] is unchanged (=1/data_parallel_word_size) - non_ep_expected_grad_data_value_after_collective /= ( - parallel_state.get_data_parallel_world_size() - ) - if ep_size > 1: - # For MoE models with exper parallelism, each expert will receive tokens from EPxETP times batches, such that the expert gradient will be EPxETP times after backward, - # and the expected gradient after collective should be 1.0 as same as dense params. - ep_param_and_grad_buffer.grad_data.data.fill_(float(ep_size * etp_size)) - ep_expected_grad_data_value_after_collective = 1 - if ( - use_distributed_optimizer - and (not average_in_collective) - and parallel_state.get_expert_data_parallel_rank(partial_expert_data_parallel=True) != 0 - ): - # With above conditions, the data in param_and_grad_buffer.grad_data[0] equals to 1/EDP - # When average_in_collective=False, the grad data is always first scaled by expert_data_parallel_size and then summed by AR/RS - # after SUM collective in expert_data_group, the scale will be 1.0. - ep_expected_grad_data_value_after_collective /= ( - parallel_state.get_expert_data_parallel_world_size() - ) - - params = list(model.parameters()) - map_bucket_to_last_param_idx = {} - for i, param in enumerate(params): - if not (param in param_to_bucket_group): - # it means this parameter is not on this device, skip - continue - bucket_group = param_to_bucket_group[param] - if bucket_group in map_bucket_to_last_param_idx: - param_idx = map_bucket_to_last_param_idx[bucket_group] + 1 - else: - param_idx = 0 - map_bucket_to_last_param_idx[bucket_group] = param_idx - - register_grad_sync_context = ( - contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError) - ) - finish_grad_sync_context = contextlib.nullcontext() - if ( - param_idx < (len(bucket_group.params) - 1) - and overlap_grad_reduce - and num_distributed_optimizer_instances == 1 - ): - # Can't finish grad sync until all params have been registered ready. - finish_grad_sync_context = pytest.raises(AssertionError) - - with register_grad_sync_context: - bucket_group.register_grad_ready(param) - with finish_grad_sync_context: - # When overlap_grad_reduce is True, this should throw an assertion error until all - # params in the model have registered their grad above. - # When overlap_grad_reduce is False, the collective is forced through. - bucket_group.finish_grad_sync() - - if bucket_group in non_ep_bucket_groups: - expected_grad_data_value = non_ep_expected_grad_data_value_after_collective - else: - expected_grad_data_value = ep_expected_grad_data_value_after_collective - # Before gradient sync, the gradient value should keep original. - if overlap_grad_reduce and param_idx < (len(bucket_group.params) - 1): - if bucket_group in non_ep_bucket_groups: - expected_grad_data_value = 1 - else: - expected_grad_data_value = ep_size * etp_size - - if bucket_group in non_ep_bucket_groups: - assert non_ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value - else: - assert ep_param_and_grad_buffer.grad_data[0] == expected_grad_data_value - - if not overlap_grad_reduce: - # Reset grad_data for subsequent collectives. - if bucket_group in non_ep_bucket_groups: - non_ep_param_and_grad_buffer.grad_data.data.fill_(1.0) - else: - ep_param_and_grad_buffer.grad_data.data.fill_(float(ep_size * etp_size)) - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/distributed/test_mcore_fully_sharded_data_parallel.py b/tests/unit_tests/distributed/test_mcore_fully_sharded_data_parallel.py deleted file mode 100644 index 9840612d86..0000000000 --- a/tests/unit_tests/distributed/test_mcore_fully_sharded_data_parallel.py +++ /dev/null @@ -1,480 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import random - -import numpy as np -import pytest -import torch -from packaging import version -from torch import testing - -import megatron.core.parallel_state as mpu -from megatron.core.distributed import DistributedDataParallelConfig -from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.optimizer import OptimizerConfig -from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer -from megatron.core.process_groups_config import GradCommProcessGroups, ModelCommProcessGroups -from megatron.core.transformer import TransformerConfig -from megatron.core.utils import is_torch_min_version -from tests.unit_tests.test_utilities import Utils - - -# Test model for testing FSDP -class TestModel(torch.nn.Module): - def __init__(self, input_dim, output_dim): - super().__init__() - self.linear1 = torch.nn.Linear(input_dim, input_dim * 4) - self.activation = torch.nn.ReLU() - self.linear2 = torch.nn.Linear(input_dim * 4, output_dim) - - def forward(self, x): - x = self.linear1(x) - x = self.activation(x) - x = self.linear2(x) - return x - - -# Test model with uniform shaped weights for testing FSDP -class TestModelUniform(torch.nn.Module): - def __init__(self, hidden_dim): - super().__init__() - self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) - self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) - self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) - self.linear4 = torch.nn.Linear(hidden_dim, hidden_dim) - self.activation = torch.nn.ReLU() - - def forward(self, x): - x = self.linear1(x) - x = self.activation(x) - x = self.linear2(x) - x = self.activation(x) - x = self.linear3(x) - x = self.activation(x) - x = self.linear4(x) - return x - - -def setup_seed(seed): - random.seed(seed) # Set Python's built-in random seed - np.random.seed(seed) # Set NumPy's random seed - torch.manual_seed(seed) # Set PyTorch's CPU seed - torch.cuda.manual_seed(seed) # Set PyTorch's GPU seed (if using CUDA) - torch.cuda.manual_seed_all(seed) # Set seed for all GPUs - torch.backends.cudnn.deterministic = True # Ensure deterministic behavior - torch.backends.cudnn.benchmark = False # Disable auto-tuner for reproducibility - - -class TestFullyShardedDataParallel: - @classmethod - def setup_class(cls): - Utils.initialize_model_parallel() - - @classmethod - def teardown_class(cls): - Utils.destroy_model_parallel() - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - @pytest.mark.parametrize("dp_size", [2, 8]) # Test with 2 or 8 GPUs - def test_fsdp_with_process_groups(self, dp_size): - """Test that FSDP works correctly with different process group configurations.""" - if not is_torch_min_version("2.4.0"): - pytest.skip("Megatron FSDP requires torch >= 2.4.0") - - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - # Create HyperCommGrid with dp dimension - grid = HyperCommGrid([dp_size], ["dp"]) - - # Create process groups config with ONLY dp group - grad_comm_pgs = GradCommProcessGroups() - model_comm_pgs = ModelCommProcessGroups() - - grad_comm_pgs.dp = grid.create_pg("dp") - grad_comm_pgs.dp_cp = grad_comm_pgs.dp - - # Skip test if we don't have enough GPUs - world_size = torch.distributed.get_world_size() - if world_size != dp_size: - pytest.skip(f"This test requires {dp_size} GPUs, but only {world_size} are available") - - # Simple model config - input_dim = 13 - output_dim = 17 - - # Setup FSDP config - using optim_grads_params for full sharding test - fsdp_config = DistributedDataParallelConfig( - data_parallel_sharding_strategy="optim_grads_params", - overlap_grad_reduce=True, - overlap_param_gather=True, - bucket_size=10000, - use_megatron_fsdp=True, - ) - - # Create two identical models - model1 = TestModel(input_dim=input_dim, output_dim=output_dim).cuda() - model2 = TestModel(input_dim=input_dim, output_dim=output_dim).cuda() - - # Ensure identical weights - for p1, p2 in zip(model1.parameters(), model2.parameters()): - p2.data.copy_(p1.data) - - transformer_config = TransformerConfig( - num_attention_heads=1, num_layers=1, context_parallel_size=1 # Explicitly set CP=1 - ) - fsdp_model1 = FullyShardedDataParallel( - config=transformer_config, - ddp_config=fsdp_config, - module=model1, - fsdp_unit_modules=[torch.nn.Linear], - ) - - # Wrap second model with explicit process groups - fsdp_model2 = FullyShardedDataParallel( - config=transformer_config, - ddp_config=fsdp_config, - module=model2, - grad_comm_pgs=grad_comm_pgs, - model_comm_pgs=model_comm_pgs, - fsdp_unit_modules=[torch.nn.Linear], - ) - - # Create optimizer config - lr = 3 - optimizer_config = OptimizerConfig(optimizer="adam", lr=lr) - grad_scaler = None - - optimizer1 = DistributedOptimizer( - optimizer=None, - config=optimizer_config, - grad_scaler=grad_scaler, - init_state_fn=None, - model_chunks=[fsdp_model1], - per_model_buffers={0: [fsdp_model1.param_and_grad_buffer]}, - data_parallel_group=fsdp_model1.megatron_fsdp_dist_index.get_dp_group(), - data_parallel_group_gloo=None, - data_parallel_group_idx=0, - distributed_optimizer_instance_id=0, - ) - - optimizer2 = DistributedOptimizer( - optimizer=None, - config=optimizer_config, - grad_scaler=grad_scaler, - init_state_fn=None, - model_chunks=[fsdp_model2], - per_model_buffers={0: [fsdp_model2.param_and_grad_buffer]}, - data_parallel_group=fsdp_model2.megatron_fsdp_dist_index.get_dp_group(), - data_parallel_group_gloo=None, - data_parallel_group_idx=0, - distributed_optimizer_instance_id=1, - ) - - # Create identical inputs - batch_size = 2 - input_data = torch.randint(0, 10, (batch_size, input_dim), device='cuda', dtype=torch.long) - input_data = input_data.float() - input_data.requires_grad = True - - def loss_fn(output, _): - return output.sum() - - def train_step(model, optimizer, inputs): - inputs_clone = inputs.clone().detach().requires_grad_(True) - optimizer.zero_grad() - outputs = model(inputs_clone) - loss = loss_fn(outputs, None) - loss.backward() - optimizer.step() - return outputs, loss - - out1, loss1 = train_step(fsdp_model1, optimizer1, input_data) - out2, loss2 = train_step(fsdp_model2, optimizer2, input_data) - - testing.assert_close(out1, out2, rtol=0, atol=0) - testing.assert_close(loss1, loss2, rtol=0, atol=0) - - fsdp_model1.stop_communication() - fsdp_model2.stop_communication() - - # Check parameters after optimization step - for (name1, param1), (_, param2) in zip( - fsdp_model1.named_parameters(), fsdp_model2.named_parameters() - ): - testing.assert_close( - param1._local_tensor, - param2._local_tensor, - rtol=0, - atol=0, - msg=f"Parameters for {name1} don't match", - ) - - # Testing fsdp_double_buffer with and without nccl_ub - @pytest.mark.parametrize( - ("dp_size", "nccl_ub", "fsdp_double_buffer"), [(8, False, True), (8, True, True)] - ) - def test_fsdp_user_buffer_registration(self, dp_size, nccl_ub, fsdp_double_buffer): - """Test that FSDP works correctly with user buffer registration. - This test compares the training results of the baseline fsdp with the target fsdp config. - Baseline fsdp: nccl_ub=False, fsdp_double_buffer=False - Target fsdp: nccl_ub=[True, False], fsdp_double_buffer=[True, False] - """ - if not is_torch_min_version("2.4.0"): - pytest.skip("Megatron FSDP requires torch >= 2.4.0") - - # Skip nccl_ub=True cases if PyTorch version is less than 2.7.0 - if nccl_ub and version.parse(torch.__version__) < version.parse('2.7.0'): - pytest.skip("nccl_ub requires PyTorch 2.7.0 or later") - - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - # Create HyperCommGrid with dp dimension - grid = HyperCommGrid([dp_size], ["dp"]) - - # Create process groups config with ONLY dp group - grad_comm_pgs = GradCommProcessGroups() - model_comm_pgs = ModelCommProcessGroups() - - grad_comm_pgs.dp = grid.create_pg("dp") - - # Skip test if we don't have enough GPUs - world_size = torch.distributed.get_world_size() - if world_size != dp_size: - pytest.skip(f"This test requires {dp_size} GPUs, but only {world_size} are available") - - # Model config - hidden_dim = 16 - - # Setup FSDP config - baseline fsdp config - baseline_fsdp_config = DistributedDataParallelConfig( - data_parallel_sharding_strategy="optim_grads_params", - overlap_grad_reduce=True, - overlap_param_gather=True, - bucket_size=10000, - use_megatron_fsdp=True, - nccl_ub=False, - fsdp_double_buffer=False, - ) - - # Setup FSDP config - target fsdp config - target_fsdp_config = DistributedDataParallelConfig( - data_parallel_sharding_strategy="optim_grads_params", - overlap_grad_reduce=True, - overlap_param_gather=True, - bucket_size=10000, - use_megatron_fsdp=True, - nccl_ub=nccl_ub, - fsdp_double_buffer=fsdp_double_buffer, - ) - - # Create two identical models - model1 = TestModelUniform(hidden_dim=hidden_dim).cuda() - model2 = TestModelUniform(hidden_dim=hidden_dim).cuda() - - # Ensure identical weights - for p1, p2 in zip(model1.parameters(), model2.parameters()): - p2.data.copy_(p1.data) - - transformer_config = TransformerConfig( - num_attention_heads=1, num_layers=1, context_parallel_size=1 # Explicitly set CP=1 - ) - baseline_fsdp_model = FullyShardedDataParallel( - config=transformer_config, - ddp_config=baseline_fsdp_config, - module=model1, - fsdp_unit_modules=[torch.nn.Linear], - ) - - target_fsdp_model = FullyShardedDataParallel( - config=transformer_config, - ddp_config=target_fsdp_config, - module=model2, - fsdp_unit_modules=[torch.nn.Linear], - ) - - # Create optimizer config - lr = 3 - optimizer_config = OptimizerConfig(optimizer="adam", lr=lr) - grad_scaler = None - - optimizer1 = DistributedOptimizer( - optimizer=None, - config=optimizer_config, - grad_scaler=grad_scaler, - init_state_fn=None, - model_chunks=[baseline_fsdp_model], - per_model_buffers={0: [baseline_fsdp_model.param_and_grad_buffer]}, - data_parallel_group=baseline_fsdp_model.megatron_fsdp_dist_index.get_dp_group(), - data_parallel_group_gloo=None, - data_parallel_group_idx=0, - distributed_optimizer_instance_id=0, - ) - - optimizer2 = DistributedOptimizer( - optimizer=None, - config=optimizer_config, - grad_scaler=grad_scaler, - init_state_fn=None, - model_chunks=[target_fsdp_model], - per_model_buffers={0: [target_fsdp_model.param_and_grad_buffer]}, - data_parallel_group=target_fsdp_model.megatron_fsdp_dist_index.get_dp_group(), - data_parallel_group_gloo=None, - data_parallel_group_idx=0, - distributed_optimizer_instance_id=1, - ) - - # Create identical inputs - batch_size = 2 - input_data = torch.randint(0, 10, (batch_size, hidden_dim), device='cuda', dtype=torch.long) - input_data = input_data.float() - input_data.requires_grad = True - - def loss_fn(output, _): - return output.sum() - - def train_step(model, optimizer, inputs): - inputs_clone = inputs.clone().detach().requires_grad_(True) - optimizer.zero_grad() - outputs = model(inputs_clone) - loss = loss_fn(outputs, None) - loss.backward() - optimizer.step() - return outputs, loss - - out1, loss1 = train_step(baseline_fsdp_model, optimizer1, input_data) - out2, loss2 = train_step(target_fsdp_model, optimizer2, input_data) - - testing.assert_close(out1, out2, rtol=0, atol=0) - testing.assert_close(loss1, loss2, rtol=0, atol=0) - - # Check parameters after optimization step - baseline_fsdp_model.stop_communication() - target_fsdp_model.stop_communication() - for (name1, param1), (_, param2) in zip( - baseline_fsdp_model.named_parameters(), target_fsdp_model.named_parameters() - ): - testing.assert_close( - param1._local_tensor, - param2._local_tensor, - rtol=0, - atol=0, - msg=f"Parameters for {name1} don't match", - ) - - @classmethod - def hsdp_one_step_test(cls, num_fsdp_group): - if not is_torch_min_version("2.4.0"): - pytest.skip("Megatron FSDP requires torch >= 2.4.0") - - setup_seed(42) # Ensure reproducibility - Utils.initialize_model_parallel(num_distributed_optimizer_instances=num_fsdp_group) - - try: - # Create two identical models - input_dim = 13 - output_dim = 17 - model = TestModel(input_dim=input_dim, output_dim=output_dim).cuda() - - # Setup FSDP config - using optim_grads_params for full sharding test - fsdp_config = DistributedDataParallelConfig( - data_parallel_sharding_strategy="optim_grads_params", - overlap_grad_reduce=True, - overlap_param_gather=True, - bucket_size=10000, - use_megatron_fsdp=True, - num_distributed_optimizer_instances=num_fsdp_group, - ) - - # Wrap first model with default process groups - transformer_config = TransformerConfig( - num_attention_heads=1, num_layers=1, context_parallel_size=1 # Explicitly set CP=1 - ) - fsdp_model = FullyShardedDataParallel( - config=transformer_config, - ddp_config=fsdp_config, - module=model, - fsdp_unit_modules=[torch.nn.Linear], - ) - fsdp_model.is_last_microbatch = True # Set to True for testing - - # Create optimizer config - lr = 3 - optimizer_config = OptimizerConfig(optimizer="adam", lr=lr) - grad_scaler = None - - if num_fsdp_group > 1: - distributed_optimizer_instance_id = torch.distributed.get_rank( - mpu.get_inter_distributed_optimizer_instance_group() - ) - else: - distributed_optimizer_instance_id = 0 - - distopt = DistributedOptimizer( - optimizer=None, - config=optimizer_config, - grad_scaler=grad_scaler, - init_state_fn=None, - model_chunks=[fsdp_model], - per_model_buffers={0: [fsdp_model.param_and_grad_buffer]}, - data_parallel_group=fsdp_model.megatron_fsdp_dist_index.get_dp_group(), - data_parallel_group_gloo=None, - data_parallel_group_idx=0, - distributed_optimizer_instance_id=distributed_optimizer_instance_id, - ) - - # Create identical inputs - batch_size = 2 - input_data = torch.randint( - 0, 10, (batch_size, input_dim), device='cuda', dtype=torch.long - ) - input_data = input_data.float() - input_data.requires_grad = True - - def loss_fn(output, _): - return output.sum() - - def train_step(model, optimizer, inputs): - inputs_clone = inputs.clone().detach().requires_grad_(True) - optimizer.zero_grad() - outputs = model(inputs_clone) - loss = loss_fn(outputs, None) - loss.backward() - optimizer.step() - return outputs, loss - - out, loss = train_step(fsdp_model, distopt, input_data) - fsdp_model.stop_communication() - - return out, loss, fsdp_model.named_parameters() - finally: - Utils.destroy_model_parallel() - - @pytest.mark.parametrize("num_fsdp_group", [2]) - @pytest.mark.skipIf( - torch.cuda.device_count() % 2 == 0, "This test requires an odd number of GPUs" - ) - def test_fsdp_with_hybrid_sharding(self, num_fsdp_group): - """Test that FSDP works correctly with hybrid sharding.""" - out1, loss1, named_params1 = self.hsdp_one_step_test(num_fsdp_group) - out2, loss2, named_params2 = self.hsdp_one_step_test(1) - - testing.assert_close(out1, out2, rtol=0, atol=0) - testing.assert_close(loss1, loss2, rtol=0, atol=0) - - for (name1, param1), (name2, param2) in zip(named_params1, named_params2): - if param1.grad is None: - continue # Skip if no gradient - testing.assert_close( - param1.grad._local_tensor, - param2.grad._local_tensor, - rtol=0, - atol=0, - msg=f"Parameter gradients for {name1} and {name2} don't match", - ) diff --git a/tests/unit_tests/distributed/test_param_and_grad_buffer.py b/tests/unit_tests/distributed/test_param_and_grad_buffer.py deleted file mode 100644 index c09e2313d8..0000000000 --- a/tests/unit_tests/distributed/test_param_and_grad_buffer.py +++ /dev/null @@ -1,251 +0,0 @@ -import contextlib -import math -from typing import Optional - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.distributed.param_and_grad_buffer import partition_buckets -from megatron.core.transformer import TransformerConfig -from tests.unit_tests.test_utilities import TestModel, Utils - - -def get_model_and_buffers( - input_dim: int, - output_dim: int, - num_layers: int, - bias: bool, - shared_embedding: bool, - bucket_size: int, - use_distributed_optimizer: bool, - overlap_grad_reduce: bool, - average_in_collective: bool, - num_distributed_optimizer_instances: int = 1, -): - ddp_config = DistributedDataParallelConfig( - grad_reduce_in_fp32=True, - use_distributed_optimizer=use_distributed_optimizer, - overlap_grad_reduce=overlap_grad_reduce, - bucket_size=bucket_size, - average_in_collective=average_in_collective, - num_distributed_optimizer_instances=num_distributed_optimizer_instances, - ) - model = TestModel( - input_dim=input_dim, - output_dim=output_dim, - num_layers=num_layers, - bias=bias, - shared_embedding=shared_embedding, - ).bfloat16() - - # Wrap with DistributedDataParallel, and get underlying buffer. - # Use dummy TransformerConfig with mostly default values. Avoid divide-by-zero - # errors for num_attention_heads and num_layers. - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config=ddp_config, module=model - ) - assert len(model.buffers) == 1 - param_and_grad_buffer = model.buffers[0] - bucket_groups = model.bucket_groups - - return model, param_and_grad_buffer, bucket_groups - - -@pytest.mark.parametrize("bucket_size", [None, 9000, 9025, 9050, 18000, 18050, 20000]) -@pytest.mark.parametrize("use_distributed_optimizer", [False, True]) -@pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("shared_embedding", [False, True]) -def test_bucket_sizes( - bucket_size: Optional[int], use_distributed_optimizer: bool, bias: bool, shared_embedding: bool -): - Utils.initialize_model_parallel() - - if shared_embedding and bias: - # Don't bother running shared_embedding + bias since gold values are trickier to compute. - return - - input_dim = 95 - output_dim = 95 - num_layers = 10 - _, param_and_grad_buffer, _ = get_model_and_buffers( - input_dim=input_dim, - output_dim=output_dim, - num_layers=num_layers, - bias=bias, - shared_embedding=shared_embedding, - bucket_size=bucket_size, - use_distributed_optimizer=use_distributed_optimizer, - overlap_grad_reduce=True, - average_in_collective=False, - ) - - actual_numel_in_each_bucket = [ - bucket.numel_unpadded for bucket in param_and_grad_buffer.buckets - ] - actual_numel_padded_in_each_bucket = [ - bucket.grad_data.numel() for bucket in param_and_grad_buffer.buckets - ] - - def _pad_if_needed(numel_unpadded, divisor): - if use_distributed_optimizer: - return math.ceil(numel_unpadded / divisor) * divisor - return numel_unpadded - - def _pad_bucket_if_needed(numel_unpadded): - # Want 128-byte alignment for distributed optimizer. - divisor = math.lcm(parallel_state.get_data_parallel_world_size(), 128) - return _pad_if_needed(numel_unpadded, divisor) - - def _pad_param_if_needed(numel_unpadded): - # Want 64-byte alignment for params. - return _pad_if_needed(numel_unpadded, 64) - - if bucket_size is None: - # If bucket_size is infinite (None), number of buckets should be 1. - if shared_embedding and use_distributed_optimizer: - assert len(param_and_grad_buffer.buckets) == 2 - else: - assert len(param_and_grad_buffer.buckets) == 1 - else: - # Else, compute number of buckets. - numel_in_each_bucket = [] - numel_padded_in_each_bucket = [] - numel_in_last_bucket = 0 - param_sizes = [] - for _ in range(num_layers): - param_sizes.append(input_dim * output_dim) - if bias: # Include bias term. - param_sizes.append(output_dim) - # Create separate bucket for first parameter from reverse direction. - if shared_embedding and use_distributed_optimizer: - numel_in_each_bucket.append(param_sizes[-1]) - numel_padded_in_each_bucket.append(_pad_bucket_if_needed(param_sizes[-1])) - param_sizes = param_sizes[:-1] - # Iterate through params in backward direction. - for param_size in param_sizes[::-1]: - numel_in_last_bucket = _pad_param_if_needed(numel_in_last_bucket) - numel_in_last_bucket += param_size - if numel_in_last_bucket >= bucket_size: - numel_in_each_bucket.append(numel_in_last_bucket) - numel_padded_in_each_bucket.append(_pad_bucket_if_needed(numel_in_last_bucket)) - numel_in_last_bucket = 0 - if numel_in_last_bucket > 0: - numel_in_each_bucket.append(numel_in_last_bucket) - numel_padded_in_each_bucket.append(_pad_bucket_if_needed(numel_in_last_bucket)) - - assert len(param_and_grad_buffer.buckets) == len( - numel_in_each_bucket - ), f"Buckets don't match (got {actual_numel_in_each_bucket} but should be {numel_in_each_bucket})" - assert actual_numel_in_each_bucket == numel_in_each_bucket, ( - f"Number of parameters in each bucket should be {numel_in_each_bucket}, " - f"but is {actual_numel_in_each_bucket}" - ) - if use_distributed_optimizer: - assert all( - [ - x % parallel_state.get_data_parallel_world_size() == 0 - for x in actual_numel_padded_in_each_bucket - ] - ), ( - f"Size of each padded bucket should be divisible by " - f"{parallel_state.get_data_parallel_world_size()}" - ) - assert actual_numel_padded_in_each_bucket == numel_padded_in_each_bucket, ( - f"Number of parameters in each padded bucket should be {numel_padded_in_each_bucket}, " - f"but is {actual_numel_padded_in_each_bucket}" - ) - - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize("use_distributed_optimizer", [False, True]) -@pytest.mark.parametrize("overlap_grad_reduce", [False, True]) -@pytest.mark.parametrize("average_in_collective", [False, True]) -@pytest.mark.parametrize("num_distributed_optimizer_instances", [1, 2]) -# @pytest.mark.flaky -def test_grad_sync( - use_distributed_optimizer: bool, - overlap_grad_reduce: bool, - average_in_collective: bool, - num_distributed_optimizer_instances: int, -): - Utils.initialize_model_parallel( - num_distributed_optimizer_instances=num_distributed_optimizer_instances - ) - # Skip test if num_distributed_optimizer_instances > 1 and not using distributed optimizer - if num_distributed_optimizer_instances > 1 and not use_distributed_optimizer: - pytest.skip("Multiple optimizer instances require distributed optimizer to be enabled") - - input_dim = 100 - output_dim = 100 - num_layers = 10 - model, param_and_grad_buffer, bucket_groups = get_model_and_buffers( - input_dim=input_dim, - output_dim=output_dim, - num_layers=num_layers, - bias=True, - shared_embedding=False, - bucket_size=None, # Group all params into single bucket. - use_distributed_optimizer=use_distributed_optimizer, - overlap_grad_reduce=overlap_grad_reduce, - average_in_collective=average_in_collective, - num_distributed_optimizer_instances=num_distributed_optimizer_instances, - ) - param_to_bucket_group = {} - for bucket_group in bucket_groups: - for param in bucket_group.params: - assert param not in param_to_bucket_group - param_to_bucket_group[param] = bucket_group - - param_and_grad_buffer.grad_data.data.fill_(1.0) - expected_grad_data_value_after_collective = 1 - # under the following conditions, the data in param_and_grad_buffer.grad_data[0] equals to 1/DP - # this is because when average_in_collective=False, the grad data is always first scaled by 1/DP and then summed by AR/RS - # and when use_distributed_optimizer=True, only for rank=0 param_and_grad_buffer.grad_data[0] is updated, for other ranks - # another shard of grad_data is updated while param_and_grad_buffer.grad_data[0] is unchanged (=1/DP) - if ( - use_distributed_optimizer - and (not average_in_collective) - and parallel_state.get_data_parallel_rank( - with_context_parallel=True, partial_data_parallel=True - ) - != 0 - ): - expected_grad_data_value_after_collective /= parallel_state.get_data_parallel_world_size() - - params = list(model.parameters()) - for i, param in enumerate(params): - assert param in param_to_bucket_group - bucket_group = param_to_bucket_group[param] - register_grad_sync_context = ( - contextlib.nullcontext() if overlap_grad_reduce else pytest.raises(AssertionError) - ) - finish_grad_sync_context = contextlib.nullcontext() - if ( - i < (len(params) - 1) - and overlap_grad_reduce - and num_distributed_optimizer_instances == 1 - ): - # Can't finish grad sync until all params have been registered ready. - finish_grad_sync_context = pytest.raises(AssertionError) - - with register_grad_sync_context: - bucket_group.register_grad_ready(param) - with finish_grad_sync_context: - # When overlap_grad_reduce is True, this should throw an assertion error until all - # params in the model have registered their grad above. - # When overlap_grad_reduce is False, the collective is forced through. - bucket_group.finish_grad_sync() - - expected_grad_data_value = expected_grad_data_value_after_collective - if overlap_grad_reduce and i < (len(params) - 1): - expected_grad_data_value = 1 - assert param_and_grad_buffer.grad_data[0] == expected_grad_data_value - - if not overlap_grad_reduce: - # Reset grad_data for subsequent collectives. - param_and_grad_buffer.grad_data.data.fill_(1.0) - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/distributed/test_torch_fully_sharded_parallel.py b/tests/unit_tests/distributed/test_torch_fully_sharded_parallel.py deleted file mode 100644 index 51dc1d0bff..0000000000 --- a/tests/unit_tests/distributed/test_torch_fully_sharded_parallel.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig -from megatron.core.distributed.torch_fully_sharded_data_parallel import ( - TorchFullyShardedDataParallel, -) -from megatron.core.num_microbatches_calculator import ( - init_num_microbatches_calculator, - unset_num_microbatches_calculator, -) -from megatron.core.tensor_parallel import ColumnParallelLinear -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import MegatronModule -from megatron.core.transformer.module import Float16Module -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import init_method_normal, is_torch_min_version -from tests.unit_tests.test_utilities import Utils - - -class DummyModel(MegatronModule): - """Setup a few modules to test the FSDP2 constructor.""" - - _fsdp_modules = [torch.nn.Linear] - - def __init__(self, config: TransformerConfig): - """Initialize a dummy model with a few modules.""" - super().__init__(config) - self.linear = torch.nn.Linear(2, 2) - self.column_parallel_linear = ColumnParallelLinear( - input_size=2, output_size=2, config=config, init_method=init_method_normal(0.02) - ) - self.conv = torch.nn.Conv2d(2, 2, 1) - - -@pytest.fixture -def init_model_parallel(): - """Init torch distributed.""" - Utils.initialize_model_parallel(1, 1) - init_num_microbatches_calculator(0, None, 1, 1, 1) - model_parallel_cuda_manual_seed(123) - yield # Run the actual test. - Utils.destroy_model_parallel() - unset_num_microbatches_calculator() - - -def test_fsdp2_constructor(init_model_parallel): - """Test the FSDP2 constructor.""" - if not is_torch_min_version("2.4.0"): - pytest.skip("FSDP2 is not supported on this version of PyTorch.") - - # Create a dummy model and configs. - config = TransformerConfig(num_layers=1, kv_channels=1, bf16=True) - ddp_config = DistributedDataParallelConfig() - model = DummyModel(config) - model = Float16Module(config, model) - ddp_config = DistributedDataParallelConfig() - - # Create the sharded model. - fsdp_model = TorchFullyShardedDataParallel(config, ddp_config, model) - - def _is_fsdp_wrapped_module(instance): - # FSDP adds a prefix to the class name. - return instance.__class__.__name__.startswith("FSDP") - - assert isinstance(fsdp_model, TorchFullyShardedDataParallel) - # We manually added Linear to the list of submodules to wrap. - assert _is_fsdp_wrapped_module(fsdp_model.module.module.linear) - # ColumnParallelLinear is in the default list of submodules to wrap. - assert _is_fsdp_wrapped_module(fsdp_model.module.module.column_parallel_linear) - # Conv2d is not in the list of submodules to wrap. - assert not _is_fsdp_wrapped_module(fsdp_model.module.module.conv) - - -def test_fsdp2_constructor_with_process_group(init_model_parallel): - """Test the FSDP2 constructor with explicit process group parameter.""" - if not is_torch_min_version("2.4.0"): - pytest.skip("FSDP2 is not supported on this version of PyTorch.") - - # Create a dummy model and configs. - config = TransformerConfig(num_layers=1, kv_channels=1, bf16=True) - ddp_config = DistributedDataParallelConfig() - model = DummyModel(config) - model = Float16Module(config, model) - - # Create a custom process group (using the default world for testing) - custom_process_group = parallel_state.get_data_parallel_group(with_context_parallel=True) - - # Create the sharded model with explicit process group - fsdp_model = TorchFullyShardedDataParallel( - config, ddp_config, model, process_group=custom_process_group - ) - - # Verify the process group was set correctly - assert fsdp_model.process_group is custom_process_group - - # Check that module wrapping still works correctly - def _is_fsdp_wrapped_module(instance): - return instance.__class__.__name__.startswith("FSDP") - - assert isinstance(fsdp_model, TorchFullyShardedDataParallel) - assert _is_fsdp_wrapped_module(fsdp_model.module.module.linear) - assert _is_fsdp_wrapped_module(fsdp_model.module.module.column_parallel_linear) - assert not _is_fsdp_wrapped_module(fsdp_model.module.module.conv) diff --git a/tests/unit_tests/export/trtllm/test_distributed_fp8.py b/tests/unit_tests/export/trtllm/test_distributed_fp8.py deleted file mode 100644 index cf47a86410..0000000000 --- a/tests/unit_tests/export/trtllm/test_distributed_fp8.py +++ /dev/null @@ -1,272 +0,0 @@ -from functools import partial - -import pytest -import torch -from pytest_mock import mocker -from torch.optim import Adam -from torch.utils.data import DataLoader - -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset -from megatron.core.datasets.utils import compile_helpers -from megatron.core.export.data_type import DataType -from megatron.core.export.export_config import ExportConfig -from megatron.core.export.model_type import ModelType -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.pipeline_parallel.schedules import get_forward_backward_func -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training.tokenizer.tokenizer import _NullTokenizer -from tests.unit_tests.test_utilities import Utils - -VOCAB_SIZE = 256 -SEQUENCE_LENGTH = 64 -NUM_LAYERS = 2 -DEVICE = torch.device("cuda") -DTYPE = torch.bfloat16 - - -def _model_provider(): - """Build the model.""" - - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=512, - num_attention_heads=16, - use_cpu_initialization=True, - num_query_groups=2, - fp8='hybrid', - fp8_margin=0, - fp8_interval=1, - fp8_amax_history_len=1024, - fp8_amax_compute_algo="max", - tensor_model_parallel_size=2, - ) - - gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), - vocab_size=VOCAB_SIZE, - max_sequence_length=SEQUENCE_LENGTH, - ) - - return gpt_model - - -def _get_train_data_iterator(): - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - compile_helpers() - torch.distributed.barrier() - else: - compile_helpers() - - config = GPTDatasetConfig( - random_seed=0, - sequence_length=SEQUENCE_LENGTH, - reset_position_ids=False, - reset_attention_mask=False, - eod_mask_loss=False, - tokenizer=_NullTokenizer(vocab_size=50), - mid_level_dataset_surplus=0.005, - ) - - datasets = BlendedMegatronDatasetBuilder( - MockGPTDataset, [1000, None, None], lambda: True, config - ).build() - - train_dataloader = DataLoader(datasets[0], batch_size=8, shuffle=True) - - train_iterator = iter(train_dataloader) - - return train_iterator - - -def _forward_step_func(data_iterator, model): - - def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): - - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - # If you have data parallel reduce loss across data parallel groups. - # If pipeline parallel, loss computation is done only in last stage. - - return loss, {'lm loss': loss} - - data = next(data_iterator) - tokens = torch.ones_like(data['tokens']).to(DEVICE) - attention_mask = data['attention_mask'].to(DEVICE) - position_ids = data['position_ids'].to(DEVICE) - labels = data['labels'].to(DEVICE) - loss_mask = data['loss_mask'].to(DEVICE) - output_tensor = model(tokens, position_ids, attention_mask, labels=labels) - - return output_tensor, partial(loss_func, loss_mask) - - -class TestTRTLLMSingleDeviceConverterFP8: - QUANTIZED_LAYERS = [ - 'transformer.layers.*.attention.dense.weight', - 'transformer.layers.*.attention.qkv.weight', - 'transformer.layers.*.mlp.fc.weight', - 'transformer.layers.*.mlp.proj.weight', - ] - NON_QUANTIZED_LAYERS = [ - 'transformer.layers.*.attention.dense.bias', - 'transformer.layers.*.input_layernorm.weight', - 'transformer.layers.*.input_layernorm.bias', - 'transformer.layers.*.attention.qkv.bias', - 'transformer.layers.*.post_layernorm.weight', - 'transformer.layers.*.post_layernorm.bias', - 'transformer.layers.*.mlp.fc.bias', - 'transformer.layers.*.mlp.proj.bias', - 'transformer.vocab_embedding.weight', - 'transformer.position_embedding.weight', - 'lm_head.weight', - 'transformer.ln_f.weight', - 'transformer.ln_f.bias', - ] - SCALING_FACTORS = [ - 'transformer.layers.*.attention.dense.activation_scaling_factor', - 'transformer.layers.*.attention.dense.weights_scaling_factor', - 'transformer.layers.*.attention.qkv.activation_scaling_factor', - 'transformer.layers.*.attention.qkv.weights_scaling_factor', - 'transformer.layers.*.mlp.fc.activation_scaling_factor', - 'transformer.layers.*.mlp.fc.weights_scaling_factor', - 'transformer.layers.*.mlp.proj.activation_scaling_factor', - 'transformer.layers.*.mlp.proj.weights_scaling_factor', - ] - KV_SCALING_FACTORS = ['transformer.layers.*.attention.kv_cache_scaling_factor'] - - def _assert_has_scales(self, state_dict, quantized): - for layer in range(NUM_LAYERS): - for key in self.SCALING_FACTORS: - k = key.replace('*', str(layer)) - - if quantized: - assert k in state_dict, f'Expected {k} in the converted model' - assert ( - state_dict[k].dtype == torch.float32 - ), 'Scaling factor dtype is expected to be torch.float32' - else: - assert k not in state_dict, f'Did not expect {k} in the converted model' - - def _assert_has_kv_scales(self, state_dict, kv_quantized): - for layer in range(NUM_LAYERS): - for key in self.KV_SCALING_FACTORS: - k = key.replace('*', str(layer)) - - if kv_quantized: - assert k in state_dict, f'Expected {k} in the converted model' - assert ( - state_dict[k].dtype == torch.float32 - ), 'Scaling factor dtype is expected to be torch.float32' - else: - assert k not in state_dict, f'Did not expect {k} in the converted model' - - def _assert_quantizable_layers(self, state_dict, quantized): - expected_dtype = torch.float8_e4m3fn if quantized else DTYPE - - for layer in range(NUM_LAYERS): - for key in self.QUANTIZED_LAYERS: - k = key.replace('*', str(layer)) - - assert k in state_dict, f'Expected {k} in the converted model' - assert ( - state_dict[k].dtype == expected_dtype - ), f'Expected {k} to have the dtype == {str(expected_dtype)}' - - def _assert_non_quantizable_layers(self, state_dict): - expected_dtype = torch.bfloat16 - - for layer in range(NUM_LAYERS): - for key in self.NON_QUANTIZED_LAYERS: - k = key.replace('*', str(layer)) - - assert k in state_dict, f'Expected {k} in the converted model' - assert ( - state_dict[k].dtype == expected_dtype - ), f'Expected {k} to have the dtype == {str(expected_dtype)}' - - def setup_method(self, method): - Utils.initialize_model_parallel(2, 1) - gpt_model = _model_provider() - gpt_model.to(DEVICE) - optim = Adam(gpt_model.parameters()) - train_iterator = _get_train_data_iterator() - forward_backward_func = get_forward_backward_func() - - # Mock training to initialize constants - for _ in range(2): - optim.zero_grad() - forward_backward_func( - forward_step_func=_forward_step_func, - data_iterator=train_iterator, - model=gpt_model, - num_microbatches=1, - seq_length=SEQUENCE_LENGTH, - micro_batch_size=8, - decoder_seq_length=SEQUENCE_LENGTH, - forward_only=False, - ) - optim.step() - - self.gpt_model = gpt_model - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_get_model_weights_converter(self, mocker): - pytest.importorskip('tensorrt_llm') - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.distributed_trtllm_model_weights_converter.str_dtype_to_torch", - return_value=DTYPE, - ) - - from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper - - gpt_model = self.gpt_model - seq_len_interpolation_factor = None - if hasattr(gpt_model, "rotary_pos_emb"): - seq_len_interpolation_factor = gpt_model.rotary_pos_emb.seq_len_interpolation_factor - trtllm_helper = TRTLLMHelper( - transformer_config=gpt_model.config, - model_type=ModelType.gpt, - position_embedding_type=gpt_model.position_embedding_type, - max_position_embeddings=gpt_model.max_position_embeddings, - rotary_percentage=gpt_model.rotary_percent, - rotary_base=gpt_model.rotary_base, - moe_tp_mode=2, - multi_query_mode=False, - activation="gelu", - seq_len_interpolation_factor=seq_len_interpolation_factor, - share_embeddings_and_output_weights=gpt_model.share_embeddings_and_output_weights, - ) - - for fp8_quantized in [True, False]: - for fp8_kvcache in [True, False]: - weight_list, config_list = ( - trtllm_helper.get_trtllm_pretrained_config_and_model_weights( - model_state_dict=gpt_model.state_dict(), - dtype=DataType.bfloat16, - on_device_distributed_conversion=True, - vocab_size=VOCAB_SIZE, - gpus_per_node=2, - fp8_quantized=fp8_quantized, - fp8_kvcache=fp8_kvcache, - ) - ) - - expected_quant = 'FP8' if fp8_quantized else None - expected_kv_quant = 'FP8' if fp8_kvcache else None - assert ( - config_list[0].quantization.quant_algo == expected_quant - ), 'Wrong quantization settings' - assert ( - config_list[0].quantization.kv_cache_quant_algo == expected_kv_quant - ), 'Wrong KV-cache quantization settings' - self._assert_has_scales(weight_list[0], fp8_quantized) - self._assert_has_kv_scales(weight_list[0], fp8_kvcache) - self._assert_quantizable_layers(weight_list[0], fp8_quantized) - self._assert_non_quantizable_layers(weight_list[0]) diff --git a/tests/unit_tests/export/trtllm/test_single_device_fp8.py b/tests/unit_tests/export/trtllm/test_single_device_fp8.py deleted file mode 100644 index 04bbfdb127..0000000000 --- a/tests/unit_tests/export/trtllm/test_single_device_fp8.py +++ /dev/null @@ -1,269 +0,0 @@ -from functools import partial - -import pytest -import torch -from pytest_mock import mocker -from torch.optim import Adam -from torch.utils.data import DataLoader - -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset -from megatron.core.datasets.utils import compile_helpers -from megatron.core.export.data_type import DataType -from megatron.core.export.export_config import ExportConfig -from megatron.core.export.model_type import ModelType -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.pipeline_parallel.schedules import get_forward_backward_func -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training.tokenizer.tokenizer import _NullTokenizer -from tests.unit_tests.test_utilities import Utils - -SEQUENCE_LENGTH = 64 -NUM_LAYERS = 2 -DEVICE = torch.device("cuda") - - -def _model_provider(): - """Build the model.""" - - transformer_config = TransformerConfig( - num_layers=NUM_LAYERS, - hidden_size=64, - num_attention_heads=2, - use_cpu_initialization=True, - pipeline_dtype=torch.float32, - fp8='hybrid', - fp8_margin=0, - fp8_interval=1, - fp8_amax_history_len=1024, - fp8_amax_compute_algo="max", - ) - - gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), - vocab_size=100, - max_sequence_length=SEQUENCE_LENGTH, - ) - - return gpt_model - - -def _get_train_data_iterator(): - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - compile_helpers() - torch.distributed.barrier() - else: - compile_helpers() - - config = GPTDatasetConfig( - random_seed=0, - sequence_length=SEQUENCE_LENGTH, - reset_position_ids=False, - reset_attention_mask=False, - eod_mask_loss=False, - tokenizer=_NullTokenizer(vocab_size=50), - mid_level_dataset_surplus=0.005, - ) - - datasets = BlendedMegatronDatasetBuilder( - MockGPTDataset, [1000, None, None], lambda: True, config - ).build() - - train_dataloader = DataLoader(datasets[0], batch_size=8, shuffle=True) - - train_iterator = iter(train_dataloader) - - return train_iterator - - -def _forward_step_func(data_iterator, model): - - def _loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): - - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - # If you have data parallel reduce loss across data parallel groups. - # If pipeline parallel, loss computation is done only in last stage. - - return loss, {'lm loss': loss} - - data = next(data_iterator) - tokens = torch.ones_like(data['tokens']).to(DEVICE) - attention_mask = data['attention_mask'].to(DEVICE) - position_ids = data['position_ids'].to(DEVICE) - labels = data['labels'].to(DEVICE) - loss_mask = data['loss_mask'].to(DEVICE) - output_tensor = model(tokens, position_ids, attention_mask, labels=labels) - - return output_tensor, partial(_loss_func, loss_mask) - - -class TestTRTLLMSingleDeviceConverterFP8: - QUANTIZED_LAYERS = [ - 'transformer.layers.*.attention.dense.weight', - 'transformer.layers.*.attention.qkv.weight', - 'transformer.layers.*.mlp.fc.weight', - 'transformer.layers.*.mlp.proj.weight', - ] - NON_QUANTIZED_LAYERS = [ - 'transformer.layers.*.attention.dense.bias', - 'transformer.layers.*.input_layernorm.weight', - 'transformer.layers.*.input_layernorm.bias', - 'transformer.layers.*.attention.qkv.bias', - 'transformer.layers.*.post_layernorm.weight', - 'transformer.layers.*.post_layernorm.bias', - 'transformer.layers.*.mlp.fc.bias', - 'transformer.layers.*.mlp.proj.bias', - 'transformer.vocab_embedding.weight', - 'transformer.position_embedding.weight', - 'lm_head.weight', - 'transformer.ln_f.weight', - 'transformer.ln_f.bias', - ] - SCALING_FACTORS = [ - 'transformer.layers.*.attention.dense.activation_scaling_factor', - 'transformer.layers.*.attention.dense.weights_scaling_factor', - 'transformer.layers.*.attention.qkv.activation_scaling_factor', - 'transformer.layers.*.attention.qkv.weights_scaling_factor', - 'transformer.layers.*.mlp.fc.activation_scaling_factor', - 'transformer.layers.*.mlp.fc.weights_scaling_factor', - 'transformer.layers.*.mlp.proj.activation_scaling_factor', - 'transformer.layers.*.mlp.proj.weights_scaling_factor', - ] - KV_SCALING_FACTORS = ['transformer.layers.*.attention.kv_cache_scaling_factor'] - - def _assert_has_scales(self, state_dict, quantized): - for layer in range(NUM_LAYERS): - for key in self.SCALING_FACTORS: - k = key.replace('*', str(layer)) - - if quantized: - assert k in state_dict, f'Expected {k} in the converted model' - assert ( - state_dict[k].dtype == torch.float32 - ), 'Scaling factor dtype is expected to be torch.float32' - else: - assert k not in state_dict, f'Did not expect {k} in the converted model' - - def _assert_has_kv_scales(self, state_dict, kv_quantized): - for layer in range(NUM_LAYERS): - for key in self.KV_SCALING_FACTORS: - k = key.replace('*', str(layer)) - - if kv_quantized: - assert k in state_dict, f'Expected {k} in the converted model' - assert ( - state_dict[k].dtype == torch.float32 - ), 'Scaling factor dtype is expected to be torch.float32' - else: - assert k not in state_dict, f'Did not expect {k} in the converted model' - - def _assert_quantizable_layers(self, state_dict, quantized): - expected_dtype = torch.float8_e4m3fn if quantized else torch.bfloat16 - - for layer in range(NUM_LAYERS): - for key in self.QUANTIZED_LAYERS: - k = key.replace('*', str(layer)) - - assert k in state_dict, f'Expected {k} in the converted model' - assert ( - state_dict[k].dtype == expected_dtype - ), f'Expected {k} to have the dtype == {str(expected_dtype)}' - - def _assert_non_quantizable_layers(self, state_dict): - expected_dtype = torch.bfloat16 - - for layer in range(NUM_LAYERS): - for key in self.NON_QUANTIZED_LAYERS: - k = key.replace('*', str(layer)) - - assert k in state_dict, f'Expected {k} in the converted model' - assert ( - state_dict[k].dtype == expected_dtype - ), f'Expected {k} to have the dtype == {str(expected_dtype)}' - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - gpt_model = _model_provider() - gpt_model.to(DEVICE) - optim = Adam(gpt_model.parameters()) - train_iterator = _get_train_data_iterator() - forward_backward_func = get_forward_backward_func() - - # Mock training to initialize constants - for _ in range(2): - optim.zero_grad() - forward_backward_func( - forward_step_func=_forward_step_func, - data_iterator=train_iterator, - model=gpt_model, - num_microbatches=1, - seq_length=SEQUENCE_LENGTH, - micro_batch_size=8, - decoder_seq_length=SEQUENCE_LENGTH, - forward_only=False, - ) - optim.step() - - self.gpt_model = gpt_model - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_get_model_weights_converter(self, mocker): - pytest.importorskip('tensorrt_llm') - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.distributed_trtllm_model_weights_converter.str_dtype_to_torch", - return_value=torch.float32, - ) - - from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper - - gpt_model = self.gpt_model - export_config = ExportConfig(inference_tp_size=2) - - seq_len_interpolation_factor = None - if hasattr(gpt_model, "rotary_pos_emb"): - seq_len_interpolation_factor = gpt_model.rotary_pos_emb.seq_len_interpolation_factor - trtllm_helper = TRTLLMHelper( - transformer_config=gpt_model.config, - model_type=ModelType.gpt, - position_embedding_type=gpt_model.position_embedding_type, - max_position_embeddings=gpt_model.max_position_embeddings, - rotary_percentage=gpt_model.rotary_percent, - rotary_base=gpt_model.rotary_base, - moe_tp_mode=2, - multi_query_mode=False, - activation="gelu", - seq_len_interpolation_factor=seq_len_interpolation_factor, - share_embeddings_and_output_weights=gpt_model.share_embeddings_and_output_weights, - ) - - for fp8_quantized in [True, False]: - for fp8_kvcache in [True, False]: - weight_list, config_list = ( - trtllm_helper.get_trtllm_pretrained_config_and_model_weights( - model_state_dict=gpt_model.state_dict(), - dtype=DataType.bfloat16, - export_config=export_config, - fp8_quantized=fp8_quantized, - fp8_kvcache=fp8_kvcache, - ) - ) - - expected_quant = 'FP8' if fp8_quantized else None - expected_kv_quant = 'FP8' if fp8_kvcache else None - assert ( - config_list[0].quantization.quant_algo == expected_quant - ), 'Wrong quantization settings' - assert ( - config_list[0].quantization.kv_cache_quant_algo == expected_kv_quant - ), 'Wrong KV-cache quantization settings' - self._assert_has_scales(weight_list[0], fp8_quantized) - self._assert_has_kv_scales(weight_list[0], fp8_kvcache) - self._assert_quantizable_layers(weight_list[0], fp8_quantized) - self._assert_non_quantizable_layers(weight_list[0]) diff --git a/tests/unit_tests/export/trtllm/test_trtllm_distributed_gpu_converter.py b/tests/unit_tests/export/trtllm/test_trtllm_distributed_gpu_converter.py deleted file mode 100644 index 6a5ccb04a2..0000000000 --- a/tests/unit_tests/export/trtllm/test_trtllm_distributed_gpu_converter.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -from pytest_mock import mocker - -from megatron.core.export.data_type import DataType -from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import ( - DEFAULT_CONVERSION_DICT, -) - -# pylint: disable=line-too-long -from megatron.core.export.trtllm.trtllm_weights_converter.distributed_trtllm_model_weights_converter import ( - DistributedTRTLLMModelWeightsConverter, -) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - -_SEQUENCE_LENGTH = 64 -_VOCAB_SIZE = 256 - - -class TestTRTLLMDistributedGPUConverter: - """ - Test Distributed converter - """ - - def setup_method(self, method): - """ - Setup method - """ - Utils.initialize_model_parallel(2, 1) - model_parallel_cuda_manual_seed(123) - - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=64, - num_attention_heads=2, - use_cpu_initialization=True, - pipeline_dtype=torch.float32, - add_qkv_bias=False, - add_bias_linear=False, - ) - self.gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_local_spec(), - vocab_size=_VOCAB_SIZE, - max_sequence_length=_SEQUENCE_LENGTH, - ) - - def teardown_method(self, method): - """ - teardown method - """ - Utils.destroy_model_parallel() - - def test_get_model_weights_converter(self, mocker): - """ - test model weights onverter - """ - device = torch.device("cuda") - self.gpt_model.to(device) - - transformer_config = self.gpt_model.config - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.distributed_trtllm_model_weights_converter.str_dtype_to_torch", - return_value=torch.float32, - ) - - dtype = DataType.bfloat16 - distributed_converter = DistributedTRTLLMModelWeightsConverter( - transformer_config, dtype, activation="gelu" - ) - - model_state_dict = {} - for key, val in self.gpt_model.state_dict().items(): - # val is non for _extra_state layers . We filter it out - if val is not None: - model_state_dict[key] = val - - distributed_converter.convert( - model_state_dict=model_state_dict, - trtllm_conversion_dict=DEFAULT_CONVERSION_DICT, - tokenizer_vocab_size=_VOCAB_SIZE, - ) - - expected_result = { - 'transformer.vocab_embedding.weight': torch.Size([128, 64]), - 'transformer.position_embedding.weight': torch.Size([32, 64]), - 'lm_head.weight': torch.Size([128, 64]), - 'transformer.ln_f.weight': torch.Size([64]), - 'transformer.ln_f.bias': torch.Size([64]), - 'transformer.layers.0.input_layernorm.weight': torch.Size([64]), - 'transformer.layers.0.input_layernorm.bias': torch.Size([64]), - 'transformer.layers.0.attention.dense.weight': torch.Size([64, 32]), - 'transformer.layers.0.attention.qkv.weight': torch.Size([96, 64]), - 'transformer.layers.0.post_layernorm.weight': torch.Size([64]), - 'transformer.layers.0.post_layernorm.bias': torch.Size([64]), - 'transformer.layers.0.mlp.fc.weight': torch.Size([128, 64]), - 'transformer.layers.0.mlp.proj.weight': torch.Size([64, 128]), - 'transformer.layers.1.input_layernorm.weight': torch.Size([64]), - 'transformer.layers.1.input_layernorm.bias': torch.Size([64]), - 'transformer.layers.1.attention.dense.weight': torch.Size([64, 32]), - 'transformer.layers.1.attention.qkv.weight': torch.Size([96, 64]), - 'transformer.layers.1.post_layernorm.weight': torch.Size([64]), - 'transformer.layers.1.post_layernorm.bias': torch.Size([64]), - 'transformer.layers.1.mlp.fc.weight': torch.Size([128, 64]), - 'transformer.layers.1.mlp.proj.weight': torch.Size([64, 128]), - } - - for key, value in distributed_converter.trtllm_model_weights.items(): - assert ( - expected_result[key] == value.shape - ), f"Shape mismatch for {key}. Expected {expected_result[key]} but got {value.shape}" diff --git a/tests/unit_tests/export/trtllm/test_trtllm_helper.py b/tests/unit_tests/export/trtllm/test_trtllm_helper.py deleted file mode 100644 index d9764dc8fd..0000000000 --- a/tests/unit_tests/export/trtllm/test_trtllm_helper.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest - -from megatron.core.export.export_config import ExportConfig -from megatron.core.export.model_type import ModelType - - -# TODO : Remove importorskip and handle with mocker -class TestTRTLLMHelper: - - def test_exceptions(self, mocker): - pytest.importorskip('tensorrt_llm') - - from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper - - trtllm_helper = TRTLLMHelper( - transformer_config=None, - model_type=ModelType.gpt, - share_embeddings_and_output_weights=True, - ) - - with pytest.raises(AssertionError): - trtllm_helper.get_trtllm_pretrained_config_and_model_weights( - model_state_dict=None, - dtype=None, - on_device_distributed_conversion=True, - vocab_size=None, - gpus_per_node=2, - ) - - with pytest.raises(AssertionError): - trtllm_helper.get_trtllm_pretrained_config_and_model_weights( - model_state_dict=None, - dtype=None, - on_device_distributed_conversion=True, - vocab_size=100, - gpus_per_node=2, - ) - - with pytest.raises(AssertionError): - trtllm_helper.get_trtllm_pretrained_config_and_model_weights( - model_state_dict=None, - dtype=None, - export_config=ExportConfig(), - on_device_distributed_conversion=True, - vocab_size=100, - gpus_per_node=2, - ) - - with pytest.raises(AssertionError): - trtllm_helper.get_trtllm_pretrained_config_and_model_weights( - model_state_dict=None, - dtype=None, - vocab_size=100, - on_device_distributed_conversion=True, - gpus_per_node=None, - ) - - with pytest.raises(AssertionError): - trtllm_helper.get_trtllm_pretrained_config_and_model_weights( - model_state_dict=None, - dtype=None, - export_config=ExportConfig(use_embedding_sharing=False), - on_device_distributed_conversion=False, - ) - - with pytest.raises(AssertionError): - trtllm_helper.get_trtllm_pretrained_config_and_model_weights( - model_state_dict=None, - dtype=None, - export_config=ExportConfig(use_embedding_sharing=True), - vocab_size=100, - ) diff --git a/tests/unit_tests/export/trtllm/test_trtllm_layers.py b/tests/unit_tests/export/trtllm/test_trtllm_layers.py deleted file mode 100644 index b2e88852e5..0000000000 --- a/tests/unit_tests/export/trtllm/test_trtllm_layers.py +++ /dev/null @@ -1,111 +0,0 @@ -import pytest - -from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers, get_layer_name_without_prefix - - -class TestTRTLLMLayers: - - def test_rename_input_layer_names_to_trtllm_layer_names_without_layer_numbers(self): - - conversion_dict = { - "transformer.layers.attn.dense.bias": TRTLLMLayers.attention_dense_bias, - "transformer.layers.mlp.fc1.weight": TRTLLMLayers.mlp_fc_weight, - } - sample_dict = { - "transformer.layers.attn.dense.bias": 0, - "transformer.layers.mlp.fc1.weight": 1, - } - - converted_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names( - model_state_dict=sample_dict, - trtllm_conversion_dict=conversion_dict, - state_dict_split_by_layer_numbers=False, - ) - assert ( - converted_dict[TRTLLMLayers.attention_dense_bias.value] == 0 - ), "Something wrong with conversion dict" - assert ( - converted_dict[TRTLLMLayers.mlp_fc_weight.value] == 1 - ), "Something wrong with conversion dict" - - def test_rename_input_layer_names_to_trtllm_layer_names_exception(self): - - with pytest.raises(AssertionError): - conversion_dict = { - "transformer.layers.attn.dense.bias": "randomValue", - "transformer.layers.mlp.fc1.weight": TRTLLMLayers.mlp_fc_weight, - } - sample_dict = { - "transformer.layers.attn.dense.bias": 0, - "transformer.layers.mlp.fc1.weight": 1, - } - TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names( - model_state_dict=sample_dict, - trtllm_conversion_dict=conversion_dict, - state_dict_split_by_layer_numbers=False, - ) - - with pytest.raises(Exception): - sample_dict = { - "transformer.layers.attn.dense.bias": 0, - "transformer.layers.mlp.fc1.weight": 1, - } - del conversion_dict["attn.dense.bias"] - TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names( - model_state_dict=sample_dict, - trtllm_conversion_dict=conversion_dict, - state_dict_split_by_layer_numbers=False, - ) - - with pytest.raises(Exception): - conversion_dict = { - "transformer.layers.attn.dense.bias": TRTLLMLayers.attention_dense_bias, - "transformer.layers.mlp.fc1.weight": TRTLLMLayers.mlp_fc_weight, - } - sample_dict = { - "transformer.layers.attn.dense.bias": 0, - "transformer.layers.mlp.fc1.weight": 1, - } - - TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names( - model_state_dict=sample_dict, - trtllm_conversion_dict=conversion_dict, - state_dict_split_by_layer_numbers=True, - ) - - def test_rename_input_layer_names_to_trtllm_layer_names_with_layer_numbers(self): - - conversion_dict = { - "decoder.lm_head.weight": TRTLLMLayers.lm_head, - "decoder.layers.attn.dense.bias": TRTLLMLayers.attention_dense_bias, - "deocder.layers.mlp.fc1.weight": TRTLLMLayers.mlp_fc_weight, - } - sample_dict = { - "decoder.lm_head.weight": 2, - "decoder.layers.0.attn.dense.bias": 0, - "deocder.layers.43.mlp.fc1.weight": 1, - } - - converted_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names( - model_state_dict=sample_dict, - trtllm_conversion_dict=conversion_dict, - state_dict_split_by_layer_numbers=False, - ) - - assert ( - converted_dict['transformer.layers.0.attention.dense.bias'] == 0 - ), "Something wrong with conversion of layer names" - assert ( - converted_dict['transformer.layers.43.mlp.fc.weight'] == 1 - ), "Something wrong with conversion of layer names" - assert ( - converted_dict['lm_head.weight'] == 2 - ), "Something wrong with conversion of layer names" - - def test_get_layer_name_without_prefix(self): - layer_name_without_prefix = get_layer_name_without_prefix( - TRTLLMLayers.attention_dense_weight - ) - assert ( - layer_name_without_prefix == "attention.dense.weight" - ), f"get_layer_name_without_prefix returned {layer_name_without_prefix}, expected attention.dense.weight" diff --git a/tests/unit_tests/export/trtllm/test_trtllm_single_device_converter.py b/tests/unit_tests/export/trtllm/test_trtllm_single_device_converter.py deleted file mode 100644 index 733eed1745..0000000000 --- a/tests/unit_tests/export/trtllm/test_trtllm_single_device_converter.py +++ /dev/null @@ -1,612 +0,0 @@ -import pytest -import torch -from pytest_mock import mocker - -from megatron.core.export.data_type import DataType -from megatron.core.export.export_config import ExportConfig -from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers -from megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter import ( - SingleDeviceTRTLLMModelWeightsConverter, -) -from megatron.core.transformer.transformer_config import TransformerConfig - - -class TestTRTLLMSingleDeviceConverter: - def test_get_model_weights_converter(self, mocker): - - export_config = ExportConfig(inference_tp_size=2) - - vocab_size = 10 - hidden_dim = 4 - seq_len = 8 - num_layers = 2 - num_attn_heads = 2 - - model_config = TransformerConfig( - num_layers=num_layers, - num_attention_heads=num_attn_heads, - num_query_groups=0, - hidden_size=hidden_dim, - ffn_hidden_size=hidden_dim * 4, - ) - - dtype = DataType.bfloat16 - - model_state_dict = { - "decoder.position_embedding.weight": torch.randn(seq_len, hidden_dim), - "decoder.word_embedding.weight": torch.randn(vocab_size, hidden_dim), - "decoder.lm_head.weight": torch.randn(vocab_size, hidden_dim), - "decoder.final_layernorm.weight": torch.randn(hidden_dim), - "decoder.layers.input_layernorm.weight": torch.randn(num_layers, hidden_dim), - "decoder.layers.attention.qkv.weight": torch.randn( - num_layers, hidden_dim * 3, hidden_dim - ), - "decoder.layers.attention.qkv.bias": torch.randn(num_layers, hidden_dim * 3), - "decoder.layers.attention.dense.weight": torch.randn( - num_layers, hidden_dim, hidden_dim - ), - "deocder.layers.mlp.fc.weight": torch.randn(num_layers, 4 * hidden_dim, hidden_dim), - "decoder.layers.mlp.fc.expert": torch.randn(num_layers, hidden_dim, hidden_dim * 4), - "decoder.layers.mlp.proj.expert": torch.randn(num_layers, hidden_dim * 4, hidden_dim), - } - - trtllm_conversion_dict = { - "decoder.position_embedding.weight": TRTLLMLayers.position_embedding, - "decoder.word_embedding.weight": TRTLLMLayers.vocab_embedding, - "decoder.final_layernorm.weight": TRTLLMLayers.final_layernorm_weight, - "decoder.lm_head.weight": TRTLLMLayers.lm_head, - "decoder.layers.input_layernorm.weight": TRTLLMLayers.input_layernorm_weight, - "decoder.layers.attention.qkv.weight": TRTLLMLayers.attention_qkv_weight, - "decoder.layers.attention.qkv.bias": TRTLLMLayers.attention_qkv_bias, - "decoder.layers.attention.dense.weight": TRTLLMLayers.attention_dense_weight, - "deocder.layers.mlp.fc.weight": TRTLLMLayers.mlp_fc_weight, - "decoder.layers.mlp.fc.expert": TRTLLMLayers.mlp_fc_weight_mixture_of_experts, - "decoder.layers.mlp.proj.expert": TRTLLMLayers.mlp_projection_weight_mixture_of_experts, - } - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.str_dtype_to_torch", - return_value=torch.float32, - ) - - trtllm_model_weights_converter_cpu = SingleDeviceTRTLLMModelWeightsConverter( - export_config, model_config, dtype, activation="swiglu" - ) - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.pad_vocab_size", - return_value=10, - ) - - trtllm_model_weights_converter_cpu.convert( - model_state_dict=model_state_dict, - trtllm_conversion_dict=trtllm_conversion_dict, - state_dict_split_by_layer_numbers=False, - ) - - expected_shapes = { - 'transformer.vocab_embedding.weight': (10, 4), - 'transformer.position_embedding.weight': (8, 4), - 'lm_head.weight': (10, 4), - 'transformer.ln_f.weight': (4,), - 'transformer.layers.0.input_layernorm.weight': (4,), - 'transformer.layers.1.input_layernorm.weight': (4,), - 'transformer.layers.0.attention.qkv.weight.0.bin': (6, 4), - 'transformer.layers.0.attention.qkv.weight.1.bin': (6, 4), - 'transformer.layers.1.attention.qkv.weight.0.bin': (6, 4), - 'transformer.layers.1.attention.qkv.weight.1.bin': (6, 4), - 'transformer.layers.0.attention.qkv.bias.0.bin': (6,), - 'transformer.layers.0.attention.qkv.bias.1.bin': (6,), - 'transformer.layers.1.attention.qkv.bias.0.bin': (6,), - 'transformer.layers.1.attention.qkv.bias.1.bin': (6,), - 'transformer.layers.0.attention.dense.weight.0.bin': (4, 2), - 'transformer.layers.0.attention.dense.weight.1.bin': (4, 2), - 'transformer.layers.1.attention.dense.weight.0.bin': (4, 2), - 'transformer.layers.1.attention.dense.weight.1.bin': (4, 2), - 'transformer.layers.0.mlp.gate.weight.0.bin': (4, 4), - 'transformer.layers.0.mlp.gate.weight.1.bin': (4, 4), - 'transformer.layers.0.mlp.fc.weight.0.bin': (16, 2), - 'transformer.layers.0.mlp.fc.weight.1.bin': (16, 2), - 'transformer.layers.1.mlp.gate.weight.0.bin': (4, 4), - 'transformer.layers.1.mlp.gate.weight.1.bin': (4, 4), - 'transformer.layers.1.mlp.fc.weight.0.bin': (16, 2), - 'transformer.layers.1.mlp.fc.weight.1.bin': (16, 2), - 'transformer.layers.0.mlp.proj.weight.0.bin': (4, 8), - 'transformer.layers.0.mlp.proj.weight.1.bin': (4, 8), - 'transformer.layers.1.mlp.proj.weight.0.bin': (4, 8), - 'transformer.layers.1.mlp.proj.weight.1.bin': (4, 8), - } - - for key, value in trtllm_model_weights_converter_cpu.trtllm_model_weights.items(): - assert ( - expected_shapes[key] == value.shape - ), f"Shape mismatch for {key}. Expected {expected_shapes[key]} but got {value.shape}" - - class SampleMapping: - - def __init__(self): - self.tp_size = 2 - self.tp_rank = 1 - - def pp_layers(self, num_layers): - return [0, 1] - - def is_first_pp_rank(self): - return True - - def is_last_pp_rank(self): - return True - - trtllm_model_weights_per_gpu = ( - trtllm_model_weights_converter_cpu.get_local_model_weights_per_gpu( - mapping=SampleMapping(), trtllm_model_config=None - ) - ) - - expected_result_per_gpu = { - 'transformer.layers.0.input_layernorm.weight': (4,), - 'transformer.layers.1.input_layernorm.weight': (4,), - 'transformer.layers.0.attention.qkv.weight': (6, 4), - 'transformer.layers.1.attention.qkv.weight': (6, 4), - 'transformer.layers.0.attention.qkv.bias': (6,), - 'transformer.layers.1.attention.qkv.bias': (6,), - 'transformer.layers.0.attention.dense.weight': (4, 2), - 'transformer.layers.1.attention.dense.weight': (4, 2), - 'transformer.layers.0.mlp.gate.weight': (4, 4), - 'transformer.layers.0.mlp.fc.weight': (16, 2), - 'transformer.layers.1.mlp.gate.weight': (4, 4), - 'transformer.layers.1.mlp.fc.weight': (16, 2), - 'transformer.layers.0.mlp.proj.weight': (4, 8), - 'transformer.layers.1.mlp.proj.weight': (4, 8), - 'transformer.vocab_embedding.weight': (10, 4), - 'transformer.position_embedding.weight': (8, 4), - 'lm_head.weight': (5, 4), - 'transformer.ln_f.weight': (4,), - } - - for key, value in trtllm_model_weights_per_gpu.items(): - assert ( - expected_result_per_gpu[key] == value.shape - ), f"Shape mismatch for {key}. Expected {expected_result_per_gpu[key]} but got {value.shape}" - - def test_num_kv_heads_less_than_tp_size_valid(self, mocker): - """Test the condition where num_kv_heads < inference_tp_size and tp_size % num_kv_heads == 0 (valid case)""" - - # Configure for GQA: 8 attention heads, 2 KV heads, TP size 4 - # This is valid because 4 % 2 == 0 - export_config = ExportConfig(inference_tp_size=4) - - vocab_size = 10 - hidden_dim = 8 - seq_len = 8 - num_layers = 2 - num_attn_heads = 8 - num_kv_heads = 2 # This is less than tp_size (4) and 4 % 2 == 0 - - model_config = TransformerConfig( - num_layers=num_layers, - num_attention_heads=num_attn_heads, - num_query_groups=num_kv_heads, # GQA with 2 KV heads - hidden_size=hidden_dim, - ffn_hidden_size=hidden_dim * 4, - ) - - dtype = DataType.bfloat16 - - # Create model state dict with GQA structure - # For GQA: q_num = num_attn_heads // num_kv_heads = 8 // 2 = 4 - # So each KV head handles 4 query heads - q_num = num_attn_heads // num_kv_heads # 4 - size_per_head = hidden_dim // num_attn_heads # 1 - - # Calculate the correct tensor sizes for QKV - qkv_weight_size = num_kv_heads * (q_num + 2) * size_per_head # 2 * (4 + 2) * 1 = 12 - qkv_bias_size = num_kv_heads * (q_num + 2) * size_per_head # 2 * (4 + 2) * 1 = 12 - - model_state_dict = { - "decoder.position_embedding.weight": torch.randn(seq_len, hidden_dim), - "decoder.word_embedding.weight": torch.randn(vocab_size, hidden_dim), - "decoder.lm_head.weight": torch.randn(vocab_size, hidden_dim), - "decoder.final_layernorm.weight": torch.randn(hidden_dim), - "decoder.layers.input_layernorm.weight": torch.randn(num_layers, hidden_dim), - # QKV weight: [num_layers, qkv_weight_size, hidden_dim] - converter will transpose to [hidden_dim, qkv_weight_size] - "decoder.layers.attention.qkv.weight": torch.randn( - num_layers, qkv_weight_size, hidden_dim - ), - # QKV bias: [num_layers, qkv_bias_size] - "decoder.layers.attention.qkv.bias": torch.randn(num_layers, qkv_bias_size), - "decoder.layers.attention.dense.weight": torch.randn( - num_layers, hidden_dim, hidden_dim - ), - "decoder.layers.mlp.fc.weight": torch.randn(num_layers, 4 * hidden_dim, hidden_dim), - } - - trtllm_conversion_dict = { - "decoder.position_embedding.weight": TRTLLMLayers.position_embedding, - "decoder.word_embedding.weight": TRTLLMLayers.vocab_embedding, - "decoder.final_layernorm.weight": TRTLLMLayers.final_layernorm_weight, - "decoder.lm_head.weight": TRTLLMLayers.lm_head, - "decoder.layers.input_layernorm.weight": TRTLLMLayers.input_layernorm_weight, - "decoder.layers.attention.qkv.weight": TRTLLMLayers.attention_qkv_weight, - "decoder.layers.attention.qkv.bias": TRTLLMLayers.attention_qkv_bias, - "decoder.layers.attention.dense.weight": TRTLLMLayers.attention_dense_weight, - "decoder.layers.mlp.fc.weight": TRTLLMLayers.mlp_fc_weight, - } - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.str_dtype_to_torch", - return_value=torch.float32, - ) - - trtllm_model_weights_converter_cpu = SingleDeviceTRTLLMModelWeightsConverter( - export_config, model_config, dtype, activation="gelu" - ) - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.pad_vocab_size", - return_value=10, - ) - - # Verify the conditions are met - assert trtllm_model_weights_converter_cpu.num_kv_heads < export_config.inference_tp_size - assert ( - export_config.inference_tp_size % trtllm_model_weights_converter_cpu.num_kv_heads == 0 - ) - assert trtllm_model_weights_converter_cpu.num_kv_heads == 2 - assert export_config.inference_tp_size == 4 - - trtllm_model_weights_converter_cpu.convert( - model_state_dict=model_state_dict, - trtllm_conversion_dict=trtllm_conversion_dict, - state_dict_split_by_layer_numbers=False, - ) - - # Check that QKV weights and biases are properly split for TP=4 - # Each TP rank should get 1/4 of the weights - for layer_idx in range(num_layers): - qkv_weight_key = f"transformer.layers.{layer_idx}.attention.qkv.weight" - qkv_bias_key = f"transformer.layers.{layer_idx}.attention.qkv.bias" - - # Check that we have 4 splits (one for each TP rank) - for tp_rank in range(4): - weight_split_key = f"{qkv_weight_key}.{tp_rank}.bin" - bias_split_key = f"{qkv_bias_key}.{tp_rank}.bin" - - assert weight_split_key in trtllm_model_weights_converter_cpu.trtllm_model_weights - assert bias_split_key in trtllm_model_weights_converter_cpu.trtllm_model_weights - - # Verify that the splits have the expected dimensions - weight_split = trtllm_model_weights_converter_cpu.trtllm_model_weights[ - weight_split_key - ] - bias_split = trtllm_model_weights_converter_cpu.trtllm_model_weights[bias_split_key] - - # For TP=4, each split should have 1/4 of the original size - # The weight shape depends on the conversion process, not necessarily hidden_dim - assert ( - len(bias_split.shape) == 1 - ), f"Expected bias to be 1D, got shape {bias_split.shape}" - - # Verify that all splits have the same size - if tp_rank == 0: - expected_weight_size = weight_split.shape[1] - expected_bias_size = bias_split.shape[0] - else: - assert ( - weight_split.shape[1] == expected_weight_size - ), f"All weight splits should have same size" - assert ( - bias_split.shape[0] == expected_bias_size - ), f"All bias splits should have same size" - - def test_num_kv_heads_less_than_tp_size_invalid(self, mocker): - """Test that an exception is raised when num_kv_heads < tp_size but tp_size % num_kv_heads != 0""" - - # Configure for invalid case: 3 KV heads, TP size 4 (not divisible) - # This should raise an exception because 4 % 3 != 0 - export_config = ExportConfig(inference_tp_size=4) - - vocab_size = 10 - hidden_dim = 8 - seq_len = 8 - num_layers = 2 - num_attn_heads = 6 - num_kv_heads = 3 # This is less than tp_size (4) but 4 % 3 != 0 - - model_config = TransformerConfig( - num_layers=num_layers, - num_attention_heads=num_attn_heads, - num_query_groups=num_kv_heads, - hidden_size=hidden_dim, - ffn_hidden_size=hidden_dim * 4, - ) - - dtype = DataType.bfloat16 - - q_num = num_attn_heads // num_kv_heads # 2 - size_per_head = hidden_dim // num_attn_heads # 1 - - # Calculate the correct tensor sizes for QKV - qkv_weight_size = num_kv_heads * (q_num + 2) * size_per_head # 3 * (2 + 2) * 1 = 12 - qkv_bias_size = num_kv_heads * (q_num + 2) * size_per_head # 3 * (2 + 2) * 1 = 12 - - model_state_dict = { - "decoder.position_embedding.weight": torch.randn(seq_len, hidden_dim), - "decoder.word_embedding.weight": torch.randn(vocab_size, hidden_dim), - "decoder.lm_head.weight": torch.randn(vocab_size, hidden_dim), - "decoder.final_layernorm.weight": torch.randn(hidden_dim), - "decoder.layers.input_layernorm.weight": torch.randn(num_layers, hidden_dim), - "decoder.layers.attention.qkv.weight": torch.randn( - num_layers, qkv_weight_size, hidden_dim - ), - "decoder.layers.attention.qkv.bias": torch.randn(num_layers, qkv_bias_size), - "decoder.layers.attention.dense.weight": torch.randn( - num_layers, hidden_dim, hidden_dim - ), - "decoder.layers.mlp.fc.weight": torch.randn(num_layers, 4 * hidden_dim, hidden_dim), - } - - trtllm_conversion_dict = { - "decoder.position_embedding.weight": TRTLLMLayers.position_embedding, - "decoder.word_embedding.weight": TRTLLMLayers.vocab_embedding, - "decoder.final_layernorm.weight": TRTLLMLayers.final_layernorm_weight, - "decoder.lm_head.weight": TRTLLMLayers.lm_head, - "decoder.layers.input_layernorm.weight": TRTLLMLayers.input_layernorm_weight, - "decoder.layers.attention.qkv.weight": TRTLLMLayers.attention_qkv_weight, - "decoder.layers.attention.qkv.bias": TRTLLMLayers.attention_qkv_bias, - "decoder.layers.attention.dense.weight": TRTLLMLayers.attention_dense_weight, - "decoder.layers.mlp.fc.weight": TRTLLMLayers.mlp_fc_weight, - } - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.str_dtype_to_torch", - return_value=torch.float32, - ) - - trtllm_model_weights_converter_cpu = SingleDeviceTRTLLMModelWeightsConverter( - export_config, model_config, dtype, activation="gelu" - ) - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.pad_vocab_size", - return_value=10, - ) - - # Verify the conditions are met for the invalid case - assert trtllm_model_weights_converter_cpu.num_kv_heads < export_config.inference_tp_size - assert ( - export_config.inference_tp_size % trtllm_model_weights_converter_cpu.num_kv_heads != 0 - ) - assert trtllm_model_weights_converter_cpu.num_kv_heads == 3 - assert export_config.inference_tp_size == 4 - - # This should raise an exception during conversion - with pytest.raises(Exception) as exc_info: - trtllm_model_weights_converter_cpu.convert( - model_state_dict=model_state_dict, - trtllm_conversion_dict=trtllm_conversion_dict, - state_dict_split_by_layer_numbers=False, - ) - - # Verify the exception message - expected_message = "Number of query groups of the models is 3. Please select tensor parallelism size that can duplicate or split the number of query groups to equal number of query matrices in the each GPU." - assert expected_message in str(exc_info.value) - - def test_num_kv_heads_greater_equal_tp_size_invalid(self, mocker): - """Test that an exception is raised when num_kv_heads >= tp_size but num_kv_heads % tp_size != 0""" - - # Configure for invalid case: 5 KV heads, TP size 4 (not divisible) - # This should raise an exception because 5 % 4 != 0 - export_config = ExportConfig(inference_tp_size=4) - - vocab_size = 10 - hidden_dim = 8 - seq_len = 8 - num_layers = 2 - num_attn_heads = 10 - num_kv_heads = 5 # This is greater than tp_size (4) but 5 % 4 != 0 - - model_config = TransformerConfig( - num_layers=num_layers, - num_attention_heads=num_attn_heads, - num_query_groups=num_kv_heads, - hidden_size=hidden_dim, - ffn_hidden_size=hidden_dim * 4, - ) - - dtype = DataType.bfloat16 - - q_num = num_attn_heads // num_kv_heads # 2 - size_per_head = hidden_dim // num_attn_heads # 1 - - # Calculate the correct tensor sizes for QKV - qkv_weight_size = num_kv_heads * (q_num + 2) * size_per_head # 5 * (2 + 2) * 1 = 20 - qkv_bias_size = num_kv_heads * (q_num + 2) * size_per_head # 5 * (2 + 2) * 1 = 20 - - model_state_dict = { - "decoder.position_embedding.weight": torch.randn(seq_len, hidden_dim), - "decoder.word_embedding.weight": torch.randn(vocab_size, hidden_dim), - "decoder.lm_head.weight": torch.randn(vocab_size, hidden_dim), - "decoder.final_layernorm.weight": torch.randn(hidden_dim), - "decoder.layers.input_layernorm.weight": torch.randn(num_layers, hidden_dim), - "decoder.layers.attention.qkv.weight": torch.randn( - num_layers, qkv_weight_size, hidden_dim - ), - "decoder.layers.attention.qkv.bias": torch.randn(num_layers, qkv_bias_size), - "decoder.layers.attention.dense.weight": torch.randn( - num_layers, hidden_dim, hidden_dim - ), - "decoder.layers.mlp.fc.weight": torch.randn(num_layers, 4 * hidden_dim, hidden_dim), - } - - trtllm_conversion_dict = { - "decoder.position_embedding.weight": TRTLLMLayers.position_embedding, - "decoder.word_embedding.weight": TRTLLMLayers.vocab_embedding, - "decoder.final_layernorm.weight": TRTLLMLayers.final_layernorm_weight, - "decoder.lm_head.weight": TRTLLMLayers.lm_head, - "decoder.layers.input_layernorm.weight": TRTLLMLayers.input_layernorm_weight, - "decoder.layers.attention.qkv.weight": TRTLLMLayers.attention_qkv_weight, - "decoder.layers.attention.qkv.bias": TRTLLMLayers.attention_qkv_bias, - "decoder.layers.attention.dense.weight": TRTLLMLayers.attention_dense_weight, - "decoder.layers.mlp.fc.weight": TRTLLMLayers.mlp_fc_weight, - } - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.str_dtype_to_torch", - return_value=torch.float32, - ) - - trtllm_model_weights_converter_cpu = SingleDeviceTRTLLMModelWeightsConverter( - export_config, model_config, dtype, activation="gelu" - ) - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.pad_vocab_size", - return_value=10, - ) - - # Verify the conditions are met for the invalid case - assert trtllm_model_weights_converter_cpu.num_kv_heads >= export_config.inference_tp_size - assert ( - trtllm_model_weights_converter_cpu.num_kv_heads % export_config.inference_tp_size != 0 - ) - assert trtllm_model_weights_converter_cpu.num_kv_heads == 5 - assert export_config.inference_tp_size == 4 - - # This should raise an exception during conversion - with pytest.raises(Exception) as exc_info: - trtllm_model_weights_converter_cpu.convert( - model_state_dict=model_state_dict, - trtllm_conversion_dict=trtllm_conversion_dict, - state_dict_split_by_layer_numbers=False, - ) - - # Verify the exception message - expected_message = "Number of query groups of the models is 5. Please select tensor parallelism size that can duplicate or split the number of query groups to equal number of query matrices in the each GPU." - assert expected_message in str(exc_info.value) - - def test_num_kv_heads_greater_equal_tp_size_valid(self, mocker): - """Test the condition where num_kv_heads >= tp_size and num_kv_heads % tp_size == 0 (valid case)""" - - # Configure for valid case: 8 KV heads, TP size 4 (divisible) - # This is valid because 8 % 4 == 0 - export_config = ExportConfig(inference_tp_size=4) - - vocab_size = 10 - hidden_dim = 8 - seq_len = 8 - num_layers = 2 - num_attn_heads = 8 - num_kv_heads = 8 # This is equal to tp_size (4) and 8 % 4 == 0 - - model_config = TransformerConfig( - num_layers=num_layers, - num_attention_heads=num_attn_heads, - num_query_groups=num_kv_heads, - hidden_size=hidden_dim, - ffn_hidden_size=hidden_dim * 4, - ) - - dtype = DataType.bfloat16 - - q_num = num_attn_heads // num_kv_heads # 1 - size_per_head = hidden_dim // num_attn_heads # 1 - - # Calculate the correct tensor sizes for QKV - qkv_weight_size = num_kv_heads * (q_num + 2) * size_per_head # 8 * (1 + 2) * 1 = 24 - qkv_bias_size = num_kv_heads * (q_num + 2) * size_per_head # 8 * (1 + 2) * 1 = 24 - - model_state_dict = { - "decoder.position_embedding.weight": torch.randn(seq_len, hidden_dim), - "decoder.word_embedding.weight": torch.randn(vocab_size, hidden_dim), - "decoder.lm_head.weight": torch.randn(vocab_size, hidden_dim), - "decoder.final_layernorm.weight": torch.randn(hidden_dim), - "decoder.layers.input_layernorm.weight": torch.randn(num_layers, hidden_dim), - "decoder.layers.attention.qkv.weight": torch.randn( - num_layers, qkv_weight_size, hidden_dim - ), - "decoder.layers.attention.qkv.bias": torch.randn(num_layers, qkv_bias_size), - "decoder.layers.attention.dense.weight": torch.randn( - num_layers, hidden_dim, hidden_dim - ), - "decoder.layers.mlp.fc.weight": torch.randn(num_layers, 4 * hidden_dim, hidden_dim), - } - - trtllm_conversion_dict = { - "decoder.position_embedding.weight": TRTLLMLayers.position_embedding, - "decoder.word_embedding.weight": TRTLLMLayers.vocab_embedding, - "decoder.final_layernorm.weight": TRTLLMLayers.final_layernorm_weight, - "decoder.lm_head.weight": TRTLLMLayers.lm_head, - "decoder.layers.input_layernorm.weight": TRTLLMLayers.input_layernorm_weight, - "decoder.layers.attention.qkv.weight": TRTLLMLayers.attention_qkv_weight, - "decoder.layers.attention.qkv.bias": TRTLLMLayers.attention_qkv_bias, - "decoder.layers.attention.dense.weight": TRTLLMLayers.attention_dense_weight, - "decoder.layers.mlp.fc.weight": TRTLLMLayers.mlp_fc_weight, - } - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.str_dtype_to_torch", - return_value=torch.float32, - ) - - trtllm_model_weights_converter_cpu = SingleDeviceTRTLLMModelWeightsConverter( - export_config, model_config, dtype, activation="gelu" - ) - - mocker.patch( - "megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter.pad_vocab_size", - return_value=10, - ) - - # Verify the conditions are met for the valid case - assert trtllm_model_weights_converter_cpu.num_kv_heads >= export_config.inference_tp_size - assert ( - trtllm_model_weights_converter_cpu.num_kv_heads % export_config.inference_tp_size == 0 - ) - assert trtllm_model_weights_converter_cpu.num_kv_heads == 8 - assert export_config.inference_tp_size == 4 - - # This should not raise an exception - trtllm_model_weights_converter_cpu.convert( - model_state_dict=model_state_dict, - trtllm_conversion_dict=trtllm_conversion_dict, - state_dict_split_by_layer_numbers=False, - ) - - # Check that QKV weights and biases are properly split for TP=4 - # Each TP rank should get 1/4 of the weights - for layer_idx in range(num_layers): - qkv_weight_key = f"transformer.layers.{layer_idx}.attention.qkv.weight" - qkv_bias_key = f"transformer.layers.{layer_idx}.attention.qkv.bias" - - # Check that we have 4 splits (one for each TP rank) - for tp_rank in range(4): - weight_split_key = f"{qkv_weight_key}.{tp_rank}.bin" - bias_split_key = f"{qkv_bias_key}.{tp_rank}.bin" - - assert weight_split_key in trtllm_model_weights_converter_cpu.trtllm_model_weights - assert bias_split_key in trtllm_model_weights_converter_cpu.trtllm_model_weights - - # Verify that the splits have the expected dimensions - weight_split = trtllm_model_weights_converter_cpu.trtllm_model_weights[ - weight_split_key - ] - bias_split = trtllm_model_weights_converter_cpu.trtllm_model_weights[bias_split_key] - - # For TP=4, each split should have 1/4 of the original size - # The weight shape depends on the conversion process, not necessarily hidden_dim - assert ( - len(bias_split.shape) == 1 - ), f"Expected bias to be 1D, got shape {bias_split.shape}" - - # Verify that all splits have the same size - if tp_rank == 0: - expected_weight_size = weight_split.shape[1] - expected_bias_size = bias_split.shape[0] - else: - assert ( - weight_split.shape[1] == expected_weight_size - ), f"All weight splits should have same size" - assert ( - bias_split.shape[0] == expected_bias_size - ), f"All bias splits should have same size" diff --git a/tests/unit_tests/fusions/test_bias_dropout_fusion.py b/tests/unit_tests/fusions/test_bias_dropout_fusion.py deleted file mode 100644 index 6303a87ba2..0000000000 --- a/tests/unit_tests/fusions/test_bias_dropout_fusion.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -import torch - -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add - - -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("training", [True, False]) -def test_bias_dropout_add(dtype, training): - torch.manual_seed(42) - device = "cuda" - B, H = 16, 64 - - # Initialize inputs - x = torch.randn(B, H, dtype=dtype, device=device, requires_grad=training) - residual = torch.randn(B, H, dtype=dtype, device=device, requires_grad=training) - bias = torch.randn(H, dtype=dtype, device=device) - - # Run un-fused as reference - torch.manual_seed(42) - ref_fn = get_bias_dropout_add(training=training, fused=False) - x_ref = x.clone().detach().requires_grad_(training) - residual_ref = residual.clone().detach().requires_grad_(training) - out_ref = ref_fn((x_ref, bias), residual_ref, prob=0.0) - - # Run fused - torch.manual_seed(42) - fused_fn = get_bias_dropout_add(training=training, fused=True) - x_fused = x.clone().detach().requires_grad_(training) - residual_fused = residual.clone().detach().requires_grad_(training) - out_fused = fused_fn((x_fused, bias), residual_fused, prob=0.0) - - tols = dict(rtol=1e-6, atol=1e-6) if dtype is torch.float32 else dict(rtol=2e-2, atol=1e-2) - - assert out_fused.dtype == out_ref.dtype - assert torch.allclose(out_fused, out_ref, **tols) - - if training: - grad = torch.randn_like(out_ref) - out_ref.backward(grad) - out_fused.backward(grad) - - assert torch.allclose(x_ref.grad, x_fused.grad, **tols) - assert torch.allclose(residual_ref.grad, residual_fused.grad, **tols) - else: - # In‑place check for inference - assert out_fused.data_ptr() == x_fused.data_ptr() - assert torch.allclose(out_fused, x_fused, **tols) diff --git a/tests/unit_tests/fusions/test_mla_yarn_rope_apply.py b/tests/unit_tests/fusions/test_mla_yarn_rope_apply.py deleted file mode 100644 index 1c8976bfcb..0000000000 --- a/tests/unit_tests/fusions/test_mla_yarn_rope_apply.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.common.embeddings import apply_rotary_pos_emb -from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import YarnRotaryEmbedding -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_torch_min_version - -try: - from megatron.core.fusions.fused_mla_yarn_rope_apply import ( - fused_apply_mla_rope_for_kv, - fused_apply_mla_rope_for_q, - ) -except: - fused_apply_mla_rope_for_kv = None - fused_apply_mla_rope_for_q = None - - -def dtype_tols(dtype): - if dtype == torch.float32: - return dict(rtol=1.0e-6, atol=1.0e-6) - elif dtype == torch.float16: - return dict(rtol=3.0e-3, atol=1.0e-5) - elif dtype == torch.bfloat16: - return dict(rtol=2.0e-2, atol=5.0e-2) - else: - raise ValueError(f"Unsuppored dtype ({dtype})") - - -class FakeCPGroup: - def size(self): - return 1 - - def rank(self): - return 0 - - -def _test_fused_apply_mla_rope_for_q(input_format): - assert fused_apply_mla_rope_for_q is not None - num_heads = 32 - q_dim = 128 - emb_dim = 64 - dtype = torch.bfloat16 - transformer_config = TransformerConfig( - num_attention_heads=num_heads, - num_layers=1, - rotary_interleaved=False, - multi_latent_attention=True, - ) - - if input_format == "sbhd": - cu_seqlens = None - seqlen = 1024 - batch_size = 2 - yarn_rope = YarnRotaryEmbedding(emb_dim, original_max_position_embeddings=seqlen) - freqs, mscale = yarn_rope(seqlen, 0) - cos = (torch.cos(freqs) * mscale).to(dtype) - sin = (torch.sin(freqs) * mscale).to(dtype) - - pytorch_fwd_input = torch.randn( - (seqlen, batch_size, num_heads, q_dim + emb_dim), dtype=dtype, device='cuda' - ) - pytorch_bwd_input = torch.randn( - (seqlen, batch_size, num_heads, q_dim + emb_dim), dtype=dtype, device='cuda' - ) - else: - cu_seqlens = [0, 27, 54, 99, 128] - total_seqlen = cu_seqlens[-1] - max_seqlen = 0 - for i in range(len(cu_seqlens) - 1): - max_seqlen = max(max_seqlen, cu_seqlens[i + 1] - cu_seqlens[i]) - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device='cuda') - yarn_rope = YarnRotaryEmbedding(emb_dim, original_max_position_embeddings=max_seqlen) - freqs, mscale = yarn_rope(max_seqlen, 0) - cos = (torch.cos(freqs) * mscale).to(dtype) - sin = (torch.sin(freqs) * mscale).to(dtype) - - pytorch_fwd_input = torch.randn( - (total_seqlen, num_heads, q_dim + emb_dim), dtype=dtype, device='cuda' - ) - pytorch_bwd_input = torch.randn( - (total_seqlen, num_heads, q_dim + emb_dim), dtype=dtype, device='cuda' - ) - - pytorch_fwd_input.requires_grad_(True) - fused_fwd_input = pytorch_fwd_input.detach() - fused_fwd_input.requires_grad_(True) - fused_bwd_input = pytorch_bwd_input.detach() - - no_pe, pe = torch.split(pytorch_fwd_input, [q_dim, emb_dim], dim=-1) - pe_output = apply_rotary_pos_emb( - pe, freqs, transformer_config, cu_seqlens=cu_seqlens, mscale=mscale, cp_group=FakeCPGroup() - ) - pytorch_output = torch.concat([no_pe, pe_output], dim=-1) - pytorch_output.backward(pytorch_bwd_input, retain_graph=True) - - fused_output = fused_apply_mla_rope_for_q( - fused_fwd_input, cos, sin, q_dim, emb_dim, cu_seqlens_q=cu_seqlens - ) - fused_output.backward(fused_bwd_input, retain_graph=True) - - tols = dtype_tols(dtype) - torch.testing.assert_close( - pytorch_output.float(), - fused_output.float(), - msg=lambda msg: f"Mismatch in fwd: {msg}", - **tols, - ) - torch.testing.assert_close( - pytorch_fwd_input.grad.float(), - fused_fwd_input.grad.float(), - msg=lambda msg: f"Mismatch in bwd: {msg}", - **tols, - ) - - -def _test_fused_apply_mla_rope_for_kv(input_format): - assert fused_apply_mla_rope_for_kv is not None - num_heads = 32 - k_dim = 128 - v_dim = 128 - emb_dim = 64 - dtype = torch.bfloat16 - transformer_config = TransformerConfig( - num_attention_heads=num_heads, - num_layers=1, - rotary_interleaved=False, - multi_latent_attention=True, - ) - - if input_format == "sbhd": - cu_seqlens = None - seqlen = 1024 - batch_size = 2 - yarn_rope = YarnRotaryEmbedding(emb_dim, original_max_position_embeddings=seqlen) - freqs, mscale = yarn_rope(seqlen, 0) - cos = (torch.cos(freqs) * mscale).to(dtype) - sin = (torch.sin(freqs) * mscale).to(dtype) - - pytorch_fwd_kv_input = torch.randn( - (seqlen, batch_size, num_heads, k_dim + v_dim), dtype=dtype, device='cuda' - ) - pytorch_fwd_emb_input = torch.randn( - (seqlen, batch_size, 1, emb_dim), dtype=dtype, device='cuda' - ) - pytorch_bwd_k_input = torch.randn( - (seqlen, batch_size, num_heads, k_dim + emb_dim), dtype=dtype, device='cuda' - ) - pytorch_bwd_v_input = torch.randn( - (seqlen, batch_size, num_heads, v_dim), dtype=dtype, device='cuda' - ) - else: - cu_seqlens = [0, 27, 54, 99, 128] - total_seqlen = cu_seqlens[-1] - max_seqlen = 0 - for i in range(len(cu_seqlens) - 1): - max_seqlen = max(max_seqlen, cu_seqlens[i + 1] - cu_seqlens[i]) - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device='cuda') - yarn_rope = YarnRotaryEmbedding(emb_dim, original_max_position_embeddings=max_seqlen) - freqs, mscale = yarn_rope(max_seqlen, 0) - cos = (torch.cos(freqs) * mscale).to(dtype) - sin = (torch.sin(freqs) * mscale).to(dtype) - - pytorch_fwd_kv_input = torch.randn( - (total_seqlen, num_heads, k_dim + v_dim), dtype=dtype, device='cuda' - ) - pytorch_fwd_emb_input = torch.randn((total_seqlen, 1, emb_dim), dtype=dtype, device='cuda') - pytorch_bwd_k_input = torch.randn( - (total_seqlen, num_heads, k_dim + emb_dim), dtype=dtype, device='cuda' - ) - pytorch_bwd_v_input = torch.randn( - (total_seqlen, num_heads, v_dim), dtype=dtype, device='cuda' - ) - - pytorch_fwd_kv_input.requires_grad_(True) - pytorch_fwd_emb_input.requires_grad_(True) - fused_fwd_kv_input = pytorch_fwd_kv_input.detach() - fused_fwd_kv_input.requires_grad_(True) - fused_fwd_emb_input = pytorch_fwd_emb_input.detach() - fused_fwd_emb_input.requires_grad_(True) - fused_bwd_k_input = pytorch_bwd_k_input.detach() - fused_bwd_v_input = pytorch_bwd_v_input.detach() - - pe_output = apply_rotary_pos_emb( - pytorch_fwd_emb_input, - freqs, - transformer_config, - cu_seqlens=cu_seqlens, - mscale=mscale, - cp_group=FakeCPGroup(), - ) - if input_format == "sbhd": - pe_output = pe_output.expand(-1, -1, num_heads, -1) - else: - pe_output = pe_output.expand(-1, num_heads, -1) - k, pytorch_v_output = torch.split(pytorch_fwd_kv_input, [k_dim, v_dim], dim=-1) - pytorch_k_output = torch.concat([k, pe_output], dim=-1) - torch.autograd.backward( - (pytorch_k_output, pytorch_v_output), (pytorch_bwd_k_input, pytorch_bwd_v_input) - ) - - fused_k_output, fused_v_output = fused_apply_mla_rope_for_kv( - fused_fwd_kv_input, - fused_fwd_emb_input, - cos, - sin, - emb_dim, - k_dim, - v_dim, - cu_seqlens_kv=cu_seqlens, - ) - torch.autograd.backward( - (fused_k_output, fused_v_output), (fused_bwd_k_input, fused_bwd_v_input) - ) - - tols = dtype_tols(dtype) - torch.testing.assert_close( - pytorch_k_output.float(), - fused_k_output.float(), - msg=lambda msg: f"Mismatch in k fwd: {msg}", - **tols, - ) - torch.testing.assert_close( - pytorch_v_output.float(), - fused_v_output.float(), - msg=lambda msg: f"Mismatch in v fwd: {msg}", - **tols, - ) - torch.testing.assert_close( - pytorch_fwd_kv_input.grad.float(), - fused_fwd_kv_input.grad.float(), - msg=lambda msg: f"Mismatch in kv bwd: {msg}", - **tols, - ) - torch.testing.assert_close( - pytorch_fwd_emb_input.grad.float(), - fused_fwd_emb_input.grad.float(), - msg=lambda msg: f"Mismatch in emb bwd: {msg}", - **tols, - ) - - -@pytest.mark.experimental -@pytest.mark.internal -@pytest.mark.skipif(not is_torch_min_version("2.5.0"), reason="Requires PyTorch >= 2.5.0") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("input_format", ["sbhd", "thd"]) -class TestFusedApplyMLARope: - def test_forward_backward_for_q(self, input_format): - _test_fused_apply_mla_rope_for_q(input_format) - - def test_forward_backward_for_kv(self, input_format): - _test_fused_apply_mla_rope_for_kv(input_format) diff --git a/tests/unit_tests/fusions/test_swiglu_fusion.py b/tests/unit_tests/fusions/test_swiglu_fusion.py deleted file mode 100644 index c72679cd04..0000000000 --- a/tests/unit_tests/fusions/test_swiglu_fusion.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -import torch - -from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl, weighted_bias_swiglu_impl - - -@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) -def test_weighted_bias_swiglu(input_dtype): - if input_dtype == torch.float32: - tols = dict(rtol=1.0e-6, atol=1.0e-6) - elif input_dtype == torch.bfloat16: - tols = dict(rtol=2.0e-2, atol=1.0e-3) - else: - raise ValueError(f"Invalid input dtype: {input_dtype}") - - x = torch.randn(16, 64, dtype=input_dtype, device="cuda") - x.requires_grad = True - weights = torch.randn(16, 1, dtype=torch.float32, device="cuda") - weights.requires_grad = True - bwd_input = torch.randn(16, 32, dtype=input_dtype, device="cuda") - - y = bias_swiglu_impl(x, None) * weights - y = y.to(input_dtype) - y.backward(bwd_input) - - x_2 = x.detach() - x_2.requires_grad = True - weights_2 = weights.detach() - weights_2.requires_grad = True - bwd_input_2 = bwd_input.detach() - - y_2 = weighted_bias_swiglu_impl(x_2, None, weights_2) - y_2.backward(bwd_input_2) - - assert y_2.dtype == y.dtype - assert torch.allclose(y, y_2, **tols) - assert x_2.grad.dtype == x.grad.dtype - assert torch.allclose(x.grad, x_2.grad, **tols) - assert weights_2.grad.dtype == weights.grad.dtype - if input_dtype == torch.float32: - assert torch.allclose(weights.grad, weights_2.grad, **tols) diff --git a/tests/unit_tests/fusions/test_torch_softmax.py b/tests/unit_tests/fusions/test_torch_softmax.py deleted file mode 100644 index 63b0bc7b5d..0000000000 --- a/tests/unit_tests/fusions/test_torch_softmax.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -import torch - -from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.utils import attention_mask_func, get_default_causal_mask - - -class TestTorchSoftmax: - def setup_method(self, method): - # The important settings tested are forward_torch_softmax path - # with locally generated casual mask for attention_mask_func: - self.softmax = FusedScaleMaskSoftmax( - input_in_fp16=False, - input_in_bf16=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=False, - mask_func=attention_mask_func, - softmax_in_fp32=True, - scale=None, - ) - - def teardown_method(self): - get_default_causal_mask.cache_clear() - - def test_output_shape(self): - x = torch.randn(8, 2, 4, 4, device="cuda") - y = self.softmax(x, None) - assert x.shape == y.shape - - def test_causal_mask_input_shape_assert(self): - x = torch.randn(1, 1, 4, 16, device="cuda") - with pytest.raises(AssertionError): - self.softmax(x, None) - - def test_causal_mask_equal_scores(self): - # For equal input values (e.g. zero) correctly masked softmax should - # produce equal scores among non-masked elements. For example, in case - # sq == sk == 2 the expected output is (ignoring b and np dimensions): - # [[1.0, 0.0], - # [0.5, 0.5]] - b, np, sq, sk = 8, 2, 32, 32 - x = torch.zeros([b, np, sq, sk]).cuda() - y = self.softmax(x, None) - y_expected = torch.tril(torch.ones(b, np, sq, sk, device="cuda")) - y_expected /= torch.arange(1, sq + 1, device="cuda").reshape((-1, 1)) - assert torch.allclose(y, y_expected, rtol=1e-08, atol=1e-08) diff --git a/tests/unit_tests/fusions/test_weighted_squared_relu_fusion.py b/tests/unit_tests/fusions/test_weighted_squared_relu_fusion.py deleted file mode 100644 index 85755ac1de..0000000000 --- a/tests/unit_tests/fusions/test_weighted_squared_relu_fusion.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.activations import squared_relu -from megatron.core.fusions.fused_weighted_squared_relu import weighted_squared_relu_impl - - -@pytest.mark.internal -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) -def test_weighted_squared_relu_fusion(input_dtype): - # Tolerances depend on dtype precision - if input_dtype == torch.float32: - tols = dict(rtol=1.0e-6, atol=1.0e-6) - elif input_dtype == torch.bfloat16: - tols = dict(rtol=2.0e-2, atol=1.0e-3) - else: - raise ValueError(f"Unsupported dtype {input_dtype}") - - # Inputs - x = torch.randn(16, 64, dtype=input_dtype, device="cuda", requires_grad=True) - weights = torch.randn(16, 1, dtype=torch.float32, device="cuda", requires_grad=True) - grad_output = torch.randn(16, 64, dtype=input_dtype, device="cuda") - - # Baseline: legacy squared_relu followed by weighting. - y_baseline = squared_relu(x) * weights - y_baseline = y_baseline.to(input_dtype) - y_baseline.backward(grad_output) - - # Fused implementation. - x_fused = x.detach().clone().requires_grad_(True) - weights_fused = weights.detach().clone().requires_grad_(True) - grad_output_fused = grad_output.detach().clone() - - y_fused = weighted_squared_relu_impl(x_fused, weights_fused) - y_fused.backward(grad_output_fused) - - # Forward accuracy - assert y_fused.dtype == y_baseline.dtype - assert torch.allclose(y_fused, y_baseline, **tols) - - # Grad accuracy w.r.t input - assert x_fused.grad.dtype == x.grad.dtype - assert torch.allclose(x_fused.grad, x.grad, **tols) - - # Grad accuracy w.r.t weights - assert weights_fused.grad.dtype == weights.grad.dtype - if input_dtype == torch.float32: - # For bf16 baseline weight grad computed in fp32 then cast may lose precision. - assert torch.allclose(weights_fused.grad, weights.grad, **tols) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py deleted file mode 100644 index 341ea78dc8..0000000000 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ /dev/null @@ -1,832 +0,0 @@ -import pytest -import torch - -from megatron.core.inference.contexts.dynamic_context import ( - DynamicInferenceContext, - RequestOverflowError, - TokenOverflowError, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from tests.unit_tests.test_utilities import Utils - - -def set_rounder(value): - """Utility function to set the DynamicInferenceContext rounder.""" - DynamicInferenceContext.ROUNDER = value # For backwards compatibility - DynamicInferenceContext.TOKEN_ROUNDER = value - DynamicInferenceContext.REQUEST_ROUNDER = value - - -class TestDynamicContext: - - def _setup_model_parallel_group(self, tensor_parallel_size, pipeline_parallel_size): - - Utils.initialize_model_parallel( - tensor_model_parallel_size=tensor_parallel_size, - pipeline_model_parallel_size=pipeline_parallel_size, - ) - model_parallel_cuda_manual_seed(123) - - def _get_dynamic_context( - self, - params_dtype, - num_layers, - kv_channels, - num_attention_heads, - max_sequence_length, - buffer_size_gb, - chunk_size_tokens, - buffer_guarenteed_fraction, - buffer_overflow_factor, - max_requests_override, - max_tokens_override, - ): - set_rounder(64) - dynamic_context = DynamicInferenceContext( - params_dtype=params_dtype, - num_layers=num_layers, - kv_channels=kv_channels, - num_attention_heads=num_attention_heads, - max_sequence_length=max_sequence_length, - num_cuda_graphs=None, - buffer_size_gb=buffer_size_gb, - buffer_guaranteed_fraction=buffer_guarenteed_fraction, - chunk_size_tokens=chunk_size_tokens, - buffer_overflow_factor=buffer_overflow_factor, - max_requests_override=max_requests_override, - max_tokens_override=max_tokens_override, - ) - return dynamic_context - - def teardown_method(self, method): - set_rounder(64) - Utils.destroy_model_parallel() - - @pytest.mark.experimental - def test_initialize_dynamic_context(self): - self._setup_model_parallel_group(1, 1) - - dynamic_context = self._get_dynamic_context( - params_dtype=torch.float32, - num_layers=4, - kv_channels=8, - num_attention_heads=2, - max_sequence_length=512, - buffer_size_gb=0.03, - buffer_guarenteed_fraction=0.1, - chunk_size_tokens=128, - max_requests_override=None, - max_tokens_override=None, - buffer_overflow_factor=None, - ) - - assert dynamic_context.gtd_chunk_count == 48 - assert dynamic_context.gtd_request_count == 12 - assert dynamic_context.chunk_allocator.chunk_count_total == 491 - assert dynamic_context.max_requests == 128 - assert dynamic_context.max_tokens == 62848 - - # Check initializations to -1 - assert torch.all(dynamic_context.request_ids == -1) - - @pytest.mark.experimental - def test_is_static_batching(self): - self._setup_model_parallel_group(1, 1) - dynamic_context = DynamicInferenceContext( - params_dtype=torch.float32, - num_layers=2, - kv_channels=64, - num_attention_heads=8, - max_sequence_length=512, - num_cuda_graphs=None, - buffer_size_gb=1.0, - buffer_guaranteed_fraction=0.1, - chunk_size_tokens=128, - ) - assert not dynamic_context.is_static_batching() - - @pytest.mark.experimental - def test_is_memory_available(self): - self._setup_model_parallel_group(1, 1) - dynamic_context = DynamicInferenceContext( - params_dtype=torch.float32, - num_layers=2, - kv_channels=64, - num_attention_heads=8, - max_sequence_length=512, - num_cuda_graphs=None, - buffer_size_gb=1.0, - buffer_guaranteed_fraction=0.1, - chunk_size_tokens=128, - ) - dynamic_context.chunk_allocator.chunk_count_avail = 10 - assert dynamic_context.chunk_allocator.is_memory_available(10) - assert not dynamic_context.chunk_allocator.is_memory_available(11) - - assert dynamic_context.chunk_allocator.is_memory_available(1) - dynamic_context.chunk_allocator.chunk_count_avail = 0 - assert not dynamic_context.chunk_allocator.is_memory_available(1) - - dynamic_context.chunk_allocator.chunk_count_avail = 10 - dynamic_context.gtd_chunk_count = 5 - assert dynamic_context.chunk_allocator.is_memory_available(6) - assert not dynamic_context.chunk_allocator.is_memory_available(6, safe=True) - - @pytest.mark.experimental - def test_request_overflow(self): - self._setup_model_parallel_group(1, 1) - set_rounder(1) - dynamic_context = DynamicInferenceContext( - params_dtype=torch.float32, - num_layers=2, - kv_channels=64, - num_attention_heads=8, - max_sequence_length=128, - num_cuda_graphs=None, - buffer_size_gb=0.01, - buffer_guaranteed_fraction=0.1, - chunk_size_tokens=32, - ) - with pytest.raises(RequestOverflowError): - for i in range(dynamic_context.max_requests + 1): - dynamic_context.add_request( - i, torch.zeros(10, device='cuda') - ) # Adding more than allowed requests - - @pytest.mark.experimental - def test_token_overflow_error(self): - self._setup_model_parallel_group(1, 1) - set_rounder(1) - dynamic_context = DynamicInferenceContext( - params_dtype=torch.float32, - num_layers=2, - kv_channels=64, - num_attention_heads=8, - max_sequence_length=512, - num_cuda_graphs=None, - buffer_size_gb=0.1, - buffer_guaranteed_fraction=0.1, - chunk_size_tokens=128, - buffer_overflow_factor=1.0, - max_requests_override=2, - max_tokens_override=20, # Setting a very low token limit - ) - - with pytest.raises(TokenOverflowError): - dynamic_context.add_request( - 1, torch.arange(0, 25, device='cuda') - ) # Exceeding max token count - - @pytest.mark.experimental - def test_reset(self): - self._setup_model_parallel_group(1, 1) - dynamic_context = DynamicInferenceContext( - params_dtype=torch.float32, - num_layers=2, - kv_channels=64, - num_attention_heads=8, - max_sequence_length=128, - num_cuda_graphs=None, - buffer_size_gb=1.0, - buffer_guaranteed_fraction=0.1, - chunk_size_tokens=128, - ) - - # Initialize all variables - dynamic_context.total_request_count = 10 - dynamic_context.active_token_count = 10 - dynamic_context.paused_request_count = 5 - dynamic_context.padded_active_token_count = 10 - dynamic_context.padded_active_request_count = 5 - dynamic_context.paused_tokens = torch.tensor([1, 2, 3], device='cuda') - dynamic_context.request_ids.fill_(1) - dynamic_context.request_query_lengths.fill_(1) - dynamic_context.request_kv_length_offsets.fill_(1) - dynamic_context.request_kv_chunk_counts.fill_(1) - dynamic_context.request_last_kv_chunk_id.fill_(1) - dynamic_context.request_last_kv_chunk_offset.fill_(1) - dynamic_context.token_to_input_ids.fill_(1) - dynamic_context.token_to_pos_ids.fill_(1) - dynamic_context.token_to_request_idx.fill_(1) - dynamic_context.token_to_position_in_request.fill_(1) - dynamic_context.token_to_chunk_idx.fill_(1) - dynamic_context.token_to_local_position_within_kv_chunk.fill_(1) - dynamic_context.chunk_allocator.chunk_count_avail = 5 - dynamic_context.memory_buffer.fill_(1) - dynamic_context.request_to_kv_chunk_ids.fill_(1) - - # Call reset - dynamic_context.reset() - - # Assert all variables are reset to zero or their default values - assert dynamic_context.total_request_count == 0 - assert dynamic_context.active_token_count == 0 - assert dynamic_context.paused_request_count == 0 - assert dynamic_context.padded_active_token_count == 0 - assert dynamic_context.padded_active_request_count == 0 - assert dynamic_context.paused_tokens is None - assert torch.all(dynamic_context.request_ids == -1) - assert torch.all(dynamic_context.request_query_lengths == 0) - assert torch.all(dynamic_context.request_kv_length_offsets == 0) - assert torch.all(dynamic_context.request_kv_chunk_counts == 0) - assert torch.all(dynamic_context.request_last_kv_chunk_id == -1) - assert torch.all(dynamic_context.request_last_kv_chunk_offset == 0) - assert torch.all(dynamic_context.token_to_input_ids == 0) - assert torch.all(dynamic_context.token_to_pos_ids == 0) - assert torch.all(dynamic_context.token_to_request_idx == -1) - assert torch.all(dynamic_context.token_to_position_in_request == 0) - assert torch.all(dynamic_context.token_to_chunk_idx == -1) - assert torch.all(dynamic_context.token_to_local_position_within_kv_chunk == 0) - assert ( - dynamic_context.chunk_allocator.chunk_count_avail - == dynamic_context.chunk_allocator.chunk_count_total - 1 - ) - assert torch.all(dynamic_context.request_to_kv_chunk_ids == -1) - - @pytest.mark.experimental - def test_allocate_and_release_memory_chunks(self): - self._setup_model_parallel_group(1, 1) - dynamic_context = self._get_dynamic_context( - params_dtype=torch.float32, - num_layers=4, - kv_channels=8, - num_attention_heads=2, - max_sequence_length=512, - buffer_size_gb=0.03, - buffer_guarenteed_fraction=0.1, - chunk_size_tokens=128, - max_requests_override=None, - max_tokens_override=None, - buffer_overflow_factor=None, - ) - - assert dynamic_context.chunk_allocator.allocate_memory_chunks( - 4 - ).cpu().detach().numpy().tolist() == [486, 487, 488, 489] - assert dynamic_context.chunk_allocator.chunk_count_avail == 486 - dynamic_context.chunk_allocator.release_memory_chunks( - torch.tensor([488, 489], device='cuda') - ) - assert dynamic_context.chunk_allocator.chunk_count_avail == 488 - assert dynamic_context.chunk_allocator.allocate_memory_chunks(1).item() == 489 - assert dynamic_context.chunk_allocator.chunk_count_avail == 487 - # Should return None since we allocate more chunks than what we have. - assert ( - dynamic_context.chunk_allocator.allocate_memory_chunks( - dynamic_context.chunk_allocator.chunk_count_avail + 100 - ) - == None - ) - - @pytest.mark.experimental - def test_add_request(self): - self._setup_model_parallel_group(1, 1) - dynamic_context = self._get_dynamic_context( - params_dtype=torch.float32, - num_layers=4, - kv_channels=8, - num_attention_heads=2, - max_sequence_length=512, - buffer_size_gb=0.03, - buffer_guarenteed_fraction=0.1, - chunk_size_tokens=128, - max_requests_override=None, - max_tokens_override=None, - buffer_overflow_factor=None, - ) - assert dynamic_context.chunk_size_tokens == 128 - context_length = 144 - dynamic_context.add_request( - request_id=0, tokens=torch.arange(0, context_length, dtype=torch.long, device='cuda') - ) - assert dynamic_context.total_request_count == 1 - assert dynamic_context.active_token_count == context_length - assert dynamic_context.request_ids[0] == 0 - assert torch.all(dynamic_context.request_ids[1:] == -1) - assert dynamic_context.request_query_lengths[0] == context_length - assert dynamic_context.request_kv_length_offsets[0] == 0 - assert dynamic_context.request_to_kv_chunk_ids[0].cpu().detach().numpy().tolist() == [ - 488, - 489, - -1, - -1, - ] - assert dynamic_context.request_kv_chunk_counts[0] == 2 - assert dynamic_context.request_last_kv_chunk_id[0] == 489 - assert dynamic_context.request_last_kv_chunk_offset[0].item() == 15 - assert torch.all( - dynamic_context.token_to_pos_ids[0:context_length] - == torch.arange(0, context_length, dtype=torch.long, device='cuda') - ) - assert torch.all( - dynamic_context.token_to_input_ids[0:context_length] - == torch.arange(0, context_length, dtype=torch.long, device='cuda') - ) - assert torch.all( - dynamic_context.token_to_position_in_request[0:context_length] - == torch.arange(0, context_length, dtype=torch.long, device='cuda') - ) - assert torch.all( - dynamic_context.token_to_chunk_idx[0:context_length][ - 0 : dynamic_context.chunk_size_tokens - ] - == 488 - ) - assert torch.all( - dynamic_context.token_to_chunk_idx[0:context_length][ - dynamic_context.chunk_size_tokens : context_length - ] - == 489 - ) - assert torch.all( - dynamic_context.token_to_local_position_within_kv_chunk[0:context_length] - == torch.arange(0, context_length, dtype=torch.long, device='cuda') - % dynamic_context.chunk_size_tokens - ) - - @pytest.mark.experimental - def test_update_request(self): - self._setup_model_parallel_group(1, 1) - dynamic_context = self._get_dynamic_context( - params_dtype=torch.float32, - num_layers=4, - kv_channels=8, - num_attention_heads=2, - max_sequence_length=512, - buffer_size_gb=0.03, - buffer_guarenteed_fraction=0.1, - chunk_size_tokens=128, - max_requests_override=None, - max_tokens_override=None, - buffer_overflow_factor=None, - ) - - # This case should just reset and return since all requests are finished - active_requests_mask = torch.Tensor([0, 0, 0]) - dynamic_context.paused_request_count = 0 - dynamic_context.total_request_count = 3 - dynamic_context.request_kv_chunk_counts[0:3] = 1 - new_chunk_ids = dynamic_context.chunk_allocator.allocate_memory_chunks(3, safe=True) - dynamic_context.request_to_kv_chunk_ids[0:3, 0] = new_chunk_ids - dynamic_context.update_requests( - active_requests_mask=active_requests_mask, new_tokens=torch.tensor([0, 1, 2]) - ) - assert dynamic_context.total_request_count == 0 - - # This case would cover all cases - # 1. Already there will be 2 paused requests - # 2. Active request mask will have active and finished requests. - # 3. The active requests will also have some requests that have to be paused because of reaching max token limit within chunk - # 4. Some of these requests will be resumed. - # Setup is as follows : - # Request ids 0, 1 are paused - # Request ids 2 , 4, 9 are active requests - # Request ids 3 7 8 have completed - # Request ids 5 and 6 will require on more chunk later on coz they finished their current chunk - - dynamic_context = self._get_dynamic_context( - params_dtype=torch.float32, - num_layers=4, - kv_channels=8, - num_attention_heads=2, - max_sequence_length=512, - buffer_size_gb=0.03, - buffer_guarenteed_fraction=0.1, - chunk_size_tokens=128, - max_requests_override=None, - max_tokens_override=None, - buffer_overflow_factor=None, - ) - - active_requests_mask = torch.Tensor([1, 0, 1, 1, 1, 0, 0, 1]).cuda().int() - next_tokens = torch.arange(2, 10, device='cuda').int() - dynamic_context.paused_request_count = 2 - dynamic_context.paused_tokens = torch.Tensor([0, 1]).cuda().int() - dynamic_context.total_request_count = 5 - - # Total req count should be equal to paused + num elements in active request mask. - # So here it will raise an assertion error - with pytest.raises(AssertionError) as error: - dynamic_context.update_requests( - active_requests_mask=active_requests_mask, new_tokens=next_tokens - ) - - total_request_count = 10 - dynamic_context.chunk_allocator.chunk_count_avail -= 11 # We align 11 chunks to the 10 requests we have. 3rd request alone we setup like it requires 2 chunks - dynamic_context.total_request_count = total_request_count - - dynamic_context.request_to_kv_chunk_ids[0:total_request_count, 0] = torch.arange( - dynamic_context.chunk_allocator.chunk_count_avail, - dynamic_context.chunk_allocator.chunk_count_avail + 10, - ) - dynamic_context.request_to_kv_chunk_ids[3][ - 1 - ] = ( - dynamic_context.chunk_allocator.chunk_count_avail - ) # Assign one extra chunk to request 3. - dynamic_context.request_kv_length_offsets[0:total_request_count] = 10 - # For 0, 1, 5, 6, the total number of tokens in last chunk is chunk size -1, so that they will all need extra chunks - dynamic_context.request_kv_length_offsets[0:2] = dynamic_context.chunk_size_tokens - 1 - dynamic_context.request_kv_length_offsets[5:7] = dynamic_context.chunk_size_tokens - 1 - # For the 3rd request, its completed and required 2 chunks. So we add more tokens than chunks size - dynamic_context.request_kv_length_offsets[3] = dynamic_context.chunk_size_bytes + 10 - dynamic_context.request_query_lengths[0:total_request_count] = ( - 1 # Everything is in decode phase - ) - - dynamic_context.request_ids[0:total_request_count] = torch.arange(0, total_request_count) - dynamic_context.request_kv_chunk_counts[0:total_request_count] = 1 - dynamic_context.request_kv_chunk_counts[3] = 2 # 3rd chunk alone requies 2 chunks - dynamic_context.request_last_kv_chunk_id[0:total_request_count] = torch.arange( - 0, total_request_count - ) - dynamic_context.request_last_kv_chunk_id[3] = 11 - dynamic_context.request_last_kv_chunk_offset[0:total_request_count] = 10 - # For the 3rd request, its completed and required 2 chunks. So we add more tokens than chunks size - dynamic_context.request_last_kv_chunk_offset[0:2] = dynamic_context.chunk_size_tokens - 1 - dynamic_context.request_last_kv_chunk_offset[5:7] = dynamic_context.chunk_size_tokens - 1 - - dynamic_context.update_requests( - active_requests_mask=active_requests_mask, new_tokens=next_tokens - ) - - # Then set up the test data - dynamic_context.request_ids[0:10] = torch.tensor( - [0, 1, 5, 6, 4, 2, 9, 7, 8, 9], device=torch.cuda.current_device() - ) - - # Now verify the values - assert dynamic_context.request_ids[0:10].cpu().numpy().tolist() == [ - 0, - 1, - 5, - 6, - 4, - 2, - 9, - 7, - 8, - 9, - ] - - assert dynamic_context.paused_request_count == 0 - assert dynamic_context.total_request_count == 7 - assert dynamic_context.active_token_count == 7 - - # The first four are zero because they have all obtained a new chunk - assert dynamic_context.request_last_kv_chunk_offset[0:10].cpu().numpy().tolist() == [ - 0, - 0, - 0, - 0, - 11, - 11, - 11, - 10, - 10, - 10, - ] - assert dynamic_context.token_to_input_ids[ - : dynamic_context.active_token_count - ].cpu().numpy().tolist() == [0, 1, 5, 6, 4, 2, 9] - - assert dynamic_context.token_to_pos_ids[ - : dynamic_context.active_token_count - ].cpu().numpy().tolist() == [128, 128, 128, 128, 11, 11, 11] - - # The first 4 requests will require an extra chunk. - # Since 3 requests have finished, the last 3 rows should be all -1. - assert torch.all( - dynamic_context.request_to_kv_chunk_ids[0:10].cpu() - == torch.tensor( - [ - [479, 482, -1, -1], - [480, 479, -1, -1], - [484, 486, -1, -1], - [485, 487, -1, -1], - [483, -1, -1, -1], - [481, -1, -1, -1], - [488, -1, -1, -1], - [-1, -1, -1, -1], - [-1, -1, -1, -1], - [-1, -1, -1, -1], - ] - ) - ) - - @pytest.mark.experimental - def test_release_memory_chunks_for_finished_requests(self): - """Test that memory chunks are correctly released for finished requests.""" - self._setup_model_parallel_group(1, 1) - dynamic_context = self._get_dynamic_context( - params_dtype=torch.float32, - num_layers=4, - kv_channels=8, - num_attention_heads=2, - max_sequence_length=512, - buffer_size_gb=0.03, - buffer_guarenteed_fraction=0.1, - chunk_size_tokens=128, - max_requests_override=None, - max_tokens_override=None, - buffer_overflow_factor=None, - ) - - # Set up the initial state with 5 requests - # Allocate 5 chunks for 5 requests - initial_chunks = dynamic_context.chunk_allocator.allocate_memory_chunks(5, safe=True) - dynamic_context.total_request_count = 5 - dynamic_context.paused_request_count = 0 - - # Record the available chunks before releasing memory - initial_available_chunks = dynamic_context.chunk_allocator.chunk_count_avail - - # Assign chunks to the requests (one chunk per request) - for i in range(5): - dynamic_context.request_to_kv_chunk_ids[i, 0] = initial_chunks[i] - dynamic_context.request_query_lengths[i] = 1 - dynamic_context.request_ids[i] = i - - # Create an active_requests_mask where requests 0, 2, and 4 are finished (0), - # and requests 1 and 3 are still active (1) - active_requests_mask = torch.tensor([0, 1, 0, 1, 0], device=torch.cuda.current_device()) - - # Call update_requests with these parameters - dynamic_context.update_requests( - active_requests_mask=active_requests_mask, - new_tokens=torch.tensor([10, 11, 12, 13, 14], device=torch.cuda.current_device()), - ) - - # After the update, we should have released 3 chunks (for requests 0, 2, and 4) - # and have 2 active requests (1 and 3) - assert dynamic_context.total_request_count == 2 - assert dynamic_context.active_token_count == 2 - - # Verify that 3 chunks were released by checking the available chunks - assert dynamic_context.chunk_allocator.chunk_count_avail == initial_available_chunks + 3 - - @pytest.mark.experimental - def test_finished_requests_with_multiple_chunks(self): - """Test that all memory chunks are correctly released for finished requests that use multiple chunks.""" - self._setup_model_parallel_group(1, 1) - dynamic_context = self._get_dynamic_context( - params_dtype=torch.float32, - num_layers=4, - kv_channels=8, - num_attention_heads=2, - max_sequence_length=512, - buffer_size_gb=0.03, - buffer_guarenteed_fraction=0.1, - chunk_size_tokens=128, - max_requests_override=None, - max_tokens_override=None, - buffer_overflow_factor=None, - ) - - # Set up the initial state with 3 requests, where some use multiple chunks - # Allocate 6 chunks in total for the requests - initial_chunks = dynamic_context.chunk_allocator.allocate_memory_chunks(6, safe=True) - dynamic_context.total_request_count = 3 - dynamic_context.paused_request_count = 0 - - # Record the available chunks before releasing memory - initial_available_chunks = dynamic_context.chunk_allocator.chunk_count_avail - - # Assign chunks to the requests: - # - Request 0: 1 chunk - # - Request 1: 2 chunks - # - Request 2: 3 chunks - dynamic_context.request_to_kv_chunk_ids[0, 0] = initial_chunks[0] - - dynamic_context.request_to_kv_chunk_ids[1, 0] = initial_chunks[1] - dynamic_context.request_to_kv_chunk_ids[1, 1] = initial_chunks[2] - - dynamic_context.request_to_kv_chunk_ids[2, 0] = initial_chunks[3] - dynamic_context.request_to_kv_chunk_ids[2, 1] = initial_chunks[4] - dynamic_context.request_to_kv_chunk_ids[2, 2] = initial_chunks[5] - - dynamic_context.request_kv_chunk_counts[0] = 1 - dynamic_context.request_kv_chunk_counts[1] = 2 - dynamic_context.request_kv_chunk_counts[2] = 3 - - for i in range(3): - dynamic_context.request_query_lengths[i] = 1 - dynamic_context.request_ids[i] = i - - # Create an active_requests_mask where all requests are finished - active_requests_mask = torch.tensor([0, 0, 0], device=torch.cuda.current_device()) - - # Call update_requests with these parameters - dynamic_context.update_requests( - active_requests_mask=active_requests_mask, - new_tokens=torch.tensor([10, 11, 12], device=torch.cuda.current_device()), - ) - - # After the update, we should have released all 6 chunks and have 0 active requests - assert dynamic_context.total_request_count == 0 - assert dynamic_context.active_token_count == 0 - - # Verify that all 6 chunks were released by checking the available chunks - assert dynamic_context.chunk_allocator.chunk_count_avail == initial_available_chunks + 6 - - @pytest.mark.experimental - def test_calculate_and_store_log_probs(self): - self._setup_model_parallel_group(1, 1) - dynamic_context = self._get_dynamic_context( - params_dtype=torch.float32, - num_layers=4, - kv_channels=8, - num_attention_heads=2, - max_sequence_length=512, - buffer_size_gb=0.03, - buffer_guarenteed_fraction=0.1, - chunk_size_tokens=128, - max_requests_override=None, - max_tokens_override=None, - buffer_overflow_factor=None, - ) - - # Add a few requests to the context - request_data = { - 1001: { - "tokens": torch.randint(0, 100, (10,), device='cuda'), - "prefill_len": 10, - "initial_token_offset": 0, - }, - 1002: { - "tokens": torch.randint(0, 100, (5,), device='cuda'), - "prefill_len": 5, - "initial_token_offset": 10, - }, - 1003: { - "tokens": torch.randint(0, 100, (7,), device='cuda'), - "prefill_len": 7, - "initial_token_offset": 15, - }, - } - - current_token_idx = 0 - for req_id, data in request_data.items(): - dynamic_context.add_request(req_id, data["tokens"]) - # Update the initial_token_offset as requests are added - request_data[req_id]["initial_token_offset"] = current_token_idx - current_token_idx += data["prefill_len"] - - # Simulate prefill step - total_active_tokens = dynamic_context.active_token_count - vocab_size = 50000 - # logits will have shape [1, total_active_tokens, vocab_size] - prefill_logits = torch.randn( - 1, total_active_tokens, vocab_size, device='cuda', dtype=torch.float32 - ) - - # New tokens from prefill (one token per active request) - num_active_requests = ( - dynamic_context.total_request_count - dynamic_context.paused_request_count - ) - prefill_new_tokens = torch.randint(0, 100, (num_active_requests,), device='cuda').long() - - # Call the function for prefill - prefill_log_probs = dynamic_context.calculate_log_probs(prefill_logits, prefill_new_tokens) - - # Calculate expected prefill log probs for the selected tokens - expected_prefill_log_probs = ( - torch.nn.functional.log_softmax(prefill_logits.squeeze(0), dim=-1) - .to(torch.float32) - .cpu() - ) - - for i, (req_id, data) in enumerate(request_data.items()): - req_len = data["tokens"].shape[0] - initial_token_offset = data["initial_token_offset"] - - assert len(prefill_log_probs[i]) == req_len, len(prefill_log_probs[i]) - - # Get the prompt tokens for this request and add the new sampled token - request_tokens = data["tokens"][1:].tolist() - request_tokens.append(prefill_new_tokens[i].item()) - - for j, token in enumerate(request_tokens): - assert ( - prefill_log_probs[i][j] - == expected_prefill_log_probs[initial_token_offset + j, token].item() - ) - - # Simulate decode step - # All requests are active, so the mask will be all ones for the current active requests - active_requests_mask = torch.ones(dynamic_context.total_request_count, device='cuda').int() - - dynamic_context.update_requests( - active_requests_mask=active_requests_mask, new_tokens=prefill_new_tokens - ) - - # Generate new logits for the decode step. Now each request contributes 1 token. - decode_logits = torch.randn( - 1, num_active_requests, vocab_size, device='cuda', dtype=torch.float32 - ) - decode_new_tokens = torch.randint(0, 100, (num_active_requests,), device='cuda').long() - decode_log_probs = dynamic_context.calculate_log_probs(decode_logits, decode_new_tokens) - - # Verify the stored decode log probabilities - expected_decode_log_probs = torch.nn.functional.log_softmax( - decode_logits.squeeze(0), dim=-1 - ).to(torch.float32) - - for i, (req_id, data) in enumerate(request_data.items()): - assert len(decode_log_probs[i]) == 1, len(decode_log_probs[i]) - - token = decode_new_tokens[i].item() - assert decode_log_probs[i][0] == expected_decode_log_probs[i, token].item() - - # Simulate mixed prefill and decode step (adding a new request to existing context) - dynamic_context.update_requests( - active_requests_mask=active_requests_mask, new_tokens=prefill_new_tokens - ) - - # Add a new prefill request to the existing context - new_request_id = 1004 - new_request_tokens = torch.randint(0, 100, (12,), device='cuda').long() - new_request_prefill_len = new_request_tokens.shape[0] - initial_token_offset_new_request = dynamic_context.active_token_count - dynamic_context.add_request(new_request_id, new_request_tokens) - request_data[new_request_id] = { - "tokens": new_request_tokens, - "prefill_len": new_request_prefill_len, - "initial_token_offset": initial_token_offset_new_request, - } - - # Simulate the step after adding the new prefill request. - # This step will involve both prefill (for the new request) and decode (for existing requests). - - dynamic_context.initialize_attention_state() - - total_active_tokens_mixed_step = dynamic_context.active_token_count - mixed_step_logits = torch.randn( - 1, total_active_tokens_mixed_step, vocab_size, device='cuda', dtype=torch.float32 - ) - - num_active_requests_mixed_step = ( - dynamic_context.total_request_count - dynamic_context.paused_request_count - ) - mixed_step_new_tokens = torch.randint( - 0, 100, (num_active_requests_mixed_step,), device='cuda' - ).long() - - mixed_step_log_probs = dynamic_context.calculate_log_probs( - mixed_step_logits, mixed_step_new_tokens - ) - - expected_mixed_step_log_probs = ( - torch.nn.functional.log_softmax(mixed_step_logits.squeeze(0), dim=-1) - .to(torch.float32) - .cpu() - ) - - # Verify log probs for the mixed step - current_global_token_offset = 0 - for i, (req_id, data) in enumerate(request_data.items()): - - # This logic needs to consider if the request was new (prefill) or existing (decode) - if req_id == new_request_id: - # This is the newly added prefill request - expected_len = data["prefill_len"] - assert len(mixed_step_log_probs[i]) == expected_len - - # For prefill, the log probs are for tokens[1:] + new_token - prompt_tokens = data["tokens"][1:].tolist() - new_sampled_token = mixed_step_new_tokens[i].item() - - for j in range(expected_len - 1): - # For prompt tokens - assert ( - mixed_step_log_probs[i][j] - == expected_mixed_step_log_probs[ - current_global_token_offset + j, prompt_tokens[j] - ].item() - ) - - # For the newly sampled token - assert ( - mixed_step_log_probs[i][expected_len - 1] - == expected_mixed_step_log_probs[ - current_global_token_offset + expected_len - 1, new_sampled_token - ].item() - ) - - current_global_token_offset += expected_len - - else: - # These are existing requests, now in decode phase - expected_len = 1 - assert len(mixed_step_log_probs[i]) == expected_len - - # For decode, the log prob is for the single new token - new_sampled_token = mixed_step_new_tokens[i].item() - assert ( - mixed_step_log_probs[i][0] - == expected_mixed_step_log_probs[ - current_global_token_offset, new_sampled_token - ].item() - ) - - current_global_token_offset += expected_len diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py deleted file mode 100644 index 56e638b961..0000000000 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ /dev/null @@ -1,799 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import asyncio -import random -import types -from dataclasses import dataclass -from typing import List, Optional - -import pytest -import torch -from tqdm import tqdm - -from megatron.core import parallel_state -from megatron.core.inference.contexts.dynamic_context import ( - ActiveRequestCountOverflowError, - ChunkOverflowError, - DynamicInferenceContext, - RequestOverflowError, - TokenOverflowError, - WarmupEngineMode, -) -from megatron.core.inference.engines import DynamicInferenceEngine -from megatron.core.inference.inference_request import DynamicInferenceRequest, Status -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( - GPTInferenceWrapper, -) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.text_generation_controllers.text_generation_controller import ( - TextGenerationController, -) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_fa_min_version -from tests.unit_tests.test_utilities import Utils - - -def set_rounder(value): - """Utility function to set the DynamicInferenceContext rounder.""" - DynamicInferenceContext.ROUNDER = value # For backwards compatibility - DynamicInferenceContext.TOKEN_ROUNDER = value - DynamicInferenceContext.REQUEST_ROUNDER = value - - -class Request: - """Simple class to hold prompt tokens and output tokens.""" - - def __init__(self, prompt: List[int], num_tokens_to_generate: Optional[int] = None): - self.prompt = prompt - self.num_tokens_to_generate = num_tokens_to_generate - self.output = [] - self.state = "queued" - - def __str__(self) -> str: - return "[%s]; prompt len %d; output len %d" % ( - self.state, - len(self.prompt), - len(self.output), - ) - - -@dataclass -class DynamicEngineTestConfig: - """Test configuration args.""" - - set_rounder(4) - num_requests: int = 2 * DynamicInferenceContext.round_up_requests(1, 1) - min_prompt_length: int = 4 - max_prompt_length: int = 16 - max_output_length: int = 4 - max_sequence_length: Optional[int] = None - - num_gap_steps: int = 2 - - context_buffer_size_gb: float = 0.1 # enough room for all tokens. - context_chunk_size_tokens: int = 256 - context_buffer_guaranteed_fraction: float = 0.01 - context_buffer_overflow_factor: Optional[float] = None - context_max_requests_override: Optional[int] = None - context_max_tokens_override: Optional[int] = None - tensor_model_parallel_size: int = 1 - pipeline_model_parallel_size: int = 1 - expert_model_parallel_size: int = 1 - sequence_parallel: bool = False - - use_fixed_output_lengths: bool = False - num_cuda_graphs: int = None - actually_build_cuda_graphs: bool = ( - False # only test_simple requires us to actually build a cuda-graph - ) - return_log_probs: bool = False - materialize_only_last_token_logits: bool = True - skip_prompt_log_probs_for_dynamic_inference: bool = False - - def __post_init__(self): - - # Compute max_sequence_length. - assert self.max_sequence_length is None - self.max_sequence_length = self.max_prompt_length + self.max_output_length - - # Update overrides if not using overflow factor. - if self.context_buffer_overflow_factor is None: - - # Enough room for all requests. - if self.context_max_requests_override is None: - self.context_max_requests_override = self.num_requests - - # Enough room for all tokens. - if self.context_max_tokens_override is None: - self.context_max_tokens_override = self.num_requests * self.max_sequence_length - - -@dataclass -class DynamicEngineTestEnv: - """Test environment, including requests and engine.""" - - config: DynamicEngineTestConfig - sampling_params: SamplingParams - requests: List[Request] - engine: DynamicInferenceEngine - - -class TestDynamicInferenceEngine: - - @classmethod - def _build_requests( - cls, - num_requests: int, - min_prompt_length: int, - max_prompt_length: int, - max_sequence_length: int, - vocab_size: int, - use_fixed_output_lengths: bool = False, - ) -> List[Request]: - prompt_lengths = torch.randint( - min_prompt_length, max_prompt_length + 1, (num_requests,) - ).tolist() - num_tokens_to_generate: List[Optional[int]] - if use_fixed_output_lengths: - num_tokens_to_generate = [ - random.randint(1, max_sequence_length - p) for p in prompt_lengths - ] - else: - num_tokens_to_generate = [None for _ in range(num_requests)] - prompts = [ - torch.randint(0, vocab_size - 1, (length,)).tolist() for length in prompt_lengths - ] - requests = [ - Request(prompt=p, num_tokens_to_generate=n) - for (p, n) in zip(prompts, num_tokens_to_generate) - ] - return requests - - @classmethod - def _build_inference_context( - cls, - test_config: DynamicEngineTestConfig, - transformer_config: TransformerConfig, - requests: List[Request], - ): - """The inference context manages the KV cache and other inference state.""" - - # Inference context. - context = DynamicInferenceContext( - params_dtype=transformer_config.params_dtype, - num_layers=transformer_config.num_layers, - kv_channels=transformer_config.kv_channels, - num_attention_heads=transformer_config.num_query_groups, - max_sequence_length=test_config.max_sequence_length, - num_cuda_graphs=test_config.num_cuda_graphs, - buffer_size_gb=test_config.context_buffer_size_gb, - buffer_guaranteed_fraction=test_config.context_buffer_guaranteed_fraction, - chunk_size_tokens=test_config.context_chunk_size_tokens, - buffer_overflow_factor=test_config.context_buffer_overflow_factor, - max_requests_override=test_config.context_max_requests_override, - max_tokens_override=test_config.context_max_tokens_override, - tensor_model_parallel_size=transformer_config.tensor_model_parallel_size, - materialize_only_last_token_logits=test_config.materialize_only_last_token_logits, - ) - - return context - - @classmethod - def _build_test_env(cls, test_config): - Utils.initialize_model_parallel( - tensor_model_parallel_size=test_config.tensor_model_parallel_size, - pipeline_model_parallel_size=test_config.pipeline_model_parallel_size, - ) - - set_rounder(4) - - random_seed = 123 - vocab_size = 100 - - # Random state. - random.seed(random_seed) - torch.manual_seed(random_seed) - model_parallel_cuda_manual_seed( - seed=random_seed, - inference_rng_tracker=True, - use_cudagraphable_rng=False, - force_reset_rng=True, - ) - - # Transformer config. - transformer_config = TransformerConfig( - params_dtype=torch.bfloat16, - num_layers=4, - hidden_size=32, - num_attention_heads=4, - use_cpu_initialization=True, - enable_cuda_graph=test_config.num_cuda_graphs is not None, - inference_rng_tracker=True, - tensor_model_parallel_size=test_config.tensor_model_parallel_size, - pipeline_model_parallel_size=test_config.pipeline_model_parallel_size, - expert_model_parallel_size=test_config.expert_model_parallel_size, - num_moe_experts=( - None - if test_config.expert_model_parallel_size == 1 - else test_config.expert_model_parallel_size - ), - sequence_parallel=test_config.sequence_parallel, - pipeline_dtype=torch.bfloat16, - add_bias_linear=test_config.expert_model_parallel_size == 1, - ) - - # Requests. - requests = cls._build_requests( - num_requests=test_config.num_requests, - min_prompt_length=test_config.min_prompt_length, - max_prompt_length=test_config.max_prompt_length, - max_sequence_length=test_config.max_sequence_length, - vocab_size=vocab_size, - use_fixed_output_lengths=test_config.use_fixed_output_lengths, - ) - - # Sampling params. - sampling_params = SamplingParams( - num_tokens_to_generate=test_config.max_output_length, - return_log_probs=test_config.return_log_probs, - ) - sampling_params.add_attributes( - { - "skip_prompt_log_probs_for_dynamic_inference": test_config.skip_prompt_log_probs_for_dynamic_inference - } - ) - - # GPT model. - model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_local_spec(), - vocab_size=vocab_size, - max_sequence_length=test_config.max_sequence_length, - parallel_output=True, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), - ).cuda() - - for param in model.parameters(): - param.data = param.data.to(transformer_config.params_dtype) - - model.eval() - - # Inference config. - inference_config = InferenceWrapperConfig( - hidden_size=transformer_config.hidden_size, - inference_batch_times_seqlen_threshold=400, - fp32_residual_connection=False, - params_dtype=transformer_config.params_dtype, - padded_vocab_size=vocab_size, - ) - - # Inference context. - inference_context = cls._build_inference_context( - test_config=test_config, transformer_config=transformer_config, requests=requests - ) - - # Inference model wrapper. - inference_wrapped_model = GPTInferenceWrapper(model, inference_config, inference_context) - - # Note: the following is taken from AbstractModelInferenceWrapper.prep_model_for_inference(). - inference_wrapped_model.model_is_pipeline_parallel = not ( - parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() - ) - - # Text generation controller. - text_generation_controller = TextGenerationController( - inference_wrapped_model=inference_wrapped_model, - tokenizer=types.SimpleNamespace(vocab_size=vocab_size), - ) - - # Reset global cuda graph state. - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] - CudaGraphManager.global_mempool = None - - # Inference engine. - engine = DynamicInferenceEngine( - text_generation_controller, - inference_context, - termination_id=vocab_size - 1, - random_seed=random_seed, - enable_cuda_graph=test_config.num_cuda_graphs is not None - and test_config.actually_build_cuda_graphs, - ) - - # Test env. - env = DynamicEngineTestEnv( - config=test_config, sampling_params=sampling_params, requests=requests, engine=engine - ) - - # Mock the detokenize method to return predictable result - def mock_detokenize_prompt(tokens): - return "tokenized_prompt" - - env.engine.controller.tokenizer.detokenize = mock_detokenize_prompt - - return env - - @classmethod - def _run_step(cls, env): - set_rounder(4) - # Step inference engine (i.e., generate one token per request). - active_requests, finished_requests, step_time = env.engine.step( - env.sampling_params, verbose=False - ) - - # Nothing done? - if len(finished_requests) == 0: - return - - # Append output tokens. - for finished_request in finished_requests: - request = env.requests[finished_request.request_id] - request.output = finished_request.generated_tokens - request.state = "finished" - - @classmethod - def _run_test(cls, **test_config_kwargs): - - # Test environment. - test_config = DynamicEngineTestConfig(**test_config_kwargs) - env = cls._build_test_env(test_config) - - # Add requests to engine. - for request_id in tqdm(range(len(env.requests)), "add requests"): - - # Add request. - num_tokens_to_generate = env.requests[request_id].num_tokens_to_generate - env.engine.add_request( - request_id, - env.requests[request_id].prompt, - num_tokens_to_generate=num_tokens_to_generate, - ) - env.requests[request_id].state = "pending" - - # Insert gap steps between adding requests. - for _ in range(test_config.num_gap_steps): - cls._run_step(env) - - # Step engine until finished. - while True: - cls._run_step(env) - if not env.engine.has_unfinished_requests(): - break - - # Validate all requests finished. - for request_id, request in enumerate(env.requests): - assert request.state == "finished", f"request.state == '{request.state}'." - - num_tokens_to_generate = env.requests[request_id].num_tokens_to_generate - assert ( - num_tokens_to_generate is None or len(request.output) == num_tokens_to_generate - ), ( - f"Request {request_id} expected to generate {num_tokens_to_generate} " - f"tokens but generated {len(request.output)}" - ) - - return env - - def teardown_method(self, method): - set_rounder(64) - Utils.destroy_model_parallel() - - @pytest.mark.experimental - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - @pytest.mark.parametrize( - "num_cuda_graphs", [None, 4] - ) # todo: cannot run test case with multiple num_cuda_graphs like [None, 1, 4] - def test_simple(self, num_cuda_graphs) -> None: - """Simple test that runs without errors, and validates output.""" - - # Run test. - env = self._run_test( - num_cuda_graphs=num_cuda_graphs, - actually_build_cuda_graphs=num_cuda_graphs is not None, - context_max_requests_override=32, - ) - - # Validate max_requests, max_tokens. - assert env.engine.context.max_requests == 32 - assert env.engine.context.max_tokens == 160 - - # Validate output tokens. - expected_outputs = [ - [69, 85, 55, 74, 85, 89, 64, 59, 55, 67], - [29, 54, 33, 30, 45, 76, 41, 56, 28, 25, 94, 2, 61, 6, 98], - [35, 78, 54, 32, 79, 98, 22, 5, 60], - [25, 75, 57, 85, 81], - [32, 5, 15, 58, 6, 37, 54, 47, 22, 1, 87, 42, 36, 26, 27, 56], - [85, 51, 88, 62, 71], - [30, 0, 1, 76, 77, 11, 25], - [23, 15, 70, 76, 97, 36, 37, 99], - ] - - assert len(env.requests) == len(expected_outputs) - for request, expected_output in zip(env.requests, expected_outputs): - assert request.output == expected_output - - @pytest.mark.experimental - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - def test_overflow_factor(self) -> None: - """Test overflow factor arg.""" - # Run test. - env = self._run_test( - context_buffer_overflow_factor=0.1, - context_max_requests_override=None, - context_max_tokens_override=None, - ) - - # Validate max_requests, max_tokens. - assert env.engine.context.max_requests == 420 - assert env.engine.context.max_tokens == 420 - - @pytest.mark.experimental - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - def test_request_overflow(self) -> None: - """Test request overflow.""" - self._run_test(context_max_requests_override=1) - - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - def test_token_overflow_transient(self) -> None: - """Test token overflow (transient).""" - test_config = DynamicEngineTestConfig( - min_prompt_length=6, - max_prompt_length=6, - max_output_length=2, - context_max_tokens_override=8, - ) - env = self._build_test_env(test_config) - for request_id, request in enumerate(env.requests): - env.engine.add_request(request_id, request.prompt, request.num_tokens_to_generate) - assert list(env.engine.waiting_request_ids) == [ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - ], f"waiting_request_ids: {list(env.engine.waiting_request_ids)}." - - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - @pytest.mark.skip( - reason="activate for `megatron-core >= 0.15`, after fixing " - "`raise TokenOverflowError(is_transient=False)` compatibility with " - "legacy tests." - ) - def test_token_overflow_nontransient(self) -> None: - """Test token overflow (non-transient).""" - test_config = DynamicEngineTestConfig(context_max_tokens_override=8) - env = self._build_test_env(test_config) - try: - env.engine.add_request( - 0, env.requests[0].prompt, env.requests[0].num_tokens_to_generate - ) - except TokenOverflowError as e: - assert e.is_transient == False - else: - raise Exception("should have raised TokenOverflowError(is_transient=False).") - - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - def test_chunk_overflow(self) -> None: - """Test token overflow.""" - env = self._build_test_env(DynamicEngineTestConfig()) - context = env.engine.context - chunk_size_bytes = context.chunk_size_bytes - buffer_size_gb = (chunk_size_bytes + 1) / 1024**3 - test_config = DynamicEngineTestConfig(context_buffer_size_gb=buffer_size_gb) - env = self._build_test_env(test_config) - env.engine.add_request(0, env.requests[0].prompt, env.requests[0].num_tokens_to_generate) - assert list(env.engine.waiting_request_ids) == [0] - - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - def test_multi_add(self) -> None: - """Test adding multiple requests simultaneously.""" - self._run_test(num_gap_steps=0) - - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - def test_fixed_output_lengths(self) -> None: - """Test generating a fixed number of output tokens.""" - self._run_test(use_fixed_output_lengths=True) - - def test_cuda_graph_token_counts(self) -> None: - """Test initialization of `cuda_graph_token_counts` in dynamic context.""" - - # Test num_cuda_graphs. - for num_cuda_graphs, expected_cuda_graph_token_counts in [ - (0, [64]), - (1, [64]), - (2, [64, 32]), - (4, [64, 48, 32, 16]), - (8, [64, 56, 48, 40, 32, 24, 16, 8]), - (16, [64, 56, 48, 40, 32, 24, 16, 8]), - (64, [64, 56, 48, 40, 32, 24, 16, 8]), - (1024, [64, 56, 48, 40, 32, 24, 16, 8]), - ]: - - # Build cuda graphs (inside dynamic engine). - env = self._build_test_env( - DynamicEngineTestConfig(num_requests=64, num_cuda_graphs=num_cuda_graphs) - ) - actual_cuda_graph_token_counts = env.engine.context.cuda_graph_token_counts - assert ( - actual_cuda_graph_token_counts == expected_cuda_graph_token_counts - ), "num_cuda_graphs %d ... cuda_graph_token_counts: expected %s, found %s." % ( - num_cuda_graphs, - expected_cuda_graph_token_counts, - actual_cuda_graph_token_counts, - ) - - @pytest.mark.parametrize( - "warmup_engine_mode", [WarmupEngineMode.DECODE, WarmupEngineMode.NON_DECODE] - ) - @pytest.mark.parametrize( - "num_warmup_tokens, expected_cuda_graph_token_count", - [ - (1, 8), - (2, 8), - (4, 8), - (8, 8), - (10, 16), - (12, 16), - (16, 16), - (20, 24), - (24, 24), - (28, 32), - (32, 32), - ], - ) - def test_cuda_graph_warmup( - self, - warmup_engine_mode: WarmupEngineMode, - num_warmup_tokens: int, - expected_cuda_graph_token_count: int, - ) -> None: - """Test initialization during cuda graph warmup.""" - if num_warmup_tokens == 1 and warmup_engine_mode == WarmupEngineMode.NON_DECODE: - pytest.skip("WarmupEngineMode.NON_DECODE with num_warmup_tokens=1 is not supported.") - - # Initialize context. - env = self._build_test_env(DynamicEngineTestConfig(num_requests=32, num_cuda_graphs=8)) - - context = env.engine.context - assert context.is_decode_only() - assert context.cuda_graph_token_counts == [ - 32, - 24, - 16, - 8, - ], "cuda_graph_token_counts: %s." % str(context.cuda_graph_token_counts) - - context.initialize_attention_state( - num_warmup_tokens=num_warmup_tokens, warmup_engine_mode=warmup_engine_mode - ) - - # Validate request & token counts. - - assert ( - expected_cuda_graph_token_count - == context.padded_active_request_count - == context.padded_active_token_count - ), ( - "failed ... num_warmup_tokens (%d) ... expected_cuda_graph_request_count (%d) == context.padded_active_request_count (%d) == context.padded_active_token_count (%d)" - % ( - num_warmup_tokens, - expected_cuda_graph_token_count, - context.padded_active_request_count, - context.padded_active_token_count, - ) - ) - - # Validate input/position dimensions. - input_ids, pos_ids = context.current_input_and_position_ids() - assert input_ids.shape[1] == pos_ids.shape[1] == expected_cuda_graph_token_count - assert context.using_cuda_graph_this_step, ( - "expected `using_cuda_graph_this_step` to be True for decode step with " - "num_warmup_tokens <= max_requests." - ) - context.reset() - - # Test active request count overflow - for num_warmup_tokens in (64, 128, 1024): - try: - context.initialize_attention_state( - num_warmup_tokens=num_warmup_tokens, warmup_engine_mode=warmup_engine_mode - ) - except ActiveRequestCountOverflowError as e: - continue - raise Exception("`ActiveRequestCountOverflowError should have been raised.") - - context.reset() - - # test the case where the active token count exceeds max requests. - # expectation: we should be in non-decode mode and not using cuda graphs - - # add all requests to the context. - for request_id in tqdm(range(len(env.requests)), "add requests"): - env.engine.add_request( - request_id, env.requests[request_id].prompt, num_tokens_to_generate=1 - ) - - # we should now have more active tokens than max requests. - context.initialize_attention_state() - assert not context.is_decode_only() - assert not context.using_cuda_graph_this_step(), ( - "expected `using_cuda_graph_this_step` to be False for non-decode step where " - "the active token count exceeds max requests" - ) - context.reset() - - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - def test_generate_function(self) -> None: - """Test the generate function that processes multiple prompts at once.""" - # Set up test environment - test_config = DynamicEngineTestConfig( - num_requests=4, max_prompt_length=8, max_output_length=4 - ) - env = self._build_test_env(test_config) - - # Create string prompts (just mock strings, since the test environment mocks the tokenizer) - prompts = ["prompt1", "prompt2", "prompt3", "prompt4"] - - # Mock the tokenize_prompt method to return predictable token sequences - def mock_tokenize_prompt(prompt): - # Return a token sequence based on the prompt number - prompt_num = int(prompt[-1]) - return [10 + i for i in range(prompt_num + 2)] - - env.engine.controller.tokenize_prompt = mock_tokenize_prompt - - # Call the generate function - finished_requests = env.engine.generate(prompts, env.sampling_params) - - # Verify results - assert len(finished_requests) == len( - prompts - ), "Should return same number of finished requests as prompts" - print() - # Check each request was processed - for i, request in enumerate(finished_requests): - # Verify each request has generated tokens - assert len(request.generated_tokens) > 0, f"Request {i} should have generated tokens" - assert request.status == Status.COMPLETED, f"Request {i} should be completed" - - @pytest.mark.asyncio - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - async def test_run_engine(self): - """ - Test asynchronously adding and waiting for requests while the engine is - running continuously. - """ - # Test environment. - test_config = DynamicEngineTestConfig(use_fixed_output_lengths=True) - env = self._build_test_env(test_config) - - engine_task = asyncio.create_task( - env.engine.run_engine(sampling_params=env.sampling_params, verbose=False) - ) - - request_completion_futures: Dict[int, asyncio.Future[DynamicInferenceRequest]] = {} - - # Add requests to engine. - for request_id in tqdm(range(len(env.requests)), "add requests"): - - # Add request. - num_tokens_to_generate = env.requests[request_id].num_tokens_to_generate - request_completion_futures[request_id] = env.engine.add_request( - request_id, - env.requests[request_id].prompt, - num_tokens_to_generate=num_tokens_to_generate, - ) - env.requests[request_id].state = "pending" - - # Wait for all requests to complete. - await asyncio.gather(*request_completion_futures.values()) - - # Verify that all request outputs were set. - for request_id, fut in request_completion_futures.items(): - num_tokens_to_generate = env.requests[request_id].num_tokens_to_generate - result = fut.result() - assert result.generated_length == num_tokens_to_generate, ( - f"Request {request_id} expected to generate {num_tokens_to_generate} " - f"tokens but generated {result.generated_length}" - ) - - engine_task.cancel() - - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - def test_return_log_probs(self): - """Verify that returning log probs does not raise any error.""" - # Returning log probs requires materializing the full prompt logits or - # explicitly disabling prompt logits. - with pytest.raises(AssertionError): - env = self._run_test(return_log_probs=True, materialize_only_last_token_logits=True) - env = self._run_test(return_log_probs=True, materialize_only_last_token_logits=False) - env = self._run_test( - return_log_probs=True, - materialize_only_last_token_logits=True, - skip_prompt_log_probs_for_dynamic_inference=True, - ) - - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - @pytest.mark.parametrize("materialize_only_last_token_logits", [False, True]) - @pytest.mark.parametrize("sequence_parallel", [False, True]) - @pytest.mark.parametrize("ep_size", [1, 2]) - @pytest.mark.parametrize("pp_size", [1, 2]) - @pytest.mark.parametrize("tp_size", [1, 2]) - def test_parallel_inference( - self, tp_size, pp_size, ep_size, sequence_parallel, materialize_only_last_token_logits - ): - if tp_size == 1 and pp_size == 1 and ep_size == 1: - pytest.skip(reason="Test requires tp_size > 1 or pp_size > 1 or ep_size > 1") - elif not torch.distributed.is_initialized(): - pytest.skip("Distributed not initialized") - world_size = torch.distributed.get_world_size() - min_world_size = tp_size * pp_size * ep_size - if world_size < min_world_size: - pytest.skip(f"Test requires at least {min_world_size} GPUs") - elif tp_size == 1 and sequence_parallel: - pytest.skip(reason="Sequence parallelism requires tp_size > 1") - elif tp_size > 1 and ep_size > 1 and not sequence_parallel: - pytest.skip(reason="Sequence parallelism must be used with tp_size > 1 and ep_size > 1") - env = self._run_test( - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - expert_model_parallel_size=ep_size, - sequence_parallel=sequence_parallel, - materialize_only_last_token_logits=materialize_only_last_token_logits, - ) - - -if __name__ == "__main__": - test = TestDynamicInferenceEngine() - test.test_simple() - test.test_overflow_factor() - test.test_request_overflow() - test.test_token_overflow_transient() - test.test_token_overflow_nontransient() - test.test_chunk_overflow() - test.test_multi_add() - test.test_fixed_output_lengths() - test.test_cuda_graph_request_counts() - test.test_cuda_graph_warmup() - test.test_generate_function() - asyncio.run(test.test_run_engine()) - test.test_return_log_probs() - test.teardown_method(None) - print("~~~") - print("success.") diff --git a/tests/unit_tests/inference/engines/test_static_engine.py b/tests/unit_tests/inference/engines/test_static_engine.py deleted file mode 100644 index 055a59c317..0000000000 --- a/tests/unit_tests/inference/engines/test_static_engine.py +++ /dev/null @@ -1,303 +0,0 @@ -import asyncio -import random -import string -from typing import AsyncGenerator, List, Union -from unittest import mock - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.inference.engines import StaticInferenceEngine -from megatron.core.inference.inference_request import InferenceRequest, Status -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( - GPTInferenceWrapper, -) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.text_generation_controllers.text_generation_controller import ( - TextGenerationController, -) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestStaticInferenceEngine: - def setup_engine( - self, - engine_max_batch_size=None, - vocab_size=100, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - expert_model_parallel_size=1, - sequence_parallel=False, - ): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tensor_model_parallel_size, - pipeline_model_parallel_size=pipeline_model_parallel_size, - ) - - model_parallel_cuda_manual_seed(123) - self.batch_size = 4 - self.hidden_size = 12 - self.vocab_size = vocab_size - self.sequence_length = 64 - transformer_config = TransformerConfig( - num_layers=4, - hidden_size=self.hidden_size, - num_attention_heads=4, - use_cpu_initialization=True, - inference_rng_tracker=True, - tensor_model_parallel_size=tensor_model_parallel_size, - pipeline_model_parallel_size=pipeline_model_parallel_size, - expert_model_parallel_size=expert_model_parallel_size, - num_moe_experts=None if expert_model_parallel_size == 1 else expert_model_parallel_size, - sequence_parallel=sequence_parallel, - pipeline_dtype=torch.bfloat16, - add_bias_linear=expert_model_parallel_size == 1, - ) - - gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_local_spec(), - vocab_size=self.vocab_size, - max_sequence_length=self.sequence_length, - parallel_output=True, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), - ).cuda() - - inference_wrapper_config = InferenceWrapperConfig( - hidden_size=self.hidden_size, - inference_batch_times_seqlen_threshold=400, - inference_max_requests=self.batch_size, - fp32_residual_connection=False, - params_dtype=torch.float, - padded_vocab_size=self.vocab_size, - ) - - inference_context = StaticInferenceContext.from_config(inference_wrapper_config) - - inference_wrapped_model = GPTInferenceWrapper( - gpt_model, inference_wrapper_config, inference_context - ) - self.mock_tokenizer = mock.Mock() - text_generation_controller = TextGenerationController( - inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer - ) - - if engine_max_batch_size is not None and engine_max_batch_size > self.batch_size: - with pytest.warns(UserWarning): - self.static_engine = StaticInferenceEngine( - text_generation_controller=text_generation_controller, - max_batch_size=engine_max_batch_size, - ) - else: - self.static_engine = StaticInferenceEngine( - text_generation_controller=text_generation_controller, - max_batch_size=engine_max_batch_size, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - "batch_size,num_trials,empty_prompt", - [(4, 1, False), (4, 1, True), (4, 3, False), (2, 1, False), (8, 1, False)], - ) - def test_generate(self, batch_size: int, num_trials: int, empty_prompt: bool): - self.setup_engine(engine_max_batch_size=batch_size) - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.eod = self.vocab_size - 1 - # Generating random length integer prompts - self.mock_tokenizer.tokenize.return_value = [ - random.randint(0, self.vocab_size - 1) for _ in range(random.randint(5, 10)) - ] - # Generates some random string - self.mock_tokenizer.detokenize.return_value = ''.join( - random.choices(string.ascii_letters, k=random.randint(4, 10)) - ) - - for _ in range(num_trials): - if empty_prompt: - prompts = ["" for i in range(batch_size)] - else: - prompts = ["sample" * (i + 1) for i in range(batch_size)] - results: List[InferenceRequest] = self.static_engine.generate( - prompts, sampling_params=SamplingParams(num_tokens_to_generate=10) - ) - - assert len(results) == batch_size - for result in results: - assert ( - result.status == Status.COMPLETED - ), f"Status should be completed but its {result.status}" - assert result.generated_length > 0, f"Generated length should be greater than zero" - assert result.generated_text is not None, f'Generated text should not be None' - - @pytest.mark.asyncio - async def test_streaming(self): - self.setup_engine() - - async def collect_stream(stream_generator, num_tokens_to_generate): - prev_log_probs = None - prev_text = "" - prev_idx = 0 - prev_length = 0 - num_output_tokens = 0 - async for output in stream_generator: - num_output_tokens += 1 - assert isinstance( - output, InferenceRequest - ), f"Expected InferenceRequest, got {type(output)}" - assert output.generated_log_probs is not None, f"Expected log probs tensor" - assert ( - output.generated_tokens.shape[0] == output.generated_length - ), f"Expected log probs length to match # generated tokens" - assert ( - len(output.generated_log_probs) == output.generated_length - ), f"Expected log probs length to match # generated tokens" - assert output.generated_length > prev_length, f"Expected generated length to grow" - assert ( - output.generated_text[:prev_idx] == prev_text - ), f"Expected generated text to match previous text" - assert ( - prev_log_probs is None or prev_log_probs == output.generated_log_probs[:-1] - ), f"Expected previous log probs to match new log probs" - prev_length = output.generated_length - prev_text = output.generated_text - prev_idx = len(output.generated_text) - prev_log_probs = output.generated_log_probs - - assert ( - num_output_tokens == num_tokens_to_generate - ), f"Should have streamed {num_tokens_to_generate} tokens but actually streamed {num_output_tokens}" - assert ( - len(output.generated_tokens) == num_tokens_to_generate - ), f"Should have included {num_tokens_to_generate} tokens but actually returned {len(output.generated_tokens)}" - assert ( - len(output.generated_log_probs) == num_tokens_to_generate - ), f"Should have included {num_tokens_to_generate} log probs but actually returned {len(output.generated_log_probs)}" - - return output - - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.bos = self.vocab_size - 2 - # Generating random length integer prompts - self.mock_tokenizer.tokenize.return_value = [ - random.randint(0, self.vocab_size - 1) for _ in range(random.randint(5, 10)) - ] - # Generates some random string - self.mock_tokenizer.detokenize.return_value = ''.join( - random.choices(string.ascii_letters, k=random.randint(4, 10)) - ) - - prompts = ["" for i in range(self.batch_size)] - - num_tokens_to_generate = 10 - sampling_params = SamplingParams( - num_tokens_to_generate=num_tokens_to_generate, return_log_probs=True - ) - request_ids: List[str] = [ - self.static_engine.add_request( - prompt, add_BOS=True, sampling_params=sampling_params, streaming=True - ) - for prompt in prompts - ] - stream_generators: List[AsyncGenerator[InferenceRequest, None]] = [ - self.static_engine.get_stream_generator(request_id) for request_id in request_ids - ] - assert all(stream_generator is not None for stream_generator in stream_generators) - - tasks = [ - asyncio.create_task(collect_stream(stream_generator, num_tokens_to_generate)) - for stream_generator in stream_generators - ] - - await self.static_engine.run_engine_async() - final_streamed_tokens: List[InferenceRequest] = await asyncio.gather(*tasks) - results: List[InferenceRequest] = [ - self.static_engine.scheduler.completed_request_pool[request_id] - for request_id in request_ids - ] - assert len(final_streamed_tokens) == len(results) - for result, final_streamed_token in zip(results, final_streamed_tokens): - assert torch.equal( - result.generated_tokens.cpu(), final_streamed_token.generated_tokens.cpu() - ), ( - f"result.generated_tokens={result.generated_tokens.cpu()}," - f"final_streamed_token.generated_tokens={final_streamed_token.generated_tokens}" - ) - assert result.generated_log_probs == final_streamed_token.generated_log_probs, ( - f"result.generated_log_probs={result.generated_log_probs}, " - f"final_streamed_token.generated_log_probs={final_streamed_token.generated_log_probs}" - ) - - @pytest.mark.parametrize("sequence_parallel", [False, True]) - @pytest.mark.parametrize("ep_size", [1, 2]) - @pytest.mark.parametrize("pp_size", [1, 2]) - @pytest.mark.parametrize("tp_size", [1, 2]) - def test_parallel_inference(self, tp_size, pp_size, ep_size, sequence_parallel): - if tp_size == 1 and pp_size == 1 and ep_size == 1: - pytest.skip(reason="Test requires tp_size > 1 or pp_size > 1 or ep_size > 1") - elif not torch.distributed.is_initialized(): - pytest.skip("Distributed not initialized") - world_size = torch.distributed.get_world_size() - min_world_size = tp_size * pp_size * ep_size - if world_size < min_world_size: - pytest.skip(f"Test requires at least {min_world_size} GPUs") - elif tp_size == 1 and sequence_parallel: - pytest.skip(reason="Sequence parallelism requires tp_size > 1") - elif tp_size > 1 and ep_size > 1 and not sequence_parallel: - pytest.skip(reason="Sequence parallelism must be used with tp_size > 1 and ep_size > 1") - - batch_size = 8 - - self.setup_engine( - engine_max_batch_size=batch_size, - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - expert_model_parallel_size=ep_size, - sequence_parallel=sequence_parallel, - ) - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.eod = -1 - - random.seed(42) - - # Generating random length integer prompts, ensuring sequence length is divisible by TP size - self.mock_tokenizer.tokenize.return_value = [ - random.randint(0, self.vocab_size - 1) for _ in range(32) - ] - # Generates some random string - self.mock_tokenizer.detokenize.return_value = ''.join( - random.choices(string.ascii_letters, k=random.randint(4, 10)) - ) - - prompts = ["sample" * (i + 1) for i in range(batch_size)] - - if sequence_parallel and (ep_size == 1 or tp_size == 1): - with pytest.raises(NotImplementedError): - results: List[InferenceRequest] = self.static_engine.generate( - prompts, sampling_params=SamplingParams(num_tokens_to_generate=10) - ) - return - else: - results: List[InferenceRequest] = self.static_engine.generate( - prompts, sampling_params=SamplingParams(num_tokens_to_generate=10) - ) - - assert len(results) == batch_size - for result in results: - assert ( - result.status == Status.COMPLETED - ), f"Status should be completed but its {result.status}" - assert result.generated_length > 0, f"Generated length should be greater than zero" - assert result.generated_text is not None, f'Generated text should not be None' diff --git a/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py b/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py deleted file mode 100644 index 644cb14998..0000000000 --- a/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py +++ /dev/null @@ -1,176 +0,0 @@ -from argparse import Namespace - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( - GPTInferenceWrapper, -) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestGPTInferenceWrapper: - - def setup_model(self, tensor_parallel_size, pipeline_parallel_size): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tensor_parallel_size, - pipeline_model_parallel_size=pipeline_parallel_size, - ) - model_parallel_cuda_manual_seed(123) - self.vocab_size = 100 - self.batch_size = 4 - self.sequence_length = 32 - hidden_size = 12 - - transformer_config = TransformerConfig( - num_layers=4, - hidden_size=hidden_size, - num_attention_heads=4, - use_cpu_initialization=True, - ) - - gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_local_spec(), - vocab_size=self.vocab_size, - max_sequence_length=self.sequence_length, - parallel_output=True, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), - ).cuda() - - inference_wrapper_config = InferenceWrapperConfig( - hidden_size=hidden_size, - inference_batch_times_seqlen_threshold=20, - inference_max_requests=self.batch_size, - fp32_residual_connection=False, - params_dtype=torch.float, - padded_vocab_size=self.vocab_size, - ) - - inference_context = StaticInferenceContext.from_config(inference_wrapper_config) - - self.inference_wrapped_model = GPTInferenceWrapper( - gpt_model, inference_wrapper_config, inference_context - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - # This will call the inference_wrapped_model.forward_pass_with_pipeline_parallel_small_input_batch() - @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) - def test_inference_pipeline_parallel_small_size(self, materialize_only_last_token_logits): - self.setup_model(tensor_parallel_size=2, pipeline_parallel_size=2) - - batch_prompt_tokens = ( - torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length)) - .int() - .cuda() - ) - self.inference_wrapped_model.prep_model_for_inference() - self.inference_wrapped_model.inference_context.materialize_only_last_token_logits = ( - materialize_only_last_token_logits - ) - - inference_input = self.inference_wrapped_model.prep_inference_input( - prompts_tokens=batch_prompt_tokens - ) - - inference_input_for_context_window = ( - self.inference_wrapped_model.get_batch_for_context_window(inference_input, 0, 5) - ) - - logits_seq_len = 1 if materialize_only_last_token_logits else 5 - - logits = self.inference_wrapped_model.run_one_forward_step( - inference_input_for_context_window - ) - # Logits are not returned in all ranks in PP - if parallel_state.is_pipeline_last_stage(): - assert logits.shape == ( - self.batch_size, - logits_seq_len, - self.vocab_size, - ), f"Shape mismatch . Expected {(self.batch_size, logits_seq_len, self.vocab_size)}, but got {logits.shape}" - - # This will call the inference_wrapped_model.forward_pass_with_pipeline_parallel_large_input_batch() - @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) - def test_inference_pipeline_parallel_large_size(self, materialize_only_last_token_logits): - self.setup_model(tensor_parallel_size=2, pipeline_parallel_size=2) - - batch_prompt_tokens = ( - torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length)) - .int() - .cuda() - ) - self.inference_wrapped_model.prep_model_for_inference() - self.inference_wrapped_model.inference_context.materialize_only_last_token_logits = ( - materialize_only_last_token_logits - ) - - inference_input = self.inference_wrapped_model.prep_inference_input( - prompts_tokens=batch_prompt_tokens - ) - - inference_input_for_context_window = ( - self.inference_wrapped_model.get_batch_for_context_window(inference_input, 0, 10) - ) - - logits_seq_len = 1 if materialize_only_last_token_logits else 10 - - logits = self.inference_wrapped_model.run_one_forward_step( - inference_input_for_context_window - ) - - if parallel_state.is_pipeline_last_stage(): - assert logits.shape == ( - self.batch_size, - logits_seq_len, - self.vocab_size, - ), f"Shape mismatch . Expected {(self.batch_size, logits_seq_len, self.vocab_size)}, but got {logits.shape}" - - @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) - def test_inference_only_tensor_parallel(self, materialize_only_last_token_logits): - self.setup_model(tensor_parallel_size=4, pipeline_parallel_size=1) - - batch_prompt_tokens = ( - torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length)) - .int() - .cuda() - ) - self.inference_wrapped_model.prep_model_for_inference() - self.inference_wrapped_model.inference_context.materialize_only_last_token_logits = ( - materialize_only_last_token_logits - ) - - inference_input = self.inference_wrapped_model.prep_inference_input( - prompts_tokens=batch_prompt_tokens - ) - - inference_input_for_context_window = ( - self.inference_wrapped_model.get_batch_for_context_window(inference_input, 0, 5) - ) - - logits_seq_len = 1 if materialize_only_last_token_logits else 5 - - logits = self.inference_wrapped_model.run_one_forward_step( - inference_input_for_context_window - ) - - assert logits.shape == ( - self.batch_size, - logits_seq_len, - self.vocab_size, - ), f"Shape mismatch . Expected {(self.batch_size, logits_seq_len, self.vocab_size)}, but got {logits.shape}" diff --git a/tests/unit_tests/inference/model_inference_wrappers/t5/test_t5_inference_wrapper.py b/tests/unit_tests/inference/model_inference_wrappers/t5/test_t5_inference_wrapper.py deleted file mode 100644 index 36d5187b5e..0000000000 --- a/tests/unit_tests/inference/model_inference_wrappers/t5/test_t5_inference_wrapper.py +++ /dev/null @@ -1,137 +0,0 @@ -from argparse import Namespace -from copy import deepcopy -from unittest import mock - -import numpy as np -import torch - -from megatron.core import parallel_state -from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import ( - T5InferenceWrapper, -) -from megatron.core.models.T5.t5_model import T5Model -from megatron.core.models.T5.t5_spec import ( - get_t5_decoder_with_transformer_engine_block_spec, - get_t5_encoder_with_transformer_engine_block_spec, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import AttnBackend -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestT5InferenceWrapper: - - def setup_model(self, tensor_parallel_size, pipeline_parallel_size): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tensor_parallel_size, - pipeline_model_parallel_size=pipeline_parallel_size, - ) - model_parallel_cuda_manual_seed(123) - self.vocab_size = 100 - self.batch_size = 8 - self.encoder_sequence_length = 32 - self.decoder_sequence_length = 16 - hidden_size = 768 - - transformer_config = TransformerConfig( - num_layers=12, - hidden_size=hidden_size, - num_attention_heads=12, - tensor_model_parallel_size=tensor_parallel_size, - pipeline_model_parallel_size=pipeline_parallel_size, - attention_backend=AttnBackend.unfused, - ) - - encoder_config = deepcopy(transformer_config) - encoder_config.num_layers = transformer_config.num_layers - - encoder_layers_per_pipeline = ( - encoder_config.num_layers // encoder_config.pipeline_model_parallel_size - ) - decoder_layers_per_pipeline = ( - transformer_config.num_layers // transformer_config.pipeline_model_parallel_size - ) - en_block_spec = get_t5_encoder_with_transformer_engine_block_spec( - encoder_layers_per_pipeline - ) - de_block_spec = get_t5_decoder_with_transformer_engine_block_spec( - decoder_layers_per_pipeline - ) - - t5_model = T5Model( - config=transformer_config, - encoder_config=encoder_config, - transformer_encoder_layer_spec=en_block_spec, - transformer_decoder_layer_spec=de_block_spec, - vocab_size=self.vocab_size, - max_sequence_length=self.encoder_sequence_length, - parallel_output=True, - pre_process=True, - post_process=True, - add_encoder=True, - add_decoder=True, - ).cuda() - - inference_wrapper_config = InferenceWrapperConfig( - hidden_size=hidden_size, - inference_batch_times_seqlen_threshold=-1, - fp32_residual_connection=False, - params_dtype=torch.float, - padded_vocab_size=self.vocab_size, - ) - - inference_context = StaticInferenceContext.from_config(inference_wrapper_config) - - self.inference_wrapped_model = T5InferenceWrapper( - t5_model, inference_wrapper_config, inference_context - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_inference_only_tensor_parallel(self): - self.setup_model(tensor_parallel_size=4, pipeline_parallel_size=1) - - batch_prompt_tokens = ( - torch.randint( - low=0, high=self.vocab_size, size=(self.batch_size, self.decoder_sequence_length) - ) - .int() - .cuda() - ) - batch_encoder_prompts = ["sample prompt encoders"] * self.batch_size - mock_tokenizer = mock.Mock() - mock_tokenizer.pad = self.vocab_size - 1 - mock_tokenizer.additional_special_tokens_ids = list(range(100)) - mock_tokenizer.tokenize.return_value = np.random.randint( - self.vocab_size, size=self.encoder_sequence_length - ).tolist() - - self.inference_wrapped_model.prep_model_for_inference() - - inference_input = self.inference_wrapped_model.prep_inference_input( - prompts_tokens=batch_prompt_tokens, - encoder_prompts=batch_encoder_prompts, - tokenizer=mock_tokenizer, - ) - - inference_input_for_context_window = ( - self.inference_wrapped_model.get_batch_for_context_window( - inference_input, 0, self.decoder_sequence_length - ) - ) - - logits = self.inference_wrapped_model.run_one_forward_step( - inference_input_for_context_window - ) - - assert logits.shape == ( - self.batch_size, - self.decoder_sequence_length, - self.vocab_size, - ), f"Shape mismatch . Expected {(self.batch_size, self.decoder_sequence_length, self.vocab_size)}, but got {logits.shape}" diff --git a/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py b/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py deleted file mode 100644 index 794634760d..0000000000 --- a/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch - -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) - - -class TestModelInferenceWrapperConfig: - - def test_inference_config(self): - inference_config = InferenceWrapperConfig( - hidden_size=10, - inference_batch_times_seqlen_threshold=10, - padded_vocab_size=10, - params_dtype=torch.float, - fp32_residual_connection=False, - ) - inference_config.add_attributes({"abc": 45}) - assert ( - inference_config.abc == 45 - ), f"min tokens not set correctly. it is {inference_config.min_tokens}" diff --git a/tests/unit_tests/inference/test_common_inference_params.py b/tests/unit_tests/inference/test_common_inference_params.py deleted file mode 100644 index c80cd2ab29..0000000000 --- a/tests/unit_tests/inference/test_common_inference_params.py +++ /dev/null @@ -1,11 +0,0 @@ -from megatron.core.inference.sampling_params import SamplingParams - - -class TestSamplingParams: - - def test_sampling_params(self): - sampling_params = SamplingParams() - sampling_params.add_attributes({"min_tokens": 45}) - assert ( - sampling_params.min_tokens == 45 - ), f"min tokens not set correctly. it is {sampling_params.min_tokens}" diff --git a/tests/unit_tests/inference/test_communication_utils.py b/tests/unit_tests/inference/test_communication_utils.py deleted file mode 100644 index 95de6c7056..0000000000 --- a/tests/unit_tests/inference/test_communication_utils.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -from megatron.core import parallel_state -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.inference.communication_utils import ( - broadcast_from_last_pipeline_stage, - recv_from_prev_pipeline_rank_, - send_to_next_pipeline_rank, -) -from megatron.core.utils import is_torch_min_version -from tests.unit_tests.test_utilities import Utils - - -class TestCommunicationWithCustomPPGroup: - """Test suite comparing communication with and without custom pp_group.""" - - @pytest.fixture(autouse=True) - def setup(self): - """Set up test parameters.""" - self.size = [16, 8] - self.dtype = torch.float32 - - @pytest.mark.skipif( - not is_torch_min_version("2.4.0"), - reason="torch.distributed.init_device_mesh requires torch >= 2.4.0", - ) - @pytest.mark.parametrize("tp_size,pp_size", [(1, 8), (2, 4), (4, 2)]) - def test_broadcast_comparison(self, tp_size, pp_size): - """Test broadcast with different parallel configurations.""" - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size - ) - - rank = dist.get_rank() - - device = torch.device(f"cuda:{rank}") - - # Set a random seed based on rank for reproducibility but different values - torch.manual_seed(rank) - - local_tensor = torch.randn(self.size, dtype=self.dtype, device=device) - - # Broadcast using global state - tensor_received_global = broadcast_from_last_pipeline_stage( - size=self.size, dtype=self.dtype, tensor=local_tensor - ) - - # Initialize torch.distributed if not already initialized - if not dist.is_initialized(): - dist.init_process_group(backend='nccl') - - # Note: HyperCommGrid uses minor-to-major order (tp, pp), which is reverse of device mesh - grid = HyperCommGrid([tp_size, pp_size], ["tp", "pp"]) - pp_group = grid.create_pg("pp") - - # Broadcast using custom pp_group - tensor_received_custom = broadcast_from_last_pipeline_stage( - size=self.size, dtype=self.dtype, tensor=local_tensor, pp_group=pp_group - ) - - # Synchronize before test - dist.barrier() - assert torch.allclose( - tensor_received_global, tensor_received_custom - ), "broadcast_from_last_pipeline_stage should be the same with or without custom pp_group" - Utils.destroy_model_parallel() - - @pytest.mark.skipif( - not is_torch_min_version("2.4.0"), - reason="torch.distributed.init_device_mesh requires torch >= 2.4.0", - ) - @pytest.mark.parametrize("tp_size,pp_size", [(1, 8), (2, 4), (4, 2)]) - def test_send_recv(self, tp_size, pp_size): - """Test send/recv in a ring pattern with different configs.""" - # Initialize model parallel for this test - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size - ) - - # Get rank info - rank = dist.get_rank() - - # Set a random seed based on rank for reproducibility but different values - torch.manual_seed(rank) - - # Create unique random data for this rank - device = torch.device(f"cuda:{rank}") - local_send_data = torch.randn(self.size, dtype=self.dtype, device=device) - - # Synchronize before test - dist.barrier() - - # Send/recv using global state - if not parallel_state.is_pipeline_first_stage(): - local_recv_buffer_global = torch.zeros(self.size, dtype=self.dtype, device=device) - recv_from_prev_pipeline_rank_(recv_buffer=local_recv_buffer_global) - else: - local_recv_buffer_global = torch.zeros(self.size, dtype=self.dtype, device=device) - - if not parallel_state.is_pipeline_last_stage(): - send_to_next_pipeline_rank(tensor=local_send_data) - - dist.barrier() - - # Initialize torch.distributed if not already initialized - if not dist.is_initialized(): - dist.init_process_group(backend='nccl') - - # Note: HyperCommGrid uses minor-to-major order (tp, pp), which is reverse of device mesh - grid = HyperCommGrid([tp_size, pp_size], ["tp", "pp"]) - pp_group = grid.create_pg("pp") - - # Send/recv using custom pp_group - if pp_group.rank() != 0: - local_recv_buffer_custom = torch.zeros(self.size, dtype=self.dtype, device=device) - recv_from_prev_pipeline_rank_(recv_buffer=local_recv_buffer_custom, pp_group=pp_group) - else: - local_recv_buffer_custom = torch.zeros(self.size, dtype=self.dtype, device=device) - - if pp_group.rank() != pp_group.size() - 1: - send_to_next_pipeline_rank(tensor=local_send_data, pp_group=pp_group) - - dist.barrier() - assert torch.allclose( - local_recv_buffer_global, local_recv_buffer_custom - ), "Custom and global recv buffers should be the same." - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/inference/test_flash_decode.py b/tests/unit_tests/inference/test_flash_decode.py deleted file mode 100644 index 77ac08c061..0000000000 --- a/tests/unit_tests/inference/test_flash_decode.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb_with_cos_sin -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding - - -class TestRotaryEmbeddingWithPrecomputedCosSin: - - def setup_method(self): - self.batch_size = 3 - self.seq_len = 4 - self.d_rot = 6 - self.rotary_embedding = RotaryEmbedding(kv_channels=4, rotary_percent=1.0) - - def test_output_shapes_match(self): - - # Create input tensors - t = torch.randn(self.seq_len, self.batch_size, 2, self.d_rot * 2, device="cuda") - rotary_pos_cos, rotary_pos_sin = self.rotary_embedding.get_cos_sin(self.seq_len) - - # Test using Flash Decoding optimized kernel which requires precomputed cos & sin tensors - expected_shape = torch.Size( - [self.seq_len, self.batch_size, self.seq_len // 2, self.seq_len * self.batch_size] - ) - output_flash_rotary = apply_rotary_pos_emb_with_cos_sin( - t, rotary_pos_cos, rotary_pos_sin, rotary_interleaved=True - ) - - assert ( - output_flash_rotary.shape == expected_shape - ), f"Outputs do not match: {output_flash_rotary.shape} != {expected_shape}" diff --git a/tests/unit_tests/inference/test_inference_utils.py b/tests/unit_tests/inference/test_inference_utils.py deleted file mode 100644 index fc4e69018d..0000000000 --- a/tests/unit_tests/inference/test_inference_utils.py +++ /dev/null @@ -1,12 +0,0 @@ -from megatron.core.inference.utils import Counter - - -class TestInferenceUtils: - - def test_counter(self): - counter = Counter() - r = next(counter) - assert r == 0, f'Counter return value should be 0 but it is {r}' - assert counter.counter == 1, f'Counter should be 1 but it is {counter.counter}' - counter.reset() - assert counter.counter == 0, f'Counter should be 0 but it is {counter.counter}' diff --git a/tests/unit_tests/inference/test_scheduler.py b/tests/unit_tests/inference/test_scheduler.py deleted file mode 100644 index 91bb55e3d6..0000000000 --- a/tests/unit_tests/inference/test_scheduler.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Dict - -import torch - -from megatron.core.inference.inference_request import InferenceRequest, Status -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.scheduler import Scheduler - - -class TestScheduler: - - def setup_method(self, method): - self.max_batch_size = 4 - self.scheduler = Scheduler(max_batch_size=self.max_batch_size) - assert ( - len(self.scheduler.active_request_pool) == 0 - ), "Active request pool should be empty on initalization" - assert ( - len(self.scheduler.waiting_request_pool) == 0 - ), "Waiting request pool should be empty on initalization" - assert ( - len(self.scheduler.completed_request_pool) == 0 - ), "Completed request pool should be empty on initalization" - - def test_scheduler(self): - prompt = "sample prompt" - prompt_tokens = torch.randn(5) - sampling_params = SamplingParams() - - active_request_ids = [] - for i in range(self.max_batch_size): - request_id = self.scheduler.add_request(prompt, prompt_tokens, sampling_params) - assert ( - len(self.scheduler.active_request_pool) == i + 1 - ), f"Active request pool should have {i+1} requests, but it has only {len(self.scheduler.active_request_pool)}" - active_request_ids.append(request_id) - - request_id = self.scheduler.add_request(prompt, prompt_tokens, sampling_params) - assert ( - len(self.scheduler.waiting_request_pool) == 1 - ), f"Waiting request pool should have 1 request but it has {len(self.scheduler.waiting_request_pool)} requests" - - waiting_request: InferenceRequest = list(self.scheduler.waiting_request_pool.values())[0] - assert ( - waiting_request.status == Status.WAITING_IN_QUEUE - ), f"Status should be WAITING_IN_QUEUE, but its {waiting_request.status} for the waiting request" - assert ( - request_id == waiting_request.request_id - ), f"Waiting request request ID should match returned request ID" - - assert ( - self.scheduler.have_requests_pending() - ), "Scheduler should have requests pending, but it seems to be having no requests" - - active_request_dict: Dict[str, InferenceRequest] = self.scheduler.active_request_pool - assert set(active_request_dict.keys()) == set( - active_request_ids - ), f"Active request pool IDs should match returned request IDs" - for request_id, request in active_request_dict.items(): - # Mark every even request compelted - if int(request_id) % 2 == 0: - request.status = Status.COMPLETED - - self.scheduler.update_requests_pools(active_request_dict) - assert ( - len(self.scheduler.active_request_pool) == 3 - ), f"Active request pool should have 3 requests, but it has {len(self.scheduler.active_request_pool)}" - - assert ( - len(self.scheduler.waiting_request_pool) == 0 - ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests" - - assert ( - len(self.scheduler.completed_request_pool) == 2 - ), f"Completed request pool should have 2 requests but it has {len(self.scheduler.completed_request_pool)} requests " - - active_request_dict: Dict[str, InferenceRequest] = self.scheduler.active_request_pool - for request_id, request in active_request_dict.items(): - # Mark all requests compelted - request.status = Status.COMPLETED - - self.scheduler.update_requests_pools(active_request_dict) - assert ( - len(self.scheduler.active_request_pool) == 0 - ), f"Active request pool should be empty, but it has {len(self.scheduler.active_request_pool)}" - - assert ( - len(self.scheduler.waiting_request_pool) == 0 - ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests" - - assert ( - len(self.scheduler.completed_request_pool) == 5 - ), f"Completed request pool should have 5 requests but it has {len(self.scheduler.completed_request_pool)} requests " - - assert ( - self.scheduler.have_requests_pending() == False - ), "Scheduler should not have any requests pending" diff --git a/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py deleted file mode 100644 index 93a208710f..0000000000 --- a/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py +++ /dev/null @@ -1,150 +0,0 @@ -import random -import string -import time -from collections import OrderedDict -from copy import deepcopy -from typing import Dict -from unittest import mock - -import numpy as np -import pytest -import torch - -from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.inference.inference_request import InferenceRequest, Status -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import ( - T5InferenceWrapper, -) -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( - EncoderDecoderTextGenerationController, -) -from megatron.core.models.T5.t5_model import T5Model -from megatron.core.models.T5.t5_spec import ( - get_t5_decoder_with_transformer_engine_block_spec, - get_t5_encoder_with_transformer_engine_block_spec, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import AttnBackend -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestEncoderDecoderTextGenerationController: - - def setup_method(self, method): - Utils.initialize_model_parallel( - tensor_model_parallel_size=4, pipeline_model_parallel_size=1 - ) - model_parallel_cuda_manual_seed(123) - self.vocab_size = 100 - self.batch_size = 8 - self.encoder_sequence_length = 32 - self.decoder_sequence_length = 16 - hidden_size = 768 - - transformer_config = TransformerConfig( - num_layers=12, - hidden_size=hidden_size, - num_attention_heads=12, - tensor_model_parallel_size=4, - pipeline_model_parallel_size=1, - attention_backend=AttnBackend.unfused, - ) - - encoder_config = deepcopy(transformer_config) - encoder_config.num_layers = transformer_config.num_layers - - encoder_layers_per_pipeline = ( - encoder_config.num_layers // encoder_config.pipeline_model_parallel_size - ) - decoder_layers_per_pipeline = ( - transformer_config.num_layers // transformer_config.pipeline_model_parallel_size - ) - en_block_spec = get_t5_encoder_with_transformer_engine_block_spec( - encoder_layers_per_pipeline - ) - de_block_spec = get_t5_decoder_with_transformer_engine_block_spec( - decoder_layers_per_pipeline - ) - - t5_model = T5Model( - config=transformer_config, - encoder_config=encoder_config, - transformer_encoder_layer_spec=en_block_spec, - transformer_decoder_layer_spec=de_block_spec, - vocab_size=self.vocab_size, - max_sequence_length=self.encoder_sequence_length, - parallel_output=True, - pre_process=True, - post_process=True, - add_encoder=True, - add_decoder=True, - ).cuda() - - inference_wrapper_config = InferenceWrapperConfig( - hidden_size=hidden_size, - inference_batch_times_seqlen_threshold=-1, - fp32_residual_connection=False, - params_dtype=torch.float, - padded_vocab_size=self.vocab_size, - ) - - inference_context = StaticInferenceContext.from_config(inference_wrapper_config) - - inference_wrapped_model = T5InferenceWrapper( - t5_model, inference_wrapper_config, inference_context - ) - - self.mock_tokenizer = mock.Mock() - - self.text_generation_controller = EncoderDecoderTextGenerationController( - inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_generate_all_output_tokens_static_batch(self): - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.pad = self.vocab_size - 2 - self.mock_tokenizer.additional_special_tokens_ids = list(range(100)) - self.mock_tokenizer.detokenize.return_value = ''.join( - random.choices(string.ascii_letters, k=random.randint(4, 10)) - ) - self.mock_tokenizer.tokenize.return_value = np.random.randint( - self.vocab_size, size=(self.encoder_sequence_length - 5) - ).tolist() - - active_requests: Dict[str, InferenceRequest] = OrderedDict() - for i in range(self.batch_size): - prompt = "decoder_sample" - prompt_tokens = np.random.randint( - self.vocab_size, size=self.decoder_sequence_length - ).tolist() - encoder_prompt = "encoder_sample" - inference_request = InferenceRequest( - request_id=i, - prompt=prompt, - encoder_prompt=encoder_prompt, - sampling_params=SamplingParams(num_tokens_to_generate=10), - arrival_time=time.time(), - prompt_tokens=prompt_tokens, - status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, - ) - active_requests[i] = inference_request - - requests = self.text_generation_controller.generate_all_output_tokens_static_batch( - active_requests - ) - - for request_id, request in requests.items(): - assert ( - request.status == Status.COMPLETED - ), f"Status should be completed but its {request.status}" - assert request.generated_length > 0, f"Generated length should be greater than zero" - assert request.generated_text is not None, "Generated text should not be None" diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py deleted file mode 100644 index 68f0062c8d..0000000000 --- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py +++ /dev/null @@ -1,564 +0,0 @@ -import copy -import os -import random -import string -import time -from collections import OrderedDict, defaultdict -from typing import Dict, List -from unittest import mock - -import pytest -import torch -from transformer_engine.pytorch.fp8 import check_fp8_support - -from megatron.core import parallel_state -from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.inference.contexts.dynamic_context import MaxSequenceLengthOverflowError -from megatron.core.inference.inference_request import InferenceRequest, Status -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( - GPTInferenceWrapper, -) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.text_generation_controllers.text_generation_controller import ( - TextGenerationController, -) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import AttnBackend -from megatron.core.transformer.module import Float16Module -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from tests.unit_tests.test_utilities import Utils - - -class TestTextGenerationController: - - def setup_model(self, dtype, symmetric_ar_type=None, fp8: bool = False): - Utils.initialize_model_parallel( - tensor_model_parallel_size=2, pipeline_model_parallel_size=1 - ) - model_parallel_cuda_manual_seed(123) - self.batch_size = 4 - self.hidden_size = 12 - self.vocab_size = 100 - self.sequence_length = 60 if fp8 else 64 # Test padding for fp8 - transformer_config = TransformerConfig( - num_layers=4, - hidden_size=self.hidden_size, - num_attention_heads=4, - use_cpu_initialization=True, - attention_backend=AttnBackend.local, - params_dtype=dtype, - symmetric_ar_type=symmetric_ar_type, - fp8="hybrid" if fp8 else None, - fp8_recipe="tensorwise" if fp8 else None, - fp8_param=fp8, - ) - if dtype == torch.bfloat16: - transformer_config.bf16 = True - - gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_local_spec(), - vocab_size=self.vocab_size, - max_sequence_length=self.sequence_length, - parallel_output=True, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), - ).cuda() - if dtype == torch.bfloat16: - gpt_model = Float16Module(gpt_model.config, gpt_model) - - inference_wrapper_config = InferenceWrapperConfig( - hidden_size=self.hidden_size, - inference_batch_times_seqlen_threshold=-1, - inference_max_seq_length=2048, - inference_max_requests=16 if fp8 else self.batch_size, - fp32_residual_connection=False, - params_dtype=dtype, - padded_vocab_size=self.vocab_size, - ) - - inference_context = StaticInferenceContext.from_config(inference_wrapper_config) - - inference_wrapped_model = GPTInferenceWrapper( - gpt_model, inference_wrapper_config, inference_context - ) - - self.mock_tokenizer = mock.Mock() - - self.text_generation_controller = TextGenerationController( - inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_sample_from_logits(self): - self.setup_model(torch.float32) - - with pytest.raises(AssertionError) as aerror: - self.text_generation_controller.sample_from_logits( - last_token_logits=None, - sampling_params=SamplingParams(top_k=2, top_p=0.4), - vocab_size=self.vocab_size, - ) - assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero' - - with pytest.raises(AssertionError) as aerror: - self.text_generation_controller.sample_from_logits( - last_token_logits=None, - sampling_params=SamplingParams(top_p=1.4, top_k=0), - vocab_size=self.vocab_size, - ) - assert str(aerror.value) == 'top-p should be in (0,1]' - - with pytest.raises(AssertionError) as aerror: - self.text_generation_controller.sample_from_logits( - last_token_logits=torch.randn(self.batch_size, 1), - sampling_params=SamplingParams(top_k=self.vocab_size + 10), - vocab_size=self.vocab_size, - ) - assert str(aerror.value) == 'top-k is larger than logit size.' - - last_token_logits = ( - torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda() - ) - sampled_logits = self.text_generation_controller.sample_from_logits( - last_token_logits, SamplingParams(top_k=1), self.vocab_size - ) - assert torch.all( - sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1 - ), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}" - - top_n_logprobs_dict = defaultdict(list) - - class MockTokenizer: - def detokenize(self, inp, skip_special_tokens=False): - return inp[0] - - self.text_generation_controller.tokenizer = MockTokenizer() - last_token_logits_top_n_input = ( - torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda() / 10 - ) - sampled_logits = self.text_generation_controller.sample_from_logits( - last_token_logits_top_n_input, - SamplingParams(top_k=1, top_n_logprobs=3), - self.vocab_size, - generation_started=torch.tensor([True] * self.batch_size), - top_n_logprobs_dict=top_n_logprobs_dict, - ) - - assert list(top_n_logprobs_dict[0][0].values()) == pytest.approx( - [-2.3521223068237305, -2.452122688293457, -2.5521230697631836], abs=1e-3 - ) - - sampled_logits = self.text_generation_controller.sample_from_logits( - last_token_logits, SamplingParams(top_k=2), self.vocab_size - ) - assert torch.all( - sampled_logits >= self.vocab_size - 2 - ), f"The sampled logits should all be greater than {self.vocab_size-2} but its {sampled_logits}" - - l = last_token_logits[0] - top_p = 0.3 - expected_min_value = l[l.softmax(dim=-1).cumsum(dim=-1) > top_p][0].item() - sampled_logits = self.text_generation_controller.sample_from_logits( - last_token_logits, SamplingParams(top_p=top_p, top_k=0), self.vocab_size - ) - assert torch.all( - sampled_logits >= expected_min_value - ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" - - top_p = 0.95 - temperature = 2 - expected_min_value = l[l.div_(temperature).softmax(dim=-1).cumsum(dim=-1) > top_p][0].item() - sampled_logits = self.text_generation_controller.sample_from_logits( - last_token_logits, - SamplingParams(top_p=top_p, temperature=temperature, top_k=0), - self.vocab_size, - ) - assert torch.all( - sampled_logits >= expected_min_value - ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" - - @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) - @pytest.mark.parametrize( - "symmetric_ar_type", - [ - None, - pytest.param( - "multimem_all_reduce", - marks=pytest.mark.skipif( - not is_te_min_version("2.3"), - reason="multimem_all_reduce requires Transformer Engine >= 2.3", - ), - ), - ], - ) - @pytest.mark.parametrize("fp8", [False, True]) - def test_generate_all_output_tokens_static_batch(self, dtype, symmetric_ar_type, fp8): - if fp8: - fp8_available, reason_for_no_fp8 = check_fp8_support() - if not fp8_available: - pytest.skip(reason_for_no_fp8) - elif not is_te_min_version("2.2.0"): - pytest.skip(reason="TE 2.2.0 is required") - elif dtype != torch.bfloat16: - pytest.skip("Only testing fp8 inference with bf16 params") - - self.setup_model(dtype, symmetric_ar_type, fp8) - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.detokenize.side_effect = lambda x, skip_special_tokens=False: ' '.join( - [ - ''.join(random.choices(string.ascii_letters, k=random.randint(4, 10))) - for _ in range(len(x)) - ] - ) - self.mock_tokenizer.offsets.side_effect = lambda _, s: [ - i for i, c in enumerate(s) if c == ' ' - ] + [len(s)] - - active_requests: Dict[str, InferenceRequest] = OrderedDict() - all_prompt_tokens: Dict[str, List[int]] = OrderedDict() - for i in range(self.batch_size): - prompt = "sample" * (i + 1) - self.mock_tokenizer.tokenize.return_value = torch.randn( - self.batch_size, self.vocab_size - ).cuda() - prompt_tokens = torch.randint( - low=0, high=self.vocab_size - 1, size=(len(prompt),) - ).tolist() - - request_id = str(i) - inference_request = InferenceRequest( - request_id=request_id, - prompt=prompt, - sampling_params=SamplingParams( - num_tokens_to_generate=10, return_log_probs=True, return_segments=True - ), - arrival_time=time.time(), - prompt_tokens=prompt_tokens, - status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, - ) - active_requests[request_id] = inference_request - all_prompt_tokens[request_id] = copy.deepcopy(prompt_tokens) - - requests = self.text_generation_controller.generate_all_output_tokens_static_batch( - active_requests - ) - - for request_id, request in requests.items(): - assert ( - request.status == Status.COMPLETED - ), f"Status should be completed but its {request.status}" - assert request.generated_length > 0, f"Generated length should be greater than zero" - assert request.generated_text is not None, "Generated text should not be None" - assert ( - all_prompt_tokens[request_id] == request.prompt_tokens - ), "Prompt tokens should not have changed during generation" - # Log probabilities are calculated based on the likelihood of a token given the - # preceding context. The first token lacks this dependency and is excluded from - # the logprobs output, which is why the +1 is necessary - assert ( - len(request.segments) - == len(request.prompt_log_probs) + len(request.generated_log_probs) + 1 - ), "Segments should be returned for both prompt and generated tokens" - assert len(request.prompt) + len(request.generated_text) == len( - request.text - ), "Output text should include prompts and generations" - assert ( - request.tpot is not None - and isinstance(request.tpot, list) - and len(request.tpot) == request.generated_length - ) - - @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) - def test_output_log_probs(self, dtype): - self.setup_model(dtype) - - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.bos = 0 - self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.detokenize.side_effect = lambda x, skip_special_tokens=False: ' '.join( - [ - ''.join(random.choices(string.ascii_letters, k=random.randint(4, 10))) - for _ in range(len(x)) - ] - ) - self.mock_tokenizer.offsets.side_effect = lambda _, s: [ - i for i, c in enumerate(s) if c == ' ' - ] + [len(s)] - - prompt = "" - active_requests: Dict[int, InferenceRequest] = OrderedDict() - for i in range(self.batch_size): - self.mock_tokenizer.tokenize.return_value = torch.randn( - self.batch_size, self.vocab_size - ).cuda() - inference_request = InferenceRequest( - request_id=i, - prompt=prompt, - sampling_params=SamplingParams(num_tokens_to_generate=1, return_log_probs=True), - arrival_time=time.time(), - prompt_tokens=[self.mock_tokenizer.bos], - status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, - ) - active_requests[i] = inference_request - - requests = self.text_generation_controller.generate_all_output_tokens_static_batch( - active_requests - ) - - for request_id, request in requests.items(): - assert ( - request.status == Status.COMPLETED - ), f"Status should be completed but its {request.status}" - assert request.generated_length > 0, f"Generated length should be greater than zero" - assert request.generated_text is not None, "Generated text should not be None" - assert len(request.generated_log_probs) == request.generated_length - - @pytest.mark.parametrize("num_tokens_to_generate", [0, 4]) - @pytest.mark.parametrize("return_prompt_top_n_logprobs", [True, False]) - @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) - def test_logprobs_and_topn_consistency( - self, num_tokens_to_generate, return_prompt_top_n_logprobs, dtype - ): - """ - 1. Ensures that a batch request containing prompts of - *different* lengths still returns the correct number of log‑probs for - every request. - 2. Verifies that, for every token whose log prob is returned, the value - exactly matches the log prob reported for that same token in the - `top_n_logprobs` payload. - """ - self.setup_model(dtype) - - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.bos = 0 - self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.detokenize.side_effect = lambda toks, **_: " ".join( - f"T{t}" for t in toks - ) # unique, deterministic - self.mock_tokenizer.offsets.side_effect = lambda _, s: [ - i for i, c in enumerate(s) if c == " " - ] + [len(s)] - - prompts = ["a", "foo", "foobar", "lorem ipsum"] - active_reqs: Dict[str, InferenceRequest] = OrderedDict() - - for rid, p in enumerate(prompts): - prompt_tokens = torch.randint(1, self.vocab_size - 2, (len(p) + 1,)).tolist() # +bos - prompt_tokens[0] = self.mock_tokenizer.bos # ensure BOS - - self.mock_tokenizer.tokenize.return_value = torch.randn( - self.batch_size, self.vocab_size - ).cuda() - - active_reqs[str(rid)] = InferenceRequest( - request_id=str(rid), - prompt=p, - prompt_tokens=prompt_tokens, - sampling_params=SamplingParams( - num_tokens_to_generate=num_tokens_to_generate, - top_k=1, - top_p=0.0, - temperature=0.0, - return_log_probs=True, - top_n_logprobs=5, - return_prompt_top_n_logprobs=return_prompt_top_n_logprobs, - ), - arrival_time=time.time(), - status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, - ) - - completed = self.text_generation_controller.generate_all_output_tokens_static_batch( - active_reqs - ) - - for request_id, request in completed.items(): - prompt_log_probs = request.prompt_log_probs - generated_log_probs = request.generated_log_probs - prompt_top_n_logprobs = request.prompt_top_n_logprobs - generated_top_n_logprobs = request.generated_top_n_logprobs - generated_tokens = request.generated_tokens - - assert len(prompt_log_probs) == len(request.prompt_tokens) - 1, ( - f"{request_id}: Expected {len(request.prompt_tokens)-1} prompt log probs, " - f"got {len(prompt_log_probs)}" - ) - assert len(generated_log_probs) == request.generated_length, ( - f"{request_id}: Expected {request.generated_length} generated log probs, " - f"got {len(generated_log_probs)}" - ) - - assert (not return_prompt_top_n_logprobs and prompt_top_n_logprobs is None) or ( - return_prompt_top_n_logprobs - and prompt_top_n_logprobs is not None - and len(prompt_top_n_logprobs) == len(prompt_log_probs) - ) - assert len(generated_top_n_logprobs) == request.generated_length, ( - f"{request_id}: Expected {request.generated_length} generated log probs, " - f"got {len(generated_top_n_logprobs)}" - ) - assert ( - request.tpot is not None - and isinstance(request.tpot, list) - and len(request.tpot) == request.generated_length - ) - - # Verify that the generated log probs match what is returned - # in the top-N log probs dict - for k, log_probs in enumerate(generated_log_probs): - token_id = generated_tokens[k] - top_n = generated_top_n_logprobs[k] - token = self.mock_tokenizer.detokenize([token_id]) - - assert token in top_n, f"{request_id}: Generated token {token} missing in top‑N" - assert ( - pytest.approx(log_probs, rel=1e-6) == top_n[token] - ), f"{request_id}: mismatch @ generated token {k}: {log_probs} vs {top_n[token]}" - - def test_token_overflow(self): - self.setup_model(torch.float32) - - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.bos = 0 - self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.detokenize.side_effect = lambda x: ' '.join( - [ - ''.join(random.choices(string.ascii_letters, k=random.randint(4, 10))) - for _ in range(len(x)) - ] - ) - self.mock_tokenizer.offsets.side_effect = lambda _, s: [ - i for i, c in enumerate(s) if c == ' ' - ] + [len(s)] - - prompt = "" - active_requests: Dict[int, InferenceRequest] = OrderedDict() - for i in range(self.batch_size): - self.mock_tokenizer.tokenize.return_value = torch.randn( - self.batch_size, self.vocab_size - ).cuda() - inference_request = InferenceRequest( - request_id=i, - prompt=prompt, - sampling_params=SamplingParams(num_tokens_to_generate=4096, return_log_probs=True), - arrival_time=time.time(), - prompt_tokens=[self.mock_tokenizer.bos], - status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, - ) - active_requests[i] = inference_request - - with pytest.raises(MaxSequenceLengthOverflowError): - requests = self.text_generation_controller.generate_all_output_tokens_static_batch( - active_requests - ) - - def test_zero_tokens_generated_batch_vs_single(self): - """ - Verifies that when `num_tokens_to_generate=0`, the outputs from batched inference - match the outputs from single-request inference for prompt-related fields. - """ - self.setup_model(dtype=torch.bfloat16) - - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.bos = 0 - self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.detokenize.side_effect = lambda toks, **_: " ".join( - f"T{t}" for t in toks - ) # unique, deterministic - self.mock_tokenizer.offsets.side_effect = lambda _, s: [ - i for i, c in enumerate(s) if c == " " - ] + [len(s)] - - prompts = [ - "a short prompt", - "a slightly longer prompt that still fits", - "an even longer prompt to test prompt length variability", - ] - batch_size_test = len(prompts) - active_requests_batched: Dict[str, InferenceRequest] = OrderedDict() - expected_single_requests: Dict[str, InferenceRequest] = OrderedDict() - - for rid, p in enumerate(prompts): - prompt_tokens = torch.randint(1, self.vocab_size - 2, (len(p) + 1,)).tolist() - prompt_tokens[0] = self.mock_tokenizer.bos - - # Mock tokenize for consistency across batch and single - self.mock_tokenizer.tokenize.return_value = torch.randn( - batch_size_test, self.vocab_size - ).cuda() - - sampling_params = SamplingParams( - num_tokens_to_generate=0, - temperature=0.0, - top_k=1, - return_log_probs=True, - top_n_logprobs=5, - return_prompt_top_n_logprobs=True, - ) - - inference_request = InferenceRequest( - request_id=str(rid), - prompt=p, - prompt_tokens=prompt_tokens, - sampling_params=copy.deepcopy(sampling_params), - arrival_time=time.time(), - status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, - ) - active_requests_batched[str(rid)] = copy.deepcopy(inference_request) - expected_single_requests[str(rid)] = copy.deepcopy(inference_request) - - # Perform batched inference - completed_batched = self.text_generation_controller.generate_all_output_tokens_static_batch( - active_requests_batched - ) - - # Perform single-request inference for comparison - completed_single: Dict[str, InferenceRequest] = OrderedDict() - for request_id, req in expected_single_requests.items(): - single_request_dict = {request_id: req} - result = self.text_generation_controller.generate_all_output_tokens_static_batch( - single_request_dict - ) - completed_single.update(result) - - # Compare results - for request_id in completed_batched.keys(): - request_batched = completed_batched[request_id] - request_single = completed_single[request_id] - - assert request_batched.status == Status.COMPLETED - assert request_single.status == Status.COMPLETED - - assert request_batched.generated_length == 0 - assert request_single.generated_length == 0 - - assert request_batched.prompt_tokens == request_single.prompt_tokens - assert request_batched.prompt_log_probs == pytest.approx( - request_single.prompt_log_probs - ) - - # Assert prompt_top_n_logprobs for consistency - assert request_batched.prompt_top_n_logprobs is not None - assert request_single.prompt_top_n_logprobs is not None - assert len(request_batched.prompt_top_n_logprobs) == len( - request_single.prompt_top_n_logprobs - ) - for i in range(len(request_batched.prompt_top_n_logprobs)): - assert ( - request_batched.prompt_top_n_logprobs[i].keys() - == request_single.prompt_top_n_logprobs[i].keys() - ) - for token_str in request_batched.prompt_top_n_logprobs[i]: - assert ( - pytest.approx(request_batched.prompt_top_n_logprobs[i][token_str], rel=1e-6) - == request_single.prompt_top_n_logprobs[i][token_str] - ) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_vlm_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_vlm_text_generation_controller.py deleted file mode 100644 index 5eb99c933f..0000000000 --- a/tests/unit_tests/inference/text_generation_controllers/test_vlm_text_generation_controller.py +++ /dev/null @@ -1,171 +0,0 @@ -import copy -import os -import random -import string -import time -from argparse import Namespace -from collections import OrderedDict -from typing import Dict -from unittest import mock - -import pytest -import torch - -from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.inference.inference_request import InferenceRequest, Status, VLMInferenceRequest -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.inference.model_inference_wrappers.multimodal.vlm_inference_wrapper import ( - VLMInferenceWrapper, -) -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.text_generation_controllers.vlm_text_generation_controller import ( - VLMTextGenerationController, -) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.models.multimodal.llava_model import LLaVAModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import AttnBackend -from megatron.core.transformer.module import Float16Module -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestVLMTextGenerationController: - - @pytest.mark.internal # The model is under active development and its methods may change. - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - - self.language_hidden_size = 64 - self.language_num_attention_heads = 4 - self.language_vocab_size = 8192 - self.language_max_sequence_length = 4096 - self.img_h = 336 - self.img_w = 336 - - language_config = TransformerConfig( - num_layers=3, - hidden_size=self.language_hidden_size, - num_attention_heads=self.language_num_attention_heads, - use_cpu_initialization=False, - bf16=True, - ) - vision_config = TransformerConfig( - num_layers=2, - hidden_size=16, - num_attention_heads=2, - use_cpu_initialization=False, - bf16=True, - ) - vision_projection_config = TransformerConfig( - num_layers=2, - hidden_size=self.language_hidden_size, - ffn_hidden_size=32, - num_attention_heads=1, - use_cpu_initialization=False, - bf16=True, - ) - - language_layer_spec = get_gpt_layer_local_spec() - vision_layer_spec = copy.deepcopy(language_layer_spec) - vision_projection_spec = copy.deepcopy(language_layer_spec.submodules.mlp.submodules) - - language_config.language_model_type = "dummy" - vision_config.vision_model_type = "clip" - self.model = LLaVAModel( - language_transformer_config=language_config, - language_transformer_layer_spec=language_layer_spec, - language_vocab_size=self.language_vocab_size, - language_max_sequence_length=self.language_max_sequence_length, - vision_transformer_config=vision_config, - vision_transformer_layer_spec=vision_layer_spec, - drop_vision_class_token=False, - vision_projection_config=vision_projection_config, - vision_projection_layer_spec=vision_projection_spec, - img_h=self.img_h, - img_w=self.img_w, - patch_dim=14, - ).cuda() - self.image_token_index = self.model.image_token_index - self.model = Float16Module(self.model.config, self.model) - - inference_wrapper_config = InferenceWrapperConfig( - hidden_size=self.language_hidden_size, - inference_batch_times_seqlen_threshold=-1, - fp32_residual_connection=False, - params_dtype=torch.float, - padded_vocab_size=self.language_vocab_size, - ) - - inference_context = StaticInferenceContext.from_config(inference_wrapper_config) - - inference_wrapped_model = VLMInferenceWrapper( - self.model, inference_wrapper_config, inference_context - ) - - self.mock_tokenizer = mock.Mock() - - self.text_generation_controller = VLMTextGenerationController( - inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_generate_all_output_tokens_static_batch(self): - self.mock_tokenizer.vocab_size = self.language_vocab_size - self.mock_tokenizer.eod = self.language_vocab_size - 1 - self.mock_tokenizer.detokenize.return_value = ''.join( - random.choices(string.ascii_letters, k=random.randint(4, 10)) - ) - - batch_size: int = 1 - num_img_embeddings_per_tile: int = 576 - imgs: torch.Tensor = torch.randn(1, 3, self.img_h, self.img_w).cuda() - num_tiles: torch.Tensor = torch.Tensor([1]).int() - decoder_seq_length: int = self.language_max_sequence_length - - active_requests: Dict[str, InferenceRequest] = OrderedDict() - all_prompt_tokens: Dict[str, List[int]] = OrderedDict() - for i in range(batch_size): - prompt = "sample" * (i + 1) - self.mock_tokenizer.tokenize.return_value = torch.randn( - batch_size, self.language_vocab_size - ).cuda() - prompt_tokens = torch.randint( - low=0, high=self.language_vocab_size - 1, size=(len(prompt),) - ).tolist() - prompt_tokens[3] = self.image_token_index - - request_id = str(i) - inference_request = VLMInferenceRequest( - request_id=request_id, - prompt=prompt, - sampling_params=SamplingParams(num_tokens_to_generate=10), - arrival_time=time.time(), - prompt_tokens=prompt_tokens, - num_img_embeddings_per_tile=num_img_embeddings_per_tile, - imgs=imgs, - num_tiles=num_tiles, - decoder_seq_length=decoder_seq_length, - status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, - ) - active_requests[request_id] = inference_request - all_prompt_tokens[request_id] = copy.deepcopy(prompt_tokens) - - requests = self.text_generation_controller.generate_all_output_tokens_static_batch( - active_requests - ) - - for request_id, request in requests.items(): - assert ( - request.status == Status.COMPLETED - ), f"Status should be completed but its {request.status}" - assert request.generated_length > 0, f"Generated length should be greater than zero" - assert request.generated_text is not None, "Generated text should not be None" - assert ( - all_prompt_tokens[request_id] == request.prompt_tokens - ), "Prompt tokens should not have changed during generation" diff --git a/tests/unit_tests/models/test_base_embedding.py b/tests/unit_tests/models/test_base_embedding.py deleted file mode 100644 index 0ce18b3843..0000000000 --- a/tests/unit_tests/models/test_base_embedding.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestBaseEmbedding: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - self.base_embedding = LanguageModelEmbedding( - config=transformer_config, - vocab_size=100, - max_sequence_length=4, - position_embedding_type='learned_absolute', - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - assert isinstance(self.base_embedding, LanguageModelEmbedding) - num_weights = sum([p.numel() for p in self.base_embedding.parameters()]) - assert num_weights == 1248 - - def test_zero_parameters(self): - sum_weights = sum([p.sum() for p in self.base_embedding.parameters()]) - assert sum_weights != 0 - self.base_embedding.zero_parameters() - sum_weights = sum([p.sum() for p in self.base_embedding.parameters()]) - assert sum_weights == 0 - - def test_cpu_forward(self): - input_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)) - position_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)) - embeddings = self.base_embedding(input_ids, position_ids) - assert embeddings.device.type == 'cpu' - assert embeddings.shape[0] == self.base_embedding.max_sequence_length - assert embeddings.shape[1] == input_ids.shape[0] - assert embeddings.shape[2] == self.base_embedding.config.hidden_size - - def test_gpu_forward(self): - self.base_embedding.cuda() - input_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda() - position_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda() - embeddings = self.base_embedding(input_ids, position_ids) - assert embeddings.device.type == 'cuda' - assert embeddings.shape[0] == self.base_embedding.max_sequence_length - assert embeddings.shape[1] == input_ids.shape[0] - assert embeddings.shape[2] == self.base_embedding.config.hidden_size diff --git a/tests/unit_tests/models/test_bert_model.py b/tests/unit_tests/models/test_bert_model.py deleted file mode 100644 index b30d1413cf..0000000000 --- a/tests/unit_tests/models/test_bert_model.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import os -from importlib.metadata import version - -import pytest -import torch -from packaging.version import Version as PkgVersion -from pytest_mock import mocker - -from megatron.core.models.bert.bert_layer_specs import ( - bert_layer_local_spec, - bert_layer_with_transformer_engine_spec, -) -from megatron.core.models.bert.bert_model import BertModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import AttnBackend, AttnMaskType -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestBertModel: - - def setup_method(self, method): - tp = 1 - pp = 1 - Utils.initialize_model_parallel(tp, pp) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - perform_initialization=True, - tensor_model_parallel_size=tp, - pipeline_model_parallel_size=pp, - pipeline_dtype=torch.bfloat16, - attention_backend=AttnBackend.unfused, - ) - self.bert_model = BertModel( - config=transformer_config, - num_tokentypes=0, - transformer_layer_spec=bert_layer_with_transformer_engine_spec, - vocab_size=100, - max_sequence_length=4, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_constructor(self): - assert isinstance(self.bert_model, BertModel) - - assert self.bert_model.max_sequence_length == 4 - - num_weights = sum([p.numel() for p in self.bert_model.parameters()]) - assert num_weights == 6702 - - @pytest.mark.internal - def test_set_input_tensor(self): - config: TransformerConfig = self.bert_model.config - sequence_length = self.bert_model.max_sequence_length - micro_batch_size = 2 - - # [sequence length, batch size, hidden size] - input_tensor = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - - self.bert_model.set_input_tensor(input_tensor) - - assert self.bert_model.encoder.input_tensor.shape[0] == sequence_length - assert self.bert_model.encoder.input_tensor.shape[1] == micro_batch_size - assert self.bert_model.encoder.input_tensor.shape[2] == config.hidden_size - - @pytest.mark.internal - def test_post_process_forward(self): - config: TransformerConfig = self.bert_model.config - sequence_length = self.bert_model.max_sequence_length - micro_batch_size = 2 - - self.bert_model.cuda() - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones((micro_batch_size, sequence_length), dtype=bool).cuda() - - logits = self.bert_model.forward(input_ids=input_ids, attention_mask=attention_mask) - - assert logits[0].shape[0] == micro_batch_size - assert logits[0].shape[1] == sequence_length - assert logits[0].shape[2] == self.bert_model.vocab_size - - -class TestBertModelAttentionDimensions: - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - attention_backend=AttnBackend.auto, - ) - # This should convert arbitray mask to padding mask - self.bert_model = BertModel( - config=self.transformer_config, - num_tokentypes=0, - transformer_layer_spec=bert_layer_with_transformer_engine_spec, - vocab_size=100, - max_sequence_length=4, - ) - - @pytest.mark.internal - def test_local_spec(self, mocker): - self.bert_model.config.attention_backend = AttnBackend.local - self.bert_model.transformer_layer_spec = bert_layer_local_spec - attn_mask_dimensions = self.bert_model._sanity_check_attention_and_get_attn_mask_dimension() - assert ( - attn_mask_dimensions == "b1ss" - ), f"Expected b1ss for attn_mask_dimensions but got {attn_mask_dimensions}" - - @pytest.mark.internal - def test_local_spec_exception(self, mocker): - self.bert_model.config.attention_backend = AttnBackend.flash - self.bert_model.transformer_layer_spec = bert_layer_local_spec - with pytest.raises(Exception) as exc_info: - self.bert_model._sanity_check_attention_and_get_attn_mask_dimension() - assert ( - str(exc_info.value) - == 'Expected AttnBackend to be local or auto while using mcore self attention, but found AttnBackend.flash. Set --attn-backend to local or dont use MCore SelfAttention submodule in layer specs' - ) - - @pytest.mark.internal - def test_transformer_engine_version_1_10(self, mocker): - bert_layer_with_transformer_engine_spec.submodules.self_attention.params[ - 'attn_mask_type' - ] == AttnMaskType.arbitrary - - mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("1.10")) - self.bert_model.transformer_layer_spec = bert_layer_with_transformer_engine_spec - attn_mask_dimensions = self.bert_model._sanity_check_attention_and_get_attn_mask_dimension() - attn_mask_type = self.bert_model.transformer_layer_spec.submodules.self_attention.params[ - 'attn_mask_type' - ] - assert ( - attn_mask_type == AttnMaskType.padding - ), f"Exepcted attn mask type to be padding, but got {attn_mask_type}" - assert ( - attn_mask_dimensions == "b11s" - ), f"Expected b11s for attn_mask_dimensions but got {attn_mask_dimensions}" - - @pytest.mark.internal - def test_transformer_engine_version_1_7_to_1_10_flash_attn(self, mocker): - self.bert_model.config.attention_backend = AttnBackend.flash - mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("1.8")) - self.bert_model.transformer_layer_spec = bert_layer_with_transformer_engine_spec - attn_mask_dimensions = self.bert_model._sanity_check_attention_and_get_attn_mask_dimension() - assert ( - attn_mask_dimensions == "b11s" - ), f"Expected b11s for attn_mask_dimensions but got {attn_mask_dimensions}" - - @pytest.mark.internal - @pytest.mark.flaky_in_dev - def test_transformer_engine_version_1_7_to_1_10_rng_error(self, mocker): - bert_layer_with_transformer_engine_spec.submodules.self_attention.params[ - 'attn_mask_type' - ] == AttnMaskType.padding - mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("1.8")) - with pytest.raises(Exception) as exc_info: - self.bert_model = BertModel( - config=self.transformer_config, - num_tokentypes=0, - transformer_layer_spec=bert_layer_with_transformer_engine_spec, - vocab_size=100, - max_sequence_length=4, - ) - assert str(exc_info.value) == ( - "Linear.__init__() got an unexpected keyword argument 'rng_tracker_name' when " - "instantiating TERowParallelLinear when instantiating SelfAttention when " - "instantiating TransformerLayer" - ) - - @pytest.mark.internal - def test_transformer_engine_version_1_7_to_1_10_unfused_attention(self, mocker): - self.bert_model.config.attention_backend = AttnBackend.unfused - bert_layer_with_transformer_engine_spec.submodules.self_attention.params[ - 'attn_mask_type' - ] == AttnMaskType.padding - mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("1.8")) - self.bert_model.transformer_layer_spec = bert_layer_with_transformer_engine_spec - attn_mask_dimensions = self.bert_model._sanity_check_attention_and_get_attn_mask_dimension() - attn_mask_type = self.bert_model.transformer_layer_spec.submodules.self_attention.params[ - 'attn_mask_type' - ] - assert ( - attn_mask_type == AttnMaskType.arbitrary - ), f"Exepcted attn mask type to be arbitrary, but got {attn_mask_type}" - assert ( - attn_mask_dimensions == "b1ss" - ), f"Expected b1ss for attn_mask_dimensions but got {attn_mask_dimensions}" - - @pytest.mark.internal - def test_transformer_engine_version_less_than_1_7(self, mocker): - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - self.bert_model.config.attention_backend = AttnBackend.flash - with pytest.raises(Exception) as exc_info: - mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("1.5")) - self.bert_model = BertModel( - config=self.transformer_config, - num_tokentypes=0, - transformer_layer_spec=bert_layer_with_transformer_engine_spec, - vocab_size=100, - max_sequence_length=4, - ) - - assert str(exc_info.value) == ( - "Flash and fused attention is not supported with transformer engine version " - "< 1.7. Set --attention-backend to unfused or leave it to be default (auto) or upgrade transformer engine >= 1.7" - ) diff --git a/tests/unit_tests/models/test_clip_vit_model.py b/tests/unit_tests/models/test_clip_vit_model.py deleted file mode 100644 index c176c188d1..0000000000 --- a/tests/unit_tests/models/test_clip_vit_model.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import pytest -import torch - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestCLIPViTModel: - """Test CLIP ViT model.""" - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True - ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec() - self.model = CLIPViTModel( - transformer_config, transformer_layer_spec, img_h=336, img_w=336, patch_dim=14 - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - assert isinstance(self.model, CLIPViTModel) - - num_weights = sum([p.numel() for p in self.model.parameters()]) - assert num_weights == 174720 - - def test_set_input_tensor(self): - # [s, b, h] expected to the transformer. - expected_shape = (577, 2, 64) - input_tensor = torch.zeros(expected_shape) - - self.model.set_input_tensor(input_tensor) - - assert self.model.decoder.input_tensor.shape == torch.Size(expected_shape) - - def test_forward(self): - self.model.cuda() - - img = torch.zeros((2, 3, 336, 336)).cuda() - - out = self.model.forward(img) - assert out.shape == torch.Size([2, 577, 64]) - - def test_save_load(self, tmp_path): - path = tmp_path / "model.pt" - torch.save(self.model.state_dict(), path) - - self.model.load_state_dict(torch.load(path)) - - -@pytest.mark.internal -@pytest.mark.parametrize( - "vision_model,pixel_shuffle,tile_tags,expected", - [ - ("clip", False, False, 1024), - ("internvit300M", False, False, 1024), - ("clip", True, False, 256), - ("internvit300M", True, True, 262), - ], -) -def test_get_num_image_embeddings(vision_model, pixel_shuffle, tile_tags, expected): - assert ( - get_num_image_embeddings( - 448, 448, 14, vision_model, True, 1, pixel_shuffle, tile_tags, 0, "nemotron5" - ) - == expected - ) diff --git a/tests/unit_tests/models/test_gpt_model.py b/tests/unit_tests/models/test_gpt_model.py deleted file mode 100644 index 6cc827c406..0000000000 --- a/tests/unit_tests/models/test_gpt_model.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import inspect -import os -from datetime import timedelta - -import pytest -import torch -from packaging import version -from pytest import approx - -from megatron.core import parallel_state -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_with_transformer_engine_spec, - get_mlp_module_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from tests.unit_tests.test_utilities import Utils - - -class TestGPTModel: - - def setup_method(self, method): - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - embedding_init_method_std=1.0, # Test that we can initialize the embedding weights to something else. - ) - self.gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), - vocab_size=100, - max_sequence_length=4, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_constructor(self): - assert isinstance(self.gpt_model, GPTModel) - - assert self.gpt_model.max_sequence_length == 4 - - num_weights = sum([p.numel() for p in self.gpt_model.parameters()]) - assert num_weights == 6240 - - @pytest.mark.internal - def test_set_input_tensor(self): - config: TransformerConfig = self.gpt_model.config - sequence_length = self.gpt_model.max_sequence_length - micro_batch_size = 2 - - # [sequence length, batch size, hidden size] - input_tensor = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - - self.gpt_model.set_input_tensor(input_tensor) - - assert self.gpt_model.decoder.input_tensor.shape[0] == sequence_length - assert self.gpt_model.decoder.input_tensor.shape[1] == micro_batch_size - assert self.gpt_model.decoder.input_tensor.shape[2] == config.hidden_size - - def test_embedding_init(self): - """Test that we can initialize the embedding weights to something else. This test could be added to any model.""" - config: TransformerConfig = self.gpt_model.config - assert self.gpt_model.embedding.word_embeddings.weight.std().cpu().item() == approx( - config.embedding_init_method_std, abs=1e-1 - ) - assert self.gpt_model.embedding.word_embeddings.weight.mean().cpu().item() == approx( - 0.0, abs=1e-1 - ) - - @pytest.mark.internal - def test_post_process_forward(self): - _ = self.gpt_model.config - sequence_length = self.gpt_model.max_sequence_length - micro_batch_size = 2 - - self.gpt_model.cuda() - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = self.gpt_model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.gpt_model.vocab_size - - -def test_get_mlp_module_spec_interface(): - # Get the function signature - sig = inspect.signature(get_mlp_module_spec) - - # Define the expected signature - expected_params = { - "use_te": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "num_experts": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "moe_grouped_gemm": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "fp8": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "moe_use_legacy_grouped_gemm": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "use_te_op_fuser": inspect.Parameter.POSITIONAL_OR_KEYWORD, - } - - expected_defaults = { - "use_te": True, - "num_experts": None, - "moe_grouped_gemm": False, - "fp8": None, - "moe_use_legacy_grouped_gemm": False, - "use_te_op_fuser": False, - } - - # Check expected parameters are in function signature - for param_name, param_kind in expected_params.items(): - assert param_name in sig.parameters, f"Unexpected parameter: {param_name}" - assert ( - param_kind is sig.parameters[param_name].kind - ), f"Wrong kind for parameter: {param_name}" - - # Check default values - sig_defaults = { - k: v.default for k, v in sig.parameters.items() if v.default is not inspect.Parameter.empty - } - for k, v in expected_defaults.items(): - assert ( - k in sig_defaults and v == sig_defaults[k] - ), f"Default value of {sig_defaults[k]} does not match the expected value of {v} for parameter {k}." - - -@pytest.mark.skipif( - not is_te_min_version("1.13.0"), reason="TEFusedMLP is only supported with TE 1.13+." -) -class TestGPTWithFusedOps: - """GPT model with Transformer Engine operation-based API""" - - def setup_method(self, method) -> None: - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - self.gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(use_te_op_fuser=True), - vocab_size=100, - max_sequence_length=4, - ) - - def teardown_method(self, method) -> None: - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_forward(self) -> None: - _ = self.gpt_model.config - sequence_length = self.gpt_model.max_sequence_length - micro_batch_size = 2 - - self.gpt_model.cuda() - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = self.gpt_model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.gpt_model.vocab_size - - -@pytest.mark.skipif( - not is_te_min_version("1.13.0"), reason="TEFusedMLP is only supported with TE 1.13+." -) -@pytest.mark.parametrize("num_experts", [None, 4]) -@pytest.mark.parametrize("gated_linear_unit", [True, False]) -def test_gpt_with_te_activation_func(num_experts, gated_linear_unit): - """Test GPT model with Transformer Engine activation function""" - - # setup - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=512, - num_attention_heads=4, - use_cpu_initialization=True, - add_bias_linear=False, - use_te_activation_func=True, - bias_activation_fusion=False, - gated_linear_unit=gated_linear_unit, - num_moe_experts=num_experts, - moe_grouped_gemm=(num_experts is not None), - ) - gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec( - num_experts=num_experts, use_te_activation_func=True - ), - vocab_size=128, - max_sequence_length=128, - ) - - # test - sequence_length = gpt_model.max_sequence_length - micro_batch_size = 2 - - gpt_model.cuda() - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = gpt_model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == gpt_model.vocab_size - - # teardown - Utils.destroy_model_parallel() - - -class TestGPTModelWithCustomPG: - def setup_method(self, method): - Utils.destroy_model_parallel() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - @pytest.mark.parametrize( - "tp_size, dp_size, cp_size", [(1, 8, 1), (2, 4, 1)] # TP 1, DP 8, CP 1 # TP 2, DP 4, CP 1 - ) - def test_gpt_model_with_custom_pg(self, tp_size, dp_size, cp_size): - Utils.initialize_model_parallel(tensor_model_parallel_size=tp_size) - model_parallel_cuda_manual_seed(123) - - # Create HyperCommGrid with dimensions tp, cp, ep, pp, dp (reversed from device mesh order) - grid = HyperCommGrid([tp_size, cp_size, 1, 1, dp_size], ["tp", "cp", "ep", "pp", "dp"]) - - tp_group = grid.create_pg("tp") - cp_group = grid.create_pg("cp") - pp_group = grid.create_pg("pp") - ep_group = grid.create_pg("ep") - embd_group_ranks = parallel_state.default_embedding_ranks( - torch.distributed.get_process_group_ranks(pp_group) - ) - embd_group = torch.distributed.new_group( - ranks=embd_group_ranks, timeout=timedelta(minutes=30) - ) - model_comm_pgs = ModelCommProcessGroups( - tp=tp_group, cp=cp_group, pp=pp_group, ep=ep_group, embd=embd_group - ) - - transformer_config = TransformerConfig( - num_layers=2, hidden_size=1024, num_attention_heads=16, use_cpu_initialization=True - ) - self.gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), - vocab_size=100, - max_sequence_length=512, - model_comm_pgs=model_comm_pgs, - post_process=False, - ) - - # Check that model weights are distributed as expected when using TP - assert ( - self.gpt_model.decoder.layers[0].self_attention.linear_qkv.weight.shape[0] - == (1024 * 3) / tp_size - ) - assert self.gpt_model.decoder.layers[0].self_attention.linear_qkv.weight.shape[1] == 1024 - assert self.gpt_model.decoder.layers[0].self_attention.linear_proj.weight.shape[0] == 1024 - assert ( - self.gpt_model.decoder.layers[0].self_attention.linear_proj.weight.shape[1] - == 1024 / tp_size - ) - - # Check that the logits output shape is correct - sequence_length = self.gpt_model.max_sequence_length - micro_batch_size = 2 - - self.gpt_model.cuda() - - input_ids = torch.ones(micro_batch_size, sequence_length, dtype=torch.int64, device="cuda") - position_ids = torch.ones( - micro_batch_size, sequence_length, dtype=torch.int64, device="cuda" - ) - - logits = self.gpt_model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=None - ) - - assert logits.shape[0] == sequence_length - assert logits.shape[1] == micro_batch_size - assert logits.shape[2] == self.gpt_model.config.hidden_size diff --git a/tests/unit_tests/models/test_gpt_model_quantization.py b/tests/unit_tests/models/test_gpt_model_quantization.py deleted file mode 100644 index 2b7c5cc6ff..0000000000 --- a/tests/unit_tests/models/test_gpt_model_quantization.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import pytest - -from megatron.core.models.gpt import GPTModel -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec -from megatron.core.quantization.quant_config import MatchContext, RecipeConfig -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig -from tests.unit_tests.test_utilities import Utils - -try: - HAVE_TE = True - import transformer_engine as te -except ImportError: - HAVE_TE = False - -try: - import nvidia_kitchen - - HAVE_KITCHEN = True -except ImportError: - HAVE_KITCHEN = False - - -@pytest.mark.skipif(not HAVE_KITCHEN, reason="Kitchen required for using kitchen backend.") -@pytest.mark.skipif( - not HAVE_TE, reason="Transformer Engine required for using kitchen backend with TE layers." -) -class TestGPTModelKitchenQuantizationConfig: - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_kitchen_config_resolution_dense(self) -> None: - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=False, - gated_linear_unit=True, - bias_activation_fusion=True, - add_bias_linear=False, - use_kitchen=True, - quant_recipe=RecipeConfig.from_config_dict( - { - "matchers": { - "keep_in_hp": { - "type": "glob", - "enabled": True, - "pattern": "*fc2", - "config": "bf16", - }, - "use_fp8_cs": { - "type": "glob", - "enabled": True, - "pattern": "*", - "config": "fp8_cs", - }, - }, - "configs": { - "bf16": {"kitchen_config_type": "QLinearParams", "recipe_idx": 1}, - "fp8_cs": {"kitchen_config_type": "QLinearParams", "recipe_idx": 2}, - }, - } - ), - ) - transformer_layer_spec = get_gpt_decoder_block_spec( - config=transformer_config, use_transformer_engine=True - ) - padded_vocab_size = 512 - max_position_embeddings = 4096 - model = GPTModel( - config=transformer_config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=padded_vocab_size, - max_sequence_length=max_position_embeddings, - ) - - expected_types = { - "decoder.layers.0.self_attention.linear_proj": KitchenRowParallelLinear, - "decoder.layers.1.self_attention.linear_proj": KitchenRowParallelLinear, - "decoder.layers.0.self_attention.linear_qkv": KitchenLayerNormColumnParallelLinear, - "decoder.layers.1.self_attention.linear_qkv": KitchenLayerNormColumnParallelLinear, - "decoder.layers.0.mlp.linear_fc1": KitchenLayerNormColumnParallelLinear, - "decoder.layers.1.mlp.linear_fc1": KitchenLayerNormColumnParallelLinear, - "decoder.layers.0.mlp.linear_fc2": KitchenRowParallelLinear, - "decoder.layers.1.mlp.linear_fc2": KitchenRowParallelLinear, - } - - expected_match = { - "decoder.layers.0.self_attention.linear_proj": ( - MatchContext("decoder.layers.0.self_attention.linear_proj", layer_number=0), - "fp8_cs", - ), - "decoder.layers.1.self_attention.linear_proj": ( - MatchContext("decoder.layers.1.self_attention.linear_proj", layer_number=1), - "fp8_cs", - ), - "decoder.layers.0.self_attention.linear_qkv": ( - MatchContext("decoder.layers.0.self_attention.linear_qkv", layer_number=0), - "fp8_cs", - ), - "decoder.layers.1.self_attention.linear_qkv": ( - MatchContext("decoder.layers.1.self_attention.linear_qkv", layer_number=1), - "fp8_cs", - ), - "decoder.layers.0.mlp.linear_fc1": ( - MatchContext("decoder.layers.0.mlp.linear_fc1", layer_number=0), - "fp8_cs", - ), - "decoder.layers.1.mlp.linear_fc1": ( - MatchContext("decoder.layers.1.mlp.linear_fc1", layer_number=1), - "fp8_cs", - ), - "decoder.layers.0.mlp.linear_fc2": ( - MatchContext("decoder.layers.0.mlp.linear_fc2", layer_number=0), - "bf16", - ), - "decoder.layers.1.mlp.linear_fc2": ( - MatchContext("decoder.layers.1.mlp.linear_fc2", layer_number=1), - "bf16", - ), - } - - visited_keys = set() - for name, module in model.named_modules(): - if name in expected_types: - assert ( - type(module) == expected_types[name] - ), f"Expected {name} to be {expected_types[name]}, but it is {type(module)}" - visited_keys.add(name) - assert hasattr(module, "kitchen_quant_params") - assert module.kitchen_quant_params.params_config_key == expected_match[name][1] - assert module.kitchen_quant_params.match_input == expected_match[name][0] - assert visited_keys == set(expected_types.keys()) - - def test_kitchen_config_resolution_moe(self) -> None: - transformer_config = TransformerConfig( - moe_layer_freq=1, - num_moe_experts=2, - moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1, - moe_grouped_gemm=True, - moe_use_legacy_grouped_gemm=False, - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=False, - gated_linear_unit=True, - bias_activation_fusion=True, - add_bias_linear=False, - use_kitchen=True, - quant_recipe=RecipeConfig.from_config_dict( - { - "matchers": { - "keep_in_hp": { - "type": "glob", - "enabled": True, - "pattern": "*fc2", - "config": "bf16", - }, - "use_fp8_cs": { - "type": "glob", - "enabled": True, - "pattern": "*", - "config": "fp8_cs", - }, - }, - "configs": { - "bf16": {"kitchen_config_type": "QLinearParams", "recipe_idx": 1}, - "fp8_cs": {"kitchen_config_type": "QLinearParams", "recipe_idx": 2}, - }, - } - ), - ) - transformer_layer_spec = get_gpt_decoder_block_spec( - config=transformer_config, use_transformer_engine=True - ) - padded_vocab_size = 512 - max_position_embeddings = 4096 - model = GPTModel( - config=transformer_config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=padded_vocab_size, - max_sequence_length=max_position_embeddings, - ) - - expected_types = { - "decoder.layers.0.self_attention.linear_proj": KitchenRowParallelLinear, - "decoder.layers.1.self_attention.linear_proj": KitchenRowParallelLinear, - "decoder.layers.0.self_attention.linear_qkv": KitchenLayerNormColumnParallelLinear, - "decoder.layers.1.self_attention.linear_qkv": KitchenLayerNormColumnParallelLinear, - "decoder.layers.0.mlp.experts.linear_fc1": KitchenColumnParallelGroupedLinear, - "decoder.layers.1.mlp.experts.linear_fc1": KitchenColumnParallelGroupedLinear, - "decoder.layers.0.mlp.experts.linear_fc2": KitchenRowParallelGroupedLinear, - "decoder.layers.1.mlp.experts.linear_fc2": KitchenRowParallelGroupedLinear, - } - - expected_match = { - "decoder.layers.0.self_attention.linear_proj": ( - MatchContext("decoder.layers.0.self_attention.linear_proj", layer_number=0), - "fp8_cs", - ), - "decoder.layers.1.self_attention.linear_proj": ( - MatchContext("decoder.layers.1.self_attention.linear_proj", layer_number=1), - "fp8_cs", - ), - "decoder.layers.0.self_attention.linear_qkv": ( - MatchContext("decoder.layers.0.self_attention.linear_qkv", layer_number=0), - "fp8_cs", - ), - "decoder.layers.1.self_attention.linear_qkv": ( - MatchContext("decoder.layers.1.self_attention.linear_qkv", layer_number=1), - "fp8_cs", - ), - "decoder.layers.0.mlp.experts.linear_fc1": ( - MatchContext("decoder.layers.0.mlp.experts.linear_fc1", layer_number=0), - "fp8_cs", - ), - "decoder.layers.1.mlp.experts.linear_fc1": ( - MatchContext("decoder.layers.1.mlp.experts.linear_fc1", layer_number=1), - "fp8_cs", - ), - "decoder.layers.0.mlp.experts.linear_fc2": ( - MatchContext("decoder.layers.0.mlp.experts.linear_fc2", layer_number=0), - "bf16", - ), - "decoder.layers.1.mlp.experts.linear_fc2": ( - MatchContext("decoder.layers.1.mlp.experts.linear_fc2", layer_number=1), - "bf16", - ), - } - - visited_keys = set() - for name, module in model.named_modules(): - if name in expected_types: - assert ( - type(module) == expected_types[name] - ), f"Expected {name} to be {expected_types[name]}, but it is {type(module)}" - visited_keys.add(name) - assert hasattr(module, "kitchen_quant_params") - assert module.kitchen_quant_params.params_config_key == expected_match[name][1] - assert module.kitchen_quant_params.match_input == expected_match[name][0] - assert visited_keys == set(expected_types.keys()) diff --git a/tests/unit_tests/models/test_heterogeneous_gpt_model.py b/tests/unit_tests/models/test_heterogeneous_gpt_model.py deleted file mode 100644 index 56d112021c..0000000000 --- a/tests/unit_tests/models/test_heterogeneous_gpt_model.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import json - -import pytest -import torch - -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( - get_gpt_heterogeneous_layer_spec, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.heterogeneous.heterogeneous_config import ( - HeterogeneousTransformerConfig, -) -from megatron.core.utils import is_torch_min_version -from tests.unit_tests.test_utilities import Utils - -TORCH_VERSION_GE_2_4 = is_torch_min_version("2.4.0") - -first_layer = { - "attention": {"no_op": False, "replace_with_linear": False, "num_query_groups": 8}, - "mlp": {"no_op": False, "replace_with_linear": False, "ffn_hidden_size": 14336}, -} - - -@pytest.fixture -def heterogeneous_gpt_model(request, tmp_path): - ( - attention_no_op, - attention_replace_with_linear, - attention_num_query_groups, - mlp_no_op, - mlp_replace_with_linear, - mlp_ffn_hidden_size, - use_transformer_engine, - ) = request.param - - second_layer_config = { - "attention": { - "no_op": attention_no_op, - "replace_with_linear": attention_replace_with_linear, - "num_query_groups": attention_num_query_groups, - }, - "mlp": { - "no_op": mlp_no_op, - "replace_with_linear": mlp_replace_with_linear, - "ffn_hidden_size": mlp_ffn_hidden_size, - }, - } - - block_config_data = {"block_configs": [first_layer, second_layer_config]} - block_config_file = tmp_path / "config.json" - block_config_file.write_text(json.dumps(block_config_data)) - - transformer_config = HeterogeneousTransformerConfig( - num_layers=2, - hidden_size=4096, - add_bias_linear=False, - normalization="RMSNorm", - gated_linear_unit=True, - num_attention_heads=32, - use_cpu_initialization=True, - perform_initialization=False, - heterogeneous_layers_config_path=str(block_config_file), - ) - - return GPTModel( - transformer_config, - transformer_layer_spec=get_gpt_heterogeneous_layer_spec( - transformer_config, use_te=use_transformer_engine - ), - vocab_size=128256, - position_embedding_type="rope", - max_sequence_length=4, - ) - - -@pytest.mark.parametrize( - "heterogeneous_gpt_model, expected_num_parameters", - [ - ((False, False, 8, False, False, 14336, True), 1486901248), # regular TE - pytest.param( - (False, False, 8, False, False, 14336, False), - 1486901248, - marks=pytest.mark.skipif(not TORCH_VERSION_GE_2_4, reason="Requires PyTorch >= 2.4.0"), - ), # regular local - ((True, False, None, False, False, 14336, True), 1444954112), # attn no-op TE - pytest.param( - (True, False, None, False, False, 14336, False), - 1444954112, - marks=pytest.mark.skipif(not TORCH_VERSION_GE_2_4, reason="Requires PyTorch >= 2.4.0"), - ), # attn no-op local - ((False, False, 8, True, False, None, True), 1310736384), # mlp no-op TE - pytest.param( - (False, False, 8, True, False, None, False), - 1310736384, - marks=pytest.mark.skipif(not TORCH_VERSION_GE_2_4, reason="Requires PyTorch >= 2.4.0"), - ), # mlp no-op local - ((False, True, None, False, False, 14336, True), 1461735424), # attn replace with linear TE - pytest.param( - (False, True, None, False, False, 14336, False), - 1461735424, - marks=pytest.mark.skipif(not TORCH_VERSION_GE_2_4, reason="Requires PyTorch >= 2.4.0"), - ), # attn replace with linear local - ((False, False, 8, False, True, None, True), 1327517696), # mlp replace with linear TE - pytest.param( - (False, False, 8, False, True, None, False), - 1327517696, - marks=pytest.mark.skipif(not TORCH_VERSION_GE_2_4, reason="Requires PyTorch >= 2.4.0"), - ), # mlp replace with linear local - ], - indirect=["heterogeneous_gpt_model"], -) -class TestHeterogeneousGPTModel: - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self, heterogeneous_gpt_model, expected_num_parameters): - assert isinstance(heterogeneous_gpt_model, GPTModel) - - assert heterogeneous_gpt_model.max_sequence_length == 4 - - num_weights = sum([p.numel() for p in heterogeneous_gpt_model.parameters()]) - assert num_weights == expected_num_parameters - - def test_post_process_forward(self, heterogeneous_gpt_model, expected_num_parameters): - sequence_length = heterogeneous_gpt_model.max_sequence_length - micro_batch_size = 2 - - heterogeneous_gpt_model.cuda() - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = heterogeneous_gpt_model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == heterogeneous_gpt_model.vocab_size diff --git a/tests/unit_tests/models/test_llava_model.py b/tests/unit_tests/models/test_llava_model.py deleted file mode 100644 index cee6d2b0b2..0000000000 --- a/tests/unit_tests/models/test_llava_model.py +++ /dev/null @@ -1,810 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from contextlib import nullcontext -from copy import deepcopy -from types import SimpleNamespace - -import pytest -import torch - -from megatron.core import parallel_state as ps -from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.multimodal import context_parallel -from megatron.core.models.multimodal.llava_model import LLaVAModel -from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from megatron.training.global_vars import set_args -from tests.unit_tests.test_utilities import Utils - - -class TestLLaVAModel: - @pytest.mark.internal # The model is under active development and its methods may change. - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - - self.language_hidden_size = 64 - self.language_num_attention_heads = 4 - - language_config = TransformerConfig( - num_layers=3, - hidden_size=self.language_hidden_size, - num_attention_heads=self.language_num_attention_heads, - use_cpu_initialization=False, - ) - vision_config = TransformerConfig( - num_layers=2, hidden_size=16, num_attention_heads=2, use_cpu_initialization=False - ) - vision_projection_config = TransformerConfig( - num_layers=2, - hidden_size=self.language_hidden_size, - ffn_hidden_size=32, - num_attention_heads=1, - use_cpu_initialization=False, - ) - - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() - vision_layer_spec = deepcopy(language_layer_spec) - vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules) - - language_config.language_model_type = "dummy" - vision_config.vision_model_type = "clip" - self.model = LLaVAModel( - language_transformer_config=language_config, - language_transformer_layer_spec=language_layer_spec, - language_vocab_size=8192, - language_max_sequence_length=4096, - vision_transformer_config=vision_config, - vision_transformer_layer_spec=vision_layer_spec, - drop_vision_class_token=False, - vision_projection_config=vision_projection_config, - vision_projection_layer_spec=vision_projection_spec, - img_h=336, - img_w=336, - patch_dim=14, - ) - - @pytest.mark.internal - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_constructor(self): - assert isinstance(self.model, LLaVAModel) - - num_weights = sum([p.numel() for p in self.model.parameters()]) - assert num_weights == 1488736 - - @pytest.mark.internal - def test_set_input_tensor(self): - expected_shape = (1, 2, 3, 4) - input_tensor = torch.zeros(expected_shape) - self.model.set_input_tensor(input_tensor) - assert self.model.vision_model.decoder.input_tensor.shape == expected_shape - - @pytest.mark.internal - def test_preprocess_data(self): - self.model.cuda() - - hidden_size = 72 - - # 3 images with 1 tile and 2 image with 2 tiles = 7 tiles. - image_embeddings = ( - torch.arange(577 * 7 * hidden_size, dtype=torch.float) - .reshape(577, 7, hidden_size) - .cuda() - ) - - image_token_index = self.model.image_token_index - input_ids = torch.arange(1024).expand(5, 1024).cuda() - input_ids[0, 0] = image_token_index # image before text - input_ids[1, 100] = image_token_index # image in between - input_ids[2, -1] = image_token_index # image at the end - # input_ids[3] - no image - input_ids[4, 50] = image_token_index # two images in between - input_ids[4, 150] = image_token_index - - # Using negative sign to distinguish from image embeddings. - language_embeddings = ( - -torch.arange(5 * 1024 * hidden_size, dtype=torch.float) - .reshape(5, 1024, hidden_size) - .cuda() - ) - - # Labels are input_ids shifted to left by one. - labels = torch.arange(1, 1025, dtype=torch.int).expand(5, 1024).cuda() - # labels[0] - image token got dropped by shift to left by one. - labels[1, 99] = image_token_index - labels[2, -2] = image_token_index - # labels[3] - no image. - labels[4, 49] = image_token_index - labels[4, 149] = image_token_index - - loss_mask = torch.ones((5, 1024), dtype=torch.float).cuda() - # Mask some text inputs (the text mask should carry over) - loss_mask[:2, :10] = 0.0 - loss_mask[:2, 110:120] = 0.0 - - # Number of tiles for each image in the batch. - num_image_tiles = torch.tensor([1, 2, 1, 2, 1], dtype=torch.int).cuda() - - use_inference_kv_cache = False - inference_context = None - - embeddings, labels, loss_mask = self.model._preprocess_data( - image_embeddings, - language_embeddings, - input_ids, - loss_mask, - labels, - use_inference_kv_cache, - inference_context, - image_token_index, - num_image_tiles, - ) - - img_seq_len = 577 - # The fifth sample has 2 images with 3 tiles and 1024 text tokens. - max_seq_len = 3 * img_seq_len - 2 + 1024 - - assert embeddings.shape == torch.Size((max_seq_len, 5, hidden_size)) - assert labels.shape == torch.Size((5, max_seq_len)) - assert loss_mask.shape == labels.shape - - # First sample where image is before text (index 0). - expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda() - expected_embeddings[:577] = image_embeddings[:, 0] - expected_embeddings[577:1600] = language_embeddings[0, 1:] - expected_embeddings[1600:] = 0 # padding - - expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() - expected_labels[:576] = -100 # image - expected_labels[576:1600] = torch.arange(1, 1025, dtype=torch.int) - expected_labels[1600:] = -100 # padding - - expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() - expected_loss_mask[:577] = 0 - expected_loss_mask[577:586] = 0 - expected_loss_mask[586:686] = 1 - expected_loss_mask[686:696] = 0 - expected_loss_mask[696:1600] = 1 - expected_loss_mask[1600:] = 0 - - assert torch.allclose(embeddings[:, 0], expected_embeddings) - assert torch.allclose(labels[0], expected_labels) - assert torch.allclose(loss_mask[0], expected_loss_mask) - - # Second sample where image is in between (index 100). The image has 2 tiles. - expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda() - expected_embeddings[:100] = language_embeddings[1, :100] - expected_embeddings[100:677] = image_embeddings[:, 1] - expected_embeddings[677:1254] = image_embeddings[:, 2] - expected_embeddings[1254:2177] = language_embeddings[1, 101:] - expected_embeddings[2177:] = 0 # padding - - expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() - expected_labels[:99] = torch.arange(1, 100) - expected_labels[99:1253] = -100 # image - expected_labels[1253:2177] = torch.arange(101, 1025) - expected_labels[2177:] = -100 # padding - - expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() - expected_loss_mask[:10] = 0 - expected_loss_mask[10:99] = 1 - # Last text position before the image is not required to predict the first image embedding. - expected_loss_mask[99] = 0 - expected_loss_mask[100:1254] = 0 - expected_loss_mask[1254:1263] = 1 - expected_loss_mask[1263:1273] = 0 - expected_loss_mask[1273:2177] = 1 - expected_loss_mask[2177:] = 0 # padding - - assert torch.allclose(embeddings[:, 1], expected_embeddings) - assert torch.allclose(labels[1], expected_labels) - assert torch.allclose(loss_mask[1], expected_loss_mask) - - # Third sample where image is at the end. - expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda() - expected_embeddings[:1023] = language_embeddings[2, :1023] - expected_embeddings[1023:1600] = image_embeddings[:, 3] - expected_embeddings[1600:] = 0 # padding - - expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() - expected_labels[:1022] = torch.arange(1, 1023) - expected_labels[1022:1599] = -100 - expected_labels[1599] = 1024 - expected_labels[1600:] = -100 # padding - - expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() - expected_loss_mask[:1022] = 1 - # Last text position before the image is not required to predict the first image embedding. - expected_loss_mask[1022] = 0 - expected_loss_mask[1023:1600] = 0 - expected_loss_mask[1600:] = 0 # padding - - assert torch.allclose(embeddings[:, 2], expected_embeddings) - assert torch.allclose(labels[2], expected_labels) - assert torch.allclose(loss_mask[2], expected_loss_mask) - - # Fourth sample where there is no image. - expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda() - expected_embeddings[:1024] = language_embeddings[3] - expected_embeddings[1024:] = 0 # padding - - expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() - expected_labels[:1024] = torch.arange(1, 1025) - expected_labels[1024:] = -100 # padding - - expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() - expected_loss_mask[:1024] = 1 - expected_loss_mask[1024:] = 0 # padding - - assert torch.allclose(embeddings[:, 3], expected_embeddings) - assert torch.allclose(labels[3], expected_labels) - assert torch.allclose(loss_mask[3], expected_loss_mask) - - # Fifth sample has two images in between (indices 50 and 150). The first image has two tiles. - expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda() - expected_embeddings[:50] = language_embeddings[4, :50] - expected_embeddings[50:627] = image_embeddings[:, 4] # two tiles - expected_embeddings[627:1204] = image_embeddings[:, 5] - expected_embeddings[1204:1303] = language_embeddings[4, 51:150] - expected_embeddings[1303:1880] = image_embeddings[:, 6] - expected_embeddings[1880:] = language_embeddings[4, 151:] - - expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() - expected_labels[:49] = torch.arange(1, 50) - expected_labels[49:1203] = -100 # image - expected_labels[1203:1302] = torch.arange(51, 150) - expected_labels[1302:1879] = -100 # image - expected_labels[1879:] = torch.arange(151, 1025) - - expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() - expected_loss_mask[:49] = 1 - expected_loss_mask[49:1204] = 0 - expected_loss_mask[1204:1302] = 1 - expected_loss_mask[1302:1880] = 0 - expected_loss_mask[1880:] = 1 - - assert torch.allclose(embeddings[:, 4], expected_embeddings) - assert torch.allclose(labels[4], expected_labels) - assert torch.allclose(loss_mask[4], expected_loss_mask) - - @pytest.mark.internal - def test_forward(self): - self.model.cuda() - - # 3 images with 1 tile and 2 images with 2 tiles. - img = torch.randn((7, 3, 336, 336)).cuda() - - image_token_index = self.model.image_token_index - input_ids = torch.randint(0, 2048, (5, 1024)).cuda() - input_ids[0, 0] = image_token_index # image before text - input_ids[1, 100] = image_token_index # image in between - input_ids[2, -1] = image_token_index # image at the end - # input_ids[3] - no image - input_ids[4, 50] = image_token_index - input_ids[4, 150] = image_token_index - - position_ids = torch.arange(0, 1024, dtype=torch.int).expand(5, 1024).cuda() - - loss_mask = torch.ones((5, 1024)).cuda() - - attention_mask = None # Causal. - - labels = torch.randint(0, 2048, (5, 1024)).cuda() - labels[1, 99] = image_token_index - labels[2, -2] = image_token_index - - num_image_tiles = torch.tensor([1, 2, 1, 2, 1], dtype=torch.int).cuda() - - # Try with labels. - loss, new_loss_mask = self.model.forward( - img, - input_ids, - position_ids, - attention_mask, - labels, - loss_mask, - num_image_tiles=num_image_tiles, - ) - - # The maximum sequence length is given by the sample with 2 images in 3 tiles, minus two image token indices, plus other text tokens. - img_seq_len = 577 - max_seq_len = img_seq_len * 3 - 2 + 1024 - assert loss.shape == new_loss_mask.shape == torch.Size((5, max_seq_len)) - - # Try with labels and PackedSeqParams. Only micro batch size 1 is supported in this mode. - packed_seq_params = PackedSeqParams( - qkv_format="thd", - cu_seqlens_q=torch.tensor( - [0, 512, 1024, 1600], dtype=torch.int32 - ).cuda(), # Just example values. - cu_seqlens_kv=torch.tensor([0, 512, 1024, 1600], dtype=torch.int32).cuda(), - max_seqlen_q=1600, - max_seqlen_kv=1600, - ) - - # NOTE: Packing is only supported with BF16. Use BF16 here and switch back to default. - self.model.to(torch.bfloat16) - loss, new_loss_mask = self.model.forward( - img[:1].to(torch.bfloat16), - input_ids[:1], - position_ids[:1], - attention_mask, - labels[:1], - loss_mask[:1], - num_image_tiles=num_image_tiles[:1], - packed_seq_params=packed_seq_params, - ) - self.model.to(torch.float32) - - # 1600 = 577 (img_seq_len) + 1024 (text tokens in the first sample) - 1 (image token). - assert loss.shape == new_loss_mask.shape == torch.Size((1, 1600)) - - # Try text-only input. - loss, new_loss_mask = self.model.forward( - torch.tensor([], dtype=torch.float).cuda(), - torch.randint(0, 2048, (5, 1024)).cuda(), - position_ids, - attention_mask, - torch.randint(0, 2048, (5, 1024)).cuda(), - loss_mask, - num_image_tiles=torch.tensor([], dtype=torch.int).cuda(), - ) - - assert loss.shape == new_loss_mask.shape == torch.Size((5, 1024)) - - # Try without labels and without inference params. - logits, _ = self.model.forward( - img, - input_ids, - position_ids, - attention_mask, - labels=None, - loss_mask=None, - num_image_tiles=num_image_tiles, - ) - assert logits.shape == torch.Size((5, max_seq_len, 8192)) - - # Try without labels and with inference params. - inference_context = StaticInferenceContext(5, max_seq_len) - logits, _ = self.model.forward( - img, - input_ids, - position_ids, - attention_mask, - labels=None, - loss_mask=None, - num_image_tiles=num_image_tiles, - inference_context=inference_context, - ) - assert logits.shape == torch.Size((5, max_seq_len, 8192)) - - # Check KV cache got populated correctly. - kv_dict = inference_context.key_value_memory_dict - - assert kv_dict["image_tokens_count"] == 577 * 7 - for layer_no in range(1, 4): # 3 layers in the model. - layer_kv = kv_dict[layer_no] - # Expected shape is [sequence_len, batch_size, num_heads, hidden_size_per_head] - assert ( - layer_kv[0].shape - == layer_kv[1].shape - == torch.Size((max_seq_len, 5, self.language_num_attention_heads, 16)) - ) - - @pytest.mark.internal - def test_forward_fsdp(self): - """Test FSDP workaround for text-only data. - - FSDP can hang with text-only data. As a workaround, we run the vision model with a dummy image, - but then effectively discard the image embeddings. - """ - self.model.cuda() - - # Dummy image for the FSDP workaround but not image tiles. - img = torch.zeros((1, 3, 336, 336)).cuda() - num_image_tiles = torch.tensor([], dtype=torch.int).cuda() - - # No image tag in the input ids (text-only sample). - image_token_index = self.model.image_token_index - input_ids = torch.arange(1024, device="cuda").unsqueeze(0) - assert ( - torch.sum(input_ids == image_token_index) == 0 - ), "expected no image tag in the input ids" - - position_ids = torch.arange(1024, device="cuda").unsqueeze(0) - - loss_mask = torch.ones((1, 1024), device="cuda") - - attention_mask = None # Causal. - - labels = torch.arange(1, 1025, device="cuda").unsqueeze(0) - - # Mock the FSDP attribute. - self.model.vision_model._is_fsdp_managed_module = True - loss, new_loss_mask = self.model.forward( - img, - input_ids, - position_ids, - attention_mask, - labels, - loss_mask, - num_image_tiles=num_image_tiles, - ) - self.model.vision_model._is_fsdp_managed_module = False - - assert loss.shape == new_loss_mask.shape == torch.Size((1, 1024)) - - @pytest.mark.internal - def test_save_load(self, tmp_path): - path = tmp_path / "model.pt" - torch.save(self.model.state_dict(), path) - - self.model.load_state_dict(torch.load(path)) - - @pytest.mark.internal - def test_freeze(self): - self.model.freeze( - freeze_language_model=True, freeze_vision_model=True, freeze_vision_projection=False - ) - - for module in [self.model.language_model, self.model.vision_model]: - for param in module.parameters(): - assert not param.requires_grad - - for param in self.model.vision_projection.parameters(): - assert param.requires_grad - - -@pytest.fixture(scope='class', params=["siglip", "radio-g"]) -def setup_and_teardown_llava_model(request): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - - language_config = TransformerConfig( - num_layers=3, hidden_size=128, num_attention_heads=8, use_cpu_initialization=False - ) - vision_config = TransformerConfig( - num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=False - ) - vision_projection_config = TransformerConfig( - num_layers=2, - hidden_size=128, - ffn_hidden_size=72, - num_attention_heads=1, - use_cpu_initialization=False, - ) - - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() - vision_layer_spec = deepcopy(language_layer_spec) - vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules) - - language_config.language_model_type = "dummy" - vision_model_type = request.param - vision_config.vision_model_type = vision_model_type - model = LLaVAModel( - language_transformer_config=language_config, - language_transformer_layer_spec=language_layer_spec, - language_vocab_size=2048, - language_max_sequence_length=4096, - vision_transformer_config=vision_config, - vision_transformer_layer_spec=vision_layer_spec, - drop_vision_class_token=False, - vision_projection_config=vision_projection_config, - vision_projection_layer_spec=vision_projection_spec, - img_h=336, - img_w=336, - patch_dim=14, - ) - - yield model, vision_model_type - - Utils.destroy_model_parallel() - - -class TestLLaVAModelVisionEncoders: - num_weights_by_encoder = {"siglip": 1832456, "radio-g": 2844552} - - @pytest.mark.internal - def test_constructor(self, setup_and_teardown_llava_model): - model, vision_model_type = setup_and_teardown_llava_model - assert isinstance(model, LLaVAModel) - - num_weights = sum([p.numel() for p in model.parameters()]) - assert num_weights == self.num_weights_by_encoder[vision_model_type] - - @pytest.mark.internal - def test_set_input_tensor(self, setup_and_teardown_llava_model): - model, _ = setup_and_teardown_llava_model - expected_shape = (1, 2, 3, 4) - input_tensor = torch.zeros(expected_shape) - model.set_input_tensor(input_tensor) - assert model.vision_model.decoder.input_tensor.shape == expected_shape - - -def create_test_args(cp_size, sequence_parallel): - # Set dummy values for the args. - args = SimpleNamespace() - args.context_parallel_size = cp_size - args.sequence_parallel = sequence_parallel - - return args - - -class TestLLaVAModelTokenParallel: - - def _init_llava_model(self, cp_size, tp_size, sequence_parallel): - language_hidden_size = 64 - language_num_attention_heads = 16 - - language_config = TransformerConfig( - num_layers=3, - hidden_size=language_hidden_size, - num_attention_heads=language_num_attention_heads, - use_cpu_initialization=False, - tensor_model_parallel_size=tp_size, - sequence_parallel=sequence_parallel, - context_parallel_size=cp_size, - ) - # SP and CP are not yet supported for the Vision Backbone - vision_config = TransformerConfig( - num_layers=2, - hidden_size=16, - num_attention_heads=8, - use_cpu_initialization=False, - tensor_model_parallel_size=tp_size, - sequence_parallel=False, - context_parallel_size=1, - ) - vision_projection_config = TransformerConfig( - num_layers=2, - hidden_size=language_hidden_size, - ffn_hidden_size=128, - num_attention_heads=8, - use_cpu_initialization=False, - tensor_model_parallel_size=tp_size, - sequence_parallel=False, - context_parallel_size=1, - ) - - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() - # SP/CP either requires user to ensure token lengths do not require padding OR change mask type to padding - if ( - language_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') - == AttnMaskType.causal - ): - language_layer_spec.submodules.self_attention.params['attn_mask_type'] = ( - AttnMaskType.padding_causal - ) - elif ( - language_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') - == AttnMaskType.no_mask - ): - language_layer_spec.submodules.self_attention.params['attn_mask_type'] = ( - AttnMaskType.padding - ) - - vision_layer_spec = deepcopy(language_layer_spec) - vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules) - - language_config.language_model_type = "dummy" - vision_config.vision_model_type = "clip" - model = LLaVAModel( - language_transformer_config=language_config, - language_transformer_layer_spec=language_layer_spec, - language_vocab_size=8192, - language_max_sequence_length=4096, - vision_transformer_config=vision_config, - vision_transformer_layer_spec=vision_layer_spec, - drop_vision_class_token=False, - vision_projection_config=vision_projection_config, - vision_projection_layer_spec=vision_projection_spec, - img_h=336, - img_w=336, - patch_dim=14, - ) - - return model - - def _prepare_inputs(self, cp_size, tp_size, sequence_parallel, padding): - self.batch_size = 2 - if padding: - self.combined_valid_seqlen = 2049 - self.combined_padded_seqlen = 2064 - else: - self.combined_valid_seqlen = 2048 - self.combined_padded_seqlen = 2048 - - if cp_size > 1: - combined_embeddings = torch.ones( - [self.batch_size, self.combined_padded_seqlen, 4096], - device='cuda', - dtype=torch.bfloat16, - ) # [B, S, H] - else: - combined_embeddings = torch.ones( - [self.combined_padded_seqlen, self.batch_size, 4096], - device='cuda', - dtype=torch.bfloat16, - ) # [S, B, H] - new_labels = torch.ones( - [self.batch_size, self.combined_padded_seqlen], device='cuda', dtype=torch.bfloat16 - ) # [B, S] - new_loss_mask = torch.ones( - [self.batch_size, self.combined_padded_seqlen], device='cuda', dtype=torch.bfloat16 - ) # [B, S] - - cu_seqlens = torch.arange( - 0, - (self.batch_size + 1) * (self.combined_valid_seqlen), - step=(self.combined_valid_seqlen), - dtype=torch.int32, - device=combined_embeddings.device, - ) - cu_seqlens_padded = torch.arange( - 0, - (self.batch_size + 1) * (self.combined_padded_seqlen), - step=(self.combined_padded_seqlen), - dtype=torch.int32, - device=combined_embeddings.device, - ) - - qkv_format = 'sbhd' # Default format when not using padding - if cp_size > 1 and padding: - # Reshape from [B,S] to [1,T] - combined_embeddings = ( - combined_embeddings.contiguous() - .view(combined_embeddings.shape[0] * combined_embeddings.shape[1], -1) - .unsqueeze(0) - ) - new_labels = new_labels.view(new_labels.shape[0] * new_labels.shape[1]).unsqueeze(0) - new_loss_mask = new_loss_mask.view( - new_loss_mask.shape[0] * new_loss_mask.shape[1] - ).unsqueeze(0) - qkv_format = 'thd' - - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded, - max_seqlen_q=self.combined_padded_seqlen, - max_seqlen_kv=self.combined_padded_seqlen, - qkv_format=qkv_format, - ) - - return combined_embeddings, new_labels, new_loss_mask, packed_seq_params - - @pytest.mark.internal - def setup_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - @pytest.mark.parametrize( - "cp_size,tp_size,sequence_parallel,padding", - [(1, 8, True, True), (2, 4, False, True), (2, 4, True, False), (2, 4, True, True)], - ) - def test_process_embedding_token_parallel(self, cp_size, tp_size, sequence_parallel, padding): - """Test _process_embedding_token_parallel. - - Note: This test requires TE version >= 1.10.0 to run properly. - """ - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, context_parallel_size=cp_size - ) - model_parallel_cuda_manual_seed(123) - - # TE version must be at least 1.10.0 if using context parallelism. Exit otherwise. - ctx = ( - nullcontext() - if (is_te_min_version("1.10.0") or cp_size <= 1) - else pytest.raises(AssertionError) - ) - model = None - with ctx: - model = self._init_llava_model(cp_size, tp_size, sequence_parallel) - - if model is None: - return - - model.cuda() - - args = create_test_args(cp_size, sequence_parallel) - set_args(args) - - combined_embeddings, new_labels, new_loss_mask, packed_seq_params = self._prepare_inputs( - cp_size, tp_size, sequence_parallel, padding - ) - - combined_embeddings, new_labels, new_loss_mask, packed_seq_params = ( - model._process_embedding_token_parallel( - combined_embeddings, new_labels, new_loss_mask, packed_seq_params - ) - ) - - # Check if output shape is as expected - if cp_size > 1 and sequence_parallel: - if padding: - # THD format - assert combined_embeddings.shape[0] == self.batch_size * ( - self.combined_padded_seqlen / (tp_size * cp_size) - ) - assert combined_embeddings.shape[1] == 1 - else: - # SBHD format - assert combined_embeddings.shape[0] == ( - self.combined_padded_seqlen / (tp_size * cp_size) - ) - assert combined_embeddings.shape[1] == self.batch_size - elif cp_size > 1: - if padding: - # THD format - assert combined_embeddings.shape[0] == self.batch_size * ( - self.combined_padded_seqlen / cp_size - ) - assert combined_embeddings.shape[1] == 1 - else: - # SBHD format - assert combined_embeddings.shape[0] == (self.combined_padded_seqlen / cp_size) - assert combined_embeddings.shape[1] == self.batch_size - else: - # SBHD format - assert combined_embeddings.shape[0] == self.combined_padded_seqlen / tp_size - assert combined_embeddings.shape[1] == self.batch_size - - -def count_parameters(model): - return sum(p.numel() for p in model.parameters()) - - -@pytest.mark.internal -@pytest.mark.parametrize( - "cp_size, tp_size, has_sp, seq_len, fp8_enabled, expected_padding", - [ - (1, 1, False, 99, False, 0), - (2, 2, True, 99, False, 5), - (2, 2, False, 99, False, 1), - (1, 4, False, 99, True, 13), - ], -) -def test_get_padding(cp_size, tp_size, has_sp, seq_len, fp8_enabled, expected_padding): - """Test calculating padding for context parallel.""" - padding = context_parallel.get_padding( - seq_len, cp_size, tp_size, has_sp, fp8_enabled=fp8_enabled - ) - - assert padding == expected_padding - - -@pytest.mark.internal -@pytest.mark.parametrize( - "tokens, img_seq_len, padding_needed, cp_size, expected_seq_len", - [(torch.ones((1, 100)), 100, 0, 2, 200), (torch.ones((1, 100)), 128, 1, 2, 227)], -) -def test_get_packed_seq_params(tokens, img_seq_len, padding_needed, cp_size, expected_seq_len): - """Test creating PackedSeqParams for context parallel.""" - packed_seq_params = context_parallel.get_packed_seq_params( - tokens, img_seq_len, padding_needed, cp_size - ) - - assert torch.equal( - packed_seq_params.cu_seqlens_q, torch.tensor([0, expected_seq_len], dtype=torch.int32) - ) - - if padding_needed > 0: - padded_seq_len = tokens.shape[1] + img_seq_len - assert torch.equal( - packed_seq_params.cu_seqlens_q_padded, - torch.tensor([0, padded_seq_len], dtype=torch.int32), - ) - assert packed_seq_params.max_seqlen_q == padded_seq_len diff --git a/tests/unit_tests/models/test_mamba_model.py b/tests/unit_tests/models/test_mamba_model.py deleted file mode 100644 index 1fbc05852e..0000000000 --- a/tests/unit_tests/models/test_mamba_model.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -from datetime import timedelta - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.inference.contexts import BaseInferenceContext, StaticInferenceContext -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.models.mamba.mamba_model import MambaModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig -from megatron.core.utils import divide, is_torch_min_version -from tests.unit_tests.test_utilities import Utils - - -class TestMambaModel: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - model_config = TransformerConfig( - num_layers=3, # 1 Mamba layer, 1 attention layer, 1 MLP layer - hidden_size=256, # The Mamba layer places several constraints on this - num_attention_heads=4, - use_cpu_initialization=True, - ) - self.model = MambaModel( - config=model_config, - mamba_stack_spec=mamba_stack_spec, - vocab_size=100, - max_sequence_length=4, - hybrid_attention_ratio=0.3, - hybrid_mlp_ratio=0.3, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - assert isinstance(self.model, MambaModel) - - assert self.model.max_sequence_length == 4 - - num_weights = sum([p.numel() for p in self.model.parameters()]) - assert num_weights == 1774872 - - def test_set_input_tensor(self): - config: TransformerConfig = self.model.config - sequence_length = self.model.max_sequence_length - micro_batch_size = 2 - - # [sequence length, batch size, hidden size] - input_tensor = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - - self.model.set_input_tensor(input_tensor) - - assert self.model.decoder.input_tensor.shape[0] == sequence_length - assert self.model.decoder.input_tensor.shape[1] == micro_batch_size - assert self.model.decoder.input_tensor.shape[2] == config.hidden_size - - def test_forward(self): - config: TransformerConfig = self.model.config - sequence_length = self.model.max_sequence_length - micro_batch_size = 2 - - self.model.cuda() - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = self.model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.model.vocab_size - - def test_inference(self): - config: TransformerConfig = self.model.config - micro_batch_size = 2 - inference_context: BaseInferenceContext = StaticInferenceContext( - max_batch_size=micro_batch_size, max_sequence_length=self.model.max_sequence_length - ) - prompt_length = self.model.max_sequence_length - 1 - - self.model.cuda() - - # load-context/first-output-token, step/generate - for offset in (0, prompt_length): - if offset == 0: - sequence_length = prompt_length - else: - sequence_length = 1 - inference_context.sequence_len_offset = offset - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = ( - torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - ) - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - inference_context=inference_context, - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.model.vocab_size - - def test_save_load(self, tmp_path): - path = tmp_path / "model.pt" - torch.save(self.model.state_dict(), path) - - self.model.load_state_dict(torch.load(path)) - - def test_layer_numbers(self): - """ - The layer numbers should start at one (for the embedding # layer) and go up - incrementally from there. This is required for PEFT to work. - """ - model = self.model - for expected, layer in enumerate(model.decoder.layers, start=1): - assert expected == layer.layer_number, "layer numbers are incorrect" - - @pytest.mark.skipif( - not is_torch_min_version("2.4.0"), - reason="torch.distributed.init_device_mesh requires torch >= 2.4.0", - ) - @pytest.mark.parametrize("tp_size,cp_size,pp_size", [(2, 1, 4), (1, 1, 8), (8, 1, 1)]) - def test_with_custom_process_groups(self, tmp_path, tp_size, cp_size, pp_size): - """Test MambaModel with custom process groups.""" - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, - context_parallel_size=cp_size, - pipeline_model_parallel_size=pp_size, - ) - - # Create device mesh for custom process groups - assert torch.distributed.get_world_size() == 8, "Test requires 8 GPUs" - - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - # Create HyperCommGrid with dimensions tp, cp, pp (reversed from device mesh order) - grid = HyperCommGrid([tp_size, cp_size, pp_size], ["tp", "cp", "pp"]) - - pp_group = grid.create_pg("pp") - cp_group = grid.create_pg("cp") - tp_group = grid.create_pg("tp") - embd_group_ranks = parallel_state.default_embedding_ranks( - torch.distributed.get_process_group_ranks(pp_group) - ) - embd_group = torch.distributed.new_group( - ranks=embd_group_ranks, timeout=timedelta(minutes=30) - ) - - # Create model with custom process groups - from megatron.core.process_groups_config import ModelCommProcessGroups - - model_comm_pgs = ModelCommProcessGroups( - tp=tp_group, cp=cp_group, pp=pp_group, embd=embd_group - ) - - # Configure model with appropriate sizes for parallelism - model_config = TransformerConfig( - num_layers=3 * pp_size, # Scale layers with PP size - hidden_size=256 * tp_size, - num_attention_heads=4 * tp_size, # Scale heads with TP size - use_cpu_initialization=True, - tensor_model_parallel_size=tp_size, - context_parallel_size=cp_size, - pipeline_model_parallel_size=pp_size, - pipeline_dtype=torch.bfloat16, - ) - - model = MambaModel( - config=model_config, - mamba_stack_spec=mamba_stack_spec, - vocab_size=128, - max_sequence_length=4, - hybrid_attention_ratio=0.3, - hybrid_mlp_ratio=0.3, - model_comm_pgs=model_comm_pgs, - ) - - # Basic forward test - micro_batch_size = 2 - sequence_length = model.max_sequence_length - - model.cuda() - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == divide(model.vocab_size, tp_size) diff --git a/tests/unit_tests/models/test_mimo_audio_submodules.py b/tests/unit_tests/models/test_mimo_audio_submodules.py deleted file mode 100644 index 0f3865d940..0000000000 --- a/tests/unit_tests/models/test_mimo_audio_submodules.py +++ /dev/null @@ -1,396 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -''' -WORLD_SIZE=1 LOCAL_RANK=0 python -m pytest tests/unit_tests/models/test_mimo_audio_submodules.py -''' -import math -import random - -import numpy as np -import pytest -import torch -from transformers import ( - ASTConfig, - ASTFeatureExtractor, - ASTModel, - Wav2Vec2FeatureExtractor, - WavLMConfig, - WavLMModel, - WhisperConfig, - WhisperFeatureExtractor, - WhisperModel, -) - -from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from tests.unit_tests.test_utilities import Utils - -pytest.importorskip("modelopt", minversion="0.25") -# modelopt version < 0.27 breaks HF AutoModel.from_pretrained API -# so we need to skip the tests unitl versions are bumped in pyt LTS CI container - -# Model-specific audio processing parameters -AUDIO_MODEL_PARAMS = { - "openai/whisper-base": { - "sample_rate": 16000, # 16kHz - "window_stride": 0.01, # 10ms - "encoder_down_sampling": 2, - "d_model": 512, - "max_length_seconds": 30.0, - }, - # WavLM models - "patrickvonplaten/wavlm-libri-clean-100h-base-plus": { - "sample_rate": 16000, # 16kHz - "window_stride": 0.02, # 20ms - # Note: WavLM uses a series of convolutional layers with different kernels and strides - # rather than a single downsampling factor. The overall effect is approximately 320x, - # but we calculate it precisely using the conv_kernel and conv_stride parameters: - # conv_kernel = [10, 3, 3, 3, 3, 2, 2] - # conv_stride = [5, 2, 2, 2, 2, 2, 2] - "encoder_down_sampling": 1, # Placeholder, not used for WavLM - "d_model": 768, - "max_length_seconds": 30.0, - }, - # AST model - "MIT/ast-finetuned-audioset-10-10-0.4593": { - "sample_rate": 16000, # 16kHz - "window_stride": 0.01, # 10ms for spectrogram creation - # AST uses fixed-size mel spectrograms and processes with patches - "max_spectrogram_length": 1024, # Maximum spectrogram length in frames - "num_mel_bins": 128, # Number of mel bins - "patch_size": 16, # Size of each patch - "time_stride": 10, # Stride for time dimension - "frequency_stride": 10, # Stride for frequency dimension - "d_model": 768, # Hidden size - "max_length_seconds": 10.0, # Reasonable maximum for testing - }, -} - - -class AudioEncoderWrapper(torch.nn.Module): - """Generic wrapper for audio encoder models that extracts last_hidden_state.""" - - def __init__(self, encoder, model_type="whisper"): - super().__init__() - self.encoder = encoder - self.model_type = model_type - - def forward(self, input_features, seq_lengths=None): - with torch.no_grad(): - hidden = self.encoder(input_features).last_hidden_state # [b, s, h] - if seq_lengths is not None: - seq_len = hidden.shape[1] - # breakpoint() - mask = torch.arange(seq_len, device=hidden.device)[None, :] < seq_lengths[:, None] - hidden = hidden[mask] - return hidden - - -def calculate_num_mel_frames(audio_length, sample_rate, window_stride, window_length=None): - """ - Calculate the number of mel frames from an audio signal. - - Parameters: - - audio_length (int): Total number of audio samples. - - sample_rate (int or float): Sampling rate of the audio (samples per second). - - window_stride (float): The time (in seconds) between successive frames. - - window_length (float, optional): Window length in seconds. If provided, this function - uses the standard formula: floor((N - window_length_in_samples) / hop_length) + 1. - Otherwise, it uses the simplified calculation based on the window stride only. - - Returns: - - int: The number of mel frames. - """ - hop_length_samples = int(window_stride * sample_rate) - - if window_length is None: - num_frames = math.ceil((audio_length + 1) / hop_length_samples) - else: - window_length_samples = int(window_length * sample_rate) - num_frames = math.floor((audio_length - window_length_samples) / hop_length_samples) + 1 - - return num_frames - - -class TestAudioSubmodule: - """Test the AudioModalitySubmodules class with forward passes.""" - - def setup_method(self, method, model_name="openai/whisper-base"): - '''setup env''' - # Initialize distributed environment - try: - Utils.initialize_model_parallel(1, 1) - except Exception as e: - print(f"Warning: Could not initialize model parallel: {e}") - - model_parallel_cuda_manual_seed(123) - random.seed(123) # For reproducible random test cases - - # Get model-specific parameters - if model_name not in AUDIO_MODEL_PARAMS: - raise ValueError( - f"Model {model_name} not supported. Available models: {list(AUDIO_MODEL_PARAMS.keys())}" - ) - - model_params = AUDIO_MODEL_PARAMS[model_name] - - # Audio processing parameters - self.sample_rate = model_params["sample_rate"] - self.window_stride = model_params.get("window_stride", 0.01) - self.sample_per_mel_frame = int(self.window_stride * self.sample_rate) - self.encoder_down_sampling = model_params.get("encoder_down_sampling", 1) - self.max_length_seconds = model_params["max_length_seconds"] - - # For AST model - self.max_spectrogram_length = model_params.get("max_spectrogram_length", None) - self.num_mel_bins = model_params.get("num_mel_bins", None) - self.patch_size = model_params.get("patch_size", None) - self.time_stride = model_params.get("time_stride", None) - self.frequency_stride = model_params.get("frequency_stride", None) - - self.audio_token_id = 50000 - - # Keep name for logs - self.model_name = model_name - - # Decide model type - if "whisper" in model_name: - self.model_type = "whisper" - config = WhisperConfig() - model = WhisperModel(config) - raw_encoder = model.encoder - self.processor = WhisperFeatureExtractor() - elif "wavlm" in model_name: - self.model_type = "wavlm" - config = WavLMConfig() - model = WavLMModel(config) - raw_encoder = model - self.processor = Wav2Vec2FeatureExtractor() - elif "ast" in model_name.lower(): - self.model_type = "ast" - config = ASTConfig( - num_mel_bins=self.num_mel_bins, - patch_size=self.patch_size, - fstride=self.frequency_stride, - tstride=self.time_stride, - ) - model = ASTModel(config) - raw_encoder = model - self.processor = ASTFeatureExtractor() - else: - raise ValueError(f"Unsupported model type: {model_name}") - - self.encoder = AudioEncoderWrapper(raw_encoder, self.model_type) - if hasattr(model.config, "d_model"): - self.d_model = model.config.d_model - else: - self.d_model = model_params["d_model"] - self.projection = torch.nn.Linear(self.d_model, 768) - self.audio_module = AudioModalitySubmodules( - encoders={"encoder": self.encoder}, input_projections=[self.projection] - ) - - def teardown_method(self, method): - '''teardown env''' - try: - Utils.destroy_model_parallel() - except Exception as e: - print(f"Warning: Could not destroy model parallel: {e}") - - def _create_sample_audio(self, duration_seconds, sample_rate=None): - """Create a sample audio waveform. - - Args: - duration_seconds (float): Duration of audio in seconds - sample_rate (int, optional): Sample rate in Hz. Defaults to self.sample_rate. - - Returns: - torch.Tensor: Audio waveform of shape [1, samples] - """ - sample_rate = sample_rate or self.sample_rate - - # Create a time array - t = np.linspace(0, duration_seconds, int(duration_seconds * sample_rate), endpoint=False) - - # Create a simple sine wave at 440 Hz (A4) - frequency = 440.0 - waveform = 0.5 * np.sin(2 * np.pi * frequency * t) - - # Convert to torch tensor - return torch.tensor(waveform, dtype=torch.float32).unsqueeze(0) - - def _calculate_seq_length(self, audio_tensor): - - # Get audio length in samples - audio_length = audio_tensor.shape[1] - - if self.model_type in ["whisper"]: - - num_mel_frames = calculate_num_mel_frames( - audio_length, self.sample_rate, self.window_stride - ) - encoder_seq_length = math.ceil(num_mel_frames / self.encoder_down_sampling) - - elif self.model_type == "wavlm": - # For WavLM, use the exact convolutional calculation logic - # WavLM uses a series of convolutional layers with different kernels and strides - conv_kernel = [10, 3, 3, 3, 3, 2, 2] - conv_stride = [5, 2, 2, 2, 2, 2, 2] - - # Function to calculate output length of 1D convolution - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 - - # Start with the original input length - input_length = audio_length - - # Apply each convolutional layer - for kernel_size, stride in zip(conv_kernel, conv_stride): - input_length = _conv_out_length(input_length, kernel_size, stride) - - # The result is the encoder sequence length - encoder_seq_length = input_length - - elif self.model_type == "ast": - # AST uses a fixed-size spectrogram and divides it into patches - # The exact formula is based on how CNN output dimensions are calculated - # See: https://cs231n.github.io/convolutional-networks/#conv - frequency_out_dimension = ( - self.num_mel_bins - self.patch_size - ) // self.frequency_stride + 1 - time_out_dimension = ( - self.max_spectrogram_length - self.patch_size - ) // self.time_stride + 1 - - # Number of patches is the product of these dimensions - num_patches = frequency_out_dimension * time_out_dimension - - # Add 2 for the cls_token and distillation_token - encoder_seq_length = num_patches + 2 - - print( - f"AST patches: freq_dim={frequency_out_dimension}, time_dim={time_out_dimension}, " - f"patches={num_patches}, total={encoder_seq_length}" - ) - - else: - raise ValueError(f"Unsupported model type: {self.model_type}") - - return max(1, int(encoder_seq_length)) - - def _create_batch(self, batch_size=3, min_duration=1.0, max_duration=1.5): - """ - Create a simple batch with mixed text and audio content. - """ - # Use default parameters if not provided - sample_rate = self.sample_rate - - # Randomly choose 1-4 audio segments per sample - num_segments_per_sample = [random.randint(1, 4) for _ in range(batch_size)] - total_segments = sum(num_segments_per_sample) - audio_samples = [ - self._create_sample_audio(random.uniform(min_duration, max_duration), sample_rate) - for _ in range(total_segments) - ] - - processor_kwargs = {"sampling_rate": sample_rate, "return_tensors": "pt"} - # processor for whisper (30 sec) and ast pads (1024 framesto max length - # for wavlm lets pad to longest in the batch - if self.model_type in ["wavlm"]: - processor_kwargs["padding"] = "longest" - processed = self.processor( - [sample.squeeze().numpy() for sample in audio_samples], **processor_kwargs - ) - - if self.model_type == "whisper": - processed_features = processed.input_features - elif self.model_type in ["ast", "wavlm"]: - processed_features = processed.input_values - else: - raise ValueError(f"Unsupported model type: {self.model_type}") - - # Calculate sequence lengths for audio tokens - seq_lengths = torch.tensor( - [self._calculate_seq_length(sample) for sample in audio_samples], dtype=torch.long - ) - - max_seq_len = 4096 # Arbitrary length that's enough for test - input_ids = torch.zeros((batch_size, max_seq_len), dtype=torch.long) - - # Keep track of which audio segment we're using - segment_idx = 0 - - # Fill input_ids with text and audio tokens - for i in range(batch_size): - pos = 0 - num_segments = num_segments_per_sample[i] - - for _ in range(num_segments): - # Random text - text_len = random.randint(3, 8) - input_ids[i, pos : pos + text_len] = torch.randint(1, 30000, (text_len,)) - pos += text_len - - # Audio segment - audio_len = seq_lengths[segment_idx].item() - input_ids[i, pos : pos + audio_len] = self.audio_token_id - pos += audio_len - segment_idx += 1 - - # Final text - text_len = random.randint(3, 8) - if pos + text_len < max_seq_len: - input_ids[i, pos : pos + text_len] = torch.randint(1, 30000, (text_len,)) - pos += text_len - - # Padding - if pos < max_seq_len: - input_ids[i, pos:] = 0 - - return { - 'audio': processed_features, - 'input_ids': input_ids, - 'modality_seq_lengths': {'audio': seq_lengths}, - } - - @pytest.mark.parametrize( - "model_name,batch_size", - [ - # Test with batch_size=1 - pytest.param("openai/whisper-base", 1, id="whisper-base-batch1"), - pytest.param("patrickvonplaten/wavlm-libri-clean-100h-base-plus", 1, id="wavlm-batch1"), - pytest.param("MIT/ast-finetuned-audioset-10-10-0.4593", 1, id="ast-batch1"), - # Test with batch_size=2 - pytest.param("openai/whisper-base", 2, id="whisper-base-batch2"), - pytest.param("patrickvonplaten/wavlm-libri-clean-100h-base-plus", 2, id="wavlm-batch2"), - pytest.param("MIT/ast-finetuned-audioset-10-10-0.4593", 2, id="ast-batch2"), - ], - ) - def test_multiple_audio_encoders(self, model_name, batch_size): - '''Test the forward pass with different audio encoder models and batch sizes''' - self.setup_method(None, model_name=model_name) - - batch = self._create_batch(batch_size=batch_size, min_duration=1.0, max_duration=3.0) - - feature_key = "input_features" - - # Create encoder inputs dictionary with named encoder - seq_lengths = batch['modality_seq_lengths']['audio'] - encoder_inputs = {"encoder": {feature_key: batch['audio'], "seq_lengths": seq_lengths}} - - # Call forward with new interface - embeddings = self.audio_module.forward(encoder_inputs) - - num_audio_tokens = (batch['input_ids'] == self.audio_token_id).sum().item() - - # Verify number of embeddings matches number of audio tokens - assert embeddings.shape[0] == num_audio_tokens - - # Verify embeddings have expected dimension (768 is our target dimension) - assert embeddings.shape[1] == 768 - - print( - f"Model {model_name} (d_model={self.d_model}) successfully processed audio and projected to dimension 768" - ) diff --git a/tests/unit_tests/models/test_mimo_embedding_alignment.py b/tests/unit_tests/models/test_mimo_embedding_alignment.py deleted file mode 100644 index 7459625677..0000000000 --- a/tests/unit_tests/models/test_mimo_embedding_alignment.py +++ /dev/null @@ -1,447 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -''' -WORLD_SIZE=1 LOCAL_RANK=0 python -m pytest tests/unit_tests/models/test_mimo_embedding_alignment.py -''' - -from unittest.mock import MagicMock - -import pytest -import torch - -from megatron.core.models.mimo.config import MimoModelConfig -from megatron.core.models.mimo.model.base import MimoModel -from megatron.core.transformer.spec_utils import ModuleSpec - - -class TestEmbeddingAlignment: - """Test the align_embeddings_by_token_positions method in MimoModel.""" - - def setup_method(self): - """Set up for each test.""" - # Create a minimal MimoModelConfig - language_model_spec = ModuleSpec(module=MagicMock, params={'config': MagicMock()}) - self.mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={}, - special_token_ids={}, - ) - - # Create MimoModel instance - self.model = MimoModel(self.mimo_config) - - self.hidden_dim = 64 - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def create_marker_embeddings(self, num_embeddings, marker_positions=None, marker_values=None): - """Create embeddings with marker values at specific positions. - - Args: - num_embeddings: Number of embeddings to create - marker_positions: List of positions to place markers. If None, uses range(num_embeddings) - marker_values: List of values to use for markers. If None, uses [10.0, 20.0, ...] - - Returns: - Tensor of shape [num_embeddings, hidden_dim] with markers at specified positions - """ - if marker_positions is None: - marker_positions = list(range(num_embeddings)) - - embeddings = torch.zeros((num_embeddings, self.hidden_dim), device=self.device) - - # Set distinctive markers - for i, pos in enumerate(marker_positions): - # Use provided value or default pattern - if marker_values is not None and i < len(marker_values): - marker_value = marker_values[i] - else: - marker_value = float(i + 1) * 10.0 # Values like 10.0, 20.0, 30.0, etc. - - embeddings[i, pos % self.hidden_dim] = marker_value - - return embeddings - - def test_basic_alignment(self): - """Test basic alignment with text and one modality.""" - # Create a simple batch - batch_size = 2 - seq_length = 8 - hidden_dim = self.hidden_dim - - # Create input_ids with special tokens - # Sequence 1: [text, image_token, text, text, text, text, text, text] - # Sequence 2: [text, text, text, image_token, text, text, text, text] - input_ids = torch.full((batch_size, seq_length), 100, dtype=torch.long, device=self.device) - - # Add image special tokens at different positions for each sequence - image_token_id = 50 - input_ids[0, 1] = image_token_id # Batch 0, position 1 - input_ids[1, 3] = image_token_id # Batch 1, position 3 - - # Create text embeddings (14 tokens total - 7 text tokens per sequence) - # Instead of zeros, use a small distinct value for text embeddings - text_embeddings = torch.full((14, hidden_dim), 0.01, device=self.device) - - # Create vision embeddings with distinctive markers - # For batch 0: marker at position 0 with value 10.0 - # For batch 1: marker at position 1 with value 20.0 - vision_embeddings = self.create_marker_embeddings(2, marker_positions=[0, 1]) - - # Define special token IDs - special_token_ids = {"vision": image_token_id} - - # Align embeddings - modality_embeddings = {"text": text_embeddings, "vision": vision_embeddings} - - combined = self.model.align_embeddings_by_token_positions( - modality_embeddings=modality_embeddings, - input_ids=input_ids, - special_token_ids=special_token_ids, - ) - - # Check output shape - assert combined.shape == (seq_length, batch_size, hidden_dim) - - # Check special token positions have the correct embeddings - # First vision token (Batch 0, Seq 1) should have the first vision embedding - assert combined[1, 0, 0] == 10.0 # First marker - assert torch.all(combined[1, 0, 1:] == 0.0), "Non-zero values found after marker" - - # Second vision token (Batch 1, Seq 3) should have the second vision embedding - assert combined[3, 1, 1] == 20.0 # Second marker - assert torch.all(combined[3, 1, :1] == 0.0), "Non-zero values found before marker" - assert torch.all(combined[3, 1, 2:] == 0.0), "Non-zero values found after marker" - - # Verify text positions have only zeros - text_positions = [ - (0, 0), - (2, 0), - (3, 0), - (4, 0), - (5, 0), - (6, 0), - (7, 0), # Batch 0 - (0, 1), - (1, 1), - (2, 1), - (4, 1), - (5, 1), - (6, 1), - (7, 1), # Batch 1 - ] - - for s, b in text_positions: - assert torch.all(combined[s, b] == 0.01) - - def test_multiple_modalities(self): - """Test alignment with multiple modalities with special tokens at different positions.""" - batch_size = 2 - seq_length = 10 - hidden_dim = self.hidden_dim - - # Create input_ids with special tokens for multiple modalities - # Sequence 1: [text, vision, text, text, audio, text, text, text, video, text] - # Sequence 2: [text, text, vision, text, text, audio, text, video, text, text] - input_ids = torch.full((batch_size, seq_length), 100, dtype=torch.long, device=self.device) - - # Define special token IDs - vision_token_id = 50 - audio_token_id = 51 - video_token_id = 52 - - # Add special tokens at different positions in each sequence - # First sequence - input_ids[0, 1] = vision_token_id # Vision at pos 1 in seq 0 - input_ids[0, 4] = audio_token_id # Audio at pos 4 in seq 0 - input_ids[0, 8] = video_token_id # Video at pos 8 in seq 0 - - # Second sequence - input_ids[1, 2] = vision_token_id # Vision at pos 2 in seq 1 - input_ids[1, 5] = audio_token_id # Audio at pos 5 in seq 1 - input_ids[1, 7] = video_token_id # Video at pos 7 in seq 1 - - # Calculate text tokens: 7 tokens in each sequence - # Create non-zero text embeddings for better verification - text_embeddings = torch.full((14, hidden_dim), 0.01, device=self.device) - - # Create marker embeddings for each modality with specific positions and values - # For vision: both embeddings have markers at position 0 - vision_embeddings = self.create_marker_embeddings( - num_embeddings=2, - marker_positions=[0, 0], # Both markers at position 0 - marker_values=[10.0, 20.0], # Batch 0 and Batch 1 markers - ) - - # For audio: both embeddings have markers at position 1 - audio_embeddings = self.create_marker_embeddings( - num_embeddings=2, - marker_positions=[1, 1], # Both markers at position 1 - marker_values=[30.0, 40.0], # Batch 0 and Batch 1 markers - ) - - # For video: both embeddings have markers at position 2 - video_embeddings = self.create_marker_embeddings( - num_embeddings=2, - marker_positions=[2, 2], # Both markers at position 2 - marker_values=[50.0, 60.0], # Batch 0 and Batch 1 markers - ) - - # Define special token mapping - special_token_ids = { - "vision": vision_token_id, - "audio": audio_token_id, - "video": video_token_id, - } - - # Align embeddings - modality_embeddings = { - "text": text_embeddings, - "vision": vision_embeddings, - "audio": audio_embeddings, - "video": video_embeddings, - } - - combined = self.model.align_embeddings_by_token_positions( - modality_embeddings=modality_embeddings, - input_ids=input_ids, - special_token_ids=special_token_ids, - ) - - # Check output shape - assert combined.shape == (seq_length, batch_size, hidden_dim) - - # Check that special token positions have the correct markers and only at correct positions - - # Batch 0 markers - assert torch.isclose(combined[1, 0, 0], torch.tensor(10.0, device=self.device)) # Vision - assert torch.isclose(combined[4, 0, 1], torch.tensor(30.0, device=self.device)) # Audio - assert torch.isclose(combined[8, 0, 2], torch.tensor(50.0, device=self.device)) # Video - - # Batch 1 markers - assert torch.isclose(combined[2, 1, 0], torch.tensor(20.0, device=self.device)) # Vision - assert torch.isclose(combined[5, 1, 1], torch.tensor(40.0, device=self.device)) # Audio - assert torch.isclose(combined[7, 1, 2], torch.tensor(60.0, device=self.device)) # Video - - # Also check that markers are ONLY at their specific positions - # For vision in batch 0 (position 1, value at index 0) - assert torch.all(combined[1, 0, 1:] == 0.0), "Non-zero values found after marker" - - # For audio in batch 1 (position 5, value at index 1) - assert torch.all(combined[5, 1, :1] == 0.0), "Non-zero values found before marker" - assert torch.all(combined[5, 1, 2:] == 0.0), "Non-zero values found after marker" - - def test_multiple_images_with_variable_length(self): - """Test handling multiple images per sample with variable sequence lengths. - - This test verifies that: - 1. Multiple image occurrences per batch sample are handled correctly - 2. Images with different sequence lengths are processed properly - 3. The batch-first ordering is preserved - 4. Embeddings are correctly placed at their corresponding positions - """ - # Create a test case with 2 batches: - # - Batch 0: 2 images with different sequence lengths (3 and 2 patches) - # - Batch 1: 1 image with 4 patches - batch_size = 2 - seq_length = 10 - hidden_dim = self.hidden_dim - - # Create input_ids with vision special tokens - input_ids = torch.full((batch_size, seq_length), 100, dtype=torch.long, device=self.device) - - # Define vision token ID - vision_token_id = 50 - - # Place special tokens: - # Batch 0: positions 1, 2, 3 (first image, 3 patches) and 5, 6 (second image, 2 patches) - # Batch 1: positions 2, 3, 4, 5 (one image, 4 patches) - # Batch 0 - first image (3 patches) - input_ids[0, 1] = vision_token_id - input_ids[0, 2] = vision_token_id - input_ids[0, 3] = vision_token_id - - # Batch 0 - second image (2 patches) - input_ids[0, 5] = vision_token_id - input_ids[0, 6] = vision_token_id - - # Batch 1 - one image (4 patches) - input_ids[1, 2] = vision_token_id - input_ids[1, 3] = vision_token_id - input_ids[1, 4] = vision_token_id - input_ids[1, 5] = vision_token_id - - # Count text tokens (all non-vision tokens) - # Batch 0: 5 text tokens (positions 0, 4, 7, 8, 9) - # Batch 1: 6 text tokens (positions 0, 1, 6, 7, 8, 9) - text_embeddings = torch.full((11, hidden_dim), 0.01, device=self.device) - - # Create the unflattened embeddings that would come from a vision encoder - # First, create 3 tensors with different sequence lengths: - - # Batch 0, Image 1: 3 patches - image_0_1 = self.create_marker_embeddings( - num_embeddings=3, - marker_positions=[0, 1, 2], - marker_values=[101.0, 102.0, 103.0], # Distinct values for each patch - ) - - # Batch 0, Image 2: 2 patches - image_0_2 = self.create_marker_embeddings( - num_embeddings=2, marker_positions=[3, 4], marker_values=[104.0, 105.0] - ) - - # Batch 1, Image 1: 4 patches - image_1_1 = self.create_marker_embeddings( - num_embeddings=4, - marker_positions=[5, 6, 7, 8], - marker_values=[201.0, 202.0, 203.0, 204.0], - ) - - # Flatten the images as the vision submodule would do - # They should be concatenated in batch order - vision_embeddings = torch.cat([image_0_1, image_0_2, image_1_1], dim=0) - - # Define special token IDs - special_token_ids = {"vision": vision_token_id} - - # Create modality embeddings - modality_embeddings = {"text": text_embeddings, "vision": vision_embeddings} - - # Align embeddings - combined = self.model.align_embeddings_by_token_positions( - modality_embeddings=modality_embeddings, - input_ids=input_ids, - special_token_ids=special_token_ids, - ) - - # Check output shape - assert combined.shape == (seq_length, batch_size, hidden_dim) - - # Verify vision token embeddings are placed correctly - - # Batch 0, first image embeddings (3 patches) - assert torch.isclose(combined[1, 0, 0], torch.tensor(101.0, device=self.device)) - assert torch.isclose(combined[2, 0, 1], torch.tensor(102.0, device=self.device)) - assert torch.isclose(combined[3, 0, 2], torch.tensor(103.0, device=self.device)) - - # Batch 0, second image embeddings (2 patches) - assert torch.isclose(combined[5, 0, 3], torch.tensor(104.0, device=self.device)) - assert torch.isclose(combined[6, 0, 4], torch.tensor(105.0, device=self.device)) - - # Batch 1, image embeddings (4 patches) - assert torch.isclose(combined[2, 1, 5], torch.tensor(201.0, device=self.device)) - assert torch.isclose(combined[3, 1, 6], torch.tensor(202.0, device=self.device)) - assert torch.isclose(combined[4, 1, 7], torch.tensor(203.0, device=self.device)) - assert torch.isclose(combined[5, 1, 8], torch.tensor(204.0, device=self.device)) - - # Verify that each embedding only has one non-zero value - for b in range(batch_size): - # Check positions with special tokens - positions = [(1, 2, 3, 5, 6), (2, 3, 4, 5)][b] - for s in positions: - emb = combined[s, b].clone() - # Find the non-zero position - nonzero_indices = torch.nonzero(emb) - # Make sure we actually have non-zero values - assert ( - nonzero_indices.nelement() > 0 - ), f"No non-zero values found at position {s},{b}" - nonzero_pos = nonzero_indices[0].item() - # Check that all other positions are zero - assert torch.all( - emb[:nonzero_pos] == 0.0 - ), f"Non-zero values found before marker at {s},{b}" - assert torch.all( - emb[nonzero_pos + 1 :] == 0.0 - ), f"Non-zero values found after marker at {s},{b}" - - def test_validation_errors(self): - """Test validation errors when token counts don't match embedding counts.""" - batch_size = 2 - seq_length = 5 - hidden_dim = self.hidden_dim - - # Create input_ids with different numbers of tokens - input_ids = torch.full((batch_size, seq_length), 100, dtype=torch.long, device=self.device) - - # Add 3 special tokens for vision - vision_token_id = 50 - input_ids[0, 1] = vision_token_id - input_ids[0, 3] = vision_token_id - input_ids[1, 2] = vision_token_id - - # Create text embeddings (non-zero for better verification) - # We have 3 vision tokens, so we need: - # (batch_size * seq_length) - num_vision_tokens = 2*5 - 3 = 7 text embeddings - text_embeddings = torch.full((7, hidden_dim), 0.01, device=self.device) - - # Create vision embeddings with only 2 embeddings (not enough for 3 tokens) - vision_embeddings = self.create_marker_embeddings(2) - - special_token_ids = {"vision": vision_token_id} - - modality_embeddings = {"text": text_embeddings, "vision": vision_embeddings} - - # Should raise a ValueError because we have 3 special tokens but only 2 embeddings - with pytest.raises(ValueError, match="Number of vision tokens.*does not match"): - self.model.align_embeddings_by_token_positions( - modality_embeddings=modality_embeddings, - input_ids=input_ids, - special_token_ids=special_token_ids, - ) - - # Test with wrong number of text tokens - input_ids = torch.full((batch_size, seq_length), 100, dtype=torch.long, device=self.device) - - # Add 1 special token in each batch - input_ids[0, 1] = vision_token_id - input_ids[1, 2] = vision_token_id - - # This would leave 8 text tokens (4 per batch), but we'll provide only 6 - text_embeddings = torch.full((6, hidden_dim), 0.01, device=self.device) - - # Create matching vision embeddings (correct count this time) - vision_embeddings = self.create_marker_embeddings(2) - - modality_embeddings = {"text": text_embeddings, "vision": vision_embeddings} - - # Should raise a ValueError for mismatched text token count - with pytest.raises(ValueError, match="Number of text tokens.*does not match"): - self.model.align_embeddings_by_token_positions( - modality_embeddings=modality_embeddings, - input_ids=input_ids, - special_token_ids=special_token_ids, - ) - - def test_missing_special_token_id(self): - """Test error when a modality is missing from special_token_ids.""" - batch_size = 2 - seq_length = 5 - hidden_dim = self.hidden_dim - - # Create input_ids - input_ids = torch.full((batch_size, seq_length), 100, dtype=torch.long, device=self.device) - - # Define text embeddings with non-zero value - text_embeddings = torch.full( - (batch_size * seq_length, hidden_dim), 0.01, device=self.device - ) - - # Create vision embeddings (not referenced in special_token_ids) - vision_embeddings = self.create_marker_embeddings(1) - - # Empty special_token_ids - special_token_ids = {} - - modality_embeddings = { - "text": text_embeddings, - "vision": vision_embeddings, # Not in special_token_ids - } - - # Should raise a ValueError because vision modality is missing from special_token_ids - with pytest.raises(ValueError, match="No special token ID defined for modality vision"): - self.model.align_embeddings_by_token_positions( - modality_embeddings=modality_embeddings, - input_ids=input_ids, - special_token_ids=special_token_ids, - ) diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py deleted file mode 100644 index f786f118c6..0000000000 --- a/tests/unit_tests/models/test_mimo_model.py +++ /dev/null @@ -1,457 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -''' -WORLD_SIZE=1 LOCAL_RANK=0 python -m pytest tests/unit_tests/models/test_mimo_model.py -''' - -import math - -import pytest -import torch -import torch.nn as nn -from transformers import WhisperConfig, WhisperModel - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.models.mimo.config.base_configs import MimoModelConfig -from megatron.core.models.mimo.model.base import MimoModel -from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules -from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules -from megatron.core.models.vision.clip_vit_model import CLIPViTModel -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - -pytest.importorskip("modelopt", minversion="0.25") -# modelopt version < 0.27 breaks HF AutoModel.from_pretrained API -# so we need to skip the tests unitl versions are bumped in pyt LTS CI container - - -class AudioEncoderWrapper(torch.nn.Module): - """Generic wrapper for audio encoder models that extracts last_hidden_state.""" - - def __init__(self, config): - super().__init__() - # Use a local Whisper model (tiny config) to avoid checkpoint download - self.encoder = WhisperModel(WhisperConfig()).encoder - - def forward(self, input_features): - # Process through encoder and extract last_hidden_state - with torch.no_grad(): - return self.encoder(input_features).last_hidden_state - - -def get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim): - """Get the submodule spec for the vision modality.""" - vision_layer_spec = get_gpt_layer_with_transformer_engine_spec() - - vision_config = TransformerConfig( - num_layers=1, hidden_size=hidden_size, num_attention_heads=4, use_cpu_initialization=True - ) - vision_encoder_spec = ModuleSpec( - module=CLIPViTModel, - params={ - "transformer_config": vision_config, - "transformer_layer_spec": vision_layer_spec, - "img_h": img_h, - "img_w": img_w, - "patch_dim": patch_dim, - }, - ) - - # Create vision projection spec - vision_projection_spec = ModuleSpec( - module=nn.Linear, - params={ - "in_features": vision_config.hidden_size, - "out_features": vision_config.hidden_size, - }, - ) - - # Create vision modality spec - vision_submodule_spec = ModuleSpec( - module=VisionModalitySubmodules, - submodules={ - "encoders": {"clip_encoder": vision_encoder_spec}, - "input_projections": [vision_projection_spec], - }, - ) - - return vision_submodule_spec - - -def get_audio_submodules_spec(hidden_size): - """Get the submodule spec for the audio modality.""" - - class AudioEncoderWrapper(torch.nn.Module): - """Generic wrapper for audio encoder models that extracts last_hidden_state.""" - - def __init__(self, model_name="openai/whisper-tiny"): - super().__init__() - # Local tiny Whisper model with random weights - self.encoder = WhisperModel(WhisperConfig()).encoder - - def forward(self, input_features): - # Process through encoder and extract last_hidden_state - with torch.no_grad(): - return self.encoder(input_features).last_hidden_state - - # Audio modality configuration - audio_encoder_spec = ModuleSpec( - module=AudioEncoderWrapper, params={"model_name": "openai/whisper-tiny"} - ) - - audio_projection_spec = ModuleSpec( - module=nn.Linear, - params={"in_features": 384, "out_features": hidden_size}, # Whisper tiny hidden size - ) - - audio_submodule_spec = ModuleSpec( - module=AudioModalitySubmodules, - submodules={ - "encoders": {"whisper_encoder": audio_encoder_spec}, - "input_projections": [audio_projection_spec], - }, - ) - - return audio_submodule_spec - - -def get_language_model_spec(hidden_size, vocab_size, seq_len): - """Get the language model spec.""" - lm_config = TransformerConfig( - num_layers=2, hidden_size=hidden_size, num_attention_heads=4, use_cpu_initialization=True - ) - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() - language_model_spec = ModuleSpec( - module=GPTModel, - params={ - "config": lm_config, - "transformer_layer_spec": language_layer_spec, - "vocab_size": vocab_size, - "max_sequence_length": seq_len, - "pre_process": True, - "post_process": True, - }, - ) - return language_model_spec - - -def get_avlm_mimo_model( - hidden_size, vocab_size, seq_len, img_h, img_w, patch_dim, special_token_ids -): - language_model_spec = get_language_model_spec(hidden_size, vocab_size, seq_len) - vision_submodule_spec = get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim) - audio_submodule_spec = get_audio_submodules_spec(hidden_size) - - mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={"images": vision_submodule_spec, "audio": audio_submodule_spec}, - special_token_ids=special_token_ids, - ) - - # Create MIMO model - mimo_model = MimoModel(mimo_config) - return mimo_model - - -def get_vlm_mimo_model( - hidden_size, vocab_size, seq_len, img_h, img_w, patch_dim, special_token_ids -): - language_model_spec = get_language_model_spec(hidden_size, vocab_size, seq_len) - vision_submodule_spec = get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim) - - mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={"images": vision_submodule_spec}, - special_token_ids=special_token_ids, - ) - - # Create MIMO model - mimo_model = MimoModel(mimo_config) - return mimo_model - - -class TestMimoModel: - """Test the MimoModel class.""" - - def setup_method(self, method): - '''setup env and model''' - try: - Utils.initialize_model_parallel(1, 1) - except Exception as e: - print(f"Warning: Could not initialize model parallel: {e}") - - # Set dimensions - self.hidden_size = 64 - self.batch_size = 2 - self.seq_len = 2048 - self.img_h = 224 - self.img_w = 224 - self.patch_dim = 16 - self.vocab_size = 48000 - - # Define special token IDs, not in LLM vocab - self.special_token_ids = {"images": 50257, "audio": 50258} - - def teardown_method(self, method): - '''teardown env''' - try: - Utils.destroy_model_parallel() - except Exception as e: - print(f"Warning: Could not destroy model parallel: {e}") - - def test_constructor(self): - """Test constructor initialization.""" - - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - - # Move to device - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - mimo_model = mimo_model.to(device) - - # Test that modality submodules were initialized correctly - assert "images" in mimo_model.modality_submodules - assert "audio" in mimo_model.modality_submodules - assert isinstance(mimo_model.modality_submodules["images"], VisionModalitySubmodules) - assert isinstance(mimo_model.modality_submodules["audio"], AudioModalitySubmodules) - # Test that language model was initialized - assert hasattr(mimo_model, "language_model") - assert isinstance(mimo_model.language_model, GPTModel) - - # Test that special token IDs were set correctly - assert mimo_model.special_token_ids == self.special_token_ids - - def test_get_text_embeddings(self): - """Test getting text embeddings.""" - # Create random input and position IDs (within vocab size range) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - # Get text embeddings - text_embeddings = mimo_model.get_text_embeddings( - input_ids, position_ids, self.special_token_ids - ) - # Verify shape - # [b*s, h] - assert text_embeddings.shape == (self.batch_size * self.seq_len, self.hidden_size) - - def test_forward_text_only(self): - """Test forward pass with only text input.""" - # Create inputs - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) - - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - # Run forward pass with explicit parameters - outputs, _ = mimo_model( - input_ids=input_ids, position_ids=position_ids, modality_inputs=None - ) - assert outputs is not None - - # Verify output shape - assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) - - def test_forward_with_image_modality(self): - """Test forward pass with text and image input.""" - # Calculate expected number of image tokens based on image size and patch dimension - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - expected_img_seq_len = (self.img_h // self.patch_dim) * ( - self.img_w // self.patch_dim - ) + 1 # +1 for CLS token - - # Create a fixed distribution of images: 3 in first sample, 2 in second sample - num_images = 5 - images_per_sample = [3, 2] # Must sum to num_images - assert sum(images_per_sample) == num_images - assert len(images_per_sample) == self.batch_size - - # Create images tensor - images = torch.rand( - num_images, 3, self.img_h, self.img_w, device=device - ) # [num_images, 3, h, w] format - - # Create input_ids with text tokens - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - - # Create position_ids - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) - - # Include image special tokens in input IDs - image_token_id = self.special_token_ids["images"] - start_pos = 5 # Start position for image tokens - - # Make sure there's enough space in the sequence for all image tokens in each sample - for b in range(self.batch_size): - tokens_needed = images_per_sample[b] * expected_img_seq_len - assert ( - start_pos + tokens_needed <= self.seq_len - ), f"Sequence length too short for image tokens in sample {b}" - - # Add image tokens to each batch sample according to its number of images - for b in range(self.batch_size): - tokens_in_this_batch = images_per_sample[b] * expected_img_seq_len - if tokens_in_this_batch > 0: - input_ids[b, start_pos : start_pos + tokens_in_this_batch] = image_token_id - - # Create modality inputs using the new structure - modality_inputs = {"images": {"clip_encoder": {"x": images}}} - - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - - # Run forward pass with new interface - outputs, _ = mimo_model( - input_ids=input_ids, position_ids=position_ids, modality_inputs=modality_inputs - ) - assert outputs is not None - - # Verify output shape - assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) - - def test_forward_with_image_and_audio_modality(self): - """Test forward pass with text, image, and audio input.""" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - - # Calculate image sequence length - img_seq_len = (self.img_h // self.patch_dim) * (self.img_w // self.patch_dim) + 1 - - encoder_down_sampling = 2 - - # Create simple audio input (30 sec) - mel_bins = 80 # Whisper uses 80 mel bins - time_bins = 3000 # 30 seconds of audio at 10ms per frame - audio_features = torch.rand(2, mel_bins, time_bins, device=device) - - # Calculate audio sequence length using Whisper's formula - audio_seq_len = math.ceil(time_bins / encoder_down_sampling) # 1500 tokens - - # Create batch data - batch_size = 2 - seq_len = self.seq_len - - # Create input_ids with special tokens - input_ids = torch.randint(0, self.vocab_size, (batch_size, seq_len), device=device) - position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) - - # Add special tokens at specific positions - start_pos = 5 - image_token_id = self.special_token_ids["images"] - audio_token_id = self.special_token_ids["audio"] - - # Place image tokens followed by audio tokens in each batch item - for i in range(batch_size): - # Add image tokens - input_ids[i, start_pos : start_pos + img_seq_len] = image_token_id - # Add audio tokens after a gap - input_ids[ - i, start_pos + img_seq_len + 10 : start_pos + img_seq_len + 10 + audio_seq_len - ] = audio_token_id - - # Prepare modality inputs - modality_inputs = { - "images": { - "clip_encoder": {"x": torch.rand(2, 3, self.img_h, self.img_w, device=device)} - }, - "audio": {"whisper_encoder": {"input_features": audio_features}}, - } - - # Run forward pass - outputs, _ = mimo_model( - input_ids=input_ids, position_ids=position_ids, modality_inputs=modality_inputs - ) - - # Verify output shape - assert outputs is not None - assert outputs.shape == (batch_size, seq_len, self.vocab_size) - - def test_state_dict(self): - """Test state dict methods.""" - # Get state dict - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - state_dict = mimo_model.state_dict() - assert len(state_dict) > 0 - - # Make sure we have keys for language model and modality submodules - has_lm_keys = False - has_modality_keys = False - - for key in state_dict.keys(): - if key.startswith("language_model."): - has_lm_keys = True - if key.startswith("modality_submodules."): - has_modality_keys = True - - assert has_lm_keys - assert has_modality_keys - - # Test checkpoint state dict - checkpoint_dict = mimo_model.state_dict_for_save_checkpoint() - assert len(checkpoint_dict) > 0 diff --git a/tests/unit_tests/models/test_mimo_submodules.py b/tests/unit_tests/models/test_mimo_submodules.py deleted file mode 100644 index 6111394cc1..0000000000 --- a/tests/unit_tests/models/test_mimo_submodules.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -''' -WORLD_SIZE=1 LOCAL_RANK=0 python -m torch.distributed.run \ - --nproc_per_node=1 -m pytest \ - tests/unit_tests/models/test_mimo_submodules.py -v -''' - -from typing import Any, Dict, List, Optional - -import pytest -import torch -import torch.nn as nn - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.mimo.submodules.base import ModalitySubmodules -from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules -from megatron.core.models.vision.clip_vit_model import CLIPViTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class MockModalitySubmodule(ModalitySubmodules): - """Concrete implementation of ModalitySubmodules for testing purposes.""" - - def combine_embeddings(self, embeddings: List[torch.Tensor]) -> torch.Tensor: - return - - def encode(self, data_batch: Dict) -> List[torch.Tensor]: - return [] - - def decode(self, embeddings: torch.Tensor, data_batch: Dict) -> torch.Tensor: - return - - def project_embeddings( - self, embeddings: List[torch.Tensor], is_input: bool = True - ) -> Optional[torch.Tensor]: - return None - - def forward(self, encoder_inputs: Dict[str, Any], seq_lengths: Optional[torch.Tensor] = None): - return None - - -@pytest.mark.experimental -class TestBaseSubmodule: - """Test the base ModalitySubmodules class initialization.""" - - def setup_method(self, method): - '''setup env''' - # Initialize distributed environment - try: - Utils.initialize_model_parallel(1, 1) - except Exception as e: - print(f"Warning: Could not initialize model parallel: {e}") - - # Create transformer config for vision encoder - self.vision_config = TransformerConfig( - num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True - ) - - # Create layer spec for transformer - self.layer_spec = get_gpt_layer_with_transformer_engine_spec() - - # Define vision encoder parameters - self.img_h = 224 - self.img_w = 224 - self.patch_dim = 16 - - # Create encoder spec (using CLIP ViT model) - self.encoder_spec = ModuleSpec( - module=CLIPViTModel, - params={ - "transformer_config": self.vision_config, - "transformer_layer_spec": self.layer_spec, - "img_h": self.img_h, - "img_w": self.img_w, - "patch_dim": self.patch_dim, - }, - ) - - # Create projection spec - self.projection_spec = ModuleSpec( - module=nn.Linear, - params={ - "in_features": self.vision_config.hidden_size, - "out_features": self.vision_config.hidden_size, - }, - ) - - # Create the main module spec - self.module_spec = ModuleSpec( - module=MockModalitySubmodule, - submodules={ - "encoders": {"clip_encoder": self.encoder_spec}, - "input_projections": [self.projection_spec], - }, - ) - - def teardown_method(self, method): - '''teardown env''' - try: - Utils.destroy_model_parallel() - except Exception as e: - print(f"Warning: Could not destroy model parallel: {e}") - - def test_initialize_with_modules(self): - """Test constructor with pre-built modules.""" - # Create actual modules - encoder = CLIPViTModel( - transformer_config=self.vision_config, - transformer_layer_spec=self.layer_spec, - img_h=self.img_h, - img_w=self.img_w, - patch_dim=self.patch_dim, - ) - - projection = nn.Linear( - in_features=self.vision_config.hidden_size, out_features=self.vision_config.hidden_size - ) - - # Create submodule with modules - submodule = MockModalitySubmodule( - encoders={"clip_encoder": encoder}, input_projections=[projection] - ) - - # Check modules are set correctly - assert len(submodule.encoders) == 1 - assert len(submodule.decoders) == 0 - assert len(submodule.input_projections) == 1 - assert len(submodule.output_projections) == 0 - - # Check the encoder module is of the right type - assert isinstance(submodule.encoders['clip_encoder'], CLIPViTModel) - - # Check the projection module is of the right type - assert isinstance(submodule.input_projections[0], nn.Linear) - - def test_initialize_from_spec(self): - """Test creating a submodule from a ModuleSpec with real modules.""" - # Create from spec - submodule_from_spec = MockModalitySubmodule.from_spec(self.module_spec) - - # Verify the submodule was created correctly - assert len(submodule_from_spec.encoders) == 1 - assert len(submodule_from_spec.decoders) == 0 - assert len(submodule_from_spec.input_projections) == 1 - assert len(submodule_from_spec.output_projections) == 0 - - # Check the encoder modules are of the right type - assert isinstance(submodule_from_spec.encoders['clip_encoder'], CLIPViTModel) - - # Check the projection module is of the right type - assert isinstance(submodule_from_spec.input_projections[0], nn.Linear) - - # Check parameters of the encoder - encoder = submodule_from_spec.encoders['clip_encoder'] - assert encoder.img_h == self.img_h - assert encoder.img_w == self.img_w - assert encoder.patch_dim == self.patch_dim - - # Check parameters of the projection - projection = submodule_from_spec.input_projections[0] - assert projection.in_features == self.vision_config.hidden_size - assert projection.out_features == self.vision_config.hidden_size - - -@pytest.mark.experimental -class TestVisionSubmodule: - """Test the VisionModalitySubmodules class with forward passes.""" - - def setup_method(self, method): - '''setup env''' - # Initialize distributed environment - try: - Utils.initialize_model_parallel(1, 1) - except Exception as e: - print(f"Warning: Could not initialize model parallel: {e}") - - model_parallel_cuda_manual_seed(123) - - self.hidden_size = 64 - self.vision_config = TransformerConfig( - num_layers=1, - hidden_size=self.hidden_size, - num_attention_heads=4, - use_cpu_initialization=True, - ) - - # Create layer spec for transformer - self.layer_spec = get_gpt_layer_with_transformer_engine_spec() - - # Define vision parameters - self.img_h = 224 - self.img_w = 224 - self.patch_dim = 16 - - # Create vision encoder - self.vision_encoder = CLIPViTModel( - transformer_config=self.vision_config, - transformer_layer_spec=self.layer_spec, - img_h=self.img_h, - img_w=self.img_w, - patch_dim=self.patch_dim, - ) - - # Create projection layer - self.input_projection = nn.Linear(self.hidden_size, self.hidden_size) - - # Create output projection - self.output_projection = nn.Linear(self.hidden_size, self.hidden_size) - - # Create VisionModalitySubmodules with encoder and projection - self.vision_submodule = VisionModalitySubmodules( - encoders={"clip_encoder": self.vision_encoder}, - input_projections=[self.input_projection], - ) - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.vision_submodule = self.vision_submodule.to(self.device) - - # Set all modules to eval mode to disable dropout and other stochastic layers - # This makes tests more deterministic - self.vision_submodule.eval() - self.vision_encoder.eval() - self.input_projection.eval() - - def teardown_method(self, method): - '''teardown env''' - try: - Utils.destroy_model_parallel() - except Exception as e: - print(f"Warning: Could not destroy model parallel: {e}") - - def test_encode_with_random_data(self): - """Test encoding with random image data.""" - # Create random batch of images - num_images = 2 - images = torch.rand(num_images, 3, self.img_h, self.img_w, device=self.device) - data_batch = {"clip_encoder": {"x": images}} - - # Test encode method - embeddings = self.vision_submodule.encode(data_batch) - - # Verify embeddings shape and content - assert len(embeddings) == 1 # One encoder - embedding = embeddings[0] - - # Number of tokens depends on image size and patch size - expected_seq_len = (self.img_h // self.patch_dim) * ( - self.img_w // self.patch_dim - ) + 1 # +1 for cls token - assert embedding.shape[0] == num_images * expected_seq_len - assert embedding.shape[1] == self.hidden_size - - def test_combine_embeddings(self): - """Test combining embeddings functionality.""" - # Create test embeddings with different sequence lengths - num_images = 2 - seq_len1 = 10 - seq_len2 = 15 - - # Create test embeddings - embedding1 = torch.rand(num_images * seq_len1, self.hidden_size, device=self.device) - embedding2 = torch.rand(num_images * seq_len2, self.hidden_size, device=self.device) - embeddings = [embedding1, embedding2] - - # Test combining embeddings - combined = self.vision_submodule.combine_embeddings(embeddings) - assert combined.shape == (num_images * (seq_len1 + seq_len2), self.hidden_size) - - # Test combining a single embedding - single_combined = self.vision_submodule.combine_embeddings([embedding1]) - assert single_combined.shape == (num_images * seq_len1, self.hidden_size) - assert torch.all(single_combined == embedding1) - - # Test combining empty embeddings raises error - with pytest.raises(ValueError): - self.vision_submodule.combine_embeddings([]) - - def test_forward_pass(self): - """Test the complete forward pass.""" - # Create random batch of images - num_images = 2 - images = torch.rand(num_images, 3, self.img_h, self.img_w, device=self.device) - data_batch = {"clip_encoder": {"x": images}} - - # Test forward pass - output = self.vision_submodule(data_batch) - assert output is not None - - # Check output shape - flattened to [num_image_embeddings, hidden_dim] - expected_seq_len = (self.img_h // self.patch_dim) * (self.img_w // self.patch_dim) + 1 - expected_total_embeddings = num_images * expected_seq_len - assert output.shape == (expected_total_embeddings, self.hidden_size) - - def test_empty_data_batch(self): - """Test forward pass with empty data batch.""" - # Create a data batch without images - data_batch = {} - - # Test forward pass - output = self.vision_submodule(data_batch) - assert output is None diff --git a/tests/unit_tests/models/test_multimodal_projector.py b/tests/unit_tests/models/test_multimodal_projector.py deleted file mode 100644 index 52fda330c2..0000000000 --- a/tests/unit_tests/models/test_multimodal_projector.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - - -import torch - -from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec -from megatron.core.models.vision.multimodal_projector import MultimodalProjector -from megatron.core.tensor_parallel.layers import ColumnParallelLinear -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.mlp import MLPSubmodules -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestMultimodalProjector: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True - ) - mlp_layer_spec = get_mlp_module_spec().submodules - - affine_layer_spec = MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=None) - self.mlp = MultimodalProjector( - config=transformer_config, - submodules=mlp_layer_spec, - projector_type="mlp", - input_size=1024, - ) - self.affine = MultimodalProjector( - config=transformer_config, - submodules=affine_layer_spec, - projector_type="affine", - input_size=1024, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - assert isinstance(self.mlp, MultimodalProjector) - assert isinstance(self.affine, MultimodalProjector) - - num_weights = sum([p.numel() for p in self.mlp.parameters()]) - assert num_weights == 280896 - - num_weights = sum([p.numel() for p in self.affine.parameters()]) - assert num_weights == 65600 - - def test_forward(self): - self.mlp.cuda() - self.affine.cuda() - - image_projection = torch.zeros((2, 1024)).cuda() - - logits = self.mlp.forward(image_projection) - assert len(logits) == 2 - assert logits.shape == torch.Size([2, 64]) - - logits = self.affine.forward(image_projection) - assert len(logits) == 2 - assert logits.shape == torch.Size([2, 64]) - - def test_save_load(self, tmp_path): - path = tmp_path / "mlp.pt" - torch.save(self.mlp.state_dict(), path) - - self.mlp.load_state_dict(torch.load(path)) - - path = tmp_path / "affine.pt" - torch.save(self.affine.state_dict(), path) - - self.affine.load_state_dict(torch.load(path)) diff --git a/tests/unit_tests/models/test_radio_model.py b/tests/unit_tests/models/test_radio_model.py deleted file mode 100644 index de51d57911..0000000000 --- a/tests/unit_tests/models/test_radio_model.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import pytest -import torch - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.vision.radio import RADIOViTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestRADIOViTModel: - """Test RADIO ViT model.""" - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True - ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec() - self.model = RADIOViTModel( - transformer_config, - transformer_layer_spec, - img_h=224, - img_w=224, - patch_dim=14, - add_class_token=False, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - assert isinstance(self.model, RADIOViTModel) - - num_weights = sum([p.numel() for p in self.model.parameters()]) - assert num_weights == 1501824 - - def test_set_input_tensor(self): - # [s, b, h] expected to the transformer. - expected_shape = (256, 2, 64) - input_tensor = torch.zeros(expected_shape) - - self.model.set_input_tensor(input_tensor) - - assert self.model.decoder.input_tensor.shape == torch.Size(expected_shape) - - def test_forward(self): - self.model.cuda() - - img = torch.zeros((2, 3, 224, 224)).cuda() - - out = self.model.forward(img) - assert out.shape == torch.Size([2, 256, 64]) - - def test_save_load(self, tmp_path): - path = tmp_path / "model.pt" - torch.save(self.model.state_dict(), path) - - self.model.load_state_dict(torch.load(path)) diff --git a/tests/unit_tests/models/test_t5_model.py b/tests/unit_tests/models/test_t5_model.py deleted file mode 100644 index cc4e348d01..0000000000 --- a/tests/unit_tests/models/test_t5_model.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import os -from copy import deepcopy - -import pytest -import torch -from packaging.version import Version as PkgVersion -from pytest_mock import mocker - -import megatron.core.parallel_state as ps -from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset -from megatron.core.models.T5.t5_model import T5Model -from megatron.core.models.T5.t5_spec import ( - get_t5_decoder_with_local_block_spec, - get_t5_decoder_with_transformer_engine_block_spec, - get_t5_encoder_with_local_block_spec, - get_t5_encoder_with_transformer_engine_block_spec, -) -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestT5Model: - - def setup_method(self, method): - tp = 4 - pp = 1 - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp - ) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=12, - hidden_size=768, - num_attention_heads=12, - kv_channels=64, - ffn_hidden_size=3072, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - tensor_model_parallel_size=tp, - pipeline_model_parallel_size=pp, - ) - rank = ps.get_pipeline_model_parallel_rank() - world_size = ps.get_pipeline_model_parallel_world_size() - en_block_spec = get_t5_encoder_with_transformer_engine_block_spec(12) - de_block_spec = get_t5_decoder_with_transformer_engine_block_spec(12) - - pre_process = True - post_process = True - add_encoder = True - add_decoder = True - - self.t5_model = T5Model( - encoder_config=transformer_config, - config=transformer_config, - transformer_encoder_layer_spec=en_block_spec, - transformer_decoder_layer_spec=de_block_spec, - vocab_size=29184, - max_sequence_length=4, - pre_process=pre_process, - post_process=post_process, - add_encoder=add_encoder, - add_decoder=add_decoder, - model_comm_pgs=ModelCommProcessGroups.use_mpu_process_groups( - required_pgs=['tp', 'cp', 'pp'] - ), - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - assert isinstance(self.t5_model, T5Model) - assert Utils.world_size == 8 - - assert self.t5_model.max_sequence_length == 4 - assert self.t5_model.add_decoder - assert self.t5_model.decoder.num_layers_per_pipeline_rank == 12 - assert self.t5_model.decoder.num_layers_per_pipeline_rank == 12 - assert self.t5_model.pre_process - assert self.t5_model.post_process - - def test_set_input_tensor(self): - config: TransformerConfig = self.t5_model.config - sequence_length = self.t5_model.max_sequence_length - micro_batch_size = 2 - - # [sequence length, batch size, hidden size] - input_tensor = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - - self.t5_model.set_input_tensor(input_tensor) - - if self.t5_model.add_encoder: - assert self.t5_model.encoder.input_tensor.shape[0] == sequence_length - assert self.t5_model.encoder.input_tensor.shape[1] == micro_batch_size - assert self.t5_model.encoder.input_tensor.shape[2] == config.hidden_size - else: - assert self.t5_model.encoder is None - assert self.t5_model.encoder_hidden_state.shape[0] == sequence_length - assert self.t5_model.encoder_hidden_state.shape[1] == micro_batch_size - assert self.t5_model.encoder_hidden_state.shape[2] == config.hidden_size - - @pytest.mark.flaky_in_dev - def test_post_process_forward(self): - pass - - def test_forward_output_encoder_hidden_only(self): - pass - - def test_forward_with_encoder_hidden_states(self): - pass - - def test_no_post_process_forward(self): - pass - - def test_no_preprocess_forward(self): - pass - - def test_state_dict_for_save_checkpoint(self): - pass - - def test_load_state_dict(self): - pass - - -class TestT5ModelAttentionDimensions: - - def teardown_method(self, method): - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - - def setup_method(self, method): - self.bs = 4 - self.seq_len = 512 - self.seq_len_dec = 128 - self.encoder_tokens = torch.ones([self.bs, self.seq_len]) - self.decoder_tokens = torch.ones([self.bs, self.seq_len_dec]) - self.encoder_mask = torch.ones([self.bs, self.seq_len]) < 0.5 - self.decoder_mask = torch.ones([self.bs, self.seq_len_dec]) < 0.5 - - @pytest.mark.internal - def test_local_spec(self): - encoder_mask, decoder_mask, encoder_decoder_mask = ( - T5MaskedWordPieceDataset.config_attention_mask( - self.encoder_tokens, - self.decoder_tokens, - self.encoder_mask, - self.decoder_mask, - use_local=True, - ) - ) - - assert list(encoder_mask.shape) == [self.bs, 1, self.seq_len, self.seq_len] - assert list(decoder_mask.shape) == [self.bs, 1, self.seq_len_dec, self.seq_len_dec] - assert list(encoder_decoder_mask.shape) == [self.bs, 1, self.seq_len_dec, self.seq_len] - - @pytest.mark.internal - def test_transformer_engine_version_1_10(self): - encoder_mask, decoder_mask, encoder_decoder_mask = ( - T5MaskedWordPieceDataset.config_attention_mask( - self.encoder_tokens, - self.decoder_tokens, - self.encoder_mask, - self.decoder_mask, - use_local=False, - test_te_version="1.10", - ) - ) - - assert list(encoder_mask.shape) == [self.bs, 1, 1, self.seq_len] - assert decoder_mask is None - assert list(encoder_decoder_mask[0].shape) == [self.bs, 1, 1, self.seq_len_dec] - assert list(encoder_decoder_mask[1].shape) == [self.bs, 1, 1, self.seq_len] - - @pytest.mark.internal - def test_transformer_engine_version_1_7_to_1_10_flashfused_attn(self): - os.environ['NVTE_FLASH_ATTN'] = '1' - os.environ['NVTE_FUSED_ATTN'] = '1' - - encoder_mask, decoder_mask, encoder_decoder_mask = ( - T5MaskedWordPieceDataset.config_attention_mask( - self.encoder_tokens, - self.decoder_tokens, - self.encoder_mask, - self.decoder_mask, - use_local=False, - test_te_version="1.8", - ) - ) - - assert list(encoder_mask.shape) == [self.bs, 1, 1, self.seq_len] - assert decoder_mask is None - assert list(encoder_decoder_mask[0].shape) == [self.bs, 1, 1, self.seq_len_dec] - assert list(encoder_decoder_mask[1].shape) == [self.bs, 1, 1, self.seq_len] - - @pytest.mark.internal - def test_transformer_engine_version_1_7_to_1_10_unfused_attention(self): - os.environ['NVTE_FLASH_ATTN'] = '0' - os.environ['NVTE_FUSED_ATTN'] = '0' - - encoder_mask, decoder_mask, encoder_decoder_mask = ( - T5MaskedWordPieceDataset.config_attention_mask( - self.encoder_tokens, - self.decoder_tokens, - self.encoder_mask, - self.decoder_mask, - use_local=False, - test_te_version="1.8", - ) - ) - - assert list(encoder_mask.shape) == [self.bs, 1, self.seq_len, self.seq_len] - assert decoder_mask is None - assert list(encoder_decoder_mask.shape) == [self.bs, 1, self.seq_len_dec, self.seq_len] - - @pytest.mark.internal - def test_transformer_engine_version_less_than_1_7(self): - os.environ['NVTE_FLASH_ATTN'] = '1' - with pytest.raises(Exception) as exc_info: - encoder_mask, decoder_mask, encoder_decoder_mask = ( - T5MaskedWordPieceDataset.config_attention_mask( - self.encoder_tokens, - self.decoder_tokens, - self.encoder_mask, - self.decoder_mask, - use_local=False, - test_te_version="1.5", - ) - ) - - assert str(exc_info.value) == ( - "Flash and fused attention is not supported with transformer " - "engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0" - "or upgrade transformer engine >= 1.7" - ) diff --git a/tests/unit_tests/pipeline_parallel/test_helpers.py b/tests/unit_tests/pipeline_parallel/test_helpers.py deleted file mode 100644 index a20c3a5401..0000000000 --- a/tests/unit_tests/pipeline_parallel/test_helpers.py +++ /dev/null @@ -1,124 +0,0 @@ -def compare_helpers(pipeline_parallel_size, num_microbatches, num_model_chunks): - total_num_microbatches = num_microbatches * num_model_chunks - - # Baseline helpers - def baseline_get_model_chunk_id(microbatch_id, forward): - """Helper method to get the model chunk ID given the iteration number.""" - microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) - model_chunk_id = microbatch_id_in_group // pipeline_parallel_size - if not forward: - model_chunk_id = num_model_chunks - model_chunk_id - 1 - return model_chunk_id - - def baseline_get_microbatch_id_in_model_chunk(iteration_id, forward): - """Helper method to get the microbatch_id within model chunk given the iteration number.""" - assert forward - iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks) - microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + ( - iteration_id % pipeline_parallel_size - ) - return microbatch_id_in_model_chunk - - def baseline_is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: - """Check if an iteration is the first for a model chunk.""" - microbatch_group_size = pipeline_parallel_size * num_model_chunks - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == 0: - return microbatch_id_in_group % pipeline_parallel_size == 0 - else: - return False - - def baseline_is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: - """Check if an iteration is the last for a model chunk.""" - microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = total_num_microbatches // microbatch_group_size - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == num_microbatch_groups - 1: - return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1 - else: - return False - - # Create schedule table prior to new helper methods - schedule_table = [] - for min_microbatch_id_in_group in range(0, num_microbatches, pipeline_parallel_size): - if min_microbatch_id_in_group + pipeline_parallel_size >= num_microbatches: - # Construct schedule for the last microbatch group - schedule_table.extend( - [ - (microbatch_id, model_chunk_id) - for model_chunk_id in range(num_model_chunks) - for microbatch_id in range(min_microbatch_id_in_group, num_microbatches) - ] - ) - else: - # Construct schedule for other microbatch groups - schedule_table.extend( - [ - (microbatch_id, model_chunk_id) - for model_chunk_id in range(num_model_chunks) - for microbatch_id in range( - min_microbatch_id_in_group, - min_microbatch_id_in_group + pipeline_parallel_size, - ) - ] - ) - - microbatch_id_table, model_chunk_id_table = zip(*schedule_table) - - # New helper methods that indexes schedule table - def new_get_model_chunk_id(virtual_microbatch_id, forward): - """Helper method to get the model chunk ID given the iteration number.""" - model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches] - if not forward: - model_chunk_id = num_model_chunks - model_chunk_id - 1 - return model_chunk_id - - def new_get_microbatch_id_in_model_chunk(iteration_id, forward): - """Helper method to get the microbatch_id within model chunk given the iteration number.""" - assert forward - microbatch_id_in_model_chunk = microbatch_id_table[iteration_id] - return microbatch_id_in_model_chunk - - def new_is_first_microbatch_for_model_chunk(virtual_microbatch_id: int) -> bool: - """Check if an iteration is the first for a model chunk.""" - if virtual_microbatch_id < total_num_microbatches: - return microbatch_id_table[virtual_microbatch_id] == 0 - else: - return False - - def new_is_last_microbatch_for_model_chunk(virtual_microbatch_id: int) -> bool: - """Check if an iteration is the last for a model chunk.""" - if virtual_microbatch_id < total_num_microbatches: - return microbatch_id_table[virtual_microbatch_id] == num_microbatches - 1 - else: - return False - - for i in range(total_num_microbatches): - # Test both forward and backward - assert baseline_get_model_chunk_id(i, forward=False) == new_get_model_chunk_id( - i, forward=False - ) - assert baseline_get_model_chunk_id(i, forward=True) == new_get_model_chunk_id( - i, forward=True - ) - - # Only used in forward - assert baseline_get_microbatch_id_in_model_chunk( - i, forward=True - ) == new_get_microbatch_id_in_model_chunk(i, forward=True) - - assert baseline_is_first_microbatch_for_model_chunk( - i - ) == new_is_first_microbatch_for_model_chunk(i) - assert baseline_is_last_microbatch_for_model_chunk( - i - ) == new_is_last_microbatch_for_model_chunk(i) - - -def test_helpers(): - for pp in [2, 4, 8]: - for m in [pp, 2 * pp, 4 * pp, 8 * pp]: - for vp in range(2, 13): - compare_helpers(pipeline_parallel_size=pp, num_microbatches=m, num_model_chunks=vp) diff --git a/tests/unit_tests/pipeline_parallel/test_pipeline_layout.py b/tests/unit_tests/pipeline_parallel/test_pipeline_layout.py deleted file mode 100644 index 86b76a526e..0000000000 --- a/tests/unit_tests/pipeline_parallel/test_pipeline_layout.py +++ /dev/null @@ -1,336 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import os -from pathlib import Path -from types import SimpleNamespace - -import pytest -import torch -import torch.distributed - -from megatron.core import mpu, parallel_state -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_with_transformer_engine_spec as gpt_te_spec, -) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.num_microbatches_calculator import ( - init_num_microbatches_calculator, - unset_num_microbatches_calculator, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training.checkpointing import load_checkpoint, save_checkpoint -from megatron.training.global_vars import set_args -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.dist_checkpointing.models.common import ( - common_test_parallel_reconfiguration_e2e, -) -from tests.unit_tests.test_utilities import Utils - - -def initialize_gpt_model( - seed, - layer_spec_fn=gpt_te_spec, - vocab_size=128, - virtual_pipeline_model_parallel_size=None, - is_moe=False, - with_mtp=False, - **config_kwargs, -): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - - default_config_kwargs = dict( - num_layers=8, - hidden_size=128, - num_attention_heads=8, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - bf16=True, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - hidden_dropout=0.0, - attention_dropout=0.0, - ) - default_config_kwargs.update(**config_kwargs) - transformer_config = TransformerConfig(**default_config_kwargs) - if is_moe: - transformer_config.moe_layer_freq = [0, 1, 1, 1, 1, 0, 1, 0] - transformer_config.moe_ffn_hidden_size = 128 - transformer_config.num_moe_experts = 4 - transformer_config.add_bias_linear = False - if with_mtp: - transformer_config.mtp_num_layers = 1 - transformer_config.mtp_loss_scaling_factor = 1.0 - model = [] - for i in range(virtual_pipeline_model_parallel_size or 1): - if is_moe: - layer_spec = layer_spec_fn(transformer_config, use_transformer_engine=True, vp_stage=i) - else: - layer_spec = layer_spec_fn() - - if is_moe and with_mtp and mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i): - transformer_layer_spec_for_mtp = gpt_te_spec(transformer_config) - mtp_block_spec = get_gpt_mtp_block_spec( - transformer_config, - transformer_layer_spec_for_mtp, - use_transformer_engine=True, - vp_stage=i, - ) - else: - mtp_block_spec = None - pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) - post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) - this_model = ( - GPTModel( - config=transformer_config, - transformer_layer_spec=layer_spec, - vocab_size=vocab_size, - max_sequence_length=4, - pre_process=pre_process, - post_process=post_process, - position_embedding_type="rope", - vp_stage=i, - mtp_block_spec=mtp_block_spec, - share_embeddings_and_output_weights=False, - ) - .bfloat16() - .cuda() - ) - this_model.model_type = ModelType.encoder_or_decoder - model.append(this_model) - - if virtual_pipeline_model_parallel_size is None: - model = model[0] - return model - - -@pytest.fixture -def create_args(): - """Setup dummy args.""" - args = SimpleNamespace() - args.finetune = False - args.non_persistent_global_ckpt_dir = None - args.non_persistent_ckpt_type = None - args.non_persistent_save_interval = None - args.exit_on_missing_checkpoint = True - args.async_save = False - args.data_parallel_random_init = False - args.log_progress = False - args.ckpt_fully_parallel_save = False - args.ckpt_fully_parallel_load = False - args.auto_detect_ckpt_format = False - args.retro_add_retriever = False - args.ckpt_convert_update_legacy_dist_opt_format = False - args.ckpt_step = None - args.use_dist_ckpt = True - args.consumed_train_samples = 0 - args.skipped_train_samples = 0 - args.consumed_valid_samples = 0 - args.vocab_file = None - args.add_position_embedding = False - args.ckpt_assume_constant_structure = True - args.dist_ckpt_strictness = "assume_ok_unexpected" - args.fp16 = False - args.bf16 = True - args.no_save_optim = True - args.no_save_rng = True - args.no_load_optim = True - args.no_load_rng = True - args.use_distributed_optimizer = True - args.use_megatron_fsdp = False - - yield args - - -# Dense and MoE Models -@pytest.mark.parametrize( - ('tp_pp_vpp', 'pp_layout', 'is_moe', 'with_mtp'), - [ - ((1, 2, 1), None, True, True), - ( - (1, 4, 2), - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - False, - True, - ), - ((1, 2, None), [["embedding"] + ["decoder"] * 4, ["decoder"] * 4 + ["loss"]], False, False), - ( - (1, 4, 2), - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - True, - False, - ), - ((1, 2, None), [["embedding"] + ["decoder"] * 4, ["decoder"] * 4 + ["loss"]], True, False), - ((1, 4, 2), "E|t*3|(t|)*5L", True, True), - ], -) -def test_forward_vpp(create_args, tmp_path_dist_ckpt, tp_pp_vpp, pp_layout, is_moe, with_mtp): - from megatron.core.pipeline_parallel import get_forward_backward_func - - args = create_args - # Model config - args.num_layers = 8 - args.hidden_size = 128 - args.num_attention_heads = 8 - # Ckpt format - args.ckpt_format = "torch_dist" - set_args(args) - - def set_tp_pp_vpp(tp, pp, vpp=None, pp_layout=None, destroy_first=True): - if destroy_first: - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(tp, pp, vpp) - args.tensor_model_parallel_size = tp - args.pipeline_model_parallel_size = pp - args.virtual_pipeline_model_parallel_size = vpp - args.pipeline_model_parallel_layout = pp_layout - - set_tp_pp_vpp(*tp_pp_vpp, pp_layout=pp_layout, destroy_first=False) - init_num_microbatches_calculator(0, None, 1, 1, 1) - - def forward_step_func(data_iterator, model: GPTModel): - """Forward training step. Copied from `pretrain_gpt.py`""" - tokens = torch.LongTensor([[2, 1, 2, 3, 4, 5, 7, 6]]).cuda() - position_ids = torch.arange(8).view(1, -1).cuda() - labels = torch.ones_like(position_ids) - attention_mask = None - - output_tensor = model(tokens, position_ids, attention_mask, labels=labels) - - def loss_func(output_tensor: torch.Tensor): - loss = output_tensor.sum() - return output_tensor, loss - - return output_tensor, loss_func - - iteration = 123 - layer_spec_fn = get_gpt_decoder_block_spec if is_moe else gpt_te_spec - model = initialize_gpt_model( - 1, - layer_spec_fn=layer_spec_fn, - num_layers=args.num_layers, - hidden_size=args.hidden_size, - num_attention_heads=args.num_attention_heads, - tensor_model_parallel_size=args.tensor_model_parallel_size, - pipeline_model_parallel_size=args.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_layout=args.pipeline_model_parallel_layout, - is_moe=is_moe, - with_mtp=with_mtp, - ) - model = model if isinstance(model, list) else [model] - - forward_backward_func = get_forward_backward_func() - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[get_batch_iterator(seq_length=8, micro_batch_size=1)] * len(model), - model=model, - num_microbatches=4, - seq_length=8, - micro_batch_size=1, - forward_only=True, - ) - - optimizer = None - opt_param_scheduler = None - num_floating_point_operations_so_far = 456 - - with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_A') as ckpt_dir: - args.save = ckpt_dir - args.load = ckpt_dir - save_checkpoint( - iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far - ) - print(f"save checkpoint done") - - set_tp_pp_vpp(1, 1) - model_baseline = initialize_gpt_model( - 123, - layer_spec_fn=layer_spec_fn, - num_layers=args.num_layers, - hidden_size=args.hidden_size, - num_attention_heads=args.num_attention_heads, - tensor_model_parallel_size=args.tensor_model_parallel_size, - pipeline_model_parallel_size=args.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_layout=args.pipeline_model_parallel_layout, - is_moe=is_moe, - with_mtp=with_mtp, - ) - load_checkpoint([model_baseline], optimizer, opt_param_scheduler, strict=False) - - forward_backward_func = get_forward_backward_func() - losses_reduced_baseline = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=get_batch_iterator(seq_length=8, micro_batch_size=1), - model=[model_baseline], - num_microbatches=4, - seq_length=8, - micro_batch_size=1, - forward_only=True, - ) - - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - for loss, loss_baseline in zip(losses_reduced, losses_reduced_baseline): - assert torch.equal(loss, loss_baseline) - - Utils.destroy_model_parallel() - unset_num_microbatches_calculator() - - -def get_batch_iterator(seq_length, micro_batch_size, num_batches=None): - """ - Generator function that yields batches indefinitely or for a specified number of batches. - - Args: - seq_length: Length of the sequence - micro_batch_size: Size of each micro batch - num_batches: Optional number of batches to generate. If None, generates indefinitely. - """ - batch_count = 0 - while num_batches is None or batch_count < num_batches: - # Generate different data for each batch by adding batch_count offset - data = list(range(batch_count, batch_count + seq_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - labels = 1 + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = ( - torch.tensor(list(range(seq_length)), dtype=torch.int64) - .repeat((micro_batch_size, 1)) - .cuda() - ) - attention_mask = torch.ones( - (micro_batch_size, 1, seq_length, seq_length), dtype=bool - ).cuda() - loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda() - - yield input_ids, labels, position_ids, attention_mask, loss_mask - batch_count += 1 - - -# if __name__ == "__main__": -# import os - -# args = create_args() -# test_forward_vpp(args, Path("./tmp_path_dist_ckpt"), (1, 2, 1), None, True, True) -# print("test done") diff --git a/tests/unit_tests/pipeline_parallel/test_schedules.py b/tests/unit_tests/pipeline_parallel/test_schedules.py deleted file mode 100644 index 6ddec7ae9c..0000000000 --- a/tests/unit_tests/pipeline_parallel/test_schedules.py +++ /dev/null @@ -1,718 +0,0 @@ -import os - -import pytest -import torch -import torch.distributed as dist -from packaging import version -from pytest_mock import mocker - -import megatron.core.pipeline_parallel.schedules as schedule -from megatron.core import ModelParallelConfig -from megatron.core.distributed.finalize_model_grads import finalize_model_grads -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator -from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage -from megatron.core.process_groups_config import GradFinalizeProcessGroups -from tests.unit_tests.test_utilities import Utils - -rank = Utils.rank - - -def _populate_embedding_and_position_groups(pp_group): - """Create *new* embedding-related process groups from *pp_group* ranks.""" - - pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) - - pos_embd_ranks = [pp_ranks[0]] - embd_ranks = [pp_ranks[0]] - if pp_ranks[-1] != pp_ranks[0]: - embd_ranks.append(pp_ranks[-1]) - - pos_embd_pg = dist.new_group(ranks=pos_embd_ranks) - embd_pg = dist.new_group(ranks=embd_ranks) - - return pos_embd_pg, embd_pg - - -def test_get_forward_backward_func(): - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) - assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) - assert ( - schedule.get_forward_backward_func() - == schedule.forward_backward_pipelining_without_interleaving - ) - Utils.destroy_model_parallel() - Utils.initialize_model_parallel( - tensor_model_parallel_size=2, - pipeline_model_parallel_size=4, - virtual_pipeline_model_parallel_size=2, - ) - assert ( - schedule.get_forward_backward_func() - == schedule.forward_backward_pipelining_with_interleaving - ) - Utils.destroy_model_parallel() - Utils.initialize_model_parallel( - tensor_model_parallel_size=2, - pipeline_model_parallel_size=2, - virtual_pipeline_model_parallel_size=4, - ) - assert ( - schedule.get_forward_backward_func() - == schedule.forward_backward_pipelining_with_interleaving - ) - Utils.destroy_model_parallel() - - -def test_deallocate_output_tensor(): - out = torch.tensor([[1, 2, 3], [4, 5, 6]]) - schedule.deallocate_output_tensor(out) - assert out.nelement() == 6 - - -@pytest.mark.internal -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "pipeline_model_parallel_size,microbatch_group_size_per_vp_stage", - [(1, 1), (2, 2), (2, 4), (4, 4), (4, 5), (8, 9), (8, 11)], -) -@pytest.mark.parametrize("num_microbatches", [8, 32]) -@pytest.mark.parametrize("virtual_pipeline_model_parallel_size", [None, 2, 4, 8]) -def test_get_pipeline_parallel_order( - pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size, - num_microbatches, - microbatch_group_size_per_vp_stage, -): - if pipeline_model_parallel_size == 1 and virtual_pipeline_model_parallel_size is not None: - return - - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - ) - num_model_chunks = ( - virtual_pipeline_model_parallel_size - if virtual_pipeline_model_parallel_size is not None - else 1 - ) - - _, _, num_warmup_microbatches, _ = schedule.get_pp_rank_microbatches( - num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, False - ) - schedule_table = schedule.get_schedule_table( - num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage - ) - order = schedule.convert_schedule_table_to_order( - num_warmup_microbatches, num_model_chunks, schedule_table - ) - - assert max(order) == num_model_chunks - assert len(order) == num_microbatches * num_model_chunks * 2 - order_cnt = {} - accumulated_order = 0 - for o in order: - order_cnt[o] = order_cnt.get(o, 0) + 1 - if o < 0: - assert -o in order_cnt and order_cnt[-o] >= order_cnt[o] - elif -o in order_cnt: - assert order_cnt[-o] < order_cnt[o] - accumulated_order += o - assert accumulated_order >= 0 - assert accumulated_order == 0 - assert 0 not in order_cnt - for k, v in order_cnt.items(): - assert -k in order_cnt and order_cnt[-k] == v - - Utils.destroy_model_parallel() - - -def test_forward_backward_func_without_pipeline_parallel(mocker): - from megatron.core.pipeline_parallel import get_forward_backward_func - - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) - - def forward_step_func(data_iterator, model): - import os - - rank = int(os.environ['LOCAL_RANK']) - dummy_data = torch.ones(1, 4) - - def loss_func(output_tensor): - return rank, {'loss_reduced': rank} - - return model(dummy_data), loss_func - - model = torch.nn.Linear(4, 1) - model.model_type = 'unit-test' - - def set_input_tensor(input_tensor): - return None - - model.set_input_tensor = set_input_tensor - - forward_backward_func = get_forward_backward_func() - assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining - - mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2) - config = ModelParallelConfig(pipeline_model_parallel_size=1) - model.config = config - - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=range(0, 100), - model=[model], - num_microbatches=4, - seq_length=None, - micro_batch_size=None, - forward_only=True, - ) - - loss_reduced_expected = [ - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - ] - - for i, j in zip(losses_reduced, loss_reduced_expected): - assert i['loss_reduced'] == j['loss_reduced'] - Utils.destroy_model_parallel() - - -def test_forward_backward_func_with_pipeline_parallel(mocker): - from megatron.core.pipeline_parallel import get_forward_backward_func - - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) - - def forward_step_func(data_iterator, model): - import os - - rank = int(os.environ['LOCAL_RANK']) - - def loss_func(output_tensor): - return rank, {'loss_reduced': rank} - - return torch.rand(512, 8, 256).cuda(), loss_func - - model = torch.nn.Linear(4, 1) - model.model_type = 'unit-test' - - def set_input_tensor(input_tensor): - return None - - model.set_input_tensor = set_input_tensor - - forward_backward_func = get_forward_backward_func() - assert ( - schedule.get_forward_backward_func() - == schedule.forward_backward_pipelining_without_interleaving - ) - - sequence_length = 512 - micro_batch_size = 8 - hidden_size = 256 - - config = ModelParallelConfig( - pipeline_model_parallel_size=4, sequence_parallel=False, pipeline_dtype=torch.float - ) - config.hidden_size = hidden_size - model.config = config - - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=None, - model=[model], - num_microbatches=micro_batch_size, - seq_length=sequence_length, - micro_batch_size=micro_batch_size, - forward_only=True, - ) - - loss_reduced_expected = [ - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - ] - for i, j in zip(losses_reduced, loss_reduced_expected): - print(losses_reduced) - assert i['loss_reduced'] == j['loss_reduced'] - Utils.destroy_model_parallel() - - -@pytest.mark.internal -def test_forward_backward_func_with_interleaving(mocker): - from megatron.core.enums import ModelType - from megatron.core.pipeline_parallel import get_forward_backward_func - - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=4, - virtual_pipeline_model_parallel_size=2, - ) - - def forward_step_func(data_iterator, model): - import os - - rank = int(os.environ['LOCAL_RANK']) - - def loss_func(output_tensor): - return rank, {'loss_reduced': rank} - - return torch.rand(512, 8, 256).cuda(), loss_func - - model = torch.nn.Linear(4, 1) - - def set_input_tensor(input_tensor): - return None - - model.set_input_tensor = set_input_tensor - - forward_backward_func = get_forward_backward_func() - assert ( - schedule.get_forward_backward_func() - == schedule.forward_backward_pipelining_with_interleaving - ) - - sequence_length = 512 - micro_batch_size = 8 - hidden_size = 256 - - config = ModelParallelConfig( - pipeline_model_parallel_size=4, - sequence_parallel=False, - pipeline_dtype=torch.float, - virtual_pipeline_model_parallel_size=2, - ) - config.hidden_size = hidden_size - model.config = config - - mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2) - - loss_reduced_expected = [ - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - ] - - model.model_type = ModelType.encoder_or_decoder - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[range(0, 100), range(0, 100)], - model=[model, model], - num_microbatches=micro_batch_size, - seq_length=sequence_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=256, - forward_only=True, - ) - - for i, j in zip(losses_reduced, loss_reduced_expected): - print(f"losses_reduced: {i} loss_reduced_expected: {j}") - assert i['loss_reduced'] == j['loss_reduced'] - - with pytest.raises(RuntimeError): - model.model_type = ModelType.encoder_or_decoder - forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[range(0, 100), range(0, 100)], - model=[model, model], - num_microbatches=7, - seq_length=sequence_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=512, - forward_only=True, - ) - - model.model_type = ModelType.encoder_or_decoder - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[range(0, 100), range(0, 100)], - model=[model, model], - num_microbatches=micro_batch_size, - seq_length=sequence_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=sequence_length, - forward_only=True, - ) - - for i, j in zip(losses_reduced, loss_reduced_expected): - print(f"losses_reduced: {i} loss_reduced_expected: {j}") - assert i['loss_reduced'] == j['loss_reduced'] - - Utils.destroy_model_parallel() - - -@pytest.mark.internal -def test_forward_backward_func_with_uneven_interleaving(mocker): - from megatron.core.enums import ModelType - from megatron.core.pipeline_parallel import get_forward_backward_func - - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=4, - virtual_pipeline_model_parallel_size=2, - ) - - def forward_step_func(data_iterator, model): - import os - - rank = int(os.environ['LOCAL_RANK']) - - def loss_func(output_tensor): - return rank, {'loss_reduced': rank} - - return torch.rand(512, 8, 256).cuda(), loss_func - - model_a = torch.nn.Linear(4, 1) - model_b = torch.nn.Linear(8, 1) - model_a.vp_stage = 0 - model_b.vp_stage = 1 - - def set_input_tensor(input_tensor): - return None - - model_a.set_input_tensor = set_input_tensor - model_b.set_input_tensor = set_input_tensor - - forward_backward_func = get_forward_backward_func() - assert ( - schedule.get_forward_backward_func() - == schedule.forward_backward_pipelining_with_interleaving - ) - - sequence_length = 512 - micro_batch_size = 8 - hidden_size = 256 - - config = ModelParallelConfig( - pipeline_model_parallel_size=4, - sequence_parallel=False, - pipeline_dtype=torch.float, - virtual_pipeline_model_parallel_size=2, - ) - config.hidden_size = hidden_size - model_a.config = config - model_b.config = config - - mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2) - - loss_reduced_expected = [ - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - ] - - model_a.model_type = ModelType.encoder_or_decoder - model_b.model_type = ModelType.encoder_or_decoder - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[range(0, 100), range(0, 100)], - model=[model_a, model_b], - num_microbatches=micro_batch_size, - seq_length=sequence_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=256, - forward_only=True, - ) - - for i, j in zip(losses_reduced, loss_reduced_expected): - print(f"losses_reduced: {i} loss_reduced_expected: {j}") - assert i['loss_reduced'] == j['loss_reduced'] - - with pytest.raises(RuntimeError): - model_a.model_type = ModelType.encoder_or_decoder - model_b.model_type = ModelType.encoder_or_decoder - forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[range(0, 100)], - model=[model_a, model_b], - num_microbatches=7, - seq_length=sequence_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=512, - forward_only=True, - ) - - model_a.model_type = ModelType.encoder_or_decoder - model_b.model_type = ModelType.encoder_or_decoder - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[range(0, 100), range(0, 100)], - model=[model_a, model_b], - num_microbatches=micro_batch_size, - seq_length=sequence_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=sequence_length, - forward_only=True, - ) - - for i, j in zip(losses_reduced, loss_reduced_expected): - print(f"losses_reduced: {i} loss_reduced_expected: {j}") - assert i['loss_reduced'] == j['loss_reduced'] - - Utils.destroy_model_parallel() - - -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", -) -@pytest.mark.internal -def test_forward_backward_pipelining_without_interleaving_with_custom_pgs(mocker): - """Test that forward_backward_pipelining_without_interleaving produces the same output - with and without explicit process group parameters.""" - - # Initialize model parallel with pipeline parallelism (no interleaving) - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) - - def dummy_step_func(data_iterator, model): - rank = int(os.environ['LOCAL_RANK']) - - def loss_func(output_tensor): - return rank, {'loss_reduced': rank} - - return torch.rand(512, 8, 256).cuda(), loss_func - - # Create model - model = torch.nn.Linear(4, 1) - model.model_type = 'unit-test' - - def return_none(input_tensor): - return None - - model.set_input_tensor = return_none - - sequence_length = 512 - micro_batch_size = 8 - hidden_size = 256 - - config = ModelParallelConfig( - pipeline_model_parallel_size=4, sequence_parallel=False, pipeline_dtype=torch.float - ) - config.hidden_size = hidden_size - config.finalize_model_grads_func = finalize_model_grads - model.config = config - - # Mock custom_backward to avoid actual computation - mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2) - - # Common arguments for both calls - common_args = { - 'forward_step_func': dummy_step_func, - 'data_iterator': None, - 'model': [model], - 'num_microbatches': micro_batch_size, - 'seq_length': sequence_length, - 'micro_batch_size': micro_batch_size, - 'forward_only': True, - } - - # First call: without providing process group parameters (they'll be created internally) - losses_reduced_default = schedule.forward_backward_pipelining_without_interleaving( - **common_args - ) - - grid = HyperCommGrid([2, 1, 4, 1], ["tp", "cp", "pp", "dp"]) - - pp_group = grid.create_pg("pp") - p2p_communicator = P2PCommunicator(pp_group=pp_group, config=config) - pos_embd_pg, embd_pg = _populate_embedding_and_position_groups(pp_group) - pos_embd_pg = pos_embd_pg if is_pp_first_stage(pp_group) else None - embd_pg = embd_pg if (is_pp_last_stage(pp_group) or is_pp_first_stage(pp_group)) else None - dp_cp_group = grid.create_pg(["dp", "cp"]) - - grad_finalize_pgs = GradFinalizeProcessGroups() - grad_finalize_pgs.tp = grid.create_pg("tp") - grad_finalize_pgs.pp = pp_group - grad_finalize_pgs.embd = embd_pg - grad_finalize_pgs.pos_embd = pos_embd_pg - grad_finalize_pgs.dp_cp = dp_cp_group - grad_finalize_pgs.cp = grid.create_pg("cp") - - losses_reduced_explicit = schedule.forward_backward_pipelining_without_interleaving( - p2p_communicator=p2p_communicator, grad_finalize_pgs=grad_finalize_pgs, **common_args - ) - - assert len(losses_reduced_default) == len( - losses_reduced_explicit - ), "Output lengths should be identical" - - for i, (default_loss, explicit_loss) in enumerate( - zip(losses_reduced_default, losses_reduced_explicit) - ): - assert ( - default_loss == explicit_loss - ), f"Loss at index {i} should be identical between default and explicit PG calls" - Utils.destroy_model_parallel() - - -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", -) -@pytest.mark.internal -def test_forward_backward_pipelining_with_interleaving_with_custom_pgs(mocker): - """Test that forward_backward_pipelining_with_interleaving produces the same output - with and without explicit process group parameters.""" - - from megatron.core.enums import ModelType - from megatron.core.pipeline_parallel import get_forward_backward_func - - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=4, - virtual_pipeline_model_parallel_size=2, - ) - - def forward_step_func(data_iterator, model): - import os - - rank = int(os.environ['LOCAL_RANK']) - - def loss_func(output_tensor): - return rank, {'loss_reduced': rank} - - return torch.rand(512, 8, 256).cuda(), loss_func - - model = torch.nn.Linear(4, 1) - - def set_input_tensor(input_tensor): - return None - - model.set_input_tensor = set_input_tensor - - forward_backward_func = get_forward_backward_func() - assert ( - schedule.get_forward_backward_func() - == schedule.forward_backward_pipelining_with_interleaving - ) - - sequence_length = 512 - micro_batch_size = 8 - hidden_size = 256 - - config = ModelParallelConfig( - pipeline_model_parallel_size=4, - sequence_parallel=False, - pipeline_dtype=torch.float, - virtual_pipeline_model_parallel_size=2, - ) - config.hidden_size = hidden_size - model.config = config - - mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2) - - loss_reduced_expected = [ - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - ] - - grid = HyperCommGrid([1, 1, 4, 2], ["tp", "cp", "pp", "dp"]) - pp_group = grid.create_pg("pp") - p2p_communicator = P2PCommunicator(pp_group=pp_group, config=config) - pos_embd_pg, embd_pg = _populate_embedding_and_position_groups(pp_group) - pos_embd_pg = pos_embd_pg if is_pp_first_stage(pp_group) else None - embd_pg = embd_pg if (is_pp_last_stage(pp_group) or is_pp_first_stage(pp_group)) else None - - grad_finalize_pgs = GradFinalizeProcessGroups() - grad_finalize_pgs.tp = grid.create_pg("tp") - grad_finalize_pgs.cp = grid.create_pg("cp") - grad_finalize_pgs.pp = pp_group - grad_finalize_pgs.embd = embd_pg - grad_finalize_pgs.pos_embd = pos_embd_pg - grad_finalize_pgs.dp_cp = grid.create_pg(["dp", "cp"]) - - model.model_type = ModelType.encoder_or_decoder - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[range(0, 100), range(0, 100)], - model=[model, model], - num_microbatches=micro_batch_size, - seq_length=sequence_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=256, - forward_only=True, - grad_finalize_pgs=grad_finalize_pgs, - p2p_communicator=p2p_communicator, - ) - - for i, j in zip(losses_reduced, loss_reduced_expected): - print(f"losses_reduced: {i} loss_reduced_expected: {j}") - assert i['loss_reduced'] == j['loss_reduced'] - - Utils.destroy_model_parallel() - - -def test_forward_backward_no_pipelining_with_custom_pgs(mocker): - """Validate no-pipeline schedule when explicit custom PGs are provided.""" - - from megatron.core.pipeline_parallel import get_forward_backward_func - - Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) - - def forward_step_func(data_iterator, model): - import os - - rank_local = int(os.environ['LOCAL_RANK']) - - def loss_func(output_tensor): - return rank_local, {'loss_reduced': rank_local} - - dummy_inp = torch.ones(1, 4) - return model(dummy_inp), loss_func - - # Simple model. - model = torch.nn.Linear(4, 1) - model.model_type = 'unit-test' - model.set_input_tensor = lambda _tensor: None # type: ignore[assignment] - - # Minimal config. - config = ModelParallelConfig(pipeline_model_parallel_size=1) - model.config = config - - grid = HyperCommGrid([2, 1, 1, 4], ["tp", "cp", "pp", "dp"]) - - pp_group = grid.create_pg("pp") - tp_group = grid.create_pg("tp") - cp_group = grid.create_pg("cp") - pos_embd_pg, embd_pg = _populate_embedding_and_position_groups(pp_group) - dp_cp_group = grid.create_pg(["dp", "cp"]) - - grad_finalize_pgs = GradFinalizeProcessGroups() - grad_finalize_pgs.tp = tp_group - grad_finalize_pgs.cp = cp_group - grad_finalize_pgs.embd = embd_pg - grad_finalize_pgs.pos_embd = pos_embd_pg - grad_finalize_pgs.pp = pp_group - grad_finalize_pgs.dp_cp = dp_cp_group - - forward_backward_func = get_forward_backward_func() - assert forward_backward_func == schedule.forward_backward_no_pipelining - - mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2) - - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=range(0, 10), - model=[model], - num_microbatches=4, - seq_length=None, - micro_batch_size=None, - forward_only=True, - grad_finalize_pgs=grad_finalize_pgs, - ) - - expected = {'loss_reduced': Utils.rank} - for l in losses_reduced: - assert l['loss_reduced'] == expected['loss_reduced'] - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/post_training/test_modelopt_module_spec.py b/tests/unit_tests/post_training/test_modelopt_module_spec.py deleted file mode 100644 index f27a22390f..0000000000 --- a/tests/unit_tests/post_training/test_modelopt_module_spec.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import inspect -import tempfile - -import pytest -import torch -from packaging.version import Version - -from megatron.core import dist_checkpointing -from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_decoder_block_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.models.mamba.mamba_model import MambaModel -from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec -from megatron.core.post_training.modelopt.gpt.state_dict_hooks import ( - mcore_gpt_load_te_state_dict_pre_hook, -) -from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.transformer_config import MLATransformerConfig -from megatron.core.utils import get_te_version -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -def model_forward(model: torch.nn.Module, config: TransformerConfig, micro_batch_size: int = 2): - inference_context: StaticInferenceContext = StaticInferenceContext( - max_batch_size=micro_batch_size, max_sequence_length=model.max_sequence_length - ) - prompt_length = model.max_sequence_length - 1 - - # load-context/first-output-token, step/generate - for offset in (0, prompt_length): - if offset == 0: - sequence_length = prompt_length - else: - sequence_length = 1 - inference_context.sequence_len_offset = offset - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - inference_context=inference_context, - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == model.vocab_size - - -class TestModelOptGPTModel: - - _test_inference = True - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self._dist_checkpoint_name = "standard_gpt_model" - - transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - self.default_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), - vocab_size=100, - max_sequence_length=4, - ) - # Ensure that a GPTModel can be built with the modelopt spec. - self.modelopt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_modelopt_spec( - transformer_config, remap_te_layernorm=True - ), - vocab_size=100, - max_sequence_length=4, - ) - - def test_sharded_state_dict_restore(self, tmp_path_dist_ckpt): - """Save with the default TE spec and restore using the ModelOpt spec.""" - _dist_checkpoint_name = "default_model" - te_fused_sharded_state_dict = self.default_model.sharded_state_dict() - modelopt_sharded_state_dict = self.modelopt_model.sharded_state_dict() - - with TempNamedDir(tmp_path_dist_ckpt / _dist_checkpoint_name, sync=True) as tmpdirname: - dist_checkpointing.save(te_fused_sharded_state_dict, tmpdirname) - state_dict = dist_checkpointing.load(modelopt_sharded_state_dict, tmpdirname) - self.modelopt_model.load_state_dict(state_dict) - - def test_inference(self): - if not self._test_inference: - return - config: TransformerConfig = self.modelopt_model.config - model = self.modelopt_model.cuda() - model_forward(model, config) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - -class TestModelOptMLAMoE(TestModelOptGPTModel): - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - - # Early version of TE DotProductAttention does not support - # q, k, v to have different shapes. - self._test_inference = get_te_version() > Version("1.10") - - transformer_config = MLATransformerConfig( - num_layers=2, - hidden_size=512, - num_attention_heads=8, - add_bias_linear=False, - num_moe_experts=2, - moe_layer_freq=[0, 1], - moe_ffn_hidden_size=128, - moe_shared_expert_intermediate_size=128, - qk_layernorm=True, - use_cpu_initialization=True, - ) - default_spec = get_gpt_decoder_block_spec(transformer_config, use_transformer_engine=True) - self.default_model = GPTModel( - config=transformer_config, - transformer_layer_spec=default_spec, - vocab_size=100, - max_sequence_length=8, - ) - modelopt_spec = get_gpt_modelopt_spec(transformer_config, remap_te_layernorm=True) - # Ensure that a GPTModel can be built with the modelopt spec. - self.modelopt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=modelopt_spec, - vocab_size=100, - max_sequence_length=8, - ) - - -class TestModelOptLlama4MoE(TestModelOptGPTModel): - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - - # Early version of TE DotProductAttention does not support - # q, k, v to have different shapes. - self._test_inference = get_te_version() > Version("1.10") - - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=512, - num_attention_heads=8, - add_bias_linear=False, - num_moe_experts=2, - moe_layer_freq=[0, 1], - moe_ffn_hidden_size=128, - moe_shared_expert_intermediate_size=128, - qk_layernorm=True, - use_cpu_initialization=True, - ) - default_spec = get_gpt_decoder_block_spec( - transformer_config, use_transformer_engine=True, qk_l2_norm=True - ) - self.default_model = GPTModel( - config=transformer_config, - transformer_layer_spec=default_spec, - vocab_size=100, - max_sequence_length=8, - ) - modelopt_spec = get_gpt_modelopt_spec( - transformer_config, remap_te_layernorm=True, qk_l2_norm=True - ) - # Ensure that a GPTModel can be built with the modelopt spec. - self.modelopt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=modelopt_spec, - vocab_size=100, - max_sequence_length=8, - ) - - -class TestModelOptMambaModel(TestModelOptGPTModel): - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=3, hidden_size=256, num_attention_heads=4, use_cpu_initialization=True - ) - - # A Hybrid MambaModel using fused-TE spec (default) - self.default_model = MambaModel( - config=transformer_config, - mamba_stack_spec=mamba_stack_spec, - vocab_size=100, - max_sequence_length=4, - hybrid_override_pattern="M*-", - ) - - # A Hybrid MambaModel using ModelOpt spec (local + TENorm). - self.modelopt_model = MambaModel( - config=transformer_config, - mamba_stack_spec=get_mamba_stack_modelopt_spec(remap_te_layernorm=True), - vocab_size=100, - max_sequence_length=4, - hybrid_override_pattern="M*-", - ) - - -def test_get_gpt_modelopt_spec_interface(): - # Get the function signature - sig = inspect.signature(get_gpt_modelopt_spec) - - # Define the expected signature - expected_params = { - "config": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "local_core_attention": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "remap_te_layernorm": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "real_quant_cfg": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "qk_l2_norm": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "use_arbitrary_attention_mask": inspect.Parameter.POSITIONAL_OR_KEYWORD, - } - - expected_defaults = { - "local_core_attention": False, - "remap_te_layernorm": False, - "real_quant_cfg": "None", - "qk_l2_norm": False, - "use_arbitrary_attention_mask": False, - } - - # Check expected parameters are in function signature - for param_name, param_kind in expected_params.items(): - assert param_name in sig.parameters, f"Unexpected parameter: {param_name}" - assert ( - param_kind is sig.parameters[param_name].kind - ), f"Wrong kind for parameter: {param_name}" - - # Check default values - sig_defaults = { - k: v.default for k, v in sig.parameters.items() if v.default is not inspect.Parameter.empty - } - for k, v in expected_defaults.items(): - assert ( - k in sig_defaults and v == sig_defaults[k] - ), f"Default value of {sig_defaults[k]} does not match the expected value of {v} for parameter {k}." - - -def test_get_mamba_stack_modelopt_spec_interface(): - # Get the function signature - sig = inspect.signature(get_mamba_stack_modelopt_spec) - - # Define the expected signature - expected_params = { - "local_core_attention": inspect.Parameter.POSITIONAL_OR_KEYWORD, - "remap_te_layernorm": inspect.Parameter.POSITIONAL_OR_KEYWORD, - } - - expected_defaults = {"local_core_attention": False, "remap_te_layernorm": False} - - # Check expected parameters are in function signature - for param_name, param_kind in expected_params.items(): - assert param_name in sig.parameters, f"Unexpected parameter: {param_name}" - assert ( - param_kind is sig.parameters[param_name].kind - ), f"Wrong kind for parameter: {param_name}" - - # Check default values - sig_defaults = { - k: v.default for k, v in sig.parameters.items() if v.default is not inspect.Parameter.empty - } - for k, v in expected_defaults.items(): - assert ( - k in sig_defaults and v == sig_defaults[k] - ), f"Default value of {sig_defaults[k]} does not match the expected value of {v} for parameter {k}." diff --git a/tests/unit_tests/ssm/test_mamba_block.py b/tests/unit_tests/ssm/test_mamba_block.py deleted file mode 100644 index e72d05c6f1..0000000000 --- a/tests/unit_tests/ssm/test_mamba_block.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.ssm.mamba_block import MambaStack -from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols -from megatron.core.ssm.mamba_layer import MambaLayer -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.attention import SelfAttention -from megatron.core.transformer.mlp import MLP -from megatron.core.transformer.transformer_layer import TransformerLayer -from tests.unit_tests.test_utilities import Utils - - -@pytest.mark.internal -class TestMambaBlock: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - - def get_model_comm_pgs(self): - return ModelCommProcessGroups.use_mpu_process_groups(required_pgs=['tp', 'pp', 'cp']) - - def get_mamba_block(self, hybrid_override_pattern): - transformer_config = TransformerConfig( - hidden_size=256, # The Mamba layer places several constraints on this - # Need to specify num_attention_heads and num_layers or TransformerConfig - # will generate errors. - num_layers=len(hybrid_override_pattern), - num_attention_heads=4, - use_cpu_initialization=True, - ) - modules = mamba_stack_spec.submodules - return MambaStack( - transformer_config, - modules, - hybrid_override_pattern=hybrid_override_pattern, - model_comm_pgs=self.get_model_comm_pgs(), - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_gpu_forward(self): - """Test GPU forward pass.""" - hybrid_override_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP - block = self.get_mamba_block(hybrid_override_pattern) - block.cuda() - micro_batch_size = 2 - sequence_length = 32 - hidden_states = torch.ones((sequence_length, micro_batch_size, block.config.hidden_size)) - hidden_states = hidden_states.cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ) - attention_mask = attention_mask.cuda() - output = block(hidden_states, attention_mask=attention_mask) - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == block.config.hidden_size - assert output.dtype == torch.float32 - - def test_layer_types(self): - """ - Make sure that the layer types specified with hybrid_override_pattern - were honored. - """ - hybrid_override_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP - block = self.get_mamba_block(hybrid_override_pattern) - layers = block.layers - # Note that this matches the order specified by hybrid_override_pattern in setup_method - assert isinstance(layers[0], MambaLayer) - assert isinstance(layers[1], TransformerLayer) - assert isinstance(layers[1].self_attention, SelfAttention) - assert isinstance(layers[2], TransformerLayer) - assert isinstance(layers[2].mlp, MLP) - - def test_invalid_layer_types_cause_failure(self): - invalid_symbol = '+' - assert invalid_symbol not in Symbols.VALID # sanity check. - hybrid_override_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP + invalid_symbol - # _allocate_override() in mamba_hybrid_layer_allocation.py throws a ValueError. - with pytest.raises(ValueError): - block = self.get_mamba_block(hybrid_override_pattern) diff --git a/tests/unit_tests/ssm/test_mamba_context_parallel.py b/tests/unit_tests/ssm/test_mamba_context_parallel.py deleted file mode 100644 index 59cd1fb061..0000000000 --- a/tests/unit_tests/ssm/test_mamba_context_parallel.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import math - -import pytest -import torch -import torch.nn as nn - -from megatron.core import parallel_state -from megatron.core.ssm.mamba_context_parallel import MambaContextParallel -from tests.unit_tests.test_utilities import Utils - - -@pytest.mark.internal -class TestMambaContextParallel: - - @pytest.mark.parametrize( - "ngroups_local_tp, cp_size, D_has_hdim", - [ - (16, 4, False), # ngroups_local_tp > cp_size - (8, 8, False), # ngroups_local_tp == cp_size - (4, 8, False), # ngroups_local_tp < cp_size - (1, 4, True), # ngroups_local_tp < cp_size - ], - ) - def test_forward(self, ngroups_local_tp, cp_size, D_has_hdim): - Utils.initialize_model_parallel(context_parallel_size=cp_size) - - dtype = torch.bfloat16 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - headdim = 64 - d_inner_local_tp = cp_size * headdim - nheads_local_tp = d_inner_local_tp // headdim - d_state = 128 - - conv_dim = d_inner_local_tp + 2 * ngroups_local_tp * d_state - conv_bias = True - d_conv = 4 - # weight shape: [conv_dim, 1, d_conv] - # bias shape: [conv_dim] - conv1d_cp1 = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - device=device, - dtype=dtype, - ) - - dt_bias_cp1 = torch.rand(nheads_local_tp, device=device, dtype=dtype) - A_log_cp1 = torch.rand(nheads_local_tp, device=device, dtype=dtype) - D_cp1 = torch.rand( - d_inner_local_tp if D_has_hdim else nheads_local_tp, device=device, dtype=dtype - ) - - cp = MambaContextParallel( - cp_group=parallel_state.get_context_parallel_group(), - d_inner_local_tp=d_inner_local_tp, - nheads_local_tp=nheads_local_tp, - ngroups_local_tp=ngroups_local_tp, - d_state=d_state, - conv1d_cp1=conv1d_cp1, - dt_bias_cp1=dt_bias_cp1, - A_log_cp1=A_log_cp1, - D_cp1=D_cp1, - D_has_hdim=D_has_hdim, - ) - - sequence_length = cp_size * 2 - batch_size = 1 - - # pre_conv_ssm - sequence_length_cp = sequence_length // cp_size - in_hidden = 2 * d_inner_local_tp + 2 * ngroups_local_tp * d_state + nheads_local_tp - in_shape = [sequence_length_cp, batch_size, in_hidden] - in_tensor = torch.rand(in_shape, device=device, dtype=dtype) - pre_conv_ssm_tensor = cp.pre_conv_ssm(in_tensor) - if ngroups_local_tp < cp_size: - repeat_groups = cp_size // ngroups_local_tp - else: - repeat_groups = 1 - repeated_groups_size = ngroups_local_tp * d_state * repeat_groups - expected_hidden = ( - 2 * d_inner_local_tp + 2 * repeated_groups_size + nheads_local_tp - ) // cp_size - assert list(pre_conv_ssm_tensor.shape) == [sequence_length, batch_size, expected_hidden] - - d_inner_local_tpcp = d_inner_local_tp // cp_size - - # post_conv_ssm - y_shape = [sequence_length, batch_size, d_inner_local_tpcp] - y_tensor = torch.rand(y_shape, device=device, dtype=dtype) - y_tensor = cp.post_conv_ssm(y_tensor) - assert list(y_tensor.shape) == [sequence_length_cp, batch_size, d_inner_local_tp] - - # conv1d - conv_dim_cp = (d_inner_local_tp + 2 * repeated_groups_size) // cp_size - conv_input_shape = [batch_size, conv_dim_cp, sequence_length] - conv_input = torch.rand(conv_input_shape, device=device, dtype=dtype) - conv_output = cp.conv1d(conv_input) - assert list(conv_output.shape) == [batch_size, conv_dim_cp, sequence_length + d_conv - 1] - - # conv1d_channels - assert cp.conv1d_channels() == conv_dim_cp - - # get_conv1d_weight - assert list(cp.get_conv1d_weight().shape) == [conv_dim_cp, 1, d_conv] - - # get_conv1d_bias - assert list(cp.get_conv1d_bias().shape) == [conv_dim_cp] - - nheads_local_tpcp = nheads_local_tp // cp_size - - # get_dt_bias - assert list(cp.get_dt_bias().shape) == [nheads_local_tpcp] - - # get_A_log - assert list(cp.get_A_log().shape) == [nheads_local_tpcp] - - # get_D - assert list(cp.get_D().shape) == [d_inner_local_tpcp if D_has_hdim else nheads_local_tpcp] - - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - "nheads_tp, ngroups_tp, cp_size, expected_error_message", - [ - (3, 2, 2, "nheads must be evenly divisible by tp_size \\* cp_size"), - (12, 3, 4, "cp_size must be evenly divisible by ngroups/tp_size"), - (12, 3, 2, "ngroups must be evenly divisible by tp_size \\* cp_size"), - ], - ) - def test_error_check(self, nheads_tp, ngroups_tp, cp_size, expected_error_message): - Utils.initialize_model_parallel(context_parallel_size=cp_size) - with pytest.raises(AssertionError, match=expected_error_message): - cp = MambaContextParallel( - cp_group=parallel_state.get_context_parallel_group(), - d_inner_local_tp=nheads_tp, - nheads_local_tp=nheads_tp, - ngroups_local_tp=ngroups_tp, - d_state=None, - conv1d_cp1=None, - dt_bias_cp1=None, - A_log_cp1=None, - D_cp1=None, - D_has_hdim=False, - ) - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py b/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py deleted file mode 100644 index 77d02c6960..0000000000 --- a/tests/unit_tests/ssm/test_mamba_hybrid_layer_allocation.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import math -import re - -import pytest -import torch - -from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols, allocate_layers - - -@pytest.mark.internal -class TestMambaHybridLayerAllocation: - - def test_hybrid_layer_allocation(self): - # The format for the test cases is: - # (layers_count, attention_ratio, mlp_ratio, override_pattern). - test_cases = [ - (9, 0.0, 0.0, "M*-M*-M*-"), - (9, 0.0, 0.0, "MMMMMMMMM"), - (30, 0.0, 0.0, None), - (8, 0.25, 0.25, "MM*-MM*-"), - (8, 0.5, 0.25, "M**-M**-"), - (48, 0.5, 0.2, None), - ] - for test in test_cases: - (layers_count, attention_ratio, mlp_ratio, override_pattern) = test - - layer_types = allocate_layers(*test) - - # Check that return value is in the right format. - assert isinstance(layer_types, list) - assert layers_count == len(layer_types) - - # Make sure all the layers are valid. - for layer_type in layer_types: - assert layer_type in Symbols.VALID - - # Make sure each layer is as requested by override_pattern. - if override_pattern is not None: - assert len(override_pattern) == len(layer_types) - for index, layer_type in enumerate(layer_types): - assert override_pattern[index] == layer_types[index] - else: - # Make sure the count of each type of layer is correct. - counts = {layer_type: 0 for layer_type in Symbols.VALID} # Initialize all to zero. - for layer_type in layer_types: - assert layer_type in counts - counts[layer_type] += 1 - # Check the ratios. - remainder = 1.0 - attention_ratio - mlp_ratio - assert remainder >= 0 - assert int(attention_ratio * layers_count + 0.5) == counts[Symbols.ATTENTION] - assert int(mlp_ratio * layers_count + 0.5) == counts[Symbols.MLP] - assert int(remainder * layers_count + 0.5) == counts[Symbols.MAMBA] - - # Make sure the ratios are as requested. - # This code is not working yet because capsys seems broken in Megatron. - # captured = capsys.readouterr() # Remove this output from the capture buffer. - # out = captured.out # Get stdout. - # if attention_ratio != 0 or mlp_ratio != 0: - # assert ( - # match := re.search(r'Actual attention ratio: (1\.0|0\.[0-9]+)\.', out) - # ) and math.isclose(match.group(1), attention_ratio) - # assert ( - # match := re.search(r'Actual mlp ratio: (1\.0|0\.[0-9]+)\.', out) - # ) and math.isclose(match.group(1), mlp_ratio) - - @pytest.mark.xfail(raises=ValueError) - def test_wrong_length_override_pattern(self): - # This override_pattern is too short. - layer_types = allocate_layers(9, 0.0, 0.0, "M*-M*-") - - @pytest.mark.xfail(raises=ValueError) - def test_wrong_number_of_layer_types_in_override_pattern(self): - # This override_pattern has too many mlps and not enough attention - layer_types = allocate_layers(8, 0.5, 0.25, "M*--M**-") diff --git a/tests/unit_tests/ssm/test_mamba_layer.py b/tests/unit_tests/ssm/test_mamba_layer.py deleted file mode 100644 index 59f9f832f5..0000000000 --- a/tests/unit_tests/ssm/test_mamba_layer.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.ssm.mamba_layer import MambaLayer -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -@pytest.mark.internal -class TestMambaLayer: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - hidden_size=256, # The Mamba layer places several constraints on this - # Need to specify num_attention_heads and num_layers or TransformerConfig - # will generate errors. - num_layers=1, - num_attention_heads=1, - use_cpu_initialization=True, - ) - modules = mamba_stack_spec.submodules.mamba_layer.submodules - model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups(required_pgs=['tp', 'cp']) - self.layer = MambaLayer(transformer_config, modules, model_comm_pgs=model_comm_pgs) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_gpu_forward(self): - layer = self.layer - layer.cuda() - micro_batch_size = 2 - sequence_length = 32 - hidden_states = torch.ones((sequence_length, micro_batch_size, layer.config.hidden_size)) - hidden_states = hidden_states.cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ) - attention_mask = attention_mask.cuda() - output = layer(hidden_states, attention_mask=attention_mask) - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == layer.config.hidden_size - assert output.dtype == torch.float32 diff --git a/tests/unit_tests/ssm/test_mamba_mixer.py b/tests/unit_tests/ssm/test_mamba_mixer.py deleted file mode 100644 index aa7278c704..0000000000 --- a/tests/unit_tests/ssm/test_mamba_mixer.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.inference.contexts.static_context import StaticInferenceContext -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.ssm.mamba_mixer import MambaMixer -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -@pytest.mark.internal -class TestMambaMixer: - - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def get_mixer(self, tp_size=1, cp_size=1, use_mem_eff_path=True): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=1, - context_parallel_size=cp_size, - ) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - hidden_size=256, # The Mamba layer places several constraints on this - # Need to specify num_attention_heads and num_layers or TransformerConfig - # will generate errors. - num_layers=1, - num_attention_heads=1, - use_cpu_initialization=True, - ) - modules = mamba_stack_spec.submodules.mamba_layer.submodules.mixer.submodules - model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups(required_pgs=['tp', 'cp']) - mixer = MambaMixer( - transformer_config, - modules, - transformer_config.hidden_size, - layer_number=1, - use_mem_eff_path=use_mem_eff_path, - model_comm_pgs=model_comm_pgs, - ) - mixer.cuda() - return mixer - - @pytest.mark.parametrize( - "tp_size,cp_size,use_mem_eff_path", - [ - (1, 1, True), - (1, 1, False), - (8, 1, True), - (4, 2, True), - (2, 4, True), - (1, 8, True), - (1, 8, False), - ], - ) - def test_gpu_forward(self, tp_size, cp_size, use_mem_eff_path): - mixer = self.get_mixer(1, 1, use_mem_eff_path) - micro_batch_size = 2 - sequence_length = 32 - hidden_states = torch.ones((sequence_length, micro_batch_size, mixer.config.hidden_size)) - hidden_states = hidden_states.cuda() - output, bias = mixer(hidden_states) - assert mixer.config.mamba_num_heads == None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == mixer.config.hidden_size - assert output.dtype == torch.float32 - - def test_variable_batch_size_inference(self): - mixer = self.get_mixer() - - # Test cases where batch size decreases, remains the same, and increases - micro_batch_sizes = [4, 2, 2, 8] - sequence_length = 32 - inference_context = StaticInferenceContext( - max_batch_size=max(micro_batch_sizes), max_sequence_length=sequence_length - ) - - for micro_batch_size in micro_batch_sizes: - inference_context.max_seqlen = inference_context.max_sequence_length - inference_context.seqlen_offset = inference_context.sequence_len_offset - hidden_states = torch.ones( - (sequence_length, micro_batch_size, mixer.config.hidden_size) - ) - hidden_states = hidden_states.cuda() - output, bias = mixer(hidden_states, inference_context=inference_context) - assert mixer.config.mamba_num_heads == None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == mixer.config.hidden_size - assert output.dtype == torch.float32 - - -class TestMambaMixerErrorChecks: - - @pytest.mark.parametrize( - "hidden_size, ngroups, tp_size, expected_error_message", - [ - (65, 8, 1, "d_inner must be evenly divisible by headdim"), - (96, 8, 2, "nheads must be evenly divisble by tp_size"), # nheads = 3 - (128, 2, 4, "ngroups must be evenly divisible by tp_size"), - (128, 8, 4, "nheads must be evenly divisible by ngroups"), # nheads = 4 - ], - ) - def test_error_check(self, hidden_size, ngroups, tp_size, expected_error_message): - Utils.initialize_model_parallel(tp_size) - transformer_config = TransformerConfig( - hidden_size=hidden_size, - num_layers=1, - num_attention_heads=1, - use_cpu_initialization=True, - mamba_num_groups=ngroups, - ) - submodules = mamba_stack_spec.submodules.mamba_layer.submodules.mixer.submodules - model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups(required_pgs=['tp', 'cp']) - with pytest.raises(AssertionError, match=expected_error_message): - MambaMixer( - transformer_config, - submodules, - transformer_config.hidden_size, - model_comm_pgs=model_comm_pgs, - ) - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_cross_entropy.py b/tests/unit_tests/tensor_parallel/test_cross_entropy.py deleted file mode 100644 index 66982fd234..0000000000 --- a/tests/unit_tests/tensor_parallel/test_cross_entropy.py +++ /dev/null @@ -1,34 +0,0 @@ -import numpy as np -import torch - -from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy -from tests.unit_tests.test_utilities import Utils - - -def test_vocab_parallel_cross_entropy(): - Utils.initialize_model_parallel(4, 2) - vocab_parallel_logits = torch.range(0, 7).repeat(16, 4).cuda() - target = torch.arange(0, 32, 2).cuda() - output = vocab_parallel_cross_entropy(vocab_parallel_logits, target) - expected_output = torch.tensor( - [ - 10.2309, - 8.2309, - 6.2309, - 4.2309, - 10.2309, - 8.2309, - 6.2309, - 4.2309, - 10.2309, - 8.2309, - 6.2309, - 4.2309, - 10.2309, - 8.2309, - 6.2309, - 4.2309, - ] - ).cuda() - assert torch.equal(torch.round(expected_output), torch.round(output)) - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_data.py b/tests/unit_tests/tensor_parallel/test_data.py deleted file mode 100644 index 211d48b4fd..0000000000 --- a/tests/unit_tests/tensor_parallel/test_data.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - -from megatron.core.tensor_parallel.data import broadcast_data -from tests.unit_tests.test_utilities import Utils - - -def test_broadcast_data(): - Utils.initialize_model_parallel(2, 4) - input_data = { - 0: torch.ones((8, 8)).cuda() * 0.0, - 1: torch.ones((8, 8)).cuda() * 1.0, - 2: torch.ones((8, 8)).cuda() * 2.0, - 3: torch.ones((8, 8)).cuda() * 3.0, - 4: torch.ones((8, 8)).cuda() * 4.0, - 5: torch.ones((8, 8)).cuda() * 5.0, - 6: torch.ones((8, 8)).cuda() * 6.0, - 7: torch.ones((8, 8)).cuda() * 7.0, - } - dtype = torch.float32 - actual_output = broadcast_data([0, 1], input_data, dtype) - assert torch.equal(actual_output[0], input_data[0]) - assert torch.equal(actual_output[1], input_data[1]) - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_initialization.py b/tests/unit_tests/tensor_parallel/test_initialization.py deleted file mode 100644 index e0d835f1e7..0000000000 --- a/tests/unit_tests/tensor_parallel/test_initialization.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -import megatron.core.parallel_state as ps -from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear -from megatron.core.tensor_parallel.layers import ( - ColumnParallelLinear, - RowParallelLinear, - VocabParallelEmbedding, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class Test: - - transformer_config = TransformerConfig( - num_layers=1, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_embedding_init(self): - - Utils.initialize_model_parallel(1, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(42) - - tp1 = VocabParallelEmbedding( - num_embeddings=16, - embedding_dim=4, - init_method=self.transformer_config.init_method, - config=self.transformer_config, - ).weight - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel(4, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(41) # intentionally different. - tp4 = VocabParallelEmbedding( - num_embeddings=16, - embedding_dim=4, - init_method=self.transformer_config.init_method, - config=self.transformer_config, - ).weight - - rank = ps.get_tensor_model_parallel_rank() - assert tp4.shape[0] * 4 == tp1.shape[0] - assert torch.equal(tp1[rank * 4 : (rank + 1) * 4], tp4) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_row_init(self): - - Utils.initialize_model_parallel(1, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(42) - - tp1 = RowParallelLinear( - input_size=16, - output_size=16, - init_method=self.transformer_config.init_method, - bias=True, - input_is_parallel=False, - config=self.transformer_config, - skip_bias_add=False, - ).weight - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel(4, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(41) # intentionally different. - tp4 = RowParallelLinear( - input_size=16, - output_size=16, - init_method=self.transformer_config.init_method, - bias=True, - input_is_parallel=False, - config=self.transformer_config, - skip_bias_add=False, - ).weight - - rank = ps.get_tensor_model_parallel_rank() - assert tp4.shape[1] * 4 == tp1.shape[1] - assert torch.equal(tp1[:, rank * 4 : (rank + 1) * 4], tp4) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_col_init(self): - - Utils.initialize_model_parallel(1, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(42) - - tp1 = ColumnParallelLinear( - input_size=16, - output_size=16, - init_method=self.transformer_config.init_method, - bias=True, - config=self.transformer_config, - skip_bias_add=False, - ).weight - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel(4, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(41) # intentionally different. - tp4 = ColumnParallelLinear( - input_size=16, - output_size=16, - init_method=self.transformer_config.init_method, - bias=True, - config=self.transformer_config, - skip_bias_add=False, - ).weight - - rank = ps.get_tensor_model_parallel_rank() - assert tp4.shape[0] * 4 == tp1.shape[0] - assert torch.equal(tp1[rank * 4 : (rank + 1) * 4], tp4) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.timeout(100) - def test_te_col_init(self): - - Utils.initialize_model_parallel(1, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(42) - - tp1 = TEColumnParallelLinear( - input_size=16, - output_size=16, - init_method=self.transformer_config.init_method, - bias=True, - config=self.transformer_config, - skip_bias_add=False, - gather_output=False, - is_expert=False, - ).weight - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel(4, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(41) # intentionally different. - tp4 = TEColumnParallelLinear( - input_size=16, - output_size=16, - init_method=self.transformer_config.init_method, - bias=True, - config=self.transformer_config, - skip_bias_add=False, - gather_output=False, - is_expert=False, - ).weight - - if torch.distributed.get_rank() == 0: - assert tp4.shape[0] * 4 == tp1.shape[0] - assert torch.allclose(tp1[:4], tp4) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.timeout(100) - def test_te_row_init(self): - - Utils.initialize_model_parallel(1, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(42) - - tp1 = TERowParallelLinear( - input_size=16, - output_size=16, - init_method=self.transformer_config.init_method, - bias=True, - input_is_parallel=True, - config=self.transformer_config, - skip_bias_add=False, - is_expert=False, - ).weight - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel(4, 1) - torch.manual_seed(42) - model_parallel_cuda_manual_seed(41) # intentionally different. - tp4 = TERowParallelLinear( - input_size=16, - output_size=16, - init_method=self.transformer_config.init_method, - bias=True, - input_is_parallel=True, - config=self.transformer_config, - skip_bias_add=False, - is_expert=False, - ).weight - - if torch.distributed.get_rank() == 0: - assert tp4.shape[1] * 4 == tp1.shape[1] - assert torch.allclose(tp1[:, :4], tp4) diff --git a/tests/unit_tests/tensor_parallel/test_layers.py b/tests/unit_tests/tensor_parallel/test_layers.py deleted file mode 100644 index d635e164d1..0000000000 --- a/tests/unit_tests/tensor_parallel/test_layers.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import pytest -import torch - -from megatron.core.tensor_parallel.layers import linear_with_frozen_weight -from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region -from tests.unit_tests.test_utilities import Utils - - -@pytest.mark.parametrize("tensor_parallel,allreduce_dgrad", [(1, False), (8, True)]) -def test_LinearWithFrozenWeight(tensor_parallel, allreduce_dgrad): - Utils.initialize_model_parallel(tensor_parallel, 1) - - size_per_partition = int(8 / tensor_parallel) - - # Input is an 8x8 identity matrix. - input_data = torch.eye(8).cuda() - input_data.requires_grad = True - - # Weight is an 8x8 matrix of all ones. If tensor parallelism > 1, the weight is partitioned evenly across GPUs. - weight = torch.ones((size_per_partition, 8)).cuda() - - # Bias is a vector of length 8 of all zeros. If tensor parallelism > 1, the bias is partitioned evenly across GPUs - bias = torch.zeros((size_per_partition)).cuda() - - gradient_accumulation_fusion = False - sequence_parallel = False - grad_output_buffer = None - wgrad_deferral_limit = None - - output_parallel = linear_with_frozen_weight( - input_data, - weight, - bias, - gradient_accumulation_fusion, - allreduce_dgrad, - sequence_parallel, - grad_output_buffer, - wgrad_deferral_limit, - ) - output = gather_from_tensor_model_parallel_region( - output_parallel - ) # no-op if tensor_parallel == 1. - output.sum().backward() - - expected_output = torch.ones(8).cuda() - expected_grad = 8 * torch.ones(8).cuda() - - assert torch.allclose(output, expected_output) - assert torch.allclose(input_data.grad, expected_grad) - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_mappings.py b/tests/unit_tests/tensor_parallel/test_mappings.py deleted file mode 100644 index 16751ab628..0000000000 --- a/tests/unit_tests/tensor_parallel/test_mappings.py +++ /dev/null @@ -1,199 +0,0 @@ -import pytest -import torch - -from megatron.core.tensor_parallel import mappings -from megatron.core.utils import get_tensor_model_parallel_group_if_none -from tests.unit_tests.test_utilities import Utils - - -@pytest.mark.internal -def test_CopyToModelParallelRegion(): - Utils.initialize_model_parallel(4, 2) - input_data = torch.ones((1)).cuda() * Utils.rank - - tp_group = get_tensor_model_parallel_group_if_none(tp_group=None) - - class Ctx: - group = tp_group - - output_data, _ = mappings._CopyToModelParallelRegion.backward(Ctx(), input_data) - result = torch.ones(1).cuda() - result = result * 22 if Utils.rank >= 4 else result * 6 - assert torch.equal(output_data, result) - assert torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data)) - assert torch.equal( - input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data, tp_group) - ) - Utils.destroy_model_parallel() - - -@pytest.mark.internal -def test_ReduceFromModelParallelRegion(): - Utils.initialize_model_parallel(4, 2) - input_data = torch.ones((1)).cuda() * Utils.rank - - tp_group = get_tensor_model_parallel_group_if_none(tp_group=None) - output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data, tp_group) - - result = torch.ones(1).cuda() - result = result * 22 if Utils.rank >= 4 else result * 6 - assert torch.equal(output_data, result) - - input_data = torch.ones((1)).cuda() * Utils.rank - assert torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result) - - class Ctx: - group = tp_group - - output_data, _ = mappings._ReduceFromModelParallelRegion.backward(Ctx(), input_data) - assert torch.equal(input_data, output_data) - Utils.destroy_model_parallel() - - -@pytest.mark.internal -def test_ScatterToModelParallelRegion(): - Utils.initialize_model_parallel(4, 2) - input_data = torch.rand((8, 4)).cuda() - - tp_group = get_tensor_model_parallel_group_if_none(tp_group=None) - output_data = mappings.scatter_to_tensor_model_parallel_region(input_data) - - req_dim = int(Utils.rank % (Utils.world_size / 2)) - assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1))) - output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data, tp_group) - assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1))) - - input_data = torch.ones(8).cuda() * Utils.rank - - class Ctx: - group = tp_group - - actual_output_data, _ = mappings._ScatterToModelParallelRegion.backward(Ctx(), input_data) - expected_output = torch.cat( - (torch.ones(8) * 0, torch.ones(8) * 1, torch.ones(8) * 2, torch.ones(8) * 3) - ).cuda() - if Utils.rank >= 4: - expected_output = expected_output + 4 - assert torch.equal(actual_output_data, expected_output) - Utils.destroy_model_parallel() - - -@pytest.mark.internal -def test_GatherFromModelParallelRegion(): - Utils.initialize_model_parallel(4, 2) - input_data = torch.rand((8, 4)).cuda() - - tp_group = get_tensor_model_parallel_group_if_none(tp_group=None) - req_dim = int(Utils.rank % (Utils.world_size / 2)) - - class Ctx: - group = tp_group - - output_data, _ = mappings._GatherFromModelParallelRegion.backward(Ctx(), input_data) - assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1))) - - input_data = torch.ones(8).cuda() * Utils.rank - actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data) - expected_output = torch.cat( - (torch.ones(8) * 0, torch.ones(8) * 1, torch.ones(8) * 2, torch.ones(8) * 3) - ).cuda() - if Utils.rank >= 4: - expected_output = expected_output + 4 - assert torch.equal(actual_output_data, expected_output) - assert torch.equal( - mappings._GatherFromModelParallelRegion.symbolic(None, input_data, tp_group), - expected_output, - ) - Utils.destroy_model_parallel() - - -@pytest.mark.internal -def test_ScatterToSequenceParallelRegion(): - Utils.initialize_model_parallel(4, 2) - input_data = torch.rand((8, 4)).cuda() - - tp_group = get_tensor_model_parallel_group_if_none(tp_group=None) - req_dim = int(Utils.rank % (Utils.world_size / 2)) * 2 - output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data, tp_group) - assert torch.equal(output_data, input_data[req_dim : req_dim + 2, :]) - output_data = mappings.scatter_to_sequence_parallel_region(input_data) - assert torch.equal(output_data, input_data[req_dim : req_dim + 2, :]) - - input_data = torch.ones(4).cuda() * Utils.rank - - class Ctx: - group = tp_group - - output_data, _ = mappings._ScatterToModelParallelRegion.backward(Ctx(), input_data) - expected_output = torch.concat( - (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) - ).cuda() - if Utils.rank >= 4: - expected_output = expected_output + 4 - assert torch.equal(output_data, expected_output) - Utils.destroy_model_parallel() - - -@pytest.mark.internal -def test_GatherFromSequenceParallelRegion(): - Utils.initialize_model_parallel(4, 2) - input_data = torch.ones(4).cuda() * Utils.rank - - tp_group = get_tensor_model_parallel_group_if_none(tp_group=None) - output_data = mappings.gather_from_sequence_parallel_region(input_data) - expected_output = torch.concat( - (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) - ).cuda() - if Utils.rank >= 4: - expected_output = expected_output + 4 - assert torch.equal(output_data, expected_output) - assert torch.equal( - mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data, tp_group), - expected_output, - ) - input_data = torch.vstack( - (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) - ).cuda() - - class Ctx: - tensor_parallel_output_grad = True - output_split_sizes = None - group = tp_group - use_global_buffer = False - - output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data) - expected_output = torch.ones((1, 4)).cuda() * 4 * int(Utils.rank % 4) - assert torch.equal(output_data[0], expected_output) - Utils.destroy_model_parallel() - - -@pytest.mark.internal -def test_ReduceScatterToSequenceParallelRegion(): - Utils.initialize_model_parallel(4, 2) - input_data = torch.vstack( - (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) - ).cuda() - - tp_group = get_tensor_model_parallel_group_if_none(tp_group=None) - output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data) - expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4) - assert torch.equal(output_data[0], expected_output) - assert torch.equal( - mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data, tp_group), - expected_output.reshape((1, 4)), - ) - input_data = torch.ones(4).cuda() * Utils.rank - - class Ctx: - input_split_sizes = None - group = tp_group - use_global_buffer = False - - output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(Ctx(), input_data) - expected_output = torch.concat( - (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) - ).cuda() - if Utils.rank >= 4: - expected_output = expected_output + 4 - assert torch.equal(output_data[0], expected_output) - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_random.py b/tests/unit_tests/tensor_parallel/test_random.py deleted file mode 100644 index 47b607b879..0000000000 --- a/tests/unit_tests/tensor_parallel/test_random.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytest -import torch - -from megatron.core.tensor_parallel.random import ( - CheckpointWithoutOutput, - CudaRNGStatesTracker, - checkpoint, - get_cuda_rng_tracker, - model_parallel_cuda_manual_seed, -) -from tests.unit_tests.test_utilities import Utils - - -def test_cuda_rng_states_tracker(): - rng_tracker = CudaRNGStatesTracker() - rng_tracker.set_states({"state1": 1234}) - assert rng_tracker.get_states()["state1"] == 1234 - rng_tracker.reset() - assert rng_tracker.get_states() == {} - seed = 1111 - rng_tracker.add("state2", seed) - with pytest.raises(Exception): - assert rng_tracker.add("state3", seed) - with pytest.raises(Exception): - assert rng_tracker.add("state2", 111) - assert rng_tracker.get_states()['state2'] is not None - with pytest.raises(Exception): - assert () - - rng_tracker.fork("state2") - torch.cuda.manual_seed(seed) - rng_state = torch.cuda.get_rng_state() - assert torch.equal(rng_tracker.get_states()['state2'], rng_state) - - -def test_model_parallel_cuda_manual_seed(): - Utils.initialize_model_parallel(4, 2) - model_parallel_cuda_manual_seed(0, force_reset_rng=True) - rng_tracker = get_cuda_rng_tracker() - assert rng_tracker.get_states()['model-parallel-rng'] is not None - Utils.destroy_model_parallel() - - -def test_checkpoint(): - def test_forward(*input): - return input[0] + input[1] - - assert torch.equal( - torch.ones(16) * 3, checkpoint(test_forward, None, torch.ones(16), torch.ones(16) * 2) - ) - Utils.initialize_model_parallel() - input1 = torch.ones((4, 4)) - checkpoint(test_forward, True, input1, torch.ones((4, 4)) * 2) - assert torch.equal(torch.ones(input1.numel()).cuda(), input1) - Utils.destroy_model_parallel() - - -def test_checkpoint_without_output(): - def normal_forward(input): - x = torch.nn.functional.gelu(input) - y = x * input - return y - - def checkpoint_forward(input): - checkpoint = CheckpointWithoutOutput() - x = checkpoint.checkpoint(torch.nn.functional.gelu, input) - y = x * input - checkpoint.discard_output_and_register_recompute(y) - return y - - Utils.initialize_model_parallel() - - input1 = torch.ones((4, 4)) - input1.requires_grad_(True) - output1 = normal_forward(input1) - input2 = torch.ones((4, 4)) - input2.requires_grad_(True) - output2 = checkpoint_forward(input2) - assert torch.equal(output1, output2) - - output1.backward(torch.ones((4, 4)), retain_graph=True) - output2.backward(torch.ones((4, 4)), retain_graph=True) - assert torch.equal(input1.grad, input2.grad) - - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py b/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py deleted file mode 100644 index 5df774e5ff..0000000000 --- a/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch - -import megatron.core.parallel_state as ps -import megatron.core.tensor_parallel.utils as util -from tests.unit_tests.test_utilities import Utils - -rank = Utils.rank - - -def test_split_tensor_along_last_dim(): - input_tensor = torch.rand((3, 4)) - torch.equal(input_tensor[0:2, 0:2], util.split_tensor_along_last_dim(input_tensor, 2)[0]) - torch.equal(input_tensor[2:, 2:], util.split_tensor_along_last_dim(input_tensor, 2)[1]) - - -def test_split_tensor_into_1d_equal_chunks(): - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) - input_tensor = torch.rand((3, 4)) - output_tensor = util.split_tensor_into_1d_equal_chunks(input_tensor) - if rank % 2 == 0: - start = 0 - end = int(input_tensor.numel() / 2) - else: - start = int(input_tensor.numel() / 2) - end = input_tensor.numel() - - assert torch.equal(output_tensor, input_tensor.flatten()[start:end]) - Utils.destroy_model_parallel() - - -def test_gather_split_1d_tensor(): - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) - input_tensor = torch.ones((2, 4)).cuda() * rank - actual_output_tensor = util.gather_split_1d_tensor(input_tensor) - if rank % 2 == 0: - expected_output_tensor = torch.concat((input_tensor.flatten(), input_tensor.flatten() + 1)) - else: - expected_output_tensor = torch.concat((input_tensor.flatten() - 1, input_tensor.flatten())) - assert torch.equal(actual_output_tensor, expected_output_tensor) - Utils.destroy_model_parallel() - - -def test_vocab(): - global_vocab_size = 1600 - per_partition_vocab_size = 1600 / Utils.world_size - assert (rank * per_partition_vocab_size, (rank + 1) * per_partition_vocab_size) == ( - util.VocabUtility.vocab_range_from_per_partition_vocab_size( - global_vocab_size // Utils.world_size, rank, Utils.world_size - ) - ) - assert (rank * per_partition_vocab_size, (rank + 1) * per_partition_vocab_size) == ( - util.VocabUtility.vocab_range_from_global_vocab_size( - global_vocab_size, rank, Utils.world_size - ) - ) diff --git a/tests/unit_tests/test_basic.py b/tests/unit_tests/test_basic.py deleted file mode 100644 index d2a60f92c8..0000000000 --- a/tests/unit_tests/test_basic.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_import(): - import megatron diff --git a/tests/unit_tests/test_checkpointing.py b/tests/unit_tests/test_checkpointing.py deleted file mode 100644 index 3af9094c06..0000000000 --- a/tests/unit_tests/test_checkpointing.py +++ /dev/null @@ -1,336 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Note: --ckpt-format torch_dist has tests in tests/unit_tests/dist_checkpointing. -import os -from types import SimpleNamespace -from typing import Optional -from unittest import mock - -import pytest -import torch -import torch.distributed.checkpoint - -from megatron.core.num_microbatches_calculator import ( - init_num_microbatches_calculator, - unset_num_microbatches_calculator, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_torch_min_version -from megatron.training.checkpointing import ( - CheckpointType, - _build_sharded_state_dict_metadata, - _load_base_checkpoint, - get_checkpoint_tracker_filename, - load_checkpoint, - save_checkpoint, -) -from megatron.training.global_vars import set_args -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - - -class MockModel(MegatronModule): - """Dummy megatron model.""" - - def __init__(self, config): - super().__init__(config=config) - self.l = torch.nn.Linear(1, 2) - torch.nn.init.ones_(self.l.weight) - torch.nn.init.zeros_(self.l.bias) - self._called_metadata = [] - - def sharded_state_dict(self, *args, metadata: Optional[dict] = None, **kwargs): - self._called_metadata.append(metadata) - return self.state_dict() - - -class MockState: - def __init__(self, state_dict): - self._state_dict = state_dict - self.is_stub_optimizer = False - self._called_metadata = [] - - def state_dict(self, is_loading=False): - return self._state_dict - - def load_state_dict(self, state_dict): - self._state_dict = state_dict - - def save_parameter_state(self, *args, **kwargs): - pass - - def load_parameter_state(self, *args, **kwargs): - pass - - def sharded_state_dict(self, *args, metadata: Optional[dict] = None, **kwargs): - self._called_metadata.append(metadata) - return self.state_dict() - - -def create_checkpoint(load_path, ckpt_format): - """Setup a dummy checkpoint directory.""" - iteration = 123 - ckpt_dir = load_path / "iter_{:07d}".format(iteration) - tracker_path = get_checkpoint_tracker_filename(load_path) - with open(tracker_path, "w") as f: - f.write(str(iteration)) - - state_dict = {"args": "dummy", "iteration": iteration} - - if ckpt_format == "torch": - # Torch checkpoints use a specific directory structure. - pt_dir = ckpt_dir / "mp_rank_00" - pt_dir.mkdir(parents=True) - torch.save(state_dict, pt_dir / "model_optim_rng.pt") - elif ckpt_format == "torch_dcp" and is_torch_min_version("2.4.0"): - torch.distributed.checkpoint.save(state_dict, checkpoint_id=ckpt_dir) - - -@pytest.fixture -def create_args(): - """Setup dummy args.""" - args = SimpleNamespace() - args.finetune = False - args.non_persistent_global_ckpt_dir = None - args.non_persistent_ckpt_type = None - args.non_persistent_save_interval = None - args.exit_on_missing_checkpoint = True - args.async_save = False - args.data_parallel_random_init = False - args.no_save_optim = False - args.no_save_rng = False - args.no_load_optim = False - args.no_load_rng = False - args.log_progress = False - args.ckpt_fully_parallel_save = False - args.auto_detect_ckpt_format = False - args.retro_add_retriever = False - args.ckpt_convert_update_legacy_dist_opt_format = False - args.ckpt_step = None - - yield args - - -@pytest.fixture -def create_ckpt_load_args(create_args): - """Setup dummy args allowing checkpoint load.""" - args = create_args - args.auto_detect_ckpt_format = False - args.consumed_train_samples = 0 - args.skipped_train_samples = 0 - args.consumed_valid_samples = 0 - args.num_layers = 1 - args.hidden_size = 2 - args.num_attention_heads = 1 - args.add_position_embedding = False - args.vocab_file = None - args.tensor_model_parallel_size = 1 - args.pipeline_model_parallel_size = 1 - args.ckpt_assume_constant_structure = False - args.ckpt_fully_parallel_save = False - args.ckpt_fully_parallel_load = False - args.dist_ckpt_strictness = 'assume_ok_unexpected' - args.use_megatron_fsdp = False - args.strict_fsdp_dtensor_load = True - - yield args - - -@pytest.fixture -def init_model_parallel(): - """Init torch distributed.""" - Utils.initialize_model_parallel(1, 1) - init_num_microbatches_calculator(0, None, 1, 1, 1) - model_parallel_cuda_manual_seed(123) - yield # Run the actual test. - Utils.destroy_model_parallel() - unset_num_microbatches_calculator() - - -@pytest.mark.parametrize("ckpt_format", ["torch_dcp"]) -def test_load_base_checkpoint( - init_model_parallel, create_ckpt_load_args, ckpt_format, tmp_path_dist_ckpt -): - """Test _load_base_checkpoint.""" - - if ckpt_format == "torch_dcp" and not is_torch_min_version("2.4.0"): - pytest.skip("torch_dcp requires torch >= 2.4.0") - - # TempNamedDir uses the same directory for all ranks in a multi-GPU setup. Cleanup is handled. - with TempNamedDir(tmp_path_dist_ckpt / "test_load_base_checkpoint", sync=True) as load_dir: - create_checkpoint(load_dir, ckpt_format) - args = create_ckpt_load_args - args.ckpt_format = ckpt_format - - state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint( - load_dir, args, rank0=True - ) - - assert state_dict["args"] == "dummy" - assert state_dict["iteration"] == 123 - - expected_ckpt_path = None - if ckpt_format == "torch": - expected_ckpt_path = str(load_dir / "iter_0000123" / "mp_rank_00" / "model_optim_rng.pt") - elif ckpt_format == "torch_dcp": - expected_ckpt_path = str(load_dir / "iter_0000123") - - assert checkpoint_name == expected_ckpt_path - assert not release - - expected_ckpt_type = None - if ckpt_format == "torch": - expected_ckpt_type = CheckpointType.LEGACY - elif ckpt_format == "torch_dcp": - expected_ckpt_type = CheckpointType.TORCH_DCP - - assert ckpt_type == expected_ckpt_type - - -@pytest.mark.parametrize("ckpt_format", ["torch", "torch_dcp"]) -def test_save_checkpoint(init_model_parallel, create_args, tmp_path_dist_ckpt, ckpt_format): - """Test save_checkpoint.""" - args = create_args - args.ckpt_format = ckpt_format - - if ckpt_format == "torch_dcp" and not is_torch_min_version("2.4.0"): - pytest.skip("torch_dcp requires torch >= 2.4.0") - - args.use_distributed_optimizer = ckpt_format != "torch_dcp" - args.use_dist_ckpt = ckpt_format != "torch" - - iteration = 123 - config = TransformerConfig(num_layers=1, kv_channels=1) - model = MockModel(config) - optimizer = MockState({"optimizer": "optimizer_state"}) - opt_param_scheduler = MockState({"opt_param_scheduler": "scheduler_state"}) - num_floating_point_operations_so_far = 456 - - with TempNamedDir(tmp_path_dist_ckpt / "test_save_checkpoint", sync=True) as save_dir: - args.save = save_dir - set_args(args) - - save_checkpoint( - iteration, [model], optimizer, opt_param_scheduler, num_floating_point_operations_so_far - ) - - with open(args.save / "latest_checkpointed_iteration.txt", "r") as f: - assert iteration == int(f.read()) - - ckpt_dir = args.save / "iter_0000123" - - expected_ckpt_path = None - if ckpt_format == "torch": - expected_ckpt_path = ckpt_dir / "mp_rank_00" / "model_optim_rng.pt" - elif ckpt_format == "torch_dcp": - expected_ckpt_path = ckpt_dir / ".metadata" - - assert os.path.exists(expected_ckpt_path) - - -@pytest.mark.parametrize("ckpt_format", ["torch"]) -def test_load_checkpoint( - init_model_parallel, create_ckpt_load_args, tmp_path_dist_ckpt, ckpt_format -): - """Test load_checkpoint.""" - args = create_ckpt_load_args - args.ckpt_format = ckpt_format - args.use_distributed_optimizer = ckpt_format != "torch_dcp" - args.use_dist_ckpt = ckpt_format != "torch" - - if ckpt_format == "torch_dcp" and not is_torch_min_version("2.4.0"): - pytest.skip("torch_dcp requires torch >= 2.4.0") - - with TempNamedDir(tmp_path_dist_ckpt / "test_load_checkpoint", sync=True) as ckpt_dir: - args.load = ckpt_dir - args.save = ckpt_dir - set_args(args) - - # Create and save a checkpoint first. - iteration = 123 - config = TransformerConfig(num_layers=1, kv_channels=1) - model = MockModel(config) - - optimizer = MockState({"optimizer": "optimizer_state"}) - opt_param_scheduler = MockState({"opt_param_scheduler": "scheduler_state"}) - num_floating_point_operations_so_far = 456 - - save_checkpoint( - iteration, [model], optimizer, opt_param_scheduler, num_floating_point_operations_so_far - ) - - # Create new model, optimizer, and scheduler instances to load into. - new_model = MockModel(config) - new_optimizer = MockState({"optimizer": "dummy1"}) - new_opt_param_scheduler = MockState({"opt_param_scheduler": "dummy2"}) - - # Load checkpoint - loaded_iter, loaded_flops = load_checkpoint( - [new_model], new_optimizer, new_opt_param_scheduler, strict=True - ) - - assert loaded_iter == iteration - assert loaded_flops == num_floating_point_operations_so_far - - for k in model.state_dict(): - assert torch.equal(model.state_dict()[k], new_model.state_dict()[k]) - - assert new_optimizer.state_dict() == optimizer.state_dict() - assert new_opt_param_scheduler.state_dict() == opt_param_scheduler.state_dict() - - -def test_dist_checkpoint_versioning(init_model_parallel, tmp_path_dist_ckpt, create_ckpt_load_args): - """Test distributed checkpoint versioning.""" - args = create_ckpt_load_args - args.ckpt_format = 'torch_dist' - args.use_distributed_optimizer = True - args.use_dist_ckpt = True - - with TempNamedDir( - tmp_path_dist_ckpt / "test_dist_checkpoint_versioning", sync=True - ) as ckpt_dir: - args.load = ckpt_dir - args.save = ckpt_dir - set_args(args) - - # Create and save a checkpoint first. - iteration = 123 - config = TransformerConfig(num_layers=1, kv_channels=1) - model = MockModel(config) - - optimizer = MockState({"optimizer": "optimizer_state"}) - opt_param_scheduler = MockState({"opt_param_scheduler": "scheduler_state"}) - num_fp_ops = 456 - - base_metadata = _build_sharded_state_dict_metadata(args) - first_job_mock_metadata = {**base_metadata, 'metadata_A': 42, 'metadata_B_soon_removed': 43} - with mock.patch( - 'megatron.training.checkpointing._build_sharded_state_dict_metadata', - return_value=first_job_mock_metadata, - ): - save_checkpoint(iteration, [model], optimizer, opt_param_scheduler, num_fp_ops) - - second_job_mock_metadata = { - **base_metadata, - 'metadata_A': 'changed_default_value', - 'metadata_C_new': {'nested': 'val'}, - } - with mock.patch( - 'megatron.training.checkpointing._build_sharded_state_dict_metadata', - return_value=second_job_mock_metadata, - ): - # Load checkpoint (into the same model, we don't check load correctness here) - load_checkpoint([model], optimizer, opt_param_scheduler, strict=True) - assert optimizer._called_metadata[-1] == first_job_mock_metadata - - # Save the checkpoint again to check if the content metadata for the new checkpoint will be new - save_checkpoint(iteration, [model], optimizer, opt_param_scheduler, num_fp_ops) - assert optimizer._called_metadata[-1] == second_job_mock_metadata - - assert optimizer._called_metadata == model._called_metadata - assert optimizer._called_metadata == [ - first_job_mock_metadata, - first_job_mock_metadata, - second_job_mock_metadata, - ] diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py deleted file mode 100644 index bdcf00c89a..0000000000 --- a/tests/unit_tests/test_fp8_param.py +++ /dev/null @@ -1,261 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import contextlib -import os -import sys - -import pytest -import torch -from transformer_engine.pytorch.fp8 import check_fp8_support - -from megatron.core.enums import ModelType -from megatron.core.fp8_utils import is_float8tensor -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.utils import is_te_min_version -from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args -from megatron.training.global_vars import ( - destroy_global_vars, - get_args, - set_args, - set_global_variables, -) -from megatron.training.training import get_model, setup_model_and_optimizer -from megatron.training.utils import get_device_arch_version -from tests.unit_tests.test_utilities import Utils - -_SEED = 1234 -fp8_available, reason_for_no_fp8 = check_fp8_support() - - -class TestFP8Param: - - def setup_method(self, method): - self.seq_length = 512 - self.micro_batch_size = 2 - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - - def teardown_method(self, method): - Utils.destroy_model_parallel() - destroy_global_vars() - destroy_num_microbatches_calculator() - - def model_provider( - self, - pre_process=True, - post_process=True, - layer_spec_fn=get_gpt_layer_with_transformer_engine_spec, - **config_kwargs, - ): - model_parallel_cuda_manual_seed(_SEED) - args = get_args() - config = core_transformer_config_from_args(args) - transformer_layer_spec = layer_spec_fn() - return GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=args.vocal_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - ) - - def create_test_args( - self, tp, recipe, sequence_length, micro_batch_size, inference=False, **kwargs - ): - destroy_global_vars() - destroy_num_microbatches_calculator() - - sys.argv = ['test_fp8_param.py'] - args = parse_args() - args.num_layers = 4 - args.vocal_size = 128800 - args.hidden_size = 128 - args.num_attention_heads = 8 - args.max_position_embeddings = 512 - args.micro_batch_size = micro_batch_size - args.create_attention_mask_in_dataloader = True - args.seq_length = sequence_length - args.tensor_model_parallel_size = tp - args.sequence_parallel = True if tp > 1 else False - args.pipeline_model_parallel_size = 1 - args.context_parallel_size = 1 - args.train_iters = 10 - args.lr = 3e-5 - args.bf16 = True - args.add_bias_linear = False - args.swiglu = True - args.use_distributed_optimizer = not inference - args.fp8 = "e4m3" - args.fp8_recipe = recipe - args.fp8_param_gather = True - - # MXFP8 test settings - if recipe == "mxfp8": - args.reuse_grad_buf_for_mxfp8_param_ag = True - - for key, value in kwargs.items(): - assert hasattr(args, key) - setattr(args, key, value) - - validate_args(args) - set_global_variables(args, False) - return args - - def get_batch(self, seq_length, micro_batch_size): - data = list(range(seq_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - labels = 1 + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, seq_length, seq_length), dtype=bool - ).cuda() - loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda() - return input_ids, labels, position_ids, attention_mask, loss_mask - - def _run_test_helper(self, tp_size, recipe, inference: bool = False, **kwargs): - """Test fp8_param with gpt_model.""" - args = self.create_test_args( - tp_size, recipe, self.seq_length, self.micro_batch_size, inference=inference, **kwargs - ) - - if recipe == "blockwise" and args.sequence_parallel: - assert ( - tp_size * 128 <= self.seq_length - ), "Blockwise recipe and sequence parallelism requires tp_size * 128 <= seq_length" - - set_args(args) - torch.manual_seed(_SEED) - Utils.initialize_model_parallel(tensor_model_parallel_size=tp_size) - input_ids, labels, position_ids, attention_mask, loss_mask = self.get_batch( - self.seq_length, self.micro_batch_size - ) - if inference: - gpt_model = get_model( - self.model_provider, ModelType.encoder_or_decoder, wrap_with_ddp=False - ) - gpt_model[0].eval() - optimizer = None - else: - gpt_model, optimizer, _ = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder - ) - assert len(gpt_model) == 1 # Assume only one model in the model provider. - - num_fp8_params = 0 - for _, param in gpt_model[0].named_parameters(): - if not inference: - assert param.requires_grad - assert param.main_grad is not None - if is_float8tensor(param): - num_fp8_params += 1 - - # Verify the number of fp8 params. - fp8_layers = args.num_layers - if kwargs.get("first_last_layers_bf16", False): - fp8_layers -= kwargs["num_layers_at_start_in_bf16"] - fp8_layers -= kwargs["num_layers_at_end_in_bf16"] - # Each layer has 4 GEMM weights: qkv, proj, fc1, fc2. - assert num_fp8_params == 4 * fp8_layers - - for i in range(100): - if not inference: - gpt_model[0].zero_grad_buffer() - optimizer.zero_grad() - - gpt_model[0].set_is_first_microbatch() - output = gpt_model[0].forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - ) - - # Check output shapes - assert output.shape[0] == self.micro_batch_size - assert output.shape[1] == self.seq_length - - if inference: - continue - - # Verify gradients - loss = output.mean() - loss.backward() - for name, param in gpt_model[0].named_parameters(): - assert param.main_grad is not None - - update_successful, _, _ = optimizer.step() - assert update_successful - - def run_test(self, tp_size, recipe, inference: bool = False, **kwargs): - """Test fp8_param with gpt_model.""" - ctx = torch.inference_mode if inference else contextlib.nullcontext - with ctx(): - self._run_test_helper(tp_size, recipe, inference=inference, **kwargs) - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize("tp_size", [4]) - def test_delayed_scaling(self, tp_size): - self.run_test(tp_size=tp_size, recipe="delayed") - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.skipif(not is_te_min_version("2.2.0"), reason="TE 2.2.0 is required") - @pytest.mark.parametrize("tp_size", [4]) - def test_tensorwise_scaling(self, tp_size): - self.run_test(tp_size=tp_size, recipe="tensorwise") - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.skipif(not is_te_min_version("2.2.0"), reason="TE 2.2.0 is required") - @pytest.mark.parametrize("tp_size", [4]) - def test_tensorwise_scaling_inference(self, tp_size): - self.run_test(tp_size=tp_size, recipe="tensorwise", inference=True) - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.skipif(not is_te_min_version("2.2.0"), reason="TE 2.2.0 is required") - @pytest.mark.parametrize("tp_size", [4]) - def test_tensorwise_scaling_with_first_last_layers_bf16(self, tp_size): - kwargs = { - "first_last_layers_bf16": True, - "num_layers_at_start_in_bf16": 1, - "num_layers_at_end_in_bf16": 1, - } - self.run_test(tp_size=tp_size, recipe="tensorwise", **kwargs) - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.skipif(not is_te_min_version("2.4.0.dev0"), reason="TE 2.4.0.dev0 is required") - @pytest.mark.parametrize("tp_size", [4]) - def test_blockwise_scaling(self, tp_size): - self.run_test(tp_size=tp_size, recipe="blockwise") - - @pytest.mark.skipif( - get_device_arch_version() < 10, reason="MXFP8 is supported since Blackwell architecture" - ) - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.skipif(not is_te_min_version("2.3.0.dev0"), reason="TE 2.3.0.dev0 is required") - @pytest.mark.parametrize("tp_size", [2]) - @pytest.mark.parametrize("dp_overlap", [(False, False), (False, True), (True, True)]) - def test_mxfp8(self, tp_size, dp_overlap): - """ - dp_overlap: (overlap_param_gather, overlap_grad_reduce) - """ - kwargs = {"overlap_param_gather": dp_overlap[0], "overlap_grad_reduce": dp_overlap[1]} - self.run_test(tp_size=tp_size, recipe="mxfp8", **kwargs) - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.skipif(not is_te_min_version("2.4.0.dev0"), reason="TE 2.4.0.dev0 is required") - @pytest.mark.parametrize("tp_size", [4]) - def test_blockwise_scaling_with_first_last_layers_bf16(self, tp_size): - kwargs = { - "first_last_layers_bf16": True, - "num_layers_at_start_in_bf16": 1, - "num_layers_at_end_in_bf16": 1, - } - self.run_test(tp_size=tp_size, recipe="blockwise", **kwargs) diff --git a/tests/unit_tests/test_fp8_utils.py b/tests/unit_tests/test_fp8_utils.py deleted file mode 100644 index 5be17f03c9..0000000000 --- a/tests/unit_tests/test_fp8_utils.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from unittest.mock import Mock, patch - -import pytest -import torch -import torch.nn as nn - -from megatron.core import fp8_utils -from tests.unit_tests.test_utilities import Utils - - -class MockTELinear(nn.Module): - """Mock TE Linear module for testing.""" - - def __init__(self, in_features, out_features): - super().__init__() - self.weight = nn.Parameter(torch.randn(out_features, in_features)) - - def forward(self, x): - return x @ self.weight.t() - - -class TestFP8Padding: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - # Clear the wrapped modules set before each test - fp8_utils._fp8_inference_wrapped_modules.clear() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - fp8_utils._fp8_inference_wrapped_modules.clear() - - def test_prepare_model_for_fp8_inference_basic(self): - """Test prepare_model_for_fp8_inference wraps TE modules.""" - - class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.te_layer = MockTELinear(128, 128) - self.regular_layer = nn.Linear(128, 128) - - with ( - patch.object(fp8_utils, 'HAVE_TE', True), - patch.object(fp8_utils, 'Fp8Padding'), - patch.object(fp8_utils, 'Fp8Unpadding'), - patch.object(fp8_utils, 'TE_LINEAR_TYPES', (MockTELinear,)), - ): - - model = SimpleModel() - original_te_forward = model.te_layer.forward - original_regular_forward = model.regular_layer.forward - - # Prepare model - prepared_model = fp8_utils.prepare_model_for_fp8_inference(model) - - # Check same model returned - assert prepared_model is model - - # Check TE layer was wrapped - assert model.te_layer.forward != original_te_forward - assert model.te_layer in fp8_utils._fp8_inference_wrapped_modules - - # Check regular layer was not wrapped - assert model.regular_layer.forward == original_regular_forward - - def test_padding_mechanism_works(self): - """Test that the padding mechanism actually pads and unpads correctly.""" - - with ( - patch.object(fp8_utils, 'HAVE_TE', True), - patch.object(fp8_utils, 'Fp8Padding') as mock_pad_class, - patch.object(fp8_utils, 'Fp8Unpadding') as mock_unpad_class, - ): - - # Setup padding mock to pad from 6 to 16 - mock_pad_instance = Mock() - mock_pad_instance.return_value = (torch.zeros(16, 8192), [16]) - mock_pad_class.return_value = mock_pad_instance - - # Setup unpadding mock to unpad from 16 to 6 - mock_unpad_instance = Mock() - mock_unpad_instance.return_value = torch.zeros(6, 8192) - mock_unpad_class.return_value = mock_unpad_instance - - # Create module and get access to padded_forward directly - module = MockTELinear(4096, 4096) - module.cuda() - - # Store original forward to track what it receives - original_forward_input = None - - def track_forward(x): - nonlocal original_forward_input - original_forward_input = x - return torch.randn(x.shape[0], x.shape[1], 4096).cuda() - - module.forward = track_forward - - # Manually create the wrapped forward function - fp8_utils._wrap_te_linear_for_padding(module) - padded_forward = module.forward - - # Mock FP8GlobalStateManager.is_fp8_enabled to return True - with patch( - 'transformer_engine.pytorch.fp8.FP8GlobalStateManager.is_fp8_enabled', - return_value=True, - ): - # Create input: (seq_len=6, batch=2, hidden=4096) - input_tensor = torch.randn(6, 2, 4096).cuda() - - # Call padded_forward directly - output = padded_forward(input_tensor) - - # Verify padding was called with correct reshaped input - mock_pad_instance.assert_called_once() - call_args = mock_pad_instance.call_args[0] - assert call_args[0].shape == (6, 8192) # Reshaped to 2D - assert call_args[1] == [6] # Split info - - # Verify the original forward received padded input with correct shape - assert original_forward_input.shape == (16, 2, 4096) # Padded to 16 - - # Verify unpadding was called - mock_unpad_instance.assert_called_once() - unpad_args = mock_unpad_instance.call_args[0] - assert unpad_args[0].shape == (16, 8192) # Padded 2D tensor - assert unpad_args[1] == [6] # Original split - - # Verify output has original shape - assert output.shape == (6, 2, 4096) # Back to original seq_len diff --git a/tests/unit_tests/test_hyper_comm_grid.py b/tests/unit_tests/test_hyper_comm_grid.py deleted file mode 100644 index dd27f84f60..0000000000 --- a/tests/unit_tests/test_hyper_comm_grid.py +++ /dev/null @@ -1,544 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import os -from unittest.mock import MagicMock, patch - -import pytest -import torch -import torch.distributed as dist - -from megatron.core.hyper_comm_grid import HyperCommGrid - - -class TestHyperCommGrid: - """Comprehensive tests for HyperCommGrid class.""" - - def test_init_basic(self): - """Test basic initialization of HyperCommGrid.""" - shape = [2, 2, 2] - dim_names = ["tp", "cp", "dp"] - - grid = HyperCommGrid(shape, dim_names) - - assert grid.shape == shape - assert grid.dim_names == dim_names - assert grid.rank_offset == 0 - assert grid.backend is None - assert grid.size == 8 # 2 * 2 * 2 - assert grid._pgs == {} - - def test_init_with_optional_params(self): - """Test initialization with optional parameters.""" - shape = [2, 2] # Changed from [2, 4] to fit world size 8 with offset 8 - dim_names = ["tp", "dp"] - rank_offset = 0 # Changed from 8 to 0 to avoid size error - backend = "nccl" - - grid = HyperCommGrid(shape, dim_names, rank_offset, backend) - - assert grid.shape == shape - assert grid.dim_names == dim_names - assert grid.rank_offset == rank_offset - assert grid.backend == backend - assert grid.size == 4 # 2 * 2 - - def test_init_validation_errors(self): - """Test initialization validation errors.""" - # Shape and dim_names length mismatch - with pytest.raises(ValueError, match="len\\(shape\\).*!= len\\(dim_names\\)"): - HyperCommGrid([2, 2], ["tp"]) - - # Grid too large for world size - with pytest.raises(RuntimeError, match="Grid shape.*is over sized"): - HyperCommGrid([4, 4], ["tp", "dp"]) # 16 > 8 world size - - def test_order_dims_single_dim(self): - """Test _order_dims with single dimension.""" - grid = HyperCommGrid( - [2, 2, 2], ["tp", "cp", "dp"] - ) # Changed from [2, 3, 4] to fit world size - - ordered_dims, unique_key = grid._order_dims("cp") - - assert ordered_dims == ["cp"] - assert unique_key == "cp" - - def test_order_dims_multiple_dims(self): - """Test _order_dims with multiple dimensions.""" - grid = HyperCommGrid( - [2, 2, 2], ["tp", "cp", "dp"] - ) # Changed from [2, 3, 4, 5] to fit world size - - # Should order according to reversed dim_names order - ordered_dims, unique_key = grid._order_dims(["dp", "tp"]) - - assert ordered_dims == [ - "dp", - "tp", - ] # Changed: dp comes before tp in reversed order ["dp", "cp", "tp"] - assert unique_key == "dp-tp" - - def test_order_dims_all_dims(self): - """Test _order_dims with all dimensions.""" - grid = HyperCommGrid( - [2, 2, 2], ["tp", "cp", "dp"] - ) # Changed from [2, 3, 4] to fit world size - - ordered_dims, unique_key = grid._order_dims(["dp", "cp", "tp"]) - - assert ordered_dims == ["dp", "cp", "tp"] # Changed: reversed order - assert unique_key == "dp-cp-tp" - - def test_gen_rank_enum_single_dim(self): - """Test _gen_rank_enum for single dimension.""" - grid = HyperCommGrid([2, 4], ["tp", "dp"]) - - rank_enum = grid._gen_rank_enum(["tp"]) - - # Should have 4 groups of 2 ranks each - expected = [[0, 1], [2, 3], [4, 5], [6, 7]] - assert rank_enum == expected - - def test_gen_rank_enum_multiple_dims(self): - """Test _gen_rank_enum for multiple dimensions.""" - grid = HyperCommGrid([2, 2, 2], ["tp", "cp", "dp"]) - - rank_enum = grid._gen_rank_enum(["tp", "cp"]) - - # Should have 2 groups (for dp) with 4 ranks each (tp * cp) - expected = [[0, 2, 1, 3], [4, 6, 5, 7]] # Updated to match actual einops rearrange result - assert rank_enum == expected - - def test_gen_rank_enum_with_offset(self): - """Test _gen_rank_enum with rank offset.""" - grid = HyperCommGrid([2, 2], ["tp", "dp"], rank_offset=4) - - rank_enum = grid._gen_rank_enum(["tp"]) - - # Should start from rank 4 - expected = [[4, 5], [6, 7]] - assert rank_enum == expected - - @patch('torch.distributed.new_subgroups_by_enumeration') - def test_create_pg_single_dim(self, mock_new_subgroups): - """Test create_pg for single dimension.""" - mock_pg = MagicMock(spec=dist.ProcessGroup) - mock_new_subgroups.return_value = (mock_pg, None) - - grid = HyperCommGrid([2, 4], ["tp", "dp"]) - - result = grid.create_pg("tp") - - assert result == mock_pg - assert "tp" in grid._pgs - assert grid._pgs["tp"] == mock_pg - - # Verify the enumeration passed to new_subgroups_by_enumeration - args, kwargs = mock_new_subgroups.call_args - expected_enum = [[0, 1], [2, 3], [4, 5], [6, 7]] - assert args[0] == expected_enum - assert kwargs["backend"] is None - - @patch('torch.distributed.new_subgroups_by_enumeration') - def test_create_pg_multiple_dims(self, mock_new_subgroups): - """Test create_pg for multiple dimensions.""" - mock_pg = MagicMock(spec=dist.ProcessGroup) - mock_new_subgroups.return_value = (mock_pg, None) - - grid = HyperCommGrid([2, 2, 2], ["tp", "cp", "dp"]) - - result = grid.create_pg(["tp", "cp"]) - - assert result == mock_pg - assert "cp-tp" in grid._pgs - - args, kwargs = mock_new_subgroups.call_args - expected_enum = [[0, 1, 2, 3], [4, 5, 6, 7]] - assert args[0] == expected_enum - - @patch('torch.distributed.new_subgroups_by_enumeration') - def test_create_pg_with_options(self, mock_new_subgroups): - """Test create_pg with additional options.""" - mock_pg = MagicMock(spec=dist.ProcessGroup) - mock_new_subgroups.return_value = (mock_pg, None) - - grid = HyperCommGrid([2, 4], ["tp", "dp"], backend="nccl") - - # Mock ProcessGroupNCCL.Options - mock_options = MagicMock() - - result = grid.create_pg("tp", pg_options=mock_options, group_desc="TEST_GROUP") - - assert result == mock_pg - - args, kwargs = mock_new_subgroups.call_args - assert kwargs["backend"] == "nccl" - assert kwargs["pg_options"] == mock_options - - @patch('torch.distributed.new_subgroups_by_enumeration') - def test_create_pg_duplicate_error(self, mock_new_subgroups): - """Test create_pg raises error when trying to recreate existing process group.""" - mock_pg = MagicMock(spec=dist.ProcessGroup) - mock_new_subgroups.return_value = (mock_pg, None) - - grid = HyperCommGrid([2, 4], ["tp", "dp"]) - - # Create process group first time - grid.create_pg("tp") - - # Try to create again should raise KeyError - with pytest.raises(KeyError, match="Process group.*has already been created"): - grid.create_pg("tp") - - @patch('torch.distributed.new_subgroups_by_enumeration') - def test_get_pg_success(self, mock_new_subgroups): - """Test get_pg returns existing process group.""" - mock_pg = MagicMock(spec=dist.ProcessGroup) - mock_new_subgroups.return_value = (mock_pg, None) - - grid = HyperCommGrid([2, 4], ["tp", "dp"]) - - # Create process group first - grid.create_pg("dp") - - # Get should return the same process group - result = grid.get_pg("dp") - assert result == mock_pg - - def test_get_pg_not_created_error(self): - """Test get_pg raises error when process group doesn't exist.""" - grid = HyperCommGrid([2, 4], ["tp", "dp"]) - - with pytest.raises(KeyError, match="Process group for.*hasn't been created"): - grid.get_pg("tp") - - @patch('torch.distributed.new_subgroups_by_enumeration') - def test_get_pg_multiple_dims(self, mock_new_subgroups): - """Test get_pg with multiple dimensions.""" - mock_pg = MagicMock(spec=dist.ProcessGroup) - mock_new_subgroups.return_value = (mock_pg, None) - - grid = HyperCommGrid([2, 2, 2], ["tp", "cp", "dp"]) - - # Create process group with multiple dims - grid.create_pg(["cp", "dp"]) - - # Get should work with different order - result = grid.get_pg(["dp", "cp"]) - assert result == mock_pg - - def test_complex_grid_scenario(self): - """Test a complex scenario similar to the docstring example.""" - os.environ["WORLD_SIZE"] = "120" # Set larger world size for this test - - grid = HyperCommGrid([2, 3, 4, 5], ["tp", "cp", "pp", "dp"]) - - assert grid.size == 120 - assert grid.shape == [2, 3, 4, 5] - assert grid.dim_names == ["tp", "cp", "pp", "dp"] - - # Test ordering of different dimension combinations - ordered_dims, key = grid._order_dims(["dp", "pp"]) - assert ordered_dims == ["dp", "pp"] # Changed: actual order matches reversed dim_names - assert key == "dp-pp" - - # Test rank enumeration for dp (last dimension) - rank_enum = grid._gen_rank_enum(["dp"]) - assert len(rank_enum) == 24 # 2 * 3 * 4 = 24 groups - assert len(rank_enum[0]) == 5 # Each group has 5 ranks - - # Clean up - os.environ["WORLD_SIZE"] = "8" - - @patch('torch.distributed.new_subgroups_by_enumeration') - def test_end_to_end_workflow(self, mock_new_subgroups): - """Test complete workflow: init -> create -> get.""" - mock_pg1 = MagicMock(spec=dist.ProcessGroup) - mock_pg2 = MagicMock(spec=dist.ProcessGroup) - mock_new_subgroups.side_effect = [(mock_pg1, None), (mock_pg2, None)] - - grid = HyperCommGrid([2, 2, 2], ["tp", "cp", "dp"]) - - # Create different process groups - tp_pg = grid.create_pg("tp") - dp_cp_pg = grid.create_pg(["dp", "cp"]) - - # Verify they're created correctly - assert tp_pg == mock_pg1 - assert dp_cp_pg == mock_pg2 - - # Verify we can get them back - assert grid.get_pg("tp") == mock_pg1 - assert grid.get_pg(["cp", "dp"]) == mock_pg2 # Different order should work - - # Verify internal state - assert len(grid._pgs) == 2 - assert "tp" in grid._pgs - assert "dp-cp" in grid._pgs # Changed: actual key order - - def test_edge_case_single_rank_dims(self): - """Test edge case with dimensions of size 1.""" - grid = HyperCommGrid([1, 2, 4], ["tp", "cp", "dp"]) - - # Test with tp dimension (size 1) - rank_enum = grid._gen_rank_enum(["tp"]) - expected = [[0], [1], [2], [3], [4], [5], [6], [7]] # 8 groups of 1 rank each - assert rank_enum == expected - - # Test with multiple dims including size 1 - rank_enum = grid._gen_rank_enum(["tp", "cp"]) - expected = [[0, 1], [2, 3], [4, 5], [6, 7]] # 4 groups of 2 ranks each - assert rank_enum == expected - - def test_rank_enumeration_correctness(self): - """Test that rank enumeration produces correct pattern.""" - grid = HyperCommGrid([2, 2, 2], ["a", "b", "c"]) - - # For dimension "a" (first in original order, last in reversed) - rank_enum_a = grid._gen_rank_enum(["a"]) - expected_a = [[0, 1], [2, 3], [4, 5], [6, 7]] - assert rank_enum_a == expected_a - - # For dimension "c" (last in original order, first in reversed) - rank_enum_c = grid._gen_rank_enum(["c"]) - expected_c = [[0, 4], [1, 5], [2, 6], [3, 7]] - assert rank_enum_c == expected_c - - # For dimensions "a" and "b" - rank_enum_ab = grid._gen_rank_enum(["a", "b"]) - expected_ab = [[0, 2, 1, 3], [4, 6, 5, 7]] - assert rank_enum_ab == expected_ab - - -class TestHyperCommGridIntegration: - """Integration tests for HyperCommGrid with real distributed initialization.""" - - @classmethod - def setup_class(cls): - """Set up distributed environment for the entire test class.""" - if not dist.is_initialized(): - # Initialize PyTorch distributed with NCCL backend - # This assumes proper environment variables are set (RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT) - try: - dist.init_process_group(backend="nccl") - cls.distributed_initialized = True - except Exception as e: - pytest.skip(f"Cannot initialize distributed: {e}") - else: - cls.distributed_initialized = True - - def test_real_distributed_basic_functionality(self): - """Test basic HyperCommGrid functionality with real distributed backend.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") - - world_size = dist.get_world_size() - if world_size > 8: - pytest.skip("Test requires at most 8 GPUs") - - # Test with world_size that fits our constraint - if world_size == 8: - shape = [2, 2, 2] - dim_names = ["tp", "cp", "dp"] - elif world_size == 4: - shape = [2, 2] - dim_names = ["tp", "dp"] - elif world_size == 2: - shape = [2] - dim_names = ["tp"] - else: - pytest.skip(f"Unsupported world size: {world_size}") - - grid = HyperCommGrid(shape, dim_names, backend="nccl") - - assert grid.size == world_size - assert grid.shape == shape - assert grid.dim_names == dim_names - assert grid.backend == "nccl" - - def test_real_distributed_process_group_creation(self): - """Test process group creation with real distributed backend.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") - - world_size = dist.get_world_size() - if world_size != 8: - pytest.skip("This test specifically requires 8 GPUs") - - grid = HyperCommGrid([2, 2, 2], ["tp", "cp", "dp"], backend="nccl") - - # Create different types of process groups - tp_pg = grid.create_pg("tp") - cp_pg = grid.create_pg("cp") - dp_pg = grid.create_pg("dp") - - # Verify process groups are real PyTorch ProcessGroup objects - assert isinstance(tp_pg, dist.ProcessGroup) - assert isinstance(cp_pg, dist.ProcessGroup) - assert isinstance(dp_pg, dist.ProcessGroup) - - # Verify we can get the process groups back - assert grid.get_pg("tp") == tp_pg - assert grid.get_pg("cp") == cp_pg - assert grid.get_pg("dp") == dp_pg - - # Test process group sizes - tp_ranks = dist.get_process_group_ranks(tp_pg) - cp_ranks = dist.get_process_group_ranks(cp_pg) - dp_ranks = dist.get_process_group_ranks(dp_pg) - - assert len(tp_ranks) == 2 # tp dimension size - assert len(cp_ranks) == 2 # cp dimension size - assert len(dp_ranks) == 2 # dp dimension size - - def test_real_distributed_multi_dimensional_groups(self): - """Test multi-dimensional process group creation with real distributed backend.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") - - world_size = dist.get_world_size() - if world_size != 8: - pytest.skip("This test specifically requires 8 GPUs") - - grid = HyperCommGrid([2, 2, 2], ["tp", "cp", "dp"], backend="nccl") - - # Create multi-dimensional process groups - tp_cp_pg = grid.create_pg(["tp", "cp"]) - cp_dp_pg = grid.create_pg(["cp", "dp"]) - - # Verify process groups are real - assert isinstance(tp_cp_pg, dist.ProcessGroup) - assert isinstance(cp_dp_pg, dist.ProcessGroup) - - # Test process group sizes - tp_cp_ranks = dist.get_process_group_ranks(tp_cp_pg) - cp_dp_ranks = dist.get_process_group_ranks(cp_dp_pg) - - assert len(tp_cp_ranks) == 4 # tp * cp = 2 * 2 - assert len(cp_dp_ranks) == 4 # cp * dp = 2 * 2 - - def test_real_distributed_all_reduce(self): - """Test actual communication using the created process groups.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") - - world_size = dist.get_world_size() - if world_size != 8: - pytest.skip("This test specifically requires 8 GPUs") - - grid = HyperCommGrid([2, 2, 2], ["tp", "cp", "dp"], backend="nccl") - - # Create a process group - tp_pg = grid.create_pg("tp") - - # Create a tensor for communication test - rank = dist.get_rank() - device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") - tensor = torch.ones(1, device=device) * rank - - # Perform all-reduce within the tensor parallel group - dist.all_reduce(tensor, group=tp_pg) - - # Verify the result (sum of ranks in the group) - tp_ranks = dist.get_process_group_ranks(tp_pg) - expected_sum = sum(tp_ranks) - - assert tensor.item() == expected_sum - - def test_real_distributed_different_world_sizes(self): - """Test HyperCommGrid with different valid world sizes.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") - - world_size = dist.get_world_size() - rank = dist.get_rank() - - # Test configurations for different world sizes - configs = { - 1: ([1], ["dp"]), - 2: ([2], ["tp"]), - 4: ([2, 2], ["tp", "dp"]), - 8: ([2, 2, 2], ["tp", "cp", "dp"]), - } - - if world_size not in configs: - pytest.skip(f"No test configuration for world size {world_size}") - - shape, dim_names = configs[world_size] - grid = HyperCommGrid(shape, dim_names, backend="nccl") - - assert grid.size == world_size - - # Create and test first dimension process group - first_dim_pg = grid.create_pg(dim_names[0]) - assert isinstance(first_dim_pg, dist.ProcessGroup) - - # Test communication if world size > 1 - if world_size > 1: - device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") - tensor = torch.tensor([rank], dtype=torch.float, device=device) - - # All-reduce to verify the process group works - dist.all_reduce(tensor, group=first_dim_pg) - - # Verify the result - group_ranks = dist.get_process_group_ranks(first_dim_pg) - expected_sum = sum(group_ranks) - assert tensor.item() == expected_sum - - def test_real_distributed_error_handling(self): - """Test error handling with real distributed backend.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") - - world_size = dist.get_world_size() - if world_size > 8: - pytest.skip("Test requires at most 8 GPUs") - - # Test shape validation with real world size - if world_size == 8: - # This should work - grid = HyperCommGrid([2, 2, 2], ["tp", "cp", "dp"]) - assert grid.size == 8 - - # This should fail - too large for world size - with pytest.raises(RuntimeError, match="Grid shape.*is over sized"): - HyperCommGrid([4, 4], ["tp", "dp"]) # 16 > 8 - - # Test duplicate process group creation - if world_size >= 2: - grid = HyperCommGrid([2, world_size // 2], ["tp", "dp"]) - grid.create_pg("tp") - - with pytest.raises(KeyError, match="Process group.*has already been created"): - grid.create_pg("tp") - - def test_real_distributed_rank_enumeration_verification(self): - """Verify rank enumeration produces correct communication patterns.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") - - world_size = dist.get_world_size() - if world_size != 8: - pytest.skip("This test specifically requires 8 GPUs") - - grid = HyperCommGrid([2, 2, 2], ["tp", "cp", "dp"]) - - # Test that ranks in the same TP group can communicate - tp_pg = grid.create_pg("tp") - tp_ranks = dist.get_process_group_ranks(tp_pg) - - current_rank = dist.get_rank() - if current_rank in tp_ranks: - device = torch.device(f"cuda:{current_rank % torch.cuda.device_count()}") - - # Create a unique tensor based on rank - tensor = torch.tensor([current_rank], dtype=torch.float, device=device) - original_value = tensor.clone() - - # All-reduce within TP group - dist.all_reduce(tensor, group=tp_pg) - - # Verify the sum is correct - expected_sum = sum(tp_ranks) - assert tensor.item() == expected_sum diff --git a/tests/unit_tests/test_imports.py b/tests/unit_tests/test_imports.py deleted file mode 100644 index bad67cd8d5..0000000000 --- a/tests/unit_tests/test_imports.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import importlib -import inspect -import os -import traceback - -import torch -import wrapt - -from megatron.core.transformer.module import MegatronModule - - -def import_class_by_path(path: str): - paths = path.split('.') - path = ".".join(paths[:-1]) - class_name = paths[-1] - mod = __import__(path, fromlist=[class_name]) - mod = getattr(mod, class_name) - return mod - - -def _build_import_path(subdomains: list, imp): - import_path = ["megatron", "core"] - import_path.extend(subdomains) - import_path.append(imp) - path = ".".join(import_path) - return path - - -def _get_class_from_path(subdomains, imp): - path = _build_import_path(subdomains, imp) - print(path) - class_ = None - result = None - try: - class_ = import_class_by_path(path) - if inspect.isclass(class_): - if isinstance(class_, wrapt.FunctionWrapper): - class_ = class_.__wrapped__ - if issubclass(class_, (MegatronModule, torch.nn.Module)): - result = class_ - else: - class_ = None - error = None - except Exception: - error = traceback.format_exc() - return class_, result, error - - -def _test_domain_module_imports(module, subdomains: list): - module_list = [] - failed_list = [] - error_list = [] - - error = None - if len(subdomains) > 0: - basepath = module.__path__[0] - megatron_index = basepath.rfind("megatron") - basepath = basepath[megatron_index:].replace(os.path.sep, ".") - new_path = '.'.join([basepath, *subdomains]) - - try: - module = importlib.import_module(new_path) - except Exception: - print(f"Could not import `{new_path}` ; Traceback below :") - error = traceback.format_exc() - error_list.append(error) - - if error is None: - for imp in dir(module): - class_, result, error = _get_class_from_path(subdomains, imp) - - if result is not None: - module_list.append(class_) - - elif class_ is not None: - failed_list.append(class_) - - if error is not None: - error_list.append(error) - - for module in module_list: - print("Module successfully imported :", module) - - print() - for module in failed_list: - print( - "Module did not match a valid signature of Megatron core Model (hence ignored):", module - ) - - print() - if len(error_list) > 0: - print("Imports crashed with following traceback !") - - for error in error_list: - print("*" * 100) - print() - print(error) - print() - print("*" * 100) - print() - - if len(error_list) > 0: - return False - else: - return True - - -############################### - - -def test_domain_mcore(): - import megatron.core as mcore - - all_passed = _test_domain_module_imports(mcore, subdomains=['models']) - - all_passed = _test_domain_module_imports(mcore, subdomains=['pipeline_parallel']) - - all_passed = _test_domain_module_imports(mcore, subdomains=['tensor_parallel']) - - all_passed = _test_domain_module_imports(mcore, subdomains=['transformer']) - - all_passed = _test_domain_module_imports(mcore, subdomains=['fusions']) - - all_passed = _test_domain_module_imports(mcore, subdomains=['distributed']) - - all_passed = _test_domain_module_imports(mcore, subdomains=['datasets']) - - all_passed = _test_domain_module_imports(mcore, subdomains=['dist_checkpointing']) - - if not all_passed: - exit(1) - - -if __name__ == '__main__': - test_domain_mcore() diff --git a/tests/unit_tests/test_inference.py b/tests/unit_tests/test_inference.py deleted file mode 100644 index 752b0fff25..0000000000 --- a/tests/unit_tests/test_inference.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse -import unittest.mock - -import numpy as np -import pytest -import torch - -from megatron.inference.text_generation_server import MegatronServer -from megatron.training import tokenizer -from tests.unit_tests.inference.engines.test_static_engine import TestStaticInferenceEngine -from tests.unit_tests.test_tokenizer import GPT2_VOCAB_SIZE, gpt2_tiktok_vocab -from tests.unit_tests.test_utilities import Utils - - -@pytest.fixture(scope="module") -def gpt2_tiktoken_tokenizer(gpt2_tiktok_vocab): - return tokenizer.build_tokenizer(gpt2_tiktok_vocab) - - -@pytest.fixture(scope="module") -def static_inference_engine(gpt2_tiktoken_tokenizer): - engine_wrapper = TestStaticInferenceEngine() - engine_wrapper.setup_engine(vocab_size=gpt2_tiktoken_tokenizer.vocab_size) - - controller = engine_wrapper.static_engine.text_generation_controller - controller.tokenizer = gpt2_tiktoken_tokenizer - - def mock_forward(*args, **kwargs): - tokens = args[0] - B, L = tokens.shape - assert B == 1, "Test assumes batch_size == 1" - V = gpt2_tiktoken_tokenizer.vocab_size - next_token_idxs = tokens[0, 1:] - logits = torch.zeros(1, L, V, dtype=torch.float32, device=tokens.device) - logits[0, torch.arange(L - 1), next_token_idxs] = 100 - logits[0, -1, gpt2_tiktoken_tokenizer.eos] = 100 - return logits - - controller.inference_wrapped_model.model.forward = mock_forward - yield engine_wrapper.static_engine - - -@pytest.fixture(scope="module") -def app(static_inference_engine): - return MegatronServer(static_inference_engine).app - - -@pytest.fixture() -def client(app): - return app.test_client() - - -@unittest.mock.patch('megatron.inference.endpoints.completions.send_do_generate') -@unittest.mock.patch("megatron.inference.text_generation.tokenization.get_tokenizer") -@unittest.mock.patch("megatron.inference.endpoints.completions.get_tokenizer") -def test_completions_endpoint( - mock_get_tokenizer1, mock_get_tokenizer2, mock_send_do_generate, client, gpt2_tiktoken_tokenizer -): - Utils.initialize_distributed() - - mock_get_tokenizer1.return_value = gpt2_tiktoken_tokenizer - mock_get_tokenizer2.return_value = gpt2_tiktoken_tokenizer - - twinkle = ("twinkle twinkle little star,", " how I wonder what you are") - request_data = {"prompt": twinkle[0] + twinkle[1], "max_tokens": 0, "logprobs": 5, "echo": True} - - response = client.post('/completions', json=request_data) - - assert response.status_code == 200 - assert response.is_json - - json_data = response.get_json() - assert 'choices' in json_data - assert len(json_data['choices']) > 0 - assert 'text' in json_data['choices'][0] - assert 'logprobs' in json_data['choices'][0] - - # whats up with the reconstruction of the prompt? - # we are replicating what lm-eval-harness::TemplateLM::_encode_pair does - # it encodes prompt, then prompt+suffix, and then infers the suffix tokens - # from the combined encoding. - logprobs = json_data["choices"][0]["logprobs"] - num_reconstructed_prompt_tokens = np.searchsorted(logprobs["text_offset"], len(twinkle[0])) - assert num_reconstructed_prompt_tokens == len(gpt2_tiktoken_tokenizer.tokenize(twinkle[0])) - suffix_logprob = logprobs["token_logprobs"][num_reconstructed_prompt_tokens:] - - # we mock logits to be 0 everywhere, and 100 at gt tokens, so logprob should be 0 for gt tokens - assert sum(suffix_logprob) == 0, f"{suffix_logprob} != [0, .... 0]" - - # Test for unsupported HTTP methods - response = client.put('/completions', json=request_data) - assert response.status_code == 405 # Method Not Allowed - - mock_send_do_generate.assert_called_once() diff --git a/tests/unit_tests/test_local_multi_tensor_fns.py b/tests/unit_tests/test_local_multi_tensor_fns.py deleted file mode 100644 index 9c06cd24af..0000000000 --- a/tests/unit_tests/test_local_multi_tensor_fns.py +++ /dev/null @@ -1,94 +0,0 @@ -import copy - -import pytest -import torch - -from megatron.core.utils import ( - local_multi_tensor_applier, - local_multi_tensor_l2_norm, - local_multi_tensor_scale, -) - - -def test_local_multi_tensor_l2_norm_and_scale(): - amp_C = pytest.importorskip("amp_C") - multi_tensor_apply = pytest.importorskip("apex.multi_tensor_apply") - - torch.manual_seed(42) - - tensor_list = [torch.rand(5, 5).cuda() for _ in range(10)] - tensor_list_hold = copy.copy(tensor_list) - tensor_list_copy = copy.deepcopy(tensor_list) - tensor_list_copy_hold = copy.copy(tensor_list_copy) - - # test multi_tensor_l2norm - norm_apex, _ = multi_tensor_apply.multi_tensor_applier( - amp_C.multi_tensor_l2norm, - torch.tensor([0], dtype=torch.int, device='cuda'), - [tensor_list], - False, - ) - norm_local, _ = multi_tensor_apply.multi_tensor_applier( - local_multi_tensor_l2_norm, - torch.tensor([0], dtype=torch.int, device='cuda'), - [tensor_list_copy], - False, - ) - torch.testing.assert_close(norm_apex, norm_local) - - # test src is dst - clip_coeff = 0.05 - multi_tensor_apply.multi_tensor_applier( - amp_C.multi_tensor_scale, - torch.tensor([0], dtype=torch.int, device='cuda'), - [tensor_list, tensor_list], - clip_coeff, - ) - multi_tensor_apply.multi_tensor_applier( - local_multi_tensor_scale, - torch.tensor([0], dtype=torch.int, device='cuda'), - [tensor_list_copy, tensor_list_copy], - clip_coeff, - ) - torch.testing.assert_close(tensor_list, tensor_list_hold) - torch.testing.assert_close(tensor_list_copy, tensor_list_copy_hold) - torch.testing.assert_close(tensor_list, tensor_list_copy) - - # test src is not dst - clip_coeff = 2.0 - multi_tensor_apply.multi_tensor_applier( - amp_C.multi_tensor_scale, - torch.tensor([0], dtype=torch.int, device='cuda'), - [copy.deepcopy(tensor_list), tensor_list], - clip_coeff, - ) - multi_tensor_apply.multi_tensor_applier( - local_multi_tensor_scale, - torch.tensor([0], dtype=torch.int, device='cuda'), - [copy.deepcopy(tensor_list_copy), tensor_list_copy], - clip_coeff, - ) - torch.testing.assert_close(tensor_list, tensor_list_hold) - torch.testing.assert_close(tensor_list_copy, tensor_list_copy_hold) - torch.testing.assert_close(tensor_list, tensor_list_copy) - - -def test_local_multi_tensor_apply(): - amp_C = pytest.importorskip("amp_C") - multi_tensor_apply = pytest.importorskip("apex.multi_tensor_apply") - - tensor_list = [torch.rand(5, 5).cuda() for _ in range(10)] - - norm_apex, _ = multi_tensor_apply.multi_tensor_applier( - amp_C.multi_tensor_l2norm, - torch.tensor([0], dtype=torch.int, device='cuda'), - [tensor_list], - False, - ) - norm_local, _ = local_multi_tensor_applier( - amp_C.multi_tensor_l2norm, - torch.tensor([0], dtype=torch.int, device='cuda'), - [tensor_list], - False, - ) - torch.testing.assert_close(norm_apex, norm_local) diff --git a/tests/unit_tests/test_model_configs.py b/tests/unit_tests/test_model_configs.py deleted file mode 100644 index 383c78200e..0000000000 --- a/tests/unit_tests/test_model_configs.py +++ /dev/null @@ -1,37 +0,0 @@ -import pathlib - -import pytest -import yaml - -YAML_DIR = pathlib.Path(__file__).parent / ".." / "functional_tests" / "test_cases" - - -def get_yaml_files(directory): - """Retrieve all YAML files from the specified directory.""" - return list([file for file in directory.rglob("model_config.yaml") if file is not None]) - - -def load_yaml(file_path): - """Load a YAML file and return its content as a Python dictionary.""" - with open(file_path, "r") as f: - return yaml.safe_load(f) - - -@pytest.mark.parametrize( - "metric", - ["--log-memory-to-tensorboard", "--log-num-zeros-in-grad", "--log-timers-to-tensorboard"], -) -@pytest.mark.parametrize("yaml_file", get_yaml_files(YAML_DIR)) -def test_model_config_tracks_memory(yaml_file, metric): - """Test if each YAML file contains the required record.""" - print("gpt3-nemo" in str(yaml_file) or "ckpt_converter" in str(yaml_file)) - if any(k in str(yaml_file) for k in ["gpt3-nemo", "ckpt_converter", "gpt-nemo", "inference"]): - pytest.skip("Skipping `test_model_config_tracks_memory`") - - model_config = load_yaml(yaml_file) - - assert ( - "MODEL_ARGS" in model_config - and metric in model_config["MODEL_ARGS"] - and model_config["MODEL_ARGS"][metric] is True - ), f"Please add argument `{metric}` to `{yaml_file.parent.name}/model_config.yaml` that its metric gets tracked." diff --git a/tests/unit_tests/test_nccl_allocator.py b/tests/unit_tests/test_nccl_allocator.py deleted file mode 100644 index d890f6fe6b..0000000000 --- a/tests/unit_tests/test_nccl_allocator.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import os - -import pytest -import torch -from packaging import version - -import megatron.core.nccl_allocator as nccl_allocator -from tests.unit_tests.test_utilities import Utils - - -class TestNCCLAllocator: - @classmethod - def setup_class(cls): - Utils.initialize_model_parallel() - - @classmethod - def teardown_class(cls): - Utils.destroy_model_parallel() - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.7.0'), - reason="Requires PyTorch 2.7.0 or later", - ) - def test_nccl_allocator_init_sets_env_vars(self): - nccl_allocator.init() - assert os.environ.get("NCCL_NVLS_ENABLE") == "1" - assert os.environ.get("TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK") == "0" - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.7.0'), - reason="Requires PyTorch 2.7.0 or later", - ) - @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs") - def test_nccl_nccl_mem_register_and_allreduce(self): - if not torch.cuda.is_available(): - pytest.skip("CUDA is required for NCCL allocator tests") - - world_size = torch.distributed.get_world_size() - - device = torch.device("cuda", torch.cuda.current_device()) - torch.cuda.set_device(device) - - # Default process group and backend - pg = torch.distributed.new_group(ranks=list(range(world_size)), backend="nccl") - - nccl_allocator.init() - - # Create mempool via our allocator and register it around allocation - pool = nccl_allocator.create_nccl_mem_pool() - with nccl_allocator.nccl_mem(pool, group=pg): - tensor = torch.ones([1], device=device) - - # Perform an all-reduce to ensure communication works with the pool registered - torch.distributed.all_reduce(tensor, group=pg) - torch.cuda.synchronize(device=device) - assert tensor == torch.tensor([world_size], device=device) - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.7.0'), - reason="Requires PyTorch 2.7.0 or later", - ) - @pytest.mark.skipif(torch.cuda.device_count() != 8, reason="Requires 8 GPUs") - @pytest.mark.skipif( - torch.cuda.nccl.version() < (2, 27, 0), reason="Requires at least NCCL v2.27.0" - ) - def test_ag_with_nccl_cta_policy(self): - if not torch.cuda.is_available(): - pytest.skip("CUDA is required for NCCL allocator tests") - - os.environ["NCCL_CTA_POLICY"] = "1" - - world_size = torch.distributed.get_world_size() - - if world_size != 8: - pytest.skip("Requires 8 ranks") - - device = torch.device("cuda", torch.cuda.current_device()) - torch.cuda.set_device(device) - - pg = torch.distributed.new_group(ranks=list(range(world_size)), backend="nccl") - - nccl_allocator.init() - - pool = nccl_allocator.create_nccl_mem_pool() - target_tensor_numel = 1000000 - with nccl_allocator.nccl_mem(pool, group=pg): - tensor_shard = torch.ones([target_tensor_numel // world_size], device=device) - tensor_unshard = torch.ones([target_tensor_numel], device=device) - - torch.distributed.all_gather_into_tensor(tensor_unshard, tensor_shard, group=pg) - torch.cuda.synchronize(device=device) diff --git a/tests/unit_tests/test_num_microbatches_calculator.py b/tests/unit_tests/test_num_microbatches_calculator.py deleted file mode 100644 index 9b3356b8af..0000000000 --- a/tests/unit_tests/test_num_microbatches_calculator.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import List, Optional - -import pytest - -import megatron.core.num_microbatches_calculator as mb_calculator - - -def test_init_num_microbatches_calculator(): - mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None - mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 2, False) - assert mb_calculator.get_num_microbatches() == 2 - assert mb_calculator.get_current_global_batch_size() == 32 - - with pytest.raises(AssertionError): - mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 2, False) - - mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None - mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 3, True) - assert mb_calculator.get_num_microbatches() == 1 - assert mb_calculator.get_current_global_batch_size() == 32 - assert mb_calculator.get_current_running_global_batch_size() == 24 - - mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None - mb_calculator.init_num_microbatches_calculator(0, None, 33, 8, 2, True) - assert mb_calculator.get_num_microbatches() == 2 - assert mb_calculator.get_current_global_batch_size() == 33 - assert mb_calculator.get_current_running_global_batch_size() == 32 - - -def test_reconfigure_num_microbatches_calculator(): - mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None - mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 2, False) - assert mb_calculator.get_num_microbatches() == 2 - assert mb_calculator.get_current_global_batch_size() == 32 - - mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 8, 2, False) - assert mb_calculator.get_num_microbatches() == 1 - assert mb_calculator.get_current_global_batch_size() == 16 - - mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 16, 96], 32, 8, 2, False) - assert mb_calculator.get_num_microbatches() == 1 - assert mb_calculator.get_current_global_batch_size() == 16 - - -def test_get_num_microbatches(): - mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 8, 2, False) - assert mb_calculator.get_num_microbatches() == 1 - - mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 4, 3, True) - assert mb_calculator.get_num_microbatches() == 1 - - -def test_get_current_global_batch_size(): - mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 4, 2, False) - assert mb_calculator.get_current_global_batch_size() == 16 - - mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 4, 3, True) - assert mb_calculator.get_current_global_batch_size() == 16 - assert mb_calculator.get_current_running_global_batch_size() == 12 - - -def test_get_micro_batch_size(): - mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 8, 2, False) - assert mb_calculator.get_micro_batch_size() == 8 - - -def test_update_num_microbatches(): - mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 8, 96], 32, 4, 2, False) - assert mb_calculator.get_num_microbatches() == 2 - mb_calculator.update_num_microbatches(48, False) - assert mb_calculator.get_num_microbatches() == 3 - - mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 8, 96], 32, 8, 2, False) - with pytest.raises(AssertionError): - mb_calculator.update_num_microbatches(49, True) - - mb_calculator.reconfigure_num_microbatches_calculator(0, None, 32, 8, 2, False) - mb_calculator.update_num_microbatches(16) - assert mb_calculator.get_num_microbatches() == 2 - - -def test_build_num_microbatches_calculator(): - temp_calculator = mb_calculator._build_num_microbatches_calculator(0, None, 32, 8, 2, False) - assert temp_calculator.get() == 2 - assert temp_calculator.get_current_global_batch_size() == 32 - assert type(temp_calculator) is mb_calculator.ConstantNumMicroBatchesCalculator - - temp_calculator = mb_calculator._build_num_microbatches_calculator( - 0, [16, 16, 48], 32, 8, 2, False - ) - assert temp_calculator.get() == 1 - assert temp_calculator.get_current_global_batch_size() == 16 - assert type(temp_calculator) is mb_calculator.RampupBatchsizeNumMicroBatchesCalculator - - -class TestConstantNumMicroBatchesCalculator: - def setup_method(self, method): - self.mb_calculator = mb_calculator.ConstantNumMicroBatchesCalculator(32, 8, 2, False, 0) - - def test_constructor(self): - assert type(self.mb_calculator) is mb_calculator.ConstantNumMicroBatchesCalculator - assert self.mb_calculator.num_micro_batches == 2 - assert self.mb_calculator.current_global_batch_size == 32 - assert self.mb_calculator.micro_batch_size == 8 - - def test_get(self): - assert self.mb_calculator.get() == 2 - - def test_get_current_global_batch_size(self): - assert self.mb_calculator.get_current_global_batch_size() == 32 - - -class TestRampupBatchsizeNumMicroBatchesCalculator: - def setup_method(self, method): - self.mb_calculator = mb_calculator.RampupBatchsizeNumMicroBatchesCalculator( - 32, 8, 2, False, 0, 16, 16, 48 - ) - - def test_constructor(self): - assert type(self.mb_calculator) is mb_calculator.RampupBatchsizeNumMicroBatchesCalculator - assert self.mb_calculator.global_batch_size == 32 - assert self.mb_calculator.micro_batch_size == 8 - assert self.mb_calculator.data_parallel_size == 2 - assert self.mb_calculator.start_global_batch_size == 16 - assert self.mb_calculator.batch_size_increment == 16 - assert self.mb_calculator.ramup_samples == 48 - assert self.mb_calculator.micro_batch_times_data_parallel_size == 16 - assert self.mb_calculator.num_micro_batches == 1 - - def test_get(self): - assert self.mb_calculator.get() == 1 - - def test_get_current_global_batch_size(self): - assert self.mb_calculator.get_current_global_batch_size() == 16 - - -def test_ramp_up(): - mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 16, 96], 32, 8, 2, False) - consumed_samples = 0 - count = 0 - expected_consumed_samples = [0, 16, 32, 48, 64, 80, 96, 128, 160, 192, 224, 256] - - while consumed_samples < 256: - consumed_samples += mb_calculator.get_current_global_batch_size() - count += 1 - assert consumed_samples == expected_consumed_samples[count] - mb_calculator.update_num_microbatches(consumed_samples, True) diff --git a/tests/unit_tests/test_optimizer.py b/tests/unit_tests/test_optimizer.py deleted file mode 100644 index 26e6b30eb6..0000000000 --- a/tests/unit_tests/test_optimizer.py +++ /dev/null @@ -1,612 +0,0 @@ -import os - -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.optim import SGD, Adam - -# FP8 recipe will be used to test precision-aware-optimizer. -from transformer_engine.pytorch.fp8 import fp8_autocast - -from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.optimizer import ChainedOptimizer, OptimizerConfig, get_megatron_optimizer -from megatron.core.process_groups_config import GradCommProcessGroups, ModelCommProcessGroups -from megatron.core.transformer import TransformerConfig -from megatron.core.utils import is_te_min_version, is_torch_min_version -from tests.unit_tests.test_utilities import Utils -from tests.unit_tests.test_utils import _deinit_distributed, _init_distributed - -try: - # Check if FP8 block scaling is available. - from transformer_engine.pytorch.fp8 import check_fp8_block_scaling_support - - fp8_block_scaling_available, reason_for_no_fp8_block_scaling = check_fp8_block_scaling_support() - from transformer_engine.common.recipe import Float8BlockScaling, Format -except: - fp8_block_scaling_available = False - reason_for_no_fp8_block_scaling = "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." - try: - from transformer_engine.common.recipe import DelayedScaling - except: - delayed_scaling_available = False - - -class Net(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) # flatten all dimensions except batch - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -def test_chained_optimizer(): - net = Net() - optimizer_1 = Adam(list(net.parameters())[:2], lr=0.01) - optimizer_2 = SGD(list(net.parameters())[2:], lr=0.1, momentum=0.9) - chained_optimizer = ChainedOptimizer([optimizer_1, optimizer_2]) - - # Test the chained optimizer's param groups is a reference of the underlying optimizers' param groups - assert optimizer_1.param_groups[0]["lr"] == 0.01 - chained_optimizer.param_groups[0]["lr"] = 0.02 - assert optimizer_1.param_groups[0]["lr"] == 0.02 - - # Test the chained optimizer's state is a reference of the underlying optimizers' state - # 1. run step on optimizers, make sure there is state - assert len(chained_optimizer.state) == 0 - input = torch.randn(1, 3, 32, 32) - output = net(input) - output.sum().backward() - optimizer_1.step() - optimizer_2.step() - assert len(chained_optimizer.state) != 0 - - # 2. check the state is a reference - assert not list(optimizer_1.state.values())[0]["exp_avg"].is_cuda - assert not list(optimizer_2.state.values())[0]["momentum_buffer"].is_cuda - - def to_cuda(d): - for k, v in d.items(): - if isinstance(v, torch.Tensor): - d[k] = v.to("cuda") - elif isinstance(v, dict): - to_cuda(v) - return d - - for k, v in chained_optimizer.state.items(): - chained_optimizer.state[k] = to_cuda(v) - - assert list(optimizer_1.state.values())[0]["exp_avg"].is_cuda - assert list(optimizer_2.state.values())[0]["momentum_buffer"].is_cuda - - -def test_precision_aware_fused_adam(): - try: - from transformer_engine.pytorch.optimizers import FusedAdam - except ImportError: - # Older versions of TE don't have FusedAdam. - return - - import inspect - - adam_args = inspect.signature(FusedAdam).parameters - arg_names = ["master_weight_dtype", "exp_avg_dtype", "exp_avg_sq_dtype", "use_decoupled_grad"] - for name in arg_names: - if name not in adam_args: - # Skip the test if TE doesn't support precision aware FusedAdam. - return - - tensor = torch.rand(278011, dtype=torch.bfloat16).cuda() - params_1 = [torch.nn.Parameter(tensor.float())] # FP32 reference - params_2 = [torch.nn.Parameter(tensor.clone())] # BF16 - - options = {"lr": 1, "betas": (0.1, 0.25), "eps": 1e-08, "weight_decay": 0, "amsgrad": False} - - optimizer_1 = FusedAdam(params_1, **options) - optimizer_2 = FusedAdam(params_2, master_weights=True, use_decoupled_grad=True, **options) - - for _ in range(1000): - for p_1, p_2 in zip(params_1, params_2): - p_1.grad = torch.rand_like(p_1) - p_2.decoupled_grad = p_1.grad.clone() - - optimizer_1.step() - optimizer_2.step() - - master_params = [optimizer_2.get_unscaled_state(p, "master_param") for p in params_2] - for p_1, p_2 in zip(params_1, master_params): - bytes_1 = p_1.data.view(torch.uint8) - bytes_2 = p_2.data.view(torch.uint8) - # Make sure bit-wise matched - assert torch.all(bytes_1 == bytes_2) - - for p_1, p_2 in zip(params_1, params_2): - bytes_1 = p_1.data.bfloat16().view(torch.uint8) - bytes_2 = p_2.data.view(torch.uint8) - # Make sure bit-wise matched - assert torch.all(bytes_1 == bytes_2) - - -@pytest.mark.skipif( - not is_te_min_version("1.13.0"), reason="TE 1.13.0 is required for precision aware optimizer" -) -@pytest.mark.parametrize("precision", ['bf16', 'fp8']) -@pytest.mark.parametrize("main_params_dtype", [torch.float32, torch.float16]) -@pytest.mark.parametrize("main_grads_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize( - # use the same dtype for exp_avg and exp_avg_sq to reduce the number of tests - "moment_dtype", - [torch.float32, torch.float16, torch.bfloat16, torch.uint8], -) -def test_precision_aware_optimizer( - precision: str, - main_params_dtype: torch.dtype, - main_grads_dtype: torch.dtype, - moment_dtype: torch.dtype, -): - # Skip because bf16 optimizer states are not supported before TE 2.3.0 - if (moment_dtype == torch.bfloat16) and not is_te_min_version("2.3.0"): - pytest.skip("bfloat16 for moment_dtype requires TE >= 2.3.0") - - if precision == 'fp8': - if not fp8_block_scaling_available: - fp8_recipe = "delayed" - fp8_recipe_settings = DelayedScaling() - else: - fp8_recipe = "blockwise" - fp8_recipe_settings = Float8BlockScaling(fp8_format=Format.E4M3) - else: - fp8_recipe = None - fp8_recipe_settings = None - - world = int(os.getenv('WORLD_SIZE', '1')) - rank = int(os.getenv('RANK', '0')) - - # Setup: distributed, model, mock_args. - _init_distributed(world, rank) - Utils.initialize_model_parallel() - - # First create baseline model with float32 optimizer states - baseline_model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda') - baseline_model.requires_grad_(True) - baseline_model.weight.data.fill_(1.0) - baseline_ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) - baseline_model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), baseline_ddp_config, baseline_model - ) - baseline_optimizer_config = OptimizerConfig( - optimizer='adam', - lr=0.01, - bf16=True, - use_distributed_optimizer=True, - use_precision_aware_optimizer=False, - main_params_dtype=torch.float32, - main_grads_dtype=torch.float32, - exp_avg_dtype=torch.float32, - exp_avg_sq_dtype=torch.float32, - ) - baseline_optim = get_megatron_optimizer(baseline_optimizer_config, [baseline_model]) - - # Create test model with specified dtypes for optimizer states - test_model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda') - test_model.requires_grad_(True) - test_model.weight.data.fill_(1.0) - ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) - test_model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, test_model - ) - test_optimizer_config = OptimizerConfig( - optimizer='adam', - lr=0.01, - bf16=True, - fp8_recipe=fp8_recipe, - use_distributed_optimizer=True, - use_precision_aware_optimizer=True, - main_params_dtype=main_params_dtype, - main_grads_dtype=main_grads_dtype, - exp_avg_dtype=moment_dtype, - exp_avg_sq_dtype=moment_dtype, - ) - test_optim = get_megatron_optimizer(test_optimizer_config, [test_model]) - - # Use same input for both models - input = torch.randn(8, 100, dtype=torch.bfloat16, device='cuda') - - # Run model - def run_model(model, input, optim, fp8_recipe, fp8_recipe_settings): - if not fp8_recipe: - output = model(input) - else: - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe_settings): - output = model(input) - loss = output.sum() - loss.backward() - optim.step() - return loss.item(), optim.get_grad_norm() - - # Run baseline model and test model - baseline_loss, baseline_grad_norm = run_model( - baseline_model, input, baseline_optim, fp8_recipe, fp8_recipe_settings - ) - test_loss, test_grad_norm = run_model( - test_model, input, test_optim, fp8_recipe, fp8_recipe_settings - ) - - rtol = 1e-3 # relative tolerance - atol = 1e-5 # absolute tolerance - - # Compare grad norms - allow small difference due to precision - rel_diff = abs(test_grad_norm - baseline_grad_norm) / ( - abs(baseline_grad_norm) + 1e-7 # avoid div by 0 - ) - abs_diff = abs(test_grad_norm - baseline_grad_norm) - assert ( - rel_diff <= rtol or abs_diff <= atol - ), f"Grad norm mismatch: baseline={baseline_grad_norm}, test={test_grad_norm}, rel_diff={rel_diff}, abs_diff={abs_diff}" - - # Compare losses - allow small difference due to precision - loss_rel_diff = abs(test_loss - baseline_loss) / (abs(baseline_loss) + 1e-7) - loss_abs_diff = abs(test_loss - baseline_loss) - assert ( - loss_rel_diff <= rtol or loss_abs_diff <= atol - ), f"Loss mismatch: baseline={baseline_loss}, test={test_loss}, rel_diff={loss_rel_diff}, abs_diff={loss_abs_diff}" - - # Save and reload state dict for the test model - state_dict = test_optim.state_dict() - test_optim.load_state_dict(state_dict) - - -@pytest.mark.parametrize("use_distributed_optimizer", [False, True]) -@pytest.mark.parametrize("precision", ['bf16', 'fp32']) -def test_optim_sharded_state_dict(use_distributed_optimizer: bool, precision: str): - world = int(os.getenv('WORLD_SIZE', '1')) - rank = int(os.getenv('RANK', '0')) - - # Setup: distributed, model, mock_args. - _init_distributed(world, rank) - Utils.initialize_model_parallel() - model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda') - model.requires_grad_(True) - model.weight.data.fill_(1.0) - ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=use_distributed_optimizer) - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model - ) - for param in model.parameters(): - assert param.requires_grad - - if precision == 'bf16': - optimizer_config = OptimizerConfig( - optimizer='adam', bf16=True, use_distributed_optimizer=use_distributed_optimizer - ) - elif precision == 'fp32': - optimizer_config = OptimizerConfig( - optimizer='adam', - bf16=False, - fp16=False, - use_distributed_optimizer=use_distributed_optimizer, - ) - optim = get_megatron_optimizer(optimizer_config, [model]) - - model_sharded_state_dict = model.sharded_state_dict() - sharded_state_dict = optim.sharded_state_dict(model_sharded_state_dict) - - if 'optimizer' in sharded_state_dict and 'state' in sharded_state_dict['optimizer']: - assert ( - 'common_step' not in sharded_state_dict['optimizer']['state'] - or sharded_state_dict['optimizer']['state']['common_step'] is not None - ), "Found 'optimizer.state.common_step=None' in sharded state dict." - - -def test_optimizer_reload_model_params(): - world = int(os.getenv('WORLD_SIZE', '1')) - rank = int(os.getenv('RANK', '0')) - _init_distributed(world, rank) - Utils.initialize_model_parallel() - - model = Net().bfloat16().cuda() - # Initial values of model params are 1. - for param in model.parameters(): - param.data.fill_(1.0) - ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model - ) - optimizer_config = OptimizerConfig(optimizer='adam', bf16=True, use_distributed_optimizer=True) - optim = get_megatron_optimizer(optimizer_config, [model]) - - # Set all model params to 2. - for param in model.parameters(): - param.data.fill_(2.0) - - # Although model params are 2 now, but we haven't called reload_model_params() yet, so - # main_params should be 1. - for group in optim.param_groups: - for main_param in group['params']: - assert main_param.dtype == torch.float32 - torch.testing.assert_close( - main_param, torch.empty_like(main_param).fill_(1.0), atol=0, rtol=0 - ) - - # Copy model params to main_params, so main_params should be 2 now. - optim.reload_model_params() - for group in optim.param_groups: - for main_param in group['params']: - assert main_param.dtype == torch.float32 - torch.testing.assert_close( - main_param, torch.empty_like(main_param).fill_(2.0), atol=0, rtol=0 - ) - - # Create a new state_dict with all params set to 3. - state_dict = model.state_dict() - new_state_dict = {} - for name, param in state_dict.items(): - new_state_dict[name] = torch.empty_like(param).fill_(3.0) - - # Initialize main_params with the new state_dict, so main_params should be 3 now, but model - # params should still be 2. - optim.reload_model_params(new_state_dict) - for param in model.parameters(): - torch.testing.assert_close(param, torch.empty_like(param).fill_(2.0), atol=0, rtol=0) - for group in optim.param_groups: - for main_param in group['params']: - assert main_param.dtype == torch.float32 - torch.testing.assert_close( - main_param, torch.empty_like(main_param).fill_(3.0), atol=0, rtol=0 - ) - - -@pytest.mark.skipif( - not is_torch_min_version("2.4.0"), - reason="torch.distributed.init_device_mesh requires torch >= 2.4.0", -) -@pytest.mark.parametrize( - "world_size, tp_size, cp_size, dp_size", - [ - (1, 1, 1, 1), # Single GPU, no parallelism - (2, 1, 2, 1), # 2 GPUs, 1 TP, 2 CP - (2, 2, 1, 1), # 2 GPUs, 2 TP, 1 CP - (8, 8, 1, 1), # 8 GPUs, 8 TP, 1 CP - (8, 2, 4, 1), # 8 GPUs, 2 TP, 4 CP - (8, 4, 2, 1), # 8 GPUs, 4 TP, 2 CP - (8, 1, 1, 8), # 8 GPUs, 1 TP, 1 CP, 8 DP - (8, 2, 1, 4), # 8 GPUs, 2 TP, 1 CP, 4 DP - (8, 2, 2, 2), # 8 GPUs, 2 TP, 2 CP, 2 DP - ], -) -def test_get_megatron_optimizer_with_custom_process_groups(world_size, tp_size, cp_size, dp_size): - """ - Test that get_megatron_optimizer works correctly with custom process groups - provided via grad_comm_pgs and model_comm_pgs parameters. - """ - # Skip if world size doesn't match available GPUs - actual_world_size = torch.cuda.device_count() - if actual_world_size != world_size: - pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}") - - # Initialize model parallel with default settings first - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, context_parallel_size=cp_size - ) - - # Create device mesh for custom process groups - device_mesh = torch.distributed.init_device_mesh( - "cuda", (1, dp_size, 1, cp_size, tp_size), mesh_dim_names=("pp", "dp", "ep", "cp", "tp") - ) - - # Create custom process groups from device mesh - dp_group = device_mesh.get_group(mesh_dim="dp") - cp_group = device_mesh.get_group(mesh_dim="cp") - tp_group = device_mesh.get_group(mesh_dim="tp") - pp_group = device_mesh.get_group(mesh_dim="pp") - - # Create dp_cp group - dp_cp_mesh = device_mesh["dp", "cp"] - dp_cp_group = dp_cp_mesh._flatten().get_group() - - # Create model parallel group (tp + pp) - mp_mesh = device_mesh["pp", "tp"] - mp_group = mp_mesh._flatten().get_group() - - # Create process group configurations - grad_comm_pgs = GradCommProcessGroups() - grad_comm_pgs.dp = dp_group - grad_comm_pgs.dp_cp = dp_cp_group - grad_comm_pgs.expt_dp = None # Not using expert parallelism in this test - - model_comm_pgs = ModelCommProcessGroups() - model_comm_pgs.tp = tp_group - model_comm_pgs.cp = cp_group - model_comm_pgs.pp = pp_group - model_comm_pgs.mp = mp_group - model_comm_pgs.tp_ep_pp = None # Not using expert parallelism in this test - - # Create a simple model for testing - model = torch.nn.Linear(100, 100, bias=False, device='cuda') - model.requires_grad_(True) - model.weight.data.fill_(1.0) - ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model - ) - for param in model.parameters(): - assert param.requires_grad - model_chunks = [model] - - # Create optimizer config - optimizer_config = OptimizerConfig( - optimizer='adam', - lr=0.001, - weight_decay=0.01, - adam_beta1=0.9, - adam_beta2=0.999, - adam_eps=1e-8, - ) - - # Test 1: Create optimizer with custom process groups - optimizer = get_megatron_optimizer( - config=optimizer_config, - model_chunks=model_chunks, - use_gloo_process_groups=False, # Required when using custom process groups - grad_comm_pgs=grad_comm_pgs, - model_comm_pgs=model_comm_pgs, - ) - - # Verify optimizer was created successfully - assert optimizer is not None, "Optimizer should not be None" - assert hasattr(optimizer, 'param_groups'), "Optimizer should have param_groups" - assert len(optimizer.param_groups) > 0, "Optimizer should have at least one parameter group" - - # Test 2: Verify optimizer can perform forward and backward pass - input_tensor = torch.randn(32, 100, device='cuda', requires_grad=True) - output = model(input_tensor) - loss = output.sum() - loss.backward() - - # Test 3: Optimizer step should work - optimizer.zero_grad() - output = model(input_tensor) - loss = output.sum() - loss.backward() - - # Store original parameters - original_weight = model.module.weight.data.clone() - original_bias = model.module.bias.data.clone() if model.module.bias is not None else None - - # Perform optimizer step - optimizer.step() - - # Verify parameters were updated - assert not torch.equal( - model.module.weight.data, original_weight - ), "Weight should be updated after optimizer step" - if model.module.bias is not None: - assert not torch.equal( - model.module.bias.data, original_bias - ), "Bias should be updated after optimizer step" - - # Test 4: Compare with default process groups optimizer (if world_size allows) - if world_size == 1: # Only test on single GPU to avoid complex setup - # Create optimizer with default process groups - default_optimizer = get_megatron_optimizer( - config=optimizer_config, model_chunks=model_chunks - ) - - # Both optimizers should have the same structure - assert len(optimizer.param_groups) == len( - default_optimizer.param_groups - ), "Custom and default optimizers should have same number of parameter groups" - - -def test_get_megatron_optimizer_custom_process_groups_validation(): - """ - Test validation logic for custom process groups in get_megatron_optimizer. - """ - Utils.initialize_model_parallel(tensor_model_parallel_size=1) - - # Create a simple model - model = torch.nn.Linear(100, 100, bias=False, device='cuda') - model.requires_grad_(True) - model.weight.data.fill_(1.0) - ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model - ) - for param in model.parameters(): - assert param.requires_grad - model_chunks = [model] - optimizer_config = OptimizerConfig(optimizer='adam', lr=0.001) - - # Test 1: Both grad_comm_pgs and model_comm_pgs must be provided together - grad_comm_pgs = GradCommProcessGroups() - grad_comm_pgs.dp = torch.distributed.new_group() - - with pytest.raises( - ValueError, match="Grad and model comm process groups must be provided or both must be None" - ): - get_megatron_optimizer( - config=optimizer_config, - model_chunks=model_chunks, - grad_comm_pgs=grad_comm_pgs, - model_comm_pgs=None, # Missing model_comm_pgs - ) - - # Test 2: Missing dp process group in grad_comm_pgs - grad_comm_pgs_no_dp = GradCommProcessGroups() - # Missing required 'dp' group - model_comm_pgs = ModelCommProcessGroups() - - with pytest.raises(ValueError, match="dp process group is required"): - get_megatron_optimizer( - config=optimizer_config, - model_chunks=model_chunks, - grad_comm_pgs=grad_comm_pgs_no_dp, - model_comm_pgs=model_comm_pgs, - ) - - # Test 3: Missing expt_dp attribute in grad_comm_pgs - grad_comm_pgs_no_expt_dp = GradCommProcessGroups() - grad_comm_pgs_no_expt_dp.dp = torch.distributed.new_group() - # Missing required 'expt_dp' attribute - - with pytest.raises(AssertionError, match="expt_dp process group is required"): - get_megatron_optimizer( - config=optimizer_config, - model_chunks=model_chunks, - grad_comm_pgs=grad_comm_pgs_no_expt_dp, - model_comm_pgs=model_comm_pgs, - ) - - # Test 4: Missing mp attribute in model_comm_pgs - grad_comm_pgs_complete = GradCommProcessGroups() - grad_comm_pgs_complete.dp = torch.distributed.new_group() - grad_comm_pgs_complete.expt_dp = None # Explicitly set to None as allowed - model_comm_pgs_no_mp = ModelCommProcessGroups() - # Missing required 'mp' attribute - - with pytest.raises(AssertionError, match="mp process group is required"): - get_megatron_optimizer( - config=optimizer_config, - model_chunks=model_chunks, - grad_comm_pgs=grad_comm_pgs_complete, - model_comm_pgs=model_comm_pgs_no_mp, - ) - - # Test 5: Missing tp_ep_pp attribute in model_comm_pgs - model_comm_pgs_no_tp_ep_pp = ModelCommProcessGroups() - model_comm_pgs_no_tp_ep_pp.mp = None # Explicitly set to None as allowed - # Missing required 'tp_ep_pp' attribute - - with pytest.raises(AssertionError, match="tp_ep_pp process group is required"): - get_megatron_optimizer( - config=optimizer_config, - model_chunks=model_chunks, - grad_comm_pgs=grad_comm_pgs_complete, - model_comm_pgs=model_comm_pgs_no_tp_ep_pp, - ) - - # Test 6: Gloo process groups should not be used with custom process groups - model_comm_pgs_complete = ModelCommProcessGroups() - model_comm_pgs_complete.mp = None # Explicitly set to None as allowed - model_comm_pgs_complete.tp_ep_pp = None # Explicitly set to None as allowed - - with pytest.raises(AssertionError, match="Gloo process groups are not supported"): - get_megatron_optimizer( - config=optimizer_config, - model_chunks=model_chunks, - use_gloo_process_groups=True, # Should be False when using custom groups - grad_comm_pgs=grad_comm_pgs_complete, - model_comm_pgs=model_comm_pgs_complete, - ) diff --git a/tests/unit_tests/test_optimizer_cpu_offloading.py b/tests/unit_tests/test_optimizer_cpu_offloading.py deleted file mode 100644 index 1c367100da..0000000000 --- a/tests/unit_tests/test_optimizer_cpu_offloading.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import random - -import numpy as np -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.optim import SGD, Adam - -try: - from transformer_engine.pytorch.optimizers import FusedAdam as GPUAdam - from transformer_engine.pytorch.optimizers import FusedSGD as GPUSGD -except: - # Handle environment where transformer_engine is not installed - from torch.optim import SGD as GPUSGD - from torch.optim import Adam as GPUAdam - -from megatron.core.optimizer.cpu_offloading import HybridDeviceOptimizer - - -class Net(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) # flatten all dimensions except batch - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -def setup_seed(seed): - random.seed(seed) # Set Python's built-in random seed - np.random.seed(seed) # Set NumPy's random seed - torch.manual_seed(seed) # Set PyTorch's CPU seed - torch.cuda.manual_seed(seed) # Set PyTorch's GPU seed (if using CUDA) - torch.cuda.manual_seed_all(seed) # Set seed for all GPUs - torch.backends.cudnn.deterministic = True # Ensure deterministic behavior - torch.backends.cudnn.benchmark = False # Disable auto-tuner for reproducibility - - -@pytest.mark.skipif( - torch.__version__ < '2.3.0', - reason=( - "Requires PyTorch 2.3.0 or higher, lower versions of pytorch have " - "misaligned optimizer accuracy for CPU and GPU." - ), -) -@pytest.mark.parametrize('n_steps', [1, 10]) -@pytest.mark.parametrize('overlap_cpu_optimizer_d2h_h2d', [False, True]) -@pytest.mark.parametrize('offload_fraction', [0, 0.5, 1.0]) -@pytest.mark.parametrize('optimizer', ['sgd', 'adam']) -@pytest.mark.parametrize('with_param_groups', [False, True]) -def test_multi_device_hybrid_optimizer( - with_param_groups, optimizer, offload_fraction, overlap_cpu_optimizer_d2h_h2d, n_steps -): - setup_seed(42) - net1 = Net().cuda() - net2 = Net().cuda() - net2.load_state_dict(net1.state_dict()) - base_lr = 1e-3 - params = list(net1.parameters()) - ref_params = list(net2.parameters()) - if with_param_groups: - param_groups = [ - {"params": params[: len(params) // 2], "wd_mult": 1.0, "lr_mult": 1e-4}, - {"params": params[len(params) // 2 :], "wd_mult": 0.0, "lr_mult": 2e-4}, - ] - params = param_groups - ref_param_groups = [ - {"params": ref_params[: len(ref_params) // 2], "wd_mult": 1.0, "lr_mult": 1e-4}, - {"params": ref_params[len(ref_params) // 2 :], "wd_mult": 0.0, "lr_mult": 2e-4}, - ] - ref_params = ref_param_groups - - if optimizer == 'adam': - cls_kwargs = dict(cpu_optimizer_cls=Adam, gpu_optimizer_cls=GPUAdam) - else: - cls_kwargs = dict(cpu_optimizer_cls=SGD, gpu_optimizer_cls=GPUSGD) - - hdo = HybridDeviceOptimizer( - params, - offload_fraction=offload_fraction, - lr=base_lr, - overlap_cpu_optimizer_d2h_h2d=overlap_cpu_optimizer_d2h_h2d, - **cls_kwargs, - ) - - ref_optimizer = cls_kwargs['gpu_optimizer_cls'](ref_params, lr=base_lr) - - # 1. run step on optimizer, make sure there is state generated - assert len(hdo.state_dict()["state"]) == 0 # state is empty - input = torch.randn(1, 3, 32, 32).cuda() - output = net1(input) - output.sum().backward() - hdo.step() - output = net2(input) - output.sum().backward() - ref_optimizer.step() - # PyTorch SGD will not generate state - if optimizer != 'sgd': - assert len(hdo.state_dict()["state"]) != 0 - - # 2. check the state is on right device - if optimizer == 'adam': - first_param_id = hdo.state_dict()["param_groups"][0]["params"][0] - last_param_id = hdo.state_dict()["param_groups"][-1]["params"][-1] - if offload_fraction > 0: - assert not hdo.state_dict()["state"][first_param_id]["exp_avg"].is_cuda - if offload_fraction < 1: - assert hdo.state_dict()["state"][last_param_id]["exp_avg"].is_cuda - - # 3. check parameters allclose - for _ in range(1, n_steps): - input = torch.randn(1, 3, 32, 32).cuda() - output = net1(input) - output.sum().backward() - hdo.step() - output = net2(input) - output.sum().backward() - ref_optimizer.step() - - params = net1.state_dict() - ref_params = net2.state_dict() - for k, v in params.items(): - assert (v.isnan() == ref_params[k].isnan()).all() - torch.nan_to_num_(v, 0) - torch.nan_to_num_(ref_params[k], 0) - assert torch.allclose( - v, ref_params[k], atol=1e-03 - ), f"Weight {k} value mismatch, max error: {(v - ref_params[k]).abs().max()}" diff --git a/tests/unit_tests/test_optimizer_param_scheduler.py b/tests/unit_tests/test_optimizer_param_scheduler.py deleted file mode 100644 index 9b78169454..0000000000 --- a/tests/unit_tests/test_optimizer_param_scheduler.py +++ /dev/null @@ -1,251 +0,0 @@ -import math -from unittest.mock import MagicMock - -import pytest - -from megatron.core.optimizer_param_scheduler import ( # Adjust import according to your module path - OptimizerParamScheduler, -) - - -@pytest.fixture -def mock_optimizer(): - optimizer = MagicMock() - optimizer.param_groups = [{'lr': 0.0, 'weight_decay': 0.0}] - return optimizer - - -def test_initialization(mock_optimizer): - scheduler = OptimizerParamScheduler( - optimizer=mock_optimizer, - init_lr=0.01, - max_lr=0.1, - min_lr=0.001, - lr_warmup_steps=100, - lr_decay_steps=1000, - lr_decay_style='linear', - start_wd=0.0, - end_wd=0.1, - wd_incr_steps=1000, - wd_incr_style='linear', - ) - - assert scheduler.init_lr == 0.01 - assert scheduler.max_lr == 0.1 - assert scheduler.min_lr == 0.001 - assert scheduler.lr_warmup_steps == 100 - assert scheduler.lr_decay_steps == 1000 - assert scheduler.lr_decay_style == 'linear' - assert scheduler.start_wd == 0.0 - assert scheduler.end_wd == 0.1 - assert scheduler.wd_incr_steps == 1000 - assert scheduler.wd_incr_style == 'linear' - - -def test_get_wd_constant(mock_optimizer): - scheduler = OptimizerParamScheduler( - optimizer=mock_optimizer, - init_lr=0.01, - max_lr=0.1, - min_lr=0.001, - lr_warmup_steps=100, - lr_decay_steps=1000, - lr_decay_style='linear', - start_wd=0.1, - end_wd=0.1, - wd_incr_steps=1000, - wd_incr_style='constant', - ) - - scheduler.step(500) - wd = scheduler.get_wd() - assert wd == 0.1 - - -def test_get_wd_linear(mock_optimizer): - scheduler = OptimizerParamScheduler( - optimizer=mock_optimizer, - init_lr=0.01, - max_lr=0.1, - min_lr=0.001, - lr_warmup_steps=100, - lr_decay_steps=1000, - lr_decay_style='linear', - start_wd=0.0, - end_wd=0.1, - wd_incr_steps=1000, - wd_incr_style='linear', - ) - - scheduler.step(500) - wd = scheduler.get_wd() - assert wd == 0.05 - - -def test_get_wd_cosine(mock_optimizer): - scheduler = OptimizerParamScheduler( - optimizer=mock_optimizer, - init_lr=0.01, - max_lr=0.1, - min_lr=0.001, - lr_warmup_steps=100, - lr_decay_steps=1000, - lr_decay_style='cosine', - start_wd=0.0, - end_wd=0.1, - wd_incr_steps=1000, - wd_incr_style='cosine', - ) - - scheduler.step(500) - wd = scheduler.get_wd() - expected_wd = 0.05 * (math.cos(math.pi * (1 - 0.5)) + 1.0) - assert math.isclose(wd, expected_wd, rel_tol=1e-5) - - -def test_get_lr_linear(mock_optimizer): - scheduler = OptimizerParamScheduler( - optimizer=mock_optimizer, - init_lr=0.01, - max_lr=0.1, - min_lr=0.001, - lr_warmup_steps=100, - lr_decay_steps=1000, - lr_decay_style='linear', - start_wd=0.0, - end_wd=0.1, - wd_incr_steps=1000, - wd_incr_style='linear', - ) - - param_group = {'max_lr': 0.1, 'min_lr': 0.001} - - scheduler.step(50) - lr = scheduler.get_lr(param_group) - expected_lr = 0.01 + (0.1 - 0.01) * (50 / 100) - assert math.isclose(lr, expected_lr, rel_tol=1e-5) - - scheduler.step(450) - lr = scheduler.get_lr(param_group) - expected_lr = 0.1 - ((0.1 - 0.001) * ((500 - 100) / (1000 - 100))) - assert math.isclose(lr, expected_lr, rel_tol=1e-5) - - scheduler.step(501) - lr = scheduler.get_lr(param_group) - expected_lr = 0.001 - assert math.isclose(lr, expected_lr, rel_tol=1e-5) - - -def test_get_lr_cosine(mock_optimizer): - scheduler = OptimizerParamScheduler( - optimizer=mock_optimizer, - init_lr=0.01, - max_lr=0.1, - min_lr=0.001, - lr_warmup_steps=100, - lr_decay_steps=1000, - lr_decay_style='cosine', - start_wd=0.0, - end_wd=0.1, - wd_incr_steps=1000, - wd_incr_style='linear', - ) - - scheduler.step(500) - param_group = {'max_lr': 0.1, 'min_lr': 0.001} - lr = scheduler.get_lr(param_group) - expected_lr = 0.001 + (0.1 - 0.001) * 0.5 * ( - math.cos(math.pi * ((500 - 100) / (1000 - 100))) + 1.0 - ) - assert math.isclose(lr, expected_lr, rel_tol=1e-5) - - -def test_step_function(mock_optimizer): - scheduler = OptimizerParamScheduler( - optimizer=mock_optimizer, - init_lr=0.01, - max_lr=0.1, - min_lr=0.001, - lr_warmup_steps=100, - lr_decay_steps=1000, - lr_decay_style='linear', - start_wd=0.0, - end_wd=0.1, - wd_incr_steps=1000, - wd_incr_style='linear', - ) - - scheduler.step(100) - assert scheduler.num_steps == 100 - param_group = mock_optimizer.param_groups[0] - assert math.isclose(param_group['lr'], 0.01 + (0.1 - 0.01) * (100 / 100), rel_tol=1e-5) - assert math.isclose(param_group['weight_decay'], 0.01, rel_tol=1e-5) - - -def test_state_dict(mock_optimizer): - scheduler = OptimizerParamScheduler( - optimizer=mock_optimizer, - init_lr=0.01, - max_lr=0.1, - min_lr=0.001, - lr_warmup_steps=100, - lr_decay_steps=1000, - lr_decay_style='linear', - start_wd=0.0, - end_wd=0.1, - wd_incr_steps=1000, - wd_incr_style='linear', - ) - - state_dict = scheduler.state_dict() - assert state_dict['max_lr'] == 0.1 - assert state_dict['lr_warmup_steps'] == 100 - assert state_dict['num_steps'] == 0 - assert state_dict['lr_decay_style'] == 'linear' - assert state_dict['lr_decay_steps'] == 1000 - assert state_dict['min_lr'] == 0.001 - assert state_dict['start_wd'] == 0.0 - assert state_dict['end_wd'] == 0.1 - assert state_dict['wd_incr_style'] == 'linear' - assert state_dict['wd_incr_steps'] == 1000 - - -def test_load_state_dict(mock_optimizer): - scheduler = OptimizerParamScheduler( - optimizer=mock_optimizer, - init_lr=0.01, - max_lr=0.1, - min_lr=0.001, - lr_warmup_steps=100, - lr_decay_steps=1000, - lr_decay_style='linear', - start_wd=0.0, - end_wd=0.1, - wd_incr_steps=1000, - wd_incr_style='linear', - ) - - state_dict = { - 'max_lr': 0.2, - 'min_lr': 0.0005, - 'lr_warmup_steps': 200, - 'lr_decay_steps': 2000, - 'lr_decay_style': 'cosine', - 'num_steps': 500, - 'start_wd': 0.01, - 'end_wd': 0.2, - 'wd_incr_steps': 500, - 'wd_incr_style': 'cosine', - } - - scheduler.load_state_dict(state_dict) - assert scheduler.max_lr == 0.2 - assert scheduler.min_lr == 0.0005 - assert scheduler.lr_warmup_steps == 200 - assert scheduler.lr_decay_steps == 2000 - assert scheduler.lr_decay_style == 'cosine' - assert scheduler.num_steps == 500 - assert scheduler.start_wd == 0.01 - assert scheduler.end_wd == 0.2 - assert scheduler.wd_incr_steps == 500 - assert scheduler.wd_incr_style == 'cosine' diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py deleted file mode 100644 index ddee638776..0000000000 --- a/tests/unit_tests/test_parallel_state.py +++ /dev/null @@ -1,498 +0,0 @@ -import pytest -import torch - -import megatron.core.parallel_state as ps -from tests.unit_tests.test_utilities import Utils - -rank = Utils.rank -world_size = Utils.world_size -test_parallel_order = ['tp-cp-ep-dp-pp', 'tp-cp-pp-ep-dp'] - - -@pytest.mark.parametrize('order', test_parallel_order) -@pytest.mark.flaky_in_dev -def test_initialize_and_destroy_model_parallel(order): - with pytest.raises(AssertionError): - assert ps.initialize_model_parallel(order=order) - Utils.initialize_distributed() - with pytest.raises(RuntimeError): - assert ps.initialize_model_parallel(tensor_model_parallel_size=2 * world_size, order=order) - with pytest.raises(RuntimeError): - assert ps.initialize_model_parallel( - pipeline_model_parallel_size=2 * world_size, order=order - ) - with pytest.raises(RuntimeError): - assert ps.initialize_model_parallel( - pipeline_model_parallel_size=world_size, - tensor_model_parallel_size=world_size, - order=order, - ) - with pytest.raises(RuntimeError): - assert ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2, order=order) - Utils.initialize_model_parallel( - tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order - ) - - assert ps.model_parallel_is_initialized() - assert ps.get_model_parallel_group() is not None - assert ps.get_tensor_model_parallel_group() is not None - assert ps.get_pipeline_model_parallel_group() is not None - assert ps.get_data_parallel_group() is not None - assert ps.get_expert_model_parallel_group() is not None - assert ps.get_expert_tensor_parallel_group() is not None - assert ps.get_expert_data_parallel_group() is not None - assert ps.get_expert_tensor_model_pipeline_parallel_group() is not None - Utils.destroy_model_parallel() - assert ps._MODEL_PARALLEL_GROUP is None - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_pipeline_parallel_initializations(order): - Utils.initialize_model_parallel( - tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order - ) - assert ps.get_pipeline_model_parallel_first_rank() == rank % 2 - assert ps.get_data_parallel_src_rank() == rank - assert ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size) - assert ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size) - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_data_parallel_initializations(order): - Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert ps.get_data_parallel_src_rank() == rank - assert ps.get_data_parallel_world_size() == 1 - assert ps.get_data_parallel_rank() == 0 - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_tensor_model_parellel_world_size(order): - Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order) - assert ps.get_tensor_model_parallel_world_size() == world_size - ps.set_tensor_model_parallel_world_size(None) - assert ps.get_tensor_model_parallel_world_size() == world_size - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_expert_tensor_parellel_world_size(order): - Utils.initialize_model_parallel(expert_tensor_parallel_size=world_size, order=order) - assert ps.get_expert_tensor_parallel_world_size() == world_size - ps.set_expert_tensor_parallel_world_size(None) - assert ps.get_expert_tensor_parallel_world_size() == world_size - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_pipeline_model_parallel_world_size(order): - Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert ps.get_pipeline_model_parallel_world_size() == world_size - ps.set_pipeline_model_parallel_world_size(None) - assert ps.get_pipeline_model_parallel_world_size() == world_size - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_tensor_model_parallel_rank(order): - Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order) - assert ps.get_tensor_model_parallel_rank() == rank - ps.set_tensor_model_parallel_rank(None) - assert ps.get_tensor_model_parallel_rank() == rank - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_moe_tensor_model_parellel_rank(order): - Utils.initialize_model_parallel(expert_tensor_parallel_size=world_size, order=order) - assert ps.get_expert_tensor_parallel_rank() == rank - ps.set_expert_tensor_parallel_rank(None) - assert ps.get_expert_tensor_parallel_rank() == rank - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_pipeline_model_parallel_rank(order): - Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert ps.get_pipeline_model_parallel_rank() == rank - ps.set_pipeline_model_parallel_rank(None) - assert ps.get_pipeline_model_parallel_rank() == rank - Utils.destroy_model_parallel() - - -def test_context_parallel_rank(): - Utils.initialize_model_parallel(context_parallel_size=world_size) - assert ps.get_context_parallel_rank() == rank - Utils.destroy_model_parallel() - - -def test_expert_model_parallel_rank(): - Utils.initialize_model_parallel(expert_model_parallel_size=world_size) - assert ps.get_expert_model_parallel_rank() == rank - ps.set_expert_model_parallel_rank(None) - assert ps.get_expert_model_parallel_rank() == rank - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_is_pipeline_first_stage(order): - Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert ps.is_pipeline_first_stage(ignore_virtual=False) == (rank == 0) - assert ps.is_pipeline_first_stage() == (rank == 0) - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_is_pipeline_last_stage(order): - Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert ps.is_pipeline_last_stage(ignore_virtual=False) == (rank == world_size - 1) - assert ps.is_pipeline_last_stage() == (rank == world_size - 1) - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_virtual_pipeline_model_parallel_rank(order): - Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - ps.set_virtual_pipeline_model_parallel_rank(rank) - assert ps.get_virtual_pipeline_model_parallel_rank() == rank - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize('order', test_parallel_order) -def test_get_tensor_model_parallel_src_rank(order): - Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order) - assert ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size) - Utils.destroy_model_parallel() - - -@pytest.mark.internal -@pytest.mark.parametrize( - 'src_tp_pp, ep_size', - [ - ((1, 8), 1), - ((2, 4), 1), - ((4, 2), 1), - ((8, 1), 1), - ((4, 1), 2), - ((1, 1), 8), - ((1, 1), 2), - ((2, 1), 4), - ], -) -def test_different_initialize_order_consistency(src_tp_pp, ep_size): - Utils.initialize_model_parallel( - *src_tp_pp, expert_model_parallel_size=ep_size, order='tp-ep-dp-pp' - ) - tp_rank = ps.get_tensor_model_parallel_rank() - dp_rank = ps.get_data_parallel_rank() - pp_rank = ps.get_pipeline_model_parallel_rank() - ep_rank = ps.get_expert_model_parallel_rank() - - tp_g = torch.distributed.get_process_group_ranks(ps.get_tensor_model_parallel_group()) - dp_g = torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) - pp_g = torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) - dp_no_ep_g = torch.distributed.get_process_group_ranks(ps.get_expert_data_parallel_group()) - cp_g = torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) - mp_g = torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) - tp_ep_g = torch.distributed.get_process_group_ranks( - ps.get_expert_tensor_and_model_parallel_group() - ) - tp_dp_g = torch.distributed.get_process_group_ranks( - ps.get_tensor_and_data_parallel_group(False) - ) - - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel( - *src_tp_pp, expert_model_parallel_size=ep_size, order='tp-pp-ep-dp' - ) - assert tp_rank == ps.get_tensor_model_parallel_rank() - assert dp_rank == ps.get_data_parallel_rank() - assert pp_rank == ps.get_pipeline_model_parallel_rank() - assert ep_rank == ps.get_expert_model_parallel_rank() - - assert tp_g == torch.distributed.get_process_group_ranks(ps.get_tensor_model_parallel_group()) - assert dp_g == torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) - assert pp_g == torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) - assert dp_no_ep_g == torch.distributed.get_process_group_ranks( - ps.get_expert_data_parallel_group() - ) - assert cp_g == torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) - assert mp_g == torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) - assert tp_ep_g == torch.distributed.get_process_group_ranks( - ps.get_expert_tensor_and_model_parallel_group() - ) - assert tp_dp_g == torch.distributed.get_process_group_ranks( - ps.get_tensor_and_data_parallel_group(False) - ) - - Utils.destroy_model_parallel() - - -@pytest.mark.parametrize( - 'src_tp_pp, ep_size', - [((1, 2), 1), ((1, 4), 1), ((2, 2), 1), ((1, 2), 2), ((1, 4), 2), ((2, 2), 2)], -) -@pytest.mark.flaky -@pytest.mark.flaky_in_dev -def test_different_initialize_order_unconsistency(src_tp_pp, ep_size): - Utils.initialize_model_parallel( - *src_tp_pp, expert_model_parallel_size=ep_size, order='tp-ep-dp-pp' - ) - - tp_g = torch.distributed.get_process_group_ranks(ps.get_tensor_model_parallel_group()) - dp_g = torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) - pp_g = torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) - cp_g = torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) - amax_g = torch.distributed.get_process_group_ranks(ps.get_amax_reduction_group(False)) - mp_g = torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) - - Utils.destroy_model_parallel() - - Utils.initialize_model_parallel( - *src_tp_pp, expert_model_parallel_size=ep_size, order='tp-pp-ep-dp' - ) - assert tp_g == torch.distributed.get_process_group_ranks(ps.get_tensor_model_parallel_group()) - assert dp_g != torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) - assert pp_g != torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) - assert cp_g == torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) - assert amax_g != torch.distributed.get_process_group_ranks(ps.get_amax_reduction_group(False)) - assert mp_g != torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) - - Utils.destroy_model_parallel() - - -@pytest.mark.internal -@pytest.mark.parametrize( - 'nodes, num_gpu, tp, pp, cp, ep', - [ - (1, 1, 1, 1, 1, 1), - (1, 8, 8, 1, 1, 1), - (1, 8, 2, 2, 1, 1), - (1, 8, 2, 4, 1, 1), - (3, 8, 8, 3, 1, 1), - (4, 8, 2, 4, 1, 1), - (8, 8, 8, 8, 1, 1), - (8, 8, 2, 1, 1, 4), - (8, 8, 2, 2, 2, 4), - (8, 8, 2, 1, 4, 8), - (8, 8, 2, 2, 2, 8), - (16, 8, 4, 8, 1, 1), - (16, 8, 4, 8, 1, 4), - (16, 8, 4, 8, 4, 1), - (16, 8, 8, 8, 1, 1), - (16, 8, 4, 8, 1, 1), - (16, 8, 8, 8, 1, 1), - (32, 8, 4, 8, 1, 1), - (32, 8, 8, 8, 1, 1), - (32, 8, 4, 8, 1, 4), - (32, 8, 8, 8, 4, 1), - (64, 8, 4, 2, 8, 8), - (64, 8, 4, 8, 1, 1), - (64, 8, 8, 8, 1, 1), - (96, 8, 4, 8, 1, 1), - (128, 8, 4, 2, 8, 8), - (128, 8, 4, 8, 1, 1), - (256, 8, 4, 8, 1, 1), - (316, 8, 4, 8, 1, 1), - (384, 8, 4, 8, 1, 1), - (512, 8, 4, 8, 1, 1), - (768, 8, 4, 8, 1, 1), - (1024, 8, 4, 8, 1, 1), - (1280, 8, 4, 8, 1, 1), - (1344, 8, 4, 8, 1, 1), - ], -) -def test_rank_generator_for_tp_dp_pp(nodes, num_gpu, tp, pp, cp, ep): - def golden_rank_result_from_past_code( - world_size: int, - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - context_parallel_size: int = 1, - expert_model_parallel_size: int = 1, - ): - data_parallel_size: int = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size - ) - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - - dp_groups = [] - dp_groups_with_cp = [] - - all_data_parallel_group_ranks_with_cp = [] - for i in range(pipeline_model_parallel_size): - start_rank = i * num_pipeline_model_parallel_groups - end_rank = (i + 1) * num_pipeline_model_parallel_groups - for j in range(context_parallel_size * tensor_model_parallel_size): - ranks = range( - start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size - ) - dp_groups.append(list(ranks)) - for j in range(tensor_model_parallel_size): - ranks_with_cp = range(start_rank + j, end_rank, tensor_model_parallel_size) - all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp)) - dp_groups_with_cp.append(list(ranks_with_cp)) - - cp_group = [] - for i in range(pipeline_model_parallel_size): - for j in range(data_parallel_size): - start_rank = ( - i * num_pipeline_model_parallel_groups - + j * tensor_model_parallel_size * context_parallel_size - ) - end_rank = ( - i * num_pipeline_model_parallel_groups - + (j + 1) * tensor_model_parallel_size * context_parallel_size - ) - for k in range(tensor_model_parallel_size): - ranks = range(start_rank + k, end_rank, tensor_model_parallel_size) - cp_group.append(list(ranks)) - - mp_group = [] - for i in range(data_parallel_size * context_parallel_size): - ranks = [ - data_parallel_group_ranks_with_cp[i] - for data_parallel_group_ranks_with_cp in all_data_parallel_group_ranks_with_cp - ] - mp_group.append(list(ranks)) - - tp_group = [] - for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - tp_group.append(list(ranks)) - - pp_group = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) - pp_group.append(list(ranks)) - - tp_dp_group = [] - tp_dp_cp_group = [] - tensor_and_data_group_size_with_cp: int = ( - tensor_model_parallel_size * data_parallel_size * context_parallel_size - ) - num_tensor_and_data_groups_with_cp: int = world_size // tensor_and_data_group_size_with_cp - for i in range(num_tensor_and_data_groups_with_cp): - start_rank = i * tensor_and_data_group_size_with_cp - end_rank = start_rank + tensor_and_data_group_size_with_cp - ranks = range(start_rank, end_rank) - tp_dp_cp_group.append(list(ranks)) - - for j in range(context_parallel_size): - ranks = [] - for k in range(data_parallel_size): - start_rank = ( - i * tensor_and_data_group_size_with_cp - + j * tensor_model_parallel_size - + k * tensor_model_parallel_size * context_parallel_size - ) - end_rank = start_rank + tensor_model_parallel_size - ranks = ranks + list(range(start_rank, end_rank)) - tp_dp_group.append(list(ranks)) - - expert_tp_ep_group = [] - expert_dp_group = [] - - expert_data_parallel_size = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size * expert_model_parallel_size - ) - all_ranks = torch.arange(world_size).reshape( - ( - pipeline_model_parallel_size, - expert_data_parallel_size, - expert_model_parallel_size, - tensor_model_parallel_size, - ) - ) - # (pp, dp, ep, tp) -> (pp*dp, ep*tp) - tp_ep_rearrange = torch.reshape( - all_ranks, (-1, expert_model_parallel_size * tensor_model_parallel_size) - ) - num_tp_ep_groups = tp_ep_rearrange.shape[0] - for i in range(num_tp_ep_groups): - expert_tensor_and_model_parallel_ranks = tp_ep_rearrange[i].tolist() - expert_tp_ep_group.append(expert_tensor_and_model_parallel_ranks) - - # (pp, dp, ep, tp) -> (pp*ep*tp, dp) - expert_dp_rearrange = torch.permute(all_ranks, (0, 2, 3, 1)).reshape( - -1, expert_data_parallel_size - ) - num_expert_dp_groups = world_size // expert_data_parallel_size - for i in range(num_expert_dp_groups): - expert_dp_ranks = expert_dp_rearrange[i].tolist() - expert_dp_group.append(expert_dp_ranks) - - return ( - dp_groups, - dp_groups_with_cp, - cp_group, - mp_group, - tp_group, - pp_group, - tp_dp_group, - tp_dp_cp_group, - expert_tp_ep_group, - expert_dp_group, - ) - - world_size = nodes * num_gpu - dp = world_size // (tp * pp * cp) - expert_dp = world_size // (tp * ep * pp) - assert dp % ep == 0, f"dp size ({dp}) is not divisible by ep {ep} ." - assert ( - world_size % (tp * pp * cp) == 0 - ), f"world_size ({world_size}) is not divisible by tp {tp} x pp {pp} x cp {cp}." - ( - dp_groups, - dp_groups_with_cp, - cp_group, - mp_group, - tp_group, - pp_group, - tp_dp_group, - tp_dp_cp_group, - expert_tp_ep_group, - expert_dp_group, - ) = golden_rank_result_from_past_code( - world_size=world_size, - tensor_model_parallel_size=tp, - pipeline_model_parallel_size=pp, - context_parallel_size=cp, - expert_model_parallel_size=ep, - ) - rank_generator = ps.RankGenerator(tp=tp, ep=1, dp=dp, pp=pp, cp=cp, order="tp-cp-dp-pp") - expert_rank_generator = ps.RankGenerator( - tp=tp, ep=ep, dp=expert_dp, pp=pp, cp=1, order="tp-ep-dp-pp" - ) - assert dp_groups == rank_generator.get_ranks( - "dp" - ), f"{dp_groups} != {rank_generator.get_ranks('dp')}" - assert dp_groups_with_cp == rank_generator.get_ranks( - 'dp-cp' - ), f"{dp_groups_with_cp} != {rank_generator.get_ranks('dp-cp')}" - assert cp_group == rank_generator.get_ranks( - "cp" - ), f"{cp_group} != {rank_generator.get_ranks('cp')}." - assert mp_group == rank_generator.get_ranks( - "tp-pp" - ), f"{mp_group} != {rank_generator.get_ranks('tp-pp')}" - assert tp_group == rank_generator.get_ranks( - "tp" - ), f"{tp_group} != {rank_generator.get_ranks('tp')}" - assert pp_group == rank_generator.get_ranks( - "pp" - ), f"{pp_group} != {rank_generator.get_ranks('pp')}" - assert tp_dp_group == rank_generator.get_ranks( - "tp-dp" - ), f"{tp_dp_group} != {rank_generator.get_ranks('tp-dp')}" - assert tp_dp_cp_group == rank_generator.get_ranks( - "tp-dp-cp" - ), f"{tp_dp_cp_group} != {rank_generator.get_ranks('tp-dp-cp')}" - assert expert_tp_ep_group == expert_rank_generator.get_ranks( - "tp-ep" - ), f"{expert_tp_ep_group} != {expert_rank_generator.get_ranks('tp-ep')}." - assert expert_dp_group == expert_rank_generator.get_ranks( - "dp" - ), f"{expert_dp_group} != {expert_rank_generator.get_ranks('dp')}." diff --git a/tests/unit_tests/test_process_groups_config.py b/tests/unit_tests/test_process_groups_config.py deleted file mode 100644 index 7a2dc30160..0000000000 --- a/tests/unit_tests/test_process_groups_config.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import torch.distributed as dist - -from megatron.core.process_groups_config import GradCommProcessGroups, ModelCommProcessGroups - - -class TestProcessGroupsConfig: - """Simple tests for process group dataclasses.""" - - def test_transformer_process_groups(self, mocker): - """Test basic functionality of TransformerProcessGroups.""" - mock_pg1 = mocker.Mock(spec=dist.ProcessGroup) - mock_pg2 = mocker.Mock(spec=dist.ProcessGroup) - - # Create instance - model_pgs = ModelCommProcessGroups() - - # Test setting attributes after creation - model_pgs.tp = mock_pg1 - model_pgs.pp = mock_pg2 - - # Test accessing attributes - assert model_pgs.tp == mock_pg1 - assert model_pgs.pp == mock_pg2 - - # Test attribute existence - assert hasattr(model_pgs, 'tp') - assert hasattr(model_pgs, 'pp') - assert not hasattr(model_pgs, 'cp') # Not set yet - - def test_grad_comm_process_groups(self, mocker): - """Test basic functionality of GradCommProcessGroups.""" - # Create mock process groups - mock_pg = mocker.Mock(spec=dist.ProcessGroup) - - # Create instance - grad_pgs = GradCommProcessGroups() - - # Test setting attributes after creation - grad_pgs.dp = mock_pg - - # Test accessing attributes - assert grad_pgs.dp == mock_pg - - # Test attribute existence - assert hasattr(grad_pgs, 'dp') - assert not hasattr(grad_pgs, 'dp_cp') # Not set yet - - def test_hierarchical_context_parallel_groups(self, mocker): - """Test setting and accessing the hierarchical context parallel list.""" - # Create mock process groups - mock_pg1 = mocker.Mock(spec=dist.ProcessGroup) - mock_pg2 = mocker.Mock(spec=dist.ProcessGroup) - - # Create instance - model_pgs = ModelCommProcessGroups() - - # Set the hierarchical context parallel groups - model_pgs.hcp = [mock_pg1, mock_pg2] - - # Test list access - assert isinstance(model_pgs.hcp, list) - assert len(model_pgs.hcp) == 2 - assert model_pgs.hcp[0] == mock_pg1 - assert model_pgs.hcp[1] == mock_pg2 diff --git a/tests/unit_tests/test_tokenizer.py b/tests/unit_tests/test_tokenizer.py deleted file mode 100644 index 3d8f5d9c33..0000000000 --- a/tests/unit_tests/test_tokenizer.py +++ /dev/null @@ -1,276 +0,0 @@ -import base64 -import json -from argparse import Namespace -from pathlib import Path - -import numpy as np -import pytest -import requests - -from megatron.training import tokenizer -from megatron.training.tokenizer.gpt2_tokenization import PRETRAINED_VOCAB_ARCHIVE_MAP -from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer - -TOKENIZER_DIR = Path("~/data/tokenizers").expanduser() - -# Copied over from test_preprocess_data.py -from tests.unit_tests.data.test_preprocess_data import __LOCAL_GPT2_VOCAB - -GPT2_VOCAB_SIZE = 32768 - - -def offsets_to_substrs(offsets, string): - return [string[start:end] for start, end in zip([0] + offsets, offsets + [len(string)])] - - -def local_test_specs(): - return [ - Namespace( - rank=0, - tensor_model_parallel_size=8, - make_vocab_size_divisible_by=128, - tokenizer_type="GPTSentencePieceTokenizer", - tokenizer_model=f"{TOKENIZER_DIR}/nemotron_2_256k.model", - ), - Namespace( - rank=0, - vocab_size=131072, - make_vocab_size_divisible_by=128, - tensor_model_parallel_size=8, - tokenizer_type="TikTokenizer", - tokenizer_model=f"{TOKENIZER_DIR}/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json", - tiktoken_pattern="v2", - tiktoken_num_special_tokens=1000, - tiktoken_special_tokens=["", "", ""], - ), - Namespace( - rank=0, - vocab_size=131072, - make_vocab_size_divisible_by=128, - tensor_model_parallel_size=8, - tokenizer_type="TikTokenizer", - tokenizer_model=f"{TOKENIZER_DIR}/multiMixV5_fix_default_500000_128k.vocab.json", - tiktoken_pattern="v1", - tiktoken_num_special_tokens=1000, - tiktoken_special_tokens=["", "", ""], - ), - Namespace( - rank=0, - vocab_size=128000, - make_vocab_size_divisible_by=128, - tensor_model_parallel_size=8, - tokenizer_type="HuggingFaceTokenizer", - tokenizer_model="meta-llama/Llama-2-7b-hf", - ), - Namespace( - rank=0, - vocab_size=128000, - make_vocab_size_divisible_by=128, - tensor_model_parallel_size=8, - tokenizer_type="HuggingFaceTokenizer", - tokenizer_model="meta-llama/Meta-Llama-3.1-8B", - ), - ] - - -@pytest.fixture(scope="session") -def gpt2_tiktok_vocab(tmp_path_factory): - - if Path(__LOCAL_GPT2_VOCAB).exists(): - with open(__LOCAL_GPT2_VOCAB, "r", encoding="utf-8") as reader: - gpt2_vocab = json.load(reader) - else: - gpt2_vocab = json.loads(requests.get(PRETRAINED_VOCAB_ARCHIVE_MAP["gpt2"]).content) - - N = 256 - tiktok_vocab = [ - {"token_bytes": base64.b64encode(bytes([i])).decode("utf-8"), "token_str": str(i)} - for i in range(N) - ] - tiktok_vocab_bytes = {x["token_bytes"] for x in tiktok_vocab} - - tiktok_vocab += [ - {"token_bytes": base64.b64encode(token.encode('utf-8')).decode("utf-8"), "token_str": token} - for token in gpt2_vocab - if base64.b64encode(token.encode('utf-8')).decode("utf-8") not in tiktok_vocab_bytes - ] - - for i, entry in enumerate(tiktok_vocab): - entry["rank"] = i - - for i, x in enumerate(tiktok_vocab): - assert x.keys() == {"rank", "token_bytes", "token_str"} - assert x["rank"] == i - merge = base64.b64decode(x["token_bytes"]) - assert i >= 256 or merge == bytes([i]), f"{i} {merge} {bytes([i])}" - - file_name = tmp_path_factory.mktemp("data") / "gpt2_vocab.json" - with open(file_name, "w") as f: - json.dump(tiktok_vocab, f) - - return Namespace( - rank=0, - vocab_size=32768, - make_vocab_size_divisible_by=128, - tensor_model_parallel_size=8, - tokenizer_type="TikTokenizer", - tokenizer_model=str(file_name), - tiktoken_pattern="v1", - tiktoken_num_special_tokens=1000, - tiktoken_special_tokens=["", "", ""], - ) - - -@pytest.mark.parametrize("args", local_test_specs()) -def test_tokenizer(args): - if not TOKENIZER_DIR.exists(): - pytest.skip("Skipping tokenizer tests because the tokenizer directory does not exist") - - tok = tokenizer.build_tokenizer(args) - run_tokenizer_tests(tok) - - -def test_gpt2_tiktok_tokenizer(gpt2_tiktok_vocab): - tok = tokenizer.build_tokenizer(gpt2_tiktok_vocab) - run_tokenizer_tests(tok) - - -def run_tokenizer_tests(tok): - string1 = ( - "The following are multiple choice questions (with answers) about college biology.\n" - "Monoclonal antisera are distinguished from polyclonal antisera in which of the " - "following ways?\n" - "A. Each type of antibody in a monoclonal antiserum reacts against a single region of " - "a single antigen; each type of antibody in a polyclonal antiserum reacts against " - "multiple regions of different antigens.\n" - "B. A monoclonal antibody reacts against multiple regions of a single antigen; a " - "polyclonal antibody reacts against a single region of related antigens.\n" - "C. A monoclonal antiserum contains antibodies secreted from the descendants of a " - "single B lymphocyte; a polyclonal antiserum contains antibodies secreted from the " - "descendants of different B lymphocytes.\n" - "D. A monoclonal antiserum contains antibodies secreted from the descendants of a " - "single B lymphocyte; a polyclonal antiserum contains antibodies secreted from the " - "descendants of both B and T lymphocytes.\n" - "Answer: C" - ) - string2 = "Жизнь прекрасна и удивительна" - string3 = "お誕生日おめでとう" - strings = [string1, string2, string3] - - for test_string in strings: - toks = tok.tokenize(test_string) - offsets = tok.offsets(toks, test_string) - dec = offsets_to_substrs(offsets, test_string) - detok_str = ''.join(dec) - # the following is not necessarily true by construction above, - # since the many tokenizers may operate at the byte level and not - # only at the character level. - assert ( - detok_str == test_string - ), f"Detokenized string {detok_str} does not match original {test_string}" - assert len(toks) == len( - offsets - ), f"Tokenized string {toks} does not match original {offsets}" - - -def test_null_tokenizer(): - args = Namespace( - tokenizer_type="NullTokenizer", - rank=0, - vocab_size=128000, - make_vocab_size_divisible_by=128, - tensor_model_parallel_size=8, - ) - tok = tokenizer.build_tokenizer(args) - test_string = "1 23 456 789" - toks = tok.tokenize(test_string) - offsets = tok.offsets(toks, test_string) - dec = offsets_to_substrs(offsets, test_string) - detok_str = ''.join(dec) - - assert ( - detok_str == test_string - ), f"Detokenized string {detok_str} does not match original {test_string}" - assert len(toks) == len(offsets), f"Tokenized string {toks} does not match original {offsets}" - - -class MockUnderlyingTokenizer: - """Mock tokenizer for testing purposes.""" - - def __init__(self): - self.pad_token_id = 256 - - def __len__(self): - return 256 - - def encode(self, text: str) -> list[int]: - """Convert text to a list of token IDs.""" - return [ord(c) for c in text] - - def decode(self, tokens: list[int]) -> str: - """Convert list of token IDs to plaintext.""" - return "".join([chr(t) for t in tokens]) - - def apply_chat_template(self, conversation: list[dict], *args, **kwargs) -> list[int]: - """Convert a conversation to token IDs.""" - out = [] - for turn in conversation: - turn_tokens = self.encode(f"{turn['role']}:{turn['content']}") - out.extend(turn_tokens) - - if kwargs.get("return_tensors", None) == "np": - return [np.array(out)] - - return out - - def convert_tokens_to_ids(self, text: str) -> list[int]: - """Convert plaintext to token IDs.""" - return self.encode(text) - - def add_tokens(self, extra_tokens: list[str], *args, **kwargs) -> int: - """Add tokens to the tokenizer. No-op for this mock tokenizer.""" - return len(extra_tokens) - - -def test_multimodal_tokenizer(): - """Test MultimodalTokenizer.""" - underlying = MockUnderlyingTokenizer() - prompt_format = "chatml" - special_tokens = [""] - image_tag_type = "" - tokenizer = MultimodalTokenizer(underlying, prompt_format, special_tokens, image_tag_type) - - # Simple encode - decode roundtrip. - assert ( - tokenizer.detokenize(tokenizer.tokenize("abc")) == "abc" - ), "encode-decode roundtrip failed" - - # Apply chat template. - conversation = [ - {"role": "system", "content": "abc"}, - {"role": "user", "content": "123"}, - {"role": "assistant", "content": "xyz"}, - ] - conv_tokens = tokenizer.tokenize_conversation( - conversation, return_target=False, add_generation_prompt=False - ) - assert len(conv_tokens) > 0, "failed to tokenize conversation" - - conv_tokens, target_tokens = tokenizer.tokenize_conversation( - conversation, return_target=True, add_generation_prompt=True - ) - assert len(conv_tokens) > 0 and len(conv_tokens) == len( - target_tokens - ), "failed to tokenize conversation and return target tokens" - - # Try converting tokens to ids. - assert tokenizer.convert_tokens_to_ids("a"), "failed to convert tokens to ids." - - # Try image tags. - image_tag_type = "nvlm" - tokenizer = MultimodalTokenizer(underlying, prompt_format, special_tokens, image_tag_type) - - assert tokenizer._apply_image_tag("hello") == "hello" - assert tokenizer._apply_image_tag([{"role": "user", "content": "hello"}]) == [ - {"role": "user", "content": "hello"} - ] diff --git a/tests/unit_tests/test_training.py b/tests/unit_tests/test_training.py deleted file mode 100644 index 953a80e094..0000000000 --- a/tests/unit_tests/test_training.py +++ /dev/null @@ -1,73 +0,0 @@ -from types import SimpleNamespace - -from megatron.training.global_vars import set_args -from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding -from megatron.training.training import build_train_valid_test_data_iterators -from tests.unit_tests.test_utilities import Utils - - -def mock_train_valid_test_datasets_provider(train_val_test_num_samples): - return iter([1]), iter([2]), iter([3]) - - -def create_test_args(): - # Set dummy values for the args. - args = SimpleNamespace() - args.iteration = 0 - args.train_samples = 1 - args.train_iters = 1 - args.eval_interval = 1 - args.eval_iters = 1 - args.global_batch_size = 1 - args.consumed_train_samples = 1 - args.consumed_valid_samples = 1 - args.dataloader_type = "external" - args.skip_train = False - args.full_validation = False - args.multiple_validation_sets = False - args.perform_rl_step = False - - return args - - -class TestTraining: - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - args = create_test_args() - set_args(args) - - def test_build_train_valid_test_data_iterators(self): - train_iter, valid_iter, test_iter = build_train_valid_test_data_iterators( - mock_train_valid_test_datasets_provider - ) - train_data = next(train_iter) - valid_data = next(valid_iter) - test_data = next(test_iter) - assert (train_data, valid_data, test_data) == (1, 2, 3) - - def test_closed_formula_vocab_size_with_padding(self): - def old_round_impl(after, multiple): - while (after % multiple) != 0: - after += 1 - return after - - args = SimpleNamespace() - args.rank = 0 - args.tensor_model_parallel_size = 1 - - for vocab in range(1, 600000, 1000): - for mult in [1, 17, 32, 64, 128]: - args.make_vocab_size_divisible_by = mult - assert old_round_impl(vocab, mult) == _vocab_size_with_padding( - vocab, args, False - ), (vocab, mult) - - for vocab in range(1, 10_000, 500): - for mult in range(1, 1024 + 1): - args.make_vocab_size_divisible_by = mult - assert old_round_impl(vocab, mult) == _vocab_size_with_padding( - vocab, args, False - ), (vocab, mult) - - def teardown_method(self, method): - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/test_utilities.py b/tests/unit_tests/test_utilities.py deleted file mode 100644 index f16f88f786..0000000000 --- a/tests/unit_tests/test_utilities.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -from datetime import timedelta - -import torch -from torch._C._distributed_c10d import PrefixStore -from torch.distributed import rendezvous - -import megatron.core.parallel_state as ps - - -class TestModel(torch.nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - num_layers: int, - bias: bool, - shared_embedding: bool = False, - ): - super().__init__() - self.layers = torch.nn.ModuleList( - [torch.nn.Linear(input_dim, output_dim, bias) for _ in range(num_layers)] - ) - if shared_embedding: - self.layers[-1].weight.shared_embedding = True - - -class Utils: - - world_size = int(os.environ['WORLD_SIZE']) - rank = int(os.environ['LOCAL_RANK']) - inited = False - store = None - - @staticmethod - def initialize_distributed(): - - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - - if not torch.distributed.is_initialized() and Utils.rank >= 0: - print( - f'Initializing torch.distributed with rank: {Utils.rank}, ' - f'world_size: {Utils.world_size}' - ) - torch.cuda.set_device(Utils.rank % torch.cuda.device_count()) - init_method = 'tcp://' - master_ip = os.getenv('MASTER_ADDR', 'localhost') - master_port = os.getenv('MASTER_PORT', '6000') - init_method += master_ip + ':' + master_port - rendezvous_iterator = rendezvous( - init_method, Utils.rank, Utils.world_size, timeout=timedelta(minutes=1) - ) - store, rank, world_size = next(rendezvous_iterator) - store.set_timeout(timedelta(minutes=1)) - - # Use a PrefixStore to avoid accidental overrides of keys used by - # different systems (e.g. RPC) in case the store is multi-tenant. - store = PrefixStore("default_pg", store) - Utils.store = store - - torch.distributed.init_process_group( - backend='nccl', world_size=Utils.world_size, rank=Utils.rank, store=store - ) - - torch.distributed.barrier() - Utils.inited = True - - @staticmethod - def set_world_size(world_size=None, rank=None): - Utils.world_size = torch.cuda.device_count() if world_size is None else world_size - if ( - torch.distributed.is_initialized() - and Utils.world_size != torch.distributed.get_world_size() - ): - torch.distributed.destroy_process_group() - - if rank is None: - Utils.rank = int(os.environ['LOCAL_RANK']) - if Utils.rank >= Utils.world_size: - Utils.rank = -1 - else: - Utils.rank = rank - - @staticmethod - def destroy_model_parallel(): - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - if not Utils.inited: - return - torch.distributed.barrier() - ps.destroy_model_parallel() - Utils.inited = False - - @staticmethod - def initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - virtual_pipeline_model_parallel_size=None, - **kwargs, - ): - # Need to unset these variables to make sure previous - # tests setting them doesn't interfere current test. - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - - ps.destroy_model_parallel() - Utils.initialize_distributed() - ps.initialize_model_parallel( - tensor_model_parallel_size, - pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size, - **kwargs, - ) - Utils.inited = True - - @staticmethod - def fake_initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - virtual_pipeline_model_parallel_size=None, - expert_model_parallel_size=1, - ): - """Used for layer-wise UT as a proxy for NeMo-style intialization.""" - ps.set_tensor_model_parallel_world_size(tensor_model_parallel_size) - ps.set_tensor_model_parallel_rank(0) - - ps.set_expert_model_parallel_world_size(expert_model_parallel_size) - ps.set_expert_model_parallel_rank(0) - if virtual_pipeline_model_parallel_size is not None: - ps.set_virtual_pipeline_model_parallel_world_size(virtual_pipeline_model_parallel_size) - ps.set_virtual_pipeline_model_parallel_rank(0) - - ps.set_pipeline_model_parallel_world_size(pipeline_model_parallel_size) diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py deleted file mode 100644 index 18e3787c24..0000000000 --- a/tests/unit_tests/test_utils.py +++ /dev/null @@ -1,383 +0,0 @@ -import os -import time -import urllib.request as req -from types import SimpleNamespace -from unittest.mock import patch - -import mock -import numpy as np -import pytest -import torch - -import megatron.core.utils as util -import megatron.training.utils as training_util -from megatron.core import config -from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer -from megatron.core.transformer import TransformerConfig -from tests.unit_tests.test_utilities import Utils - -success_string = "hello,world" - - -@util.experimental_cls(introduced_with_version="0.1.0") -class A: - - def __init__(self): - pass - - def some_method(self): - return success_string - - @classmethod - def some_static_method(cls): - return success_string - - -def test_divide_properly(): - assert util.divide(4, 2) == 2 - - -def test_divide_improperly(): - with pytest.raises(AssertionError): - util.divide(4, 5) - - -def test_experimental_cls_init(): - with patch.object(config, 'ENABLE_EXPERIMENTAL', True): - # Check that initialization works - a = A() - assert a.__class__.__qualname__ == "A" - assert a.some_method() == success_string - assert a.is_experimental is True - - -def test_experimental_cls_static(): - with patch.object(config, 'ENABLE_EXPERIMENTAL', True): - # Check that static methods work - assert A.__class__.__qualname__ == "A" - assert A.some_static_method() == success_string - assert A.is_experimental is True - - -def test_experimental_cls_exception_init(): - with ( - patch.object(config, 'ENABLE_EXPERIMENTAL', False), - pytest.raises(util.ExperimentalNotEnabledError), - ): - a = A() - assert a.some_method() == success_string - assert a.is_experimental is False - - -def test_experimental_cls_exception_static(): - with ( - patch.object(config, 'ENABLE_EXPERIMENTAL', False), - pytest.raises(util.ExperimentalNotEnabledError), - ): - assert A.some_static_method() == success_string - - assert A.is_experimental is False - - -def test_global_memory_buffer(): - global_memory_buffer = util.GlobalMemoryBuffer() - obtained_tensor = global_memory_buffer.get_tensor((3, 2), torch.float32, "test_tensor") - expected_tensor = torch.empty((3, 2), dtype=torch.float32, device=torch.cuda.current_device()) - assert obtained_tensor.shape == expected_tensor.shape - - -def test_make_viewless_tensor(): - inp = torch.rand((3, 4)) - assert torch.equal(inp, util.make_viewless_tensor(inp, True, True)) - assert torch.equal(inp, util.make_viewless_tensor(inp, True, False)) - - -def test_safely_set_viewless_tensor_data(): - tensor = torch.zeros((3, 4)) - new_data_tensor = torch.tensor(np.random.rand(3, 4)) - util.safely_set_viewless_tensor_data(tensor, new_data_tensor) - assert torch.equal(tensor, new_data_tensor) - - -def test_assert_viewless_tensor(): - tensor = torch.rand((3, 4)) - assert torch.equal(util.assert_viewless_tensor(tensor), tensor) - input_tensor_list = [tensor, tensor, tensor] - output_tensor_list = util.assert_viewless_tensor(input_tensor_list) - for inp, out in zip(input_tensor_list, output_tensor_list): - assert torch.equal(inp, out) - - -# Initialize torch.distributed; do not call init_process_group here, call -# Utils.initialize_distributed() instead. -def _init_distributed(world, rank): - Utils.initialize_distributed() - assert torch.distributed.is_initialized() == True - assert torch.distributed.get_rank() == rank - assert torch.cuda.device_count() == world - torch.distributed.barrier() - - -# Deinitialization and cleanup. -# Do not call torch.distributed.destroy_process_group, may be needed by other tests. -def _deinit_distributed(): - assert torch.distributed.is_initialized() == True - torch.distributed.barrier() - - -@pytest.mark.parametrize( - "msg,suffix", - [(None, None), ("test_message", None), (None, "test_suffix"), ("test_message", "test_suffix")], -) -def test_nvtx_range(msg, suffix): - # Track function execution - execution_tracker = {'ranges': False} - - def _call_nvtx_range(): - util.nvtx_range_push(msg, suffix) - execution_tracker['ranges'] = True - util.nvtx_range_pop(msg, suffix) - - # Test with NVTX disabled - util.configure_nvtx_profiling(False) - _call_nvtx_range() - assert execution_tracker['ranges'] - - # Reset tracker - execution_tracker['ranges'] = False - - # Test with NVTX enabled - util.configure_nvtx_profiling(True) - _call_nvtx_range() - assert execution_tracker['ranges'] - - -def test_nvtx_decorator(): - # Track function execution - execution_tracker = {'decorated': False, 'decorated_with_message': False} - - # Create decorated functions - @util.nvtx_decorator() - def nvtx_decorated_function(): - execution_tracker['decorated'] = True - - @util.nvtx_decorator(message="test_nvtx_decorator", color="red") - def nvtx_decorated_function_with_message(): - execution_tracker['decorated_with_message'] = True - - # Test with NVTX disabled - util.configure_nvtx_profiling(False) - nvtx_decorated_function() - nvtx_decorated_function_with_message() - assert all(execution_tracker.values()) - - # Reset tracker - execution_tracker = {'decorated': False, 'decorated_with_message': False} - - # Test with NVTX enabled - util.configure_nvtx_profiling(True) - nvtx_decorated_function() - nvtx_decorated_function_with_message() - assert all(execution_tracker.values()) - - -@pytest.mark.flaky_in_dev -def test_check_param_hashes_across_dp_replicas(): - world = int(os.getenv('WORLD_SIZE', '1')) - rank = int(os.getenv('RANK', '0')) - - # Setup. - _init_distributed(world, rank) - Utils.initialize_model_parallel() - model = torch.nn.Linear(100, 100, bias=False, device='cuda') - - # First check case where all replicas agree. - model.weight.data.fill_(1.0) - assert util.check_param_hashes_across_dp_replicas([model]) - - # Now check case where replica 0 disagrees with all other replicas. - if rank == 0: - model.weight.data.fill_(0.0) - param_hashes_match = util.check_param_hashes_across_dp_replicas([model]) - expected_param_hashes_match = rank == 0 - assert param_hashes_match == expected_param_hashes_match - - # Teardown. - _deinit_distributed() - - -@pytest.mark.flaky_in_dev -def test_cross_check_param_hashes_across_dp_replicas(): - world = int(os.getenv('WORLD_SIZE', '1')) - rank = int(os.getenv('RANK', '0')) - - # Setup. - _init_distributed(world, rank) - Utils.initialize_model_parallel() - model = torch.nn.Linear(100, 100, bias=False, device='cuda') - - # First check case where all replicas agree. - model.weight.data.fill_(1.0) - assert util.check_param_hashes_across_dp_replicas([model], True) - - # Now check case where replica 0 disagrees with all other replicas. - if rank == 0: - model.weight.data.fill_(0.0) - assert not util.check_param_hashes_across_dp_replicas([model], True) - - # Teardown. - _deinit_distributed() - - -@pytest.mark.parametrize("use_distributed_optimizer", [False, True]) -@pytest.mark.flaky_in_dev -def test_param_norm(use_distributed_optimizer: bool): - world = int(os.getenv('WORLD_SIZE', '1')) - rank = int(os.getenv('RANK', '0')) - - # Setup: distributed, model, mock_args. - _init_distributed(world, rank) - Utils.initialize_model_parallel() - model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda') - model.requires_grad_(True) - model.weight.data.fill_(1.0) - ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=use_distributed_optimizer) - # Use dummy TransformerConfig which doesn't trigger __post_init__ assertions. - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model - ) - for param in model.parameters(): - assert param.requires_grad - mock_args = SimpleNamespace(bf16=True) - - with mock.patch('megatron.training.utils.get_args', new=lambda: mock_args): - # Make sure norm is correct when `main_param` attribute is not available. - assert training_util.calc_params_l2_norm( - model, force_create_fp32_copy=False - ) == pytest.approx(100.0) - assert training_util.calc_params_l2_norm( - model, force_create_fp32_copy=True - ) == pytest.approx(100.0) - - # Make sure norm is correct when `main_param` attribute is available. - optimizer_config = OptimizerConfig( - bf16=True, use_distributed_optimizer=use_distributed_optimizer - ) - _ = get_megatron_optimizer(optimizer_config, [model]) - for param in model.parameters(): - assert hasattr(param, 'main_param') - if use_distributed_optimizer: - assert getattr(param, 'main_param_sharded', False) - assert training_util.calc_params_l2_norm( - model, force_create_fp32_copy=False - ) == pytest.approx(100.0) - assert training_util.calc_params_l2_norm( - model, force_create_fp32_copy=True - ) == pytest.approx(100.0) - - # Teardown. - _deinit_distributed() - - -@pytest.mark.flaky_in_dev -def test_straggler_detector(): - world = int(os.getenv('WORLD_SIZE', '1')) - rank = int(os.getenv('RANK', '0')) - master = os.getenv('MASTER_ADDR', 'localhost') - port = 65535 - - # Checks if the instance is disabled. - def straggler_detector_disabled(): - assert stimer.enabled == False - - # Checks if the instance is enabled. - def straggler_detector_enabled(): - assert stimer.enabled == True - - # Enable. - def straggler_detector_enable(): - if rank == 0: - resp = req.urlopen(f"http://{master}:{port}").read().decode().split() - assert resp[3] == "ON" - # Call the report function, this will propagate the change. - stimer.report() - - # Time an operation. - def straggler_detector_timeit(): - s = 2 # Sleep for 2 seconds. - M = 20 - K = 30 - N = 40 - mat1 = torch.randn(M, K, device='cuda') - mat2 = torch.randn(K, N, device='cuda') - # batch_data. - with stimer(bdata=True): - time.sleep(s) - # GEMM. - with stimer: - res = torch.matmul(mat1, mat2) - delta, batch_delta, _, _, _, _ = stimer.elapsed() - assert delta > 0.0 - assert batch_delta >= s - - # Test function to raise ValueError - def straggler_value_error(): - raise ValueError("Exception value raised") - - # Check that exception is not suppressed. - def straggler_detector_exception_propagate(): - # batch_data - with pytest.raises(ZeroDivisionError): - with stimer(bdata=True): - x = 1 / 0 - # non-batch-data - with pytest.raises(ValueError, match=r".* value .*"): - with stimer(): - straggler_value_error() - - # Reporting. - def straggler_detector_report(): - s = 2 # Sleep for 2 seconds. - N = 20 - P = 30 - M = 40 - mat1 = torch.randn(N, P, device='cuda') - mat2 = torch.randn(P, M, device='cuda') - tfp = (N * M) * (2 * P - 1) # Theoretical. - iter = 10 # Mock. - # batch_data. - with stimer(bdata=True): - time.sleep(s) - # GEMM. - with stimer: - res = torch.matmul(mat1, mat2) - r = stimer.report(total_flops=tfp, log_interval=iter) - rb = True if rank == 0 else False - assert r == rb - - # Start test. - # Setup. - _init_distributed(world, rank) - - # Create a straggler_detector with enabled set to false. - stimer = util.StragglerDetector() - stimer.configure(world, rank, enabled=False, port=port) - # Check if configuration was success. - assert stimer.configured == True - - # Check if the instance is in disabled state. - straggler_detector_disabled() - # Enable it now, must call report. - straggler_detector_enable() - # Check if all ranks have straggler detector enabled. - straggler_detector_enabled() - # Time some operation. - straggler_detector_timeit() - # Report only from rank 0. - straggler_detector_report() - # Check that exception is not suppressed. - straggler_detector_exception_propagate() - util.StragglerDetector._configured = False - # Teardown. - _deinit_distributed() diff --git a/tests/unit_tests/tokenizers/test_tokenizer.py b/tests/unit_tests/tokenizers/test_tokenizer.py deleted file mode 100755 index bed9f5fef5..0000000000 --- a/tests/unit_tests/tokenizers/test_tokenizer.py +++ /dev/null @@ -1,257 +0,0 @@ -import json -import sys - -import pytest -import torch -from packaging import version - -from megatron.core.tokenizers import MegatronTokenizer - - -def get_conversation(): - return [ - {"role": "system", "content": "You are a helpful AI assistant."}, - { - "role": "user", - "content": "Hi, can you help me understand how transformers work in machine learning?", - }, - { - "role": "assistant", - "content": "Sure! Transformers are a type of deep learning model introduced in the paper \"Attention Is All You Need\". They rely heavily on self-attention mechanisms to process sequences of data in parallel, unlike RNNs which process data sequentially.", - }, - {"role": "user", "content": "What is self-attention?"}, - { - "role": "assistant", - "content": "Self-attention is a mechanism that allows the model to weigh the importance of different words in a sentence when encoding each word. It helps the model capture relationships between words regardless of their distance in the sequence.", - }, - {"role": "user", "content": "Thanks, that's really helpful!"}, - {"role": "assistant", "content": "You're welcome! Let me know if you have more questions."}, - ] - - -def get_chat_template(): - return """{% for message in messages %} - {% if message['role'] == 'system' %} - <|system|> - {{ message['content'].strip() }} - {% elif message['role'] == 'user' %} - <|user|> - {{ message['content'].strip() }} - {% elif message['role'] == 'assistant' %} - <|assistant|> - {{ message['content'].strip() }} - {% endif %} - {% endfor %} - {% if add_generation_prompt %} - <|assistant|> - {% endif %}""" - - -def test_sp_tokenizer(): - # Load SP tokenizer - tokenizer = MegatronTokenizer.from_pretrained( - "/opt/data/tokenizers/sentencepiece/tokenizer.model" - ) - - # Load SP tokenizer with custom metadata - metadata = {"library": "sentencepiece", "model_type": "gpt"} - - chat_template = get_chat_template() - tokenizer = MegatronTokenizer.from_pretrained( - tokenizer_path="/opt/data/tokenizers/sentencepiece/tokenizer.model", - metadata_path=metadata, - chat_template=chat_template, - ) - - # Test chat template - tokenizer.apply_chat_template(conversation=get_conversation(), chat_template=chat_template) - - # Test tokenization - ids = tokenizer.tokenize("hi how are you?") - assert ids == [ - 7251, - 920, - 526, - 366, - 29973, - ], f"[7251, 920, 526, 366, 29973] are expeted ids but got {ids}." - - # Test detokenization - text = tokenizer.detokenize([306, 29915, 29885, 2691, 3969, 29889]) - assert text == "I'm fine thanks.", f"'I'm fine thanks.' is expeted output but got {text}." - - assert tokenizer.vocab_size == 32000 - assert tokenizer.eos_id == 2 - assert tokenizer.eod == 2 - assert tokenizer.pad == -1 - assert tokenizer.bos == 1 - - -def test_hf_tokenizer(): - # Load HF tokenizer with custom metadata - metadata = {"library": "huggingface"} - chat_template = "test chat template" - - tokenizer = MegatronTokenizer.from_pretrained( - "/opt/data/tokenizers/huggingface", metadata_path=metadata - ) - - # Load HF tokenizer with adding special tokens - special_tokens = {"bos_token": "", "eos_token": ""} - - tokenizer = MegatronTokenizer.from_pretrained( - "/opt/data/tokenizers/huggingface", - metadata_path=metadata, - chat_template=chat_template, - **special_tokens, - ) - - assert tokenizer.chat_template == chat_template - assert tokenizer.tokenize("") == [128257, 128256] - assert tokenizer.detokenize([3, 4, 5]) == "$%&" - assert tokenizer.vocab_size == 128258 - - -def test_megatron_tokenizer(): - # Load tokenizer with additional special tokens - special_tokens = {} - special_tokens['additional_special_tokens'] = [f'' for i in range(100)] - - metadata = {"library": "megatron", "model_type": "gpt"} - vocab_file = "/opt/data/tokenizers/megatron/gpt2-vocab.json" - merges_file = "/opt/data/tokenizers/megatron/gpt2-vocab.json" - tokenizer = MegatronTokenizer.from_pretrained( - tokenizer_path="GPT2BPETokenizer", - metadata_path=metadata, - vocab_file=vocab_file, - merges_file=merges_file, - **special_tokens, - ) - - # Test tokenization - ids = tokenizer.tokenize("hi how are you?") - assert ids == [ - 5303, - 703, - 389, - 345, - 30, - ], f"[5303, 703, 389, 345, 30] are expeted ids but got {ids}." - - # Test detokenization - text = tokenizer.detokenize([40, 1101, 3734, 5176, 13]) - assert text == "I'm fine thanks.", f"'I'm fine thanks.' is expeted output but got {text}." - - assert tokenizer.vocab_size == 50357 - assert tokenizer.eos_id == 50256 - assert tokenizer.eod == 50256 - assert tokenizer.model_type == "gpt" - - assert tokenizer.vocab_file == vocab_file - assert tokenizer.merges_file == merges_file - - -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), reason="Not supported for LTS" -) -def test_tiktoken_tokenizer(): - # Load tiktoken tokenizer - chat_template = get_chat_template() - tokenizer = MegatronTokenizer.from_pretrained( - tokenizer_path="/opt/data/tokenizers/tiktoken/tiktoken.vocab.json", - chat_template=chat_template, - vocab_size=131072, - ) - - # Test tokenization - ids = tokenizer.tokenize("hi how are you?") - assert ids == [ - 8101, - 2606, - 1584, - 1636, - 1063, - ], f"[8101, 2606, 1584, 1636, 1063] are expeted ids but got {ids}." - - # Test detokenization - text = tokenizer.detokenize([1073, 4525, 7771, 14899, 1046]) - assert text == "I'm fine thanks.", f"'I'm fine thanks.' is expeted output but got {text}." - - text = tokenizer.detokenize([0, 1073, 2, 5]) - assert text == "I" - - ids = tokenizer.tokenize("I") - assert ids == [0, 1073, 2, 3] - - # Test methods - assert tokenizer.vocab_size == 131072 - assert tokenizer.eos_id == 2 - assert tokenizer.eod == 2 - assert tokenizer.unk == 0 - assert tokenizer.mask == 3 - assert tokenizer.cls == 5 - - # Test chat template - tokenizer.apply_chat_template(conversation=get_conversation(), chat_template=chat_template) - - -def test_null_tokenizer(): - metadata = {"library": "null"} - tokenizer = MegatronTokenizer.from_pretrained(metadata_path=metadata, vocab_size=131072) - - ids = tokenizer.tokenize("11 325 97") - - assert ids == [11, 325, 97] - assert tokenizer.vocab_size == 131073 - - -def test_bytelevel_tokenizer(): - metadata = {"library": "byte-level"} - vocab_size = 1024 - special_tokens = ["", ""] - tokenizer = MegatronTokenizer.from_pretrained( - metadata_path=metadata, vocab_size=vocab_size, _bos_id=3, special_tokens=special_tokens - ) - - assert tokenizer.vocab_size == (vocab_size + len(special_tokens)) - assert tokenizer.tokenize("Hello") == [72, 101, 108, 108, 111] - assert tokenizer.detokenize([72, 101, 108, 108, 111]) == "Hello" - - -def test_write_metadata(): - tokenizer_path = "/opt/data/tokenizers/huggingface" - chat_template = "test chat template" - tokenizer_library = "huggingface" - MegatronTokenizer.write_metadata( - tokenizer_path=tokenizer_path, - tokenizer_library=tokenizer_library, - chat_template=chat_template, - overwrite=True, - ) - - # When metadata already exists - with pytest.raises(ValueError): - MegatronTokenizer.write_metadata( - tokenizer_path=tokenizer_path, tokenizer_library=tokenizer_library - ) - - # Overwrite metadata - class CustomTokenizerClass: - pass - - MegatronTokenizer.write_metadata( - tokenizer_path=tokenizer_path, - tokenizer_library=tokenizer_library, - tokenizer_class=CustomTokenizerClass, - overwrite=True, - ) - - # Save metadata to specific path - metadata_path = f"{tokenizer_path}/test_metadata.json" - MegatronTokenizer.write_metadata( - tokenizer_path=tokenizer_path, - metadata_path=metadata_path, - tokenizer_library=tokenizer_library, - model_type="gpt", - overwrite=True, - ) diff --git a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py deleted file mode 100644 index 0c9af06156..0000000000 --- a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core import config -from megatron.core.utils import is_te_min_version -from tests.unit_tests.test_utilities import Utils -from tests.unit_tests.transformer.moe.test_token_dispatcher import ( - MoEModelTestContainer, - permute_fusion_params, -) - - -def test_placeholder(): - """This is here because otherwise there's no other test in this module (all disabled) - and pytest would fail.""" - pass - - -class TestAlltoAllDispatcher: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.timeout(120) - @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)]) - @pytest.mark.parametrize("permute_fusion", permute_fusion_params) - def test_forward_backward(self, tp_size, ep_size, permute_fusion): - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="alltoall", - moe_permute_fusion=permute_fusion, - ) - container.dispatcher_dropless_test() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.timeout(120) - @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)]) - @pytest.mark.parametrize("permute_fusion", permute_fusion_params) - def test_capacity_forward_backward(self, tp_size, ep_size, permute_fusion): - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="alltoall", - moe_token_drop_policy="probs", - moe_expert_capacity_factor=0.5, - moe_pad_expert_input_to_capacity=False, - moe_permute_fusion=permute_fusion, - ) - container.dispatcher_capacity_test() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.timeout(120) - @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)]) - @pytest.mark.parametrize("permute_fusion", permute_fusion_params) - def test_capacity_padding_forward_backward(self, tp_size, ep_size, permute_fusion): - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="alltoall", - moe_token_drop_policy="probs", - moe_expert_capacity_factor=0.6, - moe_pad_expert_input_to_capacity=True, - moe_permute_fusion=permute_fusion, - ) - container.dispatcher_drop_and_pad_test() - - @pytest.mark.skipif( - not is_te_min_version("1.7.0"), reason="TE 1.7.0 is required for MoE with FP8." - ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.timeout(120) - @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2)]) - @pytest.mark.parametrize("permute_fusion", permute_fusion_params) - @pytest.mark.parametrize("experimental_fusion", [True, False]) - def test_router_padding_for_fp8_forward_backward( - self, tp_size, ep_size, permute_fusion, experimental_fusion - ): - if experimental_fusion: - config.ENABLE_EXPERIMENTAL = True - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="alltoall", - moe_pad_expert_input_to_capacity=False, - moe_permute_fusion=permute_fusion, - hidden_size=4, - ) - container.dispatcher_router_padding_for_fp8_test() - config.ENABLE_EXPERIMENTAL = False diff --git a/tests/unit_tests/transformer/moe/test_aux_loss.py b/tests/unit_tests/transformer/moe/test_aux_loss.py deleted file mode 100644 index 61bbfc5dfd..0000000000 --- a/tests/unit_tests/transformer/moe/test_aux_loss.py +++ /dev/null @@ -1,577 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import dataclasses - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region -from megatron.core.tensor_parallel.random import ( - get_cuda_rng_tracker, - model_parallel_cuda_manual_seed, -) -from megatron.core.transformer.moe.moe_utils import ( - clear_aux_losses_tracker, - get_default_model_comm_pgs, - get_moe_layer_wise_logging_tracker, -) -from megatron.core.transformer.moe.router import TopKRouter -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training.initialize import _set_random_seed -from tests.unit_tests.test_utilities import Utils -from tests.unit_tests.transformer.moe.test_token_dispatcher import MoEModelTestContainer - -try: - # Check availability of TE fused router aux ops - from megatron.core.extensions.transformer_engine import ( - fused_compute_score_for_moe_aux_loss as _fused_compute_score_for_moe_aux_loss, - ) - from megatron.core.extensions.transformer_engine import ( - fused_moe_aux_loss as _fused_moe_aux_loss, - ) - - HAVE_ROUTER_FUSION = ( - _fused_compute_score_for_moe_aux_loss is not None and _fused_moe_aux_loss is not None - ) -except Exception: # pragma: no cover - defensive - HAVE_ROUTER_FUSION = False - - -class AuxlossTestContainer(MoEModelTestContainer): - def partition_input(self, input): - partitioned_input = input.chunk( - parallel_state.get_tensor_and_context_parallel_world_size(), dim=0 - )[parallel_state.get_tensor_and_context_parallel_rank()] - output = partitioned_input.clone().detach() - output.requires_grad = True - return output - - @pytest.mark.internal - def aux_loss_test(self, input, baseline_grad, loss_name): - partitioned_input = self.partition_input(input) - moe_layer = self.moe_layer - probs, indices = moe_layer.router(partitioned_input) - probs.sum().mul_(0).backward() - aux_loss_grad = partitioned_input.grad - torch.distributed.barrier() - ans = self.partition_input(baseline_grad) - assert torch.allclose(aux_loss_grad, ans), f"Diff: {(aux_loss_grad/ans).mean()}" - loss = get_moe_layer_wise_logging_tracker()[loss_name]['values'] - assert loss > 0, "Loss should be greater than 0" - clear_aux_losses_tracker() - - with torch.no_grad(): - probs, indices = moe_layer.router(partitioned_input) - loss = get_moe_layer_wise_logging_tracker()[loss_name]['values'] - assert loss == 0, "Loss should be 0" - clear_aux_losses_tracker() - - -class TestAuxLoss: - def setup_method(self, method): - baseline_container = AuxlossTestContainer( - tp_size=1, - ep_size=1, - pp_size=1, - cp_size=1, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="alltoall", - moe_aux_loss_coeff=0.1, - ) - moe_layer = baseline_container.moe_layer - self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda() - self.input.requires_grad = True - probs, indices = moe_layer.router(self.input) - probs.sum().mul_(0).backward() # zero out the main gradients - self.baseline_grad = self.input.grad - self.input.grad = None - clear_aux_losses_tracker() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize( - "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] - ) - def test_allgather_dispatcher(self, tp_size, ep_size, cp_size): - container = AuxlossTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - cp_size=cp_size, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="allgather", - moe_aux_loss_coeff=0.1, - ) - container.aux_loss_test(self.input, self.baseline_grad, "load_balancing_loss") - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize( - "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] - ) - def test_a2a_dispatcher(self, tp_size, ep_size, cp_size): - container = AuxlossTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - cp_size=cp_size, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="alltoall", - moe_aux_loss_coeff=0.1, - ) - container.aux_loss_test(self.input, self.baseline_grad, "load_balancing_loss") - - -class TestSeqAuxLoss: - def setup_method(self, method): - baseline_container = AuxlossTestContainer( - tp_size=1, - ep_size=1, - pp_size=1, - cp_size=1, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="seq_aux_loss", - moe_token_dispatcher_type="alltoall", - moe_aux_loss_coeff=0.1, - ) - moe_layer = baseline_container.moe_layer - self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda() - self.input.requires_grad = True - probs, indices = moe_layer.router(self.input) - probs.sum().mul_(0).backward() # zero out the main gradients - self.baseline_grad = self.input.grad - self.input.grad = None - clear_aux_losses_tracker() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize( - "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] - ) - def test_a2a_dispatcher(self, tp_size, ep_size, cp_size): - container = AuxlossTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - cp_size=cp_size, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="seq_aux_loss", - moe_token_dispatcher_type="alltoall", - moe_aux_loss_coeff=0.1, - ) - container.aux_loss_test(self.input, self.baseline_grad, "seq_load_balancing_loss") - - -class TestRouterAuxLoss: - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - _set_random_seed(seed_=123, data_parallel_random_init=False) - - # Default configuration - self.default_transformer_config = TransformerConfig( - num_layers=1, - hidden_size=12, - num_attention_heads=8, - num_moe_experts=32, - use_cpu_initialization=True, - moe_router_load_balancing_type="aux_loss", - moe_router_topk=8, - moe_aux_loss_coeff=0, - bf16=True, - params_dtype=torch.bfloat16, - add_bias_linear=False, - ) - - def new_router(self, **kwargs): - """Create a new router with updated configuration. - - Args: - **kwargs: Configuration parameters to update in the default config. - - Returns: - Router: A new router instance with the specified configuration. - """ - model_comm_pgs = get_default_model_comm_pgs() - # Create a new config with updated parameters - new_transformer_config = dataclasses.replace(self.default_transformer_config, **kwargs) - - # Create the router with the updated config - router = TopKRouter(config=new_transformer_config, model_comm_pgs=model_comm_pgs) - router.set_layer_number(0) - return router - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize( - "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] - ) - def test_seq_aux_loss(self, tp_size, ep_size, cp_size): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, - expert_tensor_parallel_size=ep_size, - context_parallel_size=cp_size, - ) - model_parallel_cuda_manual_seed(42) - - # Test that with batch_size=1, aux_loss and seq_aux_loss should be the same - aux_loss_router = self.new_router( - moe_router_load_balancing_type="aux_loss", - moe_aux_loss_coeff=1.0, - moe_router_dtype="fp64", - tensor_model_parallel_size=tp_size, - expert_tensor_parallel_size=ep_size, - context_parallel_size=cp_size, - ).cuda() - seq_aux_loss_router = self.new_router( - moe_router_load_balancing_type="seq_aux_loss", - moe_aux_loss_coeff=1.0, - moe_router_dtype="fp64", - tensor_model_parallel_size=tp_size, - expert_tensor_parallel_size=ep_size, - context_parallel_size=cp_size, - ).cuda() - - # Set identical weights for fair comparison - with torch.no_grad(): - seq_aux_loss_router.weight.copy_(aux_loss_router.weight) - - ### MBS=1 case: results should be identical ### - clear_aux_losses_tracker() - seq_len = 32 - batch_size = 1 - with get_cuda_rng_tracker().fork(): - hidden_states = torch.randn( - (seq_len, batch_size, aux_loss_router.config.hidden_size), - device=torch.device("cuda"), - dtype=torch.bfloat16, - ) - - # Forward pass for aux_loss router - aux_loss_router.weight.grad = None - scores1, routing_map1 = aux_loss_router(hidden_states) - loss1 = scores1.sum() - loss1.backward() - grad1 = aux_loss_router.weight.grad.clone() - - # Forward pass for seq_aux_loss router - seq_aux_loss_router.weight.grad = None - scores2, routing_map2 = seq_aux_loss_router(hidden_states) - loss2 = scores2.sum() - loss2.backward() - grad2 = seq_aux_loss_router.weight.grad.clone() - - # For batch_size=1, they should produce the same results - tracker = get_moe_layer_wise_logging_tracker() - aux_loss = tracker["load_balancing_loss"]["values"][0] - seq_aux_loss = tracker["seq_load_balancing_loss"]["values"][0] - - reduce_from_tensor_model_parallel_region(aux_loss, aux_loss_router.tp_cp_group) - reduce_from_tensor_model_parallel_region(seq_aux_loss, aux_loss_router.tp_cp_group) - - assert torch.equal(routing_map1, routing_map2) - assert torch.equal(grad1, grad2) - assert torch.equal(scores1, scores2) - assert aux_loss == seq_aux_loss, f"aux_loss: {aux_loss}, seq_aux_loss: {seq_aux_loss}" - - ### MBS=2 case ### - clear_aux_losses_tracker() - batch_size = 2 - with get_cuda_rng_tracker().fork(): - hidden_states = torch.randn( - (seq_len, batch_size, aux_loss_router.config.hidden_size), - device=torch.device("cuda"), - dtype=torch.bfloat16, - ) - - # Forward pass for aux_loss router - aux_loss_router.weight.grad = None - scores_first_batch, _ = aux_loss_router(hidden_states[:, 0:1, :]) - scores_second_batch, _ = aux_loss_router(hidden_states[:, 1:, :]) - - # setting grad to 0 to only backward aux loss - (scores_first_batch + scores_second_batch).backward(torch.zeros_like(scores_first_batch)) - - grad1 = aux_loss_router.weight.grad.clone() - - # Forward pass for seq_aux_loss router - seq_aux_loss_router.weight.grad = None - scores2, routing_map2 = seq_aux_loss_router(hidden_states) - # setting grad to 0 to only backward aux loss - scores2.backward(torch.zeros_like(scores2)) - grad2 = seq_aux_loss_router.weight.grad.clone() * 2 - - aux_loss = tracker["load_balancing_loss"]["values"][0] / 2 - seq_aux_loss = tracker["seq_load_balancing_loss"]["values"][0] - reduce_from_tensor_model_parallel_region(aux_loss, aux_loss_router.tp_cp_group) - reduce_from_tensor_model_parallel_region(seq_aux_loss, aux_loss_router.tp_cp_group) - - torch.testing.assert_close(aux_loss, seq_aux_loss) - torch.testing.assert_close(grad1, grad2) - - @pytest.mark.internal - @pytest.mark.skipif( - not torch.cuda.is_available() or not HAVE_ROUTER_FUSION, - reason="CUDA or TE fused router ops not available", - ) - @pytest.mark.parametrize("aux_type", ["aux_loss", "seq_aux_loss"]) - def test_aux_loss_fusion_equivalence(self, aux_type): - # Compare fused vs unfused aux loss path to ensure numerical equivalence - router_ref = self.new_router( - moe_router_load_balancing_type=aux_type, moe_aux_loss_coeff=1.0, moe_router_dtype="fp32" - ).cuda() - router_fused = self.new_router( - moe_router_load_balancing_type=aux_type, moe_aux_loss_coeff=1.0, moe_router_dtype="fp32" - ).cuda() - - with torch.no_grad(): - router_fused.weight.copy_(router_ref.weight) - - hidden_states = torch.randn((32, 2, router_ref.config.hidden_size)).cuda().bfloat16() - - # Map aux type to its tracker key - loss_name_map = { - "aux_loss": "load_balancing_loss", - "seq_aux_loss": "seq_load_balancing_loss", - } - loss_name = loss_name_map[aux_type] - - # Unfused - router_ref.config.moe_router_fusion = False - clear_aux_losses_tracker() - router_ref.weight.grad = None - scores_ref, routing_ref = router_ref(hidden_states) - # Backward zeros to isolate aux-loss-only gradient contribution - scores_ref.backward(torch.zeros_like(scores_ref)) - grad_ref = router_ref.weight.grad.clone() - tracker = get_moe_layer_wise_logging_tracker() - aux_loss_ref = tracker[loss_name]["values"][0] - reduce_from_tensor_model_parallel_region(aux_loss_ref, router_ref.tp_cp_group) - - # Fused - router_fused.config.moe_router_fusion = True - clear_aux_losses_tracker() - router_fused.weight.grad = None - scores_fused, routing_fused = router_fused(hidden_states) - scores_fused.backward(torch.zeros_like(scores_fused)) - grad_fused = router_fused.weight.grad.clone() - tracker = get_moe_layer_wise_logging_tracker() - aux_loss_fused = tracker[loss_name]["values"][0] - reduce_from_tensor_model_parallel_region(aux_loss_fused, router_fused.tp_cp_group) - - # Checks - assert torch.equal(routing_ref, routing_fused) - torch.testing.assert_close(scores_ref, scores_fused, rtol=2.0e-2, atol=1.0e-3) - torch.testing.assert_close(aux_loss_ref, aux_loss_fused) - torch.testing.assert_close(grad_ref, grad_fused) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize( - "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] - ) - def test_global_aux_loss(self, tp_size, ep_size, cp_size): - clear_aux_losses_tracker() - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, - expert_tensor_parallel_size=ep_size, - context_parallel_size=cp_size, - ) - - router = self.new_router( - moe_router_load_balancing_type="global_aux_loss", - moe_aux_loss_coeff=1.0, - tensor_model_parallel_size=tp_size, - expert_tensor_parallel_size=ep_size, - context_parallel_size=cp_size, - ).cuda() - - seq_len = 32 - # Verify global tokens tracker initialized - assert router.global_tokens_per_expert is not None - assert router.ga_steps == 0 - - # First microbatch - with get_cuda_rng_tracker().fork(): - hidden_states = torch.randn((seq_len, 2, router.config.hidden_size)).cuda().bfloat16() - num_local_tokens = seq_len * 2 - scores, routing_map = router(hidden_states) - # Check that global tokens were counted - assert torch.all(router.global_tokens_per_expert >= 0) - assert ( - router.global_tokens_per_expert.sum() - == num_local_tokens * router.tp_dp_cp_group.size() * router.ga_steps * router.topk - ) - global_aux_loss_1 = get_moe_layer_wise_logging_tracker()["global_load_balancing_loss"][ - "values" - ][0] - reduce_from_tensor_model_parallel_region(global_aux_loss_1, router.tp_dp_cp_group) - assert global_aux_loss_1 >= 1 - - # When DP size is 1, the global aux loss should match the aux loss - # for the first microbatch - if get_default_model_comm_pgs().tp_dp_cp.size() == tp_size: - ref_router = self.new_router( - moe_router_load_balancing_type="aux_loss", moe_aux_loss_coeff=1.0 - ).cuda() - with torch.no_grad(): - ref_router.weight.copy_(router.weight) - ref_scores, ref_routing_map = ref_router(hidden_states) - aux_loss = get_moe_layer_wise_logging_tracker()["load_balancing_loss"]["values"][0] - reduce_from_tensor_model_parallel_region(aux_loss, router.tp_cp_group) - - assert torch.equal( - aux_loss, global_aux_loss_1 - ), f"aux_loss: {aux_loss}, global_aux_loss_1: {global_aux_loss_1}" - - clear_aux_losses_tracker() - - # Get current tokens count to verify accumulation - current_per_expert = router.global_tokens_per_expert.clone() - - # Second microbatch - should accumulate - hidden_states = torch.randn((seq_len, 2, router.config.hidden_size)).cuda().bfloat16() - scores, routing_map = router(hidden_states) - global_aux_loss_2 = get_moe_layer_wise_logging_tracker()["global_load_balancing_loss"][ - "values" - ][0] - reduce_from_tensor_model_parallel_region(global_aux_loss_2, router.tp_dp_cp_group) - assert torch.all(global_aux_loss_2 >= 1), f"global_aux_loss_2: {global_aux_loss_2}" - - # Verify tokens were accumulated - assert router.ga_steps == 2 - assert torch.any(router.global_tokens_per_expert > current_per_expert) - clear_aux_losses_tracker() - - # Reset global tracker - router.reset_global_aux_loss_tracker() - assert router.ga_steps == 0 - assert torch.all(router.global_tokens_per_expert == 0) - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize( - "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] - ) - def test_combined_aux_loss(self, tp_size, ep_size, cp_size): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, - expert_tensor_parallel_size=ep_size, - context_parallel_size=cp_size, - ) - clear_aux_losses_tracker() - - # Test combined aux loss types - router = self.new_router( - moe_router_load_balancing_type=["aux_loss", "seq_aux_loss", "global_aux_loss"], - moe_aux_loss_coeff=[0.5, 1.0, 2.0], - tensor_model_parallel_size=tp_size, - expert_tensor_parallel_size=ep_size, - context_parallel_size=cp_size, - ).cuda() - - # Verify all aux loss trackers initialized - assert router.global_tokens_per_expert is not None - assert router.ga_steps == 0 - - # Execute forward pass - hidden_states = torch.randn((32, 2, router.config.hidden_size)).cuda().bfloat16() - router.weight.grad = None - scores, routing_map = router(hidden_states) - loss = scores.sum() - loss.backward() - - aux_loss = get_moe_layer_wise_logging_tracker()["load_balancing_loss"]["values"][0] - seq_aux_loss = get_moe_layer_wise_logging_tracker()["seq_load_balancing_loss"]["values"][0] - global_aux_loss = get_moe_layer_wise_logging_tracker()["global_load_balancing_loss"][ - "values" - ][0] - - reduce_from_tensor_model_parallel_region(aux_loss, router.tp_cp_group) - reduce_from_tensor_model_parallel_region(seq_aux_loss, router.tp_cp_group) - reduce_from_tensor_model_parallel_region(global_aux_loss, router.tp_dp_cp_group) - - assert aux_loss >= 1 - assert seq_aux_loss >= 1 - assert global_aux_loss >= 1 - - # Verify gradient is non-zero (aux losses are being applied) - assert router.weight.grad.abs().sum() > 0 - - # Verify method to get aux loss coeffs works properly - assert router.get_aux_loss_coeff("aux_loss") == 0.5 - assert router.get_aux_loss_coeff("seq_aux_loss") == 1.0 - assert router.get_aux_loss_coeff("global_aux_loss") == 2.0 - assert router.get_aux_loss_coeff("non_existent_type") == 0.0 - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize( - "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] - ) - def test_force_balanced_aux_loss(self, tp_size, ep_size, cp_size): - """Test if aux loss is 1.0 when using uniform routing""" - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, - expert_tensor_parallel_size=ep_size, - context_parallel_size=cp_size, - ) - clear_aux_losses_tracker() - seq_len = 32 - batch_size = 2 - - # Create router with each aux loss type - for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]: - router = self.new_router( - moe_router_load_balancing_type=aux_loss_type, - moe_aux_loss_coeff=1.0, - moe_router_dtype="fp32", - tensor_model_parallel_size=tp_size, - expert_tensor_parallel_size=ep_size, - context_parallel_size=cp_size, - ).cuda() - # create uniform weights - with torch.no_grad(): - router.weight.copy_(torch.ones_like(router.weight) / router.weight.numel()) - - # Create uniform logits (all experts equally likely) - hidden_size = router.config.hidden_size - num_experts = router.config.num_moe_experts - - loss_name = { - "aux_loss": "load_balancing_loss", - "seq_aux_loss": "seq_load_balancing_loss", - "global_aux_loss": "global_load_balancing_loss", - }[aux_loss_type] - - hidden_states = torch.randn( - (seq_len, batch_size, hidden_size), - device=torch.device("cuda"), - dtype=torch.bfloat16, - ) - - # Get routing scores and map - scores, routing_map = router(hidden_states) - aux_loss = get_moe_layer_wise_logging_tracker()[loss_name]["values"][0] - if aux_loss_type == "global_aux_loss": - reduce_from_tensor_model_parallel_region(aux_loss, router.tp_dp_cp_group) - else: - reduce_from_tensor_model_parallel_region(aux_loss, router.tp_cp_group) - assert aux_loss.item() == 1, f"{aux_loss_type}: {aux_loss.item()}" - clear_aux_losses_tracker() diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py deleted file mode 100644 index f215f9008b..0000000000 --- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py +++ /dev/null @@ -1,396 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch -import torch.nn.functional as F - -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.transformer.module import Float16Module -from megatron.core.transformer.moe import grouped_gemm_util as gg -from megatron.core.transformer.moe.experts import TEGroupedMLP -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from megatron.training.arguments import parse_args -from megatron.training.initialize import _set_random_seed -from tests.unit_tests.test_utilities import Utils - -DEVICE_CAPABILITY = None -if torch.cuda.is_available(): - DEVICE_CAPABILITY = torch.cuda.get_device_capability() - - -@pytest.mark.skipif(is_te_min_version("1.9.0.dev0"), reason="Switch to TEGroupedMLP when TE>1.9.") -class TestParallelGroupedMLP: - - def setup_method(self, method, use_cpu_initialization=False, swiglu=True): - print("============") - print( - "Test for use_cpu_initilization={} and swiglu={}.".format( - use_cpu_initialization, swiglu - ) - ) - print("============") - Utils.initialize_model_parallel(1, 1) - num_layers = 1 # 2 - self.hidden_size = ( - 16 # must be an multiple of 16, otherwise trigger CUTLASS misaligned issue - ) - self.num_experts = 2 - self.gated_linear_unit = swiglu - self.activation_func = F.silu if swiglu else F.gelu - self.use_cpu_initialization = use_cpu_initialization - - tf_config = TransformerConfig( - num_layers=num_layers, - hidden_size=self.hidden_size, - num_attention_heads=4, - num_moe_experts=self.num_experts, - use_cpu_initialization=self.use_cpu_initialization, - add_bias_linear=False, - gated_linear_unit=self.gated_linear_unit, - activation_func=self.activation_func, - bias_activation_fusion=False, - bf16=True, - params_dtype=torch.bfloat16, - moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1, - ) - - self.fc1_ffn_hidden_size = tf_config.ffn_hidden_size - self.fc2_ffn_hidden_size = tf_config.ffn_hidden_size - # If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - if self.gated_linear_unit: - self.fc1_ffn_hidden_size *= 2 - - ## Vanilla sequential GEMM - # Set random seed for reproducability - _set_random_seed(seed_=123, data_parallel_random_init=False) - transformer_layer_spec = get_gpt_layer_local_spec(self.num_experts, moe_grouped_gemm=False) - self.sequential_mlp = MoELayer(tf_config, transformer_layer_spec.submodules.mlp.submodules) - - self.args = parse_args(ignore_unknown_args=True) - self.args.bf16 = True - # Bias is not supported in grouped gemm currently, thus we disable the - # bias in the linear layer. - self.args.add_bias_linear = False - self.sequential_mlp = Float16Module(self.sequential_mlp.config, self.sequential_mlp).module - print("done intializing for sequential gemm") - - ## Grouped GEMM - _set_random_seed(seed_=123, data_parallel_random_init=False) - tf_config.moe_grouped_gemm = True - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - self.num_experts, moe_grouped_gemm=True - ) - self.grouped_mlp = MoELayer(tf_config, transformer_layer_spec.submodules.mlp.submodules) - self.grouped_mlp = Float16Module(self.grouped_mlp.config, self.grouped_mlp).module - print("done intializing for grouped gemm") - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_constructor(self): - assert isinstance(self.sequential_mlp, MoELayer) - assert isinstance(self.grouped_mlp, MoELayer) - - num_weights_smm = sum([p.numel() for p in self.sequential_mlp.parameters()]) - num_weights_gmm = sum([p.numel() for p in self.grouped_mlp.parameters()]) - - # For the same hyper-parm model configs except the `moe_grouped_gemm`, - # GroupedGEMM and sequential GEMMs should hold the same number of parms. - assert num_weights_smm == num_weights_gmm - # expected num weights: router linear weights+bias + MLP weights(no bias) of all experts - expected_num_weights = ( - self.hidden_size * self.num_experts - + self.hidden_size - * (self.fc1_ffn_hidden_size + self.fc2_ffn_hidden_size) - * self.num_experts - ) - assert num_weights_smm == expected_num_weights - - assert torch.equal(self.sequential_mlp.router.weight, self.grouped_mlp.router.weight) - - # weight1: [h, num_experts*4h] - # weight2: [num_experts*4h, h] - assert self.grouped_mlp.experts.weight1.shape[0] == self.hidden_size - assert ( - self.grouped_mlp.experts.weight1.shape[1] == self.num_experts * self.fc1_ffn_hidden_size - ) - if self.gated_linear_unit: - assert ( - self.grouped_mlp.experts.weight2.shape[0] - == self.num_experts * self.fc2_ffn_hidden_size - ) - assert self.grouped_mlp.experts.weight2.shape[1] == self.hidden_size - else: - assert ( - self.grouped_mlp.experts.weight1.shape == self.grouped_mlp.experts.weight2.t().shape - ) - - @pytest.mark.internal - def test_weight_init_value_the_same(self): - gmm_w1 = self.grouped_mlp.experts.weight1.view(self.num_experts, -1, self.hidden_size) - gmm_w2 = self.grouped_mlp.experts.weight2.view(self.num_experts, self.hidden_size, -1) - gmm_expert1_fc1 = gmm_w1[0] - gmm_expert1_fc2 = gmm_w2[0] - gmm_expert2_fc1 = gmm_w1[1] - gmm_expert2_fc2 = gmm_w2[1] - - smm_expert1_fc1 = self.sequential_mlp.experts.local_experts[0].linear_fc1.weight - smm_expert1_fc2 = self.sequential_mlp.experts.local_experts[0].linear_fc2.weight - smm_expert2_fc1 = self.sequential_mlp.experts.local_experts[1].linear_fc1.weight - smm_expert2_fc2 = self.sequential_mlp.experts.local_experts[1].linear_fc2.weight - - assert torch.equal(gmm_expert1_fc1, smm_expert1_fc1) - if not self.use_cpu_initialization: - assert torch.equal(gmm_expert1_fc2, smm_expert1_fc2) - # the param init value is not exactly the same between gmm and smm (refer to test_weight_init_value_the_same.) - # TODO: is it necessary to keep smm and gmm share exactly the same init params? - # assert torch.equal(gmm_expert2_fc1, smm_expert2_fc1) - if self.use_cpu_initialization: - assert torch.equal(gmm_expert2_fc2, smm_expert2_fc2) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.skipif( - not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, - reason='GroupedGEMM kernels are not supported on this device.', - ) - def test_gpu_forward(self): - self.sequential_mlp.cuda() - self.grouped_mlp.cuda() - # [sequence length, batch size, hidden size] - seq_len = 3 # 32 - batch_size = 2 - hidden_states = torch.rand( - (seq_len, batch_size, self.sequential_mlp.config.hidden_size), dtype=torch.bfloat16 - ) - hidden_states = hidden_states.cuda() - output_smm, _ = self.sequential_mlp(hidden_states) - output_gmm, _ = self.grouped_mlp(hidden_states) - - # The following assert fails due to the param init value is not exactly - # the same between gmm and smm (refer to test_weight_init_value_the_same.) - # assert torch.equal(output_smm, output_gmm) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.skipif( - not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, - reason='GroupedGEMM kernels are not supported on this device.', - ) - def test_gpu_forward_with_no_tokens_allocated(self): - """Test the case when no token is allocated for groupedGEMM kernels.""" - w1 = self.grouped_mlp.experts.weight1.view(self.num_experts, -1, self.hidden_size) - num_allocated_tokens = 0 - tokens_per_expert = torch.zeros(self.num_experts) - hidden_states = torch.rand((num_allocated_tokens, self.hidden_size), dtype=torch.bfloat16) - hidden_states = hidden_states.cuda() - try: - gg.ops.gmm(hidden_states, w1, tokens_per_expert, trans_b=False) - except Exception as e: - print("Expected error message from groupedGEMM:", e) - assert str(e) == "Input batch_sizes should not be all zeros!" - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.skipif( - not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, - reason='GroupedGEMM kernels are not supported on this device.', - ) - def test_gradient_with_no_tokens_allocated(self): - """Test that when no token is passed in, the parameters of the grouped MLP will also have gradients.""" - self.grouped_mlp.cuda() - num_allocated_tokens = 0 - tokens_per_expert = torch.zeros(self.num_experts) - hidden_states = torch.rand((num_allocated_tokens, self.hidden_size), dtype=torch.bfloat16) - hidden_states = hidden_states.cuda() - probs = torch.rand((num_allocated_tokens,), dtype=torch.float32) - probs = probs.cuda() - output_gmm, _ = self.grouped_mlp.experts( - hidden_states, tokens_per_expert=tokens_per_expert, permuted_probs=probs - ) - output_gmm.mean().backward() - assert self.grouped_mlp.experts.weight1.grad is not None - - -@pytest.mark.skipif( - not is_te_min_version("1.9.0.dev0"), - reason="TE Grouped MLP is only supported in TE 1.9.0.dev0 and later.", -) -class TestTEGroupedMLP: - - def setup_method(self, method, use_cpu_initialization=False, swiglu=True): - Utils.initialize_model_parallel(1, 1) - num_layers = 1 - self.hidden_size = 16 - self.num_experts = 2 - self.gated_linear_unit = swiglu - self.activation_func = F.silu if swiglu else F.gelu - self.use_cpu_initialization = use_cpu_initialization - - tf_config = TransformerConfig( - num_layers=num_layers, - hidden_size=self.hidden_size, - num_attention_heads=4, - num_moe_experts=self.num_experts, - use_cpu_initialization=self.use_cpu_initialization, - add_bias_linear=False, - gated_linear_unit=self.gated_linear_unit, - activation_func=self.activation_func, - bias_activation_fusion=False, - bf16=True, - params_dtype=torch.bfloat16, - moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1, - ) - - self.fc1_ffn_hidden_size = tf_config.ffn_hidden_size - self.fc2_ffn_hidden_size = tf_config.ffn_hidden_size - # If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - if self.gated_linear_unit: - self.fc1_ffn_hidden_size *= 2 - - ## Vanilla sequential GEMM - # Set random seed for reproducability - _set_random_seed(seed_=123, data_parallel_random_init=False) - transformer_layer_spec = get_gpt_layer_local_spec(self.num_experts, moe_grouped_gemm=False) - self.sequential_mlp = MoELayer(tf_config, transformer_layer_spec.submodules.mlp.submodules) - - self.args = parse_args(ignore_unknown_args=True) - self.args.bf16 = True - # Bias is not supported in grouped gemm currently, thus we disable the - # bias in the linear layer. - self.args.add_bias_linear = False - self.sequential_mlp = Float16Module(self.sequential_mlp.config, self.sequential_mlp).module - - ## Grouped GEMM - _set_random_seed(seed_=123, data_parallel_random_init=False) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - self.num_experts, moe_grouped_gemm=True - ) - tf_config.moe_grouped_gemm = True - self.grouped_mlp = MoELayer(tf_config, transformer_layer_spec.submodules.mlp.submodules) - assert isinstance(self.grouped_mlp.experts, TEGroupedMLP) - self.grouped_mlp = Float16Module(self.grouped_mlp.config, self.grouped_mlp).module - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_constructor(self): - assert isinstance(self.sequential_mlp, MoELayer) - assert isinstance(self.grouped_mlp, MoELayer) - - num_weights_smm = sum([p.numel() for p in self.sequential_mlp.parameters()]) - num_weights_gmm = sum([p.numel() for p in self.grouped_mlp.parameters()]) - - # For the same hyper-parm model configs except the `moe_grouped_gemm`, - # GroupedGEMM and sequential GEMMs should hold the same number of parms. - assert num_weights_smm == num_weights_gmm - # expected num weights: router linear weights+bias + MLP weights(no bias) of all experts - expected_num_weights = ( - self.hidden_size * self.num_experts - + self.hidden_size - * (self.fc1_ffn_hidden_size + self.fc2_ffn_hidden_size) - * self.num_experts - ) - assert num_weights_smm == expected_num_weights - - assert torch.equal(self.sequential_mlp.router.weight, self.grouped_mlp.router.weight) - - # weights of linear_fc1: [fc1_ffn_hidden_size, hidden_size] - # weights of linear_fc2: [hidden_size, fc2_ffn_hidden_size] - for i in range(self.num_experts): - assert getattr(self.grouped_mlp.experts.linear_fc1, f"weight{i}").shape == ( - self.fc1_ffn_hidden_size, - self.hidden_size, - ) - assert getattr(self.grouped_mlp.experts.linear_fc2, f"weight{i}").shape == ( - self.hidden_size, - self.fc2_ffn_hidden_size, - ) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - def test_gpu_forward_backward(self): - self.sequential_mlp.cuda() - self.grouped_mlp.cuda() - # Copy the weights to ensure the same init value - with torch.no_grad(): - for i in range(self.num_experts): - self.sequential_mlp.experts.local_experts[i].linear_fc1.weight.copy_( - getattr(self.grouped_mlp.experts.linear_fc1, f"weight{i}") - ) - self.sequential_mlp.experts.local_experts[i].linear_fc2.weight.copy_( - getattr(self.grouped_mlp.experts.linear_fc2, f"weight{i}") - ) - # [sequence length, batch size, hidden size] - seq_len = 32 - batch_size = 2 - hidden_states = torch.rand( - (seq_len, batch_size, self.hidden_size), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - hidden_states.retain_grad() - - output_smm, _ = self.sequential_mlp(hidden_states) - output_smm.mean().backward() - smm_results = [output_smm, hidden_states.grad] - for i in range(self.num_experts): - smm_results.append(self.sequential_mlp.experts.local_experts[i].linear_fc1.weight.grad) - smm_results.append(self.sequential_mlp.experts.local_experts[i].linear_fc2.weight.grad) - - hidden_states.grad = None - output_gmm, _ = self.grouped_mlp(hidden_states) - output_gmm.mean().backward() - gmm_results = [output_gmm, hidden_states.grad] - for i in range(self.num_experts): - gmm_results.append(getattr(self.grouped_mlp.experts.linear_fc1, f"weight{i}").grad) - gmm_results.append(getattr(self.grouped_mlp.experts.linear_fc2, f"weight{i}").grad) - - for smm_result, gmm_result in zip(smm_results, gmm_results): - torch.testing.assert_close(smm_result, gmm_result) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - def test_gpu_forward_backward_with_no_tokens_allocated(self): - """Test the case when no token is allocated for groupedGEMM kernels.""" - self.grouped_mlp.cuda() - num_allocated_tokens = 0 - tokens_per_expert = torch.zeros(self.num_experts, dtype=torch.int32) - hidden_states = torch.rand((num_allocated_tokens, self.hidden_size), dtype=torch.bfloat16) - hidden_states = hidden_states.cuda() - probs = torch.rand((num_allocated_tokens,), dtype=torch.float32) - probs = probs.cuda() - output, _ = self.grouped_mlp.experts( - hidden_states, tokens_per_expert=tokens_per_expert, permuted_probs=probs - ) - assert torch.equal(output, torch.zeros_like(output)) - assert output.shape == (num_allocated_tokens, self.hidden_size) - - output.mean().backward() - for i in range(self.num_experts): - assert getattr(self.grouped_mlp.experts.linear_fc1, f"weight{i}").grad is not None - assert getattr(self.grouped_mlp.experts.linear_fc2, f"weight{i}").grad is not None - - -if __name__ == "__main__": - for use_cpu_unitilization in [True, False]: - for swiglu in [True, False]: - GMLP_test = TestParallelGroupedMLP() - GMLP_test.setup_method( - method=None, use_cpu_initialization=use_cpu_unitilization, swiglu=swiglu - ) - GMLP_test.test_constructor() - GMLP_test.test_weight_init_value_the_same() - GMLP_test.test_gpu_forward() - GMLP_test.test_gpu_forward_with_no_tokens_allocated() - GMLP_test.teardown_method(method=None) diff --git a/tests/unit_tests/transformer/moe/test_moe_layer.py b/tests/unit_tests/transformer/moe/test_moe_layer.py deleted file mode 100644 index 59385f757b..0000000000 --- a/tests/unit_tests/transformer/moe/test_moe_layer.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_decoder_block_spec, - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.moe.router import Router -from megatron.core.transformer.transformer_block import TransformerBlock -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from megatron.training.initialize import _set_random_seed -from tests.unit_tests.test_utilities import Utils - - -class TestMoELayerInit: - def setup_method(self, method): - pass - - @pytest.mark.skipif( - not is_te_min_version("1.7.0.dev0"), - reason="Expert with TE Linear is only supported in TE 1.7.0 and later.", - ) - @pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"]) - @pytest.mark.parametrize("num_moe_experts", [1, 2]) - @pytest.mark.parametrize("grouped_gemm", [True, False]) - def test_te_moe_layer(self, num_moe_experts, moe_token_dispatcher_type, grouped_gemm): - Utils.initialize_model_parallel(1, 1) - _set_random_seed(seed_=123, data_parallel_random_init=False) - self.transformer_config = TransformerConfig( - num_layers=1, - hidden_size=12, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - use_cpu_initialization=True, - moe_token_dispatcher_type=moe_token_dispatcher_type, - moe_router_topk=2, - moe_aux_loss_coeff=0.01, - moe_grouped_gemm=grouped_gemm, - moe_ffn_hidden_size=128, - add_bias_linear=False, - ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm - ) - moe_layer = MoELayer( - self.transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - Utils.destroy_model_parallel() - - @pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"]) - @pytest.mark.parametrize("num_moe_experts", [1, 2]) - @pytest.mark.parametrize("grouped_gemm", [True, False]) - def test_legacy_moe_layer(self, num_moe_experts, moe_token_dispatcher_type, grouped_gemm): - Utils.initialize_model_parallel(1, 1) - _set_random_seed(seed_=123, data_parallel_random_init=False) - num_moe_experts = 4 - self.transformer_config = TransformerConfig( - num_layers=1, - hidden_size=12, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - use_cpu_initialization=True, - moe_token_dispatcher_type=moe_token_dispatcher_type, - moe_router_load_balancing_type="aux_loss", - moe_router_topk=2, - moe_aux_loss_coeff=0.01, - moe_grouped_gemm=grouped_gemm, - add_bias_linear=False, - ) - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm - ) - moe_layer = MoELayer( - self.transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - Utils.destroy_model_parallel() - - @pytest.mark.skip( - "Late init of parallel_state was broken after parallel states refactor MR2988." - ) - @pytest.mark.parametrize("moe_token_dispatcher_type", ["alltoall", "allgather"]) - @pytest.mark.parametrize("grouped_gemm", [True, False]) - @pytest.mark.parametrize("tp_size,ep_size", [(1, 1), (2, 2)]) - def test_moe_with_late_initialize( - self, moe_token_dispatcher_type, grouped_gemm, tp_size, ep_size - ): - num_moe_experts = 4 - hidden_size = 12 - transformer_config = TransformerConfig( - num_layers=1, - hidden_size=hidden_size, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - use_cpu_initialization=True, - moe_router_load_balancing_type="aux_loss", - moe_router_topk=2, - moe_aux_loss_coeff=0.01, - add_bias_linear=False, - moe_grouped_gemm=grouped_gemm, - moe_token_dispatcher_type=moe_token_dispatcher_type, - tensor_model_parallel_size=tp_size, - expert_model_parallel_size=ep_size, - sequence_parallel=tp_size > 1, - bf16=True, - params_dtype=torch.bfloat16, - ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm - ) - - # Fake initialization as NeMo does - Utils.fake_initialize_model_parallel( - tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size - ) - moe_layer = MoELayer( - transformer_config, transformer_layer_spec.submodules.mlp.submodules - ).cuda() - - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size - ) - _set_random_seed(seed_=123, data_parallel_random_init=False) - - input_data = torch.randn( - 16, 4, hidden_size, device=torch.cuda.current_device(), dtype=torch.bfloat16 - ) - output = moe_layer(input_data) - - Utils.destroy_model_parallel() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - -class TestInterleaveTransformerBlock: - - @pytest.mark.parametrize("moe_layer_freq", [2, eval("[0,1,1,1]"), eval("[0]*2+[1]*2")]) - def test_interleave_transformer_block(self, moe_layer_freq): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = TransformerConfig( - num_layers=4, - hidden_size=64, - num_attention_heads=4, - moe_layer_freq=moe_layer_freq, - moe_ffn_hidden_size=256, - use_cpu_initialization=True, - num_moe_experts=2, - add_bias_linear=False, - ) - self.parallel_transformer_block = TransformerBlock( - self.transformer_config, get_gpt_decoder_block_spec(self.transformer_config, False) - ) - - # Check if the moe layer is interleaved correctly - if isinstance(self.transformer_config.moe_layer_freq, int): - moe_layer_pattern = [ - 1 if (i % self.transformer_config.moe_layer_freq == 0) else 0 - for i in range(self.transformer_config.num_layers) - ] - else: - moe_layer_pattern = self.transformer_config.moe_layer_freq - - for i, layer in enumerate(self.parallel_transformer_block.layers): - is_moe_layer = isinstance(layer.mlp, MoELayer) - assert is_moe_layer == moe_layer_pattern[i] - - # Test forward pass - parallel_transformer_block = self.parallel_transformer_block - config: TransformerConfig = parallel_transformer_block.config - sequence_length = 32 - micro_batch_size = 2 - parallel_transformer_block.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - hidden_states = parallel_transformer_block( - hidden_states=hidden_states, attention_mask=attention_mask - ) - assert hidden_states.shape[0] == sequence_length - assert hidden_states.shape[1] == micro_batch_size - assert hidden_states.shape[2] == config.hidden_size - - def teardown_method(self, method): - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/transformer/moe/test_moe_layer_discrepancy.py b/tests/unit_tests/transformer/moe/test_moe_layer_discrepancy.py deleted file mode 100644 index 4386a844dd..0000000000 --- a/tests/unit_tests/transformer/moe/test_moe_layer_discrepancy.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import time - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import TransformerLayer -from megatron.training.initialize import _set_random_seed -from tests.unit_tests.test_utilities import Utils - - -class TestMoELayerDispatcherDiscrepancy: - def setup_method(self, method): - pass - - @pytest.mark.parametrize("num_moe_experts", [8]) - @pytest.mark.parametrize("grouped_gemm", [False]) - @pytest.mark.parametrize( - "tp_size,ep_size", [(1, 1), (1, 2), (1, 8), (2, 1), (8, 1), (2, 2), (2, 4), (4, 2)] - ) - @pytest.mark.internal - def test_moe_layer_dispatcher_discrepancy( - self, num_moe_experts, grouped_gemm, tp_size, ep_size - ): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size - ) - self.transformer_config = TransformerConfig( - num_layers=1, - hidden_size=4096, - num_attention_heads=32, - num_moe_experts=num_moe_experts, - use_cpu_initialization=False, - moe_token_dispatcher_type="allgather", - moe_router_topk=2, - moe_aux_loss_coeff=0.01, - moe_grouped_gemm=grouped_gemm, - moe_router_dtype="fp64", - add_bias_linear=False, - tensor_model_parallel_size=tp_size, - expert_model_parallel_size=ep_size, - sequence_parallel=True if (tp_size > 1) else False, - ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm - ) - # Init input and layer - _set_random_seed(seed_=123, data_parallel_random_init=False) - input = torch.randn(1, 4096, 4096).cuda().float() - - # Init allgather moe layer - _set_random_seed(seed_=123, data_parallel_random_init=False) - layer = ( - TransformerLayer(self.transformer_config, transformer_layer_spec.submodules) - .cuda() - .float() - ) - ag_moe_layer = layer.mlp - ag_moe_layer.eval() - # Init a2a moe layer - self.transformer_config.moe_token_dispatcher_type = "alltoall" - _set_random_seed(seed_=123, data_parallel_random_init=False) - layer = ( - TransformerLayer(self.transformer_config, transformer_layer_spec.submodules) - .cuda() - .float() - ) - a2a_moe_layer = layer.mlp - a2a_moe_layer.eval() - - # Check if parameters are the same - for ag_param, a2a_param in zip(ag_moe_layer.parameters(), a2a_moe_layer.parameters()): - assert torch.equal(ag_param, a2a_param) - torch.distributed.barrier() - - # Allgather the input to check if the input is the same in all the ranks - # Check if input is the same across all ranks - input_ag_shape = (torch.distributed.get_world_size(), *(input.shape)) - input_ag = torch.zeros(input_ag_shape, device=input.device, dtype=input.dtype) - torch.distributed.all_gather_into_tensor( - input_ag, input, group=torch.distributed.group.WORLD - ) - if torch.distributed.get_rank() == 0: - for i in range(1, torch.distributed.get_world_size()): - assert torch.equal(input_ag[0], input_ag[i]), f"Input differs at rank {i}" - # print(f"Input is the same across all ranks") - - # Test allgather dispatcher - with torch.no_grad(): - ag_output = ag_moe_layer(input)[0] - a2a_output = a2a_moe_layer(input)[0] - - assert torch.allclose( - ag_output, a2a_output, atol=1e-6 - ), f"Ag output: {ag_output.min()}, {ag_output.max()}, {ag_output.sum()}, a2a output: {a2a_output.min()}, {a2a_output.max()}, {a2a_output.sum()}, diff: {torch.abs(ag_output - a2a_output).max()}" - # print(f"Allgather and A2A output is the same", flush=True) - - Utils.destroy_model_parallel() - - @pytest.mark.parametrize("num_moe_experts", [8]) - @pytest.mark.parametrize("grouped_gemm", [False, True]) - @pytest.mark.parametrize( - "tp_size,ep_size", [(1, 1), (1, 2), (1, 8), (2, 1), (8, 1), (2, 2), (2, 4), (4, 2)] - ) - @pytest.mark.internal - def test_moe_layer_ag_dispatcher_discrepancy( - self, num_moe_experts, grouped_gemm, tp_size, ep_size - ): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size - ) - self.transformer_config = TransformerConfig( - num_layers=1, - hidden_size=4096, - num_attention_heads=32, - num_moe_experts=num_moe_experts, - use_cpu_initialization=False, - moe_token_dispatcher_type="allgather", - moe_router_topk=2, - moe_aux_loss_coeff=0.01, - moe_grouped_gemm=grouped_gemm, - moe_router_dtype="fp64", - add_bias_linear=False, - tensor_model_parallel_size=tp_size, - expert_model_parallel_size=ep_size, - sequence_parallel=True if (tp_size > 1 and ep_size > 1) else False, - bf16=True, - ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm - ) - # Init input and layer - _set_random_seed(seed_=123, data_parallel_random_init=False) - input = torch.randn(1, 4096, 4096).cuda().bfloat16() - - # Init allgather moe layer - _set_random_seed(seed_=123, data_parallel_random_init=False) - layer = ( - TransformerLayer(self.transformer_config, transformer_layer_spec.submodules) - .cuda() - .bfloat16() - ) - ag_moe_layer = layer.mlp - ag_moe_layer.eval() - - # Test allgather dispatcher - ag_output = ag_moe_layer(input)[0] - # Allgather the output to check if it's the same in all ranks - ag_output_ag_shape = (torch.distributed.get_world_size(), *(ag_output.shape)) - ag_output_ag = torch.zeros( - ag_output_ag_shape, device=ag_output.device, dtype=ag_output.dtype - ) - torch.distributed.all_gather_into_tensor( - ag_output_ag, ag_output, group=torch.distributed.group.WORLD - ) - # Check if output is the same across all ranks - if parallel_state.get_data_parallel_rank() == 0: - for i in range(1, parallel_state.get_tensor_model_parallel_world_size()): - if not torch.allclose(ag_output_ag[0], ag_output_ag[i]): - print(f"Allgather output differs at rank {torch.distributed.get_rank()}") - print( - f"ag_output_ag[0]: min {ag_output_ag[0].double().min()}, max {ag_output_ag[0].double().max()}, std {ag_output_ag[0].double().std()}" - ) - print( - f"ag_output_ag[{i}]: min {ag_output_ag[i].double().min()}, max {ag_output_ag[i].double().max()}, std {ag_output_ag[i].double().std()}" - ) - raise ValueError("Allgather output differs at rank {i}") - torch.cuda.synchronize() - Utils.destroy_model_parallel() - - @pytest.mark.parametrize("num_moe_experts", [8]) - @pytest.mark.parametrize("grouped_gemm", [False]) - @pytest.mark.parametrize( - "tp_size,ep_size", [(1, 1), (1, 2), (1, 8), (2, 1), (4, 1), (8, 1), (2, 4), (4, 2)] - ) - @pytest.mark.internal - def test_moe_layer_a2a_dispatcher_discrepancy( - self, num_moe_experts, grouped_gemm, tp_size, ep_size - ): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size - ) - self.transformer_config = TransformerConfig( - num_layers=1, - hidden_size=4096, - num_attention_heads=32, - num_moe_experts=num_moe_experts, - use_cpu_initialization=False, - moe_token_dispatcher_type="alltoall", - moe_router_topk=2, - moe_aux_loss_coeff=0.01, - moe_grouped_gemm=grouped_gemm, - moe_router_dtype="fp64", - add_bias_linear=False, - tensor_model_parallel_size=tp_size, - expert_model_parallel_size=ep_size, - sequence_parallel=True if (tp_size > 1 and ep_size > 1) else False, - bf16=True, - ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm - ) - # Init input and layer - _set_random_seed(seed_=123, data_parallel_random_init=False) - input = torch.randn(1, 4096, 4096).cuda().bfloat16() - - # Init a2a moe layer - layer = ( - TransformerLayer(self.transformer_config, transformer_layer_spec.submodules) - .cuda() - .bfloat16() - ) - a2a_moe_layer = layer.mlp - a2a_moe_layer.eval() - - # Test alltoall dispatcher - a2a_output = a2a_moe_layer(input)[0] - # Allgather the output to check if it's the same in all ranks - at_output_ag_shape = (torch.distributed.get_world_size(), *(a2a_output.shape)) - at_output_ag = torch.zeros( - at_output_ag_shape, device=a2a_output.device, dtype=a2a_output.dtype - ) - torch.distributed.all_gather_into_tensor( - at_output_ag, a2a_output, group=torch.distributed.group.WORLD - ) - # Check if output is the same across all ranks - if parallel_state.get_data_parallel_rank() == 0: - for i in range(1, parallel_state.get_tensor_model_parallel_world_size()): - if not torch.equal(at_output_ag[0], at_output_ag[i]): - print( - f"at_output_ag[0]: min {at_output_ag[0].double().min()}, max {at_output_ag[0].double().max()}, sum {at_output_ag[0].double().sum()}" - ) - print( - f"at_output_ag[{i}]: min {at_output_ag[i].double().min()}, max {at_output_ag[i].double().max()}, sum {at_output_ag[i].double().sum()}" - ) - print(f"diff {torch.abs(at_output_ag[0] - at_output_ag[i]).max()}") - print(f"A2A output differs at rank {torch.distributed.get_rank()}") - raise ValueError(f"A2A output differs at rank {i}") - torch.cuda.synchronize() - - Utils.destroy_model_parallel() - - def teardown_method(self, method): - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/transformer/moe/test_multihot_indices_converter.py b/tests/unit_tests/transformer/moe/test_multihot_indices_converter.py deleted file mode 100644 index a5775f1e12..0000000000 --- a/tests/unit_tests/transformer/moe/test_multihot_indices_converter.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import copy -import random - -import pytest -import torch - -from megatron.core import config -from megatron.core.fusions.fused_indices_converter import fused_indices_to_multihot - - -class PytorchIndicesToMultihot: - def __init__(self, num_instances): - self.num_instances = num_instances - - def _indices_to_multihot(self, indices, probs): - batch_size = indices.shape[0] - multihot_routing_map = torch.zeros( - (batch_size, self.num_instances), dtype=torch.long, device=indices.device - ) - multihot_probs = torch.zeros( - (batch_size, self.num_instances), dtype=torch.float, device=indices.device - ) - mask = indices != -1 - valid_indices = indices[mask] - row_indices = torch.arange(batch_size, device=indices.device).repeat_interleave( - mask.sum(dim=1) - ) - multihot_routing_map[row_indices, valid_indices] = 1 - multihot_probs[row_indices, valid_indices] = probs[mask] - return multihot_routing_map.bool(), multihot_probs - - -class TestIndicesToMultihot: - - def setup_method(self, method): - # enable experimental feature - if config.ENABLE_EXPERIMENTAL is False: - config.ENABLE_EXPERIMENTAL = True - - def teardown_method(self, method): - # disable experimental feature - if config.ENABLE_EXPERIMENTAL is True: - config.ENABLE_EXPERIMENTAL = False - - @pytest.mark.experimental - @pytest.mark.parametrize("num_of_token", [3, 5, 8, 128, 512]) - @pytest.mark.parametrize("topk", [2, 4, 6, 7, 8]) - @pytest.mark.parametrize("num_of_local_experts", [4, 7, 8, 12, 20, 30, 31, 32]) - def test_indices_to_multihot(self, num_of_token, topk, num_of_local_experts): - # construct the indices and probs_indices - indices = torch.full((num_of_token, topk), -1, dtype=torch.int32, device='cuda') - probs_indices = torch.full((num_of_token, topk), 0, dtype=torch.float32, device='cuda') - # Fill the indices with random values - # There are 2 non-ordinary values in each row - for i in range(num_of_token): - positions = random.sample(range(indices.shape[1]), 2) - values = random.sample(range(num_of_local_experts), 2) - indices[i, positions[0]] = values[0] - indices[i, positions[1]] = values[1] - mask = indices != -1 - probs_indices[mask] = torch.rand(mask.sum(), device=indices.device) - probs_indices.requires_grad = True - - indices_pytorch = copy.deepcopy(indices) - probs_indices_pytorch = copy.deepcopy(probs_indices) - - # test forward - multihot_indices, probs_in_multihot = fused_indices_to_multihot( - indices, probs_indices, num_of_local_experts - ) - pytorch_class = PytorchIndicesToMultihot(num_of_local_experts) - multihot_indices_pytorch, probs_in_multihot_pytorch = pytorch_class._indices_to_multihot( - indices_pytorch, probs_indices_pytorch - ) - assert torch.allclose(multihot_indices, multihot_indices_pytorch) - assert torch.allclose(probs_in_multihot, probs_in_multihot_pytorch) - - # test backward - loss = (probs_in_multihot @ torch.transpose(probs_in_multihot, 0, 1)).sum() / 2 - loss.backward() - loss_pytorch = ( - probs_in_multihot_pytorch @ torch.transpose(probs_in_multihot_pytorch, 0, 1) - ).sum() / 2 - loss_pytorch.backward() - assert torch.allclose(probs_indices.grad, probs_indices_pytorch.grad) diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py deleted file mode 100644 index 172eb0ae29..0000000000 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ /dev/null @@ -1,486 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.moe.moe_utils import get_updated_expert_bias, router_gating_linear -from megatron.core.transformer.moe.router import Router -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training.initialize import _set_random_seed -from tests.unit_tests.test_utilities import Utils - -try: - # Check availability of TE fused router ops - from megatron.core.extensions.transformer_engine import ( - fused_topk_with_score_function as _fused_topk_with_score_function, - ) - - HAVE_ROUTER_FUSION = _fused_topk_with_score_function is not None -except Exception: # pragma: no cover - defensive - HAVE_ROUTER_FUSION = False - - -class TestTop2Router: - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - _set_random_seed(seed_=123, data_parallel_random_init=False) - print("done intializing") - num_moe_experts = 4 - self.transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - use_cpu_initialization=True, - moe_router_load_balancing_type="aux_loss", - moe_router_topk=2, - moe_aux_loss_coeff=0, - bf16=True, - params_dtype=torch.bfloat16, - add_bias_linear=False, - ) - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False - ) - self.sequential_mlp = MoELayer( - self.transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - self.router = self.sequential_mlp.router - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_constructor(self): - assert isinstance(self.router, Router) - - num_weights = sum([p.numel() for p in self.router.parameters()]) - assert num_weights == 12 * 4, num_weights - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize("moe_router_pre_softmax", [(True), (False)]) - @pytest.mark.parametrize("score_function", ["sigmoid", "softmax"]) - def test_router_forward(self, moe_router_pre_softmax, score_function): - with torch.no_grad(): - self.router = self.router.cuda() - self.router.config.moe_router_pre_softmax = moe_router_pre_softmax - self.router.config.moe_router_score_function = score_function - # [num tokens, hidden size] - hidden_states = torch.randn((32, 2, self.router.config.hidden_size)) - hidden_states = hidden_states.cuda().bfloat16() - scores, indices = self.router(hidden_states) - - @pytest.mark.internal - @pytest.mark.skipif( - not torch.cuda.is_available() or not HAVE_ROUTER_FUSION, - reason="TE fused router ops not available", - ) - @pytest.mark.parametrize("score_function", ["sigmoid", "softmax"]) - def test_router_forward_fusion_equivalence(self, score_function): - with torch.no_grad(): - self.router = self.router.cuda() - self.router.config.moe_router_score_function = score_function - hidden_states = torch.randn((32, 2, self.router.config.hidden_size)) - hidden_states = hidden_states.cuda().bfloat16() - - # Unfused - self.router.config.moe_router_fusion = False - scores_ref, routing_ref = self.router(hidden_states) - - # Fused - self.router.config.moe_router_fusion = True - scores_fused, routing_fused = self.router(hidden_states) - - assert torch.equal(routing_ref, routing_fused), "Routing map mismatch" - torch.testing.assert_close(scores_ref, scores_fused) - # restore the config - self.router.config.moe_router_fusion = False - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_aux_loss(self): - self.sequential_mlp = self.sequential_mlp.cuda() - - # Without aux loss - hidden_states = torch.randn((32, 2, self.router.config.hidden_size)) - hidden_states = hidden_states.cuda().bfloat16() - out = self.sequential_mlp(hidden_states)[0] - out.sum().mul_(0).backward() - assert self.sequential_mlp.router.weight.grad.abs().sum() == 0 - - # With aux loss - self.transformer_config.moe_aux_loss_coeff = 1 - out = self.sequential_mlp(hidden_states)[0] - out.sum().mul_(0).backward() - assert self.sequential_mlp.router.weight.grad.abs().sum() > 0 - - # With Z loss - self.transformer_config.moe_aux_loss_coeff = 0 - self.transformer_config.moe_z_loss_coeff = 1 - self.sequential_mlp.router.weight.grad.fill_(0) - out = self.sequential_mlp(hidden_states)[0] - out.sum().mul_(0).backward() - assert self.sequential_mlp.router.weight.grad.abs().sum() > 0 - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_router_dtype(self): - self.router = self.router.cuda() - self.sequential_mlp = self.sequential_mlp.cuda() - hidden_states = torch.randn((32, 2, self.router.config.hidden_size), dtype=torch.bfloat16) - hidden_states = hidden_states.cuda() - - # Test with default setting (bf16) - self.router.config.moe_router_dtype = None - with torch.no_grad(): - scores, routing_map = self.router(hidden_states) - out = self.sequential_mlp(hidden_states) - assert scores.dtype == torch.bfloat16, "Router output should be bf16 by default" - assert out[0].dtype == torch.bfloat16 - - # Test with fp32 enabled - self.router.config.moe_router_dtype = 'fp32' - with torch.no_grad(): - scores, routing_map = self.router(hidden_states) - out = self.sequential_mlp(hidden_states) - assert scores.dtype == torch.float32, "Router output should be fp32 when enabled" - assert out[0].dtype == torch.bfloat16 - self.sequential_mlp.config.moe_token_dispatcher_type = "alltoall" - out = self.sequential_mlp(hidden_states) - assert out[0].dtype == torch.bfloat16 - self.sequential_mlp.config.moe_token_dispatcher_type = "allgather" - - # Test with fp64 enabled - self.router.config.moe_router_dtype = 'fp64' - with torch.no_grad(): - scores, routing_map = self.router(hidden_states) - out = self.sequential_mlp(hidden_states) - assert scores.dtype == torch.float64, "Router output should be fp64 when enabled" - assert out[0].dtype == torch.bfloat16 - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_force_load_balancing(self): - hidden_states = torch.randn( - (32, 2, self.router.config.hidden_size), device="cuda", dtype=torch.bfloat16 - ) - hidden_states.requires_grad = True - - # First forward pass with normal routing - normal_scores, normal_routing_map = self.router(hidden_states) - - # Second forward pass with force load balancing - self.router.config.moe_router_force_load_balancing = True - force_scores, force_routing_map = self.router(hidden_states) - - assert normal_scores.shape == force_scores.shape - assert normal_routing_map.shape == force_routing_map.shape - assert torch.equal(normal_scores, force_scores) == False - - # Backward pass for force load balancing - self.router.zero_grad() - force_scores.sum().backward() - assert hidden_states.grad is not None - assert self.router.weight.grad.norm() > 0 - - self.router.config.moe_router_force_load_balancing = False - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize("capacity_factor", [None, 1.0, 2.0]) - @pytest.mark.parametrize("drop_policy", ["probs", "position"]) - @pytest.mark.parametrize("pad_to_capacity", [True, False]) - def test_token_dropping(self, capacity_factor, drop_policy, pad_to_capacity): - if capacity_factor is None and pad_to_capacity: - pytest.skip("Capacity factor is None, so no token dropping should be applied") - - num_tokens = 32 - self.router = self.router.cuda() - self.router.config.moe_expert_capacity_factor = capacity_factor - self.router.config.moe_token_drop_policy = drop_policy - self.router.config.moe_pad_expert_input_to_capacity = pad_to_capacity - - hidden_states = torch.randn( - (num_tokens, self.router.config.hidden_size), dtype=torch.bfloat16, device="cuda" - ) - hidden_states.requires_grad = True - probs, routing_map = self.router(hidden_states) - - if capacity_factor is not None: - if pad_to_capacity: - assert ( - routing_map.sum().item() - == num_tokens * self.router.config.moe_router_topk * capacity_factor - ) - else: - assert ( - routing_map.sum().item() - <= num_tokens * self.router.config.moe_router_topk * capacity_factor - ) - else: - assert routing_map.sum().item() == num_tokens * self.router.config.moe_router_topk - - # restore the config - self.router.config.moe_expert_capacity_factor = None - self.router.config.moe_token_drop_policy = "probs" - self.router.config.moe_pad_expert_input_to_capacity = False - - -class TestGroupLimitedRouter: - def setup_method(self, method): - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - expert_model_parallel_size=8, - context_parallel_size=1, - ) - _set_random_seed(seed_=123, data_parallel_random_init=False) - print("done intializing") - - num_moe_experts = 16 - self.transformer_config = TransformerConfig( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - expert_model_parallel_size=8, - context_parallel_size=1, - num_moe_experts=num_moe_experts, - moe_router_topk=4, - moe_router_group_topk=2, - moe_router_num_groups=8, - moe_router_pre_softmax=True, - moe_router_load_balancing_type="aux_loss", - moe_aux_loss_coeff=0, - moe_router_dtype='fp32', - moe_token_dispatcher_type="alltoall", - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - bf16=True, - params_dtype=torch.bfloat16, - add_bias_linear=False, - ) - - # init MoE layer - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False - ) - self.moe_layer = MoELayer( - self.transformer_config, transformer_layer_spec.submodules.mlp.submodules - ).cuda() - self.router = self.moe_layer.router - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_constructor(self): - assert isinstance(self.router, Router) - - num_weights = sum([p.numel() for p in self.router.parameters()]) - assert ( - num_weights - == self.transformer_config.hidden_size * self.transformer_config.num_moe_experts - ), num_weights - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize("moe_router_group_topk,moe_router_num_groups", [(3, 8), (2, 4)]) - @pytest.mark.parametrize("moe_router_pre_softmax", [(True), (False)]) - @pytest.mark.parametrize("score_function", ["sigmoid", "softmax"]) - def test_router_forward( - self, moe_router_group_topk, moe_router_num_groups, moe_router_pre_softmax, score_function - ): - with torch.no_grad(): - self.router.config.moe_router_group_topk = moe_router_group_topk - self.router.config.moe_router_num_groups = moe_router_num_groups - self.router.config.moe_router_pre_softmax = moe_router_pre_softmax - self.router.config.moe_router_score_function = score_function - if moe_router_pre_softmax: - self.router.config.moe_router_topk_scaling_factor = 16.0 - - seq_len = 128 - batch_size = 4 - num_tokens = seq_len * batch_size - # hidden_states shape: [seq_len, batch_size, hidden_size] - hidden_states = ( - torch.randn((seq_len, batch_size, self.router.config.hidden_size)).cuda().bfloat16() - ) - scores, routing_map = self.router(hidden_states) - assert scores.shape == (num_tokens, self.router.config.num_moe_experts), scores.shape - assert routing_map.shape == ( - num_tokens, - self.router.config.num_moe_experts, - ), routing_map.shape - - group_routing_map = ( - routing_map.reshape(num_tokens, moe_router_num_groups, -1).max(dim=-1).values - ) - assert torch.all(group_routing_map.sum(dim=-1) <= moe_router_group_topk) - - @pytest.mark.internal - @pytest.mark.skipif( - not torch.cuda.is_available() or not HAVE_ROUTER_FUSION, - reason="TE fused router ops not available", - ) - @pytest.mark.parametrize("score_function", ["sigmoid", "softmax"]) - def test_router_forward_fusion_equivalence(self, score_function): - with torch.no_grad(): - self.router = self.router.cuda() - self.router.score_function = score_function - seq_len = 32 - batch_size = 4 - hidden_states = torch.randn((seq_len, batch_size, self.router.config.hidden_size)) - hidden_states = hidden_states.cuda().bfloat16() - - # Unfused - self.router.config.moe_router_fusion = False - scores_ref, routing_ref = self.router(hidden_states) - - # Fused - self.router.config.moe_router_fusion = True - scores_fused, routing_fused = self.router(hidden_states) - - assert torch.equal(routing_ref, routing_fused), "Routing map mismatch" - torch.testing.assert_close(scores_ref, scores_fused) - # restore the config - self.router.config.moe_router_fusion = False - - -class TestAuxLossFreeTop2Router: - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1, expert_model_parallel_size=8) - _set_random_seed(seed_=123, data_parallel_random_init=False) - print("done intializing") - num_moe_experts = 8 - self.transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - use_cpu_initialization=True, - expert_model_parallel_size=8, - moe_router_load_balancing_type="none", # No aux loss - moe_router_score_function="sigmoid", # Using sigmoid scoring - moe_router_enable_expert_bias=True, # Enable expert bias - moe_router_bias_update_rate=0.1, # Set bias update rate - moe_router_topk=2, - bf16=True, - params_dtype=torch.bfloat16, - add_bias_linear=False, - ) - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False - ) - self.moe_layer = MoELayer( - self.transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - self.router = self.moe_layer.router - assert self.router.expert_bias is not None - assert self.router.local_tokens_per_expert is not None - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_router_forward_aux_free(self): - hidden_states = torch.randn((32, 2, self.router.config.hidden_size)) - hidden_states = hidden_states.cuda().bfloat16() - self.router = self.router.cuda() - - # First forward pass - initial_bias = self.router.expert_bias.clone() - scores1, indices1 = self.router(hidden_states) - initial_tokens = self.router.local_tokens_per_expert.clone() - updated_bias = get_updated_expert_bias( - self.router.local_tokens_per_expert, - self.router.expert_bias, - self.router.config.moe_router_bias_update_rate, - ) - - # Verify expert bias was updated - assert not torch.equal(initial_bias, updated_bias), "Expert bias should be updated" - - # Basic output checks - assert scores1.shape == (64, 8), "Router scores shape mismatch" - assert indices1.shape == (64, 8), "Router indices shape mismatch" - - # Print some debug info - print("Updated bias after first forward pass:", updated_bias) - - @pytest.mark.internal - @pytest.mark.skipif( - not torch.cuda.is_available() or not HAVE_ROUTER_FUSION, - reason="TE fused router ops not available", - ) - @pytest.mark.parametrize("score_function", ["sigmoid", "softmax"]) - def test_router_forward_fusion_equivalence(self, score_function): - with torch.no_grad(): - # Build two fresh routers to avoid bias update interference - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=self.transformer_config.num_moe_experts, moe_grouped_gemm=False - ) - moe_layer_ref = MoELayer( - self.transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - moe_layer_fused = MoELayer( - self.transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - router_ref = moe_layer_ref.router.cuda() - router_fused = moe_layer_fused.router.cuda() - - # Ensure identical initial parameters/state - router_fused.weight.copy_(router_ref.weight) - expert_bias_sample = torch.randn_like(router_ref.expert_bias) - router_ref.expert_bias.copy_(expert_bias_sample) - router_fused.expert_bias.copy_(expert_bias_sample) - - router_ref.config.moe_router_score_function = score_function - router_fused.config.moe_router_score_function = score_function - - hidden_states = torch.randn((32, 2, router_ref.config.hidden_size)) - hidden_states = hidden_states.cuda().bfloat16() - - # Unfused - router_ref.config.moe_router_fusion = False - scores_ref, routing_ref = router_ref(hidden_states) - - # Fused - router_fused.config.moe_router_fusion = True - scores_fused, routing_fused = router_fused(hidden_states) - - assert torch.equal(routing_ref, routing_fused) - torch.testing.assert_close(scores_ref, scores_fused) - - -@pytest.mark.internal -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("router_dtype", [torch.bfloat16, torch.float32, torch.float64]) -def test_router_gating_linear(router_dtype): - tols = dict(rtol=2.0e-2, atol=1.0e-3) - - ref_inp = torch.randn((4096, 7168), dtype=torch.bfloat16, device="cuda") - ref_weight = torch.randn((256, 7168), dtype=torch.bfloat16, device="cuda") - ref_inp.requires_grad = True - ref_weight.requires_grad = True - bwd_input = torch.randn((4096, 256), dtype=router_dtype, device="cuda") - - ref_output = torch.nn.functional.linear(ref_inp.to(router_dtype), ref_weight.to(router_dtype)) - ref_output.backward(bwd_input) - - inp = ref_inp.detach() - weight = ref_weight.detach() - inp.requires_grad = True - weight.requires_grad = True - output = router_gating_linear(inp, weight, router_dtype) - output.backward(bwd_input) - - assert output.dtype == router_dtype - assert ref_inp.grad.dtype == ref_inp.dtype - assert ref_weight.grad.dtype == ref_weight.dtype - assert torch.allclose(output, ref_output, **tols) - assert torch.allclose(inp.grad, ref_inp.grad, **tols) - assert torch.allclose(weight.grad, ref_weight.grad, **tols) diff --git a/tests/unit_tests/transformer/moe/test_sequential_mlp.py b/tests/unit_tests/transformer/moe/test_sequential_mlp.py deleted file mode 100644 index 132cce57fa..0000000000 --- a/tests/unit_tests/transformer/moe/test_sequential_mlp.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -from importlib.metadata import version - -import pytest -import torch - -from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.mlp import MLPSubmodules -from megatron.core.transformer.moe.experts import SequentialMLP -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.moe.moe_utils import get_default_model_comm_pgs -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from tests.unit_tests.test_utilities import Utils - - -class TestParallelSequentialMLP: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - print("done intializing") - num_moe_experts = 2 - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - use_cpu_initialization=True, - activation_func=torch.nn.functional.silu, - gated_linear_unit=True, - bias_activation_fusion=True, - moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1, - add_bias_linear=False, - ) - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False - ) - self.sequential_mlp = MoELayer( - transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_constructor(self): - assert isinstance(self.sequential_mlp, MoELayer) - - num_weights = sum([p.numel() for p in self.sequential_mlp.parameters()]) - assert num_weights == 3480 - - @pytest.mark.internal - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_gpu_forward(self): - sequential_mlp = self.sequential_mlp - sequential_mlp.cuda() - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((32, 2, sequential_mlp.config.hidden_size)) - hidden_states = hidden_states.cuda() - output, output_bias = sequential_mlp(hidden_states) - assert output.shape[0] == 32 - assert output.shape[1] == 2 - assert output.shape[2] == sequential_mlp.config.hidden_size - assert output.dtype == torch.float32 - assert output.device.type == 'cuda' - - -class TestTEParallelSequentialMLP: - def setup_method(self, method): - Utils.initialize_model_parallel(tensor_model_parallel_size=2, expert_model_parallel_size=2) - model_parallel_cuda_manual_seed(123) - num_moe_experts = 4 - model_comm_pgs = get_default_model_comm_pgs() - self.transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - use_cpu_initialization=False, - activation_func=torch.nn.functional.silu, - gated_linear_unit=True, - bias_activation_fusion=False, - moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1, - params_dtype=torch.bfloat16, - expert_model_parallel_size=2, - tensor_model_parallel_size=2, - sequence_parallel=True, - add_bias_linear=False, - ) - - self.local_mlp_spec = MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear - ) - self.te_mlp_spec = MLPSubmodules( - linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear - ) - print("Done intializing") - - self.num_local_experts = 2 - model_parallel_cuda_manual_seed(123) - self.local_sequential_mlp = SequentialMLP( - self.num_local_experts, - self.transformer_config, - self.local_mlp_spec, - model_comm_pgs=model_comm_pgs, - ) - - model_parallel_cuda_manual_seed(123) - self.te_sequential_mlp = SequentialMLP( - self.num_local_experts, - self.transformer_config, - self.te_mlp_spec, - model_comm_pgs=model_comm_pgs, - ) - - @pytest.mark.internal - @pytest.mark.skipif( - not is_te_min_version("1.7.0"), - reason="Transformer Engine under v1.7.0 doesn't support MoE training.", - ) - @pytest.mark.internal - def test_constructor(self): - for i in range(self.num_local_experts): - assert torch.equal( - self.local_sequential_mlp.local_experts[i].linear_fc1.weight, - self.te_sequential_mlp.local_experts[i].linear_fc1.weight, - ) - assert torch.equal( - self.local_sequential_mlp.local_experts[i].linear_fc2.weight, - self.te_sequential_mlp.local_experts[i].linear_fc2.weight, - ) - - @pytest.mark.internal - @pytest.mark.skipif( - not is_te_min_version("1.7.0"), - reason="Transformer Engine under v1.7.0 doesn't support MoE training.", - ) - @pytest.mark.internal - def test_gpu_forward(self): - self.local_sequential_mlp.cuda() - self.te_sequential_mlp.cuda() - seq_len = 4 - batch_size = 2 - - tokens_per_expert = torch.tensor([2, 2], device="cuda") - hidden_states = torch.rand( - (seq_len, batch_size, self.local_sequential_mlp.config.hidden_size), - dtype=torch.bfloat16, - device="cuda", - ) - probs = torch.rand((seq_len, batch_size), dtype=torch.float32, device="cuda") - - output_local, _ = self.local_sequential_mlp(hidden_states, tokens_per_expert, probs) - output_te, _ = self.te_sequential_mlp(hidden_states, tokens_per_expert, probs) - assert torch.equal(output_local, output_te) - - @pytest.mark.internal - @pytest.mark.skipif( - not is_te_min_version("1.7.0"), - reason="Transformer Engine under v1.7.0 doesn't support MoE training.", - ) - @pytest.mark.internal - def test_gpu_forward_with_one_local_expert(self): - model_parallel_cuda_manual_seed(123) - model_comm_pgs = get_default_model_comm_pgs() - local_sequential_mlp = SequentialMLP( - 1, self.transformer_config, self.local_mlp_spec, model_comm_pgs=model_comm_pgs - ) - model_parallel_cuda_manual_seed(123) - te_sequential_mlp = SequentialMLP( - 1, self.transformer_config, self.te_mlp_spec, model_comm_pgs=model_comm_pgs - ) - seq_len = 4 - batch_size = 2 - - tokens_per_expert = torch.tensor([4], device="cuda") - hidden_states = torch.rand( - (seq_len, batch_size, self.local_sequential_mlp.config.hidden_size), - dtype=torch.bfloat16, - device="cuda", - ) - probs = torch.rand((seq_len, batch_size), dtype=torch.float32, device="cuda") - - output_local, _ = local_sequential_mlp(hidden_states, tokens_per_expert, probs) - output_te, _ = te_sequential_mlp(hidden_states, tokens_per_expert, probs) - assert torch.equal(output_local, output_te) - - @pytest.mark.internal - @pytest.mark.skipif( - not is_te_min_version("1.7.0"), - reason="Transformer Engine under v1.7.0 doesn't support MoE training.", - ) - @pytest.mark.internal - def test_gpu_forward_with_no_tokens_allocated(self): - self.local_sequential_mlp.cuda() - self.te_sequential_mlp.cuda() - seq_len = 4 - batch_size = 2 - - tokens_per_expert = torch.tensor([0, 4], device="cuda") - hidden_states = torch.rand( - (seq_len, batch_size, self.local_sequential_mlp.config.hidden_size), - dtype=torch.bfloat16, - device="cuda", - ) - probs = torch.rand((seq_len, batch_size), dtype=torch.float32, device="cuda") - - output_local, _ = self.local_sequential_mlp(hidden_states, tokens_per_expert, probs) - output_te, _ = self.te_sequential_mlp(hidden_states, tokens_per_expert, probs) - assert torch.equal(output_local, output_te) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - -if __name__ == "__main__": - MLP_test = TestTEParallelSequentialMLP() - MLP_test.setup_method(method=None) - MLP_test.test_constructor() - MLP_test.test_gpu_forward() - MLP_test.test_gpu_forward_with_one_local_expert() - MLP_test.test_gpu_forward_with_no_tokens_allocated() - MLP_test.teardown_method(method=None) diff --git a/tests/unit_tests/transformer/moe/test_shared_experts.py b/tests/unit_tests/transformer/moe/test_shared_experts.py deleted file mode 100644 index f721c48293..0000000000 --- a/tests/unit_tests/transformer/moe/test_shared_experts.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestSharedExperts: - - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - def test_gpu_forward(self): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - print("done intializing") - num_moe_experts = 2 - transformer_config = TransformerConfig( - num_layers=1, - hidden_size=12, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - moe_shared_expert_intermediate_size=32, - use_cpu_initialization=True, - activation_func=torch.nn.functional.silu, - gated_linear_unit=True, - bias_activation_fusion=True, - moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1, - add_bias_linear=False, - ) - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False - ) - self.moe_layer = MoELayer( - transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - - assert isinstance(self.moe_layer, MoELayer) - - num_weights = sum([p.numel() for p in self.moe_layer.parameters()]) - assert num_weights == 3480 + 1152 - assert self.moe_layer.shared_experts is not None - assert self.moe_layer.shared_experts.stream is None - assert self.moe_layer.token_dispatcher.shared_experts is None - - moe_layer = self.moe_layer - moe_layer.cuda() - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((32, 2, moe_layer.config.hidden_size)) - hidden_states = hidden_states.cuda() - output, _ = moe_layer(hidden_states) - assert output.shape[0] == 32 - assert output.shape[1] == 2 - assert output.shape[2] == moe_layer.config.hidden_size - assert output.dtype == torch.float32 - assert output.device.type == 'cuda' - - -class TestSharedExpertsOverlap: - - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - def test_gpu_forward(self): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - print("done intializing") - num_moe_experts = 2 - transformer_config = TransformerConfig( - num_layers=1, - hidden_size=12, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - moe_shared_expert_intermediate_size=32, - moe_shared_expert_overlap=True, - moe_token_dispatcher_type="alltoall", - use_cpu_initialization=True, - activation_func=torch.nn.functional.silu, - gated_linear_unit=True, - bias_activation_fusion=True, - moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1, - add_bias_linear=False, - ) - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False - ) - self.moe_layer = MoELayer( - transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - - assert isinstance(self.moe_layer, MoELayer) - - num_weights = sum([p.numel() for p in self.moe_layer.parameters()]) - assert num_weights == 3480 + 1152 - assert self.moe_layer.shared_experts is not None - assert self.moe_layer.shared_experts.stream is not None - assert self.moe_layer.token_dispatcher.shared_experts is not None - - moe_layer = self.moe_layer - moe_layer.cuda() - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((32, 2, moe_layer.config.hidden_size)) - hidden_states = hidden_states.cuda() - output, _ = moe_layer(hidden_states) - assert output.shape[0] == 32 - assert output.shape[1] == 2 - assert output.shape[2] == moe_layer.config.hidden_size - assert output.dtype == torch.float32 - assert output.device.type == 'cuda' diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py deleted file mode 100644 index ac9fd8e81f..0000000000 --- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py +++ /dev/null @@ -1,481 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import copy -import dataclasses - -import pytest -import torch - -from megatron.core import config, parallel_state -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.moe.moe_utils import get_capacity -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from megatron.training.initialize import _set_random_seed -from tests.unit_tests.test_utilities import Utils - - -def token_permutation(token_dispatcher, hidden_states, probs, indices): - hidden_states, probs = token_dispatcher.dispatch_preprocess(hidden_states, indices, probs) - hidden_states, probs = token_dispatcher.token_dispatch(hidden_states, probs) - hidden_states, tokens_per_expert, permuted_probs = token_dispatcher.dispatch_postprocess( - hidden_states, probs - ) - return hidden_states, tokens_per_expert, permuted_probs - - -def token_unpermutation(token_dispatcher, hidden_states): - hidden_states = token_dispatcher.combine_preprocess(hidden_states) - hidden_states = token_dispatcher.token_combine(hidden_states) - hidden_states = token_dispatcher.combine_postprocess(hidden_states) - return hidden_states, None - - -class MoEModelTestContainer: - def __init__( - self, - tp_size, - ep_size, - pp_size, - cp_size=1, - moe_tp_size=None, - data_parallel_random_init=False, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="alltoall", - moe_expert_capacity_factor=None, - moe_pad_expert_input_to_capacity=False, - moe_aux_loss_coeff=0.1, - **kwargs, - ): - self.num_local_experts = num_moe_experts // ep_size - if moe_tp_size is None: - moe_tp_size = tp_size - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - expert_model_parallel_size=ep_size, - context_parallel_size=cp_size, - expert_tensor_parallel_size=moe_tp_size, - ) - _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init) - local_expert_indices_offset = ( - parallel_state.get_expert_model_parallel_rank() * self.num_local_experts - ) - self.local_expert_indices = [ - local_expert_indices_offset + i for i in range(self.num_local_experts) - ] - self.config = TransformerConfig( - tensor_model_parallel_size=tp_size, - expert_model_parallel_size=ep_size, - pipeline_model_parallel_size=pp_size, - context_parallel_size=cp_size, - expert_tensor_parallel_size=moe_tp_size, - moe_router_topk=moe_router_topk, - num_moe_experts=num_moe_experts, - moe_router_load_balancing_type=moe_router_load_balancing_type, - moe_token_dispatcher_type=moe_token_dispatcher_type, - moe_expert_capacity_factor=moe_expert_capacity_factor, - moe_pad_expert_input_to_capacity=moe_pad_expert_input_to_capacity, - moe_aux_loss_coeff=moe_aux_loss_coeff, - num_layers=1, - moe_grouped_gemm=kwargs.get("moe_grouped_gemm", False), - hidden_size=kwargs.get("hidden_size", 16), - num_attention_heads=kwargs.get("num_attention_heads", 8), - use_cpu_initialization=kwargs.get("use_cpu_initialization", True), - sequence_parallel=tp_size > 1, - add_bias_linear=kwargs.get("add_bias_linear", False), - moe_permute_fusion=kwargs.get("moe_permute_fusion", False), - moe_enable_deepep=kwargs.get("moe_enable_deepep", False), - ) - - # init moe layer - self.moe_layer = self.new_moe_layer() - - def new_moe_layer(self, **kargs): - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=self.config.num_moe_experts, moe_grouped_gemm=self.config.moe_grouped_gemm - ) - new_config = dataclasses.replace(self.config, **kargs) - moe_layer = MoELayer(new_config, transformer_layer_spec.submodules.mlp.submodules).cuda() - moe_layer.set_layer_number(0) - return moe_layer - - def __del__(self): - torch.distributed.barrier() - torch.cuda.synchronize() - Utils.destroy_model_parallel() - - @pytest.mark.internal - def dispatcher_dropless_test(self): - moe_layer = self.moe_layer - bs = 32 - seql = 8 - # TODO: Find why setting manual seed can cause the test to fail - # Manual seed to differentiate input data for each rank - # rank = torch.distributed.get_rank() - # torch.manual_seed(1000 + rank) - hidden_states = torch.randn((bs, seql, moe_layer.config.hidden_size)) - hidden_states = hidden_states.cuda() - # Permute and then unpermute data are supposed to restore original data - ans = hidden_states - hidden_states.requires_grad = True - probs, indices = moe_layer.router(hidden_states) - probs = torch.ones_like(probs) / moe_layer.router.topk - - (permuted_local_hidden_states, tokens_per_expert, permuted_probs) = token_permutation( - moe_layer.token_dispatcher, hidden_states, probs, indices - ) - - permuted_local_hidden_states = permuted_local_hidden_states * permuted_probs.unsqueeze(-1) - - restored_hidden_states, restored_bias = token_unpermutation( - moe_layer.token_dispatcher, permuted_local_hidden_states - ) - - # reduce across TP rank equals to multiply data by a scale of ETP - scale = moe_layer.config.expert_tensor_parallel_size - restored_hidden_states = restored_hidden_states / scale - - assert torch.allclose( - restored_hidden_states, ans - ), "Restored hidden states do not match original hidden states" - - # check if the grad of the hidden states is same as the hidden states - torch.autograd.backward(restored_hidden_states, hidden_states) - assert torch.allclose( - hidden_states.grad, ans - ), "Restored hidden states do not match original hidden states" - - @pytest.mark.internal - def dispatcher_capacity_test(self): - moe_layer = self.moe_layer - num_tokens = 16 - hidden_states = torch.randn((num_tokens, moe_layer.config.hidden_size)) - hidden_states = hidden_states.cuda() - hidden_states.requires_grad = True - probs, indices = moe_layer.router(hidden_states) - - # Create the answer. - prob_mask = probs != 0 - probs = torch.ones_like(probs) * prob_mask / moe_layer.router.topk - local_probss = probs - restored_hidden_states_answer = hidden_states * local_probss.sum(dim=1).unsqueeze(1) - - (permuted_local_hidden_states, tokens_per_expert, permuted_probs) = token_permutation( - moe_layer.token_dispatcher, hidden_states, probs, indices - ) - - # Check tokens per expert not exceed the capacity. - capacity = get_capacity( - num_tokens * self.config.moe_router_topk, - self.config.num_moe_experts, - self.config.moe_expert_capacity_factor, - ) - assert torch.all( - tokens_per_expert - <= capacity - * self.config.expert_model_parallel_size - * self.config.tensor_model_parallel_size - ), "Tokens per expert exceed the capacity" - - permuted_local_hidden_states = permuted_local_hidden_states * permuted_probs.unsqueeze(-1) - - permuted_local_hidden_states /= moe_layer.config.tensor_model_parallel_size - - restored_hidden_states, restored_bias = token_unpermutation( - moe_layer.token_dispatcher, permuted_local_hidden_states - ) - assert torch.allclose( - restored_hidden_states, restored_hidden_states_answer - ), "Restored hidden states does not match" - - # check if the grad of the hidden states is same as the hidden states - torch.autograd.backward(restored_hidden_states, hidden_states) - assert torch.allclose( - hidden_states.grad, restored_hidden_states_answer - ), "Gradient of hidden states should be same as hidden states" - - @pytest.mark.internal - def dispatcher_drop_and_pad_test(self): - """Test if the tokens are dropped and padded correctly. - - Since the probs of padded tokens are 0, the combined results for - dispatching with or without padding should be the same. - """ - moe_layer = self.new_moe_layer(moe_pad_expert_input_to_capacity=False) - - num_tokens = 16 - hidden_states = torch.randn((num_tokens, moe_layer.config.hidden_size)).cuda() - hidden_states.requires_grad = True - - probs_1, indices_1 = moe_layer.router(hidden_states) - (permuted_input_1, tokens_per_expert, permuted_probs_1) = token_permutation( - moe_layer.token_dispatcher, hidden_states, probs_1, indices_1 - ) - permuted_input_1 = permuted_input_1 * permuted_probs_1.unsqueeze(-1) - forward_answer, restored_bias = token_unpermutation( - moe_layer.token_dispatcher, permuted_input_1 - ) - torch.autograd.backward(forward_answer, forward_answer) - backward_answer = hidden_states.grad.clone() - hidden_states.grad = None - torch.cuda.synchronize() - # End - - moe_layer_2 = self.new_moe_layer(moe_pad_expert_input_to_capacity=True) - moe_layer_2.load_state_dict(moe_layer.state_dict()) - - probs_2, indices_2 = moe_layer_2.router(hidden_states) - (permuted_input_2, tokens_per_expert, permuted_probs_2) = token_permutation( - moe_layer_2.token_dispatcher, hidden_states, probs_2, indices_2 - ) - permuted_input_2 = permuted_input_2 * permuted_probs_2.unsqueeze(-1) - restored_hidden_states, restored_bias = token_unpermutation( - moe_layer_2.token_dispatcher, permuted_input_2 - ) - - # # Check tokens per expert equals to the capacity. - capacity = get_capacity( - num_tokens * self.config.moe_router_topk, - self.config.num_moe_experts, - self.config.moe_expert_capacity_factor, - ) - assert torch.all( - tokens_per_expert - == capacity - * self.config.expert_model_parallel_size - * self.config.tensor_model_parallel_size - ), "Tokens per expert should be the same as the capacity" - assert torch.allclose( - restored_hidden_states, forward_answer - ), "Restored hidden states does not match" - - # check if the grad of the hidden states is same as the hidden states - torch.autograd.backward(restored_hidden_states, restored_hidden_states) - assert torch.allclose( - hidden_states.grad, backward_answer - ), "Gradient of hidden states should be same as hidden states" - - @pytest.mark.internal - def dispatcher_router_padding_for_fp8_test(self): - """Test if the routing map is padded correctly for FP8 training. - - The test runs the forward flow twice: - 1. First with moe_router_padding_for_fp8=False - 2. Then with moe_router_padding_for_fp8=True - - We verify that: - 1. The results are the same in both cases - 2. The number of tokens received by each expert is padded to a multiple of 16 - """ - # First run with moe_router_padding_for_fp8 = False - moe_layer = self.new_moe_layer(moe_router_padding_for_fp8=False) - - num_tokens = 32 - hidden_states = torch.randn((num_tokens, moe_layer.config.hidden_size)).cuda() - hidden_states.requires_grad = True - - probs_1, indices_1 = moe_layer.router(hidden_states) - (permuted_input_1, tokens_per_expert_1, permuted_probs_1) = token_permutation( - moe_layer.token_dispatcher, hidden_states, probs_1, indices_1 - ) - permuted_input_1 = permuted_input_1 * permuted_probs_1.unsqueeze(-1) - restored_hidden_states_1, _ = token_unpermutation( - moe_layer.token_dispatcher, permuted_input_1 - ) - torch.autograd.backward(restored_hidden_states_1, restored_hidden_states_1) - grad_1 = hidden_states.grad.clone() - hidden_states.grad = None - - # Run with moe_router_padding_for_fp8 = True - moe_layer_2 = self.new_moe_layer(moe_router_padding_for_fp8=True, fp8="hybrid") - moe_layer_2.load_state_dict(moe_layer.state_dict()) - - probs_2, indices_2 = moe_layer_2.router(hidden_states) - (permuted_input_2, tokens_per_expert_2, permuted_probs_2) = token_permutation( - moe_layer_2.token_dispatcher, hidden_states, probs_2, indices_2 - ) - assert ( - sum(tokens_per_expert_2) == permuted_input_2.shape[0] - ), f"number of tokens is not the same, {sum(tokens_per_expert_2)} != {permuted_input_2.shape[0]}" - # when there is only one expert, the tokens is not enough for router padding - if moe_layer_2.num_local_experts > 1: - assert torch.all( - tokens_per_expert_2 % 16 == 0 - ), "number of tokens for expert is not a multiple of 16" - - permuted_input_2 = permuted_input_2 * permuted_probs_2.unsqueeze(-1) - restored_hidden_states_2, _ = token_unpermutation( - moe_layer_2.token_dispatcher, permuted_input_2 - ) - - # Check that the results are the same - assert torch.allclose( - restored_hidden_states_1, restored_hidden_states_2 - ), "Restored hidden states do not match between padded and non-padded versions" - - # Check gradients - torch.autograd.backward(restored_hidden_states_2, restored_hidden_states_2) - assert torch.allclose( - grad_1, hidden_states.grad - ), "Gradients do not match between padded and non-padded versions" - - def set_params(self): - # TODO: Set consistent parameters for various parallelisms. - raise NotImplementedError - - def destroy(self): - Utils.destroy_model_parallel() - - -permute_fusion_params = [False] -if is_te_min_version("2.1.0"): - permute_fusion_params.append(True) - - -class TestAllgatherDispatcher: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4), (1, 1)]) - @pytest.mark.parametrize("permute_fusion", permute_fusion_params) - def test_forward_backward(self, tp_size, ep_size, permute_fusion): - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="allgather", - moe_permute_fusion=permute_fusion, - ) - - container.dispatcher_dropless_test() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.parametrize("permute_fusion", permute_fusion_params) - @pytest.mark.parametrize( - "tp_size,ep_size,moe_tp_size", [(1, 1, 8), (1, 2, 4), (1, 4, 2), (2, 2, 4)] - ) - def test_moe_tp_forward_backward(self, tp_size, ep_size, moe_tp_size, permute_fusion): - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - moe_tp_size=moe_tp_size, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="allgather", - sequence_parallel=True, - moe_permute_fusion=permute_fusion, - use_cpu_initialization=False, - ) - - container.dispatcher_dropless_test() - - -def is_deep_ep_available(): - from megatron.core.transformer.moe.fused_a2a import HAVE_DEEP_EP - - return HAVE_DEEP_EP - - -@pytest.mark.skipif(not is_deep_ep_available(), reason="Deep EP is not available") -class TestFlexDispatcher: - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4)]) - @pytest.mark.parametrize("permute_fusion", permute_fusion_params) - @pytest.mark.parametrize("experimental_fusion", [True, False]) - def test_forward_backward(self, tp_size, ep_size, permute_fusion, experimental_fusion): - if experimental_fusion: - config.ENABLE_EXPERIMENTAL = True - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="flex", - moe_permute_fusion=permute_fusion, - hidden_size=4, - moe_enable_deepep=True, - ) - container.dispatcher_dropless_test() - # reset experimental flag to False - config.ENABLE_EXPERIMENTAL = False - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.timeout(120) - @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2)]) - @pytest.mark.parametrize("permute_fusion", permute_fusion_params) - @pytest.mark.parametrize("experimental_fusion", [True, False]) - def test_capacity_forward_backward(self, tp_size, ep_size, permute_fusion, experimental_fusion): - if experimental_fusion: - config.ENABLE_EXPERIMENTAL = True - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - num_moe_experts=8, - moe_router_topk=2, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="flex", - moe_token_drop_policy="probs", - moe_expert_capacity_factor=0.5, - moe_pad_expert_input_to_capacity=False, - moe_permute_fusion=permute_fusion, - hidden_size=4, - moe_enable_deepep=True, - ) - container.dispatcher_capacity_test() - config.ENABLE_EXPERIMENTAL = False - - @pytest.mark.skipif( - not is_te_min_version("1.7.0"), reason="TE 1.7.0 is required for MoE with FP8." - ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.timeout(180) - @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2)]) - @pytest.mark.parametrize("permute_fusion", permute_fusion_params) - @pytest.mark.parametrize("experimental_fusion", [True, False]) - def test_router_padding_for_fp8_forward_backward( - self, tp_size, ep_size, permute_fusion, experimental_fusion - ): - if experimental_fusion: - config.ENABLE_EXPERIMENTAL = True - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - num_moe_experts=32, - moe_router_topk=4, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="flex", - moe_pad_expert_input_to_capacity=False, - moe_permute_fusion=permute_fusion, - hidden_size=4, - moe_enable_deepep=True, - ) - container.dispatcher_router_padding_for_fp8_test() - config.ENABLE_EXPERIMENTAL = False diff --git a/tests/unit_tests/transformer/moe/test_upcycling.py b/tests/unit_tests/transformer/moe/test_upcycling.py deleted file mode 100644 index 9cb7aa327b..0000000000 --- a/tests/unit_tests/transformer/moe/test_upcycling.py +++ /dev/null @@ -1,311 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import os -import sys - -import pytest -import torch -import torch.distributed - -from megatron.core import mpu -from megatron.core.enums import ModelType -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.moe import upcycling_utils -from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP -from megatron.core.utils import get_te_version, is_te_min_version -from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args -from megatron.training.global_vars import ( - destroy_global_vars, - get_args, - set_args, - set_global_variables, -) -from megatron.training.training import get_model, setup_model_and_optimizer -from megatron.training.utils import ( - get_batch_on_this_cp_rank, - get_batch_on_this_tp_rank, - unwrap_model, -) -from tests.unit_tests.test_utilities import Utils - -try: - from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear - - HAVE_TE = True -except ImportError: - HAVE_TE = False - -_SEED = 42 - - -def _find_submodule(model, submodule_name): - """ - Find sub-module in model - """ - for name, submodule in model.named_modules(): - if name.endswith("." + submodule_name) or name == submodule_name: - return submodule - return None - - -def model_provider( - pre_process=True, - post_process=True, - layer_spec_fn=get_gpt_layer_with_transformer_engine_spec, - **config_kwargs, -): - model_parallel_cuda_manual_seed(_SEED) - args = get_args() - - config = core_transformer_config_from_args(args) - use_te = args.transformer_impl == "transformer_engine" - if use_te: - layer_spec_fn = get_gpt_layer_with_transformer_engine_spec - else: - layer_spec_fn = get_gpt_layer_local_spec - - model = GPTModel( - config=config, - transformer_layer_spec=layer_spec_fn( - args.num_experts, args.moe_grouped_gemm, args.qk_layernorm - ), - vocab_size=args.vocal_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - ) - return model - - -def create_test_args(tp, grouped_gemm, swiglu, squared_relu, use_te): - destroy_global_vars() - destroy_num_microbatches_calculator() - - sys.argv = ['test_upcycling.py'] - args = parse_args() - args.num_layers = 2 - args.vocal_size = 256 - args.hidden_size = 128 - args.num_attention_heads = 8 - args.max_position_embeddings = 256 - args.micro_batch_size = 1 - args.create_attention_mask_in_dataloader = True - args.seq_length = 256 - args.tensor_model_parallel_size = tp - if tp > 1: - # During training, performance may degrade if MoE and tensor - # parallelismare enabled without also enabling sequence parallelism. - args.sequence_parallel = True - args.context_parallel_size = 1 - args.num_experts = None - args.train_iters = 1 - args.ckpt_format = 'torch_dist' - args.moe_router_topk = 2 - args.moe_router_pre_softmax = False - args.lr = 3e-5 - args.attention_dropout = 0.0 - args.hidden_dropout = 0.0 - args.async_tensor_model_parallel_allreduce = False - args.no_save_optim = True - args.no_load_optim = True - args.no_load_rng = True - args.moe_grouped_gemm = grouped_gemm - args.transformer_impl = "transformer_engine" if use_te else "local" - args.bf16 = True - args.add_bias_linear = False - args.moe_token_dispatcher_type = "alltoall" - - args.swiglu = swiglu - args.squared_relu = squared_relu - if args.squared_relu == True: - assert args.swiglu == False, 'must set swiglu=False while squared_relu==True' - args.bias_gelu_fusion = False - args.bias_swiglu_fusion = False - - validate_args(args) - set_global_variables(args, False) - return args - - -def set_upcycling_args(ep, granularity, num_experts=8): - args = get_args() - args.moe_use_upcycling = True - args.num_experts = num_experts - args.expert_model_parallel_size = ep - args.moe_upcycling_granularity = granularity - dense_ffn_hidden_size = args.ffn_hidden_size - args.ffn_hidden_size = dense_ffn_hidden_size // args.moe_upcycling_granularity - args.moe_ffn_hidden_size = dense_ffn_hidden_size // args.moe_upcycling_granularity - set_args(args) - - -def set_bias_value(dense_model): - # change the bias value, make sure they are not zero - state_dict = dense_model[0].state_dict() - for name in state_dict: - if name.endswith("bias"): - value = state_dict[name] - value = torch.randn(value.shape) - state_dict[name] = value - dense_model[0].load_state_dict(state_dict, strict=True) - - -def get_batch(data_iterator): - if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None - - batch = get_batch_on_this_tp_rank(data_iterator) - batch = get_batch_on_this_cp_rank(batch) - - return batch.values() - - -class TestGPTModel: - def setup_method(self, method): - Utils.destroy_model_parallel() - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - - def teardown_method(self, method): - Utils.destroy_model_parallel() - destroy_global_vars() - destroy_num_microbatches_calculator() - - @pytest.mark.parametrize( - ('tp_ep', 'granularity', 'grouped_gemm', 'swiglu', 'squared_relu'), - [pytest.param((1, 1), 1, False, False, False)], - ) - def test_upcycling_Local(self, tp_ep, granularity, grouped_gemm, swiglu, squared_relu): - tp = tp_ep[0] - ep = tp_ep[1] - args = create_test_args(tp, grouped_gemm, swiglu, squared_relu, use_te=False) - - torch.manual_seed(_SEED) - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp, - virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, - ) - - dense_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, ModelType.encoder_or_decoder - ) - data = list(range(args.seq_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda() - position_ids = ( - torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda() - ) - attention_mask = torch.ones( - (args.micro_batch_size, 1, args.seq_length, args.seq_length), dtype=bool - ).cuda() - dense_model = unwrap_model(dense_model) - set_bias_value(dense_model) - dense_logits = dense_model[0].forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - - Utils.destroy_model_parallel() - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp, expert_model_parallel_size=ep - ) - set_upcycling_args(ep, granularity, num_experts=2) - # model_parallel_cuda_manual_seed(_SEED+1) - moe_model = get_model(model_provider, ModelType.encoder_or_decoder) - - # Upcycle the dense model to the MoE model - moe_model = unwrap_model(moe_model) - - state_dict = upcycling_utils.upcycle_state_dict(moe_model, dense_model) - if len(moe_model) == 1: - moe_model[0].load_state_dict(state_dict['model'], strict=True) - else: - for i in range(len(moe_model)): - moe_model[i].load_state_dict(state_dict['model%d' % i], strict=True) - - moe_logits = moe_model[0].forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - # Compare the outputs of the MoE model and the dense model. - assert torch.allclose( - moe_logits, dense_logits, rtol=1e-01, atol=1e-01 - ), "The output of moe model do not match the output of dense model." - - @pytest.mark.skipif( - not HAVE_TE or not is_te_min_version("2.1.0"), - reason="grouped_gemm requires TransformerEngine >= 2.1.0", - ) - @pytest.mark.parametrize( - ('tp_ep', 'granularity', 'grouped_gemm', 'swiglu', 'squared_relu'), - [ - pytest.param((1, 2), 1, False, False, False), - pytest.param((1, 2), 2, False, False, False), - pytest.param((1, 2), 1, True, False, False), - pytest.param((2, 1), 1, True, False, False), - pytest.param((1, 2), 2, True, False, False), - pytest.param((1, 2), 2, True, False, True), - pytest.param((1, 2), 2, True, True, False), - ], - ) - def test_upcycling_TE(self, tp_ep, granularity, grouped_gemm, swiglu, squared_relu): - tp = tp_ep[0] - ep = tp_ep[1] - args = create_test_args(tp, grouped_gemm, swiglu, squared_relu, use_te=True) - set_args(args) - - torch.manual_seed(_SEED) - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp, - virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, - ) - - dense_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, ModelType.encoder_or_decoder - ) - data = list(range(args.seq_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda() - position_ids = ( - torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda() - ) - attention_mask = torch.ones( - (args.micro_batch_size, 1, args.seq_length, args.seq_length), dtype=bool - ).cuda() - dense_model = unwrap_model(dense_model) - set_bias_value(dense_model) - dense_logits = dense_model[0].forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - - Utils.destroy_model_parallel() - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp, expert_model_parallel_size=ep - ) - set_upcycling_args(ep, granularity) - # model_parallel_cuda_manual_seed(_SEED+1) - moe_model = get_model(model_provider, ModelType.encoder_or_decoder) - - # Upcycle the dense model to the MoE model - moe_model = unwrap_model(moe_model) - - state_dict = upcycling_utils.upcycle_state_dict(moe_model, dense_model) - if len(moe_model) == 1: - moe_model[0].load_state_dict(state_dict['model'], strict=True) - else: - for i in range(len(moe_model)): - mpu.set_virtual_pipeline_model_parallel_rank(i) - moe_model[i].load_state_dict(state_dict['model%d' % i], strict=True) - - moe_logits = moe_model[0].forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - # Compare the outputs of the MoE model and the dense model. - assert torch.allclose( - moe_logits, dense_logits, rtol=1e-01, atol=1e-01 - ), "The output of moe model do not match the output of dense model." diff --git a/tests/unit_tests/transformer/test_attention.py b/tests/unit_tests/transformer/test_attention.py deleted file mode 100644 index a43251084f..0000000000 --- a/tests/unit_tests/transformer/test_attention.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import copy - -import pytest -import torch -from packaging import version - -import megatron.core.parallel_state as parallel_state -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.attention import SelfAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.utils import is_te_min_version -from tests.unit_tests.test_utilities import Utils - - -class TestParallelAttention: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = TransformerConfig( - num_layers=2, - hidden_size=128, - num_attention_heads=4, - use_cpu_initialization=True, - bf16=True, - params_dtype=torch.bfloat16, - ) - self.parallel_attention = SelfAttention( - self.transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - assert isinstance(self.parallel_attention, SelfAttention) - assert self.parallel_attention.layer_number == 1 - - num_weights = sum([p.numel() for p in self.parallel_attention.parameters()]) - assert num_weights == 66304 - - def test_cpu_forward(self): - # we can't currently do this because the global memory buffer is on GPU - pass - - def test_gpu_forward(self): - - config = self.parallel_attention.config - sequence_length = 32 - micro_batch_size = 2 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size), - dtype=torch.bfloat16, - ) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((micro_batch_size, 1, 1, sequence_length), dtype=bool).cuda() - - output, bias = self.parallel_attention(hidden_states, attention_mask) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - @pytest.mark.skipif(not is_te_min_version("1.4.0"), reason="Fused RoPE requires TE >= 1.4.0") - @pytest.mark.parametrize("rotary_interleaved", [True, False]) - def test_fused_rope_gpu_forward(self, rotary_interleaved): - self.parallel_attention.config.apply_rope_fusion = True - if rotary_interleaved and not is_te_min_version("2.3.0"): - pytest.skip("Only TE >= 2.3.0 supports interleaved fused RoPE.") - self.parallel_attention.config.rotary_interleaved = rotary_interleaved - config = self.parallel_attention.config - sequence_length = 32 - micro_batch_size = 2 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size), - dtype=torch.bfloat16, - ) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((micro_batch_size, 1, 1, sequence_length), dtype=bool).cuda() - rotary_pos_emb = torch.ones( - sequence_length, 1, 1, self.parallel_attention.config.kv_channels - ).cuda() - output, bias = self.parallel_attention( - hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb - ) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - self.parallel_attention.config.apply_rope_fusion = False - self.parallel_attention.config.rotary_interleaved = False - - def test_checkpointed_gpu_forward(self): - transformer_config = self.transformer_config - transformer_config.recompute_granularity = 'selective' - checkpointed_parallel_attention = SelfAttention( - transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1, - ) - config = checkpointed_parallel_attention.config - - sequence_length = 32 - micro_batch_size = 2 - - checkpointed_parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, checkpointed_parallel_attention.config.hidden_size), - dtype=torch.bfloat16, - ) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((micro_batch_size, 1, 1, sequence_length), dtype=bool).cuda() - - output, bias = checkpointed_parallel_attention(hidden_states, attention_mask) - - assert config.recompute_granularity == 'selective' - assert "core_attn" in config.recompute_modules - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - -class TestSelfAttention: - - def setup_method(self, method): - Utils.destroy_model_parallel() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def run_self_attention(self, model_comm_pgs): - tensor_model_parallel_size = torch.distributed.get_world_size(model_comm_pgs.tp) - self.transformer_config = TransformerConfig( - num_layers=2, - hidden_size=128, - num_attention_heads=4, - tensor_model_parallel_size=tensor_model_parallel_size, - use_cpu_initialization=False, - ) - self.self_attention = SelfAttention( - self.transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1, - attn_mask_type=AttnMaskType.causal, - model_comm_pgs=model_comm_pgs, - ) - - config = self.self_attention.config - sequence_length = 127 - micro_batch_size = 2 - - self.self_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.self_attention.config.hidden_size), - device='cuda', - ) - hidden_states_ref = copy.deepcopy(hidden_states) - - output, bias = self.self_attention(hidden_states, None) - assert config.recompute_granularity is None - # Check if output and bias have the correct shape - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - @pytest.mark.internal - def test_self_attention_mpu(self): - - tp_size = 4 - cp_size = 2 - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, context_parallel_size=cp_size - ) - model_parallel_cuda_manual_seed(123) - - # Get TP and CP process groups from device mesh - tp_group = parallel_state.get_tensor_model_parallel_group() - cp_group = parallel_state.get_context_parallel_group() - - model_comm_pgs = ModelCommProcessGroups(tp=tp_group, cp=cp_group) - - self.run_self_attention(model_comm_pgs) - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - @pytest.mark.internal - def test_self_attention_independent_pg_smoke(self): - - tp_size = 4 - cp_size = 2 - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, context_parallel_size=cp_size - ) - model_parallel_cuda_manual_seed(123) - - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - # Create HyperCommGrid with dimensions cp, tp (reversed from device mesh order) - grid = HyperCommGrid([cp_size, tp_size], ["cp", "tp"]) - - # Get TP and CP process groups from HyperCommGrid - tp_group = grid.create_pg("tp") - cp_group = grid.create_pg("cp") - - model_comm_pgs = ModelCommProcessGroups(tp=tp_group, cp=cp_group) - - self.run_self_attention(model_comm_pgs) diff --git a/tests/unit_tests/transformer/test_attention_no_rope.py b/tests/unit_tests/transformer/test_attention_no_rope.py deleted file mode 100644 index 30e11609e5..0000000000 --- a/tests/unit_tests/transformer/test_attention_no_rope.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.attention import SelfAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestParallelAttentionWithNoRope: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - # use BF16 and a large enough hidden size to enable FlashAttention - self.transformer_config = TransformerConfig( - num_layers=8, # Using 8 layers to test patterns like [0,0,0,1,0,0,0,1] - hidden_size=64, - num_attention_heads=4, - use_cpu_initialization=True, - bf16=True, - params_dtype=torch.bfloat16, - pipeline_dtype=torch.bfloat16, - autocast_dtype=torch.bfloat16, - flash_decode=False, # Ensure flash_decode is off to test RoPE skipping - ) - self.parallel_attention = SelfAttention( - self.transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_integer_no_rope_freq_pattern(self): - """Test that integer no_rope value is correctly converted to pattern.""" - config = self.transformer_config - config.no_rope_freq = 4 # Should convert to [0,0,0,1,0,0,0,1] - config.__post_init__() - - # Verify the pattern conversion - assert isinstance(config.no_rope_freq, list) - assert len(config.no_rope_freq) == config.num_layers - assert config.no_rope_freq == [0, 0, 0, 1, 0, 0, 0, 1] - - def test_custom_no_rope_pattern(self): - """Test custom no_rope pattern.""" - config = self.transformer_config - config.no_rope_freq = [0, 1, 0, 1, 0, 1, 0, 1] # Custom pattern - config.__post_init__() - - # Verify the pattern is preserved - assert isinstance(config.no_rope_freq, list) - assert len(config.no_rope_freq) == config.num_layers - assert config.no_rope_freq == [0, 1, 0, 1, 0, 1, 0, 1] - - def test_gpu_forward_with_no_rope(self): - """Test forward pass with no_rope pattern.""" - config = self.parallel_attention.config - config.no_rope_freq = 4 # Use pattern [0,0,0,1,0,0,0,1] - config.__post_init__() # Ensure pattern is converted - - sequence_length = 32 - micro_batch_size = 1 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.randn( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda().to(torch.bfloat16) - - attention_mask = None - - # Create rotary position embeddings - # Shape: [seq_len, 1, 1, kv_channels] - rotary_pos_emb = torch.randn( - sequence_length, 1, 1, self.parallel_attention.config.kv_channels - ).cuda() - - # For self-attention, rotary_pos_emb needs to be a tuple of (q_pos_emb, k_pos_emb) - rotary_pos_emb = (rotary_pos_emb, rotary_pos_emb) - - # Test with layer 3 which should skip RoPE - self.parallel_attention.layer_number = 3 - # Run forward pass without RoPE - output_without_rope, _ = self.parallel_attention( - hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb - ) - - # Test with layer 0 which should NOT skip RoPE - self.parallel_attention.layer_number = 0 - # Run forward pass with RoPE (but should be skipped for this layer) - output_with_rope, bias = self.parallel_attention( - hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb - ) - - # Verify RoPE was skipped for this layer - # If RoPE was skipped, outputs should be the same - assert not torch.allclose( - output_without_rope, output_with_rope - ), "Outputs are expected to be different." - - # Verify output shapes - assert config.recompute_granularity is None - assert output_with_rope.shape[0] == sequence_length - assert output_with_rope.shape[1] == micro_batch_size - assert output_with_rope.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - def test_invalid_no_rope_freq_pattern(self): - """Test invalid no_rope patterns raise appropriate errors.""" - config = self.transformer_config - - # Test invalid integer pattern - with pytest.raises(AssertionError): - config.no_rope_freq = 3 # Not divisible by num_layers=8 - config.__post_init__() - - # Test invalid list pattern - with pytest.raises(AssertionError): - config.no_rope_freq = [0, 1, 0, 1] # Wrong length - config.__post_init__() - - def test_gpu_forward_no_rope_freq_not_specified(self): - """Test forward pass with no_rope pattern not provided.""" - config = self.parallel_attention.config - config.no_rope_freq = None - config.__post_init__() # Ensure pattern is converted - - sequence_length = 32 - micro_batch_size = 1 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.randn( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda().to(torch.bfloat16) - - attention_mask = None - - # Create rotary position embeddings - # Shape: [seq_len, 1, 1, kv_channels] - rotary_pos_emb = torch.randn( - sequence_length, 1, 1, self.parallel_attention.config.kv_channels - ).cuda() - - # For self-attention, rotary_pos_emb needs to be a tuple of (q_pos_emb, k_pos_emb) - rotary_pos_emb = (rotary_pos_emb, rotary_pos_emb) - # Run forward pass - output, bias = self.parallel_attention( - hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb - ) - # Verify output shapes - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - def test_checkpointed_gpu_forward(self): - """Test checkpointed forward pass with no_rope pattern.""" - transformer_config = self.transformer_config - transformer_config.recompute_granularity = 'selective' - transformer_config.no_rope_freq = 4 # Use pattern [0,0,0,1,0,0,0,1] - transformer_config.__post_init__() - - checkpointed_parallel_attention = SelfAttention( - transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - config = checkpointed_parallel_attention.config - - sequence_length = 32 - micro_batch_size = 1 - - checkpointed_parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, checkpointed_parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda().to(torch.bfloat16) - - attention_mask = None - rotary_pos_emb = torch.ones( - sequence_length, 1, 1, checkpointed_parallel_attention.config.kv_channels - ).cuda() - - output, bias = checkpointed_parallel_attention( - hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb - ) - - assert config.recompute_granularity == 'selective' - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - def test_flash_decode_with_no_rope_freq(self): - """Test that flash_decode cannot be used with no_rope.""" - config = self.transformer_config - config.flash_decode = True - config.no_rope_freq = 4 # Use pattern [0,0,0,1,0,0,0,1] - - # Verify that setting both flash_decode and no_rope raises an assertion error - with pytest.raises(AssertionError, match="flash_decode cannot be used with no_rope"): - config.__post_init__() diff --git a/tests/unit_tests/transformer/test_attention_packed_seq.py b/tests/unit_tests/transformer/test_attention_packed_seq.py deleted file mode 100644 index e6e2c84739..0000000000 --- a/tests/unit_tests/transformer/test_attention_packed_seq.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.attention import SelfAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from tests.unit_tests.test_utilities import Utils - - -def make_test_packed_seq_params(sequence_length): - cu_seqlens = torch.IntTensor([0, 6, 19, 22, sequence_length]).cuda() - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - max_seqlen = seqlens.max().item() - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format='thd', - ) - return packed_seq_params - - -def make_test_packed_padded_seq_params(sequence_length): - cu_seqlens = torch.IntTensor([0, 18, 44, 52, 96, 118]).cuda() - cu_seqlens_padded = torch.IntTensor([0, 20, 48, 56, 100, sequence_length]).cuda() - seqlens = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] - max_seqlen = seqlens.max().item() - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format='thd', - ) - return packed_seq_params - - -class TestParallelAttentionWithPackedSequence: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - # use BF16 and a large enough hidden size to enable FlashAttention for thd format. - self.transformer_config = TransformerConfig( - num_layers=2, - hidden_size=64, - num_attention_heads=4, - use_cpu_initialization=True, - bf16=True, - params_dtype=torch.bfloat16, - pipeline_dtype=torch.bfloat16, - autocast_dtype=torch.bfloat16, - ) - self.parallel_attention = SelfAttention( - self.transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_cpu_forward(self): - # we can't currently do this because the global memory buffer is on GPU - pass - - def test_gpu_forward(self): - - config = self.parallel_attention.config - sequence_length = 32 - micro_batch_size = 1 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda().to(torch.bfloat16) - - attention_mask = None - - packed_seq_params = make_test_packed_seq_params(sequence_length) - output, bias = self.parallel_attention( - hidden_states, attention_mask, packed_seq_params=packed_seq_params - ) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - @pytest.mark.skipif(not is_te_min_version("1.4.0"), reason="Fused RoPE requires TE >= 1.4.0") - def test_fused_rope_gpu_forward(self): - self.parallel_attention.config.apply_rope_fusion = True - config = self.parallel_attention.config - sequence_length = 32 - micro_batch_size = 1 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda().to(torch.bfloat16) - - attention_mask = None - rotary_pos_emb = torch.ones( - sequence_length, 1, 1, self.parallel_attention.config.kv_channels - ).cuda() - - packed_seq_params = make_test_packed_seq_params(sequence_length) - output, bias = self.parallel_attention( - hidden_states, attention_mask, packed_seq_params=packed_seq_params - ) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - self.parallel_attention.config.apply_rope_fusion = False - - def test_checkpointed_gpu_forward(self): - transformer_config = self.transformer_config - transformer_config.recompute_granularity = 'selective' - checkpointed_parallel_attention = SelfAttention( - transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - config = checkpointed_parallel_attention.config - - sequence_length = 32 - micro_batch_size = 1 - - checkpointed_parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, checkpointed_parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda().to(torch.bfloat16) - - attention_mask = None - - packed_seq_params = make_test_packed_seq_params(sequence_length) - output, bias = checkpointed_parallel_attention( - hidden_states, attention_mask, packed_seq_params=packed_seq_params - ) - - assert config.recompute_granularity == 'selective' - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - -# Note: this test requires TE >= 1.8 as well as cuDNN FusedAttention to run -class TestParallelAttentionWithPackedPaddedSequence(TestParallelAttentionWithPackedSequence): - - def test_gpu_forward(self): - - config = self.parallel_attention.config - sequence_length = 128 - micro_batch_size = 1 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda().to(torch.bfloat16) - - attention_mask = None - - packed_seq_params = make_test_packed_padded_seq_params(sequence_length) - output, bias = self.parallel_attention( - hidden_states, attention_mask, packed_seq_params=packed_seq_params - ) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size diff --git a/tests/unit_tests/transformer/test_core_attention.py b/tests/unit_tests/transformer/test_core_attention.py deleted file mode 100644 index d8710e2242..0000000000 --- a/tests/unit_tests/transformer/test_core_attention.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - - -import pytest -import torch - -from megatron.core.transformer.attention import CrossAttention - -""" - -@pytest.fixture -def core_attention(transformer_config): - return CrossAttention(transformer_config) - - -class TestCoreAttention: - def test_constructor(self, core_attention): - assert isinstance(core_attention, CrossAttention) - assert core_attention.layer_number == 1 - - num_weights = sum([p.numel() for p in core_attention.parameters()]) - assert num_weights == 0 - - def test_cpu_forward(self, core_attention): - # we can't currently do this because the global memory buffer is on GPU - pass - - def test_gpu_forward(self, core_attention): - - # destroy_global_memory_buffer() - # _set_global_memory_buffer() - # model_parallel_cuda_manual_seed(123) - - core_attention.cuda() - config = core_attention.config - sequence_length = 32 - micro_batch_size = 2 - # query_layer (float): [sequence_length, micro_batch_size, num_attention_heads, hidden_size / num_attention_heads] - query_layer = torch.ones( - ( - sequence_length, - micro_batch_size, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ) - ).cuda() - - key_layer = torch.ones_like(query_layer).cuda() - - value_layer = torch.ones_like(query_layer).cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - context_layer = core_attention( - query_layer=query_layer, key_layer=key_layer, value_layer=value_layer, attention_mask=attention_mask - ) - - assert context_layer.shape[0] == sequence_length - assert context_layer.shape[1] == micro_batch_size - assert context_layer.shape[2] == config.hidden_size - assert context_layer.device.type == 'cuda' - assert context_layer.dtype == torch.float32 - -""" diff --git a/tests/unit_tests/transformer/test_cuda_graphs.py b/tests/unit_tests/transformer/test_cuda_graphs.py deleted file mode 100644 index b764efc371..0000000000 --- a/tests/unit_tests/transformer/test_cuda_graphs.py +++ /dev/null @@ -1,717 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import os -import random -import time -import types - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.inference.contexts import DynamicInferenceContext -from megatron.core.inference.engines import DynamicInferenceEngine -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( - GPTInferenceWrapper, -) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.text_generation_controllers.text_generation_controller import ( - TextGenerationController, -) -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.pipeline_parallel.schedules import set_current_microbatch -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.ssm.mamba_block import MambaStack -from megatron.core.tensor_parallel.random import ( - HAVE_TE, - initialize_rng_tracker, - model_parallel_cuda_manual_seed, -) -from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord -from megatron.core.transformer.transformer_block import TransformerBlock -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_fa_min_version, is_te_min_version -from tests.unit_tests.test_utilities import Utils - - -class TestParallelTransformerBlockCudagraphs: - def setup_method(self, method): - # initialize parallel state - initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) - Utils.initialize_model_parallel( - tensor_model_parallel_size=2, pipeline_model_parallel_size=2 - ) - model_parallel_cuda_manual_seed(123) - - # initialize transformer model - num_layers = 8 - hidden_size = 64 - self.transformer_config = TransformerConfig( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=4, - use_cpu_initialization=True, - enable_cuda_graph=True, - ) - self.parallel_transformer_block = TransformerBlock( - self.transformer_config, get_gpt_layer_with_transformer_engine_spec() - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - _CudagraphGlobalRecord.cudagraph_record = [] - CudaGraphManager.global_mempool = None - - @pytest.mark.skipif( - not (HAVE_TE and is_te_min_version("1.5.0")), - reason="use_te_rng_tracker requires TransformerEngine version >= 1.5", - ) - def test_gpu_cudagraph(self): - parallel_transformer_block = self.parallel_transformer_block - parallel_transformer_block.cuda() - - # [sequence length, batch size, hidden size] - sequence_length = 32 - micro_batch_size = 2 - transformer_config: TransformerConfig = parallel_transformer_block.config - num_layers = transformer_config.num_layers - hidden_size = transformer_config.hidden_size - hidden_states = torch.ones((sequence_length, micro_batch_size, hidden_size)) - hidden_states = hidden_states.cuda() - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - hidden_states = parallel_transformer_block( - hidden_states=hidden_states, attention_mask=attention_mask - ) - - for _ in range(num_layers): - assert hasattr(parallel_transformer_block.layers[0], "cudagraph_manager") - assert ( - len(parallel_transformer_block.layers[0].cudagraph_manager.cudagraph_runners) == 1 - ) - del ( - parallel_transformer_block.layers[_] - .cudagraph_manager.cudagraph_runners[0] - .fwd_graph - ) - - -@pytest.mark.skipif( - not (HAVE_TE and is_te_min_version("1.5.0")), - reason="use_te_rng_tracker requires TransformerEngine version >= 1.5", -) -@pytest.mark.parametrize( - "total_num_layers, pp, vpp, account_for_embedding_in_pipeline_split, account_for_loss_in_pipeline_split, num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage, pp_layout, first_layer_numbers_golden, last_layer_numbers_golden", - [ - (4, 1, None, False, False, None, None, None, [1], [4]), - (8, 2, None, False, False, None, None, None, [1, 5], [4, 8]), - (8, 2, 2, False, False, None, None, None, [1, 3, 5, 7], [2, 4, 6, 8]), - (14, 4, None, True, True, None, None, None, [1, 4, 8, 12], [3, 7, 11, 14]), - ( - 14, - 4, - 2, - True, - True, - None, - None, - None, - [1, 2, 4, 6, 8, 10, 12, 14], - [1, 3, 5, 7, 9, 11, 13, 14], - ), - (12, 4, None, False, False, 2, 2, None, [1, 3, 7, 11], [2, 6, 10, 12]), - ( - 12, - 4, - 2, - False, - False, - 2, - 2, - None, - [1, 2, 4, 6, 7, 8, 10, 12], - [1, 3, 5, 6, 7, 9, 11, 12], - ), - ( - 14, - 4, - 2, - False, - False, - None, - None, - [ - ["embedding", "decoder"], - ["decoder", "decoder"], - ["decoder", "decoder"], - ["decoder", "decoder"], - ["decoder", "decoder"], - ["decoder", "decoder"], - ["decoder", "decoder"], - ["decoder", "loss"], - ], - [1, 2, 4, 6, 8, 10, 12, 14], - [1, 3, 5, 7, 9, 11, 13, 14], - ), - ], -) -def test_cuda_graph_determine_first_last_layer_logic( - total_num_layers, - pp, - vpp, - account_for_embedding_in_pipeline_split, - account_for_loss_in_pipeline_split, - num_layers_in_first_pipeline_stage, - num_layers_in_last_pipeline_stage, - pp_layout, - first_layer_numbers_golden, - last_layer_numbers_golden, -): - # Initialize RNG tracker - initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) - - # Initialize parallel state - Utils.initialize_model_parallel( - pipeline_model_parallel_size=pp, virtual_pipeline_model_parallel_size=vpp - ) - - # initialize model - torch.manual_seed(123) - model_parallel_cuda_manual_seed(123) - hidden_size = 128 - transformer_config = TransformerConfig( - num_layers=total_num_layers, - hidden_size=hidden_size, - num_attention_heads=1, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - bf16=True, - virtual_pipeline_model_parallel_size=vpp, - pipeline_model_parallel_size=pp, - deallocate_pipeline_outputs=True, - enable_cuda_graph=True, - use_te_rng_tracker=True, - account_for_embedding_in_pipeline_split=account_for_embedding_in_pipeline_split, - account_for_loss_in_pipeline_split=account_for_loss_in_pipeline_split, - num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, - num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, - pipeline_model_parallel_layout=pp_layout, - ) - model = [] - for i in range(vpp or 1): - this_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), - vocab_size=128, - max_sequence_length=1024, - position_embedding_type="rope", - vp_stage=i, - ).cuda() - model.append(this_model) - - # create runner by running a fake forward pass - sequence_length, micro_batch_size = 32, 1 - hidden_states = torch.ones((sequence_length, micro_batch_size, hidden_size)).cuda() - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - for m in model: - _ = m( - input_ids=None, - position_ids=None, - attention_mask=attention_mask, - decoder_input=hidden_states, - ) - - # Check if cuda graph is correctly setting is first/last layer - for m in model: - for l in m.decoder.layers: - assert hasattr(l, "cudagraph_manager") - assert ( - len(l.cudagraph_manager.cudagraph_runners) == 1 - ), "Cuda graph runner should be created" - runner = l.cudagraph_manager.cudagraph_runners[0] - assert runner.is_first_layer is not None and runner.is_last_layer is not None - assert runner.is_first_layer == (l.layer_number in first_layer_numbers_golden) - assert runner.is_last_layer == (l.layer_number in last_layer_numbers_golden) - - del l.cudagraph_manager.cudagraph_runners[0].fwd_graph - - # Destroy all captured graphs deterministically - for m in model: - for l in m.decoder.layers: - for runner in getattr(l.cudagraph_manager, "cudagraph_runners", []): - # Safely delete both graphs if present - if hasattr(runner, "fwd_graph"): - del runner.fwd_graph - if hasattr(runner, "bwd_graph"): - del runner.bwd_graph - - # Ensure all pending work is complete and graph destruction runs now - torch.cuda.synchronize() - - # Teardown - Utils.destroy_model_parallel() - _CudagraphGlobalRecord.cudagraph_record = [] - CudaGraphManager.global_mempool = None - CudaGraphManager.fwd_mempools = None - CudaGraphManager.bwd_mempools = None - - -class TestLLaVACudaGraph: - """Test CUDA graphs with LLaVA model focusing on is_last_layer logic for encoder/decoder transitions.""" - - def setup_method(self, method): - # Initialize parallel state - initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - virtual_pipeline_model_parallel_size=None, - ) - model_parallel_cuda_manual_seed(123) - - from copy import deepcopy - - from megatron.core.models.multimodal.llava_model import LLaVAModel - from megatron.core.models.vision.vit_layer_specs import ( - get_vit_layer_with_transformer_engine_spec, - ) - - # Create language transformer config with CUDA graphs enabled - self.language_hidden_size = 64 - self.language_num_attention_heads = 4 - language_config = TransformerConfig( - num_layers=2, - hidden_size=self.language_hidden_size, - num_attention_heads=self.language_num_attention_heads, - use_cpu_initialization=True, - enable_cuda_graph=True, # Enable CUDA graphs - ) - - # Create vision transformer config - vision_config = TransformerConfig( - num_layers=2, - hidden_size=16, - num_attention_heads=2, - use_cpu_initialization=True, - enable_cuda_graph=True, # Enable CUDA graphs for vision model too - ) - - # Create vision projection config - vision_projection_config = TransformerConfig( - num_layers=1, - hidden_size=self.language_hidden_size, - ffn_hidden_size=32, - num_attention_heads=1, - use_cpu_initialization=True, - ) - - # Get layer specs - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() - vision_layer_spec = get_vit_layer_with_transformer_engine_spec() - vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules) - - # Set vision model type - vision_config.vision_model_type = "clip" - language_config.language_model_type = "dummy" - - # Create LLaVA model with both encoder and decoder - self.llava_model = LLaVAModel( - language_transformer_config=language_config, - language_transformer_layer_spec=language_layer_spec, - language_vocab_size=8192, - language_max_sequence_length=4096, - vision_transformer_config=vision_config, - vision_transformer_layer_spec=vision_layer_spec, - drop_vision_class_token=False, - vision_projection_config=vision_projection_config, - vision_projection_layer_spec=vision_projection_spec, - img_h=336, - img_w=336, - patch_dim=14, - pre_process=True, - post_process=True, - add_encoder=True, - add_decoder=True, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - _CudagraphGlobalRecord.cudagraph_record = [] - - @pytest.mark.skipif( - not (HAVE_TE and is_te_min_version("1.5.0")), - reason="use_te_rng_tracker requires TransformerEngine version >= 1.5", - ) - def test_llava_cudagraph_is_last_layer_logic(self): - """Test that is_last_layer logic correctly resets prev_bwd_hidden_state_inputgrad for LLaVA models.""" - - # Move model to CUDA - self.llava_model.cuda() - - set_current_microbatch(self.llava_model.vision_model, 1) - set_current_microbatch(self.llava_model.language_model, 1) - - # Create test inputs - batch_size = 2 - seq_length = 1024 - num_images = 1 - - images = torch.ones((num_images, 3, 336, 336), dtype=torch.float32).cuda() - - # Create text input with image tokens - input_ids = torch.randint(0, 1000, (batch_size, seq_length), dtype=torch.long).cuda() - # Insert image token (using default image token index) - input_ids[0, 5] = self.llava_model.image_token_index - - position_ids = torch.arange(seq_length).unsqueeze(0).expand(batch_size, -1).cuda() - attention_mask = None - - # Create labels and loss mask for training - labels = torch.randint(0, 1000, (batch_size, seq_length), dtype=torch.long).cuda() - loss_mask = torch.ones((batch_size, seq_length), dtype=torch.float32).cuda() - - # Create num_image_tiles - num_image_tiles = torch.ones(num_images, dtype=torch.int).cuda() - - # First forward pass - this should record the CUDA graphs - output1, loss_mask1 = self.llava_model( - images=images, - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - num_image_tiles=num_image_tiles, - ) - - # Verify that CUDA graph managers were created - if hasattr(self.llava_model.vision_model, 'decoder') and hasattr( - self.llava_model.vision_model.decoder, 'layers' - ): - for layer in self.llava_model.vision_model.decoder.layers: - if hasattr(layer, 'cudagraph_manager'): - assert ( - layer.cudagraph_manager is not None - ), "Vision model layers should have CUDA graph managers" - - if hasattr(self.llava_model.language_model, 'decoder') and hasattr( - self.llava_model.language_model.decoder, 'layers' - ): - for layer in self.llava_model.language_model.decoder.layers: - if hasattr(layer, 'cudagraph_manager'): - assert ( - layer.cudagraph_manager is not None - ), "Language model layers should have CUDA graph managers" - - # Verify that CUDA graphs were created successfully - for runner in layer.cudagraph_manager.cudagraph_runners: - assert hasattr(runner, 'fwd_graph') - assert hasattr(runner, 'bwd_graph') - - # Perform backward pass to trigger backward graph recording - if isinstance(output1, tuple): - loss = output1[0].sum() - else: - loss = output1.sum() - loss.backward() - - if hasattr(self.llava_model.vision_model, 'decoder') and hasattr( - self.llava_model.vision_model.decoder, 'layers' - ): - for layer in self.llava_model.vision_model.decoder.layers: - del layer.cudagraph_manager.cudagraph_runners[0].fwd_graph - del layer.cudagraph_manager.cudagraph_runners[0].bwd_graph - - if hasattr(self.llava_model.language_model, 'decoder') and hasattr( - self.llava_model.language_model.decoder, 'layers' - ): - for layer in self.llava_model.language_model.decoder.layers: - del layer.cudagraph_manager.cudagraph_runners[0].fwd_graph - del layer.cudagraph_manager.cudagraph_runners[0].bwd_graph - - -class TestParallelMambaBlockCudagraphs: - def setup_method(self, method): - # initialize parallel state - initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) - Utils.initialize_model_parallel(tensor_model_parallel_size=2) - model_parallel_cuda_manual_seed(123) - - # Ensure that this test is capturing to a fresh memory pool. - CudaGraphManager.global_mempool = None - - def get_model_comm_pgs(): - return ModelCommProcessGroups.use_mpu_process_groups(required_pgs=['tp', 'pp', 'cp']) - - def get_mamba_block(hybrid_override_pattern): - transformer_config = TransformerConfig( - hidden_size=256, # The Mamba layer places several constraints on this - # Need to specify num_attention_heads and num_layers or TransformerConfig - # will generate errors. - num_layers=len(hybrid_override_pattern), - num_attention_heads=4, - use_cpu_initialization=True, - enable_cuda_graph=True, - ) - modules = mamba_stack_spec.submodules - return MambaStack( - transformer_config, - modules, - hybrid_override_pattern=hybrid_override_pattern, - model_comm_pgs=get_model_comm_pgs(), - ) - - self.mamba_block = get_mamba_block(hybrid_override_pattern="M-M*-") - self.transformer_config = self.mamba_block.config - - def teardown_method(self, method): - Utils.destroy_model_parallel() - _CudagraphGlobalRecord.cudagraph_record = [] - - @pytest.mark.skipif( - not (HAVE_TE and is_te_min_version("1.5.0")), - reason="use_te_rng_tracker requires TransformerEngine version >= 1.5", - ) - def test_gpu_cudagraph(self): - parallel_mamba_block = self.mamba_block - parallel_mamba_block.cuda() - - # [sequence length, batch size, hidden size] - sequence_length = 32 - micro_batch_size = 2 - transformer_config: TransformerConfig = parallel_mamba_block.config - num_layers = transformer_config.num_layers - hidden_size = transformer_config.hidden_size - hidden_states = torch.ones((sequence_length, micro_batch_size, hidden_size)) - hidden_states = hidden_states.cuda() - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - hidden_states = parallel_mamba_block( - hidden_states=hidden_states, attention_mask=attention_mask - ) - - for _ in range(num_layers): - assert hasattr(parallel_mamba_block.layers[0], "cudagraph_manager") - assert len(parallel_mamba_block.layers[0].cudagraph_manager.cudagraph_runners) == 1 - - del parallel_mamba_block.layers[_].cudagraph_manager.cudagraph_runners[0].fwd_graph - - -class TestCaptureFreezeGC: - - def capture_cuda_graphs(self, cuda_graph_capture_freeze_gc: bool) -> None: - """Capture multiple cuda graphs by initializing the `DynamicInferenceEngine`. - - The `DynamicInferenceEngine` is used here because it is currently (as of - August 2025) one of the heaviest users of multiple cuda graphs, and so - its setup tests a realistic use-case of multi-batch size cuda graphs. - - Args: - cuda_graph_capture_freeze_gc (bool): Flag that determines whether to - freeze garbage collection. - """ - - # Set freeze-gc environment variable. - os.environ["CUDA_GRAPH_CAPTURE_FREEZE_GC"] = str(int(cuda_graph_capture_freeze_gc)) - - # Configuration. - random_seed = 123 - vocab_size = 100 - num_tokens_to_prompt = 128 - num_tokens_to_generate = 32 - max_sequence_length = num_tokens_to_prompt + num_tokens_to_generate - num_cuda_graphs = 4 - - # Rounder values. - rounder = 4 - DynamicInferenceContext.ROUNDER = rounder # For backwards compatibility - DynamicInferenceContext.TOKEN_ROUNDER = rounder - DynamicInferenceContext.REQUEST_ROUNDER = rounder - - # Random state. - random.seed(random_seed) - torch.manual_seed(random_seed) - model_parallel_cuda_manual_seed( - seed=random_seed, - inference_rng_tracker=True, - use_cudagraphable_rng=False, - force_reset_rng=True, - ) - - # Transformer config. - transformer_config = TransformerConfig( - params_dtype=torch.bfloat16, - num_layers=4, - hidden_size=32, - num_attention_heads=4, - use_cpu_initialization=True, - enable_cuda_graph=True, - inference_rng_tracker=True, - tensor_model_parallel_size=1, # needed? - ) - - # Sampling params. - sampling_params = SamplingParams(num_tokens_to_generate=num_tokens_to_generate) - - # GPT model. - model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_local_spec(), - vocab_size=vocab_size, - max_sequence_length=max_sequence_length, - parallel_output=True, - ).cuda() - - for param in model.parameters(): - param.data = param.data.to(transformer_config.params_dtype) - - model.eval() - - # Inference config. - inference_config = InferenceWrapperConfig( - hidden_size=transformer_config.hidden_size, - inference_batch_times_seqlen_threshold=400, - fp32_residual_connection=False, - params_dtype=transformer_config.params_dtype, - padded_vocab_size=vocab_size, - ) - - # Inference context. - context = DynamicInferenceContext( - params_dtype=transformer_config.params_dtype, - num_layers=transformer_config.num_layers, - kv_channels=transformer_config.kv_channels, - num_attention_heads=transformer_config.num_query_groups, - max_sequence_length=max_sequence_length, - num_cuda_graphs=num_cuda_graphs, - buffer_size_gb=20, - buffer_guaranteed_fraction=0.05, - chunk_size_tokens=256, - buffer_overflow_factor=1.1, - max_requests_override=512, - max_tokens_override=8196, - tensor_model_parallel_size=transformer_config.tensor_model_parallel_size, - ) - - # Inference model wrapper. - inference_wrapped_model = GPTInferenceWrapper(model, inference_config, context) - - # Note: the following is taken from AbstractModelInferenceWrapper.prep_model_for_inference(). - inference_wrapped_model.model_is_pipeline_parallel = not ( - parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() - ) - - # Text generation controller. - text_generation_controller = TextGenerationController( - inference_wrapped_model=inference_wrapped_model, - tokenizer=types.SimpleNamespace(vocab_size=vocab_size), - ) - - # Inference engine. - engine = DynamicInferenceEngine( - text_generation_controller, - context, - termination_id=vocab_size - 1, - random_seed=random_seed, - ) - - return engine.capture_stats - - @pytest.mark.experimental - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - def test_capture_freeze_gc(self): - """Test cuda graph capture while freezing the GC.""" - - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, pipeline_model_parallel_size=1 - ) - - # Run tests with GC freeze off/on. - result_map = {} - for freeze_gc in (False, True): - - # Reset global cuda graph state. - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] - CudaGraphManager.global_mempool = None - - # Capture multiple cuda graphs by initializing DynamicInferenceEngine. - mem_stats_start = torch.cuda.memory_stats() - time_start = time.time() - internal_stats = self.capture_cuda_graphs(freeze_gc) - time_end = time.time() - mem_stats_end = torch.cuda.memory_stats() - - # Track local (external) stats, in addition to internal stats. - external_stats = { - "time": time_end - time_start, - "allocated_bytes": ( - mem_stats_end["allocated_bytes.all.current"] - - mem_stats_start["allocated_bytes.all.current"] - ), - "reserved_bytes": ( - mem_stats_end["reserved_bytes.all.current"] - - mem_stats_start["reserved_bytes.all.current"] - ), - } - - # Record results. - result_map[freeze_gc] = {"internal": internal_stats, "external": external_stats} - - # Extract results. - freeze_off_results = result_map[False] - freeze_on_results = result_map[True] - print( - "test capture | freeze off: internal %.3f, external %.3f." - % (freeze_off_results["internal"]["time"], freeze_off_results["external"]["time"]) - ) - print( - "test capture | freeze on: internal %.3f, external %.3f." - % (freeze_on_results["internal"]["time"], freeze_on_results["external"]["time"]) - ) - - # Validate time and memory usage. - assert freeze_on_results["internal"]["time"] < 0.2 * freeze_off_results["internal"]["time"] - assert freeze_on_results["external"]["time"] < 0.2 * freeze_off_results["external"]["time"] - assert ( - freeze_on_results["internal"]["allocated_bytes"] - <= freeze_off_results["internal"]["allocated_bytes"] - ) - assert ( - freeze_on_results["external"]["allocated_bytes"] - <= freeze_off_results["external"]["allocated_bytes"] - ) - assert ( - freeze_on_results["internal"]["reserved_bytes"] - <= freeze_off_results["internal"]["reserved_bytes"] - ) - assert ( - freeze_on_results["external"]["reserved_bytes"] - <= freeze_off_results["external"]["reserved_bytes"] - ) - - -if __name__ == "__main__": - - test = TestParallelTransformerBlockCudagraphs() - test.setup_method(method=None) - test.test_gpu_cudagraph() - test.teardown_method(method=None) - - llava_test = TestLLaVACudaGraph() - llava_test.setup_method(method=None) - llava_test.test_llava_cudagraph_is_last_layer_logic() - llava_test.teardown_method(method=None) - - test = TestCaptureFreezeGC() - test.test_capture_freeze_gc() diff --git a/tests/unit_tests/transformer/test_full_cuda_graph.py b/tests/unit_tests/transformer/test_full_cuda_graph.py deleted file mode 100644 index 312ae46730..0000000000 --- a/tests/unit_tests/transformer/test_full_cuda_graph.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch -from pytest_mock import mocker - -import megatron.core.pipeline_parallel.schedules as schedule -from megatron.core import ModelParallelConfig -from megatron.core.full_cuda_graph import FullCudaGraphWrapper -from megatron.core.tensor_parallel.random import ( - HAVE_TE, - initialize_rng_tracker, - model_parallel_cuda_manual_seed, -) -from megatron.core.utils import is_te_min_version -from tests.unit_tests.test_utilities import Utils - -rank = Utils.rank - - -@pytest.mark.skipif( - not (HAVE_TE and is_te_min_version("1.5.0")), - reason="use_te_rng_tracker requires TransformerEngine version >= 1.5", -) -def test_forward_backward_func_with_full_cuda_graph(mocker): - from megatron.core.pipeline_parallel import get_forward_backward_func - - initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) - - def forward_step_func(data_iterator, model): - import os - - rank = int(os.environ['LOCAL_RANK']) - dummy_data = torch.ones(1, 4) - - def loss_func(output_tensor): - return rank, {'loss_reduced': rank} - - return model(dummy_data), loss_func - - model = torch.nn.Linear(4, 1) - - model.model_type = 'unit-test' - - def set_input_tensor(input_tensor): - return None - - model.set_input_tensor = set_input_tensor - - forward_backward_func = get_forward_backward_func() - assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining - - # Wrapping the forward_backward_func with FullCudaGraphWrapper enables full iteration CUDA graphs. - forward_backward_func = FullCudaGraphWrapper(forward_backward_func) - mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2) - config = ModelParallelConfig(pipeline_model_parallel_size=1) - model.config = config - - num_microbatches = 4 - - # CUDA graph warmup - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[iter([{'input': torch.ones(1, 4)}] * num_microbatches)], - model=[model], - num_microbatches=num_microbatches, - seq_length=None, - micro_batch_size=None, - forward_only=True, - ) - # CUDA graph capture and replay - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=[iter([{'input': torch.ones(1, 4)}] * num_microbatches)], - model=[model], - num_microbatches=num_microbatches, - seq_length=None, - micro_batch_size=None, - forward_only=True, - ) - loss_reduced_expected = [ - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - {'loss_reduced': rank}, - ] - - for i, j in zip(losses_reduced, loss_reduced_expected): - print(losses_reduced) - assert i['loss_reduced'] == j['loss_reduced'] - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/transformer/test_mlp.py b/tests/unit_tests/transformer/test_mlp.py deleted file mode 100644 index d2c25e0cc5..0000000000 --- a/tests/unit_tests/transformer/test_mlp.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.mlp import MLP -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - - -class TestParallelMLP: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - self.mlp = MLP(transformer_config, get_gpt_layer_local_spec().submodules.mlp.submodules) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - assert isinstance(self.mlp, MLP) - - num_weights = sum([p.numel() for p in self.mlp.parameters()]) - assert num_weights == 1212 - - """ - def test_cpu_forward(self, mlp): - # [sequence length, micro batch size, hidden size] - hidden_states = torch.ones((32, 2, mlp.config.hidden_size)) - output, output_bias = mlp(hidden_states) - assert output.shape[0] == 32 - assert output.shape[1] == 2 - assert output.shape[2] == mlp.config.hidden_size - assert output_bias.shape[0] == mlp.config.hidden_size - assert output.dtype == torch.float32 - """ - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_gpu_forward(self): - mlp = self.mlp - mlp.cuda() - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((32, 2, mlp.config.hidden_size)) - hidden_states = hidden_states.cuda() - output, output_bias = mlp(hidden_states) - assert output.shape[0] == 32 - assert output.shape[1] == 2 - assert output.shape[2] == mlp.config.hidden_size - assert output_bias.shape[0] == mlp.config.hidden_size - assert output.dtype == torch.float32 - assert output.device.type == 'cuda' - assert output_bias.device.type == 'cuda' diff --git a/tests/unit_tests/transformer/test_module.py b/tests/unit_tests/transformer/test_module.py deleted file mode 100644 index 64826a0ee5..0000000000 --- a/tests/unit_tests/transformer/test_module.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.module import Float16Module, MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig -from tests.unit_tests.test_utilities import Utils - -DEVICE_CAPABILITY = None -if torch.cuda.is_available(): - DEVICE_CAPABILITY = torch.cuda.get_device_capability() - - -class DummyModule(MegatronModule): - # def __init__(self, config: TransformerConfig, share_embeddings_and_output_weights=True): - def __init__(self, config: TransformerConfig): - super().__init__(config) - - self.linear = torch.nn.modules.Linear(in_features=2, out_features=1) - - def forward(self, x): - return self.linear(x) - - -class TestMegatronModule: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - self.megatron_module = DummyModule(config=transformer_config).cuda() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_megatron_module(self): - megatron_module = self.megatron_module - assert megatron_module - assert megatron_module.config.hidden_size == 12 - assert megatron_module.config.ffn_hidden_size == 48 - assert megatron_module.linear.weight.dtype == torch.float32 - - x = torch.ones((2, 2)).cuda() - assert megatron_module(x).dtype == torch.float32 - - # TODO: test bad configs actually fail - # failed_module = megatron_module - # failed_module.fp16 = True - # failed_module.bf16 = True - - -class TestFloat16Module: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - self.megatron_module = DummyModule(config=self.transformer_config).cuda() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_fp16_module(self): - transformer_config = self.transformer_config - megatron_module = self.megatron_module - transformer_config.fp16 = True - fp16_module = Float16Module(config=transformer_config, module=megatron_module) - - assert fp16_module - assert fp16_module.config.hidden_size == 12 - assert fp16_module.config.ffn_hidden_size == 48 - assert fp16_module.module.linear.weight.dtype == torch.float16 - - x = torch.ones((2, 2)).cuda() - # inputs are converted to fp16 then outputs are converted to fp32 - assert fp16_module(x).dtype == torch.float32 - - pytest.mark.skipif( - not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, - reason='bfloat16 is not supported on this device', - ) - - def test_bf16_module(self): - transformer_config = self.transformer_config - megatron_module = self.megatron_module - transformer_config.bf16 = True - bf16_module = Float16Module(config=transformer_config, module=megatron_module) - - assert bf16_module - assert bf16_module.config.hidden_size == 12 - assert bf16_module.config.ffn_hidden_size == 48 - assert bf16_module.module.linear.weight.dtype == torch.bfloat16 - - x = torch.ones((2, 2)).cuda() - # inputs are converted to bf16 then outputs are converted to fp32 - assert bf16_module(x).dtype == torch.float32 diff --git a/tests/unit_tests/transformer/test_multi_latent_attention.py b/tests/unit_tests/transformer/test_multi_latent_attention.py deleted file mode 100644 index dd57adf378..0000000000 --- a/tests/unit_tests/transformer/test_multi_latent_attention.py +++ /dev/null @@ -1,1272 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import os -from functools import partial -from importlib.metadata import version -from inspect import signature -from unittest import mock - -import pytest -import torch -import transformer_engine as te - -from megatron.core import parallel_state -from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider -from megatron.core.models.common.embeddings.rope_utils import ( - get_pos_emb_on_this_cp_rank as get_tensor_on_this_cp_rank, -) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.attention import Attention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention -from megatron.core.transformer.transformer_config import MLATransformerConfig -from megatron.core.utils import is_te_min_version, is_torch_min_version -from megatron.training.arguments import parse_args -from megatron.training.checkpointing import load_checkpoint, save_checkpoint -from megatron.training.global_vars import set_args -from megatron.training.training import get_model -from megatron.training.utils import unwrap_model -from tests.unit_tests.dist_checkpointing import ( - TempNamedDir, - init_basic_mock_args, - init_checkpointing_mock_args, -) -from tests.unit_tests.test_utilities import Utils - - -def make_test_packed_seq_params(sequence_length=None, cu_seqlens=None): - if cu_seqlens is None: - assert sequence_length is not None - cu_seqlens = [0, 6, 19, 22, sequence_length] - cu_seqlens = torch.IntTensor(cu_seqlens).cuda() - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - max_seqlen = seqlens.max().item() - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format='thd', - ) - return packed_seq_params - - -def make_test_packed_seq_params_with_padding( - sequence_length=None, cu_seqlens=None, cu_seqlens_padded=None -): - """Create PackedSeqParams with both regular and padded cu_seqlens for testing padded sequences.""" - if cu_seqlens is None: - assert sequence_length is not None - cu_seqlens = [ - 0, - 6, - 19, - 22, - sequence_length - 8, - ] # Actual sequence lengths (with some padding removed) - if cu_seqlens_padded is None: - assert sequence_length is not None - cu_seqlens_padded = [0, 8, 22, 28, sequence_length] # Padded sequence lengths - - cu_seqlens = torch.IntTensor(cu_seqlens).cuda() - cu_seqlens_padded = torch.IntTensor(cu_seqlens_padded).cuda() - - # Use padded lengths for max_seqlen calculation - seqlens_padded = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] - max_seqlen, _ = seqlens_padded.max(dim=0, keepdim=True) - max_seqlen = max_seqlen.tolist()[0] - - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format='thd', - ) - return packed_seq_params - - -def get_mla_self_attn_submodules(linear_qkv_down_proj=None): - submodules = get_gpt_layer_with_transformer_engine_spec( - multi_latent_attention=True - ).submodules.self_attention.submodules - if linear_qkv_down_proj is not None: - submodules.linear_q_down_proj = linear_qkv_down_proj - submodules.linear_kv_down_proj = linear_qkv_down_proj - return submodules - - -backend = TESpecProvider() -linear_qkv_down_proj_options = [backend.linear(), backend.column_parallel_linear()] - - -@pytest.mark.parametrize("rope_type", ('yarn', 'rope')) -class TestParallelMLAAttention: - - @pytest.fixture(scope='function', autouse=True) - def setup_and_teardown(self, rope_type): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = MLATransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - q_lora_rank=32, - kv_lora_rank=32, - qk_head_dim=128, - v_head_dim=128, - qk_pos_emb_head_dim=64, - rope_type=rope_type, - rotary_base=10000, - original_max_position_embeddings=32, - ) - self.parallel_attention = MLASelfAttention( - self.transformer_config, - get_mla_self_attn_submodules(), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_input_params_forward(self): - """ - Test to ensure that MultiLatentAttention has all parameters - required by the Attention class's forward method. - """ - # Extract parameters from the forward methods of both Attention and MultiLatentAttention - attn_params = set(signature(Attention.forward).parameters.keys()) - mla_params = set(signature(MultiLatentAttention.forward).parameters.keys()) - - # Identify parameters that are in Attention but missing in MultiLatentAttention - missing_params = attn_params - mla_params - assert not missing_params, f"Missing parameters in MultiLatentAttention: {missing_params}" - - def test_constructor(self): - assert isinstance(self.parallel_attention, MLASelfAttention) - assert self.parallel_attention.layer_number == 1 - - num_weights = sum([p.numel() for p in self.parallel_attention.parameters()]) - assert num_weights == 65036 - - def test_cpu_forward(self): - # we can't currently do this because the global memory buffer is on GPU - pass - - def test_gpu_forward(self): - if is_te_min_version("1.10.0"): - config = self.parallel_attention.config - sequence_length = 32 - micro_batch_size = 2 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - output, bias = self.parallel_attention(hidden_states, attention_mask) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - @pytest.mark.experimental - def test_gpu_forward_with_yarn_rope_fusion(self): - if self.transformer_config.rope_type == "rope": - pytest.skip("Rope is not supported for this test") - if is_te_min_version("1.10.0"): - transformer_config = self.transformer_config - transformer_config.apply_rope_fusion = True - checkpointed_parallel_attention = MLASelfAttention( - transformer_config, - get_mla_self_attn_submodules(), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - config = checkpointed_parallel_attention.config - - sequence_length = 32 - micro_batch_size = 2 - - checkpointed_parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - ( - sequence_length, - micro_batch_size, - checkpointed_parallel_attention.config.hidden_size, - ) - ) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - output, bias = checkpointed_parallel_attention(hidden_states, attention_mask) - - assert config.apply_rope_fusion == True - - def test_gpu_forward_thd(self): - if is_te_min_version("1.10.0"): - # use flash attention for hopper, future may support fused attention for ampere - _environ = os.environ.copy() - os.environ['NVTE_FUSED_ATTN'] = "1" - os.environ['NVTE_FLASH_ATTN'] = "0" - - config = self.parallel_attention.config - sequence_length = 32 - micro_batch_size = 1 - - self.parallel_attention.cuda().bfloat16() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda().bfloat16() - - attention_mask = None - packed_seq_params = make_test_packed_seq_params(sequence_length=sequence_length) - output, bias = self.parallel_attention( - hidden_states, attention_mask, packed_seq_params=packed_seq_params - ) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - os.environ.clear() - os.environ.update(_environ) - - def test_gpu_forward_thd_padded(self): - """Test MLA forward pass with cu_seqlens_q_padded and cu_seqlens_kv_padded.""" - if is_te_min_version("1.10.0"): - config = self.parallel_attention.config - sequence_length = 32 - micro_batch_size = 1 - - self.parallel_attention.cuda().bfloat16() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda().bfloat16() - - attention_mask = None - - # Create packed seq params with both regular and padded cu_seqlens - packed_seq_params = make_test_packed_seq_params_with_padding( - sequence_length=sequence_length - ) - - # Verify that the PackedSeqParams has both regular and padded cu_seqlens - assert packed_seq_params.cu_seqlens_q is not None - assert packed_seq_params.cu_seqlens_kv is not None - assert packed_seq_params.cu_seqlens_q_padded is not None - assert packed_seq_params.cu_seqlens_kv_padded is not None - - # Test the forward pass with padded cu_seqlens - output, bias = self.parallel_attention( - hidden_states, attention_mask, packed_seq_params=packed_seq_params - ) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - # Test that the get_query_key_value_tensors function properly handles padded cu_seqlens - query, key, value = self.parallel_attention.get_query_key_value_tensors( - hidden_states, None, None, packed_seq_params, None - ) - - assert query is not None - assert key is not None - assert value is not None - assert query.is_contiguous() - assert key.is_contiguous() - assert value.is_contiguous() - - def test_checkpointed_gpu_forward(self): - if is_te_min_version("1.10.0"): - transformer_config = self.transformer_config - transformer_config.recompute_granularity = 'selective' - checkpointed_parallel_attention = MLASelfAttention( - transformer_config, - get_mla_self_attn_submodules(), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - config = checkpointed_parallel_attention.config - - sequence_length = 32 - micro_batch_size = 2 - - checkpointed_parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - ( - sequence_length, - micro_batch_size, - checkpointed_parallel_attention.config.hidden_size, - ) - ) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - output, bias = checkpointed_parallel_attention(hidden_states, attention_mask) - - assert config.recompute_granularity == 'selective' - assert "core_attn" in config.recompute_modules - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - def test_up_proj_recomputed_gpu_forward(self): - if is_te_min_version("1.10.0"): - transformer_config = self.transformer_config - transformer_config.recompute_granularity = 'selective' - transformer_config.recompute_modules = ["mla_up_proj"] - checkpointed_parallel_attention = MLASelfAttention( - transformer_config, - get_mla_self_attn_submodules(), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - config = checkpointed_parallel_attention.config - - sequence_length = 32 - micro_batch_size = 2 - - checkpointed_parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - ( - sequence_length, - micro_batch_size, - checkpointed_parallel_attention.config.hidden_size, - ) - ) - hidden_states = hidden_states.cuda() - - q, k, v = checkpointed_parallel_attention.get_query_key_value_tensors(hidden_states) - assert q.is_contiguous() - assert k.is_contiguous() - assert v.is_contiguous() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - output, bias = checkpointed_parallel_attention(hidden_states, attention_mask) - - assert checkpointed_parallel_attention.recompute_up_proj == True - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - -@pytest.mark.parametrize("linear_qkv_down_proj", linear_qkv_down_proj_options) -class TestSequenceParallelMLAAttention: - @pytest.fixture(scope='function', autouse=True) - def setup_method(self, linear_qkv_down_proj): - self.tensor_parallel_size = 2 - Utils.initialize_model_parallel(self.tensor_parallel_size, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = MLATransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - q_lora_rank=32, - kv_lora_rank=32, - qk_head_dim=128, - v_head_dim=128, - qk_pos_emb_head_dim=64, - rotary_base=10000, - original_max_position_embeddings=64, - tensor_model_parallel_size=self.tensor_parallel_size, - sequence_parallel=True, - ) - self.parallel_attention = MLASelfAttention( - self.transformer_config, - get_mla_self_attn_submodules(linear_qkv_down_proj=linear_qkv_down_proj), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_gpu_forward(self): - if is_te_min_version("1.10.0"): - config = self.parallel_attention.config - sequence_length = 64 - sub_sequence_length = sequence_length // self.tensor_parallel_size - micro_batch_size = 2 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sub_sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - output, bias = self.parallel_attention(hidden_states, attention_mask) - - assert config.recompute_granularity is None - assert output.shape[0] == sub_sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - -@pytest.mark.parametrize("linear_qkv_down_proj", linear_qkv_down_proj_options) -class TestTensorParallelMLAAttention: - @pytest.fixture(scope='function', autouse=True) - def setup_method(self, linear_qkv_down_proj): - self.tensor_parallel_size = 2 - Utils.initialize_model_parallel(self.tensor_parallel_size, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = MLATransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - q_lora_rank=32, - kv_lora_rank=32, - qk_head_dim=128, - v_head_dim=128, - qk_pos_emb_head_dim=64, - rotary_base=10000, - original_max_position_embeddings=64, - tensor_model_parallel_size=self.tensor_parallel_size, - sequence_parallel=False, - ) - self.parallel_attention = MLASelfAttention( - self.transformer_config, - get_mla_self_attn_submodules(linear_qkv_down_proj=linear_qkv_down_proj), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_gpu_forward(self): - if is_te_min_version("1.10.0"): - config = self.parallel_attention.config - sequence_length = 64 - micro_batch_size = 2 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - output, bias = self.parallel_attention(hidden_states, attention_mask) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - -@pytest.mark.experimental -@pytest.mark.skipif( - not is_te_min_version("2.5.0", check_equality=True), - reason="Requires TransformerEngine >= 2.5.0", -) -@pytest.mark.parametrize( - ("rope_type", "apply_rope_fusion"), - ( - ('rope', False), - ('yarn', False), - ('yarn', True), # apply_rope_fusion for MLA only works with YARN RoPE. - ), -) -class TestContextParallelMLAAttention: - - @pytest.fixture(scope='function', autouse=True) - def setup_method(self, rope_type, apply_rope_fusion): - self.context_parallel_size = 4 - Utils.initialize_model_parallel(1, 1, context_parallel_size=self.context_parallel_size) - model_parallel_cuda_manual_seed(123) - self.transformer_config = MLATransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - q_lora_rank=32, - kv_lora_rank=32, - qk_head_dim=128, - v_head_dim=128, - qk_pos_emb_head_dim=64, - rotary_base=10000, - max_position_embeddings=64, - context_parallel_size=self.context_parallel_size, - bf16=True, - rope_type=rope_type, - apply_rope_fusion=apply_rope_fusion, - ) - self.parallel_attention = MLASelfAttention( - self.transformer_config, - get_mla_self_attn_submodules(), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ).bfloat16() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_gpu_forward(self): - if is_te_min_version("2.5.0", check_equality=True): - config = self.parallel_attention.config - sequence_length = 64 - micro_batch_size = 2 - - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - ( - sequence_length // self.context_parallel_size, - micro_batch_size, - self.parallel_attention.config.hidden_size, - ) - ).bfloat16() - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - output, bias = self.parallel_attention(hidden_states, attention_mask) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length // self.context_parallel_size - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - def test_gpu_forward_thd(self): - if is_te_min_version("2.5.0", check_equality=True): - config = self.parallel_attention.config - sequence_length = 128 - micro_batch_size = 1 - cu_seqlens = [0, 16, 48, 64, 128] - self.parallel_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - ( - sequence_length // self.context_parallel_size, - micro_batch_size, - self.parallel_attention.config.hidden_size, - ) - ).bfloat16() - hidden_states = hidden_states.cuda() - - attention_mask = None - packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens) - - output, bias = self.parallel_attention( - hidden_states, attention_mask, packed_seq_params=packed_seq_params - ) - - assert config.recompute_granularity is None - assert output.shape[0] == sequence_length // self.context_parallel_size - assert output.shape[1] == micro_batch_size - assert output.shape[2] == config.hidden_size - assert bias.shape[0] == config.hidden_size - - -@pytest.mark.parametrize("rope_type", ('yarn', 'rope')) -class TestParallelMLAAttentionPrecision: - - @pytest.fixture(scope='function', autouse=True) - def setup_and_teardown(self, rope_type): - self._environ_backup = os.environ.copy() - os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = "0" - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = MLATransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - q_lora_rank=32, - kv_lora_rank=32, - qk_head_dim=128, - v_head_dim=128, - qk_pos_emb_head_dim=64, - rope_type=rope_type, - rotary_base=10000, - original_max_position_embeddings=32, - deterministic_mode=True, - hidden_dropout=0.0, - attention_dropout=0.0, - ) - self.parallel_attention = MLASelfAttention( - self.transformer_config, - get_mla_self_attn_submodules(), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - - def teardown_method(self, method): - os.environ.clear() - os.environ.update(self._environ_backup) - Utils.destroy_model_parallel() - - def test_gpu_forward_thd_precision(self): - if is_te_min_version("1.10.0"): - # use flash attention for hopper, future may support fused attention for ampere - _environ = os.environ.copy() - os.environ['NVTE_FUSED_ATTN'] = "1" - os.environ['NVTE_FLASH_ATTN'] = "0" - - config = self.parallel_attention.config - - self.parallel_attention.cuda().bfloat16() - - # Input shape - sequence_length = 32 - micro_batch_size = 4 - cu_seqlens = [0, 32, 64, 96, 128] - # sbhd input shape: [sequence length, batch size, hidden size] - hidden_states_sbhd = torch.rand( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - attention_mask_sbhd = torch.ones( - (1, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - # thd input shape: [sequence length * batch size, 1, hidden size] - hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16() - hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous() - hidden_states_thd = hidden_states_thd.view( - -1, 1, self.parallel_attention.config.hidden_size - ) - attention_mask_thd = None - packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens) - - # fine-grained check - query_sbhd, key_sbhd, value_sbhd = self.parallel_attention.get_query_key_value_tensors( - hidden_states_sbhd, None, None, None, None - ) - query_thd, key_thd, value_thd = self.parallel_attention.get_query_key_value_tensors( - hidden_states_thd, None, None, packed_seq_params, None - ) - _query_sbhd = query_sbhd.transpose(0, 1).contiguous().view(*query_thd.shape) - _key_sbhd = key_sbhd.transpose(0, 1).contiguous().view(*key_thd.shape) - _value_sbhd = value_sbhd.transpose(0, 1).contiguous().view(*value_thd.shape) - assert torch.equal(_query_sbhd, query_thd) - assert torch.equal(_key_sbhd, key_thd) - assert torch.equal(_value_sbhd, value_thd) - - core_attn_out_sbhd = self.parallel_attention.core_attention( - query_sbhd, - key_sbhd, - value_sbhd, - attention_mask_sbhd, - packed_seq_params=None, - attn_mask_type=self.parallel_attention.attn_mask_type, - ) - query_thd = query_thd.squeeze(1) - key_thd = key_thd.squeeze(1) - value_thd = value_thd.squeeze(1) - core_attn_out_thd = self.parallel_attention.core_attention( - query_thd, - key_thd, - value_thd, - attention_mask_thd, - packed_seq_params=packed_seq_params, - attn_mask_type=self.parallel_attention.attn_mask_type, - ) - core_attn_out_thd = core_attn_out_thd.reshape(core_attn_out_thd.size(0), 1, -1) - _core_attn_out_sbhd = ( - core_attn_out_sbhd.transpose(0, 1).contiguous().view(*core_attn_out_thd.shape) - ) - assert torch.equal(_core_attn_out_sbhd, core_attn_out_thd) - - output_sbhd, bias_sbhd = self.parallel_attention.linear_proj(core_attn_out_sbhd) - output_thd, bias_thd = self.parallel_attention.linear_proj(core_attn_out_thd) - _output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) - assert torch.equal(_output_sbhd, output_thd) - - output_thd_fine_grained = output_thd - bias_thd_fine_grained = bias_thd - - # E2E check - # sbhd - output_sbhd, bias_sbhd = self.parallel_attention( - hidden_states_sbhd, attention_mask_sbhd - ) - # thd - output_thd, bias_thd = self.parallel_attention( - hidden_states_thd, attention_mask_thd, packed_seq_params=packed_seq_params - ) - _output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) - assert torch.equal(_output_sbhd, output_thd) - assert bias_thd.shape == bias_sbhd.shape - assert torch.equal(bias_sbhd, bias_thd) - - assert torch.equal(output_thd, output_thd_fine_grained) - assert torch.equal(bias_thd, bias_thd_fine_grained) - - os.environ.clear() - os.environ.update(_environ) - - -@pytest.mark.experimental -@pytest.mark.skipif( - not is_te_min_version("2.5.0", check_equality=True), - reason="Requires TransformerEngine >= 2.5.0", -) -@pytest.mark.parametrize( - ("rope_type", "apply_rope_fusion"), - ( - ('rope', False), - ('yarn', False), - ('yarn', True), # apply_rope_fusion for MLA only works with YARN RoPE. - ), -) -class TestContextParallelMLAAttentionPrecision: - - @pytest.fixture(scope='function', autouse=True) - def setup_and_teardown(self, rope_type, apply_rope_fusion): - self._environ_backup = os.environ.copy() - os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = "0" - self.context_parallel_size = 4 - Utils.initialize_model_parallel(1, 1, context_parallel_size=self.context_parallel_size) - model_parallel_cuda_manual_seed(123) - self.transformer_config = MLATransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - q_lora_rank=32, - kv_lora_rank=32, - qk_head_dim=128, - v_head_dim=128, - qk_pos_emb_head_dim=64, - rotary_base=10000, - max_position_embeddings=64, - context_parallel_size=self.context_parallel_size, - bf16=True, - rope_type=rope_type, - apply_rope_fusion=apply_rope_fusion, - deterministic_mode=True, - hidden_dropout=0.0, - attention_dropout=0.0, - ) - self.parallel_attention = MLASelfAttention( - self.transformer_config, - get_mla_self_attn_submodules(), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ).bfloat16() - - def teardown_method(self, method): - os.environ.clear() - os.environ.update(self._environ_backup) - Utils.destroy_model_parallel() - - def test_gpu_forward_thd_precision(self): - if is_te_min_version("2.5.0", check_equality=True): - # use flash attention for hopper, future may support fused attention for ampere - _environ = os.environ.copy() - os.environ['NVTE_FUSED_ATTN'] = "1" - os.environ['NVTE_FLASH_ATTN'] = "0" - atol, rtol = 3e-4, 3e-4 - - self.parallel_attention.cuda().bfloat16() - - # Input shape - sequence_length = 32 - micro_batch_size = 4 - cu_seqlens = [0, 32, 64, 96, 128] - # sbhd input shape: [sequence length, batch size, hidden size] - hidden_states_sbhd = torch.rand( - ( - sequence_length // self.context_parallel_size, - micro_batch_size, - self.parallel_attention.config.hidden_size, - ) - ) - attention_mask_sbhd = None - # thd input shape: [sequence length * batch size, 1, hidden size] - hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16() - hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous() - hidden_states_thd = hidden_states_thd.view( - -1, 1, self.parallel_attention.config.hidden_size - ) - attention_mask_thd = None - packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens) - - # fine-grained check - query_sbhd, key_sbhd, value_sbhd = self.parallel_attention.get_query_key_value_tensors( - hidden_states_sbhd, None, None, None, None - ) - query_thd, key_thd, value_thd = self.parallel_attention.get_query_key_value_tensors( - hidden_states_thd, None, None, packed_seq_params, None - ) - _query_sbhd = query_sbhd.transpose(0, 1).contiguous().view(*query_thd.shape) - _key_sbhd = key_sbhd.transpose(0, 1).contiguous().view(*key_thd.shape) - _value_sbhd = value_sbhd.transpose(0, 1).contiguous().view(*value_thd.shape) - torch.testing.assert_close(_query_sbhd, query_thd, atol=1e-6, rtol=1e-6) - torch.testing.assert_close(_key_sbhd, key_thd, atol=1e-6, rtol=1e-6) - torch.testing.assert_close(_value_sbhd, value_thd, atol=1e-6, rtol=1e-6) - - core_attn_out_sbhd = self.parallel_attention.core_attention( - query_sbhd, - key_sbhd, - value_sbhd, - attention_mask_sbhd, - packed_seq_params=None, - attn_mask_type=self.parallel_attention.attn_mask_type, - ) - query_thd = query_thd.squeeze(1) - key_thd = key_thd.squeeze(1) - value_thd = value_thd.squeeze(1) - core_attn_out_thd = self.parallel_attention.core_attention( - query_thd, - key_thd, - value_thd, - attention_mask_thd, - packed_seq_params=packed_seq_params, - attn_mask_type=self.parallel_attention.attn_mask_type, - ) - core_attn_out_thd = core_attn_out_thd.reshape(core_attn_out_thd.size(0), 1, -1) - _core_attn_out_sbhd = ( - core_attn_out_sbhd.transpose(0, 1).contiguous().view(*core_attn_out_thd.shape) - ) - torch.testing.assert_close(_core_attn_out_sbhd, core_attn_out_thd, atol=atol, rtol=rtol) - - output_sbhd, bias_sbhd = self.parallel_attention.linear_proj(core_attn_out_sbhd) - output_thd, bias_thd = self.parallel_attention.linear_proj(core_attn_out_thd) - _output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) - torch.testing.assert_close(_output_sbhd, output_thd, atol=atol, rtol=rtol) - - output_thd_fine_grained = output_thd - bias_thd_fine_grained = bias_thd - - # E2E check - # sbhd - output_sbhd, bias_sbhd = self.parallel_attention( - hidden_states_sbhd, attention_mask_sbhd - ) - # thd - output_thd, bias_thd = self.parallel_attention( - hidden_states_thd, attention_mask_thd, packed_seq_params=packed_seq_params - ) - _output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) - torch.testing.assert_close(_output_sbhd, output_thd, atol=atol, rtol=rtol) - assert bias_thd.shape == bias_sbhd.shape - torch.testing.assert_close(bias_sbhd, bias_thd, atol=atol, rtol=rtol) - - assert torch.equal(output_thd, output_thd_fine_grained) - assert torch.equal(bias_thd, bias_thd_fine_grained) - - os.environ.clear() - os.environ.update(_environ) - - -@pytest.mark.experimental -@pytest.mark.skipif(not is_torch_min_version("2.5.0"), reason="Requires PyTorch >= 2.5.0") -class TestParallelMLAAttentionPrecisionWithRopeFusion: - - @pytest.fixture(scope='function', autouse=True) - def setup_and_teardown(self): - self._environ_backup = os.environ.copy() - os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = "0" - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = MLATransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - q_lora_rank=32, - kv_lora_rank=32, - qk_head_dim=128, - v_head_dim=128, - qk_pos_emb_head_dim=64, - rope_type="yarn", - rotary_base=10000, - max_position_embeddings=32, - deterministic_mode=True, - hidden_dropout=0.0, - attention_dropout=0.0, - apply_rope_fusion=True, - ) - self.parallel_attention = MLASelfAttention( - self.transformer_config, - get_mla_self_attn_submodules(), - layer_number=1, - attn_mask_type=AttnMaskType.causal, - ) - - def teardown_method(self, method): - os.environ.clear() - os.environ.update(self._environ_backup) - Utils.destroy_model_parallel() - - def test_gpu_forward_thd_precision(self): - if is_te_min_version("1.10.0"): - # use flash attention for hopper, future may support fused attention for ampere - _environ = os.environ.copy() - os.environ['NVTE_FUSED_ATTN'] = "1" - os.environ['NVTE_FLASH_ATTN'] = "0" - - config = self.parallel_attention.config - - self.parallel_attention.cuda().bfloat16() - - # Input shape - sequence_length = 32 - micro_batch_size = 4 - cu_seqlens = [0, 32, 64, 96, 128] - # sbhd input shape: [sequence length, batch size, hidden size] - hidden_states_sbhd = torch.rand( - (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) - ) - attention_mask_sbhd = torch.ones( - (1, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - # thd input shape: [sequence length * batch size, 1, hidden size] - hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16() - hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous() - hidden_states_thd = hidden_states_thd.view( - -1, 1, self.parallel_attention.config.hidden_size - ) - attention_mask_thd = None - packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens) - - # fine-grained check - query_sbhd, key_sbhd, value_sbhd = self.parallel_attention.get_query_key_value_tensors( - hidden_states_sbhd, None, None, None, None - ) - query_thd, key_thd, value_thd = self.parallel_attention.get_query_key_value_tensors( - hidden_states_thd, None, None, packed_seq_params, None - ) - _query_sbhd = query_sbhd.transpose(0, 1).contiguous().view(*query_thd.shape) - _key_sbhd = key_sbhd.transpose(0, 1).contiguous().view(*key_thd.shape) - _value_sbhd = value_sbhd.transpose(0, 1).contiguous().view(*value_thd.shape) - assert torch.equal(_query_sbhd, query_thd) - assert torch.equal(_key_sbhd, key_thd) - assert torch.equal(_value_sbhd, value_thd) - - core_attn_out_sbhd = self.parallel_attention.core_attention( - query_sbhd, - key_sbhd, - value_sbhd, - attention_mask_sbhd, - packed_seq_params=None, - attn_mask_type=self.parallel_attention.attn_mask_type, - ) - query_thd = query_thd.squeeze(1) - key_thd = key_thd.squeeze(1) - value_thd = value_thd.squeeze(1) - core_attn_out_thd = self.parallel_attention.core_attention( - query_thd, - key_thd, - value_thd, - attention_mask_thd, - packed_seq_params=packed_seq_params, - attn_mask_type=self.parallel_attention.attn_mask_type, - ) - core_attn_out_thd = core_attn_out_thd.reshape(core_attn_out_thd.size(0), 1, -1) - _core_attn_out_sbhd = ( - core_attn_out_sbhd.transpose(0, 1).contiguous().view(*core_attn_out_thd.shape) - ) - assert torch.equal(_core_attn_out_sbhd, core_attn_out_thd) - - output_sbhd, bias_sbhd = self.parallel_attention.linear_proj(core_attn_out_sbhd) - output_thd, bias_thd = self.parallel_attention.linear_proj(core_attn_out_thd) - _output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) - assert torch.equal(_output_sbhd, output_thd) - - output_thd_fine_grained = output_thd - bias_thd_fine_grained = bias_thd - - # E2E check - # sbhd - output_sbhd, bias_sbhd = self.parallel_attention( - hidden_states_sbhd, attention_mask_sbhd - ) - # thd - output_thd, bias_thd = self.parallel_attention( - hidden_states_thd, attention_mask_thd, packed_seq_params=packed_seq_params - ) - _output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) - assert torch.equal(_output_sbhd, output_thd) - assert bias_thd.shape == bias_sbhd.shape - assert torch.equal(bias_sbhd, bias_thd) - - assert torch.equal(output_thd, output_thd_fine_grained) - assert torch.equal(bias_thd, bias_thd_fine_grained) - - os.environ.clear() - os.environ.update(_environ) - - -@pytest.mark.experimental -@pytest.mark.parametrize( - ("rope_type", "apply_rope_fusion"), - [ - ("rope", False), - ("yarn", False), - ("yarn", True), # apply_rope_fusion for MLA only works with YARN RoPE. - ], -) -@pytest.mark.parametrize( - ("tp", "sp", "cp"), - [ - (4, False, 1), # TP w/o SP - (4, True, 1), # TP w/ SP - (1, False, 4), # CP - (2, False, 2), # CP + TP w/o SP - (2, True, 2), # CP + TP w/ SP - ], -) -@pytest.mark.skipif(not is_te_min_version("1.10.0"), reason="Requires TransformerEngine >= 1.10.0") -def test_parallel_multi_latent_attention_correctness( - tmp_path_dist_ckpt, rope_type, apply_rope_fusion, tp, sp, cp -): - if cp > 1 and not is_te_min_version("2.5.0", check_equality=True): - pytest.skip("MLA CP requires TransformerEngine >= 2.5.0") - if rope_type == "yarn" and apply_rope_fusion and not is_torch_min_version("2.5.0"): - pytest.skip("MLA yarn rope fusion requires PyTorch >= 2.5.0") - if ( - cp > 1 - and rope_type == "yarn" - and apply_rope_fusion - and not is_te_min_version("2.6.0", check_equality=True) - ): - pytest.skip("MLA CP + yarn rope fusion requires PyTorch >= 2.6.0") - - # Non-deterministic mode has bug to be fixed with MLA - _environ = os.environ.copy() - os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = "1" - os.environ['NVTE_FUSED_ATTN'] = "1" - os.environ['NVTE_FLASH_ATTN'] = "0" - - # Constants - seed = 123 - sequence_length = 256 - micro_batch_size = 4 - hidden_size = 128 - - # Model initialization function - def initialize_gpt_model(config, pre_process=True, post_process=True, vp_stage=None): - layer_spec = get_gpt_layer_with_transformer_engine_spec(multi_latent_attention=True) - gpt_model = GPTModel( - config=config, - transformer_layer_spec=layer_spec, - vocab_size=128, - max_sequence_length=sequence_length, - pre_process=pre_process, - post_process=post_process, - vp_stage=vp_stage, - ) - return gpt_model - - # Initialize baseline parallel state - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1 - ) - - # Initialize input hidden states - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - input_hidden_states = ( - torch.rand((sequence_length, micro_batch_size, hidden_size)) - .cuda() - .bfloat16() - .requires_grad_(True) - ) - - # Initialize transformer config - transformer_config = MLATransformerConfig( - num_layers=1, - hidden_size=hidden_size, - num_attention_heads=4, - q_lora_rank=32, - kv_lora_rank=32, - qk_head_dim=128, - v_head_dim=128, - qk_pos_emb_head_dim=64, - rotary_base=10000, - max_position_embeddings=64, - context_parallel_size=1, - tensor_model_parallel_size=1, - sequence_parallel=False, - bf16=True, - rope_type=rope_type, - apply_rope_fusion=apply_rope_fusion, - hidden_dropout=0.0, - attention_dropout=0.0, - ) - - with TempNamedDir(tmp_path_dist_ckpt / 'test_parallel_mla', sync=True) as ckpt_dir: - # Set argument - mock_args = parse_args(ignore_unknown_args=True) - set_args(mock_args) - - # Initialize baseline model - init_basic_mock_args(mock_args, 1, 1, bf16=True) - mock_args.context_parallel_size = 1 - mock_args.sequence_parallel = 1 - gpt_model = unwrap_model( - get_model(partial(initialize_gpt_model, config=transformer_config)) - ) - - # Initialize args and save checkpoint - init_checkpointing_mock_args(mock_args, ckpt_dir, False) - mock_args.no_save_optim = True - mock_args.no_save_rng = True - mock_args.no_load_optim = True - mock_args.no_load_rng = True - save_checkpoint(10, gpt_model, None, None, 0) - - # Calculate baseline output - attention = gpt_model[0].decoder.layers[0].self_attention - output_hidden_states_baseline, bias_hidden_states_baseline = attention( - input_hidden_states, attention_mask=None - ) - output_hidden_states_baseline.sum().backward() - - # Save baseline output - input_grad_baseline = input_hidden_states.grad.detach() - output_hidden_states_baseline = output_hidden_states_baseline.detach() - bias_hidden_states_baseline = bias_hidden_states_baseline.detach() - - # Initialize parallel model - Utils.destroy_model_parallel() - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp, pipeline_model_parallel_size=1, context_parallel_size=cp - ) - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - transformer_config.context_parallel_size = cp - transformer_config.tensor_model_parallel_size = tp - transformer_config.sequence_parallel = sp - init_basic_mock_args(mock_args, tp, 1, bf16=True) - mock_args.context_parallel_size = cp - mock_args.sequence_parallel = sp - gpt_model = unwrap_model( - get_model(partial(initialize_gpt_model, config=transformer_config)) - ) - with mock.patch('megatron.training.checkpointing.check_checkpoint_args'): - with mock.patch('megatron.training.checkpointing.update_num_microbatches'): - load_checkpoint(gpt_model, None, None) - - # Function to get tensor on this tp and cp rank - cp_group = parallel_state.get_context_parallel_group() - tp_rank = parallel_state.get_tensor_model_parallel_rank() - - def get_tensor_on_this_rank(tensor): - if cp > 1: - tensor = get_tensor_on_this_cp_rank(tensor, 0, cp_group) - if tp > 1 and sp: - sp_seg = sequence_length // tp // cp - tensor = tensor[tp_rank * sp_seg : (tp_rank + 1) * sp_seg] - return tensor - - # Calculate parallel model output - input_hidden_states = get_tensor_on_this_rank(input_hidden_states) - input_hidden_states = input_hidden_states.detach().requires_grad_(True) - parallel_attention = gpt_model[0].decoder.layers[0].self_attention - output_hidden_states_parallel, bias_hidden_states_parallel = parallel_attention( - input_hidden_states, attention_mask=None - ) - output_hidden_states_parallel.sum().backward() - input_grad_parallel = input_hidden_states.grad.detach() - - # Check if the output is the same - if cp: - atol, rtol = 5e-3, 5e-3 - else: - atol, rtol = 5e-4, 5e-4 - output_hidden_states_baseline = get_tensor_on_this_rank(output_hidden_states_baseline) - input_grad_baseline = get_tensor_on_this_rank(input_grad_baseline) - - assert torch.all( - ~torch.isnan(output_hidden_states_baseline) - ), "output_hidden_states_baseline contains nan" - assert torch.all( - ~torch.isinf(output_hidden_states_baseline) - ), "output_hidden_states_baseline contains inf" - assert torch.all( - ~torch.isnan(bias_hidden_states_baseline) - ), "bias_hidden_states_baseline contains nan" - assert torch.all( - ~torch.isinf(bias_hidden_states_baseline) - ), "bias_hidden_states_baseline contains inf" - assert torch.all(~torch.isnan(input_grad_baseline)), "input_grad_baseline contains nan" - assert torch.all(~torch.isinf(input_grad_baseline)), "input_grad_baseline contains inf" - assert torch.all( - ~torch.isnan(output_hidden_states_parallel) - ), "output_hidden_states_parallel contains nan" - assert torch.all( - ~torch.isinf(output_hidden_states_parallel) - ), "output_hidden_states_parallel contains inf" - assert torch.all( - ~torch.isnan(bias_hidden_states_parallel) - ), "bias_hidden_states_parallel contains nan" - assert torch.all( - ~torch.isinf(bias_hidden_states_parallel) - ), "bias_hidden_states_parallel contains inf" - assert torch.all(~torch.isnan(input_grad_parallel)), "input_grad_parallel contains nan" - assert torch.all(~torch.isinf(input_grad_parallel)), "input_grad_parallel contains inf" - - torch.testing.assert_close( - output_hidden_states_baseline, - output_hidden_states_parallel, - atol=atol, - rtol=rtol, - msg=lambda msg: f"Mismatch in output_hidden_states: {msg}", - ) - torch.testing.assert_close( - bias_hidden_states_baseline, - bias_hidden_states_parallel, - atol=atol, - rtol=rtol, - msg=lambda msg: f"Mismatch in bias_hidden_states: {msg}", - ) - torch.testing.assert_close( - input_grad_baseline, - input_grad_parallel, - atol=atol, - rtol=rtol, - msg=lambda msg: f"Mismatch in input_grad: {msg}", - ) - - Utils.destroy_model_parallel() - - os.environ.clear() - os.environ.update(_environ) diff --git a/tests/unit_tests/transformer/test_multi_token_prediction.py b/tests/unit_tests/transformer/test_multi_token_prediction.py deleted file mode 100644 index a3c456a2cd..0000000000 --- a/tests/unit_tests/transformer/test_multi_token_prediction.py +++ /dev/null @@ -1,480 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import os -import sys - -import pytest -import torch - -from megatron.core.enums import ModelType -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, - get_gpt_mtp_block_spec, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.multi_token_prediction import ( - MTPLossLoggingHelper, - MultiTokenPredictionBlock, -) -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version -from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args -from megatron.training.checkpointing import load_checkpoint, save_checkpoint -from megatron.training.global_vars import ( - destroy_global_vars, - get_args, - set_args, - set_global_variables, -) -from megatron.training.training import get_model, setup_model_and_optimizer -from megatron.training.utils import get_batch_on_this_cp_rank, unwrap_model -from tests.unit_tests.dist_checkpointing import TempNamedDir -from tests.unit_tests.test_utilities import Utils - -try: - from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear - - HAVE_TE = True -except ImportError: - HAVE_TE = False - -_SEED = 42 - - -class TestMultiTokenPredictionLayer: - def setup_method(self, method): - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - - def teardown_method(self, method): - Utils.destroy_model_parallel() - destroy_global_vars() - destroy_num_microbatches_calculator() - - def _create_config_and_mtp_block_spec(self, tp, cp, use_te=False): - Utils.initialize_model_parallel(tensor_model_parallel_size=tp, context_parallel_size=cp) - config = TransformerConfig( - mtp_num_layers=2, - num_layers=4, - hidden_size=64, - num_attention_heads=8, - use_cpu_initialization=True, - tensor_model_parallel_size=tp, - sequence_parallel=True if tp > 1 else False, - context_parallel_size=cp, # Enable CP for MTP testing - ) - if use_te: - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec() - else: - transformer_layer_spec = get_gpt_layer_local_spec() - mtp_block_spec = get_gpt_mtp_block_spec( - config=config, spec=transformer_layer_spec, use_transformer_engine=use_te - ) - return config, mtp_block_spec - - @pytest.mark.parametrize(('tp'), [(1), (2), (4)]) - def test_constructor_local(self, tp): - """Test basic construction of MTP module.""" - - torch.manual_seed(_SEED) - config, mtp_block_spec = self._create_config_and_mtp_block_spec(tp, cp=1) - mtp = MultiTokenPredictionBlock(config=config, spec=mtp_block_spec) - - assert isinstance(mtp, MultiTokenPredictionBlock) - assert len(mtp.layers) == config.mtp_num_layers - for i in range(config.mtp_num_layers): - assert mtp.layers[i].layer_number == i + 1 - assert mtp.layers[i].enorm.weight.shape[0] == config.hidden_size - assert mtp.layers[i].hnorm.weight.shape[0] == config.hidden_size - assert mtp.layers[i].eh_proj.weight.shape[0] == config.hidden_size / tp - assert mtp.layers[i].eh_proj.weight.shape[1] == config.hidden_size * 2 - assert mtp.layers[i].transformer_layer is not None - num_weights = sum([p.numel() for p in mtp.parameters()]) - if tp == 1: - assert num_weights == 58560 * config.mtp_num_layers - elif tp == 2: - assert num_weights == 29664 * config.mtp_num_layers - elif tp == 4: - assert num_weights == 15216 * config.mtp_num_layers - - @pytest.mark.skipif(not HAVE_TE, reason="transformer_engine not available") - @pytest.mark.parametrize(('tp', 'cp'), [(1, 1), (1, 2), (2, 1), (2, 2)]) - def test_constructor_ues_te(self, tp, cp): - """Test basic construction of MTP module.""" - torch.manual_seed(_SEED) - Utils.initialize_model_parallel(tensor_model_parallel_size=tp, context_parallel_size=cp) - config, mtp_block_spec = self._create_config_and_mtp_block_spec(tp, cp, use_te=True) - mtp = MultiTokenPredictionBlock(config=config, spec=mtp_block_spec) - - assert isinstance(mtp, MultiTokenPredictionBlock) - assert len(mtp.layers) == config.mtp_num_layers - for i in range(config.mtp_num_layers): - assert mtp.layers[i].layer_number == i + 1 - assert mtp.layers[i].enorm.weight.shape[0] == config.hidden_size - assert mtp.layers[i].hnorm.weight.shape[0] == config.hidden_size - assert mtp.layers[i].eh_proj.weight.shape[0] == config.hidden_size / tp - assert mtp.layers[i].eh_proj.weight.shape[1] == config.hidden_size * 2 - assert mtp.layers[i].transformer_layer is not None - num_weights = sum([p.numel() for p in mtp.parameters()]) - if tp == 1: - assert num_weights == 58560 * config.mtp_num_layers - elif tp == 2: - assert num_weights == 29664 * config.mtp_num_layers - elif tp == 4: - assert num_weights == 15216 * config.mtp_num_layers - - -class TestMultiTokenPrediction: - def setup_method(self, method): - self.seq_length = 32 - self.micro_batch_size = 2 - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - - def teardown_method(self, method): - Utils.destroy_model_parallel() - destroy_global_vars() - destroy_num_microbatches_calculator() - MTPLossLoggingHelper.tracker = {} - - def model_provider( - self, - pre_process=True, - post_process=True, - layer_spec_fn=get_gpt_layer_with_transformer_engine_spec, - **config_kwargs, - ): - model_parallel_cuda_manual_seed(_SEED) - args = get_args() - config = core_transformer_config_from_args(args) - transformer_layer_spec = layer_spec_fn( - args.num_experts, args.moe_grouped_gemm, args.qk_layernorm - ) - mtp_block_spec = get_gpt_mtp_block_spec( - config=config, spec=transformer_layer_spec, use_transformer_engine=True - ) - model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - mtp_block_spec=mtp_block_spec, - vocab_size=args.vocal_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - ) - - return model - - def create_test_args( - self, tp, cp, sequence_length, micro_batch_size, fp8=None, full_recompute=False - ): - destroy_global_vars() - destroy_num_microbatches_calculator() - - sys.argv = ['test_multi_token_predictioin.py'] - args = parse_args() - args.num_layers = 2 - args.mtp_num_layers = 2 - args.mtp_loss_scaling_factor = 0.1 - args.vocal_size = 128800 - args.hidden_size = 128 - args.num_attention_heads = 8 - args.max_position_embeddings = 256 - args.micro_batch_size = micro_batch_size - args.create_attention_mask_in_dataloader = True - args.seq_length = sequence_length - args.tensor_model_parallel_size = tp - args.sequence_parallel = True if tp > 1 else False - args.context_parallel_size = cp - args.position_embedding_type = 'rope' - args.num_experts = 8 - args.train_iters = 1 - args.ckpt_format = 'torch_dist' - args.moe_router_topk = 2 - args.moe_router_pre_softmax = False - args.lr = 3e-5 - args.attention_dropout = 0.0 - args.hidden_dropout = 0.0 - args.async_tensor_model_parallel_allreduce = False - args.no_save_optim = True - args.no_load_optim = True - args.no_load_rng = True - if HAVE_TE: - # only use grouped gemm if there is TE - args.moe_grouped_gemm = True - else: - args.moe_grouped_gemm = False - args.bf16 = True - if fp8 is not None: - args.fp8 = 'e4m3' - if full_recompute: - args.recompute_granularity = 'full' - args.recompute_method = 'uniform' - args.recompute_num_layers = 1 - else: - args.recompute_granularity = None - args.add_bias_linear = False - args.swiglu = True - - validate_args(args) - set_global_variables(args, False) - return args - - def get_batch(self, seq_length, micro_batch_size): - data = list(range(seq_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - labels = 1 + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, seq_length, seq_length), dtype=bool - ).cuda() - loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda() - batch = { - 'tokens': input_ids, - 'labels': labels, - 'loss_mask': loss_mask, - 'attention_mask': attention_mask, - 'position_ids': position_ids, - } - return batch - - @pytest.mark.skipif( - not HAVE_TE or not is_te_min_version("2.1.0"), - reason="grouped_gemm requires TransformerEngine >= 2.1.0", - ) - @pytest.mark.parametrize(("tp", "cp"), [(1, 1), (1, 2), (2, 1), (2, 2)]) - def test_sharded_state_dict(self, tp, cp): - """Test MTP with different tensor parallel sizes.""" - args = self.create_test_args(tp, cp, self.seq_length, self.micro_batch_size) - set_args(args) - torch.manual_seed(_SEED) - Utils.initialize_model_parallel(tensor_model_parallel_size=tp, context_parallel_size=cp) - gpt_model = get_model(self.model_provider, ModelType.encoder_or_decoder) - gpt_model = unwrap_model(gpt_model) - sharded_state_dict = gpt_model[0].sharded_state_dict() - for i in range(args.mtp_num_layers): - assert f"mtp.layers.{i}.enorm.weight" in sharded_state_dict.keys() - assert f"mtp.layers.{i}.hnorm.weight" in sharded_state_dict.keys() - assert f"mtp.layers.{i}.eh_proj.weight" in sharded_state_dict.keys() - - @pytest.mark.skipif( - not HAVE_TE or not is_te_min_version("2.1.0"), - reason="grouped_gemm requires TransformerEngine >= 2.1.0", - ) - @pytest.mark.parametrize("full_recompute", [False, True]) - @pytest.mark.parametrize( - ("tp", "cp"), [(1, 1), (1, 2), (1, 4), (2, 1), (2, 2), (2, 4), (4, 1), (4, 2)] - ) - def test_forward_backward(self, tmp_path_dist_ckpt, tp, cp, full_recompute): - """Test MTP forward and backward with gptmodel.""" - tp_ref = 1 - cp_ref = 1 - args = self.create_test_args(tp_ref, cp_ref, self.seq_length, self.micro_batch_size) - set_args(args) - torch.manual_seed(_SEED) - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_ref, context_parallel_size=cp_ref - ) - batch = self.get_batch(self.seq_length, self.micro_batch_size) - tokens, labels, loss_mask, attention_mask, position_ids = batch.values() - gpt_model_ref, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder - ) - output_ref = gpt_model_ref[0].forward( - input_ids=tokens, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - ) - tracker = MTPLossLoggingHelper.tracker - mtp_loss_ref = None - assert "values" in tracker - mtp_loss_ref = tracker['values'].clone() - MTPLossLoggingHelper.clean_loss_in_tracker() - - iteration = 123 - num_floating_point_operations_so_far = 456 - - def set_ckpt_path(ckpt_path): - args.save = ckpt_path - args.load = ckpt_path - - with TempNamedDir( - tmp_path_dist_ckpt / 'test_mtp_model_reconfiguration_model_A' - ) as ckpt_dir_A: - set_ckpt_path(ckpt_dir_A) - save_checkpoint( - iteration, - gpt_model_ref, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - ) - - expected_ckpt_path = args.save / "iter_0000123" / ".metadata" - assert os.path.exists(expected_ckpt_path) - - # Test with different TP/CP configuration - Utils.destroy_model_parallel() - args = self.create_test_args( - tp, cp, self.seq_length, self.micro_batch_size, full_recompute=full_recompute - ) - set_args(args) - set_ckpt_path(ckpt_dir_A) - torch.manual_seed(_SEED) - Utils.initialize_model_parallel(tensor_model_parallel_size=tp, context_parallel_size=cp) - gpt_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder - ) - load_checkpoint(gpt_model, optimizer, opt_param_scheduler, strict=False) - batch["output_ref"] = output_ref - # Get batch for current CP rank (handles CP tensor splitting) - batch = get_batch_on_this_cp_rank(batch) - tokens, labels, loss_mask, attention_mask, position_ids, output_ref = batch.values() - output = gpt_model[0].forward( - input_ids=tokens, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - ) - tracker = MTPLossLoggingHelper.tracker - assert "values" in tracker - mtp_loss = tracker['values'].clone() - # Average MTP loss across CP ranks for comparison with reference - model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups(required_pgs=['cp']) - torch.distributed.all_reduce( - mtp_loss, group=model_comm_pgs.cp, op=torch.distributed.ReduceOp.AVG - ) - MTPLossLoggingHelper.clean_loss_in_tracker() - assert torch.allclose(output_ref, output, rtol=1e-03, atol=1e-03) - assert torch.allclose(mtp_loss, mtp_loss_ref, rtol=1e-02, atol=1e-02) - - # Check output shapes - sequence length is divided by CP size - assert output.shape[0] == self.micro_batch_size - assert output.shape[1] == self.seq_length / cp - - # Verify gradients - loss = output.mean() - loss.backward() - # for param in gpt_model[0].parameters(): - for name, param in gpt_model[0].named_parameters(): - assert param.main_grad is not None - - @pytest.mark.skipif( - not HAVE_TE or not is_te_min_version("1.7.0"), - reason="Only transformer-engine>=1.7.0 supports MoE FP8 training", - ) - @pytest.mark.parametrize("full_recompute", [False, True]) - def test_fp8_support(self, full_recompute): - """Test MTP with FP8 training enabled.""" - tp = 1 - cp = 1 - fp8 = 'e4m3' - args = self.create_test_args( - tp, cp, self.seq_length, self.micro_batch_size, fp8, full_recompute=full_recompute - ) - set_args(args) - - torch.manual_seed(_SEED) - Utils.initialize_model_parallel(tensor_model_parallel_size=tp, context_parallel_size=cp) - batch = self.get_batch(self.seq_length, self.micro_batch_size) - tokens, labels, loss_mask, attention_mask, position_ids = batch.values() - gpt_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder - ) - - output = gpt_model[0].forward( - input_ids=tokens, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - ) - - assert output.dtype == torch.float32 # Output should be converted back to float32 - - loss = output.mean() - loss.backward() - - -class TestMTPLossLoggingHelper: - def setup_method(self, method): - self.num_layers = 4 - # Reset the tracker before each test - MTPLossLoggingHelper.tracker = {} - - def teardown_method(self, method): - # Clean up the tracker after each test - MTPLossLoggingHelper.tracker = {} - - def test_save_loss_to_tracker(self): - """Test saving loss to tracker.""" - # Create a dummy loss tensor - loss = torch.tensor(1.3) - layer_number = 2 - num_layers = self.num_layers - - # Test saving loss - MTPLossLoggingHelper.save_loss_to_tracker( - loss=loss, layer_number=layer_number, num_layers=num_layers - ) - - # Verify tracker state - assert "values" in MTPLossLoggingHelper.tracker - assert MTPLossLoggingHelper.tracker["values"].shape == (num_layers,) - assert MTPLossLoggingHelper.tracker["values"][layer_number] == loss - assert MTPLossLoggingHelper.tracker["reduce_group"] is None - assert MTPLossLoggingHelper.tracker["avg_group"] is None - - def test_track_mtp_metrics(self): - """Test tracking MTP metrics.""" - # First save some losses - loss = torch.tensor(2.3) - num_layers = self.num_layers - for i in range(num_layers): - MTPLossLoggingHelper.save_loss_to_tracker( - loss=loss, layer_number=i, num_layers=num_layers - ) - - # Create dummy writer and loss dict - class DummyWriter: - def add_scalar(self, name, value, iteration): - pass - - class DummyWandBWriter: - def log(self, metrics, iteration): - pass - - loss_scale = 1.5 - iteration = 2 - writer = DummyWriter() - wandb_writer = DummyWandBWriter() - total_loss_dict = {} - - # Test tracking metrics - MTPLossLoggingHelper.track_mtp_metrics( - loss_scale=loss_scale, - iteration=iteration, - writer=writer, - wandb_writer=wandb_writer, - total_loss_dict=total_loss_dict, - ) - - # Verify total_loss_dict is populated - for i in range(num_layers): - assert f"mtp_{i+1} loss" in total_loss_dict - assert total_loss_dict[f"mtp_{i+1} loss"] == loss * loss_scale - - # Verify tracker is cleaned - assert torch.all(MTPLossLoggingHelper.tracker["values"] == 0) - assert MTPLossLoggingHelper.tracker["reduce_group"] is None - assert MTPLossLoggingHelper.tracker["avg_group"] is None diff --git a/tests/unit_tests/transformer/test_quantization_config.py b/tests/unit_tests/transformer/test_quantization_config.py deleted file mode 100644 index fe57934bde..0000000000 --- a/tests/unit_tests/transformer/test_quantization_config.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import Any, Dict - -import pytest - -from megatron.core.quantization.quant_config import GlobMatcher, MatchContext, RecipeConfig - -try: - import nvidia_kitchen - from nvidia_kitchen.config import ( - AutogradFunctionImplementation, - QuantizeRecipe, - get_qlinear_params_from_predefined, - ) - - from megatron.core.extensions.kitchen import QLinearParamsConfigSchema - - HAVE_KITCHEN = True -except ImportError: - HAVE_KITCHEN = False - - -def test_recipe_config_matching() -> None: - - recipe_config = RecipeConfig( - [ - GlobMatcher("*fc2", "fc2_cfg"), - GlobMatcher("*fc*", "fc_cfg"), - GlobMatcher("*", "default"), - ], - {"fc2_cfg": {"fc2": "foo"}, "fc_cfg": {"fc1": "bar"}, "default": {"default": "baz"}}, - ) - - assert ( - recipe_config.match_to_config_key(MatchContext("decoder.1.linear_fc2", layer_number=1)) - == "fc2_cfg" - ) - assert ( - recipe_config.match_to_config_key(MatchContext("decoder.1.linear_fc1", layer_number=1)) - == "fc_cfg" - ) - assert ( - recipe_config.match_to_config_key(MatchContext("decoder.1.linear_qkv", layer_number=1)) - == "default" - ) - - -@pytest.mark.skipif(not HAVE_KITCHEN, reason="Kitchen required for using kitchen backend.") -def test_parse_qlinear_params_example() -> None: - qat_params = 2 - config = {"kitchen_config_type": "QLinearParams", "recipe_idx": qat_params} - qlinear_params_actual = QLinearParamsConfigSchema.parse_config_dict(config).to_kitchen_qlinear() - qlinear_params_expected = get_qlinear_params_from_predefined(QuantizeRecipe.FP8_CS) - assert qlinear_params_actual.x_params == qlinear_params_expected.x_params - assert qlinear_params_actual.w_params == qlinear_params_expected.w_params - assert qlinear_params_actual.g_params == qlinear_params_expected.g_params - assert qlinear_params_actual.mm_fprop == qlinear_params_expected.mm_fprop - assert qlinear_params_actual.mm_dgrad == qlinear_params_expected.mm_dgrad - assert qlinear_params_actual.mm_wgrad == qlinear_params_expected.mm_wgrad - assert ( - qlinear_params_actual.autograd_function_implementation - == AutogradFunctionImplementation.QUANTIZED - ) - - -@pytest.mark.skipif(not HAVE_KITCHEN, reason="Kitchen required for using kitchen backend.") -def test_error_from_malformed() -> None: - qat_params = 2 - config: Dict[Any, Any] = {"recipe_idx": qat_params} - with pytest.raises(KeyError, match="Missing required keys"): - qlinear_params_actual = QLinearParamsConfigSchema.parse_config_dict(config) - config = {"kitchen_config_type": "QLinearParams"} - with pytest.raises(KeyError, match="Missing required keys"): - qlinear_params_actual = QLinearParamsConfigSchema.parse_config_dict(config) - config = {"kitchen_config_type": "QUnknownParams", "recipe_idx": qat_params} - with pytest.raises(ValueError, match="Unsupported config type"): - qlinear_params_actual = QLinearParamsConfigSchema.parse_config_dict(config) - config = {"kitchen_config_type": "QLinearParams", "recipe_idx": "MyRecipe"} - with pytest.raises(ValueError, match="recipe_idx must be a positive integer"): - qlinear_params_actual = QLinearParamsConfigSchema.parse_config_dict(config) - config = { - "kitchen_config_type": "QLinearParams", - "recipe_idx": qat_params, - "extra_key": "extra_value", - } - with pytest.raises(KeyError, match="Unexpected keys in config"): - qlinear_params_actual = QLinearParamsConfigSchema.parse_config_dict(config) diff --git a/tests/unit_tests/transformer/test_relative_attention.py b/tests/unit_tests/transformer/test_relative_attention.py deleted file mode 100644 index dd1d4c02ab..0000000000 --- a/tests/unit_tests/transformer/test_relative_attention.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch -import torch.nn.init as init - -from megatron.core.models.common.embeddings.relative_pos_embedding import RelativePositionEmbedding -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from tests.unit_tests.test_utilities import Utils - - -class TestRelativePositionEmbedding: - def setup_method(self): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.num_heads = 12 - self.relative_pos_emb = RelativePositionEmbedding( - bidirectional=True, - init_method=init.normal_, - num_attention_heads=self.num_heads, - relative_attention_num_buckets=32, - relative_attention_max_distance=128, - ) - - def teardown_method(self, method): - del self.relative_pos_emb - Utils.destroy_model_parallel() - - def test_constructor(self): - assert isinstance(self.relative_pos_emb, RelativePositionEmbedding) - - def test_forward(self): - self.query_seq_length = 512 - output = self.relative_pos_emb(self.query_seq_length, self.query_seq_length) - assert output.shape[0] == 1 - assert output.shape[1] == self.num_heads - assert output.shape[2] == self.query_seq_length - assert output.shape[3] == self.query_seq_length diff --git a/tests/unit_tests/transformer/test_retro_attention.py b/tests/unit_tests/transformer/test_retro_attention.py deleted file mode 100644 index e735105111..0000000000 --- a/tests/unit_tests/transformer/test_retro_attention.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import os -import types - -import pytest -import torch - -from megatron.core.models.retro import RetroConfig, get_retro_decoder_block_spec -from megatron.core.models.retro.decoder_attention import ( - RetroDecoderBiasDropoutAdd, - RetroDecoderCrossAttention, -) -from megatron.core.models.retro.encoder_attention import ( - RetroEncoderBiasDropoutAdd, - RetroEncoderCrossAttention, - RetroEncoderLayerNorm, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_block import TransformerBlock -from tests.unit_tests.test_utilities import Utils - - -class TestRetroAttention: - - @classmethod - def get_config(cls): - return RetroConfig( - num_layers=12, - hidden_size=16, - num_attention_heads=4, - use_cpu_initialization=True, - retro_num_neighbors=2, - retro_chunk_length=4, - retro_retrieved_length=8, - retro_split_preprocessing="98,2,0", - ) - - @classmethod - def get_modules(cls, config, use_transformer_engine, use_gpu): - - # Retro decoder layer. - decoder_block_spec = get_retro_decoder_block_spec( - config, use_transformer_engine=use_transformer_engine - ) - decoder_block = TransformerBlock(config=config, spec=decoder_block_spec) - decoder_layers = [ - layer - for layer in decoder_block.layers - if isinstance(layer.cross_attention, RetroDecoderCrossAttention) - ] - decoder_layer = decoder_layers[0] - - # Retro encoder layer. - encoder_block = decoder_layer.cross_attention.encoder - encoder_layers = [ - layer - for layer in encoder_block.layers - if isinstance(layer.cross_attention, RetroEncoderCrossAttention) - ] - encoder_layer = encoder_layers[0] - - # Modules. - modules = types.SimpleNamespace( - decoder_attn=decoder_layer.cross_attention, - decoder_bda=decoder_layer.cross_attn_bda, - encoder_attn=encoder_layer.cross_attention, - encoder_bda=encoder_layer.cross_attn_bda, - encoder_norm=encoder_layer.pre_mlp_layernorm, - ) - - # GPU. - if use_gpu: - [m.cuda() for m in vars(modules).values()] - - return modules - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - os.environ['NVTE_FLASH_ATTN'] = "0" - os.environ['NVTE_FUSED_ATTN'] = "0" - - model_parallel_cuda_manual_seed(123) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - - config = self.get_config() - modules = self.get_modules(config, use_transformer_engine=True, use_gpu=False) - - assert isinstance(modules.decoder_attn, RetroDecoderCrossAttention) - assert isinstance(modules.decoder_bda, RetroDecoderBiasDropoutAdd) - assert isinstance(modules.encoder_attn, RetroEncoderCrossAttention) - assert isinstance(modules.encoder_bda, RetroEncoderBiasDropoutAdd) - assert isinstance(modules.encoder_norm, RetroEncoderLayerNorm) - - assert modules.decoder_attn.attn.layer_number == 6 - assert modules.encoder_attn.attn.layer_number == 1 - - get_nparams = lambda m: sum(p.numel() for p in m.parameters()) - assert get_nparams(modules.decoder_attn) == 8768 - assert get_nparams(modules.decoder_bda) == 0 - assert get_nparams(modules.encoder_attn) == 1088 - assert get_nparams(modules.encoder_bda) == 0 - assert get_nparams(modules.encoder_norm) == 32 - - def test_cpu_forward(self): - # we can't currently do this because the global memory buffer is on GPU - pass - - def run_gpu_forward(self, recompute_granularity, use_transformer_engine): - - config = self.get_config() - config.recompute_granularity = recompute_granularity - modules = self.get_modules(config, use_transformer_engine, use_gpu=True) - - seq_length = 32 - micro_batch_size = 2 - n_chunks_per_sample = seq_length // config.retro_chunk_length - - # Init tensors. - hidden_states = torch.ones((seq_length, micro_batch_size, config.hidden_size)).cuda() - attention_mask = None - decoder_context = torch.ones( - ( - config.retro_retrieved_length, - config.retro_num_neighbors * micro_batch_size * n_chunks_per_sample, - config.hidden_size, - ) - ).cuda() - encoder_context = torch.ones( - (config.retro_chunk_length, micro_batch_size * n_chunks_per_sample, config.hidden_size) - ).cuda() - - # Forward decoder. - decoder_attn_output = modules.decoder_attn(hidden_states, attention_mask, decoder_context) - with torch.enable_grad(): - decoder_bda_output = modules.decoder_bda(True, True)( - decoder_attn_output, hidden_states, config.hidden_dropout - ) - - # Forward encoder. - encoder_attn_output_tuples = modules.encoder_attn(decoder_context, None, encoder_context) - with torch.enable_grad(): - encoder_bda_output = modules.encoder_bda(True, True)( - encoder_attn_output_tuples, decoder_context, config.retro_encoder_hidden_dropout - ) - encoder_norm_output = modules.encoder_norm(encoder_bda_output) - - # Verify decoder. - assert set(decoder_attn_output.keys()) == set( - ["ns", "bs", "d", "l", "pad", "attention_output", "attention_bias", "context"] - ) - assert decoder_attn_output["ns"] == seq_length - assert decoder_attn_output["bs"] == micro_batch_size - assert decoder_attn_output["d"] == config.hidden_size - assert decoder_attn_output["l"] == n_chunks_per_sample - assert decoder_attn_output["pad"] == 3 - assert tuple(decoder_attn_output["attention_output"].shape) == ( - config.retro_chunk_length, - micro_batch_size * n_chunks_per_sample, - config.hidden_size, - ) - assert tuple(decoder_attn_output["attention_bias"].shape) == (config.hidden_size,) - assert decoder_attn_output["context"].shape == ( - config.retro_retrieved_length * config.retro_num_neighbors, - micro_batch_size * n_chunks_per_sample, - config.hidden_size, - ) - assert decoder_bda_output.shape == hidden_states.shape - - # Verify encoder. - assert len(encoder_attn_output_tuples) == config.retro_num_neighbors - for output, bias, residual in encoder_attn_output_tuples: - assert tuple(output.shape) == ( - config.retro_retrieved_length, - micro_batch_size * n_chunks_per_sample, - config.hidden_size, - ) - assert tuple(bias.shape) == (config.hidden_size,) - assert tuple(residual.shape) == ( - config.retro_retrieved_length, - micro_batch_size * n_chunks_per_sample, - config.hidden_size, - ) - assert encoder_bda_output.shape == ( - config.retro_retrieved_length, - config.retro_num_neighbors * micro_batch_size * n_chunks_per_sample, - config.hidden_size, - ) - assert encoder_norm_output.shape == ( - config.retro_retrieved_length, - config.retro_num_neighbors * micro_batch_size * n_chunks_per_sample, - config.hidden_size, - ) - - @pytest.mark.flaky_in_dev - def test_gpu_forward(self): - for recompute_granularity in (None, 'selective'): - for use_transformer_engine in (True, False): - self.run_gpu_forward(recompute_granularity, use_transformer_engine) diff --git a/tests/unit_tests/transformer/test_rope.py b/tests/unit_tests/transformer/test_rope.py deleted file mode 100644 index a088427ad8..0000000000 --- a/tests/unit_tests/transformer/test_rope.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import pytest -import torch - -from megatron.core.models.common.embeddings.rotary_pos_embedding import ( - MultimodalRotaryEmbedding, - RotaryEmbedding, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from tests.unit_tests.test_utilities import Utils - - -class TestMultimodalRotaryEmbedding: - def setup_method(self): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.kv_channels = 128 - self.rotary_percent = 1.0 - self.rope_gpu_init = MultimodalRotaryEmbedding(self.kv_channels, self.rotary_percent) - - def teardown_method(self, method): - del self.rope_gpu_init - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_constructor(self): - assert isinstance(self.rope_gpu_init, MultimodalRotaryEmbedding) - assert self.rope_gpu_init.inv_freq.device.type == 'cuda' - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_gpu_forward(self): - output = self.rope_gpu_init(torch.Tensor(3, 1, 64), mrope_section=[16, 24, 24]) - assert output.shape[0] == 64 - assert output.shape[1] == 1 - assert output.shape[2] == 1 - assert output.shape[3] == self.kv_channels - assert output.dtype == torch.float32 - assert output.device.type == 'cuda' - - -class TestRotaryEmbedding: - def setup_method(self): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.kv_channels = 8 - self.rotary_percent = 1.0 - self.rope_cpu_init = RotaryEmbedding( - self.kv_channels, self.rotary_percent, use_cpu_initialization=True - ) - self.rope_gpu_init = RotaryEmbedding( - self.kv_channels, self.rotary_percent, use_cpu_initialization=False - ) - - def teardown_method(self, method): - del self.rope_gpu_init - del self.rope_cpu_init - Utils.destroy_model_parallel() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_constructor(self): - assert isinstance(self.rope_cpu_init, RotaryEmbedding) - assert self.rope_cpu_init.inv_freq.device.type == 'cpu' - assert isinstance(self.rope_gpu_init, RotaryEmbedding) - assert self.rope_gpu_init.inv_freq.device.type == 'cuda' - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_gpu_forward(self): - output = self.rope_gpu_init(64) - assert output.shape[0] == 64 - assert output.shape[1] == 1 - assert output.shape[2] == 1 - assert output.shape[3] == self.kv_channels - assert output.dtype == torch.float32 - assert output.device.type == 'cuda' - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_cpu_forward(self): - output = self.rope_cpu_init(64) - assert output.shape[0] == 64 - assert output.shape[1] == 1 - assert output.shape[2] == 1 - assert output.shape[3] == self.kv_channels - assert output.dtype == torch.float32 - assert output.device.type == 'cuda' diff --git a/tests/unit_tests/transformer/test_spec_customization.py b/tests/unit_tests/transformer/test_spec_customization.py deleted file mode 100755 index e2c5f47d6c..0000000000 --- a/tests/unit_tests/transformer/test_spec_customization.py +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import sys -from dataclasses import dataclass, fields - -import pytest -import torch -import transformer_engine as te - -from megatron.core.extensions.transformer_engine import ( - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TENorm, - TERowParallelLinear, -) -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from megatron.core.parallel_state import get_context_parallel_group, get_tensor_model_parallel_group -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp -from megatron.core.transformer.spec_utils import ModuleSpec, build_module, import_module -from megatron.core.transformer.torch_norm import L2Norm -from megatron.core.transformer.transformer_block import TransformerBlock, TransformerBlockSubmodules -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules -from megatron.core.utils import is_te_min_version -from tests.unit_tests.test_utilities import Utils - - -class TestSpecCustomization: - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - - # specify Transformer Layer spec with all identity ops - self.transformer_layer_spec = TransformerLayerSubmodules() - - # specify attention spec using already imported class - self.attention_spec = ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - q_layernorm=IdentityOp, - k_layernorm=IdentityOp, - ), - ) - - # specify layernorm spec with module path to test dynamic importing - self.layernorm_spec = ModuleSpec( - module=("megatron.core.extensions.transformer_engine", "TENorm") - ) - - # specify bias dropout add with module path - self.bda_spec = ModuleSpec( - module=("megatron.core.fusions.fused_bias_dropout", "get_bias_dropout_add") - ) - - # Create model process groups for test. - self.model_comm_pgs = ModelCommProcessGroups( - tp=get_tensor_model_parallel_group(), cp=get_context_parallel_group() - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_import_module(self): - self_attention_cls = import_module( - module_path=('megatron.core.transformer.attention', 'SelfAttention') - ) - assert id(self_attention_cls) == id(SelfAttention) - - layernorm_cls = import_module(module_path=self.layernorm_spec.module) - assert id(layernorm_cls) == id(TENorm) - - def test_build_module(self): - # Check NoOp TransformerLayer - random_input = 12 - noop_transformer_layer = [ - build_module(getattr(self.transformer_layer_spec, field.name)) - for field in fields(self.transformer_layer_spec) - if field.name != 'sharded_state_dict_keys_map' - ] - - x = random_input - for mod in noop_transformer_layer: - # checking for `IdentityFuncOp` before `IdentityOp` because former - # is derived from the latter and so the second if statement will - # always be `True`. - if isinstance(mod, IdentityFuncOp): - x = mod()(x) - elif isinstance(mod, IdentityOp): - x = mod(x) - - assert x == random_input - - # Check SelfAttention - self_attention = build_module(self.attention_spec, config=self.config, layer_number=1) - assert isinstance(self_attention, SelfAttention) - assert self_attention.layer_number == 1 - assert self_attention.attn_mask_type == self.attention_spec.params['attn_mask_type'] - - num_weights = sum([p.numel() for p in self_attention.parameters()]) - assert num_weights == 648 - - # Check SelfAttention but with already initialized module - # `self_attention`. In this test, `build_module` acts as a no op as it - # simply returns the initialized module. - # NOTE: (sudhakars) Uncomment this test once this feature gets added - # back. - # self_attention2 = build_module( - # self_attention, config=self.config, spec=self.attention_spec, - # ) - # assert isinstance(self_attention2, SelfAttention) - # assert self_attention2.layer_number == 1 - # assert self_attention2.attn_mask_type == self.attention_spec.params['attn_mask_type'] - - # num_weights = sum([p.numel() for p in self_attention2.parameters()]) - # assert num_weights == 648 - - # Check LayerNorm - layernorm = build_module( - self.layernorm_spec, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - assert isinstance(layernorm, te.pytorch.LayerNorm) - - # Check BiasDropoutAdd - bda_op = build_module(self.bda_spec) - assert id(bda_op) == id(get_bias_dropout_add) - - def test_sliding_window_attention(self): - if not is_te_min_version("1.2.0"): - print("SWA not tested because TE version is not >= 1.2.0", file=sys.stderr) - return - - config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, - window_size=[10, 0], - ) - # Make sure DotProductAttention throws (swa unsupported). - threw = False - try: - attn = DotProductAttention( - config, - layer_number=1, - attn_mask_type=AttnMaskType.causal, - attention_type='self', - model_comm_pgs=self.model_comm_pgs, - ) - except: - threw = True - finally: - assert threw, 'Expected DotProductAttention to throw exception for SWA' - - # Test TEDotProductAttention - attn = TEDotProductAttention( - config, - layer_number=1, - attn_mask_type=AttnMaskType.causal, - attention_type='self', - model_comm_pgs=self.model_comm_pgs, - ) - # Make sure window-size is what we expect. - assert attn.window_size == config.window_size - - # Single integer window-size unsupported, make sure it throws - threw = False - try: - config.window_size = 11 - attn = TEDotProductAttention( - config, - layer_number=1, - attn_mask_type=AttnMaskType.causal, - attention_type='self', - model_comm_pgs=self.model_comm_pgs, - ) - except: - threw = True - finally: - assert threw, "Expected TEDotProductAttention to throw for integer window-size" - - # `None` makes this causal. - config.window_size = None - attn = TEDotProductAttention( - config, - layer_number=1, - attn_mask_type=AttnMaskType.causal, - attention_type='self', - model_comm_pgs=self.model_comm_pgs, - ) - # Make sure it's causal. - assert attn.window_size == (-1, 0) - - def test_transformer_block_custom(self): - """ - This test checks that the two ways of passing `layer_spec` to a - `TransformerBlock` result in an identical model: - 1. ModuleSpec(module=..., submodules=...) - 2. TransformerBlockSubmodules(layer_specs=...) - """ - - transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - layer_local_spec = get_gpt_layer_local_spec() - - # The following way can be used to pass a different `TransformerLayer` - # and internally the `TransformerBlock` would fan out the single - # `ModuleSpec` layer spec provided to all the layers of the block. - layer_spec1 = ModuleSpec(module=TransformerLayer, submodules=layer_local_spec.submodules) - model_parallel_cuda_manual_seed(123) - torch.manual_seed(0) - parallel_transformer_block1 = TransformerBlock(transformer_config, layer_spec1) - - layer_spec2 = TransformerBlockSubmodules( - layer_specs=[ - ModuleSpec(module=TransformerLayer, submodules=layer_local_spec.submodules) - ] - * transformer_config.num_layers, - layer_norm=TENorm, - ) - # make sure the model init conditions are identical - model_parallel_cuda_manual_seed(123) - torch.manual_seed(0) - parallel_transformer_block2 = TransformerBlock(transformer_config, layer_spec2) - - sequence_length = 32 - micro_batch_size = 2 - parallel_transformer_block1.cuda() - parallel_transformer_block2.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, transformer_config.hidden_size) - ) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - out1 = parallel_transformer_block1( - hidden_states=hidden_states, attention_mask=attention_mask - ) - out2 = parallel_transformer_block2( - hidden_states=hidden_states, attention_mask=attention_mask - ) - - assert torch.all(torch.eq(out1, out2)) - assert out1.shape[0] == sequence_length == out2.shape[0] - assert out1.shape[1] == micro_batch_size == out2.shape[1] - assert out1.shape[2] == transformer_config.hidden_size == out2.shape[2] - - def test_l2_qk_norm(self): - """Test L2 normalization for QK vectors using local spec.""" - layer_spec = get_gpt_layer_local_spec(qk_l2_norm=True) - - # Build the self-attention module from the spec - self_attention = build_module( - layer_spec.submodules.self_attention, config=self.config, layer_number=1 - ) - - assert isinstance(self_attention, SelfAttention) - # Verify that q_layernorm and k_layernorm are L2Norm instances - assert isinstance(self_attention.q_layernorm, L2Norm) - assert isinstance(self_attention.k_layernorm, L2Norm) - - # Test forward pass - sequence_length = 32 - micro_batch_size = 2 - self_attention.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.config.hidden_size) - ).cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - output, bias = self_attention(hidden_states=hidden_states, attention_mask=attention_mask) - - # Assert output shape is same as input shape - assert output.shape == hidden_states.shape diff --git a/tests/unit_tests/transformer/test_submodule_callables.py b/tests/unit_tests/transformer/test_submodule_callables.py deleted file mode 100644 index d0f5ad12d3..0000000000 --- a/tests/unit_tests/transformer/test_submodule_callables.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import pytest -import torch - -from megatron.core.models.gpt.fine_grained_callables import build_layer_callables -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.transformer.transformer_layer import TransformerLayer -from megatron.core.utils import is_te_min_version -from tests.unit_tests.a2a_overlap.utils import ( - DummyNode, - build_data, - compare_captures, - deterministic_mode, - get_test_config, - get_valid_token_dispatcher_types, - reset_model, -) -from tests.unit_tests.test_utilities import Utils - - -def run_model_ref_with_capture(model, input_tensors, iterations): - """ - Runs the model in reference mode and captures outputs and gradients. - - Args: - model: The transformer model to run. - input_tensors: List of input tensors for each iteration. - iterations: Number of iterations to run the model. - - Returns: - dict: A dictionary containing model outputs and parameter gradients. - """ - - output_tensors = [] - for i in range(iterations): - output = model(input_tensors[i].clone())[0] - output_tensors.append(output) - output.backward(torch.ones_like(output)) - - capture = {"outputs": output_tensors} - for name, param in model.named_parameters(): - capture[name] = param.grad - - return capture - - -def run_model_submodules_with_capture(model, input_tensors, microbatches): - """ - Runs the model with all-to-all overlap optimization and captures outputs and gradients. - - Args: - model: The transformer model to run. - input_tensors: List of input tensors for each microbatch. - microbatches: Number of microbatches to process. - - Returns: - dict: A dictionary containing model outputs and parameter gradients. - """ - - for i in range(len(input_tensors)): - input_tensors[i] = input_tensors[i].clone() - - output_tensors = [] - # get callables - callables, dw = build_layer_callables(model) - attn, post_attn, dispatch, moe, combine, post_process = callables - assert post_process is None - for i in range(microbatches): - # build mock func/state - node = DummyNode() - - # attn fwd - hidden_states = attn(node, input_tensors[i]) - - # post attn fwd - local_tokens, probs = post_attn(node, hidden_states) - - # dispatch fwd - dispatched_tokens = dispatch(node, local_tokens, probs) - - # moe fwd - expert_outputs = moe(node, dispatched_tokens) - if model.mlp.use_shared_expert: - expert_output, shared_expert_output = expert_outputs - else: - expert_output = expert_outputs - shared_expert_output = None - - # combine fwd - hidden_states = combine(node, expert_output, shared_expert_output) - - # loss - output_tensors.append(hidden_states) - hidden_states.backward(torch.ones_like(hidden_states)) - - capture = {"outputs": output_tensors} - for name, param in model.named_parameters(): - capture[name] = param.grad - - return capture - - -class TestTransformerLayerSubmoduleCallables: - """ - Test class for transformer layer submodule callables. - - This class contains tests to verify that the transformer layer submodule callables - provide the same results as the reference implementation. - """ - - def setup_method(self, method): - pass - - def teardown_method(self, method): - pass - - @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") - @pytest.mark.parametrize("dispatcher_type", get_valid_token_dispatcher_types()) - @pytest.mark.parametrize("grouped_gemm", [True, False]) - @pytest.mark.parametrize("permute_fusion", [True, False]) - def test_1f1b_overlap(self, dispatcher_type, grouped_gemm, permute_fusion): - """ - Tests the 1-forward-1-backward overlap optimization. - - This test verifies that the all-to-all overlap optimization produces - the same results as the reference implementation. - """ - - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=4, - expert_model_parallel_size=2, - virtual_pipeline_model_parallel_size=2, - ) - extra_kwargs = { - "moe_token_dispatcher_type": dispatcher_type, - "moe_permute_fusion": permute_fusion, - } - if dispatcher_type == "flex": - extra_kwargs["moe_enable_deepep"] = True - extra_kwargs["moe_router_dtype"] = "fp32" - config = get_test_config(extra_kwargs=extra_kwargs, moe_grouped_gemm=grouped_gemm) - microbatches = 4 - with deterministic_mode(): - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=8, - moe_grouped_gemm=grouped_gemm, - qk_layernorm=True, - multi_latent_attention=True, - ) - model = TransformerLayer(config, transformer_layer_spec.submodules) - - params = reset_model(model) - input_tensors = [build_data() for _ in range(microbatches)] - - capture_ref = run_model_ref_with_capture(model, input_tensors, microbatches) - reset_model(model, params) - capture_callables = run_model_submodules_with_capture( - model, input_tensors, microbatches - ) - comp_res = compare_captures(capture_ref, capture_callables, True) - assert comp_res[0], f"[rank {torch.distributed.get_rank()}] {comp_res[1]}" - Utils.destroy_model_parallel() diff --git a/tests/unit_tests/transformer/test_transformer_block.py b/tests/unit_tests/transformer/test_transformer_block.py deleted file mode 100644 index 48b678c5fd..0000000000 --- a/tests/unit_tests/transformer/test_transformer_block.py +++ /dev/null @@ -1,613 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import copy -from contextlib import nullcontext - -import pytest -import torch -from packaging import version - -from megatron.core import mpu, parallel_state -from megatron.core.fp8_utils import get_fp8_context -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout -from megatron.core.transformer.spec_utils import build_module -from megatron.core.transformer.transformer_block import TransformerBlock, get_num_layers_to_build -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import TransformerLayer -from tests.unit_tests.test_utilities import Utils - - -class TestParallelTransformerBlock: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - self.transformer_config = TransformerConfig( - num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True - ) - self.parallel_transformer_block = TransformerBlock( - self.transformer_config, get_gpt_layer_with_transformer_engine_spec() - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - parallel_transformer_block = self.parallel_transformer_block - assert isinstance(parallel_transformer_block, TransformerBlock) - num_weights = sum([p.numel() for p in parallel_transformer_block.parameters()]) - assert num_weights == 100096 - assert parallel_transformer_block.num_layers_per_pipeline_rank == 2 - assert len(parallel_transformer_block.layers) == 2 - layer_0: TransformerLayer = parallel_transformer_block._get_layer(0) - assert layer_0.layer_number == 1 - layer_1: TransformerLayer = parallel_transformer_block._get_layer(1) - assert layer_1.layer_number == 2 - - def test_gpu_forward(self): - parallel_transformer_block = self.parallel_transformer_block - config: TransformerConfig = parallel_transformer_block.config - - sequence_length = 32 - micro_batch_size = 2 - parallel_transformer_block.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - hidden_states = parallel_transformer_block( - hidden_states=hidden_states, attention_mask=attention_mask - ) - assert hidden_states.shape[0] == sequence_length - assert hidden_states.shape[1] == micro_batch_size - assert hidden_states.shape[2] == config.hidden_size - - def test_gpu_forward_full_checkpoint(self): - self._run_full_checkpoint_test(fp8=None) - - def test_gpu_forward_full_checkpoint_fp8(self): - self._run_full_checkpoint_test(fp8="e4m3") - - def test_gpu_forward_selective_checkpoint(self): - self._run_selective_checkpoint_test(fp8=None) - - def test_gpu_forward_selective_checkpoint_fp8(self): - self._run_selective_checkpoint_test(fp8="e4m3") - - def _run_full_checkpoint_test(self, fp8): - transformer_config = self.transformer_config - config = transformer_config - config.recompute_granularity = 'full' - config.recompute_method = 'block' - config.fp8 = fp8 - config.recompute_num_layers = config.num_layers - full_transformer_block = TransformerBlock( - config, get_gpt_layer_with_transformer_engine_spec() - ) - assert full_transformer_block.config.recompute_granularity == 'full' - assert full_transformer_block.config.recompute_method == 'block' - assert full_transformer_block.config.fp8 == fp8 - - sequence_length = 32 - micro_batch_size = 2 - full_transformer_block.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - hidden_states = full_transformer_block( - hidden_states=hidden_states, attention_mask=attention_mask - ) - assert hidden_states.shape[0] == sequence_length - assert hidden_states.shape[1] == micro_batch_size - assert hidden_states.shape[2] == config.hidden_size - - def _run_selective_checkpoint_test(self, fp8): - transformer_config = self.transformer_config - config = transformer_config - config.recompute_granularity = 'selective' - config.fp8 = fp8 - selective_transformer_block = TransformerBlock( - config, get_gpt_layer_with_transformer_engine_spec() - ) - assert selective_transformer_block.config.recompute_granularity == 'selective' - assert "core_attn" in selective_transformer_block.config.recompute_modules - assert selective_transformer_block.checkpoint_core_attention - assert selective_transformer_block.config.fp8 == fp8 - - sequence_length = 32 - micro_batch_size = 2 - selective_transformer_block.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - hidden_states = selective_transformer_block( - hidden_states=hidden_states, attention_mask=attention_mask - ) - assert hidden_states.shape[0] == sequence_length - assert hidden_states.shape[1] == micro_batch_size - assert hidden_states.shape[2] == config.hidden_size - - -class TestPipelineParallelTransformerBlock: - @pytest.mark.parametrize( - "num_layers, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, " - "account_for_embedding_in_pipeline_split, account_for_loss_in_pipeline_split, " - "first_pipeline_num_layers, last_pipeline_num_layers, should_assert_error", - [ - # Last pipeline stage has specified layers - (60, 5, None, False, False, None, 4, False), - # Uneven PP 6*[8]+[6]+[6]=60 - (60, 8, None, False, False, 6, 6, False), - # Even PP - (64, 4, None, False, False, None, None, False), - # Even VPP - (64, 4, 8, False, False, None, None, False), - # First pipeline stage has specified layers - # Should distribute remaining layers evenly among other stages - (60, 6, None, False, False, 5, None, False), - # Uneven distribution leading to assertion error - (101, 8, None, False, False, 13, 13, True), - # Include embedding in pipeline split without virtual PP - (63, 4, None, True, False, None, None, False), - # Include loss in pipeline split without virtual PP - (63, 4, None, False, True, None, None, False), - # Include embedding and loss in pipeline split without virtual PP - (62, 4, None, True, True, None, None, False), - # Include embedding and loss with virtual PP - (62, 4, 2, True, True, None, None, False), - # num_layers not divisible by pipeline size without embedding/loss - (65, 4, None, False, False, None, None, True), - # num_layers not divisible by pipeline size with embedding/loss - (65, 4, None, True, True, None, None, True), - # Uneven distribution with specified first pipeline layers causing error - (61, 4, None, False, False, 12, None, True), - # Too few layers for the number of pipeline stages - (2, 4, None, False, False, None, None, True), - # Uneven PP with embedding included (should assert per code) - (60, 6, None, True, False, 5, 5, True), - # Virtual PP where num_layers not divisible by total virtual stages - (50, 2, 7, False, False, None, None, True), - # Edge case where num_layers per virtual rank is zero - (4, 4, 4, False, False, None, None, True), - ], - ) - def test_layer_builder( - self, - num_layers, - pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size, - account_for_embedding_in_pipeline_split, - account_for_loss_in_pipeline_split, - first_pipeline_num_layers, - last_pipeline_num_layers, - should_assert_error, - ): - Utils.fake_initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - ) - context = ( - pytest.raises((AssertionError, ValueError)) if should_assert_error else nullcontext() - ) - with context: - transformer_config = TransformerConfig( - num_layers=num_layers, - pipeline_model_parallel_size=pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - account_for_embedding_in_pipeline_split=account_for_embedding_in_pipeline_split, - account_for_loss_in_pipeline_split=account_for_loss_in_pipeline_split, - num_layers_in_first_pipeline_stage=first_pipeline_num_layers, - num_layers_in_last_pipeline_stage=last_pipeline_num_layers, - pipeline_dtype=torch.bfloat16, - hidden_size=128, - num_attention_heads=16, - ) - total_build_layers = 0 - for i in range(pipeline_model_parallel_size): - parallel_state.set_pipeline_model_parallel_rank(i) - if virtual_pipeline_model_parallel_size is not None: - for j in range(virtual_pipeline_model_parallel_size): - num_layers_to_build = get_num_layers_to_build(transformer_config, j) - total_build_layers += num_layers_to_build - else: - num_layers_to_build = get_num_layers_to_build(transformer_config) - total_build_layers += num_layers_to_build - if not should_assert_error: - assert ( - total_build_layers == num_layers - ), f"total build layers {total_build_layers} should be equal to num_layers {num_layers}" - parallel_state.set_pipeline_model_parallel_world_size(None) - parallel_state.set_virtual_pipeline_model_parallel_world_size(None) - - -class TestProcessGroupTransformerBlock: - def setup_method(self, method): - Utils.destroy_model_parallel() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - @pytest.mark.parametrize( - "tp_size,cp_size,dp_size,use_custom_pg", - [(2, 2, 2, True), (2, 4, 1, True), (2, 2, 2, False), (2, 4, 1, False)], - ) - def test_pg_input_args(self, tp_size, cp_size, dp_size, use_custom_pg): - """ - Test TransformerBlock with custom process groups. - """ - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, context_parallel_size=cp_size - ) - model_parallel_cuda_manual_seed(123) - if use_custom_pg: - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - # Create HyperCommGrid with dimensions cp, tp, dp (reversed from device mesh order) - grid = HyperCommGrid([cp_size, tp_size, dp_size], ["cp", "tp", "dp"]) - - # Get process groups from HyperCommGrid - tp_group = grid.create_pg("tp") - cp_group = grid.create_pg("cp") - - # Create ModelCommProcessGroups with custom process groups - model_comm_pgs = ModelCommProcessGroups(tp=tp_group, cp=cp_group) - else: - # Rely on TransformerBlock to create default process groups - model_comm_pgs = None - - self.transformer_config = TransformerConfig( - num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True - ) - self.transformer_block = TransformerBlock( - self.transformer_config, - get_gpt_layer_with_transformer_engine_spec(), - model_comm_pgs=model_comm_pgs, - ) - self.transformer_block.cuda() - - sequence_length = 128 - micro_batch_size = 1 - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.transformer_block.config.hidden_size), - device="cuda", - ) - - hidden_states = self.transformer_block(hidden_states=hidden_states, attention_mask=None) - - assert hidden_states.shape[0] == sequence_length - assert hidden_states.shape[1] == micro_batch_size - assert hidden_states.shape[2] == self.transformer_block.config.hidden_size - - -class TestMixedProcessGroups: - def setup_method(self, method): - Utils.destroy_model_parallel() - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - @pytest.mark.parametrize("tp_size,cp_size", [(2, 4)]) - def test_mixed_pg_transformer_block(self, tp_size, cp_size, monkeypatch): - """ - Test TransformerBlock with custom process groups. - """ - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, context_parallel_size=cp_size - ) - model_parallel_cuda_manual_seed(123) - - # Create a new build_layers method that uses interleaved attention - def _build_layers_with_interleaved_attention(self): - def build_layer(layer_spec, layer_number): - fp8_init_context = get_fp8_context(self.config, layer_number - 1, is_init=True) - if layer_number % 4 == 0: - config = self.local_attn_config - model_comm_pgs = self.local_pgs - else: - config = self.config - model_comm_pgs = self.model_comm_pgs - with fp8_init_context: - module = build_module( - layer_spec, - config=config, - layer_number=layer_number, - model_comm_pgs=model_comm_pgs, - ) - return module - - # Modify TransformerConfig and ModelCommProcessGroups for local attention - self.local_attn_config = copy.deepcopy(self.config) - self.local_pgs = ModelCommProcessGroups.use_mpu_process_groups() - self.local_attn_config.context_parallel_size = 1 - self.local_pgs.cp = torch.distributed.new_group(ranks=[torch.distributed.get_rank()]) - - # offset is implicit in TransformerLayer - self.layers = torch.nn.ModuleList( - [ - build_layer(layer_spec, i + 1) - for i, layer_spec in enumerate(self.submodules.layer_specs) - ] - ) - - # Copied from TransformerBlock.build_layers - if self.submodules.layer_norm and self.post_process and self.post_layer_norm: - self.final_layernorm = build_module( - self.submodules.layer_norm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - else: - self.final_layernorm = None # Either this or nn.Identity - - # Replace the default build_layers method - monkeypatch.setattr( - TransformerBlock, "_build_layers", _build_layers_with_interleaved_attention - ) - - self.transformer_config = TransformerConfig( - num_layers=4, - hidden_size=64, - num_attention_heads=4, - use_cpu_initialization=True, - context_parallel_size=cp_size, - bf16=True, - ) - self.transformer_block = TransformerBlock( - self.transformer_config, get_gpt_layer_with_transformer_engine_spec() - ) - self.transformer_block.cuda().bfloat16() - - sequence_length = 128 - micro_batch_size = 1 - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones( - (sequence_length, micro_batch_size, self.transformer_block.config.hidden_size), - dtype=torch.bfloat16, - device="cuda", - ) - - hidden_states = self.transformer_block(hidden_states=hidden_states, attention_mask=None) - - assert hidden_states.shape[0] == sequence_length - assert hidden_states.shape[1] == micro_batch_size - assert hidden_states.shape[2] == self.transformer_block.config.hidden_size - - -class TestPipelineParallelLayoutTransformerBlock: - @pytest.mark.parametrize( - "num_layers, pp_size, vpp_size, pipeline_model_parallel_layout, should_assert_error", - [ - # No embedding layer provided - (7, 2, 1, [["decoder"] * 6, ["decoder", "loss"]], True), - # No loss layer provided - (7, 2, 1, [["embedding"] + ["decoder"] * 6, ["decoder"]], True), - # Invalid layer type - (7, 2, 1, [["embedding"], ["invalid_type"] * 7 + ["loss"]], True), - # Invalid pp size - (7, 2, 2, [["embedding"], ["decoder"] * 7, ["loss"]], True), - # Invalid layout - ( - 7, - 2, - 2, - [[["embedding", "decoder"], ["decoder"] * 4], ["decoder"], ["decoder", "loss"]], - True, - ), - # Invalid layout - ( - 7, - 2, - 1, - [[["embedding", "decoder"], ["decoder"] * 4], ["decoder"] * 2 + ["loss"]], - True, - ), - # Invalid layout - (7, 2, 1, [[["embedding"] + ["decoder"] * 5], ["decoder"] * 2 + ["loss"]], True), - # Usual pp case - ( - 7, - 2, - 2, - [ - [["embedding", "decoder"], ["decoder"] * 3], - [["decoder"] * 2, ["decoder", "loss"]], - ], - True, - ), - # Usual pp case - ( - 7, - 2, - 2, - [["embedding", "decoder"], ["decoder"] * 4, ["decoder"], ["decoder", "loss"]], - False, - ), - # Empty stage - (7, 2, 2, [["embedding"], ["decoder"] * 7, [], ["loss"]], False), - # Usual uneven vpp case with standalone embedding and loss layer - (7, 2, 2, [["embedding"], ["decoder"] * 6, ["decoder"], ["loss"]], False), - ], - ) - def test_layer_builder( - self, num_layers, pp_size, vpp_size, pipeline_model_parallel_layout, should_assert_error - ): - Utils.fake_initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=pp_size, - virtual_pipeline_model_parallel_size=vpp_size, - ) - context = ( - pytest.raises((AssertionError, ValueError)) if should_assert_error else nullcontext() - ) - with context: - transformer_config = TransformerConfig( - num_layers=num_layers, - pipeline_model_parallel_layout=pipeline_model_parallel_layout, - pipeline_model_parallel_size=pp_size, - pipeline_dtype=torch.bfloat16, - hidden_size=128, - num_attention_heads=16, - ) - total_build_layers = 0 - for i in range(pp_size): - parallel_state.set_pipeline_model_parallel_rank(i) - for j in range(vpp_size): - total_build_layers += get_num_layers_to_build(transformer_config, vp_stage=j) - if not should_assert_error: - assert ( - total_build_layers == num_layers - ), f"total build layers {total_build_layers} should be equal to num_layers {num_layers}" - parallel_state.set_pipeline_model_parallel_world_size(None) - parallel_state.set_virtual_pipeline_model_parallel_world_size(None) - - @pytest.mark.parametrize( - ('pipeline_model_parallel_layout', 'layer_number_golden_answer'), - [ - ( - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - [[[], []], [[1], [5]], [[2, 3], [6]], [[4], [7, 8]]], - ) - ], - ) - def test_layout_layer_number(self, pipeline_model_parallel_layout, layer_number_golden_answer): - tp_size = 1 - pp_size = 4 - vpp_size = 2 - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - virtual_pipeline_model_parallel_size=vpp_size, - ) - model_parallel_cuda_manual_seed(123) - torch.manual_seed(123) - - # Initialize GPT model - default_config_kwargs = dict( - num_layers=8, - hidden_size=8, - num_attention_heads=8, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - bf16=True, - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - virtual_pipeline_model_parallel_size=vpp_size, - pipeline_model_parallel_layout=pipeline_model_parallel_layout, - ) - transformer_config = TransformerConfig(**default_config_kwargs) - gpt_model = [] - for i in range(vpp_size): - pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) - post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) - this_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), - vocab_size=128, - max_sequence_length=4, - pre_process=pre_process, - post_process=post_process, - vp_stage=i, - ) - this_model.model_type = ModelType.encoder_or_decoder - gpt_model.append(this_model) - - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - for vpp_rank in range(vpp_size): - layers = gpt_model[vpp_rank].decoder.layers - layer_numbers = [l.layer_number for l in layers] - golden_answer_curr_stage = layer_number_golden_answer[pp_rank][vpp_rank] - assert len(layers) == len( - golden_answer_curr_stage - ), f"{pp_rank=}, {vpp_rank=}, {len(layers)=}, {len(golden_answer_curr_stage)=}" - assert ( - layer_numbers == golden_answer_curr_stage - ), f"{pp_rank=}, {vpp_rank=}, {layer_numbers=}, {golden_answer_curr_stage=}" - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - "pp_size, input_layout_str, input_layout_list", - [ - ( - 2, - "Et|t*4|t|tL", - [["embedding", "decoder"], ["decoder"] * 4, ["decoder"], ["decoder", "loss"]], - ), - (2, "E|t*6|t|L", [["embedding"], ["decoder"] * 6, ["decoder"], ["loss"]]), - ( - 4, - "E|t|t*2|t||(t|)*2,t*2,L", - [ - ["embedding"], - ["decoder"], - ["decoder"] * 2, - ["decoder"], - [], - ["decoder"], - ["decoder"], - ["decoder"] * 2 + ["loss"], - ], - ), - ( - 8, - "Et*3|(tt|)*29,m|L", - [["embedding"] + ["decoder"] * 3] + [["decoder"] * 2] * 29 + [["mtp"], ["loss"]], - ), - ( - 16, - "Et*2|(tt|)*29,t|mL", - [["embedding"] + ["decoder"] * 2] - + [["decoder"] * 2] * 29 - + [["decoder"]] - + [["mtp", "loss"]], - ), - ], - ) - def test_parsing_layout_from_str(self, pp_size, input_layout_str, input_layout_list): - parsed_layout_from_str = PipelineParallelLayerLayout.from_str(input_layout_str, pp_size) - parsed_layout_baseline = PipelineParallelLayerLayout(input_layout_list, pp_size) - assert parsed_layout_from_str.layout == parsed_layout_baseline.layout - assert ( - parsed_layout_from_str.virtual_pipeline_model_parallel_size - == parsed_layout_baseline.virtual_pipeline_model_parallel_size - ) diff --git a/tests/unit_tests/transformer/test_transformer_block_custom_pgs.py b/tests/unit_tests/transformer/test_transformer_block_custom_pgs.py deleted file mode 100644 index 1a639ba085..0000000000 --- a/tests/unit_tests/transformer/test_transformer_block_custom_pgs.py +++ /dev/null @@ -1,773 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - - -import copy -import os -from typing import Optional - -import pytest -import torch -from packaging import version - -from megatron.core import parallel_state -from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.extensions.transformer_engine import ( - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TERowParallelLinear, -) -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.process_groups_config import GradCommProcessGroups, ModelCommProcessGroups -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.enums import AttnBackend, AttnMaskType -from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_block import TransformerBlock -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules -from tests.unit_tests.test_utilities import Utils - - -class HeterogenousTransformerLayer(TransformerLayer): - """A transformer layer that supports different process groups for attention and MLP. - - This specialized transformer layer implementation allows independent parallelism - strategies for the self-attention and MLP components - - Implementation details: - - Uses identity operations as placeholders during initialization - - Replaces the placeholder modules with properly configured attention and MLP - using their respective process groups - - Requires process groups to be specified in the submodule parameters - - Args: - config (TransformerConfig): Configuration for the transformer layer - submodules (TransformerLayerSubmodules): Submodule specifications with process group params - layer_number (int, optional): Index of this layer. Defaults to 1. - hidden_dropout (float, optional): Override dropout rate. Defaults to None. - model_comm_pgs (ModelCommProcessGroups, optional): Default process groups. Defaults to None. - vp_stage (int, optional): Virtual pipeline stage. Defaults to None. - """ - - def __init__( - self, - config: TransformerConfig, - submodules: TransformerLayerSubmodules, - layer_number: int = 1, - hidden_dropout: Optional[float] = None, - model_comm_pgs: ModelCommProcessGroups = None, - vp_stage: Optional[int] = None, - ): - # Temporarily replace attention and MLP with IdentityOp, - # This is a temporary workaround for the test until we have a better interface - # will rebuild them with custom process groups after super init - def _modify_submodules(submodules): - submodules.self_attention = IdentityOp - submodules.mlp = IdentityOp - return submodules - - original_attention = submodules.self_attention - original_mlp = submodules.mlp - new_submodules = _modify_submodules(copy.copy(submodules)) - - super().__init__( - config=config, - submodules=new_submodules, - layer_number=layer_number, - hidden_dropout=hidden_dropout, - model_comm_pgs=model_comm_pgs, - vp_stage=vp_stage, - ) - - assert ( - 'model_comm_pgs' in submodules.self_attention.params - ), "model_comm_pgs should be in the params of the submodules" - self.self_attention = build_module( - original_attention, config=self.config, layer_number=layer_number - ) - assert ( - 'tp_group' in submodules.mlp.params - ), "tp_group should be in the params of the submodules" - self.mlp = build_module(original_mlp, config=self.config) - - -def create_reference_mlp(hidden_size, ffn_hidden_size, seed=12345): - """Create a reference MLP with full unsharded weights. - - Args: - hidden_size: Input/output dimension - ffn_hidden_size: Hidden dimension - seed: Random seed for weight initialization - - Returns: - Reference MLP with unsharded weights (nn.Sequential) - """ - # Set seed for reproducible initialization - torch.manual_seed(seed) - - # Create standard PyTorch Linear layers (unsharded) - ref_fc1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True) - ref_fc2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True) - - # Return as a simple sequential model - return torch.nn.Sequential(ref_fc1, ref_fc2).cpu() - - -def copy_weights_to_tp_mlp(ref_mlp, tp_mlp, tp_group): - """Copy weights from reference MLP to tensor-parallel MLP. - - Args: - ref_mlp: Reference MLP with full weights (nn.Sequential) - tp_mlp: Tensor-parallel MLP (megatron MLP instance) - tp_group: Tensor parallel process group - - Returns: - None (modifies tp_mlp in-place) - """ - # Get tensor parallel rank and world size - tp_rank = tp_group.rank() - tp_world_size = tp_group.size() - - # Reference components - ref_fc1 = ref_mlp[0] # First linear layer - ref_fc2 = ref_mlp[1] # Second linear layer - - # Manually copy and shard weights based on TP rank - with torch.no_grad(): - # FC1 (Column Parallel) - split along output dimension - out_size = ref_fc1.weight.size(0) - per_partition_size = out_size // tp_world_size - start_idx = tp_rank * per_partition_size - end_idx = (tp_rank + 1) * per_partition_size - - tp_mlp.linear_fc1.weight.copy_( - ref_fc1.weight[start_idx:end_idx].to(tp_mlp.linear_fc1.weight.device) - ) - if hasattr(tp_mlp.linear_fc1, 'bias') and tp_mlp.linear_fc1.bias is not None: - tp_mlp.linear_fc1.bias.copy_( - ref_fc1.bias[start_idx:end_idx].to(tp_mlp.linear_fc1.bias.device) - ) - - # FC2 (Row Parallel) - split along input dimension - in_size = ref_fc2.weight.size(1) - per_partition_size = in_size // tp_world_size - start_idx = tp_rank * per_partition_size - end_idx = (tp_rank + 1) * per_partition_size - - tp_mlp.linear_fc2.weight.copy_( - ref_fc2.weight[:, start_idx:end_idx].to(tp_mlp.linear_fc2.weight.device) - ) - if hasattr(tp_mlp.linear_fc2, 'bias') and tp_mlp.linear_fc2.bias is not None: - tp_mlp.linear_fc2.bias.copy_(ref_fc2.bias.to(tp_mlp.linear_fc2.bias.device)) - - -def _gpt_te_layer_spec_with_hetro_pgs(attn_model_comm_pgs, mlp_model_comm_pgs): - return ModuleSpec( - module=HeterogenousTransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={ - "attn_mask_type": AttnMaskType.causal, - "model_comm_pgs": attn_model_comm_pgs, - }, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=IdentityOp, - mlp=ModuleSpec( - module=MLP, - params={'tp_group': mlp_model_comm_pgs.tp}, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ) - - -class TestTransformerBlockWithProcessGroups: - def setup_method(self, method): - Utils.destroy_model_parallel() - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - def teardown_method(self, method): - Utils.destroy_model_parallel() - torch.backends.cudnn.deterministic = False - torch.backends.cudnn.benchmark = True - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - @pytest.mark.parametrize( - "world_size, tp_size, cp_size, dp_size", - [ - (1, 1, 1, 1), # Single GPU, no parallelism - (2, 1, 2, 1), # 2 GPUs, 1 TP, 2 CP - (2, 2, 1, 1), # 2 GPUs, 2 TP, 1 CP - (8, 8, 1, 1), # 8 GPUs, 8 TP, 1 CP - (8, 2, 4, 1), # 8 GPUs, 2 TP, 4 CP - (8, 4, 2, 1), # 8 GPUs, 4 TP, 2 CP - (8, 1, 1, 8), # 8 GPUs, 1 TP, 1 CP, 8 DP - (8, 2, 1, 4), # 8 GPUs, 2 TP, 1 CP, 4 DP - (8, 2, 2, 2), # 8 GPUs, 2 TP, 2 CP, 2 DP - ], - ) - def test_params_and_grads_match_transformer_block(self, world_size, tp_size, cp_size, dp_size): - """ - Test that parameters and gradients match after one forward and backward pass - between transformer blocks using default and custom process groups. - """ - # Skip if world size doesn't match - actual_world_size = torch.cuda.device_count() - if actual_world_size != world_size: - pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}") - Utils.initialize_model_parallel( - tensor_model_parallel_size=tp_size, context_parallel_size=cp_size - ) - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - - torch.manual_seed(12345) - model_parallel_cuda_manual_seed(123) - - # Create transformer configuration - transformer_config = TransformerConfig( - num_layers=3, - hidden_size=4096, - num_attention_heads=32, - use_cpu_initialization=True, - attention_dropout=0.0, - hidden_dropout=0.0, - bf16=True, - context_parallel_size=cp_size, - ) - - # Create a transformer block with default process groups - default_block = ( - TransformerBlock(transformer_config, get_gpt_layer_with_transformer_engine_spec()) - .cuda() - .bfloat16() - ) - - # Create custom process groups - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - # Create HyperCommGrid with dimensions tp, cp, ep, pp, dp (reversed from device mesh order) - grid = HyperCommGrid([tp_size, cp_size, 1, 1, dp_size], ["tp", "cp", "ep", "pp", "dp"]) - - tp_group = grid.create_pg("tp") - cp_group = grid.create_pg("cp") - pp_group = grid.create_pg("pp") - ep_group = grid.create_pg("ep") - model_comm_pgs = ModelCommProcessGroups(tp=tp_group, cp=cp_group, pp=pp_group, ep=ep_group) - - dp_group = grid.create_pg("dp") - dp_cp_group = grid.create_pg(["dp", "cp"]) - - grad_comm_pgs = GradCommProcessGroups() - grad_comm_pgs.dp = dp_group - grad_comm_pgs.dp_cp = dp_cp_group - - # Create a transformer block with custom process groups - custom_block = ( - TransformerBlock( - transformer_config, - get_gpt_layer_with_transformer_engine_spec(), - model_comm_pgs=model_comm_pgs, - ) - .cuda() - .bfloat16() - ) - - # Initialize with same parameters - for default_param, custom_param in zip( - default_block.parameters(), custom_block.parameters() - ): - custom_param.data.copy_(default_param.data) - - # copy buffers - for default_buffer, custom_buffer in zip(default_block.buffers(), custom_block.buffers()): - custom_buffer.data.copy_(default_buffer.data) - - # wrap with DDP - ddp_config = DistributedDataParallelConfig(overlap_grad_reduce=True, bucket_size=10000) - default_block = DistributedDataParallel( - config=transformer_config, ddp_config=ddp_config, module=default_block - ) - - custom_block = DistributedDataParallel( - config=transformer_config, - ddp_config=ddp_config, - module=custom_block, - grad_comm_pgs=grad_comm_pgs, - model_comm_pgs=model_comm_pgs, - ) - - # Create test input - sequence_length = 4096 - micro_batch_size = 4 - hidden_states = ( - torch.randn( - (sequence_length, micro_batch_size, transformer_config.hidden_size), - device="cuda", - requires_grad=True, - ) - .bfloat16() - .requires_grad_(True) - ) - hidden_states.retain_grad() - - torch.distributed.all_reduce(hidden_states, op=torch.distributed.ReduceOp.SUM) - - hidden_states_default = hidden_states.clone().detach().requires_grad_(True) - hidden_states_custom = hidden_states.clone().detach().requires_grad_(True) - - # Forward passes - output_default = default_block(hidden_states=hidden_states_default, attention_mask=None) - output_custom = custom_block(hidden_states=hidden_states_custom, attention_mask=None) - # Verify outputs match - torch.testing.assert_close( - output_default, - output_custom, - rtol=1e-8, - atol=0, - msg="Forward outputs don't match between default and custom process groups", - ) - - output_default.backward(torch.ones_like(output_default) * 1e3) - output_custom.backward(torch.ones_like(output_custom) * 1e3) - # Verify gradients match for parameters - # with DDP grad attribute is None, only main_grad is available - for i, (default_param, custom_param) in enumerate( - zip(default_block.parameters(), custom_block.parameters()) - ): - if default_param.main_grad is not None and custom_param.main_grad is not None: - param_name = [name for name, param in default_block.named_parameters()][i] - - # ideally we want to grads and assert they are close - # but the grads are too small to use rtol - # for now just a smoke test to ensure grads are available - # TODO: ykarnati - improve this test to ensure we have large grads for comparision - assert ( - default_param.main_grad is not None and custom_param.main_grad is not None - ), f"Gradient is None for parameter '{param_name}' at index {i}" - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - @pytest.mark.parametrize( - "world_size, attn_tp_size, attn_cp_size, mlp_tp_size", - [ - (1, 1, 1, 1), # Single GPU, no parallelism - (2, 1, 1, 2), # 2 GPUs, attn: 1 TP, 1 CP; mlp: 2 TP - (2, 2, 1, 2), # 2 GPUs, attn: 2 TP, 1 CP; mlp: 2 TP - (8, 1, 1, 8), # 8 GPUs, attn: 1 TP, 1 CP; mlp: 8 TP - (8, 8, 1, 1), # 8 GPUs, attn: 8 TP, 1 CP; mlp: 1 TP - (8, 2, 1, 4), # 8 GPUs, attn: 2 TP, 1 CP; mlp: 4 TP - (8, 4, 1, 2), # 8 GPUs, attn: 4 TP, 1 CP; mlp: 2 TP - (8, 2, 2, 2), # 8 GPUs, attn: 2 TP, 2 CP; mlp: 2 TP - ], - ) - def test_fwd_bwd_pass_non_uniform_transformer_block( - self, world_size, attn_tp_size, attn_cp_size, mlp_tp_size - ): - """ - Test that parameters and gradients after one forward and backward pass - with different process groups for attention and mlp. - """ - - actual_world_size = torch.cuda.device_count() - if actual_world_size != world_size: - pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}") - Utils.initialize_model_parallel() - torch.manual_seed(12345) - model_parallel_cuda_manual_seed(123) - - # Create transformer configuration - transformer_config = TransformerConfig( - num_layers=3, - hidden_size=4096, - num_attention_heads=32, - use_cpu_initialization=True, - attention_dropout=0.0, - hidden_dropout=0.0, - bf16=True, - context_parallel_size=attn_cp_size, - ) - - # Create custom process groups - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - # Create HyperCommGrid with dimensions mlp_tp, attn_cp, attn_tp (reversed from device mesh order) - grid = HyperCommGrid( - [mlp_tp_size, attn_cp_size, attn_tp_size], ["mlp_tp", "attn_cp", "attn_tp"] - ) - - attn_tp_group = grid.create_pg("attn_tp") - attn_cp_group = grid.create_pg("attn_cp") - mlp_tp_group = grid.create_pg("mlp_tp") - - attn_model_comm_pgs = ModelCommProcessGroups(tp=attn_tp_group, cp=attn_cp_group) - mlp_model_comm_pgs = ModelCommProcessGroups(tp=mlp_tp_group) - - # Get the layer spec with different process groups for attention and mlp - hetro_layer_spec = _gpt_te_layer_spec_with_hetro_pgs( - attn_model_comm_pgs, mlp_model_comm_pgs - ) - custom_block = TransformerBlock(transformer_config, hetro_layer_spec).cuda().bfloat16() - - sequence_length = 4096 - micro_batch_size = 2 - - # [sequence length, batch size, hidden size] - hidden_states = ( - torch.randn( - (sequence_length, micro_batch_size, transformer_config.hidden_size), - device="cuda", - requires_grad=True, - ) - .bfloat16() - .requires_grad_(True) - ) - hidden_states.retain_grad() - - output_custom = custom_block(hidden_states=hidden_states, attention_mask=None) - - assert ( - output_custom.shape[0] == sequence_length - ), f"Output shape is {output_custom.shape} dont match sequence length {sequence_length}" - assert ( - output_custom.shape[1] == micro_batch_size - ), f"Output shape is {output_custom.shape} dont match micro batch size {micro_batch_size}" - assert ( - output_custom.shape[2] == transformer_config.hidden_size - ), f"Output shape is {output_custom.shape} dont match hidden size {transformer_config.hidden_size}" - - loss = output_custom.sum() - loss.backward() - - assert hidden_states.grad is not None, "Hidden states gradient is None" - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - def test_fwd_bwd_pass_mix_and_match_transformer_blocks(self): - world_size = 8 - actual_world_size = torch.cuda.device_count() - if actual_world_size != world_size: - pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}") - - Utils.initialize_model_parallel() - torch.manual_seed(12345) - model_parallel_cuda_manual_seed(123) - - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - # Create HyperCommGrid with dimensions tp, cp (reversed from device mesh order) - grid_cp_2_tp_4 = HyperCommGrid([4, 2], ["tp", "cp"]) - - tp_group = grid_cp_2_tp_4.create_pg("tp") - cp_group = grid_cp_2_tp_4.create_pg("cp") - model_comm_pgs = ModelCommProcessGroups(tp=tp_group, cp=cp_group) - - transformer_config = TransformerConfig( - num_layers=3, - hidden_size=4096, - num_attention_heads=32, - use_cpu_initialization=True, - attention_dropout=0.0, - hidden_dropout=0.0, - context_parallel_size=2, - ) - transformer_block_cp2_tp4 = ( - TransformerBlock( - transformer_config, - get_gpt_layer_with_transformer_engine_spec(), - model_comm_pgs=model_comm_pgs, - ) - .cuda() - .bfloat16() - ) - - sequence_length = 4096 - micro_batch_size = 4 - hidden_states = ( - torch.randn( - (sequence_length, micro_batch_size, transformer_config.hidden_size), device="cuda" - ) - .bfloat16() - .requires_grad_(True) - ) - hidden_states.retain_grad() - - # Create HyperCommGrid with dimensions ep, pp, dp, cp, tp (reversed from device mesh order) - grid_cp_2_tp_2_dp_2 = HyperCommGrid([2, 2, 2, 1, 1], ["tp", "cp", "dp", "pp", "ep"]) - tp_group = grid_cp_2_tp_2_dp_2.create_pg("tp") - cp_group = grid_cp_2_tp_2_dp_2.create_pg("cp") - dp_group = grid_cp_2_tp_2_dp_2.create_pg("dp") - pp_group = grid_cp_2_tp_2_dp_2.create_pg("pp") - ep_group = grid_cp_2_tp_2_dp_2.create_pg("ep") - model_comm_pgs = ModelCommProcessGroups(tp=tp_group, cp=cp_group, pp=pp_group, ep=ep_group) - grad_comm_pgs = GradCommProcessGroups() - - dp_cp_group = grid_cp_2_tp_2_dp_2.create_pg(["dp", "cp"]) - grad_comm_pgs.dp = dp_group - grad_comm_pgs.dp_cp = dp_cp_group - - transformer_block_cp2_tp2 = ( - TransformerBlock( - transformer_config, - get_gpt_layer_with_transformer_engine_spec(), - model_comm_pgs=model_comm_pgs, - ) - .cuda() - .bfloat16() - ) - - ddp_config = DistributedDataParallelConfig(overlap_grad_reduce=True, bucket_size=10000) - transformer_block_cp2_tp2_dp_2 = DistributedDataParallel( - config=transformer_config, - ddp_config=ddp_config, - module=transformer_block_cp2_tp2, - grad_comm_pgs=grad_comm_pgs, - model_comm_pgs=model_comm_pgs, - ) - - output_cp2_tp_2_dp_2 = transformer_block_cp2_tp2_dp_2( - hidden_states=hidden_states, attention_mask=None - ) - - assert output_cp2_tp_2_dp_2.shape == ( - sequence_length, - micro_batch_size, - transformer_config.hidden_size, - ), ( - f"Output shape is {output_cp2_tp_2_dp_2.shape} dont match sequence length {sequence_length}, " - f"micro batch size {micro_batch_size}, hidden size {transformer_config.hidden_size}" - ) - - # pass as input to transformer_block_cp2_tp4 - output_cp2_tp4 = transformer_block_cp2_tp4( - hidden_states=output_cp2_tp_2_dp_2, attention_mask=None - ) - - assert output_cp2_tp4.shape == ( - sequence_length, - micro_batch_size, - transformer_config.hidden_size, - ), ( - f"Output shape is {output_cp2_tp4.shape} dont match sequence length {sequence_length}, " - f"micro batch size {micro_batch_size}, hidden size {transformer_config.hidden_size}" - ) - - # verify backward pass - loss = output_cp2_tp4.sum() - loss.backward() - - assert hidden_states.grad is not None, "Hidden states gradient is None" - - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Device mesh feature requires PyTorch 2.3 or later", - ) - @pytest.mark.parametrize( - "world_size, tp_size, dp_size, reverse_tp_dp_order", - [ - (1, 1, 1, False), # Single GPU, no parallelism - (2, 1, 2, False), # 2 GPUs, 1 TP, 2 DP - (2, 2, 1, False), # 2 GPUs, 2 TP, 1 DP - (8, 8, 1, False), # 8 GPUs, 8 TP, 1 DP - (8, 1, 8, False), # 8 GPUs, 1 TP, 8 DP - (8, 2, 4, False), # 8 GPUs, 2 TP, 4 DP - (8, 4, 2, False), # 8 GPUs, 4 TP, 2 DP - (8, 2, 4, True), # 8 GPUs, 2 TP, 4 DP # reverse TP and DP order in device mesh - (8, 4, 2, True), # 8 GPUs, 4 TP, 2 DP # reverse TP and DP order in device mesh - ], - ) - def test_mlp_with_custom_pgs(self, world_size, tp_size, dp_size, reverse_tp_dp_order): - - actual_world_size = torch.cuda.device_count() - if actual_world_size != world_size: - pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}") - - Utils.initialize_model_parallel(tensor_model_parallel_size=tp_size) - - # Set PyTorch random seed explicitly for reproducible input - torch.manual_seed(12345) - model_parallel_cuda_manual_seed(123) - - # Initialize torch.distributed if not already initialized - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') - - if reverse_tp_dp_order: - # Create HyperCommGrid with dimensions ep, pp, tp, dp (reversed from device mesh order) - grid = HyperCommGrid([dp_size, tp_size, 1, 1], ["dp", "tp", "pp", "ep"]) - else: - # Create HyperCommGrid with dimensions ep, pp, dp, tp (reversed from device mesh order) - grid = HyperCommGrid([tp_size, dp_size, 1, 1], ["tp", "dp", "pp", "ep"]) - - pp_group = grid.create_pg("pp") - ep_group = grid.create_pg("ep") - dp_group = grid.create_pg("dp") - tp_group = grid.create_pg("tp") - mlp_model_comm_pgs = ModelCommProcessGroups(tp=tp_group, pp=pp_group, ep=ep_group) - - grad_comm_pgs = GradCommProcessGroups() - grad_comm_pgs.dp = dp_group - grad_comm_pgs.dp_cp = dp_group - - transformer_config = TransformerConfig( - num_layers=1, - hidden_size=4096, - num_attention_heads=32, - use_cpu_initialization=True, - attention_dropout=0.0, - hidden_dropout=0.0, - context_parallel_size=1, - ffn_hidden_size=4 * 4096, - ) - - default_mlp_spec = ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear - ), - ) - - custom_mlp_spec = ModuleSpec( - module=MLP, - params={'tp_group': mlp_model_comm_pgs.tp}, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear - ), - ) - - reference_mlp = create_reference_mlp( - transformer_config.hidden_size, transformer_config.ffn_hidden_size - ) - default_mlp = build_module(default_mlp_spec, config=transformer_config).cuda() - custom_mlp = build_module(custom_mlp_spec, config=transformer_config).cuda() - - copy_weights_to_tp_mlp( - reference_mlp, default_mlp, parallel_state.get_tensor_model_parallel_group() - ) - copy_weights_to_tp_mlp(reference_mlp, custom_mlp, tp_group) - - ddp_config = DistributedDataParallelConfig(overlap_grad_reduce=True, bucket_size=10000) - - default_mlp = DistributedDataParallel( - config=transformer_config, ddp_config=ddp_config, module=default_mlp - ) - - custom_mlp = DistributedDataParallel( - config=transformer_config, - ddp_config=ddp_config, - module=custom_mlp, - model_comm_pgs=mlp_model_comm_pgs, - grad_comm_pgs=grad_comm_pgs, - ) - - sequence_length = 4096 - micro_batch_size = 1 - hidden_states = torch.randn( - (sequence_length, micro_batch_size, transformer_config.hidden_size), device="cuda" - ).requires_grad_(True) - - torch.distributed.all_reduce(hidden_states, op=torch.distributed.ReduceOp.SUM) - - output_default, _ = default_mlp(hidden_states) - output_custom, _ = custom_mlp(hidden_states) - - torch.testing.assert_close(output_default, output_custom, rtol=1e-8, atol=0) - - def test_deterministic_output_from_single_block(self): - """ - Test that a single transformer block produces identical outputs - when run twice with the same input. - """ - # Initialize model parallel with no parallelism - Utils.initialize_model_parallel(tensor_model_parallel_size=1, context_parallel_size=1) - - # Set PyTorch random seed explicitly for reproducible inputs - torch.manual_seed(12345) - model_parallel_cuda_manual_seed(123) - - # Create transformer configuration - transformer_config = TransformerConfig( - num_layers=1, - hidden_size=64, - num_attention_heads=4, - use_cpu_initialization=True, - deterministic_mode=True, - attention_dropout=0.0, - hidden_dropout=0.0, - attention_backend=AttnBackend.unfused, - ) - - transformer_config_2 = copy.deepcopy(transformer_config) - - # Create a single transformer block - block = TransformerBlock(transformer_config, get_gpt_layer_local_spec()) - block_2 = TransformerBlock(transformer_config_2, get_gpt_layer_local_spec()) - # Move block to GPU - block.cuda() - block_2.cuda() - - # Create test input - sequence_length = 37 - micro_batch_size = 7 - - # copy weights from block_2 to block - for default_param, custom_param in zip(block.parameters(), block_2.parameters()): - custom_param.data.copy_(default_param.data) - - for name, buffer in block.named_buffers(): - if name in dict(block_2.named_buffers()): - dict(block_2.named_buffers())[name].copy_(buffer) - - hidden_states_int = torch.randint( - -10, - 10, - (sequence_length, micro_batch_size, transformer_config.hidden_size), - device="cuda", - ) - hidden_states = hidden_states_int.float() - - # Create two identical copies of the input - hidden_states_1 = hidden_states.clone() - hidden_states_2 = hidden_states.clone() - - # Forward passes through the same block - output_1 = block(hidden_states=hidden_states_1, attention_mask=None) - output_block_2 = block_2(hidden_states=hidden_states_2, attention_mask=None) - - torch.testing.assert_close( - output_1, - output_block_2, - rtol=0, - atol=0, - msg="Outputs don't match for identical inputs through the same block", - ) diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py deleted file mode 100644 index 7db9aa30fe..0000000000 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - - -import pytest -import torch - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensor -from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import ( - TransformerLayer, - get_transformer_layer_offset, -) -from tests.unit_tests.test_utilities import Utils - - -class TestParallelTransformerLayer: - - def setup_method(self, method): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True - ) - self.parallel_transformer_layer = TransformerLayer( - transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - parallel_transformer_layer = self.parallel_transformer_layer - assert isinstance(parallel_transformer_layer, TransformerLayer) - assert parallel_transformer_layer.layer_number == 1 - - num_weights = sum([p.numel() for p in parallel_transformer_layer.parameters()]) - assert num_weights == 1884 - - def test_gpu_forward(self): - parallel_transformer_layer = self.parallel_transformer_layer - config: TransformerConfig = parallel_transformer_layer.config - sequence_length = 32 - micro_batch_size = 2 - parallel_transformer_layer.cuda() - - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - hidden_states = hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - hidden_states, context = parallel_transformer_layer( - hidden_states=hidden_states, attention_mask=attention_mask - ) - assert hidden_states.shape[0] == sequence_length - assert hidden_states.shape[1] == micro_batch_size - assert hidden_states.shape[2] == config.hidden_size - - def test_chunked_mlp(self): - with torch.no_grad(): - - def test( - num_layers, - hidden_size, - num_attention_heads, - mlp_chunks_for_prefill, - hidden_states, - inference_context, - ): - - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - mlp_chunks_for_prefill=4, - add_bias_linear=True, - use_cpu_initialization=True, - ) - parallel_transformer_layer = TransformerLayer( - transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules - ) - - parallel_transformer_layer.cuda() - - hidden_states, context = parallel_transformer_layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - inference_context=inference_context, - ) - - return hidden_states, context - - num_layers = 2 - hidden_size = 12 - num_attention_heads = 4 - - sequence_length = 32 - micro_batch_size = 2 - - # [sequence length, batch size, hidden size] - input_hidden_states = torch.ones((sequence_length, micro_batch_size, hidden_size)) - input_hidden_states = input_hidden_states.cuda() - - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - - inference_context = StaticInferenceContext( - max_batch_size=micro_batch_size, max_sequence_length=sequence_length - ) - - outputs = {} - - for mlp_chunks_for_prefill in [1, 4]: - hidden_states, context = test( - num_layers, - hidden_size, - num_attention_heads, - mlp_chunks_for_prefill, - input_hidden_states, - inference_context, - ) - assert hidden_states.shape[0] == sequence_length - assert hidden_states.shape[1] == micro_batch_size - assert hidden_states.shape[2] == hidden_size - - outputs[mlp_chunks_for_prefill] = (hidden_states, context) - - assert torch.equal(outputs[1][0], outputs[4][0]) - - def test_get_layer_offset(self): - config = self.parallel_transformer_layer.config - assert get_transformer_layer_offset(config) == 0 - - @pytest.mark.parametrize( - "config_params,expected_offsets", - [ - # Test case 1: Both first and last stages set (30 layers: 8+6+6+10) - ( - { - "num_layers": 30, - "pipeline_model_parallel_size": 4, - "virtual_pipeline_model_parallel_size": 2, - "num_layers_in_first_pipeline_stage": 8, - "num_layers_in_last_pipeline_stage": 10, - "pipeline_dtype": torch.bfloat16, - }, - { - (0, 0): 0, # Stage 0, VP 0: layers 0-3 - (0, 1): 15, # Stage 0, VP 1: layers 15-18 - (1, 0): 4, # Stage 1, VP 0: layers 4-6 - (1, 1): 19, # Stage 1, VP 1: layers 19-21 - (2, 0): 7, # Stage 2, VP 0: layers 7-9 - (2, 1): 22, # Stage 2, VP 1: layers 22-24 - (3, 0): 10, # Stage 3, VP 0: layers 10-14 - (3, 1): 25, # Stage 3, VP 1: layers 25-29 - }, - ), - # Test case 2: Only first stage set (26 layers: 8+6+6+6) - ( - { - "num_layers": 26, - "pipeline_model_parallel_size": 4, - "virtual_pipeline_model_parallel_size": 2, - "num_layers_in_first_pipeline_stage": 8, - "num_layers_in_last_pipeline_stage": None, - "pipeline_dtype": torch.bfloat16, - }, - { - (0, 0): 0, # Stage 0, VP 0: layers 0-3 - (0, 1): 13, # Stage 0, VP 1: layers 13-16 - (1, 0): 4, # Stage 1, VP 0: layers 4-6 - (1, 1): 17, # Stage 1, VP 1: layers 17-19 - (2, 0): 7, # Stage 2, VP 0: layers 7-9 - (2, 1): 20, # Stage 2, VP 1: layers 20-22 - (3, 0): 10, # Stage 3, VP 0: layers 10-12 - (3, 1): 23, # Stage 3, VP 1: layers 23-25 - }, - ), - # Test case 3: Only last stage set (26 layers: 6+6+6+8) - ( - { - "num_layers": 26, - "pipeline_model_parallel_size": 4, - "virtual_pipeline_model_parallel_size": 2, - "num_layers_in_first_pipeline_stage": None, - "num_layers_in_last_pipeline_stage": 8, - "pipeline_dtype": torch.bfloat16, - }, - { - (0, 0): 0, # Stage 0, VP 0: layers 0-2 - (0, 1): 13, # Stage 0, VP 1: layers 13-15 - (1, 0): 3, # Stage 1, VP 0: layers 3-5 - (1, 1): 16, # Stage 1, VP 1: layers 16-18 - (2, 0): 6, # Stage 2, VP 0: layers 6-8 - (2, 1): 19, # Stage 2, VP 1: layers 19-21 - (3, 0): 9, # Stage 3, VP 0: layers 9-12 - (3, 1): 22, # Stage 3, VP 1: layers 22-25 - }, - ), - # Test case 4: Even distribution (24 layers: 6+6+6+6) - ( - { - "num_layers": 24, - "pipeline_model_parallel_size": 4, - "virtual_pipeline_model_parallel_size": 2, - "num_layers_in_first_pipeline_stage": None, - "num_layers_in_last_pipeline_stage": None, - "pipeline_dtype": torch.bfloat16, - }, - { - (0, 0): 0, # Stage 0, VP 0: layers 0-2 - (0, 1): 12, # Stage 0, VP 1: layers 12-14 - (1, 0): 3, # Stage 1, VP 0: layers 3-5 - (1, 1): 15, # Stage 1, VP 1: layers 15-17 - (2, 0): 6, # Stage 2, VP 0: layers 6-8 - (2, 1): 18, # Stage 2, VP 1: layers 18-20 - (3, 0): 9, # Stage 3, VP 0: layers 9-11 - (3, 1): 21, # Stage 3, VP 1: layers 21-23 - }, - ), - ], - ) - def test_get_layer_offset_parametrized(self, config_params, expected_offsets): - """ - Parametrized test for get_transformer_layer_offset with different configurations. - Tests various combinations of first/last stage settings and virtual pipeline sizes. - - This test verifies that the layer offset calculation correctly handles: - - Asymmetric pipeline stages (different layer counts per stage) - - Virtual pipeline parallelism (splitting physical stages into virtual stages) - - Various combinations of first/last stage configurations - - The expected_offsets dictionary maps (pipeline_rank, vp_stage) tuples to - the expected starting layer index for that stage combination. - """ - - config = TransformerConfig( - hidden_size=512, num_attention_heads=8, use_cpu_initialization=True, **config_params - ) - - for (pipeline_rank, vp_stage), expected_offset in expected_offsets.items(): - original_get_pipeline_rank = parallel_state.get_pipeline_model_parallel_rank - parallel_state.set_pipeline_model_parallel_rank(pipeline_rank) - - try: - actual_offset = get_transformer_layer_offset(config, vp_stage) - assert actual_offset == expected_offset, ( - f"Expected offset {expected_offset} for pipeline rank {pipeline_rank}, " - f"VP stage {vp_stage}, but got {actual_offset}" - ) - finally: - parallel_state.set_pipeline_model_parallel_rank(original_get_pipeline_rank) - - @pytest.mark.parametrize('order', ['tp-pp-dp', 'tp-dp-pp']) - @pytest.mark.parametrize('tp_pp', [(4, 2), (1, 1), (8, 1), (2, 2)]) - def test_sharded_state_dict(self, tp_pp, order): - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(*tp_pp, order=order) - - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, hidden_size=128, num_attention_heads=8, use_cpu_initialization=True - ) - parallel_transformer_layer = TransformerLayer( - transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules - ) - - sharded_state_dict = parallel_transformer_layer.sharded_state_dict() - - extra_states = {k: v for k, v in sharded_state_dict.items() if k.endswith('extra_state')} - sharded_tensors = { - k: v for k, v in sharded_state_dict.items() if not k.endswith('extra_state') - } - assert all(isinstance(t, ShardedObject) for t in extra_states.values()) - assert all(isinstance(t, ShardedTensor) for t in sharded_tensors.values()) - - # Test all local shapes - tensor_local_shapes = {k: v.local_shape for k, v in sharded_tensors.items()} - tp_size = parallel_state.get_tensor_model_parallel_world_size() - assert tensor_local_shapes == get_tensor_shapes_for_tp(transformer_config, tp_size) - - # Test all global shapes. Prepend num layers in front of expected shapes - tensor_global_shapes = {k: v.global_shape for k, v in sharded_tensors.items()} - expected_global_shapes = get_tensor_shapes_for_tp(transformer_config, 1) - assert tensor_global_shapes == expected_global_shapes - - # Test ShardedTensor keys - for state_dict_key, sh_ten in sharded_tensors.items(): - assert state_dict_key == sh_ten.key - - Utils.destroy_model_parallel() - Utils.initialize_model_parallel(1, 1) - - -def get_tensor_shapes_for_tp(transformer_config, tp_size): - hs = transformer_config.hidden_size - return { - 'mlp.linear_fc1.layer_norm_weight': (hs,), - 'mlp.linear_fc1.layer_norm_bias': (hs,), - 'mlp.linear_fc1.weight': (hs * 4 // tp_size, hs), - 'mlp.linear_fc1.bias': (hs * 4 // tp_size,), - 'mlp.linear_fc2.weight': (hs, hs * 4 // tp_size), - 'mlp.linear_fc2.bias': (hs,), - 'self_attention.linear_proj.weight': (hs, hs // tp_size), - 'self_attention.linear_proj.bias': (hs,), - 'self_attention.linear_qkv.layer_norm_weight': (hs,), - 'self_attention.linear_qkv.layer_norm_bias': (hs,), - 'self_attention.linear_qkv.weight': (hs * 3 // tp_size, hs), - 'self_attention.linear_qkv.bias': (hs * 3 // tp_size,), - } diff --git a/tests/unit_tests/transformer/test_utils.py b/tests/unit_tests/transformer/test_utils.py deleted file mode 100644 index ebbb13b089..0000000000 --- a/tests/unit_tests/transformer/test_utils.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import inspect -import os - -import pytest -import torch - -import megatron.core.transformer.utils as transformer_utils -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import set_model_to_sequence_parallel -from tests.unit_tests.test_utilities import Utils - - -class TestGPTModel: - - def setup_method(self, method): - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - - self.tensor_model_parallel_size = 2 - Utils.initialize_model_parallel(self.tensor_model_parallel_size, 1) - model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=48, - num_attention_heads=4, - use_cpu_initialization=True, - tensor_model_parallel_size=self.tensor_model_parallel_size, - sequence_parallel=False, - ) - self.gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), - vocab_size=100, - max_sequence_length=8, - position_embedding_type="rope", - parallel_output=False, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.internal - def test_post_process_forward(self): - _ = self.gpt_model.config - sequence_length = self.gpt_model.max_sequence_length - micro_batch_size = 2 - - self.gpt_model.cuda() - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = self.gpt_model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.gpt_model.vocab_size - - set_model_to_sequence_parallel(self.gpt_model, set_to=True) - logits = self.gpt_model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - # Test cache has been built - assert transformer_utils._sequence_parallel_attr_cache is not None - - # Check the modules have been flipped - for attribute, modules in transformer_utils._sequence_parallel_attr_cache[ - id(self.gpt_model) - ].items(): - for module in modules: - assert getattr(module, attribute) == True - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.gpt_model.vocab_size - - set_model_to_sequence_parallel(self.gpt_model, set_to=False) - logits = self.gpt_model.forward( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.gpt_model.vocab_size - - # Check the modules have been flipped - for attribute, modules in transformer_utils._sequence_parallel_attr_cache[ - id(self.gpt_model) - ].items(): - for module in modules: - assert getattr(module, attribute) == False diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 7d382a0d13..99284cbcc0 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -27,7 +27,7 @@ from megatron.training.tokenizer import build_tokenizer from megatron.training.arguments import _add_tokenizer_args from megatron.core.datasets import indexed_dataset - +from tqdm import tqdm # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer class CustomLanguageVars(PunktLanguageVars): @@ -131,7 +131,7 @@ def print_processing_stats(self, count, proc_start, total_bytes_processed): def split_sentences(self, file_name): input_file_name, output_file_name = file_name print("Opening", input_file_name) - fin = open(input_file_name, 'r', encoding='utf-8') + fin = open(input_file_name, 'r', encoding='utf-8', errors='replace') fout = open(output_file_name, 'w') encoder = Encoder(self.args) @@ -152,7 +152,7 @@ def split_sentences(self, file_name): def process_json_file(self, file_name): input_file_name, output_prefix = file_name print("Opening", input_file_name) - fin = open(input_file_name, 'r', encoding='utf-8') + fin = open(input_file_name, 'r', encoding='utf-8', errors='replace') startup_start = time.time() encoder = Encoder(self.args) @@ -192,7 +192,9 @@ def process_json_file(self, file_name): self.print_processing_stats(i, proc_start, total_bytes_processed) fin.close() - builders[key].finalize(output_idx_files[key]) + # builders[key].finalize(output_idx_files[key]) + for k in builders: + builders[k].finalize(output_idx_files[k]) def get_args(): @@ -244,17 +246,26 @@ def get_args(): return args -def get_file_name(args, file_id): - file_name, extension = os.path.splitext(args.input) - input_file_name = file_name + "_" + str(file_id) + extension - sentence_split_file = file_name + "_ss_" + str(file_id) + extension - output_prefix = args.output_prefix + "_" + str(file_id) - file_names = { - 'partition': input_file_name, - 'sentence_split': sentence_split_file, - 'output_prefix': output_prefix} - return file_names +# def get_file_name(args, file_id): + # file_name, extension = os.path.splitext(args.input) + # input_file_name = file_name + "_" + str(file_id) + extension + # sentence_split_file = file_name + "_ss_" + str(file_id) + extension + # output_prefix = args.output_prefix + "_" + str(file_id) + # file_names = { + # 'partition': input_file_name, + # 'sentence_split': sentence_split_file, + # 'output_prefix': output_prefix} + # return file_names +def get_file_name(args, file_id): + prefix = f"{args.output_prefix}_{file_id}" + partition_path = f"{args.output_prefix}_part_{file_id}.jsonl" + sentence_split_path = f"{args.output_prefix}_part_ss_{file_id}.jsonl" + return { + 'partition': partition_path, + 'sentence_split': sentence_split_path, + 'output_prefix': prefix, + } def check_files_exist(in_ss_out_names, key, num_partitions): for i in range(num_partitions): @@ -274,7 +285,9 @@ def main(): "nltk library required for sentence splitting is not available.") in_ss_out_names = [] + # import pdb;pdb.set_trace() if args.partitions == 1: + # if False: file_name, extension = os.path.splitext(args.input) sentence_split_file = file_name + "_ss" + extension file_names = { @@ -283,7 +296,10 @@ def main(): 'output_prefix': args.output_prefix} in_ss_out_names.append(file_names) else: - in_file_names = glob.glob(args.input) + in_file_names = glob.glob(args.input,recursive=True) + if not in_file_names: + raise RuntimeError(f"No files matched input pattern: {args.input}. Did you mean to use recursive=True?") + print(f"Found {len(in_file_names)} input files") # Count total number of lines across .jsonl files if args.keep_sequential_samples: @@ -296,7 +312,7 @@ def main(): partition_size = math.ceil(total_sample_count / args.partitions) # create .jsonl parition files - for idx in range(args.partitions): + for idx in tqdm(range(args.partitions)): in_ss_out_name = get_file_name(args, idx) in_ss_out_names.append(in_ss_out_name) @@ -309,16 +325,16 @@ def main(): if not partitions_present and not split_sentences_present: # populate .jsonl partition files from parent files partitioned_input_files = [] - for idx in range(args.partitions): + for idx in tqdm(range(args.partitions)): partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w') partitioned_input_files.append(partitioned_input_file) index = 0 if args.keep_sequential_samples: line_count = 0 - for in_file_name in in_file_names: + for in_file_name in tqdm(in_file_names): # support for gzip files if in_file_name.endswith(".gz"): - fin = gzip.open(in_file_name, 'rt') + fin = gzip.open(in_file_name, 'rt', encoding='utf-8', errors='replace') else: fin = open(in_file_name, 'r', encoding='utf-8') diff --git a/tools/process_wikipedia.py b/tools/process_wikipedia.py new file mode 100644 index 0000000000..89f658fa37 --- /dev/null +++ b/tools/process_wikipedia.py @@ -0,0 +1,28 @@ +import os +from datasets import load_dataset,concatenate_datasets +from tqdm import tqdm +dataset_root = "dataset/wikipedia" + +configs = [d for d in os.listdir(dataset_root) if os.path.isdir(os.path.join(dataset_root, d))] + +all_datasets = [] +for config in tqdm(configs): + try: + ds = load_dataset(dataset_root, config, split="train", keep_in_memory=False) + all_datasets.append(ds) + except Exception as e: + print(f"Failed to load dataset for config {config}: {e}") + +if all_datasets: + # merged_dataset = all_datasets[0] + merged_dataset = concatenate_datasets(all_datasets) + # for ds in tqdm(all_datasets[1:]): + # merged_dataset = merged_dataset.concatenate(ds) + print("All datasets merged successfully.") +else: + print("No datasets were loaded. Please check the dataset root directory and configurations.") + +output_file = "dataset/wikipedia.json" +merged_dataset.to_json(output_file, orient="records", lines=True, force_ascii=False) + +print(f"Merged dataset saved to {output_file}") diff --git a/transformer_engine b/transformer_engine new file mode 120000 index 0000000000..bdafc97101 --- /dev/null +++ b/transformer_engine @@ -0,0 +1 @@ +/usr/local/lib/python3.12/dist-packages/transformer_engine/ \ No newline at end of file diff --git a/visualization/README.md b/visualization/README.md new file mode 100644 index 0000000000..58a9e1a9fc --- /dev/null +++ b/visualization/README.md @@ -0,0 +1,251 @@ +# Tensor Visualization and Analysis Tools + +This directory contains tools for analyzing and visualizing tensor data across different numerical formats (bf16, hifp8, mxfp8, mxfp4). + +## Files Overview + +- `overflow.py` - **Value-level analysis**: Shows percentage of tensor values that overflow/underflow within files +- `overflow_summary.py` - **File-level analysis**: Shows percentage of tensor files that have overflow/underflow issues +- `layer_analysis.py` - Layer-specific analysis and visualization +- `distribution.py` - **Single tensor distribution**: Visualizes tensor value distribution against format's representable values +- `README.md` - This documentation file + +### Key Difference +- **overflow.py**: "X% of values in this tensor file overflow/underflow" +- **overflow_summary.py**: "X% of tensor files contain overflow/underflow issues" + +## Prerequisites + +Make sure you have the required Python packages installed: + +```bash +pip install torch numpy matplotlib +``` + +## Usage Examples + +### 1. File-Level Analysis (Which files have issues?) + +Analyze what percentage of tensor files contain overflow/underflow problems: + +```python +# Run comprehensive file-level analysis +python overflow_summary.py --base-dir enhanced_tensor_logs --output-dir visualization + +# Or with custom paths +python overflow_summary.py --base-dir /path/to/tensor/logs --output-dir /path/to/output +``` + +This will generate: +- `overflow_comprehensive_report.txt` - File-level analysis report (e.g., "15% of bf16 files have overflow") +- `overflow_detailed_results.json` - Complete results in JSON format +- `overflow_summary.csv` - Summary statistics in CSV format + +### 2. Value-Level Analysis (How many values overflow?) + +Analyze what percentage of tensor values overflow/underflow within specific files: + +```python +# Analyze a single tensor file (shows % of values that overflow) +python overflow.py enhanced_tensor_logs/bf16/tensor_file.pt + +# Analyze multiple tensor files at once +python overflow.py file1.pt file2.pt file3.pt + +# Analyze all files in a directory (shows % of values in each file) +python overflow.py enhanced_tensor_logs/bf16/ --recursive + +# Mix files and directories +python overflow.py tensor1.pt enhanced_tensor_logs/bf16/ tensor2.pt --recursive + +# Generate CSV output +python overflow.py enhanced_tensor_logs/mxfp4/ --recursive --format csv --output mxfp4_value_results.csv +``` + +### 3. Single Tensor Distribution Analysis + +Visualize how well a specific tensor fits within a data format's representable values: + +```python +# Basic distribution analysis +python distribution.py enhanced_tensor_logs/bf16/tensor_file.pt + +# With detailed statistics +python distribution.py enhanced_tensor_logs/mxfp4/tensor_file.pt --show-stats + +# Custom output directory +python distribution.py tensor_file.pt --output-dir /path/to/output/ +``` + +This will generate: +- `tensor_name.png` - Distribution plot with representable values overlay +- Console output with usability assessment (EXCELLENT/GOOD/CAUTION/POOR) + +### 4. Layer-Specific Analysis + +Analyze specific layers, ranks, and tensor types (supports multi-parameter combinations): + +```python +# Single combination analysis +python layer_analysis.py --layer 0 --rank 0 --type linear + +# Multi-parameter analysis (automatic combinations) +python layer_analysis.py --layer 1,8,15,16 --rank 0,1 --type linear,attention + +# Mixed single and multiple parameters +python layer_analysis.py --layer 0 --rank 0,1,2 --type attention + +# Analyze specific data format only +python layer_analysis.py --layer 0 --rank 0 --type linear --format bf16 + +# Custom input/output directories +python layer_analysis.py --layer 1,2 --rank 0,1 --type linear --base-dir /path/to/logs --output-dir /path/to/output +``` + +**Multi-parameter examples:** +- `--layer 1,8,15,16 --rank 0,1 --type linear,attention` generates 16 combinations (4×2×2) +- Each combination produces separate plots and reports +- Progress bar shows real-time processing status + +This will generate: +- `layer_X_rank_Y_TYPE_analysis.png` - Distribution plots for all data formats +- `layer_X_rank_Y_TYPE_report.txt` - Detailed analysis report + +## Output Files Description + +### Overflow Analysis Reports + +#### File-Level Reports (from overflow_summary.py) +1. **overflow_comprehensive_report.txt** + - **Primary focus**: Percentage of files with overflow/underflow issues + - Format-specific file statistics (e.g., "25% of bf16 files have overflow") + - Secondary reference: Overall value percentages + - Risk assessment based on file percentages + +2. **overflow_detailed_results.json** + - Complete analysis results with both file and value statistics + - Suitable for programmatic processing + - Contains metadata for each analyzed file + +3. **overflow_summary.csv** + - Tabular summary focused on file-level statistics + - Easy to import into spreadsheet applications + +#### Value-Level Reports (from overflow.py) +1. **Individual file analysis** + - **Primary focus**: Percentage of values that overflow/underflow within each file + - Detailed statistics for each tensor file + - File-by-file breakdown of problematic values + +### Distribution Analysis Outputs (from distribution.py) + +1. **Distribution Plots (PNG files)** + - Histogram of tensor value distribution + - **Red vertical lines**: All representable values in the data format + - **Dark red dashed lines**: Overflow boundaries (±max_normal) + - **Orange dotted lines**: Underflow boundaries (±min_denormal) + - **Statistics box**: Complete analysis summary + - **Usability assessment**: Color-coded recommendation (🟢🟡🟠🔴) + +### Layer Analysis Outputs + +1. **Distribution Plots (PNG files)** + - Multi-panel plots showing tensor value distributions + - Overlay of representable values for each format + - Format range indicators and overflow/underflow warnings + - Statistical information for each component + +2. **Analysis Reports (TXT files)** + - Detailed statistics for each tensor component + - Overflow/underflow detection results + - Recommendations based on findings + +## Data Format Information + +The tools understand the following numerical formats with their accurate ranges: + +| Format | Description | Max Normal | Min Denormal | Supports Inf/NaN | +|--------|-------------|------------|--------------|------------------| +| bf16 | Brain Float 16-bit | ±6.55×10⁴ | 5.96×10⁻⁸ | Yes/Yes | +| hifp8 | Huawei HiFP8 E4M3 | ±3.28×10⁴ | 2.38×10⁻⁷ | Yes/Yes | +| mxfp8 | Microsoft MX FP8 E4M3 | ±4.48×10² (min normal: ±1.56×10⁻²) | 1.95×10⁻³ | No/Yes | +| mxfp4 | Microsoft MX FP4 E2M1 | ±6.00×10⁰ (min normal: ±1.0) | 5.00×10⁻¹ | No/No | + +### Overflow and Underflow Definition + +- **Overflow**: Values with magnitude exceeding the maximum representable value (|value| > max_normal) +- **Underflow**: Non-zero values closer to zero than the smallest representable non-zero value (0 < |value| < min_denormal) + +Note: Underflow represents values that would be rounded to zero in the target format, causing precision loss. + +## File Naming Convention + +The tools expect tensor files to follow this naming pattern: +``` +YYYYMMDD_HHMMSS_XXXX_iterNNN_TYPE_LXX_forward/backward_FORMAT_rankXX_sampleX_groupXXX_COMPONENT.pt +``` + +Where: +- `TYPE` = linear, attention +- `LXX` = layer number (e.g., L0, L16) +- `FORMAT` = bf16, hifp8, mxfp8, mxfp4 +- `rankXX` = rank/card number +- `COMPONENT` = input, output, weight, bias, query, key, value, attention_weights, etc. + +## Advanced Usage + +### Programmatic Access + +You can also use the modules programmatically: + +```python +from overflow import analyze_file, DATA_TYPE_RANGES +from layer_analysis import find_matching_files, analyze_tensor + +# Analyze a single file +result = analyze_file('/path/to/tensor.pt') +if result: + print(f"Overflow: {result['overflow_percent']:.2f}%") + +# Find files matching criteria +files = find_matching_files('enhanced_tensor_logs', layer=0, rank=0, tensor_type='linear') +``` + +### Batch Processing + +For processing multiple configurations: + +```bash +# Analyze multiple layers and ranks +for layer in {0..16}; do + for rank in {0..7}; do + python layer_analysis.py --layer $layer --rank $rank --type linear + python layer_analysis.py --layer $layer --rank $rank --type attention + done +done +``` + +## Troubleshooting + +### Common Issues + +1. **Import errors**: Ensure PyTorch and other dependencies are installed +2. **File not found**: Check that tensor files exist in expected locations +3. **Memory issues**: For large tensors, the tools load data to CPU to conserve memory +4. **Permission errors**: Ensure write permissions for output directory + +### Performance Tips + +- Use `--format` option to analyze specific formats when you don't need all +- The tools automatically handle CUDA tensors by moving them to CPU +- Large batch analyses may take significant time - consider running in background + +## Support + +If you encounter issues or need additional features, please check: +1. File naming conventions match expected patterns +2. Required Python packages are installed +3. Input directories and files exist and are readable +4. Output directories are writable + +For custom analysis requirements, the modular design allows easy extension of the existing tools. diff --git a/visualization/distribution.py b/visualization/distribution.py new file mode 100755 index 0000000000..017e55e122 --- /dev/null +++ b/visualization/distribution.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +""" +Tensor distribution visualization tool. +Generates distribution plots showing tensor values against representable values for the data format. +""" + +import os +import torch +import numpy as np +import matplotlib.pyplot as plt +import argparse +from pathlib import Path + +# Import data format information from layer_analysis +import sys +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from layer_analysis import DATA_TYPE_INFO + +def load_and_process_tensor(filepath): + """ + Load and process a tensor file. + + Args: + filepath (str): Path to tensor file + + Returns: + tuple: (tensor_data, data_format, filename) or (None, None, None) if failed + """ + try: + filename = os.path.basename(filepath) + + # Detect data format from filename + data_format = detect_data_format(filename) + if data_format is None: + print(f"Warning: Could not detect data format from filename: {filename}") + return None, None, None + + # Load tensor + tensor = torch.load(filepath, map_location='cpu', weights_only=False) + + # Handle case where loaded object is not a tensor + if not isinstance(tensor, torch.Tensor): + if isinstance(tensor, dict) and 'tensor' in tensor: + tensor = tensor['tensor'] + elif isinstance(tensor, (list, tuple)) and len(tensor) > 0: + tensor = tensor[0] + else: + print(f"Warning: Loaded object is not a tensor: {filename}") + return None, None, None + + # Convert BFloat16 and other unsupported types to Float32 for CPU processing + if tensor.dtype == torch.bfloat16: + tensor = tensor.float() + elif tensor.dtype in [torch.float16, torch.half]: + tensor = tensor.float() + elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + tensor = tensor.float() + elif tensor.dtype in [torch.uint8]: + tensor = tensor.float() + + # Convert to numpy and flatten + if tensor.is_cuda: + tensor_np = tensor.cpu().numpy() + else: + tensor_np = tensor.numpy() + + # Handle empty tensors + if tensor_np.size == 0: + print(f"Warning: Empty tensor: {filename}") + return None, None, None + + # Handle complex tensors + if tensor_np.dtype in [np.complex64, np.complex128]: + tensor_np = np.abs(tensor_np) + + # Flatten for distribution analysis + flat_tensor = tensor_np.flatten() + + return flat_tensor, data_format, filename + + except Exception as e: + print(f"Error processing file {filepath}: {str(e)}") + return None, None, None + +def detect_data_format(filename): + """Extract data format from filename.""" + for fmt in DATA_TYPE_INFO.keys(): + if fmt in filename: + return fmt + return None + +def create_distribution_plot(tensor_data, data_format, filename, output_path): + """ + Create distribution plot with representable values overlay. + + Args: + tensor_data (np.array): Flattened tensor data + data_format (str): Data format identifier + filename (str): Original filename + output_path (Path): Output file path + """ + if data_format not in DATA_TYPE_INFO: + raise ValueError(f"Unknown data format: {data_format}") + + format_info = DATA_TYPE_INFO[data_format] + + # Calculate data statistics for dynamic range adjustment + data_min, data_max = np.min(tensor_data), np.max(tensor_data) + data_range = data_max - data_min + data_abs_max = max(abs(data_min), abs(data_max)) + + # Check for actual overflow/underflow + min_denormal = format_info['min_denormal'] + max_normal = format_info['max'] + + has_overflow = np.any(np.abs(tensor_data) > max_normal) + has_underflow = np.any((tensor_data != 0.0) & (np.abs(tensor_data) < min_denormal)) + + # Dynamic range calculation - focus on data distribution + if data_range > 0: + margin = data_range * 0.15 # 15% margin + plot_min = data_min - margin + plot_max = data_max + margin + + # Ensure underflow boundary is visible if relevant + if min_denormal > 0 and data_abs_max > min_denormal * 0.1: + plot_min = min(plot_min, -min_denormal * 1.5) + plot_max = max(plot_max, min_denormal * 1.5) + else: + plot_min, plot_max = data_min - 1, data_max + 1 + + # Create figure + plt.figure(figsize=(14, 10)) + + # Calculate histogram + n_bins = min(200, max(50, int(np.sqrt(len(tensor_data))))) + counts, bins, patches = plt.hist(tensor_data, bins=n_bins, alpha=0.7, + color=format_info['color'], density=True, + label=f'Tensor Values (n={len(tensor_data):,})') + + # Set dynamic x-axis range + plt.xlim(plot_min, plot_max) + + # Add representable values as vertical red lines (filtered to plot range) + if format_info['representable_values'] is not None: + rep_values = np.array(format_info['representable_values']) + visible_rep_values = rep_values[(rep_values >= plot_min) & (rep_values <= plot_max)] + + print(f"Showing {len(visible_rep_values)} representable values in range [{plot_min:.3f}, {plot_max:.3f}]") + + # Add vertical lines for representable values + for val in visible_rep_values: + plt.axvline(val, color='red', alpha=0.6, linewidth=1.0, zorder=3) + + # Add overflow boundaries only if relevant (actual overflow or values close to boundary) + if has_overflow or data_abs_max > max_normal * 0.5: + if plot_max >= max_normal: + plt.axvline(max_normal, color='darkred', linestyle='--', linewidth=2, + alpha=0.9, label=f'Overflow Boundary (+{max_normal:.1e})', zorder=4) + if plot_min <= -max_normal: + plt.axvline(-max_normal, color='darkred', linestyle='--', linewidth=2, + alpha=0.9, label=f'Overflow Boundary (-{max_normal:.1e})', zorder=4) + + # Always show underflow boundaries (critical for precision analysis) + if plot_max >= min_denormal: + plt.axvline(min_denormal, color='orange', linestyle=':', linewidth=2, + alpha=0.9, label=f'Underflow Boundary (+{min_denormal:.1e})', zorder=4) + if plot_min <= -min_denormal: + plt.axvline(-min_denormal, color='orange', linestyle=':', linewidth=2, + alpha=0.9, label=f'Underflow Boundary (-{min_denormal:.1e})', zorder=4) + + # Add zero reference line + if plot_min < 0 < plot_max: + plt.axvline(0, color='gray', linestyle='-', linewidth=1, alpha=0.5, label='Zero') + + # Calculate and display statistics + tensor_min = np.min(tensor_data) + tensor_max = np.max(tensor_data) + tensor_mean = np.mean(tensor_data) + tensor_std = np.std(tensor_data) + + # Check for overflow/underflow + overflow_count = np.sum(np.abs(tensor_data) > format_info['max']) + non_zero_mask = tensor_data != 0.0 + abs_tensor = np.abs(tensor_data) + underflow_count = np.sum(non_zero_mask & (abs_tensor < min_denormal)) + + overflow_percent = (overflow_count / len(tensor_data)) * 100 + underflow_percent = (underflow_count / len(tensor_data)) * 100 + + # Add statistics text box + stats_text = ( + f'Data Format: {data_format.upper()} ({format_info["description"]})\n' + f'Tensor Shape: {tensor_data.shape if hasattr(tensor_data, "shape") else "Flattened"}\n' + f'Total Elements: {len(tensor_data):,}\n\n' + f'Value Statistics:\n' + f' Min: {tensor_min:.6f}\n' + f' Max: {tensor_max:.6f}\n' + f' Mean: {tensor_mean:.6f}\n' + f' Std: {tensor_std:.6f}\n\n' + f'Format Boundaries:\n' + f' Max Normal: ±{format_info["max"]:.1e}\n' + f' Min Denormal: {min_denormal:.1e}\n\n' + f'Overflow/Underflow:\n' + f' Overflow: {overflow_count:,} ({overflow_percent:.4f}%)\n' + f' Underflow: {underflow_count:,} ({underflow_percent:.4f}%)' + ) + + # Add warning for issues + if overflow_count > 0 or underflow_count > 0: + stats_text += '\n\n⚠️ NUMERICAL ISSUES DETECTED!' + else: + stats_text += '\n\n✅ No overflow/underflow detected' + + # Position stats box + plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, + verticalalignment='top', fontsize=9, + bbox=dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.9, edgecolor='gray')) + + # Set labels and title + plt.xlabel('Value', fontsize=12) + plt.ylabel('Density', fontsize=12) + plt.title(f'Tensor Value Distribution vs {data_format.upper()} Representable Values\n' + f'File: {filename}', fontsize=14, fontweight='bold', pad=20) + + # Add legend + legend_elements = [ + plt.Line2D([0], [0], color=format_info['color'], alpha=0.7, linewidth=5, label='Tensor Values'), + plt.Line2D([0], [0], color='red', alpha=0.6, linewidth=1, label='Representable Values'), + plt.Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, label='Overflow Boundary'), + plt.Line2D([0], [0], color='orange', linestyle=':', linewidth=2, label='Underflow Boundary') + ] + + plt.legend(handles=legend_elements, loc='upper right', fontsize=10) + + # Improve layout + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # Save plot + plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white') + plt.close() + + print(f"Distribution plot saved to: {output_path}") + + # Return analysis summary + return { + 'filename': filename, + 'data_format': data_format, + 'total_elements': len(tensor_data), + 'value_range': [tensor_min, tensor_max], + 'mean_std': [tensor_mean, tensor_std], + 'overflow_count': overflow_count, + 'underflow_count': underflow_count, + 'overflow_percent': overflow_percent, + 'underflow_percent': underflow_percent, + 'format_max': format_info['max'], + 'format_min_denormal': min_denormal, + 'representable_values_shown': len(visible_rep_values) if format_info['representable_values'] is not None else 0 + } + +def main(): + """Main function for distribution analysis.""" + parser = argparse.ArgumentParser(description='Generate tensor value distribution plots with representable values overlay') + parser.add_argument('input_file', help='Path to tensor file (.pt)') + parser.add_argument('--output-dir', default='./draw/distribution_tensor/', + help='Output directory for plots (default: ./draw/distribution_tensor/)') + parser.add_argument('--show-stats', action='store_true', + help='Print detailed statistics to console') + + args = parser.parse_args() + + # Validate input file + input_path = Path(args.input_file) + if not input_path.exists(): + print(f"Error: Input file does not exist: {input_path}") + return 1 + + if not input_path.is_file(): + print(f"Error: Input path is not a file: {input_path}") + return 1 + + # Setup output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate output filename: same as input but with .png extension + output_filename = input_path.stem + '.png' + output_path = output_dir / output_filename + + print(f"Analyzing tensor distribution: {input_path.name}") + print("=" * 60) + + # Load and process tensor + tensor_data, data_format, filename = load_and_process_tensor(str(input_path)) + + if tensor_data is None: + print("Failed to load or process tensor file.") + return 1 + + print(f"Data format detected: {data_format.upper()}") + print(f"Tensor elements: {len(tensor_data):,}") + print(f"Value range: [{np.min(tensor_data):.6f}, {np.max(tensor_data):.6f}]") + + # Create distribution plot + analysis_summary = create_distribution_plot(tensor_data, data_format, filename, output_path) + + # Print summary statistics if requested + if args.show_stats: + print("\nDetailed Analysis Summary:") + print("-" * 40) + print(f"Data Format: {analysis_summary['data_format'].upper()}") + print(f"Total Elements: {analysis_summary['total_elements']:,}") + print(f"Value Range: [{analysis_summary['value_range'][0]:.6f}, {analysis_summary['value_range'][1]:.6f}]") + print(f"Mean ± Std: {analysis_summary['mean_std'][0]:.6f} ± {analysis_summary['mean_std'][1]:.6f}") + print(f"Format Max Normal: ±{analysis_summary['format_max']:.1e}") + print(f"Format Min Denormal: {analysis_summary['format_min_denormal']:.1e}") + print(f"Representable Values Shown: {analysis_summary['representable_values_shown']}") + print(f"Overflow: {analysis_summary['overflow_count']:,} ({analysis_summary['overflow_percent']:.4f}%)") + print(f"Underflow: {analysis_summary['underflow_count']:,} ({analysis_summary['underflow_percent']:.4f}%)") + + if analysis_summary['overflow_count'] > 0 or analysis_summary['underflow_count'] > 0: + print("⚠️ Numerical issues detected!") + else: + print("✅ No numerical issues detected") + + print(f"\nVisualization complete!") + print(f"Plot saved to: {output_path}") + + # Provide usage assessment + print("\nUsability Assessment:") + print("-" * 30) + + overflow_pct = analysis_summary['overflow_percent'] + underflow_pct = analysis_summary['underflow_percent'] + + if overflow_pct == 0 and underflow_pct == 0: + print("🟢 EXCELLENT: All values are representable in this format") + elif overflow_pct < 0.1 and underflow_pct < 0.1: + print("🟡 GOOD: Less than 0.1% of values have representation issues") + elif overflow_pct < 1.0 and underflow_pct < 1.0: + print("🟠 CAUTION: Less than 1% of values have representation issues") + else: + print("🔴 POOR: Significant representation issues detected") + print(" Consider using a higher precision format for this tensor") + + return 0 + +if __name__ == "__main__": + exit(main()) diff --git a/visualization/layer_analysis.py b/visualization/layer_analysis.py new file mode 100755 index 0000000000..2416956f66 --- /dev/null +++ b/visualization/layer_analysis.py @@ -0,0 +1,682 @@ +#!/usr/bin/env python3 +""" +Layer-specific tensor analysis and visualization. +Generates distribution plots and analysis reports for specific layers and ranks. +""" + +import os +import re +import torch +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.gridspec import GridSpec +import argparse +from pathlib import Path +from datetime import datetime +from itertools import product +from tqdm import tqdm + +# Define numerical format ranges and representable values +DATA_TYPE_INFO = { + 'bf16': { + 'min_normal': 6.103516e-05, # BFloat16 minimum normal value + 'max_normal': 6.550400e+04, # BFloat16 maximum normal value + 'min_denormal': 5.960464e-08, # BFloat16 minimum denormal value + 'max_denormal': 6.097555e-05, # BFloat16 maximum denormal value + 'min': -6.550400e+04, # Effective minimum (negative max normal) + 'max': 6.550400e+04, # Effective maximum (positive max normal) + 'supports_infinity': True, + 'supports_nan': True, + 'description': 'Brain Float 16-bit', + 'color': '#1f77b4', # Blue + 'representable_values': None # Too many to enumerate + }, + 'hifp8': { + 'min_normal': 3.051758e-05, # HiFP8 minimum normal value (2^-15) + 'max_normal': 3.276800e+04, # HiFP8 maximum normal value (2^15) + 'min_denormal': 2.384186e-07, # HiFP8 minimum denormal value (2^-22) + 'max_denormal': 1.525879e-05, # HiFP8 maximum denormal value (approx 2^-16) + 'min': -3.276800e+04, # Effective minimum (negative max normal) + 'max': 3.276800e+04, # Effective maximum (positive max normal) + 'supports_infinity': True, + 'supports_nan': True, + 'description': 'Huawei HiFP8 E4M3', + 'color': '#ff7f0e', # Orange + 'representable_values': None # Generate programmatically + }, + 'mxfp8': { + 'min_normal': 1.562500e-02, # MX FP8 minimum normal value (0.015625) + 'max_normal': 4.480000e+02, # MX FP8 maximum normal value (448.0) + 'min_denormal': 1.953125e-03, # MX FP8 minimum denormal value (2^-9) + 'max_denormal': 1.367188e-02, # MX FP8 maximum denormal value (7*2^-9) + 'min': -4.480000e+02, # Effective minimum (negative max normal) + 'max': 4.480000e+02, # Effective maximum (positive max normal) + 'supports_infinity': False, + 'supports_nan': True, + 'description': 'Microsoft MX FP8 E4M3', + 'color': '#2ca02c', # Green + 'representable_values': None # Generate programmatically + }, + 'mxfp4': { + 'min_normal': 1.000000e+00, # MX FP4 minimum normal value + 'max_normal': 6.000000e+00, # MX FP4 maximum normal value + 'min_denormal': 5.000000e-01, # MX FP4 minimum denormal value (2^-1) + 'max_denormal': 5.000000e-01, # MX FP4 maximum denormal value (only one denormal: 0.5) + 'min': -6.000000e+00, # Effective minimum (negative max normal) + 'max': 6.000000e+00, # Effective maximum (positive max normal) + 'supports_infinity': False, + 'supports_nan': False, + 'description': 'Microsoft MX FP4 E2M1', + 'color': '#d62728', # Red + 'representable_values': [ + # Negative values (symmetric to positive) + -6.000000e+00, -5.000000e+00, -4.000000e+00, -3.500000e+00, -3.000000e+00, -2.500000e+00, + -2.000000e+00, -1.750000e+00, -1.500000e+00, -1.250000e+00, -1.000000e+00, -5.000000e-01, + # Zero + 0.0, + # Positive denormal value (only one: 0.5) + 5.000000e-01, + # Positive normal values (1.0 and above) + 1.000000e+00, 1.250000e+00, 1.500000e+00, 1.750000e+00, 2.000000e+00, + 2.500000e+00, 3.000000e+00, 3.500000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00 + ] + } +} + +# Define tensor component types for different operations +LINEAR_COMPONENTS = ['input', 'weight', 'output', 'bias', 'hidden'] +ATTENTION_COMPONENTS = ['query', 'key', 'value', 'attention_weights', 'output'] + +def generate_fp8_representable_values(format_name): + """Generate representable values for FP8 formats based on mxfp.py implementation.""" + values = set([0.0]) # Include zero + + if format_name == 'hifp8': + # HiFP8 format with dynamic mantissa bits based on hifp.py implementation + # Mantissa bits depend on |exponent|: + # |e| <= 3: 3 mantissa bits + # |e| <= 7: 2 mantissa bits + # |e| <= 15: 1 mantissa bit + # Range: exp ∈ [-22, 15], but normal starts from exp=-15 + + # Generate denormal values (exp < -15) + for exp in range(-22, -15): # -22 to -16 + # For very small exponents, use minimal mantissa representation + value = 2**exp + values.add(value) + values.add(-value) + + # Generate normal values (exp >= -15) + for exp in range(-15, 16): # -15 to 15 + # Determine mantissa bits based on |exp| + abs_exp = abs(exp) + if abs_exp <= 3: + mant_bits = 3 + elif abs_exp <= 7: + mant_bits = 2 + else: # abs_exp <= 15 + mant_bits = 1 + + # Generate all possible mantissa values for this exponent + for mantissa in range(2**mant_bits): + value = (1 + mantissa / (2**mant_bits)) * (2**exp) + + # Apply the max clamp from hifp.py: max = 2^15 = 32768 + if abs(value) <= 32768.0: + values.add(value) + values.add(-value) + + elif format_name == 'mxfp8': + # MX FP8 E4M3 format: 4 exponent bits, 3 mantissa bits + # Based on mxfp.py: ebits=4, mbits=5 (total), emax=8, bias=7 + # Special handling: max_norm = 2^emax * 1.75 = 448.0 + ebits, mbits_frac = 4, 3 # 3 fractional mantissa bits + bias = 2**(ebits-1) - 1 # 7 + emin = 1 - bias # -6 + emax = 8 # From mxfp.py code + + # Generate denormal values (exponent = emin, mantissa != 0) + for mantissa in range(1, 2**mbits_frac): # 1 to 7 + value = mantissa * (2**(emin - mbits_frac)) # mantissa * 2^(-9) + values.add(value) + values.add(-value) + + # Generate normal values + for exp in range(emin, emax + 1): # -6 to 8 + for mantissa in range(2**mbits_frac): # 0 to 7 + value = (1 + mantissa / (2**mbits_frac)) * (2**exp) + + # Special handling for MXFP8: max representable is 448.0 + # This corresponds to exp=8, mantissa=6: (1 + 6/8) * 2^8 = 448.0 + # Skip exp=8, mantissa=7 which would give 480.0 + if exp == emax and mantissa >= 7: # Skip only mantissa=7 at exp=8 + continue + + values.add(value) + values.add(-value) + + return sorted(list(values)) + +# Generate representable values for FP8 formats +DATA_TYPE_INFO['hifp8']['representable_values'] = generate_fp8_representable_values('hifp8') +DATA_TYPE_INFO['mxfp8']['representable_values'] = generate_fp8_representable_values('mxfp8') + +def find_matching_files(base_dir, layer, rank, tensor_type, data_format=None): + """ + Find tensor files matching the specified criteria. + + Args: + base_dir (str): Base directory containing tensor files + layer (int): Layer number + rank (int): Rank number + tensor_type (str): Type of tensor ('linear' or 'attention') + data_format (str, optional): Specific data format to search + + Returns: + dict: Dictionary mapping data formats to lists of matching files + """ + base_path = Path(base_dir) + matching_files = {} + + # Define search patterns + layer_pattern = f"L{layer}_" + rank_pattern = f"rank{rank:02d}_" + type_pattern = f"{tensor_type}_" + + # Search in each data format directory + search_formats = [data_format] if data_format else DATA_TYPE_INFO.keys() + + for fmt in search_formats: + format_dir = base_path / fmt + if not format_dir.exists(): + continue + + format_files = [] + for file_path in format_dir.glob("*.pt"): + filename = file_path.name + + # Check if file matches criteria + if (layer_pattern in filename and + rank_pattern in filename and + type_pattern in filename and + fmt in filename): + format_files.append(file_path) + + if format_files: + matching_files[fmt] = format_files + + return matching_files + +def extract_component_type(filename): + """Extract the component type from filename.""" + filename_lower = filename.lower() + + # Check for linear components + for component in LINEAR_COMPONENTS: + if component in filename_lower: + return component + + # Check for attention components + for component in ATTENTION_COMPONENTS: + if component in filename_lower: + return component + + return 'unknown' + +def analyze_tensor(tensor_path, data_format): + """ + Analyze a single tensor file. + + Args: + tensor_path (Path): Path to tensor file + data_format (str): Data format identifier + + Returns: + dict: Analysis results + """ + try: + tensor = torch.load(tensor_path, map_location='cpu', weights_only=False) + + if not isinstance(tensor, torch.Tensor): + if isinstance(tensor, dict) and 'tensor' in tensor: + tensor = tensor['tensor'] + elif isinstance(tensor, (list, tuple)) and len(tensor) > 0: + tensor = tensor[0] + else: + return None + + # Convert BFloat16 and other unsupported types to Float32 for CPU processing + if tensor.dtype == torch.bfloat16: + tensor = tensor.float() + elif tensor.dtype in [torch.float16, torch.half]: + tensor = tensor.float() + elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + tensor = tensor.float() + elif tensor.dtype in [torch.uint8]: + tensor = tensor.float() + + # Convert to numpy for analysis + if tensor.is_cuda: + tensor_np = tensor.cpu().numpy() + else: + tensor_np = tensor.numpy() + + # Handle empty tensors + if tensor_np.size == 0: + return None + + # Handle complex tensors + if tensor_np.dtype in [np.complex64, np.complex128]: + tensor_np = np.abs(tensor_np) + + # Flatten tensor for distribution analysis + flat_tensor = tensor_np.flatten() + + # Calculate statistics + stats = { + 'filename': tensor_path.name, + 'data_format': data_format, + 'shape': list(tensor.shape), + 'total_elements': tensor_np.size, + 'min_val': float(np.min(flat_tensor)), + 'max_val': float(np.max(flat_tensor)), + 'mean_val': float(np.mean(flat_tensor)), + 'std_val': float(np.std(flat_tensor)), + 'median_val': float(np.median(flat_tensor)), + 'component_type': extract_component_type(tensor_path.name), + 'tensor_data': flat_tensor, + 'format_info': DATA_TYPE_INFO[data_format] + } + + # Calculate overflow/underflow + format_max = DATA_TYPE_INFO[data_format]['max'] + min_denormal = DATA_TYPE_INFO[data_format]['min_denormal'] # Smallest representable non-zero value + + # Overflow: values exceeding maximum representable value + overflow_count = np.sum(np.abs(flat_tensor) > format_max) + + # Underflow: non-zero values closer to zero than smallest representable non-zero value + non_zero_mask = flat_tensor != 0.0 + abs_tensor = np.abs(flat_tensor) + underflow_count = np.sum(non_zero_mask & (abs_tensor < min_denormal)) + + stats.update({ + 'overflow_count': int(overflow_count), + 'underflow_count': int(underflow_count), + 'overflow_percent': float(overflow_count / tensor_np.size * 100), + 'underflow_percent': float(underflow_count / tensor_np.size * 100) + }) + + return stats + + except Exception as e: + print(f"Error analyzing {tensor_path}: {str(e)}") + return None + +def create_distribution_plot(tensor_stats_list, layer, rank, tensor_type, output_dir): + """ + Create comprehensive distribution plots for all data formats. + + Args: + tensor_stats_list (list): List of tensor analysis results + layer (int): Layer number + rank (int): Rank number + tensor_type (str): Type of tensor + output_dir (Path): Output directory + """ + if not tensor_stats_list: + print("No tensor data to plot") + return + + # Group by component type + components = {} + for stats in tensor_stats_list: + comp_type = stats['component_type'] + if comp_type not in components: + components[comp_type] = {} + components[comp_type][stats['data_format']] = stats + + # Determine subplot layout + n_components = len(components) + n_formats = len(DATA_TYPE_INFO) + + # Create figure with subplots + fig_height = max(12, 4 * n_components) + fig_width = max(16, 4 * n_formats) + + fig = plt.figure(figsize=(fig_width, fig_height)) + gs = GridSpec(n_components, n_formats, figure=fig, hspace=0.3, wspace=0.3) + + # Set overall title + fig.suptitle(f'Layer {layer} Rank {rank} - {tensor_type.capitalize()} Analysis\n' + f'Generated on {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}', + fontsize=16, fontweight='bold') + + # Create subplots for each component and format + row_idx = 0 + for comp_type, comp_data in components.items(): + col_idx = 0 + + for fmt, fmt_info in DATA_TYPE_INFO.items(): + ax = fig.add_subplot(gs[row_idx, col_idx]) + + if fmt in comp_data: + stats = comp_data[fmt] + create_single_distribution_subplot(ax, stats, fmt_info) + else: + # No data for this format + ax.text(0.5, 0.5, f'No {fmt.upper()}\ndata available', + ha='center', va='center', transform=ax.transAxes, + fontsize=12, style='italic', color='gray') + ax.set_title(f'{comp_type.capitalize()} - {fmt.upper()}') + ax.set_xticks([]) + ax.set_yticks([]) + + col_idx += 1 + + row_idx += 1 + + # Save the plot + output_file = output_dir / f'layer_{layer}_rank_{rank}_{tensor_type}_analysis.png' + plt.savefig(output_file, dpi=300, bbox_inches='tight', facecolor='white') + plt.close() + + print(f"Distribution plot saved to: {output_file}") + +def create_single_distribution_subplot(ax, stats, fmt_info): + """Create a single distribution subplot with dynamic range adjustment.""" + tensor_data = stats['tensor_data'] + data_format = stats['data_format'] + + # Calculate data statistics for dynamic range adjustment + data_min, data_max = np.min(tensor_data), np.max(tensor_data) + data_range = data_max - data_min + data_abs_max = max(abs(data_min), abs(data_max)) + + # Check for actual overflow/underflow + min_denormal = fmt_info['min_denormal'] + max_normal = fmt_info['max'] + + has_overflow = np.any(np.abs(tensor_data) > max_normal) + has_underflow = np.any((tensor_data != 0.0) & (np.abs(tensor_data) < min_denormal)) + + # Dynamic range calculation + if data_range > 0: + # Use data range with some margin + margin = data_range * 0.15 # 15% margin + plot_min = data_min - margin + plot_max = data_max + margin + + # Ensure we show underflow boundary if there's potential underflow + if min_denormal > 0 and data_abs_max > min_denormal * 0.1: + # Extend range to show underflow boundaries clearly + plot_min = min(plot_min, -min_denormal * 2) + plot_max = max(plot_max, min_denormal * 2) + else: + # Fallback for edge cases + plot_min, plot_max = data_min - 1, data_max + 1 + + # Create histogram + n_bins = min(100, max(20, int(np.sqrt(len(tensor_data))))) + counts, bins, patches = ax.hist(tensor_data, bins=n_bins, alpha=0.7, + color=fmt_info['color'], density=True, + label='Tensor Values') + + # Set x-axis range based on data + ax.set_xlim(plot_min, plot_max) + + # Add representable values overlay for formats that have them + if fmt_info['representable_values'] and len(fmt_info['representable_values']) < 200: + rep_values = np.array(fmt_info['representable_values']) + # Filter to plot range + rep_in_range = rep_values[(rep_values >= plot_min) & (rep_values <= plot_max)] + + if len(rep_in_range) > 0: + # Add vertical lines for representable values + for val in rep_in_range[::max(1, len(rep_in_range)//50)]: # Limit to 50 lines + ax.axvline(val, color='red', alpha=0.3, linewidth=0.5) + + # Add overflow boundaries only if there's actual overflow or values are close + if has_overflow or data_abs_max > max_normal * 0.5: + if plot_max >= max_normal: + ax.axvline(max_normal, color='darkred', linestyle='--', linewidth=2, + alpha=0.8, label=f'Overflow (+{max_normal:.1e})') + if plot_min <= -max_normal: + ax.axvline(-max_normal, color='darkred', linestyle='--', linewidth=2, + alpha=0.8, label=f'Overflow (-{max_normal:.1e})') + + # Always show underflow boundaries (more important for analysis) + if plot_max >= min_denormal: + ax.axvline(min_denormal, color='orange', linestyle=':', linewidth=2, + alpha=0.9, label=f'Underflow Boundary (+{min_denormal:.1e})') + if plot_min <= -min_denormal: + ax.axvline(-min_denormal, color='orange', linestyle=':', linewidth=2, + alpha=0.9, label=f'Underflow Boundary (-{min_denormal:.1e})') + + # Add zero line for reference + if plot_min < 0 < plot_max: + ax.axvline(0, color='gray', linestyle='-', linewidth=1, alpha=0.5) + + # Add statistics text + stats_text = (f'Min: {stats["min_val"]:.4f}\n' + f'Max: {stats["max_val"]:.4f}\n' + f'Mean: {stats["mean_val"]:.4f}\n' + f'Std: {stats["std_val"]:.4f}\n' + f'Shape: {stats["shape"]}') + + # Add overflow/underflow warning with clear explanation + if stats['overflow_count'] > 0 or stats['underflow_count'] > 0: + warning_text = f'\n⚠️ Overflow: {stats["overflow_count"]} (|val| > max)\n⚠️ Underflow: {stats["underflow_count"]} (0 < |val| < min_denormal)' + stats_text += warning_text + + ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, + verticalalignment='top', fontsize=8, + bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + + # Set title and labels + ax.set_title(f'{stats["component_type"].capitalize()} - {data_format.upper()}\n' + f'{fmt_info["description"]}', fontsize=10, fontweight='bold') + ax.set_xlabel('Value') + ax.set_ylabel('Density') + + # Add legend + if fmt_info['representable_values'] and len(fmt_info['representable_values']) < 200: + rep_patch = mpatches.Patch(color='red', alpha=0.3, label='Representable Values') + ax.legend(handles=[rep_patch], loc='upper right', fontsize=8) + +def generate_analysis_report(tensor_stats_list, layer, rank, tensor_type, output_dir): + """Generate detailed analysis report.""" + if not tensor_stats_list: + return + + report_file = output_dir / f'layer_{layer}_rank_{rank}_{tensor_type}_report.txt' + + with open(report_file, 'w') as f: + f.write("=" * 80 + "\n") + f.write(f"LAYER {layer} RANK {rank} - {tensor_type.upper()} ANALYSIS REPORT\n") + f.write("=" * 80 + "\n") + f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") + + # Summary + f.write("SUMMARY\n") + f.write("-" * 40 + "\n") + f.write(f"Layer: {layer}\n") + f.write(f"Rank: {rank}\n") + f.write(f"Tensor Type: {tensor_type}\n") + f.write(f"Files Analyzed: {len(tensor_stats_list)}\n\n") + + # Group by component type + components = {} + for stats in tensor_stats_list: + comp_type = stats['component_type'] + if comp_type not in components: + components[comp_type] = [] + components[comp_type].append(stats) + + # Component analysis + f.write("COMPONENT ANALYSIS\n") + f.write("-" * 40 + "\n") + + for comp_type, comp_stats in components.items(): + f.write(f"\n{comp_type.upper()} Component:\n") + f.write("─" * 30 + "\n") + + for stats in comp_stats: + f.write(f"\nData Format: {stats['data_format'].upper()}\n") + f.write(f"File: {stats['filename']}\n") + f.write(f"Shape: {stats['shape']}\n") + f.write(f"Total Elements: {stats['total_elements']:,}\n") + f.write(f"Value Range: [{stats['min_val']:.6f}, {stats['max_val']:.6f}]\n") + f.write(f"Mean ± Std: {stats['mean_val']:.6f} ± {stats['std_val']:.6f}\n") + f.write(f"Median: {stats['median_val']:.6f}\n") + f.write(f"Format Range: [{stats['format_info']['min']}, {stats['format_info']['max']}]\n") + + if stats['overflow_count'] > 0: + f.write(f"⚠️ OVERFLOW: {stats['overflow_count']:,} elements ({stats['overflow_percent']:.4f}%)\n") + + if stats['underflow_count'] > 0: + f.write(f"⚠️ UNDERFLOW: {stats['underflow_count']:,} elements ({stats['underflow_percent']:.4f}%)\n") + + if stats['overflow_count'] == 0 and stats['underflow_count'] == 0: + f.write("✅ No overflow/underflow detected\n") + + # Recommendations + f.write("\n\nRECOMMENDATIONS\n") + f.write("-" * 40 + "\n") + + has_issues = any(s['overflow_count'] > 0 or s['underflow_count'] > 0 for s in tensor_stats_list) + + if has_issues: + f.write("Issues detected in this layer/rank combination:\n") + f.write("• Consider gradient clipping or scaling\n") + f.write("• Review initialization schemes\n") + f.write("• Monitor numerical stability during training\n") + f.write("• Consider mixed precision training strategies\n") + else: + f.write("No numerical issues detected in this layer/rank combination.\n") + f.write("Values are within representable ranges for all formats.\n") + + print(f"Analysis report saved to: {report_file}") + +def parse_multi_values(value_str, value_type=str): + """Parse comma-separated values into a list with type conversion.""" + if ',' in value_str: + values = [item.strip() for item in value_str.split(',')] + else: + values = [value_str.strip()] + + if value_type == int: + return [int(v) for v in values] + return values + +def main(): + """Main function for layer analysis.""" + parser = argparse.ArgumentParser(description='Analyze and visualize layer-specific tensor distributions') + parser.add_argument('--layer', required=True, + help='Layer number(s) to analyze (single: 1, multiple: 1,8,15,16)') + parser.add_argument('--rank', required=True, + help='Rank number(s) to analyze (single: 0, multiple: 0,1)') + parser.add_argument('--type', required=True, + help='Type(s) of tensor operation (single: linear, multiple: linear,attention)') + parser.add_argument('--base-dir', default='enhanced_tensor_logs', + help='Base directory containing tensor files (default: enhanced_tensor_logs)') + parser.add_argument('--output-dir', default='./draw/layer_analysis/', + help='Output directory for plots and reports (default: ./draw/layer_analysis/)') + parser.add_argument('--format', choices=['bf16', 'hifp8', 'mxfp8', 'mxfp4'], + help='Specific data format to analyze (default: all formats)') + + args = parser.parse_args() + + # Parse multi-value inputs + layers = parse_multi_values(args.layer, int) + ranks = parse_multi_values(args.rank, int) + types = parse_multi_values(args.type, str) + + # Validate tensor types + valid_types = ['linear', 'attention'] + for tensor_type in types: + if tensor_type not in valid_types: + print(f"Error: Invalid tensor type '{tensor_type}'. Must be one of: {valid_types}") + return 1 + + # Setup paths + base_dir = Path(args.base_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + if not base_dir.exists(): + print(f"Error: Base directory not found: {base_dir}") + return 1 + + # Generate all combinations + combinations = list(product(layers, ranks, types)) + + print(f"Multi-Parameter Analysis") + print("=" * 60) + print(f"Layers: {layers}") + print(f"Ranks: {ranks}") + print(f"Types: {types}") + print(f"Total combinations: {len(combinations)}") + print("=" * 60) + + successful_analyses = 0 + failed_analyses = 0 + + # Process each combination with progress bar + for layer, rank, tensor_type in tqdm(combinations, desc="Processing combinations", unit="combo"): + try: + # Find matching files + matching_files = find_matching_files( + base_dir=base_dir, + layer=layer, + rank=rank, + tensor_type=tensor_type, + data_format=args.format + ) + + if not matching_files: + tqdm.write(f"No matching files found for Layer {layer}, Rank {rank}, Type: {tensor_type}") + failed_analyses += 1 + continue + + # Analyze tensors + tensor_stats_list = [] + + for data_format, file_list in matching_files.items(): + for file_path in file_list: + stats = analyze_tensor(file_path, data_format) + if stats: + tensor_stats_list.append(stats) + + if not tensor_stats_list: + tqdm.write(f"No valid tensor data found for Layer {layer}, Rank {rank}, Type: {tensor_type}") + failed_analyses += 1 + continue + + # Generate visualizations and reports + create_distribution_plot(tensor_stats_list, layer, rank, tensor_type, output_dir) + generate_analysis_report(tensor_stats_list, layer, rank, tensor_type, output_dir) + + successful_analyses += 1 + tqdm.write(f"✅ Completed Layer {layer}, Rank {rank}, Type: {tensor_type} ({len(tensor_stats_list)} files)") + + except Exception as e: + tqdm.write(f"❌ Error processing Layer {layer}, Rank {rank}, Type: {tensor_type}: {str(e)}") + failed_analyses += 1 + continue + + # Final summary + print("\n" + "=" * 60) + print("ANALYSIS SUMMARY") + print("=" * 60) + print(f"Total combinations: {len(combinations)}") + print(f"Successful analyses: {successful_analyses}") + print(f"Failed analyses: {failed_analyses}") + print(f"Success rate: {(successful_analyses/len(combinations)*100):.1f}%") + print(f"Output files saved in: {output_dir}") + + return 0 if successful_analyses > 0 else 1 + +if __name__ == "__main__": + exit(main()) diff --git a/visualization/overflow.py b/visualization/overflow.py new file mode 100755 index 0000000000..2e1bf0017c --- /dev/null +++ b/visualization/overflow.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +""" +Overflow detection script for tensor files based on numerical format. +Analyzes tensors for overflow and underflow conditions based on their data format. +""" + +import os +import torch +import numpy as np +import argparse +from pathlib import Path + +# Define numerical format ranges based on research and specifications +DATA_TYPE_RANGES = { + 'bf16': { + 'min_normal': 6.103516e-05, # BFloat16 minimum normal value + 'max_normal': 6.550400e+04, # BFloat16 maximum normal value + 'min_denormal': 5.960464e-08, # BFloat16 minimum denormal value + 'max_denormal': 6.097555e-05, # BFloat16 maximum denormal value + 'min': -6.550400e+04, # Effective minimum (negative max normal) + 'max': 6.550400e+04, # Effective maximum (positive max normal) + 'supports_infinity': True, + 'supports_nan': True, + 'description': 'Brain Float 16-bit' + }, + 'hifp8': { + 'min_normal': 3.051758e-05, # HiFP8 minimum normal value (2^-15) + 'max_normal': 3.276800e+04, # HiFP8 maximum normal value (2^15) + 'min_denormal': 2.384186e-07, # HiFP8 minimum denormal value (2^-22) + 'max_denormal': 1.525879e-05, # HiFP8 maximum denormal value (approx 2^-16) + 'min': -3.276800e+04, # Effective minimum (negative max normal) + 'max': 3.276800e+04, # Effective maximum (positive max normal) + 'supports_infinity': True, + 'supports_nan': True, + 'description': 'Huawei HiFP8 E4M3 format' + }, + 'mxfp8': { + 'min_normal': 1.562500e-02, # MX FP8 minimum normal value (0.015625) + 'max_normal': 4.480000e+02, # MX FP8 maximum normal value (448.0) + 'min_denormal': 1.953125e-03, # MX FP8 minimum denormal value (2^-9) + 'max_denormal': 1.367188e-02, # MX FP8 maximum denormal value (7*2^-9) + 'min': -4.480000e+02, # Effective minimum (negative max normal) + 'max': 4.480000e+02, # Effective maximum (positive max normal) + 'supports_infinity': False, + 'supports_nan': True, + 'description': 'Microsoft MX FP8 E4M3 format' + }, + 'mxfp4': { + 'min_normal': 1.000000e+00, # MX FP4 minimum normal value + 'max_normal': 6.000000e+00, # MX FP4 maximum normal value + 'min_denormal': 5.000000e-01, # MX FP4 minimum denormal value (2^-1) + 'max_denormal': 5.000000e-01, # MX FP4 maximum denormal value (only one denormal: 0.5) + 'min': -6.000000e+00, # Effective minimum (negative max normal) + 'max': 6.000000e+00, # Effective maximum (positive max normal) + 'supports_infinity': False, + 'supports_nan': False, + 'description': 'Microsoft MX FP4 E2M1 format' + } +} + +def detect_data_format(filename): + """ + Extract data format from filename. + + Args: + filename (str): Tensor file name + + Returns: + str: Data format (bf16, hifp8, mxfp8, mxfp4) or None if not found + """ + for fmt in DATA_TYPE_RANGES.keys(): + if fmt in filename: + return fmt + return None + +def analyze_tensor_overflow(tensor, data_format): + """ + Analyze tensor for overflow and underflow conditions. + + Args: + tensor (torch.Tensor): Input tensor + data_format (str): Data format identifier + + Returns: + dict: Analysis results containing overflow/underflow statistics + """ + if data_format not in DATA_TYPE_RANGES: + raise ValueError(f"Unknown data format: {data_format}") + + format_info = DATA_TYPE_RANGES[data_format] + max_val = format_info['max'] + min_denormal = format_info['min_denormal'] # Smallest representable non-zero value + + # Convert tensor to numpy for analysis + if tensor.is_cuda: + tensor_np = tensor.cpu().numpy() + else: + tensor_np = tensor.numpy() + + # Handle empty tensors + if tensor_np.size == 0: + return { + 'filename': 'empty_tensor', + 'data_format': data_format, + 'total_elements': 0, + 'overflow_count': 0, + 'underflow_count': 0, + 'overflow_percent': 0.0, + 'underflow_percent': 0.0, + 'has_overflow': False, + 'has_underflow': False, + 'has_issues': False, + 'tensor_min': 0.0, + 'tensor_max': 0.0, + 'tensor_mean': 0.0, + 'tensor_std': 0.0, + 'format_min_denormal': min_denormal, + 'format_max': max_val, + 'shape': list(tensor.shape) + } + + # Handle different tensor types + if tensor_np.dtype == np.complex64 or tensor_np.dtype == np.complex128: + # For complex tensors, analyze magnitude + tensor_np = np.abs(tensor_np) + + # Count overflow and underflow + total_elements = tensor_np.size + + # Overflow: values exceeding maximum representable value + overflow_count = np.sum(np.abs(tensor_np) > max_val) + + # Underflow: non-zero values closer to zero than smallest representable non-zero value + # This means |value| > 0 and |value| < min_denormal + non_zero_mask = tensor_np != 0.0 + abs_tensor = np.abs(tensor_np) + underflow_count = np.sum(non_zero_mask & (abs_tensor < min_denormal)) + + # Calculate percentages + overflow_percent = (overflow_count / total_elements) * 100 + underflow_percent = (underflow_count / total_elements) * 100 + + # Additional statistics + tensor_min = np.min(tensor_np) + tensor_max = np.max(tensor_np) + tensor_mean = np.mean(tensor_np) + tensor_std = np.std(tensor_np) + + return { + 'filename': os.path.basename(tensor.filename) if hasattr(tensor, 'filename') else 'unknown', + 'data_format': data_format, + 'total_elements': total_elements, + 'overflow_count': overflow_count, + 'underflow_count': underflow_count, + 'overflow_percent': overflow_percent, + 'underflow_percent': underflow_percent, + 'has_overflow': overflow_count > 0, + 'has_underflow': underflow_count > 0, + 'has_issues': (overflow_count > 0) or (underflow_count > 0), + 'tensor_min': tensor_min, + 'tensor_max': tensor_max, + 'tensor_mean': tensor_mean, + 'tensor_std': tensor_std, + 'format_min_denormal': min_denormal, + 'format_max': max_val, + 'shape': list(tensor.shape) + } + +def analyze_file(filepath): + """ + Analyze a single tensor file for overflow/underflow. + + Args: + filepath (str): Path to tensor file + + Returns: + dict or None: Analysis results or None if file cannot be processed + """ + try: + # Extract data format from filename + filename = os.path.basename(filepath) + data_format = detect_data_format(filename) + + if data_format is None: + print(f"Warning: Could not detect data format from filename: {filename}") + return None + + # Load tensor + tensor = torch.load(filepath, map_location='cpu', weights_only=False) + + # Handle case where loaded object is not a tensor + if not isinstance(tensor, torch.Tensor): + if isinstance(tensor, dict) and 'tensor' in tensor: + tensor = tensor['tensor'] + elif isinstance(tensor, (list, tuple)) and len(tensor) > 0: + tensor = tensor[0] + else: + print(f"Warning: Loaded object is not a tensor: {filename}") + return None + + # Convert BFloat16 and other unsupported types to Float32 for CPU processing + if tensor.dtype == torch.bfloat16: + tensor = tensor.float() + elif tensor.dtype in [torch.float16, torch.half]: + tensor = tensor.float() + elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + tensor = tensor.float() + elif tensor.dtype in [torch.uint8]: + tensor = tensor.float() + + # Analyze overflow + result = analyze_tensor_overflow(tensor, data_format) + result['filename'] = filename + result['filepath'] = filepath + + return result + + except Exception as e: + print(f"Error processing file {filepath}: {str(e)}") + return None + +def main(): + parser = argparse.ArgumentParser(description='Analyze tensor files for overflow/underflow conditions') + parser.add_argument('input_paths', nargs='+', help='Path(s) to tensor file(s) or directory') + parser.add_argument('--output', '-o', default='./draw/tensor_overflow/', help='Output file for results (default: ./draw/tensor_overflow/)') + parser.add_argument('--format', '-f', choices=['txt', 'csv', 'json'], default='txt', + help='Output format (default: txt)') + parser.add_argument('--recursive', '-r', action='store_true', + help='Recursively search directories for tensor files') + + args = parser.parse_args() + + results = [] + + # Process each input path + for input_path_str in args.input_paths: + input_path = Path(input_path_str) + + if input_path.is_file(): + # Single file analysis + print(f"Processing file: {input_path.name}") + result = analyze_file(str(input_path)) + if result: + results.append(result) + elif input_path.is_dir(): + # Directory analysis + print(f"Processing directory: {input_path.name}") + pattern = "**/*.pt" if args.recursive else "*.pt" + for filepath in input_path.glob(pattern): + result = analyze_file(str(filepath)) + if result: + results.append(result) + else: + print(f"Warning: Path does not exist: {input_path}") + continue + + if not results: + print("No valid tensor files found or processed.") + return 1 + + # Always output results to console first + print("Analysis Results:") + print("=" * 50) + write_text_report(None, results) + + # Then save to file + output_path = Path(args.output) + + # If output is a directory (like the default), generate filename based on input + if str(output_path).endswith('/') or output_path.is_dir() or (not output_path.suffix and not output_path.exists()): + output_path.mkdir(parents=True, exist_ok=True) + + # Generate filename based on input paths + if len(args.input_paths) == 1: + # Single input: same as input but with .log extension + input_path_obj = Path(args.input_paths[0]) + if input_path_obj.is_file(): + filename = input_path_obj.stem + '.log' + else: + filename = input_path_obj.name + '.log' + else: + # Multiple inputs: use timestamp-based filename + from datetime import datetime + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"multi_tensor_analysis_{timestamp}.log" + + output_path = output_path / filename + else: + # If specific filename provided, create parent directories + output_path.parent.mkdir(parents=True, exist_ok=True) + + if args.format == 'json': + import json + with open(output_path, 'w') as f: + json.dump(results, f, indent=2, default=str) + elif args.format == 'csv': + import csv + with open(output_path, 'w', newline='') as f: + if results: + writer = csv.DictWriter(f, fieldnames=results[0].keys()) + writer.writeheader() + writer.writerows(results) + else: # txt format + with open(output_path, 'w') as f: + write_text_report(f, results) + + print(f"\nResults saved to: {output_path}") + + return 0 + +def write_text_report(file_handle, results): + """Write analysis results in text format.""" + def write_line(line=""): + if file_handle: + file_handle.write(line + "\n") + else: + print(line) + + write_line("=" * 80) + write_line("TENSOR VALUE OVERFLOW/UNDERFLOW ANALYSIS REPORT") + write_line("=" * 80) + write_line("This report shows the PERCENTAGE of tensor values that overflow/underflow") + write_line("Overflow: Values with |value| > max_representable") + write_line("Underflow: Non-zero values with 0 < |value| < min_denormal") + write_line("=" * 80) + write_line() + + # Group results by data format + format_groups = {} + for result in results: + fmt = result['data_format'] + if fmt not in format_groups: + format_groups[fmt] = [] + format_groups[fmt].append(result) + + # Summary statistics + write_line("SUMMARY STATISTICS") + write_line("-" * 40) + for fmt, fmt_results in format_groups.items(): + total_files = len(fmt_results) + total_elements = sum(r['total_elements'] for r in fmt_results) + total_overflow = sum(r['overflow_count'] for r in fmt_results) + total_underflow = sum(r['underflow_count'] for r in fmt_results) + + overflow_percent = (total_overflow / total_elements) * 100 if total_elements > 0 else 0 + underflow_percent = (total_underflow / total_elements) * 100 if total_elements > 0 else 0 + + format_info = DATA_TYPE_RANGES[fmt] + write_line(f"{fmt.upper()} ({format_info['description']}):") + write_line(f" Files analyzed: {total_files}") + write_line(f" Total elements: {total_elements:,}") + write_line(f" Overflow: {total_overflow:,} ({overflow_percent:.4f}%)") + write_line(f" Underflow: {total_underflow:,} ({underflow_percent:.4f}%)") + write_line(f" Max Normal: ±{format_info['max_normal']:.2e}") + write_line(f" Min Denormal: {format_info['min_denormal']:.2e}") + write_line(f" Supports Inf/NaN: {format_info['supports_infinity']}/{format_info['supports_nan']}") + write_line() + + # Detailed results + write_line("DETAILED RESULTS") + write_line("-" * 40) + for fmt, fmt_results in format_groups.items(): + write_line(f"\n{fmt.upper()} FILES:") + write_line("-" * 20) + + for result in fmt_results: + write_line(f"File: {result['filename']}") + write_line(f" Shape: {result['shape']}") + write_line(f" Elements: {result['total_elements']:,}") + write_line(f" Value range: [{result['tensor_min']:.6f}, {result['tensor_max']:.6f}]") + write_line(f" Mean ± Std: {result['tensor_mean']:.6f} ± {result['tensor_std']:.6f}") + write_line(f" Overflow: {result['overflow_count']:,} ({result['overflow_percent']:.4f}%)") + write_line(f" Underflow: {result['underflow_count']:,} ({result['underflow_percent']:.4f}%)") + + if result['overflow_count'] > 0 or result['underflow_count'] > 0: + write_line(" ⚠️ OVERFLOW/UNDERFLOW DETECTED!") + + write_line() + +if __name__ == "__main__": + exit(main()) diff --git a/visualization/overflow/README_scaling_analysis.md b/visualization/overflow/README_scaling_analysis.md new file mode 100644 index 0000000000..15ae4cd6b3 --- /dev/null +++ b/visualization/overflow/README_scaling_analysis.md @@ -0,0 +1,84 @@ +# Scaling Factor Analysis Tool + +这个Python程序用于分析所有存储的log文件,检查是否所有tensor的最推荐Scaling都是其最大值。 + +## 功能特点 + +- 🔍 自动扫描所有scaling_analysis目录下的log文件 +- 📊 解析每个tensor的对齐范围(max_align, min_align)和推荐scaling factor +- ✅ 检查推荐的scale exponent是否等于max_align(最大值) +- 📈 生成详细的统计报告,包括按层和按tensor类型的分析 +- 💾 将结果保存为JSON格式便于进一步分析 + +## 使用方法 + +### 方法1:使用简单脚本 +```bash +python3 run_analysis.py +``` + +### 方法2:使用完整程序 +```bash +python3 analyze_scaling_factors.py +``` + +### 方法3:指定目录 +```bash +python3 run_analysis.py /path/to/your/data/directory +``` + +## 输出说明 + +程序会输出以下信息: + +1. **总体摘要**: 显示分析的tensor总数和在最大值的tensor数量 +2. **按类型分类**: 显示不同tensor类型(input, weight, output等)的统计 +3. **按层分类**: 显示不同层的统计信息 +4. **性能统计**: 显示平均composite score和MSE +5. **详细列表**: 如果有tensor不在最大值,会列出详细信息 + +## 输出文件 + +- `scaling_analysis_results.json`: 包含完整分析结果的JSON文件 + +## 分析结果 + +根据当前数据的分析结果: + +🎉 **所有140个tensor的推荐scaling factor都是其最大值!** + +这意味着: +- 所有tensor都在使用最优的scaling factor +- 没有overflow风险 +- 精度损失最小 + +## 文件结构 + +``` +draw/ +├── analyze_scaling_factors.py # 主分析程序 +├── run_analysis.py # 简单包装脚本 +├── scaling_analysis_results.json # 分析结果 +└── scaling_analysis/ # 包含所有log文件的目录 + ├── 20250915_040631_0001_.../ + ├── 20250915_040632_0002_.../ + └── ... +``` + +## 技术细节 + +程序通过正则表达式解析log文件中的关键信息: +- `Calculated alignment: max_align=X, min_align=Y` +- `⭐ RECOMMENDED Scaling Factor: X` +- `Scale Exponent: Y` + +然后检查推荐的scale exponent是否等于max_align值(容差1e-6)。 + +## 依赖 + +- Python 3.6+ +- 标准库模块:os, re, glob, json, pathlib, typing, dataclasses, collections + + + + diff --git a/visualization/overflow/enhanced_overflow_analyzer.py b/visualization/overflow/enhanced_overflow_analyzer.py new file mode 100644 index 0000000000..d6bf4f7ffb --- /dev/null +++ b/visualization/overflow/enhanced_overflow_analyzer.py @@ -0,0 +1,634 @@ +#!/usr/bin/env python3 +""" +Enhanced Overflow Analysis Tool + +This program analyzes overflow/underflow data from the enhanced tensor structure +with forward/backward passes and detailed tensor type classification. + +Author: AI Assistant +Created: 2025-09-23 +""" + +import os +import re +import glob +import json +import matplotlib.pyplot as plt +import seaborn as sns +import numpy as np +from pathlib import Path +from typing import Dict, List, Tuple, Optional +from dataclasses import dataclass +from collections import defaultdict + + +@dataclass +class EnhancedQuantizationMetrics: + """Enhanced data class for quantization metrics""" + tensor_name: str + layer: int + pass_type: str # forward, backward + operation_type: str # linear, attention + tensor_type: str # input_A, input_B, output, weight, query, key, value, etc. + rank: int + group: int + total_elements: int + underflow_percentage: float + flush_to_zero_percentage: float + overflow_percentage: float + underflow_significant: bool + overflow_significant: bool + + +class EnhancedOverflowAnalyzer: + """Enhanced analyzer for overflow/underflow analysis""" + + def __init__(self, base_directory: str = "/Users/charles/Downloads/draw"): + self.base_directory = Path(base_directory) + self.scaling_dir = self.base_directory / "scaling_analysis" + self.results: List[EnhancedQuantizationMetrics] = [] + self.fp8_max_norm = 448.0 + + # Enhanced color scheme for different categories + self.colors = { + 'forward': '#2E86AB', # Blue + 'backward': '#A23B72', # Purple + 'linear': '#F18F01', # Orange + 'attention': '#C73E1D', # Red + 'input_A': '#2E8B57', # Sea Green + 'input_B': '#4682B4', # Steel Blue + 'output': '#DC143C', # Crimson + 'weight': '#8B4513', # Saddle Brown + 'query': '#9370DB', # Medium Purple + 'key': '#20B2AA', # Light Sea Green + 'value': '#FF6347', # Tomato + 'probs': '#FFD700', # Gold + 'buffer': '#32CD32', # Lime Green + 'grad_output': '#FF1493', # Deep Pink + 'grad_attention_probs': '#00CED1', # Dark Turquoise + 'grad_value': '#FF8C00', # Dark Orange + 'grad_query': '#8A2BE2', # Blue Violet + 'grad_key': '#00FF7F', # Spring Green + 'mm_output': '#FF69B4', # Hot Pink + } + + def find_result_files(self) -> List[Path]: + """Find all result files in the scaling analysis directory""" + pattern = str(self.scaling_dir / "**" / "*_results_fp8_e4m3.txt") + result_files = glob.glob(pattern, recursive=True) + return [Path(f) for f in result_files] + + def extract_enhanced_metadata_from_path(self, result_file: Path) -> Optional[Dict]: + """Extract enhanced metadata from file path""" + try: + # Extract from the directory name + parent_dir = result_file.parent.name + + # Parse the enhanced naming format: + # 20250923_100142_0001_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A + + # Extract layer + layer_match = re.search(r'_L(\d+)_', parent_dir) + if not layer_match: + return None + layer = int(layer_match.group(1)) + + # Extract pass type (forward/backward) + if '_forward_' in parent_dir: + pass_type = 'forward' + elif '_backward_' in parent_dir: + pass_type = 'backward' + else: + pass_type = 'unknown' + + # Extract operation type + if '_linear_' in parent_dir: + operation_type = 'linear' + elif '_attention_' in parent_dir: + operation_type = 'attention' + else: + operation_type = 'unknown' + + # Extract rank + rank_match = re.search(r'_rank(\d+)_', parent_dir) + rank = int(rank_match.group(1)) if rank_match else 0 + + # Extract group + group_match = re.search(r'_group(\d+)_', parent_dir) + group = int(group_match.group(1)) if group_match else 0 + + # Extract tensor type (the last part after the last underscore) + parts = parent_dir.split('_') + tensor_type = parts[-1] # input_A, input_B, output, weight, query, key, value, etc. + + # Create a more readable tensor name + tensor_name = f"L{layer}_{pass_type}_{operation_type}_{tensor_type}" + + return { + 'tensor_name': tensor_name, + 'layer': layer, + 'pass_type': pass_type, + 'operation_type': operation_type, + 'tensor_type': tensor_type, + 'rank': rank, + 'group': group + } + + except Exception as e: + print(f"Error extracting metadata from {result_file}: {e}") + return None + + def parse_recommended_metrics(self, result_file: Path) -> Optional[EnhancedQuantizationMetrics]: + """Parse metrics at recommended scaling factor""" + try: + with open(result_file, 'r', encoding='utf-8') as f: + content = f.read() + + metadata = self.extract_enhanced_metadata_from_path(result_file) + if not metadata: + return None + + # First, get the recommended scale exponent from the corresponding log file + log_file = result_file.parent / f"mxfp_scaling_test_{result_file.parent.name}_fp8_e4m3.log" + recommended_scale_exp = None + + if log_file.exists(): + try: + with open(log_file, 'r', encoding='utf-8') as f: + log_content = f.read() + scale_match = re.search(r'⭐ RECOMMENDED Scaling Factor: [\d\.e\-\+]+.*?Scale Exponent: (-?\d+\.?\d*)', log_content, re.DOTALL) + if scale_match: + recommended_scale_exp = float(scale_match.group(1)) + except: + pass + + if recommended_scale_exp is None: + return None + + # Now find the specific section for the recommended scale exponent + scale_exp_str = f"{recommended_scale_exp:.2f}" + recommended_section_pattern = rf'Scale Exponent {re.escape(scale_exp_str)} \(Factor: [\d\.e\-\+]+\):\s*\n.*?Overflow/Underflow Analysis:\s*\n\s*Total Elements: ([\d,]+)\s*\n\s*Underflow Count: [\d,]+ \(([\d\.]+)%\)\s*\n\s*Flush to Zero Count: [\d,]+ \(([\d\.]+)%\)\s*\n\s*Overflow Count: [\d,]+ \(([\d\.]+)%\)' + + analysis_match = re.search(recommended_section_pattern, content, re.DOTALL) + + if not analysis_match: + return None + + total_elements = int(analysis_match.group(1).replace(',', '')) + underflow_percentage = float(analysis_match.group(2)) + flush_to_zero_percentage = float(analysis_match.group(3)) + overflow_percentage = float(analysis_match.group(4)) + + # Extract significance flags from the same section + section_content = analysis_match.group(0) + underflow_significant = 'Has Significant Underflow: Yes' in section_content + overflow_significant = 'Has Significant Overflow: Yes' in section_content + + return EnhancedQuantizationMetrics( + tensor_name=metadata['tensor_name'], + layer=metadata['layer'], + pass_type=metadata['pass_type'], + operation_type=metadata['operation_type'], + tensor_type=metadata['tensor_type'], + rank=metadata['rank'], + group=metadata['group'], + total_elements=total_elements, + underflow_percentage=underflow_percentage, + flush_to_zero_percentage=flush_to_zero_percentage, + overflow_percentage=overflow_percentage, + underflow_significant=underflow_significant, + overflow_significant=overflow_significant + ) + except Exception as e: + print(f"Error parsing {result_file}: {e}") + return None + + def analyze_all_files(self) -> None: + """Analyze all result files and store metrics""" + result_files = self.find_result_files() + print(f"Found {len(result_files)} result files to analyze...") + + successful_parses = 0 + for result_file in result_files: + metrics = self.parse_recommended_metrics(result_file) + if metrics: + self.results.append(metrics) + successful_parses += 1 + + print(f"Successfully parsed {successful_parses} out of {len(result_files)} result files") + + def create_enhanced_overflow_analysis_plot(self) -> None: + """Create enhanced overflow analysis plot by layer and pass type""" + if not self.results: + print("No data to plot!") + return + + # Group data by layer and pass type + layer_pass_data = defaultdict(lambda: {'overflow': [], 'underflow': [], 'flush_to_zero': []}) + + for result in self.results: + key = f"L{result.layer}_{result.pass_type}" + layer_pass_data[key]['overflow'].append(result.overflow_percentage) + layer_pass_data[key]['underflow'].append(result.underflow_percentage) + layer_pass_data[key]['flush_to_zero'].append(result.flush_to_zero_percentage) + + # Create the plot + fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 16)) + + # Prepare data for plotting + layers = sorted(set(result.layer for result in self.results)) + pass_types = ['forward', 'backward'] + + x_pos = np.arange(len(layers)) + width = 0.35 + + # Overflow plot + forward_overflow = [] + backward_overflow = [] + for layer in layers: + forward_key = f"L{layer}_forward" + backward_key = f"L{layer}_backward" + + forward_avg = np.mean(layer_pass_data[forward_key]['overflow']) if forward_key in layer_pass_data else 0 + backward_avg = np.mean(layer_pass_data[backward_key]['overflow']) if backward_key in layer_pass_data else 0 + + forward_overflow.append(forward_avg) + backward_overflow.append(backward_avg) + + bars1 = ax1.bar(x_pos - width/2, forward_overflow, width, label='Forward Pass', + color=self.colors['forward'], alpha=0.8) + bars2 = ax1.bar(x_pos + width/2, backward_overflow, width, label='Backward Pass', + color=self.colors['backward'], alpha=0.8) + + ax1.set_title('Overflow Analysis by Layer and Pass Type', fontsize=16, fontweight='bold') + ax1.set_xlabel('Layer', fontsize=12) + ax1.set_ylabel('Overflow Percentage (%)', fontsize=12) + ax1.set_xticks(x_pos) + ax1.set_xticklabels([f'Layer {l}' for l in layers]) + ax1.legend() + ax1.grid(True, alpha=0.3) + + # Underflow plot + forward_underflow = [] + backward_underflow = [] + for layer in layers: + forward_key = f"L{layer}_forward" + backward_key = f"L{layer}_backward" + + forward_avg = np.mean(layer_pass_data[forward_key]['underflow']) if forward_key in layer_pass_data else 0 + backward_avg = np.mean(layer_pass_data[backward_key]['underflow']) if backward_key in layer_pass_data else 0 + + forward_underflow.append(forward_avg) + backward_underflow.append(backward_avg) + + bars3 = ax2.bar(x_pos - width/2, forward_underflow, width, label='Forward Pass', + color=self.colors['forward'], alpha=0.8) + bars4 = ax2.bar(x_pos + width/2, backward_underflow, width, label='Backward Pass', + color=self.colors['backward'], alpha=0.8) + + ax2.set_title('Underflow Analysis by Layer and Pass Type', fontsize=16, fontweight='bold') + ax2.set_xlabel('Layer', fontsize=12) + ax2.set_ylabel('Underflow Percentage (%)', fontsize=12) + ax2.set_xticks(x_pos) + ax2.set_xticklabels([f'Layer {l}' for l in layers]) + ax2.legend() + ax2.grid(True, alpha=0.3) + + # Flush-to-Zero plot + forward_flush = [] + backward_flush = [] + for layer in layers: + forward_key = f"L{layer}_forward" + backward_key = f"L{layer}_backward" + + forward_avg = np.mean(layer_pass_data[forward_key]['flush_to_zero']) if forward_key in layer_pass_data else 0 + backward_avg = np.mean(layer_pass_data[backward_key]['flush_to_zero']) if backward_key in layer_pass_data else 0 + + forward_flush.append(forward_avg) + backward_flush.append(backward_avg) + + bars5 = ax3.bar(x_pos - width/2, forward_flush, width, label='Forward Pass', + color=self.colors['forward'], alpha=0.8) + bars6 = ax3.bar(x_pos + width/2, backward_flush, width, label='Backward Pass', + color=self.colors['backward'], alpha=0.8) + + ax3.set_title('Flush-to-Zero Analysis by Layer and Pass Type', fontsize=16, fontweight='bold') + ax3.set_xlabel('Layer', fontsize=12) + ax3.set_ylabel('Flush-to-Zero Percentage (%)', fontsize=12) + ax3.set_xticks(x_pos) + ax3.set_xticklabels([f'Layer {l}' for l in layers]) + ax3.legend() + ax3.grid(True, alpha=0.3) + + plt.tight_layout() + + # Save the plot + output_dir = self.base_directory / "enhanced_overflow_plots" + output_dir.mkdir(exist_ok=True) + output_file = output_dir / "enhanced_layer_pass_analysis.png" + plt.savefig(output_file, dpi=300, bbox_inches='tight') + print(f"Enhanced layer-pass analysis plot saved to: {output_file}") + + plt.show() + + def create_tensor_type_analysis_plot(self) -> None: + """Create tensor type analysis plot""" + if not self.results: + print("No data to plot!") + return + + # Group data by tensor type + tensor_type_data = defaultdict(lambda: {'overflow': [], 'underflow': [], 'flush_to_zero': []}) + + for result in self.results: + tensor_type_data[result.tensor_type]['overflow'].append(result.overflow_percentage) + tensor_type_data[result.tensor_type]['underflow'].append(result.underflow_percentage) + tensor_type_data[result.tensor_type]['flush_to_zero'].append(result.flush_to_zero_percentage) + + # Create the plot + fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(16, 18)) + + # Prepare data for plotting + tensor_types = sorted(tensor_type_data.keys()) + x_pos = np.arange(len(tensor_types)) + + # Overflow plot + overflow_means = [np.mean(tensor_type_data[t]['overflow']) for t in tensor_types] + overflow_stds = [np.std(tensor_type_data[t]['overflow']) for t in tensor_types] + + bars1 = ax1.bar(x_pos, overflow_means, yerr=overflow_stds, capsize=5, + color=[self.colors.get(t, '#808080') for t in tensor_types], alpha=0.8) + + ax1.set_title('Overflow Analysis by Tensor Type', fontsize=16, fontweight='bold') + ax1.set_xlabel('Tensor Type', fontsize=12) + ax1.set_ylabel('Overflow Percentage (%)', fontsize=12) + ax1.set_xticks(x_pos) + ax1.set_xticklabels(tensor_types, rotation=45, ha='right') + ax1.grid(True, alpha=0.3) + + # Underflow plot + underflow_means = [np.mean(tensor_type_data[t]['underflow']) for t in tensor_types] + underflow_stds = [np.std(tensor_type_data[t]['underflow']) for t in tensor_types] + + bars2 = ax2.bar(x_pos, underflow_means, yerr=underflow_stds, capsize=5, + color=[self.colors.get(t, '#808080') for t in tensor_types], alpha=0.8) + + ax2.set_title('Underflow Analysis by Tensor Type', fontsize=16, fontweight='bold') + ax2.set_xlabel('Tensor Type', fontsize=12) + ax2.set_ylabel('Underflow Percentage (%)', fontsize=12) + ax2.set_xticks(x_pos) + ax2.set_xticklabels(tensor_types, rotation=45, ha='right') + ax2.grid(True, alpha=0.3) + + # Flush-to-Zero plot + flush_means = [np.mean(tensor_type_data[t]['flush_to_zero']) for t in tensor_types] + flush_stds = [np.std(tensor_type_data[t]['flush_to_zero']) for t in tensor_types] + + bars3 = ax3.bar(x_pos, flush_means, yerr=flush_stds, capsize=5, + color=[self.colors.get(t, '#808080') for t in tensor_types], alpha=0.8) + + ax3.set_title('Flush-to-Zero Analysis by Tensor Type', fontsize=16, fontweight='bold') + ax3.set_xlabel('Tensor Type', fontsize=12) + ax3.set_ylabel('Flush-to-Zero Percentage (%)', fontsize=12) + ax3.set_xticks(x_pos) + ax3.set_xticklabels(tensor_types, rotation=45, ha='right') + ax3.grid(True, alpha=0.3) + + plt.tight_layout() + + # Save the plot + output_dir = self.base_directory / "enhanced_overflow_plots" + output_dir.mkdir(exist_ok=True) + output_file = output_dir / "enhanced_tensor_type_analysis.png" + plt.savefig(output_file, dpi=300, bbox_inches='tight') + print(f"Enhanced tensor type analysis plot saved to: {output_file}") + + plt.show() + + def create_operation_type_analysis_plot(self) -> None: + """Create operation type analysis plot""" + if not self.results: + print("No data to plot!") + return + + # Group data by operation type and pass type + operation_data = defaultdict(lambda: {'overflow': [], 'underflow': [], 'flush_to_zero': []}) + + for result in self.results: + key = f"{result.operation_type}_{result.pass_type}" + operation_data[key]['overflow'].append(result.overflow_percentage) + operation_data[key]['underflow'].append(result.underflow_percentage) + operation_data[key]['flush_to_zero'].append(result.flush_to_zero_percentage) + + # Create the plot + fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 16)) + + # Prepare data for plotting + operation_types = sorted(operation_data.keys()) + x_pos = np.arange(len(operation_types)) + + # Overflow plot + overflow_means = [np.mean(operation_data[t]['overflow']) for t in operation_types] + overflow_stds = [np.std(operation_data[t]['overflow']) for t in operation_types] + + colors = [] + for op_type in operation_types: + if 'linear' in op_type: + colors.append(self.colors['linear']) + elif 'attention' in op_type: + colors.append(self.colors['attention']) + else: + colors.append('#808080') + + bars1 = ax1.bar(x_pos, overflow_means, yerr=overflow_stds, capsize=5, + color=colors, alpha=0.8) + + ax1.set_title('Overflow Analysis by Operation Type and Pass', fontsize=16, fontweight='bold') + ax1.set_xlabel('Operation Type', fontsize=12) + ax1.set_ylabel('Overflow Percentage (%)', fontsize=12) + ax1.set_xticks(x_pos) + ax1.set_xticklabels(operation_types, rotation=45, ha='right') + ax1.grid(True, alpha=0.3) + + # Underflow plot + underflow_means = [np.mean(operation_data[t]['underflow']) for t in operation_types] + underflow_stds = [np.std(operation_data[t]['underflow']) for t in operation_types] + + bars2 = ax2.bar(x_pos, underflow_means, yerr=underflow_stds, capsize=5, + color=colors, alpha=0.8) + + ax2.set_title('Underflow Analysis by Operation Type and Pass', fontsize=16, fontweight='bold') + ax2.set_xlabel('Operation Type', fontsize=12) + ax2.set_ylabel('Underflow Percentage (%)', fontsize=12) + ax2.set_xticks(x_pos) + ax2.set_xticklabels(operation_types, rotation=45, ha='right') + ax2.grid(True, alpha=0.3) + + # Flush-to-Zero plot + flush_means = [np.mean(operation_data[t]['flush_to_zero']) for t in operation_types] + flush_stds = [np.std(operation_data[t]['flush_to_zero']) for t in operation_types] + + bars3 = ax3.bar(x_pos, flush_means, yerr=flush_stds, capsize=5, + color=colors, alpha=0.8) + + ax3.set_title('Flush-to-Zero Analysis by Operation Type and Pass', fontsize=16, fontweight='bold') + ax3.set_xlabel('Operation Type', fontsize=12) + ax3.set_ylabel('Flush-to-Zero Percentage (%)', fontsize=12) + ax3.set_xticks(x_pos) + ax3.set_xticklabels(operation_types, rotation=45, ha='right') + ax3.grid(True, alpha=0.3) + + plt.tight_layout() + + # Save the plot + output_dir = self.base_directory / "enhanced_overflow_plots" + output_dir.mkdir(exist_ok=True) + output_file = output_dir / "enhanced_operation_type_analysis.png" + plt.savefig(output_file, dpi=300, bbox_inches='tight') + print(f"Enhanced operation type analysis plot saved to: {output_file}") + + plt.show() + + def generate_enhanced_summary(self) -> Dict: + """Generate enhanced summary statistics""" + if not self.results: + return {} + + # Group by various dimensions + layer_stats = defaultdict(lambda: {'total': 0, 'overflow': [], 'underflow': [], 'flush_to_zero': []}) + pass_type_stats = defaultdict(lambda: {'total': 0, 'overflow': [], 'underflow': [], 'flush_to_zero': []}) + operation_type_stats = defaultdict(lambda: {'total': 0, 'overflow': [], 'underflow': [], 'flush_to_zero': []}) + tensor_type_stats = defaultdict(lambda: {'total': 0, 'overflow': [], 'underflow': [], 'flush_to_zero': []}) + + for result in self.results: + # Layer statistics + layer_key = f"Layer_{result.layer}" + layer_stats[layer_key]['total'] += 1 + layer_stats[layer_key]['overflow'].append(result.overflow_percentage) + layer_stats[layer_key]['underflow'].append(result.underflow_percentage) + layer_stats[layer_key]['flush_to_zero'].append(result.flush_to_zero_percentage) + + # Pass type statistics + pass_type_stats[result.pass_type]['total'] += 1 + pass_type_stats[result.pass_type]['overflow'].append(result.overflow_percentage) + pass_type_stats[result.pass_type]['underflow'].append(result.underflow_percentage) + pass_type_stats[result.pass_type]['flush_to_zero'].append(result.flush_to_zero_percentage) + + # Operation type statistics + operation_type_stats[result.operation_type]['total'] += 1 + operation_type_stats[result.operation_type]['overflow'].append(result.overflow_percentage) + operation_type_stats[result.operation_type]['underflow'].append(result.underflow_percentage) + operation_type_stats[result.operation_type]['flush_to_zero'].append(result.flush_to_zero_percentage) + + # Tensor type statistics + tensor_type_stats[result.tensor_type]['total'] += 1 + tensor_type_stats[result.tensor_type]['overflow'].append(result.overflow_percentage) + tensor_type_stats[result.tensor_type]['underflow'].append(result.underflow_percentage) + tensor_type_stats[result.tensor_type]['flush_to_zero'].append(result.flush_to_zero_percentage) + + # Calculate averages + def calc_averages(stats_dict): + result = {} + for key, data in stats_dict.items(): + result[key] = { + 'total': data['total'], + 'avg_overflow': np.mean(data['overflow']) if data['overflow'] else 0, + 'avg_underflow': np.mean(data['underflow']) if data['underflow'] else 0, + 'avg_flush_to_zero': np.mean(data['flush_to_zero']) if data['flush_to_zero'] else 0, + 'max_overflow': np.max(data['overflow']) if data['overflow'] else 0, + 'max_underflow': np.max(data['underflow']) if data['underflow'] else 0, + 'max_flush_to_zero': np.max(data['flush_to_zero']) if data['flush_to_zero'] else 0 + } + return result + + return { + 'total_tensors': len(self.results), + 'layer_stats': calc_averages(layer_stats), + 'pass_type_stats': calc_averages(pass_type_stats), + 'operation_type_stats': calc_averages(operation_type_stats), + 'tensor_type_stats': calc_averages(tensor_type_stats) + } + + def print_enhanced_summary(self) -> None: + """Print enhanced summary report""" + if not self.results: + print("No results to report!") + return + + summary = self.generate_enhanced_summary() + + print("\n" + "="*100) + print("ENHANCED OVERFLOW/UNDERFLOW ANALYSIS SUMMARY") + print("="*100) + + print(f"\n📊 OVERALL SUMMARY:") + print(f" Total Tensors Analyzed: {summary['total_tensors']}") + + # Pass type breakdown + print(f"\n🔄 BREAKDOWN BY PASS TYPE:") + for pass_type, stats in summary['pass_type_stats'].items(): + print(f" {pass_type.upper():10} | Total: {stats['total']:3d} | " + f"Avg Overflow: {stats['avg_overflow']:6.3f}% | " + f"Avg Underflow: {stats['avg_underflow']:6.3f}% | " + f"Avg Flush-to-Zero: {stats['avg_flush_to_zero']:6.3f}%") + + # Operation type breakdown + print(f"\n⚙️ BREAKDOWN BY OPERATION TYPE:") + for op_type, stats in summary['operation_type_stats'].items(): + print(f" {op_type.upper():10} | Total: {stats['total']:3d} | " + f"Avg Overflow: {stats['avg_overflow']:6.3f}% | " + f"Avg Underflow: {stats['avg_underflow']:6.3f}% | " + f"Avg Flush-to-Zero: {stats['avg_flush_to_zero']:6.3f}%") + + # Layer breakdown + print(f"\n🏗️ BREAKDOWN BY LAYER:") + for layer, stats in sorted(summary['layer_stats'].items()): + print(f" {layer:10} | Total: {stats['total']:3d} | " + f"Avg Overflow: {stats['avg_overflow']:6.3f}% | " + f"Avg Underflow: {stats['avg_underflow']:6.3f}% | " + f"Avg Flush-to-Zero: {stats['avg_flush_to_zero']:6.3f}%") + + # Tensor type breakdown + print(f"\n📋 BREAKDOWN BY TENSOR TYPE:") + for tensor_type, stats in summary['tensor_type_stats'].items(): + print(f" {tensor_type.upper():15} | Total: {stats['total']:3d} | " + f"Avg Overflow: {stats['avg_overflow']:6.3f}% | " + f"Avg Underflow: {stats['avg_underflow']:6.3f}% | " + f"Avg Flush-to-Zero: {stats['avg_flush_to_zero']:6.3f}%") + + print("\n" + "="*100) + + +def main(): + """Main function""" + print("🔍 Starting Enhanced Overflow Analysis...") + + # Initialize analyzer + analyzer = EnhancedOverflowAnalyzer() + + # Check if scaling directory exists + if not analyzer.scaling_dir.exists(): + print(f"❌ Error: Scaling analysis directory not found: {analyzer.scaling_dir}") + return + + # Analyze all files + analyzer.analyze_all_files() + + if not analyzer.results: + print("❌ No valid results found!") + return + + # Print enhanced summary + analyzer.print_enhanced_summary() + + # Create enhanced plots + print("\n📊 Creating enhanced analysis plots...") + analyzer.create_enhanced_overflow_analysis_plot() + analyzer.create_tensor_type_analysis_plot() + analyzer.create_operation_type_analysis_plot() + + print("\n✅ Enhanced overflow analysis completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/visualization/overflow/enhanced_scaling_analysis_results.json b/visualization/overflow/enhanced_scaling_analysis_results.json new file mode 100644 index 0000000000..372f9fc128 --- /dev/null +++ b/visualization/overflow/enhanced_scaling_analysis_results.json @@ -0,0 +1,3020 @@ +{ + "analysis_summary": { + "total_tensors": 136, + "at_max_count": 19, + "not_at_max_count": 117, + "at_max_percentage": 13.970588235294118, + "layer_stats": { + "Layer_8": { + "total": 34, + "at_max": 7 + }, + "Layer_1": { + "total": 34, + "at_max": 4 + }, + "Layer_15": { + "total": 34, + "at_max": 2 + }, + "Layer_16": { + "total": 34, + "at_max": 6 + } + }, + "pass_type_stats": { + "forward": { + "total": 80, + "at_max": 9 + }, + "backward": { + "total": 56, + "at_max": 10 + } + }, + "operation_type_stats": { + "linear": { + "total": 80, + "at_max": 13 + }, + "attention": { + "total": 56, + "at_max": 6 + } + }, + "tensor_type_stats": { + "output": { + "total": 36, + "at_max": 9 + }, + "B": { + "total": 16, + "at_max": 1 + }, + "weight": { + "total": 16, + "at_max": 3 + }, + "query": { + "total": 8, + "at_max": 0 + }, + "input": { + "total": 16, + "at_max": 4 + }, + "probs": { + "total": 8, + "at_max": 0 + }, + "A": { + "total": 16, + "at_max": 1 + }, + "value": { + "total": 8, + "at_max": 0 + }, + "buffer": { + "total": 4, + "at_max": 1 + }, + "key": { + "total": 8, + "at_max": 0 + } + }, + "layer_pass_stats": { + "L8_forward": { + "total": 20, + "at_max": 3 + }, + "L1_forward": { + "total": 20, + "at_max": 2 + }, + "L15_forward": { + "total": 20, + "at_max": 1 + }, + "L15_backward": { + "total": 14, + "at_max": 1 + }, + "L8_backward": { + "total": 14, + "at_max": 4 + }, + "L16_forward": { + "total": 20, + "at_max": 3 + }, + "L1_backward": { + "total": 14, + "at_max": 2 + }, + "L16_backward": { + "total": 14, + "at_max": 3 + } + }, + "layer_operation_stats": { + "L8_linear": { + "total": 20, + "at_max": 5 + }, + "L1_linear": { + "total": 20, + "at_max": 3 + }, + "L15_attention": { + "total": 14, + "at_max": 1 + }, + "L15_linear": { + "total": 20, + "at_max": 1 + }, + "L16_linear": { + "total": 20, + "at_max": 4 + }, + "L16_attention": { + "total": 14, + "at_max": 2 + }, + "L1_attention": { + "total": 14, + "at_max": 1 + }, + "L8_attention": { + "total": 14, + "at_max": 2 + } + } + }, + "detailed_results": [ + { + "tensor_name": "L8_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100238_0157_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100238_0157_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -21.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0008, + "is_at_max": false, + "mse": 0.0002629248, + "cosine_similarity": 1.002706, + "psnr": 46.33, + "mae": 0.01092842, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100149_0019_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100149_0019_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -24.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.971756e-09, + "cosine_similarity": 1.000053, + "psnr": 45.48, + "mae": 4.254774e-05, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0291_iter000_attention_L15_forward_post_FA_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100326_0291_iter000_attention_L15_forward_post_FA_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -19.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0001800647, + "cosine_similarity": 1.000674, + "psnr": 43.99, + "mae": 0.008812865, + "relative_error": 2.22, + "layer": 15, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100149_0017_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100149_0017_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -19.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0008, + "is_at_max": false, + "mse": 0.0002506918, + "cosine_similarity": 1.002639, + "psnr": 46.69, + "mae": 0.01068572, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0345_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100340_0345_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -23.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.969729e-09, + "cosine_similarity": 0.999886, + "psnr": 45.3, + "mae": 4.255912e-05, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0297_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100326_0297_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -19.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0008, + "is_at_max": false, + "mse": 0.0002603733, + "cosine_similarity": 1.00263, + "psnr": 45.96, + "mae": 0.01090183, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_attention_query", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100321_0285_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_query/mxfp_scaling_test_20250923_100321_0285_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_query_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -18.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002472007, + "cosine_similarity": 1.00051, + "psnr": 45.24, + "mae": 0.0106003, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "query", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100405_0442_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100405_0442_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -8.0, + "min_align": -15.0, + "recommended_scale_exp": -8.0, + "recommended_scale_factor": 0.003906, + "composite_score": 1.0004, + "is_at_max": true, + "mse": 2.371866e-06, + "cosine_similarity": 1.001184, + "psnr": 52.0, + "mae": 0.001035215, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0314_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100333_0314_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -10.0, + "min_align": -26.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0003, + "is_at_max": false, + "mse": 1.056957e-06, + "cosine_similarity": 1.000867, + "psnr": 44.06, + "mae": 0.0006982776, + "relative_error": 2.26, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100344_0353_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100344_0353_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -22.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0, + "is_at_max": false, + "mse": 1.266206e-07, + "cosine_similarity": 0.999925, + "psnr": 45.79, + "mae": 0.0002403457, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100431_0541_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100431_0541_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -23.0, + "recommended_scale_exp": -14.0, + "recommended_scale_factor": 6.1e-05, + "composite_score": 1.0, + "is_at_max": true, + "mse": 3.975844e-09, + "cosine_similarity": 0.999891, + "psnr": 45.47, + "mae": 4.258759e-05, + "relative_error": 2.25, + "layer": 1, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_attention_query", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100328_0305_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_query/mxfp_scaling_test_20250923_100328_0305_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_query_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -22.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002610916, + "cosine_similarity": 1.000522, + "psnr": 45.51, + "mae": 0.01092781, + "relative_error": 2.26, + "layer": 16, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "query", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_attention_probs", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100431_0543_iter000_attention_L1_backward_post_FA_bf16_rank00_group000_grad_attention_probs/mxfp_scaling_test_20250923_100431_0543_iter000_attention_L1_backward_post_FA_bf16_rank00_group000_grad_attention_probs_fp8_e4m3.log", + "max_align": -2.0, + "min_align": -22.0, + "recommended_scale_exp": -3.0, + "recommended_scale_factor": 0.125, + "composite_score": 1.0645, + "is_at_max": false, + "mse": 0.0007662339, + "cosine_similarity": 1.215072, + "psnr": 63.62, + "mae": 0.01345307, + "relative_error": 2.25, + "layer": 1, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "probs", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0318_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100333_0318_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -25.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0007, + "is_at_max": false, + "mse": 2.885989e-05, + "cosine_similarity": 1.002257, + "psnr": 58.03, + "mae": 0.002630964, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100430_0536_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100430_0536_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -2.0, + "min_align": -17.0, + "recommended_scale_exp": -3.0, + "recommended_scale_factor": 0.125, + "composite_score": 1.0011, + "is_at_max": false, + "mse": 0.000163752, + "cosine_similarity": 1.003596, + "psnr": 69.62, + "mae": 0.004889333, + "relative_error": 2.25, + "layer": 1, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_attention_query", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100409_0448_iter000_attention_L8_backward_post_attention_bf16_rank00_group000_grad_query/mxfp_scaling_test_20250923_100409_0448_iter000_attention_L8_backward_post_attention_bf16_rank00_group000_grad_query_fp8_e4m3.log", + "max_align": -16.0, + "min_align": -31.0, + "recommended_scale_exp": -17.0, + "recommended_scale_factor": 8e-06, + "composite_score": 1.0003, + "is_at_max": false, + "mse": 1.706296e-11, + "cosine_similarity": 1.001083, + "psnr": 58.66, + "mae": 2.033102e-06, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "query", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100336_0327_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100336_0327_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -25.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.97255e-09, + "cosine_similarity": 1.000055, + "psnr": 46.37, + "mae": 4.257269e-05, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0341_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100340_0341_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -27.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.981258e-09, + "cosine_similarity": 1.000049, + "psnr": 45.82, + "mae": 4.259716e-05, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0320_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100333_0320_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -23.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0002, + "is_at_max": true, + "mse": 3.247117e-07, + "cosine_similarity": 1.000625, + "psnr": 46.01, + "mae": 0.0003858607, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100231_0141_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100231_0141_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -18.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 0.0007119142, + "cosine_similarity": 1.001191, + "psnr": 44.89, + "mae": 0.01799846, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100237_0152_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100237_0152_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -21.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0001, + "is_at_max": true, + "mse": 0.0001634078, + "cosine_similarity": 1.000405, + "psnr": 43.18, + "mae": 0.008672393, + "relative_error": 2.28, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_attention_probs", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100236_0149_iter000_attention_L8_forward_post_FA_bf16_rank00_group000_attention_probs/mxfp_scaling_test_20250923_100236_0149_iter000_attention_L8_forward_post_FA_bf16_rank00_group000_attention_probs_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -10.0, + "recommended_scale_exp": -8.0, + "recommended_scale_factor": 0.003906, + "composite_score": 1.0341, + "is_at_max": false, + "mse": 1.011233e-10, + "cosine_similarity": 1.116258, + "psnr": 99.95, + "mae": 2.751836e-06, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "probs", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0298_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100326_0298_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -24.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0007, + "is_at_max": false, + "mse": 2.818011e-05, + "cosine_similarity": 1.002216, + "psnr": 56.3, + "mae": 0.002643707, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_attention_value", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0310_iter000_attention_L16_forward_post_FA_bf16_rank00_group000_value/mxfp_scaling_test_20250923_100333_0310_iter000_attention_L16_forward_post_FA_bf16_rank00_group000_value_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -20.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 0.000272202, + "cosine_similarity": 1.000485, + "psnr": 44.34, + "mae": 0.01123203, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "value", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0340_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100340_0340_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -9.0, + "min_align": -16.0, + "recommended_scale_exp": -10.0, + "recommended_scale_factor": 0.000977, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 1.351196e-06, + "cosine_similarity": 1.000569, + "psnr": 46.79, + "mae": 0.0007860017, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0311_iter000_attention_L16_forward_post_FA_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100333_0311_iter000_attention_L16_forward_post_FA_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -18.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0001926621, + "cosine_similarity": 1.000504, + "psnr": 43.17, + "mae": 0.009442221, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100336_0326_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100336_0326_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -10.0, + "min_align": -23.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 1.304203e-06, + "cosine_similarity": 1.000611, + "psnr": 45.91, + "mae": 0.0007721899, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100231_0143_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100231_0143_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -19.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002611995, + "cosine_similarity": 1.00079, + "psnr": 45.55, + "mae": 0.01091542, + "relative_error": 2.26, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_attention_buffer", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100142_0004_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_matmul_input_buffer/mxfp_scaling_test_20250923_100142_0004_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_matmul_input_buffer_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -7.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 0.0, + "is_at_max": true, + "mse": 0.0, + "cosine_similarity": 1.0, + "psnr": 0.0, + "mae": 0.0, + "relative_error": 0.0, + "layer": 1, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "buffer", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_attention_key", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100233_0146_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_key/mxfp_scaling_test_20250923_100233_0146_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_key_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -17.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002584177, + "cosine_similarity": 1.000617, + "psnr": 45.82, + "mae": 0.01082279, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "key", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_attention_probs", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100405_0445_iter000_attention_L8_backward_post_FA_bf16_rank00_group000_grad_attention_probs/mxfp_scaling_test_20250923_100405_0445_iter000_attention_L8_backward_post_FA_bf16_rank00_group000_grad_attention_probs_fp8_e4m3.log", + "max_align": -9.0, + "min_align": -25.0, + "recommended_scale_exp": -10.0, + "recommended_scale_factor": 0.000977, + "composite_score": 1.0666, + "is_at_max": false, + "mse": 1.300845e-06, + "cosine_similarity": 1.222151, + "psnr": 51.75, + "mae": 0.0007722682, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "probs", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100409_0451_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100409_0451_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -23.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0, + "is_at_max": false, + "mse": 1.267064e-07, + "cosine_similarity": 0.999933, + "psnr": 45.79, + "mae": 0.0002404514, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0316_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100333_0316_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -23.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.265832e-07, + "cosine_similarity": 1.000359, + "psnr": 45.98, + "mae": 0.0002403485, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0343_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100340_0343_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -22.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.26704e-07, + "cosine_similarity": 1.000363, + "psnr": 46.15, + "mae": 0.0002404279, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100142_0003_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100142_0003_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -16.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002472399, + "cosine_similarity": 1.000787, + "psnr": 45.52, + "mae": 0.01060983, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0295_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100326_0295_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -18.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 0.0007145355, + "cosine_similarity": 1.001221, + "psnr": 44.34, + "mae": 0.01792298, + "relative_error": 2.26, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0299_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100326_0299_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -27.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.981258e-09, + "cosine_similarity": 1.00005, + "psnr": 45.82, + "mae": 4.259716e-05, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100328_0307_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_mm_output/mxfp_scaling_test_20250923_100328_0307_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_mm_output_fp8_e4m3.log", + "max_align": -3.0, + "min_align": -19.0, + "recommended_scale_exp": -3.0, + "recommended_scale_factor": 0.125, + "composite_score": 1.078, + "is_at_max": true, + "mse": 0.01537745, + "cosine_similarity": 1.260036, + "psnr": 47.6, + "mae": 0.08310976, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100148_0013_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100148_0013_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -23.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.975844e-09, + "cosine_similarity": 0.999892, + "psnr": 45.47, + "mae": 4.258759e-05, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100320_0283_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100320_0283_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -20.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002547138, + "cosine_similarity": 1.000808, + "psnr": 45.75, + "mae": 0.0107118, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_attention_key", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100321_0286_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_key/mxfp_scaling_test_20250923_100321_0286_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_key_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -17.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002796364, + "cosine_similarity": 1.00073, + "psnr": 44.52, + "mae": 0.01117912, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "key", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_attention_probs", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0347_iter000_attention_L15_backward_post_FA_bf16_rank00_group000_grad_attention_probs/mxfp_scaling_test_20250923_100340_0347_iter000_attention_L15_backward_post_FA_bf16_rank00_group000_grad_attention_probs_fp8_e4m3.log", + "max_align": -10.0, + "min_align": -25.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0665, + "is_at_max": false, + "mse": 7.543104e-07, + "cosine_similarity": 1.221606, + "psnr": 46.12, + "mae": 0.0005845426, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "probs", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_attention_buffer", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100320_0284_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_matmul_input_buffer/mxfp_scaling_test_20250923_100320_0284_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_matmul_input_buffer_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -19.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 2.230996e-05, + "cosine_similarity": 1.001183, + "psnr": 59.88, + "mae": 0.0005599458, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "buffer", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0315_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100333_0315_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -16.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 0.0007171242, + "cosine_similarity": 1.001191, + "psnr": 44.01, + "mae": 0.01795938, + "relative_error": 2.26, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100337_0332_iter000_attention_L16_backward_post_FA_bf16_rank00_group000_grad_output/mxfp_scaling_test_20250923_100337_0332_iter000_attention_L16_backward_post_FA_bf16_rank00_group000_grad_output_fp8_e4m3.log", + "max_align": -13.0, + "min_align": -24.0, + "recommended_scale_exp": -13.0, + "recommended_scale_factor": 0.000122, + "composite_score": 1.0002, + "is_at_max": true, + "mse": 1.55476e-08, + "cosine_similarity": 1.000528, + "psnr": 47.38, + "mae": 8.418535e-05, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0296_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100326_0296_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -22.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.26704e-07, + "cosine_similarity": 1.00036, + "psnr": 46.15, + "mae": 0.0002404279, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100405_0443_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100405_0443_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -25.0, + "recommended_scale_exp": -14.0, + "recommended_scale_factor": 6.1e-05, + "composite_score": 1.0, + "is_at_max": true, + "mse": 3.972378e-09, + "cosine_similarity": 0.999889, + "psnr": 45.83, + "mae": 4.255031e-05, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100434_0549_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100434_0549_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -22.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0, + "is_at_max": false, + "mse": 1.267756e-07, + "cosine_similarity": 0.999925, + "psnr": 45.97, + "mae": 0.0002404027, + "relative_error": 2.25, + "layer": 1, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100238_0159_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100238_0159_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -25.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.980728e-09, + "cosine_similarity": 1.000056, + "psnr": 45.78, + "mae": 4.259601e-05, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0319_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100333_0319_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -25.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.972549e-09, + "cosine_similarity": 1.000053, + "psnr": 46.37, + "mae": 4.257269e-05, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100321_0287_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_mm_output/mxfp_scaling_test_20250923_100321_0287_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_mm_output_fp8_e4m3.log", + "max_align": -3.0, + "min_align": -19.0, + "recommended_scale_exp": -4.0, + "recommended_scale_factor": 0.0625, + "composite_score": 1.085, + "is_at_max": false, + "mse": 0.0111057, + "cosine_similarity": 1.28336, + "psnr": 47.89, + "mae": 0.07067227, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100327_0302_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100327_0302_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -23.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0, + "is_at_max": false, + "mse": 1.266533e-07, + "cosine_similarity": 0.999925, + "psnr": 46.03, + "mae": 0.0002404344, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100237_0154_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100237_0154_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -10.0, + "min_align": -23.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0003, + "is_at_max": false, + "mse": 8.626195e-07, + "cosine_similarity": 1.000915, + "psnr": 46.28, + "mae": 0.0006306471, + "relative_error": 2.24, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100338_0335_iter000_attention_L16_backward_post_attention_bf16_rank00_group000_grad_output/mxfp_scaling_test_20250923_100338_0335_iter000_attention_L16_backward_post_attention_bf16_rank00_group000_grad_output_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -42.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.027, + "is_at_max": false, + "mse": 1.801698e-14, + "cosine_similarity": 1.090126, + "psnr": 95.37, + "mae": 3.677253e-08, + "relative_error": 2.3, + "layer": 16, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100337_0329_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100337_0329_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -23.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.265832e-07, + "cosine_similarity": 1.000356, + "psnr": 45.98, + "mae": 0.0002403485, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_attention_value", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100432_0544_iter000_attention_L1_backward_post_FA_bf16_rank00_group000_grad_value/mxfp_scaling_test_20250923_100432_0544_iter000_attention_L1_backward_post_FA_bf16_rank00_group000_grad_value_fp8_e4m3.log", + "max_align": -3.0, + "min_align": -28.0, + "recommended_scale_exp": -4.0, + "recommended_scale_factor": 0.0625, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 1.263131e-05, + "cosine_similarity": 1.000792, + "psnr": 73.66, + "mae": 0.0004534796, + "relative_error": 2.29, + "layer": 1, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "value", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_attention_query", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100343_0350_iter000_attention_L15_backward_post_attention_bf16_rank00_group000_grad_query/mxfp_scaling_test_20250923_100343_0350_iter000_attention_L15_backward_post_attention_bf16_rank00_group000_grad_query_fp8_e4m3.log", + "max_align": -16.0, + "min_align": -31.0, + "recommended_scale_exp": -17.0, + "recommended_scale_factor": 8e-06, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 1.634347e-11, + "cosine_similarity": 1.001167, + "psnr": 57.34, + "mae": 1.905427e-06, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "query", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100231_0142_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100231_0142_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -23.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0, + "is_at_max": false, + "mse": 1.267064e-07, + "cosine_similarity": 0.999929, + "psnr": 45.79, + "mae": 0.0002404514, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0346_iter000_attention_L15_backward_post_FA_bf16_rank00_group000_grad_output/mxfp_scaling_test_20250923_100340_0346_iter000_attention_L15_backward_post_FA_bf16_rank00_group000_grad_output_fp8_e4m3.log", + "max_align": -13.0, + "min_align": -28.0, + "recommended_scale_exp": -13.0, + "recommended_scale_factor": 0.000122, + "composite_score": 1.0002, + "is_at_max": true, + "mse": 1.631616e-08, + "cosine_similarity": 1.000534, + "psnr": 45.76, + "mae": 8.630973e-05, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_attention_key", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0337_iter000_attention_L16_backward_post_attention_bf16_rank00_group000_grad_key/mxfp_scaling_test_20250923_100340_0337_iter000_attention_L16_backward_post_attention_bf16_rank00_group000_grad_key_fp8_e4m3.log", + "max_align": -16.0, + "min_align": -33.0, + "recommended_scale_exp": -17.0, + "recommended_scale_factor": 8e-06, + "composite_score": 1.0003, + "is_at_max": false, + "mse": 2.401485e-12, + "cosine_similarity": 1.001032, + "psnr": 66.95, + "mae": 5.728989e-07, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "key", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100409_0450_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100409_0450_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -10.0, + "min_align": -31.0, + "recommended_scale_exp": -10.0, + "recommended_scale_factor": 0.000977, + "composite_score": 1.001, + "is_at_max": true, + "mse": 1.496797e-09, + "cosine_similarity": 1.003516, + "psnr": 76.07, + "mae": 6.937392e-06, + "relative_error": 2.28, + "layer": 8, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100337_0331_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100337_0331_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -26.0, + "recommended_scale_exp": -14.0, + "recommended_scale_factor": 6.1e-05, + "composite_score": 1.0, + "is_at_max": true, + "mse": 3.976778e-09, + "cosine_similarity": 0.999889, + "psnr": 46.37, + "mae": 4.256255e-05, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100431_0540_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100431_0540_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -2.0, + "min_align": -10.0, + "recommended_scale_exp": -3.0, + "recommended_scale_factor": 0.125, + "composite_score": 1.0006, + "is_at_max": false, + "mse": 0.001502305, + "cosine_similarity": 1.002165, + "psnr": 63.2, + "mae": 0.01883791, + "relative_error": 2.25, + "layer": 1, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100430_0539_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100430_0539_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -22.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.26664e-07, + "cosine_similarity": 1.000361, + "psnr": 46.5, + "mae": 0.0002405087, + "relative_error": 2.25, + "layer": 1, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100149_0015_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100149_0015_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -18.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 0.0006795852, + "cosine_similarity": 1.001239, + "psnr": 46.13, + "mae": 0.01757503, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100431_0542_iter000_attention_L1_backward_post_FA_bf16_rank00_group000_grad_output/mxfp_scaling_test_20250923_100431_0542_iter000_attention_L1_backward_post_FA_bf16_rank00_group000_grad_output_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -20.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.7316e-05, + "cosine_similarity": 1.000465, + "psnr": 62.52, + "mae": 0.002024189, + "relative_error": 2.25, + "layer": 1, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0293_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100326_0293_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -23.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.969729e-09, + "cosine_similarity": 0.999887, + "psnr": 45.3, + "mae": 4.255912e-05, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100148_0012_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100148_0012_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -22.0, + "recommended_scale_exp": -8.0, + "recommended_scale_factor": 0.003906, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 1.576756e-06, + "cosine_similarity": 1.000552, + "psnr": 63.0, + "mae": 0.0008171461, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100148_0011_iter000_attention_L1_forward_post_FA_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100148_0011_iter000_attention_L1_forward_post_FA_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -22.0, + "recommended_scale_exp": -8.0, + "recommended_scale_factor": 0.003906, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.576756e-06, + "cosine_similarity": 1.000471, + "psnr": 63.0, + "mae": 0.0008171461, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100405_0440_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100405_0440_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -12.0, + "min_align": -34.0, + "recommended_scale_exp": -13.0, + "recommended_scale_factor": 0.000122, + "composite_score": 1.0016, + "is_at_max": false, + "mse": 2.917878e-09, + "cosine_similarity": 1.00538, + "psnr": 59.91, + "mae": 2.58771e-05, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_attention_probs", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100331_0309_iter000_attention_L16_forward_post_FA_bf16_rank00_group000_attention_probs/mxfp_scaling_test_20250923_100331_0309_iter000_attention_L16_forward_post_FA_bf16_rank00_group000_attention_probs_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -10.0, + "recommended_scale_exp": -8.0, + "recommended_scale_factor": 0.003906, + "composite_score": 1.0328, + "is_at_max": false, + "mse": 9.803741e-11, + "cosine_similarity": 1.115022, + "psnr": 100.09, + "mae": 2.752137e-06, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "probs", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_attention_probs", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100324_0289_iter000_attention_L15_forward_post_FA_bf16_rank00_group000_attention_probs/mxfp_scaling_test_20250923_100324_0289_iter000_attention_L15_forward_post_FA_bf16_rank00_group000_attention_probs_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -10.0, + "recommended_scale_exp": -8.0, + "recommended_scale_factor": 0.003906, + "composite_score": 1.0342, + "is_at_max": false, + "mse": 9.700087e-11, + "cosine_similarity": 1.115252, + "psnr": 100.13, + "mae": 2.75157e-06, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "probs", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100142_0002_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100142_0002_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -22.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0, + "is_at_max": false, + "mse": 1.267756e-07, + "cosine_similarity": 0.999925, + "psnr": 45.97, + "mae": 0.0002404027, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0338_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100340_0338_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -12.0, + "min_align": -33.0, + "recommended_scale_exp": -13.0, + "recommended_scale_factor": 0.000122, + "composite_score": 1.0009, + "is_at_max": false, + "mse": 1.207305e-10, + "cosine_similarity": 1.002961, + "psnr": 69.98, + "mae": 3.711559e-06, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100337_0328_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100337_0328_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -12.0, + "min_align": -33.0, + "recommended_scale_exp": -13.0, + "recommended_scale_factor": 0.000122, + "composite_score": 1.0017, + "is_at_max": false, + "mse": 1.696953e-09, + "cosine_similarity": 1.005805, + "psnr": 58.68, + "mae": 1.987462e-05, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_attention_query", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100144_0005_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_query/mxfp_scaling_test_20250923_100144_0005_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_query_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -19.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002470705, + "cosine_similarity": 1.000512, + "psnr": 45.48, + "mae": 0.0106188, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "query", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100432_0545_iter000_attention_L1_backward_post_attention_bf16_rank00_group000_grad_output/mxfp_scaling_test_20250923_100432_0545_iter000_attention_L1_backward_post_attention_bf16_rank00_group000_grad_output_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -35.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 0.9949, + "is_at_max": false, + "mse": 2.212613e-09, + "cosine_similarity": 1.005057, + "psnr": 102.51, + "mae": 4.367415e-06, + "relative_error": 6.03, + "layer": 1, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_attention_key", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100409_0449_iter000_attention_L8_backward_post_attention_bf16_rank00_group000_grad_key/mxfp_scaling_test_20250923_100409_0449_iter000_attention_L8_backward_post_attention_bf16_rank00_group000_grad_key_fp8_e4m3.log", + "max_align": -12.0, + "min_align": -33.0, + "recommended_scale_exp": -13.0, + "recommended_scale_factor": 0.000122, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 3.603861e-11, + "cosine_similarity": 1.001385, + "psnr": 75.29, + "mae": 1.381485e-06, + "relative_error": 2.26, + "layer": 8, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "key", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100434_0548_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100434_0548_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -3.0, + "min_align": -26.0, + "recommended_scale_exp": -4.0, + "recommended_scale_factor": 0.0625, + "composite_score": 1.0003, + "is_at_max": false, + "mse": 8.11067e-06, + "cosine_similarity": 1.000861, + "psnr": 80.12, + "mae": 0.000197908, + "relative_error": 2.38, + "layer": 1, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_attention_value", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100342_0348_iter000_attention_L15_backward_post_FA_bf16_rank00_group000_grad_value/mxfp_scaling_test_20250923_100342_0348_iter000_attention_L15_backward_post_FA_bf16_rank00_group000_grad_value_fp8_e4m3.log", + "max_align": -13.0, + "min_align": -36.0, + "recommended_scale_exp": -14.0, + "recommended_scale_factor": 6.1e-05, + "composite_score": 1.0003, + "is_at_max": false, + "mse": 2.15195e-10, + "cosine_similarity": 1.001137, + "psnr": 65.09, + "mae": 7.233585e-06, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "value", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100405_0438_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100405_0438_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -8.0, + "min_align": -21.0, + "recommended_scale_exp": -9.0, + "recommended_scale_factor": 0.001953, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 2.126439e-06, + "cosine_similarity": 1.00124, + "psnr": 52.42, + "mae": 0.000982874, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100238_0158_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100238_0158_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -24.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0007, + "is_at_max": false, + "mse": 3.281131e-05, + "cosine_similarity": 1.002278, + "psnr": 57.15, + "mae": 0.002702262, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100323_0288_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100323_0288_iter000_attention_L15_forward_pre_attention_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -23.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0858, + "is_at_max": false, + "mse": 8.65289e-05, + "cosine_similarity": 1.285876, + "psnr": 47.91, + "mae": 0.006238645, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_attention_buffer", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100327_0304_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_matmul_input_buffer/mxfp_scaling_test_20250923_100327_0304_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_matmul_input_buffer_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -18.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 2.230255e-05, + "cosine_similarity": 1.001171, + "psnr": 59.46, + "mae": 0.0005601679, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "buffer", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100233_0147_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_mm_output/mxfp_scaling_test_20250923_100233_0147_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_mm_output_fp8_e4m3.log", + "max_align": -3.0, + "min_align": -18.0, + "recommended_scale_exp": -4.0, + "recommended_scale_factor": 0.0625, + "composite_score": 1.0663, + "is_at_max": false, + "mse": 0.01231307, + "cosine_similarity": 1.220964, + "psnr": 47.56, + "mae": 0.07484189, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100342_0349_iter000_attention_L15_backward_post_attention_bf16_rank00_group000_grad_output/mxfp_scaling_test_20250923_100342_0349_iter000_attention_L15_backward_post_attention_bf16_rank00_group000_grad_output_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -42.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0281, + "is_at_max": false, + "mse": 1.852631e-14, + "cosine_similarity": 1.094028, + "psnr": 95.25, + "mae": 3.76381e-08, + "relative_error": 2.29, + "layer": 15, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0344_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100340_0344_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -9.0, + "min_align": -15.0, + "recommended_scale_exp": -10.0, + "recommended_scale_factor": 0.000977, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 1.417114e-06, + "cosine_similarity": 1.00061, + "psnr": 46.58, + "mae": 0.0008036151, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0342_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100340_0342_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -12.0, + "min_align": -33.0, + "recommended_scale_exp": -13.0, + "recommended_scale_factor": 0.000122, + "composite_score": 1.0017, + "is_at_max": false, + "mse": 1.770593e-09, + "cosine_similarity": 1.005757, + "psnr": 57.68, + "mae": 2.044449e-05, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100237_0151_iter000_attention_L8_forward_post_FA_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100237_0151_iter000_attention_L8_forward_post_FA_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -21.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0001, + "is_at_max": true, + "mse": 0.0001634078, + "cosine_similarity": 1.000342, + "psnr": 43.18, + "mae": 0.008672393, + "relative_error": 2.28, + "layer": 8, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100327_0303_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100327_0303_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -20.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002679498, + "cosine_similarity": 1.000758, + "psnr": 45.4, + "mae": 0.01106775, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0313_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100333_0313_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -26.0, + "recommended_scale_exp": -14.0, + "recommended_scale_factor": 6.1e-05, + "composite_score": 1.0, + "is_at_max": true, + "mse": 3.976778e-09, + "cosine_similarity": 0.99989, + "psnr": 46.37, + "mae": 4.256255e-05, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100234_0148_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100234_0148_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -22.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0667, + "is_at_max": false, + "mse": 9.549908e-05, + "cosine_similarity": 1.22217, + "psnr": 47.6, + "mae": 0.006601206, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100149_0020_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100149_0020_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -24.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0002, + "is_at_max": true, + "mse": 3.055489e-07, + "cosine_similarity": 1.000677, + "psnr": 46.67, + "mae": 0.000373342, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_attention_value", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100338_0334_iter000_attention_L16_backward_post_FA_bf16_rank00_group000_grad_value/mxfp_scaling_test_20250923_100338_0334_iter000_attention_L16_backward_post_FA_bf16_rank00_group000_grad_value_fp8_e4m3.log", + "max_align": -13.0, + "min_align": -30.0, + "recommended_scale_exp": -14.0, + "recommended_scale_factor": 6.1e-05, + "composite_score": 1.0003, + "is_at_max": false, + "mse": 1.67845e-10, + "cosine_similarity": 1.000859, + "psnr": 61.89, + "mae": 6.785997e-06, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "value", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100238_0156_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100238_0156_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -26.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.268228e-07, + "cosine_similarity": 1.000362, + "psnr": 46.03, + "mae": 0.0002405346, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100238_0160_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100238_0160_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -26.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0003, + "is_at_max": true, + "mse": 3.797914e-07, + "cosine_similarity": 1.000852, + "psnr": 45.8, + "mae": 0.0004171906, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_attention_value", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100407_0446_iter000_attention_L8_backward_post_FA_bf16_rank00_group000_grad_value/mxfp_scaling_test_20250923_100407_0446_iter000_attention_L8_backward_post_FA_bf16_rank00_group000_grad_value_fp8_e4m3.log", + "max_align": -10.0, + "min_align": -30.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 2.194934e-09, + "cosine_similarity": 1.001281, + "psnr": 70.68, + "mae": 1.493602e-05, + "relative_error": 2.26, + "layer": 8, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "value", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_attention_query", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0336_iter000_attention_L16_backward_post_attention_bf16_rank00_group000_grad_query/mxfp_scaling_test_20250923_100340_0336_iter000_attention_L16_backward_post_attention_bf16_rank00_group000_grad_query_fp8_e4m3.log", + "max_align": -16.0, + "min_align": -31.0, + "recommended_scale_exp": -17.0, + "recommended_scale_factor": 8e-06, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 1.621011e-11, + "cosine_similarity": 1.001402, + "psnr": 56.96, + "mae": 1.854218e-06, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "query", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0300_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100326_0300_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -26.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0002, + "is_at_max": true, + "mse": 3.226245e-07, + "cosine_similarity": 1.000642, + "psnr": 46.36, + "mae": 0.0003839541, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100320_0282_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100320_0282_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -22.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0, + "is_at_max": false, + "mse": 1.266206e-07, + "cosine_similarity": 0.999928, + "psnr": 45.79, + "mae": 0.0002403457, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100142_0001_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100142_0001_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -16.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 0.000670323, + "cosine_similarity": 1.001259, + "psnr": 45.66, + "mae": 0.01748489, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100144_0007_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_mm_output/mxfp_scaling_test_20250923_100144_0007_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_mm_output_fp8_e4m3.log", + "max_align": -3.0, + "min_align": -18.0, + "recommended_scale_exp": -4.0, + "recommended_scale_factor": 0.0625, + "composite_score": 1.0679, + "is_at_max": false, + "mse": 0.01108313, + "cosine_similarity": 1.226379, + "psnr": 47.73, + "mae": 0.07087165, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_attention_query", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100434_0546_iter000_attention_L1_backward_post_attention_bf16_rank00_group000_grad_query/mxfp_scaling_test_20250923_100434_0546_iter000_attention_L1_backward_post_attention_bf16_rank00_group000_grad_query_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -30.0, + "recommended_scale_exp": -8.0, + "recommended_scale_factor": 0.003906, + "composite_score": 1.0006, + "is_at_max": false, + "mse": 5.314578e-08, + "cosine_similarity": 1.002031, + "psnr": 76.31, + "mae": 4.745786e-05, + "relative_error": 2.26, + "layer": 1, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "query", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0294_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100326_0294_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -10.0, + "min_align": -22.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 1.074224e-06, + "cosine_similarity": 1.000727, + "psnr": 43.29, + "mae": 0.0007004272, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100149_0018_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100149_0018_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -24.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0007, + "is_at_max": false, + "mse": 2.671043e-05, + "cosine_similarity": 1.002226, + "psnr": 58.92, + "mae": 0.002535962, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100405_0444_iter000_attention_L8_backward_post_FA_bf16_rank00_group000_grad_output/mxfp_scaling_test_20250923_100405_0444_iter000_attention_L8_backward_post_FA_bf16_rank00_group000_grad_output_fp8_e4m3.log", + "max_align": -12.0, + "min_align": -26.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0001, + "is_at_max": true, + "mse": 2.69724e-08, + "cosine_similarity": 1.000409, + "psnr": 50.75, + "mae": 0.0001110152, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_attention_key", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100144_0006_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_key/mxfp_scaling_test_20250923_100144_0006_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_key_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -18.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002491283, + "cosine_similarity": 1.000544, + "psnr": 45.02, + "mae": 0.01063821, + "relative_error": 2.26, + "layer": 1, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "key", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0317_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100333_0317_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -21.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0008, + "is_at_max": false, + "mse": 0.0002601993, + "cosine_similarity": 1.002532, + "psnr": 45.25, + "mae": 0.01090546, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_attention_key", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100344_0351_iter000_attention_L15_backward_post_attention_bf16_rank00_group000_grad_key/mxfp_scaling_test_20250923_100344_0351_iter000_attention_L15_backward_post_attention_bf16_rank00_group000_grad_key_fp8_e4m3.log", + "max_align": -15.0, + "min_align": -34.0, + "recommended_scale_exp": -16.0, + "recommended_scale_factor": 1.5e-05, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 3.033322e-12, + "cosine_similarity": 1.001223, + "psnr": 67.28, + "mae": 6.217894e-07, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "key", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_attention_value", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100237_0150_iter000_attention_L8_forward_post_FA_bf16_rank00_group000_value/mxfp_scaling_test_20250923_100237_0150_iter000_attention_L8_forward_post_FA_bf16_rank00_group000_value_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -15.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002660364, + "cosine_similarity": 1.000514, + "psnr": 44.54, + "mae": 0.01100971, + "relative_error": 2.26, + "layer": 8, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "value", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100149_0016_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100149_0016_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -22.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.26664e-07, + "cosine_similarity": 1.000357, + "psnr": 46.5, + "mae": 0.0002405087, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100330_0308_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100330_0308_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -23.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0775, + "is_at_max": false, + "mse": 0.0001199241, + "cosine_similarity": 1.258218, + "psnr": 47.59, + "mae": 0.007329357, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100333_0312_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100333_0312_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -18.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0001926621, + "cosine_similarity": 1.000517, + "psnr": 43.17, + "mae": 0.009442221, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100405_0439_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100405_0439_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -25.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.980728e-09, + "cosine_similarity": 1.000054, + "psnr": 45.78, + "mae": 4.259601e-05, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100430_0538_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100430_0538_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -33.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.005, + "is_at_max": true, + "mse": 2.026327e-07, + "cosine_similarity": 1.016778, + "psnr": 78.91, + "mae": 0.000125176, + "relative_error": 2.26, + "layer": 1, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_attention_query", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100233_0145_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_query/mxfp_scaling_test_20250923_100233_0145_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_query_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -17.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002602611, + "cosine_similarity": 1.000569, + "psnr": 45.16, + "mae": 0.01090646, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "query", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100145_0008_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100145_0008_iter000_attention_L1_forward_pre_attention_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -22.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0682, + "is_at_max": false, + "mse": 8.624354e-05, + "cosine_similarity": 1.227334, + "psnr": 47.75, + "mae": 0.006256695, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_attention_probs", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100337_0333_iter000_attention_L16_backward_post_FA_bf16_rank00_group000_grad_attention_probs/mxfp_scaling_test_20250923_100337_0333_iter000_attention_L16_backward_post_FA_bf16_rank00_group000_grad_attention_probs_fp8_e4m3.log", + "max_align": -10.0, + "min_align": -25.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0601, + "is_at_max": false, + "mse": 7.459086e-07, + "cosine_similarity": 1.200445, + "psnr": 46.32, + "mae": 0.00057851, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "probs", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_linear_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100148_0014_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_output/mxfp_scaling_test_20250923_100148_0014_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_output_fp8_e4m3.log", + "max_align": -10.0, + "min_align": -25.0, + "recommended_scale_exp": -11.0, + "recommended_scale_factor": 0.000488, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 8.713858e-09, + "cosine_similarity": 1.001378, + "psnr": 65.88, + "mae": 6.087456e-05, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_attention_value", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0290_iter000_attention_L15_forward_post_FA_bf16_rank00_group000_value/mxfp_scaling_test_20250923_100326_0290_iter000_attention_L15_forward_post_FA_bf16_rank00_group000_value_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -16.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0002577995, + "cosine_similarity": 1.000508, + "psnr": 45.7, + "mae": 0.01069805, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "value", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_attention_buffer", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100231_0144_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_matmul_input_buffer/mxfp_scaling_test_20250923_100231_0144_iter000_attention_L8_forward_pre_attention_bf16_rank00_group000_matmul_input_buffer_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -18.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0003, + "is_at_max": false, + "mse": 2.224732e-05, + "cosine_similarity": 1.001163, + "psnr": 59.95, + "mae": 0.0005624519, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "buffer", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0301_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100326_0301_iter000_linear_L16_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -18.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 0.0007136815, + "cosine_similarity": 1.001201, + "psnr": 44.41, + "mae": 0.01792537, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100238_0155_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100238_0155_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -19.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 0.0007105522, + "cosine_similarity": 1.001197, + "psnr": 44.96, + "mae": 0.01805477, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_forward_attention_key", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100328_0306_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_key/mxfp_scaling_test_20250923_100328_0306_iter000_attention_L16_forward_pre_attention_bf16_rank00_group000_key_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -17.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.000291721, + "cosine_similarity": 1.000731, + "psnr": 43.89, + "mae": 0.01151947, + "relative_error": 2.25, + "layer": 16, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "key", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100430_0537_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100430_0537_iter000_linear_L1_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -24.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.971756e-09, + "cosine_similarity": 1.000057, + "psnr": 45.48, + "mae": 4.254775e-05, + "relative_error": 2.25, + "layer": 1, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100320_0281_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100320_0281_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -5.0, + "min_align": -19.0, + "recommended_scale_exp": -6.0, + "recommended_scale_factor": 0.015625, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 0.0007139186, + "cosine_similarity": 1.001212, + "psnr": 44.82, + "mae": 0.01791827, + "relative_error": 2.25, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_backward_attention_key", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100434_0547_iter000_attention_L1_backward_post_attention_bf16_rank00_group000_grad_key/mxfp_scaling_test_20250923_100434_0547_iter000_attention_L1_backward_post_attention_bf16_rank00_group000_grad_key_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -30.0, + "recommended_scale_exp": -8.0, + "recommended_scale_factor": 0.003906, + "composite_score": 1.0004, + "is_at_max": false, + "mse": 5.886837e-08, + "cosine_similarity": 1.001333, + "psnr": 74.29, + "mae": 3.130567e-05, + "relative_error": 2.26, + "layer": 1, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "key", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_forward_linear_B", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100237_0153_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_B/mxfp_scaling_test_20250923_100237_0153_iter000_linear_L8_forward_pre_linear_bf16_rank00_group000_input_B_fp8_e4m3.log", + "max_align": -14.0, + "min_align": -25.0, + "recommended_scale_exp": -15.0, + "recommended_scale_factor": 3.1e-05, + "composite_score": 1.0, + "is_at_max": false, + "mse": 3.972378e-09, + "cosine_similarity": 0.999888, + "psnr": 45.83, + "mae": 4.255031e-05, + "relative_error": 2.25, + "layer": 8, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "B", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100340_0339_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100340_0339_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -23.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0, + "is_at_max": false, + "mse": 1.266533e-07, + "cosine_similarity": 0.999924, + "psnr": 46.03, + "mae": 0.0002404344, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_linear_weight", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100405_0441_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_weight/mxfp_scaling_test_20250923_100405_0441_iter000_linear_L8_backward_pre_linear_bf16_rank00_group000_weight_fp8_e4m3.log", + "max_align": -11.0, + "min_align": -26.0, + "recommended_scale_exp": -12.0, + "recommended_scale_factor": 0.000244, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 1.268228e-07, + "cosine_similarity": 1.000367, + "psnr": 46.03, + "mae": 0.0002405346, + "relative_error": 2.25, + "layer": 8, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "weight", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100344_0352_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100344_0352_iter000_linear_L15_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -12.0, + "min_align": -35.0, + "recommended_scale_exp": -13.0, + "recommended_scale_factor": 0.000122, + "composite_score": 1.0011, + "is_at_max": false, + "mse": 1.610013e-10, + "cosine_similarity": 1.003673, + "psnr": 68.54, + "mae": 3.992613e-06, + "relative_error": 2.25, + "layer": 15, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L16_backward_linear_input", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100337_0330_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_input/mxfp_scaling_test_20250923_100337_0330_iter000_linear_L16_backward_pre_linear_bf16_rank00_group000_input_fp8_e4m3.log", + "max_align": -9.0, + "min_align": -16.0, + "recommended_scale_exp": -9.0, + "recommended_scale_factor": 0.001953, + "composite_score": 1.0002, + "is_at_max": true, + "mse": 1.353944e-06, + "cosine_similarity": 1.000537, + "psnr": 46.84, + "mae": 0.0007857367, + "relative_error": 2.25, + "layer": 16, + "pass_type": "backward", + "operation_type": "linear", + "tensor_type": "input", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L8_backward_attention_output", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100407_0447_iter000_attention_L8_backward_post_attention_bf16_rank00_group000_grad_output/mxfp_scaling_test_20250923_100407_0447_iter000_attention_L8_backward_post_attention_bf16_rank00_group000_grad_output_fp8_e4m3.log", + "max_align": -13.0, + "min_align": -39.0, + "recommended_scale_exp": -14.0, + "recommended_scale_factor": 6.1e-05, + "composite_score": 1.0232, + "is_at_max": false, + "mse": 6.131819e-14, + "cosine_similarity": 1.077592, + "psnr": 101.17, + "mae": 5.819336e-08, + "relative_error": 2.32, + "layer": 8, + "pass_type": "backward", + "operation_type": "attention", + "tensor_type": "output", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_attention_value", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100148_0010_iter000_attention_L1_forward_post_FA_bf16_rank00_group000_value/mxfp_scaling_test_20250923_100148_0010_iter000_attention_L1_forward_post_FA_bf16_rank00_group000_value_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -13.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0001, + "is_at_max": false, + "mse": 0.0002459944, + "cosine_similarity": 1.00047, + "psnr": 44.68, + "mae": 0.01056745, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "value", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L15_forward_linear_A", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100326_0292_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_A/mxfp_scaling_test_20250923_100326_0292_iter000_linear_L15_forward_pre_linear_bf16_rank00_group000_input_A_fp8_e4m3.log", + "max_align": -6.0, + "min_align": -19.0, + "recommended_scale_exp": -7.0, + "recommended_scale_factor": 0.007812, + "composite_score": 1.0002, + "is_at_max": false, + "mse": 0.0001800647, + "cosine_similarity": 1.000563, + "psnr": 43.99, + "mae": 0.008812865, + "relative_error": 2.22, + "layer": 15, + "pass_type": "forward", + "operation_type": "linear", + "tensor_type": "A", + "rank": 0, + "group": 0 + }, + { + "tensor_name": "L1_forward_attention_probs", + "file_path": "/Users/charles/Downloads/draw/scaling_analysis/20250923_100147_0009_iter000_attention_L1_forward_post_FA_bf16_rank00_group000_attention_probs/mxfp_scaling_test_20250923_100147_0009_iter000_attention_L1_forward_post_FA_bf16_rank00_group000_attention_probs_fp8_e4m3.log", + "max_align": -7.0, + "min_align": -11.0, + "recommended_scale_exp": -8.0, + "recommended_scale_factor": 0.003906, + "composite_score": 1.0315, + "is_at_max": false, + "mse": 1.00665e-10, + "cosine_similarity": 1.110716, + "psnr": 99.97, + "mae": 2.752513e-06, + "relative_error": 2.25, + "layer": 1, + "pass_type": "forward", + "operation_type": "attention", + "tensor_type": "probs", + "rank": 0, + "group": 0 + } + ], + "metadata": { + "total_files_analyzed": 136, + "analysis_date": "2025-09-23", + "base_directory": "/Users/charles/Downloads/draw", + "enhanced_analysis": true + } +} \ No newline at end of file diff --git a/visualization/overflow/enhanced_scaling_analyzer.py b/visualization/overflow/enhanced_scaling_analyzer.py new file mode 100644 index 0000000000..e2375809b6 --- /dev/null +++ b/visualization/overflow/enhanced_scaling_analyzer.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +""" +Enhanced Scaling Factor Analysis Tool + +This program analyzes all log files in the scaling_analysis directory with enhanced +tensor type classification for forward/backward passes and detailed tensor naming. + +Author: AI Assistant +Created: 2025-09-23 +""" + +import os +import re +import glob +import json +from pathlib import Path +from typing import Dict, List, Tuple, Optional +from dataclasses import dataclass +from collections import defaultdict + + +@dataclass +class EnhancedTensorAnalysis: + """Enhanced data class to store tensor analysis results""" + tensor_name: str + file_path: str + max_align: float + min_align: float + recommended_scale_exp: float + recommended_scale_factor: float + composite_score: float + is_at_max: bool + mse: float + cosine_similarity: float + psnr: float + mae: float + relative_error: float + + # Enhanced classification + layer: int + pass_type: str # forward, backward + operation_type: str # linear, attention + tensor_type: str # input_A, input_B, output, weight, query, key, value, etc. + rank: int + group: int + + +class EnhancedScalingAnalyzer: + """Enhanced analyzer class for scaling factor analysis""" + + def __init__(self, base_directory: str = "/Users/charles/Downloads/draw"): + self.base_directory = Path(base_directory) + self.scaling_dir = self.base_directory / "scaling_analysis" + self.results: List[EnhancedTensorAnalysis] = [] + + def find_log_files(self) -> List[Path]: + """Find all log files in the scaling analysis directory""" + pattern = str(self.scaling_dir / "**" / "*.log") + log_files = glob.glob(pattern, recursive=True) + return [Path(f) for f in log_files] + + def parse_log_file(self, log_file: Path) -> Optional[EnhancedTensorAnalysis]: + """Parse a single log file and extract scaling information""" + try: + with open(log_file, 'r', encoding='utf-8') as f: + content = f.read() + + # Extract enhanced tensor information from file path + tensor_info = self._extract_enhanced_tensor_info(log_file) + if not tensor_info: + return None + + # Extract alignment information + alignment_match = re.search(r'Calculated alignment \(reference\): max_align=(-?\d+\.?\d*), min_align=(-?\d+\.?\d*)', content) + if not alignment_match: + print(f"Warning: Could not find alignment info in {log_file}") + return None + + max_align = float(alignment_match.group(1)) + min_align = float(alignment_match.group(2)) + + # Extract recommended scaling information + scale_exp_match = re.search(r'⭐ RECOMMENDED Scaling Factor: ([\d\.e\-\+]+)\s+.*?Scale Exponent: (-?\d+\.?\d*)', content, re.DOTALL) + if not scale_exp_match: + print(f"Warning: Could not find recommended scaling info in {log_file}") + return None + + recommended_scale_factor = float(scale_exp_match.group(1)) + recommended_scale_exp = float(scale_exp_match.group(2)) + + # Extract composite score + composite_match = re.search(r'Composite Score: ([\d\.e\-\+]+)', content) + composite_score = float(composite_match.group(1)) if composite_match else 0.0 + + # Extract performance metrics + mse_match = re.search(r'- MSE: ([\d\.e\-\+]+)', content) + cosine_match = re.search(r'- Cosine Similarity: ([\d\.e\-\+]+)', content) + psnr_match = re.search(r'- PSNR: ([\d\.e\-\+]+) dB', content) + mae_match = re.search(r'- MAE: ([\d\.e\-\+]+)', content) + rel_error_match = re.search(r'- Relative Error: ([\d\.]+)%', content) + + mse = float(mse_match.group(1)) if mse_match else 0.0 + cosine_similarity = float(cosine_match.group(1)) if cosine_match else 0.0 + psnr = float(psnr_match.group(1)) if psnr_match else 0.0 + mae = float(mae_match.group(1)) if mae_match else 0.0 + relative_error = float(rel_error_match.group(1)) if rel_error_match else 0.0 + + # Check if recommended scaling is at maximum value + is_at_max = abs(recommended_scale_exp - max_align) < 1e-6 + + return EnhancedTensorAnalysis( + tensor_name=tensor_info['tensor_name'], + file_path=str(log_file), + max_align=max_align, + min_align=min_align, + recommended_scale_exp=recommended_scale_exp, + recommended_scale_factor=recommended_scale_factor, + composite_score=composite_score, + is_at_max=is_at_max, + mse=mse, + cosine_similarity=cosine_similarity, + psnr=psnr, + mae=mae, + relative_error=relative_error, + layer=tensor_info['layer'], + pass_type=tensor_info['pass_type'], + operation_type=tensor_info['operation_type'], + tensor_type=tensor_info['tensor_type'], + rank=tensor_info['rank'], + group=tensor_info['group'] + ) + + except Exception as e: + print(f"Error parsing {log_file}: {e}") + return None + + def _extract_enhanced_tensor_info(self, log_file: Path) -> Optional[Dict]: + """Extract enhanced tensor information from file path""" + try: + # Extract from the directory name + parent_dir = log_file.parent.name + + # Parse the enhanced naming format: + # 20250923_100142_0001_iter000_linear_L1_forward_pre_linear_bf16_rank00_group000_input_A + + # Extract layer + layer_match = re.search(r'_L(\d+)_', parent_dir) + if not layer_match: + return None + layer = int(layer_match.group(1)) + + # Extract pass type (forward/backward) + if '_forward_' in parent_dir: + pass_type = 'forward' + elif '_backward_' in parent_dir: + pass_type = 'backward' + else: + pass_type = 'unknown' + + # Extract operation type + if '_linear_' in parent_dir: + operation_type = 'linear' + elif '_attention_' in parent_dir: + operation_type = 'attention' + else: + operation_type = 'unknown' + + # Extract rank + rank_match = re.search(r'_rank(\d+)_', parent_dir) + rank = int(rank_match.group(1)) if rank_match else 0 + + # Extract group + group_match = re.search(r'_group(\d+)_', parent_dir) + group = int(group_match.group(1)) if group_match else 0 + + # Extract tensor type (the last part after the last underscore) + parts = parent_dir.split('_') + tensor_type = parts[-1] # input_A, input_B, output, weight, query, key, value, etc. + + # Create a more readable tensor name + tensor_name = f"L{layer}_{pass_type}_{operation_type}_{tensor_type}" + + return { + 'tensor_name': tensor_name, + 'layer': layer, + 'pass_type': pass_type, + 'operation_type': operation_type, + 'tensor_type': tensor_type, + 'rank': rank, + 'group': group + } + + except Exception as e: + print(f"Error extracting tensor info from {log_file}: {e}") + return None + + def analyze_all_files(self) -> None: + """Analyze all log files and store results""" + log_files = self.find_log_files() + print(f"Found {len(log_files)} log files to analyze...") + + successful_parses = 0 + for log_file in log_files: + result = self.parse_log_file(log_file) + if result: + self.results.append(result) + successful_parses += 1 + + print(f"Successfully parsed {successful_parses} out of {len(log_files)} log files") + + def generate_enhanced_summary(self) -> Dict: + """Generate enhanced summary statistics""" + if not self.results: + return {} + + total_tensors = len(self.results) + at_max_count = sum(1 for r in self.results if r.is_at_max) + not_at_max_count = total_tensors - at_max_count + + # Group by various dimensions + layer_stats = defaultdict(lambda: {'total': 0, 'at_max': 0}) + pass_type_stats = defaultdict(lambda: {'total': 0, 'at_max': 0}) + operation_type_stats = defaultdict(lambda: {'total': 0, 'at_max': 0}) + tensor_type_stats = defaultdict(lambda: {'total': 0, 'at_max': 0}) + + # Combined statistics + layer_pass_stats = defaultdict(lambda: {'total': 0, 'at_max': 0}) + layer_operation_stats = defaultdict(lambda: {'total': 0, 'at_max': 0}) + + for result in self.results: + # Layer statistics + layer_key = f"Layer_{result.layer}" + layer_stats[layer_key]['total'] += 1 + if result.is_at_max: + layer_stats[layer_key]['at_max'] += 1 + + # Pass type statistics + pass_type_stats[result.pass_type]['total'] += 1 + if result.is_at_max: + pass_type_stats[result.pass_type]['at_max'] += 1 + + # Operation type statistics + operation_type_stats[result.operation_type]['total'] += 1 + if result.is_at_max: + operation_type_stats[result.operation_type]['at_max'] += 1 + + # Tensor type statistics + tensor_type_stats[result.tensor_type]['total'] += 1 + if result.is_at_max: + tensor_type_stats[result.tensor_type]['at_max'] += 1 + + # Combined statistics + layer_pass_key = f"L{result.layer}_{result.pass_type}" + layer_pass_stats[layer_pass_key]['total'] += 1 + if result.is_at_max: + layer_pass_stats[layer_pass_key]['at_max'] += 1 + + layer_operation_key = f"L{result.layer}_{result.operation_type}" + layer_operation_stats[layer_operation_key]['total'] += 1 + if result.is_at_max: + layer_operation_stats[layer_operation_key]['at_max'] += 1 + + return { + 'total_tensors': total_tensors, + 'at_max_count': at_max_count, + 'not_at_max_count': not_at_max_count, + 'at_max_percentage': (at_max_count / total_tensors) * 100 if total_tensors > 0 else 0, + 'layer_stats': dict(layer_stats), + 'pass_type_stats': dict(pass_type_stats), + 'operation_type_stats': dict(operation_type_stats), + 'tensor_type_stats': dict(tensor_type_stats), + 'layer_pass_stats': dict(layer_pass_stats), + 'layer_operation_stats': dict(layer_operation_stats) + } + + def print_enhanced_report(self) -> None: + """Print enhanced analysis report""" + if not self.results: + print("No results to report!") + return + + summary = self.generate_enhanced_summary() + + print("\n" + "="*100) + print("ENHANCED SCALING FACTOR ANALYSIS REPORT") + print("="*100) + + print(f"\n📊 OVERALL SUMMARY:") + print(f" Total Tensors Analyzed: {summary['total_tensors']}") + print(f" Tensors at Maximum Scaling: {summary['at_max_count']} ({summary['at_max_percentage']:.1f}%)") + print(f" Tensors NOT at Maximum: {summary['not_at_max_count']} ({100-summary['at_max_percentage']:.1f}%)") + + # Pass type breakdown + print(f"\n🔄 BREAKDOWN BY PASS TYPE:") + for pass_type, stats in summary['pass_type_stats'].items(): + percentage = (stats['at_max'] / stats['total']) * 100 if stats['total'] > 0 else 0 + print(f" {pass_type.upper():10} | Total: {stats['total']:3d} | At Max: {stats['at_max']:3d} ({percentage:5.1f}%)") + + # Operation type breakdown + print(f"\n⚙️ BREAKDOWN BY OPERATION TYPE:") + for op_type, stats in summary['operation_type_stats'].items(): + percentage = (stats['at_max'] / stats['total']) * 100 if stats['total'] > 0 else 0 + print(f" {op_type.upper():10} | Total: {stats['total']:3d} | At Max: {stats['at_max']:3d} ({percentage:5.1f}%)") + + # Tensor type breakdown + print(f"\n📋 BREAKDOWN BY TENSOR TYPE:") + for tensor_type, stats in summary['tensor_type_stats'].items(): + percentage = (stats['at_max'] / stats['total']) * 100 if stats['total'] > 0 else 0 + print(f" {tensor_type.upper():15} | Total: {stats['total']:3d} | At Max: {stats['at_max']:3d} ({percentage:5.1f}%)") + + # Layer breakdown + if summary['layer_stats']: + print(f"\n🏗️ BREAKDOWN BY LAYER:") + for layer, stats in sorted(summary['layer_stats'].items()): + percentage = (stats['at_max'] / stats['total']) * 100 if stats['total'] > 0 else 0 + print(f" {layer:10} | Total: {stats['total']:3d} | At Max: {stats['at_max']:3d} ({percentage:5.1f}%)") + + # Layer-Pass combination + print(f"\n🔄🏗️ BREAKDOWN BY LAYER-PASS COMBINATION:") + for layer_pass, stats in sorted(summary['layer_pass_stats'].items()): + percentage = (stats['at_max'] / stats['total']) * 100 if stats['total'] > 0 else 0 + print(f" {layer_pass:15} | Total: {stats['total']:3d} | At Max: {stats['at_max']:3d} ({percentage:5.1f}%)") + + # Layer-Operation combination + print(f"\n⚙️🏗️ BREAKDOWN BY LAYER-OPERATION COMBINATION:") + for layer_op, stats in sorted(summary['layer_operation_stats'].items()): + percentage = (stats['at_max'] / stats['total']) * 100 if stats['total'] > 0 else 0 + print(f" {layer_op:15} | Total: {stats['total']:3d} | At Max: {stats['at_max']:3d} ({percentage:5.1f}%)") + + print("\n" + "="*100) + + def save_enhanced_results_to_json(self, output_file: str = "enhanced_scaling_analysis_results.json") -> None: + """Save enhanced results to JSON file""" + output_path = self.base_directory / output_file + + # Convert dataclass objects to dictionaries + results_dict = [] + for result in self.results: + results_dict.append({ + 'tensor_name': result.tensor_name, + 'file_path': result.file_path, + 'max_align': result.max_align, + 'min_align': result.min_align, + 'recommended_scale_exp': result.recommended_scale_exp, + 'recommended_scale_factor': result.recommended_scale_factor, + 'composite_score': result.composite_score, + 'is_at_max': result.is_at_max, + 'mse': result.mse, + 'cosine_similarity': result.cosine_similarity, + 'psnr': result.psnr, + 'mae': result.mae, + 'relative_error': result.relative_error, + 'layer': result.layer, + 'pass_type': result.pass_type, + 'operation_type': result.operation_type, + 'tensor_type': result.tensor_type, + 'rank': result.rank, + 'group': result.group + }) + + summary = self.generate_enhanced_summary() + + output_data = { + 'analysis_summary': summary, + 'detailed_results': results_dict, + 'metadata': { + 'total_files_analyzed': len(self.results), + 'analysis_date': '2025-09-23', + 'base_directory': str(self.base_directory), + 'enhanced_analysis': True + } + } + + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(output_data, f, indent=2, ensure_ascii=False) + + print(f"\n💾 Enhanced results saved to: {output_path}") + + +def main(): + """Main function""" + print("🔍 Starting Enhanced Scaling Factor Analysis...") + + # Initialize analyzer + analyzer = EnhancedScalingAnalyzer() + + # Check if scaling directory exists + if not analyzer.scaling_dir.exists(): + print(f"❌ Error: Scaling analysis directory not found: {analyzer.scaling_dir}") + return + + # Analyze all files + analyzer.analyze_all_files() + + if not analyzer.results: + print("❌ No valid results found!") + return + + # Print enhanced report + analyzer.print_enhanced_report() + + # Save results to JSON + analyzer.save_enhanced_results_to_json() + + print("\n✅ Enhanced analysis completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/visualization/overflow/mxfp_scaling_test.py b/visualization/overflow/mxfp_scaling_test.py new file mode 100644 index 0000000000..f016322a56 --- /dev/null +++ b/visualization/overflow/mxfp_scaling_test.py @@ -0,0 +1,1075 @@ +#!/usr/bin/env python3 +""" +MXFP Scaling Test Tool +Tests different scaling strategies for MXFP quantization and evaluates their impact on accuracy. +""" + +import torch +import numpy as np +import matplotlib.pyplot as plt +import argparse +from pathlib import Path +import sys +import os +import logging +from datetime import datetime + +# Add the parent directory to path to import mxfp module +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from quant.mxfp import _quantize_mx, _get_format_params, ElemFormat + +def setup_logging(output_dir, tensor_name, elem_format): + """ + Setup logging to both console and file. + + Args: + output_dir (Path): Output directory for log file + tensor_name (str): Name of the input tensor + elem_format (str): Element format being tested + + Returns: + logging.Logger: Configured logger + """ + # Create logger + logger = logging.getLogger('mxfp_scaling_test') + logger.setLevel(logging.INFO) + + # Clear any existing handlers + logger.handlers.clear() + + # Create formatter + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # File handler + log_filename = f"mxfp_scaling_test_{tensor_name}_{elem_format}.log" + log_path = output_dir / log_filename + + file_handler = logging.FileHandler(log_path, mode='w', encoding='utf-8') + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + # Log initial information + logger.info("=" * 80) + logger.info("MXFP SCALING TEST LOG") + logger.info("=" * 80) + logger.info(f"Test started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + logger.info(f"Input tensor: {tensor_name}") + logger.info(f"Element format: {elem_format}") + logger.info(f"Output directory: {output_dir}") + logger.info("=" * 80) + + return logger + +def calculate_metrics(original_tensor, quantized_tensor): + """ + Calculate various metrics between original and quantized tensors. + + Args: + original_tensor (torch.Tensor): Original BF16 tensor + quantized_tensor (torch.Tensor): Quantized tensor + + Returns: + dict: Dictionary containing all calculated metrics + """ + # Convert to float32 for calculation + orig_f32 = original_tensor.float() + quant_f32 = quantized_tensor.float() + + # MSE (Mean Squared Error) + mse = torch.mean((orig_f32 - quant_f32) ** 2).item() + + # RMSE (Root Mean Squared Error) + rmse = torch.sqrt(torch.mean((orig_f32 - quant_f32) ** 2)).item() + + # Cosine Similarity + orig_flat = orig_f32.flatten() + quant_flat = quant_f32.flatten() + + # Avoid division by zero + orig_norm = torch.norm(orig_flat) + quant_norm = torch.norm(quant_flat) + + if orig_norm > 0 and quant_norm > 0: + cosine_sim = torch.dot(orig_flat, quant_flat) / (orig_norm * quant_norm) + cosine_sim = cosine_sim.item() + else: + cosine_sim = 1.0 if orig_norm == 0 and quant_norm == 0 else 0.0 + + # PSNR (Peak Signal-to-Noise Ratio) + if mse > 0: + # Use the maximum value in original tensor as peak signal + max_val = torch.max(torch.abs(orig_f32)).item() + psnr = 20 * np.log10(max_val / np.sqrt(mse)) if max_val > 0 else float('inf') + else: + psnr = float('inf') + + # MAE (Mean Absolute Error) + mae = torch.mean(torch.abs(orig_f32 - quant_f32)).item() + + # Maximum Absolute Error + max_abs_error = torch.max(torch.abs(orig_f32 - quant_f32)).item() + + # Relative Error (percentage) + orig_mean_abs = torch.mean(torch.abs(orig_f32)).item() + relative_error = (mae / orig_mean_abs * 100) if orig_mean_abs > 0 else 0.0 + + return { + 'mse': mse, + 'rmse': rmse, + 'cosine_similarity': cosine_sim, + 'psnr': psnr, + 'mae': mae, + 'max_abs_error': max_abs_error, + 'relative_error': relative_error + } + +def test_scaling_levels(input_tensor, elem_format='fp8_e4m3', scale_bits=8, + max_scale_exp=10, min_scale_exp=-10, logger=None): + """ + Test different scaling levels for MXFP quantization. + + Args: + input_tensor (torch.Tensor): Input BF16 tensor + elem_format (str): Element format for quantization + scale_bits (int): Number of scale bits + max_scale_exp (int): Maximum scale exponent (aligned with max value) + min_scale_exp (int): Minimum scale exponent (aligned with min value) + logger: Logger instance for output + + Returns: + dict: Results for each scaling level (all integers in range) + """ + # Get format parameters + ebits, mbits, emax, max_norm, min_norm = _get_format_params(elem_format) + + # Calculate tensor statistics for alignment + tensor_abs_max = torch.max(torch.abs(input_tensor)).item() + tensor_abs_min = torch.min(torch.abs(input_tensor[input_tensor != 0])).item() if torch.any(input_tensor != 0) else tensor_abs_max + + # Calculate emax for the format (following mxfp.py logic) + emax = 2**(ebits - 1) - 1 if ebits > 0 else 0 + + # Calculate scale exponents following mxfp.py _quantize_mx logic: + # In mxfp.py: + # 1. shared_exp = floor(log2(max_abs_value)) (from _shared_exponents with method="max") + # 2. shared_exp = shared_exp - emax (offset by emax) + # 3. A = A / (2^shared_exp) (apply scaling) + # + # So the actual scaling factor used by mxfp.py is: 2^(floor(log2(max)) - emax) + # + # For alignment calculations: + # - Max alignment: Use the same logic as mxfp.py (global max alignment) + # This gives: scale_exp = floor(log2(tensor_abs_max)) - emax + # - Min alignment: Find scale_exp such that tensor_abs_min / (2^scale_exp) >= min_norm + # So scale_exp <= log2(tensor_abs_min / min_norm) + + # Calculate the scale exponent that mxfp.py would use (for reference) + tensor_shared_exp = np.floor(np.log2(tensor_abs_max)) if tensor_abs_max > 0 else 0 + max_align_exp = tensor_shared_exp - emax # This is what mxfp.py actually uses + + # Calculate min alignment: find scale_exp such that scaled min >= min_norm + min_align_exp = np.floor(np.log2(tensor_abs_min / min_norm)) if tensor_abs_min > 0 and min_norm > 0 else max_align_exp + + # Use user-specified parameters directly, with calculated values as fallback for default parameters + if max_scale_exp == 10: # Default value, use calculated + max_scale_exp = max_align_exp + if min_scale_exp == -10: # Default value, use calculated + min_scale_exp = min_align_exp + + # Ensure max_scale_exp >= min_scale_exp + if max_scale_exp < min_scale_exp: + max_scale_exp, min_scale_exp = min_scale_exp, max_scale_exp + + # Generate integer scale exponents from max to min (inclusive) + max_exp_int = int(max_scale_exp) + min_exp_int = int(min_scale_exp) + + if max_exp_int == min_exp_int: + # Single point range - use the same integer value + scale_exponents = np.array([max_exp_int]) + else: + # Create integer range from max to min (inclusive) + scale_exponents = np.arange(max_exp_int, min_exp_int - 1, -1, dtype=int) + + results = { + 'scale_exponents': scale_exponents.tolist(), + 'metrics': {}, + 'elem_format': elem_format, + 'scale_bits': scale_bits, + 'format_params': { + 'ebits': ebits, + 'mbits': mbits, + 'emax': emax, + 'max_norm': max_norm, + 'min_norm': min_norm + } + } + + log_func = logger.info if logger else print + log_func(f"Tensor absolute value range: [{tensor_abs_min:.6e}, {tensor_abs_max:.6e}]") + log_func(f"Format range: max_norm={max_norm:.6e}, min_norm={min_norm:.6e}") + log_func(f"Calculated alignment (reference): max_align={max_align_exp:.2f}, min_align={min_align_exp:.2f}") + log_func(f"Testing integer scaling levels from {max_scale_exp:.2f} to {min_scale_exp:.2f}") + log_func(f"Element format: {elem_format} (e{ebits}m{mbits})") + log_func(f"Scale bits: {scale_bits}") + log_func("-" * 60) + + for i, scale_exp in enumerate(scale_exponents): + log_func(f"Testing scale exponent {scale_exp} ({i+1}/{len(scale_exponents)})...") + + # Create a custom quantize function with fixed scale exponent + quantized_tensor, overflow_underflow_analysis = quantize_with_fixed_scale( + input_tensor, elem_format, scale_bits, scale_exp, + ebits, mbits, max_norm + ) + + # Calculate metrics + metrics = calculate_metrics(input_tensor, quantized_tensor) + + # Store results + results['metrics'][f'scale_{i}'] = { + 'scale_exponent': float(scale_exp), + 'metrics': metrics, + 'overflow_underflow_analysis': overflow_underflow_analysis + } + + # Print current metrics + log_func(f" MSE: {metrics['mse']:.6e}, " + f"Cosine Sim: {metrics['cosine_similarity']:.6f}, " + f"PSNR: {metrics['psnr']:.2f} dB") + + return results + +def analyze_scaling_results(results, logger=None): + """ + Analyze scaling test results and recommend optimal scaling factors. + + Args: + results (dict): Results from test_scaling_levels + logger: Logger instance for output + + Returns: + dict: Analysis results with recommendations + """ + log_func = logger.info if logger else print + + scale_exponents = results['scale_exponents'] + elem_format = results['elem_format'] + format_params = results['format_params'] + + # Extract metrics for analysis + metrics_data = {} + for metric_name in ['mse', 'cosine_similarity', 'psnr', 'mae', 'relative_error']: + metrics_data[metric_name] = [] + for i in range(len(scale_exponents)): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + metrics_data[metric_name].append(results['metrics'][scale_key]['metrics'][metric_name]) + + # Find best indices for different metrics + # Use tolerance to handle numerical precision issues + tolerance = 1e-10 + + def find_best_indices(values, is_better_func): + """Find all indices with the best value, return the one with the largest scale exponent when tied""" + best_value = is_better_func(values) + if is_better_func == min: + best_indices = [i for i, v in enumerate(values) if abs(v - best_value) < tolerance] + else: # max + best_indices = [i for i, v in enumerate(values) if abs(v - best_value) < tolerance] + + if best_indices: + # When there are ties, choose the one with the largest scale exponent (closest to 0) + # Since scale_exponents are in descending order, the first index has the largest value + return best_indices[0] + else: + return 0 + + best_mse_idx = find_best_indices(metrics_data['mse'], min) + best_cosine_idx = find_best_indices(metrics_data['cosine_similarity'], max) + best_psnr_idx = find_best_indices(metrics_data['psnr'], max) + best_mae_idx = find_best_indices(metrics_data['mae'], min) + best_relative_error_idx = find_best_indices(metrics_data['relative_error'], min) + + # Calculate composite scores + # Normalize metrics to [0, 1] range for comparison + mse_normalized = 1 - (np.array(metrics_data['mse']) - np.min(metrics_data['mse'])) / (np.max(metrics_data['mse']) - np.min(metrics_data['mse']) + 1e-10) + cosine_normalized = np.array(metrics_data['cosine_similarity']) + psnr_normalized = (np.array(metrics_data['psnr']) - np.min(metrics_data['psnr'])) / (np.max(metrics_data['psnr']) - np.min(metrics_data['psnr']) + 1e-10) + mae_normalized = 1 - (np.array(metrics_data['mae']) - np.min(metrics_data['mae'])) / (np.max(metrics_data['mae']) - np.min(metrics_data['mae']) + 1e-10) + relative_error_normalized = 1 - (np.array(metrics_data['relative_error']) - np.min(metrics_data['relative_error'])) / (np.max(metrics_data['relative_error']) - np.min(metrics_data['relative_error']) + 1e-10) + + # Weighted composite score (can be adjusted based on priorities) + composite_scores = ( + 0.3 * mse_normalized + # Lower MSE is better + 0.3 * cosine_normalized + # Higher cosine similarity is better + 0.2 * psnr_normalized + # Higher PSNR is better + 0.1 * mae_normalized + # Lower MAE is better + 0.1 * relative_error_normalized # Lower relative error is better + ) + + # Find best composite index, handling ties by choosing larger scale exponent + best_composite_score = np.max(composite_scores) + best_composite_indices = [i for i, score in enumerate(composite_scores) if abs(score - best_composite_score) < tolerance] + # When there are ties, choose the one with the largest scale exponent (first index) + best_composite_idx = best_composite_indices[0] if best_composite_indices else 0 + + # Calculate scaling factor from scale exponent + def exp_to_factor(exp): + return 2 ** exp + + # Analysis results + analysis = { + 'best_mse': { + 'index': best_mse_idx, + 'scale_exp': scale_exponents[best_mse_idx], + 'scale_factor': exp_to_factor(scale_exponents[best_mse_idx]), + 'metrics': { + 'mse': metrics_data['mse'][best_mse_idx], + 'cosine_similarity': metrics_data['cosine_similarity'][best_mse_idx], + 'psnr': metrics_data['psnr'][best_mse_idx], + 'mae': metrics_data['mae'][best_mse_idx], + 'relative_error': metrics_data['relative_error'][best_mse_idx] + } + }, + 'best_cosine': { + 'index': best_cosine_idx, + 'scale_exp': scale_exponents[best_cosine_idx], + 'scale_factor': exp_to_factor(scale_exponents[best_cosine_idx]), + 'metrics': { + 'mse': metrics_data['mse'][best_cosine_idx], + 'cosine_similarity': metrics_data['cosine_similarity'][best_cosine_idx], + 'psnr': metrics_data['psnr'][best_cosine_idx], + 'mae': metrics_data['mae'][best_cosine_idx], + 'relative_error': metrics_data['relative_error'][best_cosine_idx] + } + }, + 'best_psnr': { + 'index': best_psnr_idx, + 'scale_exp': scale_exponents[best_psnr_idx], + 'scale_factor': exp_to_factor(scale_exponents[best_psnr_idx]), + 'metrics': { + 'mse': metrics_data['mse'][best_psnr_idx], + 'cosine_similarity': metrics_data['cosine_similarity'][best_psnr_idx], + 'psnr': metrics_data['psnr'][best_psnr_idx], + 'mae': metrics_data['mae'][best_psnr_idx], + 'relative_error': metrics_data['relative_error'][best_psnr_idx] + } + }, + 'best_composite': { + 'index': best_composite_idx, + 'scale_exp': scale_exponents[best_composite_idx], + 'scale_factor': exp_to_factor(scale_exponents[best_composite_idx]), + 'composite_score': composite_scores[best_composite_idx], + 'metrics': { + 'mse': metrics_data['mse'][best_composite_idx], + 'cosine_similarity': metrics_data['cosine_similarity'][best_composite_idx], + 'psnr': metrics_data['psnr'][best_composite_idx], + 'mae': metrics_data['mae'][best_composite_idx], + 'relative_error': metrics_data['relative_error'][best_composite_idx] + } + } + } + + # Log detailed analysis + log_func("\n" + "=" * 80) + log_func("SCALING FACTOR ANALYSIS & RECOMMENDATIONS") + log_func("=" * 80) + + log_func(f"Format: {elem_format} (e{format_params['ebits']}m{format_params['mbits']})") + log_func(f"Tested {len(scale_exponents)} scaling levels from {scale_exponents[0]:.2f} to {scale_exponents[-1]:.2f}") + log_func("-" * 80) + + # Check for ties in individual metrics + individual_indices = [best_mse_idx, best_cosine_idx, best_psnr_idx, best_mae_idx, best_relative_error_idx] + individual_names = ['MSE', 'Cosine Similarity', 'PSNR', 'MAE', 'Relative Error'] + + # Find if all individual metrics point to the same scale exponent + if len(set(individual_indices)) == 1: + log_func("🎯 ALL INDIVIDUAL METRICS AGREE:") + log_func("-" * 40) + log_func(f" All metrics recommend Scale Exp = {scale_exponents[individual_indices[0]]:.2f}") + log_func(f" Scale Factor = {analysis['best_mse']['scale_factor']:.6f}") + + # Check if there were ties and we chose the larger scale exponent + scale_exp = scale_exponents[individual_indices[0]] + all_same_values = [] + for i, (name, idx) in enumerate(zip(['MSE', 'Cosine', 'PSNR', 'MAE', 'Relative'], individual_indices)): + metric_values = [metrics_data[metric_name][idx] for metric_name in ['mse', 'cosine_similarity', 'psnr', 'mae', 'relative_error']] + all_same_values.extend([(name, metrics_data['mse'][idx]), (name, metrics_data['cosine_similarity'][idx])]) + + # Check for ties in the range + tied_indices = [] + for i in range(len(scale_exponents)): + if abs(scale_exponents[i] - scale_exp) < 0.1: # Check for nearby scale exponents + tied_indices.append(i) + + if len(tied_indices) > 1: + log_func(f" Note: Multiple scale exponents ({', '.join([f'{scale_exponents[i]:.2f}' for i in tied_indices])})") + log_func(f" produced identical performance. Selected largest: {scale_exp:.2f}") + else: + # Best results for individual metrics + log_func("INDIVIDUAL METRIC OPTIMA:") + log_func("-" * 40) + + log_func(f"🏆 Best MSE: Scale Exp = {analysis['best_mse']['scale_exp']:.2f}, Factor = {analysis['best_mse']['scale_factor']:.6f}") + log_func(f" MSE: {analysis['best_mse']['metrics']['mse']:.6e}, Cosine: {analysis['best_mse']['metrics']['cosine_similarity']:.6f}, PSNR: {analysis['best_mse']['metrics']['psnr']:.2f} dB") + + log_func(f"🎯 Best Cosine Similarity: Scale Exp = {analysis['best_cosine']['scale_exp']:.2f}, Factor = {analysis['best_cosine']['scale_factor']:.6f}") + log_func(f" MSE: {analysis['best_cosine']['metrics']['mse']:.6e}, Cosine: {analysis['best_cosine']['metrics']['cosine_similarity']:.6f}, PSNR: {analysis['best_cosine']['metrics']['psnr']:.2f} dB") + + log_func(f"📊 Best PSNR: Scale Exp = {analysis['best_psnr']['scale_exp']:.2f}, Factor = {analysis['best_psnr']['scale_factor']:.6f}") + log_func(f" MSE: {analysis['best_psnr']['metrics']['mse']:.6e}, Cosine: {analysis['best_psnr']['metrics']['cosine_similarity']:.6f}, PSNR: {analysis['best_psnr']['metrics']['psnr']:.2f} dB") + + # Composite recommendation + log_func("-" * 80) + log_func("COMPOSITE RECOMMENDATION:") + log_func("-" * 40) + + # Check if composite recommendation agrees with individual metrics + if len(set(individual_indices)) == 1 and individual_indices[0] == best_composite_idx: + log_func("🎯 UNANIMOUS RECOMMENDATION:") + log_func("-" * 40) + log_func(f" All individual metrics AND composite score agree!") + elif best_composite_idx in individual_indices: + log_func("📊 BALANCED RECOMMENDATION:") + log_func("-" * 40) + log_func(f" Composite score matches some individual metrics") + else: + log_func("⚖️ COMPOSITE RECOMMENDATION:") + log_func("-" * 40) + log_func(f" Composite score provides balanced recommendation") + + log_func(f"⭐ RECOMMENDED Scaling Factor: {analysis['best_composite']['scale_factor']:.6f}") + log_func(f" Scale Exponent: {analysis['best_composite']['scale_exp']:.2f}") + log_func(f" Composite Score: {analysis['best_composite']['composite_score']:.4f}") + log_func(f" Balanced Performance:") + log_func(f" - MSE: {analysis['best_composite']['metrics']['mse']:.6e}") + log_func(f" - Cosine Similarity: {analysis['best_composite']['metrics']['cosine_similarity']:.6f}") + log_func(f" - PSNR: {analysis['best_composite']['metrics']['psnr']:.2f} dB") + log_func(f" - MAE: {analysis['best_composite']['metrics']['mae']:.6e}") + log_func(f" - Relative Error: {analysis['best_composite']['metrics']['relative_error']:.2f}%") + + # Performance analysis + log_func("-" * 80) + log_func("PERFORMANCE ANALYSIS:") + log_func("-" * 40) + + # Calculate performance ranges + mse_range = np.max(metrics_data['mse']) - np.min(metrics_data['mse']) + cosine_range = np.max(metrics_data['cosine_similarity']) - np.min(metrics_data['cosine_similarity']) + psnr_range = np.max(metrics_data['psnr']) - np.min(metrics_data['psnr']) + + log_func(f"MSE Range: {np.min(metrics_data['mse']):.6e} to {np.max(metrics_data['mse']):.6e} (Δ: {mse_range:.6e})") + log_func(f"Cosine Range: {np.min(metrics_data['cosine_similarity']):.6f} to {np.max(metrics_data['cosine_similarity']):.6f} (Δ: {cosine_range:.6f})") + log_func(f"PSNR Range: {np.min(metrics_data['psnr']):.2f} to {np.max(metrics_data['psnr']):.2f} dB (Δ: {psnr_range:.2f} dB)") + + # Stability analysis + mse_std = np.std(metrics_data['mse']) + cosine_std = np.std(metrics_data['cosine_similarity']) + + log_func(f"MSE Stability (std): {mse_std:.6e}") + log_func(f"Cosine Stability (std): {cosine_std:.6f}") + + # Recommendations based on analysis + log_func("-" * 80) + log_func("RECOMMENDATIONS:") + log_func("-" * 40) + + if mse_range / np.min(metrics_data['mse']) < 0.1: + log_func("✅ MSE is relatively stable across scaling factors - any factor in the tested range should work well") + else: + log_func("⚠️ MSE varies significantly with scaling - choose the recommended factor carefully") + + if cosine_range < 0.01: + log_func("✅ Cosine similarity is very stable - scaling factor has minimal impact on direction preservation") + else: + log_func("⚠️ Cosine similarity varies with scaling - consider the impact on vector direction") + + if psnr_range > 20: + log_func("📈 Large PSNR range indicates significant quality differences - scaling factor choice is critical") + elif psnr_range > 10: + log_func("📊 Moderate PSNR range - scaling factor has noticeable impact on quality") + else: + log_func("✅ Small PSNR range - scaling factor has limited impact on quality") + + # Final recommendation + log_func("-" * 80) + log_func("FINAL RECOMMENDATION:") + log_func("-" * 40) + log_func(f"🎯 Use scaling factor: {analysis['best_composite']['scale_factor']:.6f}") + log_func(f" This provides the best balance of accuracy and stability for {elem_format} quantization") + log_func(f" Scale exponent: {analysis['best_composite']['scale_exp']:.2f}") + + if analysis['best_composite']['index'] == 0: + log_func(" 📍 This is at the maximum alignment end (minimal overflow risk)") + elif analysis['best_composite']['index'] == len(scale_exponents) - 1: + log_func(" 📍 This is at the minimum alignment end (minimal underflow risk)") + else: + log_func(" 📍 This is a balanced middle ground between overflow and underflow") + + log_func("=" * 80) + + return analysis + +def analyze_overflow_underflow_results(results, logger=None): + """ + Analyze and display overflow and underflow results from scaling tests. + + Args: + results (dict): Results from test_scaling_levels + logger: Logger instance for output + """ + log_func = logger.info if logger else print + + scale_exponents = results['scale_exponents'] + elem_format = results['elem_format'] + + # Collect all overflow/underflow analyses + overflow_underflow_results = [] + significant_issues = [] + + for i in range(len(scale_exponents)): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + analysis = results['metrics'][scale_key]['overflow_underflow_analysis'] + analysis['scale_exp'] = scale_exponents[i] + analysis['scale_factor'] = 2 ** scale_exponents[i] + overflow_underflow_results.append(analysis) + + if analysis['has_significant_underflow'] or analysis['has_significant_overflow']: + significant_issues.append(analysis) + + # Only display analysis if there are significant issues + if not significant_issues: + log_func("\n✅ No significant overflow or underflow issues detected across all scaling levels") + return + + # Display comprehensive overflow/underflow analysis + log_func("\n" + "=" * 80) + log_func("OVERFLOW/UNDERFLOW ANALYSIS SUMMARY") + log_func("=" * 80) + + log_func(f"Format: {elem_format}") + log_func(f"Analyzed {len(scale_exponents)} scaling levels") + log_func(f"Significant overflow/underflow detected in {len(significant_issues)} levels") + log_func("-" * 80) + + # Group by severity + high_severity = [u for u in significant_issues if u['severity'] == 'high'] + moderate_severity = [u for u in significant_issues if u['severity'] == 'moderate'] + + # Separate overflow and underflow issues + overflow_issues = [u for u in significant_issues if u['has_significant_overflow']] + underflow_issues = [u for u in significant_issues if u['has_significant_underflow']] + + # Display overflow issues + if overflow_issues: + log_func("🔴 OVERFLOW ISSUES:") + log_func("-" * 40) + for uf in overflow_issues: + log_func(f" Scale Exp: {uf['scale_exp']:.2f} (Factor: {uf['scale_factor']:.6f})") + log_func(f" Overflow: {uf['overflow_count']:,} ({uf['overflow_percent']:.2f}%)") + log_func(f" Max Normal: {uf['max_norm']:.2e}") + log_func(f" Tensor Range: [{uf['tensor_range'][0]:.2e}, {uf['tensor_range'][1]:.2e}]") + log_func(f" Severity: {uf['severity'].upper()}") + log_func("") + + # Display underflow issues + if underflow_issues: + log_func("🟡 UNDERFLOW ISSUES:") + log_func("-" * 40) + for uf in underflow_issues: + log_func(f" Scale Exp: {uf['scale_exp']:.2f} (Factor: {uf['scale_factor']:.6f})") + log_func(f" Underflow: {uf['underflow_count']:,} ({uf['underflow_percent']:.2f}%)") + log_func(f" Flush to Zero: {uf['flush_count']:,} ({uf['flush_percent']:.2f}%)") + log_func(f" Min Normal: {uf['min_norm']:.2e}") + log_func(f" Tensor Range: [{uf['tensor_range'][0]:.2e}, {uf['tensor_range'][1]:.2e}]") + log_func(f" Severity: {uf['severity'].upper()}") + log_func("") + + # Find best and worst cases + if overflow_issues: + worst_overflow = max(overflow_issues, key=lambda x: x['overflow_percent']) + log_func("OVERFLOW EXTREMES:") + log_func("-" * 40) + log_func(f"Worst Overflow: Scale Exp {worst_overflow['scale_exp']:.2f}") + log_func(f" {worst_overflow['overflow_percent']:.2f}% overflow") + + if underflow_issues: + worst_underflow = max(underflow_issues, key=lambda x: x['underflow_percent']) + best_underflow = min(underflow_issues, key=lambda x: x['underflow_percent']) + log_func("UNDERFLOW EXTREMES:") + log_func("-" * 40) + log_func(f"Worst Underflow: Scale Exp {worst_underflow['scale_exp']:.2f}") + log_func(f" {worst_underflow['underflow_percent']:.2f}% underflow, {worst_underflow['flush_percent']:.2f}% flushed to zero") + log_func(f"Best Underflow: Scale Exp {best_underflow['scale_exp']:.2f}") + log_func(f" {best_underflow['underflow_percent']:.2f}% underflow, {best_underflow['flush_percent']:.2f}% flushed to zero") + + # Recommendations + log_func("-" * 80) + log_func("OVERFLOW/UNDERFLOW RECOMMENDATIONS:") + log_func("-" * 40) + + if high_severity: + log_func("⚠️ AVOID scaling factors with HIGH overflow/underflow severity") + log_func(" These factors cause significant precision loss") + + if overflow_issues: + log_func("🔴 OVERFLOW WARNING:") + log_func(" Avoid scaling factors that cause overflow") + log_func(" These values will be saturated to max representable value") + + if underflow_issues: + log_func("🟡 UNDERFLOW CONSIDERATIONS:") + log_func(" Moderate underflow may be acceptable depending on use case") + log_func(" Balance between underflow and overflow risks") + + # Find optimal range + no_issue_levels = [u for u in overflow_underflow_results if not u['has_significant_underflow'] and not u['has_significant_overflow']] + if no_issue_levels: + optimal_range = [min(u['scale_exp'] for u in no_issue_levels), + max(u['scale_exp'] for u in no_issue_levels)] + log_func(f"✅ RECOMMENDED scaling range: {optimal_range[0]:.2f} to {optimal_range[1]:.2f}") + log_func(" This range minimizes both overflow and underflow issues") + else: + log_func("⚠️ All scaling levels have some overflow/underflow - choose least problematic") + # Find least problematic range + least_problematic = min(overflow_underflow_results, key=lambda x: max(x['overflow_percent'], x['underflow_percent'])) + log_func(f"💡 Least problematic scaling: {least_problematic['scale_exp']:.2f}") + log_func(f" Overflow: {least_problematic['overflow_percent']:.2f}%, Underflow: {least_problematic['underflow_percent']:.2f}%") + + log_func("=" * 80) + +def quantize_with_fixed_scale(input_tensor, elem_format, scale_bits, scale_exp, + ebits, mbits, max_norm, axes=None, block_size=0): + """ + Custom quantization function with fixed scale exponent. + This function simulates the exact behavior of mxfp.py _quantize_mx function. + + Args: + input_tensor (torch.Tensor): Input tensor + elem_format (str): Element format + scale_bits (int): Number of scale bits + scale_exp (float): Fixed scale exponent (log2 of scaling factor) + ebits (int): Exponent bits + mbits (int): Mantissa bits + max_norm (float): Maximum normal value + axes (list): Axes for shared exponent calculation + block_size (int): Block size for tiling + + Returns: + tuple: (quantized_tensor, overflow_underflow_analysis) + """ + A = input_tensor.clone() + + # Apply scaling directly (this simulates the A = A / (2**shared_exp) step in mxfp.py) + scale_factor = 2.0 ** scale_exp # Use float to handle negative exponents + A = A / scale_factor + + # Quantize element-wise + from quant.mxfp import _quantize_elemwise_core,_analyze_overflow_underflow_before_quantization + + # Analyze overflow/underflow without printing (collect results) + overflow_underflow_analysis = _analyze_overflow_underflow_before_quantization( + A, elem_format, mbits, ebits, max_norm, verbose=False + ) + + A = _quantize_elemwise_core( + A, mbits, ebits, max_norm, round='nearest', + allow_denorm=True, saturate_normals=True + ) + + # Undo scaling + A = A * scale_factor + + return A, overflow_underflow_analysis + +def plot_scaling_results(results, output_path): + """ + Create comprehensive plots showing scaling test results. + + Args: + results (dict): Results from test_scaling_levels + output_path (Path): Output directory for plots + """ + scale_exponents = results['scale_exponents'] + elem_format = results['elem_format'] + + # Extract metrics for plotting + metrics_data = {} + for metric_name in ['mse', 'rmse', 'cosine_similarity', 'psnr', 'mae', 'max_abs_error', 'relative_error']: + metrics_data[metric_name] = [] + for i in range(len(scale_exponents)): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + metrics_data[metric_name].append(results['metrics'][scale_key]['metrics'][metric_name]) + + # Create figure with subplots + fig, axes = plt.subplots(3, 2, figsize=(15, 18)) + fig.suptitle(f'MXFP Scaling Test Results - {elem_format.upper()}', fontsize=16, fontweight='bold') + + # Plot 1: MSE + axes[0, 0].semilogy(scale_exponents, metrics_data['mse'], 'b-o', linewidth=2, markersize=4) + axes[0, 0].set_xlabel('Scale Exponent') + axes[0, 0].set_ylabel('MSE (log scale)') + axes[0, 0].set_title('Mean Squared Error vs Scale Exponent') + axes[0, 0].grid(True, alpha=0.3) + + # Plot 2: Cosine Similarity + axes[0, 1].plot(scale_exponents, metrics_data['cosine_similarity'], 'g-o', linewidth=2, markersize=4) + axes[0, 1].set_xlabel('Scale Exponent') + axes[0, 1].set_ylabel('Cosine Similarity') + axes[0, 1].set_title('Cosine Similarity vs Scale Exponent') + axes[0, 1].grid(True, alpha=0.3) + axes[0, 1].set_ylim([0, 1]) + + # Plot 3: PSNR + # Handle infinite PSNR values + psnr_values = metrics_data['psnr'] + psnr_finite = [p if p != float('inf') else 1000 for p in psnr_values] # Cap at 1000 for plotting + + axes[1, 0].plot(scale_exponents, psnr_finite, 'r-o', linewidth=2, markersize=4) + axes[1, 0].set_xlabel('Scale Exponent') + axes[1, 0].set_ylabel('PSNR (dB)') + axes[1, 0].set_title('Peak Signal-to-Noise Ratio vs Scale Exponent') + axes[1, 0].grid(True, alpha=0.3) + + # Plot 4: MAE + axes[1, 1].semilogy(scale_exponents, metrics_data['mae'], 'm-o', linewidth=2, markersize=4) + axes[1, 1].set_xlabel('Scale Exponent') + axes[1, 1].set_ylabel('MAE (log scale)') + axes[1, 1].set_title('Mean Absolute Error vs Scale Exponent') + axes[1, 1].grid(True, alpha=0.3) + + # Plot 5: Maximum Absolute Error + axes[2, 0].semilogy(scale_exponents, metrics_data['max_abs_error'], 'c-o', linewidth=2, markersize=4) + axes[2, 0].set_xlabel('Scale Exponent') + axes[2, 0].set_ylabel('Max Absolute Error (log scale)') + axes[2, 0].set_title('Maximum Absolute Error vs Scale Exponent') + axes[2, 0].grid(True, alpha=0.3) + + # Plot 6: Relative Error + axes[2, 1].plot(scale_exponents, metrics_data['relative_error'], 'orange', marker='o', linewidth=2, markersize=4) + axes[2, 1].set_xlabel('Scale Exponent') + axes[2, 1].set_ylabel('Relative Error (%)') + axes[2, 1].set_title('Relative Error vs Scale Exponent') + axes[2, 1].grid(True, alpha=0.3) + + # Add format information + format_params = results['format_params'] + info_text = f"Format: {elem_format}\nE-bits: {format_params['ebits']}, M-bits: {format_params['mbits']}\n" + info_text += f"Max Normal: ±{format_params['max_norm']:.1e}\nMin Normal: {format_params['min_norm']:.1e}" + + fig.text(0.02, 0.02, info_text, fontsize=10, verticalalignment='bottom', + bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgray', alpha=0.8)) + + plt.tight_layout() + plt.subplots_adjust(top=0.93, bottom=0.15) + + # Save plot + plot_path = output_path / f'mxfp_scaling_test_{elem_format}.png' + plt.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='white') + plt.close() + + # This will be logged by the caller + pass + + # Create summary plot with key metrics + create_summary_plot(results, output_path) + +def create_summary_plot(results, output_path): + """Create a summary plot with the most important metrics.""" + scale_exponents = results['scale_exponents'] + elem_format = results['elem_format'] + + # Extract key metrics + mse_values = [] + cosine_sim_values = [] + psnr_values = [] + + for i in range(len(scale_exponents)): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + metrics = results['metrics'][scale_key]['metrics'] + mse_values.append(metrics['mse']) + cosine_sim_values.append(metrics['cosine_similarity']) + psnr_values.append(metrics['psnr']) + + # Handle infinite PSNR values + psnr_finite = [p if p != float('inf') else 1000 for p in psnr_values] + + # Create summary plot + fig, ax1 = plt.subplots(figsize=(12, 8)) + + # Plot MSE and PSNR on left y-axis + color1 = 'tab:blue' + ax1.set_xlabel('Scale Exponent', fontsize=12) + ax1.set_ylabel('MSE (log scale)', color=color1, fontsize=12) + line1 = ax1.semilogy(scale_exponents, mse_values, 'o-', color=color1, linewidth=2, markersize=6, label='MSE') + ax1.tick_params(axis='y', labelcolor=color1) + ax1.grid(True, alpha=0.3) + + # Create second y-axis for cosine similarity + ax2 = ax1.twinx() + color2 = 'tab:green' + ax2.set_ylabel('Cosine Similarity', color=color2, fontsize=12) + line2 = ax2.plot(scale_exponents, cosine_sim_values, 's-', color=color2, linewidth=2, markersize=6, label='Cosine Similarity') + ax2.tick_params(axis='y', labelcolor=color2) + ax2.set_ylim([0, 1]) + + # Add PSNR as dashed line on ax1 + ax1_2 = ax1.twinx() + ax1_2.spines['right'].set_position(('outward', 60)) + color3 = 'tab:red' + ax1_2.set_ylabel('PSNR (dB)', color=color3, fontsize=12) + line3 = ax1_2.plot(scale_exponents, psnr_finite, '^-', color=color3, linewidth=2, markersize=6, linestyle='--', label='PSNR') + ax1_2.tick_params(axis='y', labelcolor=color3) + + # Add title and legend + plt.title(f'MXFP Scaling Test Summary - {elem_format.upper()}\nKey Metrics vs Scale Exponent', + fontsize=14, fontweight='bold', pad=20) + + # Combine legends + lines = line1 + line2 + line3 + labels = [l.get_label() for l in lines] + ax1.legend(lines, labels, loc='upper right', fontsize=10) + + plt.tight_layout() + + # Save summary plot + summary_path = output_path / f'mxfp_scaling_summary_{elem_format}.png' + plt.savefig(summary_path, dpi=300, bbox_inches='tight', facecolor='white') + plt.close() + + # This will be logged by the caller + pass + +def save_results_to_file(results, output_path): + """Save detailed results to a text file.""" + results_path = output_path / f'mxfp_scaling_results_{results["elem_format"]}.txt' + + with open(results_path, 'w') as f: + f.write("MXFP Scaling Test Results\n") + f.write("=" * 50 + "\n\n") + + f.write(f"Element Format: {results['elem_format']}\n") + f.write(f"Scale Bits: {results['scale_bits']}\n") + f.write(f"Format Parameters: {results['format_params']}\n\n") + + f.write("Detailed Results:\n") + f.write("-" * 30 + "\n") + + for i, scale_exp in enumerate(results['scale_exponents']): + scale_key = f'scale_{i}' + if scale_key in results['metrics']: + metrics = results['metrics'][scale_key]['metrics'] + overflow_underflow_analysis = results['metrics'][scale_key]['overflow_underflow_analysis'] + + f.write(f"Scale Exponent {scale_exp:.2f} (Factor: {2**scale_exp:.6f}):\n") + f.write(" Performance Metrics:\n") + f.write(f" MSE: {metrics['mse']:.6e}\n") + f.write(f" RMSE: {metrics['rmse']:.6e}\n") + f.write(f" Cosine Similarity: {metrics['cosine_similarity']:.6f}\n") + f.write(f" PSNR: {metrics['psnr']:.2f} dB\n") + f.write(f" MAE: {metrics['mae']:.6e}\n") + f.write(f" Max Absolute Error: {metrics['max_abs_error']:.6e}\n") + f.write(f" Relative Error: {metrics['relative_error']:.2f}%\n") + + f.write(" Overflow/Underflow Analysis:\n") + f.write(f" Total Elements: {overflow_underflow_analysis['total_elements']:,}\n") + f.write(f" Underflow Count: {overflow_underflow_analysis['underflow_count']:,} ({overflow_underflow_analysis['underflow_percent']:.2f}%)\n") + f.write(f" Flush to Zero Count: {overflow_underflow_analysis['flush_count']:,} ({overflow_underflow_analysis['flush_percent']:.2f}%)\n") + f.write(f" Overflow Count: {overflow_underflow_analysis['overflow_count']:,} ({overflow_underflow_analysis['overflow_percent']:.2f}%)\n") + f.write(f" Min Denormal: {overflow_underflow_analysis['min_denormal']:.2e}\n") + f.write(f" Min Normal: {overflow_underflow_analysis['min_norm']:.2e}\n") + f.write(f" Max Normal: {overflow_underflow_analysis['max_norm']:.2e}\n") + f.write(f" Tensor Range: [{overflow_underflow_analysis['tensor_range'][0]:.2e}, {overflow_underflow_analysis['tensor_range'][1]:.2e}]\n") + f.write(f" Severity: {overflow_underflow_analysis['severity'].upper()}\n") + f.write(f" Has Significant Underflow: {'Yes' if overflow_underflow_analysis['has_significant_underflow'] else 'No'}\n") + f.write(f" Has Significant Overflow: {'Yes' if overflow_underflow_analysis['has_significant_overflow'] else 'No'}\n") + if overflow_underflow_analysis['error']: + f.write(f" Analysis Error: {overflow_underflow_analysis['error']}\n") + f.write("\n") + + # This will be logged by the caller + pass + +def process_single_tensor(input_path, args, logger=None): + """Process a single tensor file.""" + + # Validate input file + if not input_path.exists(): + print(f"Error: Input file does not exist: {input_path}") + return 1 + + if not input_path.is_file(): + print(f"Error: Input path is not a file: {input_path}") + return 1 + + # Setup output directory + if args.output_dir is None: + # Generate output directory based on tensor name + tensor_name = input_path.stem # Get filename without extension + output_dir = Path(f"./draw/scaling_analysis/{args.elem-format}{args.elem-format}/{tensor_name}") + else: + output_dir = Path(args.output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + + # Setup logging for this tensor + tensor_name = input_path.stem + tensor_logger = setup_logging(output_dir, tensor_name, args.elem_format) + + tensor_logger.info(f"Loading input tensor: {input_path.name}") + tensor_logger.info("=" * 60) + + # Load input tensor + try: + input_tensor = torch.load(str(input_path), map_location='cpu', weights_only=False) + + # Handle case where loaded object is not a tensor + if not isinstance(input_tensor, torch.Tensor): + if isinstance(input_tensor, dict) and 'tensor' in input_tensor: + input_tensor = input_tensor['tensor'] + elif isinstance(input_tensor, (list, tuple)) and len(input_tensor) > 0: + input_tensor = input_tensor[0] + else: + tensor_logger.error(f"Error: Loaded object is not a tensor: {input_path.name}") + return 1 + + # Convert to BF16 if needed + if input_tensor.dtype != torch.bfloat16: + tensor_logger.info(f"Converting tensor from {input_tensor.dtype} to bfloat16") + input_tensor = input_tensor.bfloat16() + + tensor_logger.info(f"Tensor shape: {input_tensor.shape}") + tensor_logger.info(f"Tensor dtype: {input_tensor.dtype}") + tensor_logger.info(f"Value range: [{torch.min(input_tensor):.6f}, {torch.max(input_tensor):.6f}]") + tensor_logger.info(f"Mean ± Std: {torch.mean(input_tensor):.6f} ± {torch.std(input_tensor):.6f}") + + except Exception as e: + tensor_logger.error(f"Error loading tensor {input_path.name}: {str(e)}") + return 1 + + # Run scaling test + results = test_scaling_levels( + input_tensor, + args.elem_format, + args.scale_bits, + max_scale_exp=args.max_scale_exp, + min_scale_exp=args.min_scale_exp, + logger=tensor_logger + ) + + # Save results to file + save_results_to_file(results, output_dir) + tensor_logger.info(f"Detailed results saved to: {output_dir}") + + # Generate plots unless disabled + if not args.no_plots: + plot_scaling_results(results, output_dir) + tensor_logger.info(f"Plots saved to: {output_dir}") + + # Perform detailed analysis + analysis_results = analyze_scaling_results(results, tensor_logger) + + # Analyze overflow/underflow results + analyze_overflow_underflow_results(results, tensor_logger) + + # Print summary + tensor_logger.info("\n" + "=" * 60) + tensor_logger.info("SCALING TEST SUMMARY") + tensor_logger.info("=" * 60) + + # Use analysis results for summary + best_composite = analysis_results['best_composite'] + best_mse = analysis_results['best_mse'] + best_cosine = analysis_results['best_cosine'] + + tensor_logger.info(f"Best Cosine Similarity: {best_cosine['metrics']['cosine_similarity']:.6f} at scale {best_cosine['scale_exp']:.2f}") + tensor_logger.info(f"Best MSE: {best_mse['metrics']['mse']:.6e} at scale {best_mse['scale_exp']:.2f}") + tensor_logger.info(f"Best PSNR: {best_mse['metrics']['psnr']:.2f} dB at scale {best_mse['scale_exp']:.2f}") + + tensor_logger.info(f"\n🎯 RECOMMENDED Scaling Factor: {best_composite['scale_factor']:.6f}") + tensor_logger.info(f" Scale Exponent: {best_composite['scale_exp']:.2f}") + tensor_logger.info(f" Composite Score: {best_composite['composite_score']:.4f}") + + tensor_logger.info(f"\nResults saved to: {output_dir}") + tensor_logger.info("Test completed successfully!") + + # Log completion time + tensor_logger.info("=" * 80) + tensor_logger.info(f"Test completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + tensor_logger.info("=" * 80) + + return 0 + +def main(): + """Main function for MXFP scaling test.""" + parser = argparse.ArgumentParser(description='Test different scaling strategies for MXFP quantization') + parser.add_argument('input_tensors', nargs='+', help='Path(s) to input BF16 tensor file(s) (.pt)') + parser.add_argument('--output-dir', default=None, + help='Output directory for results (default: ./draw/scaling_analysis/{args.elem-format}{args.elem-format}/{tensor_name}/)') + parser.add_argument('--elem-format', default='fp8_e4m3', + choices=['fp8_e4m3', 'fp8_e5m2', 'fp4_e2m1', 'fp6_e3m2', 'fp6_e2m3'], + help='Element format for quantization (default: fp8_e4m3)') + parser.add_argument('--scale-bits', type=int, default=8, + help='Number of scale bits (default: 8)') + parser.add_argument('--max-scale-exp', type=int, default=10, + help='Maximum scale exponent (default: auto-calculated from tensor max if using default value)') + parser.add_argument('--min-scale-exp', type=int, default=-10, + help='Minimum scale exponent (default: auto-calculated from tensor min if using default value)') + parser.add_argument('--no-plots', action='store_true', + help='Skip generating plots') + + args = parser.parse_args() + + # Process multiple tensors + total_tensors = len(args.input_tensors) + successful_tests = 0 + + print(f"Processing {total_tensors} tensor(s)...") + print("=" * 80) + + for i, tensor_path in enumerate(args.input_tensors, 1): + print(f"\n[{i}/{total_tensors}] Processing: {tensor_path}") + print("-" * 60) + + input_path = Path(tensor_path) + result = process_single_tensor(input_path, args) + + if result == 0: + successful_tests += 1 + print(f"✅ Successfully processed: {tensor_path}") + else: + print(f"❌ Failed to process: {tensor_path}") + + # Final summary + print("\n" + "=" * 80) + print("FINAL SUMMARY") + print("=" * 80) + print(f"Total tensors: {total_tensors}") + print(f"Successful: {successful_tests}") + print(f"Failed: {total_tensors - successful_tests}") + + if successful_tests == total_tensors: + print("🎉 All tests completed successfully!") + return 0 + else: + print("⚠️ Some tests failed. Check individual logs for details.") + return 1 + +if __name__ == '__main__': + exit(main()) diff --git a/visualization/overflow/run_analysis.py b/visualization/overflow/run_analysis.py new file mode 100644 index 0000000000..bf404bc049 --- /dev/null +++ b/visualization/overflow/run_analysis.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +""" +Simple wrapper script to run the scaling factor analysis. + +Usage: + python3 run_analysis.py [base_directory] + +If no base_directory is provided, it will use the current directory. +""" + +import sys +from pathlib import Path +from analyze_scaling_factors import ScalingAnalyzer + + +def main(): + # Get base directory from command line or use current directory + if len(sys.argv) > 1: + base_dir = sys.argv[1] + else: + base_dir = str(Path.cwd()) + + print(f"🎯 Running scaling factor analysis on: {base_dir}") + + try: + analyzer = ScalingAnalyzer(base_dir) + analyzer.analyze_all_files() + + if analyzer.results: + analyzer.print_detailed_report() + analyzer.save_results_to_json() + + # Quick summary + summary = analyzer.generate_summary() + if summary['at_max_percentage'] == 100.0: + print("\n🎉 RESULT: ALL tensors are using their maximum scaling factors!") + else: + print(f"\n⚠️ RESULT: {summary['not_at_max_count']} out of {summary['total_tensors']} tensors are NOT at maximum scaling.") + else: + print("❌ No valid log files found or parsed.") + + except Exception as e: + print(f"❌ Error during analysis: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() + + + + diff --git a/visualization/overflow_summary.py b/visualization/overflow_summary.py new file mode 100755 index 0000000000..80a07f7ddb --- /dev/null +++ b/visualization/overflow_summary.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +""" +Comprehensive overflow/underflow analysis for all tensor files. +Generates detailed analysis reports for all data formats. +""" + +import os +import sys +import json +from pathlib import Path +from datetime import datetime +import subprocess +from tqdm import tqdm + +# Add current directory to path to import overflow module +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from overflow import analyze_file, DATA_TYPE_RANGES + +def analyze_all_tensors(base_dir="enhanced_tensor_logs", output_dir="visualization"): + """ + Analyze all tensor files in the enhanced_tensor_logs directory structure. + + Args: + base_dir (str): Base directory containing tensor files + output_dir (str): Output directory for reports + + Returns: + dict: Complete analysis results + """ + base_path = Path(base_dir) + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + if not base_path.exists(): + raise FileNotFoundError(f"Base directory not found: {base_dir}") + + all_results = {} + summary_stats = {} + + print("Starting comprehensive tensor overflow/underflow analysis...") + print("=" * 60) + + # Analyze each data format directory + for data_format in DATA_TYPE_RANGES.keys(): + format_dir = base_path / data_format + + if not format_dir.exists(): + print(f"Warning: Directory not found for {data_format}: {format_dir}") + continue + + print(f"Analyzing {data_format.upper()} tensors...") + + format_results = [] + tensor_files = list(format_dir.glob("*.pt")) + + if not tensor_files: + print(f" No tensor files found in {format_dir}") + continue + + print(f" Found {len(tensor_files)} tensor files") + + # Analyze each tensor file with progress bar + for tensor_file in tqdm(tensor_files, desc=f" Processing {data_format.upper()}", unit="files"): + result = analyze_file(str(tensor_file)) + if result: + format_results.append(result) + + all_results[data_format] = format_results + + # Calculate summary statistics for this format + if format_results: + total_files = len(format_results) + total_elements = sum(r['total_elements'] for r in format_results) + total_overflow = sum(r['overflow_count'] for r in format_results) + total_underflow = sum(r['underflow_count'] for r in format_results) + + files_with_overflow = sum(1 for r in format_results if r['overflow_count'] > 0) + files_with_underflow = sum(1 for r in format_results if r['underflow_count'] > 0) + + overflow_percent = (total_overflow / total_elements) * 100 if total_elements > 0 else 0 + underflow_percent = (total_underflow / total_elements) * 100 if total_elements > 0 else 0 + + # Value statistics + all_mins = [r['tensor_min'] for r in format_results] + all_maxs = [r['tensor_max'] for r in format_results] + all_means = [r['tensor_mean'] for r in format_results] + + summary_stats[data_format] = { + 'total_files': total_files, + 'total_elements': total_elements, + 'total_overflow': total_overflow, + 'total_underflow': total_underflow, + 'files_with_overflow': files_with_overflow, + 'files_with_underflow': files_with_underflow, + 'overflow_percent': overflow_percent, + 'underflow_percent': underflow_percent, + 'global_min': min(all_mins), + 'global_max': max(all_maxs), + 'avg_mean': sum(all_means) / len(all_means), + 'format_range': [DATA_TYPE_RANGES[data_format]['min'], DATA_TYPE_RANGES[data_format]['max']], + 'description': DATA_TYPE_RANGES[data_format]['description'] + } + + print(f" Completed {data_format.upper()}: {len(format_results)} files processed") + + # Generate comprehensive report + generate_comprehensive_report(all_results, summary_stats, output_path) + + # Generate detailed JSON report + generate_json_report(all_results, summary_stats, output_path) + + # Generate CSV summary + generate_csv_summary(summary_stats, output_path) + + print("\nAnalysis complete! Generated reports:") + print(f" - {output_path}/overflow_comprehensive_report.txt") + print(f" - {output_path}/overflow_detailed_results.json") + print(f" - {output_path}/overflow_summary.csv") + + return all_results, summary_stats + +def generate_comprehensive_report(all_results, summary_stats, output_path): + """Generate a comprehensive text report.""" + report_file = output_path / "overflow_comprehensive_report.txt" + + with open(report_file, 'w') as f: + f.write("=" * 80 + "\n") + f.write("TENSOR FILE OVERFLOW/UNDERFLOW ANALYSIS REPORT\n") + f.write("=" * 80 + "\n") + f.write("This report shows the PERCENTAGE of tensor files that have overflow/underflow issues\n") + f.write("(vs. the percentage of values within each file)\n") + f.write("=" * 80 + "\n") + f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") + + # Executive Summary + f.write("EXECUTIVE SUMMARY\n") + f.write("-" * 40 + "\n") + + total_files_all = sum(stats['total_files'] for stats in summary_stats.values()) + total_files_with_overflow = sum(stats['files_with_overflow'] for stats in summary_stats.values()) + total_files_with_underflow = sum(stats['files_with_underflow'] for stats in summary_stats.values()) + total_files_with_issues = len(set().union(*[ + [r['filename'] for r in results if r['has_issues']] + for results in all_results.values() + ])) + + f.write(f"Total tensor files analyzed: {total_files_all:,}\n") + f.write(f"Files with overflow issues: {total_files_with_overflow:,} ({(total_files_with_overflow/total_files_all)*100:.2f}%)\n") + f.write(f"Files with underflow issues: {total_files_with_underflow:,} ({(total_files_with_underflow/total_files_all)*100:.2f}%)\n") + f.write(f"Files with any issues: {total_files_with_issues:,} ({(total_files_with_issues/total_files_all)*100:.2f}%)\n\n") + + # Format-specific summaries + f.write("FORMAT-SPECIFIC ANALYSIS\n") + f.write("-" * 40 + "\n") + + for data_format, stats in summary_stats.items(): + f.write(f"\n{data_format.upper()} ({stats['description']})\n") + f.write("─" * 50 + "\n") + f.write(f"Representable range: [{stats['format_range'][0]}, {stats['format_range'][1]}]\n") + f.write(f"Files analyzed: {stats['total_files']:,}\n") + f.write(f"Value range observed: [{stats['global_min']:.6f}, {stats['global_max']:.6f}]\n") + f.write(f"Average mean value: {stats['avg_mean']:.6f}\n\n") + + f.write("File-Level Analysis (Primary Focus):\n") + f.write(f" Files with overflow: {stats['files_with_overflow']}/{stats['total_files']} ({(stats['files_with_overflow']/stats['total_files'])*100:.1f}%)\n") + f.write(f" Files with underflow: {stats['files_with_underflow']}/{stats['total_files']} ({(stats['files_with_underflow']/stats['total_files'])*100:.1f}%)\n") + + f.write("Value-Level Statistics (Reference):\n") + f.write(f" Overflow value percentage: {stats['overflow_percent']:.4f}%\n") + f.write(f" Underflow value percentage: {stats['underflow_percent']:.4f}%\n") + f.write(f" Total elements: {stats['total_elements']:,}\n") + + # Risk assessment based on file percentage + file_overflow_pct = (stats['files_with_overflow'] / stats['total_files']) * 100 + file_underflow_pct = (stats['files_with_underflow'] / stats['total_files']) * 100 + + risk_level = "LOW" + if file_overflow_pct > 10.0 or file_underflow_pct > 10.0: + risk_level = "HIGH" + elif file_overflow_pct > 1.0 or file_underflow_pct > 1.0: + risk_level = "MEDIUM" + + f.write(f"Risk Level: {risk_level} (based on file percentage)\n") + + if stats['files_with_overflow'] > 0 or stats['files_with_underflow'] > 0: + f.write("⚠️ ATTENTION: Some files have overflow/underflow issues!\n") + + f.write("\n") + + # Detailed file listings for problematic cases + f.write("DETAILED PROBLEMATIC FILES\n") + f.write("-" * 40 + "\n") + + for data_format, results in all_results.items(): + problematic_files = [r for r in results if r['overflow_count'] > 0 or r['underflow_count'] > 0] + + if problematic_files: + f.write(f"\n{data_format.upper()} - Files with overflow/underflow:\n") + f.write("─" * 30 + "\n") + + for result in problematic_files[:20]: # Limit to first 20 problematic files + f.write(f"File: {result['filename']}\n") + f.write(f" Shape: {result['shape']}\n") + f.write(f" Range: [{result['tensor_min']:.6f}, {result['tensor_max']:.6f}]\n") + f.write(f" Overflow: {result['overflow_count']:,} ({result['overflow_percent']:.4f}%)\n") + f.write(f" Underflow: {result['underflow_count']:,} ({result['underflow_percent']:.4f}%)\n") + f.write("\n") + + if len(problematic_files) > 20: + f.write(f"... and {len(problematic_files) - 20} more problematic files\n\n") + + # Recommendations + f.write("RECOMMENDATIONS\n") + f.write("-" * 40 + "\n") + + for data_format, stats in summary_stats.items(): + if stats['total_overflow'] > 0 or stats['total_underflow'] > 0: + f.write(f"{data_format.upper()}:\n") + + if stats['total_overflow'] > 0: + f.write(f" - Consider clipping values above {stats['format_range'][1]} to prevent overflow\n") + f.write(f" - Review model architecture or training parameters causing extreme values\n") + + if stats['total_underflow'] > 0: + f.write(f" - Consider clipping values below {stats['format_range'][0]} to prevent underflow\n") + f.write(f" - Review initialization or gradient scaling settings\n") + + f.write(f" - Monitor numerical stability during training\n") + f.write(f" - Consider mixed precision training strategies\n\n") + +def generate_json_report(all_results, summary_stats, output_path): + """Generate detailed JSON report.""" + json_file = output_path / "overflow_detailed_results.json" + + report_data = { + 'metadata': { + 'generated_on': datetime.now().isoformat(), + 'total_formats': len(summary_stats), + 'total_files': sum(stats['total_files'] for stats in summary_stats.values()) + }, + 'summary_statistics': summary_stats, + 'detailed_results': all_results + } + + with open(json_file, 'w') as f: + json.dump(report_data, f, indent=2, default=str) + +def generate_csv_summary(summary_stats, output_path): + """Generate CSV summary report.""" + import csv + + csv_file = output_path / "overflow_summary.csv" + + with open(csv_file, 'w', newline='') as f: + fieldnames = [ + 'data_format', 'description', 'total_files', 'total_elements', + 'total_overflow', 'total_underflow', 'overflow_percent', 'underflow_percent', + 'files_with_overflow', 'files_with_underflow', 'global_min', 'global_max', + 'avg_mean', 'format_min', 'format_max' + ] + + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + + for data_format, stats in summary_stats.items(): + row = { + 'data_format': data_format, + 'description': stats['description'], + 'total_files': stats['total_files'], + 'total_elements': stats['total_elements'], + 'total_overflow': stats['total_overflow'], + 'total_underflow': stats['total_underflow'], + 'overflow_percent': stats['overflow_percent'], + 'underflow_percent': stats['underflow_percent'], + 'files_with_overflow': stats['files_with_overflow'], + 'files_with_underflow': stats['files_with_underflow'], + 'global_min': stats['global_min'], + 'global_max': stats['global_max'], + 'avg_mean': stats['avg_mean'], + 'format_min': stats['format_range'][0], + 'format_max': stats['format_range'][1] + } + writer.writerow(row) + +def main(): + """Main function to run comprehensive analysis.""" + import argparse + + parser = argparse.ArgumentParser(description='Comprehensive tensor overflow/underflow analysis') + parser.add_argument('--base-dir', default='enhanced_tensor_logs', + help='Base directory containing tensor files (default: enhanced_tensor_logs)') + parser.add_argument('--output-dir', default='./draw/overflow_summary/', + help='Output directory for reports (default: ./draw/overflow_summary/)') + + args = parser.parse_args() + + try: + all_results, summary_stats = analyze_all_tensors(args.base_dir, args.output_dir) + + # Print summary to console + print("\n" + "=" * 60) + print("ANALYSIS SUMMARY") + print("=" * 60) + + for data_format, stats in summary_stats.items(): + file_overflow_pct = (stats['files_with_overflow'] / stats['total_files']) * 100 + file_underflow_pct = (stats['files_with_underflow'] / stats['total_files']) * 100 + print(f"{data_format.upper()}: {stats['total_files']} files, " + f"{file_overflow_pct:.1f}% files with overflow, " + f"{file_underflow_pct:.1f}% files with underflow") + + return 0 + + except Exception as e: + print(f"Error during analysis: {str(e)}") + return 1 + +if __name__ == "__main__": + exit(main())