Skip to content

Commit 92a98f0

Browse files
committed
addressed comments from pull request, also added additional outputs to translation to allow additional confidence measures
1 parent 7ac88b8 commit 92a98f0

File tree

6 files changed

+99
-118
lines changed

6 files changed

+99
-118
lines changed

src/arc_spice/eval/classification_error.py

-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
import math
2-
31
import torch
4-
from sklearn.metrics import zero_one_loss
52

63

74
def aggregate_score(probs: torch.Tensor) -> torch.Tensor:
@@ -11,10 +8,6 @@ def aggregate_score(probs: torch.Tensor) -> torch.Tensor:
118
return 1 - torch.mean(distance)
129

1310

14-
def zero_one_loss_ceil(y_target, y_pred):
15-
return math.ceil(zero_one_loss(y_target, y_pred, normalize=True))
16-
17-
1811
def MC_dropout_scores(
1912
variational_probs: list[float], epsilon: float = 1e-14
2013
) -> dict[str, torch.Tensor]:

src/arc_spice/eval/inference_utils.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,26 @@
1717
)
1818

1919
RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"])
20-
ClassificationResults = namedtuple(
21-
"ClassificationResults",
22-
[
23-
"clean_scores",
24-
"mean_scores",
25-
"hamming_accuracy",
26-
"mean_predicted_entropy",
27-
],
28-
)
20+
2921
TranslationResults = namedtuple(
3022
"TranslationResults",
3123
[
3224
"full_output",
3325
"clean_conditional_probability",
3426
"comet_score",
3527
"weighted_semantic_density",
28+
"mean_entropy",
29+
"sequence_lengths",
30+
],
31+
)
32+
33+
ClassificationResults = namedtuple(
34+
"ClassificationResults",
35+
[
36+
"clean_scores",
37+
"mean_scores",
38+
"hamming_accuracy",
39+
"mean_predicted_entropy",
3640
],
3741
)
3842

@@ -79,6 +83,8 @@ def translation_results(
7983
source_text = test_row["target_text"]
8084
target_text = test_row["target_text"]
8185
clean_translation = clean_output["translation"]["full_output"]
86+
clean_entropy: torch.Tensor = clean_output["translation"]["mean_entropy"]
87+
seq_lens: torch.Tensor = var_output["translation"]["sequence_length"]
8288
probs: list[torch.Tensor] = clean_output["translation"]["probs"]
8389
clean_cond_prob = [
8490
conditional_probability(prob.squeeze()).detach().tolist() for prob in probs
@@ -102,6 +108,8 @@ def translation_results(
102108
comet_score=comet_output["scores"][0],
103109
full_output=clean_translation,
104110
clean_conditional_probability=clean_cond_prob,
111+
mean_entropy=clean_entropy,
112+
sequence_lengths=seq_lens,
105113
weighted_semantic_density=var_output["translation"][
106114
"weighted_semantic_density"
107115
],
@@ -144,4 +152,5 @@ def run_inference(
144152
test_row=inp,
145153
)
146154
results.append({inp["celex_id"]: row_results_dict})
155+
break
147156
return results

src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
from transformers import pipeline
55

66
from arc_spice.variational_pipelines.RTC_variational_pipeline import (
7-
CustomTranslationPipeline,
87
RTCVariationalPipelineBase,
98
)
10-
from arc_spice.variational_pipelines.utils import dropout_off, dropout_on, set_dropout
9+
from arc_spice.variational_pipelines.utils import (
10+
CustomTranslationPipeline,
11+
dropout_off,
12+
dropout_on,
13+
set_dropout,
14+
)
1115

1216

1317
class RTCSingleComponentPipeline(RTCVariationalPipelineBase):
@@ -34,19 +38,6 @@ def __init__(
3438
# define objects that are needed and nothing else
3539
# naive outputs can remain the same, though only the appropriate outputs will
3640
# be outputted
37-
self.naive_outputs = {
38-
"recognition": [
39-
"outputs",
40-
],
41-
"translation": [
42-
"full_output",
43-
"outputs",
44-
"probs",
45-
],
46-
"classification": [
47-
"scores",
48-
],
49-
}
5041
self.step_name = step_name
5142
self.input_key = input_key
5243
self.forward_function = forward_function

src/arc_spice/variational_pipelines/RTC_variational_pipeline.py

+2-55
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import copy
21
from typing import Any
32

43
import torch
5-
from torch.nn.functional import softmax
6-
from transformers import TranslationPipeline, pipeline
4+
from transformers import pipeline
75

86
from arc_spice.variational_pipelines.utils import (
7+
CustomTranslationPipeline,
98
RTCVariationalPipelineBase,
109
dropout_off,
1110
dropout_on,
@@ -134,55 +133,3 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
134133
# on standard call return the clean output
135134
def __call__(self, x):
136135
return self.clean_inference(x)
137-
138-
139-
# Translation pipeline with additional functionality to save logits from fwd pass
140-
class CustomTranslationPipeline(TranslationPipeline):
141-
"""
142-
custom translation pipeline to return the logits with the generated text. Largely
143-
the same as the pytorch version with some additional arguments passed to the
144-
`generate` method.
145-
"""
146-
147-
def postprocess(
148-
self,
149-
model_outputs: dict,
150-
**postprocess_params,
151-
):
152-
# model_outputs gets overwritten in the super().postprocess call
153-
# make a copy here so we retain the information we want
154-
raw_out = copy.deepcopy(model_outputs)
155-
processed = super().postprocess(model_outputs, **postprocess_params)
156-
157-
return {
158-
"translation_text": processed[0]["translation_text"],
159-
"raw_outputs": raw_out,
160-
}
161-
162-
def _forward(self, model_inputs, **generate_kwargs):
163-
if self.framework == "pt":
164-
in_b, input_length = model_inputs["input_ids"].shape
165-
elif self.framework == "tf":
166-
raise NotImplementedError
167-
168-
self.check_inputs(
169-
input_length,
170-
generate_kwargs.get("min_length", self.model.config.min_length),
171-
generate_kwargs.get("max_length", self.model.config.max_length),
172-
)
173-
out = self.model.generate(**model_inputs, **generate_kwargs)
174-
output_ids = out["sequences"]
175-
out_b = output_ids.shape[0]
176-
if self.framework == "pt":
177-
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
178-
elif self.framework == "tf":
179-
raise NotImplementedError
180-
181-
# logits are a tuple of length output_ids[-1]-1
182-
# each element is a tensor of shape (batch_size, vocab_size)
183-
logits = torch.stack(out["logits"], dim=1)
184-
# get softmax of the logits to get token probabilities
185-
softmax_logits = softmax(logits, dim=-1)
186-
max_token_scores = torch.max(softmax_logits, dim=-1).values
187-
188-
return {"output_ids": output_ids, "scores": max_token_scores}

src/arc_spice/variational_pipelines/utils.py

+73-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1+
import copy
12
import logging
3+
import math
24
from abc import ABC, abstractmethod
35
from functools import partial
46
from typing import Any
57

68
import torch
79
from torch.nn.functional import softmax
8-
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Pipeline
10+
from transformers import (
11+
AutoModelForSequenceClassification,
12+
AutoTokenizer,
13+
Pipeline,
14+
TranslationPipeline,
15+
)
916

1017
logger = logging.Logger("RTC_variational_pipeline")
1118

@@ -117,6 +124,7 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
117124
"full_output",
118125
"outputs",
119126
"probs",
127+
"mean_entropy",
120128
],
121129
"classification": [
122130
"scores",
@@ -264,6 +272,9 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]:
264272
{
265273
"outputs": translator_output["translation_text"],
266274
"probs": translator_output["raw_outputs"]["scores"],
275+
"mean_entropy": torch.mean(translator_output["raw_outputs"]["entropy"])
276+
.detach()
277+
.tolist(),
267278
}
268279
for translator_output in translator_outputs
269280
]
@@ -430,6 +441,7 @@ def translation_semantic_density(
430441
{
431442
"semantic_densities": densities,
432443
"weighted_semantic_density": weighted_average.item(),
444+
"sequence_length": sequence_lengths,
433445
}
434446
)
435447

@@ -480,3 +492,63 @@ def get_classification_confidence(
480492
}
481493
)
482494
return var_output
495+
496+
497+
# Translation pipeline with additional functionality to save logits from fwd pass
498+
class CustomTranslationPipeline(TranslationPipeline):
499+
"""
500+
custom translation pipeline to return the logits with the generated text. Largely
501+
the same as the pytorch version with some additional arguments passed to the
502+
`generate` method.
503+
"""
504+
505+
def postprocess(
506+
self,
507+
model_outputs: dict,
508+
**postprocess_params,
509+
):
510+
# model_outputs gets overwritten in the super().postprocess call
511+
# make a copy here so we retain the information we want
512+
raw_out = copy.deepcopy(model_outputs)
513+
processed = super().postprocess(model_outputs, **postprocess_params)
514+
515+
return {
516+
"translation_text": processed[0]["translation_text"],
517+
"raw_outputs": raw_out,
518+
}
519+
520+
def _forward(self, model_inputs, **generate_kwargs):
521+
if self.framework == "pt":
522+
in_b, input_length = model_inputs["input_ids"].shape
523+
elif self.framework == "tf":
524+
raise NotImplementedError
525+
526+
self.check_inputs(
527+
input_length,
528+
generate_kwargs.get("min_length", self.model.config.min_length),
529+
generate_kwargs.get("max_length", self.model.config.max_length),
530+
)
531+
out = self.model.generate(**model_inputs, **generate_kwargs)
532+
output_ids = out["sequences"]
533+
out_b = output_ids.shape[0]
534+
if self.framework == "pt":
535+
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
536+
elif self.framework == "tf":
537+
raise NotImplementedError
538+
539+
# logits are a tuple of length output_ids[-1]-1
540+
# each element is a tensor of shape (batch_size, vocab_size)
541+
logits = torch.stack(out["logits"], dim=1)
542+
# get softmax of the logits to get token probabilities
543+
softmax_logits = softmax(logits, dim=-1)
544+
vocab_size = softmax_logits.shape[-1]
545+
normalised_entropy = torch.distributions.Categorical(
546+
probs=softmax_logits
547+
).entropy() / math.log(vocab_size)
548+
max_token_scores = torch.max(softmax_logits, dim=-1).values
549+
550+
return {
551+
"output_ids": output_ids,
552+
"scores": max_token_scores,
553+
"entropy": normalised_entropy,
554+
}

tests/test_inference.py

-31
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
from unittest.mock import MagicMock, patch
33

44
import pytest
5-
from sklearn.metrics import hamming_loss
65

7-
from arc_spice.eval.classification_error import zero_one_loss_ceil
86
from arc_spice.utils import open_yaml_path
97
from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
108
ClassificationVariationalPipeline,
@@ -42,35 +40,6 @@ def dummy_metadata():
4240
}
4341

4442

45-
def test_errors():
46-
dummy_target = [0, 1, 0, 1, 0]
47-
dummy_middle_output = [1, 1, 0, 1, 0]
48-
49-
assert hamming_loss(dummy_target, dummy_middle_output) == pytest.approx(
50-
0.2, abs=1e-5
51-
)
52-
assert zero_one_loss_ceil(dummy_target, dummy_middle_output) == pytest.approx(
53-
1.0, abs=1e-5
54-
)
55-
56-
dummy_correct_output = [0, 1, 0, 1, 0]
57-
58-
assert hamming_loss(dummy_target, dummy_correct_output) == pytest.approx(
59-
0.0, abs=1e-5
60-
)
61-
assert zero_one_loss_ceil(dummy_target, dummy_correct_output) == pytest.approx(
62-
0.0, abs=1e-5
63-
)
64-
65-
dummy_incorrect_output = [1, 0, 1, 0, 1]
66-
assert hamming_loss(dummy_target, dummy_incorrect_output) == pytest.approx(
67-
1.0, abs=1e-5
68-
)
69-
assert zero_one_loss_ceil(dummy_target, dummy_incorrect_output) == pytest.approx(
70-
1.0, abs=1e-5
71-
)
72-
73-
7443
def test_pipeline_inputs(dummy_data, dummy_metadata):
7544
pipeline_config = open_yaml_path(PIPELINE_PATH)
7645

0 commit comments

Comments
 (0)