Skip to content

Commit 313a930

Browse files
authored
Add speculative decoding part (#3711)
* add speculative decoding part * black formatting * fix missing updates * add dataset preparation for sharegpt
1 parent fef5a52 commit 313a930

File tree

6 files changed

+360
-47
lines changed

6 files changed

+360
-47
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
FROM lmsysorg/sglang:v0.5.2rc2-cu126
2+
ENV BASE_MODEL nvidia/Llama-3.1-8B-Instruct-FP8
3+
ENV DRAFT_MODEL lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B
4+
ENV SGLANG_ARGS "--tp-size 1 --max-running-requests 32 --mem-fraction-static 0.8 --enable-torch-compile --speculative-algorithm EAGLE3 --speculative-num-steps 3 --speculative-eagle-topk 2 --speculative-num-draft-tokens 4 --dtype float16 --attention-backend fa3 --host 0.0.0.0 --port 30000"
5+
ENV SGL_HOST 0.0.0.0
6+
ENV SGL_PORT 30000
7+
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN 1
8+
9+
EXPOSE 30000
10+
ENTRYPOINT python3 -m sglang.launch_server --model-path $BASE_MODEL --speculative-draft-model-path $DRAFT_MODEL $SGLANG_ARGS
55 KB
Loading

sdk/python/foundation-models/system/reinforcement-learning/reinforcement-learning.ipynb

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,18 @@
122122
"source": [
123123
"import matplotlib.pyplot as plt\n",
124124
"from scripts.utils import setup_workspace\n",
125-
"from scripts.dataset import prepare_finqa_dataset\n",
125+
"from scripts.dataset import prepare_finqa_dataset, prepare_sharegpt_dataset\n",
126126
"from scripts.run import get_run_metrics\n",
127127
"from scripts.reinforcement_learning import run_rl_training_pipeline\n",
128128
"from scripts.evaluation import run_evaluation_pipeline\n",
129129
"from scripts.speculative_decoding import (\n",
130130
" run_draft_model_pipeline,\n",
131131
" prepare_combined_model_for_deployment,\n",
132132
" deploy_speculative_decoding_endpoint,\n",
133+
" deploy_base_model_endpoint,\n",
134+
" run_evaluation_speculative_decoding,\n",
133135
")\n",
134-
"from scripts.deployment import create_managed_deployment, test_deployment"
136+
"from scripts.deployment import test_deployment"
135137
]
136138
},
137139
{
@@ -150,7 +152,7 @@
150152
"cell_type": "markdown",
151153
"metadata": {},
152154
"source": [
153-
"<p>Prepare dataset for Finetuning. This would save train, test and valid dataset under data folder</p>"
155+
"<p>Prepare dataset for Fine-tuning. This would save train, test and valid dataset under data folder</p>"
154156
]
155157
},
156158
{
@@ -484,6 +486,15 @@
484486
"<p><strong>Reference:</strong> <a href=\"https://arxiv.org/abs/2503.01840\">https://arxiv.org/abs/2503.01840</a></p>\n"
485487
]
486488
},
489+
{
490+
"cell_type": "code",
491+
"execution_count": null,
492+
"metadata": {},
493+
"outputs": [],
494+
"source": [
495+
"draft_train_data_path = prepare_sharegpt_dataset()"
496+
]
497+
},
487498
{
488499
"cell_type": "code",
489500
"execution_count": null,
@@ -498,7 +509,7 @@
498509
" num_epochs=1, # Number of train epochs to be run by draft trainer.\n",
499510
" monitor=False, # Set to True to wait for completion.\n",
500511
" base_model_mlflow_path=\"azureml://registries/azureml-meta/models/Meta-Llama-3-8B-Instruct/versions/9\",\n",
501-
" draft_train_data_path=\"./data_for_draft_model/train/sharegpt_train_small.jsonl\",\n",
512+
" draft_train_data_path=draft_train_data_path,\n",
502513
")"
503514
]
504515
},
@@ -591,8 +602,7 @@
591602
"endpoint_name = deploy_speculative_decoding_endpoint(\n",
592603
" ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n",
593604
" combined_model=combined_model, # Reference from previous steps where combined model is created.\n",
594-
" instance_type=\"octagepu\", # Instance type Kubernetes Cluster\n",
595-
" compute_name=\"k8s-a100-compute\",\n",
605+
" instance_type=\"Standard_NC40ads_H100_v5\", # Instance type\n",
596606
")"
597607
]
598608
},
@@ -631,10 +641,9 @@
631641
"outputs": [],
632642
"source": [
633643
"# Deploy managed online endpoint with base model\n",
634-
"base_endpoint_name = create_managed_deployment( # Function to create endpoint for base model.\n",
644+
"base_endpoint_name = deploy_base_model_endpoint( # Function to create endpoint for base model.\n",
635645
" ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n",
636-
" model_asset_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # Huggingface ID of the base model.\n",
637-
" instance_type=\"Standard_ND96amsr_A100_v4\", # Compute SKU on which base model will be deployed.\n",
646+
" instance_type=\"Standard_NC40ads_H100_v5\", # Compute SKU on which base model will be deployed.\n",
638647
")"
639648
]
640649
},
@@ -711,10 +720,12 @@
711720
"# Run evaluation job to compare base model and speculative decoding endpoints' performance\n",
712721
"evaluation_job = run_evaluation_speculative_decoding(\n",
713722
" ml_client=ml_client,\n",
723+
" registry_ml_client=registry_ml_client,\n",
714724
" base_endpoint_name=base_endpoint_name, # Base model endpoint from previous step.\n",
715725
" speculative_endpoint_name=endpoint_name, # Speculative endpoint from previous step.\n",
716-
" base_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n",
717-
" speculative_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n",
726+
" base_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n",
727+
" speculative_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n",
728+
" compute_cluster=\"d13-v2\",\n",
718729
")"
719730
]
720731
},
@@ -735,7 +746,7 @@
735746
"cell_type": "markdown",
736747
"metadata": {},
737748
"source": [
738-
"<img src=\"metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">"
749+
"<img src=\"./images/metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">"
739750
]
740751
}
741752
],

sdk/python/foundation-models/system/reinforcement-learning/scripts/dataset.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
from azure.ai.ml import MLClient
77
from azure.ai.ml.entities import Data
88
from azure.ai.ml.constants import AssetTypes
9+
from typing import Optional
10+
from json import JSONDecodeError
11+
import requests
12+
from tqdm import tqdm
13+
14+
15+
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
916

1017

1118
def register_dataset(ml_client: MLClient, dataset_name: str, file_path: str):
@@ -164,3 +171,100 @@ def map_fn(example: pd.Series, idx: int, split: str):
164171
return train_data.id, test_data.id, valid_data.id
165172

166173
return train_dataset_path, test_dataset_path, valid_dataset_path
174+
175+
176+
def _is_file_valid_json(path):
177+
if not os.path.isfile(path):
178+
return False
179+
180+
try:
181+
with open(path) as f:
182+
json.load(f)
183+
return True
184+
except JSONDecodeError as e:
185+
print(
186+
f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
187+
)
188+
return False
189+
190+
191+
def _download_and_cache_file(url: str, filename: Optional[str] = None):
192+
"""Read and cache a file from a url."""
193+
if filename is None:
194+
filename = os.path.join("/tmp", url.split("/")[-1])
195+
196+
# Check if the cache file already exists
197+
if _is_file_valid_json(filename):
198+
return filename
199+
200+
print(f"Downloading from {url} to {filename}")
201+
202+
# Stream the response to show the progress bar
203+
response = requests.get(url, stream=True)
204+
response.raise_for_status() # Check for request errors
205+
206+
# Total size of the file in bytes
207+
total_size = int(response.headers.get("content-length", 0))
208+
chunk_size = 1024 # Download in chunks of 1KB
209+
210+
# Use tqdm to display the progress bar
211+
with open(filename, "wb") as f, tqdm(
212+
desc=filename,
213+
total=total_size,
214+
unit="B",
215+
unit_scale=True,
216+
unit_divisor=1024,
217+
) as bar:
218+
for chunk in response.iter_content(chunk_size=chunk_size):
219+
f.write(chunk)
220+
bar.update(len(chunk))
221+
222+
return filename
223+
224+
225+
def prepare_sharegpt_dataset(dataset_path="./data/draft_model/sharegpt_train_processed.jsonl") -> str:
226+
"""Prepare the ShareGPT dataset for training the draft model."""
227+
# Download sharegpt if necessary
228+
if not os.path.isfile(dataset_path):
229+
temp_dataset_path = _download_and_cache_file(SHAREGPT_URL)
230+
231+
# Load the dataset.
232+
with open(temp_dataset_path) as f:
233+
temp_dataset = json.load(f)
234+
# Filter out the conversations with less than 2 turns.
235+
temp_dataset = [data for data in temp_dataset if len(data["conversations"]) >= 2]
236+
237+
# Keep one conversation in one list
238+
new_dataset = []
239+
for temp_data in temp_dataset:
240+
if len(temp_data["conversations"]) % 2 != 0:
241+
continue
242+
if temp_data["conversations"][0]["from"] != "human":
243+
continue
244+
245+
new_conversations = []
246+
247+
for i in range(0, len(temp_data["conversations"]), 2):
248+
new_conversations.extend([
249+
{
250+
"role": "user",
251+
"content": temp_data["conversations"][i]["value"],
252+
},
253+
{
254+
"role": "assistant",
255+
"content": temp_data["conversations"][i + 1]["value"],
256+
}
257+
])
258+
259+
new_data = {}
260+
new_data["id"] = temp_data.get("id", "")
261+
new_data["conversations"] = new_conversations
262+
263+
new_dataset.append(new_data)
264+
265+
os.makedirs(os.path.dirname(dataset_path), exist_ok=True)
266+
with open(dataset_path, "w") as f:
267+
for item in new_dataset:
268+
f.write(json.dumps(item) + "\n")
269+
270+
return dataset_path

sdk/python/foundation-models/system/reinforcement-learning/scripts/deployment.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def create_managed_deployment(
3636
ml_client: MLClient,
3737
model_asset_id: str, # Asset ID of the model to deploy
3838
instance_type: str, # Supported instance type for managed deployment
39+
model_mount_path: Optional[str] = None,
3940
environment_asset_id: Optional[str] = None, # Asset ID of the serving engine to use
4041
endpoint_name: Optional[str] = None,
4142
endpoint_description: str = "Sample endpoint",
@@ -65,6 +66,7 @@ def create_managed_deployment(
6566
name=deployment_name,
6667
endpoint_name=endpoint_name,
6768
model=model_asset_id,
69+
model_mount_path=model_mount_path,
6870
instance_type=instance_type,
6971
instance_count=1,
7072
environment=environment_asset_id,
@@ -151,7 +153,10 @@ def test_deployment(ml_client, endpoint_name):
151153
"""Run a test request against a deployed endpoint and print the result."""
152154
print("Testing endpoint...")
153155
# Retrieve endpoint URI and API key to authenticate test request
154-
scoring_uri = ml_client.online_endpoints.get(endpoint_name).scoring_uri
156+
scoring_uri = (
157+
ml_client.online_endpoints.get(endpoint_name).scoring_uri.replace("/score", "/")
158+
+ "v1/chat/completions"
159+
)
155160
if not scoring_uri:
156161
raise ValueError("Scoring URI not found for endpoint.")
157162

0 commit comments

Comments
 (0)