Skip to content

Commit

Permalink
Streaming first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxusmusti committed Jan 25, 2024
1 parent b1173cf commit 15a8805
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
22 changes: 20 additions & 2 deletions language/llama2-70b/SUT.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self,

if not batch_size:
if device == "cpu":
batch_size = 300
batch_size = 512
else:
batch_size = 32 # Reduce to 8 if using 4 GPUs, 16 for 8.
self.batch_size = batch_size
Expand Down Expand Up @@ -455,10 +455,28 @@ def stream_api(self, input, response_ids):
token_cache.append(token)
return token_cache

def stream_api_grpc(self, input, response_ids):
token_cache = []
first = True
resps = self.grpc_client.make_request_stream(input, model_id=self.api_model_name)
for resp in resps:
if resp.text:
tokens = self.tokenizer(resp.text)["input_ids"][1:]
if first:
self.first_token_queue.put((tokens[0], response_ids[0]))
token_cache.extend(tokens[1:])
first = False
else:
token_cache.extend(tokens)
return token_cache

def async_process_query(self, input_ids_tensor, qitem_id):
decoded = self.tokenizer.decode(input_ids_tensor[0])
response_ids = [qitem_id]
output_tokens = self.stream_api(decoded, response_ids)
if self.grpc:
output_tokens = self.stream_api_grpc(decoded, response_ids)
else:
output_tokens = self.stream_api(decoded, response_ids)

response_array = array.array("B", np.array(output_tokens, np.int32).tobytes())
bi = response_array.buffer_info()
Expand Down
2 changes: 1 addition & 1 deletion language/llama2-70b/api-endpoint-artifacts/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ spec:
restartPolicy: Never
containers:
- name: mlperf-env
image: quay.io/meyceoz/mlperf-inference:grpc-batch
image: quay.io/meyceoz/mlperf-inference:grpc-stream
resources:
requests:
memory: 20000Mi
Expand Down
10 changes: 5 additions & 5 deletions language/llama2-70b/api-endpoint-artifacts/serving-tgis.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ spec:
# value: float16
# Dynamic batch size changes
- name: MAX_BATCH_SIZE
value: "256"
value: "512"
- name: MAX_CONCURRENT_REQUESTS
value: "300"
value: "5000"
- name: MAX_BATCH_WEIGHT
value: "550000"
value: "1000000"
- name: MAX_SEQUENCE_LENGTH
value: "2048"
- name: MAX_PREFILL_WEIGHT
Expand All @@ -43,8 +43,8 @@ spec:
value: "5"
resources: # configure as required
requests:
cpu: 64
memory: 900Gi
cpu: 96
memory: 1000Gi
nvidia.com/gpu: 8
limits:
nvidia.com/gpu: 8
Expand Down
15 changes: 15 additions & 0 deletions language/llama2-70b/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ def make_request(self, texts: str, model_id: str = "flan-t5-small"):
)
result = self.generation_service_stub.Generate(request=request)
return result

def make_request_stream(self, text: str, model_id: str = "flan-t5-small"):
request = generation_pb2_grpc.generation__pb2.SingleGenerationRequest(
model_id=model_id,
request=generation_pb2_grpc.generation__pb2.GenerationRequest(text=text),
params=generation_pb2_grpc.generation__pb2.Parameters(
method=generation_pb2_grpc.generation__pb2.GREEDY,
stopping=generation_pb2_grpc.generation__pb2.StoppingCriteria(
max_new_tokens=1024,
min_new_tokens=1
)
)
)
result = self.generation_service_stub.GenerateStream(request=request)
return result

def __enter__(self):
return self
Expand Down

0 comments on commit 15a8805

Please sign in to comment.