diff --git a/scripts/variational_TTS_example.py b/scripts/variational_TTS_example.py index 0b24da1..9c61f96 100644 --- a/scripts/variational_TTS_example.py +++ b/scripts/variational_TTS_example.py @@ -1,8 +1,6 @@ """ An example use of the transcription, translation and summarisation pipeline. """ -import json - import torch from datasets import Audio, load_dataset @@ -11,14 +9,17 @@ def main(TTS_params): """main function""" - var_pipe = TTSVariationalPipeline(TTS_params) + var_pipe = TTSVariationalPipeline(TTS_params,n_variational_runs=2) + ds = load_dataset( "facebook/multilingual_librispeech", "french", split="test", streaming=True ) ds = ds.cast_column("audio", Audio(sampling_rate=16_000)) input_speech = next(iter(ds))["audio"] - clean_output = var_pipe.clean_inference(input_speech["array"]) + var_pipe.clean_inference(input_speech["array"]) + clean_output = var_pipe.clean_output + # logit shapes print("\nLogit shapes:") for step in var_pipe.pipeline_map.keys(): @@ -47,14 +48,18 @@ def main(TTS_params): print(f"{step.capitalize()}: {step_prob}") print(f"Cumulative confidence: {cumulative}") - variational_output = var_pipe.variational_inference(x=input_speech['array'],n_runs=2) + print("\nConditional probabilities:") + for step in var_pipe.pipeline_map.keys(): + token_probs = clean_output[step]["probs"] + cond_prob = torch.pow(torch.prod(token_probs,-1),1/len(token_probs)) + print(f"{step.capitalize()}: {cond_prob}") + var_pipe.variational_inference(x=input_speech['array']) + variational_output = var_pipe.var_output + print("\nVariational Inference Semantic Density:") + for step in variational_output['variational'].keys(): + print(f"{step}: {variational_output['variational'][step]['semantic_density']}") - for step in var_pipe.pipeline_map.keys(): - print(f'\n{step}:') - step_output = variational_output['variational'][step] - for run in step_output: - print(run['semantic_embedding']) if __name__ == "__main__": TTS_pars = { diff --git a/src/arc_spice/dropout_utils/variational_inference.py b/src/arc_spice/dropout_utils/variational_inference.py index 69a85a8..ad73aa1 100644 --- a/src/arc_spice/dropout_utils/variational_inference.py +++ b/src/arc_spice/dropout_utils/variational_inference.py @@ -5,10 +5,11 @@ import torch import torch.nn.functional as F from torch.distributions import Categorical -from torch.nn.functional import softmax +from torch.nn.functional import cosine_similarity, softmax from transformers import ( AutomaticSpeechRecognitionPipeline, AutoModel, + AutoModelForSequenceClassification, AutoTokenizer, SummarizationPipeline, TranslationPipeline, @@ -39,7 +40,7 @@ class TTSVariationalPipeline: variational version of the TTS pipeline """ - def __init__(self, pars: dict[str : dict[str:str]]): + def __init__(self, pars: dict[str : dict[str:str]], n_variational_runs=5): self.transcriber = pipeline( task=pars["transcriber"]["specific_task"], model=pars["transcriber"]["model"], @@ -64,17 +65,35 @@ def __init__(self, pars: dict[str : dict[str:str]]): "sentence-transformers/all-MiniLM-L6-v2" ) + self.nli_tokenizer = AutoTokenizer.from_pretrained( + "microsoft/deberta-large-mnli" + ) + + self.nli_model = AutoModelForSequenceClassification.from_pretrained( + "microsoft/deberta-large-mnli" + ) + self.pipeline_map = { "transcription": self.transcriber, "translation": self.translator, "summarisation": self.summariser, } self.generate_kwargs = {"output_scores": True} + self.func_map = { "transcription": self.transcribe, "translation": self.translate, "summarisation": self.summarise, } + self.naive_outputs = { + "outputs", + "logits", + "entropy", + "normalised_entropy", + "probs", + "semantic_embedding", + } + self.n_variational_runs = n_variational_runs def get_confidence_metrics( self, output_dict: dict[str : str | torch.Tensor] @@ -149,6 +168,50 @@ def summarise(self, source_text: str): output_dict.update(confidence_metrics) return output_dict + def collect_metrics(self): + new_var_dict = {} + for step in self.var_output["variational"].keys(): + new_var_dict[step] = {} + for metric in self.naive_outputs: + new_values = [None] * self.n_variational_runs + for run in range(self.n_variational_runs): + new_values[run] = self.var_output["variational"][step][run][metric] + new_var_dict[step][metric] = new_values + + self.var_output["variational"] = new_var_dict + + def calculate_semantic_density(self): + for step in self.var_output["variational"].keys(): + clean_out = self.var_output["clean"][step]["outputs"] + var_step = self.var_output["variational"][step] + kernel_funcs = torch.zeros(self.n_variational_runs) + cond_probs = torch.zeros(self.n_variational_runs) + sims = [None] * self.n_variational_runs + for run_index, run_out in enumerate(var_step["outputs"]): + run_prob = var_step["probs"][run_index] + nli_inp = clean_out + " [SEP] " + run_out + encoded_nli = self.nli_tokenizer.encode( + nli_inp, padding=True, return_tensors="pt" + ) + sims[run_index] = cosine_similarity( + self.var_output["clean"][step]["semantic_embedding"], + var_step["semantic_embedding"][run_index], + ) + nli_out = softmax(self.nli_model(encoded_nli)["logits"], dim=-1)[0] + kernel_func = 1 - (nli_out[0] + (0.5 * nli_out[1])) + cond_probs[run_index] = torch.pow( + torch.prod(run_prob, -1), 1 / len(run_prob) + ) + kernel_funcs[run_index] = kernel_func + semantic_density = ( + 1 + / (torch.sum(cond_probs)) + * torch.sum(torch.mul(cond_probs, kernel_funcs)) + ) + self.var_output["variational"][step].update( + {"semantic_density": semantic_density.item(), "cosine_similarity": sims} + ) + def clean_inference(self, x: Union[np.ndarray, bytes, str]): """ @@ -161,48 +224,50 @@ def clean_inference(self, x: Union[np.ndarray, bytes, str]): summarised transcript with associated unvertainties at each step """ - output = {step: {} for step in self.pipeline_map.keys()} + self.clean_output = {step: {} for step in self.pipeline_map.keys()} # transcription transcription = self.transcribe(x) - output["transcription"].update(transcription) + self.clean_output["transcription"].update(transcription) # translation translation = self.translate(transcription["outputs"]) - output["translation"].update(translation) + self.clean_output["translation"].update(translation) # summarisation summarisation = self.summarise(translation["outputs"]) - output["summarisation"].update(summarisation) + self.clean_output["summarisation"].update(summarisation) - return output - - def variational_inference(self, x, n_runs=5): + def variational_inference(self, x): # we need clean inputs to pass to each step, we run that first - output = {"clean": {}, "variational": {}} - output["clean"] = self.clean_inference(x) + self.var_output = {"clean": {}, "variational": {}} + self.clean_inference(x) + self.var_output["clean"] = self.clean_output # each step accepts a different input from the clean pipeline input_map = { "transcription": x, - "translation": output["clean"]["transcription"]["outputs"], - "summarisation": output["clean"]["translation"]["outputs"], + "translation": self.var_output["clean"]["transcription"]["outputs"], + "summarisation": self.var_output["clean"]["translation"]["outputs"], } # for each model in pipeline for model_key, pl in self.pipeline_map.items(): # turn on dropout for this model set_dropout(model=pl.model, dropout_flag=True) # create the output list - output["variational"][model_key] = [None] * n_runs + self.var_output["variational"][model_key] = [None] * self.n_variational_runs # do n runs of the inference - for run_idx in range(n_runs): - output["variational"][model_key][run_idx] = self.func_map[model_key]( - input_map[model_key] - ) + for run_idx in range(self.n_variational_runs): + self.var_output["variational"][model_key][run_idx] = self.func_map[ + model_key + ](input_map[model_key]) # turn off dropout for this model set_dropout(model=pl.model, dropout_flag=False) - return output + + self.collect_metrics() + self.calculate_semantic_density() def __call__(self, x): - return self.clean_inference(x) + self.clean_inference(x) + return self.clean_output class CustomSpeechRecognitionPipeline(AutomaticSpeechRecognitionPipeline):