Skip to content

Commit

Permalink
Add docstring for ModelTrainingRecipe
Browse files Browse the repository at this point in the history
  • Loading branch information
pitt-liang committed Jun 20, 2024
1 parent 8c410cc commit 0ce3d0d
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 61 deletions.
170 changes: 110 additions & 60 deletions pai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2199,35 +2199,6 @@ def training_recipe(
class _ModelRecipe(_TrainingJobSubmitter):
MODEL_CHANNEL_NAME = "model"

@property
def algorithm_spec(self) -> AlgorithmSpec:
if self._algorithm_spec:
return self._algorithm_spec
if self.spec.algorithm_spec:
self._algorithm_spec = self.spec.algorithm_spec
else:
session = get_default_session()
algo = session.algorithm_api.get_by_name(
algorithm_name=self.spec.algorithm_name,
algorithm_provider=self.spec.algorithm_provider,
)
raw_algo_version_spec = session.algorithm_api.get_version(
algorithm_id=algo["AlgorithmId"],
algorithm_version=self.spec.algorithm_version,
)
self._algorithm_spec = AlgorithmSpec.model_validate(
raw_algo_version_spec["AlgorithmSpec"]
)
return self._algorithm_spec

@property
def input_channels(self) -> List[Channel]:
return self.algorithm_spec.input_channels

@property
def output_channels(self) -> List[Channel]:
return self.algorithm_spec.output_channels

@staticmethod
def _get_compute_resource_config(
instance_type: str,
Expand Down Expand Up @@ -2319,26 +2290,29 @@ class RecipeInitKwargs(object):
algorithm_spec: Optional[AlgorithmSpec]
input_channels: Optional[List[Channel]]
output_channels: Optional[List[Channel]]
example_train_inputs: Optional[Union[UriInput, DatasetConfig]]
default_training_inputs: Optional[Union[UriInput, DatasetConfig]]


class ModelTrainingRecipe(_ModelRecipe):
"""A recipe used to train a model."""

def __init__(
self,
model_name: Optional[str] = None,
model_version: Optional[str] = None,
model_provider: Optional[str] = None,
model_uri: Optional[str] = None,
training_method: Optional[str] = None,
source_dir: Optional[str] = None,
model_channel_name: Optional[str] = "model",
model_uri: Optional[str] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
job_type: Optional[str] = None,
image_uri: Optional[str] = None,
command: Union[str, List[str]] = None,
instance_count: Optional[int] = None,
instance_type: Optional[str] = None,
instance_spec: Optional[InstanceSpec] = None,
resource_id: Optional[str] = None,
user_vpc_config: Optional[UserVpcConfig] = None,
labels: Optional[Dict[str, str]] = None,
requirements: Optional[List[str]] = None,
Expand All @@ -2350,6 +2324,57 @@ def __init__(
example_train_inputs: Optional[Union[UriInput, DatasetConfig]] = None,
base_job_name: Optional[str] = None,
):
"""Initialize a ModelTrainingRecipe object.
Args:
model_name (str, optional): The name of the registered model. Default to
None.
model_version (str, optional): The version of the registered model. Default
to None.
model_provider (str, optional): The provider of the registered model.
Optional values are "pai", "huggingface" or None. If None, list
registered models in the workspace of the current session. Default to
None.
training_method (str, optional): The training method used to select the
specific training recipe while the registered model contains multiple
model training specs. Default to None.
model_channel_name (str, optional): The name of the model channel. Default to
"model".
model_uri (str, optional): The URI of the input pretrained model. If the URI
is not provided, the model from the registered model will be used.
Default to None.
hyperparameters (dict, optional): A dictionary of hyperparameters used in
the training job. Default to None.
job_type (str, optional): The type of the job, supported values are "PyTorch",
"TfJob", "XGBoostJob" etc.
image_uri (str, optional): The URI of the Docker image. Default to None.
source_dir (str, optional): The source code using in the training job, which
is a directory containing the training script or an OSS URI. Default to
None.
command (str or list, optional): The command to execute in the training job.
Default to None.
requirements (list, optional): A list of Python requirements used to install
the dependencies in the training job. Default to None.
instance_count (int, optional): The number of instances to use for training.
Default to None.
instance_type (str, optional): The instance type to use for training. Default
to None.
instance_spec (:class:`pai.model.InstanceSpec`, optional): The resource config
for each instance of the training job. The dedicated resource group must
be provided when the instance spec is set. Default to None.
resource_id (str, optional): The ID of the resource group used to run the
training job. Default to None.
user_vpc_config (:class:`pai.model.UserVpcConfig`, optional): The VPC
configuration used to enable the job instance to connect to the
specified user VPC. Default to None.
environments (dict, optional): A dictionary of environment variables used in
the training job. Default to None.
experiment_config (:class:`pai.model.ExperimentConfig`, optional): The
experiment
labels (dict, optional): A dictionary of labels used to tag the training job.
Default to None.
"""
init_kwargs = self._init_kwargs(
model_name=model_name,
model_version=model_version,
Expand All @@ -2373,8 +2398,9 @@ def __init__(
experiment_config=experiment_config,
input_channels=input_channels,
output_channels=output_channels,
example_train_inputs=example_train_inputs,
default_training_inputs=example_train_inputs,
max_run_time=max_run_time,
resource_id=resource_id,
)
self.model_name = init_kwargs.model_name
self.model_version = init_kwargs.model_version
Expand All @@ -2389,33 +2415,18 @@ def __init__(
self.instance_count = init_kwargs.instance_count
self.instance_type = init_kwargs.instance_type
self.instance_spec = init_kwargs.instance_spec
self.resource_id = init_kwargs.resource_id
self.user_vpc_config = init_kwargs.user_vpc_config
self.labels = init_kwargs.labels
self.requirements = init_kwargs.requirements
self.environments = init_kwargs.environments
self.experiment_config = init_kwargs.experiment_config
self.base_job_name = init_kwargs.base_job_name
self.source_dir = init_kwargs.source_dir
self.resource_id = init_kwargs.resource_id
self.max_run_time = init_kwargs.max_run_time
self.default_inputs = init_kwargs.example_train_inputs
self.default_training_inputs = init_kwargs.default_training_inputs
super().__init__()

@staticmethod
def _model_info(
model_name: Optional[str] = None,
model_version: Optional[str] = None,
model_provider: Optional[str] = None,
) -> Optional[RegisteredModel]:
if not model_name:
return
m = RegisteredModel(
model_name=model_name,
model_version=model_version,
model_provider=model_provider,
)
return m

@classmethod
def _init_kwargs(
cls,
Expand All @@ -2442,7 +2453,7 @@ def _init_kwargs(
experiment_config: Optional[ExperimentConfig] = None,
input_channels: List[Channel] = None,
output_channels: List[Channel] = None,
example_train_inputs: Optional[Union[UriInput, DatasetConfig]] = None,
default_training_inputs: Optional[Union[UriInput, DatasetConfig]] = None,
base_job_name: Optional[str] = None,
) -> RecipeInitKwargs:
model = (
Expand Down Expand Up @@ -2485,7 +2496,7 @@ def _init_kwargs(
output_channels=output_channels,
resource_id=resource_id,
max_run_time=max_run_time,
example_train_inputs=example_train_inputs,
default_training_inputs=default_training_inputs,
)
if not model_uri:
input_ = next(
Expand All @@ -2507,8 +2518,8 @@ def _init_kwargs(
type(input_),
)

if not example_train_inputs:
example_train_inputs = [
if not default_training_inputs:
default_training_inputs = [
item
for item in model_training_spec.inputs
if item.name != model_channel_name
Expand Down Expand Up @@ -2587,7 +2598,7 @@ def _init_kwargs(
input_channels=input_channels,
output_channels=output_channels,
resource_id=resource_id,
example_train_inputs=example_train_inputs,
default_training_inputs=default_training_inputs,
)

def _build_algorithm_spec(
Expand All @@ -2612,12 +2623,28 @@ def _build_algorithm_spec(

def train(
self,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[Dict[str, Union[str, DatasetConfig]]] = None,
wait: bool = True,
job_name: Optional[str] = None,
show_logs: bool = True,
) -> TrainingJob:
"""Start a training job with the given inputs."""
"""Start a training job with the given inputs.
Args:
inputs (Dict[str, Union[str, DatasetConfig]], optional): A dictionary of inputs
used in the training job. The keys are the channel name and the values are
the URIs of the input data. If not specified, the default inputs will be
used.
wait (bool): Whether to wait for the job to complete before returning. Default
to True.
job_name (str, optional): The name of the training job. If not provided, a default
job name will be generated.
show_logs (bool): Whether to show the logs of the training job. Default to True.
Returns:
:class:`pai.training.TrainingJob`: A submitted training job.
"""
job_name = self.job_name(job_name)

inputs = inputs or dict()
Expand All @@ -2631,7 +2658,7 @@ def train(
inputs[self.model_channel_name] = self.model_uri

if len(inputs.keys()) == 1 and self.model_channel_name in inputs:
default_inputs = self.default_inputs
default_inputs = self.default_training_inputs
else:
default_inputs = None

Expand Down Expand Up @@ -2701,15 +2728,38 @@ def model_data(self):
def deploy(
self,
service_name: str,
inference_spec: Optional[InferenceSpec] = None,
instance_type: Optional[str] = None,
instance_count: int = 1,
resource_config: Optional[Union[ResourceConfig, Dict[str, int]]] = None,
resource_id: str = None,
options: Optional[Dict[str, Any]] = None,
serializer: SerializerBase = None,
wait=True,
inference_spec: Optional[InferenceSpec] = None,
**kwargs,
) -> Predictor:
"""Deploy the training job output model as a online prediction service.
Args:
service_name (str): The name of the online prediction service.
instance_type (str, optional): The instance type used to run the service.
instance_count (int, optional): The number of instances used to run the
service. Default to 1.
resource_config (Union[ResourceConfig, Dict[str, int]], optional): The resource
config for the service. Default to None.
resource_id (str, optional): The ID of the resource group used to run the
service. Default to None.
options (Dict[str, Any], optional): The options used to deploy the service.
Default to None.
wait (bool, optional): Whether to wait for the service endpoint to be ready.
inference_spec (:class:`pai.model.InferenceSpec`, optional): The inference
spec used to deploy the service. If not provided, the `inference_spec` of
the model will be used. Default to None.
kwargs: Additional keyword arguments used to deploy the service.
Returns:
:class:`pai.predictor.Predictor`: A predictor object refers to the created
service.
"""
if not inference_spec and self.model_name:
model = RegisteredModel(
model_name=self.model_name,
Expand All @@ -2731,8 +2781,8 @@ def deploy(
instance_count=instance_count,
resource_config=resource_config,
resource_id=resource_id,
serializer=serializer,
options=options,
wait=wait,
**kwargs,
)
return p
2 changes: 1 addition & 1 deletion tests/integration/test_model/test_model_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_custom_inputs_train(self):
model = RegisteredModel(model_name="qwen1.5-0.5b-chat", model_provider="pai")
training_recipe = model.training_recipe(training_method="QLoRA_LLM")
self.assertTrue(
bool(training_recipe.default_inputs),
bool(training_recipe.default_training_inputs),
"Default inputs is empty for ModelTrainingRecipe.",
)

Expand Down

0 comments on commit 0ce3d0d

Please sign in to comment.