-
Notifications
You must be signed in to change notification settings - Fork 832
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #863 from OptimalScale/yizhenjia-vllm-inferencer
[Feature] Add vllm inference example
- Loading branch information
Showing
10 changed files
with
211 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. | ||
import logging | ||
import os | ||
import sys | ||
|
||
from transformers import ( | ||
HfArgumentParser | ||
) | ||
|
||
from lmflow.datasets import Dataset | ||
from lmflow.models.hf_decoder_model import HFDecoderModel | ||
from lmflow.pipeline.auto_pipeline import AutoPipeline | ||
from lmflow.args import ( | ||
ModelArguments, | ||
DatasetArguments, | ||
AutoArguments, | ||
) | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def main(): | ||
# Parses arguments | ||
pipeline_name = "vllm_inferencer" | ||
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) | ||
|
||
parser = HfArgumentParser(( | ||
ModelArguments, | ||
DatasetArguments, | ||
PipelineArguments | ||
)) | ||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | ||
# If we pass only one argument to the script and it's the path to a json file, | ||
# let's parse it to get our arguments. | ||
model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | ||
else: | ||
model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() | ||
|
||
dataset = Dataset(data_args) | ||
model = HFDecoderModel(model_args) | ||
inferencer = AutoPipeline.get_pipeline( | ||
pipeline_name=pipeline_name, | ||
model_args=model_args, | ||
data_args=data_args, | ||
pipeline_args=pipeline_args | ||
) | ||
|
||
res = inferencer.inference( | ||
model, | ||
dataset, | ||
release_gpu=False, | ||
enable_decode_inference_result=pipeline_args.enable_decode_inference_result, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
#!/bin/bash | ||
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. | ||
|
||
# Parses arguments | ||
run_name=vllm_inference | ||
model_name_or_path='Qwen/Qwen2-0.5B' | ||
dataset_path=data/alpaca/test_conversation | ||
output_dir=data/inference_results | ||
output_file_name=results.json | ||
apply_chat_template=True | ||
|
||
# Safety related arguments | ||
trust_remote_code=0 | ||
|
||
while [[ $# -ge 1 ]]; do | ||
key="$1" | ||
case ${key} in | ||
-r|--run_name) | ||
run_name="$2" | ||
shift | ||
;; | ||
-m|--model_name_or_path) | ||
model_name_or_path="$2" | ||
shift | ||
;; | ||
-d|--dataset_path) | ||
dataset_path="$2" | ||
shift | ||
;; | ||
--output_dir) | ||
output_dir="$2" | ||
shift | ||
;; | ||
--output_file_name) | ||
output_file_name="$2" | ||
shift | ||
;; | ||
--apply_chat_template) | ||
apply_chat_template="$2" | ||
shift | ||
;; | ||
--trust_remote_code) | ||
trust_remote_code="$2" | ||
shift | ||
;; | ||
*) | ||
echo "error: unknown option \"${key}\"" 1>&2 | ||
exit 1 | ||
esac | ||
shift | ||
done | ||
|
||
# inference | ||
project_dir=$(cd "$(dirname $0)"/..; pwd) | ||
log_dir=${project_dir}/log/${run_name} | ||
output_file_path=${output_dir}/${run_name}/${output_file_name} | ||
mkdir -p ${output_dir}/${run_name} ${log_dir} | ||
|
||
python examples/vllm_inference.py \ | ||
--use_vllm True \ | ||
--trust_remote_code ${trust_remote_code} \ | ||
--model_name_or_path ${model_name_or_path} \ | ||
--dataset_path ${dataset_path} \ | ||
--preprocessing_num_workers 16 \ | ||
--random_seed 42 \ | ||
--apply_chat_template ${apply_chat_template} \ | ||
--num_output_sequences 2 \ | ||
--use_beam_search False \ | ||
--temperature 1.0 \ | ||
--top_p 0.9 \ | ||
--max_new_tokens 1024 \ | ||
--save_results True \ | ||
--results_path ${output_file_path} \ | ||
--enable_decode_inference_result False \ | ||
--vllm_gpu_memory_utilization 0.95 \ | ||
--vllm_tensor_parallel_size 2 \ | ||
2>&1 | tee ${log_dir}/vllm_inference.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.