-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update HSTU and use the OSS wrapper for non-persisent kernels (#53)
Summary: Add the following features to ragged_hstu: 1. Add tflops metric 2. Use _RaggedAttentionRelativeBiasFunction to wrap the Triton kernel 3. Add backward Pull Request resolved: #53 Test Plan: ``` $ python run.py --op ragged_attention --metrics latency,tflops --mode bwd x_val hstu_triton_ragged_attention-tflops hstu_triton_ragged_attention-latency ----------------- ------------------------------------- -------------------------------------- (8, 4, 512, 2048) 0.306747 2.81939 (8, 4, 512, 2048) 1.65614 0.867936 (8, 4, 512, 2048) 2.00125 0.84768 (8, 4, 512, 2048) 2.13756 0.991968 (8, 4, 512, 2048) 1.96315 0.902976 (8, 4, 512, 2048) 1.50214 0.836192 (8, 4, 512, 2048) 1.34825 0.859936 (8, 4, 512, 2048) 1.90546 0.97408 (8, 4, 512, 2048) 1.72114 0.902368 (8, 4, 512, 2048) 2.30999 1.01107 ``` Reviewed By: manman-ren Differential Revision: D66021701 Pulled By: xuzhao9 fbshipit-source-id: b0d9f32d49e02c113e4aafa597be68c17d952283
- Loading branch information
1 parent
c08a2a8
commit c2ef66a
Showing
6 changed files
with
166 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule generative-recommenders
updated
23 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
diff --git a/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py b/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py | ||
index b4e318b..d6bc894 100644 | ||
--- a/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py | ||
+++ b/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py | ||
@@ -36,7 +36,7 @@ try: | ||
VersionedSpec, | ||
) | ||
except ImportError: | ||
- from hammer.oss.generative_recommenders.ops.triton.utils import ( | ||
+ from generative_recommenders.ops.triton.utils import ( | ||
_switch_to_contiguous_if_needed, | ||
autotune_max_seq_len, | ||
NamedSpecType, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
import subprocess | ||
import sys | ||
from pathlib import Path | ||
|
||
PATCH_DIR = str( | ||
Path(__file__) | ||
.parent.parent.parent.joinpath("submodules", "generative-recommenders") | ||
.absolute() | ||
) | ||
PATCH_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "hstu.patch") | ||
|
||
|
||
def install_hstu(): | ||
try: | ||
subprocess.check_output( | ||
[ | ||
"patch", | ||
"-p1", | ||
"--forward", | ||
"-i", | ||
PATCH_FILE, | ||
"-r", | ||
"/tmp/rej", | ||
], | ||
cwd=PATCH_DIR, | ||
) | ||
except subprocess.SubprocessError as e: | ||
output_str = str(e.output) | ||
if "previously applied" in output_str: | ||
return | ||
else: | ||
print(str(output_str)) | ||
sys.exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters