Skip to content

Commit

Permalink
add labels for service deploy via quickstart model (#16)
Browse files Browse the repository at this point in the history
* add labels for service deploy via QuickStart model

* fix base_url for openai client generated via predictor

* fix: fix workflow step name

* remove integration test in Github workflow

* fix missing labels in AlgorithmEstimator.fit
  • Loading branch information
pitt-liang authored Apr 29, 2024
1 parent b142baa commit f2618d4
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 41 deletions.
29 changes: 0 additions & 29 deletions .github/workflows/integration.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/unit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ concurrency:
cancel-in-progress: true

jobs:
lint:
unit-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def install_test_dependencies(session: Session):
session.install("-r", TEST_REQUIREMENTS)


@nox.session(venv_backend=TEST_VENV_BACKEND, python=INTEGRATION_TEST_PYTHON_VERSIONS)
@nox.session(venv_backend=TEST_VENV_BACKEND)
def integration(session: Session):
"""Run integration test."""
install_test_dependencies(session=session)
Expand Down
2 changes: 2 additions & 0 deletions pai/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ def __init__(
instance_count=instance_count,
user_vpc_config=user_vpc_config,
experiment_config=experiment_config,
resource_id=resource_id,
session=session,
)

Expand Down Expand Up @@ -1437,6 +1438,7 @@ def _fit(self, job_name, inputs: Dict[str, Any] = None):
experiment_config=self.experiment_config.to_dict()
if self.experiment_config
else None,
labels=self.labels,
)
training_job = _TrainingJob.get(training_job_id)
print(
Expand Down
21 changes: 20 additions & 1 deletion pai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
# Reserved ports for internal use, do not use them for service
_RESERVED_PORTS = [8080, 9090]

# Default model upstream source
MODEL_TASK_CREATED_BY_QUICKSTART = "QuickStart"


class DefaultServiceConfig(object):
"""Default configuration used in creating prediction service."""
Expand Down Expand Up @@ -851,6 +854,7 @@ def _deploy(
wait: bool = True,
serializer: "SerializerBase" = None,
labels: Optional[Dict[str, str]] = None,
**kwargs,
):
"""Create a prediction service."""
if not service_name:
Expand Down Expand Up @@ -1723,6 +1727,20 @@ def deploy(
if not self.inference_spec:
raise RuntimeError("No inference_spec for the registered model.")

labels = kwargs.pop("labels", dict())
if self.model_provider == ProviderAlibabaPAI:
default_labels = {
"Task": self.task,
"RootModelName": self.model_name,
"RootModelVersion": self.model_version,
"RootModelID": self.model_id,
"Domain": self.domain,
"CreatedBy": MODEL_TASK_CREATED_BY_QUICKSTART,
"BaseModelUri": self.uri,
}
default_labels.update(labels)
labels = default_labels

if is_local_run_instance_type(instance_type):
return self._deploy_local(
instance_type=instance_type,
Expand All @@ -1740,6 +1758,7 @@ def deploy(
options=options,
wait=wait,
serializer=serializer,
labels=labels,
**kwargs,
)

Expand Down Expand Up @@ -1906,7 +1925,7 @@ def get_estimator(
if self.model_provider == ProviderAlibabaPAI:
default_labels = {
"BaseModelUri": self.uri,
"CreatedBy": "QuickStart",
"CreatedBy": MODEL_TASK_CREATED_BY_QUICKSTART,
"Domain": self.domain,
"RootModelID": self.model_id,
"RootModelName": self.model_name,
Expand Down
30 changes: 23 additions & 7 deletions pai/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,19 @@ def service_status(self):
return self._service_api_object["Status"]

@property
def access_token(self):
def access_token(self) -> str:
"""Access token of the service."""
return self._service_api_object["AccessToken"]

@property
def labels(self) -> Dict[str, str]:
"""Labels of the service."""
labels = {
item["LabelKey"]: item["LabelValue"]
for item in self._service_api_object.get("Labels", [])
}
return labels

@property
def console_uri(self):
"""Returns the console URI of the service."""
Expand Down Expand Up @@ -298,17 +307,14 @@ def delete_service(self):
"""Delete the service."""
self.session.service_api.delete(name=self.service_name)

def wait_for_ready(self, force: bool = False):
def wait_for_ready(self):
"""Wait until the service enter running status.
Args:
force (bool): Whether to force wait for ready.
Raises:
RuntimeError: Raise if the service terminated unexpectedly.
"""
if self.service_status == ServiceStatus.Running and not force:
if self.service_status == ServiceStatus.Running:
return

logger.info(
Expand All @@ -327,6 +333,10 @@ def wait_for_ready(self, force: bool = False):
self._wait_for_gateway_ready()
self.refresh()

def wait(self):
"""Wait for the service to be ready."""
return self.wait_for_ready()

def _wait_for_gateway_ready(self, attempts: int = 60, interval: int = 2):
"""Hacky way to wait for the service gateway to be ready.
Expand All @@ -337,6 +347,8 @@ def _wait_for_gateway_ready(self, attempts: int = 60, interval: int = 2):
"""

def _is_gateway_ready():
# can't use HEAD method to check gateway status because the service will
# block the request until timeout.
resp = self._send_request(method="GET")
res = not (
# following status code and content indicates the gateway is not ready
Expand Down Expand Up @@ -730,7 +742,8 @@ def openai(self, **kwargs) -> "OpenAI":
raise ImportError(
"openai package is not installed, install it with `pip install openai`."
)
base_url = kwargs.pop("base_url", self.endpoint + "/v1/")

base_url = kwargs.pop("base_url", posixpath.join(self.endpoint + "v1/"))
api_key = kwargs.pop("api_key", self.access_token)

return OpenAI(base_url=base_url, api_key=api_key, **kwargs)
Expand Down Expand Up @@ -1378,6 +1391,9 @@ def wait_for_ready(self):
self._wait_local_server_ready()
time.sleep(5)

def wait(self):
return self.wait_for_ready()

def _wait_local_server_ready(
self,
interval: int = 5,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ src_paths = ["pai", "tests"]
#known_first_party = ["pai", "tests"]

[tool.pytest.ini_options]
timeout = 300
timeout = 600

[doc8]
max-line-length=88
Expand Down
8 changes: 7 additions & 1 deletion tests/integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,14 @@ def test_rm_deploy(self):
)

p = m.deploy()

self.predictors.append(p)

self.assertEqual(p.labels.get("RootModelID"), m.model_id)
self.assertEqual(p.labels.get("RootModelName"), m.model_name)
self.assertEqual(p.labels.get("RootModelVersion"), m.model_version)
self.assertEqual(p.labels.get("BaseModelUri"), m.uri)
self.assertEqual(p.labels.get("Task"), m.task)
self.assertEqual(p.labels.get("Domain"), m.domain)
self.assertTrue(p.service_name)
res = p.predict(["开心", "死亡"])
self.assertTrue(isinstance(res, list))
Expand Down

0 comments on commit f2618d4

Please sign in to comment.