81
81
82
82
# runtime script names
83
83
BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py"
84
+ MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py"
84
85
ENTRYPOINT_SCRIPT_NAME = "job_driver.sh"
85
86
PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
86
87
RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py"
167
168
fi
168
169
"""
169
170
171
+ ENTRYPOINT_MPIRUN_SCRIPT = f"""
172
+ #!/bin/bash
173
+
174
+ # Entry point for bootstrapping runtime environment and invoking remote function with mpirun
175
+
176
+ set -eu
177
+
178
+ PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}}
179
+ export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs
180
+ printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\ n"
181
+ export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
182
+ printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
183
+
184
+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
185
+ cat /opt/ml/input/config/resourceconfig.json
186
+
187
+ printf "INFO: Bootstraping runtime environment.\\ n"
188
+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
189
+ source /opt/ml/input/sm_training.env
190
+
191
+ if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
192
+ then
193
+ if [ -f "remote_function_conda_env.txt" ]
194
+ then
195
+ cp remote_function_conda_env.txt { JOB_REMOTE_FUNCTION_WORKSPACE } /remote_function_conda_env.txt
196
+ fi
197
+ printf "INFO: Changing workspace to { JOB_REMOTE_FUNCTION_WORKSPACE } .\\ n"
198
+ cd { JOB_REMOTE_FUNCTION_WORKSPACE }
199
+ fi
200
+
201
+ if [ -f "remote_function_conda_env.txt" ]
202
+ then
203
+ conda_env=$(cat remote_function_conda_env.txt)
204
+
205
+ if which mamba >/dev/null; then
206
+ conda_exe="mamba"
207
+ else
208
+ conda_exe="conda"
209
+ fi
210
+
211
+ if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then
212
+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME }
213
+
214
+ printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\ n"
215
+ printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
216
+ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
217
+ -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
218
+ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
219
+ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
220
+
221
+ python -m mpi4py -m sagemaker.remote_function.invoke_function \\ n"
222
+ $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
223
+ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
224
+ -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
225
+ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
226
+ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
227
+ $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
228
+ python -m mpi4py -m sagemaker.remote_function.invoke_function "$@"
229
+
230
+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME } --job_ended 1
231
+ else
232
+ printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\ n"
233
+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME }
234
+ fi
235
+ else
236
+ if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then
237
+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME }
238
+
239
+ printf "INFO: No conda env provided. Invoking remote function with mpirun\\ n"
240
+ printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
241
+ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
242
+ -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
243
+ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
244
+ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
245
+ $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
246
+ python -m mpi4py -m sagemaker.remote_function.invoke_function \\ n"
247
+
248
+ mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
249
+ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
250
+ -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
251
+ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
252
+ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
253
+ $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
254
+ python -m mpi4py -m sagemaker.remote_function.invoke_function "$@"
255
+
256
+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME } --job_ended 1
257
+ else
258
+ printf "INFO: This is the instance $SM_CURRENT_HOST.\\ n"
259
+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME }
260
+ fi
261
+ fi
262
+ """
263
+
170
264
ENTRYPOINT_TORCHRUN_SCRIPT = f"""
171
265
#!/bin/bash
172
266
211
305
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
212
306
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
213
307
-m sagemaker.remote_function.invoke_function \\ n"
308
+
214
309
$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
215
310
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
216
311
-m sagemaker.remote_function.invoke_function "$@"
217
312
else
218
313
printf "INFO: No conda env provided. Invoking remote function with torchrun\\ n"
219
314
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
220
315
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\ n"
316
+
221
317
torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
222
318
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
223
319
fi
@@ -278,6 +374,7 @@ def __init__(
278
374
use_spot_instances = False ,
279
375
max_wait_time_in_seconds = None ,
280
376
use_torchrun : bool = False ,
377
+ use_mpirun : bool = False ,
281
378
nproc_per_node : Optional [int ] = None ,
282
379
):
283
380
"""Initialize a _JobSettings instance which configures the remote job.
@@ -464,6 +561,9 @@ def __init__(
464
561
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
465
562
Defaults to ``False``.
466
563
564
+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
565
+ Defaults to ``False``.
566
+
467
567
nproc_per_node (Optional int): Specifies the number of processes per node for
468
568
distributed training. Defaults to ``None``.
469
569
This is defined automatically configured on the instance type.
@@ -626,6 +726,7 @@ def __init__(
626
726
self .tags = self .sagemaker_session ._append_sagemaker_config_tags (tags , REMOTE_FUNCTION_TAGS )
627
727
628
728
self .use_torchrun = use_torchrun
729
+ self .use_mpirun = use_mpirun
629
730
self .nproc_per_node = nproc_per_node
630
731
631
732
@staticmethod
@@ -874,6 +975,12 @@ def compile(
874
975
).to_string (),
875
976
]
876
977
)
978
+ if job_settings .use_torchrun :
979
+ container_args .extend (["--distribution" , "torchrun" ])
980
+ elif job_settings .use_mpirun :
981
+ container_args .extend (["--distribution" , "mpirun" ])
982
+ if job_settings .nproc_per_node is not None and int (job_settings .nproc_per_node ) > 0 :
983
+ container_args .extend (["--user_nproc_per_node" , str (job_settings .nproc_per_node )])
877
984
if job_settings .s3_kms_key :
878
985
container_args .extend (["--s3_kms_key" , job_settings .s3_kms_key ])
879
986
@@ -950,6 +1057,7 @@ def compile(
950
1057
request_dict ["Environment" ].update ({"REMOTE_FUNCTION_SECRET_KEY" : hmac_key })
951
1058
952
1059
extended_request = _extend_spark_config_to_request (request_dict , job_settings , s3_base_uri )
1060
+ extended_request = _extend_mpirun_to_request (extended_request , job_settings )
953
1061
extended_request = _extend_torchrun_to_request (extended_request , job_settings )
954
1062
955
1063
return extended_request
@@ -1031,7 +1139,7 @@ def _prepare_and_upload_runtime_scripts(
1031
1139
s3_kms_key : str ,
1032
1140
sagemaker_session : Session ,
1033
1141
use_torchrun : bool = False ,
1034
- nproc_per_node : Optional [ int ] = None ,
1142
+ use_mpirun : bool = False ,
1035
1143
):
1036
1144
"""Copy runtime scripts to a folder and upload to S3.
1037
1145
@@ -1050,6 +1158,8 @@ def _prepare_and_upload_runtime_scripts(
1050
1158
1051
1159
use_torchrun (bool): Whether to use torchrun or not.
1052
1160
1161
+ use_mpirun (bool): Whether to use mpirun or not.
1162
+
1053
1163
nproc_per_node (Optional[int]): Number of processes per node
1054
1164
"""
1055
1165
@@ -1075,23 +1185,25 @@ def _prepare_and_upload_runtime_scripts(
1075
1185
if use_torchrun :
1076
1186
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1077
1187
1078
- if nproc_per_node is not None and nproc_per_node > 0 :
1079
- entry_point_script = entry_point_script .replace (
1080
- "$SM_NPROC_PER_NODE" , str (nproc_per_node )
1081
- )
1188
+ if use_mpirun :
1189
+ entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT
1082
1190
1083
1191
with open (entrypoint_script_path , "w" , newline = "\n " ) as file :
1084
1192
file .writelines (entry_point_script )
1085
1193
1086
1194
bootstrap_script_path = os .path .join (
1087
1195
os .path .dirname (__file__ ), "runtime_environment" , BOOTSTRAP_SCRIPT_NAME
1088
1196
)
1197
+ mpi_utils_path = os .path .join (
1198
+ os .path .dirname (__file__ ), "runtime_environment" , MPI_UTILS_SCRIPT_NAME
1199
+ )
1089
1200
runtime_manager_script_path = os .path .join (
1090
1201
os .path .dirname (__file__ ), "runtime_environment" , RUNTIME_MANAGER_SCRIPT_NAME
1091
1202
)
1092
1203
1093
1204
# copy runtime scripts to tmpdir
1094
1205
shutil .copy2 (bootstrap_script_path , bootstrap_scripts )
1206
+ shutil .copy2 (mpi_utils_path , bootstrap_scripts )
1095
1207
shutil .copy2 (runtime_manager_script_path , bootstrap_scripts )
1096
1208
1097
1209
upload_path = S3Uploader .upload (
@@ -1118,7 +1230,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
1118
1230
s3_kms_key = job_settings .s3_kms_key ,
1119
1231
sagemaker_session = job_settings .sagemaker_session ,
1120
1232
use_torchrun = job_settings .use_torchrun ,
1121
- nproc_per_node = job_settings .nproc_per_node ,
1233
+ use_mpirun = job_settings .use_mpirun ,
1122
1234
)
1123
1235
1124
1236
input_data_config = [
@@ -1459,6 +1571,35 @@ def _upload_serialized_spark_configuration(
1459
1571
return config_file_s3_uri
1460
1572
1461
1573
1574
+ def _extend_mpirun_to_request (
1575
+ request_dict : Dict ,
1576
+ job_settings : _JobSettings ,
1577
+ ) -> Dict :
1578
+ """Extend the create training job request with mpirun configuration.
1579
+
1580
+ Args:
1581
+ request_dict (Dict): create training job request dict.
1582
+ job_settings (_JobSettings): the job settings.
1583
+ """
1584
+ use_mpirun = job_settings .use_mpirun
1585
+ instance_count = job_settings .instance_count
1586
+
1587
+ if not use_mpirun :
1588
+ return request_dict
1589
+
1590
+ if instance_count == 1 :
1591
+ return request_dict
1592
+
1593
+ extended_request = request_dict .copy ()
1594
+
1595
+ for input_channel in extended_request ["InputDataConfig" ]:
1596
+ s3_data_source = input_channel ["DataSource" ].get ("S3DataSource" , None )
1597
+ if s3_data_source :
1598
+ s3_data_source ["S3DataDistributionType" ] = "FullyReplicated"
1599
+
1600
+ return extended_request
1601
+
1602
+
1462
1603
def _extend_torchrun_to_request (
1463
1604
request_dict : Dict ,
1464
1605
job_settings : _JobSettings ,
0 commit comments