From 41df23944ffb5eb5bd6919fd515ef3f3ff3a4d65 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Thu, 25 Jan 2024 14:23:41 -0500 Subject: [PATCH] Streaming first pass --- language/llama2-70b/SUT.py | 22 +++++++++++++++++-- .../api-endpoint-artifacts/benchmark.yaml | 2 +- .../api-endpoint-artifacts/serving-tgis.yaml | 10 ++++----- language/llama2-70b/inference.py | 15 +++++++++++++ 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/language/llama2-70b/SUT.py b/language/llama2-70b/SUT.py index b7793563b8..bf7c195e37 100644 --- a/language/llama2-70b/SUT.py +++ b/language/llama2-70b/SUT.py @@ -112,7 +112,7 @@ def __init__(self, if not batch_size: if device == "cpu": - batch_size = 300 + batch_size = 512 else: batch_size = 32 # Reduce to 8 if using 4 GPUs, 16 for 8. self.batch_size = batch_size @@ -456,10 +456,28 @@ def stream_api(self, input, response_ids): token_cache.append(token) return token_cache + def stream_api_grpc(self, input, response_ids): + token_cache = [] + first = True + resps = self.grpc_client.make_request_stream(input, model_id=self.api_model_name) + for resp in resps: + if resp.text: + tokens = self.tokenizer(resp.text)["input_ids"][1:] + if first: + self.first_token_queue.put((tokens[0], response_ids[0])) + token_cache.extend(tokens[1:]) + first = False + else: + token_cache.extend(tokens) + return token_cache + def async_process_query(self, input_ids_tensor, qitem_id): decoded = self.tokenizer.decode(input_ids_tensor[0]) response_ids = [qitem_id] - output_tokens = self.stream_api(decoded, response_ids) + if self.grpc: + output_tokens = self.stream_api_grpc(decoded, response_ids) + else: + output_tokens = self.stream_api(decoded, response_ids) n_tokens = len(output_tokens) response_array = array.array("B", np.array(output_tokens, np.int32).tobytes()) diff --git a/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml b/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml index 79001774d1..0fe72ded5c 100644 --- a/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml +++ b/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml @@ -6,7 +6,7 @@ spec: restartPolicy: Never containers: - name: mlperf-env - image: quay.io/meyceoz/mlperf-inference:grpc-batch + image: quay.io/meyceoz/mlperf-inference:grpc-stream resources: requests: memory: 20000Mi diff --git a/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml b/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml index b461481338..fa0644af57 100644 --- a/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml +++ b/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml @@ -24,11 +24,11 @@ spec: # value: float16 # Dynamic batch size changes - name: MAX_BATCH_SIZE - value: "256" + value: "512" - name: MAX_CONCURRENT_REQUESTS - value: "300" + value: "5000" - name: MAX_BATCH_WEIGHT - value: "550000" + value: "1000000" - name: MAX_SEQUENCE_LENGTH value: "2048" - name: MAX_PREFILL_WEIGHT @@ -43,8 +43,8 @@ spec: value: "5" resources: # configure as required requests: - cpu: 64 - memory: 900Gi + cpu: 96 + memory: 1000Gi nvidia.com/gpu: 8 limits: nvidia.com/gpu: 8 diff --git a/language/llama2-70b/inference.py b/language/llama2-70b/inference.py index e62a513312..c7893c2ab7 100644 --- a/language/llama2-70b/inference.py +++ b/language/llama2-70b/inference.py @@ -71,6 +71,21 @@ def make_request(self, texts: str, model_id: str = "flan-t5-small"): ) result = self.generation_service_stub.Generate(request=request) return result + + def make_request_stream(self, text: str, model_id: str = "flan-t5-small"): + request = generation_pb2_grpc.generation__pb2.SingleGenerationRequest( + model_id=model_id, + request=generation_pb2_grpc.generation__pb2.GenerationRequest(text=text), + params=generation_pb2_grpc.generation__pb2.Parameters( + method=generation_pb2_grpc.generation__pb2.GREEDY, + stopping=generation_pb2_grpc.generation__pb2.StoppingCriteria( + max_new_tokens=1024, + min_new_tokens=1 + ) + ) + ) + result = self.generation_service_stub.GenerateStream(request=request) + return result def __enter__(self): return self