diff --git a/src/lmms_engine/models/qwen3_moe/qwen3_moe_liger.py b/src/lmms_engine/models/qwen3_moe/qwen3_moe_liger.py index 34beb04..fba6849 100644 --- a/src/lmms_engine/models/qwen3_moe/qwen3_moe_liger.py +++ b/src/lmms_engine/models/qwen3_moe/qwen3_moe_liger.py @@ -69,7 +69,7 @@ def lce_forward( output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits + output_router_logits if output_router_logits is not None else getattr(self.config, "output_router_logits", True) ) output_hidden_states = ( @@ -144,12 +144,14 @@ def lce_forward( loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None - if output_router_logits: + router_logits = getattr(outputs, "router_logits", None) + if output_router_logits and router_logits is not None: + aux_loss_mask = None if use_rmpad else attention_mask aux_loss = load_balancing_loss_func( - outputs.router_logits, + router_logits, self.num_experts, self.num_experts_per_tok, - attention_mask, + aux_loss_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device @@ -161,5 +163,5 @@ def lce_forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - router_logits=None, # Current always None + router_logits=router_logits, ) diff --git a/src/lmms_engine/models/qwen3_moe/qwen3_moe_ops.py b/src/lmms_engine/models/qwen3_moe/qwen3_moe_ops.py index 9c00f8f..d21b543 100644 --- a/src/lmms_engine/models/qwen3_moe/qwen3_moe_ops.py +++ b/src/lmms_engine/models/qwen3_moe/qwen3_moe_ops.py @@ -46,6 +46,7 @@ def model_forward( cache_position: Optional[torch.LongTensor] = None, cu_seq_lens: Optional[torch.IntTensor] = None, indices: Optional[torch.IntTensor] = None, + output_router_logits: Optional[bool] = None, **kwargs, ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): @@ -87,8 +88,13 @@ def model_forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) + output_router_logits = ( + output_router_logits if output_router_logits is not None else getattr(self.config, "output_router_logits", True) + ) + all_router_logits = () if output_router_logits else None + for decoder_layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = decoder_layer( + layer_outputs = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, @@ -98,9 +104,17 @@ def model_forward( cache_position=cache_position, cu_seq_lens=cu_seq_lens, indices=indices, + output_router_logits=output_router_logits, **kwargs, ) + if isinstance(layer_outputs, tuple): + hidden_states, router_logits = layer_outputs + if output_router_logits and router_logits is not None: + all_router_logits += (router_logits,) + else: + hidden_states = layer_outputs + hidden_states = self.norm(hidden_states) return BaseModelOutputWithPastAndRmpad( @@ -108,6 +122,7 @@ def model_forward( past_key_values=past_key_values if use_cache else None, seq_lens=cu_seq_lens, word_idx=indices, + router_logits=all_router_logits if output_router_logits else None, ) @@ -121,6 +136,7 @@ def decoder_layer_forward( cache_position: Optional[torch.LongTensor] = None, cu_seq_lens: Optional[torch.IntTensor] = None, indices: Optional[torch.IntTensor] = None, + output_router_logits: bool = True, **kwargs, ) -> torch.FloatTensor: """ @@ -170,14 +186,20 @@ def decoder_layer_forward( # Unsqueeze to unpack shape for the MoE sparse layer hidden_states = hidden_states.unsqueeze(0) hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - # For the MoE layers, we need to unpack - if isinstance(hidden_states, tuple): - hidden_states, _ = hidden_states + mlp_output = self.mlp(hidden_states) + + router_logits = None + if isinstance(mlp_output, tuple): + hidden_states, router_logits = mlp_output + else: + hidden_states = mlp_output + # Squeeze to pack shape for later hidden_states = hidden_states.squeeze(0) hidden_states = residual + hidden_states + if output_router_logits and router_logits is not None: + return hidden_states, router_logits return hidden_states diff --git a/src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_liger.py b/src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_liger.py index 6af699b..7987ca2 100644 --- a/src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_liger.py +++ b/src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_liger.py @@ -66,7 +66,9 @@ def lce_forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.text_config.output_router_logits + output_router_logits + if output_router_logits is not None + else getattr(self.config.text_config, "output_router_logits", True) ) tokens_count = attention_mask.sum().item() @@ -272,12 +274,14 @@ def lce_forward( # MoE auxiliary loss handling aux_loss = None - if output_router_logits and hasattr(outputs, "router_logits"): + router_logits = getattr(outputs, "router_logits", None) + if output_router_logits and router_logits is not None: + aux_loss_mask = None if use_rmpad else attention_mask aux_loss = load_balancing_loss_func( - outputs.router_logits, + router_logits, self.num_experts, self.num_experts_per_tok, - attention_mask, + aux_loss_mask, ) if labels is not None and loss is not None: # Add auxiliary loss weighted by router_aux_loss_coef diff --git a/src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_ops.py b/src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_ops.py index 012b3ef..844ff85 100644 --- a/src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_ops.py +++ b/src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_ops.py @@ -98,6 +98,7 @@ def text_model_forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, cu_seq_lens: Optional[torch.IntTensor] = None, @@ -108,6 +109,9 @@ def text_model_forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else getattr(self.config, "output_router_logits", True) + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -153,13 +157,14 @@ def text_model_forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None + all_router_logits = () if output_router_logits else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - hidden_states = torch.utils.checkpoint.checkpoint( + layer_outputs = torch.utils.checkpoint.checkpoint( decoder_layer.__call__, hidden_states, position_embeddings, @@ -169,10 +174,11 @@ def text_model_forward( cache_position, cu_seq_lens, indices, + output_router_logits, use_reentrant=False, ) else: - hidden_states = decoder_layer( + layer_outputs = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, @@ -181,9 +187,17 @@ def text_model_forward( cache_position=cache_position, cu_seq_lens=cu_seq_lens, indices=indices, + output_router_logits=output_router_logits, **kwargs, ) + if isinstance(layer_outputs, tuple): + hidden_states, router_logits = layer_outputs + if output_router_logits and router_logits is not None: + all_router_logits += (router_logits,) + else: + hidden_states = layer_outputs + hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -199,6 +213,7 @@ def text_model_forward( attentions=all_attentions, seq_lens=cu_seq_lens, word_idx=indices, + router_logits=all_router_logits if output_router_logits else None, ) @@ -212,6 +227,7 @@ def decoder_layer_forward( cache_position: Optional[torch.LongTensor] = None, cu_seq_lens: Optional[torch.IntTensor] = None, indices: Optional[torch.IntTensor] = None, + output_router_logits: bool = True, **kwargs, ) -> torch.FloatTensor: residual = hidden_states @@ -233,14 +249,20 @@ def decoder_layer_forward( residual = hidden_states hidden_states = hidden_states.unsqueeze(0) hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - # For the MoE layers, we need to unpack - if isinstance(hidden_states, tuple): - hidden_states, _ = hidden_states + mlp_output = self.mlp(hidden_states) + + router_logits = None + if isinstance(mlp_output, tuple): + hidden_states, router_logits = mlp_output + else: + hidden_states = mlp_output + # Squeeze to pack shape for later hidden_states = hidden_states.squeeze(0) hidden_states = residual + hidden_states + if output_router_logits and router_logits is not None: + return hidden_states, router_logits return hidden_states diff --git a/src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_liger.py b/src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_liger.py index fca4451..c0222c4 100644 --- a/src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_liger.py +++ b/src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_liger.py @@ -43,7 +43,7 @@ def lce_forward( output_router_logits = ( output_router_logits if output_router_logits is not None - else getattr(self.config.text_config, "output_router_logits", False) + else getattr(self.config.text_config, "output_router_logits", True) ) outputs = self.model( @@ -109,11 +109,12 @@ def lce_forward( if output_router_logits and router_logits is not None: router_aux_loss_coef = getattr(self.config.text_config, "router_aux_loss_coef", 0.001) + aux_loss_mask = None if use_rmpad else attention_mask aux_loss = load_balancing_loss_func( router_logits, config.num_experts, config.num_experts_per_tok, - attention_mask, + aux_loss_mask, ) loss = loss + router_aux_loss_coef * aux_loss.to(loss.device) diff --git a/src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_ops.py b/src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_ops.py index f760bf3..1e1953b 100644 --- a/src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_ops.py +++ b/src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_ops.py @@ -336,9 +336,7 @@ def text_model_forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_router_logits = ( - output_router_logits - if output_router_logits is not None - else getattr(self.config, "output_router_logits", False) + output_router_logits if output_router_logits is not None else getattr(self.config, "output_router_logits", True) ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -443,7 +441,7 @@ def decoder_layer_forward( cache_position: Optional[torch.LongTensor] = None, cu_seq_lens: Optional[torch.IntTensor] = None, indices: Optional[torch.IntTensor] = None, - output_router_logits: bool = False, + output_router_logits: bool = True, **kwargs, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: residual = hidden_states