Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c1dae3a
use driver flag
irexyc Sep 16, 2025
bb27b62
update
irexyc Sep 19, 2025
0fe88bc
accurate mask iter
irexyc Sep 22, 2025
5c02779
use fast divmod
irexyc Sep 22, 2025
53654ad
remove cp_O
irexyc Sep 22, 2025
e3dd4f7
remove unused
irexyc Sep 22, 2025
1f75dd6
return the last token's logprobs if include_stop_str_in_output is req…
lvhan028 Sep 22, 2025
be504d3
[Fix] device args in chat cli when using pytorch engine (#3999)
CyCle1024 Sep 22, 2025
25a8fb8
Merge remote-tracking branch 'origin/main' into cp2
irexyc Sep 23, 2025
77ef52a
fix NULL raw data
irexyc Sep 23, 2025
29cf813
add attn_cp_size to cli
irexyc Sep 24, 2025
0044d4f
build cutlass::FastDivmod on host
irexyc Sep 24, 2025
e4050a4
use single buffer
irexyc Sep 25, 2025
f44ef96
udpate comm
irexyc Sep 26, 2025
a329b29
use two stage reduce
irexyc Oct 24, 2025
dafcd64
Merge remote-tracking branch 'github/main' into cp2
irexyc Oct 24, 2025
c9649c0
remove unused
irexyc Oct 24, 2025
52766d2
better AllreduceResidualRMSnorm
irexyc Oct 28, 2025
b783d5c
fix max_session_len
irexyc Oct 29, 2025
c39373a
Merge remote-tracking branch 'github/main' into cp
irexyc Oct 29, 2025
47a349b
update docs
irexyc Oct 30, 2025
d83a2c7
fix embedding/lm_head split
irexyc Nov 3, 2025
c7e1e23
use same split_k on different cp_rank
irexyc Nov 4, 2025
8c5b289
always use seperate reduce for cp
irexyc Nov 5, 2025
4005547
add cp configuration parameter
irexyc Nov 5, 2025
1d2b098
remove redundant parameters
irexyc Nov 5, 2025
77920f8
remove redundant parameters
irexyc Nov 5, 2025
f54ca43
fix build
irexyc Nov 5, 2025
1ac3080
fix xgrammar build
irexyc Nov 5, 2025
7872225
update docs
irexyc Nov 5, 2025
0f82ef1
remove unused
irexyc Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ project(TurboMind LANGUAGES CXX CUDA)
if (MSVC)
# use standard conformant preprocessor
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/Zc:preprocessor>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/Zc:__cplusplus>)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/Zc:preprocessor -Xcompiler=/Zc:__cplusplus")
endif ()

Expand Down Expand Up @@ -101,6 +102,10 @@ if(NOT xgrammar_POPULATED)

# Bring the populated content into the build
add_subdirectory(${xgrammar_SOURCE_DIR} ${xgrammar_BINARY_DIR})
if(TARGET xgrammar)
target_compile_options(xgrammar PRIVATE $<$<CXX_COMPILER_ID:MSVC>:/utf-8>)
target_compile_options(xgrammar PRIVATE $<$<C_COMPILER_ID:MSVC>:/utf-8>)
endif()
endif()

# the environment variable
Expand Down
2 changes: 2 additions & 0 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def parse_args():
tb_group._group_actions.append(dtype_act)

ArgumentHelper.dp(tb_group)
ArgumentHelper.cp(tb_group)
ArgumentHelper.model_format(tb_group, default='hf')
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
Expand All @@ -344,6 +345,7 @@ def main():
max_batch_size=args.concurrency // args.dp,
tp=args.tp,
dp=args.dp,
cp=args.cp,
cache_max_entry_count=args.cache_max_entry_count,
cache_block_seq_len=args.cache_block_seq_len,
model_format=args.model_format,
Expand Down
2 changes: 1 addition & 1 deletion builder/windows/generate.ps1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cmake .. -A x64 -T "v142,cuda=$env:CUDA_PATH" `
cmake .. -A x64 -T "v143,cuda=$env:CUDA_PATH" `
-DCMAKE_BUILD_TYPE=Release `
-DCMAKE_INSTALL_PREFIX=install `
-DBUILD_PY_FFI=ON `
Expand Down
24 changes: 24 additions & 0 deletions docs/en/advance/context_parallel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Context Parallel

When the memory on a single GPU is insufficient to deploy a model, it is often deployed using tensor parallelism (TP), which generally requires `num_key_value_heads` to be divisible by `TP`. If you want to deploy with `TP > num_key_value_heads`, the kv-heads should be duplicated to meet the divisibility requirement. However, this has two disadvantages:

1. The amount of available kv_cache is halved, which reducing the maximum supported session length.
2. The maximum inference batch size is reduced, leading to lower throughput.

To address this issue, the TurboMind inference backend supports setting `attn_dp_size`, which avoids creating copies of kv-heads, but this introduces data imbalance. To eliminate data imbalance, TurboMind supports sequence parallelism, which allowing kv_cache to be stored interleaved on different cp_ranks. See the example below:

```
cp_rank=2, prompt_len=5, generation_len=4
kv_cache stored on cp_rank0: 0, 2, 4, 6, 8
kv_cache stored on cp_rank1: 1, 3, 5, 7
```

## Usage

Taking Intern-S1 / Qwen3-235B-A22B as an example, their `num_key_value_heads` is 4. If you want to deploy with `TP=8` and avoid duplication of kv_cache, you can deploy in the following way:

```
lmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2

lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2
```
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Documentation
advance/pytorch_multinodes.md
advance/pytorch_profiling.md
advance/metrics.md
advance/context_parallel.md

.. toctree::
:maxdepth: 1
Expand Down
24 changes: 24 additions & 0 deletions docs/zh_cn/advance/context_parallel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# 序列并行

在单卡显存不足以部署模型的时候,通常会以 `TP` 的方式进行部署,而这一般要求 `num_key_value_heads` 被 `TP` 整除。如果要以 `TP > num_key_value_heads` 的方式进行部署,需要创建 kv-heads 的副本,以满足整除需求。但是这样会有两个缺点:

1. 可用的 kvcache 数量减半,进而减少请求最大推理长度
2. 降低推理的最大 batch 数量,减少吞吐量。

为了解决这个问题,TurboMind 推理后端支持设置 `attn_dp_size`,避免了创建 kv-heads 的副本,但是这会引入数据的不均衡性。为了消除数据的不均衡,TurboMind 支持了序列并行,支持将 kv_cache 交错存储到不同的 cp_rank 上,例如

```
cp_rank=2, prompt_len=5, generation_len=4
kv_cache stored on cp_rank0: 0, 2, 4, 6, 8
kv_cache stored on cp_rank1: 1, 3, 5, 7
```

## 使用说明

以 `Intern-S1` / `Qwen3-235B-A22B` 为例,他们的 `num_key_value_heads` 为 4,若要用 `TP=8` 的方式部署,并避免 kv_cache 的拷贝,可以用如下的方式部署

```
lmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2

lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2
```
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能:
advance/pytorch_multinodes.md
advance/pytorch_profiling.md
advance/metrics.md
advance/context_parallel.md

.. toctree::
:maxdepth: 1
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def add_parser_chat():
ArgumentHelper.model_format(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.communicator(tb_group)
ArgumentHelper.cp(tb_group)

@staticmethod
def add_parser_checkenv():
Expand Down
8 changes: 6 additions & 2 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def add_parser_api_server():
model_format = ArgumentHelper.model_format(pt_group)
hf_overrides = ArgumentHelper.hf_overrides(pt_group)
enable_metrics = ArgumentHelper.enable_metrics(pt_group)
ArgumentHelper.dp(pt_group)
dp = ArgumentHelper.dp(pt_group)
ArgumentHelper.ep(pt_group)
ArgumentHelper.enable_microbatch(pt_group)
ArgumentHelper.enable_eplb(pt_group)
Expand All @@ -135,6 +135,8 @@ def add_parser_api_server():
tb_group._group_actions.append(model_format)
tb_group._group_actions.append(hf_overrides)
tb_group._group_actions.append(enable_metrics)
tb_group._group_actions.append(dp)
ArgumentHelper.cp(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
Expand Down Expand Up @@ -232,6 +234,8 @@ def api_server(args):
from lmdeploy.messages import TurbomindEngineConfig
backend_config = TurbomindEngineConfig(dtype=args.dtype,
tp=args.tp,
dp=args.dp,
cp=args.cp,
max_batch_size=max_batch_size,
session_len=args.session_len,
model_format=args.model_format,
Expand All @@ -250,7 +254,7 @@ def api_server(args):

from lmdeploy.messages import VisionConfig
vision_config = VisionConfig(args.vision_max_batch_size)
if args.dp == 1:
if args.dp == 1 or backend == 'turbomind':
from lmdeploy.serve.openai.api_server import serve as run_api_server

run_api_server(args.model_path,
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ def ep(parser):
default=1,
help='expert parallelism. dp is required when pytorch engine is used.')

@staticmethod
def cp(parser):
"""Add argument cp to parser."""

return parser.add_argument(
'--cp',
type=int,
default=1,
help='context parallelism size in attention for turbomind backend. Should divide tp.')

@staticmethod
def dp_rank(parser):
"""Add argument dp_rank to parser."""
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,10 @@ class TurbomindEngineConfig:
model_format: Optional[str] = None
tp: int = 1
dp: int = 1
cp: int = 1
device_num: int = None
attn_tp_size: int = None
attn_cp_size: int = None
attn_dp_size: int = None
mlp_tp_size: int = None
mlp_dp_size: int = None
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ModelConfig:
expert_weight_type: str = None
session_len: int = None
attn_tp_size: int = 1
attn_cp_size: int = 1
mlp_tp_size: int = 1
model_format: str = 'hf'
expert_num: List[int] = ()
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def get_tm_model(model_path,
tm_cfg.model_config.model_name = model_name

tm_cfg.model_config.attn_tp_size = engine_config.attn_tp_size
tm_cfg.model_config.attn_cp_size = engine_config.attn_cp_size
tm_cfg.model_config.mlp_tp_size = engine_config.mlp_tp_size

output_model = OUTPUT_MODELS.get(output_model_name)(input_model=input_model,
Expand Down
3 changes: 1 addition & 2 deletions lmdeploy/turbomind/deploy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,13 @@ def pad_weight(tensor: torch.Tensor, tp: int):
return tensor
return torch.nn.functional.pad(tensor, (0, 0, 0, pad_size), 'constant', 0)

tp = self.model.attn_tp_size * self.model.attn_cp_size
if emb is not None:
tp = self.model.attn_tp_size
emb = pad_weight(emb, tp=tp)
self.model.save_split(emb, 'tok_embeddings.weight', split_dim=1, split_num=tp)
if norm_weight is not None:
self.model.export_weight(norm_weight, 'norm.weight')
if output_weight is not None:
tp = self.model.attn_tp_size
output_weight = pad_weight(output_weight, tp=tp)
# transpose
self.model.save_split(output_weight.t(), 'output.weight', split_dim=1, split_num=tp)
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, input_model: BaseInputModel, cfg: TurbomindModelConfig, model
self.attention_config = cfg.attention_config
self.lora_config = cfg.lora_config
self.attn_tp_size = self.model_config.attn_tp_size
self.attn_cp_size = self.model_config.attn_cp_size
self.mlp_tp_size = self.model_config.mlp_tp_size
self.out_dir = out_dir
self.to_file = True if out_dir else False
Expand Down
8 changes: 5 additions & 3 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ def update_parallel_config(cfg: TurbomindEngineConfig):
inner_tp_size = cfg.tp // mlp_tp_size
cfg.outer_dp_size = cfg.dp // attn_dp_size
cfg.attn_dp_size = attn_dp_size
cfg.attn_tp_size = inner_tp_size
cfg.attn_tp_size = inner_tp_size // cfg.cp
cfg.attn_cp_size = cfg.cp
cfg.mlp_dp_size = 1
cfg.mlp_tp_size = mlp_tp_size * inner_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size == cfg.mlp_dp_size * cfg.mlp_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size == cfg.device_num
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size == cfg.mlp_dp_size * cfg.mlp_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size * cfg.outer_dp_size == cfg.device_num
cfg.devices = cfg.devices or list(range(cfg.device_num))


Expand Down Expand Up @@ -272,6 +273,7 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):

self._postprocess_config(tm_model.tm_config, engine_config)

print(yaml.safe_dump(self.config_dict))
model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='',
config=yaml.safe_dump(self.config_dict),
weight_type=self.config.model_config.weight_type)
Expand Down
5 changes: 3 additions & 2 deletions src/turbomind/comm/nccl/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,11 @@ public:

int Split(int color, int key, int group) override
{
auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit);
// auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit);

ncclComm_t comm{};
NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr));
// NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr));
NCCLCHECK(ncclCommSplit(groups_.at(group), color, key, &comm, nullptr));

int index = groups_.size();
groups_.push_back(comm);
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ set_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_compile_options(attention PRIVATE -O3
$<$<COMPILE_LANGUAGE:CUDA>:-use_fast_math --expt-relaxed-constexpr>)

target_link_libraries(attention PRIVATE nvidia::cutlass::cutlass)

if (BUILD_TEST)
target_compile_options(attention PRIVATE
Expand Down
11 changes: 11 additions & 0 deletions src/turbomind/kernels/attention/attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include "cutlass/fast_math.h"
#include <cstdint>
#include <cuda_runtime.h>

Expand All @@ -23,6 +24,8 @@ struct BlockIteratorParams {
int block_len;
};

typedef void (*cp_post_fn)(void* context, int split_cnt);

/// TODO: Rename to attention::Param
template<typename T>
struct AttentionParams {
Expand Down Expand Up @@ -79,6 +82,14 @@ struct AttentionParams {
float* partial_L;
int* locks;

// context parallel
int cp_rank{0};
cutlass::FastDivmod cp_size{1};
int cp_q_offset{0}; // decode offset
float* cp_ML{nullptr}; // cp, q, h, k, 2
cp_post_fn cp_fn{nullptr};
void* cp_fn_ctx{nullptr};

int arch;
cudaStream_t stream;

Expand Down
8 changes: 6 additions & 2 deletions src/turbomind/kernels/attention/attention_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void invokeAttention(const typename Kernel::ParamType& params)
return int2{sm_count, max_active_ctas};
}();

const int tile_count = cdiv(std::min(params.max_k_len, params.window_size), Kernel::CTA_S);
const int max_cp_k_len = (params.max_k_len + params.cp_size - 1) / params.cp_size;
const int tile_count = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S);
const int max_split_count = std::min(params.max_split_k, tile_count);

typename Kernel::CtaMap cta_map{
Expand Down Expand Up @@ -80,7 +81,10 @@ void invokeAttention(const typename Kernel::ParamType& params)
std::abort();
}

if (split_cnt > 1 && Kernel::need_separate_reduce(split_cnt)) {
if (params.cp_fn) {
params.cp_fn(params.cp_fn_ctx, split_cnt);
}
else if (split_cnt > 1 && Kernel::need_separate_reduce(split_cnt)) {
attention::invokeReduce<Kernel::kHeadDim>(params.out,
params.partial_M,
params.partial_L,
Expand Down
Loading
Loading