Skip to content

Commit 0b79ed9

Browse files
authored
Merge pull request #30 from alan-turing-institute/15-run-inference-on-baskerville
15 run inference on baskerville
2 parents 2f15ada + 92a98f0 commit 0b79ed9

16 files changed

+347
-116
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ Thumbs.db
160160

161161

162162
slurm_scripts/slurm_logs*
163+
slurm_scripts/experiments*
163164
# other
164165
temp
165166
.vscode

config/RTC_configs/roberta-mt5-zero-shot.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
OCR:
1+
ocr:
22
specific_task: "image-to-text"
33
model: "microsoft/trocr-base-handwritten"
44

config/data_configs/l1_fr_to_en.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ level: 1
55
lang_pair:
66
source: "fr"
77
target: "en"
8+
9+
drop_length: 1000
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
data_config: l1_fr_to_en
2+
3+
pipeline_config: roberta-mt5-zero-shot
4+
5+
seed:
6+
- 42
7+
8+
bask:
9+
jobname: "shortened_input_test"
10+
walltime: '0-12:0:0'
11+
gpu_number: 1
12+
node_number: 1
13+
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
data_config: l1_fr_to_en
2+
3+
pipeline_config: roberta-mt5-zero-shot
4+
5+
seed:
6+
- 42
7+
- 43
8+
- 44
9+
10+
bask:
11+
jobname: "full_experiment_with_zero_shot"
12+
walltime: '0-24:0:0'
13+
gpu_number: 1
14+
node_number: 1
15+
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"

scripts/README.md

+27
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,30 @@ It's called like so e.g. from project root:
106106
```bash
107107
python scripts/pipeline_inference.py [pipeline_config_path] [data_config_path] translator
108108
```
109+
110+
## gen_jobscripts.py
111+
112+
Create jobscript `.sh` files for an experiment, which in this case refers to a `data_config` and `pipeline_config` combo.
113+
It takes a single argument which is `experiment_config_path`. This refers to a file path to a `.yaml` file structured as below:
114+
115+
### eg. Experiment config:
116+
117+
```yaml
118+
data_config: l1_fr_to_en
119+
120+
pipeline_config: roberta-mt5-zero-shot
121+
122+
seed:
123+
- 42
124+
- 43
125+
- 44
126+
127+
bask:
128+
jobname: "full_experiment_with_zero_shot"
129+
walltime: '0-24:0:0'
130+
gpu_number: 1
131+
node_number: 1
132+
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"
133+
134+
135+
```

scripts/gen_jobscripts.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
from pathlib import Path
3+
4+
from jinja2 import Environment, FileSystemLoader
5+
from jsonargparse import CLI
6+
7+
from arc_spice.utils import open_yaml_path
8+
9+
PROJECT_DIR = Path(__file__, "..", "..").resolve()
10+
11+
12+
def main(experiment_config_path: str):
13+
"""
14+
_summary_
15+
16+
Args:
17+
experiment_config_path: _description_
18+
"""
19+
experiment_name = experiment_config_path.split("/")[-1].split(".")[0]
20+
experiment_config = open_yaml_path(experiment_config_path)
21+
pipeline_conf_dir = (
22+
f"{PROJECT_DIR}/config/RTC_configs/{experiment_config['pipeline_config']}.yaml"
23+
)
24+
data_conf_dir = (
25+
f"{PROJECT_DIR}/config/data_configs/{experiment_config['data_config']}.yaml"
26+
)
27+
pipeline_config = open_yaml_path(pipeline_conf_dir)
28+
# Get jinja template
29+
environment = Environment(
30+
loader=FileSystemLoader(PROJECT_DIR / "src" / "arc_spice" / "config")
31+
)
32+
template = environment.get_template("jobscript_template.sh")
33+
# We don't want to overwrite results
34+
35+
for index, seed in enumerate(experiment_config["seed"]):
36+
os.makedirs(
37+
f"slurm_scripts/experiments/{experiment_name}/run_{index}", exist_ok=False
38+
)
39+
for model in pipeline_config:
40+
model_script_dict: dict = experiment_config["bask"]
41+
model_script_dict.update(
42+
{
43+
"script_name": (
44+
"scripts/single_component_inference.py "
45+
f"{pipeline_conf_dir} {data_conf_dir} {seed}"
46+
f" {experiment_name} {model}"
47+
),
48+
"job_name": f"{experiment_name}_{model}",
49+
"seed": seed,
50+
}
51+
)
52+
model_train_script = template.render(model_script_dict)
53+
54+
with open(
55+
f"slurm_scripts/experiments/{experiment_name}/run_{index}/{model}.sh",
56+
"w",
57+
) as f:
58+
f.write(model_train_script)
59+
60+
pipeline_script_dict: dict = experiment_config["bask"]
61+
pipeline_script_dict.update(
62+
{
63+
"script_name": (
64+
"scripts/pipeline_inference.py "
65+
f"{pipeline_conf_dir} {data_conf_dir} {seed}"
66+
f" {experiment_name}"
67+
),
68+
"job_name": f"{experiment_name}_full_pipeline",
69+
"seed": seed,
70+
}
71+
)
72+
pipeline_train_script = template.render(pipeline_script_dict)
73+
74+
with open(
75+
f"slurm_scripts/experiments/{experiment_name}/run_{index}/full_pipeline.sh",
76+
"w",
77+
) as f:
78+
f.write(pipeline_train_script)
79+
80+
81+
if __name__ == "__main__":
82+
CLI(main)

scripts/pipeline_inference.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,43 @@
33

44
from jsonargparse import CLI
55

6-
from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
6+
from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline
77
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
8-
from arc_spice.utils import open_yaml_path
8+
from arc_spice.utils import open_yaml_path, seed_everything
99
from arc_spice.variational_pipelines.RTC_variational_pipeline import (
1010
RTCVariationalPipeline,
1111
)
1212

1313
OUTPUT_DIR = "outputs"
1414

1515

16-
def main(pipeline_config_pth: str, data_config_pth: str):
16+
def main(
17+
pipeline_config_pth: str, data_config_pth: str, seed: int, experiment_name: str
18+
):
1719
"""
1820
Run inference on a given pipeline with provided data config
1921
2022
Args:
2123
pipeline_config_pth: path to pipeline config yaml file
2224
data_config_pth: path to data config yaml file
25+
seed: seed for the the inference pass
26+
experiment_name: name of experiment for saving purposes
2327
"""
28+
# create save directory -> fail if already exists
29+
data_name = data_config_pth.split("/")[-1].split(".")[0]
30+
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
31+
save_loc = (
32+
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
33+
f"{experiment_name}/seed_{seed}/"
34+
)
35+
# This directory needs to exist for all 4 experiments
36+
os.makedirs(save_loc, exist_ok=True)
37+
# seed experiment
38+
seed_everything(seed=seed)
2439
# initialise pipeline
2540
data_config = open_yaml_path(data_config_pth)
2641
pipeline_config = open_yaml_path(pipeline_config_pth)
27-
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
42+
data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
2843
test_loader = data_sets["test"]
2944
rtc_variational_pipeline = RTCVariationalPipeline(
3045
model_pars=pipeline_config, data_pars=meta_data
@@ -37,11 +52,6 @@ def main(pipeline_config_pth: str, data_config_pth: str):
3752
results_getter=results_getter,
3853
)
3954

40-
data_name = data_config_pth.split("/")[-1].split(".")[0]
41-
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
42-
save_loc = f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}"
43-
os.makedirs(save_loc, exist_ok=True)
44-
4555
with open(f"{save_loc}/full_pipeline.json", "w") as save_file:
4656
json.dump(test_results, save_file)
4757

scripts/single_component_inference.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
from jsonargparse import CLI
1818

19-
from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
19+
from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline
2020
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
21-
from arc_spice.utils import open_yaml_path
21+
from arc_spice.utils import open_yaml_path, seed_everything
2222
from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
2323
ClassificationVariationalPipeline,
2424
RecognitionVariationalPipeline,
@@ -28,19 +28,38 @@
2828
OUTPUT_DIR = "outputs"
2929

3030

31-
def main(pipeline_config_pth: str, data_config_pth: str, model_key: str):
31+
def main(
32+
pipeline_config_pth: str,
33+
data_config_pth: str,
34+
seed: int,
35+
experiment_name: str,
36+
model_key: str,
37+
):
3238
"""
3339
Run inference on a given pipeline component with provided data config and model key.
3440
3541
Args:
3642
pipeline_config_pth: path to pipeline config yaml file
3743
data_config_pth: path to data config yaml file
44+
seed: seed for the the inference pass
45+
experiment_name: name of experiment for saving purposes
3846
model_key: name of model on which to run inference
3947
"""
48+
# create save directory -> fail if already exists
49+
data_name = data_config_pth.split("/")[-1].split(".")[0]
50+
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
51+
save_loc = (
52+
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
53+
f"{experiment_name}/seed_{seed}/"
54+
)
55+
# This directory needs to exist for all 4 experiments
56+
os.makedirs(save_loc, exist_ok=True)
57+
# seed experiment
58+
seed_everything(seed=seed)
4059
# initialise pipeline
4160
data_config = open_yaml_path(data_config_pth)
4261
pipeline_config = open_yaml_path(pipeline_config_pth)
43-
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
62+
data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
4463
test_loader = data_sets["test"]
4564
if model_key == "ocr":
4665
rtc_single_component_pipeline = RecognitionVariationalPipeline(
@@ -69,14 +88,6 @@ def main(pipeline_config_pth: str, data_config_pth: str, model_key: str):
6988
results_getter=results_getter,
7089
)
7190

72-
data_name = data_config_pth.split("/")[-1].split(".")[0]
73-
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
74-
save_loc = (
75-
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
76-
f"single_component"
77-
)
78-
os.makedirs(save_loc, exist_ok=True)
79-
8091
with open(f"{save_loc}/{model_key}.json", "w") as save_file:
8192
json.dump(test_results, save_file)
8293

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/bin/bash
2+
#SBATCH --account vjgo8416-spice
3+
#SBATCH --qos turing
4+
#SBATCH --job-name {{ job_name }}
5+
#SBATCH --time {{ walltime }}
6+
#SBATCH --nodes {{ node_number }}
7+
#SBATCH --gpus {{ gpu_number }}
8+
#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/{{ job_name }}-%j.out
9+
#SBATCH --cpus-per-gpu 18
10+
11+
12+
# Load required modules here
13+
module purge
14+
module load baskerville
15+
module load bask-apps/live/live
16+
module load Python/3.10.8-GCCcore-12.2.0
17+
18+
19+
# change working directory
20+
cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/
21+
22+
source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate
23+
24+
# change huggingface cache to be in project dir rather than user home
25+
export HF_HOME="{{ hf_cache_dir }}"
26+
27+
# TODO: script uses relative path to project home so must be run from home, fix
28+
python {{ script_name }}

src/arc_spice/data/multieurlex_utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def load_multieurlex(
133133
level: int,
134134
languages: list[str],
135135
drop_empty: bool = True,
136+
drop_length: int | None = None,
136137
split: str | None = None,
137138
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
138139
"""
@@ -188,6 +189,11 @@ def load_multieurlex(
188189
lambda x: all(x is not None for x in x["text"].values())
189190
)
190191

192+
if drop_length:
193+
dataset_dict = dataset_dict.filter(
194+
lambda x: len(x["text"][languages[0]]) <= drop_length
195+
)
196+
191197
# return datasets and metadata
192198
return dataset_dict, metadata
193199

@@ -197,11 +203,16 @@ def load_multieurlex_for_pipeline(
197203
level: int,
198204
lang_pair: dict[str, str],
199205
drop_empty: bool = True,
206+
drop_length: int | None = None,
200207
load_ocr_data: bool = False,
201208
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
202209
langs = [lang_pair["source"], lang_pair["target"]]
203210
dataset_dict, meta_data = load_multieurlex(
204-
data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty
211+
data_dir=data_dir,
212+
level=level,
213+
languages=langs,
214+
drop_empty=drop_empty,
215+
drop_length=drop_length,
205216
)
206217
# instantiate the preprocessor
207218
preprocesser = TranslationPreProcesser(lang_pair)

0 commit comments

Comments
 (0)