Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions verl/models/transformers/npu_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -53,6 +54,70 @@ def apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1)
return q_embed.to(q.dtype), k_embed.to(k.dtype)


class RouterGatingLinearFunction(torch.autograd.Function):
"""
Copied from Megatron-LM megatron/core/transformer/moe/moe_utils.py
Autograd function for router gating linear.
"""

@staticmethod
def forward(
ctx,
inp: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
router_dtype: torch.dtype,
) -> torch.Tensor:
"""
Forward pass of the RouterGatingLinearFunction function.

Args:
inp (torch.Tensor): The input tensor.
weight (torch.Tensor): The weight tensor.
bias (torch.Tensor): The bias tensor. Could be None.
router_dtype (torch.dtype): The router dtype.

Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(inp, weight, bias)
ctx.router_dtype = router_dtype
ctx.input_dtype = inp.dtype
ctx.weight_dtype = weight.dtype
inp_shape = inp.shape
inp = inp.view(-1, inp_shape[-1])

output = torch.mm(inp.to(router_dtype), weight.to(router_dtype).t())

output = output.view(*inp_shape[:-1], -1)
return output

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]:
"""
Backward pass of the RouterGatingLinearFunction function.

Args:
grad_output (torch.Tensor): The gradient output.

Returns:
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]:
The gradient input, gradient weight, gradient bias, and None.
"""
inp, weight, bias = ctx.saved_tensors
inp_shape = inp.shape
grad_shape = grad_output.shape
inp = inp.view(-1, inp_shape[-1])
grad_output = grad_output.view(-1, grad_shape[-1])

grad_input = torch.mm(grad_output, weight.to(ctx.router_dtype)).to(ctx.input_dtype)
grad_weight = torch.mm(grad_output.t(), inp.to(ctx.router_dtype)).to(ctx.weight_dtype)

grad_bias = grad_output.sum(dim=0).to(ctx.weight_dtype) if bias is not None else None
grad_input = grad_input.view(*inp_shape)
return grad_input, grad_weight, grad_bias, None


class NPUGmmFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, group_list, group_list_type=1):
Expand Down Expand Up @@ -111,7 +176,7 @@ def qwen3_moe_sparse_moe_block_forward_npu(self, hidden_states: torch.Tensor) ->
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, None, torch.float32)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
Expand Down Expand Up @@ -220,7 +285,7 @@ def __init__(self, config):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
router_logits = self.gate(hidden_states)
router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, None, torch.float32)
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
Expand Down