-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate pytorch poc python api #490
Open
HIT-cwh
wants to merge
11
commits into
open-compass:main
Choose a base branch
from
HIT-cwh:pytorch-poc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
abfbeff
WIP: add pytorch-poc
HIT-cwh b3c1f2a
add cfgs
HIT-cwh 68574af
fix meta_template
HIT-cwh eaa6326
fix meta_template
HIT-cwh 4138ad1
imporve
HIT-cwh 31fa5ef
fix path
HIT-cwh 004ed42
update
HIT-cwh 8b54e45
set stop_words to list
HIT-cwh f984e00
fix according to comments
HIT-cwh dff4ca5
add placeholder for convenient import
HIT-cwh 72cc8d9
delete useless files
HIT-cwh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from mmengine.config import read_base | ||
from opencompass.models import PytorchModel | ||
|
||
|
||
with read_base(): | ||
# choose a list of datasets | ||
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets | ||
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets | ||
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets | ||
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_6dc406 import WSC_datasets | ||
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets | ||
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets | ||
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets | ||
from .datasets.race.race_gen_69ee4f import race_datasets | ||
from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets | ||
# and output the results in a choosen format | ||
from .summarizers.medium import summarizer | ||
|
||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) | ||
|
||
|
||
meta_template = dict( | ||
round=[ | ||
dict(role='HUMAN', begin='<|User|>:', end='<eoh>\n'), | ||
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True), | ||
], | ||
eos_token_id=103028) | ||
|
||
models = [ | ||
dict( | ||
type=PytorchModel, | ||
abbr='internlm-chat-20b-pytorch-poc', | ||
path='internlm/internlm-chat-7b', | ||
max_out_len=100, | ||
max_seq_len=2048, | ||
batch_size=8, | ||
concurrency=8, | ||
meta_template=meta_template, | ||
run_cfg=dict(num_gpus=1, num_procs=1), | ||
) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from mmengine.config import read_base | ||
from opencompass.models import PytorchModel | ||
|
||
|
||
with read_base(): | ||
# choose a list of datasets | ||
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets | ||
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets | ||
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets | ||
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_6dc406 import WSC_datasets | ||
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets | ||
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets | ||
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets | ||
from .datasets.race.race_gen_69ee4f import race_datasets | ||
from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets | ||
# and output the results in a choosen format | ||
from .summarizers.medium import summarizer | ||
|
||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) | ||
|
||
|
||
meta_template = dict( | ||
round=[ | ||
dict(role='HUMAN', begin='<|User|>:', end='<eoh>\n'), | ||
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True), | ||
], | ||
eos_token_id=103028) | ||
|
||
models = [ | ||
dict( | ||
type=PytorchModel, | ||
abbr='internlm-chat-7b-pytorch-poc-w8a8', | ||
path = '/nvme/caoweihan/projects/lmdeploy/work_dir', # comming soon | ||
max_out_len=100, | ||
max_seq_len=2048, | ||
batch_size=8, | ||
concurrency=8, | ||
meta_template=meta_template, | ||
run_cfg=dict(num_gpus=1, num_procs=1), | ||
) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import random | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Dict, List, Optional, Union | ||
|
||
from transformers import AutoTokenizer | ||
|
||
from opencompass.models.base import BaseModel | ||
from opencompass.utils.logging import get_logger | ||
from opencompass.utils.prompt import PromptList | ||
|
||
try: | ||
from lmdeploy.pytorch_poc import engine as tm | ||
from lmdeploy.pytorch_poc.messages import SamplingParam | ||
except ImportError: | ||
from opencompass.utils import get_package_placeholder, get_placeholder | ||
tm = get_package_placeholder('lmdeploy') | ||
SamplingParam = get_placeholder('lmdeploy') | ||
|
||
PromptType = Union[PromptList, str] | ||
|
||
|
||
def valid_str(string, coding='utf-8'): | ||
"""decode text according to its encoding type.""" | ||
invalid_chars = [b'\xef\xbf\xbd'] | ||
bstr = bytes(string, coding) | ||
for invalid_char in invalid_chars: | ||
bstr = bstr.replace(invalid_char, b'') | ||
ret = bstr.decode(encoding=coding, errors='ignore') | ||
return ret | ||
|
||
|
||
class PytorchModel(BaseModel): | ||
"""Model wrapper for TurboMind Python API. | ||
|
||
Args: | ||
path (str): path of the turbomind model | ||
max_seq_len (int): The maximum allowed sequence length of a model. | ||
Note that the length of prompt + generated tokens shall not exceed | ||
this value. Defaults to 2048. | ||
meta_template (Dict, optional): The model's meta prompt | ||
template if needed, in case the requirement of injecting or | ||
wrapping of any meta instructions. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
path: str, | ||
concurrency: int = 8, | ||
max_seq_len: int = 2048, | ||
meta_template: Optional[Dict] = None, | ||
): | ||
super().__init__(path=path, | ||
max_seq_len=max_seq_len, | ||
meta_template=meta_template) | ||
self.logger = get_logger() | ||
self.tokenizer = AutoTokenizer.from_pretrained(path, | ||
trust_remote_code=True) | ||
tm_model = tm.Engine(path) | ||
self.generators = [ | ||
tm_model.create_instance() for i in range(concurrency) | ||
] | ||
self.generator_ids = [i + 1 for i in range(concurrency)] | ||
|
||
def generate( | ||
self, | ||
inputs: List[str], | ||
max_out_len: int = 512, | ||
temperature: float = 1.0, | ||
) -> List[str]: | ||
"""Generate results given a list of inputs. | ||
|
||
Args: | ||
inputs (List[str]): A list of prompts | ||
max_out_len (int): The maximum length of the output. | ||
temperature (float): What sampling temperature to use, | ||
between 0 and 2. Higher values like 0.8 will make the output | ||
more random, while lower values like 0.2 will make it more | ||
focused and deterministic. Defaults to 1.0. | ||
|
||
Returns: | ||
List[str]: A list of generated strings. | ||
""" | ||
assert isinstance( | ||
inputs, List), f'List(str) is expected, but got {type(inputs)}' | ||
|
||
# split inputs into batches | ||
batch_size = len(self.generators) | ||
batch_inputs = [ | ||
inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size) | ||
] | ||
|
||
results = [] | ||
for batch_input in batch_inputs: | ||
with ThreadPoolExecutor() as executor: | ||
_results = list( | ||
executor.map(self._generate, | ||
self.generators[:len(batch_input)], | ||
self.generator_ids[:len(batch_input)], | ||
batch_input, [max_out_len] * len(batch_input), | ||
[temperature] * len(batch_input))) | ||
results += _results | ||
return results | ||
|
||
def get_token_len(self, prompt: str) -> int: | ||
input_ids = self.tokenizer.encode(prompt) | ||
return len(input_ids) | ||
|
||
def wait(self): | ||
"""Wait till the next query can be sent. | ||
|
||
Applicable in both single-thread and multi-thread environments. | ||
""" | ||
return self.token_bucket.get_token() | ||
|
||
def _generate(self, generator, session_id, prompt: str or PromptList, | ||
max_out_len: int, temperature: float) -> str: | ||
"""Generate results given a list of inputs. | ||
|
||
Args: | ||
prompt (str or PromptList): A string or PromptDict. | ||
The PromptDict should be organized in OpenCompass' | ||
API format. | ||
max_out_len (int): The maximum length of the output. | ||
temperature (float): What sampling temperature to use, | ||
between 0 and 2. Higher values like 0.8 will make the output | ||
more random, while lower values like 0.2 will make it more | ||
focused and deterministic. | ||
|
||
Returns: | ||
str: The generated string. | ||
""" | ||
assert type( | ||
prompt) is str, 'We only support string for TurboMind Python API' | ||
input_ids = self.tokenizer.encode(prompt) | ||
sampling_param = SamplingParam(top_k=40, | ||
top_p=0.8, | ||
temperature=temperature, | ||
repetition_penalty=1.0, | ||
ignore_eos=False, | ||
random_seed=random.getrandbits(64), | ||
stop_words=[self.eos_token_id]) | ||
response_size = 0 | ||
|
||
for outputs in generator.stream_infer( | ||
session_id=session_id, | ||
# input_ids=input_ids, | ||
prompt_token_ids=input_ids, | ||
request_output_len=max_out_len, | ||
step=0, | ||
sampling_param=sampling_param): | ||
status, res, tokens = outputs | ||
response_all = self.tokenizer.decode(res) | ||
response_cur = response_all[response_size:] | ||
response_all = valid_str(response_all) | ||
response_size += len(response_cur) | ||
if hasattr(generator, 'end'): | ||
generator.end(session_id) | ||
return response_all |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
def get_placeholder(string: str) -> object: | ||
"""Get placeholder instance which can avoid raising errors when down-stream | ||
dependency is not installed properly. | ||
|
||
Args: | ||
string (str): the dependency's name, i.e. `mmcls` | ||
|
||
Raises: | ||
ImportError: raise it when the dependency is not installed properly. | ||
|
||
Returns: | ||
object: PlaceHolder instance. | ||
""" | ||
|
||
def raise_import_error(package_name): | ||
raise ImportError( | ||
f'`{package_name}` is not installed properly, plz check.') | ||
|
||
class PlaceHolder(): | ||
|
||
def __init__(self) -> None: | ||
raise_import_error(string) | ||
|
||
return PlaceHolder | ||
|
||
|
||
def get_package_placeholder(string: str) -> object: | ||
"""Get placeholder instance which can avoid raising errors when down-stream | ||
dependency is not installed properly. | ||
|
||
Args: | ||
string (str): the dependency's name, i.e. `mmcls` | ||
|
||
Raises: | ||
ImportError: raise it when the dependency is not installed properly. | ||
|
||
Returns: | ||
object: PlaceHolder instance. | ||
""" | ||
|
||
def raise_import_error(package_name): | ||
raise ImportError( | ||
f'`{package_name}` is not installed properly, plz check.') | ||
|
||
class PlaceHolderMetaclass(type): | ||
"""Used to support usage of PlaceHolder.xxxx.""" | ||
|
||
def __getattr__(self, name): | ||
raise_import_error(string) | ||
|
||
class PlaceHolder(metaclass=PlaceHolderMetaclass): | ||
|
||
def __init__(self) -> None: | ||
raise_import_error(string) | ||
|
||
return PlaceHolder |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
PytorchTurbomindModel
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the ambiguity. With
Turbomind
in lmdeploy, we harmoniously integrate C++ with Python to carry out the inference process. On the other hand, our Pytorch proof-of-concept prefers to take a more streamlined approach by solely utilizing Python for inference. To explore the utilization of Turbomind on OpenCompass, kindly consider referring to pr484 for detailed guidance.