From 6a43fa5632a8b51bac972e037700a6dfaf978975 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 24 Jan 2024 19:13:41 -0500 Subject: [PATCH] First pass: standalone TGIS, grpc, batching --- language/llama2-70b/SUT.py | 62 ++++-- .../api-endpoint-artifacts/Dockerfile-API | 2 +- .../api-endpoint-artifacts/benchmark.yaml | 2 +- .../api-endpoint-artifacts/model-tgis.yaml | 20 ++ .../api-endpoint-artifacts/serving-tgis.yaml | 67 +++++++ language/llama2-70b/generation_pb2.py | 70 +++++++ language/llama2-70b/generation_pb2_grpc.py | 169 ++++++++++++++++ language/llama2-70b/inference.py | 186 ++++++++++++++++++ language/llama2-70b/main.py | 4 + 9 files changed, 567 insertions(+), 15 deletions(-) create mode 100644 language/llama2-70b/api-endpoint-artifacts/model-tgis.yaml create mode 100644 language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml create mode 100644 language/llama2-70b/generation_pb2.py create mode 100644 language/llama2-70b/generation_pb2_grpc.py create mode 100644 language/llama2-70b/inference.py diff --git a/language/llama2-70b/SUT.py b/language/llama2-70b/SUT.py index f11fe1e3fe..ba877a0e12 100644 --- a/language/llama2-70b/SUT.py +++ b/language/llama2-70b/SUT.py @@ -1,6 +1,7 @@ import os import sys import time +import re import numpy as np import array import torch @@ -20,6 +21,8 @@ from urllib3.exceptions import InsecureRequestWarning import json +from inference import GrpcClient + requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) import logging @@ -90,6 +93,8 @@ def __init__(self, model_path=None, api_server=None, api_model_name=None, + grpc=False, + batch_grpc=False, dtype="bfloat16", device="cpu", batch_size=None, @@ -101,11 +106,13 @@ def __init__(self, self.model_path = model_path or "meta-llama/Llama-2-70b-chat-hf" self.api_server = api_server self.api_model_name = api_model_name + self.grpc = grpc + self.batch_grpc = batch_grpc self.device = device if not batch_size: if device == "cpu": - batch_size = 110 + batch_size = 300 else: batch_size = 32 # Reduce to 8 if using 4 GPUs, 16 for 8. self.batch_size = batch_size @@ -173,15 +180,28 @@ def query_api(self, input): }, } - response = requests.post( - self.api_server, - headers=headers, - json=json_data, - verify=False, - ) - + response_code = 0 + while response_code != 200: + try: + response = requests.post( + self.api_server, + headers=headers, + json=json_data, + verify=False, + ) + response_code = response.status_code + except: + print("connection failure") return json.loads(response.text)["generated_text"] + + + def query_api_grpc(self, input): + resp = self.grpc_client.make_request([input], model_id=self.api_model_name) + return resp.responses[0].text + def query_api_batch_grpc(self, inputs): + resps = self.grpc_client.make_request(inputs, model_id=self.api_model_name) + return [resp.text for resp in resps.responses] def process_queries(self): """Processor of the queued queries. User may choose to add batching logic """ @@ -236,8 +256,15 @@ def process_queries(self): tik2 = time.time() if self.api_server: - with ThreadPoolExecutor(max_workers=bs) as executor: - output = list(executor.map(self.query_api,cleaned)) + if self.grpc: + if self.batch_grpc: + output = self.query_api_batch_grpc(cleaned) + else: + with ThreadPoolExecutor(max_workers=bs) as executor: + output = list(executor.map(self.query_api_grpc,cleaned)) + else: + with ThreadPoolExecutor(max_workers=bs) as executor: + output = list(executor.map(self.query_api,cleaned)) else: pred_output_tokens = self.model.generate( input_ids=input_ids_tensor, @@ -277,7 +304,16 @@ def process_queries(self): def load_model(self): if self.api_server: - if not "http" in self.api_server: + if self.grpc: + hostname = re.sub("https://|http://", "", self.api_server) + if hostname[-1] == "/": + hostname = hostname[:-1] + self.grpc_client = GrpcClient( + hostname, + 443, + verify=False, + ) + elif not "http" in self.api_server: self.api_server = "http://" + self.api_server if not self.api_model_name: @@ -340,9 +376,9 @@ def __del__(self): class SUTServer(SUT): - def __init__(self, model_path=None, api_server=None, api_model_name=None, dtype="bfloat16", device="cpu", total_sample_count=24576, dataset_path=None, workers=1): + def __init__(self, model_path=None, api_server=None, api_model_name=None, grpc=False, dtype="bfloat16", device="cpu", total_sample_count=24576, dataset_path=None, workers=1): - super().__init__(model_path=model_path, api_server=api_server, api_model_name=api_model_name, dtype=dtype, device=device, total_sample_count=total_sample_count, dataset_path=dataset_path, workers=workers) + super().__init__(model_path=model_path, api_server=api_server, api_model_name=api_model_name, grpc=grpc, dtype=dtype, device=device, total_sample_count=total_sample_count, dataset_path=dataset_path, workers=workers) with open(f"{self.model_path}/tokenizer.json", 'r') as token_file: llama_tokenizer = json.load(token_file) diff --git a/language/llama2-70b/api-endpoint-artifacts/Dockerfile-API b/language/llama2-70b/api-endpoint-artifacts/Dockerfile-API index e7231ba0e2..77385d42ea 100644 --- a/language/llama2-70b/api-endpoint-artifacts/Dockerfile-API +++ b/language/llama2-70b/api-endpoint-artifacts/Dockerfile-API @@ -7,7 +7,7 @@ COPY llama-model-info ./llama-model-info RUN apt-get update && apt install build-essential -y RUN conda install pybind11==2.10.4 -c conda-forge -y RUN pip install --upgrade pip -RUN pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 +RUN pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 grpcio-tools RUN cd inference/loadgen && python -m pip install . RUN cp inference/mlperf.conf inference/language/llama2-70b/mlperf.conf diff --git a/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml b/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml index 012e78c332..79001774d1 100644 --- a/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml +++ b/language/llama2-70b/api-endpoint-artifacts/benchmark.yaml @@ -6,7 +6,7 @@ spec: restartPolicy: Never containers: - name: mlperf-env - image: quay.io/meyceoz/mlperf-inference:v4 + image: quay.io/meyceoz/mlperf-inference:grpc-batch resources: requests: memory: 20000Mi diff --git a/language/llama2-70b/api-endpoint-artifacts/model-tgis.yaml b/language/llama2-70b/api-endpoint-artifacts/model-tgis.yaml new file mode 100644 index 0000000000..6f920658b5 --- /dev/null +++ b/language/llama2-70b/api-endpoint-artifacts/model-tgis.yaml @@ -0,0 +1,20 @@ +apiVersion: serving.kserve.io/v1beta1 +kind: InferenceService +metadata: + annotations: + serving.knative.openshift.io/enablePassthrough: "true" + sidecar.istio.io/inject: "true" + sidecar.istio.io/rewriteAppHTTPProbers: "true" + name: llama-2-70b-chat-isvc +spec: + predictor: + minReplicas: 1 + maxReplicas: 1 + #apiVersion: serving.kserve.io/v1alpha2 + serviceAccountName: sa + #timeout: 240 + model: + modelFormat: + name: pytorch + runtime: tgis-runtime-grpc + storageUri: s3://mlperf-inference-models/Llama-2-70b-chat-hf/ \ No newline at end of file diff --git a/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml b/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml new file mode 100644 index 0000000000..b461481338 --- /dev/null +++ b/language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml @@ -0,0 +1,67 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: ServingRuntime +metadata: + name: tgis-runtime-grpc +spec: + multiModel: false + supportedModelFormats: + - autoSelect: true + name: pytorch + containers: + - name: kserve-container + image: quay.io/opendatahub/text-generation-inference:fast + command: ["text-generation-launcher"] + args: + - "--model-name=/mnt/models/" + - "--port=3000" + - "--grpc-port=8033" + env: + - name: TRANSFORMERS_CACHE + value: /tmp/transformers_cache + - name: NUM_GPUS + value: "8" + #- name: DTYPE_STR + # value: float16 + # Dynamic batch size changes + - name: MAX_BATCH_SIZE + value: "256" + - name: MAX_CONCURRENT_REQUESTS + value: "300" + - name: MAX_BATCH_WEIGHT + value: "550000" + - name: MAX_SEQUENCE_LENGTH + value: "2048" + - name: MAX_PREFILL_WEIGHT + value: "0" + - name: MAX_NEW_TOKENS + value: "1024" + - name: FLASH_ATTENTION + value: "true" + - name: DEPLOYMENT_FRAMEWORK + value: hf_custom_tp + - name: LOG_GPU_USAGE_INTERVAL + value: "5" + resources: # configure as required + requests: + cpu: 64 + memory: 900Gi + nvidia.com/gpu: 8 + limits: + nvidia.com/gpu: 8 + readinessProbe: # Use exec probes instad of httpGet since the probes' port gets rewritten to the containerPort + exec: + command: + - curl + - localhost:3000/health + initialDelaySeconds: 5 + livenessProbe: + exec: + command: + - curl + - localhost:3000/health + initialDelaySeconds: 5 + # periodSeconds: 5 + ports: + - containerPort: 8033 + name: h2c + protocol: TCP \ No newline at end of file diff --git a/language/llama2-70b/generation_pb2.py b/language/llama2-70b/generation_pb2.py new file mode 100644 index 0000000000..544e1ec7db --- /dev/null +++ b/language/llama2-70b/generation_pb2.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: generation.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10generation.proto\x12\x05\x66maas\"\xa1\x01\n\x18\x42\x61tchedGenerationRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x16\n\tprefix_id\x18\x02 \x01(\tH\x00\x88\x01\x01\x12*\n\x08requests\x18\x03 \x03(\x0b\x32\x18.fmaas.GenerationRequest\x12!\n\x06params\x18\n \x01(\x0b\x32\x11.fmaas.ParametersB\x0c\n\n_prefix_id\"\x9f\x01\n\x17SingleGenerationRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x16\n\tprefix_id\x18\x02 \x01(\tH\x00\x88\x01\x01\x12)\n\x07request\x18\x03 \x01(\x0b\x32\x18.fmaas.GenerationRequest\x12!\n\x06params\x18\n \x01(\x0b\x32\x11.fmaas.ParametersB\x0c\n\n_prefix_id\"I\n\x19\x42\x61tchedGenerationResponse\x12,\n\tresponses\x18\x01 \x03(\x0b\x32\x19.fmaas.GenerationResponse\"!\n\x11GenerationRequest\x12\x0c\n\x04text\x18\x02 \x01(\t\"\xf3\x01\n\x12GenerationResponse\x12\x19\n\x11input_token_count\x18\x06 \x01(\r\x12\x1d\n\x15generated_token_count\x18\x02 \x01(\r\x12\x0c\n\x04text\x18\x04 \x01(\t\x12&\n\x0bstop_reason\x18\x07 \x01(\x0e\x32\x11.fmaas.StopReason\x12\x15\n\rstop_sequence\x18\x0b \x01(\t\x12\x0c\n\x04seed\x18\n \x01(\x04\x12 \n\x06tokens\x18\x08 \x03(\x0b\x32\x10.fmaas.TokenInfo\x12&\n\x0cinput_tokens\x18\t \x03(\x0b\x32\x10.fmaas.TokenInfo\"\x81\x02\n\nParameters\x12%\n\x06method\x18\x01 \x01(\x0e\x32\x15.fmaas.DecodingMethod\x12+\n\x08sampling\x18\x02 \x01(\x0b\x32\x19.fmaas.SamplingParameters\x12)\n\x08stopping\x18\x03 \x01(\x0b\x32\x17.fmaas.StoppingCriteria\x12(\n\x08response\x18\x04 \x01(\x0b\x32\x16.fmaas.ResponseOptions\x12+\n\x08\x64\x65\x63oding\x18\x05 \x01(\x0b\x32\x19.fmaas.DecodingParameters\x12\x1d\n\x15truncate_input_tokens\x18\x06 \x01(\r\"\xc5\x01\n\x12\x44\x65\x63odingParameters\x12\x1a\n\x12repetition_penalty\x18\x01 \x01(\x02\x12\x44\n\x0elength_penalty\x18\x02 \x01(\x0b\x32\'.fmaas.DecodingParameters.LengthPenaltyH\x00\x88\x01\x01\x1a:\n\rLengthPenalty\x12\x13\n\x0bstart_index\x18\x01 \x01(\r\x12\x14\n\x0c\x64\x65\x63\x61y_factor\x18\x02 \x01(\x02\x42\x11\n\x0f_length_penalty\"v\n\x12SamplingParameters\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_k\x18\x02 \x01(\r\x12\r\n\x05top_p\x18\x03 \x01(\x02\x12\x11\n\ttypical_p\x18\x04 \x01(\x02\x12\x11\n\x04seed\x18\x05 \x01(\x04H\x00\x88\x01\x01\x42\x07\n\x05_seed\"\xb3\x01\n\x10StoppingCriteria\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\r\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\r\x12\x19\n\x11time_limit_millis\x18\x03 \x01(\r\x12\x16\n\x0estop_sequences\x18\x04 \x03(\t\x12\"\n\x15include_stop_sequence\x18\x05 \x01(\x08H\x00\x88\x01\x01\x42\x18\n\x16_include_stop_sequence\"\x98\x01\n\x0fResponseOptions\x12\x12\n\ninput_text\x18\x01 \x01(\x08\x12\x18\n\x10generated_tokens\x18\x02 \x01(\x08\x12\x14\n\x0cinput_tokens\x18\x03 \x01(\x08\x12\x16\n\x0etoken_logprobs\x18\x04 \x01(\x08\x12\x13\n\x0btoken_ranks\x18\x05 \x01(\x08\x12\x14\n\x0ctop_n_tokens\x18\x06 \x01(\r\"\x92\x01\n\tTokenInfo\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x0f\n\x07logprob\x18\x03 \x01(\x02\x12\x0c\n\x04rank\x18\x04 \x01(\r\x12-\n\ntop_tokens\x18\x05 \x03(\x0b\x32\x19.fmaas.TokenInfo.TopToken\x1a)\n\x08TopToken\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x0f\n\x07logprob\x18\x03 \x01(\x02\"k\n\x16\x42\x61tchedTokenizeRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12(\n\x08requests\x18\x02 \x03(\x0b\x32\x16.fmaas.TokenizeRequest\x12\x15\n\rreturn_tokens\x18\x03 \x01(\x08\"E\n\x17\x42\x61tchedTokenizeResponse\x12*\n\tresponses\x18\x01 \x03(\x0b\x32\x17.fmaas.TokenizeResponse\"\x1f\n\x0fTokenizeRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\"7\n\x10TokenizeResponse\x12\x13\n\x0btoken_count\x18\x01 \x01(\r\x12\x0e\n\x06tokens\x18\x02 \x03(\t\"$\n\x10ModelInfoRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\"\xb4\x01\n\x11ModelInfoResponse\x12\x36\n\nmodel_kind\x18\x01 \x01(\x0e\x32\".fmaas.ModelInfoResponse.ModelKind\x12\x1b\n\x13max_sequence_length\x18\x02 \x01(\r\x12\x16\n\x0emax_new_tokens\x18\x03 \x01(\r\"2\n\tModelKind\x12\x10\n\x0c\x44\x45\x43ODER_ONLY\x10\x00\x12\x13\n\x0f\x45NCODER_DECODER\x10\x01*(\n\x0e\x44\x65\x63odingMethod\x12\n\n\x06GREEDY\x10\x00\x12\n\n\x06SAMPLE\x10\x01*\x8b\x01\n\nStopReason\x12\x10\n\x0cNOT_FINISHED\x10\x00\x12\x0e\n\nMAX_TOKENS\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\r\n\tCANCELLED\x10\x03\x12\x0e\n\nTIME_LIMIT\x10\x04\x12\x11\n\rSTOP_SEQUENCE\x10\x05\x12\x0f\n\x0bTOKEN_LIMIT\x10\x06\x12\t\n\x05\x45RROR\x10\x07\x32\xc4\x02\n\x11GenerationService\x12O\n\x08Generate\x12\x1f.fmaas.BatchedGenerationRequest\x1a .fmaas.BatchedGenerationResponse\"\x00\x12O\n\x0eGenerateStream\x12\x1e.fmaas.SingleGenerationRequest\x1a\x19.fmaas.GenerationResponse\"\x00\x30\x01\x12K\n\x08Tokenize\x12\x1d.fmaas.BatchedTokenizeRequest\x1a\x1e.fmaas.BatchedTokenizeResponse\"\x00\x12@\n\tModelInfo\x12\x17.fmaas.ModelInfoRequest\x1a\x18.fmaas.ModelInfoResponse\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'generation_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_DECODINGMETHOD']._serialized_start=2266 + _globals['_DECODINGMETHOD']._serialized_end=2306 + _globals['_STOPREASON']._serialized_start=2309 + _globals['_STOPREASON']._serialized_end=2448 + _globals['_BATCHEDGENERATIONREQUEST']._serialized_start=28 + _globals['_BATCHEDGENERATIONREQUEST']._serialized_end=189 + _globals['_SINGLEGENERATIONREQUEST']._serialized_start=192 + _globals['_SINGLEGENERATIONREQUEST']._serialized_end=351 + _globals['_BATCHEDGENERATIONRESPONSE']._serialized_start=353 + _globals['_BATCHEDGENERATIONRESPONSE']._serialized_end=426 + _globals['_GENERATIONREQUEST']._serialized_start=428 + _globals['_GENERATIONREQUEST']._serialized_end=461 + _globals['_GENERATIONRESPONSE']._serialized_start=464 + _globals['_GENERATIONRESPONSE']._serialized_end=707 + _globals['_PARAMETERS']._serialized_start=710 + _globals['_PARAMETERS']._serialized_end=967 + _globals['_DECODINGPARAMETERS']._serialized_start=970 + _globals['_DECODINGPARAMETERS']._serialized_end=1167 + _globals['_DECODINGPARAMETERS_LENGTHPENALTY']._serialized_start=1090 + _globals['_DECODINGPARAMETERS_LENGTHPENALTY']._serialized_end=1148 + _globals['_SAMPLINGPARAMETERS']._serialized_start=1169 + _globals['_SAMPLINGPARAMETERS']._serialized_end=1287 + _globals['_STOPPINGCRITERIA']._serialized_start=1290 + _globals['_STOPPINGCRITERIA']._serialized_end=1469 + _globals['_RESPONSEOPTIONS']._serialized_start=1472 + _globals['_RESPONSEOPTIONS']._serialized_end=1624 + _globals['_TOKENINFO']._serialized_start=1627 + _globals['_TOKENINFO']._serialized_end=1773 + _globals['_TOKENINFO_TOPTOKEN']._serialized_start=1732 + _globals['_TOKENINFO_TOPTOKEN']._serialized_end=1773 + _globals['_BATCHEDTOKENIZEREQUEST']._serialized_start=1775 + _globals['_BATCHEDTOKENIZEREQUEST']._serialized_end=1882 + _globals['_BATCHEDTOKENIZERESPONSE']._serialized_start=1884 + _globals['_BATCHEDTOKENIZERESPONSE']._serialized_end=1953 + _globals['_TOKENIZEREQUEST']._serialized_start=1955 + _globals['_TOKENIZEREQUEST']._serialized_end=1986 + _globals['_TOKENIZERESPONSE']._serialized_start=1988 + _globals['_TOKENIZERESPONSE']._serialized_end=2043 + _globals['_MODELINFOREQUEST']._serialized_start=2045 + _globals['_MODELINFOREQUEST']._serialized_end=2081 + _globals['_MODELINFORESPONSE']._serialized_start=2084 + _globals['_MODELINFORESPONSE']._serialized_end=2264 + _globals['_MODELINFORESPONSE_MODELKIND']._serialized_start=2214 + _globals['_MODELINFORESPONSE_MODELKIND']._serialized_end=2264 + _globals['_GENERATIONSERVICE']._serialized_start=2451 + _globals['_GENERATIONSERVICE']._serialized_end=2775 +# @@protoc_insertion_point(module_scope) diff --git a/language/llama2-70b/generation_pb2_grpc.py b/language/llama2-70b/generation_pb2_grpc.py new file mode 100644 index 0000000000..9ab1e4eb5d --- /dev/null +++ b/language/llama2-70b/generation_pb2_grpc.py @@ -0,0 +1,169 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import generation_pb2 as generation__pb2 + + +class GenerationServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Generate = channel.unary_unary( + '/fmaas.GenerationService/Generate', + request_serializer=generation__pb2.BatchedGenerationRequest.SerializeToString, + response_deserializer=generation__pb2.BatchedGenerationResponse.FromString, + ) + self.GenerateStream = channel.unary_stream( + '/fmaas.GenerationService/GenerateStream', + request_serializer=generation__pb2.SingleGenerationRequest.SerializeToString, + response_deserializer=generation__pb2.GenerationResponse.FromString, + ) + self.Tokenize = channel.unary_unary( + '/fmaas.GenerationService/Tokenize', + request_serializer=generation__pb2.BatchedTokenizeRequest.SerializeToString, + response_deserializer=generation__pb2.BatchedTokenizeResponse.FromString, + ) + self.ModelInfo = channel.unary_unary( + '/fmaas.GenerationService/ModelInfo', + request_serializer=generation__pb2.ModelInfoRequest.SerializeToString, + response_deserializer=generation__pb2.ModelInfoResponse.FromString, + ) + + +class GenerationServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Generate(self, request, context): + """Generates text given a text prompt, for one or more inputs + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GenerateStream(self, request, context): + """Generates text given a single input prompt, streaming the response + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Tokenize(self, request, context): + """Tokenize text + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ModelInfo(self, request, context): + """Model info + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_GenerationServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Generate': grpc.unary_unary_rpc_method_handler( + servicer.Generate, + request_deserializer=generation__pb2.BatchedGenerationRequest.FromString, + response_serializer=generation__pb2.BatchedGenerationResponse.SerializeToString, + ), + 'GenerateStream': grpc.unary_stream_rpc_method_handler( + servicer.GenerateStream, + request_deserializer=generation__pb2.SingleGenerationRequest.FromString, + response_serializer=generation__pb2.GenerationResponse.SerializeToString, + ), + 'Tokenize': grpc.unary_unary_rpc_method_handler( + servicer.Tokenize, + request_deserializer=generation__pb2.BatchedTokenizeRequest.FromString, + response_serializer=generation__pb2.BatchedTokenizeResponse.SerializeToString, + ), + 'ModelInfo': grpc.unary_unary_rpc_method_handler( + servicer.ModelInfo, + request_deserializer=generation__pb2.ModelInfoRequest.FromString, + response_serializer=generation__pb2.ModelInfoResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'fmaas.GenerationService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class GenerationService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Generate(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/fmaas.GenerationService/Generate', + generation__pb2.BatchedGenerationRequest.SerializeToString, + generation__pb2.BatchedGenerationResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GenerateStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/fmaas.GenerationService/GenerateStream', + generation__pb2.SingleGenerationRequest.SerializeToString, + generation__pb2.GenerationResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Tokenize(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/fmaas.GenerationService/Tokenize', + generation__pb2.BatchedTokenizeRequest.SerializeToString, + generation__pb2.BatchedTokenizeResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ModelInfo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/fmaas.GenerationService/ModelInfo', + generation__pb2.ModelInfoRequest.SerializeToString, + generation__pb2.ModelInfoResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/language/llama2-70b/inference.py b/language/llama2-70b/inference.py new file mode 100644 index 0000000000..e62a513312 --- /dev/null +++ b/language/llama2-70b/inference.py @@ -0,0 +1,186 @@ +from typing import Optional, Union + +import grpc +import generation_pb2_grpc +import socket +import ssl +import sys + + +def get_server_certificate(host: str, port: int) -> str: + """connect to host:port and get the certificate it presents + + This is almost the same as `ssl.get_server_certificate`, but + when opening the TLS socket, `server_hostname` is also provided. + + This retrieves the correct certificate for hosts using name-based + virtual hosting. + """ + if sys.version_info >= (3, 10): + # ssl.get_server_certificate supports TLS SNI only above 3.10 + # https://github.com/python/cpython/pull/16820 + return ssl.get_server_certificate((host, port)) + + context = ssl.SSLContext() + + with socket.create_connection((host, port)) as sock, context.wrap_socket( + sock, server_hostname=host + ) as ssock: + cert_der = ssock.getpeercert(binary_form=True) + + assert cert_der + return ssl.DER_cert_to_PEM_cert(cert_der) + + +class GrpcClient: + def __init__( + self, + host: str, + port: int, + *, + insecure: bool = False, + verify: Optional[bool] = None, + ca_cert: Union[None, bytes, str] = None, + client_cert: Union[None, bytes, str] = None, + client_key: Union[None, bytes, str] = None, + ) -> None: + self._channel = self._make_channel( + host, + port, + insecure=insecure, + verify=verify, + client_key=client_key, + client_cert=client_cert, + ca_cert=ca_cert, + ) + self.generation_service_stub = generation_pb2_grpc.GenerationServiceStub( + self._channel + ) + + def make_request(self, texts: str, model_id: str = "flan-t5-small"): + request = generation_pb2_grpc.generation__pb2.BatchedGenerationRequest( + model_id=model_id, + requests=[generation_pb2_grpc.generation__pb2.GenerationRequest(text=text) for text in texts], + params=generation_pb2_grpc.generation__pb2.Parameters( + method=generation_pb2_grpc.generation__pb2.GREEDY, + stopping=generation_pb2_grpc.generation__pb2.StoppingCriteria( + max_new_tokens=1024, + min_new_tokens=1 + ) + ) + ) + result = self.generation_service_stub.Generate(request=request) + return result + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self._close() + return False + + def _close(self): + try: + if hasattr(self, "_channel") and self._channel: + self._channel.close() + except Exception as exc: + print(f"Unexpected exception while closing client: {exc}") + + def __del__(self): + self._close() + + def _make_channel( + self, + host: str, + port: int, + *, + insecure: bool = False, + verify: Optional[bool] = None, + ca_cert: Union[None, bytes, str] = None, + client_key: Union[None, bytes, str] = None, + client_cert: Union[None, bytes, str] = None, + ) -> grpc.Channel: + """Creates a grpc channel + + Args: + - host: str + - port: str + - (optional) insecure: use a plaintext connection (default=False) + - (optional) verify: set to False to disable remote host certificate(s) + verification. Cannot be used with `plaintext` or with MTLS + - (optional) ca_cert: certificate authority to use + - (optional) client_key: client key for mTLS mode + - (optional) client_cert: client cert for mTLS mode + + """ + if not host.strip(): + raise ValueError("A non empty host name is required") + if int(port) <= 0: + raise ValueError("A non zero port number is required") + if insecure and any( + (val is not None) for val in (ca_cert, client_key, client_cert) + ): + raise ValueError("cannot use insecure with TLS/mTLS certificates") + if insecure and verify: + raise ValueError("insecure cannot be used with verify") + + client_key_bytes = self._try_load_certificate(client_key) + client_cert_bytes = self._try_load_certificate(client_cert) + ca_cert_bytes = self._try_load_certificate(ca_cert) + + connection = f"{host}:{port}" + if insecure: + print("Connecting over an insecure plaintext grpc channel") + return grpc.insecure_channel(connection) + + credentials_kwargs: dict[str, bytes] = {} + if ca_cert_bytes and not (any((client_cert_bytes, client_key_bytes))): + print("Connecting using provided CA certificate for secure channel") + credentials_kwargs.update(root_certificates=ca_cert_bytes) + elif client_cert_bytes and client_key_bytes and ca_cert_bytes: + print("Connecting using mTLS for secure channel") + credentials_kwargs.update( + root_certificates=ca_cert_bytes, + private_key=client_key_bytes, + certificate_chain=client_cert_bytes, + ) + elif verify is False: + cert = get_server_certificate(host, port).encode() + credentials_kwargs.update(root_certificates=cert) + + return grpc.secure_channel( + connection, grpc.ssl_channel_credentials(**credentials_kwargs) + ) + + @staticmethod + def _try_load_certificate(certificate: Union[None, bytes, str]) -> Optional[bytes]: + """If the certificate points to a file, return the contents (plaintext reads). + Else return the bytes""" + if not certificate: + return None + + if isinstance(certificate, bytes): + return certificate + + if isinstance(certificate, str): + with open(certificate, "rb") as secret_file: + return secret_file.read() + raise ValueError( + f"{certificate=} should be a path to a certificate files or bytes" + ) + + +if __name__ == "__main__": + # host = "flan-t5-small-predictor-caikit-testing.apps.aisrhods-wx.8goc.p1.openshiftapps.com" + # port = 443 + host = "localhost" + port = 8033 + print(f"connecting to {host=}:{port}") + client = GrpcClient( + host, + port, + insecure=True, + # verify=False, + ) + + client.make_request("this is the query text", model_id="flan-t5-small") diff --git a/language/llama2-70b/main.py b/language/llama2-70b/main.py index 424d48df7a..6ccddd0e7a 100644 --- a/language/llama2-70b/main.py +++ b/language/llama2-70b/main.py @@ -19,6 +19,8 @@ def get_args(): parser.add_argument("--api-server", type=str, default=None, help="Specify an api endpoint call to use api mode") parser.add_argument("--api-model-name", type=str, default=None, help="Specify a model name to use api mode") parser.add_argument("--accuracy", action="store_true", help="Run accuracy mode") + parser.add_argument("--grpc", action="store_true", help="Enable grpc for api endpoint") + parser.add_argument("--batch-grpc", action="store_true", help="Enable batch requests for grpc") parser.add_argument("--dtype", type=str, default="float32", help="data type of the model, choose from float16, bfloat16 and float32") parser.add_argument("--device", type=str, choices=["cpu", "cuda:0"], default="cpu", help="device to use") parser.add_argument("--audit-conf", type=str, default="audit.conf", help="audit config for LoadGen settings during compliance runs") @@ -72,6 +74,8 @@ def main(): model_path=args.model_path, api_server=args.api_server, api_model_name=args.api_model_name, + grpc=args.grpc, + batch_grpc=args.batch_grpc, dtype=args.dtype, dataset_path=args.dataset_path, total_sample_count=args.total_sample_count,