Skip to content

Commit abe4fc6

Browse files
cli99bstee615tjruwase
authored
encoded ds config into command line argument when launching child processes in autotuning (#2524)
* rollback ds config changes * fix format * Fix error when output_file is a relative path without a prefix (#2397) Co-authored-by: Benjamin Steenhoek <[email protected]> * fix restuls and exprs path to use absolute path * use base64 encoded ds config as cmd arg * fix format * remove assert * write out optimial config after tuning * fix format * no need to update ds config path when encoding ds config * udpate * do not use abs path for result and expr dir * fix conflicts * fix run mode * fix format * fix format Co-authored-by: Benjamin Steenhoek <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 340fc0c commit abe4fc6

File tree

5 files changed

+36
-31
lines changed

5 files changed

+36
-31
lines changed

deepspeed/autotuning/autotuner.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(self, args, active_resources):
8787
self.rm.nodes), "num_nodes in the autotuning configuration must not be less than the --num_nodes value in the train script if any"
8888

8989
self.records = {}
90+
self.optimal_cmd = None
91+
self.optmal_ds_config = None
9092

9193
def print_tuning_results(self):
9294
"""Print the autotuning results in tabular format.
@@ -1125,9 +1127,6 @@ def write_optimal_config(self):
11251127
ds_config_path = os.path.join(self.results_dir, "ds_config_optimal.json")
11261128
json.dump(ds_config, open(ds_config_path, "w"))
11271129

1128-
idx = cmd.index(os.path.join(exp_dir, "ds_config.json"))
1129-
cmd[idx] = ds_config_path
1130-
11311130
cmd_path = os.path.join(self.results_dir, "cmd_optimal.txt")
11321131
with open(cmd_path, "w") as fd:
11331132
fd.write(" ".join(cmd))
@@ -1138,9 +1137,6 @@ def write_optimal_config(self):
11381137
logger.info(
11391138
f"Wrote the optimal DeepSpeed configuration found by autotuning to {ds_config_path}, and the corresponding DeepSpeed command to {cmd_path}"
11401139
)
1141-
else:
1142-
self.optimal_cmd = None
1143-
self.optmal_ds_config = None
11441140

11451141
def run_after_tuning(self):
11461142
""" Launches the training with the optimal DeepSpeed configuration found through the autotuning process.

deepspeed/autotuning/config.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,13 @@ def _initialize(self, autotuning_dict):
3838
AUTOTUNING_FAST,
3939
AUTOTUNING_FAST_DEFAULT)
4040

41-
self.results_dir = os.path.abspath(
42-
get_scalar_param(autotuning_dict,
43-
AUTOTUNING_RESULTS_DIR,
44-
AUTOTUNING_RESULTS_DIR_DEFAULT))
41+
self.results_dir = get_scalar_param(autotuning_dict,
42+
AUTOTUNING_RESULTS_DIR,
43+
AUTOTUNING_RESULTS_DIR_DEFAULT)
4544
assert self.results_dir, "results_dir cannot be empty"
46-
self.exps_dir = os.path.abspath(
47-
get_scalar_param(autotuning_dict,
48-
AUTOTUNING_EXPS_DIR,
49-
AUTOTUNING_EXPS_DIR_DEFAULT))
45+
self.exps_dir = get_scalar_param(autotuning_dict,
46+
AUTOTUNING_EXPS_DIR,
47+
AUTOTUNING_EXPS_DIR_DEFAULT)
5048
assert self.exps_dir, "exps_dir cannot be empty"
5149
self.overwrite = get_scalar_param(autotuning_dict,
5250
AUTOTUNING_OVERWRITE,

deepspeed/autotuning/scheduler.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import sys
77
import threading
88
import time
9+
import base64
910

11+
import os
1012
import hjson
1113
from tqdm import tqdm
1214

1315
from ..utils import logger
14-
from .constants import *
1516
from .constants import AUTOTUNING, AUTOTUNING_METRIC_PATH
1617
from .utils import get_val_by_key, search_error, was_interruptted
1718
"""
@@ -180,7 +181,6 @@ def run(self):
180181
logger.debug(f'Put exp_id = {exp["exp_id"]} back into the queue')
181182
self.experiment_check(pbar)
182183
else:
183-
184184
desc = ""
185185
for reservation in reservations:
186186
reservation.slots.sort()
@@ -336,19 +336,27 @@ def run_experiment(exp: dict, reservations, user_script, user_args):
336336
exp["job_id"] = get_job_id()
337337
exp_dir = exp["result_dir"]
338338
os.makedirs(exp_dir, exist_ok=True)
339-
340-
exp["ds_config_path"] = os.path.join(exp_dir, "ds_config.json")
339+
ds_config_path = os.path.join(exp_dir, "ds_config.json")
340+
exp["ds_config_path"] = ds_config_path
341341

342342
ds_config = copy.deepcopy(exp["ds_config"])
343+
ds_config_json = json.dumps(ds_config).encode('utf-8')
344+
345+
exp["ds_config_base64"] = base64.urlsafe_b64encode(ds_config_json).decode('utf-8')
343346

344347
with open(exp["ds_config_path"], "w", buffering=BUFSIZE) as fd:
345348
json.dump(ds_config, fd)
346349
fd.flush()
347350
os.fsync(fd)
351+
path = exp["ds_config_path"]
352+
logger.info(f"Scheduler wrote ds_config to {path}, {os.path.abspath(path)}")
353+
348354
with open(os.path.join(exp_dir, "exp.json"), "w", buffering=BUFSIZE) as fd:
349355
json.dump(exp, fd)
350356
fd.flush()
351357
os.fsync(fd)
358+
path = os.path.join(exp_dir, "exp.json")
359+
logger.info(f"Scheduler wrote exp to {path}, {os.path.abspath(path)}")
352360

353361
# remove "--deepspeed_config ds_config.json" from user_args
354362
if user_args:
@@ -357,9 +365,10 @@ def run_experiment(exp: dict, reservations, user_script, user_args):
357365
# "--deepspeed_config" is omitted in HF
358366
elif "--deepspeed" in user_args:
359367
idx = user_args.index("--deepspeed")
360-
assert idx < len(user_args) and ".json" in user_args[idx +
361-
1], "there is no ds_config file specified after --deepspeed_config or --deepspeed"
362-
user_args[idx + 1] = exp["ds_config_path"]
368+
assert idx < len(user_args), "there is no ds_config file specified after --deepspeed_config or --deepspeed"
369+
# user_args[idx + 1] = exp["ds_config_path"]
370+
# pass base64 serialized ds_config to launcher
371+
user_args[idx + 1] = exp["ds_config_base64"]
363372

364373
exp["user_script"] = user_script
365374
exp["user_args"] = user_args
@@ -375,7 +384,7 @@ def run_experiment(exp: dict, reservations, user_script, user_args):
375384
os.fsync(fd)
376385

377386
logger.info(
378-
f"Launching exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}"
387+
f"Launching exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}, and ds_config = {os.path.abspath(ds_config_path)}"
379388
)
380389

381390
with open(os.path.join(exp_dir, "stdout.log"), "wb") as out, open(

deepspeed/runtime/config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import json
1010
import copy
11+
import base64
1112

1213
from .constants import *
1314
from .fp16.loss_scaler import (
@@ -724,9 +725,13 @@ def __init__(self, config: Union[str, dict], mpu=None):
724725
"r"),
725726
object_pairs_hook=dict_raise_error_on_duplicate_keys)
726727
else:
727-
raise ValueError(
728-
f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}"
729-
)
728+
try:
729+
config_decoded = base64.urlsafe_b64decode(config).decode('utf-8')
730+
self._param_dict = json.loads(config_decoded)
731+
except (UnicodeDecodeError, AttributeError):
732+
raise ValueError(
733+
f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}"
734+
)
730735
try:
731736
self.global_rank = dist.get_rank()
732737
if mpu is None:

deepspeed/runtime/engine.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -956,12 +956,6 @@ def _do_args_sanity_check(self, args):
956956
args, "deepspeed_config") and args.deepspeed_config is not None
957957
), "DeepSpeed requires --deepspeed_config to specify configuration file"
958958

959-
assert os.path.isfile(
960-
args.deepspeed_config
961-
), "DeepSpeed configuration file: {} is not an existing file".format(
962-
args.deepspeed_config
963-
)
964-
965959
def _is_supported_optimizer(self, optimizer_name):
966960
return (optimizer_name in DEEPSPEED_OPTIMIZERS
967961
or getattr(torch.optim,
@@ -2162,6 +2156,9 @@ def _autotuning_exit(self):
21622156
msg["throughput"] = self.train_batch_size() * 1000 / \
21632157
msg["latency"]
21642158
print_json_dist(msg, [0], path=self.autotuning_metric_path())
2159+
log_dist(
2160+
f"Wrote metrics to {self.autotuning_metric_path()}, {os.path.abspath(self.autotuning_metric_path())}",
2161+
ranks=[0])
21652162
import atexit
21662163
atexit.register(print, "Autotuning: done with running current ds config.")
21672164
exit()

0 commit comments

Comments
 (0)