Skip to content

Commit 281e101

Browse files
authored
Qwen3 next (#4039)
* WIP * wip * WIP * first * fix chat * add env check * add comment * fix pad * cudagraph * init cache * fix * mem pool * fix state allocate * add skip warmup flag * update conv state
1 parent 29db947 commit 281e101

File tree

26 files changed

+1517
-27
lines changed

26 files changed

+1517
-27
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
130130
<li>Qwen2-MoE (57BA14B)</li>
131131
<li>Qwen2.5 (0.5B - 32B)</li>
132132
<li>Qwen3, Qwen3-MoE</li>
133+
<li>Qwen3-Next(80B)</li>
133134
<li>Baichuan (7B)</li>
134135
<li>Baichuan2 (7B-13B)</li>
135136
<li>Code Llama (7B - 34B)</li>

README_ja.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
117117
<li>Qwen2-MoE (57BA14B)</li>
118118
<li>Qwen2.5 (0.5B - 32B)</li>
119119
<li>Qwen3, Qwen3-MoE</li>
120+
<li>Qwen3-Next(80B)</li>
120121
<li>Baichuan (7B)</li>
121122
<li>Baichuan2 (7B-13B)</li>
122123
<li>Code Llama (7B - 34B)</li>

README_zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
131131
<li>Qwen2-MoE (57BA14B)</li>
132132
<li>Qwen2.5 (0.5B - 32B)</li>
133133
<li>Qwen3, Qwen3-MoE</li>
134+
<li>Qwen3-Next(80B)</li>
134135
<li>Baichuan (7B)</li>
135136
<li>Baichuan2 (7B-13B)</li>
136137
<li>Code Llama (7B - 34B)</li>

docs/en/supported_models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
8585
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
8686
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
8787
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes\* |
88+
| QWen3-Next | 80B | LLM | Yes | No | No | No | No |
8889
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
8990
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
9091
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |

docs/zh_cn/supported_models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
8686
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
8787
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes |
88+
| QWen3-Next | 80B | LLM | Yes | No | No | No | No |
8889
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
8990
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
9091
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |

lmdeploy/pytorch/backends/cuda/graph_runner.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ def __init__(
9191
self.pool = pool
9292
self._graph: torch.cuda.CUDAGraph = None
9393

94+
def make_output_buffers(self, output):
95+
"""Make output buffers."""
96+
output_buffers = dict(logits=output)
97+
return output_buffers
98+
99+
def slice_output(self, output_buffers: Dict[str, Any], inputs: Dict[str, Any]):
100+
"""Slice output."""
101+
num_tokens = inputs['input_ids'].size(-1)
102+
return output_buffers['logits'][:, :num_tokens]
103+
94104
@record_function('capture_cudagraph')
95105
def capture(self, **kwargs):
96106
"""Capture graph."""
@@ -102,29 +112,31 @@ def capture(self, **kwargs):
102112
current_stream = torch.cuda.current_stream()
103113

104114
# warmup
105-
self.model(**padded_kwargs)
115+
warmup_output = self.model(**padded_kwargs)
116+
warmup_buffers = self.make_output_buffers(warmup_output)
106117

107118
self._graph = torch.cuda.CUDAGraph()
108119
# unsafe kernel call in other thread might invalid the capture
109120
# so we set thread_safe capture mode here.
110121
with torch.cuda.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'):
111122
output = self.model(**padded_kwargs)
112123

113-
output_buffers = dict(logits=output)
124+
output_buffers = self.make_output_buffers(output)
114125
self.meta.output_buffers = output_buffers
126+
output = self.slice_output(warmup_buffers, kwargs)
115127
return output
116128

117129
@record_function('forward_cudagraph')
118130
def forward(self, **kwargs):
119131
"""forward."""
120-
num_tokens = kwargs['input_ids'].size(-1)
121132
assert self._graph is not None
122133
self.model.fill_buffers_cudagraph(self.meta, **kwargs)
123134
context = self.ctx_mgr.current_context()
124135
self.model.update_context_cudagraph(self.meta, context)
125136
self._graph.replay()
126137

127-
output = self.meta.output_buffers['logits'][:, :num_tokens]
138+
output_buffers = self.meta.output_buffers
139+
output = self.slice_output(output_buffers, kwargs)
128140
return output
129141

130142
def __del__(self):
@@ -223,12 +235,14 @@ def __call__(self, **kwargs):
223235
pool=self.graph_pool_handle,
224236
model_config=self.model_config,
225237
device=self.device)
226-
runner.capture(**kwargs)
238+
output = runner.capture(**kwargs)
227239
self._runner_map[graph_key] = runner
240+
# SSM would update the state in capture(warmup), replay the graph will leads unexpected state update.
241+
return output
228242
else:
229243
runner = self._runner_map[graph_key]
230-
output = runner.forward(**kwargs)
231-
return output
244+
output = runner.forward(**kwargs)
245+
return output
232246

233247
@record_function('prepare_inputs_for_generation')
234248
def prepare_inputs_for_generation(

lmdeploy/pytorch/check_env/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ def check_dtype(self, config):
5757
if not is_bf16_supported(device_type):
5858
logger.warning('Device does not support bfloat16.')
5959
except Exception as e:
60-
message = (f'Checking failed with error {e}', 'Please send issue to LMDeploy with error logs.')
60+
message = (f'Checking failed with error {e}. Please send issue to LMDeploy with error logs.')
61+
self.log_and_exit(e, 'Model', message=message)
62+
63+
try:
64+
model_config.check_env_func(device_type)
65+
except Exception as e:
66+
message = (f'Checking failed with error {e}.')
6167
self.log_and_exit(e, 'Model', message=message)
6268

6369
def check(self):

lmdeploy/pytorch/config.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import enum
3-
from dataclasses import dataclass
4-
from typing import Any, Dict, List, Literal
3+
from dataclasses import dataclass, field
4+
from typing import Any, Callable, Dict, List, Literal, Tuple
55

66
import torch
77

@@ -86,6 +86,8 @@ class CacheConfig:
8686
enable_prefix_caching: bool = False
8787
quant_policy: Literal[0, 4, 8] = 0
8888
device_type: str = 'cuda'
89+
num_state_caches: int = None
90+
states_shapes: List[Tuple] = field(default_factory=list)
8991

9092
# For PD Disaggregation
9193
role: EngineRole = EngineRole.Hybrid
@@ -183,6 +185,10 @@ def override_hf_config(hf_config: Any, hf_overrides: Dict[str, Any]):
183185
_override_hf_config(hf_config, k, v)
184186

185187

188+
def _default_check_env(device: str):
189+
pass
190+
191+
186192
@dataclass
187193
class ModelConfig:
188194
"""Config of model."""
@@ -208,6 +214,13 @@ class ModelConfig:
208214
dllm_mask_token: int = 0
209215
dllm_block_length: int = None
210216

217+
# added for qwen3_next
218+
# could used for any SSM model.
219+
states_shapes: List[Tuple[Tuple[int], torch.dtype]] = field(default_factory=list)
220+
221+
# check env for model-device combination
222+
check_env_func: Callable = _default_check_env
223+
211224
def get_head_size(self):
212225
"""Get head size."""
213226
return self.head_dim

lmdeploy/pytorch/configurations/default.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
3737
eos_token_id=hf_config.eos_token_id,
3838
sliding_window=sliding_window,
3939
head_dim=head_dim,
40+
k_head_dim=head_dim,
41+
v_head_dim=head_dim,
4042
vocab_size=hf_config.vocab_size,
4143
llm_config=hf_config,
4244
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
from .builder import AutoModelConfigBuilder
5+
from .default import DefaultModelConfigBuilder
6+
7+
8+
def _check_env_qwen3_next(device: str):
9+
"""Check env for qwen3 next."""
10+
if device != 'cuda':
11+
return
12+
13+
# check cuda
14+
try:
15+
import causal_conv1d # noqa: F401
16+
except ImportError:
17+
raise ImportError('Qwen3-Next cuda support requires https://github.com/Dao-AILab/causal-conv1d.')
18+
19+
try:
20+
import fla # noqa: F401
21+
except ImportError:
22+
raise ImportError('Qwen3-Next cuda support requires https://github.com/fla-org/flash-linear-attention.')
23+
24+
25+
class Qwen3NextModelConfigBuilder(AutoModelConfigBuilder):
26+
27+
@classmethod
28+
def condition(cls, hf_config):
29+
"""config."""
30+
return hf_config.model_type == 'qwen3_next'
31+
32+
@classmethod
33+
def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):
34+
"""build."""
35+
cfg = DefaultModelConfigBuilder.build(hf_config, model_path, tp=tp, **kwargs)
36+
37+
# update num layers
38+
num_layers = cfg.num_layers
39+
num_full_layers = num_layers // hf_config.full_attention_interval
40+
num_delta_layers = num_full_layers * (hf_config.full_attention_interval - 1)
41+
cfg.num_layers = num_full_layers
42+
43+
# set state shapes
44+
head_k_dim = hf_config.linear_key_head_dim
45+
head_v_dim = hf_config.linear_value_head_dim
46+
num_v_heads = hf_config.linear_num_value_heads // tp
47+
num_k_heads = hf_config.linear_num_key_heads // tp
48+
key_dim = head_k_dim * num_k_heads
49+
value_dim = head_v_dim * num_v_heads
50+
conv_dim = key_dim * 2 + value_dim
51+
conv_kernel_size = hf_config.linear_conv_kernel_dim
52+
53+
conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size)
54+
recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)
55+
dtype = torch.bfloat16
56+
cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]
57+
cfg.check_env_func = _check_env_qwen3_next
58+
return cfg

0 commit comments

Comments
 (0)