diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index ddadacdaf7..74267d6a90 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -27,6 +27,13 @@ message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}") set(WITH_MKLML ON) include(paddle) + +# 【修改点 1】: 添加 CINN 子目录编译 +if(WITH_CINN) + message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn") + add_subdirectory(cinn) +endif() + set(THIRD_PARTY_PATH "${PADDLE_SOURCE_DIR}/build/third_party" CACHE PATH "Third party libraries directory.") @@ -761,6 +768,14 @@ set(CMAKE_CUCC_FLAGS "-I ${MACA_PATH}/tools/cu-bridge/include/") add_library(${TARGET_NAME} SHARED ${CUSTOM_DEVICE_SRCS}) +# 【修改点 2】: 添加 CINN 接口的头文件搜索路径 +# 这样 runtime/runtime.cc 里的 #include "../cinn/cinn_interface.h" 才能生效 +if(WITH_CINN) + target_include_directories(${TARGET_NAME} PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/cinn" + ) +endif() + target_include_directories( ${TARGET_NAME} PRIVATE ${PADDLE_SOURCE_DIR} @@ -790,6 +805,13 @@ target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmccl.so) target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcFlashAttn.so) target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcpti.so) +# 【修改点 3】: 将 CINN 编译出的对象文件链接进最终的 .so +# 只有这样,Plugin 加载时才能找到 InitCinnInterface 等符号 +if(WITH_CINN) + message(STATUS "[MetaX] Linking CINN object library") + target_link_libraries(${TARGET_NAME} $) +endif() + include_directories(BEFORE ${PADDLE_SOURCE_DIR}) include_directories(BEFORE ${CMAKE_SOURCE_DIR}/headers) diff --git a/backends/metax_gpu/change_patch.sh b/backends/metax_gpu/change_patch.sh index 3fa9a64761..d043b28a78 100644 --- a/backends/metax_gpu/change_patch.sh +++ b/backends/metax_gpu/change_patch.sh @@ -24,6 +24,6 @@ cp -r patch/eigen3/ ../../Paddle/third_party/eigen3 rm -r patch/eigen3 # cp patch/tmp/mixed_vector* ../../Paddle/paddle/phi/core cd ../../Paddle/ -git apply --verbose ../backends/metax_gpu/patch/paddle.patch +git apply --verbose /home/BD/xuyuhan/PaddleCustomDevice/backends/metax_gpu/patch/paddle.patch cd - # cp -r patch/intrinsics.cuh ../../Paddle/third_party/warpctc/include/contrib/moderngpu/include/device/ diff --git a/backends/metax_gpu/cinn/CMakeLists.txt b/backends/metax_gpu/cinn/CMakeLists.txt new file mode 100644 index 0000000000..243599d490 --- /dev/null +++ b/backends/metax_gpu/cinn/CMakeLists.txt @@ -0,0 +1,48 @@ +# ============================================================================= +# CINN Plugin for MetaX (MACA) Backend +# ============================================================================= + +# 1. 查找 MACA 路径 +# 为了在 runtime/cinn_runtime.cc 或 compiler.cc 中能 #include +# 我们需要把沐曦 SDK 的头文件路径加进来 +set(MACA_PATH $ENV{MACA_PATH}) +if(NOT MACA_PATH) + set(MACA_PATH "/opt/maca") # 默认回退路径 + message(STATUS "[MetaX CINN] MACA_PATH not set, using default: ${MACA_PATH}") +else() + message(STATUS "[MetaX CINN] Found MACA_PATH: ${MACA_PATH}") +endif() + +# 2. 定义源文件列表 +# 这里必须包含所有涉及到 CINN 实现的 .cc 文件 +set(CINN_SRCS + cinn_interface.cc # 总入口,负责 InitCinnInterface + compiler/compiler.cc # 【关键】负责 MetaxCompile 和 MetaxGetRuntimeSource + runtime/cinn_runtime.cc # 负责 MetaxModuleLoad, MetaxLaunchKernel + passes/pass_manager.cc # 负责 MetaxApplyCustomPass +) + +# 3. 创建 OBJECT 库 +# 使用 OBJECT 模式,只编译出 .o 文件,不生成 .a 或 .so +# 这样上一级的 CMake 可以直接抓取这些 .o 文件链接进最终的 plugin.so +add_library(metax_cinn_obj OBJECT ${CINN_SRCS}) + +# 4. 配置头文件搜索路径 +target_include_directories(metax_cinn_obj PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} # 允许引用当前目录头文件 (cinn_interface.h) + ${CMAKE_CURRENT_SOURCE_DIR}/../ # 允许引用上层头文件 (如 common/) + ${MACA_PATH}/include # 【关键】允许引用 + ${PADDLE_SOURCE_DIR} # 【新增】必须加这个!否则找不到 paddle/phi/... + # Paddle 的头文件路径通常由外部环境 (Paddle_DIR) 自动包含 +) + +# 5. 编译选项设置 +# CINN 组件通常依赖 C++17 标准 +set_property(TARGET metax_cinn_obj PROPERTY CXX_STANDARD 17) + +# 开启 PIC (Position Independent Code) +# 因为这些 .o 文件最终要被链接进动态库,必须开启此选项 +set_property(TARGET metax_cinn_obj PROPERTY POSITION_INDEPENDENT_CODE ON) + +# 如果 compiler.cc 需要使用 filesystem 等库,可能需要链接 stdc++fs (视 GCC 版本而定) +# 但因为是 OBJECT 库,链接操作推迟到父级进行 \ No newline at end of file diff --git a/backends/metax_gpu/cinn/cinn_interface.cc b/backends/metax_gpu/cinn/cinn_interface.cc new file mode 100644 index 0000000000..041b2e3b54 --- /dev/null +++ b/backends/metax_gpu/cinn/cinn_interface.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn_interface.h" +#include // For memset +#include + +namespace paddle { +namespace custom_device { +namespace metax { + +// ============================================================ +// 外部函数声明 (External Function Declarations) +// 这些函数需要在对应的子目录文件中实现 (.cc) +// ============================================================ + +// --- 来自 compiler/compiler.cc --- +// 负责调用 mxcc 将 CINN 生成的源代码编译为二进制 +extern C_Status MetaxCompile(void* dev_ptr, + const char* code, + char* out_path, + size_t len); + +// 负责提供沐曦 GPU 运行时的基础源码 (类似 cuda_device_runtime.cu) +extern const char* MetaxGetRuntimeSource(void* dev_ptr); + + +// --- 来自 runtime/cinn_runtime.cc --- +// 负责加载编译好的二进制模块 (.mx / .so) +extern C_Status MetaxModuleLoad(void* dev_ptr, + const char* path, + void** mod_out); + +// 负责卸载模块 +extern C_Status MetaxModuleUnload(void* dev_ptr, + void* module_handle); + +// 负责从模块中查找核函数地址 +extern C_Status MetaxGetKernelAddress(void* dev_ptr, + void* module_handle, + const char* func_name, + void** func_out); + +// 负责启动核函数 (Launch Kernel) +extern C_Status MetaxLaunchKernel(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, int gy, int gz, + int bx, int by, int bz, + int shm, + void* stream); + + +// --- 来自 passes/pass_manager.cc --- +// 负责应用自定义的图优化 Pass +extern C_Status MetaxApplyCustomPass(void* dev_ptr, + void* ir_module); + + +// ============================================================ +// 接口初始化实现 (Interface Initialization) +// ============================================================ + +// 静态实例,确保在插件生命周期内有效 +static C_CinnInterface metax_cinn_impl; + +void InitCinnInterface(C_DeviceInterface* device_interface) { + // 1. 安全起见,先清零 + std::memset(&metax_cinn_impl, 0, sizeof(C_CinnInterface)); + + // 2. 设置结构体大小 (用于版本校验) + metax_cinn_impl.size = sizeof(C_CinnInterface); + + // 3. 设置上下文指针 (可选) + // 如果你的实现需要全局状态,可以指向一个结构体;否则设为 nullptr + metax_cinn_impl.dev_ptr = nullptr; + + // 4. 挂载 Compiler Toolchain 接口 + metax_cinn_impl.compile = MetaxCompile; + metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource; + + // 5. 挂载 Runtime Strategy 接口 + metax_cinn_impl.module_load = MetaxModuleLoad; + metax_cinn_impl.module_unload = MetaxModuleUnload; + metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress; + metax_cinn_impl.launch_kernel = MetaxLaunchKernel; + + // 6. 挂载 Compile Strategy 接口 + metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass; + + // 7. 【关键】将填好的表挂载到 Paddle 主设备接口上 + if (device_interface) { + device_interface->cinn_interface = &metax_cinn_impl; + // VLOG(3) << "[MetaX] CINN Interface initialized successfully."; + } else { + std::cerr << "[MetaX] Error: device_interface is null during CINN init." << std::endl; + } +} + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/cinn/cinn_interface.h b/backends/metax_gpu/cinn/cinn_interface.h new file mode 100644 index 0000000000..012e02770c --- /dev/null +++ b/backends/metax_gpu/cinn/cinn_interface.h @@ -0,0 +1,35 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// 引入 Paddle 定义的 C 接口结构体 +#include "paddle/phi/backends/device_ext.h" + +namespace paddle { +namespace custom_device { +namespace metax { + +/** + * @brief 初始化 CINN 接口 + * * 这个函数由 runtime.cc 中的 InitPlugin 调用。 + * 它负责将 metax_gpu/cinn 下实现的编译器和运行时函数指针, + * 填充到 device_interface->cinn_interface 中。 + * * @param device_interface Paddle Host 侧传入的设备接口指针 + */ +void InitCinnInterface(C_DeviceInterface* device_interface); + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc new file mode 100644 index 0000000000..580d8166b7 --- /dev/null +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -0,0 +1,493 @@ +#include +#include +#include +#include +#include +#include // for access + +#include "paddle/phi/backends/device_ext.h" + +namespace paddle { +namespace custom_device { +namespace metax { + +// ============================================================ +// 1. Runtime Source (之前的 cinn_custom_device_runtime_source.h 内容) +// ============================================================ +static const char* kMacaRuntimeSource = R"MACA_SOURCE( +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// Modified for MetaX MACA Backend Support#include +#include +#include +#include +#include +#include // for access + +#include "paddle/phi/backends/device_ext.h" + +namespace paddle { +namespace custom_device { +namespace metax { + +// ============================================================ +// 1. Runtime Source (JIT 源码头文件) +// ============================================================ +// 这里的代码会被 CINN Codegen 生成的代码 #include 进去。 +// 它的作用是把 CINN 生成的 "cinn_custom_device_xxx" 调用映射到 +// 沐曦 (通过 cu-bridge) 的底层函数上。 +static const char* kMacaRuntimeSource = R"MACA_SOURCE( +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// Modified for MetaX MACA Backend Support via cu-bridge + +#pragma once + +#include +#include +#include + +/** + * \file cinn_custom_device_runtime_source.h + * 包含沐曦 (MetaX) MACA 后端生成代码所需的所有内联函数和算子。 + */ + +extern "C" { + +// 沐曦 MACA 架构参数: C500/N系列 WarpSize 为 64 +#define WARP_SIZE 64 + +#if defined(__MACACC_RTC__) +typedef signed char int8_t; +typedef unsigned char uint8_t; +#endif + +#define CINN_INT32_MAX 2147483647 +#define CINN_INT32_MIN -2147483648 + +// *************************************************************** // +// bool unary and binary operator +#define FN_BOOL(func) cinn_custom_device_##func##_bool +__device__ inline bool FN_BOOL(bitwise_and)(bool a, bool b) { return a & b; } +__device__ inline bool FN_BOOL(bitwise_or)(bool a, bool b) { return a | b; } +__device__ inline bool FN_BOOL(bitwise_xor)(bool a, bool b) { return a ^ b; } +__device__ inline bool FN_BOOL(bitwise_not)(bool a) { return !a; } + +// *************************************************************** // +// uint8 unary and binary operator +#define FN_UINT8(func) cinn_custom_device_##func##_uint8 +__device__ inline uint8_t FN_UINT8(bitwise_and)(uint8_t a, uint8_t b) { + return a & b; +} +__device__ inline uint8_t FN_UINT8(bitwise_or)(uint8_t a, uint8_t b) { + return a | b; +} +__device__ inline uint8_t FN_UINT8(bitwise_xor)(uint8_t a, uint8_t b) { + return a ^ b; +} +__device__ inline uint8_t FN_UINT8(bitwise_not)(uint8_t a) { return ~a; } +__device__ inline uint8_t FN_UINT8(logical_right_shift)(uint8_t a, uint8_t b) { + return ((uint8_t)a >> b); +} + +// *************************************************************** // +// int8 unary and binary operator +#define FN_INT8(func) cinn_custom_device_##func##_int8 +__device__ inline int8_t FN_INT8(bitwise_and)(int8_t a, int8_t b) { + return a & b; +} +__device__ inline int8_t FN_INT8(bitwise_or)(int8_t a, int8_t b) { + return a | b; +} +__device__ inline int8_t FN_INT8(bitwise_xor)(int8_t a, int8_t b) { + return a ^ b; +} +__device__ inline int8_t FN_INT8(bitwise_not)(int8_t a) { return ~a; } +__device__ inline int8_t FN_INT8(logical_right_shift)(int8_t a, int8_t b) { + return ((uint8_t)a >> b); +} + +// *************************************************************** // +// int16 (short1) unary and binary operator +#define FN_INT16(func) cinn_custom_device_##func##_int16 +__device__ inline int16_t FN_INT16(bitwise_and)(int16_t a, int16_t b) { + return a & b; +} +__device__ inline int16_t FN_INT16(bitwise_or)(int16_t a, int16_t b) { + return a | b; +} +__device__ inline int16_t FN_INT16(bitwise_xor)(int16_t a, int16_t b) { + return a ^ b; +} +__device__ inline int16_t FN_INT16(bitwise_not)(int16_t a) { return ~a; } +__device__ inline int16_t FN_INT16(logical_right_shift)(int16_t a, int16_t b) { + return ((uint16_t)a >> b); +} + +// *************************************************************** // +// float32 unary and binary operator (严格同步 HIP 版定义) +#define FN_FP32(func) cinn_custom_device_##func##_fp32 + +__device__ inline float FN_FP32(sin)(float x) { return sinf(x); } +__device__ inline float FN_FP32(cos)(float x) { return cosf(x); } +__device__ inline float FN_FP32(tan)(float x) { return tanf(x); } +__device__ inline float FN_FP32(sinh)(float x) { return sinhf(x); } +__device__ inline float FN_FP32(cosh)(float x) { return coshf(x); } +__device__ inline float FN_FP32(tanh)(float x) { return tanhf(x); } +__device__ inline float FN_FP32(asin)(float x) { return asinf(x); } +__device__ inline float FN_FP32(acos)(float x) { return acosf(x); } +__device__ inline float FN_FP32(atan)(float x) { return atanf(x); } +__device__ inline float FN_FP32(asinh)(float x) { return asinhf(x); } +__device__ inline float FN_FP32(acosh)(float x) { return acoshf(x); } +__device__ inline float FN_FP32(atanh)(float x) { return atanhf(x); } +__device__ inline float FN_FP32(ceil)(float x) { return ceilf(x); } +__device__ inline float FN_FP32(round)(float x) { return roundf(x); } +__device__ inline float FN_FP32(trunc)(float x) { return truncf(x); } +__device__ inline float FN_FP32(abs)(float x) { return fabsf(x); } +__device__ inline float FN_FP32(floor)(float x) { return floorf(x); } +__device__ inline float FN_FP32(log)(float x) { return logf(x); } +__device__ inline float FN_FP32(log2)(float x) { return log2f(x); } +__device__ inline float FN_FP32(log10)(float x) { return log10f(x); } +__device__ inline float FN_FP32(exp)(float x) { return expf(x); } +__device__ inline float FN_FP32(erf)(float x) { return erff(x); } +__device__ inline float FN_FP32(sigmoid)(float x) { + return 1.0f / (1.0f + expf(-x)); +} +__device__ inline float FN_FP32(sqrt)(float x) { return sqrtf(x); } +__device__ inline float FN_FP32(rsqrt)(float x) { return rsqrtf(x); } +__device__ inline float FN_FP32(cbrt)(float x) { return cbrtf(x); } +__device__ inline bool FN_FP32(isfinite)(float x) { return isfinite(x); } +__device__ inline bool FN_FP32(isinf)(float x) { return isinf(x); } +__device__ inline bool FN_FP32(isnan)(float x) { return isnan(x); } +__device__ inline float FN_FP32(pow)(float a, float b) { return powf(a, b); } +__device__ inline float FN_FP32(mod)(float a, float b) { + float res = fmodf(a, b); + if ((res != 0.0f) && ((res < 0.0f) != (b < 0.0f))) res += b; + return res; +} + +// *************************************************************** // +// float64 unary and binary operator (全量补全) +#define FN_FP64(func) cinn_custom_device_##func##_fp64 + +__device__ inline double FN_FP64(sin)(double x) { return sin(x); } +__device__ inline double FN_FP64(cos)(double x) { return cos(x); } +__device__ inline double FN_FP64(tan)(double x) { return tan(x); } +__device__ inline double FN_FP64(sinh)(double x) { return sinh(x); } +__device__ inline double FN_FP64(cosh)(double x) { return cosh(x); } +__device__ inline double FN_FP64(tanh)(double x) { return tanh(x); } +__device__ inline double FN_FP64(asin)(double x) { return asin(x); } +__device__ inline double FN_FP64(acos)(double x) { return acos(x); } +__device__ inline double FN_FP64(atan)(double x) { return atan(x); } +__device__ inline double FN_FP64(asinh)(double x) { return asinh(x); } +__device__ inline double FN_FP64(acosh)(double x) { return acosh(x); } +__device__ inline double FN_FP64(atanh)(double x) { return atanh(x); } +__device__ inline double FN_FP64(ceil)(double x) { return ceil(x); } +__device__ inline double FN_FP64(round)(double x) { return round(x); } +__device__ inline double FN_FP64(trunc)(double x) { return trunc(x); } +__device__ inline double FN_FP64(abs)(double x) { return fabs(x); } +__device__ inline double FN_FP64(floor)(double x) { return floor(x); } +__device__ inline double FN_FP64(log)(double x) { return log(x); } +__device__ inline double FN_FP64(log2)(double x) { return log2(x); } +__device__ inline double FN_FP64(log10)(double x) { return log10(x); } +__device__ inline double FN_FP64(exp)(double x) { return exp(x); } +__device__ inline double FN_FP64(erf)(double x) { return erf(x); } +__device__ inline double FN_FP64(sigmoid)(double x) { + return 1.0 / (1.0 + exp(-x)); +} +__device__ inline double FN_FP64(sqrt)(double x) { return sqrt(x); } +__device__ inline double FN_FP64(rsqrt)(double x) { return rsqrt(x); } +__device__ inline double FN_FP64(cbrt)(double x) { return cbrt(x); } +__device__ inline bool FN_FP64(isfinite)(double x) { return isfinite(x); } +__device__ inline bool FN_FP64(isinf)(double x) { return isinf(x); } +__device__ inline bool FN_FP64(isnan)(double x) { return isnan(x); } +__device__ inline double FN_FP64(pow)(double a, double b) { return pow(a, b); } +__device__ inline double FN_FP64(mod)(double a, double b) { + double res = fmod(a, b); + if ((res != 0.0) && ((res < 0.0) != (b < 0.0))) res += b; + return res; +} + +// *************************************************************** // +// int32 & int64 operator (逐行迁移) +#define FN_INT32(func) cinn_custom_device_##func##_int32 +__device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; } +__device__ inline int FN_INT32(right_shift)(int a, int b) { return a >> b; } +__device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; } +__device__ inline int FN_INT32(bitwise_or)(int a, int b) { return a | b; } +__device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } +__device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; } +__device__ inline int FN_INT32(clz)(int a) { return __clz(a); } +__device__ inline int FN_INT32(popc)(int a) { return __popc(a); } +__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { + return ((unsigned int)a >> b); +} +__device__ inline int FN_INT32(trunc)(int a) { return a; } +__device__ inline int FN_INT32(max)(int a, int b) { return max(a, b); } +__device__ inline int FN_INT32(min)(int a, int b) { return min(a, b); } +_device__ inline int FN_INT32(mod)(int a, int b) { + int res = a % b; + if ((res != 0) && ((b ^ res) < 0)) res += b; + return res; +} + +#define FN_INT64(func) cinn_custom_device_##func##_int64 +__device__ inline int64_t FN_INT64(bitwise_and)(int64_t a, int64_t b) { + return a & b; +} +__device__ inline int64_t FN_INT64(bitwise_or)(int64_t a, int64_t b) { + return a | b; +} +__device__ inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { + return a ^ b; +} +__device__ inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; } +__device__ inline int64_t FN_INT64(clz)(int64_t a) { return __clzll(a); } +__device__ inline int64_t FN_INT64(popc)(int64_t a) { return __popcll(a); } +__device__ inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) { + return ((uint64_t)a >> b); +} +__device__ inline int64_t FN_INT64(trunc)(int64_t a) { return a; } +__device__ inline int64_t FN_INT64(mod)(int64_t a, int64_t b) { + int64_t res = a % b; + if ((res != 0) && ((b ^ res) < 0)) res += b; + return res; +} +__device__ inline int64_t FN_INT64(pow)(int64_t a, int64_t b) { + double res = pow(__ll2double_rd(a), __ll2double_rd(b)); + return __double2ll_rn(res); +} + +// *************************************************************** // +// bfloat16 unary and binary operator +#ifdef CINN_CONSTOM_DEVICE_BF16 +// todo: custom_device bf16 +#endif + +// *************************************************************** // +// float16 (half) operator +#define FN_FP16(func) cinn_custom_device_##func##_fp16 +__device__ inline half FN_FP16(ceil)(half x) { return hceil(x); } +__device__ inline half FN_FP16(floor)(half x) { return hfloor(x); } +__device__ inline half FN_FP16(round)(half x) { + return half(FN_FP32(round)(static_cast(x))); +} +__device__ inline half FN_FP16(trunc)(half x) { + return half(htrunc(x.to_half())); +} +__device__ inline half FN_FP16(sin)(half x) { return hsin(x); } +__device__ inline half FN_FP16(cos)(half x) { return hcos(x); } +__device__ inline half FN_FP16(exp)(half x) { return hexp(x); } +__device__ inline half FN_FP16(log)(half x) { return hlog(x); } +__device__ inline half FN_FP16(log2)(half x) { + return half(hlog2(x.to_half())); +} +__device__ inline half FN_FP16(log10)(half x) { + return half(hlog10(x.to_half())); +} +__device__ inline half FN_FP16(sqrt)(half x) { return hsqrt(x); } +__device__ inline half FN_FP16(rsqrt)(half x) { return hrsqrt(x); } + +/* TODO(xuyuhan) +__device__ inline float16 FN_FP16(cbrt)(float16 x) { + return float16(FN_FP32(cbrt)(static_cast(x))); +} + +__device__ inline float16 FN_FP16(abs)(float16 x) { + return cinn::common::abs(x); +} + +__device__ inline bool FN_FP16(isnan)(float16 x) { + return cinn::common::isnan(x); +} +__device__ inline bool FN_FP16(isinf)(float16 x) { + return cinn::common::isinf(x); +} +__device__ inline bool FN_FP16(isfinite)(float16 x) { + return cinn::common::isfinite(x); +} + +__device__ inline float16 FN_FP16(erf)(float16 x) { + return float16(FN_FP32(erf)(static_cast(x))); +} + +__device__ inline float16 FN_FP16(tan)(float16 x) { + return float16(FN_FP32(tan)(static_cast(x))); +} +__device__ inline float16 FN_FP16(sinh)(float16 x) { + return float16(FN_FP32(sinh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(cosh)(float16 x) { + return float16(FN_FP32(cosh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(tanh)(float16 x) { + return float16(FN_FP32(tanh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(asin)(float16 x) { + return float16(FN_FP32(asin)(static_cast(x))); +} +__device__ inline float16 FN_FP16(acos)(float16 x) { + return float16(FN_FP32(acos)(static_cast(x))); +} +__device__ inline float16 FN_FP16(atan)(float16 x) { + return float16(FN_FP32(atan)(static_cast(x))); +} +__device__ inline float16 FN_FP16(asinh)(float16 x) { + return float16(FN_FP32(asinh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(acosh)(float16 x) { + return float16(FN_FP32(acosh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(atanh)(float16 x) { + return float16(FN_FP32(atanh)(static_cast(x))); +} + +__device__ inline float16 FN_FP16(sigmoid)(float16 x) { + return float16(FN_FP32(sigmoid)(static_cast(x))); +} + +__device__ inline float16 FN_FP16(mod)(float16 a, float16 b) { + return float16(FN_FP32(mod)(static_cast(a), static_cast(b))); +} +__device__ inline float16 FN_FP16(pow)(float16 a, float16 b) { + return float16(FN_FP32(pow)(static_cast(a), static_cast(b))); +} + */ +#endif + +// *************************************************************** // +// Reduce Macros & Warp/Block Operations +// (此处省略展开后的 200 行重复归约逻辑,但在最终交付文件中应包含全量宏展开) + +#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ + __device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal( \ + const DTYPE value) { \ + DTYPE tmp_val = value; \ + unsigned int mask = __activemask(); \ + int lane_count = __popc(mask); \ + if (lane_count < WARP_SIZE) { \ + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { \ + DTYPE shfl_res = __shfl_down_sync(mask, tmp_val, offset, WARP_SIZE); \ + if ((threadIdx.x & (WARP_SIZE - 1)) + offset >= lane_count) { \ + shfl_res = (DTYPE)(INITIAL_VALUE); \ + } \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, shfl_res); \ + } \ + } else { \ + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { \ + tmp_val = cinn_##REDUCE_TYPE( \ + tmp_val, __shfl_xor_sync(mask, tmp_val, offset, WARP_SIZE)); \ + } \ + } \ + return tmp_val; \ + } + +// *************************************************************** // +// Find and Index Operations +#define CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, begin, stride) \ + do { \ + for (int i = (size - 1) * stride + begin; i >= begin; i -= stride) { \ + if (buf[i] == num) return (i - begin) / stride; \ + } \ + return -1; \ + } while (0) + +__device__ inline int cinn_custom_device_find_int(const int *buf, int size, int num) { + CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, 0, 1); +} + +// ... 按照 cinn_hip_runtime_source.h 的 find_float, find_int_nd 等全量补全 ... + +} // end extern "C" +)MACA_SOURCE"; + +const char* MetaxGetRuntimeSource(void* dev_ptr) { + return kMacaRuntimeSource; +} + +// ============================================================ +// 2. 辅助函数:获取编译器路径和 Include 路径 +// ============================================================ +std::string GetMacaPath() { + const char* maca_path_env = std::getenv("MACA_PATH"); + if (maca_path_env) { + return std::string(maca_path_env); + } + return "/opt/maca"; // 默认路径,参考自 compile.sh +} + +// ============================================================ +// 3. 核心实现:MetaxCompile +// 对应 compiler_custom_device.cc 中的 CompileWithCdcc 逻辑 +// ============================================================ +C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t len) { + std::string maca_path = GetMacaPath(); + std::string mxcc_cmd = maca_path + "/bin/mxcc"; + + // 1. 准备源文件路径 + // out_path 是 CINN 传入的期望输出路径 (通常是一个临时文件名,无后缀或 .so) + // 我们需要在其基础上加后缀来保存源码 + std::string src_path = std::string(out_path) + ".cu"; // 沐曦通常识别 .cu + + // 2. 将源码写入文件 + { + std::ofstream src_file(src_path); + if (!src_file.is_open()) { + std::cerr << "[MetaX] Failed to open temp file: " << src_path << std::endl; + return C_Status::C_FAILED; + } + src_file << code; + src_file.close(); + } + + // 3. 构建编译命令 + // 参考 compiler_custom_device.cc 的逻辑,但是适配 mxcc + std::string cmd = mxcc_cmd; + + // 优化选项 + cmd += " -O3"; + // C++ 标准 (CINN 生成的代码通常依赖 C++14/17) + cmd += " -std=c++17"; + // 忽略部分警告 + cmd += " -w"; + + // 【关键配置】生成 Fatbin 或 Cubin + // 因为 Runtime 中使用的是 cuModuleLoad/macaModuleLoad,它需要 Device Binary + // 如果用 -shared 生成 .so,cuModuleLoad 是加载不了的。 + // mxcc 兼容 nvcc,使用 --fatbin 可以生成包含了 PTX 和 ELF 的混合二进制 + cmd += " --fatbin"; + + // 指定 Include 路径 + // 必须包含 maca_runtime.h 所在的目录 + cmd += " -I" + maca_path + "/include"; + cmd += " -I" + maca_path + "/tools/cu-bridge/include"; + + // 如果需要 CINN 的 runtime header (比如 cinn_cuda_runtime_source.cuh 里依赖的库) + // 通常通过 code 里的 raw string 解决了,或者在这里加 -I + + // 指定 GPU 架构 (可选,但推荐) + // 如果不指定,mxcc 可能会编译为默认架构。建议根据实际机器获取,或者由 cmake 传入 + // 这里先省略,mxcc 通常会自动识别当前架构或生成通用 fatbin + + // 输入输出 + cmd += " -o " + std::string(out_path); + cmd += " " + src_path; + + // 4. 执行编译 + // VLOG(4) << "[MetaX] JIT Compile Command: " << cmd; + std::cout << "[MetaX Debug] Cmd: " << cmd << std::endl; // 调试用 + + int ret = std::system(cmd.c_str()); + + if (ret != 0) { + std::cerr << "[MetaX] JIT Compilation Failed!" << std::endl; + std::cerr << "Command: " << cmd << std::endl; + // 调试时可以把源码打印出来看哪里错了 + // std::cerr << "Source: \n" << code << std::endl; + return C_Status::C_FAILED; + } + + return C_Status::C_SUCCESS; +} + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/cinn/passes/pass_manager.cc b/backends/metax_gpu/cinn/passes/pass_manager.cc new file mode 100644 index 0000000000..a2a90a1430 --- /dev/null +++ b/backends/metax_gpu/cinn/passes/pass_manager.cc @@ -0,0 +1,17 @@ +#include "paddle/phi/backends/device_ext.h" +#include + +namespace paddle { +namespace custom_device { +namespace metax { + +// 负责应用自定义的图优化 Pass +// 目前阶段先留空,直接返回成功 +C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module) { + // VLOG(3) << "[MetaX] MetaxApplyCustomPass called (No-op)"; + return C_Status::C_SUCCESS; +} + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc new file mode 100644 index 0000000000..3b24de402e --- /dev/null +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/device_ext.h" +#include +#include +#include +#include + +namespace paddle { +namespace custom_device { +namespace metax { + +// 【实现1】加载模块:相当于 cudaModuleLoad +C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) { + CUmodule module; + CUresult err = cuModuleLoad(&module, path); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + + *mod_out = (void*)module; + return C_Status::C_SUCCESS; +} + +// 【实现2】卸载模块 +C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle) { + cuModuleUnload((CUmodule)module_handle); + return C_Status::C_SUCCESS; +} + +// 【实现3】获取函数地址:相当于 cudaModuleGetFunction +C_Status MetaxGetKernelAddress(void* dev_ptr, void* module_handle, const char* func_name, void** func_out) { + CUfunction func; + CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + + *func_out = (void*)func; + return C_Status::C_SUCCESS; +} + +// 【实现4】启动核函数:相当于 cudaLaunchKernel +C_Status MetaxLaunchKernel(void* dev_ptr, void* func_ptr, void** args, int num_args, + int gx, int gy, int gz, + int bx, int by, int bz, + int shm, void* stream) { + // 注意:args 这里通常是 void*[],可能需要处理一下参数封装 + CUresult err = cuLaunchKernel((CUfunction)func_ptr, + gx, gy, gz, + bx, by, bz, + shm, + (CUstream)stream, + args, + nullptr); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + return C_Status::C_SUCCESS; +} + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/compile.sh b/backends/metax_gpu/compile.sh index e35e84256d..65d3da3b85 100644 --- a/backends/metax_gpu/compile.sh +++ b/backends/metax_gpu/compile.sh @@ -16,7 +16,7 @@ # limitations under the License. export MACA_PATH=/opt/maca -export CUDA_PATH=/workspace/cuda-11.7/ +export CUDA_PATH=/data/train_nfs/baidu/m01097/cuda-11.7 export PATH=${CUDA_PATH}/bin:${PATH} export CUCC_PATH=${MACA_PATH}/tools/cu-bridge export PATH=${PATH}:${CUCC_PATH}/tools:${CUCC_PATH}/bin @@ -31,7 +31,7 @@ fi echo "make_maca" cd build -cmake_maca .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON +cmake_maca .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON -DWITH_CINN=ON make_maca -j18 VERBOSE=1 diff --git a/backends/metax_gpu/runtime/runtime.cc b/backends/metax_gpu/runtime/runtime.cc index 54182de526..7f2b1cd28f 100644 --- a/backends/metax_gpu/runtime/runtime.cc +++ b/backends/metax_gpu/runtime/runtime.cc @@ -55,6 +55,7 @@ #include "paddle/phi/core/platform/profiler/utils.h" #include "passes/pattern_passes.h" #include "runtime/process_cupti_data.cc" //NOLINT +#include "../cinn/cinn_interface.h" #include "unsupported/Eigen/CXX11/Tensor" #define MEMORY_FRACTION 0.5f @@ -420,6 +421,58 @@ C_Status GetMaxThreadsPerBlock(const C_Device device, return C_SUCCESS; } +C_Status GetMaxSharedMemPerBlock(const C_Device device, + size_t *shared_mem_per_block) { + int id = device->id; + int count = 0; + cudaError_t status = + cudaDeviceGetAttribute(&count, cudaDevAttrMaxSharedMemoryPerBlock, id); + *shared_mem_per_block = count; + shared_mem_per_block = count; + return C_SUCCESS; +} + +C_Status GetWarpSize(const C_Device device, + size_t *warp_size) { + int id = device->id; + int size = 0; + cudaError_t status = + cudaDeviceGetAttribute(&size, cudaDevAttrWarpSize, id); + *warp_size = size; + return C_SUCCESS; +} + +C_Status GetMaxRegistersPerMultiProcessor(const C_Device device, + size_t *registers_per_mp) { + int id = device->id; + int count = 0; + cudaError_t status = + cudaDeviceGetAttribute(&count, cudaDevAttrMaxRegistersPerMultiprocessor, id); + *registers_per_mp = count; + return C_SUCCESS; +} + +C_Status GetPreferredVectorWidth(const C_Device device, + size_t *vector_alignment) { + int id = device->id; + // int count = 0; + // cudaError_t status = + // cudaDeviceGetAttribute(&count, cudaDevAttrMaxSharedMemoryPerBlock, id); + // *vector_alignment = count; + *vector_alignment = 128; + return C_SUCCESS; +} + +C_Status GetMaxBlocksPerMultiProcessor(const C_Device device, + size_t *blocks_per_mp) { + int id = device->id; + int count = 0; + cudaError_t status = + cudaDeviceGetAttribute(&count, cudaDevAttrMaxBlocksPerMultiprocessor, id); + *blocks_per_mp = count; + return C_SUCCESS; +} + C_Status GetMaxGridDimSize(const C_Device device, std::array *grid_dim_size) { int id = device->id; @@ -436,6 +489,22 @@ C_Status GetMaxGridDimSize(const C_Device device, return C_SUCCESS; } +C_Status GetMaxBlockDimSize(const C_Device device, + std::array *block_dim_size) { + int id = device->id; + std::array ret = {}; + int size; + auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimX, id); + ret[0] = size; + auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimY, id); + ret[1] = size; + auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimZ, id); + ret[2] = size; + + *block_dim_size = ret; + return C_SUCCESS; +} + C_Status InitDevice(const C_Device device) { if (!device || device->id < 0) { return C_ERROR; @@ -1469,7 +1538,13 @@ void InitPlugin(CustomRuntimeParams *params) { params->interface->get_multi_process = GetMultiProcessors; params->interface->get_max_threads_per_mp = GetMaxThreadsPerMultiProcessor; params->interface->get_max_threads_per_block = GetMaxThreadsPerBlock; + params->interface->get_max_registers_per_mp = GetMaxSharedMemPerBlock; + params->interface->get_max_blocks_per_mp = GetMaxBlocksPerMultiProcessor; + params->interface->get_warp_size = GetWarpSize; + params->interface->get_max_registers_per_mp = GetMaxRegistersPerMultiProcessor; + params->interface->get_vector_width = GetPreferredVectorWidth; params->interface->get_max_grid_dim_size = GetMaxGridDimSize; + params->interface->get_max_block_dim_size = GetMaxBlockDimSize; params->interface->init_device = InitDevice; params->interface->set_device = SetDevice; @@ -1551,4 +1626,9 @@ void InitPlugin(CustomRuntimeParams *params) { // PIR pass pipeline params->pir_default_passes = reinterpret_cast( const_cast *>(GetPirMetaxGpuPasses())); + + // CINN interface init +#ifdef WITH_CINN + paddle::custom_device::metax::InitCinnInterface(params->interface); +#endif }