@@ -186,6 +186,7 @@ def __init__(
186186 enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
187187 enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
188188 training_plan : Optional [Union [str , PipelineVariable ]] = None ,
189+ instance_placement_config : Optional [Dict ] = None ,
189190 ** kwargs ,
190191 ):
191192 """Initialize an ``EstimatorBase`` instance.
@@ -560,6 +561,21 @@ def __init__(
560561 Specifies whether SessionTagChaining is enabled for the training job.
561562 training_plan (str or PipelineVariable): Optional.
562563 Specifies which training plan arn to use for the training job
564+ instance_placement_config (dict): Optional.
565+ Specifies UltraServer placement configuration for the training job
566+
567+ .. code:: python
568+
569+ instance_placement_config={
570+ "EnableMultipleJobs": True,
571+ "PlacementSpecifications":[
572+ {
573+ "UltraServerId": "ultraserver-1",
574+ "InstanceCount": "2"
575+ }
576+ ]
577+ }
578+
563579 """
564580 instance_count = renamed_kwargs (
565581 "train_instance_count" , "instance_count" , instance_count , kwargs
@@ -813,6 +829,8 @@ def __init__(
813829
814830 self .training_plan = training_plan
815831
832+ self .instance_placement_config = instance_placement_config
833+
816834 # Internal flag
817835 self ._is_output_path_set_from_default_bucket_and_prefix = False
818836
@@ -1997,6 +2015,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19972015 if "TrainingPlanArn" in job_details ["ResourceConfig" ]:
19982016 init_params ["training_plan" ] = job_details ["ResourceConfig" ]["TrainingPlanArn" ]
19992017
2018+ if "InstancePlacementConfig" in job_details ["ResourceConfig" ]:
2019+ init_params ["instance_placement_config" ] = job_details ["ResourceConfig" ][
2020+ "InstancePlacementConfig"
2021+ ]
2022+
20002023 has_hps = "HyperParameters" in job_details
20012024 init_params ["hyperparameters" ] = job_details ["HyperParameters" ] if has_hps else {}
20022025
@@ -2882,6 +2905,7 @@ def __init__(
28822905 enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
28832906 enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
28842907 training_plan : Optional [Union [str , PipelineVariable ]] = None ,
2908+ instance_placement_config : Optional [Dict ] = None ,
28852909 ** kwargs ,
28862910 ):
28872911 """Initialize an ``Estimator`` instance.
@@ -3249,6 +3273,20 @@ def __init__(
32493273 Specifies whether SessionTagChaining is enabled for the training job
32503274 training_plan (str or PipelineVariable): Optional.
32513275 Specifies which training plan arn to use for the training job
3276+ instance_placement_config (dict): Optional.
3277+ Specifies UltraServer placement configuration for the training job
3278+
3279+ .. code:: python
3280+
3281+ instance_placement_config={
3282+ "EnableMultipleJobs": True,
3283+ "PlacementSpecifications":[
3284+ {
3285+ "UltraServerId": "ultraserver-1",
3286+ "InstanceCount": "2"
3287+ }
3288+ ]
3289+ }
32523290 """
32533291 self .image_uri = image_uri
32543292 self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
@@ -3303,6 +3341,7 @@ def __init__(
33033341 enable_remote_debug = enable_remote_debug ,
33043342 enable_session_tag_chaining = enable_session_tag_chaining ,
33053343 training_plan = training_plan ,
3344+ instance_placement_config = instance_placement_config ,
33063345 ** kwargs ,
33073346 )
33083347
0 commit comments