Skip to content

Commit

Permalink
pipeline now calculates a sentence embedding for each input
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Oct 4, 2024
1 parent efac13e commit 8a41bb6
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 28 deletions.
12 changes: 6 additions & 6 deletions scripts/variational_TTS_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ def main(TTS_params):

clean_output = var_pipe.clean_inference(input_speech["array"])
# logit shapes
print("Logit shapes:")
print("\nLogit shapes:")
for step in var_pipe.pipeline_map.keys():
print(f"{step.capitalize()}: {clean_output[step]["logits"].shape}")

# entropy
print("Mean entropy:")
print("\nMean entropy:")
for step in var_pipe.pipeline_map.keys():
print(f"{step.capitalize()}: {torch.mean(clean_output[step]["entropy"])}")

# normalised entropy
print("Normalised mean entropy:")
print("\nNormalised mean entropy:")
cumulative = 1
for step in var_pipe.pipeline_map.keys():
step_entropy = torch.mean(clean_output[step]["normalised_entropy"])
Expand All @@ -39,7 +39,7 @@ def main(TTS_params):
print(f"Cumulative confidence (1 - entropy): {cumulative}")

# probabilities
print("Mean top probabilities:")
print("\nMean top probabilities:")
cumulative = 1
for step in var_pipe.pipeline_map.keys():
step_prob = torch.mean(clean_output[step]["probs"])
Expand All @@ -51,10 +51,10 @@ def main(TTS_params):


for step in var_pipe.pipeline_map.keys():
print(f'{step}:')
print(f'\n{step}:')
step_output = variational_output['variational'][step]
for run in step_output:
print(run['logits'])
print(run['semantic_embedding'])

if __name__ == "__main__":
TTS_pars = {
Expand Down
86 changes: 64 additions & 22 deletions src/arc_spice/dropout_utils/variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,35 @@

import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.nn.functional import softmax
from transformers import (
AutomaticSpeechRecognitionPipeline,
AutoModel,
AutoTokenizer,
SummarizationPipeline,
TranslationPipeline,
pipeline,
)

from arc_spice.dropout_utils.dropout_pipeline import set_dropout

# From huggingface page with model:
# - https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2

def get_confidence_metrics(logits: torch.Tensor) -> dict[str : torch.Tensor]:
"""
calculates confidence metrics for a tensor of logits:
- entropy : token-wise entropy
- normalised entropy : token-wise entropy normalised by vocab size
- probs : log-probabilities of the each generated token

Returns:
dictionary containing the calculated confidence metrics
"""
vocab = torch.tensor(logits.shape[-1])
entropy = Categorical(logits=logits).entropy()
normalised_entropy = entropy / torch.log(vocab)
softmax_logits = softmax(logits, dim=-1)
max_probs = torch.max(softmax_logits, dim=-1).values
return {
"entropy": entropy,
"normalised_entropy": normalised_entropy,
"probs": max_probs,
}
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)


class TTSVariationalPipeline:
Expand All @@ -60,6 +57,13 @@ def __init__(self, pars: dict[str : dict[str:str]]):
pipeline_class=CustomSummarizationPipeline,
)

self.semantic_tokenizer = AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
)
self.semantic_model = AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
)

self.pipeline_map = {
"transcription": self.transcriber,
"translation": self.translator,
Expand All @@ -72,12 +76,50 @@ def __init__(self, pars: dict[str : dict[str:str]]):
"summarisation": self.summarise,
}

def get_confidence_metrics(
self, output_dict: dict[str : str | torch.Tensor]
) -> dict[str : torch.Tensor]:
"""
calculates confidence metrics for a tensor of logits:
- entropy : token-wise entropy
- normalised entropy : token-wise entropy normalised by vocab size
- probs : log-probabilities of the each generated token
Returns:
dictionary containing the calculated confidence metrics
"""
logits = output_dict["logits"]
text = output_dict["outputs"]
vocab = torch.tensor(logits.shape[-1])
entropy = Categorical(logits=logits).entropy()
normalised_entropy = entropy / torch.log(vocab)
softmax_logits = softmax(logits, dim=-1)
max_probs = torch.max(softmax_logits, dim=-1).values
tokenized_text = self.semantic_tokenizer(
text, padding=True, truncation=True, return_tensors="pt"
)
with torch.no_grad():
model_embeddings = self.semantic_model(**tokenized_text)
# Perform pooling
sentence_embeddings = mean_pooling(
model_embeddings, tokenized_text["attention_mask"]
)

# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return {
"entropy": entropy,
"normalised_entropy": normalised_entropy,
"probs": max_probs,
"semantic_embedding": sentence_embeddings,
}

def transcribe(self, x: Union[np.ndarray, bytes, str]):
transcription = self.transcriber(x, generate_kwargs=self.generate_kwargs)
output_text = transcription["text"]
output_logits = transcription["raw_outputs"][0]["logits"].squeeze().T
output_dict = {"outputs": output_text, "logits": output_logits}
confidence_metrics = get_confidence_metrics(output_logits)
confidence_metrics = self.get_confidence_metrics(output_dict)
output_dict.update(confidence_metrics)
return output_dict

Expand All @@ -90,7 +132,7 @@ def translate(self, source_text: str):
output_text = translation["translation_text"]
output_logits = torch.cat(translation["raw_outputs"]["logits"])
output_dict = {"outputs": output_text, "logits": output_logits}
confidence_metrics = get_confidence_metrics(output_logits)
confidence_metrics = self.get_confidence_metrics(output_dict)
output_dict.update(confidence_metrics)
return output_dict

Expand All @@ -103,7 +145,7 @@ def summarise(self, source_text: str):
output_text = summarisation["summary_text"]
output_logits = torch.cat(summarisation["raw_outputs"]["logits"])
output_dict = {"outputs": output_text, "logits": output_logits}
confidence_metrics = get_confidence_metrics(output_logits)
confidence_metrics = self.get_confidence_metrics(output_dict)
output_dict.update(confidence_metrics)
return output_dict

Expand Down

0 comments on commit 8a41bb6

Please sign in to comment.