From 8a41bb6fedd9bdf17a4c2e3bc1dadd25d7cc8168 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Fri, 4 Oct 2024 11:36:25 +0100 Subject: [PATCH] pipeline now calculates a sentence embedding for each input --- scripts/variational_TTS_example.py | 12 +-- .../dropout_utils/variational_inference.py | 86 ++++++++++++++----- 2 files changed, 70 insertions(+), 28 deletions(-) diff --git a/scripts/variational_TTS_example.py b/scripts/variational_TTS_example.py index 2646668..0b24da1 100644 --- a/scripts/variational_TTS_example.py +++ b/scripts/variational_TTS_example.py @@ -20,17 +20,17 @@ def main(TTS_params): clean_output = var_pipe.clean_inference(input_speech["array"]) # logit shapes - print("Logit shapes:") + print("\nLogit shapes:") for step in var_pipe.pipeline_map.keys(): print(f"{step.capitalize()}: {clean_output[step]["logits"].shape}") # entropy - print("Mean entropy:") + print("\nMean entropy:") for step in var_pipe.pipeline_map.keys(): print(f"{step.capitalize()}: {torch.mean(clean_output[step]["entropy"])}") # normalised entropy - print("Normalised mean entropy:") + print("\nNormalised mean entropy:") cumulative = 1 for step in var_pipe.pipeline_map.keys(): step_entropy = torch.mean(clean_output[step]["normalised_entropy"]) @@ -39,7 +39,7 @@ def main(TTS_params): print(f"Cumulative confidence (1 - entropy): {cumulative}") # probabilities - print("Mean top probabilities:") + print("\nMean top probabilities:") cumulative = 1 for step in var_pipe.pipeline_map.keys(): step_prob = torch.mean(clean_output[step]["probs"]) @@ -51,10 +51,10 @@ def main(TTS_params): for step in var_pipe.pipeline_map.keys(): - print(f'{step}:') + print(f'\n{step}:') step_output = variational_output['variational'][step] for run in step_output: - print(run['logits']) + print(run['semantic_embedding']) if __name__ == "__main__": TTS_pars = { diff --git a/src/arc_spice/dropout_utils/variational_inference.py b/src/arc_spice/dropout_utils/variational_inference.py index dbc1f17..69a85a8 100644 --- a/src/arc_spice/dropout_utils/variational_inference.py +++ b/src/arc_spice/dropout_utils/variational_inference.py @@ -3,10 +3,13 @@ import numpy as np import torch +import torch.nn.functional as F from torch.distributions import Categorical from torch.nn.functional import softmax from transformers import ( AutomaticSpeechRecognitionPipeline, + AutoModel, + AutoTokenizer, SummarizationPipeline, TranslationPipeline, pipeline, @@ -14,27 +17,21 @@ from arc_spice.dropout_utils.dropout_pipeline import set_dropout +# From huggingface page with model: +# - https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 -def get_confidence_metrics(logits: torch.Tensor) -> dict[str : torch.Tensor]: - """ - calculates confidence metrics for a tensor of logits: - - entropy : token-wise entropy - - normalised entropy : token-wise entropy normalised by vocab size - - probs : log-probabilities of the each generated token - Returns: - dictionary containing the calculated confidence metrics - """ - vocab = torch.tensor(logits.shape[-1]) - entropy = Categorical(logits=logits).entropy() - normalised_entropy = entropy / torch.log(vocab) - softmax_logits = softmax(logits, dim=-1) - max_probs = torch.max(softmax_logits, dim=-1).values - return { - "entropy": entropy, - "normalised_entropy": normalised_entropy, - "probs": max_probs, - } +# Mean Pooling - Take attention mask into account for correct averaging +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[ + 0 + ] # First element of model_output contains all token embeddings + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) class TTSVariationalPipeline: @@ -60,6 +57,13 @@ def __init__(self, pars: dict[str : dict[str:str]]): pipeline_class=CustomSummarizationPipeline, ) + self.semantic_tokenizer = AutoTokenizer.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ) + self.semantic_model = AutoModel.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ) + self.pipeline_map = { "transcription": self.transcriber, "translation": self.translator, @@ -72,12 +76,50 @@ def __init__(self, pars: dict[str : dict[str:str]]): "summarisation": self.summarise, } + def get_confidence_metrics( + self, output_dict: dict[str : str | torch.Tensor] + ) -> dict[str : torch.Tensor]: + """ + calculates confidence metrics for a tensor of logits: + - entropy : token-wise entropy + - normalised entropy : token-wise entropy normalised by vocab size + - probs : log-probabilities of the each generated token + + Returns: + dictionary containing the calculated confidence metrics + """ + logits = output_dict["logits"] + text = output_dict["outputs"] + vocab = torch.tensor(logits.shape[-1]) + entropy = Categorical(logits=logits).entropy() + normalised_entropy = entropy / torch.log(vocab) + softmax_logits = softmax(logits, dim=-1) + max_probs = torch.max(softmax_logits, dim=-1).values + tokenized_text = self.semantic_tokenizer( + text, padding=True, truncation=True, return_tensors="pt" + ) + with torch.no_grad(): + model_embeddings = self.semantic_model(**tokenized_text) + # Perform pooling + sentence_embeddings = mean_pooling( + model_embeddings, tokenized_text["attention_mask"] + ) + + # Normalize embeddings + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + return { + "entropy": entropy, + "normalised_entropy": normalised_entropy, + "probs": max_probs, + "semantic_embedding": sentence_embeddings, + } + def transcribe(self, x: Union[np.ndarray, bytes, str]): transcription = self.transcriber(x, generate_kwargs=self.generate_kwargs) output_text = transcription["text"] output_logits = transcription["raw_outputs"][0]["logits"].squeeze().T output_dict = {"outputs": output_text, "logits": output_logits} - confidence_metrics = get_confidence_metrics(output_logits) + confidence_metrics = self.get_confidence_metrics(output_dict) output_dict.update(confidence_metrics) return output_dict @@ -90,7 +132,7 @@ def translate(self, source_text: str): output_text = translation["translation_text"] output_logits = torch.cat(translation["raw_outputs"]["logits"]) output_dict = {"outputs": output_text, "logits": output_logits} - confidence_metrics = get_confidence_metrics(output_logits) + confidence_metrics = self.get_confidence_metrics(output_dict) output_dict.update(confidence_metrics) return output_dict @@ -103,7 +145,7 @@ def summarise(self, source_text: str): output_text = summarisation["summary_text"] output_logits = torch.cat(summarisation["raw_outputs"]["logits"]) output_dict = {"outputs": output_text, "logits": output_logits} - confidence_metrics = get_confidence_metrics(output_logits) + confidence_metrics = self.get_confidence_metrics(output_dict) output_dict.update(confidence_metrics) return output_dict