-
Notifications
You must be signed in to change notification settings - Fork 25
[fix] Handle router logits in Qwen 3 moe and Qwen 3 omni moe for aux loss #98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
1fc6d76
eaefeac
c80204e
db0c0da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,7 @@ def model_forward( | |
| cache_position: Optional[torch.LongTensor] = None, | ||
| cu_seq_lens: Optional[torch.IntTensor] = None, | ||
| indices: Optional[torch.IntTensor] = None, | ||
| output_router_logits: Optional[bool] = None, | ||
| **kwargs, | ||
| ) -> MoeModelOutputWithPast: | ||
| if (input_ids is None) ^ (inputs_embeds is not None): | ||
|
|
@@ -87,8 +88,15 @@ def model_forward( | |
| # create position embeddings to be shared across the decoder layers | ||
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | ||
|
|
||
| output_router_logits = ( | ||
| output_router_logits | ||
| if output_router_logits is not None | ||
| else getattr(self.config, "output_router_logits", False) | ||
| ) | ||
| all_router_logits = () if output_router_logits else None | ||
|
|
||
|
Comment on lines
91
to
95
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By default the output router logits is always False? https://huggingface.co/Qwen/Qwen3-30B-A3B-Instruct-2507/blob/main/config.json Do you think we should change it to True for training, since this patch will only use in training mode. |
||
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: | ||
| hidden_states = decoder_layer( | ||
| layer_outputs = decoder_layer( | ||
| hidden_states, | ||
| position_embeddings=position_embeddings, | ||
| attention_mask=attention_mask, | ||
|
|
@@ -98,16 +106,25 @@ def model_forward( | |
| cache_position=cache_position, | ||
| cu_seq_lens=cu_seq_lens, | ||
| indices=indices, | ||
| output_router_logits=output_router_logits, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| if isinstance(layer_outputs, tuple): | ||
| hidden_states, router_logits = layer_outputs | ||
| if output_router_logits and router_logits is not None: | ||
| all_router_logits += (router_logits,) | ||
| else: | ||
| hidden_states = layer_outputs | ||
|
|
||
| hidden_states = self.norm(hidden_states) | ||
|
|
||
| return BaseModelOutputWithPastAndRmpad( | ||
| last_hidden_state=hidden_states, | ||
| past_key_values=past_key_values if use_cache else None, | ||
| seq_lens=cu_seq_lens, | ||
| word_idx=indices, | ||
| router_logits=all_router_logits if output_router_logits else None, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -121,6 +138,7 @@ def decoder_layer_forward( | |
| cache_position: Optional[torch.LongTensor] = None, | ||
| cu_seq_lens: Optional[torch.IntTensor] = None, | ||
| indices: Optional[torch.IntTensor] = None, | ||
| output_router_logits: bool = False, | ||
| **kwargs, | ||
| ) -> torch.FloatTensor: | ||
| """ | ||
|
|
@@ -170,14 +188,20 @@ def decoder_layer_forward( | |
| # Unsqueeze to unpack shape for the MoE sparse layer | ||
| hidden_states = hidden_states.unsqueeze(0) | ||
| hidden_states = self.post_attention_layernorm(hidden_states) | ||
| hidden_states = self.mlp(hidden_states) | ||
| # For the MoE layers, we need to unpack | ||
| if isinstance(hidden_states, tuple): | ||
| hidden_states, _ = hidden_states | ||
| mlp_output = self.mlp(hidden_states) | ||
|
|
||
| router_logits = None | ||
| if isinstance(mlp_output, tuple): | ||
| hidden_states, router_logits = mlp_output | ||
| else: | ||
| hidden_states = mlp_output | ||
|
|
||
| # Squeeze to pack shape for later | ||
| hidden_states = hidden_states.squeeze(0) | ||
| hidden_states = residual + hidden_states | ||
|
|
||
| if output_router_logits and router_logits is not None: | ||
| return hidden_states, router_logits | ||
| return hidden_states | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,6 +98,7 @@ def text_model_forward( | |
| use_cache: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| output_router_logits: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| cu_seq_lens: Optional[torch.IntTensor] = None, | ||
|
|
@@ -108,6 +109,11 @@ def text_model_forward( | |
| output_hidden_states = ( | ||
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
| ) | ||
| output_router_logits = ( | ||
| output_router_logits | ||
| if output_router_logits is not None | ||
| else getattr(self.config, "output_router_logits", False) | ||
| ) | ||
|
Comment on lines
112
to
114
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same problem here. Also I checked the config for Qwen3VLMoe. Seems like this problem also exist in qwen3 vl moe? Do you think we need to simply set this to true in training mode.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, will fix this |
||
| use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
|
|
@@ -153,13 +159,14 @@ def text_model_forward( | |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | ||
| all_hidden_states = () if output_hidden_states else None | ||
| all_attentions = () if output_attentions else None | ||
| all_router_logits = () if output_router_logits else None | ||
|
|
||
| for decoder_layer in self.layers: | ||
| if output_hidden_states: | ||
| all_hidden_states += (hidden_states,) | ||
|
|
||
| if self.gradient_checkpointing and self.training: | ||
| hidden_states = torch.utils.checkpoint.checkpoint( | ||
| layer_outputs = torch.utils.checkpoint.checkpoint( | ||
| decoder_layer.__call__, | ||
| hidden_states, | ||
| position_embeddings, | ||
|
|
@@ -169,10 +176,11 @@ def text_model_forward( | |
| cache_position, | ||
| cu_seq_lens, | ||
| indices, | ||
| output_router_logits, | ||
| use_reentrant=False, | ||
| ) | ||
| else: | ||
| hidden_states = decoder_layer( | ||
| layer_outputs = decoder_layer( | ||
| hidden_states, | ||
| position_embeddings=position_embeddings, | ||
| attention_mask=attention_mask, | ||
|
|
@@ -181,9 +189,17 @@ def text_model_forward( | |
| cache_position=cache_position, | ||
| cu_seq_lens=cu_seq_lens, | ||
| indices=indices, | ||
| output_router_logits=output_router_logits, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| if isinstance(layer_outputs, tuple): | ||
| hidden_states, router_logits = layer_outputs | ||
| if output_router_logits and router_logits is not None: | ||
| all_router_logits += (router_logits,) | ||
| else: | ||
| hidden_states = layer_outputs | ||
|
|
||
| hidden_states = self.norm(hidden_states) | ||
|
|
||
| if output_hidden_states: | ||
|
|
@@ -199,6 +215,7 @@ def text_model_forward( | |
| attentions=all_attentions, | ||
| seq_lens=cu_seq_lens, | ||
| word_idx=indices, | ||
| router_logits=all_router_logits if output_router_logits else None, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -212,6 +229,7 @@ def decoder_layer_forward( | |
| cache_position: Optional[torch.LongTensor] = None, | ||
| cu_seq_lens: Optional[torch.IntTensor] = None, | ||
| indices: Optional[torch.IntTensor] = None, | ||
| output_router_logits: bool = False, | ||
| **kwargs, | ||
| ) -> torch.FloatTensor: | ||
| residual = hidden_states | ||
|
|
@@ -233,14 +251,20 @@ def decoder_layer_forward( | |
| residual = hidden_states | ||
| hidden_states = hidden_states.unsqueeze(0) | ||
| hidden_states = self.post_attention_layernorm(hidden_states) | ||
| hidden_states = self.mlp(hidden_states) | ||
| # For the MoE layers, we need to unpack | ||
| if isinstance(hidden_states, tuple): | ||
| hidden_states, _ = hidden_states | ||
| mlp_output = self.mlp(hidden_states) | ||
|
|
||
| router_logits = None | ||
| if isinstance(mlp_output, tuple): | ||
| hidden_states, router_logits = mlp_output | ||
| else: | ||
| hidden_states = mlp_output | ||
|
|
||
| # Squeeze to pack shape for later | ||
| hidden_states = hidden_states.squeeze(0) | ||
| hidden_states = residual + hidden_states | ||
|
|
||
| if output_router_logits and router_logits is not None: | ||
| return hidden_states, router_logits | ||
| return hidden_states | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.