Skip to content

Commit

Permalink
fix: Hacky way to wait prediction service to be ready (#8)
Browse files Browse the repository at this point in the history
* fix: Hacky way to wait prediction service to be ready

* ci: Fix github workflow

* ci: Fix integration test.

---------

Signed-off-by: pitt-liang <[email protected]>
  • Loading branch information
pitt-liang authored Jan 2, 2024
1 parent 88d5b44 commit a91bc83
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 34 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
name: Integration test

on:
pull_request_review:
types: [submitted, edited]
push:
branches:
- master

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
integration-test:
# Integration test only works while the Pull Request is approved or the source repository is trusted repository.
if: github.event.review.state == 'APPROVED' || github.event.pull_request.head.repo.full_name == 'aliyun/pai-python-sdk'
runs-on: ubuntu-latest
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Lint test

on: [push, pull_request]
on: [push]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unit.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Unit test

on: [push, pull_request]
on: [push]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
8 changes: 7 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,20 @@ def integration(session: Session):
if os.environ.get(key, value) is not None
}

# set worker to 2 * cpu_count (physical cores) if not specified
if "-n" not in session.posargs and "--numprocesses" not in session.posargs:
pos_args = session.posargs + ["-n", str(os.cpu_count() * 2)]
else:
pos_args = session.posargs

session.run(
"pytest",
"--cov-config=.coveragerc",
"--cov-append",
"--cov-report=html",
"--cov=pai",
os.path.join("tests", "integration"),
*session.posargs,
*pos_args,
env=env,
)
session.run(
Expand Down
6 changes: 6 additions & 0 deletions pai/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,18 @@ def make_list_resource_iterator(method: Callable, **kwargs):
kwargs.update(page_number=page_number, page_size=page_size)
result = method(**kwargs)
if isinstance(result, PaginatedResult):
total_count = result.total_count
result = result.items
else:
total_count = None
for item in result:
yield item

if len(result) == 0 or len(result) < page_size:
return
if total_count and page_number * page_size >= total_count:
return

page_number += 1


Expand Down
33 changes: 23 additions & 10 deletions pai/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

from Tea.exceptions import TeaException

from .api.base import PaginatedResult
from .api.entity_base import EntityBaseMixin
from .common import ProviderAlibabaPAI, git_utils
from .common.consts import INSTANCE_TYPE_LOCAL_GPU, FileSystemInputScheme, JobType
Expand Down Expand Up @@ -1521,7 +1524,9 @@ def _normalize_name(name: str) -> str:
for name, value in self.estimator.hyperparameters.items():
env[_TrainingEnv.ENV_PAI_HPS_PREFIX + _normalize_name(name)] = str(value)
user_args.extend(["--" + name, shlex.quote(str(value))])
env[_TrainingEnv.ENV_PAI_USER_ARGS] = shlex.join(user_args)
env[_TrainingEnv.ENV_PAI_USER_ARGS] = " ".join(
[shlex.quote(v) for v in user_args]
)
env[_TrainingEnv.ENV_PAI_HPS] = json.dumps(
{name: str(value) for name, value in self.estimator.hyperparameters.items()}
)
Expand Down Expand Up @@ -1878,15 +1883,27 @@ def __init__(
self._future = None
self._stop = False

def _list_logs(self):
page_number, page_offset = 1, 0
# print training job logs.
while not self._stop:
def _list_logs_api(self, page_number: int = 1):
try:
res = self.session.training_job_api.list_logs(
self.training_job_id,
page_number=page_number,
page_size=self.page_size,
)
return res
except TeaException as e:
# hack: Backend service may raise an exception when the training job
# instance is not found.
if e.code == "TRAINING_JOB_INSTANCE_NOT_FOUND":
return PaginatedResult(items=[], total_count=0)
else:
raise e

def _list_logs(self):
page_number, page_offset = 1, 0
# print training job logs.
while not self._stop:
res = self._list_logs_api(page_number=page_number)
# 1. move to next page
if len(res.items) == self.page_size:
# print new logs starting from page_offset
Expand All @@ -1904,11 +1921,7 @@ def _list_logs(self):
# When _stop is True, wait and print remaining logs.
time.sleep(10)
while True:
res = self.session.training_job_api.list_logs(
self.training_job_id,
page_number=page_number,
page_size=self.page_size,
)
res = self._list_logs_api(page_number=page_number)
# There maybe more logs in the next page
if len(res.items) == self.page_size:
self._print_logs(logs=res.items[page_offset:])
Expand Down
2 changes: 1 addition & 1 deletion pai/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _get_supported_tf_versions_for_inference(self) -> List[str]:
return res

def _get_latest_tf_version_for_inference(self) -> str:
"""Return the latest Transformers version for inference."""
"""Return the latest transformers version for inference."""
res = self._get_supported_tf_versions_for_inference()
return max(
res,
Expand Down
15 changes: 10 additions & 5 deletions pai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,6 @@ def _deploy(
)
if wait:
predictor.wait_for_ready()
time.sleep(5)

return predictor

Expand Down Expand Up @@ -987,12 +986,18 @@ def _deploy_local(

# build command to install requirements
if requirements_list:
install_requirements = shlex.join(
["python", "-m", "pip", "install"] + requirements_list
install_requirements = " ".join(
[
shlex.quote(s)
for s in ["python", "-m", "pip", "install"] + requirements_list
]
)
elif requirements_path:
install_requirements = shlex.join(
["python", "-m", "pip", "install", "-r", requirements_path]
install_requirements = " ".join(
[
shlex.quote(s)
for s in ["python", "-m", "pip", "install", "-r", requirements_path]
]
)
else:
install_requirements = ""
Expand Down
52 changes: 47 additions & 5 deletions pai/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def completed_status(cls):


class EndpointType(object):

# Public Internet Endpoint
INTERNET = "INTERNET"

Expand All @@ -82,7 +81,6 @@ class EndpointType(object):


class ServiceType(object):

Standard = "Standard"
Async = "Async"

Expand Down Expand Up @@ -296,22 +294,66 @@ def delete_service(self):
"""Delete the service."""
self.session.service_api.delete(name=self.service_name)

def wait_for_ready(self):
"""Wait until the service enter running status."""
def wait_for_ready(self, force: bool = False):
"""Wait until the service enter running status.
Args:
force (bool): Whether to force wait for ready.
Raises:
RuntimeError: Raise if the service terminated unexpectedly.
"""
if self.service_status == ServiceStatus.Running and not force:
return

logger.info(
"Service waiting for ready: service_name={}".format(self.service_name)
)
unexpected_status = ServiceStatus.completed_status()
unexpected_status.remove(ServiceStatus.Running)

type(self)._wait_for_status(
service_name=self.service_name,
status=ServiceStatus.Running,
unexpected_status=unexpected_status,
session=self.session,
)

# hack: PAI-EAS gateway may not be ready when the service is ready.
self._wait_for_gateway_ready()
self.refresh()

def _wait_for_gateway_ready(self, attempts: int = 30, interval: int = 2):
"""Hacky way to wait for the service gateway to be ready.
Args:
attempts (int): Number of attempts to wait for the service gateway to be
ready.
interval (int): Interval between each attempt.
"""

def _is_gateway_not_ready(resp: requests.Response):
return resp.status_code == 503 and resp.content == b"no healthy upstream"

err_count_threshold = 3
err_count = 0
while attempts > 0:
attempts -= 1
try:
# Send a probe request to the service.
resp = self._send_request(method="GET")
if not _is_gateway_not_ready(resp):
logger.info("Gateway for the service is ready.")
break
except requests.exceptions.RequestException as e:
err_count += 1
if err_count >= err_count_threshold:
logger.warning("Failed to check gateway status: %s", e)
break
time.sleep(interval)
else:
logger.warning("Timeout waiting for gateway to be ready.")

@classmethod
def _wait_for_status(
cls,
Expand Down
5 changes: 0 additions & 5 deletions tests/integration/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import json
import os
import time

import numpy as np

Expand Down Expand Up @@ -85,8 +84,6 @@ def _init_predictor(cls):

p = Predictor(service_name=service_name)
p.wait_for_ready()
# hack: wait for service to be really ready
time.sleep(15)
return p

@classmethod
Expand Down Expand Up @@ -235,8 +232,6 @@ def _init_predictor(cls):

p = Predictor(service_name=service_name)
p.wait_for_ready()
# hack: wait for service to be really ready
time.sleep(15)
return p

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def is_inner(self):

@classmethod
def _load_test_config(cls):
test_config = os.environ.get("PAI_TEST_CONFIG", "test_public.ini")
test_config = os.environ.get("PAI_TEST_CONFIG", "test.ini")
cfg_parser = configparser.ConfigParser()
cfg_parser.read(os.path.join(_test_root, test_config))

Expand Down

0 comments on commit a91bc83

Please sign in to comment.