|
17 | 17 | RTCVariationalPipeline,
|
18 | 18 | )
|
19 | 19 |
|
20 |
| -RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"]) |
| 20 | +RecognitionResults = namedtuple( |
| 21 | + "RecognitionResults", |
| 22 | + [ |
| 23 | + "mean_entropy", |
| 24 | + "character_error_rate", |
| 25 | + "full_output", |
| 26 | + "max_scores", |
| 27 | + ], |
| 28 | +) |
21 | 29 |
|
22 | 30 | TranslationResults = namedtuple(
|
23 | 31 | "TranslationResults",
|
@@ -71,14 +79,19 @@ def get_results(
|
71 | 79 |
|
72 | 80 | def recognition_results(
|
73 | 81 | self,
|
74 |
| - clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]], |
75 |
| - var_output: dict[str, dict], |
| 82 | + clean_output: dict, |
| 83 | + var_output: dict, |
76 | 84 | **kwargs,
|
77 | 85 | ):
|
78 | 86 | # ### RECOGNITION ###
|
79 |
| - charerror = ocr_error(clean_output) |
| 87 | + charerror = ocr_error(clean_output["recognition"]) |
80 | 88 | confidence = var_output["recognition"]["mean_entropy"]
|
81 |
| - return RecognitionResults(confidence=confidence, accuracy=charerror) |
| 89 | + return RecognitionResults( |
| 90 | + mean_entropy=confidence, |
| 91 | + character_error_rate=charerror, |
| 92 | + max_scores=clean_output["recognition"]["outputs"]["max_scores"], |
| 93 | + full_output=clean_output["recognition"]["full_output"], |
| 94 | + ) |
82 | 95 |
|
83 | 96 | def translation_results(
|
84 | 97 | self,
|
@@ -150,13 +163,40 @@ def run_inference(
|
150 | 163 | pipeline: RTCVariationalPipeline | RTCSingleComponentPipeline,
|
151 | 164 | results_getter: ResultsGetter,
|
152 | 165 | ):
|
| 166 | + type_errors = [] |
| 167 | + oom_errors = [] |
153 | 168 | results = []
|
154 | 169 | for _, inp in enumerate(tqdm(dataloader)):
|
155 |
| - clean_out, var_out = pipeline.variational_inference(inp) |
156 |
| - row_results_dict = results_getter.get_results( |
157 |
| - clean_output=clean_out, |
158 |
| - var_output=var_out, |
159 |
| - test_row=inp, |
160 |
| - ) |
161 |
| - results.append({inp["celex_id"]: row_results_dict}) |
| 170 | + # TEMPORARY FIX |
| 171 | + try: |
| 172 | + clean_out, var_out = pipeline.variational_inference(inp) |
| 173 | + row_results_dict = results_getter.get_results( |
| 174 | + clean_output=clean_out, |
| 175 | + var_output=var_out, |
| 176 | + test_row=inp, |
| 177 | + ) |
| 178 | + results.append({inp["celex_id"]: row_results_dict}) |
| 179 | + # TEMPORARY FIX -> |
| 180 | + except TypeError: |
| 181 | + type_errors.append(inp["celex_id"]) |
| 182 | + continue |
| 183 | + |
| 184 | + except torch.cuda.OutOfMemoryError: |
| 185 | + oom_errors.append(inp["celex_id"]) |
| 186 | + continue |
| 187 | + |
| 188 | + except torch.OutOfMemoryError: |
| 189 | + oom_errors.append(inp["celex_id"]) |
| 190 | + continue |
| 191 | + |
| 192 | + print("Skipped following CELEX IDs due to TypeError:") |
| 193 | + print( |
| 194 | + '"TypeError: Incorrect format used for image. Should be an url linking to' |
| 195 | + ' an image, a base64 string, a local path, or a PIL image."' |
| 196 | + ) |
| 197 | + print(type_errors) |
| 198 | + |
| 199 | + print("Skipped following CELEX IDs due to torch.cuda.OutOfMemoryError:") |
| 200 | + print(oom_errors) |
| 201 | + |
162 | 202 | return results
|
0 commit comments