Skip to content

Commit

Permalink
ci: Fix integration test.
Browse files Browse the repository at this point in the history
Signed-off-by: pitt-liang <[email protected]>
  • Loading branch information
pitt-liang committed Jan 1, 2024
1 parent 343694a commit c62c07c
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 20 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
name: Integration test

on:
pull_request_review:
types: [submitted, edited]
push:
branches:
- master

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
integration-test:
# Integration test only works while the Pull Request is approved or the source repository is trusted repository.
if: github.event.review.state == 'APPROVED' || github.event.pull_request.head.repo.full_name == 'aliyun/pai-python-sdk'
runs-on: ubuntu-latest
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
8 changes: 7 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,20 @@ def integration(session: Session):
if os.environ.get(key, value) is not None
}

# set worker to 2 * cpu_count (physical cores) if not specified
if "-n" not in session.posargs and "--numprocesses" not in session.posargs:
pos_args = session.posargs + ["-n", str(os.cpu_count() * 2)]
else:
pos_args = session.posargs

session.run(
"pytest",
"--cov-config=.coveragerc",
"--cov-append",
"--cov-report=html",
"--cov=pai",
os.path.join("tests", "integration"),
*session.posargs,
*pos_args,
env=env,
)
session.run(
Expand Down
33 changes: 23 additions & 10 deletions pai/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

from Tea.exceptions import TeaException

from .api.base import PaginatedResult
from .api.entity_base import EntityBaseMixin
from .common import ProviderAlibabaPAI, git_utils
from .common.consts import INSTANCE_TYPE_LOCAL_GPU, FileSystemInputScheme, JobType
Expand Down Expand Up @@ -1521,7 +1524,9 @@ def _normalize_name(name: str) -> str:
for name, value in self.estimator.hyperparameters.items():
env[_TrainingEnv.ENV_PAI_HPS_PREFIX + _normalize_name(name)] = str(value)
user_args.extend(["--" + name, shlex.quote(str(value))])
env[_TrainingEnv.ENV_PAI_USER_ARGS] = shlex.join(user_args)
env[_TrainingEnv.ENV_PAI_USER_ARGS] = " ".join(
[shlex.quote(v) for v in user_args]
)
env[_TrainingEnv.ENV_PAI_HPS] = json.dumps(
{name: str(value) for name, value in self.estimator.hyperparameters.items()}
)
Expand Down Expand Up @@ -1878,15 +1883,27 @@ def __init__(
self._future = None
self._stop = False

def _list_logs(self):
page_number, page_offset = 1, 0
# print training job logs.
while not self._stop:
def _list_logs_api(self, page_number=1):
try:
res = self.session.training_job_api.list_logs(
self.training_job_id,
page_number=page_number,
page_size=self.page_size,
)
return res
except TeaException as e:
# hack: Backend service may raise an exception when the training job
# instance is not found.
if e.code == "TRAINING_JOB_INSTANCE_NOT_FOUND":
return PaginatedResult(items=[], total_count=0)
else:
raise e

def _list_logs(self):
page_number, page_offset = 1, 0
# print training job logs.
while not self._stop:
res = self._list_logs_api(page_number=page_number)
# 1. move to next page
if len(res.items) == self.page_size:
# print new logs starting from page_offset
Expand All @@ -1904,11 +1921,7 @@ def _list_logs(self):
# When _stop is True, wait and print remaining logs.
time.sleep(10)
while True:
res = self.session.training_job_api.list_logs(
self.training_job_id,
page_number=page_number,
page_size=self.page_size,
)
res = self._list_logs_api(page_number=page_number)
# There maybe more logs in the next page
if len(res.items) == self.page_size:
self._print_logs(logs=res.items[page_offset:])
Expand Down
2 changes: 1 addition & 1 deletion pai/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _get_supported_tf_versions_for_inference(self) -> List[str]:
return res

def _get_latest_tf_version_for_inference(self) -> str:
"""Return the latest Transformers version for inference."""
"""Return the latest transformers version for inference."""
res = self._get_supported_tf_versions_for_inference()
return max(
res,
Expand Down
14 changes: 10 additions & 4 deletions pai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,12 +987,18 @@ def _deploy_local(

# build command to install requirements
if requirements_list:
install_requirements = shlex.join(
["python", "-m", "pip", "install"] + requirements_list
install_requirements = " ".join(
[
shlex.quote(s)
for s in ["python", "-m", "pip", "install"] + requirements_list
]
)
elif requirements_path:
install_requirements = shlex.join(
["python", "-m", "pip", "install", "-r", requirements_path]
install_requirements = " ".join(
[
shlex.quote(s)
for s in ["python", "-m", "pip", "install", "-r", requirements_path]
]
)
else:
install_requirements = ""
Expand Down

0 comments on commit c62c07c

Please sign in to comment.