Skip to content

Commit

Permalink
Add labels for training_job from QuickStart model
Browse files Browse the repository at this point in the history
  • Loading branch information
pitt-liang committed Apr 25, 2024
1 parent 02811e7 commit fa670c1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
26 changes: 23 additions & 3 deletions pai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
import tempfile
import textwrap
import time
import typing
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import requests
from addict import Dict as AttrDict
from oss2 import ObjectIterator

from .common import git_utils
from .common import ProviderAlibabaPAI, git_utils
from .common.configs import UserVpcConfig
from .common.consts import INSTANCE_TYPE_LOCAL_GPU, ModelFormat
from .common.docker_utils import ContainerRun, run_container
Expand All @@ -40,12 +41,15 @@
random_str,
to_plain_text,
)
from .exception import DuplicatedMountException, MountPathIsOccupiedException
from .exception import DuplicatedMountException
from .image import ImageInfo
from .predictor import AsyncPredictor, LocalPredictor, Predictor, ServiceType
from .serializers import SerializerBase
from .session import Session, get_default_session

if typing.TYPE_CHECKING:
from pai.estimator import AlgorithmEstimator

logger = logging.getLogger(__name__)

# Reserved ports for internal use, do not use them for service
Expand Down Expand Up @@ -1438,6 +1442,7 @@ def __init__(
self.model_name = self._model_info.get("ModelName")
self.model_provider = self._model_info.get("Provider")
self.task = self._model_info.get("Task")
self.domain = self._model_info.get("Domain")
self.framework_type = self._model_version_info.get("FrameworkType")
self.source_type = self._model_version_info.get("SourceType")
self.source_id = self._model_version_info.get("SourceId")
Expand Down Expand Up @@ -1814,7 +1819,7 @@ def get_estimator(
output_path: Optional[str] = None,
max_run_time: Optional[int] = None,
**kwargs,
):
) -> "AlgorithmEstimator":
"""Generate an AlgorithmEstimator.
Generate an AlgorithmEstimator object from RegisteredModel's training_spec.
Expand Down Expand Up @@ -1894,6 +1899,20 @@ def get_estimator(
)
instance_spec = instance_spec or train_compute_resource.get("InstanceSpec")

labels = kwargs.pop("labels", dict())
if self.model_provider == ProviderAlibabaPAI:
default_labels = {
"BaseModelUri": self.uri,
"CreatedBy": "QuickStart",
"Domain": self.domain,
"RootModelID": self.model_id,
"RootModelName": self.model_name,
"RootModelVersion": self.model_version,
"Task": self.task,
}
default_labels.update(labels)
labels = default_labels

return AlgorithmEstimator(
algorithm_name=algorithm_name,
algorithm_version=algorithm_version,
Expand All @@ -1906,6 +1925,7 @@ def get_estimator(
instance_count=instance_count,
instance_spec=instance_spec,
output_path=output_path,
labels=labels,
**kwargs,
)

Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,21 @@ def test_builtin_algo_rm_train(self):
)

est = m.get_estimator()

self.assertEqual(
est.labels.get("BaseModelUri"),
m.uri,
)

self.assertEqual(
est.labels.get("RootModelName"),
m.model_name,
)
self.assertEqual(
est.labels.get("RootModelID"),
m.model_id,
)

inputs = m.get_estimator_inputs()
est.hyperparameters["max_epochs"] = 5
est.hyperparameters["warmup_epochs"] = 2
Expand Down

0 comments on commit fa670c1

Please sign in to comment.