diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 6563997..e70a926 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -1,8 +1,9 @@ name: Integration test on: - pull_request_review: - types: [submitted, edited] + push: + branches: + - master concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -10,8 +11,6 @@ concurrency: jobs: integration-test: - # Integration test only works while the Pull Request is approved or the source repository is trusted repository. - if: github.event.review.state == 'APPROVED' || github.event.pull_request.head.repo.full_name == 'aliyun/pai-python-sdk' runs-on: ubuntu-latest env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 05c5820..1acf4fe 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -1,6 +1,6 @@ name: Lint test -on: [push, pull_request] +on: [push] concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/.github/workflows/unit.yaml b/.github/workflows/unit.yaml index aa087d4..de74d41 100644 --- a/.github/workflows/unit.yaml +++ b/.github/workflows/unit.yaml @@ -1,6 +1,6 @@ name: Unit test -on: [push, pull_request] +on: [push] concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/noxfile.py b/noxfile.py index 244846e..68d4a0d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -50,6 +50,12 @@ def integration(session: Session): if os.environ.get(key, value) is not None } + # set worker to 2 * cpu_count (physical cores) if not specified + if "-n" not in session.posargs and "--numprocesses" not in session.posargs: + pos_args = session.posargs + ["-n", str(os.cpu_count() * 2)] + else: + pos_args = session.posargs + session.run( "pytest", "--cov-config=.coveragerc", @@ -57,7 +63,7 @@ def integration(session: Session): "--cov-report=html", "--cov=pai", os.path.join("tests", "integration"), - *session.posargs, + *pos_args, env=env, ) session.run( diff --git a/pai/common/utils.py b/pai/common/utils.py index 2e3c654..45bb80d 100644 --- a/pai/common/utils.py +++ b/pai/common/utils.py @@ -88,12 +88,18 @@ def make_list_resource_iterator(method: Callable, **kwargs): kwargs.update(page_number=page_number, page_size=page_size) result = method(**kwargs) if isinstance(result, PaginatedResult): + total_count = result.total_count result = result.items + else: + total_count = None for item in result: yield item if len(result) == 0 or len(result) < page_size: return + if total_count and page_number * page_size >= total_count: + return + page_number += 1 diff --git a/pai/estimator.py b/pai/estimator.py index 5441692..da85be6 100644 --- a/pai/estimator.py +++ b/pai/estimator.py @@ -29,6 +29,9 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Union +from Tea.exceptions import TeaException + +from .api.base import PaginatedResult from .api.entity_base import EntityBaseMixin from .common import ProviderAlibabaPAI, git_utils from .common.consts import INSTANCE_TYPE_LOCAL_GPU, FileSystemInputScheme, JobType @@ -1521,7 +1524,9 @@ def _normalize_name(name: str) -> str: for name, value in self.estimator.hyperparameters.items(): env[_TrainingEnv.ENV_PAI_HPS_PREFIX + _normalize_name(name)] = str(value) user_args.extend(["--" + name, shlex.quote(str(value))]) - env[_TrainingEnv.ENV_PAI_USER_ARGS] = shlex.join(user_args) + env[_TrainingEnv.ENV_PAI_USER_ARGS] = " ".join( + [shlex.quote(v) for v in user_args] + ) env[_TrainingEnv.ENV_PAI_HPS] = json.dumps( {name: str(value) for name, value in self.estimator.hyperparameters.items()} ) @@ -1878,15 +1883,27 @@ def __init__( self._future = None self._stop = False - def _list_logs(self): - page_number, page_offset = 1, 0 - # print training job logs. - while not self._stop: + def _list_logs_api(self, page_number: int = 1): + try: res = self.session.training_job_api.list_logs( self.training_job_id, page_number=page_number, page_size=self.page_size, ) + return res + except TeaException as e: + # hack: Backend service may raise an exception when the training job + # instance is not found. + if e.code == "TRAINING_JOB_INSTANCE_NOT_FOUND": + return PaginatedResult(items=[], total_count=0) + else: + raise e + + def _list_logs(self): + page_number, page_offset = 1, 0 + # print training job logs. + while not self._stop: + res = self._list_logs_api(page_number=page_number) # 1. move to next page if len(res.items) == self.page_size: # print new logs starting from page_offset @@ -1904,11 +1921,7 @@ def _list_logs(self): # When _stop is True, wait and print remaining logs. time.sleep(10) while True: - res = self.session.training_job_api.list_logs( - self.training_job_id, - page_number=page_number, - page_size=self.page_size, - ) + res = self._list_logs_api(page_number=page_number) # There maybe more logs in the next page if len(res.items) == self.page_size: self._print_logs(logs=res.items[page_offset:]) diff --git a/pai/huggingface/model.py b/pai/huggingface/model.py index 7b46b7a..1771470 100644 --- a/pai/huggingface/model.py +++ b/pai/huggingface/model.py @@ -259,7 +259,7 @@ def _get_supported_tf_versions_for_inference(self) -> List[str]: return res def _get_latest_tf_version_for_inference(self) -> str: - """Return the latest Transformers version for inference.""" + """Return the latest transformers version for inference.""" res = self._get_supported_tf_versions_for_inference() return max( res, diff --git a/pai/model.py b/pai/model.py index b631231..609c354 100644 --- a/pai/model.py +++ b/pai/model.py @@ -816,7 +816,6 @@ def _deploy( ) if wait: predictor.wait_for_ready() - time.sleep(5) return predictor @@ -987,12 +986,18 @@ def _deploy_local( # build command to install requirements if requirements_list: - install_requirements = shlex.join( - ["python", "-m", "pip", "install"] + requirements_list + install_requirements = " ".join( + [ + shlex.quote(s) + for s in ["python", "-m", "pip", "install"] + requirements_list + ] ) elif requirements_path: - install_requirements = shlex.join( - ["python", "-m", "pip", "install", "-r", requirements_path] + install_requirements = " ".join( + [ + shlex.quote(s) + for s in ["python", "-m", "pip", "install", "-r", requirements_path] + ] ) else: install_requirements = "" diff --git a/pai/predictor.py b/pai/predictor.py index fdd06e5..fc2cb88 100644 --- a/pai/predictor.py +++ b/pai/predictor.py @@ -73,7 +73,6 @@ def completed_status(cls): class EndpointType(object): - # Public Internet Endpoint INTERNET = "INTERNET" @@ -82,7 +81,6 @@ class EndpointType(object): class ServiceType(object): - Standard = "Standard" Async = "Async" @@ -296,22 +294,66 @@ def delete_service(self): """Delete the service.""" self.session.service_api.delete(name=self.service_name) - def wait_for_ready(self): - """Wait until the service enter running status.""" + def wait_for_ready(self, force: bool = False): + """Wait until the service enter running status. + + Args: + force (bool): Whether to force wait for ready. + + Raises: + RuntimeError: Raise if the service terminated unexpectedly. + + """ + if self.service_status == ServiceStatus.Running and not force: + return + logger.info( "Service waiting for ready: service_name={}".format(self.service_name) ) unexpected_status = ServiceStatus.completed_status() unexpected_status.remove(ServiceStatus.Running) - type(self)._wait_for_status( service_name=self.service_name, status=ServiceStatus.Running, unexpected_status=unexpected_status, session=self.session, ) + + # hack: PAI-EAS gateway may not be ready when the service is ready. + self._wait_for_gateway_ready() self.refresh() + def _wait_for_gateway_ready(self, attempts: int = 30, interval: int = 2): + """Hacky way to wait for the service gateway to be ready. + + Args: + attempts (int): Number of attempts to wait for the service gateway to be + ready. + interval (int): Interval between each attempt. + """ + + def _is_gateway_not_ready(resp: requests.Response): + return resp.status_code == 503 and resp.content == b"no healthy upstream" + + err_count_threshold = 3 + err_count = 0 + while attempts > 0: + attempts -= 1 + try: + # Send a probe request to the service. + resp = self._send_request(method="GET") + if not _is_gateway_not_ready(resp): + logger.info("Gateway for the service is ready.") + break + except requests.exceptions.RequestException as e: + err_count += 1 + if err_count >= err_count_threshold: + logger.warning("Failed to check gateway status: %s", e) + break + time.sleep(interval) + else: + logger.warning("Timeout waiting for gateway to be ready.") + @classmethod def _wait_for_status( cls, diff --git a/tests/integration/test_predictor.py b/tests/integration/test_predictor.py index 10278fc..f7b6b5a 100644 --- a/tests/integration/test_predictor.py +++ b/tests/integration/test_predictor.py @@ -14,7 +14,6 @@ import json import os -import time import numpy as np @@ -85,8 +84,6 @@ def _init_predictor(cls): p = Predictor(service_name=service_name) p.wait_for_ready() - # hack: wait for service to be really ready - time.sleep(15) return p @classmethod @@ -235,8 +232,6 @@ def _init_predictor(cls): p = Predictor(service_name=service_name) p.wait_for_ready() - # hack: wait for service to be really ready - time.sleep(15) return p @classmethod diff --git a/tests/integration/utils.py b/tests/integration/utils.py index bc19f51..dbca156 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -97,7 +97,7 @@ def is_inner(self): @classmethod def _load_test_config(cls): - test_config = os.environ.get("PAI_TEST_CONFIG", "test_public.ini") + test_config = os.environ.get("PAI_TEST_CONFIG", "test.ini") cfg_parser = configparser.ConfigParser() cfg_parser.read(os.path.join(_test_root, test_config))