Skip to content

Commit af48055

Browse files
authored
Merge pull request #31 from alan-turing-institute/27-integrate-trained-classifier
added functionality for loading pre-trained classification models
2 parents 0b79ed9 + 0f704b6 commit af48055

File tree

8 files changed

+166
-88
lines changed

8 files changed

+166
-88
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
ocr:
2+
specific_task: "image-to-text"
3+
model: "microsoft/trocr-base-handwritten"
4+
5+
translator:
6+
specific_task: "translation_fr_to_en"
7+
model: "ybanas/autotrain-fr-en-translate-51410121895"
8+
9+
classifier:
10+
specific_task: "text-classification"
11+
model: "../distilbert-topic-classifier/"
12+
kwargs:
13+
truncation : True
14+
padding : True

scripts/single_component_inference.py

-13
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
"""
2-
Steps:
3-
- Load data
4-
- Load pipeline/model
5-
- Run inference on all test data
6-
- Save outputs of specified model (on clean data)
7-
- Calculate error of specified model (on clean data)
8-
9-
- Save results
10-
- File structure:
11-
- output/check_callibration/pipeline_name/run_[X]/[OUTPUT FILES HERE]
12-
"""
13-
141
import json
152
import os
163

scripts/variational_RTC_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main(rtc_pars):
8888
rtc_variational_pipeline = RTCVariationalPipeline(rtc_pars, metadata_params)
8989

9090
# check dropout exists
91-
rtc_variational_pipeline.check_dropout()
91+
rtc_variational_pipeline.check_dropout(rtc_variational_pipeline.pipeline_map)
9292

9393
# perform variational inference
9494
clean_output, var_output = rtc_variational_pipeline.variational_inference(test_row)

src/arc_spice/eval/inference_utils.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
[
3636
"clean_scores",
3737
"mean_scores",
38-
"hamming_accuracy",
38+
"hamming_loss",
3939
"mean_predicted_entropy",
4040
],
4141
)
@@ -126,12 +126,12 @@ def classification_results(
126126
clean_scores: torch.Tensor = clean_output["classification"]["scores"]
127127
preds = torch.round(mean_scores).tolist()
128128
labels = self.multihot(test_row["labels"])
129-
hamming_acc = hamming_loss(y_pred=preds, y_true=labels)
129+
hmng_loss = hamming_loss(y_pred=preds, y_true=labels)
130130

131131
return ClassificationResults(
132132
mean_scores=mean_scores.detach().tolist(),
133+
hamming_loss=hmng_loss,
133134
clean_scores=clean_scores,
134-
hamming_accuracy=hamming_acc,
135135
mean_predicted_entropy=torch.mean(
136136
var_output["classification"]["predicted_entropy"]
137137
).item(),
@@ -152,5 +152,4 @@ def run_inference(
152152
test_row=inp,
153153
)
154154
results.append({inp["celex_id"]: row_results_dict})
155-
break
156155
return results

src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any
22

33
import torch
4+
import transformers
45
from transformers import pipeline
56

67
from arc_spice.variational_pipelines.RTC_variational_pipeline import (
@@ -10,6 +11,7 @@
1011
CustomTranslationPipeline,
1112
dropout_off,
1213
dropout_on,
14+
set_classifier,
1315
set_dropout,
1416
)
1517

@@ -88,7 +90,7 @@ def __init__(
8890
**kwargs,
8991
):
9092
self.set_device()
91-
self.ocr = pipeline(
93+
self.ocr: transformers.Pipeline = pipeline(
9294
task=model_pars["ocr"]["specific_task"],
9395
model=model_pars["ocr"]["model"],
9496
device=self.device,
@@ -125,7 +127,7 @@ def __init__(
125127
n_variational_runs=n_variational_runs,
126128
translation_batch_size=translation_batch_size,
127129
)
128-
self.translator = pipeline(
130+
self.translator: transformers.Pipeline = pipeline(
129131
task=model_pars["translator"]["specific_task"],
130132
model=model_pars["translator"]["model"],
131133
max_length=512,
@@ -151,25 +153,24 @@ def __init__(
151153
n_variational_runs=5,
152154
**kwargs,
153155
):
154-
self.set_device()
156+
if model_pars["classifier"]["specific_task"] == "zero-shot-classification":
157+
zero_shot = True
158+
else:
159+
zero_shot = False
155160
super().__init__(
156161
step_name="classification",
157162
input_key="target_text",
158-
forward_function=self.classify_topic,
163+
forward_function=self.classify_topic_zero_shot
164+
if zero_shot
165+
else self.classify_topic,
159166
confidence_function=self.get_classification_confidence,
160167
n_variational_runs=n_variational_runs,
161168
**kwargs,
162169
)
163-
self.classifier = pipeline(
164-
task=model_pars["classifier"]["specific_task"],
165-
model=model_pars["classifier"]["model"],
166-
multi_label=True,
167-
device=self.device,
170+
self.classifier: transformers.Pipeline = set_classifier(
171+
model_pars["classifier"], self.device
168172
)
169173
self.model = self.classifier.model
170174
# topic description labels for the classifier
171-
self.topic_labels = [
172-
class_names_dict["en"]
173-
for class_names_dict in data_pars["class_descriptors"]
174-
]
175+
self.dataset_meta_data = data_pars
175176
self._init_pipeline_map()

src/arc_spice/variational_pipelines/RTC_variational_pipeline.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
RTCVariationalPipelineBase,
99
dropout_off,
1010
dropout_on,
11+
set_classifier,
1112
set_dropout,
1213
)
1314

@@ -38,7 +39,12 @@ def __init__(
3839
n_variational_runs=5,
3940
translation_batch_size=16,
4041
) -> None:
41-
super().__init__(n_variational_runs, translation_batch_size)
42+
# are we doing zero-shot-classification?
43+
if model_pars["classifier"]["specific_task"] == "zero-shot-classification":
44+
self.zero_shot = True
45+
else:
46+
self.zero_shot = False
47+
super().__init__(self.zero_shot, n_variational_runs, translation_batch_size)
4248
# defining the pipeline objects
4349
self.ocr = pipeline(
4450
task=model_pars["ocr"]["specific_task"],
@@ -52,18 +58,9 @@ def __init__(
5258
pipeline_class=CustomTranslationPipeline,
5359
device=self.device,
5460
)
55-
self.classifier = pipeline(
56-
task=model_pars["classifier"]["specific_task"],
57-
model=model_pars["classifier"]["model"],
58-
multi_label=True,
59-
device=self.device,
60-
)
61-
# topic description labels for the classifier
62-
self.topic_labels = [
63-
class_names_dict["en"]
64-
for class_names_dict in data_pars["class_descriptors"]
65-
]
66-
61+
self.classifier = set_classifier(model_pars["classifier"], self.device)
62+
# topic meta_data for the classifier
63+
self.dataset_meta_data = data_pars
6764
self._init_semantic_density()
6865
self._init_pipeline_map()
6966

@@ -83,9 +80,15 @@ def clean_inference(self, x: torch.Tensor) -> dict[str, dict]:
8380
clean_output["translation"] = self.translate(
8481
clean_output["recognition"]["outputs"]
8582
)
86-
clean_output["classification"] = self.classify_topic(
87-
clean_output["translation"]["outputs"][0]
88-
)
83+
# we now need to pass the input correct to the correct forward method
84+
if self.zero_shot:
85+
clean_output["classification"] = self.classify_topic_zero_shot(
86+
clean_output["translation"]["outputs"][0]
87+
)
88+
else:
89+
clean_output["classification"] = self.classify_topic(
90+
clean_output["translation"]["outputs"][0]
91+
)
8992
return clean_output
9093

9194
def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
@@ -110,15 +113,15 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
110113
# for each model in pipeline
111114
for model_key, pl in self.pipeline_map.items():
112115
# turn on dropout for this model
113-
set_dropout(model=pl.model, dropout_flag=True) # type: ignore[union-attr]
116+
set_dropout(model=pl.model, dropout_flag=True) # type: ignore[union-attr,attr-defined]
114117
torch.nn.functional.dropout = dropout_on
115118
# do n runs of the inference
116119
for run_idx in range(self.n_variational_runs):
117120
var_output[model_key][run_idx] = self.func_map[model_key](
118121
input_map[model_key]
119122
)
120123
# turn off dropout for this model
121-
set_dropout(model=pl.model, dropout_flag=False) # type: ignore[union-attr]
124+
set_dropout(model=pl.model, dropout_flag=False) # type: ignore[union-attr,attr-defined]
122125
torch.nn.functional.dropout = dropout_off
123126

124127
# run metric helper functions

src/arc_spice/variational_pipelines/utils.py

+77-10
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,52 @@
66
from typing import Any
77

88
import torch
9+
import transformers
910
from torch.nn.functional import softmax
1011
from transformers import (
1112
AutoModelForSequenceClassification,
1213
AutoTokenizer,
1314
Pipeline,
1415
TranslationPipeline,
16+
pipeline,
1517
)
1618

1719
logger = logging.Logger("RTC_variational_pipeline")
1820

21+
# Some methods for the
22+
23+
24+
def collate_scores(
25+
scores: list[dict[str, float]], label_order
26+
) -> dict[str, list | dict]:
27+
# these need to be returned in original order
28+
# return dict for to guarantee class predictions can be recovered
29+
collated = {score["label"]: score["score"] for score in scores}
30+
return {
31+
"scores": [collated[label] for label in label_order],
32+
"score_dict": collated,
33+
}
34+
35+
36+
def set_classifier(classifier_pars: dict, device: str) -> transformers.Pipeline:
37+
# new helper function which given the classifier parameters sets the correct
38+
# pipeline method. This is needed because they take different kwargs
39+
# > THIS COULD BE REFACTORED BY PUTTING KWARGS IN THE CONFIG <
40+
if classifier_pars["specific_task"] == "zero-shot-classification":
41+
return pipeline(
42+
task=classifier_pars["specific_task"],
43+
model=classifier_pars["model"],
44+
multi_label=True,
45+
device=device,
46+
**classifier_pars.get("kwargs", {}),
47+
)
48+
return pipeline(
49+
task=classifier_pars["specific_task"],
50+
model=classifier_pars["model"],
51+
device=device,
52+
**classifier_pars.get("kwargs", {}),
53+
)
54+
1955

2056
def set_dropout(model: torch.nn.Module, dropout_flag: bool) -> None:
2157
"""
@@ -104,7 +140,7 @@ def clean_inference(self, x):
104140
def variational_inference(self, x):
105141
pass
106142

107-
def __init__(self, n_variational_runs=5, translation_batch_size=8):
143+
def __init__(self, zero_shot: bool, n_variational_runs=5, translation_batch_size=8):
108144
# device for inference
109145
self.set_device()
110146
debug_msg_device = f"Loading pipeline on device: {self.device}"
@@ -113,7 +149,9 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
113149
self.func_map = {
114150
"recognition": self.recognise,
115151
"translation": self.translate,
116-
"classification": self.classify_topic,
152+
"classification": self.classify_topic_zero_shot
153+
if zero_shot
154+
else self.classify_topic,
117155
}
118156
# the naive outputs of the pipeline stages calculated in self.clean_inference
119157
self.naive_outputs = {
@@ -139,8 +177,10 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
139177
self.classifier = None
140178

141179
# map pipeline names to their pipeline counterparts
142-
143-
self.topic_labels = None # This should be defined in subclass if needed
180+
# to replace class descriptors, we now want class descriptors and the labels
181+
self.dataset_meta_data: dict = {
182+
None: None
183+
} # This should be defined in subclass if needed
144184

145185
def _init_pipeline_map(self):
146186
"""
@@ -193,15 +233,16 @@ def split_translate_inputs(text: str, split_key: str) -> list[str]:
193233
split_rows = split_rows[:-1]
194234
return [split + split_key for split in split_rows]
195235

196-
def check_dropout(self):
236+
@staticmethod
237+
def check_dropout(pipeline_map: transformers.Pipeline):
197238
"""
198239
Checks the existence of dropout layers in the models of the pipeline.
199240
200241
Raises:
201242
ValueError: Raised when no dropout layers are found.
202243
"""
203244
logger.debug("\n\n------------------ Testing Dropout --------------------")
204-
for model_key, pl in self.pipeline_map.items():
245+
for model_key, pl in pipeline_map.items():
205246
# only test models that exist
206247
if pl is None:
207248
pipeline_none_msg_key = (
@@ -288,15 +329,41 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]:
288329
# {full translation, sentence translations, logits, semantic embeddings}
289330
return outputs
290331

291-
def classify_topic(self, text: str) -> dict[str, str]:
332+
def classify_topic(self, text: str) -> dict[str, list[float] | dict]:
292333
"""
293334
Runs the classification model
294335
295336
Returns:
296-
Dictionary of classification outputs, namely the output scores.
337+
Dictionary of classification outputs, namely the output scores and
338+
label:score dictionary.
339+
"""
340+
forward = self.classifier(text, top_k=None) # type: ignore[misc]
341+
return collate_scores(forward, self.dataset_meta_data["class_labels"]) # type: ignore[index]
342+
343+
def classify_topic_zero_shot(self, text: str) -> dict[str, list[float] | dict]:
344+
"""
345+
Runs the zero-shot classification model
346+
347+
Returns:
348+
Dictionary of classification outputs, namely the output scores and
349+
label:score dictionary.
297350
"""
298-
forward = self.classifier(text, self.topic_labels) # type: ignore[misc]
299-
return {"scores": forward["scores"]}
351+
labels = [
352+
descriptors["en"]
353+
for descriptors in self.dataset_meta_data["class_descriptors"] # type: ignore[index]
354+
]
355+
forward = self.classifier( # type: ignore[misc]
356+
text, labels
357+
)
358+
return collate_scores(
359+
[
360+
{"label": label, "score": score}
361+
for label, score in zip(
362+
forward["labels"], forward["scores"], strict=True
363+
)
364+
],
365+
label_order=labels,
366+
)
300367

301368
def stack_translator_sentence_metrics(
302369
self, all_sentence_metrics: list[dict[str, Any]]

0 commit comments

Comments
 (0)