diff --git a/evals/benchmark/stresscli/locust/aistress.py b/evals/benchmark/stresscli/locust/aistress.py index 29e4503f..e1c814f3 100644 --- a/evals/benchmark/stresscli/locust/aistress.py +++ b/evals/benchmark/stresscli/locust/aistress.py @@ -10,6 +10,7 @@ import gevent import sseclient +import transformers from locust import HttpUser, between, events, task from locust.runners import STATE_CLEANUP, STATE_STOPPED, STATE_STOPPING, MasterRunner, WorkerRunner @@ -84,6 +85,8 @@ def _(parser): bench_package = "" console_logger = logging.getLogger("locust.stats_logger") +LLM_MODEL = os.getenv("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct") +tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL) class AiStressUser(HttpUser): @@ -92,6 +95,8 @@ class AiStressUser(HttpUser): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + global tokenizer + self.environment.tokenizer = tokenizer @task def bench_main(self): diff --git a/evals/benchmark/stresscli/locust/tokenresponse.py b/evals/benchmark/stresscli/locust/tokenresponse.py index a0336dec..5ec8a846 100644 --- a/evals/benchmark/stresscli/locust/tokenresponse.py +++ b/evals/benchmark/stresscli/locust/tokenresponse.py @@ -14,7 +14,11 @@ def testFunc(): def respStatics(environment, req, resp): - tokenizer = transformers.AutoTokenizer.from_pretrained(environment.parsed_options.llm_model) + if not hasattr(environment, "tokenizer"): + tokenizer = transformers.AutoTokenizer.from_pretrained(environment.parsed_options.llm_model) + else: + tokenizer = environment.tokenizer + if environment.parsed_options.bench_target in [ "chatqnafixed", "chatqnabench",