From 77030ca3e974cae0d978fe4631fe5c85785d3f69 Mon Sep 17 00:00:00 2001 From: JYMiracle305 <604951424@qq.com> Date: Fri, 20 Mar 2026 15:54:20 +0800 Subject: [PATCH] feat: extract the common module of Transformer --- CMakeLists.txt | 2 + example/gpt2/main.cc | 10 +- example/gpt2/net.cc | 430 +++------------ example/gpt2/net.h | 150 ----- example/llama3/main.cc | 8 +- example/llama3/net.cc | 511 ++---------------- example/llama3/net.h | 189 ------- .../decode_only_transformer/layer_specs.h | 12 + .../models/decode_only_transformer/model.h | 85 +++ .../include/core/transformer/spec_utils.h | 116 ++++ .../core/transformer/transformer_block.h | 120 ++++ .../core/transformer/transformer_builders.h | 50 ++ .../core/transformer/transformer_config.h | 83 +++ .../core/transformer/transformer_layer.h | 82 +++ .../decode_only_transformer/layer_specs.cc | 64 +++ .../src/core/transformer/spec_utils.cc | 43 ++ .../src/core/transformer/transformer_block.cc | 433 +++++++++++++++ .../core/transformer/transformer_builders.cc | 166 ++++++ .../src/core/transformer/transformer_layer.cc | 250 +++++++++ .../transformer_spec/test_transformer_spec.cc | 329 +++++++++++ 20 files changed, 1962 insertions(+), 1171 deletions(-) delete mode 100644 example/gpt2/net.h delete mode 100644 example/llama3/net.h create mode 100644 infini_train/include/core/models/decode_only_transformer/layer_specs.h create mode 100644 infini_train/include/core/models/decode_only_transformer/model.h create mode 100644 infini_train/include/core/transformer/spec_utils.h create mode 100644 infini_train/include/core/transformer/transformer_block.h create mode 100644 infini_train/include/core/transformer/transformer_builders.h create mode 100644 infini_train/include/core/transformer/transformer_config.h create mode 100644 infini_train/include/core/transformer/transformer_layer.h create mode 100644 infini_train/src/core/models/decode_only_transformer/layer_specs.cc create mode 100644 infini_train/src/core/transformer/spec_utils.cc create mode 100644 infini_train/src/core/transformer/transformer_block.cc create mode 100644 infini_train/src/core/transformer/transformer_builders.cc create mode 100644 infini_train/src/core/transformer/transformer_layer.cc create mode 100644 test/transformer_spec/test_transformer_spec.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index df636b27..b6e86478 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,3 +204,5 @@ link_infini_train_exe(test_precision_check) add_executable(test_lora test/lora/test_lora.cc) link_infini_train_exe(test_lora) +add_executable(test_transformer_spec test/transformer_spec/test_transformer_spec.cc) +link_infini_train_exe(test_transformer_spec) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 8e28af52..8a611d39 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -10,7 +10,9 @@ #include "glog/logging.h" #include "infini_train/include/autocast.h" +#include "infini_train/include/core/models/decode_only_transformer/model.h" #include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/core/transformer/transformer_config.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/lora/lora_utils.h" @@ -35,7 +37,6 @@ #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" -#include "example/gpt2/net.h" // I/O DEFINE_string(input_bin, "", "input .bin to train on"); @@ -100,7 +101,7 @@ constexpr char kDtypeFP32[] = "float32"; constexpr char kDtypeBF16[] = "bfloat16"; // -const std::unordered_map kModelToConfigs = { +const std::unordered_map kModelToConfigs = { {"d12", {.block_size = 1024, .vocab_size = 50257, .n_layer = 12, .n_head = 12, .n_embd = 768}}, {"d24", {.block_size = 1024, .vocab_size = 50257, .n_layer = 24, .n_head = 16, .n_embd = 1024}}, {"d36", {.block_size = 1024, .vocab_size = 50257, .n_layer = 36, .n_head = 20, .n_embd = 1280}}, @@ -187,11 +188,12 @@ void Train(const nn::parallel::Rank &rank) { // ManualSeed(42); // init the model, either from scratch or from OpenAI pretrained checkpoint - GPT2Config model_config; + nn::TransformerConfig model_config = nn::TransformerConfig::GPT2(); std::shared_ptr model = nullptr; if (!FLAGS_llmc_filepath.empty()) { - model = GPT2::FromLLMC(FLAGS_llmc_filepath); + auto gpt2_model = GPT2::FromLLMC(FLAGS_llmc_filepath); + model = gpt2_model; } else if (kModelToConfigs.count(FLAGS_model)) { model_config = kModelToConfigs.at(FLAGS_model); model = std::make_shared(model_config); diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index d000d1cf..d4966cb5 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -1,10 +1,7 @@ -#include "example/gpt2/net.h" - #include #include #include #include -#include #include #include #include @@ -14,20 +11,14 @@ #include "glog/logging.h" #include "example/common/utils.h" -#include "infini_train/include/device.h" -#include "infini_train/include/nn/functional.h" -#include "infini_train/include/nn/init.h" -#include "infini_train/include/nn/modules/container.h" -#include "infini_train/include/nn/modules/linear.h" -#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/core/models/decode_only_transformer/model.h" +#include "infini_train/include/core/transformer/transformer_block.h" +#include "infini_train/include/core/transformer/transformer_config.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" -#include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" -#include "infini_train/include/nn/parallel/utils.h" -#include "infini_train/include/tensor.h" using namespace infini_train; namespace nn = infini_train::nn; @@ -39,297 +30,6 @@ constexpr int kRandomSeed = 42; static std::mt19937 gen{kRandomSeed}; } // namespace -std::vector> -NewGELU::Forward(const std::vector> &x) { - auto &input = x[0]; - return {0.5 * input - * (1.0 + nn::function::Tanh(std::sqrt(2.0 / M_PI) * (input + 0.044715 * nn::function::Pow(input, 3.0))))}; -} - -CausalSelfAttention::CausalSelfAttention(const GPT2Config &config) - : CloneableModule(kType), config_(config), n_head_(config.n_head), n_embd_(config.n_embd) { - auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); - CHECK_EQ(config.n_embd % config.n_head, 0); - CHECK_EQ(n_head_ % tp_world_size, 0) << "n_head must be divisible by TP world size"; - local_n_head_ = n_head_ / tp_world_size; - - // qkv: ColumnParallel (do not gather output) -> each tp_rank gets 3 * (n_embd / tp_world) channels - modules_[kCAttnLayerName] = std::make_shared( - /*in_features=*/n_embd_, - /*out_features=*/3 * n_embd_, - /*bias=*/true, - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - - // proj: RowParallel (input is parallel and output is full) - modules_[kCProjLayerName] = std::make_shared( - /*in_features=*/n_embd_, - /*out_features=*/n_embd_, - /*bias=*/true, - /*reduce_output=*/true, - /*input_is_parallel=*/true, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - - // causal mask: (1, 1, block_size, block_size) - buffers_[kParamBiasName] = nn::function::Tril(nn::function::Ones({config_.block_size, config_.block_size})) - ->View({1, 1, config_.block_size, config_.block_size}); -} - -std::vector> -CausalSelfAttention::Forward(const std::vector> &x) { - auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); - - const auto B = x[0]->Dims()[0]; // bs - const auto C = x[0]->Dims()[2]; // n_embd - const int64_t head_dim = n_embd_ / n_head_; // per-head dim (global) - const int64_t local_C = n_embd_ / tp_world_size; // per-rank hidden - - // (B, T, C) -> ColumnParallelLinear(C, 3*C) -> (B, T, 3 * local_C) - // -> Split -> (3, B, T, local_C) - auto qkv = (*modules_[kCAttnLayerName])(x)[0]->Split(local_C, 2); - - // (B, T, local_C) - auto q = qkv[0]; - auto k = qkv[1]; - auto v = qkv[2]; - - // NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear - const auto T = q->Dims()[1]; - - // View to multi-head: local_n_head * head_dim == local_C - // (B, T, local_C) -> (B, T, h_l, Dh) -> (B, h_l, T, Dh) - k = k->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - - // (B, h_l, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); - // (1, 1, T, T) - auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); - // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) - att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); - // (B, h_l, T, T) - att = nn::function::Softmax(att, -1); - // (B, h_l, T, Dh) - auto y = att->Matmul(v); - // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) - y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); - - // Get full tensor - // (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C) - y = (*modules_[kCProjLayerName])({y})[0]; - // (B, T, C) == (bs, seq_len, n_embd) - return {y}; -} - -MLP::MLP(const GPT2Config &config) : CloneableModule(kType) { - // c_fc: ColumnParallel (input full, output parallel) - modules_[kCFcLayerName] = std::make_shared( - /*in_features=*/config.n_embd, /*out_features=*/4 * config.n_embd, - /*bias=*/true, - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - - modules_[kGeluLayerName] = std::make_shared(); - - // c_proj: RowParallel (input parallel, output full) - modules_[kCProjLayerName] = std::make_shared( - /*in_features=*/4 * config.n_embd, /*out_features=*/config.n_embd, - /*bias=*/true, - /*reduce_output=*/true, - /*input_is_parallel=*/true, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); -} - -std::vector> -MLP::Forward(const std::vector> &x) { - // (B, T, C) -> ColumnParallelLinear(C, 4 * C) -> (B, T, 4 * C_local) - auto x1 = (*modules_[kCFcLayerName])(x); - // (B, T, 4 * C_local) -> GELU -> (B, T, 4 * C_local) - auto x2 = (*modules_[kGeluLayerName])(x1); - // (B, T, 4 * C_local) -> RowParallelLinear(4 * C, C) -> (B, T, C) - auto x3 = (*modules_[kCProjLayerName])(x2); - // (B, T, C) - return x3; -} - -Block::Block(const GPT2Config &config) : CloneableModule(kType) { - modules_[kLn1LayerName] = std::make_shared(std::vector{config.n_embd}); - modules_[kAttnLayerName] = std::make_shared(config); - modules_[kLn2LayerName] = std::make_shared(std::vector{config.n_embd}); - modules_[kMlpLayerName] = std::make_shared(config); -} - -std::vector> -Block::Forward(const std::vector> &x) { - // (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) -> attention -> (bs, seq_len, n_embd) - // -> Add -> (bs, seq_len, n_embd) - auto x1 = x[0] + (*modules_[kAttnLayerName])((*modules_[kLn1LayerName])(x))[0]; - // (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) -> MLP -> (bs, seq_len, n_embd) - // -> Add -> (bs, seq_len, n_embd) - auto x2 = x1 + (*modules_[kMlpLayerName])((*modules_[kLn2LayerName])({x1}))[0]; - // (bs, seq_len, n_embd) - return {x2}; -} - -GPT2FirstStage::GPT2FirstStage(const GPT2Config &config) : CloneableModule(kType), config_(config) { - modules_[kWTELayerName] = std::make_shared( - config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); - modules_[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); -} - -std::vector> -GPT2FirstStage::Forward(const std::vector> &input) { - // (B, T) - auto x1 = input[0]; - CHECK_LE(x1->Dims()[1], config_.block_size) - << "Cannot forward sequence of length " << x1->Dims()[1] << ", block size is only " << config_.block_size; - const auto device = x1->GetDevice(); - - // (T_local) - // NOTE(zbl): Slice pos sequence when SP is enabled - auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); - auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); - int tp_rank = 0; - if (tp_world_size > 1) { - auto tp_group = nn::parallel::ProcessGroupFactory::Instance(device.type()) - ->Get(nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); - tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); - } - int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1]; - int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; - auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); - - // (B, T) -> Embedding(V_local, C) -> (B, T, C) - auto tok_emb = (*modules_[kWTELayerName])({x1})[0]; - - // (T) -> Embedding(T_max, C) -> (T, C) - auto pos_emb = (*modules_[kWPELayerName])({pos})[0]; - // (B, T, C) - return {tok_emb + pos_emb}; -} - -GPT2Chunk::GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer) - : CloneableModule(kType), config_(config) { - std::vector> h; - for (int64_t i = start_layer; i < end_layer; ++i) { - auto layer = std::make_shared(config); - h.push_back(layer); - } - modules_[kHLayerName] = std::make_shared(std::move(h)); -} - -std::vector> -GPT2Chunk::Forward(const std::vector> &x) { - auto x1 = x[0]; - // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = (*h)({x1})[0]; } - return {x1}; -} - -GPT2LastStage::GPT2LastStage(const GPT2Config &config) : CloneableModule(kType), config_(config) { - modules_[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); - // don't init this one, we will tie weights - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); -} - -std::vector> -GPT2LastStage::Forward(const std::vector> &x) { - // (B, T, C) -> Layernorm -> (B, T, C) - auto x1 = (*modules_[kLnFLayerName])(x); - - // TODO(dcj): add inference-time mini-optimization - // (B, T, C) -> Linear(C, V) -> (B, T, V) - return (*modules_[kLMHeadLayerName])(x1); -} - -GPT2::GPT2(const GPT2Config &config) - : CloneableModule(kType), config_(config), - stage_info_(nn::parallel::PipelineParallel::GetStageInfo( - config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, - nn::parallel::global::GetVirtualPipelineParallelSize())) { - auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); - - // NOTE(zbl): VocabParallelEmbedding requires vocab_size % tp_size == 0 - // Megatron-LM has an optional argument `--make-vocab-size-divisible-by`, would do padding to vocab - // Here we introduce padding by default, might need modify Tokenizer correspondingly later - CHECK_EQ(config.vocab_size % tp_world_size, 0) << "Vocab size should be divisible by TP world size"; - - std::unordered_map> transformer; - if (stage_info_.is_first_stage) { - modules_[kPPFirstStageName] = std::make_shared(config_); - transformer[GPT2FirstStage::kWTELayerName] - = modules_[kPPFirstStageName]->mutable_module(GPT2FirstStage::kWTELayerName); - transformer[GPT2FirstStage::kWPELayerName] - = modules_[kPPFirstStageName]->mutable_module(GPT2FirstStage::kWPELayerName); - } - - { - std::map>> start_layer_to_layer_size_and_chunk; - for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { - const auto [start_layer, end_layer] = stage_info_.layer_ranges_per_chunk[chunk_idx]; - auto chunk = std::make_shared(config_, start_layer, end_layer); - start_layer_to_layer_size_and_chunk[start_layer] = std::make_pair(end_layer - start_layer, chunk); - } - std::vector> h; - int chunk_idx = 0; - for (auto &[start_layer, layer_size_and_chunk] : start_layer_to_layer_size_and_chunk) { - auto [layer_size, chunk] = layer_size_and_chunk; - for (int idx = 0; idx < layer_size; ++idx) { - h.push_back(chunk->mutable_module(GPT2Chunk::kHLayerName)->mutable_module(std::to_string(idx))); - } - modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)] = std::move(chunk); - ++chunk_idx; - } - transformer[GPT2Chunk::kHLayerName] = std::make_shared(std::move(h)); - } - - if (stage_info_.is_last_stage) { - modules_[kPPLastStageName] = std::make_shared(config_); - transformer[GPT2LastStage::kLnFLayerName] - = modules_[kPPLastStageName]->mutable_module(GPT2LastStage::kLnFLayerName); - modules_[GPT2LastStage::kLMHeadLayerName] - = modules_[kPPLastStageName]->mutable_module(GPT2LastStage::kLMHeadLayerName); - } - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); - - // FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation - // TODO: Implement real GPT-2 weight tying: make lm_head.weight share the exact same Parameter/Tensor (same - // shared_ptr/storage) as transformer.wte.weight (pointer aliasing, not value copy), and ensure the tie is applied - // after loading weights so it won't be overwritten. Also fix GPT2::FromLLMC() loading logic to respect weight tying - // (do not create/load a separate lm_head.weight tensor; load once into the tied weight) so parameter counting - // matches PyTorch/PEFT. - if (nn::parallel::global::GetPipelineParallelSize() == 1) { - // https://paperswithcode.com/method/weight-tying - *mutable_module(kTransformerLayerName) - ->mutable_module(GPT2FirstStage::kWTELayerName) - ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) - = module(GPT2LastStage::kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); - } -} - -std::vector> -GPT2::Forward(const std::vector> &x) { - auto x1 = (*modules_[kPPFirstStageName])(x); - for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { - x1 = (*modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)])(x1); - } - return (*modules_[kPPLastStageName])(x1); -} - std::shared_ptr GPT2::FromPretrained(ModelType model_type) { // TODO(dcj): implement this later LOG(FATAL) << "Not implemented yet"; @@ -379,12 +79,15 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { const auto padded_vocab_size = BytesToType(header, 28); // NOTE(zbl): vocab_size needs to be padded to multiple of TP size const auto model_vocab_size = tp_size > 1 ? padded_vocab_size : vocab_size; - auto local_gpt2 = std::make_shared(GPT2Config{.block_size = block_size, - .vocab_size = model_vocab_size, - .original_vocab_size = vocab_size, - .n_layer = n_layer, - .n_head = n_head, - .n_embd = n_embd}); + + nn::TransformerConfig gpt2_config = nn::TransformerConfig::GPT2(); + gpt2_config.block_size = block_size; + gpt2_config.vocab_size = model_vocab_size; + gpt2_config.original_vocab_size = vocab_size; + gpt2_config.n_layer = n_layer; + gpt2_config.n_head = n_head; + gpt2_config.n_embd = n_embd; + auto local_gpt2 = std::make_shared(gpt2_config); LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size << " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head @@ -430,12 +133,12 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { // local: (vocab_size_per_partition, n_embd) if (is_first_stage) { auto &transformer_wte_weight - = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2FirstStage::kWTELayerName, + = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerFirstStage::kWTELayerName, nn::parallel::VocabParallelEmbedding::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd, v_start, vpp); } else if (pp_size > 1 && is_last_stage) { - auto &lm_head_weight = state_dict[std::format("{}.{}", GPT2LastStage::kLMHeadLayerName, + auto &lm_head_weight = state_dict[std::format("{}.{}", nn::TransformerLastStage::kLMHeadLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(lm_head_weight->DataPtr()), model_vocab_size, n_embd, v_start, vpp); @@ -451,8 +154,9 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { if (is_first_stage) { // transformer.wpe.weight - auto &transformer_wpe_weight = state_dict[std::format( - "{}.{}.{}", GPT2::kTransformerLayerName, GPT2FirstStage::kWPELayerName, nn::Embedding::kParamWeightName)]; + auto &transformer_wpe_weight + = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerFirstStage::kWPELayerName, + nn::Embedding::kParamWeightName)]; ReadMatrixAllFloat(ifs, static_cast(transformer_wpe_weight->DataPtr()), block_size, n_embd); } else { size_t wpe_bytes = block_size * n_embd * sizeof(float); @@ -463,9 +167,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { int local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kLn1LayerName, - nn::LayerNorm::kParamWeightName)]; + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kLn1LayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; } else { @@ -478,9 +183,9 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kLn1LayerName, - nn::LayerNorm::kParamBiasName)]; + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kLn1LayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; } else { @@ -493,10 +198,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, - GPT2Chunk::kHLayerName, std::to_string(local_layer_index), - Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName, + nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T // However, each tp_rank needs to get [q_i|k_i|v_i].T, so we need to jump and read them @@ -536,10 +241,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, - GPT2Chunk::kHLayerName, std::to_string(local_layer_index), - Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, - nn::parallel::ColumnParallelLinear::kParamBiasName)]; + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName, + nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn] // However, each tp_rank needs to get [q_i|k_i|v_i], so we need to jump and read them @@ -578,10 +283,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, - GPT2Chunk::kHLayerName, std::to_string(local_layer_index), - Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName, + nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, in_pp); ++local_layer_index; @@ -595,10 +300,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, - GPT2Chunk::kHLayerName, std::to_string(local_layer_index), - Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamBiasName)]; + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName, + nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; } else { @@ -611,9 +316,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kLn2LayerName, - nn::LayerNorm::kParamWeightName)]; + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kLn2LayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; } else { @@ -626,9 +332,9 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kLn2LayerName, - nn::LayerNorm::kParamBiasName)]; + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kLn2LayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; } else { @@ -641,10 +347,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kMlpLayerName, nn::MLP::kCFcLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); ++local_layer_index; } else { @@ -657,10 +363,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamBiasName)]; + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kMlpLayerName, nn::MLP::kCFcLayerName, + nn::parallel::ColumnParallelLinear::kParamBiasName)]; ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); ++local_layer_index; } else { @@ -673,10 +379,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kMlpLayerName, nn::MLP::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, in4_pp); ++local_layer_index; @@ -690,10 +396,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamBiasName)]; + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kMlpLayerName, nn::MLP::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; } else { @@ -704,12 +410,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { if (is_last_stage) { // transformer.ln_f.weight - auto &transformer_ln_f_weight = state_dict[std::format( - "{}.{}.{}", GPT2::kTransformerLayerName, GPT2LastStage::kLnFLayerName, nn::LayerNorm::kParamWeightName)]; + auto &transformer_ln_f_weight + = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerLastStage::kLnFLayerName, + nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_weight->DataPtr()), n_embd); // transformer.ln_f.bias - auto &transformer_ln_f_bias = state_dict[std::format( - "{}.{}.{}", GPT2::kTransformerLayerName, GPT2LastStage::kLnFLayerName, nn::LayerNorm::kParamBiasName)]; + auto &transformer_ln_f_bias + = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerLastStage::kLnFLayerName, + nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_bias->DataPtr()), n_embd); } else { size_t ln_f_w_bytes = n_embd * sizeof(float); diff --git a/example/gpt2/net.h b/example/gpt2/net.h deleted file mode 100644 index 4faf5451..00000000 --- a/example/gpt2/net.h +++ /dev/null @@ -1,150 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include "glog/logging.h" - -#include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" -#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" -#include "infini_train/include/tensor.h" - -struct GPT2Config { - int64_t block_size = 1024; - int64_t vocab_size = 50304; - int64_t original_vocab_size = 50257; - int64_t n_layer = 12; - int64_t n_head = 12; - int64_t n_embd = 768; -}; - -class NewGELU : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "NewGELU"; - NewGELU() : CloneableModule(kType) {} - - std::vector> - Forward(const std::vector> &x) override; -}; - -class CausalSelfAttention : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "CausalSelfAttention"; - static constexpr char kCAttnLayerName[] = "c_attn"; - static constexpr char kCProjLayerName[] = "c_proj"; - - static constexpr char kParamBiasName[] = "bias"; - - explicit CausalSelfAttention(const GPT2Config &config); - - std::vector> - Forward(const std::vector> &x) override; - -private: - GPT2Config config_; - int64_t n_head_ = 0; - int64_t n_embd_ = 0; - - int64_t local_n_head_ = 0; -}; - -class MLP : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "MLP"; - static constexpr char kCFcLayerName[] = "c_fc"; - static constexpr char kGeluLayerName[] = "gelu"; - static constexpr char kCProjLayerName[] = "c_proj"; - - explicit MLP(const GPT2Config &config); - - std::vector> - Forward(const std::vector> &x) override; -}; - -class Block : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "Block"; - static constexpr char kLn1LayerName[] = "ln_1"; - static constexpr char kAttnLayerName[] = "attn"; - static constexpr char kLn2LayerName[] = "ln_2"; - static constexpr char kMlpLayerName[] = "mlp"; - - explicit Block(const GPT2Config &config); - - std::vector> - Forward(const std::vector> &x) override; -}; - -class GPT2FirstStage : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "GPT2FirstStage"; - static constexpr char kWTELayerName[] = "wte"; - static constexpr char kWPELayerName[] = "wpe"; - - explicit GPT2FirstStage(const GPT2Config &config); - - std::vector> - Forward(const std::vector> &x) override; - -private: - const GPT2Config config_; -}; - -class GPT2Chunk : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "GPT2Chunk"; - static constexpr char kHLayerName[] = "h"; - - GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer); - - std::vector> - Forward(const std::vector> &x) override; - -private: - const GPT2Config config_; -}; - -class GPT2LastStage : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "GPT2LastStage"; - static constexpr char kLnFLayerName[] = "ln_f"; - static constexpr char kLMHeadLayerName[] = "lm_head"; - - explicit GPT2LastStage(const GPT2Config &config); - - std::vector> - Forward(const std::vector> &x) override; - -private: - const GPT2Config config_; -}; - -class GPT2 : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "GPT2"; - static constexpr char kTransformerLayerName[] = "transformer"; - - enum class ModelType : int8_t { - kGPT2, - kGPT2Medium, - kGPT2Large, - kGPT2XL, - }; - - explicit GPT2(const GPT2Config &config); - - std::vector> - Forward(const std::vector> &x) override; - - static std::shared_ptr FromPretrained(ModelType model_type); - static std::shared_ptr FromLLMC(const std::string &filepath); - - int GetChunkSize() const; - -private: - const GPT2Config config_; - const infini_train::nn::parallel::StageInfo stage_info_; -}; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index acc20ac4..a4581b2b 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -8,7 +8,9 @@ #include "glog/logging.h" #include "infini_train/include/autocast.h" +#include "infini_train/include/core/models/decode_only_transformer/model.h" #include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/core/transformer/transformer_config.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/lora/lora_utils.h" @@ -34,7 +36,6 @@ #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" -#include "example/llama3/net.h" // I/O DEFINE_string(input_bin, "", "input .bin to train on"); @@ -167,10 +168,11 @@ void Train(const nn::parallel::Rank &rank) { // rng / reproducibility // ManualSeed(42); - LLaMA3Config model_config = LLaMA3Config(); + nn::TransformerConfig model_config = nn::TransformerConfig::LLaMA3(); std::shared_ptr model = nullptr; if (!FLAGS_llmc_filepath.empty()) { - model = LLaMA3::FromLLMC(FLAGS_llmc_filepath); + auto llama3_model = LLaMA3::FromLLMC(FLAGS_llmc_filepath); + model = llama3_model; } else { model = std::make_shared(model_config); } diff --git a/example/llama3/net.cc b/example/llama3/net.cc index a50fb831..d57646dc 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -1,5 +1,3 @@ -#include "example/llama3/net.h" - #include #include #include @@ -14,18 +12,12 @@ #include "glog/logging.h" #include "example/common/utils.h" +#include "infini_train/include/core/models/decode_only_transformer/model.h" +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_block.h" +#include "infini_train/include/core/transformer/transformer_config.h" #include "infini_train/include/device.h" -#include "infini_train/include/nn/functional.h" -#include "infini_train/include/nn/init.h" -#include "infini_train/include/nn/modules/container.h" -#include "infini_train/include/nn/modules/linear.h" -#include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/nn/modules/normalization.h" -#include "infini_train/include/nn/modules/sparse.h" -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" -#include "infini_train/include/tensor.h" using namespace infini_train; namespace nn = infini_train::nn; @@ -37,415 +29,6 @@ constexpr int kRandomSeed = 42; static std::mt19937 gen{kRandomSeed}; } // namespace -namespace { -// Used in Grouped Query Attention(GQA), broadcasts the key and value tensors -// FIXME(zbl): implement Expand() instead of using RepeatInterleave() -std::shared_ptr RepeatKV(const std::shared_ptr &x, int64_t n_rep) { - const auto &shape = x->Dims(); - const int64_t B = shape[0], T = shape[1], H = shape[2], D = shape[3]; - if (n_rep == 1) { - return x; - } - return x->View({B, T, H, 1, D})->RepeatInterleave(n_rep, 3)->Contiguous()->View({B, T, H * n_rep, D}); -} - -// ----------------------------------------------------------------- -// RoPE related -// NOTE(zbl): this RoPE implementation has no "learnable" params, as is stated in LLaMA paper -std::shared_ptr ReshapeForBroadcast(const std::shared_ptr &freqs_cis, - const std::shared_ptr &x) { - // freqs_cis: (T, D / 2, 2) - CHECK(freqs_cis != nullptr) << "freqs_cis is null."; - const auto &x_shape = x->Dims(); // (B, T, H, D) - CHECK_GE(x_shape.size(), 4); - const int64_t T = x_shape[1]; - const int64_t D = x_shape[3]; - CHECK_EQ(freqs_cis->Dims()[0], x_shape[1]); - CHECK_EQ(freqs_cis->Dims()[1], x_shape[3] / 2); - std::vector target_shape = {1, T, 1, D / 2, 2}; - return freqs_cis->View(target_shape); -} - -// TODO(zbl): ApplyScaling(const std::shared_ptr &) when use_scaled -// std::shared_ptr ApplyScaling(const std::shared_ptr &freqs, float old_context_len = 8192) {} - -std::tuple, std::shared_ptr> -ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr &xk, - const std::shared_ptr &freqs_cis) { - // Shape assumptions: xq: (B, T, H, D) - auto cos_sin = ReshapeForBroadcast(freqs_cis, xq); // -> (1, T, 1, D/2, 2) - std::vector target_shape(cos_sin->Dims().begin(), cos_sin->Dims().end() - 1); - auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2) - auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2) - - auto slice_pair = [](const std::shared_ptr &x) { - auto even = x->Slice(-1, 0, x->Dims().back(), 2); - auto odd = x->Slice(-1, 1, x->Dims().back(), 2); - return std::make_pair(even, odd); - }; - - auto [q_even, q_odd] = slice_pair(xq); - auto q_rotated_left = q_even * cos - q_odd * sin; - auto q_rotated_right = q_even * sin + q_odd * cos; - auto q_rotated - = nn::function::Stack(std::vector>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2); - - auto [k_even, k_odd] = slice_pair(xk); - auto k_rotated_left = k_even * cos - k_odd * sin; - auto k_rotated_right = k_even * sin + k_odd * cos; - auto k_rotated - = nn::function::Stack(std::vector>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2); - - return {q_rotated, k_rotated}; -} - -std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false, - infini_train::Device device = Device()) { - DataType dtype = DataType::kFLOAT32; - CHECK_GE(dim, 2) << "dim must be >= 2 for slicing"; - auto arange = nn::init::Arange(0, dim, dtype, device)->Slice(0, 0, dim, 2); - auto freqs = 1.0f / nn::function::Pow(theta, arange / float(dim)); - // TODO(zbl): use_scaled - // if (use_scaled) { - // freqs = ApplyScaling(freqs, 8192.0f); - // } - auto t = nn::init::Arange(0, end, dtype, device); - // (end, dim / 2) - auto freqs_outer = t->Outer(freqs); - auto cos = nn::function::Cos(freqs_outer); - auto sin = nn::function::Sin(freqs_outer); - // NOTE(zbl): torch script uses cis expression, here use stack - // (end, dim / 2, 2) - auto freqs_cis = nn::function::Stack(std::vector>{cos, sin}, -1)->Contiguous(); - return freqs_cis; -} - -} // namespace - -std::vector> SwiGLU::Forward(const std::vector> &x) { - return {x[0] * nn::function::Sigmoid(x[0])}; -} - -RMSNorm::RMSNorm(int64_t dim, float eps, infini_train::Device device) : CloneableModule(kType), eps_(eps) { - parameters_[kParamWeightName] - = std::make_shared(std::vector{dim}, DataType::kFLOAT32, device)->RequiresGrad(); - nn::init::Ones(parameters_[kParamWeightName]); -} - -std::vector> RMSNorm::Forward(const std::vector> &x) { - // broadcasted Mul([4, 64, 2048] * [4, 64, 1]) - auto norm = x[0] * nn::function::Rsqrt(nn::function::Mean(nn::function::Pow(x[0], 2), -1, true) + eps_); - return {norm * parameters_[kParamWeightName]}; -} - -CausalSelfAttention::CausalSelfAttention(const LLaMA3Config &config) - : CloneableModule(kType), config_(config), n_head_(config.n_head), n_embd_(config.n_embd), - n_kv_head_(config.n_kv_head), n_rep_(config.n_head / config.n_kv_head), head_dim_(config.n_embd / config.n_head) { - CHECK_LE(config.n_kv_head, config.n_head); - CHECK_EQ(config.n_head % config.n_kv_head, 0); - CHECK_EQ(config.n_embd % config.n_head, 0); - - int64_t qkv_dim = (config.n_head + 2 * n_kv_head_) * head_dim_; - // qkv: ColumnParallel (do not gather output) - modules_[kCAttnLayerName] = std::make_shared( - /*in_features=*/n_embd_, - /*out_features=*/qkv_dim, - /*bias=*/false, - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - - // proj: RowParallel (input is parallel and output is full) - modules_[kCProjLayerName] = std::make_shared( - /*in_features=*/n_embd_, - /*out_features=*/n_embd_, - /*bias=*/false, - /*reduce_output=*/true, - /*input_is_parallel=*/true, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); -} - -std::vector> CausalSelfAttention::Forward(const std::vector> &x) { - const auto B = x[0]->Dims()[0]; // bs - const auto C = x[0]->Dims()[2]; // n_embd - - const auto tp_size = nn::parallel::global::GetTensorParallelSize(); - - const auto C_local = C / tp_size; - const auto H_local = n_head_ / tp_size; - const auto KV_local = n_kv_head_ / tp_size; - const auto D = head_dim_; // n_embd / n_head - - const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; - const auto start_pos = x.size() > 2 ? x[2] : nullptr; - const auto mask = x.size() > 3 ? x[3] : nullptr; - CHECK(freqs_cis != nullptr) << "freqs_cis is null."; - - // (B, T, C) -> (B, T, (H + 2 * n_kv_head) * D) - auto qkv = (*modules_[kCAttnLayerName])({x[0]})[0]; - // NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear - const auto T = qkv->Dims()[1]; - // NOTE(zbl): torch script uses torch.split({...}, dim) to split tensors into sub-tensors in different sizes - // use Slice() to work around here - const int64_t q_size_local = H_local * D; - const int64_t kv_size_local = KV_local * D; - // -> Split into q, k, v - // q: (B, T, H_local, D) - auto q = qkv->Slice(2, 0, q_size_local)->View({B, T, H_local, D}); - // k: (B, T, KV_local, D) - auto k = qkv->Slice(2, q_size_local, q_size_local + kv_size_local)->View({B, T, KV_local, D}); - // v: (B, T, KV_local, D) - auto v = qkv->Slice(2, q_size_local + kv_size_local, q_size_local + 2 * kv_size_local)->View({B, T, KV_local, D}); - - // -> RoPE on q, k - // q: (B, T, H_local, D) - // k: (B, T, KV_local, D) - std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis); - - // TODO(zbl): use kv cache during inference - // if (use_kv_) { ... } - - // align n_head in GQA - // (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV - k = RepeatKV(k, n_rep_); - v = RepeatKV(v, n_rep_); - - // (B, T, H_local, D) -> (B, H_local, T, D) - q = q->Transpose(1, 2); - k = k->Transpose(1, 2); - v = v->Transpose(1, 2); - - // TODO(zbl): support flash attention later - // if (flash_) { ... } - - // manual implementation of attention - // this materializes the large (T,T) matrix for all the queries and keys - - // q: (B, H_local, T, D) - // k: (B, H_local, T, D) -> (B, H_local, D, T) - // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(D))); - if (mask) { - // mask: (1, 1, T, T) - att = att->MaskedFill(mask, std::numeric_limits::lowest()); - } - // (B, H_local, T, T) - att = nn::function::Softmax(att, -1); - // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D) - auto y = att->Matmul(v); - // (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local) - y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local}); - // output projection - // (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C) - y = (*modules_[kCProjLayerName])({y})[0]; - // (B, H, C) == (bs, seq_len, n_embd) - return {y}; -} - -MLP::MLP(const LLaMA3Config &config) : CloneableModule(kType) { - hidden_dim_ = 4 * config.n_embd; - hidden_dim_ = int(2 * hidden_dim_ / 3); - // use custom dim factor multiplier - if (config.ffn_dim_multiplier.has_value()) { - hidden_dim_ = int(config.ffn_dim_multiplier.value() * hidden_dim_); - } - hidden_dim_ = config.multiple_of * ((hidden_dim_ + config.multiple_of - 1) / config.multiple_of); - - // c_fc: ColumnParallel (input full, output parallel) - modules_[kCFcLayerName] = std::make_shared( - /*in_features=*/config.n_embd, /*out_features=*/hidden_dim_, - /*bias=*/false, - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - - // c_fc2: ColumnParallel (input full, output parallel) - modules_[kCFc2LayerName] = std::make_shared( - /*in_features=*/config.n_embd, /*out_features=*/hidden_dim_, - /*bias=*/false, - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - - modules_[kSiluLayerName] = std::make_shared(); - - // c_proj: RowParallel (input parallel, output full) - modules_[kCProjLayerName] = std::make_shared( - /*in_features=*/hidden_dim_, /*out_features=*/config.n_embd, - /*bias=*/false, - /*reduce_output=*/true, - /*input_is_parallel=*/true, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); -} - -std::vector> MLP::Forward(const std::vector> &x) { - // (bs, seq_len, n_embd) -> Linear(n_embd, hidden_dim) -> (bs, seq_len, hidden_dim) - auto x1 = (*modules_[kCFcLayerName])(x)[0]; - // (bs, seq_len, n_embd) -> Linear(n_embd, hidden_dim) -> (bs, seq_len, hidden_dim) - auto x2 = (*modules_[kCFc2LayerName])(x)[0]; - // (bs, seq_len, hidden_dim) -> SwiGLU -> (bs, seq_len, hidden_dim) - x2 = (*modules_[kSiluLayerName])({x2})[0]; - // (bs, seq_len, hidden_dim) - auto x3 = x1 * x2; - // (bs, seq_len, hidden_dim) -> Linear(hidden_dim, n_embd) -> (bs, seq_len, n_embd) - auto x4 = (*modules_[kCProjLayerName])({x3}); - // (bs, seq_len, n_embd) - return x4; -} - -Block::Block(const LLaMA3Config &config) : CloneableModule(kType) { - modules_[kLn1LayerName] = std::make_shared(config.n_embd, config.norm_eps); - modules_[kAttnLayerName] = std::make_shared(config); - modules_[kLn2LayerName] = std::make_shared(config.n_embd, config.norm_eps); - modules_[kMlpLayerName] = std::make_shared(config); -} - -std::vector> Block::Forward(const std::vector> &x) { - const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; - const auto start_pos = x.size() > 2 ? x[2] : nullptr; - const auto mask = x.size() > 3 ? x[3] : nullptr; - - // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) -> attention -> (bs, seq_len, n_embd) - // -> Add -> (bs, seq_len, n_embd) - auto x1 = x[0] - + (*modules_[kAttnLayerName])(std::vector>{(*modules_[kLn1LayerName])({x[0]})[0], - freqs_cis, start_pos, mask})[0]; - // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) -> MLP -> (bs, seq_len, n_embd) - // -> Add -> (bs, seq_len, n_embd) - auto x2 - = x1 + (*modules_[kMlpLayerName])(std::vector>((*modules_[kLn2LayerName])({x1})))[0]; - // (bs, seq_len, n_embd) - return {x2}; -} - -LLaMA3FirstStage::LLaMA3FirstStage(const LLaMA3Config &config) : CloneableModule(kType), config_(config) { - modules_[LLaMA3FirstStage::kWTELayerName] = std::make_shared( - config.vocab_size, config.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); -} - -std::vector> LLaMA3FirstStage::Forward(const std::vector> &x) { - return (*modules_[LLaMA3FirstStage::kWTELayerName])(x); -} - -LLaMA3Chunk::LLaMA3Chunk(const LLaMA3Config &config, int start_layer, int end_layer) - : CloneableModule(kType), config_(config) { - std::vector> h; - for (int64_t i = start_layer; i < end_layer; ++i) { - auto layer = std::make_shared(config); - h.push_back(layer); - } - modules_[LLaMA3Chunk::kHLayerName] = std::make_shared(std::move(h)); -} - -std::vector> LLaMA3Chunk::Forward(const std::vector> &x) { - auto x1 = x[0]; - const auto device = x1->GetDevice(); - // Init freqs_cis on device only once - // TODO(zbl): consider moving this to model construction - if (buffers_[kFreqsCisName] == nullptr) { - buffers_[kFreqsCisName] = PrecomputeFreqsCis(config_.n_embd / config_.n_head, config_.block_size * 2, - config_.rope_theta, config_.use_scaled_rope, device); - } - - // TODO(dcj): check if this shape is correct - const auto t = x1->Dims()[1] * nn::parallel::global::GetSequenceParallelSize(); // full_seq_len - - // TODO(zbl): dynamic start_pos - int64_t start_pos = 0; - auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + t, 1); - - // TODO(lzm): add dtype support for nn::function::Ones later - std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(x1->GetDevice())); - std::shared_ptr mask = nn::function::Triu(ones, 1)->View({1, 1, t, t}); - - std::shared_ptr start_pos_ptr = nullptr; - - // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *std::dynamic_pointer_cast(modules_[LLaMA3Chunk::kHLayerName])) { - x1 = (*h)({x1, freqs_view, start_pos_ptr, mask})[0]; - } - return {x1}; -} - -LLaMA3LastStage::LLaMA3LastStage(const LLaMA3Config &config) : CloneableModule(kType), config_(config) { - modules_[kLnFLayerName] = std::make_shared(config.n_embd, config.norm_eps); - // NOTE(zbl): weight-tying is possible but torch script did not do so - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); -} - -std::vector> LLaMA3LastStage::Forward(const std::vector> &x) { - // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) - auto x1 = (*modules_[kLnFLayerName])(x); - - // TODO(zbl): add inference-time mini-optimization - // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) - return (*modules_[kLMHeadLayerName])(x1); -} - -LLaMA3::LLaMA3(const LLaMA3Config &config) - : CloneableModule(kType), config_(config), - stage_info_(nn::parallel::PipelineParallel::GetStageInfo( - config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, - nn::parallel::global::GetVirtualPipelineParallelSize())) { - std::unordered_map> transformer; - if (stage_info_.is_first_stage) { - modules_[kPPFirstStageName] = std::make_shared(config_); - transformer[LLaMA3FirstStage::LLaMA3FirstStage::kWTELayerName] - = modules_[kPPFirstStageName]->mutable_module(LLaMA3FirstStage::LLaMA3FirstStage::kWTELayerName); - } - - { - std::map>> start_layer_to_layer_size_and_chunk; - for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { - const auto [start_layer, end_layer] = stage_info_.layer_ranges_per_chunk[chunk_idx]; - auto chunk = std::make_shared(config_, start_layer, end_layer); - start_layer_to_layer_size_and_chunk[start_layer] = std::make_pair(end_layer - start_layer, chunk); - } - std::vector> h; - int chunk_idx = 0; - for (auto &[start_layer, layer_size_and_chunk] : start_layer_to_layer_size_and_chunk) { - auto [layer_size, chunk] = layer_size_and_chunk; - for (int idx = 0; idx < layer_size; ++idx) { - h.push_back( - chunk->mutable_module(LLaMA3Chunk::LLaMA3Chunk::kHLayerName)->mutable_module(std::to_string(idx))); - } - modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)] = std::move(chunk); - ++chunk_idx; - } - transformer[LLaMA3Chunk::LLaMA3Chunk::kHLayerName] = std::make_shared(std::move(h)); - } - - if (stage_info_.is_last_stage) { - modules_[kPPLastStageName] = std::make_shared(config_); - transformer[LLaMA3LastStage::kLnFLayerName] - = modules_[kPPLastStageName]->mutable_module(LLaMA3LastStage::kLnFLayerName); - // NOTE(zbl): weight-tying is possible but torch script did not do so - modules_[LLaMA3LastStage::kLMHeadLayerName] - = modules_[kPPLastStageName]->mutable_module(LLaMA3LastStage::kLMHeadLayerName); - } - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); -} - -std::vector> LLaMA3::Forward(const std::vector> &x) { - auto x1 = (*modules_[kPPFirstStageName])({x[0]}); - for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { - x1 = (*modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)])(x1); - } - return (*modules_[kPPLastStageName])(x1); -} - std::shared_ptr LLaMA3::FromPretrained(ModelType model_type) { // TODO(zbl): implement this later LOG(FATAL) << "Not implemented yet"; @@ -485,18 +68,14 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { const auto version_major = BytesToType(header, 56); const auto version_minor = BytesToType(header, 60); - auto llama3 = std::make_shared(LLaMA3Config{.block_size = block_size, - .vocab_size = vocab_size, - .n_layer = n_layer, - .n_head = n_head, - .n_kv_head = n_kv_head, - .n_embd = n_embd, - .ffn_dim_multiplier = ffn_dim_multiplier, - .multiple_of = multiple_of, - .rope_theta = rope_theta, - .use_scaled_rope = static_cast(use_scaled_rope), - .norm_eps = norm_eps, - .max_gen_batch_size = max_gen_bs}); + nn::TransformerConfig llama3_config = nn::TransformerConfig::LLaMA3(); + llama3_config.block_size = block_size; + llama3_config.vocab_size = vocab_size; + llama3_config.n_layer = n_layer; + llama3_config.n_head = n_head; + llama3_config.n_kv_head = n_kv_head; + llama3_config.n_embd = n_embd; + auto llama3 = std::make_shared(llama3_config); // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== int pp_size = nn::parallel::global::GetPipelineParallelSize(); @@ -544,7 +123,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { const int64_t head_dim = static_cast(n_embd) / static_cast(n_head); - // MLP hidden dim calculation in LLaMA-3 + // nn::MLP hidden dim calculation in LLaMA-3 auto round_up_to = [](int64_t x, int64_t m) { return (x + m - 1) / m * m; }; int64_t hidden_dim = 4LL * static_cast(n_embd); hidden_dim = (2LL * hidden_dim) / 3LL; @@ -574,7 +153,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // RowParallel (proj) const int64_t in_pp = static_cast(n_embd) / tp_size; - // MLP: c_fc/c_fc2(shard along row),c_proj(shard along col) + // nn::MLP: c_fc/c_fc2(shard along row),c_proj(shard along col) const int64_t fc_out = ffn_hidden; const int64_t fc_pp = fc_out / tp_size; const int64_t in_fc_pp = ffn_hidden / tp_size; @@ -584,7 +163,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // ========== Read Sharded Params ========== // transformer.wte.weight : (vocab_size, n_embd) -> local tp_rank: rows of [v_start : v_start+vpp) if (is_first_stage) { - auto &wte = state_dict[std::format("{}.{}.{}", kTransformerLayerName, LLaMA3FirstStage::kWTELayerName, + auto &wte = state_dict[std::format("{}.{}.{}", kTransformerLayerName, nn::TransformerFirstStage::kWTELayerName, nn::parallel::VocabParallelEmbedding::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(wte->DataPtr()), /*rows=*/vocab_size, /*cols=*/n_embd, @@ -594,13 +173,13 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { ifs.seekg(wte_bytes, std::ios::cur); } - // transformer.h.{i}.ln_1.weight : Full version RMSNorm + // transformer.h.{i}.ln_1.weight : Full version nn::RMSNorm int local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kLn1LayerName, - RMSNorm::kParamWeightName)]; + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kLn1LayerName, nn::RMSNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; } else { @@ -614,10 +193,10 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kAttnLayerName, - CausalSelfAttention::kCAttnLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName, + nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; float *dst = static_cast(tensor->DataPtr()); const std::streampos base_pos = ifs.tellg(); @@ -654,10 +233,10 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kAttnLayerName, - CausalSelfAttention::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName, + nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/n_embd, /*cols=*/n_embd, /*col_start=*/tp_rank * in_pp, /*col_cnt=*/in_pp); @@ -672,9 +251,9 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, - std::to_string(local_layer_index), Block::kLn2LayerName, - RMSNorm::kParamWeightName)]; + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kLn2LayerName, nn::RMSNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; } else { @@ -687,9 +266,10 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), - Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerBlock::kMlpLayerName, + nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/fc_out, /*cols=*/n_embd, /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); @@ -704,9 +284,10 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), - Block::kMlpLayerName, MLP::kCFc2LayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerBlock::kMlpLayerName, nn::MLP::kCFc2LayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/fc_out, /*cols=*/n_embd, /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); @@ -721,9 +302,10 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), - Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerBlock::kMlpLayerName, + nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/n_embd, /*cols=*/fc_out, /*col_start=*/tp_rank * in_fc_pp, /*col_cnt=*/in_fc_pp); @@ -734,13 +316,14 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } } - // transformer.ln_f.weight : Full version RMSNorm + // transformer.ln_f.weight : Full version nn::RMSNorm // lm_head.weight : (vocab_size, n_embd) -> ColumnParallelLinear, but actually applies on "rows" { if (is_last_stage) { - auto &ln_f = state_dict[std::format("{}.{}.{}", kTransformerLayerName, LLaMA3LastStage::kLnFLayerName, - RMSNorm::kParamWeightName)]; - auto &lm_head = state_dict[std::format("{}.{}", LLaMA3LastStage::kLMHeadLayerName, + auto &ln_f + = state_dict[std::format("{}.{}.{}", kTransformerLayerName, nn::TransformerLastStage::kLnFLayerName, + nn::RMSNorm::kParamWeightName)]; + auto &lm_head = state_dict[std::format("{}.{}", nn::TransformerLastStage::kLMHeadLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(ln_f->DataPtr()), n_embd); ReadMatrixRowShardFloat(ifs, static_cast(lm_head->DataPtr()), diff --git a/example/llama3/net.h b/example/llama3/net.h deleted file mode 100644 index 4496a68d..00000000 --- a/example/llama3/net.h +++ /dev/null @@ -1,189 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include "glog/logging.h" - -#include "infini_train/include/device.h" -#include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" -#include "infini_train/include/tensor.h" - -struct LLaMA3Config { - // ref: https://huggingface.co/meta-llama/Llama-3.2-1B - // Model basic config - int64_t block_size = 8192; // Max seq_len - int64_t vocab_size = 128256; // Vocab size - int64_t n_layer = 16; // Num of transformer layers - int64_t n_head = 32; // Num of heads in MHA - int64_t n_kv_head = 8; // Num of Key/Value heads(< n_head if using GQA) - int64_t n_embd = 2048; // Hidden size - - // FFN config - std::optional ffn_dim_multiplier = 1.5f; // FFN dim multiplier - int64_t multiple_of = 256; // FFN dims must be multiple of this number - - // Pos embedding - float rope_theta = 500000.0f; // theta in RoPE - bool use_scaled_rope = false; // scaled RoPE - - // RMSNorm - float norm_eps = 1e-5f; // epsilon in RMSNorm - - // Inference - bool use_kv = false; // kv cache - bool flash = false; // flash attention - int64_t max_gen_batch_size = 4; // max batch size during inference -}; - -class SwiGLU : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "SwiGLU"; - SwiGLU() : CloneableModule(kType) {} - - std::vector> - Forward(const std::vector> &x) override; -}; - -// TODO(zbl): implement fused kernel -class RMSNorm : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "RMSNorm"; - static constexpr char kParamWeightName[] = "weight"; - - explicit RMSNorm(int64_t dim, float eps = 1e-6f, infini_train::Device device = infini_train::Device()); - - std::vector> - Forward(const std::vector> &x) override; - -private: - float eps_ = 1e-5f; -}; - -class CausalSelfAttention : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "CausalSelfAttention"; - static constexpr char kCAttnLayerName[] = "c_attn"; - static constexpr char kCProjLayerName[] = "c_proj"; - - explicit CausalSelfAttention(const LLaMA3Config &config); - - std::vector> - Forward(const std::vector> &x) override; - -private: - LLaMA3Config config_; - int64_t n_head_ = 0; - int64_t n_embd_ = 0; - int64_t n_kv_head_ = 0; - int64_t n_rep_ = 0; - int64_t head_dim_ = 0; -}; - -class MLP : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "MLP"; - static constexpr char kCFcLayerName[] = "c_fc"; - static constexpr char kCFc2LayerName[] = "c_fc2"; - static constexpr char kSiluLayerName[] = "silu"; - static constexpr char kCProjLayerName[] = "c_proj"; - - explicit MLP(const LLaMA3Config &config); - - std::vector> - Forward(const std::vector> &x) override; - -private: - int64_t hidden_dim_ = 0; -}; - -class Block : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "Block"; - static constexpr char kLn1LayerName[] = "ln_1"; - static constexpr char kAttnLayerName[] = "attn"; - static constexpr char kLn2LayerName[] = "ln_2"; - static constexpr char kMlpLayerName[] = "mlp"; - - explicit Block(const LLaMA3Config &config); - - std::vector> - Forward(const std::vector> &x) override; -}; - -class LLaMA3FirstStage : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "LLaMA3FirstStage"; - static constexpr char kWTELayerName[] = "wte"; - - explicit LLaMA3FirstStage(const LLaMA3Config &config); - - std::vector> - Forward(const std::vector> &x) override; - -private: - const LLaMA3Config config_; -}; - -class LLaMA3Chunk : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "LLaMA3Chunk"; - static constexpr char kHLayerName[] = "h"; - static constexpr char kFreqsCisName[] = "freqs_cis"; - - LLaMA3Chunk(const LLaMA3Config &config, int start_layer, int end_layer); - - std::vector> - Forward(const std::vector> &x) override; - -private: - const LLaMA3Config config_; -}; - -class LLaMA3LastStage : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "LLaMA3LastStage"; - static constexpr char kLnFLayerName[] = "ln_f"; - static constexpr char kLMHeadLayerName[] = "lm_head"; - - explicit LLaMA3LastStage(const LLaMA3Config &config); - - std::vector> - Forward(const std::vector> &x) override; - -private: - const LLaMA3Config config_; -}; - -class LLaMA3 : public infini_train::nn::CloneableModule { -public: - static constexpr char kType[] = "LLaMA3"; - static constexpr char kTransformerLayerName[] = "transformer"; - - enum class ModelType : int8_t { - // TODO(zbl): more model type from huggingface - kLLaMA3_1_8B, - kLLaMA3_1_70B, - kLLaMA3_2_1B, - kLLaMA3_2_3B, - kLLaMA3_3_70B, - }; - - explicit LLaMA3(const LLaMA3Config &config); - - std::vector> - Forward(const std::vector> &x) override; - - static std::shared_ptr FromPretrained(ModelType model_type); - static std::shared_ptr FromLLMC(const std::string &filepath); - - int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); } - -private: - const LLaMA3Config config_; - const infini_train::nn::parallel::StageInfo stage_info_; -}; diff --git a/infini_train/include/core/models/decode_only_transformer/layer_specs.h b/infini_train/include/core/models/decode_only_transformer/layer_specs.h new file mode 100644 index 00000000..01cbee26 --- /dev/null +++ b/infini_train/include/core/models/decode_only_transformer/layer_specs.h @@ -0,0 +1,12 @@ +#pragma once + +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_config.h" + +namespace infini_train::nn { +// Build GPT2 model spec: LayerNorm + GELU + standard attention +ModuleSpec BuildGPT2Spec(const TransformerConfig &config); + +// Build LLaMA3 model spec: RMSNorm + SwiGLU + RoPE + GQA +ModuleSpec BuildLLaMA3Spec(const TransformerConfig &config); +} // namespace infini_train::nn diff --git a/infini_train/include/core/models/decode_only_transformer/model.h b/infini_train/include/core/models/decode_only_transformer/model.h new file mode 100644 index 00000000..0d637016 --- /dev/null +++ b/infini_train/include/core/models/decode_only_transformer/model.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/models/decode_only_transformer/layer_specs.h" +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_block.h" +#include "infini_train/include/core/transformer/transformer_builders.h" +#include "infini_train/include/core/transformer/transformer_config.h" +#include "infini_train/include/core/transformer/transformer_layer.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/sparse.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" +#include "infini_train/include/tensor.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +// ========== GPT2 Model Definition ========== +// Uses LayerNorm, GELU activation, standard multi-head attention +class GPT2 : public nn::TransformerLayer { +public: + static constexpr char kType[] = "GPT2"; + static constexpr char kTransformerLayerName[] = "transformer"; + + enum class ModelType : int8_t { + kGPT2, + kGPT2Medium, + kGPT2Large, + kGPT2XL, + }; + + explicit GPT2(const nn::TransformerConfig &config) + : TransformerLayer(config, BuildGPT2Spec(config)), + stage_info_(nn::parallel::PipelineParallel::GetStageInfo( + config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, + nn::parallel::global::GetVirtualPipelineParallelSize())) {} + + static std::shared_ptr FromPretrained(ModelType model_type); + static std::shared_ptr FromLLMC(const std::string &filepath); + + int GetChunkSize() const; + +private: + const infini_train::nn::parallel::StageInfo stage_info_; +}; + +// ========== LLaMA3 Model Definition ========== +// Uses RMSNorm, SwiGLU activation, GQA attention, RoPE positional encoding +class LLaMA3 : public nn::TransformerLayer { +public: + static constexpr char kType[] = "LLaMA3"; + static constexpr char kTransformerLayerName[] = "transformer"; + + enum class ModelType : int8_t { + // TODO(zbl): more model type from huggingface + kLLaMA3_1_8B, + kLLaMA3_1_70B, + kLLaMA3_2_1B, + kLLaMA3_2_3B, + kLLaMA3_3_70B, + }; + + explicit LLaMA3(const nn::TransformerConfig &config) + : TransformerLayer(config, BuildLLaMA3Spec(config)), + stage_info_(nn::parallel::PipelineParallel::GetStageInfo( + config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, + nn::parallel::global::GetVirtualPipelineParallelSize())) {} + + static std::shared_ptr FromPretrained(ModelType model_type); + static std::shared_ptr FromLLMC(const std::string &filepath); + + int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); } + +private: + const infini_train::nn::parallel::StageInfo stage_info_; +}; diff --git a/infini_train/include/core/transformer/spec_utils.h b/infini_train/include/core/transformer/spec_utils.h new file mode 100644 index 00000000..569894f3 --- /dev/null +++ b/infini_train/include/core/transformer/spec_utils.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/transformer/transformer_config.h" + +namespace infini_train::nn { + +class Module; + +struct ModuleSpec { + ModuleSpec() = default; + + explicit ModuleSpec(std::type_index m) : module_(m) {} + + ModuleSpec &with_param(const std::string &key, std::any value) { + params_[key] = std::move(value); + return *this; + } + + ModuleSpec &with_submodule(const std::string &name, ModuleSpec spec) { + submodules_[name] = std::move(spec); + return *this; + } + + ModuleSpec & + with_build(std::function(const TransformerConfig &, const ModuleSpec &)> build_fn) { + build = std::move(build_fn); + return *this; + } + + std::type_index module_{typeid(void)}; + std::unordered_map params_; + std::unordered_map submodules_; + std::function(const TransformerConfig &, const ModuleSpec &)> build{nullptr}; +}; + +using ModuleCreator = std::function(const TransformerConfig &, const ModuleSpec &)>; + +class ModuleRegistry { +public: + static ModuleRegistry &Instance() { + static ModuleRegistry inst; + return inst; + } + + void Register(std::type_index type, ModuleCreator creator); + + ModuleCreator Get(std::type_index type) const; + + bool Has(std::type_index type) const { return registry_.contains(type); } + + std::unordered_set RegisteredTypes() const { + std::unordered_set types; + for (const auto &[type, _] : registry_) { types.insert(type); } + return types; + } + +private: + std::unordered_map registry_; +}; + +// Register a module type with automatic creator inference +#define REGISTER_MODULE(ModuleClass) \ + namespace { \ + struct ModuleClass##Registry { \ + ModuleClass##Registry() { \ + ModuleRegistry::Instance().Register(typeid(ModuleClass), \ + [](const TransformerConfig &config, const ModuleSpec &spec) { \ + return std::make_shared(config, spec); \ + }); \ + } \ + }; \ + static ModuleClass##Registry g_##ModuleClass##_registry; \ + } + +// Register a module type with custom creator function +#define REGISTER_MODULE_CUSTOM(ModuleClass, CreatorFunc) \ + namespace { \ + struct ModuleClass##Registry { \ + ModuleClass##Registry() { ModuleRegistry::Instance().Register(typeid(ModuleClass), CreatorFunc); } \ + }; \ + static ModuleClass##Registry g_##ModuleClass##_registry; \ + } + +// Get a required parameter from ModuleSpec +template inline T GetRequiredParam(const ModuleSpec &spec, const std::string &key) { + CHECK(spec.params_.contains(key)) << "Missing required parameter: " << key; + + const T *value = std::any_cast(&spec.params_.at(key)); + CHECK(value) << "Parameter type mismatch for key '" << key << "': expected " << typeid(T).name() << ", got " + << spec.params_.at(key).type().name(); + return *value; +} + +// Get an optional parameter from ModuleSpec with default value +template inline T GetOptionalParam(const ModuleSpec &spec, const std::string &key, T default_value) { + if (!spec.params_.contains(key)) { + return default_value; + } + + const T *value = std::any_cast(&spec.params_.at(key)); + return value ? *value : default_value; +} + +std::shared_ptr build_module(const TransformerConfig &config, const ModuleSpec &spec); +} // namespace infini_train::nn diff --git a/infini_train/include/core/transformer/transformer_block.h b/infini_train/include/core/transformer/transformer_block.h new file mode 100644 index 00000000..07c500b0 --- /dev/null +++ b/infini_train/include/core/transformer/transformer_block.h @@ -0,0 +1,120 @@ +#pragma once + +#include + +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/module.h" + +namespace infini_train::nn { + +class RMSNorm : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "RMSNorm"; + static constexpr char kParamWeightName[] = "weight"; + + explicit RMSNorm(int64_t dim, float eps = 1e-6f, infini_train::Device device = infini_train::Device()); + + std::vector> + Forward(const std::vector> &x) override; + +private: + float eps_ = 1e-5f; +}; + +class NewGELU : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "NewGELU"; + NewGELU() : CloneableModule(kType) {} + + std::vector> + Forward(const std::vector> &x) override; +}; + +class SwiGLU : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "SwiGLU"; + SwiGLU() : CloneableModule(kType) {} + + std::vector> + Forward(const std::vector> &x) override; +}; + +class CausalSelfAttention : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "CausalSelfAttention"; + static constexpr char kCAttnLayerName[] = "c_attn"; + static constexpr char kCProjLayerName[] = "c_proj"; + + static constexpr char kParamBiasName[] = "bias"; + + explicit CausalSelfAttention(const TransformerConfig &config, const ModuleSpec &spec = {}); + + std::vector> + Forward(const std::vector> &x) override; + +private: + TransformerConfig config_; + int64_t n_head_ = 0; + int64_t n_embd_ = 0; + int64_t local_n_head_ = 0; + + int64_t n_kv_head_ = 0; + int64_t n_rep_ = 0; + int64_t head_dim_ = 0; + + // Setup method for different attention modes + void SetupAttention(const TransformerConfig &config); + + // Standard attention forward (GPT2 style: no RoPE, no GQA) + std::vector> + ForwardStandard(const std::vector> &x); + + // RoPE-aware attention forward (LLaMA3 style: with RoPE, optional GQA) + std::vector> + ForwardWithRoPE(const std::vector> &x); + + // RoPE helper methods + std::tuple, std::shared_ptr> + ApplyRotaryEmbedding(const std::shared_ptr &xq, + const std::shared_ptr &xk, + const std::shared_ptr &freqs_cis); + + // GQA helper method + std::shared_ptr RepeatKV(const std::shared_ptr &x, int64_t n_rep); +}; + +class MLP : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "MLP"; + static constexpr char kCFcLayerName[] = "c_fc"; + static constexpr char kGeluLayerName[] = "gelu"; + static constexpr char kCProjLayerName[] = "c_proj"; + + static constexpr char kCFc2LayerName[] = "c_fc2"; + static constexpr char kSiluLayerName[] = "silu"; + + explicit MLP(const TransformerConfig &config, const ModuleSpec &spec = {}); + + std::vector> + Forward(const std::vector> &x) override; +}; + +class TransformerBlock : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "Block"; + static constexpr char kLn1LayerName[] = "ln_1"; + static constexpr char kAttnLayerName[] = "attn"; + static constexpr char kLn2LayerName[] = "ln_2"; + static constexpr char kMlpLayerName[] = "mlp"; + + explicit TransformerBlock(const TransformerConfig &config, const ModuleSpec &spec = {}); + + std::vector> + Forward(const std::vector> &x) override; + +private: + AttentionType attention_type_ = AttentionType::kStandard; +}; + +} // namespace infini_train::nn diff --git a/infini_train/include/core/transformer/transformer_builders.h b/infini_train/include/core/transformer/transformer_builders.h new file mode 100644 index 00000000..b0746358 --- /dev/null +++ b/infini_train/include/core/transformer/transformer_builders.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_config.h" + +namespace infini_train::nn { + +// Embedding +inline constexpr char kNumEmbeddings[] = "num_embeddings"; +inline constexpr char kEmbeddingDim[] = "embedding_dim"; + +// Normalization +inline constexpr char kNormalizedShape[] = "normalized_shape"; +inline constexpr char kDim[] = "dim"; +inline constexpr char kEps[] = "eps"; + +// Linear +inline constexpr char kInFeatures[] = "in_features"; +inline constexpr char kOutFeatures[] = "out_features"; +inline constexpr char kBias[] = "bias"; + +// Attention +inline constexpr char kNumHeads[] = "num_heads"; +inline constexpr char kNumKVHeads[] = "num_kv_heads"; + +// Build LayerNorm or RMSNorm spec based on config +ModuleSpec BuildNormSpec(const TransformerConfig &config); + +// Build CausalSelfAttention spec +ModuleSpec BuildAttentionSpec(const TransformerConfig &config); + +// Build MLP spec (supports GELU and SwiGLU) +ModuleSpec BuildMLPSpec(const TransformerConfig &config); + +// Build TransformerBlock spec +ModuleSpec BuildTransformerBlockSpec(const TransformerConfig &config); + +// Build VocabParallelEmbedding spec for token embeddings +ModuleSpec BuildVocabEmbeddingSpec(const TransformerConfig &config); + +// Build Embedding spec for position embeddings +ModuleSpec BuildPositionEmbeddingSpec(int64_t num_embeddings, int64_t embedding_dim); + +// Build ColumnParallelLinear spec for output projection (lm_head) +ModuleSpec BuildOutputProjSpec(const TransformerConfig &config, int64_t output_size, bool use_bias); + +} // namespace infini_train::nn diff --git a/infini_train/include/core/transformer/transformer_config.h b/infini_train/include/core/transformer/transformer_config.h new file mode 100644 index 00000000..9470cc8a --- /dev/null +++ b/infini_train/include/core/transformer/transformer_config.h @@ -0,0 +1,83 @@ +#pragma once +#include +#include +#include + +namespace infini_train::nn { + +enum class AttentionType { + kStandard, // Standard attention (GPT2 style, no RoPE) + kRoPE // Rotary Position Embedding (LLaMA3 style) +}; + +enum class MLPType { + kGELU, // GELU activation (GPT2 style) + kSwiGLU // SwiGLU activation (LLaMA3 style) +}; + +enum class NormType { + kLayerNorm, // LayerNorm (GPT2 style) + kRMSNorm // RMSNorm (LLaMA3 style) +}; + +class TransformerConfig { +public: + static constexpr char kGPT2Name[] = "GPT2"; + static constexpr char kLLaMA3Name[] = "LLaMA3"; + + std::string model_type = ""; + + int64_t block_size = 1024; // Max seq_len + int64_t vocab_size = 50304; // Vocab size + int64_t original_vocab_size = 50257; // Original vocab size before padding + int64_t n_layer = 12; // Num of transformer layers + int64_t n_head = 12; // Num of heads in MHA + int64_t n_kv_head = 12; // Num of Key/Value heads (<= n_head, < n_head if using GQA) + int64_t n_embd = 768; // Hidden size + + AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type + MLPType activation_type = MLPType::kGELU; // MLP activation type + NormType norm_type = NormType::kLayerNorm; // Normalization type + + bool use_bias = true; // Linear layers bias (GPT2: true, LLaMA3: false) + bool use_gqa = false; // Grouped Query Attention + bool use_rope = false; // Rotary Position Embedding + bool tie_weights = true; // Tie embedding and lm_head weights + + // FFN config + float ffn_expansion_ratio = 4.0f; // MLP output: n_embd * ffn_expansion_ratio + std::optional ffn_dim_multiplier = 1.5f; // FFN dim multiplier + int64_t multiple_of = 256; // FFN dims must be multiple of this number + + // RoPE config + float rope_theta = 500000.0f; // theta in RoPE + bool use_scaled_rope = false; // scaled RoPE + + // Normalization + float norm_eps = 1e-5f; // epsilon in RMSNorm + + // Inference + bool use_kv = false; // kv cache + bool flash = false; // flash attention + int64_t max_gen_batch_size = 4; // max batch size during inference + + static TransformerConfig GPT2() { return {}; } + + static TransformerConfig LLaMA3() { + return {.model_type = kLLaMA3Name, + .block_size = 8192, + .vocab_size = 128256, + .n_layer = 16, + .n_head = 32, + .n_kv_head = 8, + .n_embd = 2048, + .attention_type = AttentionType::kRoPE, + .activation_type = MLPType::kSwiGLU, + .norm_type = NormType::kRMSNorm, + .use_bias = false, + .use_gqa = true, + .use_rope = true, + .tie_weights = false}; + } +}; +} // namespace infini_train::nn diff --git a/infini_train/include/core/transformer/transformer_layer.h b/infini_train/include/core/transformer/transformer_layer.h new file mode 100644 index 00000000..5c30e285 --- /dev/null +++ b/infini_train/include/core/transformer/transformer_layer.h @@ -0,0 +1,82 @@ +#pragma once + +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_block.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" + +namespace infini_train::nn { +class TransformerConfig; + +class TransformerFirstStage : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "TransformerFirstStage"; + static constexpr char kWTELayerName[] = "wte"; + static constexpr char kWPELayerName[] = "wpe"; + + explicit TransformerFirstStage(const TransformerConfig &config, const ModuleSpec &spec = {}); + + std::vector> + Forward(const std::vector> &x) override; + +private: + TransformerConfig config_; + ModuleSpec spec_; +}; + +class TransformerChunk : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "TransformerChunk"; + static constexpr char kHLayerName[] = "h"; + static constexpr char kFreqsCisName[] = "freqs_cis"; + + TransformerChunk(const TransformerConfig &config, int start_layer, int end_layer, const ModuleSpec &spec = {}); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const TransformerConfig config_; + ModuleSpec spec_; + + // RoPE helper method + std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, + bool use_scaled = false, + infini_train::Device device = infini_train::Device()); +}; + +class TransformerLastStage : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "TransformerLastStage"; + static constexpr char kLnFLayerName[] = "ln_f"; + static constexpr char kLMHeadLayerName[] = "lm_head"; + + explicit TransformerLastStage(const TransformerConfig &config, const ModuleSpec &spec = {}); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const TransformerConfig config_; + ModuleSpec spec_; +}; + +class TransformerLayer : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "Transformer"; + static constexpr char kTransformerLayerName[] = "transformer"; + + explicit TransformerLayer(const TransformerConfig config, const ModuleSpec &spec = {}); + + std::vector> + Forward(const std::vector> &x) override; + + const TransformerConfig &GetConfig() const { return config_; } + + const TransformerConfig config_; + +private: + const infini_train::nn::parallel::StageInfo stage_info_; + ModuleSpec spec_; +}; +} // namespace infini_train::nn diff --git a/infini_train/src/core/models/decode_only_transformer/layer_specs.cc b/infini_train/src/core/models/decode_only_transformer/layer_specs.cc new file mode 100644 index 00000000..552363de --- /dev/null +++ b/infini_train/src/core/models/decode_only_transformer/layer_specs.cc @@ -0,0 +1,64 @@ +#include "infini_train/include/core/models/decode_only_transformer/layer_specs.h" + +#include +#include + +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_builders.h" +#include "infini_train/include/core/transformer/transformer_config.h" +#include "infini_train/include/core/transformer/transformer_layer.h" + +namespace infini_train::nn { + +ModuleSpec BuildGPT2Spec(const TransformerConfig &config) { + // Configure for GPT2 architecture + TransformerConfig gpt2_config = config; + ModuleSpec spec; + + // ===== First Stage ===== + ModuleSpec first_stage; + first_stage.with_submodule(TransformerFirstStage::kWTELayerName, BuildVocabEmbeddingSpec(gpt2_config)) + .with_submodule(TransformerFirstStage::kWPELayerName, + BuildPositionEmbeddingSpec(gpt2_config.block_size, gpt2_config.n_embd)); + spec.with_submodule(TransformerFirstStage::kType, first_stage); + + // ===== Transformer Block ===== + ModuleSpec block = BuildTransformerBlockSpec(gpt2_config); + spec.with_submodule(TransformerBlock::kType, block); + + // ===== Last Stage ===== + ModuleSpec last_stage; + last_stage.with_submodule(TransformerLastStage::kLnFLayerName, BuildNormSpec(gpt2_config)) + .with_submodule(TransformerLastStage::kLMHeadLayerName, + BuildOutputProjSpec(gpt2_config, gpt2_config.vocab_size, false)); + spec.with_submodule(TransformerLastStage::kType, last_stage); + + return spec; +} + +ModuleSpec BuildLLaMA3Spec(const TransformerConfig &config) { + // Configure for LLaMA3 architecture + TransformerConfig llama3_config = config; + ModuleSpec spec; + + // ===== First Stage ===== + ModuleSpec first_stage; + // LLaMA3 only has token embedding, no position embedding (uses RoPE) + first_stage.with_submodule(TransformerFirstStage::kWTELayerName, BuildVocabEmbeddingSpec(llama3_config)); + spec.with_submodule(TransformerFirstStage::kType, first_stage); + + // ===== Transformer Block ===== + ModuleSpec block = BuildTransformerBlockSpec(llama3_config); + spec.with_submodule(TransformerBlock::kType, block); + + // ===== Last Stage ===== + ModuleSpec last_stage; + last_stage.with_submodule(TransformerLastStage::kLnFLayerName, BuildNormSpec(llama3_config)) + .with_submodule(TransformerLastStage::kLMHeadLayerName, + BuildOutputProjSpec(llama3_config, llama3_config.vocab_size, false)); + spec.with_submodule(TransformerLastStage::kType, last_stage); + + return spec; +} + +} // namespace infini_train::nn diff --git a/infini_train/src/core/transformer/spec_utils.cc b/infini_train/src/core/transformer/spec_utils.cc new file mode 100644 index 00000000..bdc4b4d5 --- /dev/null +++ b/infini_train/src/core/transformer/spec_utils.cc @@ -0,0 +1,43 @@ +#include "infini_train/include/core/transformer/spec_utils.h" + +#include +#include +#include +#include + +#include "infini_train/include/core/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/module.h" + +namespace infini_train::nn { + +void ModuleRegistry::Register(std::type_index type, ModuleCreator creator) { + CHECK(!registry_.contains(type)) << "Module type already registered: " << type.name(); + + registry_[type] = std::move(creator); +} + +ModuleCreator ModuleRegistry::Get(std::type_index type) const { + auto it = registry_.find(type); + if (it == registry_.end()) { + return nullptr; + } + return it->second; +} + +std::shared_ptr build_module(const TransformerConfig &config, const ModuleSpec &spec) { + if (spec.build) { + return spec.build(config, spec); + } + + CHECK(spec.module_ != typeid(void)) << "ModuleSpec.module is not set"; + + auto creator = ModuleRegistry::Instance().Get(spec.module_); + + CHECK(creator) << "Module not registered: " << spec.module_.name(); + + auto module = creator(config, spec); + + return module; +} + +} // namespace infini_train::nn diff --git a/infini_train/src/core/transformer/transformer_block.cc b/infini_train/src/core/transformer/transformer_block.cc new file mode 100644 index 00000000..6392f28d --- /dev/null +++ b/infini_train/src/core/transformer/transformer_block.cc @@ -0,0 +1,433 @@ +#include "infini_train/include/core/transformer/transformer_block.h" + +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_builders.h" +#include "infini_train/include/core/transformer/transformer_config.h" +#include "infini_train/include/core/transformer/transformer_layer.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/sparse.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" +#include "infini_train/include/tensor.h" +namespace infini_train::nn { + +RMSNorm::RMSNorm(int64_t dim, float eps, infini_train::Device device) : CloneableModule(kType), eps_(eps) { + parameters_[kParamWeightName] + = std::make_shared(std::vector{dim}, DataType::kFLOAT32, device)->RequiresGrad(); + nn::init::Ones(parameters_[kParamWeightName]); +} + +std::vector> RMSNorm::Forward(const std::vector> &x) { + // broadcasted Mul([4, 64, 2048] * [4, 64, 1]) + auto norm = x[0] * nn::function::Rsqrt(nn::function::Mean(nn::function::Pow(x[0], 2), -1, true) + eps_); + return {norm * parameters_[kParamWeightName]}; +} + +std::vector> +NewGELU::Forward(const std::vector> &x) { + auto &input = x[0]; + return {0.5 * input + * (1.0 + nn::function::Tanh(std::sqrt(2.0 / M_PI) * (input + 0.044715 * nn::function::Pow(input, 3.0))))}; +} + +std::vector> +SwiGLU::Forward(const std::vector> &x) { + return {x[0] * nn::function::Sigmoid(x[0])}; +} + +CausalSelfAttention::CausalSelfAttention(const TransformerConfig &config, const ModuleSpec &spec) + : CloneableModule(kType), config_(config) { + SetupAttention(config); + + CHECK(spec.submodules_.contains(kCAttnLayerName)) + << "CausalSelfAttention spec missing submodule: " << kCAttnLayerName; + CHECK(spec.submodules_.contains(kCProjLayerName)) + << "CausalSelfAttention spec missing submodule: " << kCProjLayerName; + // Build submodules from spec + modules_[kCAttnLayerName] = build_module(config, spec.submodules_.at(kCAttnLayerName)); + modules_[kCProjLayerName] = build_module(config, spec.submodules_.at(kCProjLayerName)); + + // For standard attention (GPT2 style), precompute causal mask + if (config_.attention_type == AttentionType::kStandard) { + // causal mask: (1, 1, block_size, block_size) + buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size})) + ->View({1, 1, config_.block_size, config_.block_size}); + } +} + +void CausalSelfAttention::SetupAttention(const TransformerConfig &config) { + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + + CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; + CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size"; + + n_head_ = config.n_head; + n_embd_ = config.n_embd; + head_dim_ = config.n_embd / config.n_head; + local_n_head_ = n_head_ / tp_world_size; + + // For GQA, set n_kv_head and n_rep + if (config.use_gqa && config.n_kv_head < config.n_head) { + CHECK_EQ(config.n_head % config.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA"; + CHECK_EQ(config.n_kv_head % tp_world_size, 0) << "n_kv_head must be divisible by TP world size for GQA"; + + n_kv_head_ = config.n_kv_head; + n_rep_ = n_head_ / n_kv_head_; + } else { + n_kv_head_ = n_head_; + n_rep_ = 1; + } +} + +std::vector> +CausalSelfAttention::Forward(const std::vector> &x) { + if (config_.attention_type == AttentionType::kRoPE) { + return ForwardWithRoPE(x); + } else { + return ForwardStandard(x); + } +} + +std::vector> +CausalSelfAttention::ForwardStandard(const std::vector> &x) { + auto tp_world_size = parallel::global::GetTensorParallelSize(); + + const auto B = x[0]->Dims()[0]; // bs + const auto C = x[0]->Dims()[2]; // n_embd + const int64_t head_dim = n_embd_ / n_head_; // per-head dim (global) + const int64_t local_C = n_embd_ / tp_world_size; // per-rank hidden + + // (B, T, C) -> ColumnParallelLinear(C, 3*C) -> (B, T, 3 * local_C) + // -> Split -> (3, B, T, local_C) + auto qkv = (*modules_[kCAttnLayerName])(x)[0]->Split(local_C, 2); + + // (B, T, local_C) + auto q = qkv[0]; + auto k = qkv[1]; + auto v = qkv[2]; + + // NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear + const auto T = q->Dims()[1]; + + // View to multi-head: local_n_head * head_dim == local_C + // (B, T, local_C) -> (B, T, h_l, Dh) -> (B, h_l, T, Dh) + k = k->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); + q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); + v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); + + // (B, h_l, T, T) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); + // (1, 1, T, T) + auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); + // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) + att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); + // (B, h_l, T, T) + att = nn::function::Softmax(att, -1); + // (B, h_l, T, Dh) + auto y = att->Matmul(v); + // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); + + // Get full tensor + // (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C) + y = (*modules_[kCProjLayerName])({y})[0]; + // (B, T, C) == (bs, seq_len, n_embd) + return {y}; +} + +// RoPE helper methods +std::tuple, std::shared_ptr> +CausalSelfAttention::ApplyRotaryEmbedding(const std::shared_ptr &xq, + const std::shared_ptr &xk, + const std::shared_ptr &freqs_cis) { + // Reshape freqs_cis for broadcasting + const auto &x_shape = xq->Dims(); // (B, T, H, D) + const int64_t T = x_shape[1]; + const int64_t D = x_shape[3]; + + std::vector target_shape = {1, T, 1, D / 2, 2}; + auto cos_sin = freqs_cis->View(target_shape); // -> (1, T, 1, D/2, 2) + + auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2) + auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2) + + auto slice_pair = [](const std::shared_ptr &x) { + auto even = x->Slice(-1, 0, x->Dims().back(), 2); + auto odd = x->Slice(-1, 1, x->Dims().back(), 2); + return std::make_pair(even, odd); + }; + + auto [q_even, q_odd] = slice_pair(xq); + auto q_rotated_left = q_even * cos - q_odd * sin; + auto q_rotated_right = q_even * sin + q_odd * cos; + auto q_rotated + = nn::function::Stack(std::vector>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2); + + auto [k_even, k_odd] = slice_pair(xk); + auto k_rotated_left = k_even * cos - k_odd * sin; + auto k_rotated_right = k_even * sin + k_odd * cos; + auto k_rotated + = nn::function::Stack(std::vector>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2); + + return {q_rotated, k_rotated}; +} + +std::shared_ptr CausalSelfAttention::RepeatKV(const std::shared_ptr &x, + int64_t n_rep) { + const auto &shape = x->Dims(); + const int64_t B = shape[0], T = shape[1], H = shape[2], D = shape[3]; + + if (n_rep == 1) { + return x; + } + + return x->View({B, T, H, 1, D})->RepeatInterleave(n_rep, 3)->Contiguous()->View({B, T, H * n_rep, D}); +} + +std::vector> +CausalSelfAttention::ForwardWithRoPE(const std::vector> &x) { + const auto B = x[0]->Dims()[0]; // bs + const auto C = x[0]->Dims()[2]; // n_embd + + const auto tp_size = nn::parallel::global::GetTensorParallelSize(); + + const auto C_local = C / tp_size; + const auto H_local = n_head_ / tp_size; + const auto KV_local = n_kv_head_ / tp_size; + const auto D = head_dim_; // n_embd / n_head + + const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; + const auto start_pos = x.size() > 2 ? x[2] : nullptr; + const auto mask = x.size() > 3 ? x[3] : nullptr; + CHECK(freqs_cis != nullptr) << "freqs_cis is null."; + + // (B, T, C) -> (B, T, (H + 2 * n_kv_head) * D) + auto qkv = (*modules_[kCAttnLayerName])({x[0]})[0]; + // NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear + const auto T = qkv->Dims()[1]; + // NOTE(zbl): torch script uses torch.split({...}, dim) to split tensors into sub-tensors in different sizes + // use Slice() to work around here + const int64_t q_size_local = H_local * D; + const int64_t kv_size_local = KV_local * D; + // -> Split into q, k, v + // q: (B, T, H_local, D) + auto q = qkv->Slice(2, 0, q_size_local)->View({B, T, H_local, D}); + // k: (B, T, KV_local, D) + auto k = qkv->Slice(2, q_size_local, q_size_local + kv_size_local)->View({B, T, KV_local, D}); + // v: (B, T, KV_local, D) + auto v = qkv->Slice(2, q_size_local + kv_size_local, q_size_local + 2 * kv_size_local)->View({B, T, KV_local, D}); + + // -> RoPE on q, k + // q: (B, T, H_local, D) + // k: (B, T, KV_local, D) + std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis); + + // TODO(zbl): use kv cache during inference + // if (use_kv_) { ... } + + // align n_head in GQA + // (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV + k = RepeatKV(k, n_rep_); + v = RepeatKV(v, n_rep_); + + // (B, T, H_local, D) -> (B, H_local, T, D) + q = q->Transpose(1, 2); + k = k->Transpose(1, 2); + v = v->Transpose(1, 2); + + // TODO(zbl): support flash attention later + // if (flash_) { ... } + + // manual implementation of attention + // this materializes the large (T,T) matrix for all the queries and keys + + // q: (B, H_local, T, D) + // k: (B, H_local, T, D) -> (B, H_local, D, T) + // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(D))); + if (mask) { + // mask: (1, 1, T, T) + att = att->MaskedFill(mask, std::numeric_limits::lowest()); + } + // (B, H_local, T, T) + att = nn::function::Softmax(att, -1); + // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D) + auto y = att->Matmul(v); + // (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local}); + // output projection + // (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C) + y = (*modules_[kCProjLayerName])({y})[0]; + // (B, H, C) == (bs, seq_len, n_embd) + return {y}; +} + +MLP::MLP(const TransformerConfig &config, const ModuleSpec &spec) : CloneableModule(kType) { + // c_fc: ColumnParallel (input full, output parallel) + modules_[kCFcLayerName] = build_module(config, spec.submodules_.at(kCFcLayerName)); + + // For SwiGLU, add second projection + if (spec.submodules_.contains(kCFc2LayerName)) { + modules_[kCFc2LayerName] = build_module(config, spec.submodules_.at(kCFc2LayerName)); + } + + // Activation: check for GELU or SwiGLU + if (spec.submodules_.contains(kGeluLayerName)) { + modules_[kGeluLayerName] = build_module(config, spec.submodules_.at(kGeluLayerName)); + } else if (spec.submodules_.contains(kSiluLayerName)) { + modules_[kSiluLayerName] = build_module(config, spec.submodules_.at(kSiluLayerName)); + } + + // c_proj: RowParallel (input parallel, output full) + modules_[kCProjLayerName] = build_module(config, spec.submodules_.at(kCProjLayerName)); +} + +std::vector> +MLP::Forward(const std::vector> &x) { + // Check if this is SwiGLU (has second projection and SiLU) + bool is_swiglu = modules_.count(kCFc2LayerName) > 0 && modules_.count(kSiluLayerName) > 0; + + if (is_swiglu) { + // SwiGLU forward pass + // (B, T, C) -> ColumnParallelLinear(C, hidden_dim) -> (B, T, hidden_dim) + auto x1 = (*modules_[kCFcLayerName])(x)[0]; + // (B, T, C) -> ColumnParallelLinear(C, hidden_dim) -> (B, T, hidden_dim) + auto x2 = (*modules_[kCFc2LayerName])(x)[0]; + // (B, T, hidden_dim) -> SiLU -> (B, T, hidden_dim) + x2 = (*modules_[kSiluLayerName])({x2})[0]; + // (B, T, hidden_dim) -> element-wise mul -> (B, T, hidden_dim) + auto x3 = x1 * x2; + // (B, T, hidden_dim) -> RowParallelLinear(hidden_dim, C) -> (B, T, C) + auto x4 = (*modules_[kCProjLayerName])({x3}); + return x4; + } else { + // GELU forward pass (standard) + // (B, T, C) -> ColumnParallelLinear(C, 4*C) -> (B, T, 4*C_local) + auto x1 = (*modules_[kCFcLayerName])(x); + // (B, T, 4*C_local) -> GELU -> (B, T, 4*C_local) + auto x2 = (*modules_[kGeluLayerName])(x1); + // (B, T, 4*C_local) -> RowParallelLinear(4*C, C) -> (B, T, C) + auto x3 = (*modules_[kCProjLayerName])(x2); + return x3; + } +} + +TransformerBlock::TransformerBlock(const nn::TransformerConfig &config, const ModuleSpec &spec) + : CloneableModule(kType), attention_type_(config.attention_type) { + modules_[kLn1LayerName] = build_module(config, spec.submodules_.at(kLn1LayerName)); + modules_[kAttnLayerName] = build_module(config, spec.submodules_.at(kAttnLayerName)); + modules_[kLn2LayerName] = build_module(config, spec.submodules_.at(kLn2LayerName)); + modules_[kMlpLayerName] = build_module(config, spec.submodules_.at(kMlpLayerName)); +} + +std::vector> +TransformerBlock::Forward(const std::vector> &x) { + // (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) + auto ln1_out = (*modules_[kLn1LayerName])({x[0]})[0]; + + std::shared_ptr x1; + // Build attention input + if (attention_type_ == AttentionType::kRoPE) { + // LLaMA3: {ln1_out, freqs_cis, start_pos, mask} + const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; + const auto start_pos = x.size() > 2 ? x[2] : nullptr; + const auto mask = x.size() > 3 ? x[3] : nullptr; + auto attn_out = (*modules_[kAttnLayerName])({ln1_out, freqs_cis, start_pos, mask})[0]; + x1 = x[0] + attn_out; + } else { + // GPT2: {ln1_out} + auto attn_out = (*modules_[kAttnLayerName])({ln1_out})[0]; + x1 = x[0] + attn_out; + } + + // (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) -> MLP -> (bs, seq_len, n_embd) + // -> Add -> (bs, seq_len, n_embd) + auto x2 = x1 + (*modules_[kMlpLayerName])((*modules_[kLn2LayerName])({x1}))[0]; + + // (bs, seq_len, n_embd) + return {x2}; +} + +// ========== Module Registration using REGISTER_MODULE macro ========== +REGISTER_MODULE(CausalSelfAttention); +REGISTER_MODULE(MLP); +REGISTER_MODULE(TransformerBlock); + +// NewGELU +REGISTER_MODULE_CUSTOM(NewGELU, + [](const TransformerConfig &config, const ModuleSpec &) { return std::make_shared(); }); + +// SwiGLU +REGISTER_MODULE_CUSTOM(SwiGLU, + [](const TransformerConfig &config, const ModuleSpec &) { return std::make_shared(); }); + +// LayerNorm registration with custom config +REGISTER_MODULE_CUSTOM(LayerNorm, [](const TransformerConfig &config, const ModuleSpec &spec) { + auto normalized_shape + = GetOptionalParam>(spec, kNormalizedShape, std::vector{config.n_embd}); + return std::make_shared(normalized_shape); +}); + +// RMSNorm registration with custom config +REGISTER_MODULE_CUSTOM(RMSNorm, [](const TransformerConfig &config, const ModuleSpec &spec) { + int64_t dim = GetOptionalParam(spec, kDim, config.n_embd); + float eps = GetOptionalParam(spec, kEps, 1e-5f); + return std::make_shared(dim, eps); +}); + +// Embedding registration with params from spec +REGISTER_MODULE_CUSTOM(Embedding, [](const TransformerConfig &config, const ModuleSpec &spec) { + int num_embeddings = GetRequiredParam(spec, kNumEmbeddings); + int embedding_dim = GetRequiredParam(spec, kEmbeddingDim); + return std::make_shared(num_embeddings, embedding_dim); +}); + +namespace parallel { +// ColumnParallelLinear registration with params from spec +REGISTER_MODULE_CUSTOM(ColumnParallelLinear, [](const TransformerConfig &config, const ModuleSpec &spec) { + int in = GetRequiredParam(spec, kInFeatures); + int out = GetRequiredParam(spec, kOutFeatures); + bool bias = GetOptionalParam(spec, kBias, true); + return std::make_shared( + /*in_features=*/in, + /*out_features=*/out, + /*bias=*/bias, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/global::GetSequenceParallelEnabled()); +}); + +// RowParallelLinear registration with params from spec +REGISTER_MODULE_CUSTOM(RowParallelLinear, [](const TransformerConfig &config, const ModuleSpec &spec) { + int in = GetRequiredParam(spec, kInFeatures); + int out = GetRequiredParam(spec, kOutFeatures); + bool bias = GetOptionalParam(spec, kBias, true); + return std::make_shared( + /*in_features=*/in, + /*out_features=*/out, + /*bias=*/bias, + /*reduce_output=*/true, + /*input_is_parallel=*/true, + /*skip_bias_add=*/false, + /*sequence_parallel=*/global::GetSequenceParallelEnabled()); +}); + +// VocabParallelEmbedding registration with params from spec +REGISTER_MODULE_CUSTOM(VocabParallelEmbedding, [](const TransformerConfig &config, const ModuleSpec &spec) { + int num_embeddings = GetRequiredParam(spec, kNumEmbeddings); + int embedding_dim = GetRequiredParam(spec, kEmbeddingDim); + return std::make_shared(num_embeddings, embedding_dim, + /*reduce_scatter_embeddings=*/global::GetSequenceParallelEnabled()); +}); +} // namespace parallel +} // namespace infini_train::nn diff --git a/infini_train/src/core/transformer/transformer_builders.cc b/infini_train/src/core/transformer/transformer_builders.cc new file mode 100644 index 00000000..8a9a17c7 --- /dev/null +++ b/infini_train/src/core/transformer/transformer_builders.cc @@ -0,0 +1,166 @@ +#include "infini_train/include/core/transformer/transformer_builders.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_block.h" +#include "infini_train/include/core/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/sparse.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" + +namespace infini_train::nn { + +ModuleSpec BuildNormSpec(const TransformerConfig &config) { + ModuleSpec spec; + switch (config.norm_type) { + case NormType::kLayerNorm: + spec = ModuleSpec(typeid(LayerNorm)); + spec.with_param(kNormalizedShape, std::vector{config.n_embd}); + break; + case NormType::kRMSNorm: + spec = ModuleSpec(typeid(RMSNorm)); + spec.with_param(kDim, static_cast(config.n_embd)).with_param(kEps, config.norm_eps); + break; + default: + LOG(FATAL) << "Unsupported norm type"; + } + return spec; +} + +ModuleSpec BuildAttentionSpec(const TransformerConfig &config) { + ModuleSpec spec(typeid(CausalSelfAttention)); + + // Calculate QKV output dimension based on attention type and GQA + int64_t qkv_out; + if (config.use_gqa && config.n_kv_head < config.n_head) { + // GQA style (LLaMA3 with GQA enabled) + int64_t head_dim = config.n_embd / config.n_head; + // qkv_out = config.n_embd + 2 * config.n_kv_head * head_dim; + qkv_out = (config.n_head + 2 * config.n_kv_head) * head_dim; + } else { + // Standard MHA style (GPT2, or models without GQA) + qkv_out = 3 * config.n_embd; + } + + // Build c_attn (QKV projection) + ModuleSpec c_attn_spec(typeid(parallel::ColumnParallelLinear)); + c_attn_spec.with_param(kInFeatures, static_cast(config.n_embd)) + .with_param(kOutFeatures, static_cast(qkv_out)) + .with_param(kBias, config.use_bias); + spec.with_submodule(CausalSelfAttention::kCAttnLayerName, c_attn_spec); + + // Build c_proj (output projection) + ModuleSpec c_proj_spec(typeid(parallel::RowParallelLinear)); + c_proj_spec.with_param(kInFeatures, static_cast(config.n_embd)) + .with_param(kOutFeatures, static_cast(config.n_embd)) + .with_param(kBias, config.use_bias); + spec.with_submodule(CausalSelfAttention::kCProjLayerName, c_proj_spec); + + return spec; +} + +ModuleSpec BuildMLPSpec(const TransformerConfig &config) { + ModuleSpec spec(typeid(MLP)); + + // Compute hidden dimension + // Base dimension: n_embd * ffn_expansion_ratio + int64_t ffn_hidden = static_cast(config.n_embd * config.ffn_expansion_ratio); + + // Apply SwiGLU adjustment + if (config.activation_type == MLPType::kSwiGLU) { + ffn_hidden = int(2 * ffn_hidden) / 3; // SwiGLU intermediate + } + + // Apply multiplier + if (config.ffn_dim_multiplier.has_value()) { + ffn_hidden + = static_cast(std::llround(static_cast(ffn_hidden) * config.ffn_dim_multiplier.value())); + } + + // Round up to multiple_of + int64_t before_round = ffn_hidden; + ffn_hidden = (ffn_hidden + config.multiple_of - 1) / config.multiple_of * config.multiple_of; + + // Build c_fc (input projection) + ModuleSpec c_fc_spec(typeid(parallel::ColumnParallelLinear)); + c_fc_spec.with_param(kInFeatures, static_cast(config.n_embd)) + .with_param(kOutFeatures, static_cast(ffn_hidden)) + .with_param(kBias, config.use_bias); + spec.with_submodule(MLP::kCFcLayerName, c_fc_spec); + + // Build activation based on config + switch (config.activation_type) { + case MLPType::kGELU: { + spec.with_submodule(MLP::kGeluLayerName, ModuleSpec(typeid(NewGELU))); + break; + } + case MLPType::kSwiGLU: { + // Add second projection for SwiGLU + ModuleSpec c_fc2_spec(typeid(parallel::ColumnParallelLinear)); + c_fc2_spec.with_param(kInFeatures, static_cast(config.n_embd)) + .with_param(kOutFeatures, static_cast(ffn_hidden)) + .with_param(kBias, config.use_bias); + spec.with_submodule(MLP::kCFc2LayerName, c_fc2_spec); + + spec.with_submodule(MLP::kSiluLayerName, ModuleSpec(typeid(SwiGLU))); + break; + } + default: + LOG(FATAL) << "Unsupported MLP type"; + } + + // Build c_proj (output projection) + ModuleSpec c_proj_spec(typeid(parallel::RowParallelLinear)); + c_proj_spec.with_param(kInFeatures, static_cast(ffn_hidden)) + .with_param(kOutFeatures, static_cast(config.n_embd)) + .with_param(kBias, config.use_bias); + spec.with_submodule(MLP::kCProjLayerName, c_proj_spec); + + return spec; +} + +ModuleSpec BuildTransformerBlockSpec(const TransformerConfig &config) { + ModuleSpec spec(typeid(TransformerBlock)); + + // LayerNorm 1 (before attention) + spec.with_submodule(TransformerBlock::kLn1LayerName, BuildNormSpec(config)); + + // CausalSelfAttention + spec.with_submodule(TransformerBlock::kAttnLayerName, BuildAttentionSpec(config)); + + // LayerNorm 2 (before MLP) + spec.with_submodule(TransformerBlock::kLn2LayerName, BuildNormSpec(config)); + + // MLP + spec.with_submodule(TransformerBlock::kMlpLayerName, BuildMLPSpec(config)); + + return spec; +} + +ModuleSpec BuildVocabEmbeddingSpec(const TransformerConfig &config) { + ModuleSpec spec(typeid(parallel::VocabParallelEmbedding)); + spec.with_param(kNumEmbeddings, static_cast(config.vocab_size)) + .with_param(kEmbeddingDim, static_cast(config.n_embd)); + return spec; +} + +ModuleSpec BuildPositionEmbeddingSpec(int64_t num_embeddings, int64_t embedding_dim) { + ModuleSpec spec(typeid(Embedding)); + spec.with_param(kNumEmbeddings, static_cast(num_embeddings)) + .with_param(kEmbeddingDim, static_cast(embedding_dim)); + return spec; +} + +ModuleSpec BuildOutputProjSpec(const TransformerConfig &config, int64_t output_size, bool use_bias) { + ModuleSpec spec(typeid(parallel::ColumnParallelLinear)); + spec.with_param(kInFeatures, static_cast(config.n_embd)) + .with_param(kOutFeatures, static_cast(output_size)) + .with_param(kBias, use_bias); + return spec; +} + +} // namespace infini_train::nn diff --git a/infini_train/src/core/transformer/transformer_layer.cc b/infini_train/src/core/transformer/transformer_layer.cc new file mode 100644 index 00000000..c5c9f293 --- /dev/null +++ b/infini_train/src/core/transformer/transformer_layer.cc @@ -0,0 +1,250 @@ +#include "infini_train/include/core/transformer/transformer_layer.h" + +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_block.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/container.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/sparse.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" +#include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/tensor.h" +#include "third_party/glog/src/glog/logging.h" + +using namespace infini_train; + +namespace infini_train::nn { + +TransformerFirstStage::TransformerFirstStage(const TransformerConfig &config, const ModuleSpec &spec) + : CloneableModule(kType), config_(config), spec_(spec) { + // Build token embedding (required for all models) + modules_[kWTELayerName] = build_module(config, spec.submodules_.at(kWTELayerName)); + + // Build position embedding only for models that use absolute position encoding + // LLaMA3 use RoPE, so they don't need position embedding + if (config_.attention_type == AttentionType::kStandard) { + modules_[kWPELayerName] = build_module(config, spec.submodules_.at(kWPELayerName)); + } +} + +std::vector> TransformerFirstStage::Forward(const std::vector> &input) { + // (B, T) + auto x1 = input[0]; + CHECK_LE(x1->Dims()[1], config_.block_size) + << "Cannot forward sequence of length " << x1->Dims()[1] << ", block size is only " << config_.block_size; + const auto device = x1->GetDevice(); + + // (B, T) -> Embedding(V_local, C) -> (B, T, C) + auto tok_emb = (*modules_[kWTELayerName])({x1}); + + // Add position embedding only for models that use absolute position encoding + if (config_.attention_type == AttentionType::kStandard) { + // (T_local) + // NOTE(zbl): Slice pos sequence when SP is enabled + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); + int tp_rank = 0; + if (tp_world_size > 1) { + auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( + nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); + tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); + } + int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1]; + int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; + auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); + + // (T) -> Embedding(T_max, C) -> (T, C) + auto pos_emb = (*modules_[kWPELayerName])({pos}); + // (B, T, C) + return {tok_emb[0] + pos_emb[0]}; + } else { + // For RoPE-based models (LLaMA3), no position embedding needed + // (B, T, C) + return tok_emb; + } +} + +TransformerChunk::TransformerChunk(const TransformerConfig &config, int start_layer, int end_layer, + const ModuleSpec &spec) + : CloneableModule(kType), config_(config), spec_(spec) { + std::vector> h; + for (int64_t i = start_layer; i < end_layer; ++i) { + auto layer = std::make_shared(config, spec); + h.push_back(layer); + } + modules_[kHLayerName] = std::make_shared(std::move(h)); +} + +std::vector> TransformerChunk::Forward(const std::vector> &x) { + auto x1 = x[0]; + + // Check if we need to pass RoPE parameters (for LLaMA3 style models) + if (config_.attention_type == AttentionType::kRoPE) { + // For RoPE models, we need to prepare freqs_cis and potentially other parameters + const auto device = x1->GetDevice(); + + // Init freqs_cis on device only once + if (buffers_[kFreqsCisName] == nullptr) { + int64_t head_dim = config_.n_embd / config_.n_head; + buffers_[kFreqsCisName] = PrecomputeFreqsCis(head_dim, config_.block_size * 2, config_.rope_theta, + config_.use_scaled_rope, device); + } + + const auto t = x1->Dims()[1] * nn::parallel::global::GetSequenceParallelSize(); // full_seq_len + + // Dynamic start_pos (set to 0 for now) + int64_t start_pos = 0; + auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + t, 1); + + // Create causal mask + std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(device)); + std::shared_ptr mask = nn::function::Triu(ones, 1)->View({1, 1, t, t}); + + std::shared_ptr start_pos_ptr = nullptr; + + // Pass RoPE parameters to each transformer block + for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { + x1 = (*h)({x1, freqs_view, start_pos_ptr, mask})[0]; + } + } else { + // Standard attention (GPT2 style) + for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = (*h)({x1})[0]; } + } + + return {x1}; +} + +// Add RoPE helper method to TransformerChunk +std::shared_ptr TransformerChunk::PrecomputeFreqsCis(int64_t dim, int64_t end, float theta, bool use_scaled, + infini_train::Device device) { + auto dtype = DataType::kFLOAT32; + CHECK_GE(dim, 2) << "dim must be >= 2 for slicing"; + + auto arange = nn::init::Arange(0, dim, dtype, device)->Slice(0, 0, dim, 2); + auto freqs = 1.0f / nn::function::Pow(theta, arange / float(dim)); + // TODO(zbl): use_scaled + // if (use_scaled) { + // freqs = ApplyScaling(freqs, 8192.0f); + // } + auto t = nn::init::Arange(0, end, dtype, device); + // (end, dim / 2) + auto freqs_outer = t->Outer(freqs); + auto cos = nn::function::Cos(freqs_outer); + auto sin = nn::function::Sin(freqs_outer); + // NOTE(zbl): torch script uses cis expression, here use stack + // (end, dim / 2, 2) + auto freqs_cis = nn::function::Stack(std::vector>{cos, sin}, -1)->Contiguous(); + + return freqs_cis; +} + +TransformerLastStage::TransformerLastStage(const TransformerConfig &config, const ModuleSpec &spec) + : CloneableModule(kType), config_(config), spec_(spec) { + CHECK(spec.submodules_.contains(kLnFLayerName)) << "TransformerLastStage spec missing submodule: " << kLnFLayerName; + CHECK(spec.submodules_.contains(kLMHeadLayerName)) + << "TransformerLastStage spec missing submodule: " << kLMHeadLayerName; + modules_[kLnFLayerName] = build_module(config, spec.submodules_.at(kLnFLayerName)); + modules_[kLMHeadLayerName] = build_module(config, spec.submodules_.at(kLMHeadLayerName)); +} + +std::vector> TransformerLastStage::Forward(const std::vector> &x) { + // (B, T, C) -> Layernorm -> (B, T, C) + auto x1 = (*modules_[kLnFLayerName])(x); + + // TODO(dcj): add inference-time mini-optimization + // (B, T, C) -> Linear(C, V) -> (B, T, V) + return (*modules_[kLMHeadLayerName])(x1); +} + +TransformerLayer::TransformerLayer(const TransformerConfig config, const ModuleSpec &spec + /*, const std::unordered_map ¶ms*/) + : CloneableModule(kType), config_(config), spec_(spec), + stage_info_(nn::parallel::PipelineParallel::GetStageInfo( + config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, + nn::parallel::global::GetVirtualPipelineParallelSize())) { + + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + + // NOTE(zbl): VocabParallelEmbedding requires vocab_size % tp_size == 0 + // Megatron-LM has an optional argument `--make-vocab-size-divisible-by`, would do padding to vocab + // Here we introduce padding by default, might need modify Tokenizer correspondingly later + CHECK_EQ(config.vocab_size % tp_world_size, 0) << "Vocab size should be divisible by TP world size"; + + std::unordered_map> transformer; + if (stage_info_.is_first_stage) { + modules_[kPPFirstStageName] + = std::make_shared(config_, spec_.submodules_.at(TransformerFirstStage::kType)); + transformer[TransformerFirstStage::kWTELayerName] + = modules_[kPPFirstStageName]->mutable_module(TransformerFirstStage::kWTELayerName); + if (config_.attention_type == AttentionType::kStandard) { + transformer[TransformerFirstStage::kWPELayerName] + = modules_[kPPFirstStageName]->mutable_module(TransformerFirstStage::kWPELayerName); + } + } + + { + std::map>> start_layer_to_layer_size_and_chunk; + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + const auto [start_layer, end_layer] = stage_info_.layer_ranges_per_chunk[chunk_idx]; + auto chunk = std::make_shared(config_, start_layer, end_layer, + spec_.submodules_.at(TransformerBlock::kType)); + start_layer_to_layer_size_and_chunk[start_layer] = std::make_pair(end_layer - start_layer, chunk); + } + std::vector> h; + int chunk_idx = 0; + for (auto &[start_layer, layer_size_and_chunk] : start_layer_to_layer_size_and_chunk) { + auto [layer_size, chunk] = layer_size_and_chunk; + for (int idx = 0; idx < layer_size; ++idx) { + h.push_back(chunk->mutable_module(TransformerChunk::kHLayerName)->mutable_module(std::to_string(idx))); + } + modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)] = std::move(chunk); + ++chunk_idx; + } + transformer[TransformerChunk::kHLayerName] = std::make_shared(std::move(h)); + } + + if (stage_info_.is_last_stage) { + modules_[kPPLastStageName] + = std::make_shared(config_, spec_.submodules_.at(TransformerLastStage::kType)); + transformer[TransformerLastStage::kLnFLayerName] + = modules_[kPPLastStageName]->mutable_module(TransformerLastStage::kLnFLayerName); + modules_[TransformerLastStage::kLMHeadLayerName] + = modules_[kPPLastStageName]->mutable_module(TransformerLastStage::kLMHeadLayerName); + } + modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); + + // FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation + // TODO: Implement real GPT-2 weight tying: make lm_head.weight share the exact same Parameter/Tensor (same + // shared_ptr/storage) as transformer.wte.weight (pointer aliasing, not value copy), and ensure the tie is applied + // after loading weights so it won't be overwritten. Also fix GPT2::FromLLMC() loading logic to respect weight tying + // (do not create/load a separate lm_head.weight tensor; load once into the tied weight) so parameter counting + // matches PyTorch/PEFT. + if (config_.tie_weights && nn::parallel::global::GetPipelineParallelSize() == 1) { + // https://paperswithcode.com/method/weight-tying + *mutable_module(kTransformerLayerName) + ->mutable_module(TransformerFirstStage::kWTELayerName) + ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) + = module(TransformerLastStage::kLMHeadLayerName) + .parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); + } +} + +std::vector> TransformerLayer::Forward(const std::vector> &x) { + auto x1 = (*modules_[kPPFirstStageName])(x); + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + x1 = (*modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)])(x1); + } + return (*modules_[kPPLastStageName])(x1); +} + +} // namespace infini_train::nn diff --git a/test/transformer_spec/test_transformer_spec.cc b/test/transformer_spec/test_transformer_spec.cc new file mode 100644 index 00000000..3e8a1787 --- /dev/null +++ b/test/transformer_spec/test_transformer_spec.cc @@ -0,0 +1,329 @@ +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/models/decode_only_transformer/layer_specs.h" +#include "infini_train/include/core/models/decode_only_transformer/model.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/core/transformer/spec_utils.h" +#include "infini_train/include/core/transformer/transformer_block.h" +#include "infini_train/include/core/transformer/transformer_builders.h" +#include "infini_train/include/core/transformer/transformer_config.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/tensor.h" +#include "infini_train/src/core/runtime/cpu/cpu_guard_impl.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +// ============================================================================ +// Test 1: Basic Module Registration +// ============================================================================ +void test_module_registry() { + std::cout << "\n=== Test 1: Module Registration ===" << std::endl; + + bool all_registered = true; + + // Check all required modules are registeredP + if (!nn::ModuleRegistry::Instance().Has(typeid(nn::CausalSelfAttention))) { + std::cout << "FAIL: CausalSelfAttention not registered" << std::endl; + all_registered = false; + } + + if (!nn::ModuleRegistry::Instance().Has(typeid(nn::MLP))) { + std::cout << "FAIL: MLP not registered" << std::endl; + all_registered = false; + } + + if (!nn::ModuleRegistry::Instance().Has(typeid(nn::TransformerBlock))) { + std::cout << "FAIL: TransformerBlock not registered" << std::endl; + all_registered = false; + } + + if (!nn::ModuleRegistry::Instance().Has(typeid(nn::LayerNorm))) { + std::cout << "FAIL: LayerNorm not registered" << std::endl; + all_registered = false; + } + + if (!nn::ModuleRegistry::Instance().Has(typeid(nn::RMSNorm))) { + std::cout << "FAIL: RMSNorm not registered" << std::endl; + all_registered = false; + } + + if (!nn::ModuleRegistry::Instance().Has(typeid(nn::Embedding))) { + std::cout << "FAIL: Embedding not registered" << std::endl; + all_registered = false; + } + + if (!nn::ModuleRegistry::Instance().Has(typeid(nn::parallel::ColumnParallelLinear))) { + std::cout << "FAIL: ColumnParallelLinear not registered" << std::endl; + all_registered = false; + } + + if (!nn::ModuleRegistry::Instance().Has(typeid(nn::parallel::RowParallelLinear))) { + std::cout << "FAIL: RowParallelLinear not registered" << std::endl; + all_registered = false; + } + + if (!nn::ModuleRegistry::Instance().Has(typeid(nn::parallel::VocabParallelEmbedding))) { + std::cout << "FAIL: VocabParallelEmbedding not registered" << std::endl; + all_registered = false; + } + + if (all_registered) { + std::cout << "SUCCESS: All required modules are registered!" << std::endl; + } +} + +// ============================================================================ +// Test 2: GPT2 Spec Building +// ============================================================================ +void test_gpt2_spec() { + std::cout << "\n=== Test 2: GPT2 Spec Building ===" << std::endl; + + // Create GPT2 configuration + nn::TransformerConfig config = nn::TransformerConfig::GPT2(); + config.block_size = 1024; + config.vocab_size = 50257; + config.n_layer = 12; + config.n_head = 12; + config.n_embd = 768; + + // Build GPT2 spec + nn::ModuleSpec spec = nn::BuildGPT2Spec(config); + + // Verify spec structure + bool test_passed = true; + + if (spec.submodules_.empty()) { + std::cout << "FAIL: GPT2 spec has no submodules" << std::endl; + test_passed = false; + } + + if (!spec.submodules_.contains(nn::TransformerFirstStage::kType)) { + std::cout << "FAIL: GPT2 spec missing 'first_stage'" << std::endl; + test_passed = false; + } + + if (!spec.submodules_.contains(nn::TransformerBlock::kType)) { + std::cout << "FAIL: GPT2 spec missing 'block'" << std::endl; + test_passed = false; + } + + if (!spec.submodules_.contains(nn::TransformerLastStage::kType)) { + std::cout << "FAIL: GPT2 spec missing 'last_stage'" << std::endl; + test_passed = false; + } + + // Verify first_stage submodules + auto &first_stage = spec.submodules_[nn::TransformerFirstStage::kType]; + if (!first_stage.submodules_.contains("wte")) { + std::cout << "FAIL: first_stage missing 'wte'" << std::endl; + test_passed = false; + } + + if (!first_stage.submodules_.contains("wpe")) { + std::cout << "FAIL: first_stage missing 'wpe'" << std::endl; + test_passed = false; + } + + // Verify last_stage submodules + auto &last_stage = spec.submodules_[nn::TransformerLastStage::kType]; + if (!last_stage.submodules_.contains("ln_f")) { + std::cout << "FAIL: last_stage missing 'ln_f'" << std::endl; + test_passed = false; + } + + if (!last_stage.submodules_.contains("lm_head")) { + std::cout << "FAIL: last_stage missing 'lm_head'" << std::endl; + test_passed = false; + } + + if (test_passed) { + std::cout << "SUCCESS: GPT2 spec structure is correct!" << std::endl; + } +} + +// ============================================================================ +// Test 3: LLaMA3 Spec Building +// ============================================================================ +void test_llama3_spec() { + std::cout << "\n=== Test 3: LLaMA3 Spec Building ===" << std::endl; + + // Create LLaMA3 configuration + nn::TransformerConfig config = nn::TransformerConfig::LLaMA3(); + config.block_size = 8192; + config.vocab_size = 128256; + config.n_layer = 32; + config.n_head = 32; + config.n_kv_head = 8; + config.n_embd = 4096; + config.ffn_dim_multiplier = 1.3f; + config.multiple_of = 256; + + // Build LLaMA3 spec + nn::ModuleSpec spec = nn::BuildLLaMA3Spec(config); + + // Verify spec structure + bool test_passed = true; + + if (spec.submodules_.empty()) { + std::cout << "FAIL: LLaMA3 spec has no submodules" << std::endl; + test_passed = false; + } + + if (!spec.submodules_.contains(nn::TransformerFirstStage::kType)) { + std::cout << "FAIL: LLaMA3 spec missing 'first_stage'" << std::endl; + test_passed = false; + } + + if (!spec.submodules_.contains(nn::TransformerBlock::kType)) { + std::cout << "FAIL: LLaMA3 spec missing 'block'" << std::endl; + test_passed = false; + } + + if (!spec.submodules_.contains(nn::TransformerLastStage::kType)) { + std::cout << "FAIL: LLaMA3 spec missing 'last_stage'" << std::endl; + test_passed = false; + } + + // Verify first_stage has only wte (LLaMA3 uses RoPE) + auto &first_stage = spec.submodules_[nn::TransformerFirstStage::kType]; + if (!first_stage.submodules_.contains("wte")) { + std::cout << "FAIL: first_stage missing 'wte'" << std::endl; + test_passed = false; + } + + if (first_stage.submodules_.contains("wpe")) { + std::cout << "FAIL: first_stage should not have 'wpe' (LLaMA3 uses RoPE)" << std::endl; + test_passed = false; + } + + if (test_passed) { + std::cout << "SUCCESS: LLaMA3 spec structure is correct!" << std::endl; + } +} + +// ============================================================================ +// Test 4: GPT2 Model Instantiation +// ============================================================================ +void test_gpt2_instantiation() { + std::cout << "\n=== Test 4: GPT2 Model Instantiation ===" << std::endl; + + nn::TransformerConfig config = nn::TransformerConfig::GPT2(); + config.block_size = 1024; + config.vocab_size = 50257; + config.n_layer = 12; + config.n_head = 12; + config.n_embd = 768; + + try { + auto model = std::make_shared(config); + + if (model == nullptr) { + std::cout << "FAIL: Failed to create GPT2 model" << std::endl; + } else if (model->Parameters().empty()) { + std::cout << "FAIL: GPT2 model has no parameters" << std::endl; + } else { + std::cout << "SUCCESS: GPT2 model created with " << model->Parameters().size() << " parameters!" + << std::endl; + } + } catch (const std::exception &e) { + std::cout << "FAIL: Exception during GPT2 model creation: " << e.what() << std::endl; + } +} + +// ============================================================================ +// Test 5: LLaMA3 Model Instantiation +// ============================================================================ +void test_llama3_instantiation() { + std::cout << "\n=== Test 5: LLaMA3 Model Instantiation ===" << std::endl; + + nn::TransformerConfig config = nn::TransformerConfig::LLaMA3(); + + try { + auto model = std::make_shared(config); + + if (model == nullptr) { + std::cout << "FAIL: Failed to create LLaMA3 model" << std::endl; + } else if (model->Parameters().empty()) { + std::cout << "FAIL: LLaMA3 model has no parameters" << std::endl; + } else { + std::cout << "SUCCESS: LLaMA3 model created with " << model->Parameters().size() << " parameters!" + << std::endl; + } + } catch (const std::exception &e) { + std::cout << "FAIL: Exception during LLaMA3 model creation: " << e.what() << std::endl; + } +} + +// ============================================================================ +// Test 6: Dimension Validation (Simplified) +// ============================================================================ +void test_dimensions() { + std::cout << "\n=== Test 6: Dimension Validation ===" << std::endl; + + nn::TransformerConfig config = nn::TransformerConfig::GPT2(); + config.block_size = 1024; + config.vocab_size = 50257; + config.n_layer = 12; + config.n_head = 12; + config.n_embd = 768; + + try { + auto model = std::make_shared(config); + + // Create input tensor (batch, seq_len) + std::vector input_shape = {2, 64}; + auto input = std::make_shared(input_shape, DataType::kINT64, Device()); + + // Forward pass + auto output = (*model)({input}); + + // Verify output dimensions (batch, seq_len, vocab_size) + if (output.empty()) { + std::cout << "FAIL: Model produced no output" << std::endl; + } else if (output[0]->Dims().size() != 3) { + std::cout << "FAIL: Expected 3D output, got " << output[0]->Dims().size() << "D" << std::endl; + } else if (output[0]->Dims()[0] != 2) { + std::cout << "FAIL: Expected batch size 2, got " << output[0]->Dims()[0] << std::endl; + } else if (output[0]->Dims()[1] != 64) { + std::cout << "FAIL: Expected seq length 64, got " << output[0]->Dims()[1] << std::endl; + } else if (output[0]->Dims()[2] != config.vocab_size) { + std::cout << "FAIL: Expected vocab size " << config.vocab_size << ", got " << output[0]->Dims()[2] + << std::endl; + } else { + std::cout << "SUCCESS: Output dimensions are correct! (" << output[0]->Dims()[0] << ", " + << output[0]->Dims()[1] << ", " << output[0]->Dims()[2] << ")" << std::endl; + } + } catch (const std::exception &e) { + std::cout << "FAIL: Exception during dimension test: " << e.what() << std::endl; + } +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + + nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, 1); + + std::cout << "========================================" << std::endl; + std::cout << " Transformer Spec Tests" << std::endl; + std::cout << "========================================" << std::endl; + + test_module_registry(); + test_gpt2_spec(); + test_llama3_spec(); + test_gpt2_instantiation(); + test_llama3_instantiation(); + test_dimensions(); + + std::cout << "\n========================================" << std::endl; + std::cout << " All Tests Completed" << std::endl; + std::cout << "========================================" << std::endl; + + return 0; +}