diff --git a/i6_models/parts/factored_hybrid/__init__.py b/i6_models/parts/factored_hybrid/__init__.py index 7184ae66..7132c445 100644 --- a/i6_models/parts/factored_hybrid/__init__.py +++ b/i6_models/parts/factored_hybrid/__init__.py @@ -3,8 +3,11 @@ "FactoredDiphoneBlockV1", "FactoredDiphoneBlockV2Config", "FactoredDiphoneBlockV2", + "FactoredTriphoneBlockV1Config", + "FactoredTriphoneBlockV1", "BoundaryClassV1", ] from .diphone import * +from .triphone import * from .util import BoundaryClassV1 diff --git a/i6_models/parts/factored_hybrid/diphone.py b/i6_models/parts/factored_hybrid/diphone.py index 18895748..e3f869c8 100644 --- a/i6_models/parts/factored_hybrid/diphone.py +++ b/i6_models/parts/factored_hybrid/diphone.py @@ -123,7 +123,13 @@ def forward_factored( return logits_center, logits_left, contexts_embedded_left def forward_joint(self, features: Tensor) -> Tensor: + """See `forward_joint_diphone`.""" + return self.forward_joint_diphone(features) + + def forward_joint_diphone(self, features: Tensor) -> Tensor: """ + Computes log p(c,l|h(x)), i.e. forwards the network for the full diphone joint. + :param features: Main encoder output. shape B, T, F. F=num_inputs :return: log probabilities for p(c,l|x). """ diff --git a/i6_models/parts/factored_hybrid/triphone.py b/i6_models/parts/factored_hybrid/triphone.py new file mode 100644 index 00000000..a44c1cc2 --- /dev/null +++ b/i6_models/parts/factored_hybrid/triphone.py @@ -0,0 +1,79 @@ +__all__ = [ + "FactoredTriphoneBlockV1Config", + "FactoredTriphoneBlockV1", +] + +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch import nn, Tensor + +from .diphone import FactoredDiphoneBlockV1, FactoredDiphoneBlockV2Config +from .util import get_mlp + + +@dataclass +class FactoredTriphoneBlockV1Config(FactoredDiphoneBlockV2Config): + """ + Attributes: + Same as the FactoredDiphoneBlockV2Config. + """ + + +class FactoredTriphoneBlockV1(FactoredDiphoneBlockV1): + """ + Triphone FH model output block. + + Consumes the output h(x) of a main encoder model and computes factored logits/probabilities + for p(c|l,h(x)), p(l|h(x)) and p(r|c,l,h(x)). + """ + + def __init__(self, cfg: FactoredTriphoneBlockV1Config): + super().__init__(cfg) + + self.center_state_embedding = nn.Embedding(self.num_center, cfg.center_state_embedding_dim) + self.right_context_encoder = get_mlp( + num_input=cfg.num_inputs + cfg.center_state_embedding_dim + cfg.left_context_embedding_dim, + num_output=self.num_contexts, + hidden_dim=cfg.context_mix_mlp_dim, + num_layers=cfg.context_mix_mlp_num_layers, + dropout=cfg.dropout, + activation=cfg.activation, + ) + + # update type definitions + def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + return super().forward(*args, **kwargs) + + def forward_factored( + self, + features: Tensor, # B, T, F + contexts_left: Tensor, # B, T + contexts_center: Tensor, # B, T + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + :param features: Main encoder output. shape B, T, F. F=num_inputs + :param contexts_left: The left contexts used to compute p(c|l,x), shape B, T. + Valid values range from [0, num_contexts). + :param contexts_center: The center states used to compute p(r|c,l,x), shape B, T. + Given that the center state also contains the word-end class and HMM state ID, the valid values + range from [0, num_center_states), where num_center_states >= num_contexts. + :return: tuple of logits for p(c|l,x), p(l|x), p(r|c,l,x) and the embedded left context and center state values. + """ + + logits_center, logits_left, contexts_left_embedded = super().forward_factored(features, contexts_left) + + # This logic is very similar to FactoredDiphoneBlockV2.forward, but not the same. + # This class computes `p(r|c,l,h(x))` while FactoredDiphoneBlockV2 computes `p(r|c,h(x))`. + center_states_embedded = self.center_state_embedding(contexts_center) # B, T, E' + features_right = torch.cat((features, center_states_embedded, contexts_left_embedded), -1) # B, T, F+E'+E + logits_right = self.right_context_encoder(features_right) # B, T, C + + return logits_center, logits_left, logits_right, contexts_left_embedded, center_states_embedded + + def forward_joint(self, features: Tensor) -> Tensor: + raise NotImplementedError( + "It is computationally infeasible to forward the full triphone joint, " + "only the diphone joint can be computed via forward_joint_diphone." + ) diff --git a/tests/test_fh.py b/tests/test_fh.py index 2f034267..87109ca7 100644 --- a/tests/test_fh.py +++ b/tests/test_fh.py @@ -9,6 +9,8 @@ FactoredDiphoneBlockV1Config, FactoredDiphoneBlockV2, FactoredDiphoneBlockV2Config, + FactoredTriphoneBlockV1, + FactoredTriphoneBlockV1Config, ) from i6_models.parts.factored_hybrid.util import get_center_dim @@ -96,3 +98,55 @@ def test_v2_output_shape_and_norm(): assert output_right.shape == (b, t, n_ctx) cdim = get_center_dim(n_ctx, states_per_ph, we_class) assert output_center.shape == (b, t, cdim) + + +def test_tri_output_shape_and_norm(): + n_ctx = 42 + n_in = 32 + + for we_class, states_per_ph in product( + [BoundaryClassV1.none, BoundaryClassV1.word_end, BoundaryClassV1.boundary], + [1, 3], + ): + tri_block = FactoredTriphoneBlockV1( + FactoredTriphoneBlockV1Config( + activation=nn.ReLU, + context_mix_mlp_dim=64, + context_mix_mlp_num_layers=2, + dropout=0.1, + left_context_embedding_dim=32, + center_state_embedding_dim=128, + num_contexts=n_ctx, + num_hmm_states_per_phone=states_per_ph, + num_inputs=n_in, + boundary_class=we_class, + ) + ) + + for b, t in product([10, 50, 100], [10, 50, 100]): + cdim = get_center_dim(n_ctx, states_per_ph, we_class) + contexts_left = torch.randint(0, n_ctx, (b, t)) + contexts_center = torch.randint(0, tri_block.num_center, (b, t)) + encoder_output = torch.rand((b, t, n_in)) + output_center, output_left, output_right, _, _ = tri_block( + features=encoder_output, contexts_left=contexts_left, contexts_center=contexts_center + ) + assert output_left.shape == (b, t, n_ctx) + assert output_center.shape == (b, t, cdim) + assert output_right.shape == (b, t, n_ctx) + + try: + tri_block.forward_joint(encoder_output) + except NotImplementedError: + pass + else: + assert False, "expected Error, did not get any" + + encoder_output = torch.rand((b, t, n_in)) + output = tri_block.forward_joint_diphone(features=encoder_output) + cdim = get_center_dim(n_ctx, states_per_ph, we_class) + assert output.shape == (b, t, cdim * n_ctx) + output_p = torch.exp(output) + ones_hopefully = torch.sum(output_p, dim=-1) + close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3 + assert all(close_to_one)