Skip to content

Commit 7e5b801

Browse files
committed
issue/1090: success link flash-attention.so
1 parent 49a92dc commit 7e5b801

3 files changed

Lines changed: 55 additions & 78 deletions

File tree

include/infinicore/adaptor/aten_adaptor.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
#include <ATen/ATen.h>
77

8+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
89
#include <ATen/cuda/CUDAContext.h>
910
#include <c10/cuda/CUDAGuard.h>
11+
#endif
1012

1113
namespace infinicore::adaptor {
1214
inline at::ScalarType to_at_dtype(DataType dtype) {

xmake.lua

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ if has_config("aten") then
247247
end
248248
end
249249

250-
251250
-- cuda graph
252251
option("graph")
253252
set_default(false)
@@ -259,7 +258,6 @@ if has_config("graph") then
259258
add_defines("USE_INFINIRT_GRAPH")
260259
end
261260

262-
263261
-- InfiniCCL
264262
option("ccl")
265263
set_default(false)
@@ -473,6 +471,19 @@ target("infinicore_cpp_api")
473471
end
474472
end
475473

474+
if get_config("flash-attn") and get_config("flash-attn") ~= "" and has_config("qy-gpu") then
475+
local flash_so_qy = _qy_flash_attn_cuda_so_path()
476+
local flash_dir_qy = path.directory(flash_so_qy)
477+
local flash_name_qy = path.filename(flash_so_qy)
478+
before_link(function (target)
479+
target:add(
480+
"shflags",
481+
"-Wl,--no-as-needed -L" .. flash_dir_qy .. " -l:" .. flash_name_qy .. " -Wl,-rpath," .. flash_dir_qy,
482+
{force = true}
483+
)
484+
end)
485+
end
486+
476487
before_build(function (target)
477488
if has_config("aten") then
478489
local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()

xmake/qy.lua

Lines changed: 40 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,35 @@ local FLASH_ATTN_ROOT = get_config("flash-attn")
1313

1414
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
1515

16+
local FLASH_ATTN_QY_CUDA_SO_CONTAINER_DEFAULT =
17+
"/home/shangyouren/miniconda3/envs/xiaobase/lib/python3.12/site-packages/flash_attn_2_cuda.cpython-312-x86_64-linux-gnu.so"
18+
19+
function _qy_flash_attn_cuda_so_path()
20+
-- Highest priority: override the exact `.so` file to link.
21+
local env_path = os.getenv("FLASH_ATTN_2_CUDA_SO")
22+
if env_path and env_path ~= "" then
23+
env_path = env_path:trim()
24+
if not os.isfile(env_path) then
25+
raise("qy+flash-attn: FLASH_ATTN_2_CUDA_SO is not a file: %s", env_path)
26+
end
27+
return env_path
28+
end
29+
30+
-- Second priority: allow overriding the "expected" container path via env.
31+
local container_path = os.getenv("FLASH_ATTN_QY_CUDA_SO_CONTAINER")
32+
if not container_path or container_path == "" then
33+
container_path = FLASH_ATTN_QY_CUDA_SO_CONTAINER_DEFAULT
34+
end
35+
36+
if not os.isfile(container_path) then
37+
raise(
38+
"qy+flash-attn: expected %s\n Install flash-attn in the conda env, or export FLASH_ATTN_2_CUDA_SO to your .so path.",
39+
container_path
40+
)
41+
end
42+
return container_path
43+
end
44+
1645
add_includedirs("/usr/local/denglin/sdk/include", "../include")
1746
add_linkdirs("/usr/local/denglin/sdk/lib")
1847
add_links("curt", "cublas", "cudnn")
@@ -177,89 +206,24 @@ target("infiniccl-qy")
177206
target_end()
178207

179208
target("flash-attn-qy")
180-
set_kind("shared")
209+
set_kind("phony")
181210
set_default(false)
211+
182212

183-
set_languages("cxx17")
184-
add_cxxflags("-std=c++17")
185-
add_cuflags("--std=c++17", {force = true})
186-
187-
-- 🔥 DLCC 规则
188-
add_rules("qy.cuda", {override = true})
189-
190-
if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= false and FLASH_ATTN_ROOT ~= "" then
191-
192-
-- ⭐⭐⭐ 关键:用 on_load(不是 before_build)
193-
on_load(function (target)
194-
213+
if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then
214+
before_build(function (target)
215+
target:add("includedirs", "/usr/local/denglin/sdk/include", {public = true})
195216
local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
196217
local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
197218
local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()
198-
local LIB_PYTHON = os.iorunv("python", {"-c", "import glob,sysconfig,os;print(glob.glob(os.path.join(sysconfig.get_config_var('LIBDIR'),'libpython*.so'))[0])"}):trim()
199-
200-
-- ✅ CUDA(最关键)
201-
target:add("includedirs", "/usr/local/denglin/sdk/include", {public = true})
202-
203-
-- ✅ flash-attn
204-
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc")
205-
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn")
206-
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src")
207-
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/common")
208-
209-
-- ✅ torch
210-
target:add("includedirs", TORCH_DIR .. "/include")
211-
target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include")
212-
213-
-- ⚠️ 很关键:ATen 有些头在这里
214-
target:add("includedirs", TORCH_DIR .. "/include/TH")
215-
target:add("includedirs", TORCH_DIR .. "/include/THC")
216-
217-
-- ✅ python
218-
target:add("includedirs", PYTHON_INCLUDE)
219-
220-
-- ✅ cutlass
221-
if CUTLASS_ROOT then
222-
target:add("includedirs", CUTLASS_ROOT .. "/include")
223-
end
224-
225-
-- link dirs
226-
target:add("linkdirs", TORCH_DIR .. "/lib")
227-
target:add("linkdirs", PYTHON_LIB_DIR)
228-
target:add("linkdirs", "/usr/local/denglin/sdk/lib")
229-
230-
-- links
231-
target:add("links",
232-
"curt",
233-
"cublas",
234-
"cudnn",
235-
"torch",
236-
"torch_cpu",
237-
"torch_cuda",
238-
"c10",
239-
"c10_cuda",
240-
"torch_python",
241-
LIB_PYTHON
242-
)
219+
220+
-- Validate build/runtime env in container and keep these paths available for downstream linking.
221+
target:add("includedirs", TORCH_DIR .. "/include", TORCH_DIR .. "/include/torch/csrc/api/include", PYTHON_INCLUDE, {public = false})
222+
target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR, {public = false})
243223
end)
244-
245-
-- ✅ C++ host
246-
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp")
247-
248-
-- ✅ CUDA kernel
249-
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu")
250-
251-
-- flags
252-
add_cxflags("-fPIC", {force = true})
253-
add_cuflags("-O2", "-fPIC", "--expt-relaxed-constexpr", "--use_fast_math", {force = true})
254-
255-
add_ldflags("-Wl,--no-undefined", {force = true})
256-
257224
else
258-
on_load(function ()
259-
print("Flash Attention not available, skipping flash-attn-qy build")
225+
before_build(function (target)
226+
print("Flash Attention not available, skipping flash-attn-qy integration")
260227
end)
261228
end
262-
263-
on_install(function (target) end)
264-
265229
target_end()

0 commit comments

Comments
 (0)