|
| 1 | +import copy |
1 | 2 | import logging
|
| 3 | +import math |
2 | 4 | from abc import ABC, abstractmethod
|
3 | 5 | from functools import partial
|
4 | 6 | from typing import Any
|
5 | 7 |
|
6 | 8 | import torch
|
7 | 9 | from torch.nn.functional import softmax
|
8 |
| -from transformers import AutoModelForSequenceClassification, AutoTokenizer, Pipeline |
| 10 | +from transformers import ( |
| 11 | + AutoModelForSequenceClassification, |
| 12 | + AutoTokenizer, |
| 13 | + Pipeline, |
| 14 | + TranslationPipeline, |
| 15 | +) |
9 | 16 |
|
10 | 17 | logger = logging.Logger("RTC_variational_pipeline")
|
11 | 18 |
|
@@ -117,6 +124,7 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
|
117 | 124 | "full_output",
|
118 | 125 | "outputs",
|
119 | 126 | "probs",
|
| 127 | + "mean_entropy", |
120 | 128 | ],
|
121 | 129 | "classification": [
|
122 | 130 | "scores",
|
@@ -264,6 +272,9 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]:
|
264 | 272 | {
|
265 | 273 | "outputs": translator_output["translation_text"],
|
266 | 274 | "probs": translator_output["raw_outputs"]["scores"],
|
| 275 | + "mean_entropy": torch.mean(translator_output["raw_outputs"]["entropy"]) |
| 276 | + .detach() |
| 277 | + .tolist(), |
267 | 278 | }
|
268 | 279 | for translator_output in translator_outputs
|
269 | 280 | ]
|
@@ -430,6 +441,7 @@ def translation_semantic_density(
|
430 | 441 | {
|
431 | 442 | "semantic_densities": densities,
|
432 | 443 | "weighted_semantic_density": weighted_average.item(),
|
| 444 | + "sequence_length": sequence_lengths, |
433 | 445 | }
|
434 | 446 | )
|
435 | 447 |
|
@@ -480,3 +492,63 @@ def get_classification_confidence(
|
480 | 492 | }
|
481 | 493 | )
|
482 | 494 | return var_output
|
| 495 | + |
| 496 | + |
| 497 | +# Translation pipeline with additional functionality to save logits from fwd pass |
| 498 | +class CustomTranslationPipeline(TranslationPipeline): |
| 499 | + """ |
| 500 | + custom translation pipeline to return the logits with the generated text. Largely |
| 501 | + the same as the pytorch version with some additional arguments passed to the |
| 502 | + `generate` method. |
| 503 | + """ |
| 504 | + |
| 505 | + def postprocess( |
| 506 | + self, |
| 507 | + model_outputs: dict, |
| 508 | + **postprocess_params, |
| 509 | + ): |
| 510 | + # model_outputs gets overwritten in the super().postprocess call |
| 511 | + # make a copy here so we retain the information we want |
| 512 | + raw_out = copy.deepcopy(model_outputs) |
| 513 | + processed = super().postprocess(model_outputs, **postprocess_params) |
| 514 | + |
| 515 | + return { |
| 516 | + "translation_text": processed[0]["translation_text"], |
| 517 | + "raw_outputs": raw_out, |
| 518 | + } |
| 519 | + |
| 520 | + def _forward(self, model_inputs, **generate_kwargs): |
| 521 | + if self.framework == "pt": |
| 522 | + in_b, input_length = model_inputs["input_ids"].shape |
| 523 | + elif self.framework == "tf": |
| 524 | + raise NotImplementedError |
| 525 | + |
| 526 | + self.check_inputs( |
| 527 | + input_length, |
| 528 | + generate_kwargs.get("min_length", self.model.config.min_length), |
| 529 | + generate_kwargs.get("max_length", self.model.config.max_length), |
| 530 | + ) |
| 531 | + out = self.model.generate(**model_inputs, **generate_kwargs) |
| 532 | + output_ids = out["sequences"] |
| 533 | + out_b = output_ids.shape[0] |
| 534 | + if self.framework == "pt": |
| 535 | + output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) |
| 536 | + elif self.framework == "tf": |
| 537 | + raise NotImplementedError |
| 538 | + |
| 539 | + # logits are a tuple of length output_ids[-1]-1 |
| 540 | + # each element is a tensor of shape (batch_size, vocab_size) |
| 541 | + logits = torch.stack(out["logits"], dim=1) |
| 542 | + # get softmax of the logits to get token probabilities |
| 543 | + softmax_logits = softmax(logits, dim=-1) |
| 544 | + vocab_size = softmax_logits.shape[-1] |
| 545 | + normalised_entropy = torch.distributions.Categorical( |
| 546 | + probs=softmax_logits |
| 547 | + ).entropy() / math.log(vocab_size) |
| 548 | + max_token_scores = torch.max(softmax_logits, dim=-1).values |
| 549 | + |
| 550 | + return { |
| 551 | + "output_ids": output_ids, |
| 552 | + "scores": max_token_scores, |
| 553 | + "entropy": normalised_entropy, |
| 554 | + } |
0 commit comments