|
19 | 19 | from functools import partial
|
20 | 20 | from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
21 | 21 |
|
| 22 | +import matplotlib.pyplot as plt |
| 23 | +import numpy as np |
22 | 24 | import torch
|
23 | 25 | import torch.nn as nn
|
| 26 | +from tokenizer import TextTokenCollater |
24 | 27 | from torch import Tensor
|
25 | 28 | from torch.nn import Linear, Module
|
26 | 29 | from torch.nn import functional as F
|
@@ -1658,6 +1661,88 @@ def continual(
|
1658 | 1661 | assert len(codes) == 8
|
1659 | 1662 | return torch.stack(codes, dim=-1)
|
1660 | 1663 |
|
| 1664 | + def visualize( |
| 1665 | + self, |
| 1666 | + predicts: Tuple[torch.Tensor], |
| 1667 | + batch: Dict[str, Union[List, torch.Tensor]], |
| 1668 | + tokenizer: TextTokenCollater, |
| 1669 | + output_dir: str, |
| 1670 | + limit: int = 4, |
| 1671 | + ) -> None: |
| 1672 | + audio_features = batch["features"].to("cpu").detach().numpy() |
| 1673 | + audio_features_lens = batch["features_lens"].to("cpu").detach().numpy() |
| 1674 | + |
| 1675 | + tokens = batch["tokens"] |
| 1676 | + text_tokens, text_tokens_lens = tokenizer(tokens) |
| 1677 | + assert text_tokens.ndim == 2 |
| 1678 | + |
| 1679 | + texts = batch["text"] |
| 1680 | + utt_ids = [cut.id for cut in batch["cut"]] |
| 1681 | + |
| 1682 | + encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() |
| 1683 | + decoder_outputs = predicts[1] |
| 1684 | + if isinstance(decoder_outputs, list): |
| 1685 | + decoder_outputs = decoder_outputs[-1] |
| 1686 | + decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() |
| 1687 | + |
| 1688 | + vmin, vmax = 0, 1024 # Encodec |
| 1689 | + if decoder_outputs.dtype == np.float32: |
| 1690 | + vmin, vmax = -6, 0 # Fbank |
| 1691 | + |
| 1692 | + num_figures = 3 |
| 1693 | + for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): |
| 1694 | + _ = plt.figure(figsize=(14, 8 * num_figures)) |
| 1695 | + |
| 1696 | + S = text_tokens_lens[b] |
| 1697 | + T = audio_features_lens[b] |
| 1698 | + |
| 1699 | + # encoder |
| 1700 | + plt.subplot(num_figures, 1, 1) |
| 1701 | + plt.title(f"Text: {text}") |
| 1702 | + plt.imshow( |
| 1703 | + X=np.transpose(encoder_outputs[b]), |
| 1704 | + cmap=plt.get_cmap("jet"), |
| 1705 | + aspect="auto", |
| 1706 | + interpolation="nearest", |
| 1707 | + ) |
| 1708 | + plt.gca().invert_yaxis() |
| 1709 | + plt.axvline(x=S - 0.4, linewidth=2, color="r") |
| 1710 | + plt.xlabel("Encoder Output") |
| 1711 | + plt.colorbar() |
| 1712 | + |
| 1713 | + # decoder |
| 1714 | + plt.subplot(num_figures, 1, 2) |
| 1715 | + plt.imshow( |
| 1716 | + X=np.transpose(decoder_outputs[b]), |
| 1717 | + cmap=plt.get_cmap("jet"), |
| 1718 | + aspect="auto", |
| 1719 | + interpolation="nearest", |
| 1720 | + vmin=vmin, |
| 1721 | + vmax=vmax, |
| 1722 | + ) |
| 1723 | + plt.gca().invert_yaxis() |
| 1724 | + plt.axvline(x=T - 0.4, linewidth=2, color="r") |
| 1725 | + plt.xlabel("Decoder Output") |
| 1726 | + plt.colorbar() |
| 1727 | + |
| 1728 | + # target |
| 1729 | + plt.subplot(num_figures, 1, 3) |
| 1730 | + plt.imshow( |
| 1731 | + X=np.transpose(audio_features[b]), |
| 1732 | + cmap=plt.get_cmap("jet"), |
| 1733 | + aspect="auto", |
| 1734 | + interpolation="nearest", |
| 1735 | + vmin=vmin, |
| 1736 | + vmax=vmax, |
| 1737 | + ) |
| 1738 | + plt.gca().invert_yaxis() |
| 1739 | + plt.axvline(x=T - 0.4, linewidth=2, color="r") |
| 1740 | + plt.xlabel("Decoder Target") |
| 1741 | + plt.colorbar() |
| 1742 | + |
| 1743 | + plt.savefig(f"{output_dir}/{utt_id}.png") |
| 1744 | + plt.close() |
| 1745 | + |
1661 | 1746 |
|
1662 | 1747 | # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
1663 | 1748 | def top_k_top_p_filtering(
|
|
0 commit comments