diff --git a/README.md b/README.md
index c635c73b0..dde466980 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,23 @@ Please see the [MLPerf Inference benchmark paper](https://arxiv.org/abs/1911.025
```
## MLPerf Inference v4.0 (submission deadline February 23, 2024)
-Code freeze coming soon...
+There is an extra one-week extension allowed only for the llama2-70b submissions. For submissions, please use the master branch and any commit since the [4.0 seed release](https://github.com/mlcommons/inference/commit/8e36925bd36a503e39fcbbc488e9e46126f079ed) although it is best to use the latest commit. v4.0 tag will be created from the master branch after the result publication.
+
+For power submissions please use [SPEC PTD 1.10](https://github.com/mlcommons/power/tree/main/inference_v1.0) (needs special access) and any commit of the power-dev repository after the [code-freeze](https://github.com/mlcommons/power-dev/commit/4e026f43481f46ad57d2464d28924018444b0428)
+
+| model | reference app | framework | dataset | category
+| ---- | ---- | ---- | ---- | ---- |
+| resnet50-v1.5 | [vision/classification_and_detection](https://github.com/mlcommons/inference/tree/master/vision/classification_and_detection) | tensorflow, onnx, tvm, ncnn | imagenet2012 | edge,datacenter |
+| retinanet 800x800 | [vision/classification_and_detection](https://github.com/mlcommons/inference/tree/master/vision/classification_and_detection) | pytorch, onnx | openimages resized to 800x800| edge,datacenter |
+| bert | [language/bert](https://github.com/mlcommons/inference/tree/master/language/bert) | tensorflow, pytorch, onnx | squad-1.1 | edge,datacenter |
+| dlrm-v2 | [recommendation/dlrm_v2](https://github.com/mlcommons/inference/tree/master/recommendation/dlrm_v2/pytorch) | pytorch | Multihot Criteo Terabyte | datacenter |
+| 3d-unet | [vision/medical_imaging/3d-unet-kits19](https://github.com/mlcommons/inference/tree/master/vision/medical_imaging/3d-unet-kits19) | pytorch, tensorflow, onnx | KiTS19 | edge,datacenter |
+| rnnt | [speech_recognition/rnnt](https://github.com/mlcommons/inference/tree/master/speech_recognition/rnnt) | pytorch | OpenSLR LibriSpeech Corpus | edge,datacenter |
+| gpt-j | [language/gpt-j](https://github.com/mlcommons/inference/tree/master/language/gpt-j)| pytorch | CNN-Daily Mail | edge,datacenter |
+| stable-diffusion-xl | [text_to_image](https://github.com/mlcommons/inference/tree/master/text_to_image) | pytorch | COCO 2014| edge,datacenter |
+| llama2-70b | [language/llama2-70b](https://github.com/mlcommons/inference/tree/master/language/llama2-70b) | pytorch | OpenOrca | datacenter |
+
+* Framework here is given for the reference implementation. Submitters are free to use their own frameworks to run the benchmark.
## MLPerf Inference v3.1 (submission August 18, 2023)
Please use [v3.1 tag](https://github.com/mlcommons/inference/releases/tag/v3.1) (```git checkout v3.1```) if you would like to reproduce the v3.1 results.
diff --git a/compliance/nvidia/TEST01/stable-diffusion-xl/audit.config b/compliance/nvidia/TEST01/stable-diffusion-xl/audit.config
new file mode 100644
index 000000000..7e7cfdf55
--- /dev/null
+++ b/compliance/nvidia/TEST01/stable-diffusion-xl/audit.config
@@ -0,0 +1,9 @@
+# The format of this config file is 'key = value'.
+# The key has the format 'model.scenario.key'. Value is mostly int64_t.
+# Model maybe '*' as wildcard. In that case the value applies to all models.
+# All times are in milli seconds
+
+# mode dictionary (0 = submission, 1 = accuracy, 2 = performance, 3 = find peak perf)
+*.*.mode = 2
+*.*.accuracy_log_rng_seed = 720381539243781796
+*.*.accuracy_log_sampling_target = 128
diff --git a/compliance/nvidia/TEST06/README.md b/compliance/nvidia/TEST06/README.md
index a93312f94..a867e7527 100644
--- a/compliance/nvidia/TEST06/README.md
+++ b/compliance/nvidia/TEST06/README.md
@@ -8,9 +8,10 @@ This repository provides the config files and scripts to run and verify TEST 06
## Introduction
-The purpose of this test is to ensure the consistency of the output of the Llama2 model and avoid a potential EOS exploit. This test will make a performance run, with a limit of 100 samples and logging them into `mlperf_log_accuracy.json`. To achieve a passing result in this test, two criteria must be met:
+The purpose of this test is to ensure the consistency of the output of the Llama2 model and avoid a potential EOS exploit. This test will make a performance run, with a limit of 100 samples and logging them into `mlperf_log_accuracy.json`. To achieve a passing result in this test, three criteria must be met:
- In the case the first token is reported independently (not applicable for Offline scenario), it should match for every query with the first token of the model output.
- For each query, the model output should only end with zero or one EOS token
+- The number of reported tokens should match with the length of it's
## Requisites
@@ -36,6 +37,7 @@ Expected output
```
First token check pass: True
EOS check pass: True
+Sample length check pass: True
TEST06 verification complete
```
@@ -44,5 +46,6 @@ Or:
```
First token check pass: Skipped
EOS check pass: True
+Sample length check pass: True
TEST06 verification complete
```
\ No newline at end of file
diff --git a/compliance/nvidia/TEST06/audit.config b/compliance/nvidia/TEST06/audit.config
index c8888be50..090ec03b1 100644
--- a/compliance/nvidia/TEST06/audit.config
+++ b/compliance/nvidia/TEST06/audit.config
@@ -8,4 +8,6 @@
*.*.accuracy_log_rng_seed = 720381539243781796
*.*.accuracy_log_sampling_target = 100
*.*.min_query_count = 100
-*.*.min_duration = 0
\ No newline at end of file
+*.*.min_duration = 0
+# Turn off equal issue mode for TEST06
+*.*.sample_concatenate_permutation = 0
\ No newline at end of file
diff --git a/compliance/nvidia/TEST06/run_verification.py b/compliance/nvidia/TEST06/run_verification.py
index 9d4cd1ad7..03c60c2ef 100644
--- a/compliance/nvidia/TEST06/run_verification.py
+++ b/compliance/nvidia/TEST06/run_verification.py
@@ -61,13 +61,21 @@ def first_token_check(acc_data, dtype):
for sample in acc_data:
data = np.frombuffer(bytes.fromhex(sample["data"]), dtype=dtype)
token_data = np.frombuffer(bytes.fromhex(sample["token_data"]), dtype=dtype)
- print(token_data)
for t1, t2 in zip(data, token_data):
if t1 != t2:
return False
return True
+def sample_len_check(acc_data, dtype):
+ for sample in acc_data:
+ data = np.frombuffer(bytes.fromhex(sample["data"]), dtype=dtype)
+ token_count = int(sample["token_count"])
+ if len(data) != token_count:
+ return False
+ return True
+
+
def main():
args = get_args()
accuracy_file = os.path.join(args.compliance_dir, "mlperf_log_accuracy.json")
@@ -90,6 +98,8 @@ def main():
print("Unexpected error occured while doing the first token check")
first_token_pass = False
+ sample_len_pass = sample_len_check(acc_data, DTYPE_MAP[args.dtype])
+
# Construct output based on the results of checks
output = ""
# Add first token check
@@ -101,14 +111,17 @@ def main():
# Add EOS check
output += f"EOS check pass: {eos_pass}\n"
- if eos_pass and first_token_pass:
+ # Add sample length check
+ output += f"Sample length check pass: {sample_len_pass}\n"
+
+ if eos_pass and first_token_pass and sample_len_pass:
output += "TEST06 verification complete\n"
else:
output += "TEST06 verification failed\n"
# Output test output to console and folder
- output_dir = args.output_dir
- output_accuracy_dir = os.path.join(args.output_dir, "accuracy")
+ output_dir = os.path.join(args.output_dir, "TEST06")
+ output_accuracy_dir = os.path.join(output_dir, "accuracy")
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
diff --git a/language/bert/README.md b/language/bert/README.md
index 1c5fefe3a..144e6d9d0 100644
--- a/language/bert/README.md
+++ b/language/bert/README.md
@@ -51,8 +51,6 @@ The below CM command will launch the SUT server
```
cm run script --tags=generate-run-cmds,inference --model=bert-99 --backend=pytorch \
---rerun --adr.mlperf-implementation.version=custom \
---adr.mlperf-implementation.tags=_repo.https://github.com/GATEOVerflow/inference \
--mode=performance --device=cuda --quiet --test_query_count=1000 --network=sut
```
@@ -61,8 +59,6 @@ Once the SUT server is launched, the below command can be run on the loadgen nod
```
cm run script --tags=generate-run-cmds,inference --model=bert-99 --backend=pytorch --rerun \
---adr.mlperf-implementation.version=custom \
---adr.mlperf-implementation.tags=_repo.https://github.com/GATEOVerflow/inference \
--mode=performance --device=cuda --quiet --test_query_count=1000 \
--sut_servers,=http://localhost:8000 --network=lon
```
diff --git a/language/gpt-j/README.md b/language/gpt-j/README.md
index cc46135f2..061067027 100644
--- a/language/gpt-j/README.md
+++ b/language/gpt-j/README.md
@@ -68,11 +68,37 @@ pip install datasets
python prepare-calibration.py --calibration-list-file calibration-list.txt --output-dir
```
### Download GPT-J model
-Please download the fine-tuned GPT-J checkpoint from [here](https://cloud.mlcommons.org/index.php/s/QAZ2oM94MkFtbQx) and extract it as model/. The download_gptj.py only downloads the default huggingface model which is not fine-tuned on CNN-Daily mail dataset.
+Please download the fine-tuned GPT-J checkpoint using the instructions below. The download_gptj.py only downloads the default huggingface model which is not fine-tuned on CNN-Daily mail dataset.
+#### CM method
+
+The following MLCommons CM commands can be used to programmatically download the model checkpoint.
+
+```
+pip install cmind
+cm pull repo mlcommons@ck
+cm run script --tags=get,ml-model,gptj,_pytorch,_rclone -j
+```
+
+#### Manual method
+
+The above command automatically runs a set of Rclone commands to download the data from a Cloudflare R2 bucket. However, if you'd like to run the Rclone commands manually, you can do so as follows:
+
+To run Rclone on Windows, you can download the executable [here](https://rclone.org/install/#windows).
+To install Rclone on Linux/macOS/BSD systems, run:
```
-wget https://cloud.mlcommons.org/index.php/s/QAZ2oM94MkFtbQx/download --output-document checkpoint.zip
+sudo -v ; curl https://rclone.org/install.sh | sudo bash
```
+Once Rclone is installed, run the following command to authenticate with the bucket:
+```
+rclone config create mlc-inference s3 provider=Cloudflare access_key_id=f65ba5eef400db161ea49967de89f47b secret_access_key=fbea333914c292b854f14d3fe232bad6c5407bf0ab1bebf78833c2b359bdfd2b endpoint=https://c2686074cb2caf5cbaf6d134bdba8b47.r2.cloudflarestorage.com
+```
+You can then navigate in the terminal to your desired download directory and run the following command to download the model checkpoint:
+
+```
+rclone copy mlc-inference:mlcommons-inference-wg-public/gpt-j ./model -P
+```
+
### Running the Benchmark
Replace the model and dataset path arguments with your corresponding paths. For evaluating the ROUGE score after the run, include --accuracy as shown below. For user specific target qps, please include user.conf.
diff --git a/language/gpt-j/SUT.py b/language/gpt-j/SUT.py
new file mode 100644
index 000000000..7233676c6
--- /dev/null
+++ b/language/gpt-j/SUT.py
@@ -0,0 +1,616 @@
+import os
+import sys
+import time
+import re
+import numpy as np
+import array
+import torch
+from torch.nn.functional import pad
+from torch.utils.data import DataLoader
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation.streamers import BaseStreamer
+
+import pickle
+import time
+import threading
+import tqdm
+import queue
+
+from concurrent.futures.thread import ThreadPoolExecutor
+import requests
+from urllib3.exceptions import InsecureRequestWarning
+import json
+
+from inference import GrpcClient
+import more_itertools as mit
+from itertools import repeat
+
+requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
+
+import logging
+from typing import TYPE_CHECKING, Optional, List
+from pathlib import Path
+
+import mlperf_loadgen as lg
+from dataset import Dataset
+
+logging.basicConfig(level=logging.INFO)
+log = logging.getLogger("GPT-J-SUT")
+
+gen_kwargs = {
+ "early_stopping": True,
+ "max_new_tokens": 128,
+ "min_new_tokens": 30,
+ "num_beams": 4,
+}
+
+
+
+class FirstTokenStreamer(BaseStreamer):
+ """ Streams first tokens to a 'holder' """
+
+ def __init__(self, first_token, tokens_cache=[], is_first_token=True, response_ids=[] ):
+ """ Response ids added to 'sign' the first token"""
+
+ self.first_token = first_token # Queue for first token
+ self.is_first_token = is_first_token
+
+ # Cache for subsequent generated tokens
+ self.tokens_cache = tokens_cache
+
+ self.response_ids = response_ids
+
+ self.is_prompt = True # The first tokens sent to the streamer are actually the input prompts
+
+ def put(self, value):
+ """ Caches the tokens as they're generated. Assumes bs=1 """
+
+ # Prompts are streamed first so we need to skip the first time value that arrives
+ if self.is_prompt:
+ self.is_prompt = False
+ return
+
+ value = value.item()
+ if self.is_first_token:
+
+ # Add generated first token together with its query response_id to first tokens queue
+ self.first_token.put((value, self.response_ids[0]))
+
+ self.is_first_token = False
+ return
+
+ self.tokens_cache.append(value)
+
+
+ def end(self):
+ pass
+
+ def get_out_tokens(self):
+ return self.tokens_cache
+
+
+class SUT():
+ def __init__(self,
+ model_path=None,
+ api_server=None,
+ api_model_name=None,
+ additional_servers=[],
+ grpc=False,
+ batch_grpc=False,
+ vllm=False,
+ dtype="bfloat16",
+ device="cpu",
+ batch_size=None,
+ total_sample_count=13368,
+ dataset_path=None,
+ use_cached_outputs=False, # Set this to True *only for test accuracy runs* in case your prior session was killed partway through
+ workers=1):
+
+ self.model_path = model_path or "EleutherAI/gpt-j-6B"
+ self.api_model_name = api_model_name
+ self.api_servers = []
+ if api_server:
+ self.api_servers.append(api_server)
+ if additional_servers and not api_server:
+ sys.exit("Additional servers cannot be used without primary api server")
+ for server in additional_servers:
+ self.api_servers.append(server)
+ self.grpc = grpc
+ self.batch_grpc = batch_grpc
+ self.vllm = vllm
+ if self.vllm and (self.grpc or self.batch_grpc):
+ sys.exit("vllm does not support grpc")
+ self.device = device
+
+ if not batch_size:
+ if device == "cpu": # Also applies to API server mode
+ batch_size = 31192
+ else:
+ batch_size = 32 # Reduce to 8 if using 4 GPUs, 16 for 8.
+ self.batch_size = batch_size
+
+ # dtype
+ if dtype == 'bfloat16': # Irrelevant for API server mode
+ self.amp_enabled = True
+ self.amp_dtype = torch.bfloat16
+ elif dtype == 'float16':
+ self.amp_enabled = True
+ self.amp_dtype = torch.float16
+ else:
+ self.amp_enabled = False
+ self.amp_dtype = torch.float32
+
+ if 'cuda' in self.device:
+ assert torch.cuda.is_available(), "torch gpu is not available, exiting..."
+
+ self.dataset_path = dataset_path
+ self.data_object = Dataset(dataset_path=self.dataset_path,
+ total_count_override=total_sample_count)
+ self.qsl = lg.ConstructQSL(self.data_object.count, self.data_object.perf_count,
+ self.data_object.LoadSamplesToRam, self.data_object.UnloadSamplesFromRam)
+ self.load_model()
+
+ self.num_workers = workers
+ self.worker_threads = [None] * self.num_workers
+ self.query_queue = queue.Queue()
+
+ self.use_cached_outputs = use_cached_outputs
+ self.sample_counter = 0
+ self.sample_counter_lock = threading.Lock()
+
+
+ def start(self):
+ # Create worker threads
+ for j in range(self.num_workers):
+ worker = threading.Thread(target=self.process_queries)
+ worker.start()
+ self.worker_threads[j] = worker
+
+ def stop(self):
+ for _ in range(self.num_workers):
+ self.query_queue.put(None)
+
+ for worker in self.worker_threads:
+ worker.join()
+
+
+ def query_api(self, input, idx):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model_id': self.api_model_name,
+ 'inputs': input,
+ 'parameters': {
+ 'max_new_tokens': 128,
+ 'min_new_tokens': 30,
+ 'decoding_method': "GREEDY"
+ },
+ }
+
+ response_code = 0
+ while response_code != 200:
+ try:
+ response = requests.post(
+ self.api_servers[idx],
+ headers=headers,
+ json=json_data,
+ verify=False,
+ )
+ response_code = response.status_code
+ except:
+ print("connection failure")
+ return json.loads(response.text)["generated_text"]
+
+ def query_api_vllm(self, inputs, idx):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model': '/mnt/models/',
+ 'prompt': inputs,
+ 'max_tokens': 128,
+ 'temperature': 0,
+ }
+
+ response_code = 0
+ while response_code != 200:
+ try:
+ response = requests.post(f'{self.api_servers[idx]}/v1/completions', headers=headers, json=json_data, verify=False)
+ response_code = response.status_code
+ except:
+ print("connection failure")
+ return [resp["text"] for resp in json.loads(response.text)["choices"]]
+
+ def query_api_grpc(self, input, idx):
+ resp = self.grpc_clients[idx].make_request([input], model_id=self.api_model_name)
+ return resp.responses[0].text
+
+ def query_api_batch_grpc(self, inputs, idx):
+ resps = self.grpc_clients[idx].make_request(inputs, model_id=self.api_model_name)
+ return [resp.text for resp in resps.responses]
+
+ def api_action_handler(self, chunk, server_idx):
+ if self.grpc:
+ if self.batch_grpc:
+ output = self.query_api_batch_grpc(chunk, server_idx)
+ else:
+ with ThreadPoolExecutor(max_workers=len(chunk)) as executor:
+ output = list(executor.map(self.query_api_grpc,chunk, repeat(server_idx)))
+ elif self.vllm:
+ output = self.query_api_vllm(chunk, server_idx)
+ else:
+ with ThreadPoolExecutor(max_workers=len(chunk)) as executor:
+ output = list(executor.map(self.query_api,chunk, repeat(server_idx)))
+ return output
+
+ def process_queries(self):
+ """Processor of the queued queries. User may choose to add batching logic """
+
+ while True:
+ qitem = self.query_queue.get()
+ if qitem is None:
+ break
+
+ query_ids = [q.index for q in qitem]
+
+ fname = "q" + "_".join([str(i) for i in query_ids])
+ fname = f"run_outputs/{fname}.pkl"
+ _p = Path(fname)
+ if self.use_cached_outputs and _p.exists():
+ # Read cache
+ with _p.open(mode="rb") as f:
+ d = pickle.load(f)
+ processed_output = d["outputs"]
+ tik1 = None
+ tik2 = None
+ tik3 = None
+ tok = None
+ else:
+ # Construct / collate batch
+ tik1 = time.time()
+
+ input_ids_tensor = []
+ input_masks_tensor = []
+ input_len = []
+ for q in qitem:
+ input_ids_tensor.append(self.data_object.source_encoded_input_ids[q.index])
+ if self.api_servers:
+ cleaned_chunks = [list(c) for c in mit.divide(len(self.api_servers), input_ids_tensor)]
+
+ tik2 = time.time()
+
+ if self.api_servers:
+ with ThreadPoolExecutor(max_workers=len(self.api_servers)) as executor:
+ output_chunks = list(executor.map(self.api_action_handler,cleaned_chunks,range(len(self.api_servers))))
+ output = []
+ for row in output_chunks:
+ output += row
+ else:
+ pred_output_tokens = self.model.generate(
+ input_ids=input_ids_tensor,
+ attention_mask=input_masks_tensor,
+ pad_token_id=self.tokenizer.pad_token_id,
+ **gen_kwargs
+ )
+
+ tik3 = time.time()
+
+ if self.api_servers:
+ processed_output = np.array(self.tokenizer(output, padding='longest')['input_ids'])
+ else:
+ processed_output = self.data_object.postProcess(pred_output_tokens,
+ input_seq_lens=input_len,
+ query_id_list=query_ids)
+
+ for i in range(len(qitem)):
+ n_tokens = processed_output[i].shape[0]
+ response_array = array.array("B", processed_output[i].tobytes())
+ bi = response_array.buffer_info()
+ response = [lg.QuerySampleResponse(qitem[i].id, bi[0], bi[1], n_tokens)]
+ lg.QuerySamplesComplete(response)
+
+ tok = time.time()
+
+ with self.sample_counter_lock:
+ self.sample_counter += len(qitem)
+ print(f"Samples run: {self.sample_counter}")
+ if tik1:
+ print(f"\tBatchMaker time: {tik2 - tik1}")
+ print(f"\tInference time: {tik3 - tik2}")
+ print(f"\tPostprocess time: {tok - tik3}")
+ print(f"\t==== Total time: {tok - tik1}")
+ else:
+ print(f"\tLoaded from cache: {_p}")
+
+
+ def load_model(self):
+ if self.api_servers:
+ if not self.api_model_name:
+ sys.exit("API Server was specified but no model name was provided")
+ self.grpc_clients = []
+ for server in self.api_servers:
+ if self.grpc:
+ hostname = re.sub("https://|http://", "", server)
+ if hostname[-1] == "/":
+ hostname = hostname[:-1]
+ grpc_client = GrpcClient(
+ hostname,
+ 443,
+ verify=False,
+ )
+ self.grpc_clients.append(grpc_client)
+ elif not "http" in server:
+ server = "http://" + server
+
+ if not self.api_model_name:
+ sys.exit("API Server was specified but no model name was provided")
+ else:
+ sys.exit("ONLY API SERVER MODE SUPPORTED FOR GPT-J")
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ device_map="auto",
+ low_cpu_mem_usage=True,
+ torch_dtype=self.amp_dtype
+ )
+ print("Loaded model")
+
+ self.device = torch.device(self.device)
+ if self.device == "cpu":
+ self.model = self.model.to(self.device) # Force CPU if your system has GPU and you specifically want CPU-only run
+
+ self.model.eval()
+ self.model = self.model.to(memory_format=torch.channels_last)
+
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.model_path,
+ model_max_length=1024,
+ padding_side="left",
+ use_fast=True,) #changed from false
+
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ print("Loaded tokenizer")
+
+ def get_sut(self):
+ self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries)
+ return self.sut
+
+ def get_qsl(self):
+ return self.qsl
+
+
+ def predict(self,**kwargs):
+ raise NotImplementedError
+
+
+ def issue_queries(self, query_samples):
+ """ Receives samples from loadgen and adds them to queue. Users may choose to batch here"""
+
+ list_prompts_tokens = []
+ list_prompts_attn_masks = []
+
+ print(f"IssueQuery started with {len(query_samples)} samples")
+ while len(query_samples) > 0:
+ self.query_queue.put(query_samples[:self.batch_size])
+ query_samples = query_samples[self.batch_size:]
+ print(f"IssueQuery done")
+
+
+ def flush_queries(self):
+ pass
+
+ def __del__(self):
+ pass
+
+
+class SUTServer(SUT):
+ def __init__(self, model_path=None, api_server=None, additional_servers=[], api_model_name=None, grpc=False, batch_grpc=False, vllm=False, dtype="bfloat16", device="cpu", total_sample_count=24576, dataset_path=None, workers=1):
+
+ super().__init__(model_path=model_path, api_server=api_server, additional_servers=additional_servers, api_model_name=api_model_name, grpc=grpc, vllm=vllm, dtype=dtype, device=device, total_sample_count=total_sample_count, dataset_path=dataset_path, workers=workers)
+
+ with open(f"{self.model_path}/tokenizer.json", 'r') as token_file:
+ gpt_tokenizer = json.load(token_file)
+ self.gpt_vocab = gpt_tokenizer["model"]["vocab"]
+
+ self.first_token_queue = queue.Queue()
+
+
+ def start(self):
+
+ # Create worker threads
+ for j in range(self.num_workers):
+ worker = threading.Thread(target=self.process_queries)
+ worker.start()
+ self.worker_threads[j] = worker
+
+ # Create first token response thread
+ self.ft_response_thread = threading.Thread(target=self.process_first_tokens)
+ self.ft_response_thread.start()
+
+
+ def process_first_tokens(self):
+
+ while True:
+ first_token_item = self.first_token_queue.get()
+
+ if first_token_item is None:
+ log.info("Exiting First token response thread")
+ break
+
+ first_tokens, response_id = first_token_item
+
+ response_data = array.array("B", np.array(first_tokens, np.float32).tobytes())
+ bi = response_data.buffer_info()
+ response = [lg.QuerySampleResponse(response_id, bi[0], bi[1])]
+ lg.FirstTokenComplete(response)
+
+ def stream_api(self, input, response_ids, idx):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model_id': 'GPT-J',
+ 'inputs': input,
+ 'parameters': {
+ 'max_new_tokens': 128,
+ 'min_new_tokens': 30,
+ 'decoding_method': "GREEDY"
+ },
+ }
+
+ token_cache = []
+ s = requests.Session()
+ first = True
+ with s.post(
+ self.api_servers[idx],
+ headers=headers,
+ json=json_data,
+ verify=False,
+ stream=True
+ ) as resp:
+ for line in resp.iter_lines():
+ if line:
+ decoded = line.decode()
+ if decoded.startswith("data"):
+ token_l = json.loads(decoded[6:])["tokens"]
+ if token_l:
+ token = self.gpt_vocab[token_l[0]["text"]]
+ if first:
+ self.first_token_queue.put((token, response_ids[0]))
+ first = False
+ else:
+ token_cache.append(token)
+ return token_cache
+
+ def stream_api_grpc(self, input, response_ids, idx):
+ token_cache = []
+ first = True
+ resps = self.grpc_clients[idx].make_request_stream(input, model_id=self.api_model_name)
+ for resp in resps:
+ if resp.tokens:
+ token = self.gpt_vocab[resp.tokens[0].text]
+ if first:
+ self.first_token_queue.put((token, response_ids[0]))
+ first = False
+ else:
+ token_cache.append(token)
+ return token_cache
+
+ def stream_api_vllm(self, input, response_ids, idx):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model': '/mnt/models/',
+ 'prompt': input,
+ 'max_tokens': 128,
+ 'temperature': 0,
+ 'stream': True,
+ 'logprobs': 1
+ }
+
+ while True:
+ try:
+ token_cache = []
+ s = requests.Session()
+ first = True
+ with s.post(
+ f'{self.api_servers[idx]}/v1/completions',
+ headers=headers,
+ json=json_data,
+ verify=False,
+ stream=True
+ ) as resp:
+ for line in resp.iter_lines():
+ if line:
+ decoded = line.decode()
+ if decoded.startswith("data") and "[DONE]" not in decoded:
+ inter = json.loads(decoded[6:])["choices"][0]["logprobs"]
+ if "top_logprobs" in inter:
+ token_s = list(inter["top_logprobs"][0].keys())[0]
+ token = self.gpt_vocab[token_s]
+ if first:
+ self.first_token_queue.put((token, response_ids[0]))
+ first = False
+ else:
+ token_cache.append(token)
+ s.close()
+ return token_cache
+ except:
+ s.close()
+ print("Connection failure")
+
+
+ def async_process_query(self, input_ids_tensor, qitem_id, idx):
+ decoded = input_ids_tensor
+ response_ids = [qitem_id]
+ if self.grpc:
+ output_tokens = self.stream_api_grpc(decoded, response_ids, idx)
+ elif self.vllm:
+ output_tokens = self.stream_api_vllm(decoded, response_ids, idx)
+ else:
+ output_tokens = self.stream_api(decoded, response_ids, idx)
+
+ n_tokens = len(output_tokens)
+ response_array = array.array("B", np.array(output_tokens, np.int32).tobytes())
+ bi = response_array.buffer_info()
+ response = [lg.QuerySampleResponse(
+ qitem_id, bi[0], bi[1], n_tokens)]
+ lg.QuerySamplesComplete(response)
+ return
+
+ def process_queries(self):
+ """Processor of the queued queries. User may choose to add batching logic """
+ server_idx = 0
+ while True:
+
+ qitem = self.query_queue.get()
+ if qitem is None:
+ break
+
+ input_ids_tensor = self.data_object.source_encoded_input_ids[qitem.index]
+ input_masks_tensor = []#self.data_object.source_encoded_attn_masks[qitem.index]
+
+ if self.api_servers:
+ threading.Thread(target=self.async_process_query, args=(input_ids_tensor, qitem.id, server_idx)).start()
+ server_idx = (server_idx + 1) % len(self.api_servers)
+ else:
+ #TODO: This PoC is super slow with significant overhead. Best to create a patch to `generate`
+ tokens_cache = []
+ tokens_streamer = FirstTokenStreamer(self.first_token_queue, tokens_cache=tokens_cache, is_first_token=True, response_ids=[qitem.id])
+
+ _ = self.model.generate( input_ids=input_ids_tensor,
+ attention_mask=input_masks_tensor,
+ pad_token_id=self.tokenizer.pad_token_id,
+ streamer = tokens_streamer,
+ **gen_kwargs
+ )
+
+ output_tokens = tokens_streamer.get_out_tokens()
+
+ n_tokens = len(output_tokens)
+ response_array = array.array("B", np.array(output_tokens, np.int32).tobytes())
+ bi = response_array.buffer_info()
+ response = [lg.QuerySampleResponse(
+ qitem.id, bi[0], bi[1], n_tokens)]
+ lg.QuerySamplesComplete(response)
+
+
+ def issue_queries(self, query_samples):
+
+ self.query_queue.put(query_samples[0])
+
+
+ def stop(self):
+ for _ in range(self.num_workers):
+ self.query_queue.put(None)
+
+ for worker in self.worker_threads:
+ worker.join()
+
+ self.first_token_queue.put(None)
+ self.ft_response_thread.join()
diff --git a/language/gpt-j/SUT_local.py b/language/gpt-j/SUT_local.py
new file mode 100644
index 000000000..769c2b33a
--- /dev/null
+++ b/language/gpt-j/SUT_local.py
@@ -0,0 +1,549 @@
+import os
+import sys
+import time
+import re
+import numpy as np
+import array
+import torch
+from torch.nn.functional import pad
+from torch.utils.data import DataLoader
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation.streamers import BaseStreamer
+
+import pickle
+import time
+import threading
+import tqdm
+import queue
+
+from concurrent.futures.thread import ThreadPoolExecutor
+import requests
+from urllib3.exceptions import InsecureRequestWarning
+import json
+
+from inference import GrpcClient
+
+from vllm import LLM, SamplingParams
+
+requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
+
+import logging
+from typing import TYPE_CHECKING, Optional, List
+from pathlib import Path
+
+import mlperf_loadgen as lg
+from dataset import Dataset
+
+logging.basicConfig(level=logging.INFO)
+log = logging.getLogger("GPT-J-SUT")
+
+gen_kwargs = {
+ "early_stopping": True,
+ "max_new_tokens": 128,
+ "min_new_tokens": 30,
+ "num_beams": 4,
+}
+
+
+
+class FirstTokenStreamer(BaseStreamer):
+ """ Streams first tokens to a 'holder' """
+
+ def __init__(self, first_token, tokens_cache=[], is_first_token=True, response_ids=[] ):
+ """ Response ids added to 'sign' the first token"""
+
+ self.first_token = first_token # Queue for first token
+ self.is_first_token = is_first_token
+
+ # Cache for subsequent generated tokens
+ self.tokens_cache = tokens_cache
+
+ self.response_ids = response_ids
+
+ self.is_prompt = True # The first tokens sent to the streamer are actually the input prompts
+
+ def put(self, value):
+ """ Caches the tokens as they're generated. Assumes bs=1 """
+
+ # Prompts are streamed first so we need to skip the first time value that arrives
+ if self.is_prompt:
+ self.is_prompt = False
+ return
+
+ value = value.item()
+ if self.is_first_token:
+
+ # Add generated first token together with its query response_id to first tokens queue
+ self.first_token.put((value, self.response_ids[0]))
+
+ self.is_first_token = False
+ return
+
+ self.tokens_cache.append(value)
+
+
+ def end(self):
+ pass
+
+ def get_out_tokens(self):
+ return self.tokens_cache
+
+
+class SUT():
+ def __init__(self,
+ model_path=None,
+ api_server=None,
+ api_model_name=None,
+ grpc=False,
+ batch_grpc=False,
+ vllm=False,
+ dtype="bfloat16",
+ device="cpu",
+ batch_size=None,
+ total_sample_count=13368,
+ dataset_path=None,
+ use_cached_outputs=False, # Set this to True *only for test accuracy runs* in case your prior session was killed partway through
+ workers=1):
+
+ self.model_path = model_path or "EleutherAI/gpt-j-6B"
+ self.api_server = api_server
+ self.api_model_name = api_model_name
+ self.grpc = grpc
+ self.batch_grpc = batch_grpc
+ self.vllm = vllm
+ if self.vllm and (self.grpc or self.batch_grpc):
+ sys.exit("vllm does not support grpc")
+ self.device = device
+
+ if not batch_size:
+ if device == "cpu":
+ batch_size = 128
+ else:
+ batch_size = 32 # Reduce to 8 if using 4 GPUs, 16 for 8.
+ self.batch_size = batch_size
+
+ # dtype
+ if dtype == 'bfloat16':
+ self.amp_enabled = True
+ self.amp_dtype = torch.bfloat16
+ elif dtype == 'float16':
+ self.amp_enabled = True
+ self.amp_dtype = torch.float16
+ else:
+ self.amp_enabled = False
+ self.amp_dtype = torch.float32
+
+ if 'cuda' in self.device:
+ assert torch.cuda.is_available(), "torch gpu is not available, exiting..."
+
+ self.dataset_path = dataset_path
+ self.data_object = Dataset(dataset_path=self.dataset_path,
+ total_count_override=total_sample_count)
+ self.qsl = lg.ConstructQSL(self.data_object.count, self.data_object.perf_count,
+ self.data_object.LoadSamplesToRam, self.data_object.UnloadSamplesFromRam)
+ self.load_model()
+
+ self.num_workers = workers
+ self.worker_threads = [None] * self.num_workers
+ self.query_queue = queue.Queue()
+
+ self.use_cached_outputs = use_cached_outputs
+ self.sample_counter = 0
+ self.sample_counter_lock = threading.Lock()
+
+
+ def start(self):
+ # Create worker threads
+ for j in range(self.num_workers):
+ worker = threading.Thread(target=self.process_queries)
+ worker.start()
+ self.worker_threads[j] = worker
+
+ def stop(self):
+ for _ in range(self.num_workers):
+ self.query_queue.put(None)
+
+ for worker in self.worker_threads:
+ worker.join()
+
+
+ def query_api(self, input):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model_id': self.api_model_name,
+ 'inputs': input,
+ 'parameters': {
+ 'max_new_tokens': 1024,
+ 'min_new_tokens': 1,
+ 'decoding_method': "GREEDY"
+ },
+ }
+
+ response_code = 0
+ while response_code != 200:
+ try:
+ response = requests.post(
+ self.api_server,
+ headers=headers,
+ json=json_data,
+ verify=False,
+ )
+ response_code = response.status_code
+ except:
+ print("connection failure")
+ return json.loads(response.text)["generated_text"]
+
+ def query_vllm(self, inputs):
+ sampling_params = SamplingParams(
+ max_tokens=128,
+ use_beam_search=True,
+ best_of=4,
+ temperature=0,
+ early_stopping=True,
+ )
+
+ outputs = self.llm.generate(inputs, sampling_params)
+ return [output.outputs[0].text for output in outputs]
+
+ def query_api_grpc(self, input):
+ resp = self.grpc_client.make_request([input], model_id=self.api_model_name)
+ return resp.responses[0].text
+
+ def query_api_batch_grpc(self, inputs):
+ resps = self.grpc_client.make_request(inputs, model_id=self.api_model_name)
+ return [resp.text for resp in resps.responses]
+
+ def process_queries(self):
+ """Processor of the queued queries. User may choose to add batching logic """
+
+ while True:
+ qitem = self.query_queue.get()
+ if qitem is None:
+ break
+
+ query_ids = [q.index for q in qitem]
+
+ fname = "q" + "_".join([str(i) for i in query_ids])
+ fname = f"run_outputs/{fname}.pkl"
+ _p = Path(fname)
+ if self.use_cached_outputs and _p.exists():
+ # Read cache
+ with _p.open(mode="rb") as f:
+ d = pickle.load(f)
+ processed_output = d["outputs"]
+ tik1 = None
+ tik2 = None
+ tik3 = None
+ tok = None
+ else:
+ # Construct / collate batch
+ tik1 = time.time()
+
+ input_ids_tensor = []
+ input_masks_tensor = []
+ input_len = []
+ for q in qitem:
+ input_ids_tensor.append(self.data_object.source_encoded_input_ids[q.index])
+ #input_masks_tensor.append(self.data_object.source_encoded_attn_masks[q.index])
+ #input_len.append(self.data_object.input_lens[q.index])
+ #input_ids_tensor = torch.cat(input_ids_tensor)
+ #input_masks_tensor = torch.cat(input_masks_tensor)
+
+ #assert input_ids_tensor.shape == input_masks_tensor.shape
+ #assert input_ids_tensor.shape[0] <= self.batch_size
+
+ if self.api_server:
+ #decoded = self.tokenizer.batch_decode(input_ids_tensor)
+ #cleaned = [entry.replace('','').replace('','') for entry in decoded]
+ bs = len(input_ids_tensor)
+
+ tik2 = time.time()
+
+ if self.api_server:
+ if self.grpc:
+ if self.batch_grpc:
+ output = self.query_api_batch_grpc(input_ids_tensor)
+ else:
+ with ThreadPoolExecutor(max_workers=bs) as executor:
+ output = list(executor.map(self.query_api_grpc,input_ids_tensor))
+ elif self.vllm:
+ output = self.query_vllm(input_ids_tensor)
+ else:
+ with ThreadPoolExecutor(max_workers=bs) as executor:
+ output = list(executor.map(self.query_api,input_ids_tensor))
+ else:
+ pred_output_tokens = self.model.generate(
+ input_ids=input_ids_tensor,
+ attention_mask=input_masks_tensor,
+ pad_token_id=self.tokenizer.pad_token_id,
+ **gen_kwargs
+ )
+
+ tik3 = time.time()
+
+ if self.api_server:
+ processed_output = np.array(self.tokenizer(output, padding='longest')['input_ids'])
+ else:
+ processed_output = self.data_object.postProcess(pred_output_tokens,
+ input_seq_lens=input_len,
+ query_id_list=query_ids)
+
+ for i in range(len(qitem)):
+ n_tokens = processed_output[i].shape[0]
+ response_array = array.array("B", processed_output[i].tobytes())
+ bi = response_array.buffer_info()
+ response = [lg.QuerySampleResponse(qitem[i].id, bi[0], bi[1], n_tokens)]
+ lg.QuerySamplesComplete(response)
+
+ tok = time.time()
+
+ with self.sample_counter_lock:
+ self.sample_counter += len(qitem)
+ print(f"Samples run: {self.sample_counter}")
+ if tik1:
+ print(f"\tBatchMaker time: {tik2 - tik1}")
+ print(f"\tInference time: {tik3 - tik2}")
+ print(f"\tPostprocess time: {tok - tik3}")
+ print(f"\t==== Total time: {tok - tik1}")
+ else:
+ print(f"\tLoaded from cache: {_p}")
+
+
+ def load_model(self):
+ if self.api_server:
+ if self.grpc:
+ hostname = re.sub("https://|http://", "", self.api_server)
+ if hostname[-1] == "/":
+ hostname = hostname[:-1]
+ self.grpc_client = GrpcClient(
+ hostname,
+ 443,
+ verify=False,
+ )
+ elif self.vllm:
+ self.llm = LLM(model="/workspace/gpt-model-info/", dtype="float16")
+ elif not "http" in self.api_server:
+ self.api_server = "http://" + self.api_server
+
+ if not self.api_model_name:
+ sys.exit("API Server was specified but no model name was provided")
+ else:
+ sys.exit("ONLY API SERVER MODE SUPPORTED FOR GPT-J")
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ device_map="auto",
+ low_cpu_mem_usage=True,
+ torch_dtype=self.amp_dtype
+ )
+ print("Loaded model")
+
+ self.device = torch.device(self.device)
+ if self.device == "cpu":
+ self.model = self.model.to(self.device) # Force CPU if your system has GPU and you specifically want CPU-only run
+
+ self.model.eval()
+ self.model = self.model.to(memory_format=torch.channels_last)
+
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.model_path,
+ model_max_length=1024,
+ padding_side="left",
+ use_fast=True,) #changed from false
+
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ print("Loaded tokenizer")
+
+ def get_sut(self):
+ self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries)
+ return self.sut
+
+ def get_qsl(self):
+ return self.qsl
+
+
+ def predict(self,**kwargs):
+ raise NotImplementedError
+
+
+ def issue_queries(self, query_samples):
+ """ Receives samples from loadgen and adds them to queue. Users may choose to batch here"""
+
+ list_prompts_tokens = []
+ list_prompts_attn_masks = []
+
+ print(f"IssueQuery started with {len(query_samples)} samples")
+ while len(query_samples) > 0:
+ self.query_queue.put(query_samples[:self.batch_size])
+ query_samples = query_samples[self.batch_size:]
+ print(f"IssueQuery done")
+
+
+ def flush_queries(self):
+ pass
+
+ def __del__(self):
+ pass
+
+
+class SUTServer(SUT):
+ def __init__(self, model_path=None, api_server=None, api_model_name=None, grpc=False, batch_grpc=False, dtype="bfloat16", device="cpu", total_sample_count=24576, dataset_path=None, workers=1):
+
+ super().__init__(model_path=model_path, api_server=api_server, api_model_name=api_model_name, grpc=grpc, dtype=dtype, device=device, total_sample_count=total_sample_count, dataset_path=dataset_path, workers=workers)
+
+ with open(f"{self.model_path}/tokenizer.json", 'r') as token_file:
+ gpt_tokenizer = json.load(token_file)
+ self.gpt_vocab = gpt_tokenizer["model"]["vocab"]
+
+ self.first_token_queue = queue.Queue()
+
+
+ def start(self):
+
+ # Create worker threads
+ for j in range(self.num_workers):
+ worker = threading.Thread(target=self.process_queries)
+ worker.start()
+ self.worker_threads[j] = worker
+
+ # Create first token response thread
+ self.ft_response_thread = threading.Thread(target=self.process_first_tokens)
+ self.ft_response_thread.start()
+
+
+ def process_first_tokens(self):
+
+ while True:
+ first_token_item = self.first_token_queue.get()
+
+ if first_token_item is None:
+ log.info("Exiting First token response thread")
+ break
+
+ first_tokens, response_id = first_token_item
+
+ response_data = array.array("B", np.array(first_tokens, np.float32).tobytes())
+ bi = response_data.buffer_info()
+ response = [lg.QuerySampleResponse(response_id, bi[0], bi[1])]
+ lg.FirstTokenComplete(response)
+
+ def stream_api(self, input, response_ids):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model_id': 'GPT-J',
+ 'inputs': input,
+ 'parameters': {
+ 'max_new_tokens': 1024,
+ 'min_new_tokens': 1,
+ 'decoding_method': "GREEDY"
+ },
+ }
+
+ token_cache = []
+ s = requests.Session()
+ first = True
+ with s.post(
+ self.api_server,
+ headers=headers,
+ json=json_data,
+ verify=False,
+ stream=True
+ ) as resp:
+ for line in resp.iter_lines():
+ if line:
+ decoded = line.decode()
+ if decoded.startswith("data"):
+ token_l = json.loads(decoded[6:])["tokens"]
+ if token_l:
+ token = self.gpt_vocab[token_l[0]["text"]]
+ if first:
+ self.first_token_queue.put((token, response_ids[0]))
+ first = False
+ else:
+ token_cache.append(token)
+ return token_cache
+
+ def stream_api_grpc(self, input, response_ids):
+ token_cache = []
+ first = True
+ resps = self.grpc_client.make_request_stream(input, model_id=self.api_model_name)
+ for resp in resps:
+ if resp.tokens:
+ token = self.gpt_vocab[resp.tokens[0].text]
+ if first:
+ self.first_token_queue.put((token, response_ids[0]))
+ first = False
+ else:
+ token_cache.append(token)
+ return token_cache
+
+ def async_process_query(self, input_ids_tensor, qitem_id):
+ decoded = input_ids_tensor
+ response_ids = [qitem_id]
+ if self.grpc:
+ output_tokens = self.stream_api_grpc(decoded, response_ids)
+ else:
+ output_tokens = self.stream_api(decoded, response_ids)
+
+ n_tokens = len(output_tokens)
+ response_array = array.array("B", np.array(output_tokens, np.int32).tobytes())
+ bi = response_array.buffer_info()
+ response = [lg.QuerySampleResponse(
+ qitem_id, bi[0], bi[1], n_tokens)]
+ lg.QuerySamplesComplete(response)
+ sys.exit()
+
+ def process_queries(self):
+ """Processor of the queued queries. User may choose to add batching logic """
+ while True:
+
+ qitem = self.query_queue.get()
+ if qitem is None:
+ break
+
+ input_ids_tensor = self.data_object.source_encoded_input_ids[qitem.index]
+ input_masks_tensor = []#self.data_object.source_encoded_attn_masks[qitem.index]
+
+ if self.api_server:
+ threading.Thread(target=self.async_process_query, args=(input_ids_tensor, qitem.id)).start()
+ else:
+ #TODO: This PoC is super slow with significant overhead. Best to create a patch to `generate`
+ tokens_cache = []
+ tokens_streamer = FirstTokenStreamer(self.first_token_queue, tokens_cache=tokens_cache, is_first_token=True, response_ids=[qitem.id])
+
+ _ = self.model.generate( input_ids=input_ids_tensor,
+ attention_mask=input_masks_tensor,
+ pad_token_id=self.tokenizer.pad_token_id,
+ streamer = tokens_streamer,
+ **gen_kwargs
+ )
+
+ output_tokens = tokens_streamer.get_out_tokens()
+
+ n_tokens = len(output_tokens)
+ response_array = array.array("B", np.array(output_tokens, np.int32).tobytes())
+ bi = response_array.buffer_info()
+ response = [lg.QuerySampleResponse(
+ qitem.id, bi[0], bi[1], n_tokens)]
+ lg.QuerySamplesComplete(response)
+
+
+ def issue_queries(self, query_samples):
+
+ self.query_queue.put(query_samples[0])
+
+
+ def stop(self):
+ for _ in range(self.num_workers):
+ self.query_queue.put(None)
+
+ for worker in self.worker_threads:
+ worker.join()
+
+ self.first_token_queue.put(None)
+ self.ft_response_thread.join()
diff --git a/language/gpt-j/api-endpoint-artifacts/Dockerfile-API b/language/gpt-j/api-endpoint-artifacts/Dockerfile-API
new file mode 100644
index 000000000..163974c7f
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/Dockerfile-API
@@ -0,0 +1,27 @@
+FROM pytorch/pytorch:latest
+
+COPY inference ./inference
+COPY cnn_eval.json ./cnn_eval.json
+COPY gpt-model-info ./gpt-model-info
+
+RUN apt-get update && apt install build-essential -y
+RUN conda install pybind11==2.10.4 -c conda-forge -y
+#RUN conda install mkl mkl-include -y
+#RUN conda install gperftools jemalloc==5.2.1 -c conda-forge -y
+RUN pip install --upgrade pip
+RUN pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 grpcio-tools datasets simplejson
+RUN cd inference/loadgen && python -m pip install .
+#RUN cd inference/loadgen && CFLAGS="-std=c++14 -O3" python setup.py bdist_wheel && cd .. && pip install --force-reinstall loadgen/dist/`ls -r loadgen/dist/ | head -n1` ;
+RUN cp inference/mlperf.conf inference/language/gpt-j/mlperf.conf
+
+ENV DATASET_PATH=/workspace/cnn_eval.json
+ENV CHECKPOINT_PATH=/workspace/gpt-model-info
+ENV ACCURACY_LOG_FILE=/workspace/inference/language/gpt-j/offline-logs/mlperf_log_accuracy.json
+
+#ENV KMP_BLOCKTIME=1
+#ENV KMP_SETTINGS=1
+#ENV KMP_AFFINITY=granularity=fine,compact,1,0
+# IOMP
+#ENV LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libiomp5.so
+# Tcmalloc is a recommended malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support.
+#ENV LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
diff --git a/language/gpt-j/api-endpoint-artifacts/README.md b/language/gpt-j/api-endpoint-artifacts/README.md
new file mode 100644
index 000000000..895ee47fa
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/README.md
@@ -0,0 +1,39 @@
+# Using OpenShift AI Model Serving (TGIS) with MLPerf Inference
+
+Prerequisites:
+ - Install the OpenShift AI model serving stack
+ - Add your AWS credentials to `secret.yaml` access the model files
+ - Apply `secret.yaml`, `sa.yaml`
+ - FOR TGIS: Apply `serving-tgis.yaml`, then finally `model.yaml`
+ - FOR VLLM: Apply `serving-vllm.yaml`, then finally `model-vllm.yaml`
+ - Create a benchmark pod using `benchmark.yaml`
+
+In the pod, before any benchmark, first run `cd inference/language/gpt-j`
+
+## STANDALONE TGIS INSTRUCTIONS
+For the full accuracy benchmark (offline), run in the pod:
+```
+python3 -u main.py --scenario Offline --model-path ${CHECKPOINT_PATH} --api-server --api-model-name gpt-j-cnn --mlperf-conf mlperf.conf --accuracy --vllm --user-conf user.conf --dataset-path ${DATASET_PATH} --output-log-dir offline-logs --dtype float32 --device cpu 2>&1 | tee offline_performance_log.log
+```
+You can then run the same evaluation/consolidation scripts as the regular benchmark
+
+Example API host: `https://gpt-j-isvc-predictor-gpt-service.apps.gdr-perf.perf.eng.bos2.dc.redhat.com`
+
+
+For the performance benchmark (offline), run in the pod:
+```
+python3 -u main.py --scenario Offline --model-path ${CHECKPOINT_PATH} --api-server --api-model-name gpt-j-cnn --mlperf-conf mlperf.conf --vllm --user-conf user.conf --dataset-path ${DATASET_PATH} --output-log-dir offline-logs --dtype float32 --device cpu 2>&1 | tee offline_performance_log.log
+```
+(It is the same, just with `--accuracy` removed)
+
+ - For multiple endpoints, add `--additional-servers ...`
+
+
+For the performance benchmark (server), run in the pod:
+```
+python3 -u main.py --scenario Server --model-path ${CHECKPOINT_PATH} --api-server --api-model-name gpt-j-cnn --mlperf-conf mlperf.conf --vllm --user-conf user.conf --dataset-path ${DATASET_PATH} --output-log-dir server-logs --dtype float32 --device cpu 2>&1 | tee server_performance_log.log
+```
+(Configure target qps in `user.conf`)
+
+
+NOTE: Hyperparams are currently configured for N instance x H100 80GB
diff --git a/language/gpt-j/api-endpoint-artifacts/benchmark.yaml b/language/gpt-j/api-endpoint-artifacts/benchmark.yaml
new file mode 100644
index 000000000..b117ddbbc
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/benchmark.yaml
@@ -0,0 +1,22 @@
+apiVersion: v1
+kind: Pod
+metadata:
+ name: mlperf-inference-gpt-2
+spec:
+ restartPolicy: Never
+ containers:
+ - name: mlperf-env
+ image: quay.io/meyceoz/mlperf-inference-gpt:v2
+ resources:
+ requests:
+ cpu: 140
+ memory: 20000Mi
+ volumeMounts:
+ - mountPath: /dev/shm
+ name: dshm
+ command: [ "/bin/sh", "-c" ]
+ args: [ "sleep infinity" ]
+ volumes:
+ - name: dshm
+ emptyDir:
+ medium: Memory
\ No newline at end of file
diff --git a/language/gpt-j/api-endpoint-artifacts/model-vllm.yaml b/language/gpt-j/api-endpoint-artifacts/model-vllm.yaml
new file mode 100644
index 000000000..d314cf952
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/model-vllm.yaml
@@ -0,0 +1,20 @@
+apiVersion: serving.kserve.io/v1beta1
+kind: InferenceService
+metadata:
+ annotations:
+ serving.knative.openshift.io/enablePassthrough: "true"
+ sidecar.istio.io/inject: "true"
+ sidecar.istio.io/rewriteAppHTTPProbers: "true"
+ name: gpt-j-isvc
+spec:
+ predictor:
+ minReplicas: 1
+ maxReplicas: 1
+ #apiVersion: serving.kserve.io/v1alpha2
+ serviceAccountName: sa
+ #timeout: 240
+ model:
+ modelFormat:
+ name: pytorch
+ runtime: vllm
+ storageUri: s3://mlperf-inference-models/gpt-j-cnn/
\ No newline at end of file
diff --git a/language/gpt-j/api-endpoint-artifacts/model.yaml b/language/gpt-j/api-endpoint-artifacts/model.yaml
new file mode 100644
index 000000000..69ce8ae25
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/model.yaml
@@ -0,0 +1,20 @@
+apiVersion: serving.kserve.io/v1beta1
+kind: InferenceService
+metadata:
+ annotations:
+ serving.knative.openshift.io/enablePassthrough: "true"
+ sidecar.istio.io/inject: "true"
+ sidecar.istio.io/rewriteAppHTTPProbers: "true"
+ name: gpt-j-isvc
+spec:
+ predictor:
+ minReplicas: 1
+ maxReplicas: 1
+ #apiVersion: serving.kserve.io/v1alpha2
+ serviceAccountName: sa
+ #timeout: 240
+ model:
+ modelFormat:
+ name: pytorch
+ runtime: tgis-runtime-grpc
+ storageUri: s3://mlperf-inference-models/gpt-j-cnn/
\ No newline at end of file
diff --git a/language/gpt-j/api-endpoint-artifacts/sa.yaml b/language/gpt-j/api-endpoint-artifacts/sa.yaml
new file mode 100644
index 000000000..6ccfd13b7
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/sa.yaml
@@ -0,0 +1,6 @@
+apiVersion: v1
+kind: ServiceAccount
+metadata:
+ name: sa
+secrets:
+- name: storage-config
\ No newline at end of file
diff --git a/language/gpt-j/api-endpoint-artifacts/secret.yaml b/language/gpt-j/api-endpoint-artifacts/secret.yaml
new file mode 100644
index 000000000..436eaa057
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/secret.yaml
@@ -0,0 +1,12 @@
+apiVersion: v1
+kind: Secret
+metadata:
+ annotations:
+ serving.kserve.io/s3-endpoint: "s3.amazonaws.com" # replace with your s3 endpoint e.g minio-service.kubeflow:9000
+ serving.kserve.io/s3-usehttps: "1" # by default 1, if testing with minio you can set to 0
+ serving.kserve.io/s3-region: "us-east-1"
+ serving.kserve.io/s3-useanoncredential: "false" # omitting this is the same as false, if true will ignore provided credential and use anonymous credentials
+ name: storage-config
+stringData:
+ "AWS_ACCESS_KEY_ID": "XXXXXXXXXXXXXXXXXX"
+ "AWS_SECRET_ACCESS_KEY": "XXXXXXXXXXXXXXXXXXXXXX"
\ No newline at end of file
diff --git a/language/gpt-j/api-endpoint-artifacts/service.yaml b/language/gpt-j/api-endpoint-artifacts/service.yaml
new file mode 100644
index 000000000..7d9a2c523
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/service.yaml
@@ -0,0 +1,20 @@
+kind: Service
+apiVersion: v1
+metadata:
+ name: vllm
+ labels:
+ app: vllm
+spec:
+ clusterIP: None
+ ipFamilies:
+ - IPv4
+ ports:
+ - name: http
+ protocol: TCP
+ port: 8000
+ targetPort: http
+ type: ClusterIP
+ ipFamilyPolicy: SingleStack
+ sessionAffinity: None
+ selector:
+ app: vllm
diff --git a/language/gpt-j/api-endpoint-artifacts/serving-tgis.yaml b/language/gpt-j/api-endpoint-artifacts/serving-tgis.yaml
new file mode 100644
index 000000000..43719baf2
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/serving-tgis.yaml
@@ -0,0 +1,67 @@
+apiVersion: serving.kserve.io/v1alpha1
+kind: ServingRuntime
+metadata:
+ name: tgis-runtime-grpc
+spec:
+ multiModel: false
+ supportedModelFormats:
+ - autoSelect: true
+ name: pytorch
+ containers:
+ - name: kserve-container
+ image: quay.io/opendatahub/text-generation-inference:fast
+ command: ["text-generation-launcher"]
+ args:
+ - "--model-name=/mnt/models/"
+ - "--port=3000"
+ - "--grpc-port=8033"
+ env:
+ - name: TRANSFORMERS_CACHE
+ value: /tmp/transformers_cache
+ - name: NUM_GPUS
+ value: "1"
+ - name: DTYPE_STR
+ value: float16
+ # Dynamic batch size changes
+ - name: MAX_BATCH_SIZE
+ value: "64"
+ - name: MAX_CONCURRENT_REQUESTS
+ value: "128"
+ - name: MAX_BATCH_WEIGHT
+ value: "10000"
+ - name: MAX_SEQUENCE_LENGTH
+ value: "2048"
+ - name: MAX_PREFILL_WEIGHT
+ value: "0"
+ - name: MAX_NEW_TOKENS
+ value: "128"
+# - name: FLASH_ATTENTION
+# value: "false"
+# - name: DEPLOYMENT_FRAMEWORK
+# value: hf_custom_tp
+ - name: LOG_GPU_USAGE_INTERVAL
+ value: "5"
+ resources: # configure as required
+ requests:
+ cpu: 12
+ memory: 100Gi
+ nvidia.com/gpu: 1
+ limits:
+ nvidia.com/gpu: 1
+ readinessProbe: # Use exec probes instad of httpGet since the probes' port gets rewritten to the containerPort
+ exec:
+ command:
+ - curl
+ - localhost:3000/health
+ initialDelaySeconds: 5
+ livenessProbe:
+ exec:
+ command:
+ - curl
+ - localhost:3000/health
+ initialDelaySeconds: 5
+ # periodSeconds: 5
+ ports:
+ - containerPort: 8033
+ name: h2c
+ protocol: TCP
\ No newline at end of file
diff --git a/language/gpt-j/api-endpoint-artifacts/serving-vllm.yaml b/language/gpt-j/api-endpoint-artifacts/serving-vllm.yaml
new file mode 100644
index 000000000..0f7d51556
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/serving-vllm.yaml
@@ -0,0 +1,47 @@
+apiVersion: serving.kserve.io/v1alpha1
+kind: ServingRuntime
+labels:
+ opendatahub.io/dashboard: "true"
+metadata:
+ annotations:
+ openshift.io/display-name: vLLM
+ name: vllm
+spec:
+ builtInAdapter:
+ modelLoadingTimeoutMillis: 90000
+ containers:
+ - args:
+ - --model
+ - /mnt/models/
+ - --download-dir
+ - /models-cache
+ - --port
+ - "8080"
+ - --dtype
+ - float16
+ image: quay.io/rh-aiservices-bu/vllm-openai-ubi9:0.3.1-fix-2939
+ name: kserve-container
+ ports:
+ - containerPort: 8080
+ name: http1
+ protocol: TCP
+ resources: # configure as required
+ requests:
+ cpu: 12
+ memory: 128Gi
+ nvidia.com/gpu: 1
+ limits:
+ cpu: 12
+ memory: 128Gi
+ nvidia.com/gpu: 1
+ volumeMounts:
+ - mountPath: /dev/shm
+ name: dshm
+ volumes:
+ - name: dshm
+ emptyDir:
+ medium: Memory
+ multiModel: false
+ supportedModelFormats:
+ - autoSelect: true
+ name: pytorch
\ No newline at end of file
diff --git a/language/gpt-j/api-endpoint-artifacts/standalone.yaml b/language/gpt-j/api-endpoint-artifacts/standalone.yaml
new file mode 100644
index 000000000..3ea95ac55
--- /dev/null
+++ b/language/gpt-j/api-endpoint-artifacts/standalone.yaml
@@ -0,0 +1,111 @@
+kind: Deployment
+apiVersion: apps/v1
+metadata:
+ annotations:
+ deployment.kubernetes.io/revision: '8'
+ resourceVersion: '249593'
+ name: vllm
+ generation: 24
+ namespace: gpt-service
+ labels:
+ app: vllm
+spec:
+ replicas: 1
+ selector:
+ matchLabels:
+ app: vllm
+ template:
+ metadata:
+ creationTimestamp: null
+ labels:
+ app: vllm
+ spec:
+ restartPolicy: Always
+ schedulerName: default-scheduler
+ affinity: {}
+ terminationGracePeriodSeconds: 120
+ securityContext: {}
+ containers:
+ - resources:
+ limits:
+ cpu: '12'
+ memory: 128Gi
+ nvidia.com/gpu: '1'
+ requests:
+ cpu: '12'
+ memory: 128Gi
+ nvidia.com/gpu: '1'
+# readinessProbe:
+# httpGet:
+# path: /health
+# port: http
+# scheme: HTTP
+# timeoutSeconds: 5
+# periodSeconds: 30
+# successThreshold: 1
+# failureThreshold: 3
+ terminationMessagePath: /dev/termination-log
+ name: server
+# livenessProbe:
+# httpGet:
+# path: /health
+# port: http
+# scheme: HTTP
+# timeoutSeconds: 8
+# periodSeconds: 100
+# successThreshold: 1
+# failureThreshold: 3
+ securityContext:
+ capabilities:
+ drop:
+ - ALL
+ runAsNonRoot: true
+ allowPrivilegeEscalation: false
+ seccompProfile:
+ type: RuntimeDefault
+ ports:
+ - name: http
+ containerPort: 8000
+ protocol: TCP
+ imagePullPolicy: IfNotPresent
+# startupProbe:
+# httpGet:
+# path: /health
+# port: http
+# scheme: HTTP
+# timeoutSeconds: 1
+# periodSeconds: 30
+# successThreshold: 1
+# failureThreshold: 24
+ volumeMounts:
+ - name: models-cache
+ mountPath: /models-cache
+ - name: shm
+ mountPath: /dev/shm
+ terminationMessagePolicy: File
+ image: 'quay.io/rh-aiservices-bu/vllm-openai-ubi9:0.3.1-fix-2939'
+ args:
+ - '--model'
+ - /models-cache/gpt-model-info/
+# - EleutherAI/gpt-j-6b
+ - '--download-dir'
+ - /models-cache
+ - '--dtype'
+ - float16
+ volumes:
+ - name: models-cache
+ persistentVolumeClaim:
+ claimName: vllm-model-cache
+ - name: shm
+ emptyDir:
+ medium: Memory
+ sizeLimit: 1Gi
+ dnsPolicy: ClusterFirst
+ tolerations:
+ - key: nvidia.com/gpu
+ operator: Exists
+ effect: NoSchedule
+ strategy:
+ type: Recreate
+ revisionHistoryLimit: 10
+ progressDeadlineSeconds: 600
diff --git a/language/gpt-j/dataset.py b/language/gpt-j/dataset.py
index 37d9cf354..1f0aec786 100644
--- a/language/gpt-j/dataset.py
+++ b/language/gpt-j/dataset.py
@@ -51,7 +51,7 @@ def __init__(self, dataset_path, batch_size=1, pad_val=1, pad_max=196, total_cou
self.targets = [
f"{example['output']}" for example in self.list_data_dict]
- self.source_encoded_input_ids, self.source_encoded_attn_masks = self.encode_samples()
+ self.source_encoded_input_ids = self.encode_samples()#, self.source_encoded_attn_masks = self.encode_samples()
self.count = total_count_override or len(self.sources)
self.perf_count = perf_count_override or self.count
@@ -62,16 +62,20 @@ def encode_samples(self):
total_samples = len(self.sources)
source_encoded_input_ids = []
- source_encoded_attn_masks = []
+ #source_encoded_attn_masks = []
for i in range(total_samples):
- source_encoded = self.tokenizer(self.sources[i], return_tensors="pt",
- padding=True, truncation=True,
- max_length=1919)
- source_encoded_input_ids.append(source_encoded.input_ids)
- source_encoded_attn_masks.append(source_encoded.attention_mask)
-
- return source_encoded_input_ids, source_encoded_attn_masks
+ #source_encoded = self.tokenizer(self.sources[i], return_tensors="pt",
+ # padding=True, truncation=True,
+ # max_length=1919)
+ tok = self.tokenizer(self.sources[i])["input_ids"]
+ while len(tok) > 1920:
+ self.sources[i] = self.sources[i][:-16 - (len(tok) - 1920)*4] + self.sources[i][-16:]
+ tok = self.tokenizer(self.sources[i])["input_ids"]
+ source_encoded_input_ids.append(self.sources[i])
+ #source_encoded_attn_masks.append(source_encoded.attention_mask)
+
+ return source_encoded_input_ids#, source_encoded_attn_masks
def LoadSamplesToRam(self, sample_list):
pass
diff --git a/language/gpt-j/generation_pb2.py b/language/gpt-j/generation_pb2.py
new file mode 100644
index 000000000..544e1ec7d
--- /dev/null
+++ b/language/gpt-j/generation_pb2.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: generation.proto
+# Protobuf Python Version: 4.25.0
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10generation.proto\x12\x05\x66maas\"\xa1\x01\n\x18\x42\x61tchedGenerationRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x16\n\tprefix_id\x18\x02 \x01(\tH\x00\x88\x01\x01\x12*\n\x08requests\x18\x03 \x03(\x0b\x32\x18.fmaas.GenerationRequest\x12!\n\x06params\x18\n \x01(\x0b\x32\x11.fmaas.ParametersB\x0c\n\n_prefix_id\"\x9f\x01\n\x17SingleGenerationRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x16\n\tprefix_id\x18\x02 \x01(\tH\x00\x88\x01\x01\x12)\n\x07request\x18\x03 \x01(\x0b\x32\x18.fmaas.GenerationRequest\x12!\n\x06params\x18\n \x01(\x0b\x32\x11.fmaas.ParametersB\x0c\n\n_prefix_id\"I\n\x19\x42\x61tchedGenerationResponse\x12,\n\tresponses\x18\x01 \x03(\x0b\x32\x19.fmaas.GenerationResponse\"!\n\x11GenerationRequest\x12\x0c\n\x04text\x18\x02 \x01(\t\"\xf3\x01\n\x12GenerationResponse\x12\x19\n\x11input_token_count\x18\x06 \x01(\r\x12\x1d\n\x15generated_token_count\x18\x02 \x01(\r\x12\x0c\n\x04text\x18\x04 \x01(\t\x12&\n\x0bstop_reason\x18\x07 \x01(\x0e\x32\x11.fmaas.StopReason\x12\x15\n\rstop_sequence\x18\x0b \x01(\t\x12\x0c\n\x04seed\x18\n \x01(\x04\x12 \n\x06tokens\x18\x08 \x03(\x0b\x32\x10.fmaas.TokenInfo\x12&\n\x0cinput_tokens\x18\t \x03(\x0b\x32\x10.fmaas.TokenInfo\"\x81\x02\n\nParameters\x12%\n\x06method\x18\x01 \x01(\x0e\x32\x15.fmaas.DecodingMethod\x12+\n\x08sampling\x18\x02 \x01(\x0b\x32\x19.fmaas.SamplingParameters\x12)\n\x08stopping\x18\x03 \x01(\x0b\x32\x17.fmaas.StoppingCriteria\x12(\n\x08response\x18\x04 \x01(\x0b\x32\x16.fmaas.ResponseOptions\x12+\n\x08\x64\x65\x63oding\x18\x05 \x01(\x0b\x32\x19.fmaas.DecodingParameters\x12\x1d\n\x15truncate_input_tokens\x18\x06 \x01(\r\"\xc5\x01\n\x12\x44\x65\x63odingParameters\x12\x1a\n\x12repetition_penalty\x18\x01 \x01(\x02\x12\x44\n\x0elength_penalty\x18\x02 \x01(\x0b\x32\'.fmaas.DecodingParameters.LengthPenaltyH\x00\x88\x01\x01\x1a:\n\rLengthPenalty\x12\x13\n\x0bstart_index\x18\x01 \x01(\r\x12\x14\n\x0c\x64\x65\x63\x61y_factor\x18\x02 \x01(\x02\x42\x11\n\x0f_length_penalty\"v\n\x12SamplingParameters\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_k\x18\x02 \x01(\r\x12\r\n\x05top_p\x18\x03 \x01(\x02\x12\x11\n\ttypical_p\x18\x04 \x01(\x02\x12\x11\n\x04seed\x18\x05 \x01(\x04H\x00\x88\x01\x01\x42\x07\n\x05_seed\"\xb3\x01\n\x10StoppingCriteria\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\r\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\r\x12\x19\n\x11time_limit_millis\x18\x03 \x01(\r\x12\x16\n\x0estop_sequences\x18\x04 \x03(\t\x12\"\n\x15include_stop_sequence\x18\x05 \x01(\x08H\x00\x88\x01\x01\x42\x18\n\x16_include_stop_sequence\"\x98\x01\n\x0fResponseOptions\x12\x12\n\ninput_text\x18\x01 \x01(\x08\x12\x18\n\x10generated_tokens\x18\x02 \x01(\x08\x12\x14\n\x0cinput_tokens\x18\x03 \x01(\x08\x12\x16\n\x0etoken_logprobs\x18\x04 \x01(\x08\x12\x13\n\x0btoken_ranks\x18\x05 \x01(\x08\x12\x14\n\x0ctop_n_tokens\x18\x06 \x01(\r\"\x92\x01\n\tTokenInfo\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x0f\n\x07logprob\x18\x03 \x01(\x02\x12\x0c\n\x04rank\x18\x04 \x01(\r\x12-\n\ntop_tokens\x18\x05 \x03(\x0b\x32\x19.fmaas.TokenInfo.TopToken\x1a)\n\x08TopToken\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x0f\n\x07logprob\x18\x03 \x01(\x02\"k\n\x16\x42\x61tchedTokenizeRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12(\n\x08requests\x18\x02 \x03(\x0b\x32\x16.fmaas.TokenizeRequest\x12\x15\n\rreturn_tokens\x18\x03 \x01(\x08\"E\n\x17\x42\x61tchedTokenizeResponse\x12*\n\tresponses\x18\x01 \x03(\x0b\x32\x17.fmaas.TokenizeResponse\"\x1f\n\x0fTokenizeRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\"7\n\x10TokenizeResponse\x12\x13\n\x0btoken_count\x18\x01 \x01(\r\x12\x0e\n\x06tokens\x18\x02 \x03(\t\"$\n\x10ModelInfoRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\"\xb4\x01\n\x11ModelInfoResponse\x12\x36\n\nmodel_kind\x18\x01 \x01(\x0e\x32\".fmaas.ModelInfoResponse.ModelKind\x12\x1b\n\x13max_sequence_length\x18\x02 \x01(\r\x12\x16\n\x0emax_new_tokens\x18\x03 \x01(\r\"2\n\tModelKind\x12\x10\n\x0c\x44\x45\x43ODER_ONLY\x10\x00\x12\x13\n\x0f\x45NCODER_DECODER\x10\x01*(\n\x0e\x44\x65\x63odingMethod\x12\n\n\x06GREEDY\x10\x00\x12\n\n\x06SAMPLE\x10\x01*\x8b\x01\n\nStopReason\x12\x10\n\x0cNOT_FINISHED\x10\x00\x12\x0e\n\nMAX_TOKENS\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\r\n\tCANCELLED\x10\x03\x12\x0e\n\nTIME_LIMIT\x10\x04\x12\x11\n\rSTOP_SEQUENCE\x10\x05\x12\x0f\n\x0bTOKEN_LIMIT\x10\x06\x12\t\n\x05\x45RROR\x10\x07\x32\xc4\x02\n\x11GenerationService\x12O\n\x08Generate\x12\x1f.fmaas.BatchedGenerationRequest\x1a .fmaas.BatchedGenerationResponse\"\x00\x12O\n\x0eGenerateStream\x12\x1e.fmaas.SingleGenerationRequest\x1a\x19.fmaas.GenerationResponse\"\x00\x30\x01\x12K\n\x08Tokenize\x12\x1d.fmaas.BatchedTokenizeRequest\x1a\x1e.fmaas.BatchedTokenizeResponse\"\x00\x12@\n\tModelInfo\x12\x17.fmaas.ModelInfoRequest\x1a\x18.fmaas.ModelInfoResponse\"\x00\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'generation_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_DECODINGMETHOD']._serialized_start=2266
+ _globals['_DECODINGMETHOD']._serialized_end=2306
+ _globals['_STOPREASON']._serialized_start=2309
+ _globals['_STOPREASON']._serialized_end=2448
+ _globals['_BATCHEDGENERATIONREQUEST']._serialized_start=28
+ _globals['_BATCHEDGENERATIONREQUEST']._serialized_end=189
+ _globals['_SINGLEGENERATIONREQUEST']._serialized_start=192
+ _globals['_SINGLEGENERATIONREQUEST']._serialized_end=351
+ _globals['_BATCHEDGENERATIONRESPONSE']._serialized_start=353
+ _globals['_BATCHEDGENERATIONRESPONSE']._serialized_end=426
+ _globals['_GENERATIONREQUEST']._serialized_start=428
+ _globals['_GENERATIONREQUEST']._serialized_end=461
+ _globals['_GENERATIONRESPONSE']._serialized_start=464
+ _globals['_GENERATIONRESPONSE']._serialized_end=707
+ _globals['_PARAMETERS']._serialized_start=710
+ _globals['_PARAMETERS']._serialized_end=967
+ _globals['_DECODINGPARAMETERS']._serialized_start=970
+ _globals['_DECODINGPARAMETERS']._serialized_end=1167
+ _globals['_DECODINGPARAMETERS_LENGTHPENALTY']._serialized_start=1090
+ _globals['_DECODINGPARAMETERS_LENGTHPENALTY']._serialized_end=1148
+ _globals['_SAMPLINGPARAMETERS']._serialized_start=1169
+ _globals['_SAMPLINGPARAMETERS']._serialized_end=1287
+ _globals['_STOPPINGCRITERIA']._serialized_start=1290
+ _globals['_STOPPINGCRITERIA']._serialized_end=1469
+ _globals['_RESPONSEOPTIONS']._serialized_start=1472
+ _globals['_RESPONSEOPTIONS']._serialized_end=1624
+ _globals['_TOKENINFO']._serialized_start=1627
+ _globals['_TOKENINFO']._serialized_end=1773
+ _globals['_TOKENINFO_TOPTOKEN']._serialized_start=1732
+ _globals['_TOKENINFO_TOPTOKEN']._serialized_end=1773
+ _globals['_BATCHEDTOKENIZEREQUEST']._serialized_start=1775
+ _globals['_BATCHEDTOKENIZEREQUEST']._serialized_end=1882
+ _globals['_BATCHEDTOKENIZERESPONSE']._serialized_start=1884
+ _globals['_BATCHEDTOKENIZERESPONSE']._serialized_end=1953
+ _globals['_TOKENIZEREQUEST']._serialized_start=1955
+ _globals['_TOKENIZEREQUEST']._serialized_end=1986
+ _globals['_TOKENIZERESPONSE']._serialized_start=1988
+ _globals['_TOKENIZERESPONSE']._serialized_end=2043
+ _globals['_MODELINFOREQUEST']._serialized_start=2045
+ _globals['_MODELINFOREQUEST']._serialized_end=2081
+ _globals['_MODELINFORESPONSE']._serialized_start=2084
+ _globals['_MODELINFORESPONSE']._serialized_end=2264
+ _globals['_MODELINFORESPONSE_MODELKIND']._serialized_start=2214
+ _globals['_MODELINFORESPONSE_MODELKIND']._serialized_end=2264
+ _globals['_GENERATIONSERVICE']._serialized_start=2451
+ _globals['_GENERATIONSERVICE']._serialized_end=2775
+# @@protoc_insertion_point(module_scope)
diff --git a/language/gpt-j/generation_pb2_grpc.py b/language/gpt-j/generation_pb2_grpc.py
new file mode 100644
index 000000000..9ab1e4eb5
--- /dev/null
+++ b/language/gpt-j/generation_pb2_grpc.py
@@ -0,0 +1,169 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+import generation_pb2 as generation__pb2
+
+
+class GenerationServiceStub(object):
+ """Missing associated documentation comment in .proto file."""
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.Generate = channel.unary_unary(
+ '/fmaas.GenerationService/Generate',
+ request_serializer=generation__pb2.BatchedGenerationRequest.SerializeToString,
+ response_deserializer=generation__pb2.BatchedGenerationResponse.FromString,
+ )
+ self.GenerateStream = channel.unary_stream(
+ '/fmaas.GenerationService/GenerateStream',
+ request_serializer=generation__pb2.SingleGenerationRequest.SerializeToString,
+ response_deserializer=generation__pb2.GenerationResponse.FromString,
+ )
+ self.Tokenize = channel.unary_unary(
+ '/fmaas.GenerationService/Tokenize',
+ request_serializer=generation__pb2.BatchedTokenizeRequest.SerializeToString,
+ response_deserializer=generation__pb2.BatchedTokenizeResponse.FromString,
+ )
+ self.ModelInfo = channel.unary_unary(
+ '/fmaas.GenerationService/ModelInfo',
+ request_serializer=generation__pb2.ModelInfoRequest.SerializeToString,
+ response_deserializer=generation__pb2.ModelInfoResponse.FromString,
+ )
+
+
+class GenerationServiceServicer(object):
+ """Missing associated documentation comment in .proto file."""
+
+ def Generate(self, request, context):
+ """Generates text given a text prompt, for one or more inputs
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def GenerateStream(self, request, context):
+ """Generates text given a single input prompt, streaming the response
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def Tokenize(self, request, context):
+ """Tokenize text
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def ModelInfo(self, request, context):
+ """Model info
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_GenerationServiceServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'Generate': grpc.unary_unary_rpc_method_handler(
+ servicer.Generate,
+ request_deserializer=generation__pb2.BatchedGenerationRequest.FromString,
+ response_serializer=generation__pb2.BatchedGenerationResponse.SerializeToString,
+ ),
+ 'GenerateStream': grpc.unary_stream_rpc_method_handler(
+ servicer.GenerateStream,
+ request_deserializer=generation__pb2.SingleGenerationRequest.FromString,
+ response_serializer=generation__pb2.GenerationResponse.SerializeToString,
+ ),
+ 'Tokenize': grpc.unary_unary_rpc_method_handler(
+ servicer.Tokenize,
+ request_deserializer=generation__pb2.BatchedTokenizeRequest.FromString,
+ response_serializer=generation__pb2.BatchedTokenizeResponse.SerializeToString,
+ ),
+ 'ModelInfo': grpc.unary_unary_rpc_method_handler(
+ servicer.ModelInfo,
+ request_deserializer=generation__pb2.ModelInfoRequest.FromString,
+ response_serializer=generation__pb2.ModelInfoResponse.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'fmaas.GenerationService', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler,))
+
+
+ # This class is part of an EXPERIMENTAL API.
+class GenerationService(object):
+ """Missing associated documentation comment in .proto file."""
+
+ @staticmethod
+ def Generate(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/fmaas.GenerationService/Generate',
+ generation__pb2.BatchedGenerationRequest.SerializeToString,
+ generation__pb2.BatchedGenerationResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def GenerateStream(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(request, target, '/fmaas.GenerationService/GenerateStream',
+ generation__pb2.SingleGenerationRequest.SerializeToString,
+ generation__pb2.GenerationResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def Tokenize(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/fmaas.GenerationService/Tokenize',
+ generation__pb2.BatchedTokenizeRequest.SerializeToString,
+ generation__pb2.BatchedTokenizeResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def ModelInfo(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/fmaas.GenerationService/ModelInfo',
+ generation__pb2.ModelInfoRequest.SerializeToString,
+ generation__pb2.ModelInfoResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
diff --git a/language/gpt-j/inference.py b/language/gpt-j/inference.py
new file mode 100644
index 000000000..a643ea0f2
--- /dev/null
+++ b/language/gpt-j/inference.py
@@ -0,0 +1,210 @@
+from typing import Optional, Union
+
+import grpc
+import generation_pb2_grpc
+import socket
+import ssl
+import sys
+
+
+def get_server_certificate(host: str, port: int) -> str:
+ """connect to host:port and get the certificate it presents
+
+ This is almost the same as `ssl.get_server_certificate`, but
+ when opening the TLS socket, `server_hostname` is also provided.
+
+ This retrieves the correct certificate for hosts using name-based
+ virtual hosting.
+ """
+ if sys.version_info >= (3, 10):
+ # ssl.get_server_certificate supports TLS SNI only above 3.10
+ # https://github.com/python/cpython/pull/16820
+ return ssl.get_server_certificate((host, port))
+
+ context = ssl.SSLContext()
+
+ with socket.create_connection((host, port)) as sock, context.wrap_socket(
+ sock, server_hostname=host
+ ) as ssock:
+ cert_der = ssock.getpeercert(binary_form=True)
+
+ assert cert_der
+ return ssl.DER_cert_to_PEM_cert(cert_der)
+
+
+class GrpcClient:
+ def __init__(
+ self,
+ host: str,
+ port: int,
+ *,
+ insecure: bool = False,
+ verify: Optional[bool] = None,
+ ca_cert: Union[None, bytes, str] = None,
+ client_cert: Union[None, bytes, str] = None,
+ client_key: Union[None, bytes, str] = None,
+ ) -> None:
+ self._channel = self._make_channel(
+ host,
+ port,
+ insecure=insecure,
+ verify=verify,
+ client_key=client_key,
+ client_cert=client_cert,
+ ca_cert=ca_cert,
+ )
+ self.generation_service_stub = generation_pb2_grpc.GenerationServiceStub(
+ self._channel
+ )
+
+ def make_request(self, texts: str, model_id: str = "flan-t5-small"):
+ request = generation_pb2_grpc.generation__pb2.BatchedGenerationRequest(
+ model_id=model_id,
+ requests=[generation_pb2_grpc.generation__pb2.GenerationRequest(text=text) for text in texts],
+ params=generation_pb2_grpc.generation__pb2.Parameters(
+ method=generation_pb2_grpc.generation__pb2.SAMPLE,
+ stopping=generation_pb2_grpc.generation__pb2.StoppingCriteria(
+ max_new_tokens=128,
+ min_new_tokens=30,
+ ),
+ sampling=generation_pb2_grpc.generation__pb2.SamplingParameters(
+ top_k=4
+ )
+ )
+ )
+ result = self.generation_service_stub.Generate(request=request)
+ return result
+
+ def make_request_stream(self, text: str, model_id: str = "flan-t5-small"):
+ request = generation_pb2_grpc.generation__pb2.SingleGenerationRequest(
+ model_id=model_id,
+ request=generation_pb2_grpc.generation__pb2.GenerationRequest(text=text),
+ params=generation_pb2_grpc.generation__pb2.Parameters(
+ method=generation_pb2_grpc.generation__pb2.SAMPLE,
+ stopping=generation_pb2_grpc.generation__pb2.StoppingCriteria(
+ max_new_tokens=128,
+ min_new_tokens=30,
+ ),
+ sampling=generation_pb2_grpc.generation__pb2.SamplingParameters(
+ top_k=4
+ ),
+ response=generation_pb2_grpc.generation__pb2.ResponseOptions(
+ generated_tokens=True
+ )
+ )
+ )
+ result = self.generation_service_stub.GenerateStream(request=request)
+ return result
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc_info):
+ self._close()
+ return False
+
+ def _close(self):
+ try:
+ if hasattr(self, "_channel") and self._channel:
+ self._channel.close()
+ except Exception as exc:
+ print(f"Unexpected exception while closing client: {exc}")
+
+ def __del__(self):
+ self._close()
+
+ def _make_channel(
+ self,
+ host: str,
+ port: int,
+ *,
+ insecure: bool = False,
+ verify: Optional[bool] = None,
+ ca_cert: Union[None, bytes, str] = None,
+ client_key: Union[None, bytes, str] = None,
+ client_cert: Union[None, bytes, str] = None,
+ ) -> grpc.Channel:
+ """Creates a grpc channel
+
+ Args:
+ - host: str
+ - port: str
+ - (optional) insecure: use a plaintext connection (default=False)
+ - (optional) verify: set to False to disable remote host certificate(s)
+ verification. Cannot be used with `plaintext` or with MTLS
+ - (optional) ca_cert: certificate authority to use
+ - (optional) client_key: client key for mTLS mode
+ - (optional) client_cert: client cert for mTLS mode
+
+ """
+ if not host.strip():
+ raise ValueError("A non empty host name is required")
+ if int(port) <= 0:
+ raise ValueError("A non zero port number is required")
+ if insecure and any(
+ (val is not None) for val in (ca_cert, client_key, client_cert)
+ ):
+ raise ValueError("cannot use insecure with TLS/mTLS certificates")
+ if insecure and verify:
+ raise ValueError("insecure cannot be used with verify")
+
+ client_key_bytes = self._try_load_certificate(client_key)
+ client_cert_bytes = self._try_load_certificate(client_cert)
+ ca_cert_bytes = self._try_load_certificate(ca_cert)
+
+ connection = f"{host}:{port}"
+ if insecure:
+ print("Connecting over an insecure plaintext grpc channel")
+ return grpc.insecure_channel(connection)
+
+ credentials_kwargs: dict[str, bytes] = {}
+ if ca_cert_bytes and not (any((client_cert_bytes, client_key_bytes))):
+ print("Connecting using provided CA certificate for secure channel")
+ credentials_kwargs.update(root_certificates=ca_cert_bytes)
+ elif client_cert_bytes and client_key_bytes and ca_cert_bytes:
+ print("Connecting using mTLS for secure channel")
+ credentials_kwargs.update(
+ root_certificates=ca_cert_bytes,
+ private_key=client_key_bytes,
+ certificate_chain=client_cert_bytes,
+ )
+ elif verify is False:
+ cert = get_server_certificate(host, port).encode()
+ credentials_kwargs.update(root_certificates=cert)
+
+ return grpc.secure_channel(
+ connection, grpc.ssl_channel_credentials(**credentials_kwargs)
+ )
+
+ @staticmethod
+ def _try_load_certificate(certificate: Union[None, bytes, str]) -> Optional[bytes]:
+ """If the certificate points to a file, return the contents (plaintext reads).
+ Else return the bytes"""
+ if not certificate:
+ return None
+
+ if isinstance(certificate, bytes):
+ return certificate
+
+ if isinstance(certificate, str):
+ with open(certificate, "rb") as secret_file:
+ return secret_file.read()
+ raise ValueError(
+ f"{certificate=} should be a path to a certificate files or bytes"
+ )
+
+
+if __name__ == "__main__":
+ # host = "flan-t5-small-predictor-caikit-testing.apps.aisrhods-wx.8goc.p1.openshiftapps.com"
+ # port = 443
+ host = "localhost"
+ port = 8033
+ print(f"connecting to {host=}:{port}")
+ client = GrpcClient(
+ host,
+ port,
+ insecure=True,
+ # verify=False,
+ )
+
+ client.make_request("this is the query text", model_id="flan-t5-small")
diff --git a/language/gpt-j/main.py b/language/gpt-j/main.py
index 367af5212..b8e99963f 100644
--- a/language/gpt-j/main.py
+++ b/language/gpt-j/main.py
@@ -2,91 +2,105 @@
import mlperf_loadgen as lg
import argparse
import os
-
+import logging
import sys
-from backend import get_SUT
+from SUT import SUT, SUTServer
+
sys.path.insert(0, os.getcwd())
+logging.basicConfig(level=logging.INFO)
+log = logging.getLogger("GPT-J-MAIN")
def get_args():
parser = argparse.ArgumentParser()
- parser.add_argument(
- "--backend", choices=["pytorch"], default="pytorch", help="Backend")
- parser.add_argument("--scenario", choices=["SingleStream", "Offline",
- "Server"], default="Offline", help="Scenario")
- parser.add_argument("--model-path", default="EleutherAI/gpt-j-6B", help="")
- parser.add_argument(
- "--dataset-path", default="./data/cnn_eval.json", help="")
- parser.add_argument("--accuracy", action="store_true",
- help="enable accuracy pass")
- parser.add_argument("--dtype", default="float32", help="data type of the model, choose from float16, bfloat16 and float32")
- parser.add_argument("--quantized", action="store_true",
- help="use quantized model (only valid for onnxruntime backend)")
- parser.add_argument("--profile", action="store_true",
- help="enable profiling (only valid for onnxruntime backend)")
- parser.add_argument("--gpu", action="store_true",
- help="use GPU instead of CPU for the inference")
- parser.add_argument("--audit_conf", default="audit.conf",
- help="audit config for LoadGen settings during compliance runs")
- parser.add_argument(
- "--mlperf_conf", default="mlperf.conf", help="mlperf rules config")
- parser.add_argument("--user_conf", default="user.conf",
- help="user config for user LoadGen settings such as target QPS")
- parser.add_argument("--max_examples", type=int, default=13368,
- help="Maximum number of examples to consider (not limited by default)")
+ parser.add_argument("--scenario", type=str, choices=["Offline", "Server"], default="Offline", help="Scenario")
+ parser.add_argument("--model-path", type=str, default="EleutherAI/gpt-j-6B", help="Model name")
+ parser.add_argument("--dataset-path", type=str, default=None, help="")
+ parser.add_argument("--api-server", type=str, default=None, help="Specify an api endpoint call to use api mode")
+ parser.add_argument("--api-model-name", type=str, default=None, help="Specify a model name to use api mode")
+ parser.add_argument("--additional-servers", nargs='+', default=[], help="Specify additional endpoints for load splitting")
+ parser.add_argument("--accuracy", action="store_true", help="Run accuracy mode")
+ parser.add_argument("--grpc", action="store_true", help="Enable grpc for api endpoint")
+ parser.add_argument("--batch-grpc", action="store_true", help="Enable batch requests for grpc")
+ parser.add_argument("--vllm", action="store_true", help="Switch runtime to vllm for api endpoint")
+ parser.add_argument("--dtype", type=str, default="float32", help="data type of the model, choose from float16, bfloat16 and float32")
+ parser.add_argument("--device", type=str, choices=["cpu", "cuda:0"], default="cpu", help="device to use")
+ parser.add_argument("--audit-conf", type=str, default="audit.conf", help="audit config for LoadGen settings during compliance runs")
+ parser.add_argument("--mlperf-conf", type=str, default="mlperf.conf", help="mlperf rules config")
+ parser.add_argument("--user-conf", type=str, default="user.conf", help="user config for user LoadGen settings such as target QPS")
+ parser.add_argument("--total-sample-count", type=int, default=13368, help="Number of samples to use in benchmark.") # TODO: This interpretation of 'total-sample-count' is a little misleading. Fix it
+ parser.add_argument("--output-log-dir", type=str, default="output-logs", help="Where logs are saved")
+ parser.add_argument("--enable-log-trace", action="store_true", help="Enable log tracing. This file can become quite large")
+ parser.add_argument("--num-workers", type=int, default=1, help="Number of workers to process queries")
+
args = parser.parse_args()
return args
scenario_map = {
- "SingleStream": lg.TestScenario.SingleStream,
- "Offline": lg.TestScenario.Offline,
- "Server": lg.TestScenario.Server,
- "MultiStream": lg.TestScenario.MultiStream
-}
+ "offline": lg.TestScenario.Offline,
+ "server": lg.TestScenario.Server,
+ }
+sut_map = {
+ "offline": SUT,
+ "server": SUTServer
+ }
def main():
args = get_args()
- sut = get_SUT(
- model_path=args.model_path,
- scenario=args.scenario,
- dtype=args.dtype,
- dataset_path=args.dataset_path,
- max_examples=args.max_examples,
- use_gpu=args.gpu,
- )
-
settings = lg.TestSettings()
- settings.scenario = scenario_map[args.scenario]
+ settings.scenario = scenario_map[args.scenario.lower()]
# Need to update the conf
settings.FromConfig(args.mlperf_conf, "gptj", args.scenario)
settings.FromConfig(args.user_conf, "gptj", args.scenario)
if args.accuracy:
settings.mode = lg.TestMode.AccuracyOnly
+ log.warning("Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet")
else:
settings.mode = lg.TestMode.PerformanceOnly
- log_path = os.environ.get("LOG_PATH")
- if not log_path:
- log_path = "build/logs"
- if not os.path.exists(log_path):
- os.makedirs(log_path)
+
+ os.makedirs(args.output_log_dir, exist_ok=True)
log_output_settings = lg.LogOutputSettings()
- log_output_settings.outdir = log_path
+ log_output_settings.outdir = args.output_log_dir
log_output_settings.copy_summary_to_stdout = True
log_settings = lg.LogSettings()
log_settings.log_output = log_output_settings
- log_settings.enable_trace = True
+ log_settings.enable_trace = args.enable_log_trace
+
+ sut_cls = sut_map[args.scenario.lower()]
+
+ sut = sut_cls(
+ model_path=args.model_path,
+ api_server=args.api_server,
+ api_model_name=args.api_model_name,
+ additional_servers=args.additional_servers,
+ grpc=args.grpc,
+ batch_grpc=args.batch_grpc,
+ vllm=args.vllm,
+ dtype=args.dtype,
+ dataset_path=args.dataset_path,
+ total_sample_count=args.total_sample_count,
+ device=args.device,
+ )
+
+ # Start sut before loadgen starts
+ sut.start()
+ lgSUT = lg.ConstructSUT(sut.issue_queries, sut.flush_queries)
+ log.info("Starting Benchmark run")
+ lg.StartTestWithLogSettings(lgSUT, sut.qsl, settings, log_settings, args.audit_conf)
+
+ # Stop sut after completion
+ sut.stop()
- lg.StartTestWithLogSettings(sut.sut, sut.qsl, settings, log_settings, args.audit_conf)
- print("Test Done!")
+ log.info("Run Completed!")
- print("Destroying SUT...")
- lg.DestroySUT(sut.sut)
+ log.info("Destroying SUT...")
+ lg.DestroySUT(lgSUT)
- print("Destroying QSL...")
+ log.info("Destroying QSL...")
lg.DestroyQSL(sut.qsl)
diff --git a/language/gpt-j/backend.py b/language/gpt-j/old_backend.py
similarity index 100%
rename from language/gpt-j/backend.py
rename to language/gpt-j/old_backend.py
diff --git a/language/gpt-j/old_main.py b/language/gpt-j/old_main.py
new file mode 100644
index 000000000..367af5212
--- /dev/null
+++ b/language/gpt-j/old_main.py
@@ -0,0 +1,94 @@
+import subprocess
+import mlperf_loadgen as lg
+import argparse
+import os
+
+import sys
+from backend import get_SUT
+sys.path.insert(0, os.getcwd())
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--backend", choices=["pytorch"], default="pytorch", help="Backend")
+ parser.add_argument("--scenario", choices=["SingleStream", "Offline",
+ "Server"], default="Offline", help="Scenario")
+ parser.add_argument("--model-path", default="EleutherAI/gpt-j-6B", help="")
+ parser.add_argument(
+ "--dataset-path", default="./data/cnn_eval.json", help="")
+ parser.add_argument("--accuracy", action="store_true",
+ help="enable accuracy pass")
+ parser.add_argument("--dtype", default="float32", help="data type of the model, choose from float16, bfloat16 and float32")
+ parser.add_argument("--quantized", action="store_true",
+ help="use quantized model (only valid for onnxruntime backend)")
+ parser.add_argument("--profile", action="store_true",
+ help="enable profiling (only valid for onnxruntime backend)")
+ parser.add_argument("--gpu", action="store_true",
+ help="use GPU instead of CPU for the inference")
+ parser.add_argument("--audit_conf", default="audit.conf",
+ help="audit config for LoadGen settings during compliance runs")
+ parser.add_argument(
+ "--mlperf_conf", default="mlperf.conf", help="mlperf rules config")
+ parser.add_argument("--user_conf", default="user.conf",
+ help="user config for user LoadGen settings such as target QPS")
+ parser.add_argument("--max_examples", type=int, default=13368,
+ help="Maximum number of examples to consider (not limited by default)")
+ args = parser.parse_args()
+ return args
+
+
+scenario_map = {
+ "SingleStream": lg.TestScenario.SingleStream,
+ "Offline": lg.TestScenario.Offline,
+ "Server": lg.TestScenario.Server,
+ "MultiStream": lg.TestScenario.MultiStream
+}
+
+
+def main():
+ args = get_args()
+
+ sut = get_SUT(
+ model_path=args.model_path,
+ scenario=args.scenario,
+ dtype=args.dtype,
+ dataset_path=args.dataset_path,
+ max_examples=args.max_examples,
+ use_gpu=args.gpu,
+ )
+
+ settings = lg.TestSettings()
+ settings.scenario = scenario_map[args.scenario]
+ # Need to update the conf
+ settings.FromConfig(args.mlperf_conf, "gptj", args.scenario)
+ settings.FromConfig(args.user_conf, "gptj", args.scenario)
+
+ if args.accuracy:
+ settings.mode = lg.TestMode.AccuracyOnly
+ else:
+ settings.mode = lg.TestMode.PerformanceOnly
+ log_path = os.environ.get("LOG_PATH")
+ if not log_path:
+ log_path = "build/logs"
+ if not os.path.exists(log_path):
+ os.makedirs(log_path)
+ log_output_settings = lg.LogOutputSettings()
+ log_output_settings.outdir = log_path
+ log_output_settings.copy_summary_to_stdout = True
+ log_settings = lg.LogSettings()
+ log_settings.log_output = log_output_settings
+ log_settings.enable_trace = True
+
+ lg.StartTestWithLogSettings(sut.sut, sut.qsl, settings, log_settings, args.audit_conf)
+ print("Test Done!")
+
+ print("Destroying SUT...")
+ lg.DestroySUT(sut.sut)
+
+ print("Destroying QSL...")
+ lg.DestroyQSL(sut.qsl)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/language/gpt-j/user.conf b/language/gpt-j/user.conf
index 07a10bbe2..ded93b822 100644
--- a/language/gpt-j/user.conf
+++ b/language/gpt-j/user.conf
@@ -2,3 +2,5 @@
# The key has the format 'model.scenario.key'. Value is mostly int64_t.
# Model maybe '*' as wildcard. In that case the value applies to all models.
# All times are in milli seconds
+*.Server.target_qps = 132
+*.Offline.min_query_count = 93576
diff --git a/language/llama2-70b/README.md b/language/llama2-70b/README.md
index a9a150b4c..cd3af3184 100644
--- a/language/llama2-70b/README.md
+++ b/language/llama2-70b/README.md
@@ -1,10 +1,18 @@
+# Red Hat OpenShift Model Serving Stack Implementation
+
+Please see the `README.md` located in the `api-endpoint-artifacts` dir for instructions on running the model serving stack implementation.
+
+Below is the original guide for the reference implementation, which we retained backwards compatibility with:
+
# Reference Implementation for llama2-70b
**Basic implementation for llama2-70b. Few noteworthy items:**
+ Processing of Validation dataset is not finalized yet. Decision on input token lengths is pending
+ Streamer for communicating with loadgen has quite some overhead. This is only meant to provide functional implementation
-
++ For custom/optimized implementations of this benchmark it is important to include the :
+ - For server scenario, it is necesary to call `lg.FirstTokenComplete(response)` for each query. This way the first token will be reported and it's latency will be measured.
+ - For all scenarios, when calling `lg.QuerySamplesComplete(response)`, it is necessary that each of the elements in response is a `lg.QuerySampleResponse` that contains the number of tokens (can be create this way: `lg.QuerySampleResponse(qitem.id, bi[0], bi[1], n_tokens)`). The number of tokens reported should match with the number of tokens on your answer and this will be checked in [TEST06](../../compliance/nvidia/TEST06/)
## Prepare environment
@@ -62,7 +70,12 @@ CPU-only setup, as well as any GPU versions for applicable libraries like PyTorc
## Get Model
-+ For now, MLCommons is not hosting the checkpoint, so you must first go to [llama2-request-link](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and make a request, sign in to huggingface (if you don't have account, you'd need to create one). **Please note your authentication credentials** as you may be required to provide them when cloninng below
+### MLCommons Members Download
+MLCommons hosts the model and preprocessed dataset for download exclusively by MLCommons Members. You must first agree to the [confidentiality notice](https://docs.google.com/forms/d/e/1FAIpQLSc_8VIvRmXM3I8KQaYnKf7gy27Z63BBoI_I1u02f4lw6rBp3g/viewform), then follow the link to a directory containing Rclone download instructions.
+
+
+### External Download
++ First go to [llama2-request-link](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and make a request, sign in to HuggingFace (if you don't have account, you'll need to create one). **Please note your authentication credentials** as you may be required to provide them when cloning below.
+ Requires Git Large Files Storage
```
export CHECKPOINT_PATH=${PWD}/Llama-2-70b-chat-hf
@@ -73,6 +86,29 @@ git clone https://huggingface.co/meta-llama/Llama-2-70b-chat-hf ${CHECKPOINT_PAT
## Get Dataset
+### Preprocessed
+
+You can use Rclone to download the preprocessed dataset from a Cloudflare R2 bucket.
+
+To run Rclone on Windows, you can download the executable [here](https://rclone.org/install/#windows).
+To install Rclone on Linux/macOS/BSD systems, run:
+```
+sudo -v ; curl https://rclone.org/install.sh | sudo bash
+```
+Once Rclone is installed, run the following command to authenticate with the bucket:
+```
+rclone config create mlc-inference s3 provider=Cloudflare access_key_id=f65ba5eef400db161ea49967de89f47b secret_access_key=fbea333914c292b854f14d3fe232bad6c5407bf0ab1bebf78833c2b359bdfd2b endpoint=https://c2686074cb2caf5cbaf6d134bdba8b47.r2.cloudflarestorage.com
+```
+You can then navigate in the terminal to your desired download directory and run the following command to download the dataset:
+
+```
+rclone copy mlc-inference:mlcommons-inference-wg-public/open_orca ./open_orca -P
+```
+
+### Unprocessed
+
+You can also download and process the dataset yourself as follows:
+
```
# First get the `open-orca` parquet from huggingface
export OPENORCA_DATASET=${PWD}/open-orca
@@ -195,15 +231,15 @@ if [ -e ${ACCURACY_LOG_FILE} ]; then
fi
```
-The ServerSUT was not tested for GPU runs. You can try setting `--device cuda:0`, but YMMV.
+The ServerSUT was not tested for GPU runs.
## Accuracy Target
Running the GPU implementation in FP32 precision resulted in the following FP32 accuracy targets (normalized to a 0-100
scale from a 0.0-1.0 scale):
-- Rouge1: 43.88
-- Rouge2: 21.7108
-- RougeL: 28.2502
-- RougeLsum: 41.4821
+- Rouge1: 44.4312
+- Rouge2: 22.0352
+- RougeL: 28.6162
+- Tokens per sample: 294.45
-This was run an 8xH100 node. Total runtime was ~4.5 days.
+This was run on a DGX-H100 node. Total runtime was ~4.5 days.
diff --git a/language/llama2-70b/SUT.py b/language/llama2-70b/SUT.py
index 30fb21f8f..ff263e961 100644
--- a/language/llama2-70b/SUT.py
+++ b/language/llama2-70b/SUT.py
@@ -1,5 +1,7 @@
import os
+import sys
import time
+import re
import numpy as np
import array
import torch
@@ -14,6 +16,17 @@
import tqdm
import queue
+from concurrent.futures.thread import ThreadPoolExecutor
+import requests
+from urllib3.exceptions import InsecureRequestWarning
+import json
+
+from inference import GrpcClient
+import more_itertools as mit
+from itertools import repeat
+
+requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
+
import logging
from typing import TYPE_CHECKING, Optional, List
from pathlib import Path
@@ -80,6 +93,12 @@ def get_out_tokens(self):
class SUT():
def __init__(self,
model_path=None,
+ api_server=None,
+ api_model_name=None,
+ additional_servers=[],
+ grpc=False,
+ batch_grpc=False,
+ vllm=False,
dtype="bfloat16",
device="cpu",
batch_size=None,
@@ -89,11 +108,24 @@ def __init__(self,
workers=1):
self.model_path = model_path or "meta-llama/Llama-2-70b-chat-hf"
+ self.api_servers = []
+ if api_server:
+ self.api_servers.append(api_server)
+ if additional_servers and not api_server:
+ sys.exit("Additional servers cannot be used without primary api server")
+ for server in additional_servers:
+ self.api_servers.append(server)
+ self.api_model_name = api_model_name
+ self.grpc = grpc
+ self.batch_grpc = batch_grpc
+ self.vllm = vllm
+ if self.vllm and (self.grpc or self.batch_grpc):
+ sys.exit("vllm does not support grpc")
self.device = device
if not batch_size:
if device == "cpu":
- batch_size = 1
+ batch_size = 2000
else:
batch_size = 32 # Reduce to 8 if using 4 GPUs, 16 for 8.
self.batch_size = batch_size
@@ -146,6 +178,78 @@ def stop(self):
worker.join()
+ def query_api(self, input, idx):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model_id': self.api_model_name,
+ 'inputs': input,
+ 'parameters': {
+ 'max_new_tokens': 1024,
+ 'min_new_tokens': 1,
+ 'decoding_method': "GREEDY"
+ },
+ }
+
+ response_code = 0
+ while response_code != 200:
+ try:
+ response = requests.post(
+ self.api_servers[idx],
+ headers=headers,
+ json=json_data,
+ verify=False,
+ )
+ response_code = response.status_code
+ except:
+ print("connection failure")
+ return json.loads(response.text)["generated_text"]
+
+ def query_api_vllm(self, inputs, idx):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model': '/mnt/models/',
+ 'prompt': inputs,
+ 'max_tokens': 1024,
+ 'temperature': 0,
+ }
+
+ response_code = 0
+ while response_code != 200:
+ try:
+ response = requests.post(f'{self.api_servers[idx]}/v1/completions', headers=headers, json=json_data, verify=False)
+ response_code = response.status_code
+ except:
+ print("connection failure")
+ return [resp["text"] for resp in json.loads(response.text)["choices"]]
+
+ def query_api_grpc(self, input, idx):
+ resp = self.grpc_clients[idx].make_request([input], model_id=self.api_model_name)
+ return resp.responses[0].text
+
+ def query_api_batch_grpc(self, inputs, idx):
+ resps = self.grpc_clients[idx].make_request(inputs, model_id=self.api_model_name)
+ return [resp.text for resp in resps.responses]
+
+ def api_action_handler(self, chunk, server_idx):
+ if self.grpc:
+ if self.batch_grpc:
+ output = self.query_api_batch_grpc(chunk, server_idx)
+ else:
+ with ThreadPoolExecutor(max_workers=len(chunk)) as executor:
+ output = list(executor.map(self.query_api_grpc,chunk, repeat(server_idx)))
+ elif self.vllm:
+ output = self.query_api_vllm(chunk, server_idx)
+ else:
+ with ThreadPoolExecutor(max_workers=len(chunk)) as executor:
+ output = list(executor.map(self.query_api,chunk, repeat(server_idx)))
+ return output
+
def process_queries(self):
"""Processor of the queued queries. User may choose to add batching logic """
@@ -191,25 +295,43 @@ def process_queries(self):
assert input_ids_tensor.shape == input_masks_tensor.shape
assert input_ids_tensor.shape[0] <= self.batch_size
+ if self.api_servers:
+ decoded = self.tokenizer.batch_decode(input_ids_tensor)
+ cleaned = [entry.replace('','').replace('','') for entry in decoded]
+ cleaned_chunks = [list(c) for c in mit.divide(len(self.api_servers), cleaned)]
+
tik2 = time.time()
- pred_output_tokens = self.model.generate(
- input_ids=input_ids_tensor,
- attention_mask=input_masks_tensor,
- pad_token_id=self.tokenizer.pad_token_id,
- **gen_kwargs
- )
+ if self.api_servers:
+ with ThreadPoolExecutor(max_workers=len(self.api_servers)) as executor:
+ #needs to be tested
+ output_chunks = list(executor.map(self.api_action_handler,cleaned_chunks,range(len(self.api_servers))))
+ output = []
+ for row in output_chunks:
+ output += row
+ else:
+ pred_output_tokens = self.model.generate(
+ input_ids=input_ids_tensor,
+ attention_mask=input_masks_tensor,
+ pad_token_id=self.tokenizer.pad_token_id,
+ **gen_kwargs
+ )
tik3 = time.time()
- processed_output = self.data_object.postProcess(pred_output_tokens,
- input_seq_lens=input_len,
- query_id_list=query_ids)
+ if self.api_servers:
+ processed_output = np.array(self.tokenizer(output, padding='longest')['input_ids'])
+ else:
+ processed_output = self.data_object.postProcess(pred_output_tokens,
+ input_seq_lens=input_len,
+ query_id_list=query_ids)
for i in range(len(qitem)):
- response_array = array.array("B", processed_output[i].tobytes())
+ unpadded = np.delete(processed_output[i], np.where(processed_output[i] == 2))
+ n_tokens = unpadded.shape[0]
+ response_array = array.array("B", unpadded.tobytes())
bi = response_array.buffer_info()
- response = [lg.QuerySampleResponse(qitem[i].id, bi[0], bi[1])]
+ response = [lg.QuerySampleResponse(qitem[i].id, bi[0], bi[1], n_tokens)]
lg.QuerySamplesComplete(response)
tok = time.time()
@@ -227,26 +349,45 @@ def process_queries(self):
def load_model(self):
- self.model = LlamaForCausalLM.from_pretrained(
- self.model_path,
- device_map="auto",
- low_cpu_mem_usage=True,
- torch_dtype=self.amp_dtype
- )
- print("Loaded model")
+ if self.api_servers:
+ if not self.api_model_name:
+ sys.exit("API Server was specified but no model name was provided")
+ self.grpc_clients = []
+ for server in self.api_servers:
+ if self.grpc:
+ hostname = re.sub("https://|http://", "", server)
+ if hostname[-1] == "/":
+ hostname = hostname[:-1]
+ grpc_client = GrpcClient(
+ hostname,
+ 443,
+ verify=False,
+ )
+ self.grpc_clients.append(grpc_client)
+ elif not "http" in server:
+ server = "http://" + server
+
+ else:
+ self.model = LlamaForCausalLM.from_pretrained(
+ self.model_path,
+ device_map="auto",
+ low_cpu_mem_usage=True,
+ torch_dtype=self.amp_dtype
+ )
+ print("Loaded model")
- self.device = torch.device(self.device)
- if self.device == "cpu":
- self.model = self.model.to(self.device) # Force CPU if your system has GPU and you specifically want CPU-only run
+ self.device = torch.device(self.device)
+ if self.device == "cpu":
+ self.model = self.model.to(self.device) # Force CPU if your system has GPU and you specifically want CPU-only run
- self.model.eval()
- self.model = self.model.to(memory_format=torch.channels_last)
+ self.model.eval()
+ self.model = self.model.to(memory_format=torch.channels_last)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
model_max_length=1024,
padding_side="left",
- use_fast=False,)
+ use_fast=True,) #changed from false
self.tokenizer.pad_token = self.tokenizer.eos_token
print("Loaded tokenizer")
@@ -284,11 +425,16 @@ def __del__(self):
class SUTServer(SUT):
- def __init__(self, model_path=None, dtype="bfloat16", device="cpu", total_sample_count=24576, dataset_path=None, workers=1):
+ def __init__(self, model_path=None, api_server=None, additional_servers=[], api_model_name=None, grpc=False, batch_grpc=False, vllm=False, dtype="bfloat16", device="cpu", total_sample_count=24576, dataset_path=None, workers=1):
- super().__init__(model_path=model_path, dtype=dtype, device=device, total_sample_count=total_sample_count, dataset_path=dataset_path, workers=workers)
+ super().__init__(model_path=model_path, api_server=api_server, additional_servers=additional_servers, api_model_name=api_model_name, grpc=grpc, vllm=vllm, dtype=dtype, device=device, total_sample_count=total_sample_count, dataset_path=dataset_path, workers=workers)
+
+ with open(f"{self.model_path}/tokenizer.json", 'r') as token_file:
+ llama_tokenizer = json.load(token_file)
+ self.llama_vocab = llama_tokenizer["model"]["vocab"]
self.first_token_queue = queue.Queue()
+
def start(self):
@@ -314,13 +460,132 @@ def process_first_tokens(self):
first_tokens, response_id = first_token_item
- response_data = array.array("B", np.array(first_tokens, np.float32).tobytes())
+ response_data = array.array("B", np.array(first_tokens, np.int32).tobytes())
bi = response_data.buffer_info()
response = [lg.QuerySampleResponse(response_id, bi[0], bi[1])]
lg.FirstTokenComplete(response)
+ def stream_api(self, input, response_ids, idx):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model_id': 'Llama-2-70b-chat-hf-caikit',
+ 'inputs': input,
+ 'parameters': {
+ 'max_new_tokens': 1024,
+ 'min_new_tokens': 1,
+ 'decoding_method': "GREEDY"
+ },
+ }
+
+ token_cache = []
+ s = requests.Session()
+ first = True
+ with s.post(
+ self.api_servers[idx],
+ headers=headers,
+ json=json_data,
+ verify=False,
+ stream=True
+ ) as resp:
+ for line in resp.iter_lines():
+ if line:
+ decoded = line.decode()
+ if decoded.startswith("data"):
+ token_l = json.loads(decoded[6:])["tokens"]
+ if token_l:
+ token = self.llama_vocab[token_l[0]["text"]]
+ if first:
+ self.first_token_queue.put((token, response_ids[0]))
+ first = False
+ token_cache.append(token)
+ return token_cache
+
+ def stream_api_grpc(self, input, response_ids, idx):
+ token_cache = []
+ first = True
+ resps = self.grpc_clients[idx].make_request_stream(input, model_id=self.api_model_name)
+ for resp in resps:
+ if resp.tokens:
+ token = self.llama_vocab[resp.tokens[0].text]
+ if first:
+ self.first_token_queue.put((token, response_ids[0]))
+ first = False
+ token_cache.append(token)
+ return token_cache
+
+ def stream_api_vllm(self, input, response_ids, idx):
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ json_data = {
+ 'model': '/mnt/models/',
+ 'prompt': input,
+ 'max_tokens': 1024,
+ 'temperature': 0,
+ 'stream': True,
+ 'logprobs': 1
+ }
+
+ while True:
+ try:
+ token_cache = []
+ s = requests.Session()
+ first = True
+ with s.post(
+ f'{self.api_servers[idx]}/v1/completions',
+ headers=headers,
+ json=json_data,
+ verify=False,
+ stream=True
+ ) as resp:
+ for line in resp.iter_lines():
+ if line:
+ decoded = line.decode()
+ if decoded.startswith("data") and "[DONE]" not in decoded:
+ inter = json.loads(decoded[6:])["choices"][0]["logprobs"]
+ if "top_logprobs" in inter:
+ token_s = list(inter["top_logprobs"][0].keys())[0]
+ token = self.llama_vocab[token_s]
+ if first:
+ self.first_token_queue.put((token, response_ids[0]))
+ first = False
+ token_cache.append(token)
+ s.close()
+ if token_cache:
+ return token_cache
+ except:
+ s.close()
+ print("Connection failure")
+
+ def async_process_query(self, input_ids_tensor, qitem_id, idx):
+ decoded = self.tokenizer.decode(input_ids_tensor[0])
+ response_ids = [qitem_id]
+ if self.grpc:
+ output_tokens = self.stream_api_grpc(decoded, response_ids, idx)
+ elif self.vllm:
+ output_tokens = self.stream_api_vllm(decoded, response_ids, idx)
+ else:
+ output_tokens = self.stream_api(decoded, response_ids, idx)
+
+ n_tokens = len(output_tokens)
+ if n_tokens <= 1:
+ print("WARNING: caught low token count")
+ print(input_ids_tensor)
+ print(output_tokens)
+ response_array = array.array("B", np.array(output_tokens, np.int32).tobytes())
+ bi = response_array.buffer_info()
+ response = [lg.QuerySampleResponse(
+ qitem_id, bi[0], bi[1], n_tokens)]
+ lg.QuerySamplesComplete(response)
+ sys.exit()
+
def process_queries(self):
"""Processor of the queued queries. User may choose to add batching logic """
+ server_idx = 0
while True:
qitem = self.query_queue.get()
@@ -330,24 +595,29 @@ def process_queries(self):
input_ids_tensor = self.data_object.input_ids[qitem.index]
input_masks_tensor = self.data_object.attention_masks[qitem.index]
- #TODO: This PoC is super slow with significant overhead. Best to create a patch to `generate`
- tokens_cache = []
- tokens_streamer = FirstTokenStreamer(self.first_token_queue, tokens_cache=tokens_cache, is_first_token=True, response_ids=[qitem.id])
+ if self.api_servers:
+ threading.Thread(target=self.async_process_query, args=(input_ids_tensor, qitem.id, server_idx)).start()
+ server_idx = (server_idx + 1) % len(self.api_servers)
+ else:
+ #TODO: This PoC is super slow with significant overhead. Best to create a patch to `generate`
+ tokens_cache = []
+ tokens_streamer = FirstTokenStreamer(self.first_token_queue, tokens_cache=tokens_cache, is_first_token=True, response_ids=[qitem.id])
- _ = self.model.generate( input_ids=input_ids_tensor,
- attention_mask=input_masks_tensor,
- pad_token_id=self.tokenizer.pad_token_id,
- streamer = tokens_streamer,
- **gen_kwargs
- )
+ _ = self.model.generate( input_ids=input_ids_tensor,
+ attention_mask=input_masks_tensor,
+ pad_token_id=self.tokenizer.pad_token_id,
+ streamer = tokens_streamer,
+ **gen_kwargs
+ )
- output_tokens = tokens_streamer.get_out_tokens()
+ output_tokens = tokens_streamer.get_out_tokens()
- response_array = array.array("B", np.array(output_tokens, np.int32).tobytes())
- bi = response_array.buffer_info()
- response = [lg.QuerySampleResponse(
- qitem.id, bi[0], bi[1])]
- lg.QuerySamplesComplete(response)
+ n_tokens = len(output_tokens)
+ response_array = array.array("B", np.array(output_tokens, np.int32).tobytes())
+ bi = response_array.buffer_info()
+ response = [lg.QuerySampleResponse(
+ qitem.id, bi[0], bi[1], n_tokens)]
+ lg.QuerySamplesComplete(response)
def issue_queries(self, query_samples):
diff --git a/language/llama2-70b/api-endpoint-artifacts/Dockerfile-API b/language/llama2-70b/api-endpoint-artifacts/Dockerfile-API
new file mode 100644
index 000000000..77385d42e
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/Dockerfile-API
@@ -0,0 +1,16 @@
+FROM pytorch/pytorch:latest
+
+COPY inference ./inference
+COPY processed-openorca/processed-data.pkl ./processed-data.pkl
+COPY llama-model-info ./llama-model-info
+
+RUN apt-get update && apt install build-essential -y
+RUN conda install pybind11==2.10.4 -c conda-forge -y
+RUN pip install --upgrade pip
+RUN pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 grpcio-tools
+RUN cd inference/loadgen && python -m pip install .
+RUN cp inference/mlperf.conf inference/language/llama2-70b/mlperf.conf
+
+ENV DATASET_PATH=/workspace/processed-data.pkl
+ENV CHECKPOINT_PATH=/workspace/llama-model-info
+ENV ACCURACY_LOG_FILE=/workspace/inference/language/llama2-70b/offline-logs/mlperf_log_accuracy.json
diff --git a/language/llama2-70b/api-endpoint-artifacts/README.md b/language/llama2-70b/api-endpoint-artifacts/README.md
new file mode 100644
index 000000000..376601f55
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/README.md
@@ -0,0 +1,90 @@
+# Using OpenShift AI Model Serving (Caikit/TGIS) with MLPerf Inference
+
+Prerequisites:
+ - Install the OpenShift AI model serving stack
+ - Add your AWS credentials to `secret.yaml` access the model files
+ - Apply `secret.yaml`, `sa.yaml`
+ - FOR CAIKIT+TGIS: Apply `serving-runtime.yaml`, then finally `model.yaml`
+ - FOR TGIS: Apply `serving-tgis.yaml`, then finally `model-tgis.yaml`
+ - FOR VLLM (Best Performing): Apply `serving-vllm.yaml`, then finally `model-vllm.yaml`
+ - Create a benchmark pod using `benchmark.yaml`
+
+In the pod, before any benchmark, first run `cd inference/language/llama2-70b`
+
+## VLLM INSTRUCTIONS
+For the full accuracy benchmark (offline), run in the pod:
+```
+python3 -u main.py --scenario Offline --model-path ${CHECKPOINT_PATH} --api-server --api-model-name Llama-2-70b-chat-hf --mlperf-conf mlperf.conf --accuracy --vllm --user-conf user.conf --total-sample-count 24576 --dataset-path ${DATASET_PATH} --output-log-dir offline-logs --dtype float32 --device cpu 2>&1 | tee offline_performance_log.log
+```
+You can then run the same evaluation/consolidation scripts as the regular benchmark
+Example API host: `https://llama-2-70b-chat-isvc-predictor-llama-service.apps.h100serving.perf.lab.eng.bos.redhat.com`
+
+
+For the performance benchmark (offline), run in the pod:
+```
+python3 -u main.py --scenario Offline --model-path ${CHECKPOINT_PATH} --api-server --api-model-name Llama-2-70b-chat-hf --mlperf-conf mlperf.conf --vllm --user-conf user.conf --total-sample-count 24576 --dataset-path ${DATASET_PATH} --output-log-dir offline-logs --dtype float32 --device cpu 2>&1 | tee offline_performance_log.log
+```
+(It is the same, just with `--accuracy` removed)
+
+ - For multiple endpoints, add `--additional-servers ...`
+
+
+For the performance benchmark (server), run in the pod:
+```
+python3 -u main.py --scenario Server --model-path ${CHECKPOINT_PATH} --api-server --api-model-name Llama-2-70b-chat-hf --mlperf-conf mlperf.conf --vllm --user-conf user.conf --total-sample-count 24576 --dataset-path ${DATASET_PATH} --output-log-dir server-logs --dtype float32 --device cpu 2>&1 | tee server_performance_log.log
+```
+(Configure target qps in `user.conf`)
+
+
+NOTE: Hyperparams are currently configured for 8xH100
+
+## STANDALONE TGIS INSTRUCTIONS
+For the full accuracy benchmark (offline), run in the pod:
+```
+python3 -u main.py --scenario Offline --model-path ${CHECKPOINT_PATH} --api-server --api-model-name Llama-2-70b-chat-hf --mlperf-conf mlperf.conf --accuracy --grpc --batch-grpc --user-conf user.conf --total-sample-count 24576 --dataset-path ${DATASET_PATH} --output-log-dir offline-logs --dtype float32 --device cpu 2>&1 | tee offline_performance_log.log
+```
+You can then run the same evaluation/consolidation scripts as the regular benchmark
+Example API host: `https://llama-2-70b-chat-isvc-predictor-llama-service.apps.h100serving.perf.lab.eng.bos.redhat.com`
+
+
+For the performance benchmark (offline), run in the pod:
+```
+python3 -u main.py --scenario Offline --model-path ${CHECKPOINT_PATH} --api-server --api-model-name Llama-2-70b-chat-hf --mlperf-conf mlperf.conf --grpc --batch-grpc --user-conf user.conf --total-sample-count 24576 --dataset-path ${DATASET_PATH} --output-log-dir offline-logs --dtype float32 --device cpu 2>&1 | tee offline_performance_log.log
+```
+(It is the same, just with `--accuracy` removed)
+
+ - For multiple endpoints, add `--additional-servers ...`
+
+
+For the performance benchmark (server), run in the pod:
+```
+python3 -u main.py --scenario Server --model-path ${CHECKPOINT_PATH} --api-server --api-model-name Llama-2-70b-chat-hf --mlperf-conf mlperf.conf --grpc --user-conf user.conf --total-sample-count 24576 --dataset-path ${DATASET_PATH} --output-log-dir server-logs --dtype float32 --device cpu 2>&1 | tee server_performance_log.log
+```
+(Configure target qps in `user.conf`)
+
+
+NOTE: Hyperparams are currently configured for 8xH100
+
+## CAIKIT INSTRUCTIONS
+For the full accuracy benchmark (offline), run in the pod:
+```
+python3 -u main.py --scenario Offline --model-path ${CHECKPOINT_PATH} --api-server --api-model-name Llama-2-70b-chat-hf-caikit --accuracy --mlperf-conf mlperf.conf --user-conf user.conf --total-sample-count 24576 --dataset-path ${DATASET_PATH} --output-log-dir offline-logs --dtype float32 --device cpu 2>&1 | tee offline_performance_log.log
+```
+You can then run the same evaluation/consolidation scripts as the regular benchmark
+
+
+For the performance benchmark (offline), run in the pod:
+```
+python3 -u main.py --scenario Offline --model-path ${CHECKPOINT_PATH} --api-server --api-model-name Llama-2-70b-chat-hf-caikit --mlperf-conf mlperf.conf --user-conf user.conf --total-sample-count 24576 --dataset-path ${DATASET_PATH} --output-log-dir offline-logs --dtype float32 --device cpu 2>&1 | tee offline_performance_log.log
+```
+(It is the same, just with `--accuracy` removed)
+
+
+For the performance benchmark (server), run in the pod:
+```
+python3 -u main.py --scenario Server --model-path ${CHECKPOINT_PATH} --api-server --api-model-name Llama-2-70b-chat-hf-caikit --mlperf-conf mlperf.conf --user-conf user.conf --total-sample-count 24576 --dataset-path ${DATASET_PATH} --output-log-dir server-logs --dtype float32 --device cpu 2>&1 | tee server_performance_log.log
+```
+(Configure target qps in `user.conf`)
+
+
+NOTE: Hyperparams are currently configured for 8xH100
diff --git a/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml b/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml
new file mode 100644
index 000000000..5959ca63d
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml
@@ -0,0 +1,21 @@
+apiVersion: v1
+kind: Pod
+metadata:
+ name: mlperf-inference
+spec:
+ restartPolicy: Never
+ containers:
+ - name: mlperf-env
+ image: quay.io/meyceoz/mlperf-inference:v13
+ resources:
+ requests:
+ memory: 20000Mi
+ volumeMounts:
+ - mountPath: /dev/shm
+ name: dshm
+ command: [ "/bin/sh", "-c" ]
+ args: [ "sleep infinity" ]
+ volumes:
+ - name: dshm
+ emptyDir:
+ medium: Memory
\ No newline at end of file
diff --git a/language/llama2-70b/api-endpoint-artifacts/model-tgis.yaml b/language/llama2-70b/api-endpoint-artifacts/model-tgis.yaml
new file mode 100644
index 000000000..6f920658b
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/model-tgis.yaml
@@ -0,0 +1,20 @@
+apiVersion: serving.kserve.io/v1beta1
+kind: InferenceService
+metadata:
+ annotations:
+ serving.knative.openshift.io/enablePassthrough: "true"
+ sidecar.istio.io/inject: "true"
+ sidecar.istio.io/rewriteAppHTTPProbers: "true"
+ name: llama-2-70b-chat-isvc
+spec:
+ predictor:
+ minReplicas: 1
+ maxReplicas: 1
+ #apiVersion: serving.kserve.io/v1alpha2
+ serviceAccountName: sa
+ #timeout: 240
+ model:
+ modelFormat:
+ name: pytorch
+ runtime: tgis-runtime-grpc
+ storageUri: s3://mlperf-inference-models/Llama-2-70b-chat-hf/
\ No newline at end of file
diff --git a/language/llama2-70b/api-endpoint-artifacts/model-vllm.yaml b/language/llama2-70b/api-endpoint-artifacts/model-vllm.yaml
new file mode 100644
index 000000000..01459b7ce
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/model-vllm.yaml
@@ -0,0 +1,20 @@
+apiVersion: serving.kserve.io/v1beta1
+kind: InferenceService
+metadata:
+ annotations:
+ serving.knative.openshift.io/enablePassthrough: "true"
+ sidecar.istio.io/inject: "true"
+ sidecar.istio.io/rewriteAppHTTPProbers: "true"
+ name: llama-2-isvc
+spec:
+ predictor:
+ minReplicas: 1
+ maxReplicas: 1
+ #apiVersion: serving.kserve.io/v1alpha2
+ serviceAccountName: sa
+ #timeout: 240
+ model:
+ modelFormat:
+ name: pytorch
+ runtime: vllm
+ storageUri: s3://mlperf-inference-models/Llama-2-70b-chat-hf/
\ No newline at end of file
diff --git a/language/llama2-70b/api-endpoint-artifacts/model.yaml b/language/llama2-70b/api-endpoint-artifacts/model.yaml
new file mode 100644
index 000000000..5c21a3836
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/model.yaml
@@ -0,0 +1,20 @@
+apiVersion: serving.kserve.io/v1beta1
+kind: InferenceService
+metadata:
+ annotations:
+ serving.knative.openshift.io/enablePassthrough: "true"
+ sidecar.istio.io/inject: "true"
+ sidecar.istio.io/rewriteAppHTTPProbers: "true"
+ name: llama-2-70b-chat-isvc
+spec:
+ predictor:
+ minReplicas: 1
+ maxReplicas: 1
+ apiVersion: serving.kserve.io/v1alpha2
+ serviceAccountName: sa
+ timeout: 240
+ model:
+ modelFormat:
+ name: caikit
+ runtime: caikit-runtime
+ storageUri: s3://mlperf-inference-models/Llama-2-70b-chat-hf-caikit
\ No newline at end of file
diff --git a/language/llama2-70b/api-endpoint-artifacts/sa.yaml b/language/llama2-70b/api-endpoint-artifacts/sa.yaml
new file mode 100644
index 000000000..6ccfd13b7
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/sa.yaml
@@ -0,0 +1,6 @@
+apiVersion: v1
+kind: ServiceAccount
+metadata:
+ name: sa
+secrets:
+- name: storage-config
\ No newline at end of file
diff --git a/language/llama2-70b/api-endpoint-artifacts/secret.yaml b/language/llama2-70b/api-endpoint-artifacts/secret.yaml
new file mode 100644
index 000000000..1b90db6c0
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/secret.yaml
@@ -0,0 +1,12 @@
+apiVersion: v1
+kind: Secret
+metadata:
+ annotations:
+ serving.kserve.io/s3-endpoint: "s3.amazonaws.com" # replace with your s3 endpoint e.g minio-service.kubeflow:9000
+ serving.kserve.io/s3-usehttps: "1" # by default 1, if testing with minio you can set to 0
+ serving.kserve.io/s3-region: "us-east-1"
+ serving.kserve.io/s3-useanoncredential: "false" # omitting this is the same as false, if true will ignore provided credential and use anonymous credentials
+ name: storage-config
+stringData:
+ "AWS_ACCESS_KEY_ID": "XXXXXXXXX"
+ "AWS_SECRET_ACCESS_KEY": "XXXXXXXXX"
\ No newline at end of file
diff --git a/language/llama2-70b/api-endpoint-artifacts/serving-runtime.yaml b/language/llama2-70b/api-endpoint-artifacts/serving-runtime.yaml
new file mode 100644
index 000000000..f8e05fe5d
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/serving-runtime.yaml
@@ -0,0 +1,107 @@
+apiVersion: v1
+kind: ConfigMap
+metadata:
+ name: caikit-tgis-config
+data:
+ caikit.yml: |
+ runtime:
+ library: caikit_nlp
+ local_models_dir: /mnt/models/
+ lazy_load_local_models: true
+ model_management:
+ finders:
+ default:
+ type: MULTI
+ config:
+ finder_priority:
+ - tgis-auto
+ tgis-auto:
+ type: TGIS-AUTO
+ config:
+ test_connection: true
+ initializers:
+ default:
+ type: LOCAL
+ config:
+ backend_priority:
+ - type: TGIS
+ config:
+ local:
+ num_gpus: 8
+ load_timeout: 2000
+ connection:
+ hostname: localhost:8033
+---
+apiVersion: serving.kserve.io/v1alpha1
+kind: ServingRuntime
+metadata:
+ name: caikit-runtime
+spec:
+ multiModel: false
+ supportedModelFormats:
+ # Note: this currently *only* supports caikit format models
+ - autoSelect: true
+ name: caikit
+ containers:
+ - name: kserve-container
+ image: quay.io/opendatahub/text-generation-inference:fast
+ command: ["text-generation-launcher"]
+ args: ["--model-name=/mnt/models/artifacts/"]
+ env:
+ - name: TRANSFORMERS_CACHE
+ value: /tmp/transformers_cache
+ #- name: RUNTIME_LOCAL_MODELS_DIR
+ #value: /mnt/models
+ - name: NUM_GPUS
+ value: "8"
+ #- name: TRANSFORMERS_CACHE
+ # value: /shared_model_storage/transformers_cache
+ #- name: HUGGINGFACE_HUB_CACHE
+ # value: /shared_model_storage/transformers_cache
+ #- name: DTYPE_STR
+ # value: float16
+ # Dynamic batch size changes
+ - name: MAX_BATCH_SIZE
+ value: "128"
+ - name: MAX_CONCURRENT_REQUESTS
+ value: "200"
+ - name: MAX_BATCH_WEIGHT
+ value: "550000"
+ - name: MAX_SEQUENCE_LENGTH
+ value: "2048"
+ - name: MAX_PREFILL_WEIGHT
+ value: "0"
+ - name: MAX_NEW_TOKENS
+ value: "1024"
+ - name: FLASH_ATTENTION
+ value: "true"
+ - name: DEPLOYMENT_FRAMEWORK
+ value: hf_custom_tp
+ resources: # configure as required
+ requests:
+ cpu: 64
+ memory: 900Gi
+ nvidia.com/gpu: 8
+ limits:
+ nvidia.com/gpu: 8
+ - name: transformer-container
+ image: quay.io/opendatahub/caikit-tgis-serving:fast
+ env:
+ - name: RUNTIME_GRPC_SERVER_THREAD_POOL_SIZE
+ value: "200"
+ volumeMounts:
+ - name: config-volume
+ mountPath: /caikit/config/
+ readOnly: true
+ ports:
+ - containerPort: 8080
+ #name: h2c
+ protocol: TCP
+ resources: # configure as required
+ requests:
+ cpu: 2
+ memory: 4Gi
+ volumes:
+ - name: config-volume
+ configMap:
+ name: caikit-tgis-config
diff --git a/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml b/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml
new file mode 100644
index 000000000..fa0644af5
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml
@@ -0,0 +1,67 @@
+apiVersion: serving.kserve.io/v1alpha1
+kind: ServingRuntime
+metadata:
+ name: tgis-runtime-grpc
+spec:
+ multiModel: false
+ supportedModelFormats:
+ - autoSelect: true
+ name: pytorch
+ containers:
+ - name: kserve-container
+ image: quay.io/opendatahub/text-generation-inference:fast
+ command: ["text-generation-launcher"]
+ args:
+ - "--model-name=/mnt/models/"
+ - "--port=3000"
+ - "--grpc-port=8033"
+ env:
+ - name: TRANSFORMERS_CACHE
+ value: /tmp/transformers_cache
+ - name: NUM_GPUS
+ value: "8"
+ #- name: DTYPE_STR
+ # value: float16
+ # Dynamic batch size changes
+ - name: MAX_BATCH_SIZE
+ value: "512"
+ - name: MAX_CONCURRENT_REQUESTS
+ value: "5000"
+ - name: MAX_BATCH_WEIGHT
+ value: "1000000"
+ - name: MAX_SEQUENCE_LENGTH
+ value: "2048"
+ - name: MAX_PREFILL_WEIGHT
+ value: "0"
+ - name: MAX_NEW_TOKENS
+ value: "1024"
+ - name: FLASH_ATTENTION
+ value: "true"
+ - name: DEPLOYMENT_FRAMEWORK
+ value: hf_custom_tp
+ - name: LOG_GPU_USAGE_INTERVAL
+ value: "5"
+ resources: # configure as required
+ requests:
+ cpu: 96
+ memory: 1000Gi
+ nvidia.com/gpu: 8
+ limits:
+ nvidia.com/gpu: 8
+ readinessProbe: # Use exec probes instad of httpGet since the probes' port gets rewritten to the containerPort
+ exec:
+ command:
+ - curl
+ - localhost:3000/health
+ initialDelaySeconds: 5
+ livenessProbe:
+ exec:
+ command:
+ - curl
+ - localhost:3000/health
+ initialDelaySeconds: 5
+ # periodSeconds: 5
+ ports:
+ - containerPort: 8033
+ name: h2c
+ protocol: TCP
\ No newline at end of file
diff --git a/language/llama2-70b/api-endpoint-artifacts/serving-vllm.yaml b/language/llama2-70b/api-endpoint-artifacts/serving-vllm.yaml
new file mode 100644
index 000000000..f14b1f78e
--- /dev/null
+++ b/language/llama2-70b/api-endpoint-artifacts/serving-vllm.yaml
@@ -0,0 +1,49 @@
+apiVersion: serving.kserve.io/v1alpha1
+kind: ServingRuntime
+labels:
+ opendatahub.io/dashboard: "true"
+metadata:
+ annotations:
+ openshift.io/display-name: vLLM
+ name: vllm
+spec:
+ builtInAdapter:
+ modelLoadingTimeoutMillis: 90000
+ containers:
+ - args:
+ - --model
+ - /mnt/models/
+ - --download-dir
+ - /models-cache
+ - --port
+ - "8080"
+ - --dtype
+ - float16
+ - --tensor-parallel-size
+ - "4"
+ image: quay.io/rh-aiservices-bu/vllm-openai-ubi9:0.3.1-fix-2939
+ name: kserve-container
+ ports:
+ - containerPort: 8080
+ name: http1
+ protocol: TCP
+ resources: # configure as required
+ requests:
+ cpu: 16
+ memory: 512Gi
+ nvidia.com/gpu: 4
+ limits:
+ cpu: 16
+ memory: 512Gi
+ nvidia.com/gpu: 4
+ volumeMounts:
+ - mountPath: /dev/shm
+ name: dshm
+ volumes:
+ - name: dshm
+ emptyDir:
+ medium: Memory
+ multiModel: false
+ supportedModelFormats:
+ - autoSelect: true
+ name: pytorch
\ No newline at end of file
diff --git a/language/llama2-70b/dataset.py b/language/llama2-70b/dataset.py
index a59fd7f55..4b1b1bb91 100644
--- a/language/llama2-70b/dataset.py
+++ b/language/llama2-70b/dataset.py
@@ -84,6 +84,8 @@ def postProcess(self, out_tokens, input_seq_lens=None, query_id_list=None, sampl
assert len(query_id_list) == output_seq.shape[0]
# Save outputs
+ if not os.path.exists("run_outputs"):
+ os.makedirs("run_outputs")
fname = "q" + "_".join([str(i) for i in query_id_list])
fname = f"run_outputs/{fname}.pkl"
with open(fname, mode='wb') as f:
diff --git a/language/llama2-70b/generation_pb2.py b/language/llama2-70b/generation_pb2.py
new file mode 100644
index 000000000..544e1ec7d
--- /dev/null
+++ b/language/llama2-70b/generation_pb2.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: generation.proto
+# Protobuf Python Version: 4.25.0
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10generation.proto\x12\x05\x66maas\"\xa1\x01\n\x18\x42\x61tchedGenerationRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x16\n\tprefix_id\x18\x02 \x01(\tH\x00\x88\x01\x01\x12*\n\x08requests\x18\x03 \x03(\x0b\x32\x18.fmaas.GenerationRequest\x12!\n\x06params\x18\n \x01(\x0b\x32\x11.fmaas.ParametersB\x0c\n\n_prefix_id\"\x9f\x01\n\x17SingleGenerationRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x16\n\tprefix_id\x18\x02 \x01(\tH\x00\x88\x01\x01\x12)\n\x07request\x18\x03 \x01(\x0b\x32\x18.fmaas.GenerationRequest\x12!\n\x06params\x18\n \x01(\x0b\x32\x11.fmaas.ParametersB\x0c\n\n_prefix_id\"I\n\x19\x42\x61tchedGenerationResponse\x12,\n\tresponses\x18\x01 \x03(\x0b\x32\x19.fmaas.GenerationResponse\"!\n\x11GenerationRequest\x12\x0c\n\x04text\x18\x02 \x01(\t\"\xf3\x01\n\x12GenerationResponse\x12\x19\n\x11input_token_count\x18\x06 \x01(\r\x12\x1d\n\x15generated_token_count\x18\x02 \x01(\r\x12\x0c\n\x04text\x18\x04 \x01(\t\x12&\n\x0bstop_reason\x18\x07 \x01(\x0e\x32\x11.fmaas.StopReason\x12\x15\n\rstop_sequence\x18\x0b \x01(\t\x12\x0c\n\x04seed\x18\n \x01(\x04\x12 \n\x06tokens\x18\x08 \x03(\x0b\x32\x10.fmaas.TokenInfo\x12&\n\x0cinput_tokens\x18\t \x03(\x0b\x32\x10.fmaas.TokenInfo\"\x81\x02\n\nParameters\x12%\n\x06method\x18\x01 \x01(\x0e\x32\x15.fmaas.DecodingMethod\x12+\n\x08sampling\x18\x02 \x01(\x0b\x32\x19.fmaas.SamplingParameters\x12)\n\x08stopping\x18\x03 \x01(\x0b\x32\x17.fmaas.StoppingCriteria\x12(\n\x08response\x18\x04 \x01(\x0b\x32\x16.fmaas.ResponseOptions\x12+\n\x08\x64\x65\x63oding\x18\x05 \x01(\x0b\x32\x19.fmaas.DecodingParameters\x12\x1d\n\x15truncate_input_tokens\x18\x06 \x01(\r\"\xc5\x01\n\x12\x44\x65\x63odingParameters\x12\x1a\n\x12repetition_penalty\x18\x01 \x01(\x02\x12\x44\n\x0elength_penalty\x18\x02 \x01(\x0b\x32\'.fmaas.DecodingParameters.LengthPenaltyH\x00\x88\x01\x01\x1a:\n\rLengthPenalty\x12\x13\n\x0bstart_index\x18\x01 \x01(\r\x12\x14\n\x0c\x64\x65\x63\x61y_factor\x18\x02 \x01(\x02\x42\x11\n\x0f_length_penalty\"v\n\x12SamplingParameters\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_k\x18\x02 \x01(\r\x12\r\n\x05top_p\x18\x03 \x01(\x02\x12\x11\n\ttypical_p\x18\x04 \x01(\x02\x12\x11\n\x04seed\x18\x05 \x01(\x04H\x00\x88\x01\x01\x42\x07\n\x05_seed\"\xb3\x01\n\x10StoppingCriteria\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\r\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\r\x12\x19\n\x11time_limit_millis\x18\x03 \x01(\r\x12\x16\n\x0estop_sequences\x18\x04 \x03(\t\x12\"\n\x15include_stop_sequence\x18\x05 \x01(\x08H\x00\x88\x01\x01\x42\x18\n\x16_include_stop_sequence\"\x98\x01\n\x0fResponseOptions\x12\x12\n\ninput_text\x18\x01 \x01(\x08\x12\x18\n\x10generated_tokens\x18\x02 \x01(\x08\x12\x14\n\x0cinput_tokens\x18\x03 \x01(\x08\x12\x16\n\x0etoken_logprobs\x18\x04 \x01(\x08\x12\x13\n\x0btoken_ranks\x18\x05 \x01(\x08\x12\x14\n\x0ctop_n_tokens\x18\x06 \x01(\r\"\x92\x01\n\tTokenInfo\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x0f\n\x07logprob\x18\x03 \x01(\x02\x12\x0c\n\x04rank\x18\x04 \x01(\r\x12-\n\ntop_tokens\x18\x05 \x03(\x0b\x32\x19.fmaas.TokenInfo.TopToken\x1a)\n\x08TopToken\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x0f\n\x07logprob\x18\x03 \x01(\x02\"k\n\x16\x42\x61tchedTokenizeRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12(\n\x08requests\x18\x02 \x03(\x0b\x32\x16.fmaas.TokenizeRequest\x12\x15\n\rreturn_tokens\x18\x03 \x01(\x08\"E\n\x17\x42\x61tchedTokenizeResponse\x12*\n\tresponses\x18\x01 \x03(\x0b\x32\x17.fmaas.TokenizeResponse\"\x1f\n\x0fTokenizeRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\"7\n\x10TokenizeResponse\x12\x13\n\x0btoken_count\x18\x01 \x01(\r\x12\x0e\n\x06tokens\x18\x02 \x03(\t\"$\n\x10ModelInfoRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\"\xb4\x01\n\x11ModelInfoResponse\x12\x36\n\nmodel_kind\x18\x01 \x01(\x0e\x32\".fmaas.ModelInfoResponse.ModelKind\x12\x1b\n\x13max_sequence_length\x18\x02 \x01(\r\x12\x16\n\x0emax_new_tokens\x18\x03 \x01(\r\"2\n\tModelKind\x12\x10\n\x0c\x44\x45\x43ODER_ONLY\x10\x00\x12\x13\n\x0f\x45NCODER_DECODER\x10\x01*(\n\x0e\x44\x65\x63odingMethod\x12\n\n\x06GREEDY\x10\x00\x12\n\n\x06SAMPLE\x10\x01*\x8b\x01\n\nStopReason\x12\x10\n\x0cNOT_FINISHED\x10\x00\x12\x0e\n\nMAX_TOKENS\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\r\n\tCANCELLED\x10\x03\x12\x0e\n\nTIME_LIMIT\x10\x04\x12\x11\n\rSTOP_SEQUENCE\x10\x05\x12\x0f\n\x0bTOKEN_LIMIT\x10\x06\x12\t\n\x05\x45RROR\x10\x07\x32\xc4\x02\n\x11GenerationService\x12O\n\x08Generate\x12\x1f.fmaas.BatchedGenerationRequest\x1a .fmaas.BatchedGenerationResponse\"\x00\x12O\n\x0eGenerateStream\x12\x1e.fmaas.SingleGenerationRequest\x1a\x19.fmaas.GenerationResponse\"\x00\x30\x01\x12K\n\x08Tokenize\x12\x1d.fmaas.BatchedTokenizeRequest\x1a\x1e.fmaas.BatchedTokenizeResponse\"\x00\x12@\n\tModelInfo\x12\x17.fmaas.ModelInfoRequest\x1a\x18.fmaas.ModelInfoResponse\"\x00\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'generation_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_DECODINGMETHOD']._serialized_start=2266
+ _globals['_DECODINGMETHOD']._serialized_end=2306
+ _globals['_STOPREASON']._serialized_start=2309
+ _globals['_STOPREASON']._serialized_end=2448
+ _globals['_BATCHEDGENERATIONREQUEST']._serialized_start=28
+ _globals['_BATCHEDGENERATIONREQUEST']._serialized_end=189
+ _globals['_SINGLEGENERATIONREQUEST']._serialized_start=192
+ _globals['_SINGLEGENERATIONREQUEST']._serialized_end=351
+ _globals['_BATCHEDGENERATIONRESPONSE']._serialized_start=353
+ _globals['_BATCHEDGENERATIONRESPONSE']._serialized_end=426
+ _globals['_GENERATIONREQUEST']._serialized_start=428
+ _globals['_GENERATIONREQUEST']._serialized_end=461
+ _globals['_GENERATIONRESPONSE']._serialized_start=464
+ _globals['_GENERATIONRESPONSE']._serialized_end=707
+ _globals['_PARAMETERS']._serialized_start=710
+ _globals['_PARAMETERS']._serialized_end=967
+ _globals['_DECODINGPARAMETERS']._serialized_start=970
+ _globals['_DECODINGPARAMETERS']._serialized_end=1167
+ _globals['_DECODINGPARAMETERS_LENGTHPENALTY']._serialized_start=1090
+ _globals['_DECODINGPARAMETERS_LENGTHPENALTY']._serialized_end=1148
+ _globals['_SAMPLINGPARAMETERS']._serialized_start=1169
+ _globals['_SAMPLINGPARAMETERS']._serialized_end=1287
+ _globals['_STOPPINGCRITERIA']._serialized_start=1290
+ _globals['_STOPPINGCRITERIA']._serialized_end=1469
+ _globals['_RESPONSEOPTIONS']._serialized_start=1472
+ _globals['_RESPONSEOPTIONS']._serialized_end=1624
+ _globals['_TOKENINFO']._serialized_start=1627
+ _globals['_TOKENINFO']._serialized_end=1773
+ _globals['_TOKENINFO_TOPTOKEN']._serialized_start=1732
+ _globals['_TOKENINFO_TOPTOKEN']._serialized_end=1773
+ _globals['_BATCHEDTOKENIZEREQUEST']._serialized_start=1775
+ _globals['_BATCHEDTOKENIZEREQUEST']._serialized_end=1882
+ _globals['_BATCHEDTOKENIZERESPONSE']._serialized_start=1884
+ _globals['_BATCHEDTOKENIZERESPONSE']._serialized_end=1953
+ _globals['_TOKENIZEREQUEST']._serialized_start=1955
+ _globals['_TOKENIZEREQUEST']._serialized_end=1986
+ _globals['_TOKENIZERESPONSE']._serialized_start=1988
+ _globals['_TOKENIZERESPONSE']._serialized_end=2043
+ _globals['_MODELINFOREQUEST']._serialized_start=2045
+ _globals['_MODELINFOREQUEST']._serialized_end=2081
+ _globals['_MODELINFORESPONSE']._serialized_start=2084
+ _globals['_MODELINFORESPONSE']._serialized_end=2264
+ _globals['_MODELINFORESPONSE_MODELKIND']._serialized_start=2214
+ _globals['_MODELINFORESPONSE_MODELKIND']._serialized_end=2264
+ _globals['_GENERATIONSERVICE']._serialized_start=2451
+ _globals['_GENERATIONSERVICE']._serialized_end=2775
+# @@protoc_insertion_point(module_scope)
diff --git a/language/llama2-70b/generation_pb2_grpc.py b/language/llama2-70b/generation_pb2_grpc.py
new file mode 100644
index 000000000..9ab1e4eb5
--- /dev/null
+++ b/language/llama2-70b/generation_pb2_grpc.py
@@ -0,0 +1,169 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+import generation_pb2 as generation__pb2
+
+
+class GenerationServiceStub(object):
+ """Missing associated documentation comment in .proto file."""
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.Generate = channel.unary_unary(
+ '/fmaas.GenerationService/Generate',
+ request_serializer=generation__pb2.BatchedGenerationRequest.SerializeToString,
+ response_deserializer=generation__pb2.BatchedGenerationResponse.FromString,
+ )
+ self.GenerateStream = channel.unary_stream(
+ '/fmaas.GenerationService/GenerateStream',
+ request_serializer=generation__pb2.SingleGenerationRequest.SerializeToString,
+ response_deserializer=generation__pb2.GenerationResponse.FromString,
+ )
+ self.Tokenize = channel.unary_unary(
+ '/fmaas.GenerationService/Tokenize',
+ request_serializer=generation__pb2.BatchedTokenizeRequest.SerializeToString,
+ response_deserializer=generation__pb2.BatchedTokenizeResponse.FromString,
+ )
+ self.ModelInfo = channel.unary_unary(
+ '/fmaas.GenerationService/ModelInfo',
+ request_serializer=generation__pb2.ModelInfoRequest.SerializeToString,
+ response_deserializer=generation__pb2.ModelInfoResponse.FromString,
+ )
+
+
+class GenerationServiceServicer(object):
+ """Missing associated documentation comment in .proto file."""
+
+ def Generate(self, request, context):
+ """Generates text given a text prompt, for one or more inputs
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def GenerateStream(self, request, context):
+ """Generates text given a single input prompt, streaming the response
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def Tokenize(self, request, context):
+ """Tokenize text
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def ModelInfo(self, request, context):
+ """Model info
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_GenerationServiceServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'Generate': grpc.unary_unary_rpc_method_handler(
+ servicer.Generate,
+ request_deserializer=generation__pb2.BatchedGenerationRequest.FromString,
+ response_serializer=generation__pb2.BatchedGenerationResponse.SerializeToString,
+ ),
+ 'GenerateStream': grpc.unary_stream_rpc_method_handler(
+ servicer.GenerateStream,
+ request_deserializer=generation__pb2.SingleGenerationRequest.FromString,
+ response_serializer=generation__pb2.GenerationResponse.SerializeToString,
+ ),
+ 'Tokenize': grpc.unary_unary_rpc_method_handler(
+ servicer.Tokenize,
+ request_deserializer=generation__pb2.BatchedTokenizeRequest.FromString,
+ response_serializer=generation__pb2.BatchedTokenizeResponse.SerializeToString,
+ ),
+ 'ModelInfo': grpc.unary_unary_rpc_method_handler(
+ servicer.ModelInfo,
+ request_deserializer=generation__pb2.ModelInfoRequest.FromString,
+ response_serializer=generation__pb2.ModelInfoResponse.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'fmaas.GenerationService', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler,))
+
+
+ # This class is part of an EXPERIMENTAL API.
+class GenerationService(object):
+ """Missing associated documentation comment in .proto file."""
+
+ @staticmethod
+ def Generate(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/fmaas.GenerationService/Generate',
+ generation__pb2.BatchedGenerationRequest.SerializeToString,
+ generation__pb2.BatchedGenerationResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def GenerateStream(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(request, target, '/fmaas.GenerationService/GenerateStream',
+ generation__pb2.SingleGenerationRequest.SerializeToString,
+ generation__pb2.GenerationResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def Tokenize(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/fmaas.GenerationService/Tokenize',
+ generation__pb2.BatchedTokenizeRequest.SerializeToString,
+ generation__pb2.BatchedTokenizeResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def ModelInfo(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/fmaas.GenerationService/ModelInfo',
+ generation__pb2.ModelInfoRequest.SerializeToString,
+ generation__pb2.ModelInfoResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
diff --git a/language/llama2-70b/inference.py b/language/llama2-70b/inference.py
new file mode 100644
index 000000000..51de405ef
--- /dev/null
+++ b/language/llama2-70b/inference.py
@@ -0,0 +1,204 @@
+from typing import Optional, Union
+
+import grpc
+import generation_pb2_grpc
+import socket
+import ssl
+import sys
+
+
+def get_server_certificate(host: str, port: int) -> str:
+ """connect to host:port and get the certificate it presents
+
+ This is almost the same as `ssl.get_server_certificate`, but
+ when opening the TLS socket, `server_hostname` is also provided.
+
+ This retrieves the correct certificate for hosts using name-based
+ virtual hosting.
+ """
+ if sys.version_info >= (3, 10):
+ # ssl.get_server_certificate supports TLS SNI only above 3.10
+ # https://github.com/python/cpython/pull/16820
+ return ssl.get_server_certificate((host, port))
+
+ context = ssl.SSLContext()
+
+ with socket.create_connection((host, port)) as sock, context.wrap_socket(
+ sock, server_hostname=host
+ ) as ssock:
+ cert_der = ssock.getpeercert(binary_form=True)
+
+ assert cert_der
+ return ssl.DER_cert_to_PEM_cert(cert_der)
+
+
+class GrpcClient:
+ def __init__(
+ self,
+ host: str,
+ port: int,
+ *,
+ insecure: bool = False,
+ verify: Optional[bool] = None,
+ ca_cert: Union[None, bytes, str] = None,
+ client_cert: Union[None, bytes, str] = None,
+ client_key: Union[None, bytes, str] = None,
+ ) -> None:
+ self._channel = self._make_channel(
+ host,
+ port,
+ insecure=insecure,
+ verify=verify,
+ client_key=client_key,
+ client_cert=client_cert,
+ ca_cert=ca_cert,
+ )
+ self.generation_service_stub = generation_pb2_grpc.GenerationServiceStub(
+ self._channel
+ )
+
+ def make_request(self, texts: str, model_id: str = "flan-t5-small"):
+ request = generation_pb2_grpc.generation__pb2.BatchedGenerationRequest(
+ model_id=model_id,
+ requests=[generation_pb2_grpc.generation__pb2.GenerationRequest(text=text) for text in texts],
+ params=generation_pb2_grpc.generation__pb2.Parameters(
+ method=generation_pb2_grpc.generation__pb2.GREEDY,
+ stopping=generation_pb2_grpc.generation__pb2.StoppingCriteria(
+ max_new_tokens=1024,
+ min_new_tokens=1
+ )
+ )
+ )
+ result = self.generation_service_stub.Generate(request=request)
+ return result
+
+ def make_request_stream(self, text: str, model_id: str = "flan-t5-small"):
+ request = generation_pb2_grpc.generation__pb2.SingleGenerationRequest(
+ model_id=model_id,
+ request=generation_pb2_grpc.generation__pb2.GenerationRequest(text=text),
+ params=generation_pb2_grpc.generation__pb2.Parameters(
+ method=generation_pb2_grpc.generation__pb2.GREEDY,
+ stopping=generation_pb2_grpc.generation__pb2.StoppingCriteria(
+ max_new_tokens=1024,
+ min_new_tokens=1
+ ),
+ response=generation_pb2_grpc.generation__pb2.ResponseOptions(
+ generated_tokens=True
+ )
+ )
+ )
+ result = self.generation_service_stub.GenerateStream(request=request)
+ return result
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc_info):
+ self._close()
+ return False
+
+ def _close(self):
+ try:
+ if hasattr(self, "_channel") and self._channel:
+ self._channel.close()
+ except Exception as exc:
+ print(f"Unexpected exception while closing client: {exc}")
+
+ def __del__(self):
+ self._close()
+
+ def _make_channel(
+ self,
+ host: str,
+ port: int,
+ *,
+ insecure: bool = False,
+ verify: Optional[bool] = None,
+ ca_cert: Union[None, bytes, str] = None,
+ client_key: Union[None, bytes, str] = None,
+ client_cert: Union[None, bytes, str] = None,
+ ) -> grpc.Channel:
+ """Creates a grpc channel
+
+ Args:
+ - host: str
+ - port: str
+ - (optional) insecure: use a plaintext connection (default=False)
+ - (optional) verify: set to False to disable remote host certificate(s)
+ verification. Cannot be used with `plaintext` or with MTLS
+ - (optional) ca_cert: certificate authority to use
+ - (optional) client_key: client key for mTLS mode
+ - (optional) client_cert: client cert for mTLS mode
+
+ """
+ if not host.strip():
+ raise ValueError("A non empty host name is required")
+ if int(port) <= 0:
+ raise ValueError("A non zero port number is required")
+ if insecure and any(
+ (val is not None) for val in (ca_cert, client_key, client_cert)
+ ):
+ raise ValueError("cannot use insecure with TLS/mTLS certificates")
+ if insecure and verify:
+ raise ValueError("insecure cannot be used with verify")
+
+ client_key_bytes = self._try_load_certificate(client_key)
+ client_cert_bytes = self._try_load_certificate(client_cert)
+ ca_cert_bytes = self._try_load_certificate(ca_cert)
+
+ connection = f"{host}:{port}"
+ if insecure:
+ print("Connecting over an insecure plaintext grpc channel")
+ return grpc.insecure_channel(connection)
+
+ credentials_kwargs: dict[str, bytes] = {}
+ if ca_cert_bytes and not (any((client_cert_bytes, client_key_bytes))):
+ print("Connecting using provided CA certificate for secure channel")
+ credentials_kwargs.update(root_certificates=ca_cert_bytes)
+ elif client_cert_bytes and client_key_bytes and ca_cert_bytes:
+ print("Connecting using mTLS for secure channel")
+ credentials_kwargs.update(
+ root_certificates=ca_cert_bytes,
+ private_key=client_key_bytes,
+ certificate_chain=client_cert_bytes,
+ )
+ elif verify is False:
+ cert = get_server_certificate(host, port).encode()
+ credentials_kwargs.update(root_certificates=cert)
+
+ return grpc.secure_channel(
+ connection, grpc.ssl_channel_credentials(**credentials_kwargs)
+ )
+
+ @staticmethod
+ def _try_load_certificate(certificate: Union[None, bytes, str]) -> Optional[bytes]:
+ """If the certificate points to a file, return the contents (plaintext reads).
+ Else return the bytes"""
+ if not certificate:
+ return None
+
+ if isinstance(certificate, bytes):
+ return certificate
+
+ if isinstance(certificate, str):
+ with open(certificate, "rb") as secret_file:
+ return secret_file.read()
+ raise ValueError(
+ f"{certificate=} should be a path to a certificate files or bytes"
+ )
+
+
+if __name__ == "__main__":
+ # host = "flan-t5-small-predictor-caikit-testing.apps.aisrhods-wx.8goc.p1.openshiftapps.com"
+ # port = 443
+ host = "localhost"
+ port = 8033
+ print(f"connecting to {host=}:{port}")
+ client = GrpcClient(
+ host,
+ port,
+ insecure=True,
+ # verify=False,
+ )
+
+ client.make_request("this is the query text", model_id="flan-t5-small")
diff --git a/language/llama2-70b/main.py b/language/llama2-70b/main.py
index bf1def806..5f26e3bcc 100644
--- a/language/llama2-70b/main.py
+++ b/language/llama2-70b/main.py
@@ -16,7 +16,13 @@ def get_args():
parser.add_argument("--scenario", type=str, choices=["Offline", "Server"], default="Offline", help="Scenario")
parser.add_argument("--model-path", type=str, default="meta-llama/Llama-2-70b-chat-hf", help="Model name")
parser.add_argument("--dataset-path", type=str, default=None, help="")
+ parser.add_argument("--api-server", type=str, default=None, help="Specify an api endpoint call to use api mode")
+ parser.add_argument("--api-model-name", type=str, default=None, help="Specify a model name to use api mode")
+ parser.add_argument("--additional-servers", nargs='+', default=[], help="Specify additional endpoints for load splitting")
parser.add_argument("--accuracy", action="store_true", help="Run accuracy mode")
+ parser.add_argument("--grpc", action="store_true", help="Enable grpc for api endpoint")
+ parser.add_argument("--batch-grpc", action="store_true", help="Enable batch requests for grpc")
+ parser.add_argument("--vllm", action="store_true", help="Switch runtime to vllm for api endpoint")
parser.add_argument("--dtype", type=str, default="float32", help="data type of the model, choose from float16, bfloat16 and float32")
parser.add_argument("--device", type=str, choices=["cpu", "cuda:0"], default="cpu", help="device to use")
parser.add_argument("--audit-conf", type=str, default="audit.conf", help="audit config for LoadGen settings during compliance runs")
@@ -68,6 +74,12 @@ def main():
sut = sut_cls(
model_path=args.model_path,
+ api_server=args.api_server,
+ api_model_name=args.api_model_name,
+ additional_servers=args.additional_servers,
+ grpc=args.grpc,
+ batch_grpc=args.batch_grpc,
+ vllm=args.vllm,
dtype=args.dtype,
dataset_path=args.dataset_path,
total_sample_count=args.total_sample_count,
diff --git a/language/llama2-70b/user.conf b/language/llama2-70b/user.conf
index bb97c437a..6a9b4fb6d 100644
--- a/language/llama2-70b/user.conf
+++ b/language/llama2-70b/user.conf
@@ -3,9 +3,6 @@
# Model maybe '*' as wildcard. In that case the value applies to all models.
# All times are in milli seconds
#
-*.Offline.min_duration = 600000
-*.Offline.min_query_count = 2000
+*.Server.target_qps = 20
-*.Server.target_qps = 0.5
-*.Server.min_duration = 120000
-*.Server.min_query_count = 100
+llama2-70b.Server.sample_concatenate_permutation = 1
\ No newline at end of file
diff --git a/loadgen/bindings/python_api.cc b/loadgen/bindings/python_api.cc
index 816f7a4f6..cfe24bd3c 100644
--- a/loadgen/bindings/python_api.cc
+++ b/loadgen/bindings/python_api.cc
@@ -336,6 +336,8 @@ PYBIND11_MODULE(mlperf_loadgen, m) {
&TestSettings::test05_sample_index_rng_seed)
.def_readwrite("test05_schedule_rng_seed", &TestSettings::test05_schedule_rng_seed)
.def_readwrite("use_token_latencies", &TestSettings::use_token_latencies)
+ .def_readwrite("ttft_latency", &TestSettings::server_ttft_latency)
+ .def_readwrite("tpot_latency", &TestSettings::server_tpot_latency)
.def("FromConfig", &TestSettings::FromConfig, "FromConfig.");
pybind11::enum_(m, "LoggingMode")
diff --git a/loadgen/loadgen.cc b/loadgen/loadgen.cc
index 7db118613..76e59151b 100644
--- a/loadgen/loadgen.cc
+++ b/loadgen/loadgen.cc
@@ -120,7 +120,7 @@ struct ResponseDelegateDetailed : public ResponseDelegate {
if (sample_data_copy) {
log.LogAccuracy(sample->sequence_id, sample->sample_index,
- LogBinaryAsHexString{sample_data_copy});
+ LogBinaryAsHexString{sample_data_copy}, n_tokens);
delete sample_data_copy;
}
@@ -140,7 +140,12 @@ struct ResponseDelegateDetailed : public ResponseDelegate {
// For some reason, using std::unique_ptr wasn't moving
// into the lambda; even with C++14.
std::vector* token_data_copy = nullptr;
- if (mode == TestMode::AccuracyOnly) {
+ double accuracy_log_val =
+ sample->accuracy_log_val + accuracy_log_offset < 1.0
+ ? sample->accuracy_log_val + accuracy_log_offset
+ : sample->accuracy_log_val + accuracy_log_offset - 1.0;
+ if (mode == TestMode::AccuracyOnly ||
+ accuracy_log_val <= accuracy_log_prob) {
uint8_t* src_begin = reinterpret_cast(response->data);
uint8_t* src_end = src_begin + response->size;
token_data_copy = new std::vector(src_begin, src_end);
@@ -222,10 +227,8 @@ auto SampleDistribution(size_t sample_count,
auto& gen) mutable { return dist(gen); };
}
-/// \brief SampleDistribution for 3D-UNet SingleStream, for v2.0
-// FIXME: meant for 3D UNet SingleStream only at the moment but the logic should
-// work for others
-// TODO: consolidate the distribution generator after v2.0
+/// \brief Sample across the dataset, and ensure coverage of each of the samples.
+// Useful for non-uniform dataset (e.g. Llama2, GPTJ, 3d-unet)
auto SampleDistributionEqualIssue(size_t sample_count, size_t set_size,
std::mt19937* rng) {
std::vector indices;
@@ -301,8 +304,6 @@ std::vector GenerateQueries(
auto sample_distribution_unique = SampleDistribution(
loaded_sample_set.sample_distribution_end, sample_stride, &sample_rng);
- // FIXME: Only used for v2.0 3D-UNet KiTS19 SingleStream
- // TODO: Need to consolidate the code for any generic usage after v2.0
auto sample_distribution_equal_issue =
SampleDistributionEqualIssue(min_queries,
loaded_samples.size(),
@@ -312,14 +313,24 @@ std::vector GenerateQueries(
ScheduleDistribution(settings.target_qps);
// When sample_concatenate_permutation is turned on, pad to a multiple of the
- // complete dataset to ensure complete fairness.
- // FIXME: Only override this for Offline; fix after v2.0
- if (settings.sample_concatenate_permutation &&
- scenario == TestScenario::Offline &&
- samples_per_query % loaded_samples.size() != 0) {
- size_t pad_size =
+ // complete dataset to ensure fairness.
+ auto enable_equal_issue = settings.sample_concatenate_permutation;
+ if (mode != TestMode::AccuracyOnly && enable_equal_issue)
+ {
+ if (scenario == TestScenario::Offline &&
+ samples_per_query % loaded_samples.size() != 0)
+ {
+ // In offline mode, we pad samples_per_query
+ size_t pad_size =
(loaded_samples.size() - samples_per_query % loaded_samples.size());
- samples_per_query += pad_size;
+ samples_per_query += pad_size;
+ }
+ else if (min_queries % loaded_samples.size() != 0)
+ {
+ // In Server, SingleStream, MultiStream mode, the min_queries should be padded
+ size_t pad_size = (loaded_samples.size() - min_queries % loaded_samples.size());
+ min_queries += pad_size;
+ }
}
std::vector samples(samples_per_query);
@@ -375,16 +386,12 @@ std::vector GenerateQueries(
}
}
} else {
- // FIXME: only used for v2.0 3D-UNet KiTS19 SingleStream
- // TODO: consolidate after v2.0
- auto equal_issue = settings.sample_concatenate_permutation &&
- scenario == TestScenario::SingleStream;
for (auto& s : samples) {
s = loaded_samples[settings.performance_issue_unique
? sample_distribution_unique(sample_rng)
: settings.performance_issue_same
? same_sample
- : equal_issue
+ : enable_equal_issue
? sample_distribution_equal_issue(sample_rng)
: sample_distribution(sample_rng)];
}
@@ -392,6 +399,13 @@ std::vector GenerateQueries(
queries.emplace_back(samples, timestamp, response_delegate, sequence_gen);
prev_timestamp = timestamp;
timestamp += schedule_distribution(schedule_rng);
+ // In equal_issue mode, the min_queries will be bumped up by a multiple of the dataset size
+ // if the test time has not met the threshold.
+ if (enable_equal_issue && (queries.size() >= min_queries) &&
+ (prev_timestamp < gen_duration) && (scenario != TestScenario::Offline))
+ {
+ min_queries += loaded_samples.size();
+ }
}
// See if we need to create a "remainder" query for offline+accuracy to
@@ -524,6 +538,9 @@ PerformanceResult IssueQueries(SystemUnderTest* sut,
std::vector first_token_latencies(
GlobalLogger().GetTokenLatencies(expected_latencies));
+ std::vector time_per_output_token_arr(
+ GlobalLogger().GetTimePerOutputToken(expected_latencies));
+
std::vector tokens_per_sample(
GlobalLogger().GetTokensPerSample(expected_latencies));
@@ -583,6 +600,7 @@ PerformanceResult IssueQueries(SystemUnderTest* sut,
final_query_all_samples_done_time,
TokenPerformanceResults{
first_token_latencies,
+ time_per_output_token_arr,
tokens_per_sample
}
};
diff --git a/loadgen/logging.cc b/loadgen/logging.cc
index d990e7c4d..d33074c01 100644
--- a/loadgen/logging.cc
+++ b/loadgen/logging.cc
@@ -279,19 +279,25 @@ void AsyncLog::StopTrace() {
}
void AsyncLog::LogAccuracy(uint64_t seq_id, const QuerySampleIndex qsl_idx,
- const LogBinaryAsHexString& response) {
+ const LogBinaryAsHexString& response,
+ int64_t n_tokens = 0) {
std::unique_lock lock(log_mutex_);
if (!accuracy_out_) {
return;
}
*accuracy_out_ << (accuracy_needs_comma_ ? ",\n{ " : "\n{ ");
- if (!use_tokens_ || !needs_first_token_){
+ if (!use_tokens_){
LogArgs(accuracy_out_, "seq_id", seq_id, "qsl_idx", qsl_idx, "data",
response);
- } else {
+ } else if (!needs_first_token_)
+ {
+ LogArgs(accuracy_out_, "seq_id", seq_id, "qsl_idx", qsl_idx, "data",
+ response, "token_count", n_tokens);
+ }
+ else {
const size_t i = seq_id - latencies_first_sample_sequence_id_;
LogArgs(accuracy_out_, "seq_id", seq_id, "qsl_idx", qsl_idx, "data",
- response, "token_data", token_records_[i]);
+ response, "token_data", token_records_[i], "token_count", n_tokens);
}
*accuracy_out_ << " }";
@@ -349,6 +355,7 @@ void AsyncLog::RestartLatencyRecording(uint64_t first_sample_sequence_id,
latencies_.reserve(latencies_to_reserve);
token_latencies_.reserve(latencies_to_reserve);
tokens_per_sample_.reserve(latencies_to_reserve);
+ time_per_output_token_.reserve(latencies_to_reserve);
}
void AsyncLog::RecordSampleCompletion(uint64_t sample_sequence_id,
@@ -430,15 +437,40 @@ void AsyncLog::RecordSampleCompletion(uint64_t sample_sequence_id,
// If the SUT recorded the wrong sample, the test will hang and see
// the error above.
return;
- } else if (n_tokens == 0){
+ }
+ if (n_tokens == 0){
MLPERF_LOG_ERROR_SYNC(GlobalLogger(), "error_runtime",
"n_tokens argument missing or attempted to record 0 as number of tokens");
} else if (n_tokens < 0){
MLPERF_LOG_ERROR_SYNC(GlobalLogger(), "error_runtime",
"Attempted to record a negative number of tokens");
n_tokens = 0;
+ } else if (n_tokens == 1){
+ MLPERF_LOG_ERROR_SYNC(GlobalLogger(), "error_runtime",
+ "Number of tokens need to be greater than 1");
+ n_tokens = 0;
+ }
+ if (time_per_output_token_.size() <= i){
+ time_per_output_token_.resize(i + 1, kInvalidLatency);
+ } else if (time_per_output_token_[i] != kInvalidLatency) {
+ // Call LogErrorSync here since this kind of error could result in a
+ // segfault in the near future.
+ #if USE_NEW_LOGGING_FORMAT
+ MLPERF_LOG_ERROR_SYNC(GlobalLogger(), "error_runtime",
+ "Attempted to complete a sample twice.");
+ #else
+ GlobalLogger().LogErrorSync("Attempted to complete a sample twice.");
+ #endif
+
+ // Return without recording the latency again to avoid potentially
+ // ending the test before the SUT is actually done, which could result
+ // in a segfault.
+ // If the SUT recorded the wrong sample, the test will hang and see
+ // the error above.
+ return;
}
tokens_per_sample_[i] = n_tokens;
+ time_per_output_token_[i] = (latency - token_latencies_[i]) / (n_tokens - 1);
}
latencies_[i] = latency;
latencies_recorded_++;
@@ -572,6 +604,12 @@ std::vector AsyncLog::GetTokenLatencies(size_t expected_coun
return token_latencies;
}
+std::vector AsyncLog::GetTimePerOutputToken(size_t expected_count){
+ std::vector tpot_latencies;
+ tpot_latencies.swap(time_per_output_token_);
+ return tpot_latencies;
+}
+
std::vector AsyncLog::GetTokensPerSample(size_t expected_count) {
std::vector tokens_per_sample;
tokens_per_sample.swap(tokens_per_sample_);
@@ -908,6 +946,10 @@ std::vector Logger::GetTokenLatencies(
size_t expected_count) {
return async_logger_.GetTokenLatencies(expected_count);
}
+std::vector Logger::GetTimePerOutputToken(
+ size_t expected_count) {
+ return async_logger_.GetTimePerOutputToken(expected_count);
+}
std::vector Logger::GetTokensPerSample(
size_t expected_count) {
return async_logger_.GetTokensPerSample(expected_count);
diff --git a/loadgen/logging.h b/loadgen/logging.h
index d10f574d6..e62825859 100644
--- a/loadgen/logging.h
+++ b/loadgen/logging.h
@@ -226,7 +226,7 @@ class AsyncLog {
void SetCurrentPidTid(uint64_t pid, uint64_t tid);
void LogAccuracy(uint64_t seq_id, const QuerySampleIndex qsl_idx,
- const LogBinaryAsHexString& response);
+ const LogBinaryAsHexString& response, int64_t n_tokens);
void CacheToken(uint64_t seq_id, const LogBinaryAsHexString& response);
template
@@ -322,6 +322,7 @@ class AsyncLog {
QuerySampleLatency latency);
std::vector GetLatenciesBlocking(size_t expected_count);
std::vector GetTokenLatencies(size_t expected_count);
+ std::vector GetTimePerOutputToken(size_t expected_count);
std::vector GetTokensPerSample(size_t expected_count);
PerfClock::time_point GetMaxCompletionTime();
QuerySampleLatency GetMaxLatencySoFar();
@@ -386,6 +387,7 @@ class AsyncLog {
uint64_t latencies_first_sample_sequence_id_ = 0;
std::vector latencies_;
std::vector token_latencies_;
+ std::vector time_per_output_token_;
std::vector token_records_;
std::vector tokens_per_sample_;
QuerySampleLatency max_latency_ = 0;
@@ -421,6 +423,7 @@ class Logger {
size_t latencies_to_reserve);
std::vector GetLatenciesBlocking(size_t expected_count);
std::vector GetTokenLatencies(size_t expected_count);
+ std::vector GetTimePerOutputToken(size_t expected_count);
std::vector GetTokensPerSample(size_t expected_count);
PerfClock::time_point GetMaxCompletionTime();
QuerySampleLatency GetMaxLatencySoFar();
diff --git a/loadgen/results.cc b/loadgen/results.cc
index 21fd2c90a..445de8901 100644
--- a/loadgen/results.cc
+++ b/loadgen/results.cc
@@ -97,8 +97,15 @@ void PerformanceSummary::ProcessTokenLatencies() {
accumulated_first_token_latency += latency;
}
first_token_latency_mean = accumulated_first_token_latency / sample_count;
+ QuerySampleLatency accumulated_tpot = 0;
+ for (auto latency : pr.token_results.time_per_output_token_arr) {
+ accumulated_tpot += latency;
+ }
+ time_per_output_token_mean = accumulated_tpot / sample_count;
std::sort(pr.token_results.first_token_latencies.begin(),
pr.token_results.first_token_latencies.end());
+ std::sort(pr.token_results.time_per_output_token_arr.begin(),
+ pr.token_results.time_per_output_token_arr.end());
token_target_latency_percentile.sample_latency =
pr.token_results.first_token_latencies[sample_count * token_target_latency_percentile.percentile];
@@ -110,6 +117,16 @@ void PerformanceSummary::ProcessTokenLatencies() {
lp.sample_latency = pr.token_results.first_token_latencies[sample_count * lp.percentile];
}
+ target_tpot_percentile.sample_latency =
+ pr.token_results.time_per_output_token_arr[sample_count * target_tpot_percentile.percentile];
+ time_per_output_token_min = pr.token_results.time_per_output_token_arr.front();
+ time_per_output_token_max = pr.token_results.time_per_output_token_arr.back();
+ for (auto& lp : tpot_percentiles) {
+ assert(lp.percentile >= 0.0);
+ assert(lp.percentile < 1.0);
+ lp.sample_latency = pr.token_results.time_per_output_token_arr[sample_count * lp.percentile];
+ }
+
if (settings.scenario == TestScenario::Server) {
// TODO: Maybe another target latency needs to be added?
QuerySampleLatency max_latency = settings.target_latency.count() + 1;
@@ -121,10 +138,12 @@ void PerformanceSummary::ProcessTokenLatencies() {
}
-bool PerformanceSummary::EarlyStopping(std::string* recommendation) {
+bool PerformanceSummary::EarlyStopping(std::string* recommendation, int64_t queries_issued,
+ std::vector* sample_latencies,
+ std::vector* query_latencies,
+ std::chrono::nanoseconds target_latency) {
recommendation->clear();
- int64_t queries_issued = pr.queries_issued;
MinPassingQueriesFinder find_min_passing;
double confidence = 0.99;
double tolerance = 0.0;
@@ -155,7 +174,7 @@ bool PerformanceSummary::EarlyStopping(std::string* recommendation) {
}
}
QuerySampleLatency percentile_estimate =
- pr.sample_latencies[queries_issued - t];
+ (*sample_latencies)[queries_issued - t];
*recommendation =
" * Processed at least " + std::to_string(h_min + 1) + " queries (" +
std::to_string(queries_issued) + ").\n" + " * Would discard " +
@@ -187,7 +206,7 @@ bool PerformanceSummary::EarlyStopping(std::string* recommendation) {
break;
}
}
- percentile_estimate = pr.sample_latencies[queries_issued - t];
+ percentile_estimate = (*sample_latencies)[queries_issued - t];
*recommendation +=
"\n * Early stopping " +
DoubleToString(multi_stream_percentile * 100, 0) +
@@ -198,9 +217,9 @@ bool PerformanceSummary::EarlyStopping(std::string* recommendation) {
}
case TestScenario::Server: {
int64_t t =
- std::count_if(pr.sample_latencies.begin(), pr.sample_latencies.end(),
+ std::count_if((*sample_latencies).begin(), (*sample_latencies).end(),
[=](auto const& latency) {
- return latency > settings.target_latency.count();
+ return latency > target_latency.count();
});
int64_t h = find_min_passing(t, target_latency_percentile.percentile,
tolerance, confidence);
@@ -239,7 +258,7 @@ bool PerformanceSummary::EarlyStopping(std::string* recommendation) {
}
}
QuerySampleLatency percentile_estimate =
- pr.query_latencies[queries_issued - t];
+ (*query_latencies)[queries_issued - t];
*recommendation =
" * Processed at least " + std::to_string(h_min + 1) + " queries (" +
std::to_string(queries_issued) + ").\n" + " * Would discard " +
@@ -317,10 +336,28 @@ bool PerformanceSummary::PerfConstraintsMet(std::string* recommendation) {
break;
case TestScenario::Server:
ProcessLatencies();
- if (target_latency_percentile.sample_latency >
- settings.target_latency.count()) {
- *recommendation = "Reduce target QPS to improve latency.";
- perf_constraints_met = false;
+ if (!settings.use_token_latencies){
+ if (target_latency_percentile.sample_latency >
+ settings.target_latency.count()) {
+ *recommendation = "Reduce target QPS to improve latency.";
+ perf_constraints_met = false;
+ }
+ } else {
+ if ( token_target_latency_percentile.sample_latency >
+ settings.server_ttft_latency) {
+ *recommendation = "TTFT constrain not met: Reduce target QPS to improve latency.";
+ perf_constraints_met = false;
+ }
+
+ if ( target_tpot_percentile.sample_latency >
+ settings.server_tpot_latency) {
+ if (recommendation->empty()){
+ *recommendation = "TPOT constrain not met: Reduce target QPS to improve latency.";
+ } else {
+ recommendation->append("\n * TPOT constrain not met: Reduce target QPS to improve latency.");
+ }
+ perf_constraints_met = false;
+ }
}
break;
case TestScenario::Offline:
@@ -362,10 +399,18 @@ void PerformanceSummary::LogSummary(AsyncSummary& summary) {
// the 1 second time point; but that would be the 1001th sample in
// the stream. Given the first 1001 queries, the QPS is
// 1000 queries / 1 second.
+ // TODO: make a more permanent solution
+ if (settings.use_token_latencies){
+ double qps_as_completed =
+ (sample_count - 1) / pr.final_query_all_samples_done_time;
+ summary("Completed samples per second : ",
+ DoubleToString(qps_as_completed));
+ } else {
double qps_as_scheduled =
(sample_count - 1) / pr.final_query_scheduled_time;
summary("Scheduled samples per second : ",
DoubleToString(qps_as_scheduled));
+ }
break;
}
case TestScenario::Offline: {
@@ -394,16 +439,38 @@ void PerformanceSummary::LogSummary(AsyncSummary& summary) {
summary("Tokens per second: ", tokens_per_second);
break;
}
+ case TestScenario::Server:
+ break;
}
}
std::string min_duration_recommendation;
std::string perf_constraints_recommendation;
std::string early_stopping_recommendation;
+ std::string early_stopping_ttft_recommendation;
+ std::string early_stopping_tpot_recommendation;
bool min_duration_met = MinDurationMet(&min_duration_recommendation);
bool min_queries_met = MinQueriesMet() && MinSamplesMet();
- bool early_stopping_met = EarlyStopping(&early_stopping_recommendation);
+ bool early_stopping_met = true;
+ if (!settings.use_token_latencies){
+ early_stopping_met = EarlyStopping(&early_stopping_recommendation,
+ pr.queries_issued,
+ &pr.sample_latencies,
+ &pr.query_latencies,
+ settings.target_latency);
+ } else {
+ early_stopping_met = EarlyStopping(&early_stopping_tpot_recommendation,
+ pr.queries_issued,
+ &pr.token_results.time_per_output_token_arr,
+ &pr.query_latencies,
+ std::chrono::nanoseconds(settings.server_tpot_latency)) &&
+ EarlyStopping(&early_stopping_ttft_recommendation,
+ pr.queries_issued,
+ &pr.token_results.first_token_latencies,
+ &pr.query_latencies,
+ std::chrono::nanoseconds(settings.server_ttft_latency));
+ }
bool perf_constraints_met =
PerfConstraintsMet(&perf_constraints_recommendation);
bool all_constraints_met = min_duration_met && min_queries_met &&
@@ -435,8 +502,15 @@ void PerformanceSummary::LogSummary(AsyncSummary& summary) {
if (settings.scenario == TestScenario::SingleStream ||
settings.scenario == TestScenario::Server ||
settings.scenario == TestScenario::MultiStream) {
- summary("Early Stopping Result:");
- summary(early_stopping_recommendation);
+ if (!settings.use_token_latencies){
+ summary("Early Stopping Result:");
+ summary(early_stopping_recommendation);
+ } else {
+ summary("TTFT Early Stopping Result:");
+ summary(early_stopping_ttft_recommendation);
+ summary("TPOT Early Stopping Result:");
+ summary(early_stopping_tpot_recommendation);
+ }
}
summary(
@@ -452,10 +526,18 @@ void PerformanceSummary::LogSummary(AsyncSummary& summary) {
summary("QPS w/o loadgen overhead : " + DoubleToString(qps_wo_lg));
summary("");
} else if (settings.scenario == TestScenario::Server) {
- double qps_as_completed =
+ // TODO: make a more permanent solution
+ if (!settings.use_token_latencies){
+ double qps_as_completed =
(sample_count - 1) / pr.final_query_all_samples_done_time;
- summary("Completed samples per second : ",
- DoubleToString(qps_as_completed));
+ summary("Completed samples per second : ",
+ DoubleToString(qps_as_completed));
+ } else {
+ double qps_as_scheduled =
+ (sample_count - 1) / pr.final_query_scheduled_time;
+ summary("Scheduled samples per second : ",
+ DoubleToString(qps_as_scheduled));
+ }
summary("");
} else if (settings.scenario == TestScenario::MultiStream) {
summary("Per-query latency: ");
@@ -490,17 +572,26 @@ void PerformanceSummary::LogSummary(AsyncSummary& summary) {
} else if (settings.scenario == TestScenario::Server) {
double tps_as_completed =
token_count / pr.final_query_all_samples_done_time;
- summary("Completed tokens per second : ",
+ summary("Completed tokens per second : ",
DoubleToString(tps_as_completed));
}
if (settings.scenario != TestScenario::Offline) {
- summary("Min First Token latency (ns) : ", first_token_latency_min);
- summary("Max First Token latency (ns) : ", first_token_latency_max);
- summary("Mean First Token latency (ns) : ", first_token_latency_mean);
+ summary("Min First Token latency (ns) : ", first_token_latency_min);
+ summary("Max First Token latency (ns) : ", first_token_latency_max);
+ summary("Mean First Token latency (ns) : ", first_token_latency_mean);
for (auto& lp : token_latency_percentiles) {
summary(
- DoubleToString(lp.percentile * 100) + " percentile latency (ns) : ",
+ DoubleToString(lp.percentile * 100) + " percentile first token latency (ns) : ",
+ lp.sample_latency);
+ }
+ summary("");
+ summary("Min Time to Output Token (ns) : ", time_per_output_token_min);
+ summary("Max Time to Output Token (ns) : ", time_per_output_token_max);
+ summary("Mean Time to Output Token (ns) : ", time_per_output_token_mean);
+ for (auto& lp : tpot_percentiles) {
+ summary(
+ DoubleToString(lp.percentile * 100) + " percentile time to output token (ns) : ",
lp.sample_latency);
}
}
@@ -522,11 +613,31 @@ void PerformanceSummary::LogDetail(AsyncDetail& detail) {
std::string min_duration_recommendation;
std::string perf_constraints_recommendation;
std::string early_stopping_recommendation;
+ std::string early_stopping_ttft_recommendation;
+ std::string early_stopping_tpot_recommendation;
bool min_duration_met = MinDurationMet(&min_duration_recommendation);
bool min_queries_met = MinQueriesMet() && MinSamplesMet();
bool perf_constraints_met =
PerfConstraintsMet(&perf_constraints_recommendation);
- bool early_stopping_met = EarlyStopping(&early_stopping_recommendation);
+ bool early_stopping_met = true;
+ if (!settings.use_token_latencies){
+ early_stopping_met = EarlyStopping(&early_stopping_recommendation,
+ pr.queries_issued,
+ &pr.sample_latencies,
+ &pr.query_latencies,
+ settings.target_latency);
+ } else {
+ early_stopping_met = EarlyStopping(&early_stopping_tpot_recommendation,
+ pr.queries_issued,
+ &pr.token_results.time_per_output_token_arr,
+ &pr.query_latencies,
+ std::chrono::nanoseconds(settings.server_tpot_latency)) &&
+ EarlyStopping(&early_stopping_ttft_recommendation,
+ pr.queries_issued,
+ &pr.token_results.first_token_latencies,
+ &pr.query_latencies,
+ std::chrono::nanoseconds(settings.server_ttft_latency));
+ }
bool all_constraints_met = min_duration_met && min_queries_met &&
perf_constraints_met && early_stopping_met;
@@ -554,8 +665,16 @@ void PerformanceSummary::LogDetail(AsyncDetail& detail) {
}
std::replace(early_stopping_recommendation.begin(),
early_stopping_recommendation.end(), '\n', ' ');
- MLPERF_LOG(detail, "early_stopping_result", early_stopping_recommendation);
-
+ if (!settings.use_token_latencies){
+ MLPERF_LOG(detail, "early_stopping_result", early_stopping_recommendation);
+ } else{
+ std::replace(early_stopping_ttft_recommendation.begin(),
+ early_stopping_ttft_recommendation.end(), '\n', ' ');
+ std::replace(early_stopping_tpot_recommendation.begin(),
+ early_stopping_tpot_recommendation.end(), '\n', ' ');
+ MLPERF_LOG(detail, "early_stopping_ttft_result", early_stopping_ttft_recommendation);
+ MLPERF_LOG(detail, "early_stopping_tpot_result", early_stopping_tpot_recommendation);
+ }
// Report number of queries
MLPERF_LOG(detail, "result_query_count", query_count);
if (settings.scenario == TestScenario::Server) {
@@ -636,7 +755,7 @@ void PerformanceSummary::LogDetail(AsyncDetail& detail) {
MLPERF_LOG(detail, "result_first_token_mean_latency_ns", first_token_latency_mean);
for (auto& lp : token_latency_percentiles) {
MLPERF_LOG(detail,
- "result_" + DoubleToString(lp.percentile * 100) +
+ "result_first_token_" + DoubleToString(lp.percentile * 100) +
"_percentile_latency_ns",
lp.sample_latency);
}
@@ -644,8 +763,15 @@ void PerformanceSummary::LogDetail(AsyncDetail& detail) {
double tps_wo_lg= ((double)token_count) / (sample_latency_mean * sample_count);
MLPERF_LOG(detail, "result_token_throughput_with_loadgen_overhead", tps_w_lg);
MLPERF_LOG(detail, "result_token_throughput", tps_wo_lg);
- double tpot = sample_count * (sample_latency_mean - first_token_latency_mean) / ((double)token_count);
- MLPERF_LOG(detail, "result_time_to_output_token", tpot);
+ for (auto& lp : tpot_percentiles) {
+ MLPERF_LOG(detail,
+ "result_time_per_output_token_" + DoubleToString(lp.percentile * 100) +
+ "_percentile_ns",
+ lp.sample_latency);
+ }
+ MLPERF_LOG(detail, "result_time_to_output_token_min", time_per_output_token_min);
+ MLPERF_LOG(detail, "result_time_to_output_token_max", time_per_output_token_max);
+ MLPERF_LOG(detail, "result_time_to_output_token_mean", time_per_output_token_mean);
} else {
double tokens_per_second = token_count / pr.max_latency;
MLPERF_LOG(detail, "result_tokens_per_second", tokens_per_second);
diff --git a/loadgen/results.h b/loadgen/results.h
index 69825a9d3..38bbe32d4 100644
--- a/loadgen/results.h
+++ b/loadgen/results.h
@@ -29,6 +29,7 @@ namespace loadgen {
/// token based metrics
struct TokenPerformanceResults {
std::vector first_token_latencies;
+ std::vector time_per_output_token_arr;
std::vector tokens_per_sample;
};
@@ -86,11 +87,17 @@ struct PerformanceSummary {
QuerySampleLatency first_token_latency_min;
QuerySampleLatency first_token_latency_max;
QuerySampleLatency first_token_latency_mean;
+ QuerySampleLatency time_per_output_token_min;
+ QuerySampleLatency time_per_output_token_max;
+ QuerySampleLatency time_per_output_token_mean;
// Latency token target percentile
PercentileEntry token_target_latency_percentile{settings.target_latency_percentile};
PercentileEntry token_latency_percentiles[6] = {{.50}, {.90}, {.95},
{.97}, {.99}, {.999}};
+ PercentileEntry target_tpot_percentile{settings.target_latency_percentile};
+ PercentileEntry tpot_percentiles[6] = {{.50}, {.90}, {.95},
+ {.97}, {.99}, {.999}};
#if defined(_WIN32) || defined(WIN32) || defined(_WIN64) || defined(WIN64)
// MSVC complains if there is no explicit constructor.
@@ -104,7 +111,10 @@ struct PerformanceSummary {
void ProcessTokenLatencies();
bool MinDurationMet(std::string* recommendation);
- bool EarlyStopping(std::string* recommendation);
+ bool EarlyStopping(std::string* recommendation, int64_t queries_issued,
+ std::vector* sample_latencies,
+ std::vector* query_latencies,
+ std::chrono::nanoseconds target_latency);
bool MinQueriesMet();
bool MinSamplesMet();
bool HasPerfConstraints();
diff --git a/loadgen/test_settings.h b/loadgen/test_settings.h
index ccae5327a..b0018380d 100644
--- a/loadgen/test_settings.h
+++ b/loadgen/test_settings.h
@@ -264,6 +264,9 @@ struct TestSettings {
uint64_t performance_sample_count_override = 0;
/// \brief Measure token latencies
bool use_token_latencies = false;
+ /// Token latency parameters
+ uint64_t server_ttft_latency = 100000000;
+ uint64_t server_tpot_latency = 100000000;
/**@}*/
};
diff --git a/loadgen/test_settings_internal.cc b/loadgen/test_settings_internal.cc
index 0a4b9af6e..2bcf62c29 100644
--- a/loadgen/test_settings_internal.cc
+++ b/loadgen/test_settings_internal.cc
@@ -49,7 +49,9 @@ TestSettingsInternal::TestSettingsInternal(
performance_issue_same_index(requested.performance_issue_same_index),
performance_sample_count(0),
sample_concatenate_permutation(false),
- use_token_latencies(requested.use_token_latencies){
+ use_token_latencies(requested.use_token_latencies),
+ server_ttft_latency(requested.server_ttft_latency),
+ server_tpot_latency(requested.server_tpot_latency){
// Target QPS, target latency, and max_async_queries.
switch (requested.scenario) {
case TestScenario::SingleStream:
@@ -330,6 +332,14 @@ void LogRequestedTestSettings(const TestSettings &s) {
s.performance_issue_same_index);
MLPERF_LOG(detail, "requested_performance_sample_count_override",
s.performance_sample_count_override);
+ // Token latencies specific values
+ if (s.use_token_latencies){
+ MLPERF_LOG(detail, "requested_use_token_latencies", s.use_token_latencies);
+ if (s.scenario != TestScenario::Offline){
+ MLPERF_LOG(detail, "requested_server_ttft_latency", s.server_ttft_latency);
+ MLPERF_LOG(detail, "requested_server_tpot_latency", s.server_tpot_latency);
+ }
+ }
#else
detail("");
detail("Requested Settings:");
@@ -472,7 +482,12 @@ void TestSettingsInternal::LogAllSettings() const {
void TestSettingsInternal::LogSummary(AsyncSummary &summary) const {
summary("samples_per_query : ", samples_per_query);
summary("target_qps : ", target_qps);
- summary("target_latency (ns): ", target_latency.count());
+ if (!use_token_latencies){
+ summary("target_latency (ns): ", target_latency.count());
+ } else {
+ summary("ttft_latency (ns): ", server_ttft_latency);
+ summary("tpot_latency (ns): ", server_tpot_latency);
+ }
summary("max_async_queries : ", max_async_queries);
summary("min_duration (ms): ", min_duration.count());
summary("max_duration (ms): ", max_duration.count());
@@ -673,8 +688,15 @@ int TestSettings::FromConfig(const std::string &path, const std::string &model,
lookupkv(model, scenario, "test05_sample_index_rng_seed", &test05_sample_index_rng_seed,
nullptr);
lookupkv(model, scenario, "test05_schedule_rng_seed", &test05_schedule_rng_seed, nullptr);
- if (lookupkv(model, scenario, "use_token_latencies", &val, nullptr))
+
+ // keys that apply to token metrics
+ if (lookupkv(model, scenario, "use_token_latencies", &val, nullptr)){
use_token_latencies = (val == 1) ? true : false;
+ if (use_token_latencies){
+ lookupkv(model, "Server", "ttft_latency", &server_ttft_latency, nullptr, 1000 * 1000);
+ lookupkv(model, "Server", "tpot_latency", &server_tpot_latency, nullptr, 1000 * 1000);
+ }
+ }
// keys that apply to SingleStream
lookupkv(model, "SingleStream", "target_latency_percentile", nullptr,
diff --git a/loadgen/test_settings_internal.h b/loadgen/test_settings_internal.h
index 6bc1c3cf3..5222f3156 100644
--- a/loadgen/test_settings_internal.h
+++ b/loadgen/test_settings_internal.h
@@ -83,6 +83,8 @@ struct TestSettingsInternal {
bool sample_concatenate_permutation;
bool use_token_latencies = false;
+ int64_t server_ttft_latency;
+ int64_t server_tpot_latency;
};
/// \brief A namespace of collections of FindPeakPerformance helper functions,
diff --git a/mlperf.conf b/mlperf.conf
index fe743a8dd..7e286b565 100644
--- a/mlperf.conf
+++ b/mlperf.conf
@@ -12,6 +12,8 @@ bert.*.performance_sample_count_override = 10833
dlrm.*.performance_sample_count_override = 204800
dlrm-v2.*.performance_sample_count_override = 204800
rnnt.*.performance_sample_count_override = 2513
+gptj.*.performance_sample_count_override = 13368
+llama2-70b.*.performance_sample_count_override = 24576
stable-diffusion-xl.*.performance_sample_count_override = 5000
# set to 0 to let entire sample set to be performance sample
3d-unet.*.performance_sample_count_override = 0
@@ -28,27 +30,25 @@ stable-diffusion-xl.*.performance_sample_count_override = 5000
*.SingleStream.target_latency_percentile = 90
*.SingleStream.min_duration = 600000
-#*.SingleStream.min_query_count = 1024
*.MultiStream.target_latency_percentile = 99
*.MultiStream.samples_per_query = 8
*.MultiStream.min_duration = 600000
-#*.MultiStream.min_query_count = 270336
*.MultiStream.min_query_count = 662
retinanet.MultiStream.target_latency = 528
-
-# 3D-UNet uses equal issue mode
+# 3D-UNet uses equal issue mode because it has non-uniform inputs
3d-unet.*.sample_concatenate_permutation = 1
-# GPT-J uses equal issue mode for Single-Stream
+# LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario
+gptj.Server.sample_concatenate_permutation = 1
gptj.SingleStream.sample_concatenate_permutation = 1
+llama2-70b.Server.sample_concatenate_permutation = 1
*.Server.target_latency = 10
*.Server.target_latency_percentile = 99
*.Server.target_duration = 0
*.Server.min_duration = 600000
-#*.Server.min_query_count = 270336
resnet50.Server.target_latency = 15
retinanet.Server.target_latency = 100
bert.Server.target_latency = 130
@@ -58,7 +58,9 @@ rnnt.Server.target_latency = 1000
gptj.Server.target_latency = 20000
stable-diffusion-xl.Server.target_latency = 20000
# Falcon Server scenario requires two latency constraints
-llama2-70b.Server.target_latency = 2000
+llama2-70b.*.use_token_latencies = 1
+# Only ttft and tpot are tracked for the llama2-70b benchmark therefore target_latency = 0
+llama2-70b.Server.target_latency = 0
llama2-70b.Server.ttft_latency = 2000
llama2-70b.Server.tpot_latency = 200
diff --git a/recommendation/dlrm_v2/pytorch/README.md b/recommendation/dlrm_v2/pytorch/README.md
index 313bf1e1f..12937c0c8 100755
--- a/recommendation/dlrm_v2/pytorch/README.md
+++ b/recommendation/dlrm_v2/pytorch/README.md
@@ -67,36 +67,50 @@ cd $HOME/mlcommons/inference/loadgen
CFLAGS="-std=c++14" python setup.py develop --user
```
+
### Downloading model weights
-File name | framework | Size in bytes (`du *`) | MD5 hash (`md5sum *`)
--|-|-|-
+framework | Size in bytes (`du *`) | MD5 hash (`md5sum *`)
+-|-|-
N/A | pytorch | <2GB | -
-[weight_sharded](https://cloud.mlcommons.org/index.php/s/XzfSeLgW8FYfR3S/download) | pytorch | 97.31GB | -
+ pytorch | 97.31GB | -
+
+#### CM method
+
+The following MLCommons CM commands can be used to programmatically download the model checkpoint.
-You can download the weights by running:
```
-wget https://cloud.mlcommons.org/index.php/s/XzfSeLgW8FYfR3S/download -O weights.zip
-unzip weights.zip
+pip install cmind
+cm pull repo mlcommons@ck
+cm run script --tags=get,ml-model,dlrm,_pytorch,_weight_sharded,_rclone -j
```
-(optional) To speed up future downloads, we recommend you save the weights in a bucket (E.g GCP, AWS). For example, after saving the checkpoint in a GCP bucket, you can download the weights faster by running:
+
+#### Manual method
+
+The above command automatically runs a set of Rclone commands to download the data from a Cloudflare R2 bucket. However, if you'd like to run the Rclone commands manually, you can do so as follows:
+
+To run Rclone on Windows, you can download the executable [here](https://rclone.org/install/#windows).
+To install Rclone on Linux/macOS/BSD systems, run:
```
-export BUCKET_NAME=
-cd $HOME/mlcommons/inference/recommendation/dlrm_v2/pytorch/model/
-gsutil -m cp -r "gs://$BUCKET_NAME/model_weights/*" .
+sudo -v ; curl https://rclone.org/install.sh | sudo bash
+```
+Once Rclone is installed, run the following command to authenticate with the bucket:
```
+rclone config create mlc-inference s3 provider=Cloudflare access_key_id=f65ba5eef400db161ea49967de89f47b secret_access_key=fbea333914c292b854f14d3fe232bad6c5407bf0ab1bebf78833c2b359bdfd2b endpoint=https://c2686074cb2caf5cbaf6d134bdba8b47.r2.cloudflarestorage.com
+```
+You can then navigate in the terminal to your desired download directory and run the following command to download the model weights:
-### Downloading dataset
-| Original dataset | download link |
-| ---- | ---- |
-| Criteo Terabyte (day 23) | https://labs.criteo.com/2013/12/download-terabyte-click-logs/ |
+```
+rclone copy mlc-inference:mlcommons-inference-wg-public/model_weights ./model_weights -P
+```
+#### (optional)
-1. The Criteo fake dataset can be created in place of the real datasets in order to facilitate debugging and testing. We provide a fake (random) data generator that can be used to quickly generate data samples in a format compatible with the original dataset. Please use the following script in `./tools` to quickly create random samples for the corresponding models, which will be placed into `./fake_criteo` directory
+To speed up future downloads, we recommend you save the weights in a bucket (E.g GCP, AWS). For example, after saving the checkpoint in a GCP bucket, you can download the weights faster by running:
```
-./make_fake_criteo.sh
-mv ./fake_criteo .. && cd ..
-export DATA_DIR=./fake_criteo
+export BUCKET_NAME=
+cd $HOME/mlcommons/inference/recommendation/dlrm_v2/pytorch/model/
+gsutil -m cp -r "gs://$BUCKET_NAME/model_weights/*" .
```
diff --git a/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py b/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py
index 873f6a0e1..ce662071f 100644
--- a/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py
+++ b/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py
@@ -43,7 +43,9 @@ def get_targets(args, qsl_indices):
with open(args.aggregation_trace_file) as f:
for line in f:
sample_boundaries.append(sample_boundaries[-1] + int(line.split(", ")[2]))
- assert len(sample_boundaries) == len(qsl_indices) + 1, "Number of samples in trace file does not match number of samples in loadgen accuracy log!"
+ if len(sample_boundaries) != len(qsl_indices) + 1:
+ print("Warning: number of samples in trace file ({}) does not match number of samples ({}) in "
+ "loadgen accuracy log!".format(len(sample_boundaries)-1, len(qsl_indices)))
# Get all the ground truth labels in the original order in day_23
print("Parsing ground truth labels from day_23 file...")
ground_truths = []
diff --git a/text_to_image/README.md b/text_to_image/README.md
index 26393f54f..e353bfac0 100644
--- a/text_to_image/README.md
+++ b/text_to_image/README.md
@@ -4,9 +4,10 @@ This is the reference implementation for MLPerf Inference text to image
## Supported Models
-| model | accuracy | dataset | model link | model source | precision | notes |
-| ---- | ---- | ---- | ---- | ---- | ---- | ---- |
-| StableDiffusion | - | Coco2014 | [fp32](https://cloud.mlcommons.org/index.php/s/DjnCSGyNBkWA4Ro) and [f16](https://cloud.mlcommons.org/index.php/s/LCdW5RM6wgGWbxC) | [Hugging Face](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | fp32 | NCHW |
+| model | accuracy | dataset | model source | precision | notes |
+| ---- | ---- | ---- | ---- | ---- | ---- |
+| StableDiffusion | - | Coco2014 | [Hugging Face](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | fp32 | NCHW |
+
## Dataset
@@ -47,7 +48,43 @@ CFLAGS="-std=c++14" python setup.py install
### Download model
-We host two checkpoints ([fp32](https://cloud.mlcommons.org/index.php/s/DjnCSGyNBkWA4Ro) and [f16](https://cloud.mlcommons.org/index.php/s/LCdW5RM6wgGWbxC)) that are a snapshot of the [Hugging Face](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) pipeline at the time of the release of the benchmark. Download them and move them to your model path.
+We host two checkpoints (fp32 and fp16) that are a snapshot of the [Hugging Face](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) pipeline at the time of the release of the benchmark. Download them and move them to your model path.
+
+#### CM method
+
+The following MLCommons CM commands can be used to programmatically download the model checkpoints.
+
+```
+pip install cmind
+cm pull repo mlcommons@ck
+cm run script --tags=get,ml-model,sdxl,_fp16,_rclone -j
+cm run script --tags=get,ml-model,sdxl,_fp32,_rclone -j
+```
+#### Manual method
+
+The above command automatically runs a set of Rclone commands to download the data from a Cloudflare R2 bucket. However, if you'd like to run the Rclone commands manually, you can do so as follows:
+
+To run Rclone on Windows, you can download the executable [here](https://rclone.org/install/#windows).
+To install Rclone on Linux/macOS/BSD systems, run:
+```
+sudo -v ; curl https://rclone.org/install.sh | sudo bash
+```
+Once Rclone is installed, run the following command to authenticate with the bucket:
+```
+rclone config create mlc-inference s3 provider=Cloudflare access_key_id=f65ba5eef400db161ea49967de89f47b secret_access_key=fbea333914c292b854f14d3fe232bad6c5407bf0ab1bebf78833c2b359bdfd2b endpoint=https://c2686074cb2caf5cbaf6d134bdba8b47.r2.cloudflarestorage.com
+```
+You can then navigate in the terminal to your desired download directory and run the following commands to download the checkpoints:
+
+**`fp32`**
+```
+rclone copy mlc-inference:mlcommons-inference-wg-public/stable_diffusion_fp32 ./stable_diffusion_fp32 -P
+```
+**`fp16`**
+```
+rclone copy mlc-inference:mlcommons-inference-wg-public/stable_diffusion_fp16 ./stable_diffusion_fp16 -P
+```
+
+#### Move to model path
```bash
mkdir $MODEL_PATH
diff --git a/text_to_image/coco.py b/text_to_image/coco.py
index fa89ba215..b2c9d6dfc 100644
--- a/text_to_image/coco.py
+++ b/text_to_image/coco.py
@@ -174,9 +174,9 @@ def save_images(self, ids, ds):
for id in ids:
caption = ds.get_caption(id)
generated = Image.fromarray(self.results[idx[id]])
- image_path_tmp = f"images/{self.content_ids[id]}.png"
+ image_path_tmp = f"images/{self.content_ids[idx[id]]}.png"
generated.save(image_path_tmp)
- info.append((self.content_ids[id], caption))
+ info.append((self.content_ids[idx[id]], caption))
with open("images/captions.txt", "w+") as f:
for id, caption in info:
f.write(f"{id} {caption}\n")
diff --git a/text_to_image/main.py b/text_to_image/main.py
index 0b93ceb21..32425762f 100644
--- a/text_to_image/main.py
+++ b/text_to_image/main.py
@@ -394,11 +394,18 @@ def main():
count = ds.get_item_count()
# warmup
- ds.load_query_samples([0])
+ syntetic_str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit"
+ latents_pt = torch.rand(ds.latents.shape, dtype=dtype).to(args.device)
+ warmup_samples = [
+ {
+ "input_tokens": ds.preprocess(syntetic_str, model.pipe.tokenizer),
+ "input_tokens_2": ds.preprocess(syntetic_str, model.pipe.tokenizer_2),
+ "latents": latents_pt,
+ }
+ for _ in range(args.max_batchsize)
+ ]
for i in range(5):
- captions, _ = ds.get_samples([0])
- _ = backend.predict(captions)
- ds.unload_query_samples(None)
+ _ = backend.predict(warmup_samples)
scenario = SCENARIO_MAP[args.scenario]
runner_map = {
diff --git a/text_to_image/tools/accuracy_coco.py b/text_to_image/tools/accuracy_coco.py
index ad5c93c0c..f831c4b6e 100644
--- a/text_to_image/tools/accuracy_coco.py
+++ b/text_to_image/tools/accuracy_coco.py
@@ -26,6 +26,7 @@ def get_args():
parser.add_argument("--statistics-path", default=None, help="path to statistics")
parser.add_argument("--verbose", action="store_true", help="verbose messages")
parser.add_argument("--output-file", default="coco-results.json", help="path to output file")
+ parser.add_argument("--compliance-images-path", required=False, help="path to dump 10 stable diffusion xl compliance images")
parser.add_argument("--device", default="cpu", choices=["gpu", "cpu"])
args = parser.parse_args()
return args
@@ -43,7 +44,6 @@ def preprocess_image(img_dir, file_name):
def main():
args = get_args()
-
result_dict = {}
# Load dataset annotations
@@ -63,6 +63,22 @@ def main():
if args.statistics_path is None:
statistics_path = os.path.join(os.path.dirname(__file__), "val2014.npz")
+ # Set compliance images path
+ dump_compliance_images = False
+ if args.compliance_images_path:
+ if not os.path.exists(args.compliance_images_path):
+ os.makedirs(args.compliance_images_path)
+ dump_compliance_images = True
+ compliance_images_idx_list = []
+ with open(os.path.join(os.path.dirname(__file__), "sample_ids.txt"), 'r') as compliance_id_file:
+ for line in compliance_id_file:
+ idx = int(line.strip())
+ compliance_images_idx_list.append(idx)
+ # Dump caption.txt
+ with open(os.path.join(args.compliance_images_path, "captions.txt"), "w+") as caption_file:
+ for idx in compliance_images_idx_list:
+ caption_file.write(f"{idx} {df_captions.iloc[idx]['caption']}\n")
+
# Load torchmetrics modules
clip = CLIPEncoder(device=device)
clip_scores = []
@@ -78,6 +94,11 @@ def main():
generated_img = np.frombuffer(bytes.fromhex(j['data']), np.uint8).reshape(1024, 1024, 3)
result_list.append(generated_img)
generated_img = Image.fromarray(generated_img)
+
+ # Dump compliance images
+ if dump_compliance_images and idx in compliance_images_idx_list:
+ generated_img.save(os.path.join(args.compliance_images_path, f"{idx}.png"))
+
# generated_img = torch.Tensor(generated_img).to(torch.uint8).to(device)
# Load Ground Truth
caption = df_captions.iloc[idx]["caption"]
diff --git a/tools/submission/generate_final_report.py b/tools/submission/generate_final_report.py
index 7d1200a30..13a01fe55 100644
--- a/tools/submission/generate_final_report.py
+++ b/tools/submission/generate_final_report.py
@@ -14,8 +14,8 @@ def get_args():
"""Parse commandline."""
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True, help='results csv from checker')
- parser.add_argument('--version', default='3.1', help='mlperf version')
- parser.add_argument('--repository', default='submissions_inference_3.1', help='mlperf repository')
+ parser.add_argument('--version', default='4.0', help='mlperf version')
+ parser.add_argument('--repository', default='submissions_inference_4.0', help='mlperf repository')
args = parser.parse_args()
return args
@@ -104,7 +104,7 @@ def main():
[
'resnet', 'retinanet', '3d-unet-99', '3d-unet-99.9',
'rnnt', 'bert-99', 'bert-99.9', 'dlrm-v2-99', 'dlrm-v2-99.9',
- 'gptj-99', 'gptj-99.9'
+ 'gptj-99', 'gptj-99.9', 'stable-diffusion-xl', 'llama2-70b-99', 'llama2-70b-99.9'
], ['SingleStream', 'MultiStream', 'Server', 'Offline'],
[
'Latency (ms)',
@@ -127,6 +127,9 @@ def main():
'3d-unet-99.9': ['Offline'],
'gptj-99': ['Server', 'Offline'],
'gptj-99.9': ['Server', 'Offline'],
+ 'stable-diffusion-xl': ['Server', 'Offline'],
+ 'llama2-70b-99': ['Server', 'Offline'],
+ 'llama2-70b-99.9': ['Server', 'Offline'],
},
'edge': {
'resnet': ['SingleStream', 'MultiStream', 'Offline'],
@@ -140,6 +143,7 @@ def main():
'3d-unet-99.9': ['SingleStream', 'Offline'],
'gptj-99': ['SingleStream', 'Offline'],
'gptj-99.9': ['SingleStream', 'Offline'],
+ 'stable-diffusion-xl': ['SingleStream', 'Offline'],
}
}
diff --git a/tools/submission/power/power_checker.py b/tools/submission/power/power_checker.py
index 5adcd197c..93d9d6fb9 100755
--- a/tools/submission/power/power_checker.py
+++ b/tools/submission/power/power_checker.py
@@ -408,6 +408,8 @@ def get_avg_power(power_path: str, run_path: str) -> Tuple[float, float]:
with open(spl_fname) as f:
for line in f:
+ if not line.startswith("Time"):
+ continue
timestamp = (
datetime.strptime(line.split(",")[1], datetime_format)
).replace(tzinfo=timezone.utc)
diff --git a/tools/submission/power/sources_checksums.json b/tools/submission/power/sources_checksums.json
index 0ae4cc020..78a240f1a 100644
--- a/tools/submission/power/sources_checksums.json
+++ b/tools/submission/power/sources_checksums.json
@@ -1,38 +1,4 @@
[
- {
- "server.py": "c3f90f2f7eeb4db30727556d0c815ebc89b3d28b",
- "__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709",
- "client.py": "33ca4f26368777ac06e01f9567b714a4b8063886",
- "tests/unit/test_source_hashes.py": "00468a2907583c593e6574a1f6b404e4651c221a",
- "tests/unit/__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709",
- "tests/unit/test_server.py": "948c1995d4008bc2aa6c4046a34ffa3858d6d671",
- "lib/time_sync.py": "3210db56eb0ff0df57bf4293dc4d4b03fffd46f1",
- "lib/source_hashes.py": "60a2e02193209e8d392803326208d5466342da18",
- "lib/common.py": "611d8b29633d331eb19c9455ea3b5fa3284ed6df",
- "lib/server.py": "8054263a14dedddcf8e1c01adc19596c21bad591",
- "lib/__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709",
- "lib/summary.py": "aa92f0a3f975eecd44d3c0cd0236342ccc9f941d",
- "lib/client.py": "c146491755e219a28d440b31f83998dbd5532483",
- "lib/external/ntplib.py": "4da8f970656505a40483206ef2b5d3dd5e81711d",
- "lib/external/__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709"
- },
- {
- "__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709",
- "client.py": "33ca4f26368777ac06e01f9567b714a4b8063886",
- "lib/__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709",
- "lib/client.py": "ac2aa093c8e8bbc9569b9e2a3471bc64e58a2258",
- "lib/common.py": "611d8b29633d331eb19c9455ea3b5fa3284ed6df",
- "lib/external/__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709",
- "lib/external/ntplib.py": "4da8f970656505a40483206ef2b5d3dd5e81711d",
- "lib/server.py": "c7af63c31bb2fbedea4345f571f6e3507d268ada",
- "lib/source_hashes.py": "60a2e02193209e8d392803326208d5466342da18",
- "lib/summary.py": "aa92f0a3f975eecd44d3c0cd0236342ccc9f941d",
- "lib/time_sync.py": "122eba67a9abc85635223e054def53be1367ade2",
- "server.py": "c3f90f2f7eeb4db30727556d0c815ebc89b3d28b",
- "tests/unit/__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709",
- "tests/unit/test_server.py": "948c1995d4008bc2aa6c4046a34ffa3858d6d671",
- "tests/unit/test_source_hashes.py": "00468a2907583c593e6574a1f6b404e4651c221a"
- },
{
"__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709",
"client.py": "33ca4f26368777ac06e01f9567b714a4b8063886",
@@ -41,7 +7,7 @@
"lib/common.py": "611d8b29633d331eb19c9455ea3b5fa3284ed6df",
"lib/external/__init__.py": "da39a3ee5e6b4b0d3255bfef95601890afd80709",
"lib/external/ntplib.py": "4da8f970656505a40483206ef2b5d3dd5e81711d",
- "lib/server.py": "c7af63c31bb2fbedea4345f571f6e3507d268ada",
+ "lib/server.py": "99303c836c683aa9017ec565104e636161d02acb",
"lib/source_hashes.py": "60a2e02193209e8d392803326208d5466342da18",
"lib/summary.py": "aa92f0a3f975eecd44d3c0cd0236342ccc9f941d",
"lib/time_sync.py": "80894ef2389e540781ff78de94db16aa4203a14e",
@@ -50,4 +16,4 @@
"tests/unit/test_server.py": "948c1995d4008bc2aa6c4046a34ffa3858d6d671",
"tests/unit/test_source_hashes.py": "00468a2907583c593e6574a1f6b404e4651c221a"
}
-]
+]
\ No newline at end of file
diff --git a/tools/submission/preprocess_submission.py b/tools/submission/preprocess_submission.py
index f81cfae0a..fee2aadde 100644
--- a/tools/submission/preprocess_submission.py
+++ b/tools/submission/preprocess_submission.py
@@ -44,7 +44,7 @@ def get_args():
default=False, action="store_true")
parser.add_argument(
"--version",
- default="v3.1",
+ default="v4.0",
choices=list(checker.MODEL_CONFIG.keys()),
help="mlperf version")
parser.add_argument("--submitter", help="filter to submitter")
@@ -247,6 +247,13 @@ def infer_scenario_results(filter_submitter, noinfer_low_accuracy_results, confi
shutil.copytree(high_accuracy_model_path, \
low_accuracy_model_path)
+ high_accuracy_model_code_path = os.path.join(log_path, "..", \
+ "code", model)
+ low_accuracy_model_code_path = os.path.join(log_path, "..", \
+ "code", low_accuracy_model)
+ if not os.path.exists(low_accuracy_model_code_path):
+ shutil.copytree(high_accuracy_model_code_path, \
+ low_accuracy_model_code_path)
diff --git a/tools/submission/submission_checker.py b/tools/submission/submission_checker.py
index e61590e34..0bae5b0d4 100755
--- a/tools/submission/submission_checker.py
+++ b/tools/submission/submission_checker.py
@@ -1124,12 +1124,14 @@
"3d-unet-99.9": ("DICE", 0.86170 * 0.999),
"gptj-99" : ("ROUGE1", 42.9865 * 0.99, "ROUGE2", 20.1235 * 0.99, "ROUGEL", 29.9881 * 0.99, "GEN_LEN", 4016878*0.9),
"gptj-99.9" : ("ROUGE1", 42.9865 * 0.999, "ROUGE2", 20.1235 * 0.999, "ROUGEL", 29.9881 * 0.999, "GEN_LEN", 4016878*0.9),
- "llama2-70b-99" : ("ROUGE1", 43.88 * 0.99, "ROUGE2", 21.7108 * 0.99, "ROUGEL", 28.2502 * 0.99, "TOKENS_PER_SAMPLE", 293.3*0.9),
- "llama2-70b-99.9" : ("ROUGE1", 43.88 * 0.999, "ROUGE2", 21.7108 * 0.999, "ROUGEL", 28.2502 * 0.999, "TOKENS_PER_SAMPLE", 293.3*0.9),
+ "llama2-70b-99" : ("ROUGE1", 44.4312 * 0.99, "ROUGE2", 22.0352 * 0.99, "ROUGEL", 28.6162 * 0.99, "TOKENS_PER_SAMPLE", 294.45*0.9),
+ "llama2-70b-99.9" : ("ROUGE1", 44.4312 * 0.999, "ROUGE2", 22.0352 * 0.999, "ROUGEL", 28.6162 * 0.999, "TOKENS_PER_SAMPLE", 294.45*0.9),
"stable-diffusion-xl": ("CLIP_SCORE", 31.68631873, "FID_SCORE", 23.01085758)
},
"accuracy-upper-limit": {
- "stable-diffusion-xl": ("CLIP_SCORE", 31.81331801, "FID_SCORE", 23.95007626)
+ "stable-diffusion-xl": ("CLIP_SCORE", 31.81331801, "FID_SCORE", 23.95007626),
+ "llama2-70b-99" : ("TOKENS_PER_SAMPLE", 294.45*1.1),
+ "llama2-70b-99.9" : ("TOKENS_PER_SAMPLE", 294.45*1.1)
},
"performance-sample-count": {
"resnet": 1024,
@@ -1381,11 +1383,19 @@
RESULT_FIELD_BENCHMARK_OVERWRITE = {
"llama2-70b-99": {
"Offline": "result_tokens_per_second",
- "Server": "result_scheduled_samples_per_sec",
+ "Server": "result_completed_samples_per_sec",
},
"llama2-70b-99.9": {
"Offline": "result_tokens_per_second",
- "Server": "result_scheduled_samples_per_sec",
+ "Server": "result_completed_samples_per_sec",
+ }
+}
+
+LLAMA2_LATENCY_LIMITS = {
+ # We might add interactive in the next round. Latency in ns
+ "conversational": {
+ "ttft": 2000 * 1000000,
+ "tpot": 200 * 1000000
}
}
@@ -1774,6 +1784,8 @@ def check_extra_files(path, target_files):
if target_file not in files:
check_pass = False
missing_files.append(f"{os.path.join(path, dir, target_file)}.png")
+ if "captions" not in files:
+ missing_files.append(f"{os.path.join(path, dir, 'captions.txt')}")
return check_pass, missing_files
@@ -1834,13 +1846,17 @@ def check_accuracy_dir(config, model, path, verbose):
acc_targets = []
if acc_upper_limit is not None:
acc_limits = []
+ up_patterns = []
acc_limit_check = True
+ for i in range(0, len(acc_upper_limit), 2):
+ acc_type, acc_target = acc_upper_limit[i:i+2]
+ acc_limits.append(acc_target)
+ up_patterns.append(ACC_PATTERN[acc_type])
+
for i in range(0, len(target), 2):
acc_type, acc_target = target[i:i+2]
patterns.append(ACC_PATTERN[acc_type])
acc_targets.append(acc_target)
- if acc_upper_limit is not None:
- acc_limits.append(acc_upper_limit[i+1])
acc_seen = [False for _ in acc_targets]
with open(os.path.join(path, "accuracy.txt"), "r", encoding="utf-8") as f:
for line in f:
@@ -1857,12 +1873,21 @@ def check_accuracy_dir(config, model, path, verbose):
elif acc is not None:
all_accuracy_valid = False
log.warning("%s accuracy not met: expected=%f, found=%s", path, acc_target, acc)
- if acc is not None and acc_upper_limit is not None and float(acc) > acc_limits[i]:
- acc_limit_check = False
- log.warning("%s accuracy not met: upper limit=%f, found=%s", path, acc_limits[i], acc)
if i == 0 and acc:
result_acc = acc
acc = None
+ if acc_upper_limit is not None:
+ for i, (pattern, acc_limit) in enumerate(zip(up_patterns, acc_limits)):
+ m = re.match(pattern, line)
+ if m:
+ acc = m.group(1)
+ m = re.match(r"^hash=([\w\d]+)$", line)
+ if m:
+ hash_val = m.group(1)
+ if acc is not None and acc_upper_limit is not None and float(acc) > acc_limit:
+ acc_limit_check = False
+ log.warning("%s accuracy not met: upper limit=%f, found=%s", path, acc_limit, acc)
+ acc = None
if all(acc_seen) and hash_val:
break;
is_valid = all_accuracy_valid & all(acc_seen)
@@ -1891,6 +1916,23 @@ def check_accuracy_dir(config, model, path, verbose):
return is_valid, result_acc
+def extra_check_llama2(mlperf_log, scenario):
+ if (mlperf_log["requested_use_token_latencies"]):
+ if scenario == "Offline":
+ # For offline no further checks are necessary
+ return None, True
+ else:
+ for constraint, limits in LLAMA2_LATENCY_LIMITS.items():
+ if mlperf_log["result_first_token_99.00_percentile_latency_ns"] < limits["ttft"] and mlperf_log["result_time_per_output_token_99.00_percentile_ns"] < limits["tpot"]:
+ return constraint, True
+ else:
+ log.error(f'use_token_latencies flag needs to be enabled for Llama2 benchmark')
+ return None, False
+
+ log.error(f'Failed Llama2 extra check for TTFT and TPOT. TTFT 99-tile: {mlperf_log["result_first_token_99.00_percentile_latency_ns"]}, TPOT 99-tile: {mlperf_log["result_time_per_output_token_99.00_percentile_ns"]}')
+ return None, False
+
+
def get_performance_metric(
config, model, path, scenario_fixed, division, system_json, has_power=False
):
@@ -1911,6 +1953,8 @@ def get_performance_metric(
)
res = float(mlperf_log[RESULT_FIELD_NEW[config.version][scenario_for_res]])
+ if model in RESULT_FIELD_BENCHMARK_OVERWRITE and scenario in RESULT_FIELD_BENCHMARK_OVERWRITE[model]:
+ res = float(mlperf_log[RESULT_FIELD_BENCHMARK_OVERWRITE[model][scenario_for_res]])
inferred = False
if scenario_fixed != scenario:
@@ -1946,6 +1990,9 @@ def check_performance_dir(
res = float(mlperf_log[RESULT_FIELD_NEW[config.version][scenario_for_res]])
if model in RESULT_FIELD_BENCHMARK_OVERWRITE and scenario in RESULT_FIELD_BENCHMARK_OVERWRITE[model]:
res = float(mlperf_log[RESULT_FIELD_BENCHMARK_OVERWRITE[model][scenario_for_res]])
+
+ if model in ["llama2-70b-99", "llama2-70b-99.9"]:
+ llama_constraint, is_valid = extra_check_llama2(mlperf_log, scenario_fixed)
latency_99_percentile = mlperf_log["result_99.00_percentile_latency_ns"]
latency_mean = mlperf_log["result_mean_latency_ns"]
@@ -2230,6 +2277,8 @@ def get_power_metric(config, scenario_fixed, log_path, is_valid, res):
power_list = []
with open(spl_fname) as f:
for line in f:
+ if not line.startswith("Time"):
+ continue
timestamp = (
datetime.datetime.strptime(line.split(",")[1], datetime_format)
+ server_timezone
@@ -2938,6 +2987,7 @@ def log_result(
n = ["run_1"]
for i in n:
+ is_valid = True
perf_path = os.path.join(name, "performance", i)
if not os.path.exists(perf_path):
log.error("%s is missing", perf_path)
@@ -3071,24 +3121,21 @@ def log_result(
model_name,
scenario,
)
- if not os.path.exists(compliance_dir) and "gptj" not in model_name:
- log.error("no compliance dir for %s", name)
+ if not check_compliance_dir(
+ compliance_dir,
+ mlperf_model,
+ scenario_fixed,
+ config,
+ division,
+ system_json,
+ name
+ ):
+ log.error(
+ "compliance dir %s has issues", compliance_dir
+ )
results[name] = None
else:
- if not check_compliance_dir(
- compliance_dir,
- mlperf_model,
- scenario_fixed,
- config,
- division,
- system_json,
- ):
- log.error(
- "compliance dir %s has issues", compliance_dir
- )
- results[name] = None
- else:
- compliance = 1
+ compliance = 1
if results.get(name):
if accuracy_is_valid:
@@ -3387,96 +3434,112 @@ def check_compliance_acc_dir(test_dir, model, config):
if not os.path.exists(fname):
log.error("%s is missing in %s", fname, test_dir)
else:
- # Accuracy can fail for TEST01
- is_valid = True
- with open(fname, "r") as f:
- for line in f:
- # look for: TEST PASS
- if "TEST PASS" in line:
- acc_passed = True
- break
- if acc_passed == False:
- log.info(
- "Compliance test accuracy check (deterministic mode) in %s failed",
- test_dir,
- )
+ if "TEST01" in test_dir:
+ # Accuracy can fail for TEST01
+ is_valid = True
+ with open(fname, "r") as f:
+ for line in f:
+ # look for: TEST PASS
+ if "TEST PASS" in line:
+ acc_passed = True
+ break
+ if acc_passed == False:
+ log.info(
+ "Compliance test accuracy check (deterministic mode) in %s failed",
+ test_dir,
+ )
- # Check Accuracy dir
- test_acc_path = os.path.join(test_dir, "accuracy")
- if not os.path.exists(test_acc_path):
- log.error("%s has no accuracy directory", test_dir)
- is_valid = False
- else:
- diff = files_diff(
- list_files(test_acc_path),
- REQUIRED_TEST01_ACC_FILES_1
- if acc_passed
- else REQUIRED_TEST01_ACC_FILES,
- )
- if diff:
- log.error("%s has file list mismatch (%s)", test_acc_path, diff)
+ # Check Accuracy dir
+ test_acc_path = os.path.join(test_dir, "accuracy")
+ if not os.path.exists(test_acc_path):
+ log.error("%s has no accuracy directory", test_dir)
is_valid = False
- elif not acc_passed:
- target = config.get_accuracy_target(model)
- patterns = []
- acc_types = []
- for i in range(0, len(target), 2):
- acc_type = target[i:i+2]
- acc_types.append(acc_type)
- patterns.append(ACC_PATTERN[acc_type[0]])
- acc_seen = [False for _ in acc_type]
-
-
-
- more_accurate = model.find("99.9")
- if more_accurate == -1:
- required_delta_perc = 1
- else:
- required_delta_perc = 0.1
-
- acc_baseline = {
- acc_type: 0 for acc_type in acc_types
- }
- acc_compliance = {
- acc_type: 0 for acc_type in acc_types
- }
- with open(
- os.path.join(test_acc_path, "baseline_accuracy.txt"),
- "r",
- encoding="utf-8",
- ) as f:
- for line in f:
- for acc_type, pattern in zip(acc_types, patterns):
- m = re.match(pattern, line)
- if m:
- acc_baseline[acc_type] = float(m.group(1))
- with open(
- os.path.join(test_acc_path, "compliance_accuracy.txt"),
- "r",
- encoding="utf-8",
- ) as f:
- for line in f:
- for acc_type, pattern in zip(acc_types, patterns):
- m = re.match(pattern, line)
- if m:
- acc_compliance[acc_type] = float(m.group(1))
- for acc_type in acc_types:
- if acc_baseline[acc_type] == 0 or acc_compliance[acc_type] == 0:
- is_valid = False
- break
+ else:
+ diff = files_diff(
+ list_files(test_acc_path),
+ REQUIRED_TEST01_ACC_FILES_1
+ if acc_passed
+ else REQUIRED_TEST01_ACC_FILES,
+ )
+ if diff:
+ log.error("%s has file list mismatch (%s)", test_acc_path, diff)
+ is_valid = False
+ elif not acc_passed:
+ target = config.get_accuracy_target(model)
+ patterns = []
+ acc_types = []
+ for i in range(0, len(target), 2):
+ acc_type = target[i:i+2]
+ acc_types.append(acc_type)
+ patterns.append(ACC_PATTERN[acc_type[0]])
+ acc_seen = [False for _ in acc_type]
+
+ more_accurate = model.find("99.9")
+ if more_accurate == -1:
+ required_delta_perc = 1
else:
- delta_perc = abs(1 - acc_baseline[acc_type] / acc_compliance[acc_type]) * 100
- if delta_perc <= required_delta_perc:
- is_valid = True
- else:
+ required_delta_perc = 0.1
+ acc_baseline = {
+ acc_type: 0 for acc_type in acc_types
+ }
+ acc_compliance = {
+ acc_type: 0 for acc_type in acc_types
+ }
+ with open(
+ os.path.join(test_acc_path, "baseline_accuracy.txt"),
+ "r",
+ encoding="utf-8",
+ ) as f:
+ for line in f:
+ for acc_type, pattern in zip(acc_types, patterns):
+ m = re.match(pattern, line)
+ if m:
+ acc_baseline[acc_type] = float(m.group(1))
+ with open(
+ os.path.join(test_acc_path, "compliance_accuracy.txt"),
+ "r",
+ encoding="utf-8",
+ ) as f:
+ for line in f:
+ for acc_type, pattern in zip(acc_types, patterns):
+ m = re.match(pattern, line)
+ if m:
+ acc_compliance[acc_type] = float(m.group(1))
+ for acc_type in acc_types:
+ if acc_baseline[acc_type] == 0 or acc_compliance[acc_type] == 0:
is_valid = False
break
+ else:
+ delta_perc = abs(1 - acc_baseline[acc_type] / acc_compliance[acc_type]) * 100
+ if delta_perc <= required_delta_perc:
+ is_valid = True
+ else:
+ is_valid = False
+ break
+ elif "TEST06" in test_dir:
+ """
+ Expected output
+ First token check pass: True (or First token check pass: Skipped)
+ EOS check pass: True
+ TEST06 verification complete
+ """
+ with open(fname, "r") as f:
+ lines = f.readlines()
+ lines = [line.strip() for line in lines]
+ first_token_pass = "First token check pass: True" in lines or "First token check pass: Skipped" in lines
+ eos_pass = "EOS check pass: True" in lines
+ length_check_pass = "Sample length check pass: True" in lines
+ is_valid = first_token_pass and eos_pass and length_check_pass
+ if not is_valid:
+ log.error(f"TEST06 accuracy check failed. first_token_check: {first_token_pass} eos_check: {eos_pass} length_check: {length_check_pass}.")
+ else:
+ raise NotImplemented(f"{test_dir} is neither TEST01 and TEST06, which doesn't require accuracy check")
return is_valid
def check_compliance_dir(
- compliance_dir, model, scenario, config, division, system_json
+ compliance_dir, model, scenario, config, division, system_json, name
):
compliance_perf_pass = True
compliance_perf_dir_pass = True
@@ -3517,13 +3580,20 @@ def check_compliance_dir(
]:
test_list.append("TEST06")
- # Check performance of all Tests
+ if test_list and not os.path.exists(compliance_dir):
+ log.error("no compliance dir for %s: %s", name, compliance_dir)
+ return False
+
+ # Check performance of all Tests (except for TEST06)
for test in test_list:
test_dir = os.path.join(compliance_dir, test)
if not os.path.exists(test_dir):
log.error("Missing %s in compliance dir %s", test, compliance_dir)
compliance_perf_dir_pass = False
else:
+ # TEST06 has no performance test.
+ if "TEST06" in test_list:
+ continue
try:
compliance_perf_dir = os.path.join(
compliance_dir, test, "performance", "run_1"
@@ -3546,13 +3616,14 @@ def check_compliance_dir(
and compliance_perf_valid
)
- if "TEST01" in test_list:
- # Check accuracy for TEST01
- compliance_acc_pass = check_compliance_acc_dir(
- os.path.join(compliance_dir, "TEST01"), model, config
- )
- else:
- compliance_acc_pass= True
+ compliance_acc_pass= True
+ for test in ["TEST01", "TEST06"]:
+ if test in test_list:
+ # Check accuracy for TEST01
+ compliance_acc_pass &= check_compliance_acc_dir(
+ os.path.join(compliance_dir, test), model, config
+ )
+
return compliance_perf_pass and compliance_acc_pass and compliance_perf_dir_pass