@@ -52,12 +52,6 @@ def test_fbgemm():
52
52
print ("OK" )
53
53
54
54
55
- def install_cutlass ():
56
- from tools .cutlass_kernels .install import install_colfax_cutlass
57
-
58
- install_colfax_cutlass ()
59
-
60
-
61
55
def install_fa2 (compile = False ):
62
56
if compile :
63
57
# compile from source (slow)
@@ -83,12 +77,6 @@ def install_liger():
83
77
subprocess .check_call (cmd )
84
78
85
79
86
- def install_tk ():
87
- from tools .tk .install import install_tk
88
-
89
- install_tk ()
90
-
91
-
92
80
def install_xformers ():
93
81
os_env = os .environ .copy ()
94
82
os_env ["TORCH_CUDA_ARCH_LIST" ] = "8.0;9.0;9.0a"
@@ -101,7 +89,7 @@ def install_xformers():
101
89
parser = argparse .ArgumentParser (allow_abbrev = False )
102
90
parser .add_argument ("--fbgemm" , action = "store_true" , help = "Install FBGEMM GPU" )
103
91
parser .add_argument (
104
- "--cutlass " , action = "store_true" , help = "Install optional CUTLASS kernels"
92
+ "--colfax " , action = "store_true" , help = "Install optional Colfax CUTLASS kernels"
105
93
)
106
94
parser .add_argument (
107
95
"--fa2" , action = "store_true" , help = "Install optional flash_attention 2 kernels"
@@ -139,14 +127,16 @@ def install_xformers():
139
127
if args .fa3 or args .all :
140
128
logger .info ("[tritonbench] installing fa3..." )
141
129
install_fa3 ()
142
- if args .cutlass or args .all :
143
- logger .info ("[tritonbench] installing cutlass-kernels..." )
144
- install_cutlass ()
130
+ if args .colfax or args .all :
131
+ logger .info ("[tritonbench] installing colfax cutlass-kernels..." )
132
+ from tools .cutlass_kernels .install import install_colfax_cutlass
133
+ install_colfax_cutlass ()
145
134
if args .jax or args .all :
146
135
logger .info ("[tritonbench] installing jax..." )
147
136
install_jax ()
148
137
if args .tk or args .all :
149
138
logger .info ("[tritonbench] installing thunderkittens..." )
139
+ from tools .tk .install import install_tk
150
140
install_tk ()
151
141
if args .liger or args .all :
152
142
logger .info ("[tritonbench] installing liger-kernels..." )
0 commit comments