Skip to content

Commit c1aab43

Browse files
committed
Add install to colfax
1 parent fec9876 commit c1aab43

File tree

1 file changed

+6
-16
lines changed

1 file changed

+6
-16
lines changed

install.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ def test_fbgemm():
5252
print("OK")
5353

5454

55-
def install_cutlass():
56-
from tools.cutlass_kernels.install import install_colfax_cutlass
57-
58-
install_colfax_cutlass()
59-
60-
6155
def install_fa2(compile=False):
6256
if compile:
6357
# compile from source (slow)
@@ -83,12 +77,6 @@ def install_liger():
8377
subprocess.check_call(cmd)
8478

8579

86-
def install_tk():
87-
from tools.tk.install import install_tk
88-
89-
install_tk()
90-
91-
9280
def install_xformers():
9381
os_env = os.environ.copy()
9482
os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a"
@@ -101,7 +89,7 @@ def install_xformers():
10189
parser = argparse.ArgumentParser(allow_abbrev=False)
10290
parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU")
10391
parser.add_argument(
104-
"--cutlass", action="store_true", help="Install optional CUTLASS kernels"
92+
"--colfax", action="store_true", help="Install optional Colfax CUTLASS kernels"
10593
)
10694
parser.add_argument(
10795
"--fa2", action="store_true", help="Install optional flash_attention 2 kernels"
@@ -139,14 +127,16 @@ def install_xformers():
139127
if args.fa3 or args.all:
140128
logger.info("[tritonbench] installing fa3...")
141129
install_fa3()
142-
if args.cutlass or args.all:
143-
logger.info("[tritonbench] installing cutlass-kernels...")
144-
install_cutlass()
130+
if args.colfax or args.all:
131+
logger.info("[tritonbench] installing colfax cutlass-kernels...")
132+
from tools.cutlass_kernels.install import install_colfax_cutlass
133+
install_colfax_cutlass()
145134
if args.jax or args.all:
146135
logger.info("[tritonbench] installing jax...")
147136
install_jax()
148137
if args.tk or args.all:
149138
logger.info("[tritonbench] installing thunderkittens...")
139+
from tools.tk.install import install_tk
150140
install_tk()
151141
if args.liger or args.all:
152142
logger.info("[tritonbench] installing liger-kernels...")

0 commit comments

Comments
 (0)