-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
155 changes: 155 additions & 0 deletions
155
text-generation-inference/integration-tests/conftest.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import asyncio | ||
import contextlib | ||
import os | ||
import random | ||
import shlex | ||
import subprocess | ||
import sys | ||
import time | ||
from tempfile import TemporaryDirectory | ||
from typing import List | ||
|
||
import docker | ||
import pytest | ||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError | ||
from docker.errors import NotFound | ||
from text_generation import AsyncClient | ||
from text_generation.types import Response | ||
|
||
|
||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "neuronx-tgi:latest") | ||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) | ||
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") | ||
|
||
|
||
class LauncherHandle: | ||
def __init__(self, port: int): | ||
self.client = AsyncClient(f"http://localhost:{port}") | ||
|
||
def _inner_health(self): | ||
raise NotImplementedError | ||
|
||
async def health(self, timeout: int = 60): | ||
assert timeout > 0 | ||
for _ in range(timeout): | ||
if not self._inner_health(): | ||
raise RuntimeError("Launcher crashed") | ||
|
||
try: | ||
await self.client.generate("test") | ||
return | ||
except (ClientConnectorError, ClientOSError, ServerDisconnectedError): | ||
time.sleep(1) | ||
raise RuntimeError("Health check failed") | ||
|
||
|
||
class ContainerLauncherHandle(LauncherHandle): | ||
def __init__(self, docker_client, container_name, port: int): | ||
super(ContainerLauncherHandle, self).__init__(port) | ||
self.docker_client = docker_client | ||
self.container_name = container_name | ||
|
||
def _inner_health(self) -> bool: | ||
container = self.docker_client.containers.get(self.container_name) | ||
return container.status in ["running", "created"] | ||
|
||
|
||
class ProcessLauncherHandle(LauncherHandle): | ||
def __init__(self, process, port: int): | ||
super(ProcessLauncherHandle, self).__init__(port) | ||
self.process = process | ||
|
||
def _inner_health(self) -> bool: | ||
return self.process.poll() is None | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def event_loop(): | ||
loop = asyncio.get_event_loop() | ||
yield loop | ||
loop.close() | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def data_volume(): | ||
tmpdir = TemporaryDirectory() | ||
yield tmpdir.name | ||
# Cleanup the temporary directory using sudo as it contains root files created by the container | ||
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}")) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def launcher(event_loop, data_volume): | ||
@contextlib.contextmanager | ||
def docker_launcher( | ||
model_id: str, | ||
trust_remote_code: bool = False, | ||
): | ||
port = random.randint(8000, 10_000) | ||
|
||
args = ["--model-id", model_id, "--env"] | ||
|
||
if trust_remote_code: | ||
args.append("--trust-remote-code") | ||
|
||
client = docker.from_env() | ||
|
||
container_name = f"tgi-tests-{model_id.split('/')[-1]}" | ||
|
||
try: | ||
container = client.containers.get(container_name) | ||
container.stop() | ||
container.wait() | ||
except NotFound: | ||
pass | ||
|
||
env = {"LOG_LEVEL": "info,text_generation_router=debug"} | ||
|
||
if HUGGING_FACE_HUB_TOKEN is not None: | ||
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN | ||
|
||
for var in ["HF_BATCH_SIZE", "HF_SEQUENCE_LENGTH", "HF_AUTOCAST_TYPE", "HF_NUM_CORES"]: | ||
if var in os.environ: | ||
env[var] = os.environ[var] | ||
|
||
volumes = [f"{data_volume}:/data"] | ||
|
||
container = client.containers.run( | ||
DOCKER_IMAGE, | ||
command=args, | ||
name=container_name, | ||
environment=env, | ||
auto_remove=False, | ||
detach=True, | ||
devices=["/dev/neuron0"], | ||
volumes=volumes, | ||
ports={"80/tcp": port}, | ||
shm_size="1G", | ||
) | ||
|
||
yield ContainerLauncherHandle(client, container.name, port) | ||
|
||
try: | ||
container.stop() | ||
container.wait() | ||
except NotFound: | ||
pass | ||
|
||
container_output = container.logs().decode("utf-8") | ||
print(container_output, file=sys.stderr) | ||
|
||
container.remove() | ||
|
||
return docker_launcher | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def generate_load(): | ||
async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]: | ||
futures = [ | ||
client.generate(prompt, max_new_tokens=max_new_tokens, decoder_input_details=True) for _ in range(n) | ||
] | ||
|
||
return await asyncio.gather(*futures) | ||
|
||
return generate_load_inner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[pytest] | ||
asyncio_mode = auto |
18 changes: 18 additions & 0 deletions
18
text-generation-inference/integration-tests/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2023 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
text-generation >= 0.6.0 | ||
pytest >= 7.4.0 | ||
pytest-asyncio >= 0.21.1 | ||
docker >= 6.1.3 | ||
Levenshtein |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import os | ||
|
||
import huggingface_hub | ||
import Levenshtein | ||
import pytest | ||
|
||
|
||
MODEL_ID = "gpt2" | ||
NEURON_MODEL_ID = "aws-neuron/gpt2-neuronx-bs4-seqlen1024" | ||
BATCH_SIZE = 4 | ||
SEQUENCE_LENGTH = 1024 | ||
NUM_CORES = 2 | ||
|
||
|
||
@pytest.fixture(scope="module", params=["hub-neuron", "hub", "local-neuron"]) | ||
def model_name_or_path(request, data_volume): | ||
if request.param == "hub": | ||
os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE) | ||
os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH) | ||
os.environ["HF_NUM_CORES"] = str(NUM_CORES) | ||
yield MODEL_ID | ||
elif request.param == "hub-neuron": | ||
yield NEURON_MODEL_ID | ||
else: | ||
model_dir = f"gpt2-neuron-{BATCH_SIZE}x{SEQUENCE_LENGTH}x{NUM_CORES}" | ||
local_path = os.path.join(data_volume, model_dir) | ||
huggingface_hub.snapshot_download(NEURON_MODEL_ID, local_dir=local_path) | ||
# Return the path of the model inside the mounted volume | ||
yield os.path.join("/data", model_dir) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def tgi_service(launcher, model_name_or_path): | ||
with launcher(model_name_or_path) as tgi_service: | ||
yield tgi_service | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
async def tgi_client(tgi_service): | ||
await tgi_service.health(300) | ||
return tgi_service.client | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_model_single_request(tgi_client): | ||
|
||
# Greedy bounded without input | ||
response = await tgi_client.generate( | ||
"What is Deep Learning?", | ||
max_new_tokens=17, | ||
decoder_input_details=True, | ||
) | ||
assert response.details.generated_tokens == 17 | ||
assert response.generated_text == "\n\nDeep learning is a new field of research that has been around for a while" | ||
|
||
# Greedy bounded with input | ||
response = await tgi_client.generate( | ||
"What is Deep Learning?", | ||
max_new_tokens=17, | ||
return_full_text=True, | ||
decoder_input_details=True, | ||
) | ||
assert response.details.generated_tokens == 17 | ||
assert ( | ||
response.generated_text | ||
== "What is Deep Learning?\n\nDeep learning is a new field of research that has been around for a while" | ||
) | ||
|
||
# Sampling | ||
response = await tgi_client.generate( | ||
"What is Deep Learning?", | ||
do_sample=True, | ||
top_k=50, | ||
top_p=0.9, | ||
repetition_penalty=1.2, | ||
max_new_tokens=1000, | ||
seed=42, | ||
decoder_input_details=True, | ||
) | ||
assert "The purpose of the current post is" in response.generated_text | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_model_multiple_requests(tgi_client, generate_load): | ||
num_requests = 4 | ||
responses = await generate_load( | ||
tgi_client, | ||
"What is Deep Learning?", | ||
max_new_tokens=17, | ||
n=num_requests, | ||
) | ||
|
||
assert len(responses) == 4 | ||
expected = "\n\nDeep learning is a new field of research that has been around for a while" | ||
for r in responses: | ||
assert r.details.generated_tokens == 17 | ||
# Compute the similarity with the expectation using the levenshtein distance | ||
# We should not have more than two substitutions or additions | ||
assert Levenshtein.distance(r.generated_text, expected) < 3 |