diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..62f227d3 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,12 @@ +* text=auto +*.c text eol=lf +*.cc text eol=lf +*.cpp text eol=lf +*.cu text eol=lf +*.h text eol=lf +*.hpp text eol=lf +*.py text eol=lf +*.sh text eol=lf +*.bash text eol=lf +CMakeLists.txt text eol=lf +.gitignore text eol=lf \ No newline at end of file diff --git a/.gitignore b/.gitignore index 50b9fa06..266ddd8e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,15 @@ build/ *.log *.report.rank* *.records.log.rank* + +#------modify-start------------------------------------------ +# Local sanity-check datasets (not part of repo) +tmp_data/ +#---------modify-end----------------------------------------- +tmp/ + +# Generated Flash SDPA benchmark outputs +docs/flash_sdpa/logs/ +docs/flash_sdpa/env/ +docs/flash_sdpa/report_*.md +docs/flash_sdpa/summary_*.csv diff --git a/CMakeLists.txt b/CMakeLists.txt index df636b27..e3245860 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,6 +73,18 @@ endif() if(USE_CUDA) add_compile_definitions(USE_CUDA=1) + + #------modify-start------------------------------------------ + # CMake may fail to auto-detect nvcc / default architectures if CUDA is not in PATH. + # Pin nvcc path and a reasonable default arch for A100 (sm_80). + if(NOT DEFINED CMAKE_CUDA_COMPILER AND EXISTS "/usr/local/cuda/bin/nvcc") + set(CMAKE_CUDA_COMPILER "/usr/local/cuda/bin/nvcc") + endif() + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES 80) + endif() + #---------modify-end----------------------------------------- + enable_language(CUDA) find_package(CUDAToolkit REQUIRED) include_directories(${CUDAToolkit_INCLUDE_DIRS}) @@ -90,10 +102,34 @@ if(USE_CUDA) PUBLIC glog CUDA::cudart + #------modify-start------------------------------------------ + CUDA::nvrtc + #---------modify-end----------------------------------------- CUDA::cublas CUDA::cuda_driver ) + #------modify-start------------------------------------------ + # cuDNN + cudnn-frontend (header-only) for fused SDPA backend + find_library(CUDNN_LIBRARY cudnn HINTS /lib/x86_64-linux-gnu /usr/lib/x86_64-linux-gnu) + if(NOT CUDNN_LIBRARY) + message(FATAL_ERROR "cuDNN (libcudnn.so) not found") + endif() + find_path( + CUDNN_FRONTEND_INCLUDE_DIR + cudnn_frontend.h + HINTS + ${PROJECT_SOURCE_DIR}/third_party/cudnn_frontend/include + /usr/include + /usr/local/include + ) + if(NOT CUDNN_FRONTEND_INCLUDE_DIR) + message(FATAL_ERROR "cudnn_frontend.h not found. Install cudnn-frontend or clone it under third_party/cudnn_frontend/include") + endif() + target_link_libraries(infini_train_cuda_kernels PUBLIC ${CUDNN_LIBRARY}) + target_include_directories(infini_train_cuda_kernels PUBLIC ${CUDNN_FRONTEND_INCLUDE_DIR}) + #---------modify-end----------------------------------------- + if(USE_NCCL) message(STATUS "Add USE_NCCL, use NCCL with CUDA") list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) @@ -196,11 +232,7 @@ set_target_properties(infini_run PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BIN # Tests add_executable(test_hook test/hook/test_hook.cc) -link_infini_train_exe(test_hook) +target_link_libraries(test_hook infini_train) add_executable(test_precision_check test/hook/test_precision_check.cc) -link_infini_train_exe(test_precision_check) - -add_executable(test_lora test/lora/test_lora.cc) -link_infini_train_exe(test_lora) - +target_link_libraries(test_precision_check infini_train) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 8e28af52..a8696c62 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -13,7 +13,6 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" -#include "infini_train/include/nn/lora/lora_utils.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" @@ -75,19 +74,14 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); +//------modify-start------------------------------------------ +DEFINE_bool(flash, false, "enable fused scaled-dot-product attention (BF16 only)"); +//---------modify-end----------------------------------------- // precision check DEFINE_string( precision_check, "", "precision check config: level=N,format=simple|table,output_md5=true|false,output_path=PATH,baseline=PATH"); -// LoRA parameters -DEFINE_int32(lora_rank, 0, "LoRA rank (0 = disabled)"); -DEFINE_double(lora_alpha, 16.0, "LoRA alpha scaling factor"); -DEFINE_string(lora_target_modules, "c_attn,c_proj", - "LoRA target modules (comma-separated: c_attn,c_proj,c_fc,c_fc2,mlp.c_proj)"); -DEFINE_string(lora_save_path, "", "Path to save LoRA weights after training"); -DEFINE_string(lora_load_path, "", "Path to load LoRA weights from"); - using namespace infini_train; namespace { @@ -189,7 +183,6 @@ void Train(const nn::parallel::Rank &rank) { // init the model, either from scratch or from OpenAI pretrained checkpoint GPT2Config model_config; std::shared_ptr model = nullptr; - if (!FLAGS_llmc_filepath.empty()) { model = GPT2::FromLLMC(FLAGS_llmc_filepath); } else if (kModelToConfigs.count(FLAGS_model)) { @@ -203,29 +196,6 @@ void Train(const nn::parallel::Rank &rank) { utils::PrecisionChecker::BuildNameMap(model.get()); - // Get chunk size before wrapping with LoRA (needed for PipelineParallel) - auto gpt2_model = std::dynamic_pointer_cast(model); - CHECK(gpt2_model) << "GPT2 example expects GPT2 model."; - - // Apply LoRA using GetLoRAModel (in-place injection) - bool lora_enabled = FLAGS_lora_rank > 0; - if (lora_enabled) { - nn::lora::LoRAConfig lora_config{FLAGS_lora_rank, static_cast(FLAGS_lora_alpha), 0.0f, - nn::lora::ParseLoRATargetModules(FLAGS_lora_target_modules)}; - - // GetLoRAModel: in-place injection, modifies module tree directly - model = nn::lora::GetLoRAModel(model, lora_config); - - // Load LoRA weights if specified - if (!FLAGS_lora_load_path.empty()) { - LOG(INFO) << "Loading LoRA weights from: " << FLAGS_lora_load_path; - nn::lora::LoadLoRAWeights(model, FLAGS_lora_load_path); - } - - // Print LoRA summary - nn::lora::PrintLoRASummary(model, rank.GlobalRank()); - } - // select the data type // TODO(lzm): change to solely rely on the weight file info for determining the dtype when autocast is supported DataType dtype; @@ -239,24 +209,15 @@ void Train(const nn::parallel::Rank &rank) { auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); - // Create optimizer - use GetLoRAParameters if LoRA is enabled - std::vector> params_to_optimize; - if (lora_enabled) { - params_to_optimize = nn::lora::GetLoRAParameters(model); - LOG(INFO) << "Optimizing " << params_to_optimize.size() << " LoRA parameters"; - } else { - params_to_optimize = model->Parameters(); - LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters"; - } - if (pp_world_size > 1) { // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. auto shapes = std::vector>{ {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; - model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, - pp_rank, device, gpt2_model->GetChunkSize()); + model = std::make_shared( + model, pp_world_size, num_micro_batches, shapes, pp_rank, device, + std::dynamic_pointer_cast(model)->GetChunkSize()); if (ddp_world_size > 1) { auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; @@ -304,10 +265,10 @@ void Train(const nn::parallel::Rank &rank) { auto model_chunks = (pp_world_size > 1) ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; - optimizer = std::make_shared(optimizer_creator, params_to_optimize, + optimizer = std::make_shared(optimizer_creator, model->Parameters(), model_chunks, ddp_world_size, ddp_rank); } else { - optimizer = optimizer_creator(params_to_optimize); + optimizer = optimizer_creator(model->Parameters()); } auto train_iter = train_loader.begin(); @@ -436,13 +397,6 @@ void Train(const nn::parallel::Rank &rank) { } } } - - // Save LoRA weights if enabled and path specified - if (lora_enabled && !FLAGS_lora_save_path.empty()) { - LOG(INFO) << "Saving LoRA weights to: " << FLAGS_lora_save_path; - nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path); - } - #ifdef PROFILE_MODE Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage); Profiler::Instance().PrintRecords("gpt2.records.log"); diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index d000d1cf..6886fc65 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -11,7 +11,12 @@ #include #include +#include "gflags/gflags.h" #include "glog/logging.h" +//------modify-start------------------------------------------ +// NOTE: --flash is a global gflags option defined in main.cc. +DECLARE_bool(flash); +//---------modify-end----------------------------------------- #include "example/common/utils.h" #include "infini_train/include/device.h" @@ -105,6 +110,26 @@ CausalSelfAttention::Forward(const std::vectorView({B, T, local_n_head_, head_dim})->Transpose(1, 2); v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); + //------modify-start------------------------------------------ + // FlashAttention path (BF16 on CUDA only). + if (FLAGS_flash && q->GetDevice().type() == Device::DeviceType::kCUDA && q->Dtype() == DataType::kBFLOAT16) { + // cudnn SDPA expects a standard (B, H, T, D) layout; enforce contiguous strides. + q = q->Contiguous(); + k = k->Contiguous(); + v = v->Contiguous(); + + // (B, h_l, T, D) -> (B, h_l, T, D) + auto y = nn::function::ScaledDotProductAttention(q, k, v, /*attn_mask=*/nullptr, /*dropout_p=*/0.0, + /*is_causal=*/true); + // (B, h_l, T, D) -> (B, T, h_l, D) -> (B, T, local_C) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); + + // output projection + y = (*modules_[kCProjLayerName])({y})[0]; + return {y}; + } + //---------modify-end----------------------------------------- + // (B, h_l, T, T) auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); // (1, 1, T, T) @@ -307,11 +332,6 @@ GPT2::GPT2(const GPT2Config &config) modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); // FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation - // TODO: Implement real GPT-2 weight tying: make lm_head.weight share the exact same Parameter/Tensor (same - // shared_ptr/storage) as transformer.wte.weight (pointer aliasing, not value copy), and ensure the tie is applied - // after loading weights so it won't be overwritten. Also fix GPT2::FromLLMC() loading logic to respect weight tying - // (do not create/load a separate lm_head.weight tensor; load once into the tied weight) so parameter counting - // matches PyTorch/PEFT. if (nn::parallel::global::GetPipelineParallelSize() == 1) { // https://paperswithcode.com/method/weight-tying *mutable_module(kTransformerLayerName) diff --git a/example/llama3/main.cc b/example/llama3/main.cc index acc20ac4..d7c91913 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -11,7 +11,6 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" -#include "infini_train/include/nn/lora/lora_utils.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" @@ -73,16 +72,13 @@ DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); +//------modify-start------------------------------------------ +DEFINE_bool(flash, false, "enable fused scaled-dot-product attention (BF16 only)"); +//---------modify-end----------------------------------------- // precision check DEFINE_string( precision_check, "", "precision check config: level=N,format=simple|table,output_md5=true|false,output_path=PATH,baseline=PATH"); -// LoRA parameters -DEFINE_int32(lora_rank, 0, "LoRA rank (0 = disabled)"); -DEFINE_double(lora_alpha, 16.0, "LoRA alpha scaling factor"); -DEFINE_string(lora_target_modules, "c_attn,c_proj,c_fc,c_fc2", "LoRA target modules (comma-separated)"); -DEFINE_string(lora_save_path, "", "Path to save LoRA weights after training"); -DEFINE_string(lora_load_path, "", "Path to load LoRA weights from"); using namespace infini_train; @@ -168,6 +164,9 @@ void Train(const nn::parallel::Rank &rank) { // ManualSeed(42); LLaMA3Config model_config = LLaMA3Config(); + //------modify-start------------------------------------------ + model_config.flash = FLAGS_flash; + //---------modify-end----------------------------------------- std::shared_ptr model = nullptr; if (!FLAGS_llmc_filepath.empty()) { model = LLaMA3::FromLLMC(FLAGS_llmc_filepath); @@ -179,25 +178,6 @@ void Train(const nn::parallel::Rank &rank) { utils::PrecisionChecker::BuildNameMap(model.get()); - // Apply LoRA using GetLoRAModel (in-place injection) - bool lora_enabled = FLAGS_lora_rank > 0; - if (lora_enabled) { - nn::lora::LoRAConfig lora_config{FLAGS_lora_rank, static_cast(FLAGS_lora_alpha), 0.0f, - nn::lora::ParseLoRATargetModules(FLAGS_lora_target_modules)}; - - // GetLoRAModel: in-place injection, modifies module tree directly - model = nn::lora::GetLoRAModel(model, lora_config); - - // Load LoRA weights if specified - if (!FLAGS_lora_load_path.empty()) { - LOG(INFO) << "Loading LoRA weights from: " << FLAGS_lora_load_path; - nn::lora::LoadLoRAWeights(model, FLAGS_lora_load_path); - } - - // Print LoRA summary - nn::lora::PrintLoRASummary(model, rank.GlobalRank()); - } - LOG(INFO) << "Rank " << rank.GlobalRank() << ": Model loaded to device."; DataType dtype; @@ -263,23 +243,14 @@ void Train(const nn::parallel::Rank &rank) { auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate); std::shared_ptr optimizer = nullptr; - std::vector> params_to_optimize; - if (lora_enabled) { - params_to_optimize = nn::lora::GetLoRAParameters(model); - LOG(INFO) << "Optimizing " << params_to_optimize.size() << " LoRA parameters"; - } else { - params_to_optimize = model->Parameters(); - LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters"; - } - if (FLAGS_use_distributed_optimizer) { auto model_chunks = (pp_world_size > 1) ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; - optimizer = std::make_shared(optimizer_creator, params_to_optimize, + optimizer = std::make_shared(optimizer_creator, model->Parameters(), model_chunks, ddp_world_size, ddp_rank); } else { - optimizer = optimizer_creator(params_to_optimize); + optimizer = optimizer_creator(model->Parameters()); } auto train_iter = train_loader.begin(); @@ -405,13 +376,6 @@ void Train(const nn::parallel::Rank &rank) { } } } - - // Save LoRA weights if enabled and path specified - if (lora_enabled && !FLAGS_lora_save_path.empty()) { - LOG(INFO) << "Saving LoRA weights to: " << FLAGS_lora_save_path; - nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path); - } - #ifdef PROFILE_MODE Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage); Profiler::Instance().PrintRecords("llama3.records.log"); diff --git a/example/llama3/net.cc b/example/llama3/net.cc index a50fb831..9c4bbc81 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -216,9 +216,25 @@ std::vector> CausalSelfAttention::Forward(const std::vec q = q->Transpose(1, 2); k = k->Transpose(1, 2); v = v->Transpose(1, 2); - - // TODO(zbl): support flash attention later - // if (flash_) { ... } + //------modify-start------------------------------------------ + // FlashAttention path (BF16 on CUDA only). + if (config_.flash && q->GetDevice().type() == Device::DeviceType::kCUDA && q->Dtype() == DataType::kBFLOAT16) { + // cudnn SDPA expects a standard (B, H, T, D) layout; enforce contiguous strides. + q = q->Contiguous(); + k = k->Contiguous(); + v = v->Contiguous(); + + // (B, H_local, T, D) -> (B, H_local, T, D) + auto y = nn::function::ScaledDotProductAttention(q, k, v, /*attn_mask=*/nullptr, /*dropout_p=*/0.0, + /*is_causal=*/true); + // (B, H_local, T, D) -> (B, T, H_local, D) -> (B, T, C_local) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local}); + + // output projection + y = (*modules_[kCProjLayerName])({y})[0]; + return {y}; + } + //---------modify-end----------------------------------------- // manual implementation of attention // this materializes the large (T,T) matrix for all the queries and keys diff --git a/infini_train/include/autograd/scaled_dot_product_attention.h b/infini_train/include/autograd/scaled_dot_product_attention.h new file mode 100644 index 00000000..025573a5 --- /dev/null +++ b/infini_train/include/autograd/scaled_dot_product_attention.h @@ -0,0 +1,41 @@ +//------modify-start------------------------------------------ +#pragma once + +#include +#include +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} // namespace infini_train + +namespace infini_train::autograd { + +class ScaledDotProductAttention : public Function { +public: + static constexpr char kType[] = "ScaledDotProductAttention"; + + ScaledDotProductAttention(double dropout_p, bool is_causal, std::optional scale) + : Function(kType), dropout_p_(dropout_p), is_causal_(is_causal), scale_(scale) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + double dropout_p_ = 0.0; + bool is_causal_ = false; + std::optional scale_ = std::nullopt; + + // Saved from forward kernel for backward. + std::shared_ptr saved_stats_ = nullptr; +}; + +} // namespace infini_train::autograd +//---------modify-end----------------------------------------- diff --git a/infini_train/include/nn/functional.h b/infini_train/include/nn/functional.h index e4354fd1..45fef5b1 100644 --- a/infini_train/include/nn/functional.h +++ b/infini_train/include/nn/functional.h @@ -2,6 +2,10 @@ #include #include + +//------modify-start------------------------------------------ +#include +//---------modify-end----------------------------------------- #include namespace infini_train { @@ -183,4 +187,15 @@ std::shared_ptr Stack(const std::vector> &inputs // Concatenation of the input tensors. std::shared_ptr Concat(const std::vector> &inputs, int64_t dim = 0); +//------modify-start------------------------------------------ +// PyTorch-aligned scaled_dot_product_attention (FlashAttention/SDPA backend). +// NOTE: Current implementation supports CUDA BF16 causal attention; other modes fall back at model level. +std::shared_ptr ScaledDotProductAttention(const std::shared_ptr &query, + const std::shared_ptr &key, + const std::shared_ptr &value, + const std::shared_ptr &attn_mask = nullptr, + double dropout_p = 0.0, bool is_causal = false, + std::optional scale = std::nullopt, bool enable_gqa = false); +//---------modify-end----------------------------------------- + } // namespace infini_train::nn::function diff --git a/infini_train/src/autograd/scaled_dot_product_attention.cc b/infini_train/src/autograd/scaled_dot_product_attention.cc new file mode 100644 index 00000000..9a2cb52d --- /dev/null +++ b/infini_train/src/autograd/scaled_dot_product_attention.cc @@ -0,0 +1,93 @@ +//------modify-start------------------------------------------ +#include "infini_train/include/autograd/scaled_dot_product_attention.h" + +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> +ScaledDotProductAttention::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 3); + const auto &query = input_tensors[0]; + const auto &key = input_tensors[1]; + const auto &value = input_tensors[2]; + + CHECK_EQ(dropout_p_, 0.0) << "dropout is not supported in current SDPA backend"; + + CHECK_EQ(query->Dims().size(), 4) << "query must be 4D: (B, H, S, D)"; + const auto head_dim = query->Dims().back(); + CHECK_GT(head_dim, 0); + + const float attn_scale = scale_.has_value() ? static_cast(*scale_) + : static_cast(1.0 / std::sqrt(static_cast(head_dim))); + + auto device = query->GetDevice().type(); + + auto [output, stats] = Dispatcher::Instance().Call, std::shared_ptr>>( + {device, "ScaledDotProductAttentionForward"}, query, key, value, attn_scale, is_causal_); + + saved_stats_ = stats; + return {output}; +} + +void ScaledDotProductAttention::SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) { + CHECK_EQ(input_tensors.size(), 3); + CHECK_EQ(output_tensors.size(), 1); + + const auto &query = input_tensors[0]; + const auto &key = input_tensors[1]; + const auto &value = input_tensors[2]; + const auto &output = output_tensors[0]; + + CHECK(saved_stats_ != nullptr) << "SDPA forward must save stats for backward"; + + saved_tensors_ = {query, key, value, output, saved_stats_}; +} + +std::vector> +ScaledDotProductAttention::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(saved_tensors_.size(), 5); + CHECK_EQ(grad_outputs.size(), 1); + + const auto &query = saved_tensors_[0]; + const auto &key = saved_tensors_[1]; + const auto &value = saved_tensors_[2]; + const auto &output = saved_tensors_[3]; + const auto &stats = saved_tensors_[4]; + //------modify-start------------------------------------------ + // cuDNN-frontend SDPA backward assumes dO has the same dtype/layout as O. + // In practice, upstream may produce non-contiguous and/or FP32 grad tensors. + // Force dtype match + contiguous to avoid incorrect memory interpretation and NaNs/corruption. + auto grad_output = grad_outputs[0]; + if (grad_output->Dtype() != query->Dtype()) { + grad_output = std::make_shared(grad_output->To(query->Dtype())); + } + grad_output = grad_output->Contiguous(); + //---------modify-end----------------------------------------- + + CHECK_EQ(query->Dims().size(), 4); + const auto head_dim = query->Dims().back(); + CHECK_GT(head_dim, 0); + + const float attn_scale = scale_.has_value() ? static_cast(*scale_) + : static_cast(1.0 / std::sqrt(static_cast(head_dim))); + + auto device = query->GetDevice().type(); + + auto [grad_query, grad_key, grad_value] + = Dispatcher::Instance() + .Call, std::shared_ptr, std::shared_ptr>>( + {device, "ScaledDotProductAttentionBackward"}, query, key, value, output, stats, grad_output, + attn_scale, is_causal_); + + return {grad_query, grad_key, grad_value}; +} + +} // namespace infini_train::autograd +//---------modify-end----------------------------------------- diff --git a/infini_train/src/kernels/cuda/scaled_dot_product_attention.cu b/infini_train/src/kernels/cuda/scaled_dot_product_attention.cu new file mode 100644 index 00000000..2a057e04 --- /dev/null +++ b/infini_train/src/kernels/cuda/scaled_dot_product_attention.cu @@ -0,0 +1,417 @@ +//------modify-start------------------------------------------ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "glog/logging.h" + +//------modify-start------------------------------------------ +// Minimal cuDNN status checker (avoid multi-line macros). +inline void CudnnCheck(cudnnStatus_t status, const char *file, int line) { + if (status != CUDNN_STATUS_SUCCESS) { + LOG(FATAL) << "CUDNN Error: " << cudnnGetErrorString(status) << " at " << file << ":" << line; + } +} + +#define CUDNN_CHECK(call) CudnnCheck((call), __FILE__, __LINE__) +//---------modify-end----------------------------------------- + +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { +namespace fe = cudnn_frontend; + +namespace { + +struct SdpaKey { + int64_t b = 0; + int64_t h = 0; + int64_t s = 0; + int64_t d = 0; + bool is_causal = false; + uint32_t scale_bits = 0; + + bool operator==(const SdpaKey &other) const { + return b == other.b && h == other.h && s == other.s && d == other.d && is_causal == other.is_causal + && scale_bits == other.scale_bits; + } +}; + +struct SdpaKeyHash { + size_t operator()(const SdpaKey &k) const noexcept { + // A simple 64-bit mix. + size_t h = 1469598103934665603ull; + auto mix = [&](uint64_t v) { + h ^= static_cast(v); + h *= 1099511628211ull; + }; + mix(static_cast(k.b)); + mix(static_cast(k.h)); + mix(static_cast(k.s)); + mix(static_cast(k.d)); + mix(static_cast(k.is_causal)); + mix(static_cast(k.scale_bits)); + return h; + } +}; + +struct CachedGraph { + std::shared_ptr graph; + int64_t workspace_size = 0; +}; + +static std::mutex g_cache_mu; +static std::unordered_map g_fwd_cache; +static std::unordered_map g_bwd_cache; + +static thread_local cudnnHandle_t tls_cudnn_handle = nullptr; + +cudnnHandle_t GetCudnnHandle(cudaStream_t stream) { + if (tls_cudnn_handle == nullptr) { + CUDNN_CHECK(cudnnCreate(&tls_cudnn_handle)); + } + CUDNN_CHECK(cudnnSetStream(tls_cudnn_handle, stream)); + return tls_cudnn_handle; +} + +uint32_t FloatToBits(float x) { + uint32_t u = 0; + static_assert(sizeof(float) == sizeof(uint32_t)); + std::memcpy(&u, &x, sizeof(uint32_t)); + return u; +} + +std::vector ContigStrideBHSD(int64_t b, int64_t h, int64_t s, int64_t d) { + (void)b; + // Layout: (B, H, S, D) contiguous + return {h * s * d, s * d, d, 1}; +} + +CachedGraph BuildFwdGraph(cudnnHandle_t handle, int64_t b, int64_t h, int64_t s, int64_t d, float attn_scale, + bool is_causal) { + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::BFLOAT16) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto stride = ContigStrideBHSD(b, h, s, d); + + constexpr int Q_UID = 1; + constexpr int K_UID = 2; + constexpr int V_UID = 3; + constexpr int O_UID = 4; + constexpr int STATS_UID = 5; + + auto Q = graph->tensor( + fe::graph::Tensor_attributes().set_name("Q").set_uid(Q_UID).set_dim({b, h, s, d}).set_stride(stride)); + + auto K = graph->tensor( + fe::graph::Tensor_attributes().set_name("K").set_uid(K_UID).set_dim({b, h, s, d}).set_stride(stride)); + + auto V = graph->tensor( + fe::graph::Tensor_attributes().set_name("V").set_uid(V_UID).set_dim({b, h, s, d}).set_stride(stride)); + + auto sdpa_options + = fe::graph::SDPA_attributes().set_name("sdpa").set_generate_stats(true).set_attn_scale(attn_scale); + + if (is_causal) { + sdpa_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT) + .set_diagonal_band_right_bound(0); + } + + auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options); + + O->set_output(true).set_uid(O_UID).set_dim({b, h, s, d}).set_stride(stride); + Stats->set_output(true) + .set_uid(STATS_UID) + .set_dim({b, h, s, 1}) + .set_stride({h * s, s, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + + auto build_status = graph->build(handle, {fe::HeurMode_t::A}); + CHECK(build_status.is_good()) << "cudnn-frontend SDPA forward graph build failed: " << build_status.get_message(); + + int64_t workspace_size = 0; + auto ws_status = graph->get_workspace_size(workspace_size); + CHECK(ws_status.is_good()) << "cudnn-frontend get_workspace_size failed: " << ws_status.get_message(); + + return {.graph = std::move(graph), .workspace_size = workspace_size}; +} + +CachedGraph BuildBwdGraph(cudnnHandle_t handle, int64_t b, int64_t h, int64_t s, int64_t d, float attn_scale, + bool is_causal) { + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::BFLOAT16) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto stride = ContigStrideBHSD(b, h, s, d); + + constexpr int Q_UID = 1; + constexpr int K_UID = 2; + constexpr int V_UID = 3; + constexpr int O_UID = 4; + constexpr int STATS_UID = 5; + constexpr int DO_UID = 101; + + constexpr int DQ_UID = 102; + constexpr int DK_UID = 103; + constexpr int DV_UID = 104; + + auto Q = graph->tensor( + fe::graph::Tensor_attributes().set_name("Q").set_uid(Q_UID).set_dim({b, h, s, d}).set_stride(stride)); + + auto K = graph->tensor( + fe::graph::Tensor_attributes().set_name("K").set_uid(K_UID).set_dim({b, h, s, d}).set_stride(stride)); + + auto V = graph->tensor( + fe::graph::Tensor_attributes().set_name("V").set_uid(V_UID).set_dim({b, h, s, d}).set_stride(stride)); + + auto O = graph->tensor( + fe::graph::Tensor_attributes().set_name("O").set_uid(O_UID).set_dim({b, h, s, d}).set_stride(stride)); + + auto dO = graph->tensor( + fe::graph::Tensor_attributes().set_name("dO").set_uid(DO_UID).set_dim({b, h, s, d}).set_stride(stride)); + + auto Stats = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_uid(STATS_UID) + .set_dim({b, h, s, 1}) + .set_stride({h * s, s, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto bwd_options = fe::graph::SDPA_backward_attributes().set_name("sdpa_backward").set_attn_scale(attn_scale); + + if (is_causal) { + bwd_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT) + .set_diagonal_band_right_bound(0); + } + + auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, Stats, bwd_options); + + dQ->set_output(true).set_uid(DQ_UID).set_dim({b, h, s, d}).set_stride(stride); + dK->set_output(true).set_uid(DK_UID).set_dim({b, h, s, d}).set_stride(stride); + dV->set_output(true).set_uid(DV_UID).set_dim({b, h, s, d}).set_stride(stride); + + auto build_status = graph->build(handle, {fe::HeurMode_t::A}); + CHECK(build_status.is_good()) << "cudnn-frontend SDPA backward graph build failed: " << build_status.get_message(); + + int64_t workspace_size = 0; + auto ws_status = graph->get_workspace_size(workspace_size); + CHECK(ws_status.is_good()) << "cudnn-frontend get_workspace_size failed: " << ws_status.get_message(); + + return {.graph = std::move(graph), .workspace_size = workspace_size}; +} + +CachedGraph GetOrCreateFwdGraph(cudnnHandle_t handle, int64_t b, int64_t h, int64_t s, int64_t d, float attn_scale, + bool is_causal) { + SdpaKey key{.b = b, .h = h, .s = s, .d = d, .is_causal = is_causal, .scale_bits = FloatToBits(attn_scale)}; + std::lock_guard lock(g_cache_mu); + auto it = g_fwd_cache.find(key); + if (it != g_fwd_cache.end()) { + return it->second; + } + auto cached = BuildFwdGraph(handle, b, h, s, d, attn_scale, is_causal); + g_fwd_cache.emplace(key, cached); + return cached; +} + +CachedGraph GetOrCreateBwdGraph(cudnnHandle_t handle, int64_t b, int64_t h, int64_t s, int64_t d, float attn_scale, + bool is_causal) { + SdpaKey key{.b = b, .h = h, .s = s, .d = d, .is_causal = is_causal, .scale_bits = FloatToBits(attn_scale)}; + std::lock_guard lock(g_cache_mu); + auto it = g_bwd_cache.find(key); + if (it != g_bwd_cache.end()) { + return it->second; + } + auto cached = BuildBwdGraph(handle, b, h, s, d, attn_scale, is_causal); + g_bwd_cache.emplace(key, cached); + return cached; +} + +} // namespace + +std::tuple, std::shared_ptr> +ScaledDotProductAttentionForward(const std::shared_ptr &query, const std::shared_ptr &key, + const std::shared_ptr &value, float attn_scale, bool is_causal) { + CHECK(query->GetDevice().type() == Device::DeviceType::kCUDA); + //------modify-start------------------------------------------ + // Avoid CHECK_EQ on enum class (would require operator<< overload). + CHECK(query->Dtype() == DataType::kBFLOAT16) << "SDPA forward only supports BF16"; + CHECK(key->Dtype() == DataType::kBFLOAT16); + CHECK(value->Dtype() == DataType::kBFLOAT16); + //---------modify-end----------------------------------------- + + const auto &q_dims = query->Dims(); + CHECK_EQ(q_dims.size(), 4); + const int64_t b = q_dims[0]; + const int64_t h = q_dims[1]; + const int64_t s = q_dims[2]; + const int64_t d = q_dims[3]; + + //------modify-start------------------------------------------ + // Avoid CHECK_EQ on std::vector (requires operator<< overload). + CHECK(key->Dims() == q_dims); + CHECK(value->Dims() == q_dims); + //---------modify-end----------------------------------------- + + auto device = query->GetDevice(); + core::DeviceGuard guard(device); + + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + + auto handle = GetCudnnHandle(cuda_stream); + + //------modify-start------------------------------------------ + // Cache cuDNN-frontend graphs by shape/scale for performance. + auto cached = GetOrCreateFwdGraph(handle, b, h, s, d, attn_scale, is_causal); + //---------modify-end----------------------------------------- + + auto output = std::make_shared(q_dims, DataType::kBFLOAT16, device); + auto stats = std::make_shared(std::vector{b, h, s, 1}, DataType::kFLOAT32, device); + + //------modify-start------------------------------------------ + // Defensive initialization: if the selected cuDNN plan does not fully overwrite outputs + // (e.g., due to masked/unused lanes), this prevents stale NaNs from propagating. + CHECK(cudaMemsetAsync(output->DataPtr(), 0, output->SizeInBytes(), cuda_stream) == cudaSuccess); + CHECK(cudaMemsetAsync(stats->DataPtr(), 0, stats->SizeInBytes(), cuda_stream) == cudaSuccess); + //---------modify-end----------------------------------------- + + std::unordered_map variant_pack; + variant_pack.reserve(5); + variant_pack[1] = query->DataPtr(); + variant_pack[2] = key->DataPtr(); + variant_pack[3] = value->DataPtr(); + variant_pack[4] = output->DataPtr(); + variant_pack[5] = stats->DataPtr(); + + void *workspace_ptr = nullptr; + std::shared_ptr workspace = nullptr; + if (cached.workspace_size > 0) { + workspace = std::make_shared(std::vector{cached.workspace_size}, DataType::kUINT8, device); + workspace_ptr = workspace->DataPtr(); + + //------modify-start------------------------------------------ + // Defensive initialization: some cuDNN plans may read workspace without full overwrite. + CHECK(cudaMemsetAsync(workspace_ptr, 0, static_cast(cached.workspace_size), cuda_stream) + == cudaSuccess); + //---------modify-end----------------------------------------- + } + + auto exec_status = cached.graph->execute(handle, variant_pack, workspace_ptr); + CHECK(exec_status.is_good()) << "cudnn-frontend SDPA forward execute failed: " << exec_status.get_message(); + + return {output, stats}; +} + +std::tuple, std::shared_ptr, std::shared_ptr> +ScaledDotProductAttentionBackward(const std::shared_ptr &query, const std::shared_ptr &key, + const std::shared_ptr &value, const std::shared_ptr &output, + const std::shared_ptr &stats, const std::shared_ptr &grad_output, + float attn_scale, bool is_causal) { + CHECK(query->GetDevice().type() == Device::DeviceType::kCUDA); + //------modify-start------------------------------------------ + // Avoid CHECK_EQ on enum class (would require operator<< overload). + CHECK(query->Dtype() == DataType::kBFLOAT16) << "SDPA backward only supports BF16"; + //---------modify-end----------------------------------------- + + const auto &q_dims = query->Dims(); + CHECK_EQ(q_dims.size(), 4); + const int64_t b = q_dims[0]; + const int64_t h = q_dims[1]; + const int64_t s = q_dims[2]; + const int64_t d = q_dims[3]; + + //------modify-start------------------------------------------ + // Avoid CHECK_EQ on std::vector (requires operator<< overload). + CHECK(key->Dims() == q_dims); + CHECK(value->Dims() == q_dims); + CHECK(output->Dims() == q_dims); + CHECK(grad_output->Dims() == q_dims); + CHECK(stats->Dims() == (std::vector{b, h, s, 1})); + //---------modify-end----------------------------------------- + + auto device = query->GetDevice(); + core::DeviceGuard guard(device); + + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + + auto handle = GetCudnnHandle(cuda_stream); + + //------modify-start------------------------------------------ + // Cache cuDNN-frontend graphs by shape/scale for performance. + auto cached = GetOrCreateBwdGraph(handle, b, h, s, d, attn_scale, is_causal); + //---------modify-end----------------------------------------- + + auto grad_query = std::make_shared(q_dims, DataType::kBFLOAT16, device); + auto grad_key = std::make_shared(q_dims, DataType::kBFLOAT16, device); + auto grad_value = std::make_shared(q_dims, DataType::kBFLOAT16, device); + + //------modify-start------------------------------------------ + // Defensive initialization: ensure gradients are fully overwritten by cuDNN SDPA backward. + CHECK(cudaMemsetAsync(grad_query->DataPtr(), 0, grad_query->SizeInBytes(), cuda_stream) == cudaSuccess); + CHECK(cudaMemsetAsync(grad_key->DataPtr(), 0, grad_key->SizeInBytes(), cuda_stream) == cudaSuccess); + CHECK(cudaMemsetAsync(grad_value->DataPtr(), 0, grad_value->SizeInBytes(), cuda_stream) == cudaSuccess); + //---------modify-end----------------------------------------- + + std::unordered_map variant_pack; + variant_pack.reserve(9); + // inputs + variant_pack[1] = query->DataPtr(); + variant_pack[2] = key->DataPtr(); + variant_pack[3] = value->DataPtr(); + variant_pack[4] = output->DataPtr(); + variant_pack[5] = stats->DataPtr(); + variant_pack[101] = grad_output->DataPtr(); + // outputs + variant_pack[102] = grad_query->DataPtr(); + variant_pack[103] = grad_key->DataPtr(); + variant_pack[104] = grad_value->DataPtr(); + + void *workspace_ptr = nullptr; + std::shared_ptr workspace = nullptr; + if (cached.workspace_size > 0) { + workspace = std::make_shared(std::vector{cached.workspace_size}, DataType::kUINT8, device); + workspace_ptr = workspace->DataPtr(); + + //------modify-start------------------------------------------ + // Defensive initialization: some cuDNN plans may read workspace without full overwrite. + CHECK(cudaMemsetAsync(workspace_ptr, 0, static_cast(cached.workspace_size), cuda_stream) + == cudaSuccess); + //---------modify-end----------------------------------------- + } + + auto exec_status = cached.graph->execute(handle, variant_pack, workspace_ptr); + CHECK(exec_status.is_good()) << "cudnn-frontend SDPA backward execute failed: " << exec_status.get_message(); + + return {grad_query, grad_key, grad_value}; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_SDPA_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_SDPA_KERNEL(ScaledDotProductAttentionForward) +REGISTER_CUDA_SDPA_KERNEL(ScaledDotProductAttentionBackward) + +#undef REGISTER_CUDA_SDPA_KERNEL +//---------modify-end----------------------------------------- diff --git a/infini_train/src/nn/functional.cc b/infini_train/src/nn/functional.cc index b02f185a..68a64f9c 100644 --- a/infini_train/src/nn/functional.cc +++ b/infini_train/src/nn/functional.cc @@ -9,6 +9,10 @@ #include "infini_train/include/autograd/misc.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/softmax.h" + +//------modify-start------------------------------------------ +#include "infini_train/include/autograd/scaled_dot_product_attention.h" +//---------modify-end----------------------------------------- #include "infini_train/include/autograd/transform.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/tensor.h" @@ -79,4 +83,20 @@ std::shared_ptr Softmax(const std::shared_ptr &input, int64_t di std::shared_ptr Sigmoid(const std::shared_ptr &input) { return std::make_shared()->Apply({input})[0]; } + +//------modify-start------------------------------------------ +std::shared_ptr ScaledDotProductAttention(const std::shared_ptr &query, + const std::shared_ptr &key, + const std::shared_ptr &value, + const std::shared_ptr &attn_mask, double dropout_p, + bool is_causal, std::optional scale, bool enable_gqa) { + // Match PyTorch semantics on the signature; currently we only support the common training case. + CHECK(attn_mask == nullptr) << "attn_mask is not supported in current SDPA backend"; + CHECK_EQ(dropout_p, 0.0) << "dropout is not supported in current SDPA backend"; + (void)enable_gqa; + return std::make_shared(dropout_p, is_causal, scale) + ->Apply({query, key, value})[0]; +} +//---------modify-end----------------------------------------- + } // namespace infini_train::nn::function diff --git a/scripts/flash_sdpa_benchmark.bash b/scripts/flash_sdpa_benchmark.bash new file mode 100644 index 00000000..c5d65880 --- /dev/null +++ b/scripts/flash_sdpa_benchmark.bash @@ -0,0 +1,137 @@ +#!/usr/bin/env bash +set -euo pipefail + +#------modify-start------------------------------------------ +# Benchmark baseline vs --flash for GPT-2 and LLaMA-3. +# Outputs logs + a parsed markdown report under a local artifact directory. +# +# Usage: +# bash scripts/flash_sdpa_benchmark.bash --gpu 5 --iters 30 --seq_len 256 +# +# Notes: +# - If /tmp is small/full, we redirect TMPDIR to ~/tmp. +# - Uses shared tinyshakespeare bins if present under /data/shared/InfiniTrain-dev. +#---------modify-end----------------------------------------- + +GPU=0 +ITERS=20 +SEQ_LEN=256 +OUT_DIR="tmp/flash_sdpa" + +GPT2_BIN_DEFAULT="/data/shared/InfiniTrain-dev/data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin" +LLAMA3_BIN_DEFAULT="/data/shared/InfiniTrain-dev/data/llmc/llama3/tinyshakespeare/tiny_shakespeare_train.bin" +GPT2_BIN="$GPT2_BIN_DEFAULT" +LLAMA3_BIN="$LLAMA3_BIN_DEFAULT" + +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU="$2"; shift 2;; + --iters) ITERS="$2"; shift 2;; + --seq_len) SEQ_LEN="$2"; shift 2;; + --out_dir) OUT_DIR="$2"; shift 2;; + --gpt2_bin) GPT2_BIN="$2"; shift 2;; + --llama3_bin) LLAMA3_BIN="$2"; shift 2;; + *) + echo "Unknown arg: $1" >&2 + exit 2 + ;; + esac +done + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +BUILD_DIR="$REPO_ROOT/build" + +mkdir -p "$REPO_ROOT/$OUT_DIR/logs" +mkdir -p "$REPO_ROOT/$OUT_DIR/env" +mkdir -p "$REPO_ROOT/tmp" + +export TMPDIR="$REPO_ROOT/tmp" +export TMP="$REPO_ROOT/tmp" +export TEMP="$REPO_ROOT/tmp" +export CUDA_VISIBLE_DEVICES="$GPU" + +ts="$(date +%Y%m%d_%H%M%S)" + +env_file="$REPO_ROOT/$OUT_DIR/env/env_$ts.txt" +{ + echo "# Environment" + echo "date: $(date)" + echo "hostname: $(hostname)" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" + echo + echo "## nvidia-smi" + nvidia-smi || true + echo + echo "## nvcc --version" + (command -v nvcc && nvcc --version) || (test -x /usr/local/cuda/bin/nvcc && /usr/local/cuda/bin/nvcc --version) || true + echo + echo "## cudnn version header" + (ls /usr/include/cudnn_version_v9.h >/dev/null 2>&1 && grep -n "CUDNN_MAJOR\|CUDNN_MINOR\|CUDNN_PATCHLEVEL" /usr/include/cudnn_version_v9.h) || true + echo + echo "## gcc/g++/cmake" + (command -v gcc && gcc --version | head -n 1) || true + (command -v g++ && g++ --version | head -n 1) || true + (command -v cmake && cmake --version | head -n 1) || true +} > "$env_file" + +function run_one() { + local name="$1"; shift + local log="$REPO_ROOT/$OUT_DIR/logs/${name}_$ts.log" + echo "[RUN] $name" | tee "$log" + echo "CMD: $*" | tee -a "$log" + "$@" 2>&1 | tee -a "$log" + echo "[DONE] $name" | tee -a "$log" +} + +echo "Using GPT2_BIN=$GPT2_BIN" +echo "Using LLAMA3_BIN=$LLAMA3_BIN" + +cd "$BUILD_DIR" + +# GPT-2 config +GPT2_COMMON=( + "./gpt2" + --model d12 + --input_bin "$GPT2_BIN" + --dtype bfloat16 + --batch_size 1 + --sequence_length "$SEQ_LEN" + --total_batch_size "$SEQ_LEN" + --num_iteration "$ITERS" + --freq_generate_txt 1000000 + --sample_every 0 + --val_loss_every 0 + --learning_rate 0 +) + +run_one "gpt2_baseline" "${GPT2_COMMON[@]}" --flash=false +run_one "gpt2_flash" "${GPT2_COMMON[@]}" --flash=true + +# LLaMA3 config +LLAMA3_COMMON=( + "./llama3" + --model llama3 + --input_bin "$LLAMA3_BIN" + --dtype bfloat16 + --batch_size 1 + --sequence_length "$SEQ_LEN" + --total_batch_size "$SEQ_LEN" + --num_iteration "$ITERS" + --freq_generate_txt 1000000 + --sample_every 0 + --val_loss_every 0 +) + +run_one "llama3_baseline" "${LLAMA3_COMMON[@]}" --flash=false +run_one "llama3_flash" "${LLAMA3_COMMON[@]}" --flash=true + +cd "$REPO_ROOT" +python3 scripts/flash_sdpa_parse.py \ + --out_dir "$OUT_DIR" \ + --timestamp "$ts" \ + --seq_len "$SEQ_LEN" \ + --iters "$ITERS" \ + --env_file "$env_file" + +echo +echo "Report generated: $REPO_ROOT/$OUT_DIR/report_$ts.md" diff --git a/scripts/flash_sdpa_parse.py b/scripts/flash_sdpa_parse.py new file mode 100644 index 00000000..990b9aa4 --- /dev/null +++ b/scripts/flash_sdpa_parse.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 + +# ------modify-start------------------------------------------ +# Parse InfiniTrain training logs produced by flash_sdpa_benchmark.bash. +# Extracts: +# - avg latency (ms) excluding step1 warmup +# - avg tokens/s excluding step1 warmup +# - peak used/reserved MB (max over steps) +# - loss (step1) +# Generates a markdown report and a CSV summary. +# ---------modify-end----------------------------------------- + +from __future__ import annotations + +import argparse +import csv +import pathlib +import re +from dataclasses import dataclass +from typing import List, Optional + + +STEP_RE = re.compile( + r"step\s+(?P\d+)/(?P\d+)\s+\|\s+train loss\s+(?P[-+\w\.]+)\s+\|\s+lr\s+(?P[-+\w\.eE]+)\s+\|\s+\((?P[0-9\.]+)\s+ms\s+\|\s+(?P[0-9\.]+)\s+tok/s\s+\|\s+peak used:\s+(?P[0-9\.]+)\s+MB\s+\|\s+peak reserved:\s+(?P[0-9\.]+)\s+MB" +) + + +@dataclass +class RunMetrics: + name: str + steps: List[int] + losses: List[float] + ms: List[float] + toks: List[float] + peak_used_mb: float + peak_reserved_mb: float + + @property + def loss_step1(self) -> Optional[float]: + if not self.steps: + return None + try: + idx = self.steps.index(1) + except ValueError: + return None + return self.losses[idx] + + def avg_ms_excl_warmup(self) -> Optional[float]: + pairs = [(s, v) for s, v in zip(self.steps, self.ms) if s != 1] + if not pairs: + return None + return sum(v for _, v in pairs) / len(pairs) + + def avg_toks_excl_warmup(self) -> Optional[float]: + pairs = [(s, v) for s, v in zip(self.steps, self.toks) if s != 1] + if not pairs: + return None + return sum(v for _, v in pairs) / len(pairs) + + +def parse_log(path: pathlib.Path, name: str) -> RunMetrics: + steps: List[int] = [] + losses: List[float] = [] + ms: List[float] = [] + toks: List[float] = [] + peak_used = 0.0 + peak_reserved = 0.0 + + for line in path.read_text(encoding="utf-8", errors="ignore").splitlines(): + m = STEP_RE.search(line) + if not m: + continue + s = int(m.group("step")) + loss_str = m.group("loss") + if loss_str.lower() == "nan": + loss = float("nan") + else: + loss = float(loss_str) + steps.append(s) + losses.append(loss) + ms.append(float(m.group("ms"))) + toks.append(float(m.group("toks"))) + peak_used = max(peak_used, float(m.group("used"))) + peak_reserved = max(peak_reserved, float(m.group("reserved"))) + + return RunMetrics( + name=name, + steps=steps, + losses=losses, + ms=ms, + toks=toks, + peak_used_mb=peak_used, + peak_reserved_mb=peak_reserved, + ) + + +def fmt(x: Optional[float], nd: int = 2) -> str: + if x is None: + return "n/a" + return f"{x:.{nd}f}" + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--out_dir", required=True) + ap.add_argument("--timestamp", required=True) + ap.add_argument("--seq_len", type=int, required=True) + ap.add_argument("--iters", type=int, required=True) + ap.add_argument("--env_file", required=True) + args = ap.parse_args() + + repo_root = pathlib.Path(__file__).resolve().parents[1] + out_dir = repo_root / args.out_dir + logs_dir = out_dir / "logs" + + runs = {} + for key in ["gpt2_baseline", "gpt2_flash", "llama3_baseline", "llama3_flash"]: + log_path = logs_dir / f"{key}_{args.timestamp}.log" + if not log_path.exists(): + raise SystemExit(f"Missing log: {log_path}") + runs[key] = parse_log(log_path, key) + + # CSV summary + csv_path = out_dir / f"summary_{args.timestamp}.csv" + with csv_path.open("w", newline="", encoding="utf-8") as f: + w = csv.writer(f) + w.writerow( + [ + "run", + "avg_ms_excl_step1", + "avg_tok_s_excl_step1", + "peak_used_mb", + "peak_reserved_mb", + "loss_step1", + ] + ) + for k, r in runs.items(): + w.writerow( + [ + k, + r.avg_ms_excl_warmup(), + r.avg_toks_excl_warmup(), + r.peak_used_mb, + r.peak_reserved_mb, + r.loss_step1, + ] + ) + + def speedup(b: RunMetrics, f: RunMetrics) -> Optional[float]: + bm = b.avg_ms_excl_warmup() + fm = f.avg_ms_excl_warmup() + if bm is None or fm is None or fm == 0: + return None + return bm / fm + + def mem_saving_ratio(b: RunMetrics, f: RunMetrics) -> Optional[float]: + if b.peak_used_mb <= 0: + return None + return (b.peak_used_mb - f.peak_used_mb) / b.peak_used_mb + + # Markdown report + report_path = out_dir / f"report_{args.timestamp}.md" + env_text = pathlib.Path(args.env_file).read_text(encoding="utf-8", errors="ignore") + + gpt2_su = speedup(runs["gpt2_baseline"], runs["gpt2_flash"]) + llama3_su = speedup(runs["llama3_baseline"], runs["llama3_flash"]) + gpt2_mem = mem_saving_ratio(runs["gpt2_baseline"], runs["gpt2_flash"]) + llama3_mem = mem_saving_ratio(runs["llama3_baseline"], runs["llama3_flash"]) + + def loss_diff(a: RunMetrics, b: RunMetrics) -> Optional[float]: + if a.loss_step1 is None or b.loss_step1 is None: + return None + return abs(a.loss_step1 - b.loss_step1) + + gpt2_ld = loss_diff(runs["gpt2_baseline"], runs["gpt2_flash"]) + llama3_ld = loss_diff(runs["llama3_baseline"], runs["llama3_flash"]) + + lines = [] + lines.append(f"# Flash SDPA 性能与正确性报告 ({args.timestamp})") + lines.append("") + lines.append("## 实验配置") + lines.append(f"- seq_len: {args.seq_len}") + lines.append(f"- iters: {args.iters} (统计时排除 step1 warmup)") + lines.append("") + lines.append("## 环境信息") + lines.append("```text") + lines.append(env_text.strip()) + lines.append("```") + lines.append("") + + lines.append("## 指标汇总") + lines.append("") + lines.append( + "| Model | Variant | Avg latency (ms/step) | Avg tok/s | Peak used (MB) | Peak reserved (MB) | step1 loss |" + ) + lines.append("|---|---:|---:|---:|---:|---:|---:|") + + def row(model: str, variant: str, r: RunMetrics) -> str: + return ( + f"| {model} | {variant} | {fmt(r.avg_ms_excl_warmup())} | {fmt(r.avg_toks_excl_warmup())} | " + f"{fmt(r.peak_used_mb, 0)} | {fmt(r.peak_reserved_mb, 0)} | {fmt(r.loss_step1, 6)} |" + ) + + lines.append(row("GPT-2", "baseline", runs["gpt2_baseline"])) + lines.append(row("GPT-2", "flash", runs["gpt2_flash"])) + lines.append(row("LLaMA-3", "baseline", runs["llama3_baseline"])) + lines.append(row("LLaMA-3", "flash", runs["llama3_flash"])) + + lines.append("") + lines.append("## 对比结论") + lines.append("") + lines.append(f"- GPT-2 speedup = {fmt(gpt2_su, 3)}") + lines.append(f"- GPT-2 memory saving ratio = {fmt((gpt2_mem or 0.0) * 100.0, 2)}%") + lines.append(f"- GPT-2 |step1 loss diff| = {fmt(gpt2_ld, 6)}") + lines.append("") + lines.append(f"- LLaMA-3 speedup = {fmt(llama3_su, 3)}") + lines.append( + f"- LLaMA-3 memory saving ratio = {fmt((llama3_mem or 0.0) * 100.0, 2)}%" + ) + lines.append(f"- LLaMA-3 |step1 loss diff| = {fmt(llama3_ld, 6)}") + lines.append("") + lines.append("## 日志文件") + lines.append("") + for k in ["gpt2_baseline", "gpt2_flash", "llama3_baseline", "llama3_flash"]: + lines.append(f"- {args.out_dir}/logs/{k}_{args.timestamp}.log") + + report_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + print(f"Wrote: {report_path}") + print(f"Wrote: {csv_path}") + + +if __name__ == "__main__": + main()