Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/lmms_engine/models/qwen3_moe/qwen3_moe_liger.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def lce_forward(
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

aux_loss = None
if output_router_logits:
router_logits = getattr(outputs, "router_logits", None)
if output_router_logits and router_logits is not None:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
Expand All @@ -161,5 +162,5 @@ def lce_forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=None, # Current always None
router_logits=router_logits,
)
34 changes: 29 additions & 5 deletions src/lmms_engine/models/qwen3_moe/qwen3_moe_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def model_forward(
cache_position: Optional[torch.LongTensor] = None,
cu_seq_lens: Optional[torch.IntTensor] = None,
indices: Optional[torch.IntTensor] = None,
output_router_logits: Optional[bool] = None,
**kwargs,
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand Down Expand Up @@ -87,8 +88,15 @@ def model_forward(
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

output_router_logits = (
output_router_logits
if output_router_logits is not None
else getattr(self.config, "output_router_logits", False)
)
all_router_logits = () if output_router_logits else None

Comment on lines 91 to 95
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default the output router logits is always False?

https://huggingface.co/Qwen/Qwen3-30B-A3B-Instruct-2507/blob/main/config.json

Do you think we should change it to True for training, since this patch will only use in training mode.

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
Expand All @@ -98,16 +106,25 @@ def model_forward(
cache_position=cache_position,
cu_seq_lens=cu_seq_lens,
indices=indices,
output_router_logits=output_router_logits,
**kwargs,
)

if isinstance(layer_outputs, tuple):
hidden_states, router_logits = layer_outputs
if output_router_logits and router_logits is not None:
all_router_logits += (router_logits,)
else:
hidden_states = layer_outputs

hidden_states = self.norm(hidden_states)

return BaseModelOutputWithPastAndRmpad(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
seq_lens=cu_seq_lens,
word_idx=indices,
router_logits=all_router_logits if output_router_logits else None,
)


Expand All @@ -121,6 +138,7 @@ def decoder_layer_forward(
cache_position: Optional[torch.LongTensor] = None,
cu_seq_lens: Optional[torch.IntTensor] = None,
indices: Optional[torch.IntTensor] = None,
output_router_logits: bool = False,
**kwargs,
) -> torch.FloatTensor:
"""
Expand Down Expand Up @@ -170,14 +188,20 @@ def decoder_layer_forward(
# Unsqueeze to unpack shape for the MoE sparse layer
hidden_states = hidden_states.unsqueeze(0)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# For the MoE layers, we need to unpack
if isinstance(hidden_states, tuple):
hidden_states, _ = hidden_states
mlp_output = self.mlp(hidden_states)

router_logits = None
if isinstance(mlp_output, tuple):
hidden_states, router_logits = mlp_output
else:
hidden_states = mlp_output

# Squeeze to pack shape for later
hidden_states = hidden_states.squeeze(0)
hidden_states = residual + hidden_states

if output_router_logits and router_logits is not None:
return hidden_states, router_logits
return hidden_states


Expand Down
5 changes: 3 additions & 2 deletions src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_liger.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,10 @@ def lce_forward(

# MoE auxiliary loss handling
aux_loss = None
if output_router_logits and hasattr(outputs, "router_logits"):
router_logits = getattr(outputs, "router_logits", None)
if output_router_logits and router_logits is not None:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
Expand Down
36 changes: 30 additions & 6 deletions src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def text_model_forward(
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cu_seq_lens: Optional[torch.IntTensor] = None,
Expand All @@ -108,6 +109,11 @@ def text_model_forward(
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else getattr(self.config, "output_router_logits", False)
)
Comment on lines 112 to 114
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same problem here. Also I checked the config for Qwen3VLMoe. Seems like this problem also exist in qwen3 vl moe? Do you think we need to simply set this to true in training mode.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will fix this

use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down Expand Up @@ -153,13 +159,14 @@ def text_model_forward(
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_router_logits = () if output_router_logits else None

for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
hidden_states = torch.utils.checkpoint.checkpoint(
layer_outputs = torch.utils.checkpoint.checkpoint(
decoder_layer.__call__,
hidden_states,
position_embeddings,
Expand All @@ -169,10 +176,11 @@ def text_model_forward(
cache_position,
cu_seq_lens,
indices,
output_router_logits,
use_reentrant=False,
)
else:
hidden_states = decoder_layer(
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
Expand All @@ -181,9 +189,17 @@ def text_model_forward(
cache_position=cache_position,
cu_seq_lens=cu_seq_lens,
indices=indices,
output_router_logits=output_router_logits,
**kwargs,
)

if isinstance(layer_outputs, tuple):
hidden_states, router_logits = layer_outputs
if output_router_logits and router_logits is not None:
all_router_logits += (router_logits,)
else:
hidden_states = layer_outputs

hidden_states = self.norm(hidden_states)

if output_hidden_states:
Expand All @@ -199,6 +215,7 @@ def text_model_forward(
attentions=all_attentions,
seq_lens=cu_seq_lens,
word_idx=indices,
router_logits=all_router_logits if output_router_logits else None,
)


Expand All @@ -212,6 +229,7 @@ def decoder_layer_forward(
cache_position: Optional[torch.LongTensor] = None,
cu_seq_lens: Optional[torch.IntTensor] = None,
indices: Optional[torch.IntTensor] = None,
output_router_logits: bool = False,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
Expand All @@ -233,14 +251,20 @@ def decoder_layer_forward(
residual = hidden_states
hidden_states = hidden_states.unsqueeze(0)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# For the MoE layers, we need to unpack
if isinstance(hidden_states, tuple):
hidden_states, _ = hidden_states
mlp_output = self.mlp(hidden_states)

router_logits = None
if isinstance(mlp_output, tuple):
hidden_states, router_logits = mlp_output
else:
hidden_states = mlp_output

# Squeeze to pack shape for later
hidden_states = hidden_states.squeeze(0)
hidden_states = residual + hidden_states

if output_router_logits and router_logits is not None:
return hidden_states, router_logits
return hidden_states


Expand Down