-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[Rocm][torch.compile] Adding layernorm + fp8 block quant and silu + fp8 block quant for Aiter #25693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
charlifu
wants to merge
41
commits into
vllm-project:main
Choose a base branch
from
ROCm:amd/aiter_fusion_pass
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+362
−22
Open
[Rocm][torch.compile] Adding layernorm + fp8 block quant and silu + fp8 block quant for Aiter #25693
Changes from 11 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
2f538fa
add aiter silu fused kernel
charlifu 9d6507b
add silu fusion pass
charlifu b901f27
fix pass
charlifu 1d11425
workable silu_mul fusion pass
charlifu b48f84d
fix aiter fp8 linear support
charlifu 41e7e2f
add is rocm aiter linear enabled
charlifu 9940a40
fusion for AITER group quant RMSNorm and AITER w8a8 gemm
micah-wil cd059b9
fix undefined symbol conditional on is_rocm_aiter_linear_enabled
micah-wil 6cf02a9
only add aiter rmsnorm fusion patterns if aiter is enabled
micah-wil f2cd510
Merge branch 'main' into amd/aiter_fusion_pass
charlifu 7298b55
fix silu + fp8 block quant pass
charlifu 19348df
Merge branch 'main' into amd/aiter_fusion_pass
charlifu 254a34b
fix layernorm pass
charlifu b835c6a
Merge branch 'main' into amd/aiter_fusion_pass
charlifu 736019a
Merge branch 'main' into amd/aiter_fusion_pass
charlifu 4793f39
fix import
charlifu af6bc06
use fusion pass to enable aiter fused_mul_add
micah-wil 824c419
Merge branch 'main' into amd/aiter_fusion_pass
charlifu 5754863
Merge branch 'main' into micah/aiter_fused_muladd
charlifu f600359
Merge branch 'micah/aiter_fused_muladd' into amd/aiter_fusion_pass
charlifu 3f2ee60
Fix conditions for applying fused_mul_add
micah-wil f69d89e
Merge branch 'main' into amd/aiter_fusion_pass
charlifu 9153e6b
Merge branch 'main' into amd/aiter_fusion_pass
wuhuikx 3f250be
Merge branch 'main' into amd/aiter_fusion_pass
charlifu 06516eb
Merge branch 'main' into amd/aiter_fusion_pass
charlifu f450db9
fix rms+quant pass
charlifu 1d2803b
Merge branch 'main' into amd/aiter_fusion_pass
charlifu afdd2f5
remove mul+add pass
charlifu c3cf015
Apply suggestions from code review
charlifu 1001ee0
fix type hint
charlifu 8b1792a
fix silu pass
charlifu 881ba7e
fix silu pass
charlifu 9054adf
Merge branch 'main' into amd/aiter_fusion_pass
charlifu ee41e28
fix merging
charlifu 4710e07
fix fusion pass
charlifu 9712dcb
remove print
charlifu 8f539db
Merge branch 'main' into amd/aiter_fusion_pass
charlifu 57eebbc
add ops in aiter_ops.py
charlifu 62aef2e
fix silu pass
charlifu e63f937
fix rms passes
charlifu fa0d969
fix rms passes
charlifu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| ) | ||
| from torch._ops import OpOverload | ||
|
|
||
| from vllm import envs | ||
| from vllm.config import VllmConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
|
|
@@ -75,6 +76,93 @@ def register(self, pm_pass: PatternMatcherPass): | |
| raise NotImplementedError | ||
|
|
||
|
|
||
| def is_rocm_aiter_linear_enabled() -> bool: | ||
| return ( | ||
| current_platform.is_rocm() | ||
| and envs.VLLM_ROCM_USE_AITER | ||
| and envs.VLLM_ROCM_USE_AITER_LINEAR | ||
| ) | ||
|
|
||
|
|
||
| if is_rocm_aiter_linear_enabled(): | ||
| import aiter as rocm_aiter | ||
| from aiter.ops.triton.activation import act_mul_and_fp8_group_quant | ||
|
|
||
| from vllm.utils import direct_register_custom_op | ||
|
|
||
| rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 | ||
| rocm_aiter_fp8_quant_group_size = 128 | ||
|
|
||
| def _rocm_aiter_act_mul_and_fp8_group_quant_impl( | ||
| x: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| return act_mul_and_fp8_group_quant( | ||
| x, | ||
| activation="silu", | ||
| group_size=rocm_aiter_fp8_quant_group_size, | ||
| dtype_quant=rocm_aiter_fp8_dtype, | ||
| ) | ||
|
|
||
| def _rocm_aiter_act_mul_and_fp8_group_quant_fake( | ||
| x: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| M, N = x.shape | ||
| assert N % 2 == 0 | ||
| N_half = N // 2 | ||
| x_fp8 = torch.empty((M, N_half), dtype=rocm_aiter_fp8_dtype, device=x.device) | ||
| out_bs = torch.empty( | ||
| ( | ||
| M, | ||
| (N_half + rocm_aiter_fp8_quant_group_size - 1) | ||
| // rocm_aiter_fp8_quant_group_size, | ||
| ), | ||
| dtype=torch.float32, | ||
| device=x.device, | ||
| ) | ||
| return x_fp8, out_bs | ||
|
|
||
| direct_register_custom_op( | ||
| op_name="rocm_aiter_act_mul_and_fp8_group_quant", | ||
| op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl, | ||
| mutates_args=[], | ||
| fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake, | ||
| dispatch_key=current_platform.dispatch_key, | ||
| ) | ||
| AITER_BLOCK_QUANT_OP = torch.ops.vllm.rocm_aiter_per1x128_quant.default | ||
| FUSED_SILU_MUL_QUANT_OP = ( | ||
| torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default | ||
| ) | ||
|
|
||
| class AiterSiluMulFp8BlockQuantPattern(ActivationQuantPattern): | ||
| def __init__(self): | ||
| pass | ||
|
|
||
| def register(self, pm_pass: PatternMatcherPass): | ||
| def pattern( | ||
| input: torch.Tensor, | ||
| result_silu_mul: torch.Tensor, | ||
| ): | ||
| at1 = auto_functionalized( | ||
| SILU_MUL_OP, result=result_silu_mul, input=input | ||
| ) | ||
|
||
| at2 = AITER_BLOCK_QUANT_OP(x=at1[1]) | ||
| return at2[0], at2[1] | ||
|
|
||
| def replacement( | ||
| input: torch.Tensor, | ||
| result_silu_mul: torch.Tensor, | ||
| ): | ||
| at = FUSED_SILU_MUL_QUANT_OP(x=input) | ||
| return at[0], at[1] | ||
|
|
||
| inputs = [ | ||
| empty_bf16(5, 4), # input | ||
| empty_bf16(5, 4), # result_silu_mul | ||
| ] | ||
|
|
||
| register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) | ||
|
|
||
|
|
||
| class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): | ||
| """ | ||
| Fusion for SiluMul+Fp8StaticQuant Pattern | ||
|
|
@@ -198,6 +286,10 @@ def __init__(self, config: VllmConfig): | |
| pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() | ||
| pattern_silu_mul_nvfp4.register(self.patterns) | ||
|
|
||
| if is_rocm_aiter_linear_enabled(): | ||
| pattern_silu_mul_aiter_block_fp8 = AiterSiluMulFp8BlockQuantPattern() | ||
| pattern_silu_mul_aiter_block_fp8.register(self.patterns) | ||
charlifu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self.dump_patterns(config, self.patterns) | ||
|
|
||
| @VllmInductorPass.time_and_log | ||
|
|
@@ -206,9 +298,11 @@ def __call__(self, graph: torch.fx.Graph): | |
| logger.debug("Replaced %s patterns", self.matched_count) | ||
|
|
||
| def uuid(self): | ||
| return VllmInductorPass.hash_source( | ||
| self, | ||
| fusion_patterns = [ | ||
| ActivationQuantPattern, | ||
| SiluMulFp8StaticQuantPattern, | ||
| SiluMulNvfp4QuantPattern, | ||
| ) | ||
| ] | ||
| if is_rocm_aiter_linear_enabled(): | ||
| fusion_patterns.append(AiterSiluMulFp8BlockQuantPattern) | ||
| return VllmInductorPass.hash_source(self, *fusion_patterns) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if the latest aiter allows you to skip direct register custom ops? I remember most ops now should be able to work without calling
direct_register_custom_opson vLLM side as it is done in AITER repository. Moreover, removing thedirect_register_custom_opswrappers can reduce additional CPU overhead. Doingdirect_register_custom_opscan be costly in terms of overhead.Please take a look at the benchmarking results in this PR ROCm#717 (the second and third case) where it shows that removing the
direct_register_custom_opson vLLM side improves the perf.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for the feedback. Is there a version of aiter which has aiter.ops.triton.fused_fp8_quant and also has these direct_register_custom_ops that you mentioned? I wasn't able to figure out how to call act_mul_and_fp8_group_quant without calling direct_register_custom_op first. Would be happy to investigate further if you can point me in the right direction, otherwise I think we can always come back and get rid of these direct_register_custom_op calls if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can come back to this in later PR as the
355_wipaiter commit does not have that feature.