6
6
from typing import Any
7
7
8
8
import torch
9
+ import transformers
9
10
from torch .nn .functional import softmax
10
11
from transformers import (
11
12
AutoModelForSequenceClassification ,
12
13
AutoTokenizer ,
13
14
Pipeline ,
14
15
TranslationPipeline ,
16
+ pipeline ,
15
17
)
16
18
17
19
logger = logging .Logger ("RTC_variational_pipeline" )
18
20
21
+ # Some methods for the
22
+
23
+
24
+ def collate_scores (
25
+ scores : list [dict [str , float ]], label_order
26
+ ) -> dict [str , list | dict ]:
27
+ # these need to be returned in original order
28
+ # return dict for to guarantee class predictions can be recovered
29
+ collated = {score ["label" ]: score ["score" ] for score in scores }
30
+ return {
31
+ "scores" : [collated [label ] for label in label_order ],
32
+ "score_dict" : collated ,
33
+ }
34
+
35
+
36
+ def set_classifier (classifier_pars : dict , device : str ) -> transformers .Pipeline :
37
+ # new helper function which given the classifier parameters sets the correct
38
+ # pipeline method. This is needed because they take different kwargs
39
+ # > THIS COULD BE REFACTORED BY PUTTING KWARGS IN THE CONFIG <
40
+ if classifier_pars ["specific_task" ] == "zero-shot-classification" :
41
+ return pipeline (
42
+ task = classifier_pars ["specific_task" ],
43
+ model = classifier_pars ["model" ],
44
+ multi_label = True ,
45
+ device = device ,
46
+ ** classifier_pars .get ("kwargs" , {}),
47
+ )
48
+ return pipeline (
49
+ task = classifier_pars ["specific_task" ],
50
+ model = classifier_pars ["model" ],
51
+ device = device ,
52
+ ** classifier_pars .get ("kwargs" , {}),
53
+ )
54
+
19
55
20
56
def set_dropout (model : torch .nn .Module , dropout_flag : bool ) -> None :
21
57
"""
@@ -104,7 +140,7 @@ def clean_inference(self, x):
104
140
def variational_inference (self , x ):
105
141
pass
106
142
107
- def __init__ (self , n_variational_runs = 5 , translation_batch_size = 8 ):
143
+ def __init__ (self , zero_shot : bool , n_variational_runs = 5 , translation_batch_size = 8 ):
108
144
# device for inference
109
145
self .set_device ()
110
146
debug_msg_device = f"Loading pipeline on device: { self .device } "
@@ -113,7 +149,9 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
113
149
self .func_map = {
114
150
"recognition" : self .recognise ,
115
151
"translation" : self .translate ,
116
- "classification" : self .classify_topic ,
152
+ "classification" : self .classify_topic_zero_shot
153
+ if zero_shot
154
+ else self .classify_topic ,
117
155
}
118
156
# the naive outputs of the pipeline stages calculated in self.clean_inference
119
157
self .naive_outputs = {
@@ -139,8 +177,10 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
139
177
self .classifier = None
140
178
141
179
# map pipeline names to their pipeline counterparts
142
-
143
- self .topic_labels = None # This should be defined in subclass if needed
180
+ # to replace class descriptors, we now want class descriptors and the labels
181
+ self .dataset_meta_data : dict = {
182
+ None : None
183
+ } # This should be defined in subclass if needed
144
184
145
185
def _init_pipeline_map (self ):
146
186
"""
@@ -193,15 +233,16 @@ def split_translate_inputs(text: str, split_key: str) -> list[str]:
193
233
split_rows = split_rows [:- 1 ]
194
234
return [split + split_key for split in split_rows ]
195
235
196
- def check_dropout (self ):
236
+ @staticmethod
237
+ def check_dropout (pipeline_map : transformers .Pipeline ):
197
238
"""
198
239
Checks the existence of dropout layers in the models of the pipeline.
199
240
200
241
Raises:
201
242
ValueError: Raised when no dropout layers are found.
202
243
"""
203
244
logger .debug ("\n \n ------------------ Testing Dropout --------------------" )
204
- for model_key , pl in self . pipeline_map .items ():
245
+ for model_key , pl in pipeline_map .items ():
205
246
# only test models that exist
206
247
if pl is None :
207
248
pipeline_none_msg_key = (
@@ -288,15 +329,41 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]:
288
329
# {full translation, sentence translations, logits, semantic embeddings}
289
330
return outputs
290
331
291
- def classify_topic (self , text : str ) -> dict [str , str ]:
332
+ def classify_topic (self , text : str ) -> dict [str , list [ float ] | dict ]:
292
333
"""
293
334
Runs the classification model
294
335
295
336
Returns:
296
- Dictionary of classification outputs, namely the output scores.
337
+ Dictionary of classification outputs, namely the output scores and
338
+ label:score dictionary.
339
+ """
340
+ forward = self .classifier (text , top_k = None ) # type: ignore[misc]
341
+ return collate_scores (forward , self .dataset_meta_data ["class_labels" ]) # type: ignore[index]
342
+
343
+ def classify_topic_zero_shot (self , text : str ) -> dict [str , list [float ] | dict ]:
344
+ """
345
+ Runs the zero-shot classification model
346
+
347
+ Returns:
348
+ Dictionary of classification outputs, namely the output scores and
349
+ label:score dictionary.
297
350
"""
298
- forward = self .classifier (text , self .topic_labels ) # type: ignore[misc]
299
- return {"scores" : forward ["scores" ]}
351
+ labels = [
352
+ descriptors ["en" ]
353
+ for descriptors in self .dataset_meta_data ["class_descriptors" ] # type: ignore[index]
354
+ ]
355
+ forward = self .classifier ( # type: ignore[misc]
356
+ text , labels
357
+ )
358
+ return collate_scores (
359
+ [
360
+ {"label" : label , "score" : score }
361
+ for label , score in zip (
362
+ forward ["labels" ], forward ["scores" ], strict = True
363
+ )
364
+ ],
365
+ label_order = labels ,
366
+ )
300
367
301
368
def stack_translator_sentence_metrics (
302
369
self , all_sentence_metrics : list [dict [str , Any ]]
0 commit comments