@@ -13,6 +13,35 @@ local FLASH_ATTN_ROOT = get_config("flash-attn")
1313
1414local INFINI_ROOT = os.getenv (" INFINI_ROOT" ) or (os.getenv (is_host (" windows" ) and " HOMEPATH" or " HOME" ) .. " /.infini" )
1515
16+ local FLASH_ATTN_QY_CUDA_SO_CONTAINER_DEFAULT =
17+ " /home/shangyouren/miniconda3/envs/xiaobase/lib/python3.12/site-packages/flash_attn_2_cuda.cpython-312-x86_64-linux-gnu.so"
18+
19+ function _qy_flash_attn_cuda_so_path ()
20+ -- Highest priority: override the exact `.so` file to link.
21+ local env_path = os.getenv (" FLASH_ATTN_2_CUDA_SO" )
22+ if env_path and env_path ~= " " then
23+ env_path = env_path :trim ()
24+ if not os .isfile (env_path ) then
25+ raise (" qy+flash-attn: FLASH_ATTN_2_CUDA_SO is not a file: %s" , env_path )
26+ end
27+ return env_path
28+ end
29+
30+ -- Second priority: allow overriding the "expected" container path via env.
31+ local container_path = os.getenv (" FLASH_ATTN_QY_CUDA_SO_CONTAINER" )
32+ if not container_path or container_path == " " then
33+ container_path = FLASH_ATTN_QY_CUDA_SO_CONTAINER_DEFAULT
34+ end
35+
36+ if not os .isfile (container_path ) then
37+ raise (
38+ " qy+flash-attn: expected %s\n Install flash-attn in the conda env, or export FLASH_ATTN_2_CUDA_SO to your .so path." ,
39+ container_path
40+ )
41+ end
42+ return container_path
43+ end
44+
1645add_includedirs (" /usr/local/denglin/sdk/include" , " ../include" )
1746add_linkdirs (" /usr/local/denglin/sdk/lib" )
1847add_links (" curt" , " cublas" , " cudnn" )
@@ -177,89 +206,24 @@ target("infiniccl-qy")
177206target_end ()
178207
179208target (" flash-attn-qy" )
180- set_kind (" shared " )
209+ set_kind (" phony " )
181210 set_default (false )
211+
182212
183- set_languages (" cxx17" )
184- add_cxxflags (" -std=c++17" )
185- add_cuflags (" --std=c++17" , {force = true })
186-
187- -- 🔥 DLCC 规则
188- add_rules (" qy.cuda" , {override = true })
189-
190- if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= false and FLASH_ATTN_ROOT ~= " " then
191-
192- -- ⭐⭐⭐ 关键:用 on_load(不是 before_build)
193- on_load (function (target )
194-
213+ if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= " " then
214+ before_build (function (target )
215+ target :add (" includedirs" , " /usr/local/denglin/sdk/include" , {public = true })
195216 local TORCH_DIR = os .iorunv (" python" , {" -c" , " import torch, os; print(os.path.dirname(torch.__file__))" }):trim ()
196217 local PYTHON_INCLUDE = os .iorunv (" python" , {" -c" , " import sysconfig; print(sysconfig.get_paths()['include'])" }):trim ()
197218 local PYTHON_LIB_DIR = os .iorunv (" python" , {" -c" , " import sysconfig; print(sysconfig.get_config_var('LIBDIR'))" }):trim ()
198- local LIB_PYTHON = os .iorunv (" python" , {" -c" , " import glob,sysconfig,os;print(glob.glob(os.path.join(sysconfig.get_config_var('LIBDIR'),'libpython*.so'))[0])" }):trim ()
199-
200- -- ✅ CUDA(最关键)
201- target :add (" includedirs" , " /usr/local/denglin/sdk/include" , {public = true })
202-
203- -- ✅ flash-attn
204- target :add (" includedirs" , FLASH_ATTN_ROOT .. " /csrc" )
205- target :add (" includedirs" , FLASH_ATTN_ROOT .. " /csrc/flash_attn" )
206- target :add (" includedirs" , FLASH_ATTN_ROOT .. " /csrc/flash_attn/src" )
207- target :add (" includedirs" , FLASH_ATTN_ROOT .. " /csrc/common" )
208-
209- -- ✅ torch
210- target :add (" includedirs" , TORCH_DIR .. " /include" )
211- target :add (" includedirs" , TORCH_DIR .. " /include/torch/csrc/api/include" )
212-
213- -- ⚠️ 很关键:ATen 有些头在这里
214- target :add (" includedirs" , TORCH_DIR .. " /include/TH" )
215- target :add (" includedirs" , TORCH_DIR .. " /include/THC" )
216-
217- -- ✅ python
218- target :add (" includedirs" , PYTHON_INCLUDE )
219-
220- -- ✅ cutlass
221- if CUTLASS_ROOT then
222- target :add (" includedirs" , CUTLASS_ROOT .. " /include" )
223- end
224-
225- -- link dirs
226- target :add (" linkdirs" , TORCH_DIR .. " /lib" )
227- target :add (" linkdirs" , PYTHON_LIB_DIR )
228- target :add (" linkdirs" , " /usr/local/denglin/sdk/lib" )
229-
230- -- links
231- target :add (" links" ,
232- " curt" ,
233- " cublas" ,
234- " cudnn" ,
235- " torch" ,
236- " torch_cpu" ,
237- " torch_cuda" ,
238- " c10" ,
239- " c10_cuda" ,
240- " torch_python" ,
241- LIB_PYTHON
242- )
219+
220+ -- Validate build/runtime env in container and keep these paths available for downstream linking.
221+ target :add (" includedirs" , TORCH_DIR .. " /include" , TORCH_DIR .. " /include/torch/csrc/api/include" , PYTHON_INCLUDE , {public = false })
222+ target :add (" linkdirs" , TORCH_DIR .. " /lib" , PYTHON_LIB_DIR , {public = false })
243223 end )
244-
245- -- ✅ C++ host
246- add_files (FLASH_ATTN_ROOT .. " /csrc/flash_attn/flash_api.cpp" )
247-
248- -- ✅ CUDA kernel
249- add_files (FLASH_ATTN_ROOT .. " /csrc/flash_attn/src/*.cu" )
250-
251- -- flags
252- add_cxflags (" -fPIC" , {force = true })
253- add_cuflags (" -O2" , " -fPIC" , " --expt-relaxed-constexpr" , " --use_fast_math" , {force = true })
254-
255- add_ldflags (" -Wl,--no-undefined" , {force = true })
256-
257224 else
258- on_load (function ()
259- print (" Flash Attention not available, skipping flash-attn-qy build " )
225+ before_build (function (target )
226+ print (" Flash Attention not available, skipping flash-attn-qy integration " )
260227 end )
261228 end
262-
263- on_install (function (target ) end )
264-
265229target_end ()
0 commit comments