From 4055cda9ae89ede82193a04d69f7005e350773e1 Mon Sep 17 00:00:00 2001 From: Zhen Yang Date: Wed, 13 Aug 2025 17:40:53 +0800 Subject: [PATCH] fix several typos in megatron/core/transformer/multi_token_prediction.py --- megatron/core/transformer/multi_token_prediction.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 6eb276ea36..8999b8e1b5 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -448,7 +448,7 @@ def __init__( ) # For the linear projection at the (k - 1)-th MTP layer, the input is the concatenation - # of the i-th tocken's hidden states and the (i + K)-th tocken's decoder input, + # of the i-th token's hidden states and the (i + K)-th token's decoder input, # so the input's shape is [s, b, 2*h]. # The output will be send to the following transformer layer, # so the output's shape should be [s, b, h]. @@ -499,7 +499,7 @@ def forward( decoder_input (Tensor): Input tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size. At the (k - 1)-th MTP module, the i-th element of decoder input is - the embedding of (i + K)-th tocken. + the embedding of (i + K)-th token. attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking self-attention. context (Tensor, optional): Context tensor for cross-attention. @@ -545,8 +545,8 @@ def forward( hidden_states = make_viewless_tensor( inp=hidden_states, requires_grad=True, keep_graph=True ) - # At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states - # and the (i + K)-th tocken's embedding, and combine them with linear projection. + # At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states + # and the (i + K)-th token's embedding, and combine them with linear projection. hidden_states = torch.cat((decoder_input, hidden_states), -1) hidden_states, _ = self.eh_proj(hidden_states) # For tensor parallel we need to gather the tensor across the model-parallel