Skip to content

Commit 91b574e

Browse files
authored
Merge pull request #34 from alan-turing-institute/33-fix-bugs-for-inference-on-baskerville
33 fix bugs for inference on baskerville
2 parents 3f6c10d + 4b950ff commit 91b574e

24 files changed

+174
-91
lines changed

config/RTC_configs/roberta-mt5-trained.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
ocr:
22
specific_task: "image-to-text"
3-
model: "microsoft/trocr-base-handwritten"
3+
model: "microsoft/trocr-small-printed"
44

55
translator:
66
specific_task: "translation_fr_to_en"

config/RTC_configs/roberta-mt5-zero-shot.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
ocr:
2-
specific_task: "image-to-text"
3-
model: "microsoft/trocr-base-handwritten"
2+
model: "microsoft/trocr-small-printed"
43

54
translator:
65
specific_task: "translation_fr_to_en"

config/data_configs/l1_fr_to_en.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ lang_pair:
77
target: "en"
88

99
drop_length: 1000
10+
11+
load_ocr_data: True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
data_config: l1_fr_to_en
2+
3+
pipeline_config: roberta-mt5-zero-shot
4+
5+
seed:
6+
- 42
7+
- 43
8+
- 44
9+
10+
bask:
11+
jobname: "full_experiment_with_zero_shot"
12+
walltime: '0-24:0:0'
13+
gpu_number: 1
14+
node_number: 1
15+
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"

scripts/single_component_inference.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,19 @@ def main(
4646
# initialise pipeline
4747
data_config = open_yaml_path(data_config_pth)
4848
pipeline_config = open_yaml_path(pipeline_config_pth)
49+
50+
if model_key != "ocr":
51+
data_config["load_ocr_data"] = False
52+
4953
data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
5054
test_loader = data_sets["test"]
5155
if model_key == "ocr":
5256
rtc_single_component_pipeline = RecognitionVariationalPipeline(
53-
model_pars=pipeline_config, data_pars=meta_data
57+
model_pars=pipeline_config
5458
)
5559
elif model_key == "translator":
5660
rtc_single_component_pipeline = TranslationVariationalPipeline(
57-
model_pars=pipeline_config, data_pars=meta_data
61+
model_pars=pipeline_config
5862
)
5963
elif model_key == "classifier":
6064
rtc_single_component_pipeline = ClassificationVariationalPipeline(

src/arc_spice/data/multieurlex_utils.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,17 @@ def extract_articles(
6767

6868
def _make_ocr_data(text: str) -> list[tuple[Image.Image, str]]:
6969
text_split = text.split()
70-
text_split = [text for text in text_split if text not in ("", " ", None)]
70+
text_split = [text for text in text_split if text not in ("", " ")]
7171
generator = GeneratorFromStrings(text_split, count=len(text_split))
7272
return list(generator)
7373

7474

75-
def make_ocr_data(item: LazyRow) -> dict[str, tuple[Image.Image] | tuple[str]]:
76-
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
77-
return {"ocr_images": images, "ocr_targets": targets}
75+
def make_ocr_data(item: LazyRow) -> dict:
76+
try:
77+
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
78+
except ValueError:
79+
return {"ocr_data": {"ocr_images": None, "ocr_targets": None}}
80+
return {"ocr_data": {"ocr_images": images, "ocr_targets": targets}}
7881

7982

8083
class TranslationPreProcesser:
@@ -229,11 +232,14 @@ def load_multieurlex_for_pipeline(
229232
make_ocr_data,
230233
features=datasets.Features(
231234
{
232-
"ocr_images": datasets.Sequence(datasets.Image(decode=True)),
233-
"ocr_targets": datasets.Sequence(datasets.Value("string")),
235+
"ocr_data": {
236+
"ocr_images": datasets.Sequence(
237+
datasets.Image(decode=True)
238+
),
239+
"ocr_targets": datasets.Sequence(datasets.Value("string")),
240+
},
234241
**feats,
235242
}
236243
),
237244
)
238-
239245
return dataset_dict, meta_data

src/arc_spice/eval/inference_utils.py

+52-12
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,15 @@
1717
RTCVariationalPipeline,
1818
)
1919

20-
RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"])
20+
RecognitionResults = namedtuple(
21+
"RecognitionResults",
22+
[
23+
"mean_entropy",
24+
"character_error_rate",
25+
"full_output",
26+
"max_scores",
27+
],
28+
)
2129

2230
TranslationResults = namedtuple(
2331
"TranslationResults",
@@ -71,14 +79,19 @@ def get_results(
7179

7280
def recognition_results(
7381
self,
74-
clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]],
75-
var_output: dict[str, dict],
82+
clean_output: dict,
83+
var_output: dict,
7684
**kwargs,
7785
):
7886
# ### RECOGNITION ###
79-
charerror = ocr_error(clean_output)
87+
charerror = ocr_error(clean_output["recognition"])
8088
confidence = var_output["recognition"]["mean_entropy"]
81-
return RecognitionResults(confidence=confidence, accuracy=charerror)
89+
return RecognitionResults(
90+
mean_entropy=confidence,
91+
character_error_rate=charerror,
92+
max_scores=clean_output["recognition"]["outputs"]["max_scores"],
93+
full_output=clean_output["recognition"]["full_output"],
94+
)
8295

8396
def translation_results(
8497
self,
@@ -150,13 +163,40 @@ def run_inference(
150163
pipeline: RTCVariationalPipeline | RTCSingleComponentPipeline,
151164
results_getter: ResultsGetter,
152165
):
166+
type_errors = []
167+
oom_errors = []
153168
results = []
154169
for _, inp in enumerate(tqdm(dataloader)):
155-
clean_out, var_out = pipeline.variational_inference(inp)
156-
row_results_dict = results_getter.get_results(
157-
clean_output=clean_out,
158-
var_output=var_out,
159-
test_row=inp,
160-
)
161-
results.append({inp["celex_id"]: row_results_dict})
170+
# TEMPORARY FIX
171+
try:
172+
clean_out, var_out = pipeline.variational_inference(inp)
173+
row_results_dict = results_getter.get_results(
174+
clean_output=clean_out,
175+
var_output=var_out,
176+
test_row=inp,
177+
)
178+
results.append({inp["celex_id"]: row_results_dict})
179+
# TEMPORARY FIX ->
180+
except TypeError:
181+
type_errors.append(inp["celex_id"])
182+
continue
183+
184+
except torch.cuda.OutOfMemoryError:
185+
oom_errors.append(inp["celex_id"])
186+
continue
187+
188+
except torch.OutOfMemoryError:
189+
oom_errors.append(inp["celex_id"])
190+
continue
191+
192+
print("Skipped following CELEX IDs due to TypeError:")
193+
print(
194+
'"TypeError: Incorrect format used for image. Should be an url linking to'
195+
' an image, a base64 string, a local path, or a PIL image."'
196+
)
197+
print(type_errors)
198+
199+
print("Skipped following CELEX IDs due to torch.cuda.OutOfMemoryError:")
200+
print(oom_errors)
201+
162202
return results

src/arc_spice/eval/ocr_error.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from torchmetrics.text import CharErrorRate
88

9+
cer = CharErrorRate()
10+
911

1012
def ocr_error(ocr_output: dict[Any, Any]) -> float:
1113
"""
@@ -30,7 +32,6 @@ def ocr_error(ocr_output: dict[Any, Any]) -> float:
3032
Returns:
3133
Character error rate across entire output of OCR (float)
3234
"""
33-
preds = [itm["generated_text"].lower() for itm in ocr_output["full_output"]]
34-
targs = [itm["target"].lower() for itm in ocr_output["full_output"]]
35-
cer = CharErrorRate()
36-
return cer(preds, targs).item()
35+
preds = [itm["generated_text"].lower() for itm in ocr_output["outputs"]]
36+
targs = [itm["target"].lower() for itm in ocr_output["outputs"]]
37+
return cer(preds, targs).detach().item()

src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -89,26 +89,23 @@ def __init__(
8989
model_pars: dict[str, dict[str, str]],
9090
n_variational_runs=5,
9191
ocr_batch_size=64,
92-
**kwargs,
9392
):
9493
self.set_device()
94+
super().__init__(
95+
step_name="recognition",
96+
input_key="ocr_data",
97+
forward_function=self.recognise,
98+
confidence_function=self.get_ocr_confidence,
99+
n_variational_runs=n_variational_runs,
100+
)
95101
self.ocr: transformers.Pipeline = pipeline(
96102
model=model_pars["ocr"]["model"],
97103
device=self.device,
98104
pipeline_class=CustomOCRPipeline,
99105
max_new_tokens=20,
100106
batch_size=ocr_batch_size,
101-
**kwargs,
102107
)
103108
self.model = self.ocr.model
104-
super().__init__(
105-
step_name="recognition",
106-
input_key="ocr_data",
107-
forward_function=self.recognise,
108-
confidence_function=self.get_ocr_confidence,
109-
n_variational_runs=n_variational_runs,
110-
**kwargs,
111-
)
112109
self._init_pipeline_map()
113110

114111

@@ -118,7 +115,6 @@ def __init__(
118115
model_pars: dict[str, dict[str, str]],
119116
n_variational_runs=5,
120117
translation_batch_size=4,
121-
**kwargs,
122118
):
123119
self.set_device()
124120
# need to initialise the NLI models in this case

src/arc_spice/variational_pipelines/RTC_variational_pipeline.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,19 @@ def clean_inference(self, x: torch.Tensor) -> dict[str, dict]:
7979

8080
# run the functions
8181
# UNTIL THE OCR DATA IS AVAILABLE
82-
clean_output["recognition"] = self.recognise(x)
82+
clean_output["recognition"] = self.recognise(x["ocr_data"])
8383

8484
clean_output["translation"] = self.translate(
85-
clean_output["recognition"]["outputs"]
85+
clean_output["recognition"]["full_output"]
8686
)
8787
# we now need to pass the input correct to the correct forward method
8888
if self.zero_shot:
8989
clean_output["classification"] = self.classify_topic_zero_shot(
90-
clean_output["translation"]["outputs"][0]
90+
clean_output["translation"]["full_output"]
9191
)
9292
else:
9393
clean_output["classification"] = self.classify_topic(
94-
clean_output["translation"]["outputs"][0]
94+
clean_output["translation"]["full_output"]
9595
)
9696
return clean_output
9797

@@ -109,8 +109,8 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
109109
}
110110
# define the input map for brevity in forward pass
111111
input_map = {
112-
"recognition": x,
113-
"translation": clean_output["recognition"]["outputs"],
112+
"recognition": x["ocr_data"],
113+
"translation": clean_output["recognition"]["full_output"],
114114
"classification": clean_output["translation"]["full_output"],
115115
}
116116

@@ -130,6 +130,7 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
130130

131131
# run metric helper functions
132132
var_output = self.stack_variational_outputs(var_output)
133+
var_output = self.get_ocr_confidence(var_output)
133134
var_output = self.translation_semantic_density(
134135
clean_output=clean_output, var_output=var_output
135136
)

0 commit comments

Comments
 (0)