diff --git a/pai/model.py b/pai/model.py index 2eba4ce..2d0435c 100644 --- a/pai/model.py +++ b/pai/model.py @@ -23,13 +23,14 @@ import tempfile import textwrap import time +import typing from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import requests from addict import Dict as AttrDict from oss2 import ObjectIterator -from .common import git_utils +from .common import ProviderAlibabaPAI, git_utils from .common.configs import UserVpcConfig from .common.consts import INSTANCE_TYPE_LOCAL_GPU, ModelFormat from .common.docker_utils import ContainerRun, run_container @@ -40,12 +41,15 @@ random_str, to_plain_text, ) -from .exception import DuplicatedMountException, MountPathIsOccupiedException +from .exception import DuplicatedMountException from .image import ImageInfo from .predictor import AsyncPredictor, LocalPredictor, Predictor, ServiceType from .serializers import SerializerBase from .session import Session, get_default_session +if typing.TYPE_CHECKING: + from pai.estimator import AlgorithmEstimator + logger = logging.getLogger(__name__) # Reserved ports for internal use, do not use them for service @@ -1438,6 +1442,7 @@ def __init__( self.model_name = self._model_info.get("ModelName") self.model_provider = self._model_info.get("Provider") self.task = self._model_info.get("Task") + self.domain = self._model_info.get("Domain") self.framework_type = self._model_version_info.get("FrameworkType") self.source_type = self._model_version_info.get("SourceType") self.source_id = self._model_version_info.get("SourceId") @@ -1814,7 +1819,7 @@ def get_estimator( output_path: Optional[str] = None, max_run_time: Optional[int] = None, **kwargs, - ): + ) -> "AlgorithmEstimator": """Generate an AlgorithmEstimator. Generate an AlgorithmEstimator object from RegisteredModel's training_spec. @@ -1894,6 +1899,20 @@ def get_estimator( ) instance_spec = instance_spec or train_compute_resource.get("InstanceSpec") + labels = kwargs.pop("labels", dict()) + if self.model_provider == ProviderAlibabaPAI: + default_labels = { + "BaseModelUri": self.uri, + "CreatedBy": "QuickStart", + "Domain": self.domain, + "RootModelID": self.model_id, + "RootModelName": self.model_name, + "RootModelVersion": self.model_version, + "Task": self.task, + } + default_labels.update(labels) + labels = default_labels + return AlgorithmEstimator( algorithm_name=algorithm_name, algorithm_version=algorithm_version, @@ -1906,6 +1925,7 @@ def get_estimator( instance_count=instance_count, instance_spec=instance_spec, output_path=output_path, + labels=labels, **kwargs, ) diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 188d892..1b038d0 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -266,6 +266,21 @@ def test_builtin_algo_rm_train(self): ) est = m.get_estimator() + + self.assertEqual( + est.labels.get("BaseModelUri"), + m.uri, + ) + + self.assertEqual( + est.labels.get("RootModelName"), + m.model_name, + ) + self.assertEqual( + est.labels.get("RootModelID"), + m.model_id, + ) + inputs = m.get_estimator_inputs() est.hyperparameters["max_epochs"] = 5 est.hyperparameters["warmup_epochs"] = 2