Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ def get_launching_server_cmd(model_path, backend, server_config):
elif backend == 'sglang':
cmd = ['python3', '-m', 'sglang.launch_server', '--model-path', model_path]
elif backend == 'vllm':
cmd = ['vllm', 'serve', '--model', model_path]
cmd = ['vllm', 'serve', model_path]
else:
raise ValueError(f'unknown backend: {backend}')
for key, value in server_config.items():
# Convert snake_case to kebab-case for command line args
key = key.replace('_', '-')
cmd.append(f'--{key}')
cmd.append(str(value))
if str(value):
cmd.append(str(value))
# Special handling for proxy server case
if server_config.get('proxy_url') and server_config.get('dp'):
cmd.append('--allow-terminate-by-client')
Expand Down Expand Up @@ -66,9 +67,9 @@ def get_server_ip_port(backend: str, server_config: Dict) -> Tuple[str, int]:
server_ip = server_config.get('server_ip', '0.0.0.0')
server_port = server_config.get('server_port', 23333)
elif backend == 'sglang':
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 30000))
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 30000))
elif backend == 'vllm':
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 8000))
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 8000))
else:
raise ValueError(f'unknown backend: {backend}')
return server_ip, server_port
Expand Down Expand Up @@ -131,7 +132,7 @@ def benchmark(model_path: str, backend: str, server_config: Dict, data_config: D

try:

print(f"Starting api_server: {' '.join(server_cmd)}")
print(f"Starting api_server: {' '.join(server_cmd)}", flush=True)
proc = subprocess.Popen(server_cmd)
# Wait for the server to be ready
wait_server_ready(server_ip, server_port)
Expand Down
104 changes: 104 additions & 0 deletions docs/en/advance/spec_decoding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Speculative Decoding

Speculative decoding is an optimization technique that introcude a lightweight draft model to propose multiple next tokens and then, the main model verify and choose the longest matched tokens in a forward pass. Compared with standard auto-regressive decoding, this methold lets the system generate multiple tokens at once.

> \[!NOTE\]
> This is an experimental feature in lmdeploy.

## Examples

Here are some examples.

### Eagle 3

#### Prepare

Install [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release)

```shell
git clone --depth=1 https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```

#### pipeline

```python
from lmdeploy import pipeline, PytorchEngineConfig
from lmdeploy.messages import SpeculativeConfig


if __name__ == '__main__':

model_path = 'meta-llama/Llama-3.1-8B-Instruct'
spec_cfg = SpeculativeConfig(method='eagle3',
num_speculative_tokens=3,
model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B',
)
pipe = pipeline(model_path,
backend_config=PytorchEngineConfig(max_batch_size=128),
speculative_config=spec_cfg)
response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
print(response)
```

### serving

```shell
lmdeploy serve api_server \
meta-llama/Llama-3.1-8B-Instruct \
--backend pytorch \
--server-port 24545 \
--speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \
--speculative-algorithm eagle3 \
--speculative-num-draft-tokens 3 \
--max-batch-size 128 \
--enable-metrics
```

### Deepseek MTP

#### Prepare

Install [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation)

```shell
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
cd flash-mla
git submodule update --init --recursive
pip install -v .
```

#### pipeline

```python
from lmdeploy import pipeline, PytorchEngineConfig
from lmdeploy.messages import SpeculativeConfig


if __name__ == '__main__':

model_path = 'deepseek-ai/DeepSeek-V3'
spec_cfg = SpeculativeConfig(method='deepseek_mtp',
num_speculative_tokens=3,
)
pipe = pipeline(model_path,
backend_config=PytorchEngineConfig(tp=16, max_batch_size=128),
speculative_config=spec_cfg)
response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
print(response)
```

### serving

```shell
lmdeploy serve api_server \
deepseek-ai/DeepSeek-V3 \
--backend pytorch \
--server-port 24545 \
--tp 16 \
--speculative-algorithm deepseek_mtp \
--speculative-num-draft-tokens 3 \
--max-batch-size 128 \
--enable-metrics
```
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Documentation
advance/pytorch_multinodes.md
advance/pytorch_profiling.md
advance/metrics.md
advance/spec_decoding.md

.. toctree::
:maxdepth: 1
Expand Down
104 changes: 104 additions & 0 deletions docs/zh_cn/advance/spec_decoding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Speculative Decoding

推测解码是一种优化技术,它通过引入轻量级草稿模型来预测多个后续token,再由主模型在前向推理过程中验证并选择匹配度最高的长token序列。与标准的自回归解码相比,这种方法可使系统一次性生成多个token。

> \[!NOTE\]
> 请注意,这是lmdeploy中的实验性功能。

## 示例

请参考如下使用示例。

### Eagle 3

#### 安装依赖

安装 [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release)

```shell
git clone --depth=1 https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```

#### pipeline

```python
from lmdeploy import pipeline, PytorchEngineConfig
from lmdeploy.messages import SpeculativeConfig


if __name__ == '__main__':

model_path = 'meta-llama/Llama-3.1-8B-Instruct'
spec_cfg = SpeculativeConfig(method='eagle3',
num_speculative_tokens=3,
model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B',
)
pipe = pipeline(model_path,
backend_config=PytorchEngineConfig(max_batch_size=128),
speculative_config=spec_cfg)
response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
print(response)
```

### serving

```shell
lmdeploy serve api_server \
meta-llama/Llama-3.1-8B-Instruct \
--backend pytorch \
--server-port 24545 \
--speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \
--speculative-algorithm eagle3 \
--speculative-num-draft-tokens 3 \
--max-batch-size 128 \
--enable-metrics
```

### Deepseek MTP

#### 安装依赖

Install [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation)

```shell
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
cd flash-mla
git submodule update --init --recursive
pip install -v .
```

#### pipeline

```python
from lmdeploy import pipeline, PytorchEngineConfig
from lmdeploy.messages import SpeculativeConfig


if __name__ == '__main__':

model_path = 'deepseek-ai/DeepSeek-V3'
spec_cfg = SpeculativeConfig(method='deepseek_mtp',
num_speculative_tokens=3,
)
pipe = pipeline(model_path,
backend_config=PytorchEngineConfig(tp=16, max_batch_size=128),
speculative_config=spec_cfg)
response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
print(response)
```

### serving

```shell
lmdeploy serve api_server \
deepseek-ai/DeepSeek-V3 \
--backend pytorch \
--server-port 24545 \
--tp 16 \
--speculative-algorithm deepseek_mtp \
--speculative-num-draft-tokens 3 \
--max-batch-size 128 \
--enable-metrics
```
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能:
advance/pytorch_multinodes.md
advance/pytorch_profiling.md
advance/metrics.md
advance/spec_decoding.md

.. toctree::
:maxdepth: 1
Expand Down
10 changes: 9 additions & 1 deletion lmdeploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Literal, Optional, Union

from .archs import autoget_backend_config, get_task
from .messages import PytorchEngineConfig, TurbomindEngineConfig
from .messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig
from .model import ChatTemplateConfig


Expand All @@ -12,6 +12,7 @@ def pipeline(model_path: str,
chat_template_config: Optional[ChatTemplateConfig] = None,
log_level: str = 'WARNING',
max_log_len: int = None,
speculative_config: SpeculativeConfig = None,
**kwargs):
"""
Args:
Expand Down Expand Up @@ -68,6 +69,12 @@ def pipeline(model_path: str,
if backend_config is not None else None
model_path = get_model(model_path, download_dir, revision)

# spec model
if speculative_config is not None and speculative_config.model and not os.path.exists(speculative_config.model):
download_dir = backend_config.download_dir \
if backend_config is not None else None
speculative_config.model = get_model(speculative_config.model, download_dir)

_, pipeline_class = get_task(model_path)
if not isinstance(backend_config, PytorchEngineConfig):
# set auto backend mode
Expand All @@ -80,6 +87,7 @@ def pipeline(model_path: str,
backend_config=backend_config,
chat_template_config=chat_template_config,
max_log_len=max_log_len,
speculative_config=speculative_config,
**kwargs)


Expand Down
16 changes: 13 additions & 3 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os

from ..version import __version__
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args
from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args,
get_speculative_config)


class CLI(object):
Expand Down Expand Up @@ -44,12 +45,12 @@ def add_parser_chat():
', "baichuan-inc/baichuan2-7b-chat" and so on')
# common args
ArgumentHelper.backend(parser)
# # chat template args
# chat template args
ArgumentHelper.chat_template(parser)
# model args
ArgumentHelper.revision(parser)
ArgumentHelper.download_dir(parser)
#

# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.adapters(pt_group)
Expand Down Expand Up @@ -77,6 +78,9 @@ def add_parser_chat():
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.communicator(tb_group)

# speculative decoding
ArgumentHelper.add_spec_group(parser)

@staticmethod
def add_parser_checkenv():
"""Add parser for check_env command."""
Expand Down Expand Up @@ -168,7 +172,13 @@ def get_gpu_topo():
@staticmethod
def chat(args):
from .chat import main

kwargs = convert_args(args)
speculative_config = get_speculative_config(args)
to_remove = ['speculative_algorithm', 'speculative_draft_model', 'speculative_num_draft_tokens']
for key in to_remove:
kwargs.pop(key)
kwargs['speculative_config'] = speculative_config
main(**kwargs)

@staticmethod
Expand Down
Loading