Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions cmake/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,32 @@ if(WITH_ROCM)
file(COPY paddle/cinn/common/float16.h DESTINATION $ENV{runtime_include_dir})
endif()

if(WITH_CUSTOM_DEVICE)
message(STATUS "CINN Compile with custom device support")
# 除非你确定需要 SYCL/DPCPP,否则删掉下面两行,避免找不到包
# set(DPCPP_DIR ${PROJECT_SOURCE_DIR}/cmake/cinn)
# find_package(DPCPP REQUIRED CONFIG)

add_definitions(-DCINN_WITH_CUSTOM_DEVICE)

# 2. 设置运行时头文件路径 (参考 CUDA/ROCm)
if(NOT DEFINED ENV{runtime_include_dir})
# 为自定义设备创建一个专门的 runtime 路径
set(ENV{runtime_include_dir}
"${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/custom_device")
add_definitions(
-DRUNTIME_INCLUDE_DIR="${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/custom_device"
)
endif()

# 3. 拷贝必要的头文件,否则 JIT 编译算子时会找不到 float16 等定义
message(STATUS "copy float16 headers for custom device")
file(MAKE_DIRECTORY $ENV{runtime_include_dir})
file(COPY paddle/cinn/common/float16.h paddle/cinn/common/bfloat16.h
paddle/cinn/common/float8e4m3.h
DESTINATION $ENV{runtime_include_dir})
endif()

set(cinnapi_src CACHE INTERNAL "" FORCE)
set(core_src CACHE INTERNAL "" FORCE)
set(core_includes CACHE INTERNAL "" FORCE)
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ if(WITH_SYCL)
add_subdirectory(sycl)
endif()

if(WITH_CUSTOM_DEVICE)
add_subdirectory(custom_device)
endif()

if(WITH_OPENMP)
cinn_cc_library(__x86_source_fake_lib SRCS _x86_builtin_source.cc)
endif()
Expand Down
9 changes: 9 additions & 0 deletions paddle/cinn/backends/codegen_cuda_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ class CodeGenGpuHost : public CodeGenHost {
} else {
return CodeGenHost::Visit(op);
}
},
[&](common::CustomDeviceArch) {
if (op->name == runtime::intrinsic::call_custom_device_kernel ||
op->name ==
runtime::intrinsic::call_custom_device_cooperative_kernel) {
return LowerGPUKernelCall(op);
} else {
return CodeGenHost::Visit(op);
}
});
}

Expand Down
11 changes: 11 additions & 0 deletions paddle/cinn/backends/codegen_device_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
[&](common::HygonDCUArchSYCL) {
#ifdef CINN_WITH_SYCL
shared_mem_bytes = Expr(0);
#endif
},
[&](common::CustomDeviceArch) {
#ifdef CINN_WITH_CUSTOM_DEVICE
shared_mem_bytes = CalculateSharedMemory(func);
#endif
});

Expand All @@ -283,6 +288,12 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
},
[&](common::HygonDCUArchSYCL) {
call_kernel = runtime::intrinsic::call_sycl_kernel;
},
[&](common::CustomDeviceArch) {
call_kernel =
RequiresCooperativeLaunch(func)
? runtime::intrinsic::call_custom_device_cooperative_kernel
: runtime::intrinsic::call_custom_device_kernel;
});
// TODO(Dmovic): use new ir when backend update done.
// Author(liujinnan): Copy args instead of use func args directly in host
Expand Down
17 changes: 17 additions & 0 deletions paddle/cinn/backends/codegen_device_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#ifdef CINN_WITH_SYCL
#include "paddle/cinn/backends/sycl/codegen_sycl_dev.h"
#endif
#ifdef CINN_WITH_CUSTOM_DEVICE
#include "paddle/cinn/backends/custom_device/codegen_custom_device_dev.h"
#endif
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
Expand Down Expand Up @@ -127,6 +130,17 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
[&](common::CustomDeviceArch) {
#ifdef CINN_WITH_CUSTOM_DEVICE
// 1. 创建 CodeGen 对象,传入默认的 CustomDevice Target
custom_device::CodeGenCustomDevice codegen_dev(
cinn::common::DefaultCustomDeviceTarget());
// 2. 模拟编译过程,这一步会遍历 AST 并计算动态共享内存的大小
codegen_dev.Compile(ir::LoweredFunc(func));
// 3. 获取计算结果
shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
#endif
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CodeGenCudaDev codegen_dev(cinn::common::DefaultNVGPUTarget());
Expand Down Expand Up @@ -165,6 +179,9 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
[&](common::CustomDeviceArch) {
call_kernel = runtime::intrinsic::call_custom_device_kernel;
},
[&](common::NVGPUArch) {
call_kernel = runtime::intrinsic::call_cuda_kernel;
},
Expand Down
125 changes: 124 additions & 1 deletion paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
#include "paddle/cinn/runtime/cuda/cuda_util.h"
#include "paddle/cinn/runtime/flags.h"
#endif
#ifdef CINN_WITH_CUSTOM_DEVICE
#include "paddle/cinn/backends/custom_device/codegen_custom_device_dev.h"
#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h"
#include "paddle/phi/backends/device_manager.h"
#endif
#ifdef CINN_WITH_HIP
#include "paddle/cinn/backends/hip/codegen_hip_dev.h"
#include "paddle/cinn/backends/hip/compiler_hip.h"
Expand Down Expand Up @@ -253,7 +258,10 @@ void Compiler::Build(const Module& module, const std::string& code) {
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) { CompileCudaModule(module, code); },
[&](common::HygonDCUArchHIP) { CompileHipModule(module, code); },
[&](common::HygonDCUArchSYCL) { CompileSyclModule(module, code); });
[&](common::HygonDCUArchSYCL) { CompileSyclModule(module, code); },
[&](common::CustomDeviceArch) {
CompileCustomDeviceModule(module, code);
});
}

void Compiler::AppendCX86(const Module& module) {
Expand Down Expand Up @@ -344,6 +352,19 @@ std::string Compiler::GetSourceCode(const ir::Module& module) {
[&](common::UnknownArch) -> std::string { CINN_NOT_IMPLEMENTED; },
[&](common::X86Arch) -> std::string { CINN_NOT_IMPLEMENTED; },
[&](common::ARMArch) -> std::string { CINN_NOT_IMPLEMENTED; },
[&](common::CustomDeviceArch) -> std::string {
#ifdef CINN_WITH_CUSTOM_DEVICE
auto _host_module_device_module_ =
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
custom_device::CodeGenCustomDevice codegen(target_);
auto source_code = codegen.Compile(device_module);
return source_code;
#else
CINN_NOT_IMPLEMENTED
#endif
},
[&](common::NVGPUArch) -> std::string {
#ifdef CINN_WITH_CUDA
auto _host_module_device_module_ =
Expand Down Expand Up @@ -390,6 +411,7 @@ void Compiler::BuildDefault(const Module& module) {
[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
[&](common::X86Arch) { CompileX86Module(module); },
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; },
[&](common::CustomDeviceArch) { CompileCustomDeviceModule(module); },
[&](common::NVGPUArch) { CompileCudaModule(module); },
[&](common::HygonDCUArchHIP) { CompileHipModule(module); },
[&](common::HygonDCUArchSYCL) { CompileSyclModule(module); });
Expand Down Expand Up @@ -418,6 +440,7 @@ void Compiler::RegisterDeviceModuleSymbol() {
[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
[&](common::X86Arch) { return; },
[&](common::ARMArch) { return; },
[&](common::CustomDeviceArch) { RegisterCustomDeviceModuleSymbol(); },
[&](common::NVGPUArch) { RegisterCudaModuleSymbol(); },
[&](common::HygonDCUArchHIP) { RegisterHipModuleSymbol(); },
[&](common::HygonDCUArchSYCL) { RegisterSyclModuleSymbol(); });
Expand Down Expand Up @@ -526,6 +549,66 @@ void Compiler::RegisterCudaModuleSymbol() {
#endif
}

void Compiler::RegisterCustomDeviceModuleSymbol() {
#ifdef CINN_WITH_CUSTOM_DEVICE
// 1. 获取插件实例
auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes();
PADDLE_ENFORCE_EQ(!dev_types.empty(),
true,
::common::errors::NotFound(
"No custom device registered in DeviceManager."));

std::string dev_type = dev_types[0];
auto place = phi::CustomPlace(dev_type, 0);
auto& plugin =
cinn::runtime::custom_device::CinnCustomDevicePlugin::GetInstance(place);

// 2. 准备源码
// 此时 device_fn_code_ 已经包含了通过 Codegen 拼接的 Runtime Source
std::string source_code = device_fn_code_;

// 3. 调用插件工具链进行编译 (Compile)
// 返回值通常是编译产物的路径 (如 .so 或 .o 文件路径)
std::string lib_path = plugin.GetToolchain()->Compile(source_code);

PADDLE_ENFORCE_EQ(
!lib_path.empty(),
true,
::common::errors::External("Custom Device Toolchain compile failed."));

// 4. 调用插件运行时加载模块 (LoadModule)
// device_module_ 是 Compiler 类的成员变量: std::unique_ptr<CustomModule>
// device_module_;
this->device_module_ = plugin.GetRuntime()->LoadModule(lib_path);
PADDLE_ENFORCE_NOT_NULL(
this->device_module_,
::common::errors::External(
"Custom Device Runtime failed to load module from %s",
lib_path.c_str()));

// 5. 注册 Kernel 符号
// 我们需要获取设备 Kernel 的指针 (或者 Handle),并将其注册为
// [kernel_name]_ptr_
RuntimeSymbols symbols;
for (const auto& kernel_fn_name : device_fn_name_) {
void* fn_kernel = this->device_module_->GetFunction(kernel_fn_name);

PADDLE_ENFORCE_NOT_NULL(fn_kernel,
::common::errors::NotFound(
"Custom Device Runtime cannot find kernel: %s",
kernel_fn_name.c_str()));

// 保存指针供 ExecutionEngine 使用
fn_ptr_.push_back(fn_kernel);
symbols.RegisterVar(kernel_fn_name + "_ptr_", fn_kernel);
}

engine_->RegisterModuleRuntimeSymbols(std::move(symbols));
#else
CINN_NOT_IMPLEMENTED
#endif
}

void Compiler::RegisterHipModuleSymbol() {
#ifdef CINN_WITH_HIP
hiprtc::Compiler compiler;
Expand Down Expand Up @@ -632,6 +715,46 @@ void Compiler::CompileCudaModule(const Module& module,
#endif
}

void Compiler::CompileCustomDeviceModule(const Module& module,
const std::string& code) {
#ifdef CINN_WITH_CUSTOM_DEVICE
auto _host_module_device_module_ =
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
VLOG(3) << "[CustomDevice] host module:\n" << host_module;

VLOG(3) << "[CustomDevice] device module:\n" << device_module;
std::string source_code;

if (!FLAGS_cinn_debug_custom_code_path.empty()) {
std::string file_path = FLAGS_cinn_debug_custom_code_path;
source_code = GetFileContent(file_path);
} else if (code.empty()) {
custom_device::CodeGenCustomDevice codegen(target_);
source_code = codegen.Compile(device_module);
} else {
source_code = code;
}

PADDLE_ENFORCE_EQ(!source_code.empty(),
true,
::common::errors::InvalidArgument(
"Compile CustomDevice code failed from device module"));
VLOG(3) << "[CustomDevice] Source:\n" << source_code;
SourceCodePrint::GetInstance()->write(source_code);
device_fn_code_ += source_code;

for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
device_fn_name_.emplace_back(kernel_fn_name);
}
engine_->Link<CodeGenGpuHost>(host_module);
#else
CINN_NOT_IMPLEMENTED
#endif
}

void Compiler::CompileHipModule(const Module& module, const std::string& code) {
#ifdef CINN_WITH_HIP
auto _host_module_device_module_ =
Expand Down
13 changes: 13 additions & 0 deletions paddle/cinn/backends/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#endif
#ifdef CINN_WITH_CUSTOM_DEVICE
#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h"
#endif
#ifdef CINN_WITH_HIP
#include "paddle/cinn/runtime/hip/hip_module.h"
#endif
Expand Down Expand Up @@ -174,13 +177,18 @@ class Compiler final {

void RegisterCudaModuleSymbol();

void RegisterCustomDeviceModuleSymbol();

void RegisterHipModuleSymbol();

void RegisterSyclModuleSymbol();

void CompileCudaModule(const ir::Module& module,
const std::string& code = "");

void CompileCustomDeviceModule(const ir::Module& module,
const std::string& code = "");

void CompileHipModule(const ir::Module& module, const std::string& code = "");

void CompileSyclModule(const ir::Module& module,
Expand Down Expand Up @@ -211,6 +219,11 @@ class Compiler final {
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
void* cuda_module_handle_{nullptr};
#endif

#ifdef CINN_WITH_CUSTOM_DEVICE
std::unique_ptr<runtime::CustomModule> device_module_;
#endif

#ifdef CINN_WITH_HIP
std::unique_ptr<runtime::hip::HIPModule> hip_module_;
#endif
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/backends/custom_device/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS codegen_custom_device_dev.cc
compiler_custom_device.cc)
Loading
Loading