Skip to content

Commit abb7ac6

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Patch when installing the xformers (#61)
Summary: This is to patch xformers as its FA3 extension build will fail due to lack of linking to libcuda.so: facebookresearch/xformers#1157 Fixes #20 Pull Request resolved: #61 Reviewed By: FindHao Differential Revision: D66273474 Pulled By: xuzhao9 fbshipit-source-id: 81898ccd005750937ac3cfd639c2303975ef1abe
1 parent e2bbc48 commit abb7ac6

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

install.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,6 @@ def install_liger():
7777
subprocess.check_call(cmd)
7878

7979

80-
def install_xformers():
81-
os_env = os.environ.copy()
82-
os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a"
83-
XFORMERS_PATH = REPO_PATH.joinpath("submodules", "xformers")
84-
cmd = ["pip", "install", "-e", XFORMERS_PATH]
85-
subprocess.check_call(cmd, env=os_env)
86-
87-
8880
if __name__ == "__main__":
8981
parser = argparse.ArgumentParser(allow_abbrev=False)
9082
parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU")
@@ -145,6 +137,8 @@ def install_xformers():
145137
install_liger()
146138
if args.xformers or args.all:
147139
logger.info("[tritonbench] installing xformers...")
140+
from tools.xformers.install import install_xformers
141+
148142
install_xformers()
149143
if args.hstu or args.all:
150144
logger.info("[tritonbench] installing hstu...")

tools/xformers/install.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import subprocess
3+
import sys
4+
from pathlib import Path
5+
6+
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
7+
PATCH_DIR = str(REPO_PATH.joinpath("submodules", "xformers").absolute())
8+
PATCH_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "xformers.patch")
9+
10+
11+
def patch_xformers():
12+
try:
13+
subprocess.check_output(
14+
[
15+
"patch",
16+
"-p1",
17+
"--forward",
18+
"-i",
19+
PATCH_FILE,
20+
"-r",
21+
"/tmp/rej",
22+
],
23+
cwd=PATCH_DIR,
24+
)
25+
except subprocess.SubprocessError as e:
26+
output_str = str(e.output)
27+
if "previously applied" in output_str:
28+
return
29+
else:
30+
print(str(output_str))
31+
sys.exit(1)
32+
33+
34+
def install_xformers():
35+
patch_xformers()
36+
os_env = os.environ.copy()
37+
os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a"
38+
XFORMERS_PATH = REPO_PATH.joinpath("submodules", "xformers")
39+
cmd = ["pip", "install", "-e", XFORMERS_PATH]
40+
subprocess.check_call(cmd, env=os_env)

tools/xformers/xformers.patch

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
From 1056e56f873fa6a097de3a7c1ceeeed66676ae82 Mon Sep 17 00:00:00 2001
2+
From: Xu Zhao <[email protected]>
3+
Date: Wed, 20 Nov 2024 19:19:46 -0500
4+
Subject: [PATCH] Link to cuda library
5+
6+
---
7+
setup.py | 2 ++
8+
1 file changed, 2 insertions(+)
9+
10+
diff --git a/setup.py b/setup.py
11+
index 6eaa50904..c804b4817 100644
12+
--- a/setup.py
13+
+++ b/setup.py
14+
@@ -356,6 +356,8 @@ def get_flash_attention3_extensions(cuda_version: int, extra_compile_args):
15+
Path(flash_root) / "hopper",
16+
]
17+
],
18+
+ # Without this we get and error about cuTensorMapEncodeTiled not defined
19+
+ libraries=["cuda"],
20+
)
21+
]
22+

0 commit comments

Comments
 (0)