diff --git a/CMakeLists.txt b/CMakeLists.txt index 699d2862a8..c33f0bf260 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ project(TurboMind LANGUAGES CXX CUDA) if (MSVC) # use standard conformant preprocessor add_compile_options($<$:/Zc:preprocessor>) + add_compile_options($<$:/Zc:__cplusplus>) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/Zc:preprocessor -Xcompiler=/Zc:__cplusplus") endif () @@ -101,6 +102,10 @@ if(NOT xgrammar_POPULATED) # Bring the populated content into the build add_subdirectory(${xgrammar_SOURCE_DIR} ${xgrammar_BINARY_DIR}) + if(TARGET xgrammar) + target_compile_options(xgrammar PRIVATE $<$:/utf-8>) + target_compile_options(xgrammar PRIVATE $<$:/utf-8>) + endif() endif() # the environment variable diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 3a5af42e3e..e5f80411b5 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -327,6 +327,7 @@ def parse_args(): tb_group._group_actions.append(dtype_act) ArgumentHelper.dp(tb_group) + ArgumentHelper.cp(tb_group) ArgumentHelper.model_format(tb_group, default='hf') ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) @@ -344,6 +345,7 @@ def main(): max_batch_size=args.concurrency // args.dp, tp=args.tp, dp=args.dp, + cp=args.cp, cache_max_entry_count=args.cache_max_entry_count, cache_block_seq_len=args.cache_block_seq_len, model_format=args.model_format, diff --git a/builder/windows/generate.ps1 b/builder/windows/generate.ps1 index 0c133b37d0..e54f8fe742 100644 --- a/builder/windows/generate.ps1 +++ b/builder/windows/generate.ps1 @@ -1,4 +1,4 @@ -cmake .. -A x64 -T "v142,cuda=$env:CUDA_PATH" ` +cmake .. -A x64 -T "v143,cuda=$env:CUDA_PATH" ` -DCMAKE_BUILD_TYPE=Release ` -DCMAKE_INSTALL_PREFIX=install ` -DBUILD_PY_FFI=ON ` diff --git a/docs/en/advance/context_parallel.md b/docs/en/advance/context_parallel.md new file mode 100644 index 0000000000..cf0c97f48b --- /dev/null +++ b/docs/en/advance/context_parallel.md @@ -0,0 +1,24 @@ +# Context Parallel + +When the memory on a single GPU is insufficient to deploy a model, it is often deployed using tensor parallelism (TP), which generally requires `num_key_value_heads` to be divisible by `TP`. If you want to deploy with `TP > num_key_value_heads`, the kv-heads should be duplicated to meet the divisibility requirement. However, this has two disadvantages: + +1. The amount of available kv_cache is halved, which reducing the maximum supported session length. +2. The maximum inference batch size is reduced, leading to lower throughput. + +To address this issue, the TurboMind inference backend supports setting `attn_dp_size`, which avoids creating copies of kv-heads, but this introduces data imbalance. To eliminate data imbalance, TurboMind supports sequence parallelism, which allowing kv_cache to be stored interleaved on different cp_ranks. See the example below: + +``` +cp_rank=2, prompt_len=5, generation_len=4 +kv_cache stored on cp_rank0: 0, 2, 4, 6, 8 +kv_cache stored on cp_rank1: 1, 3, 5, 7 +``` + +## Usage + +Taking Intern-S1 / Qwen3-235B-A22B as an example, their `num_key_value_heads` is 4. If you want to deploy with `TP=8` and avoid duplication of kv_cache, you can deploy in the following way: + +``` +lmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2 + +lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2 +``` diff --git a/docs/en/index.rst b/docs/en/index.rst index b64c230cb8..b28042a977 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -103,6 +103,7 @@ Documentation advance/pytorch_multinodes.md advance/pytorch_profiling.md advance/metrics.md + advance/context_parallel.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_cn/advance/context_parallel.md b/docs/zh_cn/advance/context_parallel.md new file mode 100644 index 0000000000..faea118505 --- /dev/null +++ b/docs/zh_cn/advance/context_parallel.md @@ -0,0 +1,24 @@ +# 序列并行 + +在单卡显存不足以部署模型的时候,通常会以 `TP` 的方式进行部署,而这一般要求 `num_key_value_heads` 被 `TP` 整除。如果要以 `TP > num_key_value_heads` 的方式进行部署,需要创建 kv-heads 的副本,以满足整除需求。但是这样会有两个缺点: + +1. 可用的 kvcache 数量减半,进而减少请求最大推理长度 +2. 降低推理的最大 batch 数量,减少吞吐量。 + +为了解决这个问题,TurboMind 推理后端支持设置 `attn_dp_size`,避免了创建 kv-heads 的副本,但是这会引入数据的不均衡性。为了消除数据的不均衡,TurboMind 支持了序列并行,支持将 kv_cache 交错存储到不同的 cp_rank 上,例如 + +``` +cp_rank=2, prompt_len=5, generation_len=4 +kv_cache stored on cp_rank0: 0, 2, 4, 6, 8 +kv_cache stored on cp_rank1: 1, 3, 5, 7 +``` + +## 使用说明 + +以 `Intern-S1` / `Qwen3-235B-A22B` 为例,他们的 `num_key_value_heads` 为 4,若要用 `TP=8` 的方式部署,并避免 kv_cache 的拷贝,可以用如下的方式部署 + +``` +lmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2 + +lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2 +``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index bd946ba96e..733bfc585e 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能: advance/pytorch_multinodes.md advance/pytorch_profiling.md advance/metrics.md + advance/context_parallel.md .. toctree:: :maxdepth: 1 diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index d71198791f..9c0a19138e 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -76,6 +76,7 @@ def add_parser_chat(): ArgumentHelper.model_format(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.communicator(tb_group) + ArgumentHelper.cp(tb_group) @staticmethod def add_parser_checkenv(): diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 8f67743951..4413afddcf 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -111,7 +111,7 @@ def add_parser_api_server(): model_format = ArgumentHelper.model_format(pt_group) hf_overrides = ArgumentHelper.hf_overrides(pt_group) disable_metrics = ArgumentHelper.disable_metrics(pt_group) - ArgumentHelper.dp(pt_group) + dp = ArgumentHelper.dp(pt_group) ArgumentHelper.ep(pt_group) ArgumentHelper.enable_microbatch(pt_group) ArgumentHelper.enable_eplb(pt_group) @@ -136,6 +136,8 @@ def add_parser_api_server(): tb_group._group_actions.append(model_format) tb_group._group_actions.append(hf_overrides) tb_group._group_actions.append(disable_metrics) + tb_group._group_actions.append(dp) + ArgumentHelper.cp(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) @@ -233,6 +235,8 @@ def api_server(args): from lmdeploy.messages import TurbomindEngineConfig backend_config = TurbomindEngineConfig(dtype=args.dtype, tp=args.tp, + dp=args.dp, + cp=args.cp, max_batch_size=max_batch_size, session_len=args.session_len, model_format=args.model_format, @@ -251,7 +255,7 @@ def api_server(args): from lmdeploy.messages import VisionConfig vision_config = VisionConfig(args.vision_max_batch_size) - if args.dp == 1: + if args.dp == 1 or backend == 'turbomind': from lmdeploy.serve.openai.api_server import serve as run_api_server run_api_server(args.model_path, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index a53a3cdc86..01a7768e97 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -188,6 +188,16 @@ def ep(parser): default=1, help='expert parallelism. dp is required when pytorch engine is used.') + @staticmethod + def cp(parser): + """Add argument cp to parser.""" + + return parser.add_argument( + '--cp', + type=int, + default=1, + help='context parallelism size in attention for turbomind backend. Should divide tp.') + @staticmethod def dp_rank(parser): """Add argument dp_rank to parser.""" diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 57725eb23f..fdab5d502e 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -235,8 +235,10 @@ class TurbomindEngineConfig: model_format: Optional[str] = None tp: int = 1 dp: int = 1 + cp: int = 1 device_num: int = None attn_tp_size: int = None + attn_cp_size: int = None attn_dp_size: int = None mlp_tp_size: int = None mlp_dp_size: int = None diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index ee7feb166f..5695cc2325 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -69,6 +69,7 @@ class ModelConfig: expert_weight_type: str = None session_len: int = None attn_tp_size: int = 1 + attn_cp_size: int = 1 mlp_tp_size: int = 1 model_format: str = 'hf' expert_num: List[int] = () diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index b336dbd5e8..45bbf83dc1 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -179,6 +179,7 @@ def get_tm_model(model_path, tm_cfg.model_config.model_name = model_name tm_cfg.model_config.attn_tp_size = engine_config.attn_tp_size + tm_cfg.model_config.attn_cp_size = engine_config.attn_cp_size tm_cfg.model_config.mlp_tp_size = engine_config.mlp_tp_size output_model = OUTPUT_MODELS.get(output_model_name)(input_model=input_model, diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py index 96ed4777a8..27f53ca452 100644 --- a/lmdeploy/turbomind/deploy/module.py +++ b/lmdeploy/turbomind/deploy/module.py @@ -335,14 +335,13 @@ def pad_weight(tensor: torch.Tensor, tp: int): return tensor return torch.nn.functional.pad(tensor, (0, 0, 0, pad_size), 'constant', 0) + tp = self.model.attn_tp_size * self.model.attn_cp_size if emb is not None: - tp = self.model.attn_tp_size emb = pad_weight(emb, tp=tp) self.model.save_split(emb, 'tok_embeddings.weight', split_dim=1, split_num=tp) if norm_weight is not None: self.model.export_weight(norm_weight, 'norm.weight') if output_weight is not None: - tp = self.model.attn_tp_size output_weight = pad_weight(output_weight, tp=tp) # transpose self.model.save_split(output_weight.t(), 'output.weight', split_dim=1, split_num=tp) diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 408e23e37f..d796848259 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -51,6 +51,7 @@ def __init__(self, input_model: BaseInputModel, cfg: TurbomindModelConfig, model self.attention_config = cfg.attention_config self.lora_config = cfg.lora_config self.attn_tp_size = self.model_config.attn_tp_size + self.attn_cp_size = self.model_config.attn_cp_size self.mlp_tp_size = self.model_config.mlp_tp_size self.out_dir = out_dir self.to_file = True if out_dir else False diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index db1bc67c48..ab9ddfea0a 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -100,11 +100,12 @@ def update_parallel_config(cfg: TurbomindEngineConfig): inner_tp_size = cfg.tp // mlp_tp_size cfg.outer_dp_size = cfg.dp // attn_dp_size cfg.attn_dp_size = attn_dp_size - cfg.attn_tp_size = inner_tp_size + cfg.attn_tp_size = inner_tp_size // cfg.cp + cfg.attn_cp_size = cfg.cp cfg.mlp_dp_size = 1 cfg.mlp_tp_size = mlp_tp_size * inner_tp_size - assert cfg.attn_dp_size * cfg.attn_tp_size == cfg.mlp_dp_size * cfg.mlp_tp_size - assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size == cfg.device_num + assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size == cfg.mlp_dp_size * cfg.mlp_tp_size + assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size * cfg.outer_dp_size == cfg.device_num cfg.devices = cfg.devices or list(range(cfg.device_num)) @@ -273,6 +274,7 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig): self._postprocess_config(tm_model.tm_config, engine_config) + print(yaml.safe_dump(self.config_dict)) model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='', config=yaml.safe_dump(self.config_dict), weight_type=self.config.model_config.weight_type) diff --git a/src/turbomind/comm/nccl/nccl.cu b/src/turbomind/comm/nccl/nccl.cu index 44b6e8d55b..af4faf29aa 100644 --- a/src/turbomind/comm/nccl/nccl.cu +++ b/src/turbomind/comm/nccl/nccl.cu @@ -65,7 +65,7 @@ static NcclApis& nccl_apis() static auto value = [] { int version{}; ncclGetVersion(&version); - auto handle = dlopen(nullptr, RTLD_LAZY); + auto handle = dlopen("libnccl.so.2", RTLD_LAZY); NcclApis apis{}; if (!handle) { return apis; diff --git a/src/turbomind/kernels/attention/CMakeLists.txt b/src/turbomind/kernels/attention/CMakeLists.txt index d9711f112c..5ea2d64e3b 100644 --- a/src/turbomind/kernels/attention/CMakeLists.txt +++ b/src/turbomind/kernels/attention/CMakeLists.txt @@ -45,7 +45,7 @@ set_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_compile_options(attention PRIVATE -O3 $<$:-use_fast_math --expt-relaxed-constexpr>) - +target_link_libraries(attention PRIVATE nvidia::cutlass::cutlass) if (BUILD_TEST) target_compile_options(attention PRIVATE @@ -60,6 +60,7 @@ if (BUILD_TEST) target_link_libraries(test_attention PRIVATE attention # flash_attention + nvidia::cutlass::cutlass Llama unfused_attention_kernels logger @@ -68,4 +69,7 @@ if (BUILD_TEST) add_executable(test_quant test_quant.cu test_utils.cu) target_compile_options(test_quant PRIVATE --generate-line-info -O3 -use_fast_math --expt-relaxed-constexpr) + target_link_libraries(test_quant PRIVATE + nvidia::cutlass::cutlass + ) endif () diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index 59a04368fa..4ec526d3e2 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -2,6 +2,7 @@ #pragma once +#include "cutlass/fast_math.h" #include #include @@ -23,6 +24,8 @@ struct BlockIteratorParams { int block_len; }; +typedef void (*cp_post_fn)(void* context); + /// TODO: Rename to attention::Param template struct AttentionParams { @@ -75,9 +78,14 @@ struct AttentionParams { int max_split_k; int* split_cnt; float* partial_O; - float* partial_M; - float* partial_L; - int* locks; + float* partial_ML; + + // context parallel + int cp_rank{0}; + cutlass::FastDivmod cp_size{1}; + int offset_q{0}; // decode offset + cp_post_fn cp_fn{nullptr}; + void* cp_fn_ctx{nullptr}; int arch; cudaStream_t stream; diff --git a/src/turbomind/kernels/attention/attention_template.h b/src/turbomind/kernels/attention/attention_template.h index 02dd8d20af..5c8d0ddbb7 100644 --- a/src/turbomind/kernels/attention/attention_template.h +++ b/src/turbomind/kernels/attention/attention_template.h @@ -12,8 +12,7 @@ namespace turbomind { template void invokeAttention(const typename Kernel::ParamType& params) { - static const size_t kSmemSize = - std::max(sizeof(typename Kernel::SharedStorage), sizeof(typename Kernel::ReduceOp::SharedStorage)); + static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage); if constexpr (1) { @@ -45,7 +44,8 @@ void invokeAttention(const typename Kernel::ParamType& params) return int2{sm_count, max_active_ctas}; }(); - const int tile_count = cdiv(std::min(params.max_k_len, params.window_size), Kernel::CTA_S); + const int max_cp_k_len = (params.max_k_len + params.cp_size - 1) / params.cp_size; + const int tile_count = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S); const int max_split_count = std::min(params.max_split_k, tile_count); typename Kernel::CtaMap cta_map{ @@ -80,18 +80,23 @@ void invokeAttention(const typename Kernel::ParamType& params) std::abort(); } - if (split_cnt > 1 && Kernel::need_separate_reduce(split_cnt)) { - attention::invokeReduce(params.out, - params.partial_M, - params.partial_L, - params.partial_O, - params.split_cnt, - params.max_split_k, - split_cnt, - params.token_num, - params.num_heads, - params.inv_sqrt_dh, - params.stream); + if (params.cp_fn) { + params.cp_fn(params.cp_fn_ctx); + } + + if (split_cnt > 1 || params.cp_size > 1) { + attention::invokeReduceV2(params.out + params.offset_q * params.num_heads * Kernel::kHeadDim, + params.partial_ML, + params.partial_O, + split_cnt > 1 ? params.split_cnt : nullptr, + params.max_split_k, + split_cnt, + params.cp_size, + params.cp_rank, + params.token_num, + params.num_heads, + params.inv_sqrt_dh, + params.stream); } } diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 5a1a9e7605..ce2719aa37 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -3,7 +3,6 @@ #pragma once #include "quantization.h" -#include "src/turbomind/kernels/attention/reduce_kernel.h" #include "src/turbomind/kernels/attention/rotary_embedding.h" #include "src/turbomind/kernels/core/array_ops.h" #include "src/turbomind/kernels/core/layout.h" @@ -46,8 +45,6 @@ struct AttentionUniversal { static constexpr int CTA_Q = Impl::CTA_Q; static constexpr int CTA_S = Impl::CTA_S; - using ReduceOp = attention::Reduce; - using SharedStorage = typename Mainloop::SharedStorage; static constexpr bool kProcessKV = CTA_Q == 1; @@ -256,6 +253,9 @@ struct AttentionUniversal { const int qi = offset.y / CTA_H; const int ti = history_len; + int cp_quo, cp_rem; + cp_quo = params.cp_size.divmod(cp_rem, ti); + Array param_K[1]; Array param_V[1]; @@ -276,7 +276,10 @@ struct AttentionUniversal { } iterator.block_head_.with( - iterator.block_ptrs_, ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + iterator.block_ptrs_, cp_quo, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + if (cp_rem != params.cp_rank) { + return; + } PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { const int di = offset.x + c * Map::kDeltaC; @@ -371,11 +374,18 @@ struct AttentionUniversal { const int context_len = params.cu_k_len[batch_idx + 1] - params.cu_k_len[batch_idx]; const int history_len = context_len - input_len; - const int last_K = history_len + min(query_idx + CTA_Q, input_len); - const int last_K_tile = (last_K - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index conversion + auto get_cp_len = [&](int length, int rank) -> int { + int cp_quo, cp_rem; + cp_quo = params.cp_size.divmod(cp_rem, length); + return (cp_quo + (cp_rem > rank ? 1 : 0)); + }; + + const int last_K = history_len + min(query_idx + CTA_Q, input_len); + const int last_K_tile = + (get_cp_len(last_K, 0) - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index conversion const int first_K = max(history_len + query_idx - (params.window_size - 1), 0); - const int first_K_tile = first_K / CTA_S; + const int first_K_tile = get_cp_len(first_K, 0) / CTA_S; const int tile_count = last_K_tile - first_K_tile; @@ -417,7 +427,7 @@ struct AttentionUniversal { const int offset_K = (first_K_tile + iter_end - 1) * CTA_S; // This is for avoiding OOB access only - const int max_K = min(context_len, (first_K_tile + iter_end) * CTA_S); + const int max_K = min(get_cp_len(context_len, params.cp_rank), (first_K_tile + iter_end) * CTA_S); int tile_iter = iter_end - iter_begin; @@ -430,6 +440,15 @@ struct AttentionUniversal { // -> x * CTA_S >= offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - w int mask_iter_front = cdiv(max(0, offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - params.window_size), CTA_S); + if (params.cp_size > 1) { + mask_iter_back = + cdiv(max(0, params.cp_size * (offset_K + CTA_S) - offset_Q + params.cp_rank), params.cp_size * CTA_S); + mask_iter_front = cdiv(max(0, + offset_Q + CTA_Q - params.window_size - params.cp_rank + - params.cp_size * (offset_K - tile_iter * CTA_S)), + params.cp_size * CTA_S); + } + #if 0 if (threadIdx.x == 0) { printf( @@ -453,6 +472,7 @@ struct AttentionUniversal { cache_iter.SetTile(first_K_tile + iter_end - 1); Mainloop mainloop; + mainloop.SetCpInfo(params.cp_size, params.cp_rank); mainloop(frag_Q, cache_iter, frag_O, @@ -482,66 +502,20 @@ struct AttentionUniversal { const bool separate_reduce = need_separate_reduce(cta_map.split_count()); - if (separate_reduce && iter_end == tile_count && head_idx == 0) { + if (split_cnt > 1 && iter_end == tile_count && head_idx == 0) { // Store actual split count, only used by separate reduction kernel for (int ti = threadIdx.x; ti < CTA_Q; ti += kWarpCount * WARP_SIZE) { if (qi_begin + ti < qi_end) { - params.split_cnt[qi_begin + ti] = split_idx ? split_idx + 1 : 0; + params.split_cnt[qi_begin + ti] = split_idx ? split_idx + 1 : (params.cp_size > 1 ? 1 : 0); } } } - if (iter_begin == 0 && iter_end == tile_count) { + if (iter_begin == 0 && iter_end == tile_count && params.cp_size == 1) { StoreO(frag_O, frag_L, qi_begin, qi_end, head_idx, params, storage); } else { - StorePartial(frag_O, frag_M, frag_L, qi_begin, qi_end, head_idx, split_idx, params, storage); - if (!separate_reduce) { - Reduce(qi_begin, head_idx, split_idx, iter_end == tile_count, params, cta_map, smem_buf); - } - } - } - - __device__ void Reduce(int qi_begin, - int head_idx, - int split_idx, - bool is_last, - const ParamType& params, - const CtaMap& cta_map, - char* smem_buf) - { - // Note: `head_idx` is cta_map.head_idx() * CTA_H - const auto index = (cta_map.batch_idx() * params.num_heads + cta_map.head_idx()) * params.max_split_k; - const auto locks = params.locks + index; - - if (!is_last) { // all but last split - sem_post(&locks[split_idx], 1, threadIdx.x == 0); - } - else { // only the last split - const int split_count = split_idx + 1; - - sem_wait_many(&locks[threadIdx.x], split_count - 1, threadIdx.x < split_count - 1); - - ReduceOp reduce_op; - reduce_op(params.out, - params.partial_M, - params.partial_L, - params.partial_O, - qi_begin, - head_idx, - params.num_heads, - hi_end_, - split_idx + 1, - params.max_split_k, - params.inv_sqrt_dh, - 1, - 0, - *(typename ReduceOp::SharedStorage*)smem_buf, - std::true_type{}); - - if (threadIdx.x < split_idx) { - locks[threadIdx.x] = 0; - } + StorePartial(frag_O, frag_M, frag_L, split_cnt, qi_begin, qi_end, head_idx, split_idx, params, storage); } } @@ -583,6 +557,7 @@ struct AttentionUniversal { __device__ void StorePartial(FragO& frag_O, FragM& frag_M, FragL& frag_L, + int split_cnt, int qi_begin, int qi_end, int head_idx, @@ -592,8 +567,8 @@ struct AttentionUniversal { { auto get_index = [&](int hi, int qi) { // [B, H, k, D] - return (qi_begin + qi) * params.num_heads * params.max_split_k + (head_idx + hi) * params.max_split_k - + split_idx; + return (qi_begin + qi - params.offset_q) * params.num_heads * params.max_split_k + + (head_idx + hi) * params.max_split_k + split_idx; }; Impl::StoreO(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) { @@ -605,8 +580,8 @@ struct AttentionUniversal { Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float M, float L) { const int index = get_index(hi, qi); if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) { - params.partial_M[index] = M; - params.partial_L[index] = L; + params.partial_ML[index * 2] = M; + params.partial_ML[index * 2 + 1] = L; } }); } diff --git a/src/turbomind/kernels/attention/decoding_template.h b/src/turbomind/kernels/attention/decoding_template.h index 37f6baebe3..d35c09f8ff 100644 --- a/src/turbomind/kernels/attention/decoding_template.h +++ b/src/turbomind/kernels/attention/decoding_template.h @@ -12,8 +12,7 @@ namespace turbomind { template bool invokeDecoding(const typename Kernel::ParamType& params) { - static const size_t kSmemSize = - std::max(sizeof(typename Kernel::SharedStorage), sizeof(typename Kernel::ReduceOp::SharedStorage)); + static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage); if constexpr (1) { [[maybe_unused]] static const int _ = [&] { @@ -25,7 +24,8 @@ bool invokeDecoding(const typename Kernel::ParamType& params) }(); } - const int tile_count = cdiv(std::min(params.max_k_len, params.window_size), Kernel::CTA_S); + const int max_cp_k_len = (params.max_k_len + params.cp_size - 1) / params.cp_size; + const int tile_count = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S); const int max_split_count = std::min(params.max_split_k, tile_count); using CtaMap = typename Kernel::CtaMap; @@ -79,18 +79,23 @@ bool invokeDecoding(const typename Kernel::ParamType& params) std::abort(); } - if (Kernel::need_separate_reduce(split_cnt)) { - attention::invokeReduce(params.out, - params.partial_M, - params.partial_L, - params.partial_O, - params.split_cnt, - params.max_split_k, - split_cnt, - params.token_num, - params.num_heads, - params.inv_sqrt_dh, - params.stream); + if (params.cp_fn) { + params.cp_fn(params.cp_fn_ctx); + } + + if (split_cnt > 1 || params.cp_size > 1) { + attention::invokeReduceV2(params.out, + params.partial_ML, + params.partial_O, + split_cnt > 1 ? params.split_cnt : nullptr, + params.max_split_k, + split_cnt, + params.cp_size, + params.cp_rank, + params.token_num, + params.num_heads, + params.inv_sqrt_dh, + params.stream); } return true; diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index adb697e8c4..72395c3808 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -13,6 +13,8 @@ namespace turbomind { +using cutlass::FastDivmod; + template __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, const T* k, @@ -28,6 +30,8 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, int64_t stride_h, int64_t stride_s, int layer_id, + int cp_rank, + FastDivmod cp_size, BlockLayout block_layout) { @@ -152,6 +156,8 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } } + int cp_quo, cp_rem; + blocks += cu_block_num[batch_idx]; block::Head block_head{block_layout, layer_id, head_idx}; @@ -159,9 +165,10 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int qi = offset.y + s * Map::kDeltaS + token_idx; // local offset into `input_length` - if (qi < q_len) { - const int ti = history_len + qi; // timestep - block_head.with((char**)blocks, ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + const int ti = history_len + qi; // timestep + cp_quo = cp_size.divmod(cp_rem, ti); + if (qi < q_len && cp_rem == cp_rank) { + block_head.with((char**)blocks, cp_quo, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { int di = offset.x + c * Map::kDeltaC; @@ -198,6 +205,8 @@ void invokeProcessKV_v2(char** blocks, int64_t stride_s, int block_seq_len, int layer_id, + int cp_rank, + FastDivmod cp_size, int max_q_len, int head_num, int head_dim, @@ -233,6 +242,8 @@ void invokeProcessKV_v2(char** blocks, stride_h, stride_s, layer_id, + cp_rank, + cp_size, block_layout); }; @@ -276,6 +287,8 @@ void invokeProcessKV_v2(char** blocks, int64_t stride_s, \ int block_seq_len, \ int layer_id, \ + int cp_rank, \ + FastDivmod cp_size, \ int max_q_len, \ int head_num, \ int head_dim, \ @@ -300,6 +313,8 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, int64_t stride_h, int64_t stride_s, int layer_id, + int cp_rank, + FastDivmod cp_size, BlockLayout block_layout) { constexpr int kVecSize = sizeof(uint4) / sizeof(T); @@ -341,11 +356,14 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, Array param_K[ITER_S]; Array param_V[ITER_S]; + int cp_quo, cp_rem; + PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int si = offset.y + s * Map::kDeltaS + token_idx; - if (si < seq_len) { - block_head.with((char**)blocks, si, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + cp_quo = cp_size.divmod(cp_rem, si); + if (si < seq_len && cp_rem == cp_rank) { + block_head.with((char**)blocks, cp_quo, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { int di = offset.x + c * Map::kDeltaC; @@ -389,11 +407,12 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { - const int si = offset.y + s * Map::kDeltaS + token_idx; - const int di = offset.x + c * Map::kDeltaC; - const int64_t index = - (batch_idx * stride_b + ti_beg * stride_c + si * stride_s + head_idx * stride_h) * HeadDim + di; - if (si < seq_len) { + const int si = offset.y + s * Map::kDeltaS + token_idx; + const int di = offset.x + c * Map::kDeltaC; + cp_quo = cp_size.divmod(cp_rem, si); + if (si < seq_len && cp_rem == cp_rank) { + const int64_t index = + (batch_idx * stride_b + ti_beg * stride_c + cp_quo * stride_s + head_idx * stride_h) * HeadDim + di; Store(&k[index], out_K[s][c]); Store(&v[index], out_V[s][c]); } @@ -414,6 +433,8 @@ void invokeFlattenKV_v2(T* k, int64_t stride_s, int block_seq_len, int layer_id, + int cp_rank, + FastDivmod cp_size, int max_seq_len, int head_num, int head_dim, @@ -446,6 +467,8 @@ void invokeFlattenKV_v2(T* k, stride_h, stride_s, layer_id, + cp_rank, + cp_size, block_layout); }; @@ -486,6 +509,8 @@ void invokeFlattenKV_v2(T* k, int64_t stride_s, \ int block_seq_len, \ int layer_id, \ + int cp_rank, \ + FastDivmod cp_size, \ int max_seq_len, \ int head_num, \ int head_dim, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index 01525f5596..e06b329e55 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -23,6 +23,8 @@ void invokeProcessKV_v2(char** blocks, int64_t stride_s, int block_seq_len, int layer_id, + int cp_rank, + cutlass::FastDivmod cp_size, int max_q_len, int head_num, int head_dim, @@ -48,6 +50,8 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.stride / params.size_per_head, // stride s params.block_iter_params.block_len, params.block_iter_params.layer_id, + params.cp_rank, + params.cp_size, params.max_q_len, params.num_kv_heads, params.size_per_head, @@ -69,6 +73,8 @@ void invokeFlattenKV_v2(T* k, int64_t stride_s, int block_seq_len, int layer_id, + int cp_rank, + cutlass::FastDivmod cp_size, int max_seq_len, int head_num, int head_dim, @@ -93,6 +99,8 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) 1, params.block_iter_params.block_len, params.block_iter_params.layer_id, + params.cp_rank, + params.cp_size, params.max_k_len, params.num_kv_heads, params.size_per_head, diff --git a/src/turbomind/kernels/attention/mainloop_sm70.h b/src/turbomind/kernels/attention/mainloop_sm70.h index c4d2e5afeb..a030d372de 100644 --- a/src/turbomind/kernels/attention/mainloop_sm70.h +++ b/src/turbomind/kernels/attention/mainloop_sm70.h @@ -40,6 +40,15 @@ struct Mainloop { static constexpr int CTA_S = Impl::CTA_S; + int cp_size_{1}; + int cp_rank_{0}; + + __device__ void SetCpInfo(int cp_size, int cp_rank) + { + cp_size_ = cp_size; + cp_rank_ = cp_rank; + } + template __device__ void operator()(FragQ& frag_Q, CacheIter& cache_iter, @@ -128,7 +137,7 @@ struct Mainloop { __device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K, int window_size) { Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) { - int w = (offset_Q + qi) - (offset_K + si); + int w = (offset_Q + qi) - ((offset_K + si) * cp_size_ + cp_rank_); if (0 <= w && w < window_size) {} else { score -= std::numeric_limits::infinity(); diff --git a/src/turbomind/kernels/attention/mainloop_sm80.h b/src/turbomind/kernels/attention/mainloop_sm80.h index 997a6aa9fc..3b07b717e4 100644 --- a/src/turbomind/kernels/attention/mainloop_sm80.h +++ b/src/turbomind/kernels/attention/mainloop_sm80.h @@ -49,6 +49,15 @@ struct Mainloop, Impl_> { using SharedStorage = typename Impl::SharedStorage; + int cp_size_{1}; + int cp_rank_{0}; + + __device__ void SetCpInfo(int cp_size, int cp_rank) + { + cp_size_ = cp_size; + cp_rank_ = cp_rank; + } + template __device__ void operator()(Args&&... args) { @@ -442,7 +451,7 @@ struct Mainloop, Impl_> { __device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K, int window_size) { Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) { - int w = (offset_Q + qi) - (offset_K + si); + int w = (offset_Q + qi) - ((offset_K + si) * cp_size_ + cp_rank_); if (0 <= w && w < window_size) {} else { score -= std::numeric_limits::infinity(); diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index c654f40d05..c8e7f8df14 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -1,79 +1,355 @@ // Copyright (c) OpenMMLab. All rights reserved. +#include "cutlass/fast_math.h" #include "src/turbomind/kernels/attention/cta_map.h" -#include "src/turbomind/kernels/attention/reduce_kernel.h" +#include "src/turbomind/kernels/core/array_ops.h" +#include "src/turbomind/kernels/core/thread_map.h" +#include "src/turbomind/utils/cuda_utils.h" #include namespace turbomind::attention { +int next_power_of_two(int v) +{ + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +__global__ void reduce_output(T* out, + const float* partial_ML, + float* partial_O, + const int* split_cnt_, + int max_split_cnt, + int query_num, + int head_num, + float exp_scale, + int stride_k, + int offset_k) +{ + __shared__ float s_out[WarpCnt][HeadDim]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + const int head_idx = ReduceCtaMap::head_idx(); + const int query_idx = ReduceCtaMap::query_idx(); + const int chunk_idx = ReduceCtaMap::split_idx(); + + offset_k *= chunk_idx; + const int split_cnt = (split_cnt_ != nullptr) ? split_cnt_[query_idx] : 1; + if (offset_k >= split_cnt) { // out of bound + return; + } + + // HeadDim / WARP_SIZE + // 128 -> 4 + // 64, 192 -> 2 + constexpr int kVecSize = HeadDim % 128 == 0 ? 4 : 2; + + using Map = RakedThreadMap; + static_assert(Map::kIterS == 1); + + constexpr int C = Map::kIterC; + + using Vec = Array; + + Vec accu_O[C]{}; + Vec frag_O[C]; + + const int2 d = Map::get_offset(warp_id, lane_id); + + auto for_each = [&](auto fn) { + const int ki = d.y; + PRAGMA_UNROLL + for (int c = 0; c < C; ++c) { + const int di = d.x + c * Map::kDeltaC; + fn(c, ki, di); + } + }; + + PRAGMA_UNROLL + for (int k = 0; k < CTA_K; k += WarpCnt) { + for_each([&](int c, int ki, int di) { + using namespace ops; + ki += k; + const int split_idx = offset_k + stride_k * ki; + const bool mask = split_idx < split_cnt; + const int index = (query_idx * head_num + head_idx) * max_split_cnt + split_idx; + const int offset = index * HeadDim + di; + if (mask) { + Load(frag_O[c], &partial_O[offset]); + accu_O[c] = accu_O[c] + frag_O[c] * (First ? partial_ML[index * 2] : 1.0f); + } + }); + } + + for_each([&](int c, int ki, int di) { + Store(&s_out[ki][di], accu_O[c]); // + }); + + PRAGMA_UNROLL + for (int w = WarpCnt / 2; w > 0; w /= 2) { + __syncthreads(); + for_each([&](int c, int ki, int di) { + using namespace ops; + if (ki < w) { + (Vec&)s_out[ki][di] = (Vec&)s_out[ki][di] + (Vec&)s_out[w + ki][di]; + } + }); + } + + for_each([&](int c, int ki, int di) { + if (ki == 0) { + if (gridDim.z == 1) { + const int offset = (query_idx * head_num + head_idx) * HeadDim + di; + Store(&out[offset], cast((Vec&)s_out[ki][di])); + } + else { + const int offset = ((query_idx * head_num + head_idx) * max_split_cnt + offset_k) * HeadDim + di; + Store(&partial_O[offset], (Vec&)s_out[ki][di]); + } + } + }); +} + template -void invokeReduce(T* out, - float* partial_M, - float* partial_L, - float* partial_O, - const int* split_cnt, - int partial_len, - int max_split_cnt, - int query_num, - int head_num, - float exp_scale, - cudaStream_t stream) +void invokeReduceOutput(T* out, + const float* partial_ML, // scale + float* partial_O, + const int* split_cnt, + int partial_len, + int max_split_cnt, + int query_num, + int head_num, + float exp_scale, + cudaStream_t stream) { constexpr int CTA_K = 32; // warp size - using Reduce = attention::Reduce; - - static constexpr size_t kSmemSize = sizeof(typename Reduce::SharedStorage); - static_assert(kSmemSize < (48 << 10)); - - auto invoke = [&](auto is_final, int stride_k) { - const dim3 block = Reduce::kWarpCnt * 32; - const dim3 grid = ReduceCtaMap::get_grid_shape(query_num, head_num, max_split_cnt, CTA_K); - reduce_kernel<<>>(out, // - partial_M, - partial_L, - partial_O, - nullptr, - split_cnt, - partial_len, - head_num, - exp_scale, - stride_k); + auto invoke = [&](auto is_first, int stride_k) { + constexpr int kWarpCnt = 4; + const dim3 block = kWarpCnt * WARP_SIZE; + const dim3 grid = ReduceCtaMap::get_grid_shape(query_num, head_num, max_split_cnt, CTA_K); + + static constexpr size_t kSmemSize = sizeof(float) * kWarpCnt * HeadDim; + static_assert(kSmemSize < (48 << 10)); + + reduce_output<<>>( // + out, + partial_ML, + partial_O, + split_cnt, + partial_len, + query_num, + head_num, + exp_scale, + stride_k, + stride_k * CTA_K); + + sync_check_cuda_error(); }; int stride_k = 1; + invoke(std::true_type{}, stride_k); while (max_split_cnt > CTA_K) { - invoke(std::false_type{}, stride_k); max_split_cnt = (max_split_cnt + CTA_K - 1) / CTA_K; stride_k *= CTA_K; + invoke(std::false_type{}, stride_k); } +} - invoke(std::true_type{}, stride_k); +template +__global__ void reduce_ML(float* partial_ML, // cp, q, h, k, 2 + const int* split_cnt_, + int max_split_cnt, + int query_num, + cutlass::FastDivmod head_num, + float exp_scale, + int cp_size, + int dim0) +{ + constexpr int kIterWarp = N / WARP_SIZE; + + float frag_M[kIterWarp]; + float frag_L[kIterWarp]; + + int qh = blockIdx.x * blockDim.y + threadIdx.y; + if (qh >= query_num * head_num) { + return; + } + + const int split_k = split_cnt_ != nullptr ? split_cnt_[head_num.div(qh)] : 1; + const int split_cnt = cp_size * split_k; + + float block_M = -std::numeric_limits::infinity(); + float block_L = 0.f; + + PRAGMA_UNROLL + for (int i = 0; i < kIterWarp; ++i) { + int ki = threadIdx.x + i * WARP_SIZE; + int index = (qh * max_split_cnt + ki) * 2; + bool mask = ki < split_cnt; + + if (mask && dim0 > 0) { // handle cp case + int cp_i = ki / split_k; + ki = ki % split_k; + index = cp_i * dim0 + (qh * max_split_cnt + ki) * 2; + } + + frag_M[i] = mask ? partial_ML[index] : -std::numeric_limits::infinity(); + frag_L[i] = mask ? partial_ML[index + 1] : 0.f; + block_M = max(block_M, frag_M[i]); + } + + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); + } + + PRAGMA_UNROLL + for (int i = 0; i < kIterWarp; ++i) { + block_L += (frag_M[i] == -std::numeric_limits::infinity()) ? + 0.0f : + exp2f((frag_M[i] - block_M) * exp_scale) * frag_L[i]; + } + + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); + } + + PRAGMA_UNROLL + for (int i = 0; i < kIterWarp; ++i) { + int ki = threadIdx.x + i * WARP_SIZE; + int index = (qh * max_split_cnt + ki) * 2; + bool mask = ki < split_cnt; + + if (dim0 > 0) { // handle cp case + int cp_i = ki / split_k; + ki = ki % split_k; + index = cp_i * dim0 + (qh * max_split_cnt + ki) * 2; + } + + float scale = (frag_M[i] == -std::numeric_limits::infinity()) ? + 0.0f : + exp2f((frag_M[i] - block_M) * exp_scale) / block_L; + if (mask) { + partial_ML[index] = scale; // save scale to M + } + } +} + +void invokeReduceML(float* partial_ML, + const int* split_cnt, + int partial_len, + int max_split_cnt, + int cp_size, + int cp_rank, + int query_num, + int head_num, + float exp_scale, + cudaStream_t stream) +{ + max_split_cnt *= cp_size; + TM_CHECK(max_split_cnt > 1); + + const int warp_cnt = 4; + const dim3 block(WARP_SIZE, warp_cnt); + const dim3 grid((query_num * head_num + warp_cnt - 1) / warp_cnt); + + const int dim0 = cp_size > 1 ? query_num * head_num * partial_len * 2 : 0; + partial_ML -= cp_rank * dim0; // begin address of cp_rank0 + + int n = max(next_power_of_two(max_split_cnt), WARP_SIZE); + switch (n) { +#define LAUNCH_REDUCE_ML(n) \ + case n: \ + reduce_ML<<>>( \ + partial_ML, split_cnt, partial_len, query_num, cutlass::FastDivmod(head_num), exp_scale, cp_size, dim0); \ + break; + + LAUNCH_REDUCE_ML(32); + LAUNCH_REDUCE_ML(64); + LAUNCH_REDUCE_ML(128); + LAUNCH_REDUCE_ML(256); + LAUNCH_REDUCE_ML(512); + LAUNCH_REDUCE_ML(1024); + default: + TM_CHECK(false) << "reduce_ML does not support max_split_cnt = " << max_split_cnt; +#undef LAUNCH_REDUCE_ML + } + + sync_check_cuda_error(); +} + +template +void invokeReduceV2(T* out, + float* partial_ML, + float* partial_O, + const int* split_cnt, + int partial_len, + int max_split_cnt, + int cp_size, + int cp_rank, + int query_num, + int head_num, + float exp_scale, + cudaStream_t stream) +{ + invokeReduceML(partial_ML, // + split_cnt, + partial_len, + max_split_cnt, + cp_size, + cp_rank, + query_num, + head_num, + exp_scale, + stream); + + invokeReduceOutput(out, // + partial_ML, + partial_O, + split_cnt, + partial_len, + max_split_cnt, + query_num, + head_num, + exp_scale, + stream); } -#define INSTANTIATE_invokeReduce(dim, type) \ - template void invokeReduce(type * out, \ - float* partial_M, \ - float* partial_L, \ - float* partial_O, \ - const int* split_cnt, \ - int partial_len, \ - int max_split_cnt, \ - int query_num, \ - int head_num, \ - float exp_scale, \ - cudaStream_t stream); - -INSTANTIATE_invokeReduce(64, half); -INSTANTIATE_invokeReduce(128, half); -INSTANTIATE_invokeReduce(192, half); +#define INSTANTIATE_invokeReduceV2(dim, type) \ + template void invokeReduceV2(type * out, \ + float* partial_ML, \ + float* partial_O, \ + const int* split_cnt, \ + int partial_len, \ + int max_split_cnt, \ + int cp_size, \ + int cp_rank, \ + int query_num, \ + int head_num, \ + float exp_scale, \ + cudaStream_t stream); + +INSTANTIATE_invokeReduceV2(64, half); +INSTANTIATE_invokeReduceV2(128, half); +INSTANTIATE_invokeReduceV2(192, half); #if ENABLE_BF16 -INSTANTIATE_invokeReduce(64, nv_bfloat16); -INSTANTIATE_invokeReduce(128, nv_bfloat16); -INSTANTIATE_invokeReduce(192, nv_bfloat16); +INSTANTIATE_invokeReduceV2(64, nv_bfloat16); +INSTANTIATE_invokeReduceV2(128, nv_bfloat16); +INSTANTIATE_invokeReduceV2(192, nv_bfloat16); #endif } // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/reduce.h b/src/turbomind/kernels/attention/reduce.h index c078de5958..53f40163e8 100644 --- a/src/turbomind/kernels/attention/reduce.h +++ b/src/turbomind/kernels/attention/reduce.h @@ -12,16 +12,16 @@ namespace turbomind::attention { template -void invokeReduce(T* out, - float* partial_M, - float* partial_L, - float* partial_O, - const int* split_cnt, - int partial_len, - int max_split_cnt, - int query_num, - int head_num, - float exp_scale, - cudaStream_t stream); - +void invokeReduceV2(T* out, + float* partial_ML, + float* partial_O, + const int* split_cnt, + int partial_len, + int max_split_cnt, + int cp_size, + int cp_rank, + int query_num, + int head_num, + float exp_scale, + cudaStream_t stream); } // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/reduce_kernel.h b/src/turbomind/kernels/attention/reduce_kernel.h deleted file mode 100644 index b4c9064cfe..0000000000 --- a/src/turbomind/kernels/attention/reduce_kernel.h +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "src/turbomind/kernels/attention/cta_map.h" -#include "src/turbomind/kernels/core/array_ops.h" -#include "src/turbomind/kernels/core/thread_map.h" -#include - -namespace turbomind::attention { - -template -struct Reduce { - using T = T_; - - static constexpr int CTA_H = CTA_H_; - static constexpr int CTA_K = CTA_K_; - static constexpr int kWarpCnt = WarpCnt; - - static_assert((CTA_K & (CTA_K - 1)) == 0, "must be pow of 2"); - - struct SharedStorage { - float scale[CTA_H][CTA_K]; - float O[CTA_H][WarpCnt][HeadDim]; - }; - - template - __device__ void operator()(T* out, - float* partial_M, - float* partial_L, - float* partial_O, - int query_idx, - int head_idx, - int head_num, - int hi_end, - int split_cnt, - int max_split_cnt, - float exp_scale, - int stride_k, - int offset_k, - SharedStorage& storage, - std::integral_constant) - { - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - - // iterations per warp, K > 1 when CTA_K is multiple of WARP_SIZE - constexpr int K = (CTA_K + WARP_SIZE - 1) / WARP_SIZE; - // heads per warp iteration, M > 1 when WARP_SIZE is multiple of CTA_K - constexpr int M = (WARP_SIZE + CTA_K - 1) / CTA_K; - // lanes per head, a warp is processing M heads in parallel - constexpr int L = WARP_SIZE / M; - - PRAGMA_UNROLL - for (int h = 0; h < CTA_H; h += WarpCnt * M) { - - const int hi = h + warp_id * M + lane_id / L; - - Array frag_M; - Array frag_L; - - fill(frag_M, -std::numeric_limits::infinity()); - fill(frag_L, 0.f); - - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - const int si = (lane_id % L + k * L) * stride_k + offset_k; - const int idx = (query_idx * head_num + head_idx + hi) * max_split_cnt + si; - const bool mask = hi < hi_end && si < split_cnt; - if (mask) { - frag_M[k] = partial_M[idx]; - frag_L[k] = partial_L[idx]; - } - } - - float block_M = frag_M[0]; - PRAGMA_UNROLL - for (int k = 1; k < K; ++k) { - block_M = fmaxf(block_M, frag_M[k]); - } - - PRAGMA_UNROLL - for (int mask = L / 2; mask >= 1; mask /= 2) { - block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); - } - - Array expdiff_M; - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - expdiff_M[k] = exp2f((frag_M[k] - block_M) * exp_scale); - } - - float block_L{}; - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - block_L += expdiff_M[k] * frag_L[k]; - } - - PRAGMA_UNROLL - for (int mask = L / 2; mask >= 1; mask /= 2) { - block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); - } - - Array scale; - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - scale[k] = IsFinal ? expdiff_M[k] / block_L : expdiff_M[k]; - } - - if (hi < CTA_H) { - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - storage.scale[hi][lane_id % L + k * L] = scale[k]; - } - } - - if constexpr (!IsFinal) { - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - const int si = (lane_id % L + k * L) * stride_k + offset_k; - const int idx = (query_idx * head_num + head_idx + hi) * max_split_cnt + si; - const bool mask = hi < hi_end && si < split_cnt; - if (mask) { - partial_M[idx] = block_M; - partial_L[idx] = block_L; - } - } - } - } - - __syncthreads(); - - // HeadDim / WARP_SIZE - // 128 -> 4 - // 64, 192 -> 2 - constexpr int kVecSize = HeadDim % 128 == 0 ? 4 : 2; - - using Map = RakedThreadMap; - - static_assert(Map::kIterS == CTA_H); - - constexpr int S = Map::kIterS; - constexpr int C = Map::kIterC; - - using Vec = Array; - - Vec accu_O[S][C]{}; - Vec frag_O[S][C]; - - const int2 d = Map::get_offset(warp_id, lane_id); - - auto for_each = [&](auto fn) { - PRAGMA_UNROLL - for (int s = 0; s < S; ++s) { - const int si = d.y + s * Map::kDeltaS; - const int hi = si % CTA_H; - const int ki = si / CTA_H; - PRAGMA_UNROLL - for (int c = 0; c < C; ++c) { - const int di = d.x + c * Map::kDeltaC; - fn(s, c, ki, hi, di); - } - } - }; - - PRAGMA_UNROLL - for (int k = 0; k < CTA_K; k += WarpCnt) { - for_each([&](int s, int c, int ki, int hi, int di) { - using namespace ops; - ki += k; - const int split_idx = offset_k + stride_k * ki; - const bool mask = split_idx < split_cnt && hi < hi_end; - const int offset = ((query_idx * head_num + head_idx + hi) * max_split_cnt + split_idx) * HeadDim + di; - if (mask) { - Load(frag_O[s][c], &partial_O[offset]); - accu_O[s][c] = accu_O[s][c] + frag_O[s][c] * storage.scale[hi][ki]; - } - }); - } - - for_each([&](int s, int c, int ki, int hi, int di) { - Store(&storage.O[hi][ki][di], accu_O[s][c]); // - }); - - PRAGMA_UNROLL - for (int w = WarpCnt / 2; w > 0; w /= 2) { - __syncthreads(); - for_each([&](int s, int c, int ki, int hi, int di) { - using namespace ops; - if (ki < w) { - (Vec&)storage.O[hi][ki][di] = (Vec&)storage.O[hi][ki][di] + (Vec&)storage.O[hi][w + ki][di]; - } - }); - } - - for_each([&](int s, int c, int ki, int hi, int di) { - if (ki == 0 && hi < hi_end) { - if constexpr (IsFinal) { - const int offset = (query_idx * head_num + head_idx + hi) * HeadDim + di; - Store(&out[offset], cast((Vec&)storage.O[hi][ki][di])); - } - else { - const int offset = - ((query_idx * head_num + head_idx + hi) * max_split_cnt + offset_k) * HeadDim + di; - Store(&partial_O[offset], (Vec&)storage.O[hi][ki][di]); - } - } - }); - } -}; - -template -__global__ void reduce_kernel(typename Reduce::T* out, - float* partial_M, - float* partial_L, - float* partial_O, - int* signals, - const int* split_cnt_, - int max_split_cnt, - int head_num, - float exp_scale, - int stride_k) -{ - extern __shared__ char smem[]; - - const int head_idx = ReduceCtaMap::head_idx(); - const int query_idx = ReduceCtaMap::query_idx(); - const int chunk_idx = ReduceCtaMap::split_idx(); - - const int split_cnt = split_cnt_[query_idx]; - - const int chunk_offset = chunk_idx * stride_k * Reduce::CTA_K; - - if (chunk_offset >= split_cnt) { // out of bound - return; - } - - Reduce reduce{}; - reduce(out, - partial_M, - partial_L, - partial_O, - query_idx, - head_idx, - head_num, - 1, // hi_end - split_cnt, - max_split_cnt, - exp_scale, - stride_k, - chunk_offset, - *(typename Reduce::SharedStorage*)smem, - std::integral_constant{}); -} - -} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 3ab706c2df..f07fe273c5 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -152,7 +152,9 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, seq_len, 1, block_seq_len, - 0, + 0, // layer_id + 0, // cp_rank + 1, // cp_size seq_len, head_num, head_dim, @@ -176,7 +178,9 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, seq_len, 1, block_seq_len, - 0, + 0, // layer_id + 0, // cp_rank + 1, // cp_size seq_len, head_num, head_dim, @@ -313,19 +317,15 @@ int test_attention() thrust::universal_vector cu_seqlens(kBatchSize + 1); thrust::universal_vector cu_kv_lens(kBatchSize + 1); - thrust::device_vector partial_M(kTokenNum * kHeadNum * kMaxSplitK); - thrust::device_vector partial_L(kTokenNum * kHeadNum * kMaxSplitK); + thrust::device_vector partial_ML(kTokenNum * kHeadNum * kMaxSplitK * 2); thrust::device_vector partial_O(kTokenNum * kHeadNum * kMaxSplitK * kHeadDim); thrust::device_vector split_cnt(kTokenNum); - thrust::device_vector semaphores(kTokenNum * kHeadNum * kMaxSplitK); thrust::universal_vector qk_buf((size_t)kDump * kBatchSize * kHeadNum * kInputLen * kContextLen); thrust::universal_vector pr_buf((size_t)kDump * kBatchSize * kHeadNum * kInputLen * kContextLen); thrust::universal_vector sinks(kHeadNum); - thrust::fill(semaphores.begin(), semaphores.end(), 0); - rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f); rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim); @@ -443,11 +443,9 @@ int test_attention() float scale_factor = -std::log2f(kRoPEBase) / kRoPEDim; params.rope_param = RopeKernelParam{RopeType::kDefault, nullptr, kRoPEDim, scale_factor, 1.f}; - params.split_cnt = split_cnt.data().get(); - params.partial_L = partial_L.data().get(); - params.partial_M = partial_M.data().get(); - params.partial_O = partial_O.data().get(); - params.locks = semaphores.data().get(); + params.split_cnt = split_cnt.data().get(); + params.partial_ML = partial_ML.data().get(); + params.partial_O = partial_O.data().get(); params.max_split_k = kMaxSplitK; params.arch = getSMVersion(); @@ -565,7 +563,9 @@ int test_attention() kContextLen, 1, kBlockSz, - 0, + 0, // layer_id + 0, // cp_rank + 1, // cp_size kContextLen, KvHeadNum, kHeadDim, diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 1b767d1a13..2186850712 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -21,7 +21,9 @@ add_library(Llama STATIC unified_attention_layer.cc llama_kernels.cu llama_utils.cu - mla_utils.cu) + mla_utils.cu + cp_utils.cu +) set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Llama PUBLIC CUDA::cudart @@ -29,6 +31,7 @@ target_link_libraries(Llama PUBLIC CUDA::cudart core gemm2 CUDA::cublas + nvidia::cutlass::cutlass rms_norm DynamicDecodeLayer activation_kernels diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 3c33ea133d..70673511a5 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -218,7 +218,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& int idx = 0; for (const auto& r : reqs) { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[ProcessInferRequests] Request for %llu received.", r->id); } @@ -246,7 +246,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& s = ptr->tokens.size(); } else if (s > ptr->tokens.size()) { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_WARNING("[ProcessInferRequests] Skipping invalid step (%d) setting for ID %lu", s, ptr->id); } s = ptr->tokens.size(); @@ -379,7 +379,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1 if (state.seq_len_limit[idx] >= session_len_) { state.seq_len_limit[idx] = session_len_ - 1; - if (tp_rank_ == 0) { + if (is_driver_) { const int trunc_output_len = state.seq_len_limit[idx] - state.h_context_length[idx]; TM_LOG_WARNING( "[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `max_new_tokens` is truncated to %d", @@ -823,16 +823,27 @@ void LlamaBatch::AllocSymmBuffers() const ssize_t vocab_size_padded = model_->vocab_size_padded_; // Native comm fuses allreduce & rmsnorm in token granularity - TM_CHECK(max_forward_token_num_ % tp_size_ == 0); + TM_CHECK(max_forward_token_num_ % tp_size_ == 0) << max_forward_token_num_ << " vs " << tp_size_; symm_hidden_states_buf_ = {{max_forward_token_num_ * param_.attn_dp_size, hidden_units}, data_type_, symm_alloc_}; symm_logits_buf_ = {{max_batch_size_, vocab_size_padded}, data_type_, symm_alloc_}; + + // for context parallel, we use symm_alloc_ and both prefill and decode stage have reduce process + // w/o context parallel, we use common alloc and only decode stage has reduce process + // perhaps it would be more appropriate to put this buffer in the unified_attention_layer. + Allocator alloc = param_.attn_cp_size > 1 ? symm_alloc_ : core::Context::alloc(kDEVICE); + const ssize_t attn_ws_tokens = param_.attn_cp_size > 1 ? + UnifiedAttentionLayer::kMaxWorkspaceTokens + max_forward_token_num_ : + UnifiedAttentionLayer::kMaxWorkspaceTokens; + symm_partial_ML_ = {{param_.attn_cp_size, attn_ws_tokens, (int)model_->local_head_num_, 2}, alloc}; } void LlamaBatch::FreeSymmBuffers() { symm_hidden_states_buf_ = {}; symm_logits_buf_ = {}; + + symm_partial_ML_ = {}; } LlamaBatch::~LlamaBatch() @@ -870,6 +881,7 @@ LlamaBatch::LlamaBatch(DataType data_type, tp_rank_(model->tp_rank_), data_type_(data_type), debug_(isDebug()), + is_driver_(param.attn_tp_rank == 0 && param.attn_cp_rank == 0), stream_(ctx->stream), context_(std::move(ctx)), model_(std::move(model)), @@ -998,7 +1010,7 @@ void LlamaBatch::ComputeAndOutputLogits(const Tensor& hidden_states, int first, auto logits = model_->postDecodeEmbedding(hidden_states, symm_logits_buf_.buffer()); - if (tp_rank_ == 0) { + if (is_driver_) { OutputLogits(logits, first, last, GenerationConfig::kAll); } } @@ -1159,7 +1171,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) } // ! Only rank-0 writes to output - if (tp_rank_ == 0 && output_logprobs) { + if (is_driver_ && output_logprobs) { NvtxScope scope("logprobs"); float* sampled_logprobs_ptr = h_sampled_logprobs_.data(); uint32_t* sampled_indexes_ptr = h_sampled_indexes_.data(); @@ -1186,7 +1198,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) } // ! Only rank-0 writes to output - if (tp_rank_ == 0) { + if (is_driver_) { NvtxScope scope("output_ids"); for (int i = 0; i < batch_size - g.partial; ++i) { if (auto& r = state_->requests[i]) { @@ -1202,7 +1214,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) // Cache computed blocks to block trie sequence_manager_->CachePrompt(state_->sequences, batch_size); - if (debug_ && tp_rank_ == 0) { + if (debug_ && is_driver_) { for (int i = 0; i < batch_size; ++i) { // ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")"; std::vector tokens(state_->h_context_length[i]); @@ -1243,7 +1255,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) // Interrupt should reset r FT_CHECK(!r); } - else if (r->stream_output && tp_rank_ == 0) { + else if (r->stream_output && is_driver_) { const auto seq_len = *r->sequence_length.data(); // Create signals by copying the request handles for non-finished streaming requests signals.push_back([this, r, seq_len] { // @@ -1270,11 +1282,11 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) auto LlamaBatch::Interrupt(int index, bool force_stop) -> Signal { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[Interrupt] slot %d, request %llu, stop %d", index, state_->requests[index]->id, force_stop); } - if (debug_ && tp_rank_ == 0) { + if (debug_ && is_driver_) { std::vector tokens(state_->h_context_length[index]); core::Copy(state_->output_ids.data() + index * session_len_, tokens.size(), tokens.data()); cudaStreamSynchronize(stream_); @@ -1350,7 +1362,7 @@ void LlamaBatch::InternalThreadEntry() std::shared_ptr req; - if (tp_rank_ == 0) { + if (is_driver_) { req = std::make_shared(); { NvtxScope _("pop"); @@ -1394,7 +1406,7 @@ void LlamaBatch::InternalThreadEntry() ProcessCancelRequests(req->cancel, signals); - if (tp_rank_ == 0) { + if (is_driver_) { gateway_->notify(std::move(signals)); } @@ -1418,7 +1430,7 @@ void LlamaBatch::InternalThreadEntry() comm_.h_tp_group->Sync(); } - if (tp_rank_ == 0) { + if (is_driver_) { gateway_->notify(std::move(signals)); } } @@ -1451,7 +1463,7 @@ bool LlamaBatch::Forward(GenerationState& g) const int active_size = state_->active_size; constexpr int kLogInterval = 10; - if (tp_rank_ == 0 && (g.step - 1) % kLogInterval == 0) { + if (is_driver_ && (g.step - 1) % kLogInterval == 0) { TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1); } @@ -1531,7 +1543,7 @@ bool LlamaBatch::Forward(GenerationState& g) const int dc_batch_size = p ? 0 : pf_offset; const int pf_batch_size = mini_batch_size - dc_batch_size; - if (tp_rank_ == 0) { + if (is_driver_) { if (pf_batch_size) { const auto max_q = *std::max_element(h_input_length_buf_.data() + first, h_input_length_buf_.data() + last); @@ -1572,6 +1584,7 @@ bool LlamaBatch::Forward(GenerationState& g) state_->h_context_length.slice(first, mini_batch_size), rope_theta_.slice(first, mini_batch_size), &mrope, + symm_partial_ML_, finished_buf_.slice(first, mini_batch_size), Buffer(local_token_nums.data(), local_token_nums.size(), kCPU), lora_mask_buf_, @@ -1647,7 +1660,7 @@ bool LlamaBatch::Forward(GenerationState& g) }); AnomalyHandler::instance().Reset(); - if (debug_ && tp_rank_ == 0) { + if (debug_ && is_driver_) { std::vector curr(active_size); core::Copy(token_ids_buf_.data() + g.step * active_size, active_size, curr.data()); cudaStreamSynchronize(stream_); @@ -1704,7 +1717,7 @@ void LlamaBatch::Warmup() if (auto str = std::getenv("TM_GEMM_IMPORT")) { std::ifstream ifs(str); const int n_imported = linear.Import(ifs); - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[Gemm2] %d records imported", n_imported); } return; @@ -1722,7 +1735,7 @@ void LlamaBatch::Warmup() bss.push_back(max_forward_token_num_); } - if (tp_rank_ == 0) { + if (is_driver_) { auto str = Join(bss.begin(), bss.end(), ", "); TM_LOG_INFO("[Gemm2] Tuning sequence: %s", str.c_str()); } @@ -1745,7 +1758,7 @@ void LlamaBatch::Warmup() /// NOTE: No explicit barrier can be used here as internal threads are waiting on it now for (auto token_num : bss) { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[Gemm2] %d", token_num); } @@ -1764,6 +1777,7 @@ void LlamaBatch::Warmup() Buffer{&input_length, 1, kCPU}, rope_theta_.slice(0, bsz), nullptr, // mrope + symm_partial_ML_, finished_buf_.slice(0, bsz), Buffer{local_token_nums.data(), (int)local_token_nums.size(), kCPU}, Buffer{}, @@ -1774,7 +1788,7 @@ void LlamaBatch::Warmup() auto tock = std::chrono::steady_clock::now(); - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[Gemm2] Tuning finished in %.2f seconds.", std::chrono::duration>(tock - tick).count()); } @@ -1784,7 +1798,7 @@ void LlamaBatch::Warmup() check_cuda_error(cudaStreamSynchronize(stream_)); // Only rank-0 exports the dispatch cache - if (tp_rank_ == 0) { + if (is_driver_) { if (auto path = std::getenv("TM_GEMM_EXPORT")) { std::ofstream ofs(path); const auto n_records = context_->linear->Export(ofs); @@ -1825,12 +1839,13 @@ void LlamaBatch::InitializeBufferAndKVCache() param_.cache_chunk_size, param_.enable_prefix_caching, tp_rank_, + param_.attn_cp_size, core::Context::alloc(kDEVICE), get_free_size}); - const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len; + const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len * param_.attn_cp_size; if (max_session_len < session_len_) { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.", session_len_, max_session_len); @@ -1915,7 +1930,7 @@ void LlamaBatch::DestroyCommunicators() void LlamaBatch::UpdateMetrics() { - if (tp_rank_ == 0 && param_.enable_metrics) { + if (is_driver_ && param_.enable_metrics) { // update schedule metrics int total_seqs, active_seqs, cached_seqs; std::tie(total_seqs, active_seqs, cached_seqs) = sequence_manager_->seq_stats(); diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index cf604a0a4f..55386c9aff 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -220,6 +220,7 @@ class LlamaBatch { const int tp_rank_; const DataType data_type_; const bool debug_; + const bool is_driver_; // Refs into `Context` cudaStream_t const stream_{}; @@ -244,6 +245,9 @@ class LlamaBatch { Tensor symm_hidden_states_buf_; Tensor symm_logits_buf_; + // context parallel + Tensor_ symm_partial_ML_; + Tensor decoder_output_buf_; Tensor_ sampling_logits_; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 68185eac38..583cdce47b 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -64,8 +64,8 @@ LlamaV2::LlamaV2(DataType dtype, attn_param_(attn), lora_param_(lora), comm_(&ctx.comm), - tp_size_(engine.attn_tp_size), - tp_rank_(engine.attn_tp_rank), + tp_size_(engine.attn_tp_size * engine.attn_cp_size), + tp_rank_(engine.attn_tp_rank * engine.attn_cp_size + engine.attn_cp_rank), head_num_(model.head_num), size_per_head_(model.head_dim), hidden_units_(model.hidden_units), @@ -163,6 +163,7 @@ void LlamaV2::Forward(Buffer_ input_ids, Buffer_ h_context_length, Buffer rope_base, MropeRope* mrope, + Tensor partial_ML, Buffer finished, Buffer local_token_nums, Buffer lora_mask, @@ -258,6 +259,7 @@ void LlamaV2::Forward(Buffer_ input_ids, {"decode_num", Buffer{&decode_num, 1, kCPU}}, {"prefil_num", Buffer{&prefil_num, 1, kCPU}}, {"rope_base", rope_base}, + {"partial_ML", partial_ML}, {"cu_block_nums", cu_block_nums}, {"kv_block_ptrs", kv_block_ptrs}, {"local_token_nums", local_token_nums}}; diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 304fb97fd3..7d77db0812 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -69,6 +69,7 @@ class LlamaV2 { Buffer_ h_context_length, Buffer rope_base, MropeRope* mrope, + Tensor partial_ML, Buffer finished, Buffer local_token_nums, Buffer lora_mask, diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index f46124798a..58c03034ed 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -46,8 +46,8 @@ LlamaWeight::LlamaWeight(DataType data_type, num_layer_(model.layer_num), data_type_{data_type}, weight_type_{model.weight_type}, - tp_size_(engine_param.attn_tp_size), - tp_rank_(engine_param.attn_tp_rank) + tp_size_(engine_param.attn_tp_size * engine_param.attn_cp_size), + tp_rank_(engine_param.attn_tp_rank * engine_param.attn_cp_size + engine_param.attn_cp_rank) { if (vocab_size_padded_ % tp_size_ != 0) { vocab_size_padded_ = (vocab_size_ + tp_size_ - 1) / tp_size_ * tp_size_; diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc index 1342b003a7..fe74678343 100644 --- a/src/turbomind/models/llama/SequenceManager.cc +++ b/src/turbomind/models/llama/SequenceManager.cc @@ -34,9 +34,10 @@ SequenceManager::SequenceManager(size_t layer_num, int chunk_size, bool enable_prefix_caching, int rank, + int attn_cp_size, core::Allocator allocator, GetFreeMemSize get_free_size): - block_seq_len_(block_config.block_len_), rank_(rank) + block_seq_len_(block_config.block_len_), rank_(rank), attn_cp_size_(attn_cp_size) { block::Layout layout{block_config}; // dump(layout); @@ -385,7 +386,7 @@ std::vector SequenceManager::CountRequiredBlocks(const Sequences& se { std::vector required(sequences.size()); for (int i = 0; i < sequences.size(); ++i) { - int seq_len = context_lengths[i] + step_length; + int seq_len = (context_lengths[i] + step_length + attn_cp_size_ - 1) / attn_cp_size_; int count = (seq_len + block_seq_len_ - 1) / block_seq_len_ - static_cast(sequences[i]->blocks.size()); required[i] = std::max(0, count); } diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h index 5cbdc4a426..a1c4f1615a 100644 --- a/src/turbomind/models/llama/SequenceManager.h +++ b/src/turbomind/models/llama/SequenceManager.h @@ -81,6 +81,7 @@ class SequenceManager { int chunk_size, bool enable_prefix_caching, int rank, + int attn_cp_size, core::Allocator allocator, GetFreeMemSize get_free_size); @@ -186,6 +187,7 @@ class SequenceManager { private: int block_seq_len_; int rank_; + int attn_cp_size_; // Use `std::map` to avoid reference invalidation std::map sequences_; diff --git a/src/turbomind/models/llama/context.h b/src/turbomind/models/llama/context.h index 33b7be29ac..666803100d 100644 --- a/src/turbomind/models/llama/context.h +++ b/src/turbomind/models/llama/context.h @@ -22,6 +22,7 @@ struct Communicators { comm::DeviceComm d_comm; int d_tp_group; + int d_cp_group; }; // Execution context for the model diff --git a/src/turbomind/models/llama/cp_utils.cu b/src/turbomind/models/llama/cp_utils.cu new file mode 100644 index 0000000000..6b56e7f10f --- /dev/null +++ b/src/turbomind/models/llama/cp_utils.cu @@ -0,0 +1,20 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/models/llama/cp_utils.h" + +namespace turbomind { + +void CpPost(void* context) +{ + auto ctx = reinterpret_cast(context); + + ctx->d_comm->AllGather(ctx->partial_ML + ctx->cp_rank * ctx->count, // + ctx->partial_ML, + ctx->count, + DataType::kFloat, + ctx->attn_cp_group, + ctx->stream); + sync_check_cuda_error(); +} + +} // namespace turbomind diff --git a/src/turbomind/models/llama/cp_utils.h b/src/turbomind/models/llama/cp_utils.h new file mode 100644 index 0000000000..ae94112ada --- /dev/null +++ b/src/turbomind/models/llama/cp_utils.h @@ -0,0 +1,23 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/comm/device_comm.h" +#include "src/turbomind/utils/cuda_utils.h" + +namespace turbomind { + +struct CpPostContext { + + CpPostContext(comm::DeviceCommImpl* d_comm, int attn_cp_group): d_comm(d_comm), attn_cp_group(attn_cp_group) {} + + comm::DeviceCommImpl* d_comm; + int attn_cp_group; + + int cp_rank; + int count; + float* partial_ML; + cudaStream_t stream; +}; + +void CpPost(void* context); + +} // namespace turbomind diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 57886e17a4..e3cdd973ea 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -105,6 +105,8 @@ struct EngineParam { int attn_dp_rank; int attn_tp_size; int attn_tp_rank; + int attn_cp_size; + int attn_cp_rank; int mlp_tp_size; int mlp_tp_rank; diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 5808541001..c987f242e5 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -72,6 +72,8 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, local_kv_head_num_(model.kv_head_num / tp_size), param_(attn), model_param_(model), + engine_param_(engine), + cp_fn_ctx_(ctx.comm.d_comm, ctx.comm.d_cp_group), lora_param_(lora), context_(ctx), stream_(ctx.stream), @@ -90,14 +92,17 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, init_rope_kernel_param(param_.rope, rope_param_); - partial_M_ = Tensor_({kMaxWorkspaceTokens, local_head_num_}, kDEVICE); - partial_L_ = Tensor_({kMaxWorkspaceTokens, local_head_num_}, kDEVICE); - partial_O_ = Tensor_({kMaxWorkspaceTokens, local_head_num_, size_per_head_}, kDEVICE); + // partial_O layout: + // w/ cp, decode(q, h, k, 2) + prefill(q, h, 1, 2) + // w/o cp, decode(q, h, k, 2) + const ssize_t attn_ws_tokens = engine_param_.attn_cp_size > 1 ? + kMaxWorkspaceTokens + engine_param_.max_forward_token_num : + kMaxWorkspaceTokens; + + partial_O_ = Tensor_({attn_ws_tokens, local_head_num_, size_per_head_}, kDEVICE); split_cnt_ = Tensor_({kMaxWorkspaceTokens}, kDEVICE); - barriers_ = Tensor_({kMaxWorkspaceTokens, local_head_num_}, kDEVICE); Clear(split_cnt_.buffer()); - Clear(barriers_.buffer()); const auto max_batch_size = engine.max_batch_size; @@ -136,6 +141,8 @@ void UnifiedAttentionLayer::Initialize(TensorMap& args) cu_block_nums_ = args.at("cu_block_nums").buffer(); kv_block_ptrs_ = args.at("kv_block_ptrs").buffer(); + partial_ML_ = args.at("partial_ML").borrow(); + // rotary embedding, add offest when forward if (rope_param_.type == RopeType::kDynamic) { rope_param_.base = const_cast(rope_base_.data()); @@ -240,7 +247,7 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; Tensor attn{{q_count, (int)local_head_num_ * (int)size_per_head_}, dtype, device}; - Tensor tmp_kv{{2, (int)local_kv_head_num_, k_count + MAX_CTA_S, (int)size_per_head_}, dtype, device}; + Tensor tmp_kv{{(int)local_kv_head_num_, 2, k_count + MAX_CTA_S, (int)size_per_head_}, dtype, device}; auto stream_ptr = streams_.data(); @@ -321,12 +328,35 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, // Decoding use only for now params.split_cnt = split_cnt_.data(); - params.partial_L = partial_L_.data(); - params.partial_M = partial_M_.data(); + params.partial_ML = partial_ML_.data(); params.partial_O = partial_O_.data(); - params.locks = barriers_.data(); params.max_split_k = std::min(std::max(1, kMaxWorkspaceTokens / params.token_num), max_kv_splits); + // context parallel + params.cp_rank = engine_param_.attn_cp_rank; + params.cp_size = engine_param_.attn_cp_size; + if (params.cp_size > 1) { + params.cp_size = cutlass::FastDivmod(params.cp_size); + + // update ML,O offset if both prefill and decode present + const int offset_ML_stage = + engine_param_.attn_cp_size * (offset ? kMaxWorkspaceTokens * local_head_num_ * 2 : 0); + const int offset_ML_rank = params.cp_rank * params.token_num * local_head_num_ * params.max_split_k * 2; + const int offset_O = offset ? kMaxWorkspaceTokens * local_head_num_ * size_per_head_ : 0; + + params.partial_ML = partial_ML_.data() + offset_ML_stage + offset_ML_rank; + params.partial_O = partial_O_.data() + offset_O; + params.offset_q = offset; + + // postprocess func + params.cp_fn = CpPost; + params.cp_fn_ctx = (void*)&cp_fn_ctx_; + cp_fn_ctx_.cp_rank = params.cp_rank; + cp_fn_ctx_.count = params.token_num * local_head_num_ * params.max_split_k * 2; + cp_fn_ctx_.partial_ML = partial_ML_.data() + offset_ML_stage; + cp_fn_ctx_.stream = stream; + } + params.arch = arch_; params.stream = stream; diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index a498b3b881..06b0c02531 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -30,6 +30,7 @@ #include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/LlamaLinear.h" #include "src/turbomind/models/llama/context.h" +#include "src/turbomind/models/llama/cp_utils.h" #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/utils/cuda_utils.h" @@ -87,6 +88,7 @@ class UnifiedAttentionLayer { const int local_kv_head_num_; const AttentionParam param_; + const EngineParam engine_param_; const ModelParam model_param_; const LoraParam lora_param_; const Context& context_; @@ -110,11 +112,12 @@ class UnifiedAttentionLayer { int decode_num_; int prefil_num_; - Tensor_ partial_M_; - Tensor_ partial_L_; + Tensor_ partial_ML_; Tensor_ partial_O_; Tensor_ split_cnt_; - Tensor_ barriers_; // always zero + + // context parallel + CpPostContext cp_fn_ctx_; Event event_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 853b0a96d8..56f9b2ddb9 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -372,17 +372,19 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ engine_param_.attn_dp_rank = 0; engine_param_.attn_tp_size = engine_reader["attn_tp_size"].as(); engine_param_.attn_tp_rank = 0; + engine_param_.attn_cp_size = engine_reader["attn_cp_size"].as(); + engine_param_.attn_cp_rank = 0; engine_param_.mlp_tp_size = engine_reader["mlp_tp_size"].as(); engine_param_.mlp_tp_rank = 0; engine_param_.devices = engine_reader["devices"].as>(); { - auto tp = engine_param_.attn_tp_size; + auto tp = engine_param_.attn_tp_size * engine_param_.attn_cp_size; engine_param_.max_forward_token_num = ((size_t)max_forward_token_num + tp - 1) / tp * tp; } - comm_size_ = engine_param_.attn_dp_size * engine_param_.attn_tp_size; + comm_size_ = engine_param_.attn_dp_size * engine_param_.attn_tp_size * engine_param_.attn_cp_size; FT_CHECK(engine_param_.mlp_tp_size == comm_size_); communicator_ = engine_reader["communicator"].as(); @@ -437,13 +439,16 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ } const int device_num = engine_param_.outer_dp_size * comm_size_; + const int tp_cp_size = engine_param_.attn_tp_size * engine_param_.attn_cp_size; + // comm layout: outer_dp x inner(dp, tp, cp) engine_params_.resize(device_num, engine_param_); for (int i = 0; i < device_num; ++i) { auto& e = engine_params_[i]; e.outer_dp_rank = i / comm_size_; - e.attn_tp_rank = i % comm_size_ % e.attn_tp_size; - e.attn_dp_rank = i % comm_size_ / e.attn_tp_size; + e.attn_cp_rank = i % comm_size_ % e.attn_cp_size; + e.attn_tp_rank = i % tp_cp_size / e.attn_cp_size; + e.attn_dp_rank = i % comm_size_ / tp_cp_size; e.mlp_tp_rank = i % comm_size_; } @@ -494,17 +499,26 @@ Communicators LlamaTritonModel::createCommSplits(int rank) const int outer_rank = rank / comm_size_; const int inner_rank = rank % comm_size_; + const int tp_cp_size = engine_param_.attn_tp_size * engine_param_.attn_cp_size; + const int color_tp = inner_rank / tp_cp_size; + const int color_cp = inner_rank / engine_param_.attn_cp_size; + const int color_dp = inner_rank % tp_cp_size; + comm.h_comm = group_ids_[outer_rank]->CreateCommunicator(comm_size_, inner_rank); - comm.h_tp_group = comm.h_comm->Split(inner_rank / engine_param_.attn_tp_size, 0); - comm.h_dp_group = comm.h_comm->Split(inner_rank % engine_param_.attn_tp_size, 0); + comm.h_tp_group = comm.h_comm->Split(color_tp, 0); + comm.h_dp_group = comm.h_comm->Split(color_dp, 0); if (comm_size_ > 1) { comm.d_comm = CreateDeviceCommunicator(communicator_, comm_size_, inner_rank, comm.h_comm); // comm.d_tp_group = 0; - if (engine_param_.attn_tp_size != comm_size_) { - comm.d_tp_group = comm.d_comm->Split(inner_rank / engine_param_.attn_tp_size, 0, 0); + comm.d_cp_group = 0; + if (engine_param_.attn_dp_size > 1) { // has attn_dp + comm.d_tp_group = comm.d_comm->Split(color_tp, 0, 0); + } + if (engine_param_.attn_cp_size > 1) { // has attn_cp + comm.d_cp_group = comm.d_comm->Split(color_cp, 0, 0); } }