Skip to content

Commit 6bf60e5

Browse files
authored
Merge pull request #9 from alan-turing-institute/6-mc-dropout-in-hf-pipeline
6 mc dropout in hf pipeline
2 parents 3535452 + ac2f04e commit 6bf60e5

File tree

5 files changed

+483
-0
lines changed

5 files changed

+483
-0
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,7 @@ Thumbs.db
156156
# Common editor files
157157
*~
158158
*.swp
159+
160+
# other
161+
temp
162+
.vscode

scripts/variational_TTS_example.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
An example use of the transcription, translation and summarisation pipeline.
3+
"""
4+
import torch
5+
from datasets import Audio, load_dataset
6+
7+
from arc_spice.dropout_utils.variational_inference import TTSVariationalPipeline
8+
9+
10+
def main(TTS_params):
11+
"""main function"""
12+
var_pipe = TTSVariationalPipeline(TTS_params,n_variational_runs=2)
13+
14+
ds = load_dataset(
15+
"facebook/multilingual_librispeech", "french", split="test", streaming=True
16+
)
17+
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
18+
input_speech = next(iter(ds))["audio"]
19+
20+
var_pipe.clean_inference(input_speech["array"])
21+
clean_output = var_pipe.clean_output
22+
23+
# logit shapes
24+
print("\nLogit shapes:")
25+
for step in var_pipe.pipeline_map.keys():
26+
print(f"{step.capitalize()}: {clean_output[step]["logits"].shape}")
27+
28+
# entropy
29+
print("\nMean entropy:")
30+
for step in var_pipe.pipeline_map.keys():
31+
print(f"{step.capitalize()}: {torch.mean(clean_output[step]["entropy"])}")
32+
33+
# normalised entropy
34+
print("\nNormalised mean entropy:")
35+
cumulative = 1
36+
for step in var_pipe.pipeline_map.keys():
37+
step_entropy = torch.mean(clean_output[step]["normalised_entropy"])
38+
cumulative*= (1-step_entropy)
39+
print(f"{step.capitalize()}: {step_entropy}")
40+
print(f"Cumulative confidence (1 - entropy): {cumulative}")
41+
42+
# probabilities
43+
print("\nMean top probabilities:")
44+
cumulative = 1
45+
for step in var_pipe.pipeline_map.keys():
46+
step_prob = torch.mean(clean_output[step]["probs"])
47+
cumulative *= step_prob
48+
print(f"{step.capitalize()}: {step_prob}")
49+
print(f"Cumulative confidence: {cumulative}")
50+
51+
print("\nConditional probabilities:")
52+
for step in var_pipe.pipeline_map.keys():
53+
token_probs = clean_output[step]["probs"]
54+
cond_prob = torch.pow(torch.prod(token_probs,-1),1/len(token_probs))
55+
print(f"{step.capitalize()}: {cond_prob}")
56+
57+
var_pipe.variational_inference(x=input_speech['array'])
58+
variational_output = var_pipe.var_output
59+
print("\nVariational Inference Semantic Density:")
60+
for step in variational_output['variational'].keys():
61+
print(f"{step}: {variational_output['variational'][step]['semantic_density']}")
62+
63+
64+
if __name__ == "__main__":
65+
TTS_pars = {
66+
"transcriber": {
67+
"specific_task": "automatic-speech-recognition",
68+
"model": "jonatasgrosman/wav2vec2-large-xlsr-53-french",
69+
},
70+
"translator": {
71+
"specific_task": "translation_fr_to_en",
72+
"model": "ybanas/autotrain-fr-en-translate-51410121895",
73+
},
74+
"summariser": {
75+
"specific_task": "summarization",
76+
"model": "marianna13/flan-t5-base-summarization",
77+
},
78+
}
79+
main(TTS_params=TTS_pars)

src/arc_spice/dropout_utils/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from transformers import Pipeline, pipeline
3+
4+
5+
def set_dropout(model: torch.nn.Module, dropout_flag: bool) -> None:
6+
"""
7+
Turn on or turn off dropout layers of a model.
8+
9+
Args:
10+
model: pytorch model
11+
dropout_flag: dropout -> True/False
12+
"""
13+
for _, param in model.named_modules():
14+
if isinstance(param, torch.nn.Dropout):
15+
# dropout on (True) -> want training mode train(True)
16+
# dropout off (False) -> eval mode train(False)
17+
param.train(dropout_flag)
18+
19+
20+
def MCDropoutPipeline(task: str, model: str):
21+
pl = pipeline(
22+
task=task,
23+
model=model,
24+
)
25+
initial_model = pl.model
26+
pl.model = set_dropout(model=initial_model, dropout_flag=True)
27+
return pl
28+
29+
30+
def test_dropout(pipe: Pipeline, dropout_flag: bool):
31+
model = pipe.model
32+
dropout_count = 0
33+
for _, param in model.named_modules():
34+
if isinstance(param, torch.nn.Dropout):
35+
dropout_count += 1
36+
assert param.training == dropout_flag
37+
38+
print(f"{dropout_count} dropout layers found in correct configuration.")

0 commit comments

Comments
 (0)