Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backends/metax_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}")
set(WITH_MKLML ON)

include(paddle)

# 【修改点 1】: 添加 CINN 子目录编译
if(WITH_CINN)
message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn")
add_subdirectory(cinn)
endif()

set(THIRD_PARTY_PATH
"${PADDLE_SOURCE_DIR}/build/third_party"
CACHE PATH "Third party libraries directory.")
Expand Down Expand Up @@ -761,6 +768,14 @@ set(CMAKE_CUCC_FLAGS "-I ${MACA_PATH}/tools/cu-bridge/include/")

add_library(${TARGET_NAME} SHARED ${CUSTOM_DEVICE_SRCS})

# 【修改点 2】: 添加 CINN 接口的头文件搜索路径
# 这样 runtime/runtime.cc 里的 #include "../cinn/cinn_interface.h" 才能生效
if(WITH_CINN)
target_include_directories(${TARGET_NAME} PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/cinn"
)
endif()

target_include_directories(
${TARGET_NAME}
PRIVATE ${PADDLE_SOURCE_DIR}
Expand Down Expand Up @@ -790,6 +805,13 @@ target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmccl.so)
target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcFlashAttn.so)
target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcpti.so)

# 【修改点 3】: 将 CINN 编译出的对象文件链接进最终的 .so
# 只有这样,Plugin 加载时才能找到 InitCinnInterface 等符号
if(WITH_CINN)
message(STATUS "[MetaX] Linking CINN object library")
target_link_libraries(${TARGET_NAME} $<TARGET_OBJECTS:metax_cinn_obj>)
endif()

include_directories(BEFORE ${PADDLE_SOURCE_DIR})
include_directories(BEFORE ${CMAKE_SOURCE_DIR}/headers)

Expand Down
2 changes: 1 addition & 1 deletion backends/metax_gpu/change_patch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ cp -r patch/eigen3/ ../../Paddle/third_party/eigen3
rm -r patch/eigen3
# cp patch/tmp/mixed_vector* ../../Paddle/paddle/phi/core
cd ../../Paddle/
git apply --verbose ../backends/metax_gpu/patch/paddle.patch
git apply --verbose /home/BD/xuyuhan/PaddleCustomDevice/backends/metax_gpu/patch/paddle.patch
cd -
# cp -r patch/intrinsics.cuh ../../Paddle/third_party/warpctc/include/contrib/moderngpu/include/device/
48 changes: 48 additions & 0 deletions backends/metax_gpu/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# =============================================================================
# CINN Plugin for MetaX (MACA) Backend
# =============================================================================

# 1. 查找 MACA 路径
# 为了在 runtime/cinn_runtime.cc 或 compiler.cc 中能 #include <maca_runtime.h>
# 我们需要把沐曦 SDK 的头文件路径加进来
set(MACA_PATH $ENV{MACA_PATH})
if(NOT MACA_PATH)
set(MACA_PATH "/opt/maca") # 默认回退路径
message(STATUS "[MetaX CINN] MACA_PATH not set, using default: ${MACA_PATH}")
else()
message(STATUS "[MetaX CINN] Found MACA_PATH: ${MACA_PATH}")
endif()

# 2. 定义源文件列表
# 这里必须包含所有涉及到 CINN 实现的 .cc 文件
set(CINN_SRCS
cinn_interface.cc # 总入口,负责 InitCinnInterface
compiler/compiler.cc # 【关键】负责 MetaxCompile 和 MetaxGetRuntimeSource
runtime/cinn_runtime.cc # 负责 MetaxModuleLoad, MetaxLaunchKernel
passes/pass_manager.cc # 负责 MetaxApplyCustomPass
)

# 3. 创建 OBJECT 库
# 使用 OBJECT 模式,只编译出 .o 文件,不生成 .a 或 .so
# 这样上一级的 CMake 可以直接抓取这些 .o 文件链接进最终的 plugin.so
add_library(metax_cinn_obj OBJECT ${CINN_SRCS})

# 4. 配置头文件搜索路径
target_include_directories(metax_cinn_obj PRIVATE
${CMAKE_CURRENT_SOURCE_DIR} # 允许引用当前目录头文件 (cinn_interface.h)
${CMAKE_CURRENT_SOURCE_DIR}/../ # 允许引用上层头文件 (如 common/)
${MACA_PATH}/include # 【关键】允许引用 <maca_runtime.h>
${PADDLE_SOURCE_DIR} # 【新增】必须加这个!否则找不到 paddle/phi/...
# Paddle 的头文件路径通常由外部环境 (Paddle_DIR) 自动包含
)

# 5. 编译选项设置
# CINN 组件通常依赖 C++17 标准
set_property(TARGET metax_cinn_obj PROPERTY CXX_STANDARD 17)

# 开启 PIC (Position Independent Code)
# 因为这些 .o 文件最终要被链接进动态库,必须开启此选项
set_property(TARGET metax_cinn_obj PROPERTY POSITION_INDEPENDENT_CODE ON)

# 如果 compiler.cc 需要使用 filesystem 等库,可能需要链接 stdc++fs (视 GCC 版本而定)
# 但因为是 OBJECT 库,链接操作推迟到父级进行
114 changes: 114 additions & 0 deletions backends/metax_gpu/cinn/cinn_interface.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn_interface.h"
#include <cstring> // For memset
#include <iostream>

namespace paddle {
namespace custom_device {
namespace metax {

// ============================================================
// 外部函数声明 (External Function Declarations)
// 这些函数需要在对应的子目录文件中实现 (.cc)
// ============================================================

// --- 来自 compiler/compiler.cc ---
// 负责调用 mxcc 将 CINN 生成的源代码编译为二进制
extern C_Status MetaxCompile(void* dev_ptr,
const char* code,
char* out_path,
size_t len);

// 负责提供沐曦 GPU 运行时的基础源码 (类似 cuda_device_runtime.cu)
extern const char* MetaxGetRuntimeSource(void* dev_ptr);


// --- 来自 runtime/cinn_runtime.cc ---
// 负责加载编译好的二进制模块 (.mx / .so)
extern C_Status MetaxModuleLoad(void* dev_ptr,
const char* path,
void** mod_out);

// 负责卸载模块
extern C_Status MetaxModuleUnload(void* dev_ptr,
void* module_handle);

// 负责从模块中查找核函数地址
extern C_Status MetaxGetKernelAddress(void* dev_ptr,
void* module_handle,
const char* func_name,
void** func_out);

// 负责启动核函数 (Launch Kernel)
extern C_Status MetaxLaunchKernel(void* dev_ptr,
void* func_ptr,
void** args,
int num_args,
int gx, int gy, int gz,
int bx, int by, int bz,
int shm,
void* stream);


// --- 来自 passes/pass_manager.cc ---
// 负责应用自定义的图优化 Pass
extern C_Status MetaxApplyCustomPass(void* dev_ptr,
void* ir_module);


// ============================================================
// 接口初始化实现 (Interface Initialization)
// ============================================================

// 静态实例,确保在插件生命周期内有效
static C_CinnInterface metax_cinn_impl;

void InitCinnInterface(C_DeviceInterface* device_interface) {
// 1. 安全起见,先清零
std::memset(&metax_cinn_impl, 0, sizeof(C_CinnInterface));

// 2. 设置结构体大小 (用于版本校验)
metax_cinn_impl.size = sizeof(C_CinnInterface);

// 3. 设置上下文指针 (可选)
// 如果你的实现需要全局状态,可以指向一个结构体;否则设为 nullptr
metax_cinn_impl.dev_ptr = nullptr;

// 4. 挂载 Compiler Toolchain 接口
metax_cinn_impl.compile = MetaxCompile;
metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource;

// 5. 挂载 Runtime Strategy 接口
metax_cinn_impl.module_load = MetaxModuleLoad;
metax_cinn_impl.module_unload = MetaxModuleUnload;
metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress;
metax_cinn_impl.launch_kernel = MetaxLaunchKernel;

// 6. 挂载 Compile Strategy 接口
metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass;

// 7. 【关键】将填好的表挂载到 Paddle 主设备接口上
if (device_interface) {
device_interface->cinn_interface = &metax_cinn_impl;
// VLOG(3) << "[MetaX] CINN Interface initialized successfully.";
} else {
std::cerr << "[MetaX] Error: device_interface is null during CINN init." << std::endl;
}
}

} // namespace metax
} // namespace custom_device
} // namespace paddle
35 changes: 35 additions & 0 deletions backends/metax_gpu/cinn/cinn_interface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

// 引入 Paddle 定义的 C 接口结构体
#include "paddle/phi/backends/device_ext.h"

namespace paddle {
namespace custom_device {
namespace metax {

/**
* @brief 初始化 CINN 接口
* * 这个函数由 runtime.cc 中的 InitPlugin 调用。
* 它负责将 metax_gpu/cinn 下实现的编译器和运行时函数指针,
* 填充到 device_interface->cinn_interface 中。
* * @param device_interface Paddle Host 侧传入的设备接口指针
*/
void InitCinnInterface(C_DeviceInterface* device_interface);

} // namespace metax
} // namespace custom_device
} // namespace paddle
Loading