Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 88 additions & 68 deletions python/sglang/srt/layers/attention/xpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from sglang.srt.model_executor.model_runner import ModelRunner

from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sgl_kernel.flash_attn import flash_attn_with_kvcache


@torch.compile
def extract_page_table(batch_size, req_to_token, req_pool_indices, seq_lens):
kv_indices = req_to_token[req_pool_indices, : seq_lens.max()]
return kv_indices.view(batch_size, -1)


class XPUAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -52,13 +58,19 @@ def __init__(
# extra metadata for handling speculative decoding topk > 1, extended draft decode and verify
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len
self.max_bs = model_runner.req_to_token_pool.size
self.device = model_runner.device
self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size
self.max_num_pages_per_req = self.max_context_len // self.page_size
self.kv_indptr = (
torch.arange(0, self.max_bs + 1, dtype=torch.int32, device=self.device)
* self.max_num_pages_per_req
)
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
assert (
self.use_mla is False
Expand Down Expand Up @@ -100,16 +112,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
# Draft Decode
if forward_batch.spec_info is not None:
assert (
False
), "XPUAttentionBackend doesn't support speculative decoding yet, please use --attention-backend triton instead."
if self.topk <= 1:
metadata.cache_seqlens_int32 = (
seqlens_in_batch + (self.speculative_step_id + 1)
).to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
Expand All @@ -119,13 +125,15 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table = extract_page_table(
batch_size,
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
)
else:
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
metadata.max_seq_len_q = self.topk
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.topk + 1,
Expand All @@ -139,9 +147,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table = extract_page_table(
batch_size,
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
)

metadata_expand = FlashAttentionMetadata()
decode_length = self.speculative_step_id + 1
Expand Down Expand Up @@ -176,16 +187,18 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
else:
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table = extract_page_table(
batch_size,
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
)
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(forward_batch, metadata, device)
elif forward_batch.forward_mode.is_target_verify():
Expand All @@ -194,10 +207,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
forward_batch.seq_lens + self.speculative_num_draft_tokens
).to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item()
+ self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
Expand All @@ -211,15 +220,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table = extract_page_table(
batch_size,
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
)

self._init_local_attn_metadata(forward_batch, metadata, device)
else:
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
Expand All @@ -233,9 +244,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table = extract_page_table(
batch_size,
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
)

metadata_expand = FlashAttentionMetadata()

Expand Down Expand Up @@ -317,13 +331,15 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table = extract_page_table(
batch_size,
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
)

if (
any(forward_batch.extend_prefix_lens_cpu)
Expand All @@ -335,7 +351,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
else:
metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k

# Setup local attention if enabled
Expand All @@ -344,27 +359,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
assert (
forward_batch.encoder_lens.numel() == 1
), "Only encoder size 1 is supported for now"

metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
metadata.cache_seqlens_int32 = forward_batch.encoder_lens.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
]

# Currently only support forward_batch.encoder_lens.numel() == 1
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
metadata.page_table = extract_page_table(
batch_size,
self.req_to_token,
forward_batch.req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
forward_batch.seq_lens,
)

# Convert the page table to a strided format which is needed by FA3 API
if self.page_size > 1:
Expand Down Expand Up @@ -492,9 +497,6 @@ def forward_extend(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)

result = flash_attn_with_kvcache(
Expand All @@ -513,6 +515,7 @@ def forward_extend(
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
sinks=sinks,
**kwargs,
)

Expand All @@ -534,6 +537,7 @@ def forward_extend(
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
sinks=sinks,
**kwargs,
)
o, _ = merge_state_v2_wrapper(
Expand Down Expand Up @@ -562,31 +566,49 @@ def forward_extend(
assert chunk_idx >= 0

assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
output = flash_attn_with_kvcache(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
k_cache=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(
q.dtype
),
v_cache=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(
q.dtype
),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
cu_seqlens_k=torch.diff(
forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
),
max_seqlen_q=metadata.max_seq_len_q,
page_table=torch.arange(
0, metadata.cu_seqlens_q.numel(), device=self.device
),
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
sinks=sinks,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
output = flash_attn_varlen_func(
output = flash_attn_with_kvcache(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
k_cache=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(
q.dtype
),
v_cache=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(
q.dtype
),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
cu_seqlens_k=torch.diff(metadata.cu_seqlens_q),
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
page_table=torch.arange(
0, metadata.cu_seqlens_q.numel(), device=self.device
),
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=forward_batch.mha_return_lse,
sinks=sinks,
)
if forward_batch.mha_return_lse:
output, lse, *rest = output
Expand Down Expand Up @@ -762,10 +784,10 @@ def forward_decode(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=metadata.encoder_page_table,
cache_seqlens=metadata.encoder_lens_int32,
page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=False,
Expand Down Expand Up @@ -998,9 +1020,7 @@ def _init_sliding_window_attn_spec_metadata(
)
bs = cache_seqlens_int32.shape[0]
page_table = (
metadata.page_table.new_zeros(
(bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
)
metadata.page_table.new_zeros((bs, self.max_num_pages_per_req))
if metadata_swa is None
else metadata_swa.page_table
)
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
is_cpu,
is_cuda,
is_hip,
is_xpu,
)

from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config
Expand All @@ -39,6 +40,7 @@
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_xpu = is_xpu()

if _is_cuda:
from sgl_kernel import gelu_and_mul, moe_sum_reduce, silu_and_mul
Expand All @@ -54,6 +56,8 @@
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
else:
from vllm import _custom_ops as vllm_ops
elif _is_xpu:
from sgl_kernel import moe_sum, silu_and_mul

padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0

Expand Down Expand Up @@ -552,7 +556,7 @@ def fused_experts_impl(
gemm1_alpha,
gemm1_limit,
)
elif _is_cuda or _is_hip:
elif _is_cuda or _is_hip or _is_xpu:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
vllm_ops.silu_and_mul(
Expand Down Expand Up @@ -658,6 +662,11 @@ def fused_experts_impl(
out_hidden_states[begin_chunk_idx:end_chunk_idx],
routed_scaling_factor,
)
elif _is_xpu:
moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else:
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import torch
import triton

from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import is_cuda, is_hip, is_xpu

_is_cuda = is_cuda()
_is_hip = is_hip()
_is_xpu = is_xpu()

if _is_cuda or _is_hip:
if _is_cuda or _is_hip or _is_xpu:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size


Expand Down
Loading
Loading