diff --git a/verl/models/transformers/npu_patch.py b/verl/models/transformers/npu_patch.py index ba25fe6e6ba..e8f407e3141 100644 --- a/verl/models/transformers/npu_patch.py +++ b/verl/models/transformers/npu_patch.py @@ -53,6 +53,67 @@ def apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1) return q_embed.to(q.dtype), k_embed.to(k.dtype) +class RouterGatingLinearFunction(torch.autograd.Function): + """ + Copied from Megatron-LM megatron/core/transformer/moe/moe_utils.py + Autograd function for router gating linear. + """ + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + weight: torch.Tensor, + router_dtype: torch.dtype, + ) -> torch.Tensor: + """ + Forward pass of the RouterGatingLinearFunction function. + + Args: + inp (torch.Tensor): The input tensor. + weight (torch.Tensor): The weight tensor. + router_dtype (torch.dtype): The router dtype. + + Returns: + torch.Tensor: The output tensor. + """ + ctx.save_for_backward(inp, weight) + ctx.router_dtype = router_dtype + ctx.input_dtype = inp.dtype + ctx.weight_dtype = weight.dtype + inp_shape = inp.shape + inp = inp.view(-1, inp_shape[-1]) + + output = torch.mm(inp.to(router_dtype), weight.to(router_dtype).t()) + + output = output.view(*inp_shape[:-1], -1) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: + """ + Backward pass of the RouterGatingLinearFunction function. + + Args: + grad_output (torch.Tensor): The gradient output. + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]: + The gradient input, gradient weight, gradient bias, and None. + """ + inp, weight = ctx.saved_tensors + inp_shape = inp.shape + grad_shape = grad_output.shape + inp = inp.view(-1, inp_shape[-1]) + grad_output = grad_output.view(-1, grad_shape[-1]) + + grad_input = torch.mm(grad_output, weight.to(ctx.router_dtype)).to(ctx.input_dtype) + grad_weight = torch.mm(grad_output.t(), inp.to(ctx.router_dtype)).to(ctx.weight_dtype) + + grad_input = grad_input.view(*inp_shape) + return grad_input, grad_weight, None + + class NPUGmmFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, group_list, group_list_type=1): @@ -111,7 +172,7 @@ def qwen3_moe_sparse_moe_block_forward_npu(self, hidden_states: torch.Tensor) -> hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) + router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, torch.float32) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) @@ -220,7 +281,7 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) - router_logits = self.gate(hidden_states) + router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, torch.float32) routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)