Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] RNN-T + MBR training. #593

Closed
wants to merge 7 commits into from
Closed

[WIP] RNN-T + MBR training. #593

wants to merge 7 commits into from

Conversation

pkufool
Copy link
Collaborator

@pkufool pkufool commented Sep 29, 2022

This PR depends on k2-fsa/k2#1057 in k2.

@pkufool pkufool requested a review from yaozengwei December 8, 2022 05:41
@pkufool
Copy link
Collaborator Author

pkufool commented Dec 8, 2022

The model structure is like the diagram below, it has two joiners, one is the joiner for regular RNN-T, the other is quasi-joiner that produces the expected wer. To make the quasi-joiner work well, we use an Enhanced embedding instead of the Encoder output. The Embedding enhancer is some kind of model that has self-attention from masked_encoder_output and cross-attention from text_embedding produced by a tranformer LM.

image


self.encoder_output_layer = ScaledLinear(
d_model, num_classes, bias=True
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformer lm is actually an Embedding Layer plus TransformerEncoder that encode the symbols into text_embedding.

dropout=dropout,
layer_dropout=layer_dropout,
)
self.enhancer = TransformerDecoder(decoder_layer, num_layers)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The EmbeddingEnhancer is a TransformerDecoder that has self-attention from masked_encoder_output and cross-attention from text_embedding.

N, T, C = embedding.shape
mask = torch.randn((N, T, C), device=embedding.device)
mask = mask > mask_proportion
masked_embedding = torch.masked_fill(embedding, ~mask, 0.0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I randomly mask the encoder output here.

)
return init_context

def delta_wer(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function implements the sampling process.

+ l2_loss_scale * l2_loss
+ delta_wer_scale * delta_wer_loss
+ predictor_loss_scale * predictor_loss
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The losses are combined here.

@pkufool
Copy link
Collaborator Author

pkufool commented Dec 8, 2022

@danpovey @yaozengwei @glynpu Would you please to have a look at this, if there is anything unclear, please let me know. Thanks!

@yaozengwei
Copy link
Collaborator

Sure. I will have a look.

@pkufool pkufool closed this Nov 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants