Skip to content

Commit 788dccc

Browse files
authored
Merge pull request #32 from alan-turing-institute/14-ocr-reverted
14 ocr reverted
2 parents af48055 + 510d863 commit 788dccc

File tree

5 files changed

+156
-24
lines changed

5 files changed

+156
-24
lines changed

src/arc_spice/eval/inference_utils.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tqdm import tqdm
99

1010
from arc_spice.data.multieurlex_utils import MultiHot
11+
from arc_spice.eval.ocr_error import ocr_error
1112
from arc_spice.eval.translation_error import conditional_probability, get_comet_model
1213
from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
1314
RTCSingleComponentPipeline,
@@ -68,10 +69,16 @@ def get_results(
6869
)._asdict()
6970
return results_dict
7071

71-
def recognition_results(self, *args, **kwargs):
72+
def recognition_results(
73+
self,
74+
clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]],
75+
var_output: dict[str, dict],
76+
**kwargs,
77+
):
7278
# ### RECOGNITION ###
73-
# TODO: add this into results_getter : issue #14
74-
return RecognitionResults(confidence=None, accuracy=None)
79+
charerror = ocr_error(clean_output)
80+
confidence = var_output["recognition"]["mean_entropy"]
81+
return RecognitionResults(confidence=confidence, accuracy=charerror)
7582

7683
def translation_results(
7784
self,

src/arc_spice/eval/ocr_error.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
OCR error computation for eval.
3+
"""
4+
5+
from typing import Any
6+
7+
from torchmetrics.text import CharErrorRate
8+
9+
10+
def ocr_error(ocr_output: dict[Any, Any]) -> float:
11+
"""
12+
Compute the character error rate for the predicted ocr character.
13+
14+
NB: - this puts all strings into lower case for comparisons.
15+
- ideal error rate is 0, worst case is 1.
16+
17+
Args:
18+
ocr_output: output from the ocr model, with structure,
19+
{
20+
'full_output: [
21+
{
22+
'generated_text': gen text from the ocr model (str)
23+
'target': target text (str)
24+
'entropies': entropies for UQ (torch.Tensor)
25+
}
26+
]
27+
'outpu': pieced back together full string (str)
28+
}
29+
30+
Returns:
31+
Character error rate across entire output of OCR (float)
32+
"""
33+
preds = [itm["generated_text"].lower() for itm in ocr_output["full_output"]]
34+
targs = [itm["target"].lower() for itm in ocr_output["full_output"]]
35+
cer = CharErrorRate()
36+
return cer(preds, targs).item()

src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
RTCVariationalPipelineBase,
99
)
1010
from arc_spice.variational_pipelines.utils import (
11+
CustomOCRPipeline,
1112
CustomTranslationPipeline,
1213
dropout_off,
1314
dropout_on,
@@ -87,21 +88,24 @@ def __init__(
8788
self,
8889
model_pars: dict[str, dict[str, str]],
8990
n_variational_runs=5,
91+
ocr_batch_size=64,
9092
**kwargs,
9193
):
9294
self.set_device()
9395
self.ocr: transformers.Pipeline = pipeline(
94-
task=model_pars["ocr"]["specific_task"],
9596
model=model_pars["ocr"]["model"],
9697
device=self.device,
98+
pipeline_class=CustomOCRPipeline,
99+
max_new_tokens=20,
100+
batch_size=ocr_batch_size,
97101
**kwargs,
98102
)
99103
self.model = self.ocr.model
100104
super().__init__(
101105
step_name="recognition",
102106
input_key="ocr_data",
103107
forward_function=self.recognise,
104-
confidence_function=self.recognise, # THIS WILL NEED UPDATING : #issue 14
108+
confidence_function=self.get_ocr_confidence,
105109
n_variational_runs=n_variational_runs,
106110
**kwargs,
107111
)
@@ -160,9 +164,9 @@ def __init__(
160164
super().__init__(
161165
step_name="classification",
162166
input_key="target_text",
163-
forward_function=self.classify_topic_zero_shot
164-
if zero_shot
165-
else self.classify_topic,
167+
forward_function=(
168+
self.classify_topic_zero_shot if zero_shot else self.classify_topic
169+
),
166170
confidence_function=self.get_classification_confidence,
167171
n_variational_runs=n_variational_runs,
168172
**kwargs,

src/arc_spice/variational_pipelines/RTC_variational_pipeline.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from transformers import pipeline
55

66
from arc_spice.variational_pipelines.utils import (
7+
CustomOCRPipeline,
78
CustomTranslationPipeline,
89
RTCVariationalPipelineBase,
910
dropout_off,
@@ -38,6 +39,7 @@ def __init__(
3839
data_pars: dict[str, Any],
3940
n_variational_runs=5,
4041
translation_batch_size=16,
42+
ocr_batch_size=64,
4143
) -> None:
4244
# are we doing zero-shot-classification?
4345
if model_pars["classifier"]["specific_task"] == "zero-shot-classification":
@@ -47,9 +49,11 @@ def __init__(
4749
super().__init__(self.zero_shot, n_variational_runs, translation_batch_size)
4850
# defining the pipeline objects
4951
self.ocr = pipeline(
50-
task=model_pars["ocr"]["specific_task"],
5152
model=model_pars["ocr"]["model"],
5253
device=self.device,
54+
pipeline_class=CustomOCRPipeline,
55+
max_new_tokens=20,
56+
batch_size=ocr_batch_size,
5357
)
5458
self.translator = pipeline(
5559
task=model_pars["translator"]["specific_task"],

src/arc_spice/variational_pipelines/utils.py

+96-15
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
from functools import partial
66
from typing import Any
77

8+
import numpy as np
89
import torch
910
import transformers
11+
from torch.distributions import Categorical
1012
from torch.nn.functional import softmax
1113
from transformers import (
1214
AutoModelForSequenceClassification,
1315
AutoTokenizer,
16+
ImageToTextPipeline,
1417
Pipeline,
1518
TranslationPipeline,
1619
pipeline,
@@ -149,9 +152,9 @@ def __init__(self, zero_shot: bool, n_variational_runs=5, translation_batch_size
149152
self.func_map = {
150153
"recognition": self.recognise,
151154
"translation": self.translate,
152-
"classification": self.classify_topic_zero_shot
153-
if zero_shot
154-
else self.classify_topic,
155+
"classification": (
156+
self.classify_topic_zero_shot if zero_shot else self.classify_topic
157+
),
155158
}
156159
# the naive outputs of the pipeline stages calculated in self.clean_inference
157160
self.naive_outputs = {
@@ -266,21 +269,44 @@ def check_dropout(pipeline_map: transformers.Pipeline):
266269
set_dropout(model=pl.model, dropout_flag=False)
267270
logger.debug("-------------------------------------------------------\n\n")
268271

269-
def recognise(self, inp) -> dict[str, str]:
272+
def recognise(self, inp) -> dict[str, str | list[dict[str, str | torch.Tensor]]]:
270273
"""
271-
Function to perform OCR
274+
Function to perform OCR.
272275
273276
Args:
274-
inp: input
277+
inp: input dict with key 'ocr_data', containing dict,
278+
{
279+
'ocr_images': list[ocr images],
280+
'ocr_targets': list[ocr target words]
281+
}
275282
276283
Returns:
277-
dictionary of outputs
284+
dictionary of outputs:
285+
{
286+
'full_output': [
287+
{
288+
'generated_text': generated text from ocr model (str),
289+
'target': original target text (str)
290+
}
291+
],
292+
'output': pieced back together string (str)
293+
}
278294
"""
279-
# Until the OCR data is available
280-
# This will need the below comment:
281-
# type: ignore[misc]
282-
# TODO https://github.com/alan-turing-institute/ARC-SPICE/issues/14
283-
return {"outputs": inp["source_text"]}
295+
out = self.ocr(inp["ocr_data"]["ocr_images"]) # type: ignore[misc]
296+
text = " ".join([itm[0]["generated_text"] for itm in out])
297+
return {
298+
"full_output": [
299+
{
300+
"target": target,
301+
"generated_text": gen_text["generated_text"],
302+
"entropies": gen_text["entropies"],
303+
}
304+
for target, gen_text in zip(
305+
inp["ocr_data"]["ocr_targets"], out, strict=True
306+
)
307+
],
308+
"output": text,
309+
}
284310

285311
def translate(self, text: str) -> dict[str, torch.Tensor | str]:
286312
"""
@@ -352,9 +378,7 @@ def classify_topic_zero_shot(self, text: str) -> dict[str, list[float] | dict]:
352378
descriptors["en"]
353379
for descriptors in self.dataset_meta_data["class_descriptors"] # type: ignore[index]
354380
]
355-
forward = self.classifier( # type: ignore[misc]
356-
text, labels
357-
)
381+
forward = self.classifier(text, labels) # type: ignore[misc]
358382
return collate_scores(
359383
[
360384
{"label": label, "score": score}
@@ -560,6 +584,28 @@ def get_classification_confidence(
560584
)
561585
return var_output
562586

587+
def get_ocr_confidence(self, var_output: dict) -> dict[str, float]:
588+
"""Generate the ocr confidence score.
589+
590+
Args:
591+
var_output: variational run outputs
592+
593+
Returns:
594+
dictionary with metrics
595+
"""
596+
# Adapted for variational methods from: https://arxiv.org/pdf/2412.01221
597+
stacked_entropies = torch.stack(
598+
[
599+
[data["entropies"] for data in output["full_output"]]
600+
for output in var_output["recognition"]
601+
],
602+
dim=1,
603+
)
604+
# mean entropy
605+
mean = torch.mean(stacked_entropies)
606+
var_output["recognition"].update({"mean_entropy": mean})
607+
return var_output
608+
563609

564610
# Translation pipeline with additional functionality to save logits from fwd pass
565611
class CustomTranslationPipeline(TranslationPipeline):
@@ -619,3 +665,38 @@ def _forward(self, model_inputs, **generate_kwargs):
619665
"scores": max_token_scores,
620666
"entropy": normalised_entropy,
621667
}
668+
669+
670+
class CustomOCRPipeline(ImageToTextPipeline):
671+
"""
672+
custom OCR pipeline to return logits with the generated text.
673+
"""
674+
675+
def postprocess(self, model_outputs: dict, **postprocess_params):
676+
raw_out = copy.deepcopy(model_outputs)
677+
processed = super().postprocess(
678+
model_outputs["model_output"], **postprocess_params
679+
)
680+
681+
return {"generated_text": processed[0]["generated_text"], "raw_output": raw_out}
682+
683+
def _forward(self, model_inputs, **generate_kwargs):
684+
if (
685+
"input_ids" in model_inputs
686+
and isinstance(model_inputs["input_ids"], list)
687+
and all(x is None for x in model_inputs["input_ids"])
688+
):
689+
model_inputs["input_ids"] = None
690+
691+
inputs = model_inputs.pop(self.model.main_input_name)
692+
out = self.model.generate(
693+
inputs,
694+
**model_inputs,
695+
**generate_kwargs,
696+
output_logits=True,
697+
return_dict_in_generate=True,
698+
)
699+
700+
logits = torch.stack(out.logits, dim=1)
701+
entropy = Categorical(logits=logits).entropy() / np.log(logits[0].size()[1])
702+
return {"model_output": out.sequences, "entropies": entropy}

0 commit comments

Comments
 (0)