forked from karpathy/nanoGPT
-
Notifications
You must be signed in to change notification settings - Fork 19
/
run_curriculum_learning.py
61 lines (51 loc) · 1.87 KB
/
run_curriculum_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import subprocess
import argparse
import os
import json
prev_csv_dir = ""
prev_output_dir = ""
def run_experiments_command(training_stage, config, **kwargs):
global prev_csv_dir
global prev_output_dir
dataset = os.path.splitext(config)[0]
csv_dir = f"{training_stage}_{dataset}"
output_dir = f"{training_stage}_{dataset}"
# base command
command = ["python3", "run_experiments.py"]
command.extend(["--config", f"explorations/{config}"])
# directory to output csv logs
command.extend(["--csv_ckpt_dir", csv_dir])
# directory to output ckpts
command.extend(["--output_dir", output_dir])
if prev_csv_dir and prev_output_dir:
command.extend(["--use-best-val-loss-from", "csv_logs/" + prev_csv_dir, prev_output_dir])
prev_csv_dir = csv_dir
prev_output_dir = output_dir
for key, val in kwargs.items():
command.extend([f"--override_{key}", str(val)])
return command
def main(config_file):
ext = os.path.splitext(config_file)
if ext[1] == ".py":
with open(config_file, "r") as f:
configs = f.read().splitlines()
for i, config in enumerate(configs):
subprocess.run(run_experiments_command(i+1, config))
elif ext[1] == ".json":
with open(config_file, "r") as f:
configs = json.load(f)
for i, config in enumerate(configs):
subprocess.run(run_experiments_command(i+1, **config))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Runs curriculum learning on the datasets from the provided config files."
)
parser.add_argument(
"-c",
"--config_file",
type=str,
default="curriculum/curriculum.py",
help="Path to the config file which stores the list of config files to be run."
)
args = parser.parse_args()
main(args.config_file)