Skip to content

Commit 6d2dfa0

Browse files
mpirun protocol - distributed training with @Remote decorator (#4998)
* implemented multi-node distribution with @Remote function * completed unit tests * added distributed training with CPU and torchrun * backwards compatibility nproc_per_node * fixing code: permissions for non-root users, integration tests * fixed docstyle * refactor nproc_per_node for backwards compatibility * refactor nproc_per_node for backwards compatibility * pylint fix, newlines * added unit tests for bootstrap_environment remote * added mpirun protocol for distributed training with @Remote decorator * aligned mpi_utils_remote.py to mpi_utils.py for estimator * updated docstring for sagemaker sdk doc --------- Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 90e9c9f commit 6d2dfa0

File tree

10 files changed

+1168
-49
lines changed

10 files changed

+1168
-49
lines changed

src/sagemaker/remote_function/client.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def remote(
9090
spark_config: SparkConfig = None,
9191
use_spot_instances=False,
9292
max_wait_time_in_seconds=None,
93-
use_torchrun=False,
93+
use_torchrun: bool = False,
94+
use_mpirun: bool = False,
9495
nproc_per_node: Optional[int] = None,
9596
):
9697
"""Decorator for running the annotated function as a SageMaker training job.
@@ -207,7 +208,8 @@ def remote(
207208
files are accepted and uploaded to S3.
208209
209210
instance_count (int): The number of instances to use. Defaults to 1.
210-
NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
211+
NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
212+
mpirun utilities
211213
212214
instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
213215
the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
@@ -284,6 +286,9 @@ def remote(
284286
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
285287
Defaults to ``False``.
286288
289+
use_mpirun (bool): Specifies whether to use mpirun for distributed training.
290+
Defaults to ``False``.
291+
287292
nproc_per_node (Optional int): Specifies the number of processes per node for
288293
distributed training. Defaults to ``None``.
289294
This is defined automatically configured on the instance type.
@@ -320,19 +325,21 @@ def _remote(func):
320325
use_spot_instances=use_spot_instances,
321326
max_wait_time_in_seconds=max_wait_time_in_seconds,
322327
use_torchrun=use_torchrun,
328+
use_mpirun=use_mpirun,
323329
nproc_per_node=nproc_per_node,
324330
)
325331

326332
@functools.wraps(func)
327333
def wrapper(*args, **kwargs):
328334

329335
if instance_count > 1 and not (
330-
(spark_config is not None and not use_torchrun)
331-
or (spark_config is None and use_torchrun)
336+
(spark_config is not None and not use_torchrun and not use_mpirun)
337+
or (spark_config is None and use_torchrun and not use_mpirun)
338+
or (spark_config is None and not use_torchrun and use_mpirun)
332339
):
333340
raise ValueError(
334341
"Remote function do not support training on multi instances "
335-
+ "without spark_config or use_torchrun. "
342+
+ "without spark_config or use_torchrun or use_mpirun. "
336343
+ "Please provide instance_count = 1"
337344
)
338345

@@ -536,7 +543,8 @@ def __init__(
536543
spark_config: SparkConfig = None,
537544
use_spot_instances=False,
538545
max_wait_time_in_seconds=None,
539-
use_torchrun=False,
546+
use_torchrun: bool = False,
547+
use_mpirun: bool = False,
540548
nproc_per_node: Optional[int] = None,
541549
):
542550
"""Constructor for RemoteExecutor
@@ -650,7 +658,8 @@ def __init__(
650658
files are accepted and uploaded to S3.
651659
652660
instance_count (int): The number of instances to use. Defaults to 1.
653-
NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
661+
NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
662+
mpirun utilities
654663
655664
instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
656665
the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
@@ -730,6 +739,9 @@ def __init__(
730739
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
731740
Defaults to ``False``.
732741
742+
use_mpirun (bool): Specifies whether to use mpirun for distributed training.
743+
Defaults to ``False``.
744+
733745
nproc_per_node (Optional int): Specifies the number of processes per node for
734746
distributed training. Defaults to ``None``.
735747
This is defined automatically configured on the instance type.
@@ -740,12 +752,13 @@ def __init__(
740752
raise ValueError("max_parallel_jobs must be greater than 0.")
741753

742754
if instance_count > 1 and not (
743-
(spark_config is not None and not use_torchrun)
744-
or (spark_config is None and use_torchrun)
755+
(spark_config is not None and not use_torchrun and not use_mpirun)
756+
or (spark_config is None and use_torchrun and not use_mpirun)
757+
or (spark_config is None and not use_torchrun and use_mpirun)
745758
):
746759
raise ValueError(
747760
"Remote function do not support training on multi instances "
748-
+ "without spark_config or use_torchrun. "
761+
+ "without spark_config or use_torchrun or use_mpirun. "
749762
+ "Please provide instance_count = 1"
750763
)
751764

@@ -778,6 +791,7 @@ def __init__(
778791
use_spot_instances=use_spot_instances,
779792
max_wait_time_in_seconds=max_wait_time_in_seconds,
780793
use_torchrun=use_torchrun,
794+
use_mpirun=use_mpirun,
781795
nproc_per_node=nproc_per_node,
782796
)
783797

src/sagemaker/remote_function/job.py

+147-6
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181

8282
# runtime script names
8383
BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py"
84+
MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py"
8485
ENTRYPOINT_SCRIPT_NAME = "job_driver.sh"
8586
PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
8687
RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py"
@@ -167,6 +168,99 @@
167168
fi
168169
"""
169170

171+
ENTRYPOINT_MPIRUN_SCRIPT = f"""
172+
#!/bin/bash
173+
174+
# Entry point for bootstrapping runtime environment and invoking remote function with mpirun
175+
176+
set -eu
177+
178+
PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}}
179+
export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs
180+
printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n"
181+
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
182+
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
183+
184+
printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
185+
cat /opt/ml/input/config/resourceconfig.json
186+
187+
printf "INFO: Bootstraping runtime environment.\\n"
188+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
189+
source /opt/ml/input/sm_training.env
190+
191+
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
192+
then
193+
if [ -f "remote_function_conda_env.txt" ]
194+
then
195+
cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt
196+
fi
197+
printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n"
198+
cd {JOB_REMOTE_FUNCTION_WORKSPACE}
199+
fi
200+
201+
if [ -f "remote_function_conda_env.txt" ]
202+
then
203+
conda_env=$(cat remote_function_conda_env.txt)
204+
205+
if which mamba >/dev/null; then
206+
conda_exe="mamba"
207+
else
208+
conda_exe="conda"
209+
fi
210+
211+
if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then
212+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME}
213+
214+
printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\n"
215+
printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
216+
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
217+
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
218+
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
219+
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
220+
221+
python -m mpi4py -m sagemaker.remote_function.invoke_function \\n"
222+
$conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
223+
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
224+
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
225+
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
226+
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
227+
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
228+
python -m mpi4py -m sagemaker.remote_function.invoke_function "$@"
229+
230+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
231+
else
232+
printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\n"
233+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME}
234+
fi
235+
else
236+
if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then
237+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME}
238+
239+
printf "INFO: No conda env provided. Invoking remote function with mpirun\\n"
240+
printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
241+
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
242+
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
243+
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
244+
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
245+
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
246+
python -m mpi4py -m sagemaker.remote_function.invoke_function \\n"
247+
248+
mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
249+
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
250+
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
251+
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
252+
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
253+
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
254+
python -m mpi4py -m sagemaker.remote_function.invoke_function "$@"
255+
256+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
257+
else
258+
printf "INFO: This is the instance $SM_CURRENT_HOST.\\n"
259+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME}
260+
fi
261+
fi
262+
"""
263+
170264
ENTRYPOINT_TORCHRUN_SCRIPT = f"""
171265
#!/bin/bash
172266
@@ -211,13 +305,15 @@
211305
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
212306
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
213307
-m sagemaker.remote_function.invoke_function \\n"
308+
214309
$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
215310
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
216311
-m sagemaker.remote_function.invoke_function "$@"
217312
else
218313
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
219314
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
220315
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n"
316+
221317
torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
222318
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
223319
fi
@@ -278,6 +374,7 @@ def __init__(
278374
use_spot_instances=False,
279375
max_wait_time_in_seconds=None,
280376
use_torchrun: bool = False,
377+
use_mpirun: bool = False,
281378
nproc_per_node: Optional[int] = None,
282379
):
283380
"""Initialize a _JobSettings instance which configures the remote job.
@@ -464,6 +561,9 @@ def __init__(
464561
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
465562
Defaults to ``False``.
466563
564+
use_mpirun (bool): Specifies whether to use mpirun for distributed training.
565+
Defaults to ``False``.
566+
467567
nproc_per_node (Optional int): Specifies the number of processes per node for
468568
distributed training. Defaults to ``None``.
469569
This is defined automatically configured on the instance type.
@@ -626,6 +726,7 @@ def __init__(
626726
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)
627727

628728
self.use_torchrun = use_torchrun
729+
self.use_mpirun = use_mpirun
629730
self.nproc_per_node = nproc_per_node
630731

631732
@staticmethod
@@ -874,6 +975,12 @@ def compile(
874975
).to_string(),
875976
]
876977
)
978+
if job_settings.use_torchrun:
979+
container_args.extend(["--distribution", "torchrun"])
980+
elif job_settings.use_mpirun:
981+
container_args.extend(["--distribution", "mpirun"])
982+
if job_settings.nproc_per_node is not None and int(job_settings.nproc_per_node) > 0:
983+
container_args.extend(["--user_nproc_per_node", str(job_settings.nproc_per_node)])
877984
if job_settings.s3_kms_key:
878985
container_args.extend(["--s3_kms_key", job_settings.s3_kms_key])
879986

@@ -950,6 +1057,7 @@ def compile(
9501057
request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key})
9511058

9521059
extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)
1060+
extended_request = _extend_mpirun_to_request(extended_request, job_settings)
9531061
extended_request = _extend_torchrun_to_request(extended_request, job_settings)
9541062

9551063
return extended_request
@@ -1031,7 +1139,7 @@ def _prepare_and_upload_runtime_scripts(
10311139
s3_kms_key: str,
10321140
sagemaker_session: Session,
10331141
use_torchrun: bool = False,
1034-
nproc_per_node: Optional[int] = None,
1142+
use_mpirun: bool = False,
10351143
):
10361144
"""Copy runtime scripts to a folder and upload to S3.
10371145
@@ -1050,6 +1158,8 @@ def _prepare_and_upload_runtime_scripts(
10501158
10511159
use_torchrun (bool): Whether to use torchrun or not.
10521160
1161+
use_mpirun (bool): Whether to use mpirun or not.
1162+
10531163
nproc_per_node (Optional[int]): Number of processes per node
10541164
"""
10551165

@@ -1075,23 +1185,25 @@ def _prepare_and_upload_runtime_scripts(
10751185
if use_torchrun:
10761186
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
10771187

1078-
if nproc_per_node is not None and nproc_per_node > 0:
1079-
entry_point_script = entry_point_script.replace(
1080-
"$SM_NPROC_PER_NODE", str(nproc_per_node)
1081-
)
1188+
if use_mpirun:
1189+
entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT
10821190

10831191
with open(entrypoint_script_path, "w", newline="\n") as file:
10841192
file.writelines(entry_point_script)
10851193

10861194
bootstrap_script_path = os.path.join(
10871195
os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME
10881196
)
1197+
mpi_utils_path = os.path.join(
1198+
os.path.dirname(__file__), "runtime_environment", MPI_UTILS_SCRIPT_NAME
1199+
)
10891200
runtime_manager_script_path = os.path.join(
10901201
os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME
10911202
)
10921203

10931204
# copy runtime scripts to tmpdir
10941205
shutil.copy2(bootstrap_script_path, bootstrap_scripts)
1206+
shutil.copy2(mpi_utils_path, bootstrap_scripts)
10951207
shutil.copy2(runtime_manager_script_path, bootstrap_scripts)
10961208

10971209
upload_path = S3Uploader.upload(
@@ -1118,7 +1230,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
11181230
s3_kms_key=job_settings.s3_kms_key,
11191231
sagemaker_session=job_settings.sagemaker_session,
11201232
use_torchrun=job_settings.use_torchrun,
1121-
nproc_per_node=job_settings.nproc_per_node,
1233+
use_mpirun=job_settings.use_mpirun,
11221234
)
11231235

11241236
input_data_config = [
@@ -1459,6 +1571,35 @@ def _upload_serialized_spark_configuration(
14591571
return config_file_s3_uri
14601572

14611573

1574+
def _extend_mpirun_to_request(
1575+
request_dict: Dict,
1576+
job_settings: _JobSettings,
1577+
) -> Dict:
1578+
"""Extend the create training job request with mpirun configuration.
1579+
1580+
Args:
1581+
request_dict (Dict): create training job request dict.
1582+
job_settings (_JobSettings): the job settings.
1583+
"""
1584+
use_mpirun = job_settings.use_mpirun
1585+
instance_count = job_settings.instance_count
1586+
1587+
if not use_mpirun:
1588+
return request_dict
1589+
1590+
if instance_count == 1:
1591+
return request_dict
1592+
1593+
extended_request = request_dict.copy()
1594+
1595+
for input_channel in extended_request["InputDataConfig"]:
1596+
s3_data_source = input_channel["DataSource"].get("S3DataSource", None)
1597+
if s3_data_source:
1598+
s3_data_source["S3DataDistributionType"] = "FullyReplicated"
1599+
1600+
return extended_request
1601+
1602+
14621603
def _extend_torchrun_to_request(
14631604
request_dict: Dict,
14641605
job_settings: _JobSettings,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker modules container_drivers directory."""
14+
from __future__ import absolute_import

0 commit comments

Comments
 (0)