From f6054b023cd216101f75d8c65d580e1572649619 Mon Sep 17 00:00:00 2001 From: Shangwei-Li Date: Mon, 9 Feb 2026 17:31:26 +0800 Subject: [PATCH 1/4] Use FP32 for moe routing. --- verl/models/transformers/npu_patch.py | 68 ++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/verl/models/transformers/npu_patch.py b/verl/models/transformers/npu_patch.py index ba25fe6e6ba..a6d50f3da78 100644 --- a/verl/models/transformers/npu_patch.py +++ b/verl/models/transformers/npu_patch.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional import torch import torch.nn.functional as F @@ -53,6 +54,69 @@ def apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1) return q_embed.to(q.dtype), k_embed.to(k.dtype) +class RouterGatingLinearFunction(torch.autograd.Function): + """ + Autograd function for router gating linear. + """ + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + router_dtype: torch.dtype, + ) -> torch.Tensor: + """ + Forward pass of the RouterGatingLinearFunction function. + + Args: + inp (torch.Tensor): The input tensor. + weight (torch.Tensor): The weight tensor. + bias (torch.Tensor): The bias tensor. Could be None. + router_dtype (torch.dtype): The router dtype. + + Returns: + torch.Tensor: The output tensor. + """ + ctx.save_for_backward(inp, weight, bias) + ctx.router_dtype = router_dtype + ctx.input_dtype = inp.dtype + ctx.weight_dtype = weight.dtype + inp_shape = inp.shape + inp = inp.view(-1, inp_shape[-1]) + + output = torch.mm(inp.to(router_dtype), weight.to(router_dtype).t()) + + output = output.view(*inp_shape[:-1], -1) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]: + """ + Backward pass of the RouterGatingLinearFunction function. + + Args: + grad_output (torch.Tensor): The gradient output. + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]: + The gradient input, gradient weight, gradient bias, and None. + """ + inp, weight, bias = ctx.saved_tensors + inp_shape = inp.shape + grad_shape = grad_output.shape + inp = inp.view(-1, inp_shape[-1]) + grad_output = grad_output.view(-1, grad_shape[-1]) + + grad_input = torch.mm(grad_output, weight.to(ctx.router_dtype)).to(ctx.input_dtype) + grad_weight = torch.mm(grad_output.t(), inp.to(ctx.router_dtype)).to(ctx.weight_dtype) + + grad_bias = grad_output.sum(dim=0).to(ctx.weight_dtype) if bias is not None else None + grad_input = grad_input.view(*inp_shape) + return grad_input, grad_weight, grad_bias, None + + class NPUGmmFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, group_list, group_list_type=1): @@ -111,7 +175,7 @@ def qwen3_moe_sparse_moe_block_forward_npu(self, hidden_states: torch.Tensor) -> hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) + router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, None, torch.float32) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) @@ -220,7 +284,7 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) - router_logits = self.gate(hidden_states) + router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, None, torch.float32) routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) From d80a7df16d01704b8f0684ce4678f57ded72f3f7 Mon Sep 17 00:00:00 2001 From: Shangwei-Li Date: Mon, 9 Feb 2026 17:34:38 +0800 Subject: [PATCH 2/4] Add comments about RouterGatingLinearFunction. --- verl/models/transformers/npu_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verl/models/transformers/npu_patch.py b/verl/models/transformers/npu_patch.py index a6d50f3da78..c9adb3b289a 100644 --- a/verl/models/transformers/npu_patch.py +++ b/verl/models/transformers/npu_patch.py @@ -56,6 +56,7 @@ def apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1) class RouterGatingLinearFunction(torch.autograd.Function): """ + Copied from Megatron-LM megatron/core/transformer/moe/moe_utils.py Autograd function for router gating linear. """ From 21241c1e2236c99383197b7a8519254577381a60 Mon Sep 17 00:00:00 2001 From: Shangwei-Li Date: Mon, 9 Feb 2026 18:00:52 +0800 Subject: [PATCH 3/4] Remove bias --- verl/models/transformers/npu_patch.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/verl/models/transformers/npu_patch.py b/verl/models/transformers/npu_patch.py index c9adb3b289a..0dd5e9ee9a8 100644 --- a/verl/models/transformers/npu_patch.py +++ b/verl/models/transformers/npu_patch.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import torch import torch.nn.functional as F @@ -65,7 +64,6 @@ def forward( ctx, inp: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor], router_dtype: torch.dtype, ) -> torch.Tensor: """ @@ -74,13 +72,12 @@ def forward( Args: inp (torch.Tensor): The input tensor. weight (torch.Tensor): The weight tensor. - bias (torch.Tensor): The bias tensor. Could be None. router_dtype (torch.dtype): The router dtype. Returns: torch.Tensor: The output tensor. """ - ctx.save_for_backward(inp, weight, bias) + ctx.save_for_backward(inp, weight) ctx.router_dtype = router_dtype ctx.input_dtype = inp.dtype ctx.weight_dtype = weight.dtype @@ -93,7 +90,7 @@ def forward( return output @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]: + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: """ Backward pass of the RouterGatingLinearFunction function. @@ -104,7 +101,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]: The gradient input, gradient weight, gradient bias, and None. """ - inp, weight, bias = ctx.saved_tensors + inp, weight = ctx.saved_tensors inp_shape = inp.shape grad_shape = grad_output.shape inp = inp.view(-1, inp_shape[-1]) @@ -113,9 +110,8 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor grad_input = torch.mm(grad_output, weight.to(ctx.router_dtype)).to(ctx.input_dtype) grad_weight = torch.mm(grad_output.t(), inp.to(ctx.router_dtype)).to(ctx.weight_dtype) - grad_bias = grad_output.sum(dim=0).to(ctx.weight_dtype) if bias is not None else None grad_input = grad_input.view(*inp_shape) - return grad_input, grad_weight, grad_bias, None + return grad_input, grad_weight, None class NPUGmmFunction(torch.autograd.Function): From 8cad278b4db768c8de8de9a963bbb88506de8405 Mon Sep 17 00:00:00 2001 From: Shangwei-Li Date: Mon, 9 Feb 2026 18:55:10 +0800 Subject: [PATCH 4/4] Fix parameter error. --- verl/models/transformers/npu_patch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/models/transformers/npu_patch.py b/verl/models/transformers/npu_patch.py index 0dd5e9ee9a8..e8f407e3141 100644 --- a/verl/models/transformers/npu_patch.py +++ b/verl/models/transformers/npu_patch.py @@ -172,7 +172,7 @@ def qwen3_moe_sparse_moe_block_forward_npu(self, hidden_states: torch.Tensor) -> hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) - router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, None, torch.float32) + router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, torch.float32) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) @@ -281,7 +281,7 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) - router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, None, torch.float32) + router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, torch.float32) routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)