|
5 | 5 | from functools import partial
|
6 | 6 | from typing import Any
|
7 | 7 |
|
| 8 | +import numpy as np |
8 | 9 | import torch
|
9 | 10 | import transformers
|
| 11 | +from torch.distributions import Categorical |
10 | 12 | from torch.nn.functional import softmax
|
11 | 13 | from transformers import (
|
12 | 14 | AutoModelForSequenceClassification,
|
13 | 15 | AutoTokenizer,
|
| 16 | + ImageToTextPipeline, |
14 | 17 | Pipeline,
|
15 | 18 | TranslationPipeline,
|
16 | 19 | pipeline,
|
@@ -149,9 +152,9 @@ def __init__(self, zero_shot: bool, n_variational_runs=5, translation_batch_size
|
149 | 152 | self.func_map = {
|
150 | 153 | "recognition": self.recognise,
|
151 | 154 | "translation": self.translate,
|
152 |
| - "classification": self.classify_topic_zero_shot |
153 |
| - if zero_shot |
154 |
| - else self.classify_topic, |
| 155 | + "classification": ( |
| 156 | + self.classify_topic_zero_shot if zero_shot else self.classify_topic |
| 157 | + ), |
155 | 158 | }
|
156 | 159 | # the naive outputs of the pipeline stages calculated in self.clean_inference
|
157 | 160 | self.naive_outputs = {
|
@@ -266,21 +269,44 @@ def check_dropout(pipeline_map: transformers.Pipeline):
|
266 | 269 | set_dropout(model=pl.model, dropout_flag=False)
|
267 | 270 | logger.debug("-------------------------------------------------------\n\n")
|
268 | 271 |
|
269 |
| - def recognise(self, inp) -> dict[str, str]: |
| 272 | + def recognise(self, inp) -> dict[str, str | list[dict[str, str | torch.Tensor]]]: |
270 | 273 | """
|
271 |
| - Function to perform OCR |
| 274 | + Function to perform OCR. |
272 | 275 |
|
273 | 276 | Args:
|
274 |
| - inp: input |
| 277 | + inp: input dict with key 'ocr_data', containing dict, |
| 278 | + { |
| 279 | + 'ocr_images': list[ocr images], |
| 280 | + 'ocr_targets': list[ocr target words] |
| 281 | + } |
275 | 282 |
|
276 | 283 | Returns:
|
277 |
| - dictionary of outputs |
| 284 | + dictionary of outputs: |
| 285 | + { |
| 286 | + 'full_output': [ |
| 287 | + { |
| 288 | + 'generated_text': generated text from ocr model (str), |
| 289 | + 'target': original target text (str) |
| 290 | + } |
| 291 | + ], |
| 292 | + 'output': pieced back together string (str) |
| 293 | + } |
278 | 294 | """
|
279 |
| - # Until the OCR data is available |
280 |
| - # This will need the below comment: |
281 |
| - # type: ignore[misc] |
282 |
| - # TODO https://github.com/alan-turing-institute/ARC-SPICE/issues/14 |
283 |
| - return {"outputs": inp["source_text"]} |
| 295 | + out = self.ocr(inp["ocr_data"]["ocr_images"]) # type: ignore[misc] |
| 296 | + text = " ".join([itm[0]["generated_text"] for itm in out]) |
| 297 | + return { |
| 298 | + "full_output": [ |
| 299 | + { |
| 300 | + "target": target, |
| 301 | + "generated_text": gen_text["generated_text"], |
| 302 | + "entropies": gen_text["entropies"], |
| 303 | + } |
| 304 | + for target, gen_text in zip( |
| 305 | + inp["ocr_data"]["ocr_targets"], out, strict=True |
| 306 | + ) |
| 307 | + ], |
| 308 | + "output": text, |
| 309 | + } |
284 | 310 |
|
285 | 311 | def translate(self, text: str) -> dict[str, torch.Tensor | str]:
|
286 | 312 | """
|
@@ -352,9 +378,7 @@ def classify_topic_zero_shot(self, text: str) -> dict[str, list[float] | dict]:
|
352 | 378 | descriptors["en"]
|
353 | 379 | for descriptors in self.dataset_meta_data["class_descriptors"] # type: ignore[index]
|
354 | 380 | ]
|
355 |
| - forward = self.classifier( # type: ignore[misc] |
356 |
| - text, labels |
357 |
| - ) |
| 381 | + forward = self.classifier(text, labels) # type: ignore[misc] |
358 | 382 | return collate_scores(
|
359 | 383 | [
|
360 | 384 | {"label": label, "score": score}
|
@@ -560,6 +584,28 @@ def get_classification_confidence(
|
560 | 584 | )
|
561 | 585 | return var_output
|
562 | 586 |
|
| 587 | + def get_ocr_confidence(self, var_output: dict) -> dict[str, float]: |
| 588 | + """Generate the ocr confidence score. |
| 589 | +
|
| 590 | + Args: |
| 591 | + var_output: variational run outputs |
| 592 | +
|
| 593 | + Returns: |
| 594 | + dictionary with metrics |
| 595 | + """ |
| 596 | + # Adapted for variational methods from: https://arxiv.org/pdf/2412.01221 |
| 597 | + stacked_entropies = torch.stack( |
| 598 | + [ |
| 599 | + [data["entropies"] for data in output["full_output"]] |
| 600 | + for output in var_output["recognition"] |
| 601 | + ], |
| 602 | + dim=1, |
| 603 | + ) |
| 604 | + # mean entropy |
| 605 | + mean = torch.mean(stacked_entropies) |
| 606 | + var_output["recognition"].update({"mean_entropy": mean}) |
| 607 | + return var_output |
| 608 | + |
563 | 609 |
|
564 | 610 | # Translation pipeline with additional functionality to save logits from fwd pass
|
565 | 611 | class CustomTranslationPipeline(TranslationPipeline):
|
@@ -619,3 +665,38 @@ def _forward(self, model_inputs, **generate_kwargs):
|
619 | 665 | "scores": max_token_scores,
|
620 | 666 | "entropy": normalised_entropy,
|
621 | 667 | }
|
| 668 | + |
| 669 | + |
| 670 | +class CustomOCRPipeline(ImageToTextPipeline): |
| 671 | + """ |
| 672 | + custom OCR pipeline to return logits with the generated text. |
| 673 | + """ |
| 674 | + |
| 675 | + def postprocess(self, model_outputs: dict, **postprocess_params): |
| 676 | + raw_out = copy.deepcopy(model_outputs) |
| 677 | + processed = super().postprocess( |
| 678 | + model_outputs["model_output"], **postprocess_params |
| 679 | + ) |
| 680 | + |
| 681 | + return {"generated_text": processed[0]["generated_text"], "raw_output": raw_out} |
| 682 | + |
| 683 | + def _forward(self, model_inputs, **generate_kwargs): |
| 684 | + if ( |
| 685 | + "input_ids" in model_inputs |
| 686 | + and isinstance(model_inputs["input_ids"], list) |
| 687 | + and all(x is None for x in model_inputs["input_ids"]) |
| 688 | + ): |
| 689 | + model_inputs["input_ids"] = None |
| 690 | + |
| 691 | + inputs = model_inputs.pop(self.model.main_input_name) |
| 692 | + out = self.model.generate( |
| 693 | + inputs, |
| 694 | + **model_inputs, |
| 695 | + **generate_kwargs, |
| 696 | + output_logits=True, |
| 697 | + return_dict_in_generate=True, |
| 698 | + ) |
| 699 | + |
| 700 | + logits = torch.stack(out.logits, dim=1) |
| 701 | + entropy = Categorical(logits=logits).entropy() / np.log(logits[0].size()[1]) |
| 702 | + return {"model_output": out.sequences, "entropies": entropy} |
0 commit comments