Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2f538fa
add aiter silu fused kernel
charlifu Sep 16, 2025
9d6507b
add silu fusion pass
charlifu Sep 16, 2025
b901f27
fix pass
charlifu Sep 17, 2025
1d11425
workable silu_mul fusion pass
charlifu Sep 17, 2025
b48f84d
fix aiter fp8 linear support
charlifu Sep 18, 2025
41e7e2f
add is rocm aiter linear enabled
charlifu Sep 25, 2025
9940a40
fusion for AITER group quant RMSNorm and AITER w8a8 gemm
micah-wil Sep 25, 2025
cd059b9
fix undefined symbol conditional on is_rocm_aiter_linear_enabled
micah-wil Oct 3, 2025
6cf02a9
only add aiter rmsnorm fusion patterns if aiter is enabled
micah-wil Oct 7, 2025
f2cd510
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Oct 8, 2025
7298b55
fix silu + fp8 block quant pass
charlifu Oct 8, 2025
19348df
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Oct 13, 2025
254a34b
fix layernorm pass
charlifu Oct 13, 2025
b835c6a
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Oct 14, 2025
736019a
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Oct 21, 2025
4793f39
fix import
charlifu Oct 21, 2025
af6bc06
use fusion pass to enable aiter fused_mul_add
micah-wil Oct 21, 2025
824c419
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Oct 22, 2025
5754863
Merge branch 'main' into micah/aiter_fused_muladd
charlifu Oct 22, 2025
f600359
Merge branch 'micah/aiter_fused_muladd' into amd/aiter_fusion_pass
charlifu Oct 22, 2025
3f2ee60
Fix conditions for applying fused_mul_add
micah-wil Oct 22, 2025
f69d89e
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Oct 27, 2025
9153e6b
Merge branch 'main' into amd/aiter_fusion_pass
wuhuikx Oct 28, 2025
3f250be
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Oct 30, 2025
06516eb
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Nov 3, 2025
f450db9
fix rms+quant pass
charlifu Nov 5, 2025
1d2803b
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Nov 5, 2025
afdd2f5
remove mul+add pass
charlifu Nov 6, 2025
c3cf015
Apply suggestions from code review
charlifu Nov 6, 2025
1001ee0
fix type hint
charlifu Nov 6, 2025
8b1792a
fix silu pass
charlifu Nov 6, 2025
881ba7e
fix silu pass
charlifu Nov 6, 2025
9054adf
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Nov 7, 2025
ee41e28
fix merging
charlifu Nov 7, 2025
4710e07
fix fusion pass
charlifu Nov 10, 2025
9712dcb
remove print
charlifu Nov 10, 2025
8f539db
Merge branch 'main' into amd/aiter_fusion_pass
charlifu Nov 11, 2025
57eebbc
add ops in aiter_ops.py
charlifu Nov 12, 2025
62aef2e
fix silu pass
charlifu Nov 12, 2025
e63f937
fix rms passes
charlifu Nov 12, 2025
fa0d969
fix rms passes
charlifu Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 97 additions & 3 deletions vllm/compilation/activation_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from torch._ops import OpOverload

from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -75,6 +76,93 @@ def register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError


def is_rocm_aiter_linear_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
)


if is_rocm_aiter_linear_enabled():
import aiter as rocm_aiter
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant

from vllm.utils import direct_register_custom_op

rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128

def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return act_mul_and_fp8_group_quant(
x,
activation="silu",
group_size=rocm_aiter_fp8_quant_group_size,
dtype_quant=rocm_aiter_fp8_dtype,
)

def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half), dtype=rocm_aiter_fp8_dtype, device=x.device)
out_bs = torch.empty(
(
M,
(N_half + rocm_aiter_fp8_quant_group_size - 1)
// rocm_aiter_fp8_quant_group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs

direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
Copy link
Collaborator

@tjtanaa tjtanaa Sep 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if the latest aiter allows you to skip direct register custom ops? I remember most ops now should be able to work without calling direct_register_custom_ops on vLLM side as it is done in AITER repository. Moreover, removing the direct_register_custom_ops wrappers can reduce additional CPU overhead. Doing direct_register_custom_ops can be costly in terms of overhead.

Please take a look at the benchmarking results in this PR ROCm#717 (the second and third case) where it shows that removing the direct_register_custom_ops on vLLM side improves the perf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for the feedback. Is there a version of aiter which has aiter.ops.triton.fused_fp8_quant and also has these direct_register_custom_ops that you mentioned? I wasn't able to figure out how to call act_mul_and_fp8_group_quant without calling direct_register_custom_op first. Would be happy to investigate further if you can point me in the right direction, otherwise I think we can always come back and get rid of these direct_register_custom_op calls if needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can come back to this in later PR as the 355_wip aiter commit does not have that feature.

op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
AITER_BLOCK_QUANT_OP = torch.ops.vllm.rocm_aiter_per1x128_quant.default
FUSED_SILU_MUL_QUANT_OP = (
torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
)

class AiterSiluMulFp8BlockQuantPattern(ActivationQuantPattern):
def __init__(self):
pass

def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
result_silu_mul: torch.Tensor,
):
at1 = auto_functionalized(
SILU_MUL_OP, result=result_silu_mul, input=input
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the MatcherSiluMul

at2 = AITER_BLOCK_QUANT_OP(x=at1[1])
return at2[0], at2[1]

def replacement(
input: torch.Tensor,
result_silu_mul: torch.Tensor,
):
at = FUSED_SILU_MUL_QUANT_OP(x=input)
return at[0], at[1]

inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # result_silu_mul
]

register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)


class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
Expand Down Expand Up @@ -198,6 +286,10 @@ def __init__(self, config: VllmConfig):
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)

if is_rocm_aiter_linear_enabled():
pattern_silu_mul_aiter_block_fp8 = AiterSiluMulFp8BlockQuantPattern()
pattern_silu_mul_aiter_block_fp8.register(self.patterns)

self.dump_patterns(config, self.patterns)

@VllmInductorPass.time_and_log
Expand All @@ -206,9 +298,11 @@ def __call__(self, graph: torch.fx.Graph):
logger.debug("Replaced %s patterns", self.matched_count)

def uuid(self):
return VllmInductorPass.hash_source(
self,
fusion_patterns = [
ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern,
)
]
if is_rocm_aiter_linear_enabled():
fusion_patterns.append(AiterSiluMulFp8BlockQuantPattern)
return VllmInductorPass.hash_source(self, *fusion_patterns)
171 changes: 168 additions & 3 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -88,6 +89,28 @@ def __str__(self):
}


def is_rocm_aiter_enabled() -> bool:
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER


if is_rocm_aiter_enabled():
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
AITER_RMS_ADD_GROUP_QUANT_OP = (
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
)

BLOCK_LINEAR_OP = torch.ops.vllm.apply_w8a8_block_fp8_linear.default
AITER_BLOCK_LINEAR_OP = torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale.default

AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default

import aiter as rocm_aiter

rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128


class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
self.epsilon = epsilon
Expand Down Expand Up @@ -382,6 +405,136 @@ def replacement(
)


if is_rocm_aiter_enabled():

class AiterRMSGroupQuantFP8Pattern:
def __init__(self, epsilon: float, quant_dtype: torch.dtype):
self.epsilon = epsilon
self.quant_dtype = quant_dtype

def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor, # result_rms: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor,
):
at1 = AITER_RMS_OP(
x=input, weight=weight, variance_epsilon=self.epsilon
)

at2 = BLOCK_LINEAR_OP(
input=at1,
weight=linear_weight,
block_size=[128, 128],
weight_scale=linear_weight_scale,
input_scale=None,
bias=None,
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)

return at2

def replacement(
input: torch.Tensor,
weight: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor,
):
at1 = AITER_RMS_GROUP_QUANT_OP(
x=input, residual=None, weight=weight, variance_epsilon=self.epsilon
)

at2 = AITER_BLOCK_LINEAR_OP(
A=at1[0],
B=linear_weight,
As=at1[1],
Bs=linear_weight_scale,
block_size=[128, 128],
output_dtype=input.dtype,
)

return at2

inputs = [
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
torch.empty((2, 5), device="cuda", dtype=FP8_DTYPE), # linear_weight
empty_fp32(1, 1), # linear_weight_scale
]

pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)

class AiterFusedAddRMSGroupQuantPattern:
def __init__(self, epsilon: float, quant_dtype: torch.dtype):
self.epsilon = epsilon
self.quant_dtype = quant_dtype

def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor,
):
at1 = AITER_RMS_ADD_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
)

at2 = BLOCK_LINEAR_OP(
input=at1[0],
weight=linear_weight,
block_size=[128, 128],
weight_scale=linear_weight_scale,
input_scale=None,
bias=None,
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
# result, residual
return at2, at1[1]

def replacement(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor,
):
at1 = AITER_RMS_ADD_GROUP_QUANT_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
)

at2 = AITER_BLOCK_LINEAR_OP(
A=at1[0],
B=linear_weight,
As=at1[1],
Bs=linear_weight_scale,
block_size=[128, 128],
output_dtype=input.dtype,
)
# result, residual
return at2, at1[2]

inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
torch.empty((2, 5), device="cuda", dtype=FP8_DTYPE), # linear_weight
empty_fp32(1, 1), # linear_weight_scale
]

pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
Expand Down Expand Up @@ -413,6 +566,14 @@ def __init__(self, config: VllmConfig):
self.patterns
)

if is_rocm_aiter_enabled():
# Fuse rms_norm + dynamic group fp8 quant
AiterRMSGroupQuantFP8Pattern(epsilon, FP8_DTYPE).register(self.patterns)

AiterFusedAddRMSGroupQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)

self.dump_patterns(config, self.patterns)

@VllmInductorPass.time_and_log
Expand All @@ -421,11 +582,15 @@ def __call__(self, graph: fx.Graph):
logger.debug("Replaced %s patterns", self.matched_count)

def uuid(self) -> Any:
return self.hash_source(
self,
fusion_patterns = [
RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern,
)
]
if is_rocm_aiter_enabled():
fusion_patterns.extend(
[AiterRMSGroupQuantFP8Pattern, AiterFusedAddRMSGroupQuantPattern]
)
return self.hash_source(self, *fusion_patterns)
Loading