Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions verl/models/transformers/npu_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,67 @@ def apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1)
return q_embed.to(q.dtype), k_embed.to(k.dtype)


class RouterGatingLinearFunction(torch.autograd.Function):
"""
Copied from Megatron-LM megatron/core/transformer/moe/moe_utils.py
Autograd function for router gating linear.
"""

@staticmethod
def forward(
ctx,
inp: torch.Tensor,
weight: torch.Tensor,
router_dtype: torch.dtype,
) -> torch.Tensor:
"""
Forward pass of the RouterGatingLinearFunction function.

Args:
inp (torch.Tensor): The input tensor.
weight (torch.Tensor): The weight tensor.
router_dtype (torch.dtype): The router dtype.

Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(inp, weight)
ctx.router_dtype = router_dtype
ctx.input_dtype = inp.dtype
ctx.weight_dtype = weight.dtype
inp_shape = inp.shape
inp = inp.view(-1, inp_shape[-1])

output = torch.mm(inp.to(router_dtype), weight.to(router_dtype).t())

output = output.view(*inp_shape[:-1], -1)
return output

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]:
"""
Backward pass of the RouterGatingLinearFunction function.

Args:
grad_output (torch.Tensor): The gradient output.

Returns:
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None]:
The gradient input, gradient weight, gradient bias, and None.
"""
inp, weight = ctx.saved_tensors
inp_shape = inp.shape
grad_shape = grad_output.shape
inp = inp.view(-1, inp_shape[-1])
grad_output = grad_output.view(-1, grad_shape[-1])

grad_input = torch.mm(grad_output, weight.to(ctx.router_dtype)).to(ctx.input_dtype)
grad_weight = torch.mm(grad_output.t(), inp.to(ctx.router_dtype)).to(ctx.weight_dtype)

grad_input = grad_input.view(*inp_shape)
return grad_input, grad_weight, None


class NPUGmmFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, group_list, group_list_type=1):
Expand Down Expand Up @@ -111,7 +172,7 @@ def qwen3_moe_sparse_moe_block_forward_npu(self, hidden_states: torch.Tensor) ->
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, torch.float32)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
Expand Down Expand Up @@ -220,7 +281,7 @@ def __init__(self, config):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
router_logits = self.gate(hidden_states)
router_logits = RouterGatingLinearFunction.apply(hidden_states, self.gate.weight, torch.float32)
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
Expand Down