1414from __future__ import absolute_import
1515
1616from typing import Optional
17- from sagemaker import image_uris
1817from sagemaker .jumpstart .constants import (
1918 DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
2019)
2120from sagemaker .jumpstart .enums import (
2221 JumpStartModelType ,
2322 JumpStartScriptScope ,
24- ModelFramework ,
2523)
2624from sagemaker .jumpstart .utils import (
2725 get_region_fallback ,
@@ -35,16 +33,8 @@ def _retrieve_image_uri(
3533 model_version : str ,
3634 image_scope : str ,
3735 hub_arn : Optional [str ] = None ,
38- framework : Optional [str ] = None ,
3936 region : Optional [str ] = None ,
40- version : Optional [str ] = None ,
41- py_version : Optional [str ] = None ,
4237 instance_type : Optional [str ] = None ,
43- accelerator_type : Optional [str ] = None ,
44- container_version : Optional [str ] = None ,
45- distribution : Optional [str ] = None ,
46- base_framework_version : Optional [str ] = None ,
47- training_compiler_config : Optional [str ] = None ,
4838 tolerate_vulnerable_model : bool = False ,
4939 tolerate_deprecated_model : bool = False ,
5040 sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
@@ -66,30 +56,11 @@ def _retrieve_image_uri(
6656 image_scope (str): The image type, i.e. what it is used for.
6757 Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
6858 ``image_scope`` is ignored.
69- framework (str): The name of the framework or algorithm.
7059 region (str): The AWS region. (Default: None).
71- version (str): The framework or algorithm version. This is required if there is
72- more than one supported version for the given framework or algorithm.
73- (Default: None).
74- py_version (str): The Python version. This is required if there is
75- more than one supported Python version for the given framework version.
7660 instance_type (str): The SageMaker instance type. For supported types, see
7761 https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
7862 there are different images for different processor types.
7963 (Default: None).
80- accelerator_type (str): Elastic Inference accelerator type. For more, see
81- https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
82- (Default: None).
83- container_version (str): the version of docker image.
84- Ideally the value of parameter should be created inside the framework.
85- For custom use, see the list of supported container versions:
86- https://github.com/aws/deep-learning-containers/blob/master/available_images.md.
87- (Default: None).
88- distribution (dict): A dictionary with information on how to run distributed training.
89- (Default: None).
90- training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
91- A configuration class for the SageMaker Training Compiler.
92- (Default: None).
9364 tolerate_vulnerable_model (bool): True if vulnerable versions of model
9465 specifications should be tolerated (exception not raised). If False, raises an
9566 exception if the script used by this version of the model has dependencies with known
@@ -142,14 +113,12 @@ def _retrieve_image_uri(
142113 ecr_uri = model_specs .hosting_ecr_uri
143114 return ecr_uri
144115
145- ecr_specs = model_specs .hosting_ecr_specs
146- if ecr_specs is None :
147- raise ValueError (
148- f"No inference ECR configuration found for JumpStart model ID '{ model_id } ' "
149- f"with { instance_type } instance type in { region } . "
150- "Please try another instance type or region."
151- )
152- elif image_scope == JumpStartScriptScope .TRAINING :
116+ raise ValueError (
117+ f"No inference ECR configuration found for JumpStart model ID '{ model_id } ' "
118+ f"with { instance_type } instance type in { region } . "
119+ "Please try another instance type or region."
120+ )
121+ if image_scope == JumpStartScriptScope .TRAINING :
153122 training_instance_type_variants = model_specs .training_instance_type_variants
154123 if training_instance_type_variants :
155124 image_uri = training_instance_type_variants .get_image_uri (
@@ -161,65 +130,10 @@ def _retrieve_image_uri(
161130 ecr_uri = model_specs .training_ecr_uri
162131 return ecr_uri
163132
164- ecr_specs = model_specs .training_ecr_specs
165- if ecr_specs is None :
166- raise ValueError (
167- f"No training ECR configuration found for JumpStart model ID '{ model_id } ' "
168- f"with { instance_type } instance type in { region } . "
169- "Please try another instance type or region."
170- )
171- if framework is not None and framework != ecr_specs .framework :
172- raise ValueError (
173- f"Incorrect container framework '{ framework } ' for JumpStart model ID '{ model_id } ' "
174- f"and version '{ model_version } '."
175- )
176-
177- if version is not None and version != ecr_specs .framework_version :
178- raise ValueError (
179- f"Incorrect container framework version '{ version } ' for JumpStart model ID "
180- f"'{ model_id } ' and version '{ model_version } '."
181- )
182-
183- if py_version is not None and py_version != ecr_specs .py_version :
184133 raise ValueError (
185- f"Incorrect python version '{ py_version } ' for JumpStart model ID '{ model_id } ' "
186- f"and version '{ model_version } '."
187- )
188-
189- base_framework_version_override : Optional [str ] = None
190- version_override : Optional [str ] = None
191- if ecr_specs .framework == ModelFramework .HUGGINGFACE :
192- base_framework_version_override = ecr_specs .framework_version
193- version_override = ecr_specs .huggingface_transformers_version
194-
195- if image_scope == JumpStartScriptScope .TRAINING :
196- return image_uris .get_training_image_uri (
197- region = region ,
198- framework = ecr_specs .framework ,
199- framework_version = version_override or ecr_specs .framework_version ,
200- py_version = ecr_specs .py_version ,
201- image_uri = None ,
202- distribution = None ,
203- compiler_config = None ,
204- tensorflow_version = None ,
205- pytorch_version = base_framework_version_override or base_framework_version ,
206- instance_type = instance_type ,
134+ f"No training ECR configuration found for JumpStart model ID '{ model_id } ' "
135+ f"with { instance_type } instance type in { region } . "
136+ "Please try another instance type or region."
207137 )
208- if base_framework_version_override is not None :
209- base_framework_version_override = f"pytorch{ base_framework_version_override } "
210138
211- return image_uris .retrieve (
212- framework = ecr_specs .framework ,
213- region = region ,
214- version = version_override or ecr_specs .framework_version ,
215- py_version = ecr_specs .py_version ,
216- instance_type = instance_type ,
217- hub_arn = hub_arn ,
218- accelerator_type = accelerator_type ,
219- image_scope = image_scope ,
220- container_version = container_version ,
221- distribution = distribution ,
222- base_framework_version = base_framework_version_override or base_framework_version ,
223- training_compiler_config = training_compiler_config ,
224- config_name = config_name ,
225- )
139+ raise ValueError (f"Invalid scope: { image_scope } " )
0 commit comments