|
16 | 16 |
|
17 | 17 | from jsonargparse import CLI
|
18 | 18 |
|
19 |
| -from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation |
| 19 | +from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline |
20 | 20 | from arc_spice.eval.inference_utils import ResultsGetter, run_inference
|
21 |
| -from arc_spice.utils import open_yaml_path |
| 21 | +from arc_spice.utils import open_yaml_path, seed_everything |
22 | 22 | from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
|
23 | 23 | ClassificationVariationalPipeline,
|
24 | 24 | RecognitionVariationalPipeline,
|
|
28 | 28 | OUTPUT_DIR = "outputs"
|
29 | 29 |
|
30 | 30 |
|
31 |
| -def main(pipeline_config_pth: str, data_config_pth: str, model_key: str): |
| 31 | +def main( |
| 32 | + pipeline_config_pth: str, |
| 33 | + data_config_pth: str, |
| 34 | + seed: int, |
| 35 | + experiment_name: str, |
| 36 | + model_key: str, |
| 37 | +): |
32 | 38 | """
|
33 | 39 | Run inference on a given pipeline component with provided data config and model key.
|
34 | 40 |
|
35 | 41 | Args:
|
36 | 42 | pipeline_config_pth: path to pipeline config yaml file
|
37 | 43 | data_config_pth: path to data config yaml file
|
| 44 | + seed: seed for the the inference pass |
| 45 | + experiment_name: name of experiment for saving purposes |
38 | 46 | model_key: name of model on which to run inference
|
39 | 47 | """
|
| 48 | + # create save directory -> fail if already exists |
| 49 | + data_name = data_config_pth.split("/")[-1].split(".")[0] |
| 50 | + pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] |
| 51 | + save_loc = ( |
| 52 | + f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" |
| 53 | + f"{experiment_name}/seed_{seed}/" |
| 54 | + ) |
| 55 | + # This directory needs to exist for all 4 experiments |
| 56 | + os.makedirs(save_loc, exist_ok=True) |
| 57 | + # seed experiment |
| 58 | + seed_everything(seed=seed) |
40 | 59 | # initialise pipeline
|
41 | 60 | data_config = open_yaml_path(data_config_pth)
|
42 | 61 | pipeline_config = open_yaml_path(pipeline_config_pth)
|
43 |
| - data_sets, meta_data = load_multieurlex_for_translation(**data_config) |
| 62 | + data_sets, meta_data = load_multieurlex_for_pipeline(**data_config) |
44 | 63 | test_loader = data_sets["test"]
|
45 | 64 | if model_key == "ocr":
|
46 | 65 | rtc_single_component_pipeline = RecognitionVariationalPipeline(
|
@@ -69,14 +88,6 @@ def main(pipeline_config_pth: str, data_config_pth: str, model_key: str):
|
69 | 88 | results_getter=results_getter,
|
70 | 89 | )
|
71 | 90 |
|
72 |
| - data_name = data_config_pth.split("/")[-1].split(".")[0] |
73 |
| - pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] |
74 |
| - save_loc = ( |
75 |
| - f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" |
76 |
| - f"single_component" |
77 |
| - ) |
78 |
| - os.makedirs(save_loc, exist_ok=True) |
79 |
| - |
80 | 91 | with open(f"{save_loc}/{model_key}.json", "w") as save_file:
|
81 | 92 | json.dump(test_results, save_file)
|
82 | 93 |
|
|
0 commit comments