Skip to content

Commit

Permalink
First pass: standalone TGIS, grpc, batching
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxusmusti committed Jan 25, 2024
1 parent bc594fa commit 6a43fa5
Show file tree
Hide file tree
Showing 9 changed files with 567 additions and 15 deletions.
62 changes: 49 additions & 13 deletions language/llama2-70b/SUT.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import time
import re
import numpy as np
import array
import torch
Expand All @@ -20,6 +21,8 @@
from urllib3.exceptions import InsecureRequestWarning
import json

from inference import GrpcClient

requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)

import logging
Expand Down Expand Up @@ -90,6 +93,8 @@ def __init__(self,
model_path=None,
api_server=None,
api_model_name=None,
grpc=False,
batch_grpc=False,
dtype="bfloat16",
device="cpu",
batch_size=None,
Expand All @@ -101,11 +106,13 @@ def __init__(self,
self.model_path = model_path or "meta-llama/Llama-2-70b-chat-hf"
self.api_server = api_server
self.api_model_name = api_model_name
self.grpc = grpc
self.batch_grpc = batch_grpc
self.device = device

if not batch_size:
if device == "cpu":
batch_size = 110
batch_size = 300
else:
batch_size = 32 # Reduce to 8 if using 4 GPUs, 16 for 8.
self.batch_size = batch_size
Expand Down Expand Up @@ -173,15 +180,28 @@ def query_api(self, input):
},
}

response = requests.post(
self.api_server,
headers=headers,
json=json_data,
verify=False,
)

response_code = 0
while response_code != 200:
try:
response = requests.post(
self.api_server,
headers=headers,
json=json_data,
verify=False,
)
response_code = response.status_code
except:
print("connection failure")
return json.loads(response.text)["generated_text"]


def query_api_grpc(self, input):
resp = self.grpc_client.make_request([input], model_id=self.api_model_name)
return resp.responses[0].text

def query_api_batch_grpc(self, inputs):
resps = self.grpc_client.make_request(inputs, model_id=self.api_model_name)
return [resp.text for resp in resps.responses]

def process_queries(self):
"""Processor of the queued queries. User may choose to add batching logic """
Expand Down Expand Up @@ -236,8 +256,15 @@ def process_queries(self):
tik2 = time.time()

if self.api_server:
with ThreadPoolExecutor(max_workers=bs) as executor:
output = list(executor.map(self.query_api,cleaned))
if self.grpc:
if self.batch_grpc:
output = self.query_api_batch_grpc(cleaned)
else:
with ThreadPoolExecutor(max_workers=bs) as executor:
output = list(executor.map(self.query_api_grpc,cleaned))
else:
with ThreadPoolExecutor(max_workers=bs) as executor:
output = list(executor.map(self.query_api,cleaned))
else:
pred_output_tokens = self.model.generate(
input_ids=input_ids_tensor,
Expand Down Expand Up @@ -277,7 +304,16 @@ def process_queries(self):

def load_model(self):
if self.api_server:
if not "http" in self.api_server:
if self.grpc:
hostname = re.sub("https://|http://", "", self.api_server)
if hostname[-1] == "/":
hostname = hostname[:-1]
self.grpc_client = GrpcClient(
hostname,
443,
verify=False,
)
elif not "http" in self.api_server:
self.api_server = "http://" + self.api_server

if not self.api_model_name:
Expand Down Expand Up @@ -340,9 +376,9 @@ def __del__(self):


class SUTServer(SUT):
def __init__(self, model_path=None, api_server=None, api_model_name=None, dtype="bfloat16", device="cpu", total_sample_count=24576, dataset_path=None, workers=1):
def __init__(self, model_path=None, api_server=None, api_model_name=None, grpc=False, dtype="bfloat16", device="cpu", total_sample_count=24576, dataset_path=None, workers=1):

super().__init__(model_path=model_path, api_server=api_server, api_model_name=api_model_name, dtype=dtype, device=device, total_sample_count=total_sample_count, dataset_path=dataset_path, workers=workers)
super().__init__(model_path=model_path, api_server=api_server, api_model_name=api_model_name, grpc=grpc, dtype=dtype, device=device, total_sample_count=total_sample_count, dataset_path=dataset_path, workers=workers)

with open(f"{self.model_path}/tokenizer.json", 'r') as token_file:
llama_tokenizer = json.load(token_file)
Expand Down
2 changes: 1 addition & 1 deletion language/llama2-70b/api-endpoint-artifacts/Dockerfile-API
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ COPY llama-model-info ./llama-model-info
RUN apt-get update && apt install build-essential -y
RUN conda install pybind11==2.10.4 -c conda-forge -y
RUN pip install --upgrade pip
RUN pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0
RUN pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 grpcio-tools
RUN cd inference/loadgen && python -m pip install .
RUN cp inference/mlperf.conf inference/language/llama2-70b/mlperf.conf

Expand Down
2 changes: 1 addition & 1 deletion language/llama2-70b/api-endpoint-artifacts/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ spec:
restartPolicy: Never
containers:
- name: mlperf-env
image: quay.io/meyceoz/mlperf-inference:v4
image: quay.io/meyceoz/mlperf-inference:grpc-batch
resources:
requests:
memory: 20000Mi
Expand Down
20 changes: 20 additions & 0 deletions language/llama2-70b/api-endpoint-artifacts/model-tgis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
apiVersion: serving.kserve.io/v1beta1
kind: InferenceService
metadata:
annotations:
serving.knative.openshift.io/enablePassthrough: "true"
sidecar.istio.io/inject: "true"
sidecar.istio.io/rewriteAppHTTPProbers: "true"
name: llama-2-70b-chat-isvc
spec:
predictor:
minReplicas: 1
maxReplicas: 1
#apiVersion: serving.kserve.io/v1alpha2
serviceAccountName: sa
#timeout: 240
model:
modelFormat:
name: pytorch
runtime: tgis-runtime-grpc
storageUri: s3://mlperf-inference-models/Llama-2-70b-chat-hf/
67 changes: 67 additions & 0 deletions language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
apiVersion: serving.kserve.io/v1alpha1
kind: ServingRuntime
metadata:
name: tgis-runtime-grpc
spec:
multiModel: false
supportedModelFormats:
- autoSelect: true
name: pytorch
containers:
- name: kserve-container
image: quay.io/opendatahub/text-generation-inference:fast
command: ["text-generation-launcher"]
args:
- "--model-name=/mnt/models/"
- "--port=3000"
- "--grpc-port=8033"
env:
- name: TRANSFORMERS_CACHE
value: /tmp/transformers_cache
- name: NUM_GPUS
value: "8"
#- name: DTYPE_STR
# value: float16
# Dynamic batch size changes
- name: MAX_BATCH_SIZE
value: "256"
- name: MAX_CONCURRENT_REQUESTS
value: "300"
- name: MAX_BATCH_WEIGHT
value: "550000"
- name: MAX_SEQUENCE_LENGTH
value: "2048"
- name: MAX_PREFILL_WEIGHT
value: "0"
- name: MAX_NEW_TOKENS
value: "1024"
- name: FLASH_ATTENTION
value: "true"
- name: DEPLOYMENT_FRAMEWORK
value: hf_custom_tp
- name: LOG_GPU_USAGE_INTERVAL
value: "5"
resources: # configure as required
requests:
cpu: 64
memory: 900Gi
nvidia.com/gpu: 8
limits:
nvidia.com/gpu: 8
readinessProbe: # Use exec probes instad of httpGet since the probes' port gets rewritten to the containerPort
exec:
command:
- curl
- localhost:3000/health
initialDelaySeconds: 5
livenessProbe:
exec:
command:
- curl
- localhost:3000/health
initialDelaySeconds: 5
# periodSeconds: 5
ports:
- containerPort: 8033
name: h2c
protocol: TCP
70 changes: 70 additions & 0 deletions language/llama2-70b/generation_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6a43fa5

Please sign in to comment.