diff --git a/cmake/cinn.cmake b/cmake/cinn.cmake index 6cf2ffe32881a1..02daac2ec7c027 100644 --- a/cmake/cinn.cmake +++ b/cmake/cinn.cmake @@ -133,6 +133,32 @@ if(WITH_ROCM) file(COPY paddle/cinn/common/float16.h DESTINATION $ENV{runtime_include_dir}) endif() +if(WITH_CUSTOM_DEVICE) + message(STATUS "CINN Compile with custom device support") + # 除非你确定需要 SYCL/DPCPP,否则删掉下面两行,避免找不到包 + # set(DPCPP_DIR ${PROJECT_SOURCE_DIR}/cmake/cinn) + # find_package(DPCPP REQUIRED CONFIG) + + add_definitions(-DCINN_WITH_CUSTOM_DEVICE) + + # 2. 设置运行时头文件路径 (参考 CUDA/ROCm) + if(NOT DEFINED ENV{runtime_include_dir}) + # 为自定义设备创建一个专门的 runtime 路径 + set(ENV{runtime_include_dir} + "${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/custom_device") + add_definitions( + -DRUNTIME_INCLUDE_DIR="${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/custom_device" + ) + endif() + + # 3. 拷贝必要的头文件,否则 JIT 编译算子时会找不到 float16 等定义 + message(STATUS "copy float16 headers for custom device") + file(MAKE_DIRECTORY $ENV{runtime_include_dir}) + file(COPY paddle/cinn/common/float16.h paddle/cinn/common/bfloat16.h + paddle/cinn/common/float8e4m3.h + DESTINATION $ENV{runtime_include_dir}) +endif() + set(cinnapi_src CACHE INTERNAL "" FORCE) set(core_src CACHE INTERNAL "" FORCE) set(core_includes CACHE INTERNAL "" FORCE) diff --git a/paddle/cinn/backends/CMakeLists.txt b/paddle/cinn/backends/CMakeLists.txt index 869c1f0bb0d694..393fb3b528d085 100755 --- a/paddle/cinn/backends/CMakeLists.txt +++ b/paddle/cinn/backends/CMakeLists.txt @@ -30,6 +30,10 @@ if(WITH_SYCL) add_subdirectory(sycl) endif() +if(WITH_CUSTOM_DEVICE) + add_subdirectory(custom_device) +endif() + if(WITH_OPENMP) cinn_cc_library(__x86_source_fake_lib SRCS _x86_builtin_source.cc) endif() diff --git a/paddle/cinn/backends/codegen_cuda_host.h b/paddle/cinn/backends/codegen_cuda_host.h index 33214a533de3e2..4a59126803d0b1 100644 --- a/paddle/cinn/backends/codegen_cuda_host.h +++ b/paddle/cinn/backends/codegen_cuda_host.h @@ -63,6 +63,15 @@ class CodeGenGpuHost : public CodeGenHost { } else { return CodeGenHost::Visit(op); } + }, + [&](common::CustomDeviceArch) { + if (op->name == runtime::intrinsic::call_custom_device_kernel || + op->name == + runtime::intrinsic::call_custom_device_cooperative_kernel) { + return LowerGPUKernelCall(op); + } else { + return CodeGenHost::Visit(op); + } }); } diff --git a/paddle/cinn/backends/codegen_device_util.cc b/paddle/cinn/backends/codegen_device_util.cc index 2116b575b5796e..325ae149437f56 100644 --- a/paddle/cinn/backends/codegen_device_util.cc +++ b/paddle/cinn/backends/codegen_device_util.cc @@ -257,6 +257,11 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( [&](common::HygonDCUArchSYCL) { #ifdef CINN_WITH_SYCL shared_mem_bytes = Expr(0); +#endif + }, + [&](common::CustomDeviceArch) { +#ifdef CINN_WITH_CUSTOM_DEVICE + shared_mem_bytes = CalculateSharedMemory(func); #endif }); @@ -283,6 +288,12 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( }, [&](common::HygonDCUArchSYCL) { call_kernel = runtime::intrinsic::call_sycl_kernel; + }, + [&](common::CustomDeviceArch) { + call_kernel = + RequiresCooperativeLaunch(func) + ? runtime::intrinsic::call_custom_device_cooperative_kernel + : runtime::intrinsic::call_custom_device_kernel; }); // TODO(Dmovic): use new ir when backend update done. // Author(liujinnan): Copy args instead of use func args directly in host diff --git a/paddle/cinn/backends/codegen_device_util.h b/paddle/cinn/backends/codegen_device_util.h index b5931116aeefe9..7e68f3b9255832 100644 --- a/paddle/cinn/backends/codegen_device_util.h +++ b/paddle/cinn/backends/codegen_device_util.h @@ -26,6 +26,9 @@ #ifdef CINN_WITH_SYCL #include "paddle/cinn/backends/sycl/codegen_sycl_dev.h" #endif +#ifdef CINN_WITH_CUSTOM_DEVICE +#include "paddle/cinn/backends/custom_device/codegen_custom_device_dev.h" +#endif #include "paddle/cinn/cinn.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_mutator.h" @@ -127,6 +130,17 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { [&](std::variant) { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) { +#ifdef CINN_WITH_CUSTOM_DEVICE + // 1. 创建 CodeGen 对象,传入默认的 CustomDevice Target + custom_device::CodeGenCustomDevice codegen_dev( + cinn::common::DefaultCustomDeviceTarget()); + // 2. 模拟编译过程,这一步会遍历 AST 并计算动态共享内存的大小 + codegen_dev.Compile(ir::LoweredFunc(func)); + // 3. 获取计算结果 + shared_mem_bytes = codegen_dev.GetDynSharedMemOffset(); +#endif + }, [&](common::NVGPUArch) { #ifdef CINN_WITH_CUDA CodeGenCudaDev codegen_dev(cinn::common::DefaultNVGPUTarget()); @@ -165,6 +179,9 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { [&](std::variant) { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) { + call_kernel = runtime::intrinsic::call_custom_device_kernel; + }, [&](common::NVGPUArch) { call_kernel = runtime::intrinsic::call_cuda_kernel; }, diff --git a/paddle/cinn/backends/compiler.cc b/paddle/cinn/backends/compiler.cc index b844573eca26cb..6391bd862508d7 100644 --- a/paddle/cinn/backends/compiler.cc +++ b/paddle/cinn/backends/compiler.cc @@ -37,6 +37,11 @@ #include "paddle/cinn/runtime/cuda/cuda_util.h" #include "paddle/cinn/runtime/flags.h" #endif +#ifdef CINN_WITH_CUSTOM_DEVICE +#include "paddle/cinn/backends/custom_device/codegen_custom_device_dev.h" +#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h" +#include "paddle/phi/backends/device_manager.h" +#endif #ifdef CINN_WITH_HIP #include "paddle/cinn/backends/hip/codegen_hip_dev.h" #include "paddle/cinn/backends/hip/compiler_hip.h" @@ -253,7 +258,10 @@ void Compiler::Build(const Module& module, const std::string& code) { [&](common::ARMArch) { CINN_NOT_IMPLEMENTED; }, [&](common::NVGPUArch) { CompileCudaModule(module, code); }, [&](common::HygonDCUArchHIP) { CompileHipModule(module, code); }, - [&](common::HygonDCUArchSYCL) { CompileSyclModule(module, code); }); + [&](common::HygonDCUArchSYCL) { CompileSyclModule(module, code); }, + [&](common::CustomDeviceArch) { + CompileCustomDeviceModule(module, code); + }); } void Compiler::AppendCX86(const Module& module) { @@ -344,6 +352,19 @@ std::string Compiler::GetSourceCode(const ir::Module& module) { [&](common::UnknownArch) -> std::string { CINN_NOT_IMPLEMENTED; }, [&](common::X86Arch) -> std::string { CINN_NOT_IMPLEMENTED; }, [&](common::ARMArch) -> std::string { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) -> std::string { +#ifdef CINN_WITH_CUSTOM_DEVICE + auto _host_module_device_module_ = + SplitDeviceAndHostModule(module); // NOLINT + auto& host_module = std::get<0>(_host_module_device_module_); + auto& device_module = std::get<1>(_host_module_device_module_); + custom_device::CodeGenCustomDevice codegen(target_); + auto source_code = codegen.Compile(device_module); + return source_code; +#else + CINN_NOT_IMPLEMENTED +#endif + }, [&](common::NVGPUArch) -> std::string { #ifdef CINN_WITH_CUDA auto _host_module_device_module_ = @@ -390,6 +411,7 @@ void Compiler::BuildDefault(const Module& module) { [&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; }, [&](common::X86Arch) { CompileX86Module(module); }, [&](common::ARMArch) { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) { CompileCustomDeviceModule(module); }, [&](common::NVGPUArch) { CompileCudaModule(module); }, [&](common::HygonDCUArchHIP) { CompileHipModule(module); }, [&](common::HygonDCUArchSYCL) { CompileSyclModule(module); }); @@ -418,6 +440,7 @@ void Compiler::RegisterDeviceModuleSymbol() { [&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; }, [&](common::X86Arch) { return; }, [&](common::ARMArch) { return; }, + [&](common::CustomDeviceArch) { RegisterCustomDeviceModuleSymbol(); }, [&](common::NVGPUArch) { RegisterCudaModuleSymbol(); }, [&](common::HygonDCUArchHIP) { RegisterHipModuleSymbol(); }, [&](common::HygonDCUArchSYCL) { RegisterSyclModuleSymbol(); }); @@ -526,6 +549,66 @@ void Compiler::RegisterCudaModuleSymbol() { #endif } +void Compiler::RegisterCustomDeviceModuleSymbol() { +#ifdef CINN_WITH_CUSTOM_DEVICE + // 1. 获取插件实例 + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + PADDLE_ENFORCE_EQ(!dev_types.empty(), + true, + ::common::errors::NotFound( + "No custom device registered in DeviceManager.")); + + std::string dev_type = dev_types[0]; + auto place = phi::CustomPlace(dev_type, 0); + auto& plugin = + cinn::runtime::custom_device::CinnCustomDevicePlugin::GetInstance(place); + + // 2. 准备源码 + // 此时 device_fn_code_ 已经包含了通过 Codegen 拼接的 Runtime Source + std::string source_code = device_fn_code_; + + // 3. 调用插件工具链进行编译 (Compile) + // 返回值通常是编译产物的路径 (如 .so 或 .o 文件路径) + std::string lib_path = plugin.GetToolchain()->Compile(source_code); + + PADDLE_ENFORCE_EQ( + !lib_path.empty(), + true, + ::common::errors::External("Custom Device Toolchain compile failed.")); + + // 4. 调用插件运行时加载模块 (LoadModule) + // device_module_ 是 Compiler 类的成员变量: std::unique_ptr + // device_module_; + this->device_module_ = plugin.GetRuntime()->LoadModule(lib_path); + PADDLE_ENFORCE_NOT_NULL( + this->device_module_, + ::common::errors::External( + "Custom Device Runtime failed to load module from %s", + lib_path.c_str())); + + // 5. 注册 Kernel 符号 + // 我们需要获取设备 Kernel 的指针 (或者 Handle),并将其注册为 + // [kernel_name]_ptr_ + RuntimeSymbols symbols; + for (const auto& kernel_fn_name : device_fn_name_) { + void* fn_kernel = this->device_module_->GetFunction(kernel_fn_name); + + PADDLE_ENFORCE_NOT_NULL(fn_kernel, + ::common::errors::NotFound( + "Custom Device Runtime cannot find kernel: %s", + kernel_fn_name.c_str())); + + // 保存指针供 ExecutionEngine 使用 + fn_ptr_.push_back(fn_kernel); + symbols.RegisterVar(kernel_fn_name + "_ptr_", fn_kernel); + } + + engine_->RegisterModuleRuntimeSymbols(std::move(symbols)); +#else + CINN_NOT_IMPLEMENTED +#endif +} + void Compiler::RegisterHipModuleSymbol() { #ifdef CINN_WITH_HIP hiprtc::Compiler compiler; @@ -632,6 +715,46 @@ void Compiler::CompileCudaModule(const Module& module, #endif } +void Compiler::CompileCustomDeviceModule(const Module& module, + const std::string& code) { +#ifdef CINN_WITH_CUSTOM_DEVICE + auto _host_module_device_module_ = + SplitDeviceAndHostModule(module); // NOLINT + auto& host_module = std::get<0>(_host_module_device_module_); + auto& device_module = std::get<1>(_host_module_device_module_); + VLOG(3) << "[CustomDevice] host module:\n" << host_module; + + VLOG(3) << "[CustomDevice] device module:\n" << device_module; + std::string source_code; + + if (!FLAGS_cinn_debug_custom_code_path.empty()) { + std::string file_path = FLAGS_cinn_debug_custom_code_path; + source_code = GetFileContent(file_path); + } else if (code.empty()) { + custom_device::CodeGenCustomDevice codegen(target_); + source_code = codegen.Compile(device_module); + } else { + source_code = code; + } + + PADDLE_ENFORCE_EQ(!source_code.empty(), + true, + ::common::errors::InvalidArgument( + "Compile CustomDevice code failed from device module")); + VLOG(3) << "[CustomDevice] Source:\n" << source_code; + SourceCodePrint::GetInstance()->write(source_code); + device_fn_code_ += source_code; + + for (auto& fn : device_module.functions()) { + std::string kernel_fn_name = fn->name; + device_fn_name_.emplace_back(kernel_fn_name); + } + engine_->Link(host_module); +#else + CINN_NOT_IMPLEMENTED +#endif +} + void Compiler::CompileHipModule(const Module& module, const std::string& code) { #ifdef CINN_WITH_HIP auto _host_module_device_module_ = diff --git a/paddle/cinn/backends/compiler.h b/paddle/cinn/backends/compiler.h index 38545bfeb248fe..37669c53713d4e 100644 --- a/paddle/cinn/backends/compiler.h +++ b/paddle/cinn/backends/compiler.h @@ -29,6 +29,9 @@ #ifdef CINN_WITH_CUDA #include "paddle/cinn/runtime/cuda/cuda_module.h" #endif +#ifdef CINN_WITH_CUSTOM_DEVICE +#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h" +#endif #ifdef CINN_WITH_HIP #include "paddle/cinn/runtime/hip/hip_module.h" #endif @@ -174,6 +177,8 @@ class Compiler final { void RegisterCudaModuleSymbol(); + void RegisterCustomDeviceModuleSymbol(); + void RegisterHipModuleSymbol(); void RegisterSyclModuleSymbol(); @@ -181,6 +186,9 @@ class Compiler final { void CompileCudaModule(const ir::Module& module, const std::string& code = ""); + void CompileCustomDeviceModule(const ir::Module& module, + const std::string& code = ""); + void CompileHipModule(const ir::Module& module, const std::string& code = ""); void CompileSyclModule(const ir::Module& module, @@ -211,6 +219,11 @@ class Compiler final { std::unique_ptr cuda_module_; void* cuda_module_handle_{nullptr}; #endif + +#ifdef CINN_WITH_CUSTOM_DEVICE + std::unique_ptr device_module_; +#endif + #ifdef CINN_WITH_HIP std::unique_ptr hip_module_; #endif diff --git a/paddle/cinn/backends/custom_device/CMakeLists.txt b/paddle/cinn/backends/custom_device/CMakeLists.txt new file mode 100644 index 00000000000000..6dbb4416e4dadf --- /dev/null +++ b/paddle/cinn/backends/custom_device/CMakeLists.txt @@ -0,0 +1,4 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS codegen_custom_device_dev.cc + compiler_custom_device.cc) diff --git a/paddle/cinn/backends/custom_device/codegen_custom_device_dev.cc b/paddle/cinn/backends/custom_device/codegen_custom_device_dev.cc new file mode 100644 index 00000000000000..13730333cef7d7 --- /dev/null +++ b/paddle/cinn/backends/custom_device/codegen_custom_device_dev.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/backends/custom_device/codegen_custom_device_dev.h" +#include +#include +#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h" +#include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/common/place.h" + +namespace cinn { +namespace backends { +namespace custom_device { + +CodeGenCustomDevice::CodeGenCustomDevice(Target target) + : CodeGenGpuDev(target) {} + +void CodeGenCustomDevice::PrintIncludes() { + // 1. 基础宏定义 + str_ += "#define CINN_WITH_CUSTOM_DEVICE\n"; + str_ += "#include \"float16.h\"\n"; + str_ += "using cinn::common::float16;\n"; + + // 2. 动态获取厂商的 Runtime Source + // 逻辑:找到当前系统中的 Custom Device 类型 -> 获取插件 -> 获取源码 + std::string dev_type = ""; + auto devs = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (!devs.empty()) { + dev_type = devs[0]; + } else { + LOG(WARNING) + << "No custom device found, skipping runtime source injection."; + return; + } + + // 获取插件实例 + auto place = phi::CustomPlace(dev_type, 0); + try { + auto& plugin = + cinn::runtime::custom_device::CinnCustomDevicePlugin::GetInstance( + place); + + // 3. 从 Toolchain 中获取运行时源码并追加到生成的 Kernel 字符串中 + std::string runtime_src = plugin.GetToolchain()->GetRuntimeSource(); + if (runtime_src.empty()) { + LOG(WARNING) << "Custom Device [" << dev_type + << "] returned empty runtime source."; + } + str_ += "\n// ----- Custom Device Runtime Source (Begin) -----\n"; + str_ += runtime_src; + str_ += "\n// ----- Custom Device Runtime Source (End) -----\n"; + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to get CinnCustomDevicePlugin: " << e.what(); + } +} + +const std::string& CodeGenCustomDevice::GetSourceHeader() { + static std::string empty_header = ""; + return empty_header; +} + +void CodeGenCustomDevice::Visit(const ir::Min* op) { + str_ += "std::min("; + ir::Expr a = op->a(), b = op->b(); + auto [unify_bit, both_dyn] = + common::UnifiedOperandTypeBits(&this->DynamicShapeMap(), op); + this->ProcessMinMaxOperand(&a, &b, unify_bit, both_dyn); + IrPrinter::Visit(a); + str_ += ", "; + IrPrinter::Visit(b); + str_ += ")"; +} + +void CodeGenCustomDevice::Visit(const ir::Max* op) { + str_ += "std::max("; + ir::Expr a = op->a(), b = op->b(); + auto [unify_bit, both_dyn] = + common::UnifiedOperandTypeBits(&this->DynamicShapeMap(), op); + this->ProcessMinMaxOperand(&a, &b, unify_bit, both_dyn); + IrPrinter::Visit(a); + str_ += ", "; + IrPrinter::Visit(b); + str_ += ")"; +} + +} // namespace custom_device +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/custom_device/codegen_custom_device_dev.h b/paddle/cinn/backends/custom_device/codegen_custom_device_dev.h new file mode 100644 index 00000000000000..d4262f41dcb34c --- /dev/null +++ b/paddle/cinn/backends/custom_device/codegen_custom_device_dev.h @@ -0,0 +1,44 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/backends/codegen_gpu_dev.h" + +namespace cinn { +namespace backends { +namespace custom_device { + +/** + * CUSTOMDEVICE device code generator. + * + * It generates the device function, e.g, the function called "myadd" will have + * a __global__ function called "myadd_kernel", different from codegen_c, the + * declaration of the "myadd_kernel" function has an expanded argument list, + * which finally similar to `__global__ void myadd(float* __restrict__ A, float* + * __restrict__ B, int n);` + */ +class CodeGenCustomDevice : public CodeGenGpuDev { + public: + explicit CodeGenCustomDevice(Target target); + static const std::string& GetSourceHeader(); + void PrintIncludes() override; + void Visit(const ir::Min* op) override; + void Visit(const ir::Max* op) override; + + private: +}; + +} // namespace custom_device +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/custom_device/compiler_custom_device.cc b/paddle/cinn/backends/custom_device/compiler_custom_device.cc new file mode 100644 index 00000000000000..46dc627e900bef --- /dev/null +++ b/paddle/cinn/backends/custom_device/compiler_custom_device.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/backends/custom_device/compiler_custom_device.h" +#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h" + +#if defined(__linux__) +#include +#endif +#include +#include +#include + +#include "paddle/cinn/common/common.h" +#include "paddle/cinn/runtime/custom_device/custom_device_util.h" +#include "paddle/cinn/runtime/flags.h" +#include "paddle/cinn/utils/string.h" +#include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/common/place.h" + +namespace cinn { +namespace backends { +namespace cdrtc { + +Compiler::Compiler(const cinn::common::Target& target) : target_(target) {} +std::string Compiler::operator()(const std::string& code, + bool include_headers) { + std::string dev_type = ""; + auto devs = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (!devs.empty()) { + dev_type = devs[0]; // 默认取第一个注册的自定义设备 + } + + auto place = phi::CustomPlace(dev_type, 0); + // 1. 获取插件 + auto& plugin = + cinn::runtime::custom_device::CinnCustomDevicePlugin::GetInstance(place); + + // 2. 转发给插件的 Toolchain + // include_headers 这个参数看你是否决定传给插件,或者约定代码里已经包含了 + return plugin.GetToolchain()->Compile(code); +} + +} // namespace cdrtc +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/custom_device/compiler_custom_device.h b/paddle/cinn/backends/custom_device/compiler_custom_device.h new file mode 100644 index 00000000000000..5b1cbaee869224 --- /dev/null +++ b/paddle/cinn/backends/custom_device/compiler_custom_device.h @@ -0,0 +1,48 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/cinn/common/target.h" + +namespace cinn { +namespace backends { +namespace cdrtc { + +/** + * An helper class to call Csrtc or Cdcc. Input CUSTOMDEVICE device source code, + * get hsaco string. + */ +class Compiler { + public: + explicit Compiler(const common::Target& target); + /** + * Compile the \p code and get hsaco string. + * @param code The CUSTOMDEVICE source code. + * @param include_headers Whether to include the headers of CUSTOMDEVICE and + * CINN runtime modules. + * @return Compiled hsaco code string. + */ + std::string operator()(const std::string& code, bool include_headers = true); + + private: + // 只需要保留 target,用于确定去哪个 Place 找插件 + common::Target target_; +}; + +} // namespace cdrtc +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/extern_func_emitter.h b/paddle/cinn/backends/extern_func_emitter.h index 289ac523043104..1a7f080b8104cb 100644 --- a/paddle/cinn/backends/extern_func_emitter.h +++ b/paddle/cinn/backends/extern_func_emitter.h @@ -50,6 +50,7 @@ static const char* backend_llvm_x86 = "llvm_x86"; static const char* backend_nvgpu = "nvgpu"; static const char* backend_hygondcu_hip = "hygonDCU_hip"; static const char* backend_hygondcu_sycl = "hygonDCU_sycl"; +static const char* backend_custom_device = "custom_device"; /** * \brief Base class of the emitter of all the extern functions able to trigger diff --git a/paddle/cinn/backends/extern_func_jit_register.h b/paddle/cinn/backends/extern_func_jit_register.h index d29c2e2234056c..6bffff8906dc7b 100644 --- a/paddle/cinn/backends/extern_func_jit_register.h +++ b/paddle/cinn/backends/extern_func_jit_register.h @@ -51,42 +51,42 @@ */ #define REGISTER_EXTERN_FUNC_1_IN_1_OUT(fn__, target__, in_type__, out_type__) \ REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ - .SetRetType() \ - .AddInputType() \ + .template SetRetType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ .End() /** * Register an external function with one input and one output. */ -#define REGISTER_EXTERN_FUNC_2_IN_1_OUT( \ - fn__, target__, in_type1__, in_type2__, out_type__) \ - REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ - .SetRetType() \ - .AddInputType() \ - .AddInputType() \ +#define REGISTER_EXTERN_FUNC_2_IN_1_OUT( \ + fn__, target__, in_type1__, in_type2__, out_type__) \ + REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ + .template SetRetType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ .End() /** * Register a sourced function(No function address, called in generated source * code). */ -#define REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ - fn__, target__, in_type__, out_type__) \ - REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ - .SetRetType() \ - .AddInputType() \ +#define REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + fn__, target__, in_type__, out_type__) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ + .template SetRetType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ .End() /** * Register a sourced function(No function address, called in generated source * code). */ -#define REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ - fn__, target__, in_type1__, in_type2__, out_type__) \ - REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ - .SetRetType() \ - .AddInputType() \ - .AddInputType() \ +#define REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + fn__, target__, in_type1__, in_type2__, out_type__) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ + .template SetRetType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ .End() namespace cinn { @@ -97,6 +97,9 @@ static const char* TargetToBackendRepr(Target target) { [&](common::UnknownArch) -> const char* { CINN_NOT_IMPLEMENTED; }, [&](common::X86Arch) -> const char* { return backend_llvm_host; }, [&](common::ARMArch) -> const char* { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) -> const char* { + return backend_custom_device; + }, [&](common::NVGPUArch) -> const char* { return backend_nvgpu; }, [&](common::HygonDCUArchHIP) -> const char* { return backend_hygondcu_hip; @@ -135,6 +138,11 @@ struct RegisterExternFunction { return *this; } + RegisterExternFunction& AddInputType(const cinn::common::Type& t) { + fn_proto_builder_.AddInputType(t); + return *this; + } + /** * Add an output type. * @tparam T The output type. @@ -157,6 +165,11 @@ struct RegisterExternFunction { return *this; } + RegisterExternFunction& SetRetType(const cinn::common::Type& t) { + fn_proto_builder_.SetRetType(t); + return *this; + } + /** * Add an shape inference. * @param handle The handle to help inference the shape. diff --git a/paddle/cinn/backends/function_prototype.h b/paddle/cinn/backends/function_prototype.h index 0859a441ab5970..4c246291239bb9 100644 --- a/paddle/cinn/backends/function_prototype.h +++ b/paddle/cinn/backends/function_prototype.h @@ -84,11 +84,23 @@ struct FunctionProto { data_->ret_type = type_of(); return *this; } + + Builder& SetRetType(const common::Type& t) { + data_->ret_type = t; + return *this; + } + template Builder& AddInputType() { data_->readonly_arg_types.push_back(type_of()); return *this; } + + Builder& AddInputType(const common::Type& t) { + data_->readonly_arg_types.push_back(t); + return *this; + } + template Builder& AddOutputType() { data_->mutable_arg_types.push_back(type_of()); diff --git a/paddle/cinn/backends/llvm/codegen_llvm.cc b/paddle/cinn/backends/llvm/codegen_llvm.cc index 4929a9d4bf7072..9d568c810411b3 100644 --- a/paddle/cinn/backends/llvm/codegen_llvm.cc +++ b/paddle/cinn/backends/llvm/codegen_llvm.cc @@ -49,6 +49,7 @@ #include "paddle/cinn/runtime/cinn_runtime.h" #include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/utils/string.h" +#include "paddle/phi/backends/device_manager.h" namespace cinn { namespace backends { @@ -1512,6 +1513,16 @@ int GetNaiveVecAlignmentImpl(common::HygonDCUArchSYCL, const Target &target) { return 128; } +int GetNaiveVecAlignmentImpl(common::CustomDeviceArch arch, + const Target &target) { +#ifdef CINN_WITH_CUSTOM_DEVICE + auto place = phi::CustomPlace(arch.device_type, arch.device_id); + return phi::DeviceManager::GetPreferredVectorWidth(place); +#else + return 128; +#endif +} + int GetNaiveVecAlignment(const Target &target) { return std::visit( [&](const auto &impl) { return GetNaiveVecAlignmentImpl(impl, target); }, diff --git a/paddle/cinn/backends/llvm/execution_engine.cc b/paddle/cinn/backends/llvm/execution_engine.cc index 646092649db908..238a9ae5c76845 100644 --- a/paddle/cinn/backends/llvm/execution_engine.cc +++ b/paddle/cinn/backends/llvm/execution_engine.cc @@ -312,7 +312,7 @@ bool ExecutionEngine::linkSharedLibrary( return true; #else CINN_NOT_IMPLEMENTED; -#endif +#endif // CINN_WITH_CUDA } bool ExecutionEngine::AddModule( diff --git a/paddle/cinn/common/arch.h b/paddle/cinn/common/arch.h index 3e585aa6a878f0..1228da207d935c 100644 --- a/paddle/cinn/common/arch.h +++ b/paddle/cinn/common/arch.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "paddle/common/overloaded.h" @@ -32,6 +33,10 @@ struct UnknownArch {}; struct class_name {}; CINN_ARCH_CLASS_NAMES(DEFINE_CINN_ARCH); #undef DEFINE_CINN_ARCH +struct CustomDeviceArch { + std::string device_type{"unknown_custom"}; + int device_id{0}; +}; /** * The architecture used by the target. Determines the instruction set to use. @@ -40,7 +45,8 @@ using ArchBase = std::variant< #define LIST_CINN_ARCH_ALTERNATIVE(class_name) class_name, CINN_ARCH_CLASS_NAMES(LIST_CINN_ARCH_ALTERNATIVE) #undef LIST_CINN_ARCH_ALTERNATIVE - UnknownArch>; + CustomDeviceArch, + UnknownArch>; struct Arch final : public ArchBase { using ArchBase::ArchBase; diff --git a/paddle/cinn/common/arch_util.cc b/paddle/cinn/common/arch_util.cc index 91ee3b009e0dae..ec9cf33c152c67 100644 --- a/paddle/cinn/common/arch_util.cc +++ b/paddle/cinn/common/arch_util.cc @@ -31,6 +31,8 @@ std::string GetArchNameImpl(HygonDCUArchHIP arch) { return "HygonDCU_HIP"; } std::string GetArchNameImpl(HygonDCUArchSYCL arch) { return "HygonDCU_SYCL"; } +std::string GetArchNameImpl(CustomDeviceArch arch) { return arch.device_type; } + std::string GetArchName(Arch arch) { return std::visit([](const auto& impl) { return GetArchNameImpl(impl); }, arch.variant()); diff --git a/paddle/cinn/common/target.cc b/paddle/cinn/common/target.cc index 52678027e6533e..c1e11bf1e69e41 100644 --- a/paddle/cinn/common/target.cc +++ b/paddle/cinn/common/target.cc @@ -29,6 +29,10 @@ #include "paddle/cinn/runtime/backend_api.h" #include "paddle/cinn/runtime/cinn_runtime.h" #include "paddle/common/enforce.h" +#ifdef CINN_WITH_CUSTOM_DEVICE +#include "paddle/phi/backends/device_manager.h" +#endif + using cinn::runtime::BackendAPI; namespace cinn { @@ -60,6 +64,13 @@ Target::Target(OS o, #ifndef CINN_WITH_SYCL PADDLE_THROW(::common::errors::Unimplemented( "Please recompile with flag CINN_WITH_SYCL and WITH_CINN.")); +#endif + }, + [&](CustomDeviceArch) { +#ifndef CINN_WITH_CUSTOM_DEVICE + PADDLE_THROW(::common::errors::Unimplemented( + "Please recompile with flag WITH_CUSTOM_DEVICE and " + "WITH_CINN.")); #endif }); } @@ -85,6 +96,8 @@ int GetRuntimeArchImpl(HygonDCUArchHIP) { CINN_NOT_IMPLEMENTED } int GetRuntimeArchImpl(HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED } +int GetRuntimeArchImpl(CustomDeviceArch arch) { return cinn_custom_device; } + int GetRuntimeArch(Arch arch) { return std::visit([](const auto &impl) { return GetRuntimeArchImpl(impl); }, arch.variant()); @@ -110,6 +123,13 @@ int GetMaxNumThreadsImpl(HygonDCUArchHIP arch) { return 1024; } int GetMaxNumThreadsImpl(HygonDCUArchSYCL arch) { return 1024; } +int GetMaxNumThreadsImpl(CustomDeviceArch arch) { +#ifdef CINN_WITH_CUSTOM_DEVICE + return phi::DeviceManager::GetMaxThreadsPerBlock( + phi::CustomPlace(arch.device_type, arch.device_id)); +#endif +} + int GetMaxNumThreads(Arch arch) { return std::visit([](const auto &impl) { return GetMaxNumThreadsImpl(impl); }, arch.variant()); @@ -148,6 +168,13 @@ int GetMultiProcessCountImpl(HygonDCUArchSYCL arch) { BackendAPI::DeviceProperty::MultiProcessorCount); } +int GetMultiProcessCountImpl(CustomDeviceArch arch) { +#ifdef CINN_WITH_CUSTOM_DEVICE + return phi::DeviceManager::GetMultiProcessors( + phi::CustomPlace(arch.device_type, arch.device_id)); +#endif +} + int GetMultiProcessCount(Arch arch) { return std::visit( [](const auto &impl) { return GetMultiProcessCountImpl(impl); }, @@ -192,6 +219,13 @@ int GetMaxThreadsPerSmImpl(HygonDCUArchSYCL arch) { BackendAPI::DeviceProperty::MaxThreadsPerSM); } +int GetMaxThreadsPerSmImpl(CustomDeviceArch arch) { +#ifdef CINN_WITH_CUSTOM_DEVICE + return phi::DeviceManager::GetMaxThreadsPerMultiProcessor( + phi::CustomPlace(arch.device_type, arch.device_id)); +#endif +} + int GetMaxThreadsPerSm(Arch arch) { return std::visit( [](const auto &impl) { return GetMaxThreadsPerSmImpl(impl); }, @@ -234,6 +268,13 @@ int GetMaxBlocksPerSmImpl(HygonDCUArchSYCL arch) { BackendAPI::DeviceProperty::MaxBlocksPerSM); } +int GetMaxBlocksPerSmImpl(CustomDeviceArch arch) { +#ifdef CINN_WITH_CUSTOM_DEVICE + return phi::DeviceManager::GetMaxBlocksPerMultiProcessor( + phi::CustomPlace(arch.device_type, arch.device_id)); +#endif +} + int GetMaxBlocksPerSm(Arch arch) { return std::visit( [](const auto &impl) { return GetMaxBlocksPerSmImpl(impl); }, @@ -289,6 +330,12 @@ std::string Target::device_name_str() const { std::string device_name = properties.name; device_name = std::regex_replace(device_name, std::regex(" "), "_"); return std::regex_replace(device_name, std::regex("-"), "_"); +#elif defined(CINN_WITH_CUSTOM_DEVICE) + return arch.Visit(::common::Overloaded{ + [](const CustomDeviceArch &arch) -> std::string { + return arch.device_type; + }, + [](const auto &) -> std::string { return "unknown_device"; }}); #else CINN_NOT_IMPLEMENTED #endif @@ -356,6 +403,36 @@ const Target &DefaultHygonDcuSyclTarget() { return target; } +const Target &DefaultCustomDeviceTarget() { +#ifdef CINN_WITH_CUSTOM_DEVICE + // 1. 获取当前注册的所有 Custom Device 类型 + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + + std::string device_type = "unknown_custom_device"; + int device_id = 0; + + // 2. 如果存在自定义设备,取第一个作为默认值 (通常列表非空) + if (!dev_types.empty()) { + device_type = dev_types[0]; + // 获取该类型对应的当前设备 ID + device_id = phi::DeviceManager::GetDevice(device_type); + } + + // 3. 使用获取到的 type 和 id 构造 CustomDeviceArch + // 注意:这里假设 CustomDeviceArch 结构体的字段顺序是 {device_type, device_id} + static Target target(Target::OS::Linux, + CustomDeviceArch{device_type, device_id}, + Target::Bit::k64, + {}, + {}); +#else + // Fallback: 如果没有编译 CustomDevice,保持原样 + static Target target( + Target::OS::Linux, CustomDeviceArch{}, Target::Bit::k64, {}, {}); +#endif + return target; +} + const Target &DefaultDeviceTarget() { #ifdef CINN_WITH_CUDA return DefaultNVGPUTarget(); @@ -363,6 +440,8 @@ const Target &DefaultDeviceTarget() { return DefaultHygonDcuSyclTarget(); #elif defined(CINN_WITH_HIP) return DefaultHygonDcuHipTarget(); +#elif defined(CINN_WITH_CUSTOM_DEVICE) + return DefaultCustomDeviceTarget(); #endif } @@ -377,6 +456,16 @@ int GetMaxThreads() { &max_threads, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0); // multiplication num_sm max_threads *= (num_sm * 4); +#elif defined(CINN_WITH_CUSTOM_DEVICE) + std::vector dev_types = + phi::DeviceManager::GetAllCustomDeviceTypes(); + if (!dev_types.empty()) { + int device_id = phi::DeviceManager::GetDevice(dev_types[0]); + std::string dev_type = dev_types[0]; + auto place = phi::CustomPlace(dev_type, device_id); + max_threads = phi::DeviceManager::GetMultiProcessors(place) * + phi::DeviceManager::GetMaxThreadsPerMultiProcessor(place); + } #endif return max_threads; } @@ -393,6 +482,16 @@ int GetMaxBlocks() { // multiplication num_sm max_blocks *= num_sm; +#elif defined(CINN_WITH_CUSTOM_DEVICE) + std::vector dev_types = + phi::DeviceManager::GetAllCustomDeviceTypes(); + if (!dev_types.empty()) { + int device_id = phi::DeviceManager::GetDevice(dev_types[0]); + std::string dev_type = dev_types[0]; + auto place = phi::CustomPlace(dev_type, device_id); + max_blocks = phi::DeviceManager::GetMultiProcessors(place) * + phi::DeviceManager::GetMaxBlocksPerMultiProcessor(place); + } #endif return max_blocks; } @@ -404,6 +503,10 @@ const Target &DefaultTarget() { return DefaultHygonDcuSyclTarget(); #elif defined(CINN_WITH_HIP) return DefaultHygonDcuHipTarget(); +#elif defined(CINN_WITH_CUSTOM_DEVICE) + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (!dev_types.empty()) return DefaultCustomDeviceTarget(); + return DefaultHostTarget(); #else return DefaultHostTarget(); #endif @@ -432,6 +535,8 @@ bool GetSupportsCooperativeLaunchImpl(NVGPUArch) { return supportsCoopLaunch != 0; } +bool GetSupportsCooperativeLaunchImpl(CustomDeviceArch) { return false; } + bool GetSupportsCooperativeLaunchImpl(HygonDCUArchHIP) { return false; } bool GetSupportsCooperativeLaunchImpl(HygonDCUArchSYCL) { return false; } diff --git a/paddle/cinn/common/target.h b/paddle/cinn/common/target.h index 392af62d24fe18..f00bb77aafa8de 100644 --- a/paddle/cinn/common/target.h +++ b/paddle/cinn/common/target.h @@ -110,6 +110,8 @@ const Target& DefaultHygonDcuSyclTarget(); const Target& DefaultDeviceTarget(); +const Target& DefaultCustomDeviceTarget(); + const Target& DefaultTarget(); int GetMaxThreads(); diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 690cbe11d88788..ccc61cd490b308 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -432,6 +432,19 @@ std::vector OpLowererImpl::PostProcess( VLOG(4) << "After OptimizeExprGPU in op_lowering_impl: \n" << func_body_block; func_body = ir::ConvertStmtBlockToExprBlock(func_body_block); +#endif + }, + [&](common::CustomDeviceArch) { +#ifdef CINN_WITH_CUSTOM_DEVICE + // optim::EliminateCommonGlobalMemoryRead(&(func_body)); + ir::stmt::BlockRef func_body_block = + ir::ConvertExprBlockToStmtBlock(func_body); + VLOG(4) << "Before OptimizeExprGPU in op_lowering_impl: \n" + << func_body_block; + optim::OptimizeExprGPU(func_body_block); + VLOG(4) << "After OptimizeExprGPU in op_lowering_impl: \n" + << func_body_block; + func_body = ir::ConvertStmtBlockToExprBlock(func_body_block); #endif }, [&](std::variant) { @@ -633,6 +646,14 @@ std::vector OpLowererImpl::DoOpLower( op_func_arg_tensors->push_back(expr.as_tensor_ref()); } }, + [&](common::CustomDeviceArch) { + if (!expr.as_tensor_ref()->buffer.defined()) { + op_func_arg_tensors->push_back(expr.as_tensor_ref()); + expr.as_tensor_ref()->WithBuffer(); + } else { + op_func_arg_tensors->push_back(expr.as_tensor_ref()); + } + }, [&](std::variant) { diff --git a/paddle/cinn/hlir/op/contrib/logical_right_shift.cc b/paddle/cinn/hlir/op/contrib/logical_right_shift.cc index fb868893ae33ac..3a0d519c33842c 100644 --- a/paddle/cinn/hlir/op/contrib/logical_right_shift.cc +++ b/paddle/cinn/hlir/op/contrib/logical_right_shift.cc @@ -58,6 +58,7 @@ ir::Tensor LogicalRightShift(const ir::Tensor &A, [&](std::variant) { CINN_NOT_IMPLEMENTED }, + [&](common::CustomDeviceArch) { extern_func += "customDevice_"; }, [&](common::HygonDCUArchHIP) { extern_func += "hip_"; }, [&](common::HygonDCUArchSYCL) { extern_func += "sycl_"; }); diff --git a/paddle/cinn/hlir/op/contrib/sort.cc b/paddle/cinn/hlir/op/contrib/sort.cc index a77327b71ee90b..09bd50cddf588d 100644 --- a/paddle/cinn/hlir/op/contrib/sort.cc +++ b/paddle/cinn/hlir/op/contrib/sort.cc @@ -63,6 +63,9 @@ std::vector ArgSort(const ir::Tensor &A, [&](common::NVGPUArch) { find_func_name.assign("cinn_nvgpu_next_smallest_int32"); }, + [&](common::CustomDeviceArch) { + find_func_name.assign("cinn_custom_device_next_smallest_int32"); + }, [&](common::HygonDCUArchHIP) { find_func_name.assign("cinn_hip_next_smallest_int32"); }, diff --git a/paddle/cinn/hlir/op/custom_call.cc b/paddle/cinn/hlir/op/custom_call.cc index 364224e4327110..7c9a2a2753adba 100644 --- a/paddle/cinn/hlir/op/custom_call.cc +++ b/paddle/cinn/hlir/op/custom_call.cc @@ -125,6 +125,11 @@ std::shared_ptr StrategyForCustomCall( host_args.push_back(kernel_stream); arguments.emplace_back(kernel_stream, ir::Argument::IO::kOutput); }, + [&](common::CustomDeviceArch) { + ir::Var kernel_stream(KERNEL_STREAM, type_of()); + host_args.push_back(kernel_stream); + arguments.emplace_back(kernel_stream, ir::Argument::IO::kOutput); + }, [&](std::variant) {}, diff --git a/paddle/cinn/hlir/op/nn.cc b/paddle/cinn/hlir/op/nn.cc index a524d644697278..8314d920354ddc 100644 --- a/paddle/cinn/hlir/op/nn.cc +++ b/paddle/cinn/hlir/op/nn.cc @@ -370,6 +370,34 @@ std::shared_ptr StrategyForConv2d( } }, [&](common::ARMArch) { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) { + CINN_NOT_IMPLEMENTED + if (conv_type == "forward") { + out = pe::Conv2d_NCHW(A.as_tensor_ref(), + B.as_tensor_ref(), + padding[0], + padding[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + tensor_name); + out.push_back(B.as_tensor_ref()); + } else { +#ifdef CINN_WITH_CUDNN + // as backward_data and backward_filter is not + // support now, we built a fake op to instead. + // as the runtime use cudnn to compute the + // conv2d, so this fake op is not been called. + // When cinn support + // backward_filter/backward_data code gen, this + // code is to be removed. + out = pe::Identity(A.as_tensor_ref()); + out.push_back(A.as_tensor_ref()); + out.push_back(B.as_tensor_ref()); +#endif + } + }, [&](common::NVGPUArch) { if (conv_type == "forward") { out = pe::Conv2d_NCHW(A.as_tensor_ref(), @@ -518,6 +546,16 @@ std::shared_ptr StrategyForDepthwiseConv2d( if (data_format == "NCHW") { target.arch.Match( [&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) { + CINN_NOT_IMPLEMENTED + out = pe::Depthwise_Conv2d_NCHW(A.as_tensor_ref(), + B.as_tensor_ref(), + padding[0], + padding[1], + stride[0], + stride[1], + tensor_name); + }, [&](common::X86Arch) { out = pe::Conv2d_NCHW_5D(A.as_tensor_ref(), B.as_tensor_ref(), @@ -1005,6 +1043,19 @@ std::shared_ptr StrategyForPool2d( [&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; }, [&](common::X86Arch) { use_warp_reduce = false; }, [&](common::ARMArch) { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) { + CINN_NOT_IMPLEMENTED + if (global_pooling && data_format == "NCHW") { + // TODO(hp03): 32 may not be the exact number, try + // also 16 or 8 or other number + // we choose 32 to make sure all the threads in + // a warp has work to do, + if ((A_tensor->shape[2].as_int32() * A_tensor->shape[3].as_int32()) >= + 32) { + use_warp_reduce = true; + } + } + }, [&](common::NVGPUArch) { if (global_pooling && data_format == "NCHW") { // TODO(hp03): 32 may not be the exact number, try diff --git a/paddle/cinn/hlir/op/op_util.cc b/paddle/cinn/hlir/op/op_util.cc index 1e27607c9776ca..cd331917b54407 100644 --- a/paddle/cinn/hlir/op/op_util.cc +++ b/paddle/cinn/hlir/op/op_util.cc @@ -57,6 +57,11 @@ std::string GetExternFuncNameArchPrefixImpl(common::HygonDCUArchSYCL, return "sycl_"; } +std::string GetExternFuncNameArchPrefixImpl(common::CustomDeviceArch, + const std::string& func_name) { + return "custom_device_"; +} + std::string GetExternFuncNameArchPrefix(common::Arch arch, const std::string& func_name) { return std::visit( diff --git a/paddle/cinn/hlir/op/transform.cc b/paddle/cinn/hlir/op/transform.cc index b402d76972dfc2..292f0f4d8c854c 100644 --- a/paddle/cinn/hlir/op/transform.cc +++ b/paddle/cinn/hlir/op/transform.cc @@ -128,6 +128,9 @@ std::shared_ptr StrategyForMatMul( #endif }, [&](common::ARMArch) { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) { + out = pe::Matmul(new_A, new_B, trans_a, trans_b, alpha, tensor_name); + }, [&](common::NVGPUArch) { out = pe::Matmul(new_A, new_B, trans_a, trans_b, alpha, tensor_name); }, @@ -440,6 +443,9 @@ std::shared_ptr StrategyForMul( #endif }, [&](common::ARMArch) { CINN_NOT_IMPLEMENTED; }, + [&](common::CustomDeviceArch) { + out = pe::Matmul(new_A, new_B, false, is_infer, 1.0f, tensor_name); + }, [&](common::NVGPUArch) { out = pe::Matmul(new_A, new_B, false, is_infer, 1.0f, tensor_name); }, diff --git a/paddle/cinn/hlir/pe/ir_schedule_pe.cc b/paddle/cinn/hlir/pe/ir_schedule_pe.cc index afe49455cfad9c..5a31d6bb38fb6e 100644 --- a/paddle/cinn/hlir/pe/ir_schedule_pe.cc +++ b/paddle/cinn/hlir/pe/ir_schedule_pe.cc @@ -88,6 +88,7 @@ void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT }; target.arch.Match( [&](common::NVGPUArch) { schedule_nv_hygon(); }, + [&](common::CustomDeviceArch) { schedule_nv_hygon(); }, [&](std::variant) { // IRScheduleInjectiveCPU(ir_sch, output_shape, target, false); auto blocks = ir_sch.GetAllBlocks(); @@ -122,6 +123,7 @@ void IRInjectiveSchedule(ir::IRSchedule &ir_sch, // NOLINT }; target.arch.Match( [&](common::NVGPUArch) { schedule_nv_hygon(); }, + [&](common::CustomDeviceArch) { schedule_nv_hygon(); }, [&](std::variant) { // IRScheduleInjectiveCPU(ir_sch, @@ -209,6 +211,7 @@ std::vector IRGpuScheduleMatMul( const cinn::common::Target &target) { target.arch.Match( [&](common::NVGPUArch) {}, + [&](common::CustomDeviceArch) {}, [&](std::variant) { CINN_NOT_IMPLEMENTED; }, @@ -388,6 +391,7 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, // NOLINT target.arch.Match( [&](common::NVGPUArch) { SplitScheduleGpuDcu(); }, + [&](common::CustomDeviceArch) { SplitScheduleGpuDcu(); }, [&](std::variant) { { for (auto &block_name : block_names) { diff --git a/paddle/cinn/hlir/pe/schedule.cc b/paddle/cinn/hlir/pe/schedule.cc index 52b22e2d26ce33..cf17dcd5afb0df 100644 --- a/paddle/cinn/hlir/pe/schedule.cc +++ b/paddle/cinn/hlir/pe/schedule.cc @@ -52,6 +52,10 @@ ParamsT CreateParamsImpl(common::ARMArch) { ParamsT CreateParamsImpl(common::NVGPUArch) { return CreateCudaParams(); } +ParamsT CreateParamsImpl(common::CustomDeviceArch) { + return CreateCudaParams(); +} + ParamsT CreateParamsImpl(common::HygonDCUArchHIP) { return CreateCudaParams(); } ParamsT CreateParamsImpl(common::HygonDCUArchSYCL) { diff --git a/paddle/cinn/hlir/pe/schedule.h b/paddle/cinn/hlir/pe/schedule.h index 4e1a3af76e5271..2ea43a8faf8f94 100644 --- a/paddle/cinn/hlir/pe/schedule.h +++ b/paddle/cinn/hlir/pe/schedule.h @@ -35,6 +35,11 @@ class ScheduleParam { static ScheduleParam instance{cinn::common::NVGPUArch{}}; return instance; } + + static ScheduleParam &get_customdevice_instance() { + static ScheduleParam instance{cinn::common::CustomDeviceArch{}}; + return instance; + } static ScheduleParam &get_hip_instance() { static ScheduleParam instance{cinn::common::HygonDCUArchHIP{}}; return instance; diff --git a/paddle/cinn/hlir/pe/transform.cc b/paddle/cinn/hlir/pe/transform.cc index 1acad47ed53045..e56e35f919a484 100644 --- a/paddle/cinn/hlir/pe/transform.cc +++ b/paddle/cinn/hlir/pe/transform.cc @@ -872,6 +872,14 @@ std::vector MulBaseCallImpl(common::NVGPUArch, MulBaseCallImplNvHygon(A, B, name, target); } +std::vector MulBaseCallImpl(common::CustomDeviceArch, + const Tensor& A, + const Tensor& B, + const std::string& name, + const cinn::common::Target& target) { + MulBaseCallImplNvHygon(A, B, name, target); +} + std::vector MulBaseCallImpl(common::HygonDCUArchHIP, const Tensor& A, const Tensor& B, @@ -1662,6 +1670,10 @@ ir::Tensor ScatterAssign(const ir::Tensor& input, "ScatterAssign only support X86 and NVGPU ! Please Check.\n")); }, [&](common::NVGPUArch) { extern_fun_name.assign("cinn_cuda_find_int"); }, + [&](common::CustomDeviceArch) { + PADDLE_THROW(::common::errors::Fatal( + "ScatterAssign only support X86 and NVGPU ! Please Check.\n")); + }, [&](common::HygonDCUArchHIP) { extern_fun_name.assign("cinn_hip_find_int"); }, @@ -1775,6 +1787,7 @@ ir::Tensor ScatterAdd(const ir::Tensor& input, "HygonDCU now ! Please Check.\n")); }, [&](common::NVGPUArch) { return ScatterAddNvHygon(); }, + [&](common::CustomDeviceArch) { return ScatterAddNvHygon(); }, [&](std::variant) { return ScatterAddNvHygon(); }); diff --git a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc index 2aa5a8d523b4c1..8a5ab0f260768a 100644 --- a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc +++ b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc @@ -13,7 +13,12 @@ // limitations under the License. #include "paddle/cinn/ir/group_schedule/config/group_tile_config.h" +#include +#include +#include "paddle/cinn/common/target.h" #include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h" +#include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/common/place.h" namespace cinn { namespace ir { @@ -27,10 +32,56 @@ using TileConfigMap = namespace { const int kMaxNumel = BucketInfo::kMaxNumel; -constexpr int kWarpSize = 32; -constexpr int KMaxWarpSizePerSM = 64; -constexpr int KMaxBlockSizePerSM = 32; -constexpr int KMaxRegistersPerSM = 65536; +// constexpr int warp_size = 32; +// constexpr int KMaxWarpSizePerSM = 64; +// constexpr int KMaxBlockSizePerSM = 32; +// constexpr int KMaxRegistersPerSM = 65536; + +int GetWarpSize(const common::Target& target) { + return std::visit( + [&](const auto& impl) -> int { + // 获取当前 variant 存储的具体类型 + using ArchT = std::decay_t; + + if constexpr (std::is_same_v) { + return 32; + } else if constexpr (std::is_same_v) { +#ifdef CINN_WITH_CUSTOM_DEVICE + if (!impl.device_type.empty()) { + return phi::DeviceManager::GetWarpSize( + phi::CustomPlace(impl.device_type, impl.device_id)); + } +#endif + return 32; + } else { + return 32; + } + }, + target.arch.variant()); // [重点] 使用 arch.variant() +} + +// [新增] 辅助函数:获取 SM 最大寄存器数 +int GetMaxRegistersPerSM(const common::Target& target) { + return std::visit( + [&](const auto& impl) -> int { + using ArchT = std::decay_t; + + if constexpr (std::is_same_v) { + return 65536; + } else if constexpr (std::is_same_v) { +#ifdef CINN_WITH_CUSTOM_DEVICE + if (!impl.device_type.empty()) { + return phi::DeviceManager::GetMaxRegistersPerMultiProcessor( + phi::CustomPlace(impl.device_type, impl.device_id)); + } +#endif + return 65536; + } else { + return 65536; + } + }, + target.arch.variant()); +} int64_t CeilPow2(int64_t n) { int64_t pow = 1; @@ -220,10 +271,11 @@ std::pair CalculateBlocksAndSMsNeeded(const SMConfig& sm_config, bool ShouldUpdateWarpNums(int diff_to_fill_sm, int min_diff_to_full_sm, int threads_per_block, - int best_warp_nums) { + int best_warp_nums, + int warp_size) { return (diff_to_fill_sm < min_diff_to_full_sm) || (diff_to_fill_sm == min_diff_to_full_sm && - threads_per_block > best_warp_nums * kWarpSize); + threads_per_block > best_warp_nums * warp_size); } // Only proceed with vectorization if SM utilization exceeds 100% @@ -257,13 +309,26 @@ bool CheckSmUtilization( // By default, warp_nums can be a maximum of 8 (256 threads) // The Grid value should be divisible by the SM number as much as possible to // avoid Tail Effect. -int CalculateWarpNums(const SMConfig& sm_config, int total_threads_needed) { - int best_warp_nums = 8; +int CalculateWarpNums(const SMConfig& sm_config, + int total_threads_needed, + int warp_size, + const common::Target& target) { + int max_threads = target.max_num_threads(); + int max_warp_cnt = max_threads / warp_size; + int best_warp_nums = std::min(8, max_warp_cnt); int min_diff_to_full_sm = sm_config.sm_count; - std::vector thread_configs = {1024, 512, 256}; + std::vector thread_configs; + if (max_threads >= 1024) thread_configs.push_back(1024); + if (max_threads >= 512) thread_configs.push_back(512); + if (max_threads >= 256) thread_configs.push_back(256); + if (max_threads >= 128) thread_configs.push_back(128); + if (thread_configs.empty() || thread_configs[0] != max_threads) { + if (thread_configs.empty()) thread_configs.push_back(max_threads); + } + for (int threads_per_block : thread_configs) { - int current_warp_count = threads_per_block / kWarpSize; + int current_warp_count = threads_per_block / warp_size; int blocks_needed = std::ceil(static_cast(total_threads_needed) / threads_per_block); auto [max_effective_blocks_per_sm, sms_needed] = @@ -279,7 +344,8 @@ int CalculateWarpNums(const SMConfig& sm_config, int total_threads_needed) { if (ShouldUpdateWarpNums(diff_to_fill_sm, min_diff_to_full_sm, threads_per_block, - best_warp_nums)) { + best_warp_nums, + warp_size)) { min_diff_to_full_sm = diff_to_fill_sm; best_warp_nums = current_warp_count; } @@ -291,15 +357,16 @@ int CalculateWarpNums(const SMConfig& sm_config, int total_threads_needed) { int UpdateWarpNumsInDifferentCase( const std::shared_ptr& base_info, const GroupVectorizeInfo& group_vectorize_info, - int warp_nums) { + int warp_nums, + int max_warp_cnt) { const auto& last_dim = base_info->iter_space_type.back().first; if (group_vectorize_info.has_if_else_op && last_dim == "R") { - warp_nums = Trim(warp_nums, 1, 16); + warp_nums = Trim(warp_nums, 1, std::min(16, max_warp_cnt)); } else if (!group_vectorize_info.args_broadcast_axis_info.empty() && last_dim == "S") { - warp_nums = Trim(warp_nums, 1, 8); + warp_nums = Trim(warp_nums, 1, std::min(8, max_warp_cnt)); } else { - warp_nums = Trim(warp_nums, 1, 32); + warp_nums = Trim(warp_nums, 1, max_warp_cnt); } return warp_nums; } @@ -323,12 +390,13 @@ bool ReduceRegionCanVectorize( const std::shared_ptr& base_info, const SMConfig& sm_config, const int warp_nums, - const int factor) { + const int factor, + const int warp_size) { const int64_t spatial_numel = base_info->spatial_numel; const int64_t reduce_numel = base_info->reduce_numel; if (warp_nums < 4 && spatial_numel > 1) return false; - int rd_thread_num = warp_nums * kWarpSize; + int rd_thread_num = warp_nums * warp_size; if ((warp_nums > 1 || spatial_numel < warp_nums * 64) && CheckThreadDimensionCanVectorize( rd_thread_num, reduce_numel, factor, true) && @@ -344,10 +412,11 @@ bool SpatialRegionCanVectorize( const GroupVectorizeInfo& group_vectorize_info, const SMConfig& sm_config, const int warp_nums, - const int factor) { + const int factor, + const int warp_size) { const int64_t spatial_numel = base_info->spatial_numel; const int64_t reduce_numel = base_info->reduce_numel; - const int sp_thread_num = kWarpSize * warp_nums; + const int sp_thread_num = warp_size * warp_nums; if (group_vectorize_info.has_select_op) return false; if (CheckThreadDimensionCanVectorize( sp_thread_num, spatial_numel, factor, false) && @@ -394,8 +463,9 @@ int IsScalarTensorPreload( const std::vector& loop_ranges, const std::vector> broadcast_axis_infos, const int warp_nums, - const int vectorize_factor) { - const int threads_deal_elements = warp_nums * kWarpSize * vectorize_factor; + const int vectorize_factor, + const int warp_size) { + const int threads_deal_elements = warp_nums * warp_size * vectorize_factor; bool is_scalar_tensor = true; for (int i = 0; i < broadcast_axis_infos.size(); i++) { int last_dim = broadcast_axis_infos[i].size() - 1; @@ -417,7 +487,8 @@ int CalculateBroadcastTensorRegisterNums( const std::shared_ptr& base_info, const GroupVectorizeInfo& group_vectorize_info, const int vectorize_factor, - const int warp_nums) { + const int warp_nums, + const int warp_size) { // current only support [S, R] and [S] situation. // thread parellization only current at last dimension in R or S dimension. constexpr int register_bits = 32; @@ -438,7 +509,8 @@ int CalculateBroadcastTensorRegisterNums( base_info->loop_ranges, group_vectorize_info.args_broadcast_axis_info.at(tensor_name), warp_nums, - vectorize_factor)) { + vectorize_factor, + warp_size)) { tensor_buffer_size *= vectorize_factor; } int vectorize_data_bits = tensor_buffer_size * data_type_bits; @@ -474,7 +546,9 @@ bool RegisterNumsLimitedCheckInCTACanApplyVectorize( const std::shared_ptr& base_info, const GroupVectorizeInfo& group_vectorize_info, const int vectorize_factor, - const int warp_nums) { + const int warp_nums, + const common::Target& target, + int warp_size) { int thread_register_occupy_sum = 0; int vectorize_tensor_registers = CalculateVectorizeTensorRegisterNums( group_vectorize_info, vectorize_factor); @@ -482,7 +556,7 @@ bool RegisterNumsLimitedCheckInCTACanApplyVectorize( << vectorize_tensor_registers << "\n"; thread_register_occupy_sum += vectorize_tensor_registers; int broadcast_tensor_thread_registers = CalculateBroadcastTensorRegisterNums( - base_info, group_vectorize_info, vectorize_factor, warp_nums); + base_info, group_vectorize_info, vectorize_factor, warp_nums, warp_size); thread_register_occupy_sum += broadcast_tensor_thread_registers; VLOG(5) << "calculate broadcast tensor registers is : " << broadcast_tensor_thread_registers << "\n"; @@ -492,10 +566,14 @@ bool RegisterNumsLimitedCheckInCTACanApplyVectorize( thread_register_occupy_sum += other_register_occupy_sum; VLOG(5) << "calculate other registers is : " << other_register_occupy_sum << "\n"; + int max_threads_per_sm = target.get_max_threads_per_sm(); + int max_warps_per_sm = max_threads_per_sm / warp_size; + int max_blocks_per_sm_limit = target.get_max_blocks_per_sm(); + int max_regs_per_sm = GetMaxRegistersPerSM(target); int max_blocks_per_sm = - Trim(CeilDiv(KMaxWarpSizePerSM, warp_nums), 1, KMaxBlockSizePerSM); + Trim(CeilDiv(max_warps_per_sm, warp_nums), 1, max_blocks_per_sm_limit); int best_register_nums_per_thread = - KMaxRegistersPerSM / max_blocks_per_sm / warp_nums / kWarpSize; + max_regs_per_sm / max_blocks_per_sm / warp_nums / warp_size; VLOG(5) << "calculatet thread register occupy sum is : " << thread_register_occupy_sum << ", best register nums per thread is : " @@ -527,9 +605,15 @@ bool CheckPerformanceLimitInVectorize( const std::shared_ptr& base_info, const GroupVectorizeInfo& group_vectorize_info, const int vectorize_factor, - const int warp_nums) { - if (!RegisterNumsLimitedCheckInCTACanApplyVectorize( - base_info, group_vectorize_info, vectorize_factor, warp_nums)) { + const int warp_nums, + const common::Target& target, + const int warp_size) { + if (!RegisterNumsLimitedCheckInCTACanApplyVectorize(base_info, + group_vectorize_info, + vectorize_factor, + warp_nums, + target, + warp_size)) { VLOG(5) << "According to the limit of register, current schedule block " "can't enable vectorize!"; return false; @@ -561,6 +645,7 @@ TileConfigMap BuildVectorizeConfig( return {}; } + int warp_size = GetWarpSize(target); const std::vector vectorize_factors{4, 2}; int64_t spatial_numel = base_info->spatial_numel; int64_t reduce_numel = base_info->reduce_numel; @@ -574,17 +659,18 @@ TileConfigMap BuildVectorizeConfig( SMConfig sm_config(target.get_max_threads_per_sm(), target.get_max_blocks_per_sm(), target.get_multi_processor_count()); - + int max_threads_per_block = target.max_num_threads(); + int max_warp_cnt = max_threads_per_block / warp_size; // Reduce Region if (last_dim == "R") { for (auto factor : vectorize_factors) { vectorize_factor = factor; - const int elements_in_warp = kWarpSize * vectorize_factor; + const int elements_in_warp = warp_size * vectorize_factor; warp_nums = CeilDiv(reduce_numel, elements_in_warp); - warp_nums = Trim(warp_nums, 1, 32); - rd_thread_num = warp_nums * kWarpSize; + warp_nums = Trim(warp_nums, 1, max_warp_cnt); + rd_thread_num = warp_nums * warp_size; if (ReduceRegionCanVectorize( - base_info, sm_config, warp_nums, vectorize_factor)) { + base_info, sm_config, warp_nums, vectorize_factor, warp_size)) { can_vectorize = true; reduce_method = BlockReduceMethod(); break; @@ -593,29 +679,33 @@ TileConfigMap BuildVectorizeConfig( } else if (iters_dim == 1 && last_dim == "S") { // Spatial Region for (auto factor : vectorize_factors) { vectorize_factor = factor; - const int elements_in_warp = kWarpSize * vectorize_factor; + const int elements_in_warp = warp_size * vectorize_factor; warp_nums = CeilDiv(spatial_numel, elements_in_warp); - int max_warp_nums = - CalculateWarpNums(sm_config, spatial_numel / vectorize_factor); + int max_warp_nums = CalculateWarpNums( + sm_config, spatial_numel / vectorize_factor, warp_size, target); warp_nums = Trim(warp_nums, 1, max_warp_nums); - sp_thread_num = kWarpSize * warp_nums; + sp_thread_num = warp_size * warp_nums; if (SpatialRegionCanVectorize(base_info, group_vectorize_info, sm_config, warp_nums, - vectorize_factor)) { + vectorize_factor, + warp_size)) { can_vectorize = true; break; } } } - warp_nums = - UpdateWarpNumsInDifferentCase(base_info, group_vectorize_info, warp_nums); + warp_nums = UpdateWarpNumsInDifferentCase( + base_info, group_vectorize_info, warp_nums, max_warp_cnt); - if (can_vectorize && - !CheckPerformanceLimitInVectorize( - base_info, group_vectorize_info, vectorize_factor, warp_nums)) { + if (can_vectorize && !CheckPerformanceLimitInVectorize(base_info, + group_vectorize_info, + vectorize_factor, + warp_nums, + target, + warp_size)) { can_vectorize = false; } @@ -709,6 +799,9 @@ TileConfigMap BuildPureStaticShapeConfig( const std::shared_ptr& base_info, const GroupVectorizeInfo& vectorize_info, const common::Target& target) { + int warp_size = GetWarpSize(target); + int max_threads = target.max_num_threads(); + int max_warp_cnt = max_threads / warp_size; const auto& last_dim = base_info->iter_space_type.back().first; const int sm_count = target.get_multi_processor_count(); int64_t spatial_numel = base_info->spatial_numel; @@ -728,19 +821,19 @@ TileConfigMap BuildPureStaticShapeConfig( int64_t sp_thread_num = 1; int64_t rd_thread_num = 1; if (last_dim == "R") { - rd_thread_num = 32; - int64_t remain_reduce_numel = CeilDiv(reduce_numel, 32); + rd_thread_num = warp_size; + int64_t remain_reduce_numel = CeilDiv(reduce_numel, warp_size); if ((remain_reduce_numel <= 8 && spatial_numel > 1) || (spatial_numel > remain_reduce_numel * 128)) { sp_thread_num = Trim(spatial_numel, 1, 8); reduce_method = WarpReduceMethod(); } else { - rd_thread_num *= Trim(remain_reduce_numel, 1, 32); + rd_thread_num *= Trim(remain_reduce_numel, 1, max_warp_cnt); reduce_method = BlockReduceMethod(); } } else { // last_dim == "S" - sp_thread_num = 32; - int64_t remain_spatial_numel = CeilDiv(spatial_numel, 32); + sp_thread_num = warp_size; + int64_t remain_spatial_numel = CeilDiv(spatial_numel, warp_size); if (reduce_numel <= 16) { sp_thread_num *= Trim(remain_spatial_numel, 1, 8); } else { @@ -781,7 +874,8 @@ TileConfigMap BuildPureStaticShapeConfig( int64_t sp_upper_bound = base_info->spatial_numel > 1 ? kMaxNumel : 1; int64_t rd_upper_bound = base_info->reduce_numel > 1 ? kMaxNumel : 1; - int64_t warp_num = Trim(sp_thread_num * rd_thread_num / 32, 1, 32); + int64_t warp_num = + Trim(sp_thread_num * rd_thread_num / warp_size, 1, max_warp_cnt); BucketInfo bucket_info{1, sp_upper_bound, 1, rd_upper_bound}; TileConfig tile_config{warp_num, /* tree_reduce_num = */ rd_thread_num, @@ -796,6 +890,8 @@ TileConfigMap BuildPureStaticShapeConfig( TileConfigMap BuildStaticSpatialConfig( const std::shared_ptr& base_info, const common::Target& target) { + int warp_size = GetWarpSize(target); + int max_threads = target.max_num_threads(); const auto& last_dim = base_info->iter_space_type.back().first; const int sm_count = target.get_multi_processor_count(); const int64_t spatial_numel = base_info->spatial_numel; @@ -813,18 +909,25 @@ TileConfigMap BuildStaticSpatialConfig( {8, 256, 1, 1, 1, -1, BlockReduceMethod()}); if (rd_block_num > 1 && base_info->can_apply_grid_reduce) { - int64_t rd_threshold = rd_block_num * min_loops * 1024; + int64_t rd_threshold = rd_block_num * min_loops * max_threads; collector({1, kMaxNumel, 2049, rd_threshold}, - {32, 1024, 1, 1, 1, -1, BlockReduceMethod()}); + {warp_size, max_threads, 1, 1, 1, -1, BlockReduceMethod()}); collector({1, kMaxNumel, rd_threshold + 1, kMaxNumel}, - {32, 1024, rd_block_num, 1, 1, -1, BlockReduceMethod()}); + {warp_size, + max_threads, + rd_block_num, + 1, + 1, + -1, + BlockReduceMethod()}); } else { collector({1, kMaxNumel, 2049, kMaxNumel}, - {32, 1024, 1, 1, 1, -1, BlockReduceMethod()}); + {warp_size, max_threads, 1, 1, 1, -1, BlockReduceMethod()}); } } else { // last_dim == "S" - int64_t sp_block_num = std::max(CeilDiv(spatial_numel, 32), int64_t(1)); + int64_t sp_block_num = + std::max(CeilDiv(spatial_numel, warp_size), int64_t(1)); int64_t rd_block_num = FloorPow2(sm_count / sp_block_num); if (rd_block_num > 1 && base_info->can_apply_grid_reduce) { @@ -845,6 +948,8 @@ TileConfigMap BuildStaticSpatialConfig( TileConfigMap BuildStaticReduceConfig( const std::shared_ptr& base_info, const common::Target& target) { + int warp_size = GetWarpSize(target); + int max_threads = target.max_num_threads(); const auto& last_dim = base_info->iter_space_type.back().first; TileConfigCollector collector; @@ -865,13 +970,13 @@ TileConfigMap BuildStaticReduceConfig( {warp_num, tree_reduce_num, 1, 1, 1, -1, BlockReduceMethod()}); } else { collector({1, kMaxNumel, 2049, kMaxNumel}, - {32, 1024, 1, 1, 1, -1, BlockReduceMethod()}); + {warp_size, max_threads, 1, 1, 1, -1, BlockReduceMethod()}); } } else { // last_dim == "S" if (base_info->reduce_numel == 1) { collector({1, 1023, 1, 1}, {-1, 1, 1, 1, 1, -1, NoneReduceMethod()}); collector({1024, kMaxNumel, 1, 1}, - {32, 1, 1, 4, 1, -1, NoneReduceMethod()}); + {warp_size, 1, 1, 4, 1, -1, NoneReduceMethod()}); } else if (base_info->reduce_numel <= 16) { collector({1, kMaxNumel, 1, 1}, {8, 1, 1, 1, 1, -1, NoneReduceMethod()}); } else { @@ -886,6 +991,8 @@ TileConfigMap BuildStaticReduceConfig( TileConfigMap BuildDynamicShapeConfig( const std::shared_ptr& base_info, const common::Target& target) { + int warp_size = GetWarpSize(target); + int max_threads = target.max_num_threads(); const auto& last_dim = base_info->iter_space_type.back().first; TileConfigCollector collector; @@ -897,7 +1004,7 @@ TileConfigMap BuildDynamicShapeConfig( collector({1, kMaxNumel, 257, 2048}, {8, 256, 1, 1, 1, 8, BlockReduceMethod()}); collector({1, kMaxNumel, 2049, kMaxNumel}, - {32, 1024, 1, 1, 1, -1, BlockReduceMethod()}); + {warp_size, max_threads, 1, 1, 1, -1, BlockReduceMethod()}); } else { // last_dim == "S" collector({1, kMaxNumel, 1, kMaxNumel}, {16, 16, 1, 1, 1, -1, DiscreteReduceMethod()}); diff --git a/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc index a7ba723d3157e5..351c2f616661ec 100644 --- a/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc @@ -167,6 +167,12 @@ void OptimizeReductionTactic::Apply(ir::IRSchedule* sch, sch->Bind(rb_loops.back(), "threadIdx.x"); sch->SetBuffer(rf_block, "local"); }, + [&](common::CustomDeviceArch) { + rb_loops = sch->GetLoops(block_id); + rf_block = sch->GetBlock(rf_block_id); + sch->Bind(rb_loops.back(), "threadIdx.x"); + sch->SetBuffer(rf_block, "local"); + }, [&](std::variant) { }, [&](std::variant) { diff --git a/paddle/cinn/ir/module.cc b/paddle/cinn/ir/module.cc index d6c92481f706bc..283489b4808ae7 100644 --- a/paddle/cinn/ir/module.cc +++ b/paddle/cinn/ir/module.cc @@ -45,6 +45,9 @@ std::optional GetDataAlignmentImpl(common::ARMArch arch) { std::optional GetDataAlignmentImpl(common::NVGPUArch) { return std::nullopt; } +std::optional GetDataAlignmentImpl(common::CustomDeviceArch) { + return std::nullopt; +} std::optional GetDataAlignmentImpl(common::HygonDCUArchHIP arch) { return std::nullopt; diff --git a/paddle/cinn/ir/op/ir_operators.cc b/paddle/cinn/ir/op/ir_operators.cc index d64b17b6573708..353c8d5e7f32b9 100644 --- a/paddle/cinn/ir/op/ir_operators.cc +++ b/paddle/cinn/ir/op/ir_operators.cc @@ -113,6 +113,15 @@ Expr BitwiseOrCallImpl(common::ARMArch, const Target &target, Expr a, Expr b) { PADDLE_THROW(::common::errors::InvalidArgument(ss.str())); } +Expr BitwiseOrCallImpl(common::CustomDeviceArch, + const Target &target, + Expr a, + Expr b) { + Type t_a = a.type(); + auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_or"); + return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); +} + Expr BitwiseOrCallImpl(common::NVGPUArch, const Target &target, Expr a, @@ -198,6 +207,15 @@ Expr BitwiseAndCallImpl(common::NVGPUArch, return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); } +Expr BitwiseAndCallImpl(common::CustomDeviceArch, + const Target &target, + Expr a, + Expr b) { + Type t_a = a.type(); + auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_and"); + return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); +} + Expr BitwiseAndCallImpl(common::HygonDCUArchHIP, const Target &target, Expr a, @@ -274,6 +292,15 @@ Expr BitwiseXorCallImpl(common::NVGPUArch, return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); } +Expr BitwiseXorCallImpl(common::CustomDeviceArch, + const Target &target, + Expr a, + Expr b) { + Type t_a = a.type(); + auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_xor"); + return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); +} + Expr BitwiseXorCallImpl(common::HygonDCUArchHIP, const Target &target, Expr a, @@ -343,6 +370,13 @@ Expr BitwiseNotCallImpl(common::NVGPUArch, const Target &target, Expr a) { return lang::CallExtern(func_name, {a}, {{"vectorizable", false}}); } +Expr BitwiseNotCallImpl(common::CustomDeviceArch, + const Target &target, + Expr a) { + auto func_name = hlir::GetExternFuncName(target, a->type(), "bitwise_not"); + return lang::CallExtern(func_name, {a}, {{"vectorizable", false}}); +} + Expr BitwiseNotCallImpl(common::HygonDCUArchHIP, const Target &target, Expr a) { auto func_name = hlir::GetExternFuncName(target, a->type(), "bitwise_not"); return lang::CallExtern(func_name, {a}, {{"vectorizable", false}}); diff --git a/paddle/cinn/ir/schedule/impl/for_type.cc b/paddle/cinn/ir/schedule/impl/for_type.cc index 0bdaad423cf003..cc2f9c0b8a1a21 100644 --- a/paddle/cinn/ir/schedule/impl/for_type.cc +++ b/paddle/cinn/ir/schedule/impl/for_type.cc @@ -17,6 +17,9 @@ #include "paddle/cinn/ir/schedule/impl/ir_schedule.h" #include "paddle/cinn/runtime/backend_api.h" #include "paddle/common/enforce.h" +#ifdef CINN_WITH_CUSTOM_DEVICE +#include "paddle/phi/backends/device_manager.h" +#endif namespace cinn { namespace ir { @@ -203,6 +206,23 @@ void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) { const std::array kMaxGridDims = cur_dev_info->GetMaxGridDims(); bindNvHygon(kMaxBlockDims, kMaxGridDims); #endif + }, + [&](const common::CustomDeviceArch& arch) { +#ifdef CINN_WITH_CUSTOM_DEVICE + auto place = phi::CustomPlace(arch.device_type, arch.device_id); + + const std::array kMaxBlockDims = { + static_cast(phi::DeviceManager::GetMaxBlockDimSize(place)[0]), + static_cast(phi::DeviceManager::GetMaxBlockDimSize(place)[1]), + static_cast(phi::DeviceManager::GetMaxBlockDimSize(place)[2])}; + + const std::array kMaxGridDims = { + static_cast(phi::DeviceManager::GetMaxGridDimSize(place)[0]), + static_cast(phi::DeviceManager::GetMaxGridDimSize(place)[1]), + static_cast(phi::DeviceManager::GetMaxGridDimSize(place)[2])}; + + bindNvHygon(kMaxBlockDims, kMaxGridDims); +#endif }, [&](common::HygonDCUArchHIP) { #ifdef CINN_WITH_HIP diff --git a/paddle/cinn/lang/lower.cc b/paddle/cinn/lang/lower.cc index e84e3f3804b5a0..4e006342784c42 100644 --- a/paddle/cinn/lang/lower.cc +++ b/paddle/cinn/lang/lower.cc @@ -292,6 +292,7 @@ std::vector LowerToAstVec( for (auto& res : result) { target.arch.Match( [&](common::NVGPUArch) { res->device_api = ir::DeviceAPI::GPU; }, + [&](common::CustomDeviceArch) { res->device_api = ir::DeviceAPI::GPU; }, [&](std::variant) { res->device_api = ir::DeviceAPI::GPU; }, diff --git a/paddle/cinn/lang/lower_tensor_group.cc b/paddle/cinn/lang/lower_tensor_group.cc index 044b85cbf5a28d..2cadf8bd82f157 100644 --- a/paddle/cinn/lang/lower_tensor_group.cc +++ b/paddle/cinn/lang/lower_tensor_group.cc @@ -246,6 +246,12 @@ std::vector LowerTensorGroup::GenerateFunctionBody( bodies.clear(); } }, + [&](common::CustomDeviceArch) { + if (!gpu_local) { + result.push_back(BlockRef(bodies)); + bodies.clear(); + } + }, [&](std::variant) { if (!gpu_local) { result.push_back(BlockRef(bodies)); diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index 646c31a6d1c7da..67c364ebbdc9c1 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -43,6 +43,8 @@ gather_srcs( if_fold_pass.cc simplify_util.cc) -if(WITH_CUDA OR WITH_ROCM) +if(WITH_CUDA + OR WITH_ROCM + OR WITH_CUSTOM_DEVICE) gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc) endif() diff --git a/paddle/cinn/optim/cast_bool_to_int8.cc b/paddle/cinn/optim/cast_bool_to_int8.cc index 861bf8c8bc66f0..8cb7e03d9dcc0c 100644 --- a/paddle/cinn/optim/cast_bool_to_int8.cc +++ b/paddle/cinn/optim/cast_bool_to_int8.cc @@ -55,6 +55,10 @@ void CastBoolExprToInt8Impl(common::ARMArch, Expr* e) { // Do nothing. } +void CastBoolExprToInt8Impl(common::CustomDeviceArch, Expr* e) { + // Do nothing. +} + void CastBoolExprToInt8Impl(common::NVGPUArch, Expr* e) { // Do nothing. } diff --git a/paddle/cinn/optim/map_extern_call.cc b/paddle/cinn/optim/map_extern_call.cc index bdf143a26e5170..4fdedea6a9f6ac 100644 --- a/paddle/cinn/optim/map_extern_call.cc +++ b/paddle/cinn/optim/map_extern_call.cc @@ -108,6 +108,12 @@ void DealWithIntrinsicsImpl(common::NVGPUArch, ir::Call *node, Expr *expr) { DealWithIntrinsicsNvHygon(node, expr); } +void DealWithIntrinsicsImpl(common::CustomDeviceArch, + ir::Call *node, + Expr *expr) { + DealWithIntrinsicsNvHygon(node, expr); +} + void DealWithIntrinsicsImpl(common::HygonDCUArchHIP, ir::Call *node, Expr *expr) { diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index 1c4eabc02bb6e2..679e4ccbcbf5e5 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -103,6 +103,28 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, func_pass_manager.AddPass(CreateTransBufferWithDynamicShapePass()); func_pass_manager.Run(copied); VLOG(10) << "After Optimize TransBufferWithDynamicShape:" << copied; +#endif + }, + [&](common::CustomDeviceArch) { +#ifdef CINN_WITH_CUSTOM_DEVICE + ir::SetCudaAxisInfo(copied); + if (remove_gpu_for_loops) { + VLOG(4) << "Before removing GPU for loops:\n" << copied; + FuncPassManager func_pass_manager; + func_pass_manager.AddPass(CreateRemoveGpuForLoopsPass()); + func_pass_manager.Run(copied); + VLOG(4) << "After removing GPU for loops:\n" << copied; + } + VLOG(10) << "Before Optimize CudaSyncThreadsDropIfThenElse:" << copied; + BlockPassManager blk_pass_manager; + blk_pass_manager.AddPass(CreateCudaSyncThreadsDropIfThenElsePass()); + blk_pass_manager.Run(copied->body_block); + VLOG(10) << "After Optimize CudaSyncThreadsDropIfThenElse:" << copied; + FuncPassManager func_pass_manager; + VLOG(10) << "Before Optimize TransBufferWithDynamicShape:" << copied; + func_pass_manager.AddPass(CreateTransBufferWithDynamicShapePass()); + func_pass_manager.Run(copied); + VLOG(10) << "After Optimize TransBufferWithDynamicShape:" << copied; #endif }, [&](std::variant) { @@ -161,6 +183,12 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, func_pass_manager.Run(copied); VLOG(4) << "After Optimize RearrangeLoadInstruction:" << copied; }, + [&](common::CustomDeviceArch) { + FuncPassManager func_pass_manager; + func_pass_manager.AddPass(CreateRearrangeLoadInstructionPass()); + func_pass_manager.Run(copied); + VLOG(4) << "After Optimize RearrangeLoadInstruction:" << copied; + }, [&](std::variant) { FuncPassManager func_pass_manager; func_pass_manager.AddPass(CreateRearrangeLoadInstructionPass()); diff --git a/paddle/cinn/optim/realize_composite_reduce_pass.cc b/paddle/cinn/optim/realize_composite_reduce_pass.cc index e6412651bedfe4..5c168660f2564f 100644 --- a/paddle/cinn/optim/realize_composite_reduce_pass.cc +++ b/paddle/cinn/optim/realize_composite_reduce_pass.cc @@ -683,6 +683,7 @@ LogicalResult RealizeCompositeReducePass::Run(ir::LoweredFunc func) { ReplaceOutputBufferX86(body, output_buffers, typed_buffers); }, [&](std::variant GetArchDevice(const common::Target& target) { return std::optional{device_id}; #else return std::nullopt; +#endif + }, + [&](common::CustomDeviceArch) -> std::optional { +#ifdef CINN_WITH_CUSTOM_DEVICE + int device_id = phi::DeviceManager::GetDevice(target.device_name_str()); + return std::optional{device_id}; +#else + return std::nullopt; #endif }, [&](common::HygonDCUArchHIP) -> std::optional { @@ -68,6 +78,17 @@ inline void SetArchDevice(const common::Target& target, "Required device_id should have value, but " "received std::nullopt.")); cudaSetDevice(device_id.value()); +#endif + }, + [&](common::CustomDeviceArch) -> void { +#ifdef CINN_WITH_CUSTOM_DEVICE + PADDLE_ENFORCE_EQ(device_id.has_value(), + true, + ::common::errors::InvalidArgument( + "Required device_id should have value, but " + "received std::nullopt.")); + phi::DeviceManager::SetDevice(target.device_name_str(), + device_id.value()); #endif }, [&](common::HygonDCUArchHIP) -> void { diff --git a/paddle/cinn/runtime/backend_api.cc b/paddle/cinn/runtime/backend_api.cc index 5f7374fd48aa76..06854cfef9515c 100644 --- a/paddle/cinn/runtime/backend_api.cc +++ b/paddle/cinn/runtime/backend_api.cc @@ -45,6 +45,7 @@ BackendAPI* BackendAPI::get_backend(common::Arch arch) { [&](std::variant) { std::stringstream ss; ss << "Target(" << arch << ") is not support get_backend now."; diff --git a/paddle/cinn/runtime/cinn_runtime.h b/paddle/cinn/runtime/cinn_runtime.h index 062e8e86130007..5c45a4b00deed1 100644 --- a/paddle/cinn/runtime/cinn_runtime.h +++ b/paddle/cinn/runtime/cinn_runtime.h @@ -137,7 +137,8 @@ typedef enum cinn_device_kind_t { cinn_x86_device = 0, // X86 device cinn_opencl_device = 1, // OpenCL device cinn_arm_device = 2, // ARM device - cinn_nvgpu_device = 3 // NVIDIA GPU device + cinn_nvgpu_device = 3, // NVIDIA GPU device + cinn_custom_device = 4 // custom device } cinn_device_kind_t; //! Help to tell where the buffer locates. diff --git a/paddle/cinn/runtime/custom_device/CMakeLists.txt b/paddle/cinn/runtime/custom_device/CMakeLists.txt new file mode 100755 index 00000000000000..c50b84e89b4423 --- /dev/null +++ b/paddle/cinn/runtime/custom_device/CMakeLists.txt @@ -0,0 +1,13 @@ +core_gather_headers() + +set(SRCS + custom_device_backend_api.cc custom_device_util.cc + custom_device_intrinsics.cc custom_device_intrinsics_reduce.cc + custom_device_intrinsics_float16.cc) + +# 编译为 CINN 的 Custom Device 运行时库 +cc_library( + cinn_custom_device_runtime + SRCS ${SRCS} + DEPS phi_core # 必须保留,否则找不到 DeviceManager 和 Place +) diff --git a/paddle/cinn/runtime/custom_device/bfloat16.h b/paddle/cinn/runtime/custom_device/bfloat16.h new file mode 100644 index 00000000000000..05e27da8fb0931 --- /dev/null +++ b/paddle/cinn/runtime/custom_device/bfloat16.h @@ -0,0 +1,441 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CINN_COMMON_BFLOAT16_H +#define CINN_COMMON_BFLOAT16_H + +#ifdef __cplusplus +#pragma once +#endif // __cplusplus + +#include + +#include +#include + +#ifdef CINN_WITH_CUDA +#include + +#if (defined(__CUDACC__) || defined(__CUDACC_RTC__)) && CUDA_VERSION >= 11000 +#define CINN_CUDA_BF16 +#include + +#endif // __CUDACC__ +#endif // CINN_WITH_CUDA + +#ifdef __cplusplus + +#ifndef _WIN32 +#define CINN_ALIGN(x) __attribute__((aligned(x))) +#else // _WIN32 +#define CINN_ALIGN(x) __declspec(align(x)) +#endif // _WIN32 + +#else // __cplusplus +#define CINN_ALIGN(x) +#endif // __cplusplus + +// The `HOST` macro definition is not used here, it has a potential +// conflict with the enumeration `kHOST` representing the backend. +#ifndef __host__ +#define __host__ +#endif +#ifndef __device__ +#define __device__ +#endif + +#ifdef __cplusplus +namespace cinn { +namespace common { +#endif // __cplusplus + +// Use CINN_ALIGNED(2) to ensure that each bfloat16 will be allocated +// and aligned at least on a 2-byte boundary, which leads to efficient +// memory access of float16 struct and also makes bfloat16 compatible +// with CUDA half +struct CINN_ALIGN(2) bfloat16 { + uint16_t x; + +#ifdef __cplusplus + // Constructors + bfloat16() = default; + bfloat16(const bfloat16& o) = default; + bfloat16& operator=(const bfloat16& o) = default; + bfloat16(bfloat16&& o) = default; + bfloat16& operator=(bfloat16&& o) = default; + ~bfloat16() = default; + + __host__ __device__ inline explicit bfloat16(float val) { +#if defined(CINN_CUDA_BF16) + __nv_bfloat16 tmp = __float2bfloat16(val); + x = *reinterpret_cast(&tmp); +#else + std::memcpy(&x, reinterpret_cast(&val) + 2, 2); +#endif + } + +#if defined(CINN_CUDA_BF16) + __host__ __device__ inline explicit bfloat16(const __nv_bfloat16& val) { + x = *reinterpret_cast(&val); // NOLINT + } +#endif + + template + __host__ __device__ inline explicit bfloat16(const T& val) + : x(bfloat16(static_cast(val)).x) {} + +// Assignment operators +#if defined(CINN_CUDA_BF16) + __host__ __device__ inline bfloat16& operator=(const __nv_bfloat16& val) { + x = *reinterpret_cast(&val); // NOLINT + return *this; + } +#endif + + __host__ __device__ inline bfloat16& operator=(bool b) { + x = b ? 0x3f80 : 0; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(int8_t val) { + x = bfloat16(val).x; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(uint8_t val) { + x = bfloat16(val).x; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(int16_t val) { + x = bfloat16(val).x; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(uint16_t val) { + x = bfloat16(val).x; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(int32_t val) { + x = bfloat16(val).x; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(uint32_t val) { + x = bfloat16(val).x; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(int64_t val) { + x = bfloat16(val).x; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(uint64_t val) { + x = bfloat16(val).x; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(float val) { + x = bfloat16(val).x; + return *this; + } + + __host__ __device__ inline bfloat16& operator=(double val) { + x = bfloat16(val).x; + return *this; + } + + // Conversion operators + __host__ __device__ inline operator float() const { +#ifdef CINN_CUDA_BF16 + return __bfloat162float(*reinterpret_cast(&x)); +#else + float val = 0.f; + uint16_t temp = x; + std::memcpy( + reinterpret_cast(&val) + 2, reinterpret_cast(&temp), 2); + return val; +#endif + } + +#ifdef CINN_CUDA_BF16 + __host__ __device__ inline __nv_bfloat16 to_nv_bfloat16() const { + return *reinterpret_cast(&x); + } +#endif + + __host__ __device__ inline explicit operator bool() const { + return (x & 0x7fff) != 0; + } + + __host__ __device__ inline explicit operator int8_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint8_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator int16_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint16_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator int32_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint32_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator int64_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint64_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline operator double() const { + return static_cast(static_cast(*this)); + } +#endif // __cplusplus +}; + +struct CINN_ALIGN(16) bfloat168 { + bfloat16 x, y, z, w, v, u, t, s; +}; + +struct CINN_ALIGN(8) bfloat164 { + bfloat16 x, y, z, w; +}; + +struct CINN_ALIGN(4) bfloat162 { + bfloat16 x, y; +}; + +__host__ __device__ inline bfloat16 operator+(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return bfloat16(__hadd(a.to_nv_bfloat16(), b.to_nv_bfloat16())); +#else + return bfloat16(static_cast(a) + static_cast(b)); +#endif +} + +__host__ __device__ inline bfloat16 operator-(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return bfloat16(__hsub(a.to_nv_bfloat16(), b.to_nv_bfloat16())); +#else + return bfloat16(static_cast(a) - static_cast(b)); +#endif +} + +__host__ __device__ inline bfloat16 operator*(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return bfloat16(__hmul(a.to_nv_bfloat16(), b.to_nv_bfloat16())); +#else + return bfloat16(static_cast(a) * static_cast(b)); +#endif +} + +__host__ __device__ inline bfloat16 operator/(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return bfloat16(__hdiv(a.to_nv_bfloat16(), b.to_nv_bfloat16())); +#else + return bfloat16(static_cast(a) / static_cast(b)); +#endif +} + +__host__ __device__ inline bfloat16 operator-(const bfloat16& a) { + bfloat16 res; + res.x = a.x ^ 0x8000; + return res; +} + +__host__ __device__ inline bfloat16& operator+=(bfloat16& a, // NOLINT + const bfloat16& b) { + a = a + b; + return a; +} + +__host__ __device__ inline bfloat16& operator-=(bfloat16& a, // NOLINT + const bfloat16& b) { + a = a - b; + return a; +} + +__host__ __device__ inline bfloat16& operator*=(bfloat16& a, // NOLINT + const bfloat16& b) { + a = a * b; + return a; +} + +__host__ __device__ inline bfloat16& operator/=(bfloat16& a, // NOLINT + const bfloat16& b) { + a = a / b; + return a; +} + +__host__ __device__ inline bfloat16 raw_uint16_to_bfloat16(uint16_t a) { + bfloat16 res; + res.x = a; + return res; +} + +// Comparison operators +__host__ __device__ inline bool operator==(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __heq(a.to_nv_bfloat16(), b.to_nv_bfloat16()); +#else + return static_cast(a) == static_cast(b); +#endif +} + +__host__ __device__ inline bool operator!=(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hne(a.to_nv_bfloat16(), b.to_nv_bfloat16()); +#else + return static_cast(a) != static_cast(b); +#endif +} + +__host__ __device__ inline bool operator<(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hlt(a.to_nv_bfloat16(), b.to_nv_bfloat16()); +#else + return static_cast(a) < static_cast(b); +#endif +} + +__host__ __device__ inline bool operator<=(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hle(a.to_nv_bfloat16(), b.to_nv_bfloat16()); +#else + return static_cast(a) <= static_cast(b); +#endif +} + +__host__ __device__ inline bool operator>(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hgt(a.to_nv_bfloat16(), b.to_nv_bfloat16()); +#else + return static_cast(a) > static_cast(b); +#endif +} + +__host__ __device__ inline bool operator>=(const bfloat16& a, + const bfloat16& b) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hge(a.to_nv_bfloat16(), b.to_nv_bfloat16()); +#else + return static_cast(a) >= static_cast(b); +#endif +} + +__host__ __device__ inline bool(isnan)(const bfloat16& a) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hisnan(a.to_nv_bfloat16()); +#else + return (a.x & 0x7FFF) > 0x7F80; +#endif +} + +__host__ __device__ inline bool(isinf)(const bfloat16& a) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hisinf(a.to_nv_bfloat16()); +#else + return (a.x & 0x7F80) == 0x7F80; +#endif +} + +__host__ __device__ inline bool(isfinite)(const bfloat16& a) { + return !((isnan)(a)) && !((isinf)(a)); +} + +__host__ __device__ inline bfloat16(abs)(const bfloat16& a) { +#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return bfloat16(__habs(a.to_nv_bfloat16())); +#else + return bfloat16(std::abs(static_cast(a))); +#endif +} + +#ifdef __cplusplus +} // namespace common +} // namespace cinn +#endif // __cplusplus + +// for runtime calls +#if defined(__cplusplus) && defined(CINN_CUDA_BF16) +__device__ inline cinn::common::bfloat16 __shfl_sync(unsigned mask, + cinn::common::bfloat16 var, + int srcLane, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_sync(mask, var.to_nv_bfloat16(), srcLane, width)); +} + +__device__ inline cinn::common::bfloat16 __shfl_up_sync( + unsigned mask, + cinn::common::bfloat16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_up_sync(mask, var.to_nv_bfloat16(), delta, width)); +} + +__device__ inline cinn::common::bfloat16 __shfl_down_sync( + unsigned mask, + cinn::common::bfloat16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_down_sync(mask, var.to_nv_bfloat16(), delta, width)); +} + +__device__ inline cinn::common::bfloat16 __shfl_xor_sync( + unsigned mask, + cinn::common::bfloat16 var, + int laneMask, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_xor_sync(mask, var.to_nv_bfloat16(), laneMask, width)); +} + +__host__ __device__ inline cinn::common::bfloat16 max( + const cinn::common::bfloat16& a, const cinn::common::bfloat16& b) { + return a > b ? a : b; +} +__host__ __device__ inline cinn::common::bfloat16 min( + const cinn::common::bfloat16& a, const cinn::common::bfloat16& b) { + return a < b ? a : b; +} +#endif // __cplusplus && CINN_CUDA_FP16 + +#endif // CINN_COMMON_BFLOAT16_H diff --git a/paddle/cinn/runtime/custom_device/custom_device_backend_api.cc b/paddle/cinn/runtime/custom_device/custom_device_backend_api.cc new file mode 100644 index 00000000000000..5f5c146e74b020 --- /dev/null +++ b/paddle/cinn/runtime/custom_device/custom_device_backend_api.cc @@ -0,0 +1,408 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// paddle/cinn/runtime/custom_device/custom_device_backend_api.cc + +#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h" +#include "glog/logging.h" +#include "paddle/phi/backends/device_ext.h" +#include "paddle/phi/backends/device_manager.h" + +#ifdef CINN_WITH_CUSTOM_DEVICE +namespace cinn { +namespace runtime { +namespace custom_device { + +// ============================================================ +// 匿名命名空间:定义具体的默认实现类 (不对外暴露) +// ============================================================ +namespace { +// 0. 具体的 Module 实现类 +// 默认的 CustomDeviceModule 实现(连接 module_unload 和 get_kernel_address) +class DefaultCustomDeviceModule : public cinn::runtime::CustomModule { + public: + DefaultCustomDeviceModule(void* handle, C_CinnInterface* cif) + : handle_(handle), cif_(cif) {} + + // RAII: 析构时自动调用 module_unload + ~DefaultCustomDeviceModule() override { + if (handle_ && cif_ && cif_->module_unload) { + // 传递厂商上下文 dev_ptr 和模块句柄 + cif_->module_unload(cif_->dev_ptr, handle_); + } + } + + // 实现基类的 GetFunction + void* GetFunction(const std::string& func_name) override { + if (handle_ && cif_ && cif_->get_kernel_address) { + void* func_ptr = nullptr; + // 调用 C 接口查找符号 + C_Status status = cif_->get_kernel_address( + cif_->dev_ptr, handle_, func_name.c_str(), &func_ptr); + + if (status == C_SUCCESS) { + return func_ptr; + } else { + LOG(WARNING) << "Failed to get kernel address for: " << func_name; + } + } + return nullptr; + } + + private: + void* handle_; // 模块句柄 + C_CinnInterface* cif_; // 接口指针 (用于回调) +}; + +// 1. 编译工具链接口:负责调用外部编译器 (如 mxcc) +// 默认编译工具链实现 +class DefaultCompilerToolchain : public CustomCompilerToolchain { + public: + explicit DefaultCompilerToolchain(C_CinnInterface* cif) : cif_(cif) {} + + // 1. 实现 Compile:调用 C 接口 + std::string Compile(const std::string& code) override { + if (cif_ && cif_->compile) { + // TODO(Plugin): 这里需要按照具体的 C 接口协议调用 compile + // void* handle = nullptr; + char output_path[1024] = {0}; + C_Status status = cif_->compile( + cif_->dev_ptr, code.c_str(), output_path, sizeof(output_path)); + if (status == C_SUCCESS) { + VLOG(3) << "Calling Custom Device compile_kernel..."; + return std::string(output_path); + } + } + LOG(ERROR) << "compile_kernel interface not implemented by vendor."; + return ""; + } + + // 2. 实现 GetRuntimeSource:调用 C 接口 + std::string GetRuntimeSource() override { + if (cif_ && cif_->get_runtime_source) { + // 获取厂商内置的 Runtime 源码字符串 + const char* src = cif_->get_runtime_source(cif_->dev_ptr); + return src ? std::string(src) : ""; + } + // 如果厂商没提供,可能返回一个空的,或者基础的通用定义 + return ""; + } + + private: + C_CinnInterface* cif_; +}; + +// 2. 运行时策略接口:负责加载和启动 Kernel +// 默认运行时策略实现 +class DefaultRuntimeStrategy : public CustomRuntimeStrategy { + public: + explicit DefaultRuntimeStrategy(C_CinnInterface* cif) : cif_(cif) {} + + std::unique_ptr LoadModule( + const std::string& path) override { + if (cif_ && cif_->module_load) { + void* handle = nullptr; + C_Status status = cif_->module_load(cif_->dev_ptr, path.c_str(), &handle); + + if (status == C_SUCCESS && handle != nullptr) { + // 创建 DefaultCustomDeviceModule 并移交所有权 + return std::make_unique(handle, cif_); + } + } + LOG(ERROR) << "Failed to load custom device module from path: " << path; + return nullptr; + } + + void LaunchKernel(void* func_ptr, + const std::string& func_name, + void** args, + int num_args, + int grid_x, + int grid_y, + int grid_z, + int block_x, + int block_y, + int block_z, + int shared_mem, + void* stream) override { + if (cif_ && cif_->launch_kernel) { + // 调用 C 接口 + cif_->launch_kernel(cif_->dev_ptr, + func_ptr, + args, + num_args, + grid_x, + grid_y, + grid_z, + block_x, + block_y, + block_z, + shared_mem, + stream); + return; + } + LOG(ERROR) << "launch_kernel interface not implemented by vendor."; + } + + private: + C_CinnInterface* cif_; +}; + +// 3. 编译优化接口:负责厂商自定义的 Fusion/Schedule/Pass +// 默认编译策略 +class DefaultCompileStrategy : public CustomCompileStrategy { + // 目前使用基类默认实现 (return false) +}; + +} // namespace + +// ============================================================ +// CinnCustomDevicePlugin 实现 +// ============================================================ + +// 1. 实现 InitWrappers:将 C 接口转换为 C++ 策略对象 +void CinnCustomDevicePlugin::InitWrappers(C_CinnInterface* cif) { + // 使用上面定义的 Default 实现类 + toolchain_ = std::make_unique(cif); + runtime_strategy_ = std::make_unique(cif); + compile_strategy_ = std::make_unique(); +} + +// 2. 实现 GetInstance +CinnCustomDevicePlugin& CinnCustomDevicePlugin::GetInstance( + const phi::Place& place) { + static std::unordered_map> + instances; + std::string device_type = place.GetDeviceType(); + + if (instances.find(device_type) == instances.end()) { + // A. 获取基础设备指针 + auto* device_base = phi::DeviceManager::GetDeviceWithPlace(place); + PADDLE_ENFORCE_NOT_NULL( + device_base, + phi::errors::NotFound("Device for %s not found.", place.DebugString())); + + // B. 转换为 CustomDevice 并获取 CINN 专属 C 接口 + C_CinnInterface* cif = device_base->GetCinnInterface(); + + // C. 检查接口是否存在 + if (cif == nullptr) { + LOG(FATAL) << "Custom Device [" << device_type + << "] does not support CINN (C_CinnInterface is null)."; + } + + // D. 创建并初始化插件 + auto plugin_ptr = + std::make_unique(); // 调用默认构造 + plugin_ptr->InitWrappers(cif); + + instances[device_type] = std::move(plugin_ptr); + } + + return *instances[device_type]; +} + +// ============================================================ +// CustomBackendAPI Implementation +// ============================================================ + +CustomBackendAPI* CustomBackendAPI::Global() { + static CustomBackendAPI instance; + return &instance; +} + +void CustomBackendAPI::set_device(int device_id) { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) { + LOG(WARNING) << "No custom device types found when calling set_device."; + return; + } + // Set the device for the first available custom device type + // Note: CINN usually assumes one active backend type at a time + phi::DeviceManager::SetDevice(dev_types[0], static_cast(device_id)); +} + +int CustomBackendAPI::get_device() { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return 0; + + // Return the device ID of the current active device for this type + return phi::DeviceManager::GetDevice(dev_types[0]); +} + +int CustomBackendAPI::get_device_property(DeviceProperty device_property, + std::optional device_id) { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return 0; + + // Use current device ID if not provided + size_t id = device_id.has_value() ? static_cast(device_id.value()) + : static_cast(get_device()); + std::string dev_type = dev_types[0]; + phi::Place place = phi::CustomPlace(dev_type, id); + + switch (device_property) { + case DeviceProperty::MaxSharedMemoryPerBlock: + return phi::DeviceManager::GetMaxSharedMemPerBlock(place); + case DeviceProperty::MaxThreadsPerBlock: + return phi::DeviceManager::GetMaxThreadsPerBlock(place); + case DeviceProperty::MaxThreadsPerSM: + return phi::DeviceManager::GetMaxThreadsPerMultiProcessor(place); + case DeviceProperty::MultiProcessorCount: + return phi::DeviceManager::GetMultiProcessors(place); + case DeviceProperty::MaxBlocksPerSM: + return phi::DeviceManager::GetMaxBlocksPerMultiProcessor(place); + case DeviceProperty::MaxGridDimX: + return phi::DeviceManager::GetMaxGridDimSize(place)[0]; + case DeviceProperty::MaxGridDimY: + return phi::DeviceManager::GetMaxGridDimSize(place)[1]; + case DeviceProperty::MaxGridDimZ: + return phi::DeviceManager::GetMaxGridDimSize(place)[2]; + case DeviceProperty::MaxBlockDimX: + return phi::DeviceManager::GetMaxBlockDimSize(place)[0]; + case DeviceProperty::MaxBlockDimY: + return phi::DeviceManager::GetMaxBlockDimSize(place)[1]; + case DeviceProperty::MaxBlockDimZ: + return phi::DeviceManager::GetMaxBlockDimSize(place)[2]; + default: + LOG(WARNING) << "Not supported device property: " + << static_cast(device_property); + return 0; + } +} + +void* CustomBackendAPI::malloc(size_t numBytes) { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return nullptr; + + int device_id = get_device(); + auto place = phi::CustomPlace(dev_types[0], device_id); + + // Use DeviceManager::GetDeviceWithPlace to access memory allocation + return phi::DeviceManager::GetDeviceWithPlace(place)->MemoryAllocate( + numBytes); +} + +void CustomBackendAPI::free(void* data) { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return; + + int device_id = get_device(); + auto place = phi::CustomPlace(dev_types[0], device_id); + + // Note: Standard Device interface requires size for deallocation. + // Since BackendAPI::free only provides the pointer, we might need a + // workaround or rely on the specific device implementation ignoring the size + // if possible, OR use a CINN-specific allocator that tracks sizes. For now, + // we pass 0 as size, assuming underlying implementation handles it or CINN + // fixes this API. + phi::DeviceManager::GetDeviceWithPlace(place)->MemoryDeallocate(data, 0); +} + +void CustomBackendAPI::memset(void* data, int value, size_t numBytes) { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return; + + int device_id = get_device(); + auto place = phi::CustomPlace(dev_types[0], device_id); + + // Device::MemorySet takes uint8_t value + phi::DeviceManager::GetDeviceWithPlace(place)->MemorySet( + data, static_cast(value), numBytes); +} + +void CustomBackendAPI::memcpy(void* dest, + const void* src, + size_t numBytes, + MemcpyType type) { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return; + + int device_id = get_device(); + auto place = phi::CustomPlace(dev_types[0], device_id); + auto* device = phi::DeviceManager::GetDeviceWithPlace(place); + + // Map CINN MemcpyType to Phi Device methods + switch (type) { + case MemcpyType::HostToDevice: + device->MemoryCopyH2D(dest, src, numBytes, nullptr); + break; + case MemcpyType::DeviceToHost: + device->MemoryCopyD2H(dest, src, numBytes, nullptr); + break; + case MemcpyType::DeviceToDevice: + device->MemoryCopyD2D(dest, src, numBytes, nullptr); + break; + } +} + +void CustomBackendAPI::device_sync() { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return; + + int device_id = get_device(); + auto place = phi::CustomPlace(dev_types[0], device_id); + + phi::DeviceManager::SynchronizeDevice(place); +} + +void CustomBackendAPI::stream_sync(void* stream) { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return; + + int device_id = get_device(); + auto place = phi::CustomPlace(dev_types[0], device_id); + + if (stream) { + // Convert void* to phi::stream::stream_t (which is void*) and sync + phi::DeviceManager::GetDeviceWithPlace(place)->SynchronizeStream( + static_cast(stream)); + } +} + +std::array CustomBackendAPI::get_max_grid_dims( + std::optional device_id) { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return {0, 0, 0}; + + size_t id = device_id.has_value() ? static_cast(device_id.value()) + : static_cast(get_device()); + auto place = phi::CustomPlace(dev_types[0], id); + + auto dims = phi::DeviceManager::GetMaxGridDimSize(place); + return {static_cast(dims[0]), + static_cast(dims[1]), + static_cast(dims[2])}; +} + +std::array CustomBackendAPI::get_max_block_dims( + std::optional device_id) { + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + if (dev_types.empty()) return {0, 0, 0}; + + size_t id = device_id.has_value() ? static_cast(device_id.value()) + : static_cast(get_device()); + auto place = phi::CustomPlace(dev_types[0], id); + + auto dims = phi::DeviceManager::GetMaxBlockDimSize(place); + return {static_cast(dims[0]), + static_cast(dims[1]), + static_cast(dims[2])}; +} + +} // namespace custom_device +} // namespace runtime +} // namespace cinn +#endif // CINN_WITH_CUSTOM_DEVICE diff --git a/paddle/cinn/runtime/custom_device/custom_device_backend_api.h b/paddle/cinn/runtime/custom_device/custom_device_backend_api.h new file mode 100644 index 00000000000000..d565c141556dbe --- /dev/null +++ b/paddle/cinn/runtime/custom_device/custom_device_backend_api.h @@ -0,0 +1,157 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/cinn/runtime/backend_api.h" +#include "paddle/phi/backends/device_ext.h" +#include "paddle/phi/common/place.h" + +#ifdef CINN_WITH_CUSTOM_DEVICE +namespace cinn { +namespace runtime { + +class CustomModule { + public: + virtual ~CustomModule() = default; + + virtual void* GetFunction(const std::string& func_name) = 0; +}; + +namespace custom_device { + +// ============================================================ +// 第一部分:CINN 编译与运行策略抽象接口 +// ============================================================ + +// 1. 编译工具链接口:负责调用外部编译器 (如 mxcc) +class CustomCompilerToolchain { + public: + virtual ~CustomCompilerToolchain() = default; + virtual std::string Compile(const std::string& code) = 0; + // 新增:获取厂商的基础设备库源码 (原 cinn_custom_device_runtime_source.h + // 的内容) + virtual std::string GetRuntimeSource() = 0; +}; + +// 2. 运行时策略接口:负责加载和启动 Kernel +class CustomRuntimeStrategy { + public: + virtual ~CustomRuntimeStrategy() = default; + virtual std::unique_ptr LoadModule( + const std::string& path) = 0; + virtual void LaunchKernel(void* func_ptr, + const std::string& func_name, + void** args, + int num_args, + int grid_x, + int grid_y, + int grid_z, + int block_x, + int block_y, + int block_z, + int shared_mem, + void* stream) = 0; +}; + +// 3. 编译优化接口:负责厂商自定义的 Fusion/Schedule/Pass +class CustomCompileStrategy { + public: + virtual ~CustomCompileStrategy() = default; + virtual bool ApplyCustomPass(void* ir_module) { return false; } + // 可以在这里增加 GetHeaderSource 等接口获取硬件特定头文件内容 +}; + +// ============================================================ +// 第二部分:插件管理类 (单例) +// ============================================================ +// 4. 顶层插件管理类 +class PADDLE_API CinnCustomDevicePlugin { + public: + // 禁用构造,统一通过 GetInstance 访问 + CinnCustomDevicePlugin() = default; + ~CinnCustomDevicePlugin() = default; + + // 按 Place 获取对应的单例插件实例 + static CinnCustomDevicePlugin& GetInstance(const phi::Place& place); + + // 暴露给 Compiler/Codegen 调用的包装接口 + CustomCompilerToolchain* GetToolchain() { return toolchain_.get(); } + CustomRuntimeStrategy* GetRuntime() { return runtime_strategy_.get(); } + CustomCompileStrategy* GetCompileStrategy() { + return compile_strategy_.get(); + } + + // 内部初始化,由 .cc 中的 GetInstance 调用 + void InitWrappers(C_CinnInterface* cif); + + private: + // 具体的包装器实例 + std::unique_ptr toolchain_; + std::unique_ptr runtime_strategy_; + std::unique_ptr compile_strategy_; + + // 禁止拷贝 + CinnCustomDevicePlugin(const CinnCustomDevicePlugin&) = delete; + CinnCustomDevicePlugin& operator=(const CinnCustomDevicePlugin&) = delete; +}; + +// ============================================================ +// 第三部分:BackendAPI 实现 (核心运行时接口) +// ============================================================ +class CustomBackendAPI final : public BackendAPI { + public: + CustomBackendAPI() = default; + ~CustomBackendAPI() = default; + + // 全局访问点 + static CustomBackendAPI* Global(); + + // --- 必须实现的虚函数 (来自 BackendAPI) --- + void set_device(int device_id) override; + int get_device() override; + + // 内存管理 + void* malloc(size_t numBytes) override; + void free(void* data) override; + void memset(void* data, int value, size_t numBytes) override; + void memcpy(void* dest, + const void* src, + size_t numBytes, + MemcpyType type) override; + + // 同步 + void device_sync() override; + void stream_sync(void* stream) override; + + // 属性查询 (这些通常在 Target 中也有,但 Runtime 有时需要直接调用) + int get_device_property(DeviceProperty device_property, + std::optional device_id = std::nullopt) override; + + std::array get_max_grid_dims( + std::optional device_id = std::nullopt) override; + std::array get_max_block_dims( + std::optional device_id = std::nullopt) override; +}; + +} // namespace custom_device +} // namespace runtime +} // namespace cinn +#endif // CINN_WITH_CUSTOM_DEVICE diff --git a/paddle/cinn/runtime/custom_device/custom_device_intrinsics.cc b/paddle/cinn/runtime/custom_device/custom_device_intrinsics.cc new file mode 100644 index 00000000000000..a6368a4cf5942b --- /dev/null +++ b/paddle/cinn/runtime/custom_device/custom_device_intrinsics.cc @@ -0,0 +1,454 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" +using cinn::backends::GlobalSymbolRegistry; +#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h" +using cinn::runtime::custom_device::CustomBackendAPI; +#include "paddle/cinn/backends/extern_func_jit_register.h" +#include "paddle/cinn/runtime/custom_device/custom_device_util.h" + +using cinn_buffer_ptr_t = cinn_buffer_t *; +using cinn_int_ptr_t = int *; + +CINN_REGISTER_HELPER(cinn_custom_device_host_api) { + GlobalSymbolRegistry::Global().RegisterFn( + "backend_api.custom_device", + reinterpret_cast(CustomBackendAPI::Global())); // TODO(xuyuhan) + + using cinn::runtime::custom_device::cinn_call_custom_device_kernel; + REGISTER_EXTERN_FUNC_HELPER(cinn_call_custom_device_kernel, + cinn::common::DefaultHostTarget()) + .template SetRetType() + .template AddInputType() // kernel_fn + .template AddInputType() // args + .template AddInputType(cinn::common::type_of()) // num_args + .template AddInputType(cinn::common::type_of()) // grid_x + .template AddInputType(cinn::common::type_of()) // grid_y + .template AddInputType(cinn::common::type_of()) // grid_z + .template AddInputType(cinn::common::type_of()) // block_x + .template AddInputType(cinn::common::type_of()) // block_y + .template AddInputType(cinn::common::type_of()) // block_z + .template AddInputType( + cinn::common::type_of()) // shared_memory_bytes + .template AddInputType() // stream + .End(); + using cinn::runtime::custom_device::infer_shape_set_value; + + REGISTER_EXTERN_FUNC_HELPER(infer_shape_set_value, + cinn::common::DefaultHostTarget()) + .template SetRetType() + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + return true; +} + +CINN_REGISTER_HELPER(custom_device_intrinsics) { + auto target = cinn::common::DefaultCustomDeviceTarget(); + +// bool for 1 input 1 output +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_bool, target, bool, bool) + + REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL(bitwise_not); + +#undef REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL + +// bool for 2 input 1 output +#define REGISTER_EXTERN_FUNC_2_IN_1_OUT_BOOL(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_custom_devicetom_device_##func__##_bool, target, bool, bool, bool) + + REGISTER_EXTERN_FUNC_2_IN_1_OUT_BOOL(bitwise_and); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_BOOL(bitwise_or); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_BOOL(bitwise_xor); + +#undef REGISTER_EXTERN_FUNC_2_IN_1_OUT_BOOL + +// uint8 for 1 input 1 output +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_UINT8(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_uint8, target, uint8_t, uint8_t) + + REGISTER_EXTERN_FUNC_1_IN_1_OUT_UINT8(bitwise_not); + +#undef REGISTER_EXTERN_FUNC_1_IN_1_OUT_UINT8 + +// uint8 for 2 input 1 output +#define REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_custom_device_##func__##_uint8, target, uint8_t, uint8_t, uint8_t); + + REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(bitwise_and); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(bitwise_or); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(bitwise_xor); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(logical_right_shift); + +#undef REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8 + +// int8 for 1 input 1 output +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT8(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_int8, target, int8_t, int8_t) + + REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT8(bitwise_not); + +#undef REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT8 + +// int8 for 2 input 1 output +#define REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_custom_device_##func__##_int8, target, int8_t, int8_t, int8_t); + + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(bitwise_and); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(bitwise_or); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(bitwise_xor); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(logical_right_shift); + +#undef REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8 + +// int16 for 1 input 1 output +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT16(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_int16, target, int16_t, int16_t) + + REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT16(bitwise_not); + +#undef REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT16 + +// int16 for 2 input 1 output +#define REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_custom_device_##func__##_int16, target, int16_t, int16_t, int16_t); + + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(bitwise_and); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(bitwise_or); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(bitwise_xor); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(logical_right_shift); + +#undef REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16 + +// float +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_fp32, target, float, float); + + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(abs); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(exp); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(erf); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(sqrt); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(rsqrt); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(log); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(log2); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(log10); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(floor); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(ceil); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(round); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(trunc); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(cos); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(cosh); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(tan); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(sin); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(sinh); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(acos); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(acosh); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(asin); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(asinh); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(atan); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(atanh); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(tanh); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(cbrt); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(sigmoid); + +#undef REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT + +#define REGISTER_EXTERN_FUNC_1_IN_FLOAT_1_OUT_BOOL(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_devicetom_device_##func__##_fp32, target, float, bool); + + REGISTER_EXTERN_FUNC_1_IN_FLOAT_1_OUT_BOOL(isnan); + REGISTER_EXTERN_FUNC_1_IN_FLOAT_1_OUT_BOOL(isfinite); + REGISTER_EXTERN_FUNC_1_IN_FLOAT_1_OUT_BOOL(isinf); + +#undef REGISTER_EXTERN_FUNC_1_IN_FLOAT_1_OUT_BOOL + +#define REGISTER_EXTERN_FUNC_2_IN_1_FLOAT(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_custom_device_##func__##_fp32, target, float, float, float); + + REGISTER_EXTERN_FUNC_2_IN_1_FLOAT(pow) + REGISTER_EXTERN_FUNC_2_IN_1_FLOAT(mod) + +#undef REGISTER_EXTERN_FUNC_2_IN_1_FLOAT + + // double + +#define REGISTER_EXTERN_FUNC_1_IN_1_FP64(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_fp64, target, double, double); + + REGISTER_EXTERN_FUNC_1_IN_1_FP64(abs); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(exp); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(erf); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(sqrt); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(rsqrt); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(log); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(log2); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(log10); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(floor); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(ceil); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(round); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(trunc); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(cos); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(cosh); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(tan); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(sin); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(sinh); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(acos); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(acosh); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(asin); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(asinh); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(atan); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(atanh); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(tanh); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(cbrt); + REGISTER_EXTERN_FUNC_1_IN_1_FP64(sigmoid); + +#undef REGISTER_EXTERN_FUNC_1_IN_1_FP64 + +#define REGISTER_EXTERN_FUNC_1_IN_FP64_1_OUT_BOOL(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_fp64, target, double, bool); + + REGISTER_EXTERN_FUNC_1_IN_FP64_1_OUT_BOOL(isnan); + REGISTER_EXTERN_FUNC_1_IN_FP64_1_OUT_BOOL(isfinite); + REGISTER_EXTERN_FUNC_1_IN_FP64_1_OUT_BOOL(isinf); + +#undef REGISTER_EXTERN_FUNC_1_IN_FP64_1_OUT_BOOL + +#define REGISTER_EXTERN_FUNC_2_IN_1_FP64(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_custom_device_##func__##_fp64, target, double, double, double); + + REGISTER_EXTERN_FUNC_2_IN_1_FP64(pow) + REGISTER_EXTERN_FUNC_2_IN_1_FP64(mod) + +#undef REGISTER_EXTERN_FUNC_2_IN_1_FP64 + + // int32 + +#define REGISTER_EXTERN_FUNC_1_IN_1_INT32(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_int32, target, int, int); + + REGISTER_EXTERN_FUNC_1_IN_1_INT32(bitwise_not) + REGISTER_EXTERN_FUNC_1_IN_1_INT32(clz) + REGISTER_EXTERN_FUNC_1_IN_1_INT32(popc) + REGISTER_EXTERN_FUNC_1_IN_1_INT32(trunc) + +#undef REGISTER_EXTERN_FUNC_1_IN_1_INT32 + +#define REGISTER_EXTERN_FUNC_1_IN_1_INT64(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_int64, target, int64_t, int64_t); + + REGISTER_EXTERN_FUNC_1_IN_1_INT64(bitwise_not) + REGISTER_EXTERN_FUNC_1_IN_1_INT64(clz) + REGISTER_EXTERN_FUNC_1_IN_1_INT64(popc) + REGISTER_EXTERN_FUNC_1_IN_1_INT64(trunc) + +#undef REGISTER_EXTERN_FUNC_1_IN_1_INT64 + +#define REGISTER_EXTERN_FUNC_2_IN_1_INT32(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_custom_device_##func__##_int32, target, int, int, int); + + REGISTER_EXTERN_FUNC_2_IN_1_INT32(pow) + REGISTER_EXTERN_FUNC_2_IN_1_INT32(left_shift) + REGISTER_EXTERN_FUNC_2_IN_1_INT32(right_shift) + REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_and) + REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_or) + REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_xor) + REGISTER_EXTERN_FUNC_2_IN_1_INT32(logical_right_shift) + REGISTER_EXTERN_FUNC_2_IN_1_INT32(mod) + +#undef REGISTER_EXTERN_FUNC_2_IN_1_INT32 + +#define REGISTER_EXTERN_FUNC_2_IN_1_INT64(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_custom_device_##func__##_int64, target, int64_t, int64_t, int64_t); + + REGISTER_EXTERN_FUNC_2_IN_1_INT64(pow) + REGISTER_EXTERN_FUNC_2_IN_1_INT64(bitwise_and) + REGISTER_EXTERN_FUNC_2_IN_1_INT64(bitwise_or) + REGISTER_EXTERN_FUNC_2_IN_1_INT64(bitwise_xor) + REGISTER_EXTERN_FUNC_2_IN_1_INT64(mod) + REGISTER_EXTERN_FUNC_2_IN_1_INT64(logical_right_shift) + +#undef REGISTER_EXTERN_FUNC_2_IN_1_INT64 + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_find_int, target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_find_float, target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_find_int_nd, target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_find_float_nd, target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_find_int_from, target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_find_float_from, target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_next_smallest_int32, + target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + +#define _REGISTER_CINN_NVGPU_LT_NUM(TYPE_SUFFIX, TYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_lt_num_##TYPE_SUFFIX, \ + target) \ + .template SetRetType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + + _REGISTER_CINN_NVGPU_LT_NUM(fp32, float); + _REGISTER_CINN_NVGPU_LT_NUM(fp64, double); + _REGISTER_CINN_NVGPU_LT_NUM(uint8, uint8_t); + _REGISTER_CINN_NVGPU_LT_NUM(int16, int16_t); + + _REGISTER_CINN_NVGPU_LT_NUM(int32, int); + _REGISTER_CINN_NVGPU_LT_NUM(int64, int64_t); + +#undef _REGISTER_CINN_NVGPU_LT_NUM + +#define _REGISTER_CINN_NVGPU_GT_NUM(TYPE_SUFFIX, TYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_gt_num_##TYPE_SUFFIX, \ + target) \ + .template SetRetType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + + _REGISTER_CINN_NVGPU_GT_NUM(fp32, float); + _REGISTER_CINN_NVGPU_GT_NUM(fp64, double); + _REGISTER_CINN_NVGPU_GT_NUM(uint8, uint8_t); + _REGISTER_CINN_NVGPU_GT_NUM(int16, int16_t); + _REGISTER_CINN_NVGPU_GT_NUM(int32, int); + _REGISTER_CINN_NVGPU_GT_NUM(int64, int64_t); + +#undef _REGISTER_CINN_NVGPU_GT_NUM + +#define _REGISTER_CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER( \ + cinn_custom_device_index_add_##TYPE_SUFFIX, target) \ + .template SetRetType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + + _REGISTER_CINN_NVGPU_INDEX_ADD(bool, bool); + _REGISTER_CINN_NVGPU_INDEX_ADD(int8, int8_t); + _REGISTER_CINN_NVGPU_INDEX_ADD(int32, int32_t); + _REGISTER_CINN_NVGPU_INDEX_ADD(int64, int64_t); + _REGISTER_CINN_NVGPU_INDEX_ADD(fp32, float); + _REGISTER_CINN_NVGPU_INDEX_ADD(fp64, double); + +#undef _REGISTER_CINN_NVGPU_INDEX_ADD + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_resize_bilinear, target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_resize_bicubic, target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + + return true; +} diff --git a/paddle/cinn/runtime/custom_device/custom_device_intrinsics_float16.cc b/paddle/cinn/runtime/custom_device/custom_device_intrinsics_float16.cc new file mode 100644 index 00000000000000..a05ee1647b107a --- /dev/null +++ b/paddle/cinn/runtime/custom_device/custom_device_intrinsics_float16.cc @@ -0,0 +1,130 @@ +// Copyright (c) 2025 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/backends/extern_func_jit_register.h" +#include "paddle/cinn/backends/function_prototype.h" +#include "paddle/cinn/common/float16.h" +#include "paddle/cinn/common/type.h" +#include "paddle/cinn/runtime/custom_device/custom_device_util.h" + +using cinn::common::float16; +using cinn_buffer_ptr_t = cinn_buffer_t *; +using cinn_int_ptr_t = int *; + +CINN_REGISTER_HELPER(custom_device_intrinsics_float16) { + auto target = cinn::common::DefaultCustomDeviceTarget(); + using cinn::backends::FunctionProto; + +// float16 +#define REGISTER_EXTERN_FUNC_2_IN_1_FP16(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_custom_device_##func__##_fp16, target, float16, float16, float16); + + REGISTER_EXTERN_FUNC_2_IN_1_FP16(pow) + REGISTER_EXTERN_FUNC_2_IN_1_FP16(mod) + +#undef REGISTER_EXTERN_FUNC_2_IN_1_FP16 + +#define REGISTER_EXTERN_FUNC_1_IN_1_FP16(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_fp16, target, float16, float16); + + REGISTER_EXTERN_FUNC_1_IN_1_FP16(ceil) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(floor) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(round) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(trunc) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(sin) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(cos) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(tan) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(exp) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(log) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(log2) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(log10) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(sqrt) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(rsqrt) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(cbrt) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(abs) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(erf) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(sinh) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(cosh) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(tanh) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(asin) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(acos) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(atan) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(asinh) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(acosh) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(atanh) + REGISTER_EXTERN_FUNC_1_IN_1_FP16(sigmoid) + +#undef REGISTER_EXTERN_FUNC_1_IN_1_FP16 + +#define REGISTER_EXTERN_FUNC_1_IN_1_FP16_OUT_BOOL(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_custom_device_##func__##_fp16, target, float16, bool); + + REGISTER_EXTERN_FUNC_1_IN_1_FP16_OUT_BOOL(isnan) + REGISTER_EXTERN_FUNC_1_IN_1_FP16_OUT_BOOL(isinf) + REGISTER_EXTERN_FUNC_1_IN_1_FP16_OUT_BOOL(isfinite) + +#undef REGISTER_EXTERN_FUNC_1_IN_1_FP16_OUT_BOOL + +#define REGISTER_CINN_NVGPU_GT_NUM(TYPE_SUFFIX, TYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_gt_num_##TYPE_SUFFIX, \ + target) \ + .template SetRetType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType() \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + + REGISTER_CINN_NVGPU_GT_NUM(fp16, float16); + +#undef REGISTER_CINN_NVGPU_GT_NUM + +#define REGISTER_CINN_NVGPU_LT_NUM(TYPE_SUFFIX, TYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_custom_device_lt_num_##TYPE_SUFFIX, \ + target) \ + .template SetRetType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType() \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + + REGISTER_CINN_NVGPU_LT_NUM(fp16, float16); + +#undef REGISTER_CINN_NVGPU_LT_NUM + +#define REGISTER_CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER( \ + cinn_custom_device_index_add_##TYPE_SUFFIX, target) \ + .template SetRetType() \ + .template AddInputType() \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + + REGISTER_CINN_NVGPU_INDEX_ADD(fp16, float16); + +#undef REGISTER_CINN_NVGPU_INDEX_ADD + + return true; +} diff --git a/paddle/cinn/runtime/custom_device/custom_device_intrinsics_reduce.cc b/paddle/cinn/runtime/custom_device/custom_device_intrinsics_reduce.cc new file mode 100644 index 00000000000000..e49d97455287d8 --- /dev/null +++ b/paddle/cinn/runtime/custom_device/custom_device_intrinsics_reduce.cc @@ -0,0 +1,180 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/cinn/backends/extern_func_jit_register.h" +#include "paddle/cinn/common/float16.h" +#define CINN_CUSTOM_DEVICE_FP16 + +using cinn::common::float16; + +CINN_REGISTER_HELPER(custom_device_intrinsics_reduce) { + auto target = cinn::common::DefaultCustomDeviceTarget(); + +#define EXPAND_REDUCE_INT32_REGISTER_MARCO(MARCO, ...) \ + MARCO(sum_int32, int, ##__VA_ARGS__) \ + MARCO(prod_int32, int, ##__VA_ARGS__) \ + MARCO(max_int32, int, ##__VA_ARGS__) \ + MARCO(min_int32, int, ##__VA_ARGS__) + +#define EXPAND_REDUCE_INT64_REGISTER_MARCO(MARCO, ...) \ + MARCO(sum_int64, int64_t, ##__VA_ARGS__) \ + MARCO(prod_int64, int64_t, ##__VA_ARGS__) \ + MARCO(max_int64, int64_t, ##__VA_ARGS__) \ + MARCO(min_int64, int64_t, ##__VA_ARGS__) + +#define EXPAND_REDUCE_FP32_REGISTER_MACRO(MACRO, ...) \ + MACRO(sum_fp32, float, ##__VA_ARGS__) \ + MACRO(prod_fp32, float, ##__VA_ARGS__) \ + MACRO(max_fp32, float, ##__VA_ARGS__) \ + MACRO(min_fp32, float, ##__VA_ARGS__) + +#define EXPAND_REDUCE_BOOL_REGISTER_MACRO(MACRO, ...) \ + MACRO(all, bool, ##__VA_ARGS__) \ + MACRO(any, bool, ##__VA_ARGS__) + +#define EXPAND_REDUCE_FP64_REGISTER_MACRO(MACRO, ...) \ + MACRO(sum_fp64, double, ##__VA_ARGS__) \ + MACRO(prod_fp64, double, ##__VA_ARGS__) \ + MACRO(max_fp64, double, ##__VA_ARGS__) \ + MACRO(min_fp64, double, ##__VA_ARGS__) + +#ifdef CINN_CUSTOM_DEVICE_BF16 +#define EXPAND_REDUCE_BF16_REGISTER_MACRO(MACRO, ...) \ + MACRO(sum_bf16, bfloat16, ##__VA_ARGS__) \ + MACRO(prod_bf16, bfloat16, ##__VA_ARGS__) \ + MACRO(max_bf16, bfloat16, ##__VA_ARGS__) \ + MACRO(min_bf16, bfloat16, ##__VA_ARGS__) +#endif + +#ifdef CINN_CUSTOM_DEVICE_FP16 +#define EXPAND_REDUCE_FP16_REGISTER_MACRO(MACRO, ...) \ + MACRO(sum_fp16, float16, ##__VA_ARGS__) \ + MACRO(prod_fp16, float16, ##__VA_ARGS__) \ + MACRO(max_fp16, float16, ##__VA_ARGS__) \ + MACRO(min_fp16, float16, ##__VA_ARGS__) +#endif + +#define REGISTER_BLOCK_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_##REDUCE_TYPE, target) \ + .template SetRetType() \ + .template AddInputType() \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + + EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL) + +#ifdef CINN_CUSTOM_DEVICE_BF16 + EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL) +#endif + +#ifdef CINN_CUSTOM_DEVICE_FP16 + EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL) +#endif + +#undef REGISTER_BLOCK_REDUCE_FUNC_IMPL + +#define REGISTER_DISCRETE_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_discrete_reduce_##REDUCE_TYPE, \ + target) \ + .template SetRetType() \ + .template AddInputType() \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + + EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL) + +#ifdef CINN_CUSTOM_DEVICE_BF16 + EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL) +#endif + +#ifdef CINN_CUSTOM_DEVICE_FP16 + EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL) +#endif + +#undef REGISTER_DISCRETE_REDUCE_FUNC_IMPL + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_grid_reduce_update_semaphore, target) + .template SetRetType(cinn::common::type_of()) + .template AddInputType(cinn::common::type_of()) + .End(); + +#define REGISTER_BLOCK_SHUFFLE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(block_shuffle_##REDUCE_TYPE, target) \ + .template SetRetType() \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + + EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + +#ifdef CINN_CUSTOM_DEVICE_BF16 + EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) +#endif + +#ifdef CINN_CUSTOM_DEVICE_FP16 + EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) +#endif + +#undef REGISTER_BLOCK_SHUFFLE_FUNC_IMPL + +#define REGISTER_GRID_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_grid_reduce_##REDUCE_TYPE, target) \ + .template SetRetType() \ + .template AddInputType(cinn::common::type_of()) \ + .template AddInputType(cinn::common::type_of()) \ + .End(); + EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_GRID_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_GRID_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_GRID_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_GRID_REDUCE_FUNC_IMPL) + EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_GRID_REDUCE_FUNC_IMPL) + +#ifdef CINN_CUSTOM_DEVICE_BF16 + EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_GRID_REDUCE_FUNC_IMPL) +#endif + +#ifdef CINN_CUSTOM_DEVICE_FP16 + EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_GRID_REDUCE_FUNC_IMPL) +#endif + +#undef REGISTER_GRID_REDUCE_FUNC_IMPL + +#undef EXPAND_REDUCE_INT32_REGISTER_MARCO +#undef EXPAND_REDUCE_INT64_REGISTER_MARCO +#undef EXPAND_REDUCE_FP32_REGISTER_MACRO +#undef EXPAND_REDUCE_FP64_REGISTER_MACRO +#undef EXPAND_REDUCE_BOOL_REGISTER_MACRO + +#ifdef CINN_CUSTOM_DEVICE_BF16 +#undef EXPAND_REDUCE_BF16_REGISTER_MACRO +#endif + +#ifdef CINN_CUSTOM_DEVICE_FP16 +#undef EXPAND_REDUCE_FP16_REGISTER_MACRO +#endif + + return true; +} diff --git a/paddle/cinn/runtime/custom_device/custom_device_util.cc b/paddle/cinn/runtime/custom_device/custom_device_util.cc new file mode 100644 index 00000000000000..85c0d12ce0bdea --- /dev/null +++ b/paddle/cinn/runtime/custom_device/custom_device_util.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/runtime/custom_device/custom_device_util.h" +#include +#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h" +#include "paddle/cinn/utils/profiler.h" +#include "paddle/phi/backends/device_manager.h" + +namespace cinn { +namespace runtime { +namespace custom_device { + +void cinn_call_custom_device_kernel(void *kernel_fn, + void *v_args, + int num_args, + int grid_x, + int grid_y, + int grid_z, + int block_x, + int block_y, + int block_z, + int shared_memory_bytes, + void *stream) { + // 1. 获取当前设备 (通过 Phi DeviceManager) + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + PADDLE_ENFORCE_EQ(dev_types.empty(), + false, + phi::errors::NotFound("No Custom Device type registered.")); + + std::string dev_type = dev_types[0]; + int device_id = phi::DeviceManager::GetDevice(dev_type); + auto place = phi::CustomPlace(dev_type, device_id); + + // 2. 获取插件实例 + auto &plugin = CinnCustomDevicePlugin::GetInstance(place); + auto *runtime_strategy = plugin.GetRuntime(); + + VLOG(3) << "Launching kernel on " << dev_type << ":" << device_id << " Grid(" + << grid_x << "," << grid_y << "," << grid_z << ")" + << " Block(" << block_x << "," << block_y << "," << block_z << ")"; + + // 3. 参数转换:从 cinn_pod_value_t (v_args) 到 void** (kernel_args) + // CINN 的参数协议: + // - 如果是 Buffer,传入的是 cinn_buffer_t*,我们需要提取其内部的 memory + // 指针。 + // - 如果是标量 (int/float),直接传入其地址。 + std::vector kernel_args; + kernel_args.reserve(num_args); + + cinn_pod_value_t *args = static_cast(v_args); + + { + cinn::utils::RecordEvent record_run("prepare_args", + cinn::utils::EventType::kInstruction); + for (int idx = 0; idx < num_args; ++idx) { + if (args[idx].type_code() == ::cinn_type_code()) { + // 对于显存 Buffer,获取 cinn_buffer_t->memory (这已经在 Device + // 端分配好了) + cinn_buffer_t *buffer = static_cast(args[idx]); + kernel_args.emplace_back(&(buffer->memory)); + } else { + // 对于标量参数,获取其在 host 上的数据地址 + // 注意:插件内部的 LaunchKernel 需要处理这些标量的拷贝或映射 + kernel_args.emplace_back(const_cast(args[idx].data_addr())); + } + } + } + + // 4. 调用插件的 LaunchKernel + // 此时 kernel_fn 是厂商插件 LoadModule 后返回的函数句柄 (如 + // customDeviceFunction_t) + { + cinn::utils::RecordEvent record_run("plugin_launch_kernel", + cinn::utils::EventType::kInstruction); + + // 注意:这里我们传入 args 的地址数组 + // 厂商实现通常类似于:cuLaunchKernel(..., kernel_args.data(), ...) + runtime_strategy->LaunchKernel( + kernel_fn, + "", // 这里 func_name 可为空,因为 kernel_fn 已经是句柄了 + kernel_args.data(), + num_args, + grid_x, + grid_y, + grid_z, + block_x, + block_y, + block_z, + shared_memory_bytes, + stream); + } +} + +void infer_shape_set_value(int row, int col, int64_t value, int64_t **v) { + v[row][col] = value; +} + +} // namespace custom_device +} // namespace runtime +} // namespace cinn diff --git a/paddle/cinn/runtime/custom_device/custom_device_util.h b/paddle/cinn/runtime/custom_device/custom_device_util.h new file mode 100644 index 00000000000000..3ed996d843afae --- /dev/null +++ b/paddle/cinn/runtime/custom_device/custom_device_util.h @@ -0,0 +1,48 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/runtime/cinn_runtime.h" +#include "paddle/common/enforce.h" + +namespace cinn { +namespace runtime { +namespace custom_device { + +/** + * @brief 通用的自定义设备 Kernel 调用接口。 + * * 该函数不再直接调用特定厂商的 API (如 hipLaunchKernel), + * 而是通过 CinnCustomDevicePlugin 转发给厂商插件实现。 + */ +void cinn_call_custom_device_kernel(void *kernel_fn, + void *v_args, + int num_args, + int grid_x, + int grid_y, + int grid_z, + int block_x, + int block_y, + int block_z, + int shared_memory_bytes, + void *stream); + +/** + * @brief 用于动态形状推理的 Host 端辅助函数。 + */ +void infer_shape_set_value(int row, int col, int64_t value, int64_t **v); + +} // namespace custom_device +} // namespace runtime +} // namespace cinn diff --git a/paddle/cinn/runtime/custom_device/float16.h b/paddle/cinn/runtime/custom_device/float16.h new file mode 100644 index 00000000000000..6dc31d14a1846f --- /dev/null +++ b/paddle/cinn/runtime/custom_device/float16.h @@ -0,0 +1,752 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CINN_COMMON_FLOAT16_H +#define CINN_COMMON_FLOAT16_H + +#ifdef __cplusplus +#pragma once +#endif // __cplusplus + +#if defined(_M_X64) || defined(__x86_64__) || defined(_M_IX86) || \ + defined(__i386__) +#define __CINN_x86__ +#include +#endif + +#include + +#include + +#ifdef CINN_WITH_CUDA +#include + +#if (defined(__CUDACC__) || defined(__CUDACC_RTC__)) +#define CINN_CUDA_FP16 +#include +#endif // __CUDACC__ +#endif // CINN_WITH_CUDA + +#ifdef CINN_WITH_HIP +#include +#if defined(__HIPCC__) +#define __HIP_PLATFORM_AMD__ +#include +#define CINN_HIP_FP16 +#endif +#endif + +#ifdef __cplusplus +#ifndef _WIN32 +#define CINN_ALIGN(x) __attribute__((aligned(x))) +#else // _WIN32 +#define CINN_ALIGN(x) __declspec(align(x)) +#endif // _WIN32 + +#else // __cplusplus +#define CINN_ALIGN(x) +#endif // __cplusplus + +// The `HOST` macro definition is not used here, it has a potential +// conflict with the enumeration `kHOST` representing the backend. +#ifndef __host__ +#define __host__ +#endif +#ifndef __device__ +#define __device__ +#endif + +#ifdef __cplusplus +namespace cinn { +namespace common { +#endif // __cplusplus + +// Use CINN_ALIGNED(2) to ensure that each float16 will be allocated +// and aligned at least on a 2-byte boundary, which leads to efficient +// memory access of float16 struct and also makes float16 compatible +// with CUDA half +struct CINN_ALIGN(2) float16 { + uint16_t x; + +#ifdef __cplusplus + // The following defaulted special class member functions + // are added to make float16 pass the std::is_trivial test + float16() = default; + float16(const float16& o) = default; + float16& operator=(const float16& o) = default; + float16(float16&& o) = default; + float16& operator=(float16&& o) = default; + ~float16() = default; + +// Constructors +#if defined(CINN_CUDA_FP16) || defined(CINN_HIP_FP16) + __host__ __device__ inline explicit float16(const half& h) { +#if defined(CINN_CUDA_FP16) && (CUDA_VERSION >= 9000) || defined(CINN_HIP_FP16) + x = reinterpret_cast<__half_raw*>(const_cast(&h))->x; +#else + x = h.x; +#endif // CUDA_VERSION >= 9000 + } +#endif // CINN_CUDA_FP16 + + __host__ __device__ inline explicit float16(float val) { +#if defined(CINN_CUDA_FP16) && \ + (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \ + defined(CINN_HIP_FP16) + half tmp = __float2half(val); + x = *reinterpret_cast(&tmp); + +#elif defined(__F16C__) && defined(__CINN_x86__) + x = _cvtss_sh(val, 0); + +#else + // Conversion routine adapted from + // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion + Bits v, s; + v.f = val; + uint32_t sign = v.si & sigN; + v.si ^= sign; + sign >>= shiftSign; // logical shift + s.si = mulN; + s.si = s.f * v.f; // correct subnormals + v.si ^= (s.si ^ v.si) & -(minN > v.si); + v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); + v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); + v.ui >>= shift; // logical shift + v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); + v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); + x = v.ui | sign; + +#endif + } + + __host__ __device__ inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {} + + template + __host__ __device__ inline explicit float16(const T& val) + : x(float16(static_cast(val)).x) {} + +// Assignment operators +#if defined(CINN_CUDA_FP16) || defined(CINN_HIP_FP16) + __host__ __device__ inline float16& operator=(const half& rhs) { +#if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16) + x = reinterpret_cast<__half_raw*>(const_cast(&rhs))->x; +#else + x = rhs.x; +#endif + return *this; + } +#endif + + __host__ __device__ inline float16& operator=(bool b) { + x = b ? 0x3c00 : 0; + return *this; + } + + __host__ __device__ inline float16& operator=(int8_t val) { + x = float16(val).x; + return *this; + } + + __host__ __device__ inline float16& operator=(uint8_t val) { + x = float16(val).x; + return *this; + } + + __host__ __device__ inline float16& operator=(int16_t val) { + x = float16(val).x; + return *this; + } + + __host__ __device__ inline float16& operator=(uint16_t val) { + x = float16(val).x; + return *this; + } + + __host__ __device__ inline float16& operator=(int32_t val) { + x = float16(val).x; + return *this; + } + + __host__ __device__ inline float16& operator=(uint32_t val) { + x = float16(val).x; + return *this; + } + + __host__ __device__ inline float16& operator=(int64_t val) { + x = float16(val).x; + return *this; + } + + __host__ __device__ inline float16& operator=(uint64_t val) { + x = float16(val).x; + return *this; + } + + __host__ __device__ inline float16& operator=(float val) { + x = float16(val).x; + return *this; + } + + __host__ __device__ inline float16& operator=(double val) { + x = float16(val).x; + return *this; + } + +// Conversion operators +#if defined(CINN_CUDA_FP16) || defined(CINN_HIP_FP16) + __host__ __device__ inline half to_half() const { +#if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16) + __half_raw h; + h.x = x; + return half(h); +#else + half h; + h.x = x; + return h; +#endif // CUDA_VERSION >= 9000 + } +#endif // CINN_CUDA_FP16 + + __host__ __device__ inline operator float() const { +#if defined(CINN_CUDA_FP16) && \ + (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \ + defined(CINN_HIP_FP16) + half tmp = *reinterpret_cast(this); + return __half2float(tmp); + +#elif defined(__F16C__) + return _cvtsh_ss(this->x); + +#else + // Conversion routine adapted from + // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion + Bits v; + v.ui = this->x; + int32_t sign = v.si & sigC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + +#endif + } + + __host__ __device__ inline explicit operator bool() const { + return (x & 0x7fff) != 0; + } + + __host__ __device__ inline explicit operator int8_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint8_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator int16_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint16_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator int32_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint32_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator int64_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint64_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline operator double() const { + return static_cast(static_cast(*this)); + } + + private: + union Bits { + float f; + int32_t si; + uint32_t ui; + }; + + static const int shift = 13; + static const int shiftSign = 16; + + static const int32_t infN = 0x7F800000; + static const int32_t maxN = 0x477FE000; // max flt16 as flt32 + static const int32_t minN = 0x38800000; // min flt16 normal as flt32 + static const int32_t sigN = 0x80000000; // sign bit + + static constexpr int32_t infC = infN >> shift; + static constexpr int32_t nanN = (infC + 1) + << shift; // minimum flt16 nan as float32 + static constexpr int32_t maxC = maxN >> shift; + static constexpr int32_t minC = minN >> shift; + static constexpr int32_t sigC = sigN >> shiftSign; + + static const int32_t mulN = 0x52000000; // (1 << 23) / minN + static const int32_t mulC = 0x33800000; // minN / (1 << (23 - shift)) + static const int32_t subC = 0x003FF; // max flt32 subnormal downshifted + static const int32_t norC = 0x00400; // min flt32 normal downshifted + + static constexpr int32_t maxD = infC - maxC - 1; + static constexpr int32_t minD = minC - subC - 1; +#endif // __cplusplus +}; + +struct CINN_ALIGN(32) float8 { + float x, y, z, w, v, u, t, s; +}; + +struct CINN_ALIGN(16) half8 { + float16 x, y, z, w, v, u, t, s; +}; + +struct CINN_ALIGN(8) half4 { + float16 x, y, z, w; +}; + +struct CINN_ALIGN(16) float168 { + float16 x, y, z, w, v, u, t, s; +}; + +struct CINN_ALIGN(8) float164 { + float16 x, y, z, w; +}; + +struct CINN_ALIGN(4) float162 { + float16 x, y; +}; + +#ifdef __cplusplus +// Arithmetic operators on GPU +// CUDA 9.0 provides built-in arithmetic operators for half while +// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are +// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in +// CUDA 9.0 regarding the half data type. +// ROCM has built-in arithmetic operators as not defined +// __HIP_NO_HALF_OPERATORS__ +#if (defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000) || defined(CINN_HIP_FP16) +__device__ inline half operator+(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __hadd(a, b); +#else + float res = static_cast(float16(a)) + static_cast(float16(b)); + return float16(res).to_half(); +#endif +} + +__device__ inline half operator-(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __hsub(a, b); +#else + float res = static_cast(float16(a)) - static_cast(float16(b)); + return float16(res).to_half(); +#endif +} + +__device__ inline half operator*(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __hmul(a, b); +#else + float res = static_cast(float16(a)) * static_cast(float16(b)); + return float16(res).to_half(); +#endif +} + +__device__ inline half operator/(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + float num = __half2float(a); + float denom = __half2float(b); + return __float2half(num / denom); +#else + float res = static_cast(float16(a)) / static_cast(float16(b)); + return float16(res).to_half(); +#endif +} + +__device__ inline half operator-(const half& a) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __hneg(a); +#else + float res = -static_cast(float16(a)); + return float16(res).to_half(); +#endif +} + +#ifndef CINN_WITH_HIP +__device__ inline half& operator+=(half& a, const half& b) { // NOLINT + a = a + b; + return a; +} + +__device__ inline half& operator-=(half& a, const half& b) { // NOLINT + a = a - b; + return a; +} + +__device__ inline half& operator*=(half& a, const half& b) { // NOLINT + a = a * b; + return a; +} + +__device__ inline half& operator/=(half& a, const half& b) { // NOLINT + a = a / b; + return a; +} +#endif + +__device__ inline bool operator==(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __heq(a, b); +#else + return static_cast(float16(a)) == static_cast(float16(b)); +#endif +} + +__device__ inline bool operator!=(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __hne(a, b); +#else + return static_cast(float16(a)) != static_cast(float16(b)); +#endif +} + +__device__ inline bool operator<(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __hlt(a, b); +#else + return static_cast(float16(a)) < static_cast(float16(b)); +#endif +} + +__device__ inline bool operator<=(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __hle(a, b); +#else + return static_cast(float16(a)) <= static_cast(float16(b)); +#endif +} + +__device__ inline bool operator>(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __hgt(a, b); +#else + return static_cast(float16(a)) > static_cast(float16(b)); +#endif +} + +__device__ inline bool operator>=(const half& a, const half& b) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16) + return __hge(a, b); +#else + return static_cast(float16(a)) >= static_cast(float16(b)); +#endif +} + +#endif // CINN_CUDA_FP16 + +// Arithmetic operators for float16 on GPU +__host__ __device__ inline float16 operator+(const float16& a, + const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return float16(__hadd(a.to_half(), b.to_half())); +#else + return float16(static_cast(a) + static_cast(b)); +#endif +} + +__host__ __device__ inline float16 operator-(const float16& a, + const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return float16(__hsub(a.to_half(), b.to_half())); +#else + return float16(static_cast(a) - static_cast(b)); +#endif +} + +__host__ __device__ inline float16 operator*(const float16& a, + const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return float16(__hmul(a.to_half(), b.to_half())); +#else + return float16(static_cast(a) * static_cast(b)); +#endif +} + +__host__ __device__ inline float16 operator/(const float16& a, + const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + // TODO(kexinzhao): check which cuda version starts to support __hdiv + float num = __half2float(a.to_half()); + float denom = __half2float(b.to_half()); + return float16(num / denom); +#else + return float16(static_cast(a) / static_cast(b)); +#endif +} + +__host__ __device__ inline float16 operator-(const float16& a) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return float16(__hneg(a.to_half())); +#else + float16 res; + res.x = a.x ^ 0x8000; + return res; +#endif +} + +__host__ __device__ inline float16& operator+=(float16& a, // NOLINT + const float16& b) { // NOLINT + a = a + b; + return a; +} + +__host__ __device__ inline float16& operator-=(float16& a, // NOLINT + const float16& b) { // NOLINT + a = a - b; + return a; +} + +__host__ __device__ inline float16& operator*=(float16& a, // NOLINT + const float16& b) { // NOLINT + a = a * b; + return a; +} + +__host__ __device__ inline float16& operator/=(float16& a, // NOLINT + const float16& b) { // NOLINT + a = a / b; + return a; +} + +__host__ __device__ inline bool operator==(const float16& a, const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return __heq(a.to_half(), b.to_half()); +#else + return static_cast(a) == static_cast(b); +#endif +} + +__host__ __device__ inline bool operator!=(const float16& a, const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return __hne(a.to_half(), b.to_half()); +#else + return static_cast(a) != static_cast(b); +#endif +} + +__host__ __device__ inline bool operator<(const float16& a, const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return __hlt(a.to_half(), b.to_half()); +#else + return static_cast(a) < static_cast(b); +#endif +} + +__host__ __device__ inline bool operator<=(const float16& a, const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return __hle(a.to_half(), b.to_half()); +#else + return static_cast(a) <= static_cast(b); +#endif +} + +__host__ __device__ inline bool operator>(const float16& a, const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return __hgt(a.to_half(), b.to_half()); +#else + return static_cast(a) > static_cast(b); +#endif +} + +__host__ __device__ inline bool operator>=(const float16& a, const float16& b) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return __hge(a.to_half(), b.to_half()); +#else + return static_cast(a) >= static_cast(b); +#endif +} +#endif // __cplusplus + +__host__ __device__ inline float16 raw_uint16_to_float16(uint16_t a) { + float16 res; + res.x = a; + return res; +} + +__host__ __device__ inline bool(isnan)(const float16& a) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return __hisnan(a.to_half()); +#else + return (a.x & 0x7fff) > 0x7c00; +#endif +} + +__host__ __device__ inline bool(isinf)(const float16& a) { + return (a.x & 0x7fff) == 0x7c00; +} + +__host__ __device__ inline bool(isfinite)(const float16& a) { + return !((isnan)(a)) && !((isinf)(a)); +} + +__host__ __device__ inline float16(abs)(const float16& a) { +#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 530) || \ + defined(CINN_HIP_FP16) + return static_cast(__habs(a.to_half())); +#else + return static_cast(fabsf(static_cast(a))); +#endif +} + +__host__ __device__ inline float16(log)(const float16& a) { + return float16(std::log(static_cast(a))); +} + +#ifdef __cplusplus +} // namespace common +} // namespace cinn +#endif // __cplusplus + +#if defined(__cplusplus) && defined(CINN_CUDA_FP16) +__device__ inline cinn::common::float16 __shfl_sync(unsigned mask, + cinn::common::float16 var, + int srcLane, + int width = warpSize) { + return cinn::common::float16( + __shfl_sync(mask, var.to_half(), srcLane, width)); +} + +__device__ inline cinn::common::float16 __shfl_up_sync( + unsigned mask, + cinn::common::float16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::float16( + __shfl_up_sync(mask, var.to_half(), delta, width)); +} + +__device__ inline cinn::common::float16 __shfl_down_sync( + unsigned mask, + cinn::common::float16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::float16( + __shfl_down_sync(mask, var.to_half(), delta, width)); +} + +__device__ inline cinn::common::float16 __shfl_xor_sync( + unsigned mask, + cinn::common::float16 var, + int laneMask, + int width = warpSize) { + return cinn::common::float16( + __shfl_xor_sync(mask, var.to_half(), laneMask, width)); +} + +__host__ __device__ inline cinn::common::float16 max( + const cinn::common::float16& a, const cinn::common::float16& b) { + return a > b ? a : b; +} +__host__ __device__ inline cinn::common::float16 min( + const cinn::common::float16& a, const cinn::common::float16& b) { + return a < b ? a : b; +} +#endif // __cplusplus && CINN_CUDA_FP16 + +// Note: HIP does not support half-float shuffles. +#if defined(CINN_HIP_FP16) +__device__ inline cinn::common::float16 __shfl(cinn::common::float16 var, + int srcLane, + int width = warpSize) { + return cinn::common::float16(__shfl(static_cast(var), srcLane, width)); +} + +__device__ inline cinn::common::float16 __shfl_up(cinn::common::float16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::float16( + __shfl_up(static_cast(var), delta, width)); +} + +__device__ inline cinn::common::float16 __shfl_down(cinn::common::float16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::float16( + __shfl_down(static_cast(var), delta, width)); +} + +__device__ inline cinn::common::float16 __shfl_xor(cinn::common::float16 var, + int laneMask, + int width = warpSize) { + return cinn::common::float16( + __shfl_xor(static_cast(var), laneMask, width)); +} + +__host__ __device__ inline cinn::common::float16 max( + const cinn::common::float16& a, const cinn::common::float16& b) { + return a > b ? a : b; +} + +__host__ __device__ inline cinn::common::float16 min( + const cinn::common::float16& a, const cinn::common::float16& b) { + return a < b ? a : b; +} +#endif // CINN_HIP_FP16 + +#endif // CINN_COMMON_FLOAT16_H diff --git a/paddle/cinn/runtime/custom_device/float8e4m3.h b/paddle/cinn/runtime/custom_device/float8e4m3.h new file mode 100644 index 00000000000000..e70dbad626cc77 --- /dev/null +++ b/paddle/cinn/runtime/custom_device/float8e4m3.h @@ -0,0 +1,262 @@ +// Copyright (c) 2025 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CINN_COMMON_FLOAT8E4M3_H +#define CINN_COMMON_FLOAT8E4M3_H + +#ifdef __cplusplus +#pragma once +#endif // __cplusplus + +#include + +#include +#include + +#ifdef CINN_WITH_CUDA +#include + +#if (defined(__CUDACC__) || defined(__CUDACC_RTC__)) && CUDA_VERSION >= 11080 +#define CINN_CUDA_FP8 +#include +#endif // __CUDACC__ +#endif // CINN_WITH_CUDA + +#ifdef __cplusplus + +#ifndef _WIN32 +#define CINN_ALIGN(x) __attribute__((aligned(x))) +#else // _WIN32 +#define CINN_ALIGN(x) __declspec(align(x)) +#endif // _WIN32 + +#else // __cplusplus +#define CINN_ALIGN(x) +#endif // __cplusplus + +#ifndef __host__ +#define __host__ +#endif +#ifndef __device__ +#define __device__ +#endif + +#ifdef __cplusplus +namespace cinn { +namespace common { +#endif // __cplusplus + +// E4M3 format (4 exponent bits, 3 mantissa bits) +struct CINN_ALIGN(1) float8e4m3 { + uint8_t x; + +#ifdef __cplusplus + // Constructors + float8e4m3() = default; + float8e4m3(const float8e4m3& o) = default; + float8e4m3& operator=(const float8e4m3& o) = default; + float8e4m3(float8e4m3&& o) = default; + float8e4m3& operator=(float8e4m3&& o) = default; + ~float8e4m3() = default; + + union Bits { + float f; + uint32_t ui; + }; + __host__ __device__ inline explicit float8e4m3(float val) { +#if defined(CINN_CUDA_FP8) + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(val); + x = *reinterpret_cast(&tmp); +#else + // NOTE(YuhanXu): this code is mainly from + // https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/common/float8_e4m3fn.h + // with minor changes. + // CPU implementation. + Bits fb, denorm_mask; + fb.f = val; + constexpr uint32_t fp8_max = UINT32_C(1087) << 20; + denorm_mask.ui = UINT32_C(141) << 23; + uint8_t result = 0u; + const uint32_t sign = fb.ui & UINT32_C(0x80000000); + fb.ui ^= sign; + if (fb.ui >= fp8_max) { + result = 0x7e; + } else { + if (fb.ui < (UINT32_C(121) << 23)) { + fb.f = fb.f + denorm_mask.f; + fb.ui = fb.ui - denorm_mask.ui; + result = static_cast(fb.ui); + } else { + uint8_t mant_odd = (fb.ui >> 20) & 1; + fb.ui += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + fb.ui += mant_odd; + result = static_cast(fb.ui >> 20); + } + } + + result |= static_cast(sign >> 24); + x = result; +#endif + } + +#if defined(CINN_CUDA_FP8) + __host__ __device__ inline explicit float8e4m3(const __nv_fp8_e4m3& val) { + x = *reinterpret_cast(&val); + } + __host__ __device__ inline explicit float8e4m3(const __nv_bfloat16& val) { + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(val); + x = *reinterpret_cast(&tmp); + } +#endif + + template + __host__ __device__ inline explicit float8e4m3(const T& val) + : x(float8e4m3(static_cast(val)).x) {} + +// Assignment operators +#if defined(CINN_CUDA_FP8) + __host__ __device__ inline float8e4m3& operator=(const __nv_fp8_e4m3& val) { + x = *reinterpret_cast(&val); // NOLINT + return *this; + } +#endif + + // Conversion operators + __host__ __device__ inline operator float() const { +#ifdef CINN_CUDA_FP8 + return static_cast(*reinterpret_cast(&x)); +#else + // NOTE(YuhanXu): this code is mainly from + // https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/common/float8_e4m3fn.h + // with minor changes. + // CPU implementation. + const uint32_t w = (uint32_t)x << 24; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + + // get the leading 0-bits in nonsin. + uint32_t nonsign_tmp = nonsign; + uint32_t renorm_shift = 0; + if (nonsign_tmp == 0) { + renorm_shift = sizeof(uint32_t) * 8; + } else { + if ((nonsign_tmp & 0xFFFF0000) == 0) { + renorm_shift += 16; + nonsign_tmp <<= 16; + } + if ((nonsign_tmp & 0xFF000000) == 0) { + renorm_shift += 8; + nonsign_tmp <<= 8; + } + if ((nonsign_tmp & 0xF0000000) == 0) { + renorm_shift += 4; + nonsign_tmp <<= 4; + } + if ((nonsign_tmp & 0xC0000000) == 0) { + renorm_shift += 2; + nonsign_tmp <<= 2; + } + if ((nonsign_tmp & 0x80000000) == 0) { + renorm_shift += 1; + } + } + + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000); + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + Bits result; + result.ui = + sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return result.f; +#endif + } + +#ifdef CINN_CUDA_FP8 + __host__ __device__ inline __nv_fp8_e4m3 to_nv_fp8_e4m3() const { + return *reinterpret_cast(&x); + } +#endif + + __host__ __device__ inline explicit operator bool() const { + return (x & 0x7fff) != 0; + } + + __host__ __device__ inline explicit operator int8_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint8_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator int16_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint16_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator int32_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint32_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator int64_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline explicit operator uint64_t() const { + return static_cast(static_cast(*this)); + } + + __host__ __device__ inline operator double() const { + return static_cast(static_cast(*this)); + } +#endif // __cplusplus +}; + +// Vector types +struct CINN_ALIGN(4) float8e4m34 { + float8e4m3 x, y, z, w; +}; + +struct CINN_ALIGN(2) float8e4m32 { + float8e4m3 x, y; +}; + +#ifdef __cplusplus + +/// TODO(Yuhan): Arithmetic operator+ - * / etc. + +__host__ __device__ inline float8e4m3 raw_uint8_to_float8e4m3(uint8_t a) { + float8e4m3 res; + res.x = a; + return res; +} + +/// TODO(Yuhan): Comparison operators operator== != > < <= >= / etc. + +} // namespace common +} // namespace cinn +#endif // __cplusplus + +#endif // CINN_COMMON_FLOAT8E4M3_H diff --git a/paddle/cinn/runtime/custom_device/use_extern_funcs.h b/paddle/cinn/runtime/custom_device/use_extern_funcs.h new file mode 100644 index 00000000000000..b00835e0744f6b --- /dev/null +++ b/paddle/cinn/runtime/custom_device/use_extern_funcs.h @@ -0,0 +1,23 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/backends/extern_func_jit_register.h" + +#ifdef CINN_WITH_CUSTOM_DEVICE +CINN_USE_REGISTER(cinn_custom_device_host_api) +CINN_USE_REGISTER(custom_device_intrinsics) +CINN_USE_REGISTER(custom_device_intrinsics_reduce) +CINN_USE_REGISTER(custom_device_intrinsics_float16) +#endif diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index da07da329f0e17..0995b890b97096 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -378,6 +378,16 @@ void CheckCompileOptionImpl(cinn::common::NVGPUArch) { #endif } +void CheckCompileOptionImpl(cinn::common::CustomDeviceArch) { +#if defined(CINN_WITH_CUDNN) + // Do nothing; +#else + PADDLE_THROW(::common::errors::Fatal( + "Current CINN version does not support CustomDevice, please try to " + "recompile with -DWITH_CUDA.")); +#endif +} + void CheckCompileOptionImpl(cinn::common::HygonDCUArchHIP) { #ifdef CINN_WITH_HIP // Do nothing; diff --git a/paddle/cinn/runtime/intrinsic.h b/paddle/cinn/runtime/intrinsic.h index c8aacfc4317005..f4c7fcc5bfe720 100644 --- a/paddle/cinn/runtime/intrinsic.h +++ b/paddle/cinn/runtime/intrinsic.h @@ -107,6 +107,10 @@ static const char* call_cuda_kernel = "cinn_call_cuda_kernel"; static const char* call_cuda_cooperative_kernel = "cinn_call_cuda_cooperative_kernel"; +static const char* call_custom_device_kernel = "cinn_call_custom_device_kernel"; +static const char* call_custom_device_cooperative_kernel = + "cinn_call_custom_device_cooperative_kernel"; + static const char* call_hip_kernel = "cinn_call_hip_kernel"; static const char* call_sycl_kernel = "cinn_call_sycl_kernel"; diff --git a/paddle/cinn/runtime/sycl/sycl_backend_api.cc b/paddle/cinn/runtime/sycl/sycl_backend_api.cc index 3cd6e8f10eaffe..a93cc354eec6e2 100644 --- a/paddle/cinn/runtime/sycl/sycl_backend_api.cc +++ b/paddle/cinn/runtime/sycl/sycl_backend_api.cc @@ -41,6 +41,9 @@ void SYCLBackendAPI::Init(Arch arch) { }, [&](common::X86Arch) { CINN_NOT_IMPLEMENTED }, [&](common::ARMArch) { CINN_NOT_IMPLEMENTED }, + [&](common::CustomDeviceArch) { + backend = ::sycl::backend::ext_oneapi_cuda; + }, [&](common::NVGPUArch) { backend = ::sycl::backend::ext_oneapi_cuda; }, [&](common::HygonDCUArchHIP) { CINN_NOT_IMPLEMENTED }, [&](common::HygonDCUArchSYCL) { diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 94c46eed9ebc8e..799a822d59e78b 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -635,6 +635,57 @@ class CustomDevice : public DeviceInterface { return threads_per_block; } + size_t GetMaxSharedMemPerBlock(size_t dev_id) override { + const auto device = &devices_pool[dev_id]; + size_t shared_mem_per_block = 0; + if (pimpl_->get_max_shared_mem_per_block) { + pimpl_->get_max_shared_mem_per_block(device, &shared_mem_per_block); + } + VLOG(10) << Type() << " get max threads per block " << shared_mem_per_block; + return shared_mem_per_block; + } + + size_t GetMaxBlocksPerMultiProcessor(size_t dev_id) override { + const auto device = &devices_pool[dev_id]; + size_t blocks_per_mp = 0; + if (pimpl_->get_max_blocks_per_mp) { + pimpl_->get_max_blocks_per_mp(device, &blocks_per_mp); + } + VLOG(10) << Type() << " get blocks per multiprocessor " << blocks_per_mp; + return blocks_per_mp; + } + + size_t GetWarpSize(size_t dev_id) override { + const auto device = &devices_pool[dev_id]; + size_t warp_size = 0; + if (pimpl_->get_warp_size) { + pimpl_->get_warp_size(device, &warp_size); + } + VLOG(10) << Type() << " get warp size " << warp_size; + return warp_size; + } + + size_t GetMaxRegistersPerMultiProcessor(size_t dev_id) override { + const auto device = &devices_pool[dev_id]; + size_t registers_per_mp = 0; + if (pimpl_->get_max_registers_per_mp) { + pimpl_->get_max_registers_per_mp(device, ®isters_per_mp); + } + VLOG(10) << Type() << " get registers per multiprocessor " + << registers_per_mp; + return registers_per_mp; + } + + size_t GetPreferredVectorWidth(size_t dev_id) override { + const auto device = &devices_pool[dev_id]; + size_t vector_width = 0; + if (pimpl_->get_vector_width) { + pimpl_->get_vector_width(device, &vector_width); + } + VLOG(10) << Type() << " get preferred vector width " << vector_width; + return vector_width; + } + std::array GetMaxGridDimSize(size_t dev_id) override { const auto device = &devices_pool[dev_id]; std::array grid_dim_size = {0, 0, 0}; @@ -646,6 +697,17 @@ class CustomDevice : public DeviceInterface { return grid_dim_size; } + std::array GetMaxBlockDimSize(size_t dev_id) override { + const auto device = &devices_pool[dev_id]; + std::array block_dim_size = {0, 0, 0}; + if (pimpl_->get_max_block_dim_size) { + pimpl_->get_max_block_dim_size(device, &block_dim_size); + } + VLOG(10) << Type() << " get max grid dim size [" << block_dim_size[0] + << ", " << block_dim_size[1] << ", " << block_dim_size[2] << "]"; + return block_dim_size; + } + bool IsFloat16Supported(size_t dev_id) { const auto device = &devices_pool[dev_id]; bool supported = false; @@ -1235,6 +1297,11 @@ class CustomDevice : public DeviceInterface { } } + // 新增:获取 CINN 插件能力的接口 + C_CinnInterface* GetCinnInterface() override { + return pimpl_->cinn_interface; + } + private: inline int PlaceToIdNoCheck(const Place& place) { int dev_id = place.GetDeviceId(); // NOLINT @@ -1332,7 +1399,13 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { CHECK_INTERFACE(get_multi_process, false); CHECK_INTERFACE(get_max_threads_per_mp, false); CHECK_INTERFACE(get_max_threads_per_block, false); + CHECK_INTERFACE(get_max_shared_mem_per_block, false); + CHECK_INTERFACE(get_max_blocks_per_mp, false); + CHECK_INTERFACE(get_warp_size, false); + CHECK_INTERFACE(get_max_registers_per_mp, false); + CHECK_INTERFACE(get_vector_width, false); CHECK_INTERFACE(get_max_grid_dim_size, false); + CHECK_INTERFACE(get_max_block_dim_size, false); CHECK_INTERFACE(init_eigen_device, false); CHECK_INTERFACE(destroy_eigen_device, false); diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index c365e5eee7a536..067588003a392f 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -67,12 +67,43 @@ size_t DeviceInterface::GetMaxThreadsPerBlock(size_t dev_id) { return 0; } +size_t DeviceInterface::GetMaxSharedMemPerBlock(size_t dev_id) { + VLOG(10) << Type() << " get max shared mem per block " << 0; + return 0; +} + +size_t DeviceInterface::GetMaxBlocksPerMultiProcessor(size_t dev_id) { + VLOG(10) << Type() << " get max blocks per multiprocessor " << 0; + return 0; +} + +size_t DeviceInterface::GetWarpSize(size_t dev_id) { + VLOG(10) << Type() << " get warp size " << 0; + return 0; +} + +size_t DeviceInterface::GetMaxRegistersPerMultiProcessor(size_t dev_id) { + VLOG(10) << Type() << " get max registers per multiprocessor " << 0; + return 0; +} + +size_t DeviceInterface::GetPreferredVectorWidth(size_t dev_id) { + VLOG(10) << Type() << " get preferred vector width " << 0; + return 0; +} + std::array DeviceInterface::GetMaxGridDimSize(size_t dev_id) { VLOG(10) << Type() << " get max grid dim size [" << 0 << ", " << 0 << ", " << 0 << "]"; return {0, 0, 0}; } +std::array DeviceInterface::GetMaxBlockDimSize(size_t dev_id) { + VLOG(10) << Type() << " get max block dim size [" << 0 << ", " << 0 << ", " + << 0 << "]"; + return {0, 0, 0}; +} + bool DeviceInterface::IsFloat16Supported(size_t dev_id) { VLOG(10) << Type() << " is float16 supported: " << false; return false; diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index 52e3b497cc7707..3b1efd79154077 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -22,6 +22,7 @@ #include "paddle/phi/common/place.h" #include "paddle/phi/core/allocator.h" +struct C_CinnInterface; namespace phi { struct DeviceProp { @@ -63,6 +64,8 @@ class DeviceInterface { // Driver / Runtime virtual ~DeviceInterface() {} + virtual C_CinnInterface* GetCinnInterface() { return nullptr; } + // Info virtual size_t GetComputeCapability(size_t dev_id); @@ -78,8 +81,20 @@ class DeviceInterface { // Driver / Runtime virtual size_t GetMaxThreadsPerBlock(size_t dev_id); + virtual size_t GetMaxSharedMemPerBlock(size_t dev_id); + + virtual size_t GetMaxBlocksPerMultiProcessor(size_t dev_id); + + virtual size_t GetWarpSize(size_t dev_id); + + virtual size_t GetMaxRegistersPerMultiProcessor(size_t dev_id); + + virtual size_t GetPreferredVectorWidth(size_t dev_id); + virtual std::array GetMaxGridDimSize(size_t dev_id); + virtual std::array GetMaxBlockDimSize(size_t dev_id); + virtual bool IsFloat16Supported(size_t dev_id); virtual bool IsBFloat16Supported(size_t dev_id); diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index 1ecbc038dcb206..e3d7f844207bb1 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -134,6 +134,41 @@ void profiler_add_runtime_trace_event(C_Profiler prof, void* event); void profiler_add_device_trace_event(C_Profiler prof, void* event); +struct C_CinnInterface { + size_t size; + void* dev_ptr; // 厂商私有上下文,传给下面所有函数的第一个参数 + + // --- Compiler Toolchain 部分 --- + C_Status (*compile)(void* dev_ptr, + const char* code, + char* out_path, + size_t len); + const char* (*get_runtime_source)(void* dev_ptr); + + // --- Runtime Strategy 部分 --- + C_Status (*module_load)(void* dev_ptr, const char* path, void** mod_out); + C_Status (*module_unload)(void* dev_ptr, void* module_handle); + C_Status (*get_kernel_address)(void* dev_ptr, + void* module_handle, + const char* func_name, + void** func_out); + C_Status (*launch_kernel)(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, + int gy, + int gz, + int bx, + int by, + int bz, + int shm, + void* stream); + + // --- Compile Strategy 部分 --- + C_Status (*apply_custom_pass)(void* dev_ptr, void* ir_module); +}; + struct C_DeviceInterface { // Core fill it and plugin must to check it size_t size; @@ -616,6 +651,43 @@ struct C_DeviceInterface { C_Status (*get_max_threads_per_block)(const C_Device device, size_t* threads_per_block); + /** + * @brief Get Max Shared Mem Per Block + * + * @param[size_t*] shared_mem_per_block + */ + C_Status (*get_max_shared_mem_per_block)(const C_Device device, + size_t* shared_mem_per_block); + /** + * @brief Get Max Block Per MultiProcessor + * + * @param[size_t*] blocks_per_mp + */ + C_Status (*get_max_blocks_per_mp)(const C_Device device, + size_t* blocks_per_mp); + + /** + * @brief Get Warp Size + * + * @param[size_t*] warp_size + */ + C_Status (*get_warp_size)(const C_Device device, size_t* warp_size); + + /** + * @brief Get Max Registers Per MultiProcessor + * + * @param[size_t*] registers_per_mp + */ + C_Status (*get_max_registers_per_mp)(const C_Device device, + size_t* registers_per_mp); + + /** + * @brief Get Preferred Vector Width + * + * @param[size_t*] vector_width + */ + C_Status (*get_vector_width)(const C_Device device, size_t* vector_width); + /** * @brief Get Max Grid Dim Size * @@ -624,6 +696,13 @@ struct C_DeviceInterface { C_Status (*get_max_grid_dim_size)(const C_Device device, std::array* grid_dim_size); + /** + * @brief Get Max Block Dim Size + * + * @param[std::array*] block_dim_size + */ + C_Status (*get_max_block_dim_size)( + const C_Device device, std::array* block_dim_size); /** * @brief Is float16 supported * @@ -881,7 +960,10 @@ struct C_DeviceInterface { void* x, float beta, void* y); - void* reserved_other_api[7]; + + struct C_CinnInterface* cinn_interface; + + void* reserved_other_api[6]; }; struct CustomRuntimeVersion { diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 5ee42486b101bc..8f11d0ff3cd6aa 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -514,6 +514,41 @@ size_t DeviceManager::GetMaxThreadsPerBlock(const Place& place) { return dev_impl->GetMaxThreadsPerBlock(device_id); } +size_t DeviceManager::GetMaxSharedMemPerBlock(const Place& place) { + auto device_type = place.GetDeviceType(); + auto device_id = place.GetDeviceId(); + auto dev_impl = GetDeviceInterfaceWithType(device_type); + return dev_impl->GetMaxSharedMemPerBlock(device_id); +} + +size_t DeviceManager::GetMaxBlocksPerMultiProcessor(const Place& place) { + auto device_type = place.GetDeviceType(); + auto device_id = place.GetDeviceId(); + auto dev_impl = GetDeviceInterfaceWithType(device_type); + return dev_impl->GetMaxBlocksPerMultiProcessor(device_id); +} + +size_t DeviceManager::GetWarpSize(const Place& place) { + auto device_type = place.GetDeviceType(); + auto device_id = place.GetDeviceId(); + auto dev_impl = GetDeviceInterfaceWithType(device_type); + return dev_impl->GetWarpSize(device_id); +} + +size_t DeviceManager::GetMaxRegistersPerMultiProcessor(const Place& place) { + auto device_type = place.GetDeviceType(); + auto device_id = place.GetDeviceId(); + auto dev_impl = GetDeviceInterfaceWithType(device_type); + return dev_impl->GetMaxRegistersPerMultiProcessor(device_id); +} + +size_t DeviceManager::GetPreferredVectorWidth(const Place& place) { + auto device_type = place.GetDeviceType(); + auto device_id = place.GetDeviceId(); + auto dev_impl = GetDeviceInterfaceWithType(device_type); + return dev_impl->GetPreferredVectorWidth(device_id); +} + std::array DeviceManager::GetMaxGridDimSize( const Place& place) { auto device_type = place.GetDeviceType(); @@ -522,6 +557,14 @@ std::array DeviceManager::GetMaxGridDimSize( return dev_impl->GetMaxGridDimSize(device_id); } +std::array DeviceManager::GetMaxBlockDimSize( + const Place& place) { + auto device_type = place.GetDeviceType(); + auto device_id = place.GetDeviceId(); + auto dev_impl = GetDeviceInterfaceWithType(device_type); + return dev_impl->GetMaxBlockDimSize(device_id); +} + bool DeviceManager::IsFloat16Supported(const Place& place) { auto device_type = place.GetDeviceType(); auto device_id = place.GetDeviceId(); diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 1e0221b323f1ae..fd1ff76a74bf46 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -29,6 +29,7 @@ #include "paddle/phi/backends/stream.h" #include "paddle/phi/common/port.h" +struct C_CinnInterface; namespace phi { class PADDLE_API Device final { public: @@ -126,6 +127,10 @@ class PADDLE_API Device final { std::string Type(); + struct C_CinnInterface* GetCinnInterface() const { + return impl_->GetCinnInterface(); + } + private: size_t dev_id_; DeviceInterface* impl_; @@ -185,8 +190,20 @@ class PADDLE_API DeviceManager { static size_t GetMaxThreadsPerBlock(const Place& place); + static size_t GetMaxSharedMemPerBlock(const Place& place); + + static size_t GetMaxBlocksPerMultiProcessor(const Place& place); + + static size_t GetWarpSize(const Place& place); + + static size_t GetMaxRegistersPerMultiProcessor(const Place& place); + + static size_t GetPreferredVectorWidth(const Place& place); + static std::array GetMaxGridDimSize(const Place& place); + static std::array GetMaxBlockDimSize(const Place& place); + static bool IsFloat16Supported(const Place& place); static bool IsBFloat16Supported(const Place& place);