Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/infinicore/adaptor/aten_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include <ATen/ATen.h>

#ifdef ENABLE_NVIDIA_API
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#endif
Expand Down Expand Up @@ -33,16 +33,17 @@ inline at::Device to_at_device(const Device &device) {
return at::Device(at::kCUDA, device.getIndex());
} else if (device.getType() == Device::Type::CPU) {
return at::Device(at::kCPU);
} else if (device.getType() == Device::Type::QY) {
return at::Device(at::kCUDA, device.getIndex());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这代码nv能编译吗

} else {
throw std::runtime_error("Unsupported device type for ATen");
}
}

at::Tensor to_aten_tensor(const infinicore::Tensor &t);

#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream();
#endif

} // namespace infinicore::adaptor

#endif // ENABLE_ATEN
2 changes: 0 additions & 2 deletions src/infinicore/adaptor/aten_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
options);
}

#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
#endif

} // namespace infinicore::adaptor

Expand Down
2 changes: 1 addition & 1 deletion src/infiniop/ops/avg_pool1d/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ __device__ void avgPool1dKernel(
}
}

#if defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_QY_API)
// Iluvatar __half doesn't accept size_t directly.
y[y_offset] = sum / static_cast<T>(static_cast<double>(kernel_size));
#else
Expand Down
3 changes: 0 additions & 3 deletions src/infiniop/ops/fmod/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ __INFINI_C infiniStatus_t infiniopDestroyFmodDescriptor(infiniopFmodDescriptor_t
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
Expand Down
19 changes: 17 additions & 2 deletions xmake.lua
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ if has_config("aten") then
end
end


-- cuda graph
option("graph")
set_default(false)
Expand All @@ -259,7 +258,6 @@ if has_config("graph") then
add_defines("USE_INFINIRT_GRAPH")
end


-- InfiniCCL
option("ccl")
set_default(false)
Expand Down Expand Up @@ -467,6 +465,23 @@ target("infinicore_cpp_api")
if has_config("nv-gpu") then
add_deps("flash-attn-nvidia")
end
if has_config("qy-gpu") then
add_deps("flash-attn-qy")
add_files("build/.objs/flash-attn-qy/rules/qy.cuda/__/__/flash-attention-dl-v2.7.4.post1-19/csrc/flash_attn/src/*.cu.o", {public = true})
end
end

if get_config("flash-attn") and get_config("flash-attn") ~= "" and has_config("qy-gpu") then
local flash_so_qy = _qy_flash_attn_cuda_so_path()
local flash_dir_qy = path.directory(flash_so_qy)
local flash_name_qy = path.filename(flash_so_qy)
before_link(function (target)
target:add(
"shflags",
"-Wl,--no-as-needed -L" .. flash_dir_qy .. " -l:" .. flash_name_qy .. " -Wl,-rpath," .. flash_dir_qy,
{force = true}
)
end)
end

before_build(function (target)
Expand Down
78 changes: 76 additions & 2 deletions xmake/qy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,45 @@ if CUDNN_ROOT ~= nil then
add_includedirs(CUDNN_ROOT .. "/include")
end

local CUTLASS_ROOT = os.getenv("CUTLASS_ROOT") or os.getenv("CUTLASS_HOME") or os.getenv("CUTLASS_PATH")

if CUTLASS_ROOT ~= nil then
add_includedirs(CUTLASS_ROOT)
end

local FLASH_ATTN_ROOT = get_config("flash-attn")

local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")

local FLASH_ATTN_QY_CUDA_SO_CONTAINER_DEFAULT =
"/home/shangyouren/miniconda3/envs/xiaobase/lib/python3.12/site-packages/flash_attn_2_cuda.cpython-312-x86_64-linux-gnu.so"

function _qy_flash_attn_cuda_so_path()
-- Highest priority: override the exact `.so` file to link.
local env_path = os.getenv("FLASH_ATTN_2_CUDA_SO")
if env_path and env_path ~= "" then
env_path = env_path:trim()
if not os.isfile(env_path) then
raise("qy+flash-attn: FLASH_ATTN_2_CUDA_SO is not a file: %s", env_path)
end
return env_path
end

-- Second priority: allow overriding the "expected" container path via env.
local container_path = os.getenv("FLASH_ATTN_QY_CUDA_SO_CONTAINER")
if not container_path or container_path == "" then
container_path = FLASH_ATTN_QY_CUDA_SO_CONTAINER_DEFAULT
end

if not os.isfile(container_path) then
raise(
"qy+flash-attn: expected %s\n Install flash-attn in the conda env, or export FLASH_ATTN_2_CUDA_SO to your .so path.",
container_path
)
end
return container_path
end

add_includedirs("/usr/local/denglin/sdk/include", "../include")
add_linkdirs("/usr/local/denglin/sdk/lib")
add_links("curt", "cublas", "cudnn")
Expand Down Expand Up @@ -44,8 +83,20 @@ rule("qy.cuda")
local sdk_path = "/usr/local/denglin/sdk"
local arch = "dlgput64"

local relpath = path.relative(sourcefile, project.directory())
local objfile = path.join(config.buildir(), ".objs", target:name(), "rules", "qy.cuda", relpath .. ".o")

local relpath = path.relative(sourcefile, os.projectdir())

-- 去掉 ..,转成安全路径
relpath = relpath:gsub("%.%.", "__")

local objfile = path.join(
config.buildir(),
".objs",
target:name(),
"rules",
"qy.cuda",
relpath .. ".o"
)

-- 🟢 强制注册 .o 文件给 target
target:add("objectfiles", objfile)
Expand Down Expand Up @@ -153,3 +204,26 @@ target("infiniccl-qy")
set_languages("cxx17")

target_end()

target("flash-attn-qy")
set_kind("phony")
set_default(false)


if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then
before_build(function (target)
target:add("includedirs", "/usr/local/denglin/sdk/include", {public = true})
local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()

-- Validate build/runtime env in container and keep these paths available for downstream linking.
target:add("includedirs", TORCH_DIR .. "/include", TORCH_DIR .. "/include/torch/csrc/api/include", PYTHON_INCLUDE, {public = false})
target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR, {public = false})
end)
else
before_build(function (target)
print("Flash Attention not available, skipping flash-attn-qy integration")
end)
end
target_end()
Loading