From 3014abe7236a99257085ef8037e81e3d01538c55 Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Fri, 6 Feb 2026 03:50:30 +0000 Subject: [PATCH 1/2] fix: fix runtime to support multi thread --- include/core/blob.h | 3 +- include/core/runtime.h | 15 +- src/core/runtime.cc | 115 +++++-- src/core/tensor.cc | 2 +- test/kernels/test_elementwise_kernel.cc | 416 +++++++++++++++++++++--- 5 files changed, 471 insertions(+), 80 deletions(-) diff --git a/include/core/blob.h b/include/core/blob.h index bb5dd44..2c2ec2d 100644 --- a/include/core/blob.h +++ b/include/core/blob.h @@ -13,9 +13,10 @@ class BlobObj { BlobObj(void *ptr) : ptr(ptr) {} BlobObj(BlobObj &other) = delete; BlobObj &operator=(BlobObj const &) = delete; - ~BlobObj() {}; + ~BlobObj(){}; template T getPtr() const { return reinterpret_cast(ptr); } + void *getRawDataPtr() const { return ptr; } }; } // namespace infini diff --git a/include/core/runtime.h b/include/core/runtime.h index 529db35..6cb6c57 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -15,6 +15,8 @@ struct ContextObj { infiniDevice_t device = INFINI_DEVICE_CPU; int deviceId = 0; infinirtStream_t stream = nullptr; + void *workspace = nullptr; + size_t workspaceSize = 0; }; using Context = Ref; @@ -24,13 +26,13 @@ class RuntimeObj : public std::enable_shared_from_this { mutable std::unordered_map threadContexts; mutable std::shared_mutex ctx_mutex; static thread_local Context tls_context_cache; - size_t workspaceSize; - void *workspace; + static thread_local std::thread::id tls_thread_id; public: - RuntimeObj() { allocworkspace(); } + RuntimeObj() = default; RuntimeObj(const RuntimeObj &) = delete; RuntimeObj &operator=(const RuntimeObj &) = delete; + ~RuntimeObj(); // 每个线程唯一的 Runtime static Runtime &getInstance(); @@ -40,6 +42,7 @@ class RuntimeObj : public std::enable_shared_from_this { // 获取活跃 Context Context getCurrentThreadContext() const; + // 切换当前线程的设备 void setCurrentDevice(infiniDevice_t device, int deviceId = 0); static void init(); @@ -56,15 +59,15 @@ class RuntimeObj : public std::enable_shared_from_this { infinirtMemcpyKind_t kind, infinirtStream_t stream); void *mallocAsync(size_t size, infinirtStream_t stream); void freeAsync(void *ptr, infinirtStream_t stream); + // 同步当前线程的设备 void synchronize() const; + // 获取当前 Context 的 workspace size_t getWorkspaceSize() const; void *getWorkspace(size_t size) const; bool isCpu() const; - // string toString() const; - private: - void allocworkspace(); + // void initWorkspace(size_t size = 7ll << 30); }; } // namespace infini #endif // RUNTIME_H diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 3ffd0ab..0022432 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -2,49 +2,88 @@ namespace infini { thread_local Context RuntimeObj::tls_context_cache = nullptr; +thread_local std::thread::id RuntimeObj::tls_thread_id; Runtime &RuntimeObj::getInstance() { static Runtime instance = make_ref(); return instance; } +RuntimeObj::~RuntimeObj() { + // 清理所有线程的 Context + std::unique_lock lock(ctx_mutex); + // 只清空map,不手动释放CUDA资源 + // CUDA运行时会在程序退出时自动清理所有资源 + threadContexts.clear(); +} + void RuntimeObj::initThreadContext(infiniDevice_t device, int deviceId) { - // thread_local Context currentCtx; - if (tls_context_cache) { - return; + auto current_tid = std::this_thread::get_id(); + // 检测线程复用 + if (tls_context_cache && tls_thread_id == current_tid && + tls_context_cache->device == device && + tls_context_cache->deviceId == deviceId) { + return; // 已初始化且设备相同,无需重新初始化 } - infinirtStream_t stream = nullptr; + CHECK_INFINI_ERROR(infinirtSetDevice(device, deviceId)); + + // 创建新的 stream + infinirtStream_t stream = nullptr; CHECK_INFINI_ERROR(infinirtStreamCreate(&stream)); + + // 创建新的 Context Context ctx = std::make_shared(); ctx->device = device; ctx->deviceId = deviceId; ctx->stream = stream; + ctx->workspaceSize = 7ll << 30; // 7GB + ctx->workspace = nullptr; + + // 更新缓存和全局 map tls_context_cache = ctx; + tls_thread_id = current_tid; + { std::unique_lock lock(ctx_mutex); - threadContexts[std::this_thread::get_id()] = ctx; + threadContexts[current_tid] = ctx; } + CHECK_INFINI_ERROR(infinirtMalloc(&ctx->workspace, ctx->workspaceSize)); } Context RuntimeObj::getCurrentThreadContext() const { - // thread_local Context currentCtx; - if (tls_context_cache) { + auto current_tid = std::this_thread::get_id(); + + // 检查缓存有效性 + if (tls_context_cache && tls_thread_id == current_tid) { return tls_context_cache; } + + // 从全局 map 查找 { std::shared_lock lock(ctx_mutex); - auto it = threadContexts.find(std::this_thread::get_id()); + auto it = threadContexts.find(current_tid); if (it != threadContexts.end()) { tls_context_cache = it->second; + tls_thread_id = current_tid; return it->second; } } - throw std::runtime_error("Thread context not initialized!"); + + throw std::runtime_error( + "Thread context not initialized! Call initThreadContext() first."); } void RuntimeObj::setCurrentDevice(infiniDevice_t device, int deviceId) { - CHECK_INFINI_ERROR(infinirtSetDevice(device, deviceId)); + auto ctx = getCurrentThreadContext(); + + // 如果设备相同,直接返回 + if (ctx->device == device && ctx->deviceId == deviceId) { + return; + } + + // 重新初始化 Context(force=true) + initThreadContext(device, deviceId); } void RuntimeObj::init() { CHECK_INFINI_ERROR(infinirtInit()); } @@ -54,13 +93,14 @@ void RuntimeObj::getAllDeviceCount(int *count_array) { } void RuntimeObj::run(const Graph &graph) const { + auto ctx = getCurrentThreadContext(); + IT_ASSERT(graph->checkBeforRun()); // TODO: 目前仅支持单卡,后续支持多卡 const auto &kernelRegistry = KernelRegistry::getInstance(); for (auto &op : graph->getOperators()) { - auto context = getCurrentThreadContext(); - auto device = context->device; - auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; + auto kernelAttrs = + KernelAttrs{ctx->device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); kernel->compute(op, this); } @@ -80,7 +120,7 @@ void *RuntimeObj::allocHost(size_t size) { } void *RuntimeObj::allocDevice(size_t size) { - void *ptr; + void *ptr = nullptr; CHECK_INFINI_ERROR(infinirtMalloc(&ptr, size)); return ptr; } @@ -95,6 +135,13 @@ void RuntimeObj::deallocDevice(void *ptr) { void RuntimeObj::memcpy(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind) { + // 基本指针有效性检查 + if (dst == nullptr || src == nullptr) { + std::cerr << "[ERROR] memcpy called with null pointer!" << std::endl; + // 这里应该抛出异常或返回错误,而不是继续 + throw std::runtime_error("Null pointer in memcpy"); + } + CHECK_INFINI_ERROR(infinirtMemcpy(dst, src, size, kind)); } @@ -119,19 +166,45 @@ void RuntimeObj::synchronize() const { } void *RuntimeObj::getWorkspace(size_t size) const { - IT_ASSERT(size < getWorkspaceSize(), "Workspace size is too small"); - return workspace; + auto ctx = getCurrentThreadContext(); + if (!ctx->workspace) { + throw std::runtime_error( + "Workspace not initialized! Call initWorkspace() first."); + } + return ctx->workspace; } -size_t RuntimeObj::getWorkspaceSize() const { return workspaceSize; } +size_t RuntimeObj::getWorkspaceSize() const { + auto ctx = getCurrentThreadContext(); + return ctx->workspaceSize; +} bool RuntimeObj::isCpu() const { auto context = getCurrentThreadContext(); return context->device == INFINI_DEVICE_CPU; } -void RuntimeObj::allocworkspace() { - workspaceSize = 7ll << 30; - workspace = allocDevice(workspaceSize); -} +// void RuntimeObj::initWorkspace(size_t size) { +// auto ctx = getCurrentThreadContext(); + +// // 如果已分配且大小足够,直接返回 +// if (ctx->workspace && ctx->workspaceSize >= size) { +// return; +// } + +// // 释放旧的 workspace +// if (ctx->workspace) { +// infinirtFree(ctx->workspace); +// } + +// // CPU设备不需要调用setDevice,避免与GPU线程冲突 +// if (ctx->device != INFINI_DEVICE_CPU) { +// CHECK_INFINI_ERROR(infinirtSetDevice(ctx->device, ctx->deviceId)); +// } + +// // 分配新的 workspace +// ctx->workspaceSize = size; +// ctx->workspace = nullptr; +// CHECK_INFINI_ERROR(infinirtMalloc(&ctx->workspace, size)); +// } } // namespace infini diff --git a/src/core/tensor.cc b/src/core/tensor.cc index c91b1d7..018b1d4 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -77,7 +77,6 @@ void TensorObj::setData(void *data_) { void TensorObj::dataMalloc(const Runtime &runtime) { if (data == nullptr) { data = make_ref(runtime->allocDevice(getTotalBytes())); - device = runtime->getCurrentThreadContext()->device; } else { if (runtime->getCurrentThreadContext()->device != device && device == INFINI_DEVICE_CPU) { @@ -87,6 +86,7 @@ void TensorObj::dataMalloc(const Runtime &runtime) { setData(data_ptr); } } + device = runtime->getCurrentThreadContext()->device; } ElementType TensorObj::getElement() const { diff --git a/test/kernels/test_elementwise_kernel.cc b/test/kernels/test_elementwise_kernel.cc index c62dd05..af2b04c 100644 --- a/test/kernels/test_elementwise_kernel.cc +++ b/test/kernels/test_elementwise_kernel.cc @@ -1,109 +1,423 @@ #include "core/runtime.h" #include "operators/ElementWise.h" #include "gtest/gtest.h" +#include +#include +#include +#include namespace infini { -template -void runElementWiseTest(const std::string &deviceName, infiniDevice_t DeviceT, - OpType opType, const Shape &shapeA, const Shape &shapeB, - const DataType &dataType, bool print = false) { - Runtime &runtime = RuntimeObj::getInstance(); + +// 线程测试参数 - 模板类支持任意数据类型 +template struct ThreadTestParams { + infiniDevice_t device = INFINI_DEVICE_CPU; + int deviceId = 0; + OpType opType = OpType::Unknown; + Shape shapeA; + Shape shapeB; + DataType dataType = DataType(INFINI_DTYPE_F32); + std::vector inputAData; + std::vector inputBData; + std::vector outputData; + bool completed = false; + std::string deviceName; +}; + +// 设备线程函数 - 模板函数 +template void deviceThreadFunc(ThreadTestParams ¶ms) { RuntimeObj::init(); - runtime->initThreadContext(DeviceT, 0); + Runtime &runtime = RuntimeObj::getInstance(); + + // 初始化设备 Context + runtime->initThreadContext(params.device, params.deviceId); + + // 创建 Graph Graph g = make_ref(runtime); + auto A = g->addTensor(params.shapeA, params.dataType); + auto B = g->addTensor(params.shapeB, params.dataType); + auto op = g->addOp(params.opType, A, B, nullptr); - // 创建输入张量 - auto A = g->addTensor(shapeA, dataType); - auto B = g->addTensor(shapeB, dataType); + // 先设置数据(设置CPU指针),再分配内存(触发H2D拷贝) + A->setData(params.inputAData.data()); + B->setData(params.inputBData.data()); + runtime->dataMalloc(g); // 会检测到data是CPU,执行H2D拷贝 - // 创建ElementWise算子 - auto op = g->addOp(opType, A, B, nullptr); + // 运行计算 + runtime->run(g); - // 分配内存 - runtime->dataMalloc(g); + // 获取输出并复制到 host + auto output = op->getOutput(0); + size_t numElements = output->getElement(); + params.outputData.resize(numElements); - // 设置输入数据 - size_t elementA = A->getElement(); - size_t elementB = B->getElement(); + // 检查 output 的数据是否存在 + auto dataBlob = output->getData(); + if (!dataBlob) { + throw std::runtime_error("Output data blob is null!"); + } + void *devicePtr = dataBlob->getRawDataPtr(); + if (!devicePtr && !runtime->isCpu()) { + throw std::runtime_error( + "Output device pointer is null on GPU device!"); + } + + // 复制结果数据 + void *hostPtr = runtime->allocHost(output->getTotalBytes()); + runtime->memcpy(hostPtr, devicePtr, output->getTotalBytes(), + INFINIRT_MEMCPY_D2H); + + // 根据数据类型复制 + if constexpr (std::is_same_v) { + if (params.dataType.getType() == INFINI_DTYPE_F32) { + std::memcpy(params.outputData.data(), hostPtr, + numElements * sizeof(float)); + } else if (params.dataType.getType() == INFINI_DTYPE_F16) { + // FP16 转换为 FP32 + uint16_t *fp16Data = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) { + params.outputData[i] = fp16_to_fp32(fp16Data[i]); + } + } + } else if constexpr (std::is_same_v) { + if (params.dataType.getType() == INFINI_DTYPE_F16) { + std::memcpy(params.outputData.data(), hostPtr, + numElements * sizeof(uint16_t)); + } else if (params.dataType.getType() == INFINI_DTYPE_F32) { + // FP32 转换为 FP16 + float *fp32Data = static_cast(hostPtr); + uint16_t *fp16Data = params.outputData.data(); + for (size_t i = 0; i < numElements; ++i) { + // 这里需要 FP32 转 FP16 的函数,暂时用简单的方式 + fp16Data[i] = static_cast(fp32Data[i]); + } + } + } + + runtime->deallocHost(hostPtr); + params.completed = true; +} + +// 运行多线程测试 - 模板函数 +template +void runMultiThreadTest(OpType opType, const Shape &shapeA, const Shape &shapeB, + const DataType &dataType, bool print = false) { + + // 准备输入数据 + size_t elementA = 1, elementB = 1; + for (auto dim : shapeA) + elementA *= dim; + for (auto dim : shapeB) + elementB *= dim; - // 为A和B设置不同的数据模式,方便验证结果 std::vector inputAData(elementA); std::vector inputBData(elementB); // 使用简单的递增序列和递减序列,便于计算和验证 for (size_t i = 0; i < elementA; ++i) { - inputAData[i] = static_cast(i + 1); // 1, 2, 3, ... + if constexpr (std::is_same_v) { + inputAData[i] = static_cast(i + 1); // 1, 2, 3, ... + } else if constexpr (std::is_same_v) { + inputAData[i] = static_cast(i + 1); + } } - for (size_t i = 0; i < elementB; ++i) { - inputBData[i] = static_cast(elementB - i); // n, n-1, n-2, ... + if constexpr (std::is_same_v) { + inputBData[i] = static_cast(elementB - i); // n, n-1, n-2, ... + } else if constexpr (std::is_same_v) { + inputBData[i] = static_cast(elementB - i); + } } - A->setData(inputAData.data()); - B->setData(inputBData.data()); + // 创建线程参数 + ThreadTestParams cpuParams, gpuParams; + + // CPU 线程参数 + cpuParams.device = INFINI_DEVICE_CPU; + cpuParams.deviceId = 0; + cpuParams.opType = opType; + cpuParams.shapeA = shapeA; + cpuParams.shapeB = shapeB; + cpuParams.dataType = dataType; + cpuParams.inputAData = inputAData; + cpuParams.inputBData = inputBData; + cpuParams.deviceName = "CPU"; + + // GPU 线程参数 + gpuParams.device = INFINI_DEVICE_NVIDIA; + gpuParams.deviceId = 0; + gpuParams.opType = opType; + gpuParams.shapeA = shapeA; + gpuParams.shapeB = shapeB; + gpuParams.dataType = dataType; + gpuParams.inputAData = inputAData; + gpuParams.inputBData = inputBData; + gpuParams.deviceName = "NVIDIA"; if (print) { - std::cout << "Running ElementWise Test on " << deviceName << std::endl; + std::cout << "========================================" << std::endl; + std::cout << "Running Multi-Thread ElementWise Test" << std::endl; std::cout << "OpType: " << opType.toString() << std::endl; + std::cout << "DataType: " << dataType.toString() << std::endl; std::cout << "Shape A: " << vecToString(shapeA) << std::endl; std::cout << "Shape B: " << vecToString(shapeB) << std::endl; - std::cout << "Graph: " << g->toString() << std::endl; + std::cout << "Thread 1: CPU (" << dataType.toString() << ")" + << std::endl; + std::cout << "Thread 2: NVIDIA (" << dataType.toString() << ")" + << std::endl; + std::cout << "========================================" << std::endl; } - // 执行计算 - runtime->run(g); + // 启动两个线程并行执行 + std::thread cpuThread(deviceThreadFunc, std::ref(cpuParams)); + std::thread gpuThread(deviceThreadFunc, std::ref(gpuParams)); - // 获取输出 - auto output = op->getOutput(0); + // 等待两个线程完成 + cpuThread.join(); + gpuThread.join(); + + // 验证结果 + ASSERT_TRUE(cpuParams.completed) << "CPU thread failed"; + ASSERT_TRUE(gpuParams.completed) << "NVIDIA thread failed"; + + ASSERT_EQ(cpuParams.outputData.size(), gpuParams.outputData.size()) + << "Output size mismatch"; + + // 对比结果 + size_t numErrors = 0; + float maxError = 0.0f; + const float epsilon = 1e-3f; + + for (size_t i = 0; i < cpuParams.outputData.size(); ++i) { + float cpuVal, gpuVal; + + // 转换为 float 进行比较 + if constexpr (std::is_same_v) { + cpuVal = cpuParams.outputData[i]; + gpuVal = gpuParams.outputData[i]; + } else if constexpr (std::is_same_v) { + // FP16 转 FP32 比较 + cpuVal = fp16_to_fp32(cpuParams.outputData[i]); + gpuVal = fp16_to_fp32(gpuParams.outputData[i]); + } + + float error = std::abs(cpuVal - gpuVal); + maxError = std::max(maxError, error); + + if (error > epsilon) { + numErrors++; + if (numErrors <= 5) { // 只打印前5个错误 + std::cout << "Mismatch at index " << i << ": CPU=" << cpuVal + << ", NVIDIA=" << gpuVal << ", error=" << error + << std::endl; + } + } + } if (print) { - std::cout << "Output Data: " << std::endl; - output->printData(runtime); + std::cout << "Result Comparison:" << std::endl; + std::cout << " Total elements: " << cpuParams.outputData.size() + << std::endl; + std::cout << " Errors: " << numErrors << std::endl; + std::cout << " Max error: " << maxError << std::endl; + + if (numErrors == 0) { + std::cout << " ✓ Test PASSED" << std::endl; + } else { + std::cout << " ✗ Test FAILED" << std::endl; + } + std::cout << "========================================" << std::endl; } + + EXPECT_EQ(numErrors, 0) + << "Results mismatch between CPU and NVIDIA (max error: " << maxError + << ")"; } -// 基本Add操作测试 -TEST(ElementWise, Add_Basic) { +// 基本Add操作测试 - F32 +TEST(ElementWise, Add_MultiThread_F32) { Shape shapeA = {3, 1}; Shape shapeB = {2, 3, 4}; - runElementWiseTest("CPU", INFINI_DEVICE_CPU, OpType::Add, shapeA, - shapeB, DataType(INFINI_DTYPE_F32), true); +#ifdef USE_CUDA + runMultiThreadTest(OpType::Add, shapeA, shapeB, + DataType(INFINI_DTYPE_F32), true); +#else + std::cout << "CUDA not enabled, skipping multi-thread test" << std::endl; +#endif +} + +// 基本Add操作测试 - F16 +TEST(ElementWise, Add_MultiThread_F16) { + Shape shapeA = {3, 1}; + Shape shapeB = {2, 3, 4}; #ifdef USE_CUDA - runElementWiseTest("NVIDIA", INFINI_DEVICE_NVIDIA, OpType::Add, - shapeA, shapeB, DataType(INFINI_DTYPE_F32), true); - runElementWiseTest("NVIDIA", INFINI_DEVICE_NVIDIA, OpType::Add, - shapeA, shapeB, DataType(INFINI_DTYPE_F16), - true); + runMultiThreadTest(OpType::Add, shapeA, shapeB, + DataType(INFINI_DTYPE_F16), true); +#else + std::cout << "CUDA not enabled, skipping multi-thread test" << std::endl; #endif } -// 基本Mul操作测试 -TEST(ElementWise, Mul_Basic) { +// 基本Mul操作测试 - F32 +TEST(ElementWise, Mul_MultiThread_F32) { Shape shapeA = {3, 4}; Shape shapeB = {3, 4}; - runElementWiseTest("CPU", INFINI_DEVICE_CPU, OpType::Mul, shapeA, - shapeB, DataType(INFINI_DTYPE_F32), true); +#ifdef USE_CUDA + runMultiThreadTest(OpType::Mul, shapeA, shapeB, + DataType(INFINI_DTYPE_F32), false); +#endif +} + +// 基本Mul操作测试 - F16 +TEST(ElementWise, Mul_MultiThread_F16) { + Shape shapeA = {3, 4}; + Shape shapeB = {3, 4}; #ifdef USE_CUDA - runElementWiseTest("NVIDIA", INFINI_DEVICE_NVIDIA, OpType::Mul, - shapeA, shapeB, DataType(INFINI_DTYPE_F32), true); + runMultiThreadTest(OpType::Mul, shapeA, shapeB, + DataType(INFINI_DTYPE_F16), false); #endif } -// 基本Sub操作测试 -TEST(ElementWise, Sub_Basic) { +// 基本Sub操作测试 - F32 +TEST(ElementWise, Sub_MultiThread_F32) { Shape shapeA = {1, 5, 6}; Shape shapeB = {1, 5, 6}; - runElementWiseTest("CPU", INFINI_DEVICE_CPU, OpType::Sub, shapeA, - shapeB, DataType(INFINI_DTYPE_F32), true); +#ifdef USE_CUDA + runMultiThreadTest(OpType::Sub, shapeA, shapeB, + DataType(INFINI_DTYPE_F32), false); +#endif +} + +// 基本Sub操作测试 - F16 +TEST(ElementWise, Sub_MultiThread_F16) { + Shape shapeA = {1, 5, 6}; + Shape shapeB = {1, 5, 6}; #ifdef USE_CUDA - runElementWiseTest("NVIDIA", INFINI_DEVICE_NVIDIA, OpType::Sub, - shapeA, shapeB, DataType(INFINI_DTYPE_F32), true); + runMultiThreadTest(OpType::Sub, shapeA, shapeB, + DataType(INFINI_DTYPE_F16), false); #endif } + +// 单设备测试(用于调试)- CPU +TEST(ElementWise, Add_SingleDevice_CPU) { + RuntimeObj::init(); + Runtime &runtime = RuntimeObj::getInstance(); + runtime->initThreadContext(INFINI_DEVICE_CPU, 0); + + Shape shapeA = {3, 1}; + Shape shapeB = {2, 3, 4}; + + Graph g = make_ref(runtime); + auto A = g->addTensor(shapeA, DataType(INFINI_DTYPE_F32)); + auto B = g->addTensor(shapeB, DataType(INFINI_DTYPE_F32)); + auto op = g->addOp(OpType::Add, A, B, nullptr); + + runtime->dataMalloc(g); + + // 设置输入数据 + std::vector inputAData(A->getElement()); + std::vector inputBData(B->getElement()); + + for (size_t i = 0; i < inputAData.size(); ++i) { + inputAData[i] = static_cast(i + 1); + } + for (size_t i = 0; i < inputBData.size(); ++i) { + inputBData[i] = static_cast(inputBData.size() - i); + } + + A->setData(inputAData.data()); + B->setData(inputBData.data()); + + // 执行计算 + runtime->run(g); + + // 获取输出并打印 + auto output = op->getOutput(0); + std::cout << "CPU Output Data: " << std::endl; + output->printData(runtime); +} + +#ifdef USE_CUDA +// 单设备测试(用于调试)- NVIDIA F32 +TEST(ElementWise, Add_SingleDevice_NVIDIA_F32) { + RuntimeObj::init(); + Runtime &runtime = RuntimeObj::getInstance(); + runtime->initThreadContext(INFINI_DEVICE_NVIDIA, 0); + + Shape shapeA = {3, 1}; + Shape shapeB = {2, 3, 4}; + + Graph g = make_ref(runtime); + auto A = g->addTensor(shapeA, DataType(INFINI_DTYPE_F32)); + auto B = g->addTensor(shapeB, DataType(INFINI_DTYPE_F32)); + auto op = g->addOp(OpType::Add, A, B, nullptr); + + // 设置输入数据 + std::vector inputAData(A->getElement()); + std::vector inputBData(B->getElement()); + + for (size_t i = 0; i < inputAData.size(); ++i) { + inputAData[i] = static_cast(i + 1); + } + for (size_t i = 0; i < inputBData.size(); ++i) { + inputBData[i] = static_cast(inputBData.size() - i); + } + + A->setData(inputAData.data()); + B->setData(inputBData.data()); + runtime->dataMalloc(g); + + // 执行计算 + runtime->run(g); + + // 获取输出并打印 + auto output = op->getOutput(0); + std::cout << "NVIDIA F32 Output Data: " << std::endl; + output->printData(runtime); +} + +// 单设备测试(用于调试)- NVIDIA F16 +TEST(ElementWise, Add_SingleDevice_NVIDIA_F16) { + RuntimeObj::init(); + Runtime &runtime = RuntimeObj::getInstance(); + runtime->initThreadContext(INFINI_DEVICE_NVIDIA, 0); + + Shape shapeA = {3, 1}; + Shape shapeB = {2, 3, 4}; + + Graph g = make_ref(runtime); + auto A = g->addTensor(shapeA, DataType(INFINI_DTYPE_F16)); + auto B = g->addTensor(shapeB, DataType(INFINI_DTYPE_F16)); + auto op = g->addOp(OpType::Add, A, B, nullptr); + + // 设置输入数据 + std::vector inputAData(A->getElement()); + std::vector inputBData(B->getElement()); + + for (size_t i = 0; i < inputAData.size(); ++i) { + inputAData[i] = static_cast(i + 1); + } + for (size_t i = 0; i < inputBData.size(); ++i) { + inputBData[i] = static_cast(inputBData.size() - i); + } + + A->setData(inputAData.data()); + B->setData(inputBData.data()); + runtime->dataMalloc(g); + + // 执行计算 + runtime->run(g); + + // 获取输出并打印 + auto output = op->getOutput(0); + std::cout << "NVIDIA F16 Output Data: " << std::endl; + output->printData(runtime); +} +#endif + } // namespace infini From ce39c4eb3d3c7bc431d60e012ab26c5d35db222e Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Mon, 9 Feb 2026 06:08:53 +0000 Subject: [PATCH 2/2] fix: fix runtime to apply multi-thread run && fix test file --- .github/workflows/build.yml | 3 +- format.py | 24 +- include/core/blob.h | 2 +- include/core/dtype.h | 2 +- include/core/expr.h | 8 + include/core/runtime.h | 13 +- include/utils/test_utils.h | 326 +++++++++++++++ include/utils/utils.h | 214 +++++++++- python/bindings/dtype.hpp | 32 +- python/bindings/tensor.hpp | 8 +- python/src/infinitensor/converter/registry.py | 17 +- .../src/infinitensor/torch_fx_translator.py | 136 ++++--- python/tests/conftest.py | 14 +- python/tests/test_torch_fx_translator.py | 46 +-- src/core/expr.cc | 24 +- src/core/graph.cc | 24 +- src/core/runtime.cc | 54 +-- src/core/tensor.cc | 2 +- src/operators/Gemm.cc | 2 +- src/utils/utils.cc | 131 ------ test/core/test_expr.cc | 40 +- test/core/test_graph.cc | 22 +- test/core/test_shape_expr.cc | 10 +- test/core/test_stride_expr.cc | 13 +- test/core/test_tensor_basic.cc | 51 +-- test/kernels/test_elementwise_kernel.cc | 171 ++++---- test/kernels/test_gemm_kenel.cc | 42 -- test/kernels/test_gemm_kernel.cc | 373 ++++++++++++++++++ test/operators/test_elementwise_op.cc | 21 +- test/operators/test_gemm_op.cc | 18 +- 30 files changed, 1267 insertions(+), 576 deletions(-) create mode 100644 include/utils/test_utils.h delete mode 100644 src/utils/utils.cc delete mode 100644 test/kernels/test_gemm_kenel.cc create mode 100644 test/kernels/test_gemm_kernel.cc diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 869f181..d69b1d9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,7 +28,8 @@ jobs: if [ ! -d InfiniCore ]; then git clone git@github.com:InfiniTensor/InfiniCore.git && cd InfiniCore && - git checkout f53154df00dc7005cebc49fe9080f1ea21ee1dfa + git submodule update --init && + git checkout 3c8fb3c05036c95faa2eee42bf8fcc42775edd43 else echo "InfiniCore already exists" fi diff --git a/format.py b/format.py index 344d893..9e2c15c 100644 --- a/format.py +++ b/format.py @@ -11,14 +11,26 @@ def format_file(file): file = Path(proj_path.joinpath(file)) print(file) + + # Skip if file doesn't exist + if not file.exists(): + print(f"Skipping: file does not exist - {file}") + return + if file.suffix in c_style_file: - run( - f"clang-format -style=file -i {file}", cwd=proj_path, shell=True, check=True - ) - run(f"git add {file}", cwd=proj_path, shell=True) + try: + run( + f"clang-format -style=file -i {file}", cwd=proj_path, shell=True, check=True + ) + run(f"git add {file}", cwd=proj_path, shell=True) + except Exception as e: + print(f"Error formatting file {file}: {e}") elif file.suffix == py_file: - run(f"black {file}", cwd=proj_path, shell=True, check=True) - run(f"git add {file}", cwd=proj_path, shell=True) + try: + run(f"black {file}", cwd=proj_path, shell=True, check=True) + run(f"git add {file}", cwd=proj_path, shell=True) + except Exception as e: + print(f"Error formatting file {file}: {e}") if len(sys.argv) == 1: diff --git a/include/core/blob.h b/include/core/blob.h index 2c2ec2d..58fe549 100644 --- a/include/core/blob.h +++ b/include/core/blob.h @@ -13,7 +13,7 @@ class BlobObj { BlobObj(void *ptr) : ptr(ptr) {} BlobObj(BlobObj &other) = delete; BlobObj &operator=(BlobObj const &) = delete; - ~BlobObj(){}; + ~BlobObj() {}; template T getPtr() const { return reinterpret_cast(ptr); } void *getRawDataPtr() const { return ptr; } diff --git a/include/core/dtype.h b/include/core/dtype.h index fb1cf28..a7a591d 100644 --- a/include/core/dtype.h +++ b/include/core/dtype.h @@ -32,7 +32,7 @@ class DataType { case INFINI_DTYPE_U64: return 8; case INFINI_DTYPE_F8: - return 1; // 自定义 8-bit float + return 1; case INFINI_DTYPE_F16: return 2; case INFINI_DTYPE_F32: diff --git a/include/core/expr.h b/include/core/expr.h index 771a02d..b85923e 100644 --- a/include/core/expr.h +++ b/include/core/expr.h @@ -208,5 +208,13 @@ class StrideExprObj : public BaseExprObj { Stride getConstantValue() const; }; +inline std::string vecToString(const ShapeExpr &shape) { + return shape->toString(); +} + +inline std::string vecToString(const StrideExpr &stride) { + return stride->toString(); +} + } // namespace infini #endif // EXPR_H diff --git a/include/core/runtime.h b/include/core/runtime.h index 6cb6c57..0274c4d 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -22,7 +22,6 @@ using Context = Ref; class RuntimeObj : public std::enable_shared_from_this { private: - // 全局 map: thread_id -> Context mutable std::unordered_map threadContexts; mutable std::shared_mutex ctx_mutex; static thread_local Context tls_context_cache; @@ -34,15 +33,15 @@ class RuntimeObj : public std::enable_shared_from_this { RuntimeObj &operator=(const RuntimeObj &) = delete; ~RuntimeObj(); - // 每个线程唯一的 Runtime + // Unique Runtime per thread static Runtime &getInstance(); - // 每个线程初始化自己的 Context + // Initialize each thread's own Context void initThreadContext(infiniDevice_t device, int deviceId = 0); - // 获取活跃 Context + // Get active Context Context getCurrentThreadContext() const; - // 切换当前线程的设备 + // Switch device for current thread void setCurrentDevice(infiniDevice_t device, int deviceId = 0); static void init(); @@ -59,9 +58,9 @@ class RuntimeObj : public std::enable_shared_from_this { infinirtMemcpyKind_t kind, infinirtStream_t stream); void *mallocAsync(size_t size, infinirtStream_t stream); void freeAsync(void *ptr, infinirtStream_t stream); - // 同步当前线程的设备 + // Synchronize device for current thread void synchronize() const; - // 获取当前 Context 的 workspace + // Get workspace of current Context size_t getWorkspaceSize() const; void *getWorkspace(size_t size) const; diff --git a/include/utils/test_utils.h b/include/utils/test_utils.h new file mode 100644 index 0000000..7713350 --- /dev/null +++ b/include/utils/test_utils.h @@ -0,0 +1,326 @@ +#pragma once +#ifndef TEST_UTILS_H +#define TEST_UTILS_H + +#include "core/dtype.h" +#include "utils/utils.h" + +namespace infini { +// Generate random data for testing +// Supports generating random vectors of floating-point numbers or integers +// Usage: auto data = generateRandomData(1000, -10.0f, 10.0f); +// Parameters: +// size: Data size +// min: Minimum value (default 0) +// max: Maximum value (default 100) +template +std::vector generateRandomData(size_t size, T min = static_cast(0), + T max = static_cast(100)) { + std::vector data(size); + std::random_device rd; + std::mt19937 gen(rd()); + + if constexpr (std::is_same_v || std::is_same_v) { + // Floating-point type: use uniform distribution + std::uniform_real_distribution dis(static_cast(min), + static_cast(max)); + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(dis(gen)); + } + } else { + // Integer type: use uniform distribution + std::uniform_int_distribution dis(static_cast(min), + static_cast(max)); + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(dis(gen)); + } + } + return data; +} + +// Generate sequential data (ascending or descending) for deterministic testing +// Usage: auto data = generateSequentialData(100, 1.0f, 0.5f); +// // 1.0, 1.5, 2.0, ... Parameters: +// size: Data size +// start: Starting value (default 1) +// step: Step size, supports negative values for descending (default 1) +template +std::vector generateSequentialData(size_t size, T start = static_cast(1), + T step = static_cast(1)) { + std::vector data(size); + T current = start; + for (size_t i = 0; i < size; ++i) { + data[i] = current; + current += step; + } + return data; +} +// Generic data copy and conversion function +// Copy data from device memory (hostPtr) to target vector (output), with type +// conversion as needed Parameters: +// output: Target output vector +// hostPtr: Source data pointer (pointer after device memory is copied to +// host) numElements: Number of elements dataType: Actual data type of source +// data +template +void copyAndConvertData(std::vector &output, const void *hostPtr, + size_t numElements, const DataType &dataType) { + constexpr size_t typeSize = sizeof(T); + const size_t dtypeSize = dataType.getSize(); + + if (typeSize == dtypeSize) { + // Type size matches, direct copy + std::memcpy(output.data(), hostPtr, numElements * typeSize); + } else { + // Type size mismatch, conversion needed + if constexpr (std::is_same_v) { + if (dataType.getType() == INFINI_DTYPE_F16) { + // FP16 (16-bit) -> FP32 (32-bit) + const uint16_t *srcData = + static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) { + output[i] = fp16_to_fp32(srcData[i]); + } + } else if (dataType.getType() == INFINI_DTYPE_F64) { + // F64 (64-bit) -> F32 (32-bit) + const double *srcData = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) { + output[i] = static_cast(srcData[i]); + } + } else { + // Other types -> float (integer to float) + switch (dataType.getType()) { + case INFINI_DTYPE_I8: { + const int8_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_I16: { + const int16_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_I32: { + const int32_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_I64: { + const int64_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U8: { + const uint8_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U16: { + const uint16_t *src = + static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U32: { + const uint32_t *src = + static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U64: { + const uint64_t *src = + static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + default: + throw std::runtime_error( + "Unsupported type conversion to float: " + + dataType.toString()); + } + } + } else if constexpr (std::is_same_v) { + if (dataType.getType() == INFINI_DTYPE_F32) { + // FP32 (32-bit) -> FP16 (16-bit) + const float *srcData = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) { + output[i] = fp32_to_fp16(srcData[i]); + } + } else if (dataType.getType() == INFINI_DTYPE_F64) { + // F64 (64-bit) -> FP16 (16-bit) + const double *srcData = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) { + output[i] = fp32_to_fp16(static_cast(srcData[i])); + } + } else { + throw std::runtime_error( + "Unsupported type conversion to uint16_t: " + + dataType.toString()); + } + } else if constexpr (std::is_same_v) { + if (dataType.getType() == INFINI_DTYPE_F32) { + // F32 (32-bit) -> F64 (64-bit) + const float *srcData = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) { + output[i] = static_cast(srcData[i]); + } + } else if (dataType.getType() == INFINI_DTYPE_F16) { + // FP16 (16-bit) -> F64 (64-bit) + const uint16_t *srcData = + static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) { + output[i] = static_cast(fp16_to_fp32(srcData[i])); + } + } else { + // Other types -> double + switch (dataType.getType()) { + case INFINI_DTYPE_I8: { + const int8_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_I16: { + const int16_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_I32: { + const int32_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_I64: { + const int64_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U8: { + const uint8_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U16: { + const uint16_t *src = + static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U32: { + const uint32_t *src = + static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U64: { + const uint64_t *src = + static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + default: + throw std::runtime_error( + "Unsupported type conversion to double: " + + dataType.toString()); + } + } + } else if constexpr (std::is_integral_v) { + // Generic integer type handling + switch (dataType.getType()) { + case INFINI_DTYPE_I8: { + const int8_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_I16: { + const int16_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_I32: { + const int32_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_I64: { + const int64_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U8: { + const uint8_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U16: { + const uint16_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U32: { + const uint32_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_U64: { + const uint64_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_F16: { + // FP16 -> integer (needs conversion to FP32 first) + const uint16_t *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(fp16_to_fp32(src[i])); + break; + } + case INFINI_DTYPE_F32: { + const float *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + case INFINI_DTYPE_F64: { + const double *src = static_cast(hostPtr); + for (size_t i = 0; i < numElements; ++i) + output[i] = static_cast(src[i]); + break; + } + default: + throw std::runtime_error( + "Unsupported type conversion to integer: " + + dataType.toString()); + } + } else { + throw std::runtime_error("Unsupported type conversion: T size=" + + std::to_string(typeSize) + + ", dataType=" + dataType.toString()); + } + } +} + +} // namespace infini +#endif // TEST_UTILS_H \ No newline at end of file diff --git a/include/utils/utils.h b/include/utils/utils.h index 3b47271..1461f5e 100644 --- a/include/utils/utils.h +++ b/include/utils/utils.h @@ -5,21 +5,211 @@ #include "core/common.h" #include "core/expr.h" #include +#include + namespace infini { -ShapeExpr infer_broadcast(const ShapeExpr &A, const ShapeExpr &B); -size_t calculateLinearOffset(size_t index, Shape shape, Stride stride); - -// 计算广播后的stride -// inputShape: 输入张量的形状 -// inputStride: 输入张量的原始stride -// outputShape: 输出张量的形状(广播后的形状) -// 返回: 广播后的stride,对于广播的维度,stride设置为0 -StrideExpr broadcastStride(const ShapeExpr &inputShape, - const StrideExpr &inputStride, - const ShapeExpr &outputShape); + +// Calculate stride after broadcasting +// inputShape: Shape of the input tensor +// inputStride: Original stride of the input tensor +// outputShape: Shape of the output tensor (after broadcasting) +// Returns: Stride after broadcasting, stride is set to 0 for broadcasted +// dimensions +inline StrideExpr broadcastStride(const ShapeExpr &inputShape, + const StrideExpr &inputStride, + const ShapeExpr &outputShape) { + IT_ASSERT(inputShape->size() == inputStride->size(), + "Input shape and stride must have the same rank"); + + size_t inputRank = inputShape->size(); + size_t outputRank = outputShape->size(); + + // Align dimensions from right to left (NumPy broadcasting rules) + std::vector broadcastedStride(outputRank); + + for (size_t outIdx = 0; outIdx < outputRank; ++outIdx) { + // Calculate the corresponding input dimension index (aligned from right + // to left) + int inIdx = static_cast(outIdx) - static_cast(outputRank) + + static_cast(inputRank); + + if (inIdx < 0) { + // Input dimension does not exist, equivalent to dimension being 1, + // stride is 0 + broadcastedStride[outIdx] = ExprObj::constant(0); + } else { + // Check if input dimension is 1 (broadcast dimension) or equal to + // output dimension + const Expr &inputDim = (*inputShape)[inIdx]; + const Expr &outputDim = (*outputShape)[outIdx]; + + // If input dimension is constant 1, it is a broadcast dimension, + // stride is 0 + auto inputDimConst = inputDim->asConstant(); + if (inputDimConst.has_value() && *inputDimConst == 1) { + broadcastedStride[outIdx] = ExprObj::constant(0); + } else { + // Dimensions are equal or unequal but valid (guaranteed by + // infer_broadcast), use original stride + broadcastedStride[outIdx] = (*inputStride)[inIdx]; + } + } + } + + return make_ref(broadcastedStride); +} + +inline ShapeExpr infer_broadcast(const ShapeExpr &A, const ShapeExpr &B) { + size_t rankA = A->size(); + size_t rankB = B->size(); + size_t rank = std::max(rankA, rankB); + + std::vector resultDims; + for (size_t i = 0; i < rank; ++i) { + // Align from right to left (NumPy broadcasting rules) + int idxA = static_cast(i) - static_cast(rank) + + static_cast(rankA); + int idxB = static_cast(i) - static_cast(rank) + + static_cast(rankB); + + // If index is negative, the dimension does not exist, equivalent to + // dimension being 1 + Expr aDim = (idxA < 0) ? ExprObj::constant(1) : (*A)[idxA]; + Expr bDim = (idxB < 0) ? ExprObj::constant(1) : (*B)[idxB]; + + // Validate broadcasting rules: dimensions must be equal, or one of them + // is 1 + IT_ASSERT(aDim == bDim || aDim == ExprObj::constant(1) || + bDim == ExprObj::constant(1)); + + // Broadcasting result: if a dimension is 1, take b dimension, otherwise + // take a dimension + auto shapeEle = aDim == ExprObj::constant(1) ? bDim : aDim; + resultDims.emplace_back(shapeEle); + } + + return make_ref(resultDims); +} + +inline size_t calculateLinearOffset(size_t index, Shape shape, Stride stride) { + size_t rank = shape.size(); + std::vector indices(rank); + size_t remaining = index; + for (size_t i = 0; i < rank; ++i) { + size_t dim = rank - 1 - i; + indices[dim] = remaining % shape.at(dim); + remaining /= shape.at(dim); + } + size_t offset = 0; + for (size_t i = 0; i < rank; ++i) { + offset += indices[i] * stride.at(i); + } + return offset; +} // FP16 to FP32 conversion utility -float fp16_to_fp32(uint16_t fp16); +inline float fp16_to_fp32(uint16_t fp16) { + // Union for safe type punning + union { + uint32_t u; + float f; + } converter; + + // Extract components from FP16 + uint32_t sign = (fp16 >> 15) & 0x1; + uint32_t exponent = (fp16 >> 10) & 0x1F; + uint32_t mantissa = fp16 & 0x3FF; + + // Handle special cases + if (exponent == 0) { + if (mantissa == 0) { + // Zero + converter.u = sign << 31; + return converter.f; + } else { + // Subnormal number: normalize it + while (!(mantissa & 0x400)) { + mantissa <<= 1; + exponent--; + } + exponent++; + mantissa &= 0x3FF; + } + } else if (exponent == 31) { + // Infinity or NaN + converter.u = (sign << 31) | 0x7F800000; + if (mantissa) { + converter.u |= mantissa; // NaN + } + return converter.f; + } + + // Convert to FP32 + // FP32: 1 sign bit, 8 exponent bits (bias 127), 23 mantissa bits + // FP16: 1 sign bit, 5 exponent bits (bias 15), 10 mantissa bits + converter.u = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13); + return converter.f; +} + +// FP32 to FP16 conversion utility +inline uint16_t fp32_to_fp16(float fp32) { + // Union for safe type punning + union { + uint32_t u; + float f; + } converter; + converter.f = fp32; + + // Extract components from FP32 + uint32_t sign = (converter.u >> 31) & 0x1; + uint32_t exponent = (converter.u >> 23) & 0xFF; + uint32_t mantissa = converter.u & 0x7FFFFF; + + // Handle special cases + if (exponent == 0) { + // Zero or subnormal FP32 (very small, treat as zero in FP16) + return static_cast(sign << 15); + } else if (exponent == 255) { + // Infinity or NaN + uint16_t result = static_cast((sign << 15) | 0x7C00); + if (mantissa) { + // NaN - preserve some mantissa bits + result |= static_cast(mantissa >> 13); + } + return result; + } + + // Convert to FP16 + // FP32 exponent bias: 127, FP16 exponent bias: 15 + // New exponent = exponent - 127 + 15 = exponent - 112 + int32_t newExponent = static_cast(exponent) - 112; + + // Handle overflow/underflow + if (newExponent >= 31) { + // Overflow: return infinity + return static_cast((sign << 15) | 0x7C00); + } else if (newExponent <= 0) { + // Underflow: return zero (could also handle subnormals) + return static_cast(sign << 15); + } + + // Normal number: round to nearest even + uint32_t roundedMantissa = mantissa + 0x1000; // Add rounding bias + if (roundedMantissa & 0x800000) { + // Rounding caused overflow + roundedMantissa = 0; + newExponent++; + if (newExponent >= 31) { + // Overflow after rounding + return static_cast((sign << 15) | 0x7C00); + } + } + + // Assemble FP16 + return static_cast((sign << 15) | (newExponent << 10) | + (roundedMantissa >> 13)); +} + } // namespace infini #endif diff --git a/python/bindings/dtype.hpp b/python/bindings/dtype.hpp index 81e321a..4bd821a 100644 --- a/python/bindings/dtype.hpp +++ b/python/bindings/dtype.hpp @@ -13,25 +13,25 @@ namespace py = pybind11; namespace infini { void bind_data_type(py::module &m) { py::class_(m, "DType") - // 构造函数 + // Constructor .def(py::init()) - // 方法 + // Methods .def("get_size", &DataType::getSize, "Get the size in bytes") .def("get_type", &DataType::getType, "Get the underlying enum type") .def("to_string", &DataType::toString, "Get string representation") - // 属性(通过方法暴露) + // Properties (exposed through methods) .def_property_readonly("size", &DataType::getSize) .def_property_readonly("type", &DataType::getType) .def_property_readonly("name", &DataType::toString) - // 运算符重载 + // Operator overloading .def(py::self == py::self) .def(py::self != py::self); } inline DataType dtype_from_string(const std::string &dtype_str) { - // 从字符串创建DataType + // Create DataType from string static std::unordered_map str_to_dtype = { {"byte", INFINI_DTYPE_BYTE}, {"bool", INFINI_DTYPE_BOOL}, @@ -85,7 +85,7 @@ inline DataType dtype_from_string(const std::string &dtype_str) { } inline std::string dtype_to_string(const DataType &dtype) { - // 从DataType创建字符串 + // Create string from DataType static std::unordered_map dtype_to_str = { {INFINI_DTYPE_BYTE, "byte"}, {INFINI_DTYPE_BOOL, "bool"}, {INFINI_DTYPE_I8, "int8"}, {INFINI_DTYPE_I16, "int16"}, @@ -105,31 +105,11 @@ inline std::string dtype_to_string(const DataType &dtype) { std::to_string(dtype.getType())); } -// inline torch::ScalarType dtype_to_torch_scalar_type(const DataType &dtype) { -// static std::unordered_map dtype_to_torch{ -// {INFINI_DTYPE_BOOL, torch::kBool}, {INFINI_DTYPE_I8, torch::kInt8}, -// {INFINI_DTYPE_I16, torch::kInt16}, {INFINI_DTYPE_I32, -// torch::kInt32}, {INFINI_DTYPE_I64, torch::kInt64}, {INFINI_DTYPE_U8, -// torch::kUInt8}, {INFINI_DTYPE_U16, torch::kUInt16}, {INFINI_DTYPE_U32, -// torch::kUInt32}, {INFINI_DTYPE_U64, torch::kUInt64}, {INFINI_DTYPE_F16, -// torch::kFloat16}, {INFINI_DTYPE_F32, torch::kFloat32}, -// {INFINI_DTYPE_F64, torch::kFloat64}, {INFINI_DTYPE_BF16, -// torch::kBFloat16}}; -// auto it = dtype_to_torch.find(dtype.getType()); -// if (it != dtype_to_torch.end()) { -// return it->second; -// } -// throw std::runtime_error("Unsupported Convert DataType to torch: " + -// dtype.toString()); -// } - void bind_dtype_functions(py::module &m) { m.def("dtype_from_string", &dtype_from_string, "Create DataType from string", py::arg("dtype_str")) .def("dtype_to_string", &dtype_to_string, "Convert DataType to string", py::arg("dtype")); - // m.def("dtype_to_torch_scalar_type", &dtype_to_torch_scalar_type, - // "Convert DataType to torch::ScalarType", py::arg("dtype")); } } // namespace infini diff --git a/python/bindings/tensor.hpp b/python/bindings/tensor.hpp index 764e950..820c3d5 100644 --- a/python/bindings/tensor.hpp +++ b/python/bindings/tensor.hpp @@ -88,12 +88,8 @@ void bind_tensor(py::module &m) { auto stride_vec = py::cast(stride); auto dtype_str = dtype_to_string(data_type); auto data_ptr_int = reinterpret_cast(data_ptr); - return py::make_tuple(data_ptr_int, // 数据指针 - shape_vec, // 形状 - stride_vec, // 步长 - dtype_str, // 数据类型字符串 - self.getTotalBytes() // 存储大小 - ); + return py::make_tuple(data_ptr_int, shape_vec, stride_vec, + dtype_str, self.getTotalBytes()); }) .def("set_data", [](TensorObj &self, uintptr_t ptr, Runtime &runtime) { diff --git a/python/src/infinitensor/converter/registry.py b/python/src/infinitensor/converter/registry.py index 1e190c1..e6c2a51 100644 --- a/python/src/infinitensor/converter/registry.py +++ b/python/src/infinitensor/converter/registry.py @@ -12,7 +12,7 @@ def __init__(self): self._method_converters: Dict[str, Dict[Optional[str], Callable]] = {} def register(self, op_name: str, overload: Optional[str] = None): - """装饰器:注册方法和函数转换器""" + """Decorator: register method and function converters""" def decorator(func): self._method_converters.setdefault(op_name, {})[overload] = func @@ -20,8 +20,10 @@ def decorator(func): return decorator - def get_method_converter(self, op_name: str, overload: Optional[str] = None) -> Optional[Callable]: - """获取方法和函数转换器""" + def get_method_converter( + self, op_name: str, overload: Optional[str] = None + ) -> Optional[Callable]: + """Get method and function converter""" if op_name in self._method_converters: table = self._method_converters[op_name] if overload: @@ -37,9 +39,8 @@ def get_method_converter(self, op_name: str, overload: Optional[str] = None) -> else: raise ValueError(f"Unsupported op : {op_name}") - def update(self, custom_converters: Dict): - """更新转换器 + """Update converters Args: custom_converters: { @@ -56,15 +57,15 @@ def update(self, custom_converters: Dict): raise TypeError(f"Invalid key type: {type(key)}") def clear(self): - """清空所有转换器""" + """Clear all converters""" self._method_converters.clear() def list_all_converters(self): - """列出所有转换器""" + """List all converters""" return { "methods": list(self._method_converters.keys()), } -# 全局注册器实例 +# Global registry instance registry = ConverterRegistry() diff --git a/python/src/infinitensor/torch_fx_translator.py b/python/src/infinitensor/torch_fx_translator.py index 41a7dff..e4320ed 100644 --- a/python/src/infinitensor/torch_fx_translator.py +++ b/python/src/infinitensor/torch_fx_translator.py @@ -1,6 +1,13 @@ import ctypes import pyinfinitensor -from pyinfinitensor import GraphBuilder, Tensor, dtype_from_string, Runtime, ShapeExpr, StrideExpr +from pyinfinitensor import ( + GraphBuilder, + Tensor, + dtype_from_string, + Runtime, + ShapeExpr, + StrideExpr, +) import torch from torch import fx from torch.export import export, Dim @@ -16,18 +23,22 @@ def __init__(self, runtime: Runtime, custom_converters: Optional[Dict] = None): self.builder = None self.nodes_map: Dict[fx.Node, Any] = ( {} - ) # 存储fx.Node映射关系,不论是Tensor还是Callable - self.tensors: Dict[fx.Node, Tensor] = {} # 存储所有张量 - self.params: Dict[torch.Tensor, Tensor] = {} # 存储所有参数 - self.outputs: List[Tensor] = [] # 存储输出张量 + ) # Store fx.Node mapping relationship, whether Tensor or Callable + self.tensors: Dict[fx.Node, Tensor] = {} # Store all tensors + self.params: Dict[torch.Tensor, Tensor] = {} # Store all parameters + self.outputs: List[Tensor] = [] # Store output tensors self.input_vars: Dict[str, Tensor] = {} - self.symbols = {} # 符号 -> {'var': 变量名, 'value': 具体值, 'info': 详细信息} - self.dynamic_input_infos: List[Tuple[Tuple, Tuple, str]] = [] # 动态输入信息(shape, stride, dtype) + self.symbols = ( + {} + ) # Symbol -> {'var': variable name, 'value': concrete value, 'info': detailed info} + self.dynamic_input_infos: List[Tuple[Tuple, Tuple, str]] = ( + [] + ) # Dynamic input information (shape, stride, dtype) if custom_converters: registry.update(custom_converters) def _add_symbol(self, symbol_str, input_idx, dim_idx): - """添加符号信息""" + """Add symbol information""" if symbol_str in self.symbols: self.symbols[symbol_str]["info"]["input_idx"].append(input_idx) self.symbols[symbol_str]["info"]["dim_idx"].append(dim_idx) @@ -35,7 +46,7 @@ def _add_symbol(self, symbol_str, input_idx, dim_idx): var_name = f"symbolic_{symbol_str}" self.symbols[symbol_str] = { "var": var_name, - "value": None, # 初始化为None,表示未绑定 + "value": None, # Initialize to None, indicating unbound "info": { "input_idx": [input_idx], "dim_idx": [dim_idx], @@ -43,32 +54,30 @@ def _add_symbol(self, symbol_str, input_idx, dim_idx): } def _clear_symbols(self): - """清空符号信息""" + """Clear symbol information""" for symbol_str in self.symbols: self.symbols[symbol_str]["value"] = None def _add_dynamic_shapes(self, model, input_list): """ - 为每个 Tensor 生成: + Generate for each Tensor: arg_{i}: {0: Dim.AUTO, 1: Dim.AUTO, ...} """ sig = inspect.signature(model.forward) param_names = [p.name for p in sig.parameters.values() if p.name != "self"] - assert(len(param_names) == len(input_list)) + assert len(param_names) == len(input_list) dynamic_shapes = {} for idx, (p, t) in enumerate(zip(param_names, input_list)): if not isinstance(t, torch.Tensor): raise ValueError("input is not torch Tensor") - dynamic_shapes[p] = { - dim: Dim.AUTO for dim in range(t.dim()) - } + dynamic_shapes[p] = {dim: Dim.AUTO for dim in range(t.dim())} return dynamic_shapes def _create_input_tensors( self, input_list: List[torch.Tensor], is_real_tensor: bool ) -> List: - """创建输入张量""" - # dynamic_input_infos是通过从图文件中提取的动态形状信息,input_info是用户提供的静态形状信息 + """Create input tensors""" + # dynamic_input_infos is dynamic shape information extracted from graph files, input_info is static shape information provided by user input_tensors = [] if len(self.dynamic_input_infos) != 0 and len(input_list) != len( self.dynamic_input_infos @@ -88,55 +97,57 @@ def _create_input_tensors( self.input_vars[f"inp_{i}"] = tensor else: for i, (shape, stride, dtype) in enumerate(self.dynamic_input_infos): - tensor = self.builder.tensor(ShapeExpr(shape), dtype, StrideExpr(stride)) + tensor = self.builder.tensor( + ShapeExpr(shape), dtype, StrideExpr(stride) + ) input_tensors.append(tensor) self.input_vars[f"inp_{i}"] = tensor return input_tensors def _process_dynamic_shapes(self, fake_inputs): - """处理动态形状""" + """Handle dynamic shapes""" for i, tensor in enumerate(fake_inputs.values()): shape = tensor.shape stride = tensor.stride - assert(len(shape) == len(stride)) + assert len(shape) == len(stride) tensor_shape = [] tensor_stride = [] dtype = dtype_from_string(str(tensor.dtype)) for j, (dim, st) in enumerate(zip(shape, stride[::-1])): - # 处理shape信息 + # Handle shape information if ( hasattr(torch, "SymInt") and isinstance(dim, torch.SymInt) and not str(dim).isdigit() ): - # 处理符号维度 + # Handle symbolic dimension sym_str = str(dim) self._add_symbol(sym_str, i, j) tensor_shape.append(self.symbols[sym_str]["var"]) else: - # 具体维度 + # Concrete dimension tensor_shape.append(int(dim)) - # 处理stride信息 + # Handle stride information if ( hasattr(torch, "SymInt") and isinstance(st, torch.SymInt) and not str(st).isdigit() ): - # 处理符号维度 + # Handle symbolic dimension sym_str = str(st) - assert(self.symbols.get(sym_str)) + assert self.symbols.get(sym_str) tensor_stride.insert(0, self.symbols[sym_str]["var"]) else: - # 具体维度 + # Concrete dimension tensor_stride.insert(0, int(st)) self.dynamic_input_infos.append((tensor_shape, tensor_stride, dtype)) def _process_call_function(self, node): - """处理函数调用节点""" + """Handle function call nodes""" target = node.target - if hasattr(target,"_overloadpacket"): - op_name = str(target._overloadpacket).split('.')[-1] + if hasattr(target, "_overloadpacket"): + op_name = str(target._overloadpacket).split(".")[-1] overload = target._overloadname function = registry.get_method_converter(op_name, overload) else: @@ -155,7 +166,7 @@ def _process_call_function(self, node): raise ValueError(f"Unsupported function: {func_name}") def _process_output(self, node): - """处理输出节点""" + """Handle output nodes""" args = self._retrieve_args(node.args) assert len(args) == 1 if isinstance(args[0], (tuple, list)): @@ -181,7 +192,7 @@ def _retrieve_args(self, node): return node def _tensor_from_torch_info(self, torch_info): - """从Torch信息创建张量""" + """Create tensor from Torch information""" data_ptr_int, shape, stride, dtype_str, storage_size = torch_info dtype = getattr(torch, dtype_str) buf_type = ctypes.c_char * storage_size @@ -192,80 +203,91 @@ def _tensor_from_torch_info(self, torch_info): def _extract_graph_signature(self): fake_inputs = {} + def transform_parameter_string(s): - return re.sub(r'_', '.', re.sub(r'^p_', '', s)) + return re.sub(r"_", ".", re.sub(r"^p_", "", s)) + def transform_buffer_string(s): - return re.sub(r'_', '.', re.sub(r'^b_', '', s)) + return re.sub(r"_", ".", re.sub(r"^b_", "", s)) + nodes = list(self.module.graph_module.graph.nodes) for i, spec in enumerate(self.module.graph_signature.input_specs): kind = spec.kind.name node = nodes[i] name = spec.arg.name if kind == "PARAMETER": - assert(node.op == "placeholder" and isinstance( - node.meta["val"], torch._subclasses.fake_tensor.FakeTensor - )) + assert node.op == "placeholder" and isinstance( + node.meta["val"], torch._subclasses.fake_tensor.FakeTensor + ) shape_expr = ShapeExpr(node.meta["tensor_meta"].shape) stride_expr = StrideExpr(node.meta["tensor_meta"].stride) dtype = dtype_from_string(str(node.meta["tensor_meta"].dtype)) self.params[name] = self.builder.tensor(shape_expr, dtype, stride_expr) - self.params[name].set_data(self.module.state_dict[transform_parameter_string(name)].data_ptr(),self.runtime) + self.params[name].set_data( + self.module.state_dict[transform_parameter_string(name)].data_ptr(), + self.runtime, + ) self.nodes_map[node] = self.params[name] self.tensors[node] = self.params[name] elif kind == "BUFFER": if len(node.users) == 0: continue - assert(node.op == "placeholder" and isinstance( - node.meta["val"], torch._subclasses.fake_tensor.FakeTensor - )) + assert node.op == "placeholder" and isinstance( + node.meta["val"], torch._subclasses.fake_tensor.FakeTensor + ) shape_expr = ShapeExpr(node.meta["tensor_meta"].shape) stride_expr = StrideExpr(node.meta["tensor_meta"].stride) dtype = dtype_from_string(str(node.meta["tensor_meta"].dtype)) self.params[name] = self.builder.tensor(shape_expr, dtype, stride_expr) - self.params[name].set_data(self.module.state_dict[transform_buffer_string(name)].data_ptr(),self.runtime) + self.params[name].set_data( + self.module.state_dict[transform_buffer_string(name)].data_ptr(), + self.runtime, + ) self.nodes_map[node] = self.params[name] self.tensors[node] = self.params[name] elif kind == "USER_INPUT": if "val" in node.meta and isinstance( - node.meta["val"], torch._subclasses.fake_tensor.FakeTensor): + node.meta["val"], torch._subclasses.fake_tensor.FakeTensor + ): fake_tensor = node.meta["tensor_meta"] fake_inputs[node] = fake_tensor else: raise ValueError(f"Unsupported input kind: {kind}") - - return fake_inputs + return fake_inputs def import_from_fx( self, model, input_list: List[torch.Tensor], is_real_tensor: bool = False ): """ - 导入FX图到计算图框架 + Import FX graph to computation graph framework Args: model: PyTorch Model - input_list: 输入张量列表 + input_list: Input tensor list """ self.builder = GraphBuilder(self.runtime) dynamic_shapes = self._add_dynamic_shapes(model, input_list) try: - self.module = export(model, tuple(input_list), dynamic_shapes=dynamic_shapes) + self.module = export( + model, tuple(input_list), dynamic_shapes=dynamic_shapes + ) except: raise RuntimeError("Failed to export the PyTorch model to FX.") - # 解析graph_signature,提取params、buffers、inputs、outputs + # Parse graph_signature, extract params, buffers, inputs, outputs fake_inputs = self._extract_graph_signature() - # 提取符号形状信息 + # Extract symbolic shape information self._process_dynamic_shapes(fake_inputs) - # 创建输入张量 + # Create input tensors inputs = self._create_input_tensors(input_list, is_real_tensor) - for (node, tensor) in zip(fake_inputs.keys(), inputs): + for node, tensor in zip(fake_inputs.keys(), inputs): self.nodes_map[node] = tensor self.tensors[node] = tensor - # 处理FX图节点 + # Process FX graph nodes for node in self.module.graph_module.graph.nodes: if node.op == "placeholder": continue @@ -281,10 +303,10 @@ def import_from_fx( def run(self, input_list: List[torch.Tensor]): """ - 运行计算图 + Run computation graph Args: - input_list: 输入张量列表 + input_list: Input tensor list """ self._clear_symbols() if len(input_list) != len(self.dynamic_input_infos): @@ -318,10 +340,10 @@ def run(self, input_list: List[torch.Tensor]): def get_outputs(self) -> List[torch.Tensor]: """ - 获取输出Torch张量 + Get output Torch tensors Returns: - outputs: 输出Torch张量列表 + outputs: Output Torch tensor list """ outputs = [] for output in self.outputs: diff --git a/python/tests/conftest.py b/python/tests/conftest.py index b3a141e..0e37574 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -5,7 +5,7 @@ def pytest_addoption(parser): - """添加自定义命令行选项""" + """Add custom command line options""" parser.addoption( "--device", action="store", @@ -30,7 +30,7 @@ def pytest_addoption(parser): @pytest.fixture(scope="session") def device_type(request): - """从命令行参数获取设备类型""" + """Get device type from command line arguments""" device_name = request.config.getoption("--device").upper() device_map = { @@ -50,13 +50,13 @@ def device_type(request): @pytest.fixture(scope="session") def device_id(request): - """从命令行参数获取设备ID""" + """Get device ID from command line arguments""" return int(request.config.getoption("--device-id")) @pytest.fixture(scope="session") def runtime(device_type, device_id): - """根据命令行参数创建runtime""" + """Create runtime based on command line arguments""" print(f"\n{'='*60}") print(f"Creating runtime with:") print(f" Device Type: {device_type}") @@ -64,14 +64,14 @@ def runtime(device_type, device_id): print(f"{'='*60}\n") try: - # 如果有设置设备ID的API + # If there is an API to set device ID rt = Runtime.setup(device_type, device_id=device_id) print(f"✅ Runtime created successfully") return rt except Exception as e: print(f"❌ Failed to create runtime: {e}") - # 回退到CPU + # Fall back to CPU if device_type != DeviceType.CPU: print("🔄 Falling back to CPU...") return Runtime.setup(DeviceType.CPU) @@ -81,6 +81,6 @@ def runtime(device_type, device_id): @pytest.fixture def torch_rng_seed(): - """固定随机种子,确保测试可重现""" + """Fix random seed to ensure test reproducibility""" torch.manual_seed(42) yield 42 diff --git a/python/tests/test_torch_fx_translator.py b/python/tests/test_torch_fx_translator.py index 66abb0f..356b826 100644 --- a/python/tests/test_torch_fx_translator.py +++ b/python/tests/test_torch_fx_translator.py @@ -7,57 +7,57 @@ def test_basic_matmul(runtime, torch_rng_seed): - """直接使用conftest.py中定义的fixtures""" + """Use fixtures defined in conftest.py directly""" print(f"Testing with runtime on device: {runtime}") print(f"Random seed: {torch_rng_seed}") - # 创建简单模型 + # Create simple model class MatmulModel(torch.nn.Module): def forward(self, x, y): return torch.matmul(x, y) model = MatmulModel() - # 随机初始化输入,传入形状可以与真实传入值不一样,但是数据类型需要一致 + # Randomly initialize inputs, passed shapes can differ from actual values, but data types must match input_info = [((5, 4), "float32"), ((4, 3), "float32")] input_tensors = [ torch.as_tensor(np.random.randn(*shape).astype(dtype)) for shape, dtype in input_info ] - # 创建转换器 + # Create translator translator = TorchFXTranslator(runtime) translator.import_from_fx(model, input_tensors) - # 运行 + # Run translator.run(input_tensors) - # 获取输出 + # Get outputs outputs = translator.get_outputs() - # 验证 + # Verify assert len(outputs) == 1 assert outputs[0].shape == (1, 5, 3) print("✅ Test passed!") def test_dynamic_matmul(runtime, torch_rng_seed): - """直接使用conftest.py中定义的fixtures""" + """Use fixtures defined in conftest.py directly""" print(f"Testing with runtime on device: {runtime}") print(f"Random seed: {torch_rng_seed}") - # 创建简单模型 + # Create simple model class MatmulModel(torch.nn.Module): def forward(self, x, y): return torch.matmul(x, y) model = MatmulModel() - # 随机初始化输入,传入形状可以与真实传入值不一样,但是数据类型需要一致 + # Randomly initialize inputs, passed shapes can differ from actual values, but data types must match input_info = [((5, 4), "float32"), ((4, 7), "float32")] input_tensors = [ torch.as_tensor(np.random.randn(*shape).astype(dtype)) for shape, dtype in input_info ] - # 创建转换器 + # Create translator translator = TorchFXTranslator(runtime) translator.import_from_fx(model, input_tensors) @@ -80,48 +80,50 @@ def forward(self, x, y): assert outputs[0].shape == (1, 3, 10) print("✅ Test passed!") + def test_basic_elementwise(runtime, torch_rng_seed): - """直接使用conftest.py中定义的fixtures""" + """Use fixtures defined in conftest.py directly""" print(f"Testing with runtime on device: {runtime}") print(f"Random seed: {torch_rng_seed}") - # 创建简单模型 + # Create simple model class AddModel(torch.nn.Module): def forward(self, x, y): return x + y model = AddModel() - # 随机初始化输入,传入形状可以与真实传入值不一样,但是数据类型需要一致 + # Randomly initialize inputs, passed shapes can differ from actual values, but data types must match input_info = [((5, 4), "float32"), ((3, 5, 1), "float32")] input_tensors = [ torch.as_tensor(np.random.randn(*shape).astype(dtype)) for shape, dtype in input_info ] - # 创建转换器 + # Create translator translator = TorchFXTranslator(runtime) translator.import_from_fx(model, input_tensors) translator.run(input_tensors) - # 获取输出 + # Get outputs outputs = translator.get_outputs() - # 验证 + # Verify assert len(outputs) == 1 assert outputs[0].shape == (3, 5, 4) print("✅ Test passed!") + if __name__ == "__main__": - # 可以直接运行这个文件 + # Can run this file directly import sys - # 使用pytest运行所有测试 + # Run all tests using pytest exit_code = pytest.main( [ __file__, - "-v", # 详细输出 - "-s", # 显示print输出 - "--tb=short", # 简化的错误回溯 + "-v", # Verbose output + "-s", # Show print output + "--tb=short", # Simplified error traceback ] ) diff --git a/src/core/expr.cc b/src/core/expr.cc index 597b6b3..ae1418f 100644 --- a/src/core/expr.cc +++ b/src/core/expr.cc @@ -2,7 +2,7 @@ namespace infini { //================================== -// Static Factory Functions 实现 +// Static Factory Functions Implementation //================================== Expr ExprObj::constant(ElementType value) { return Expr(new ConstantExprObj(value)); @@ -41,7 +41,7 @@ Expr ExprObj::createMax(const Expr &lhs, const Expr &rhs) { } //================================== -// 运算符重载实现 +// Operator Overload Implementation //================================== Expr operator+(const Expr &lhs, const Expr &rhs) { return ExprObj::createAdd(lhs, rhs); @@ -68,7 +68,7 @@ bool operator==(const Expr &lhs, const Expr &rhs) { return lhs->equals(rhs); } bool operator!=(const Expr &lhs, const Expr &rhs) { return !(lhs == rhs); } //================================== -// ConstantExprObj 实现 +// ConstantExprObj Implementation //================================== ConstantExprObj::ConstantExprObj(ElementType v) : value(v) {} @@ -97,7 +97,7 @@ bool ConstantExprObj::equals(const Expr &other) const { std::optional ConstantExprObj::asConstant() const { return value; } //================================== -// VariableExprObj 实现 +// VariableExprObj Implementation //================================== VariableExprObj::VariableExprObj(std::string n) : name(std::move(n)) {} @@ -126,7 +126,7 @@ bool VariableExprObj::equals(const Expr &other) const { } //================================== -// BinaryExprObj 实现 +// BinaryExprObj Implementation //================================== BinaryExprObj::BinaryExprObj(Expr l, Expr r) : lhs(std::move(l)), rhs(std::move(r)) {} @@ -147,7 +147,7 @@ bool BinaryExprObj::equals(const Expr &other) const { } //================================== -// Achieve for simple binary expr +// Implementation for simple binary expressions //================================== #define IMPLEMENT_BINARY_EXPR(CLASS, TYPE_ENUM, OP, STR) \ ExprObj::Type CLASS::getType() const { return Type::TYPE_ENUM; } \ @@ -185,7 +185,7 @@ IMPLEMENT_BINARY_EXPR(DivExprObj, DIV, /, " / ") IMPLEMENT_BINARY_EXPR(ModExprObj, MOD, %, " % ") //================================== -// MinExprObj 实现 +// MinExprObj Implementation //================================== ExprObj::Type MinExprObj::getType() const { return Type::MIN; } @@ -213,7 +213,7 @@ Expr MinExprObj::simplify() const { } //================================== -// MaxExprObj 实现 +// MaxExprObj Implementation //================================== ExprObj::Type MaxExprObj::getType() const { return Type::MAX; } @@ -241,7 +241,7 @@ Expr MaxExprObj::simplify() const { } //================================== -// BaseExprObj 实现 +// BaseExprObj Implementation //================================== BaseExprObj::BaseExprObj() = default; @@ -304,7 +304,7 @@ void BaseExprObj::insert(size_t pos, const Expr &value) { } //================================== -// ShapeExprObj 实现 +// ShapeExprObj Implementation //================================== std::optional> ShapeExprObj::evaluate( const std::unordered_map &values) const { @@ -337,7 +337,7 @@ Shape ShapeExprObj::getConstantValue() const { } //================================== -// StrideExprObj 实现 +// StrideExprObj Implementation //================================== std::optional> StrideExprObj::evaluate( const std::unordered_map &values) const { @@ -370,7 +370,7 @@ Stride StrideExprObj::getConstantValue() const { } //================================== -// ShapeExpr比较运算符实现 +// ShapeExpr Comparison Operator Implementation //================================== bool operator==(const ShapeExpr &lhs, const ShapeExpr &rhs) { return lhs->equals(rhs); diff --git a/src/core/graph.cc b/src/core/graph.cc index e52830a..f578231 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -127,7 +127,7 @@ bool GraphObj::topo_sort() { } if (sorted.size() != ops.size()) { - // 有环,拓扑失败 + // Cycle detected, topological sort failed return false; } @@ -190,26 +190,26 @@ void GraphObj::addOperatorAndConnect(const Operator &op) { } bool GraphObj::checkValid() const { - // 构建快速查找集合 + // Build fast lookup sets std::unordered_set tensorSet(tensors.begin(), tensors.end()); std::unordered_set opSet(ops.begin(), ops.end()); - // 1. 检查所有 Tensor + // 1. Check all Tensors for (auto tensor : tensors) { - // 必须有 source 或 targets + // Must have source or targets IT_ASSERT( !(tensor->getTargets().empty() && tensor->getSource() == nullptr), "Invalid tensor: " + tensor->toString() + " has no source and no targets"); - // 检查 target ops 是否都在 Graph + // Check if all target ops are in Graph for (auto op : tensor->getTargets()) { IT_ASSERT(opSet.count(op), "Tensor " + tensor->toString() + " has target op not in graph: " + op->toString()); } - // 检查 source op 是否在 Graph + // Check if source op is in Graph if (auto src = tensor->getSource()) { IT_ASSERT(opSet.count(src), "Tensor " + tensor->toString() + @@ -217,23 +217,23 @@ bool GraphObj::checkValid() const { } } - // 2. 检查所有 Operator + // 2. Check all Operators for (auto op : ops) { - // 输入 tensor 必须存在 + // Input tensor must exist for (auto tensor : op->getInputs()) { IT_ASSERT( tensorSet.count(tensor), "Op " + op->toString() + " has input tensor not in graph: " + tensor->toString()); } - // 输出 tensor 必须存在 + // Output tensor must exist for (auto tensor : op->getOutputs()) { IT_ASSERT( tensorSet.count(tensor), "Op " + op->toString() + " has output tensor not in graph: " + tensor->toString()); } - // 前驱/后继必须存在 + // Predecessors/Successors must exist for (auto pre : op->getPredecessors()) { IT_ASSERT(opSet.count(pre), "Op " + op->toString() + @@ -246,7 +246,7 @@ bool GraphObj::checkValid() const { } } - // 3. 检查 Tensor 的 FUID 唯一性 + // 3. Check uniqueness of Tensor FUIDs std::unordered_set fuids; for (auto tensor : tensors) { IT_ASSERT(fuids.insert(tensor->getFuid()).second, @@ -254,7 +254,7 @@ bool GraphObj::checkValid() const { std::to_string(tensor->getFuid())); } - // 4. 检查双向一致性 + // 4. Check bidirectional consistency for (auto tensor : tensors) { for (auto targetOp : tensor->getTargets()) { auto &inputs = targetOp->getInputs(); diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 0022432..c49321e 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -10,29 +10,29 @@ Runtime &RuntimeObj::getInstance() { } RuntimeObj::~RuntimeObj() { - // 清理所有线程的 Context + // Clean up all thread Contexts std::unique_lock lock(ctx_mutex); - // 只清空map,不手动释放CUDA资源 - // CUDA运行时会在程序退出时自动清理所有资源 + // Only clear the map, do not manually release CUDA resources + // CUDA runtime will automatically clean up all resources on program exit threadContexts.clear(); } void RuntimeObj::initThreadContext(infiniDevice_t device, int deviceId) { auto current_tid = std::this_thread::get_id(); - // 检测线程复用 + // Check for thread reuse if (tls_context_cache && tls_thread_id == current_tid && tls_context_cache->device == device && tls_context_cache->deviceId == deviceId) { - return; // 已初始化且设备相同,无需重新初始化 + return; } CHECK_INFINI_ERROR(infinirtSetDevice(device, deviceId)); - // 创建新的 stream + // Create new stream infinirtStream_t stream = nullptr; CHECK_INFINI_ERROR(infinirtStreamCreate(&stream)); - // 创建新的 Context + // Create new Context Context ctx = std::make_shared(); ctx->device = device; ctx->deviceId = deviceId; @@ -40,7 +40,7 @@ void RuntimeObj::initThreadContext(infiniDevice_t device, int deviceId) { ctx->workspaceSize = 7ll << 30; // 7GB ctx->workspace = nullptr; - // 更新缓存和全局 map + // Update cache and global map tls_context_cache = ctx; tls_thread_id = current_tid; @@ -54,12 +54,12 @@ void RuntimeObj::initThreadContext(infiniDevice_t device, int deviceId) { Context RuntimeObj::getCurrentThreadContext() const { auto current_tid = std::this_thread::get_id(); - // 检查缓存有效性 + // Check cache validity if (tls_context_cache && tls_thread_id == current_tid) { return tls_context_cache; } - // 从全局 map 查找 + // Search in global map { std::shared_lock lock(ctx_mutex); auto it = threadContexts.find(current_tid); @@ -77,12 +77,12 @@ Context RuntimeObj::getCurrentThreadContext() const { void RuntimeObj::setCurrentDevice(infiniDevice_t device, int deviceId) { auto ctx = getCurrentThreadContext(); - // 如果设备相同,直接返回 + // If device is the same, return directly if (ctx->device == device && ctx->deviceId == deviceId) { return; } - // 重新初始化 Context(force=true) + // Re-initialize Context (force=true) initThreadContext(device, deviceId); } @@ -96,7 +96,8 @@ void RuntimeObj::run(const Graph &graph) const { auto ctx = getCurrentThreadContext(); IT_ASSERT(graph->checkBeforRun()); - // TODO: 目前仅支持单卡,后续支持多卡 + // TODO: Currently only supports single device, multi-device support coming + // later const auto &kernelRegistry = KernelRegistry::getInstance(); for (auto &op : graph->getOperators()) { auto kernelAttrs = @@ -135,10 +136,10 @@ void RuntimeObj::deallocDevice(void *ptr) { void RuntimeObj::memcpy(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind) { - // 基本指针有效性检查 + // Basic pointer validity check if (dst == nullptr || src == nullptr) { std::cerr << "[ERROR] memcpy called with null pointer!" << std::endl; - // 这里应该抛出异常或返回错误,而不是继续 + // Should throw exception or return error here, not continue throw std::runtime_error("Null pointer in memcpy"); } @@ -184,27 +185,4 @@ bool RuntimeObj::isCpu() const { return context->device == INFINI_DEVICE_CPU; } -// void RuntimeObj::initWorkspace(size_t size) { -// auto ctx = getCurrentThreadContext(); - -// // 如果已分配且大小足够,直接返回 -// if (ctx->workspace && ctx->workspaceSize >= size) { -// return; -// } - -// // 释放旧的 workspace -// if (ctx->workspace) { -// infinirtFree(ctx->workspace); -// } - -// // CPU设备不需要调用setDevice,避免与GPU线程冲突 -// if (ctx->device != INFINI_DEVICE_CPU) { -// CHECK_INFINI_ERROR(infinirtSetDevice(ctx->device, ctx->deviceId)); -// } - -// // 分配新的 workspace -// ctx->workspaceSize = size; -// ctx->workspace = nullptr; -// CHECK_INFINI_ERROR(infinirtMalloc(&ctx->workspace, size)); -// } } // namespace infini diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 018b1d4..c30ad27 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -102,7 +102,7 @@ ElementType TensorObj::getStorageSize() const { size_t min_offset = 0; size_t storageSize = 1; if (constant_shape.empty()) { - return storageSize; // 标量 Tensor + return storageSize; // Scalar Tensor } for (auto i = 0; i < getRank(); ++i) { if (constant_stride[i] >= 0) { diff --git a/src/operators/Gemm.cc b/src/operators/Gemm.cc index 613dd5d..6f388ef 100644 --- a/src/operators/Gemm.cc +++ b/src/operators/Gemm.cc @@ -44,7 +44,7 @@ optional> GemmObj::inferShape() { IT_ASSERT(shapeA->size() >= 2 && shapeB->size() >= 2); Expr batchA = (shapeA->size() == 3) ? (*shapeA)[0] : ExprObj::constant(1); Expr batchB = (shapeB->size() == 3) ? (*shapeB)[0] : ExprObj::constant(1); - // 广播 batch 维度 + // broadcast batch dims Expr batch; if (batchA == batchB) batch = batchA; diff --git a/src/utils/utils.cc b/src/utils/utils.cc deleted file mode 100644 index 3ee4ce0..0000000 --- a/src/utils/utils.cc +++ /dev/null @@ -1,131 +0,0 @@ -#include "utils/utils.h" - -namespace infini { - -StrideExpr broadcastStride(const ShapeExpr &inputShape, - const StrideExpr &inputStride, - const ShapeExpr &outputShape) { - IT_ASSERT(inputShape->size() == inputStride->size(), - "Input shape and stride must have the same rank"); - - size_t inputRank = inputShape->size(); - size_t outputRank = outputShape->size(); - - // 从右向左对齐维度(NumPy广播规则) - std::vector broadcastedStride(outputRank); - - for (size_t outIdx = 0; outIdx < outputRank; ++outIdx) { - // 计算对应的输入维度索引(从右向左对齐) - int inIdx = static_cast(outIdx) - static_cast(outputRank) + - static_cast(inputRank); - - if (inIdx < 0) { - // 输入维度不存在,相当于维度为1,stride为0 - broadcastedStride[outIdx] = ExprObj::constant(0); - } else { - // 检查输入维度是否为1(广播维度)或与输出维度相等 - const Expr &inputDim = (*inputShape)[inIdx]; - const Expr &outputDim = (*outputShape)[outIdx]; - - // 如果输入维度是常量1,则是广播维度,stride为0 - auto inputDimConst = inputDim->asConstant(); - if (inputDimConst.has_value() && *inputDimConst == 1) { - broadcastedStride[outIdx] = ExprObj::constant(0); - } else { - // 维度相等或不相等但有效(由infer_broadcast保证),使用原始stride - broadcastedStride[outIdx] = (*inputStride)[inIdx]; - } - } - } - - return make_ref(broadcastedStride); -} - -ShapeExpr infer_broadcast(const ShapeExpr &A, const ShapeExpr &B) { - size_t rankA = A->size(); - size_t rankB = B->size(); - size_t rank = std::max(rankA, rankB); - - std::vector resultDims; - for (size_t i = 0; i < rank; ++i) { - // 从右向左对齐(NumPy广播规则) - int idxA = static_cast(i) - static_cast(rank) + - static_cast(rankA); - int idxB = static_cast(i) - static_cast(rank) + - static_cast(rankB); - - // 如果索引为负,表示该维度不存在,相当于维度为1 - Expr aDim = (idxA < 0) ? ExprObj::constant(1) : (*A)[idxA]; - Expr bDim = (idxB < 0) ? ExprObj::constant(1) : (*B)[idxB]; - - // 验证广播规则:维度必须相等,或者其中一个为1 - IT_ASSERT(aDim == bDim || aDim == ExprObj::constant(1) || - bDim == ExprObj::constant(1)); - - // 广播结果:如果a维度为1则取b维度,否则取a维度 - auto shapeEle = aDim == ExprObj::constant(1) ? bDim : aDim; - resultDims.emplace_back(shapeEle); - } - - return make_ref(resultDims); -} - -size_t calculateLinearOffset(size_t index, Shape shape, Stride stride) { - size_t rank = shape.size(); - std::vector indices(rank); - size_t remaining = index; - for (size_t i = 0; i < rank; ++i) { - size_t dim = rank - 1 - i; - indices[dim] = remaining % shape.at(dim); - remaining /= shape.at(dim); - } - size_t offset = 0; - for (size_t i = 0; i < rank; ++i) { - offset += indices[i] * stride.at(i); - } - return offset; -} - -float fp16_to_fp32(uint16_t fp16) { - // Union for safe type punning - union { - uint32_t u; - float f; - } converter; - - // Extract components from FP16 - uint32_t sign = (fp16 >> 15) & 0x1; - uint32_t exponent = (fp16 >> 10) & 0x1F; - uint32_t mantissa = fp16 & 0x3FF; - - // Handle special cases - if (exponent == 0) { - if (mantissa == 0) { - // Zero - converter.u = sign << 31; - return converter.f; - } else { - // Subnormal number: normalize it - while (!(mantissa & 0x400)) { - mantissa <<= 1; - exponent--; - } - exponent++; - mantissa &= 0x3FF; - } - } else if (exponent == 31) { - // Infinity or NaN - converter.u = (sign << 31) | 0x7F800000; - if (mantissa) { - converter.u |= mantissa; // NaN - } - return converter.f; - } - - // Convert to FP32 - // FP32: 1 sign bit, 8 exponent bits (bias 127), 23 mantissa bits - // FP16: 1 sign bit, 5 exponent bits (bias 15), 10 mantissa bits - converter.u = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13); - return converter.f; -} -} // namespace infini diff --git a/test/core/test_expr.cc b/test/core/test_expr.cc index e53cfdc..2747752 100644 --- a/test/core/test_expr.cc +++ b/test/core/test_expr.cc @@ -7,12 +7,12 @@ class ExprTest : public testing::Test { protected: void SetUp() override {} - // 辅助函数:创建测试环境 + // Helper function: create test environment std::unordered_map createTestEnv() { return {{"a", 5}, {"b", 3}, {"c", 2}}; } }; -// 测试常量表达式 +// Test constant expression TEST_F(ExprTest, ConstantExpr) { auto expr = ExprObj::constant(42); @@ -32,7 +32,7 @@ TEST_F(ExprTest, ConstantExpr) { EXPECT_FALSE(expr->equals(ExprObj::constant(43))); } -// 测试变量表达式 +// Test variable expression TEST_F(ExprTest, VariableExpr) { auto expr = ExprObj::variable("x"); @@ -43,10 +43,10 @@ TEST_F(ExprTest, VariableExpr) { EXPECT_EQ(variables.size(), 1); EXPECT_TRUE(variables.count("x") > 0); - // 未提供变量值 + // Variable value not provided EXPECT_FALSE(expr->evaluate({}).has_value()); - // 提供变量值 + // Variable value provided auto env = createTestEnv(); EXPECT_TRUE(expr->evaluate({{"x", 10}}).has_value()); EXPECT_EQ(*expr->evaluate({{"x", 10}}), 10); @@ -55,7 +55,7 @@ TEST_F(ExprTest, VariableExpr) { EXPECT_FALSE(expr->equals(ExprObj::variable("y"))); } -// 测试加法表达式 +// Test addition expression TEST_F(ExprTest, AddExpr) { auto a = ExprObj::variable("a"); auto b = ExprObj::variable("b"); @@ -69,14 +69,14 @@ TEST_F(ExprTest, AddExpr) { EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, 8); // 5 + 3 = 8 - // 测试常量折叠 + // Test constant folding auto constExpr = ExprObj::constant(2) + ExprObj::constant(3); auto constResult = constExpr->simplify(); EXPECT_EQ(constResult->getType(), ExprObj::Type::CONSTANT); EXPECT_EQ(*constResult->asConstant(), 5); } -// 测试减法表达式 +// Test subtraction expression TEST_F(ExprTest, SubExpr) { auto a = ExprObj::variable("a"); auto b = ExprObj::variable("b"); @@ -91,7 +91,7 @@ TEST_F(ExprTest, SubExpr) { EXPECT_EQ(*result, 2); // 5 - 3 = 2 } -// 测试乘法表达式 +// Test multiplication expression TEST_F(ExprTest, MulExpr) { auto a = ExprObj::variable("a"); auto b = ExprObj::variable("b"); @@ -106,7 +106,7 @@ TEST_F(ExprTest, MulExpr) { EXPECT_EQ(*result, 15); // 5 * 3 = 15 } -// 测试除法表达式 +// Test division expression TEST_F(ExprTest, DivExpr) { auto a = ExprObj::variable("a"); auto b = ExprObj::variable("b"); @@ -118,15 +118,15 @@ TEST_F(ExprTest, DivExpr) { auto env = createTestEnv(); auto result = expr->evaluate(env); EXPECT_TRUE(result.has_value()); - EXPECT_EQ(*result, 1); // 5 / 3 = 1 (整数除法) + EXPECT_EQ(*result, 1); // 5 / 3 = 1 (integer division) - // 测试除以0的情况 + // Test division by zero auto zeroExpr = a / ExprObj::constant(0); auto zeroResult = zeroExpr->evaluate(env); EXPECT_FALSE(zeroResult.has_value()); } -// 测试取模表达式 +// Test modulo expression TEST_F(ExprTest, ModExpr) { auto a = ExprObj::variable("a"); auto b = ExprObj::variable("b"); @@ -141,7 +141,7 @@ TEST_F(ExprTest, ModExpr) { EXPECT_EQ(*result, 2); // 5 % 3 = 2 } -// 测试min表达式 +// Test min expression TEST_F(ExprTest, MinExpr) { auto a = ExprObj::variable("a"); auto b = ExprObj::variable("b"); @@ -156,7 +156,7 @@ TEST_F(ExprTest, MinExpr) { EXPECT_EQ(*result, 3); // min(5, 3) = 3 } -// 测试max表达式 +// Test max expression TEST_F(ExprTest, MaxExpr) { auto a = ExprObj::variable("a"); auto b = ExprObj::variable("b"); @@ -186,7 +186,7 @@ TEST_F(ExprTest, ComplexExpression) { EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, 16); // (5 + 3) * 2 = 16 - // 测试嵌套表达式 + // Test nested expressions auto complexExpr = (a * b) + (b / c) - (a % c); EXPECT_EQ(complexExpr->toString(), "(((a * b) + (b / c)) - (a % c))"); @@ -195,20 +195,20 @@ TEST_F(ExprTest, ComplexExpression) { EXPECT_EQ(*complexResult, 15); // (5*3) + (3/2) - (5%2) = 15 + 1 - 1 = 15 } -// 测试表达式相等性 +// Test expression equality TEST_F(ExprTest, ExpressionEquality) { auto a = ExprObj::variable("a"); auto b = ExprObj::variable("b"); auto expr1 = a + b; auto expr2 = a + b; - auto expr3 = - b + a; // 交换律,但表达式结构不同,目前还不支持交换律、结合律等等 + auto expr3 = b + a; // Commutative, but expression structures differ; + // currently commutative, associative laws not supported EXPECT_TRUE(expr1->equals(expr2)); EXPECT_FALSE(expr1->equals(expr3)); - // 常量折叠后的相等性 + // Equality after constant folding auto const1 = (ExprObj::constant(2) + ExprObj::constant(3))->simplify(); auto const2 = ExprObj::constant(5); diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index 091f9bc..04acf93 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -11,7 +11,7 @@ class GraphBasicTest : public testing::Test { void SetUp() override { runtime = make_ref(); } }; -// 测试Graph构造和运行时获取 +// Test Graph construction and runtime retrieval TEST_F(GraphBasicTest, GraphConstruction) { auto graph = make_ref(runtime); @@ -20,7 +20,7 @@ TEST_F(GraphBasicTest, GraphConstruction) { EXPECT_TRUE(graph->getOperators().empty()); } -// 测试添加常量形状Tensor +// Test adding constant shape Tensor TEST_F(GraphBasicTest, AddConcreteTensor) { auto graph = make_ref(runtime); @@ -34,7 +34,7 @@ TEST_F(GraphBasicTest, AddConcreteTensor) { EXPECT_EQ(tensor2->getDataType(), DataType(INFINI_DTYPE_F16)); } -// 测试添加符号形状Tensor +// Test adding symbolic shape Tensor TEST_F(GraphBasicTest, AddSymbolicTensor) { auto graph = make_ref(runtime); @@ -51,16 +51,16 @@ TEST_F(GraphBasicTest, AddSymbolicTensor) { EXPECT_TRUE(tensor->getShape()->isDynamic()); } -// 测试添加带步长的Tensor +// Test adding Tensor with stride TEST_F(GraphBasicTest, AddTensorWithStride) { auto graph = make_ref(runtime); - // 常量步长 + // Constant stride auto tensor1 = graph->addTensor({2, 3, 4}, {12, 4, 1}, DataType(INFINI_DTYPE_F32)); EXPECT_TRUE(tensor1->getStride()->isConcrete()); - // 步长表达式 + // Stride expression auto strideExpr = StrideExpr(new StrideExprObj( {ExprObj::constant(12), ExprObj::constant(4), ExprObj::constant(1)})); auto tensor2 = @@ -68,7 +68,7 @@ TEST_F(GraphBasicTest, AddTensorWithStride) { EXPECT_TRUE(tensor2->getStride()->isConcrete()); } -// 测试批量添加Tensor +// Test batch adding Tensors TEST_F(GraphBasicTest, AddTensorVector) { auto graph = make_ref(runtime); @@ -86,11 +86,11 @@ TEST_F(GraphBasicTest, AddTensorVector) { EXPECT_EQ(added[2], tensors[2]); } -// 测试移除Tensor和Operator +// Test removing Tensor and Operator TEST_F(GraphBasicTest, RemoveTensorAndOperator) { auto graph = make_ref(runtime); - // 创建Tensor和Operator + // Create Tensor and Operator auto A = graph->addTensor({2, 3}, DataType(INFINI_DTYPE_F32)); auto B = graph->addTensor({3, 4}, DataType(INFINI_DTYPE_F32)); auto Y = graph->addTensor({1, 2, 4}, DataType(INFINI_DTYPE_F32)); @@ -99,11 +99,11 @@ TEST_F(GraphBasicTest, RemoveTensorAndOperator) { EXPECT_EQ(graph->getTensors().size(), 3); EXPECT_EQ(graph->getOperators().size(), 1); - // 移除Tensor + // Remove Tensor graph->removeTensor(A); EXPECT_EQ(graph->getTensors().size(), 2); - // 移除Operator + // Remove Operator graph->removeOperator(gemm); EXPECT_EQ(graph->getOperators().size(), 0); } diff --git a/test/core/test_shape_expr.cc b/test/core/test_shape_expr.cc index 9700a5f..2d66ac6 100644 --- a/test/core/test_shape_expr.cc +++ b/test/core/test_shape_expr.cc @@ -12,7 +12,7 @@ class ShapeExprTest : public testing::Test { } }; -// 测试常量形状 +// Test constant shape TEST_F(ShapeExprTest, ConcreteShape) { std::vector dims = {ExprObj::constant(32), ExprObj::constant(224), ExprObj::constant(224), ExprObj::constant(3)}; @@ -36,7 +36,7 @@ TEST_F(ShapeExprTest, ConcreteShape) { EXPECT_EQ(evalResult->size(), 4); } -// 测试符号形状 +// Test symbolic shape TEST_F(ShapeExprTest, SymbolicShape) { std::vector dims = {ExprObj::variable("batch"), ExprObj::constant(224), ExprObj::constant(224), @@ -64,14 +64,14 @@ TEST_F(ShapeExprTest, SymbolicShape) { EXPECT_EQ((*evalResult)[3], 3); // channels = 3 } -// 测试混合表达式形状 +// Test mixed expression shape TEST_F(ShapeExprTest, MixedExpressionShape) { auto batch = ExprObj::variable("batch"); auto height = ExprObj::constant(224); auto width = ExprObj::constant(224); auto channels = ExprObj::constant(3); - // 使用表达式计算维度 + // Use expressions to compute dimensions auto paddedHeight = height + ExprObj::constant(2); auto paddedWidth = width + ExprObj::constant(2); @@ -89,7 +89,7 @@ TEST_F(ShapeExprTest, MixedExpressionShape) { EXPECT_EQ((*evalResult)[2], 226); // 224 + 2 EXPECT_EQ((*evalResult)[3], 3); // 3 - // 测试简化 + // Test simplification auto simplified = shape->simplify(); EXPECT_EQ(simplified->toString(), "[batch, 226, 226, 3]"); } diff --git a/test/core/test_stride_expr.cc b/test/core/test_stride_expr.cc index 3f88cc7..5ff7a0a 100644 --- a/test/core/test_stride_expr.cc +++ b/test/core/test_stride_expr.cc @@ -7,7 +7,7 @@ class StrideExprTest : public testing::Test { void SetUp() override {} }; -// 测试步长计算 +// Test stride computation TEST_F(StrideExprTest, StrideComputation) { auto shape = ShapeExpr( new ShapeExprObj({ExprObj::constant(32), ExprObj::constant(224), @@ -31,14 +31,14 @@ TEST_F(StrideExprTest, StrideComputation) { EXPECT_EQ(constantValue[3], 1); } -// 测试符号步长 +// Test symbolic stride TEST_F(StrideExprTest, SymbolicStride) { auto batch = ExprObj::variable("batch"); auto height = ExprObj::variable("height"); auto width = ExprObj::constant(224); auto channels = ExprObj::constant(3); - // 符号步长计算 + // Symbolic stride computation std::vector strides = {height * width * channels, width * channels, channels, ExprObj::constant(1)}; @@ -47,12 +47,13 @@ TEST_F(StrideExprTest, SymbolicStride) { EXPECT_EQ(stride->toString(), "[((height * 224) * 3), (224 * 3), 3, 1]"); EXPECT_FALSE(stride->isConcrete()); - // 简化后的表达式 - // TODO:目前不支持乘法交换律,因此无法简化 + // Simplified expression + // TODO: Currently does not support multiplication commutative law, cannot + // simplify auto simplified = stride->simplify(); EXPECT_EQ(simplified->toString(), "[((height * 224) * 3), 672, 3, 1]"); - // 求值 + // Evaluate auto env = std::unordered_map{{"height", 256}}; auto result = stride->evaluate(env); EXPECT_TRUE(result.has_value()); diff --git a/test/core/test_tensor_basic.cc b/test/core/test_tensor_basic.cc index dae23d1..7d1320b 100644 --- a/test/core/test_tensor_basic.cc +++ b/test/core/test_tensor_basic.cc @@ -7,7 +7,7 @@ class TensorBasicTest : public testing::Test { void SetUp() override {} }; -// 测试从常量形状向量构造Tensor +// Test constructing Tensor from constant shape vector TEST_F(TensorBasicTest, ConstructFromConcreteShapeVector) { Shape concreteDims = {2, 3, 4}; auto tensor = make_ref(concreteDims, DataType(INFINI_DTYPE_F32)); @@ -23,7 +23,7 @@ TEST_F(TensorBasicTest, ConstructFromConcreteShapeVector) { EXPECT_EQ(constantShape, concreteDims); } -// 测试从符号形状构造Tensor +// Test constructing Tensor from symbolic shape TEST_F(TensorBasicTest, ConstructFromSymbolicShape) { auto batch = ExprObj::variable("batch"); auto height = ExprObj::constant(224); @@ -42,11 +42,11 @@ TEST_F(TensorBasicTest, ConstructFromSymbolicShape) { EXPECT_EQ(tensor->getShape()->toString(), "[batch, 224, 224, 3]"); } -// 测试从混合形状构造Tensor +// Test constructing Tensor from mixed shape TEST_F(TensorBasicTest, ConstructFromMixedShape) { auto shapeExpr = ShapeExpr(new ShapeExprObj( {ExprObj::variable("batch"), - ExprObj::constant(256) + ExprObj::constant(2), // 计算表达式 + ExprObj::constant(256) + ExprObj::constant(2), // Compute expression ExprObj::constant(128), ExprObj::variable("channels")})); auto tensor = make_ref(shapeExpr, DataType(INFINI_DTYPE_I32)); @@ -57,17 +57,17 @@ TEST_F(TensorBasicTest, ConstructFromMixedShape) { "[batch, (256 + 2), 128, channels]"); } -// 测试带步长的Tensor构造 +// Test constructing Tensor with stride TEST_F(TensorBasicTest, ConstructWithStride) { Shape concreteDims = {2, 3, 4}; - Stride strideDims = {12, 4, 1}; // 连续步长 + Stride strideDims = {12, 4, 1}; // Contiguous stride - // 从常量步长构造 + // Construct from constant stride auto tensor1 = make_ref(concreteDims, strideDims, DataType(INFINI_DTYPE_U32)); EXPECT_EQ(tensor1->getRank(), 3); - // 从步长表达式构造 + // Construct from stride expression auto strideExpr = StrideExpr(new StrideExprObj( {ExprObj::constant(12), ExprObj::constant(4), ExprObj::constant(1)})); @@ -76,22 +76,22 @@ TEST_F(TensorBasicTest, ConstructWithStride) { EXPECT_EQ(tensor2->getStride()->toString(), "[12, 4, 1]"); } -// 测试形状获取和设置 +// Test shape getter and setter TEST_F(TensorBasicTest, ShapeGetSet) { auto tensor = make_ref(Shape{2, 3, 4}, DataType(INFINI_DTYPE_F32)); - // 获取形状 + // Get shape auto shape = tensor->getShape(); EXPECT_TRUE(shape->isConcrete()); EXPECT_EQ(shape->getConstantValue(), Shape({2, 3, 4})); - // 设置为新的常量形状 + // Set to new constant shape tensor->setShape(Shape{5, 6, 7}); auto newShape = tensor->getShape(); EXPECT_EQ(newShape->getConstantValue(), Shape({5, 6, 7})); - // 设置为符号形状 + // Set to symbolic shape auto symShape = ShapeExpr( new ShapeExprObj({ExprObj::variable("batch"), ExprObj::constant(224), ExprObj::constant(224)})); @@ -100,22 +100,22 @@ TEST_F(TensorBasicTest, ShapeGetSet) { EXPECT_EQ(tensor->getShape()->toString(), "[batch, 224, 224]"); } -// 测试步长获取和设置 +// Test stride getter and setter TEST_F(TensorBasicTest, StrideGetSet) { auto tensor = make_ref(Shape{2, 3, 4}, DataType(INFINI_DTYPE_F32)); - // 初始应该是连续步长 + // Initial should be contiguous stride auto stride = tensor->getStride(); EXPECT_TRUE(stride->isConcrete()); - // 连续步长应该是 [3*4, 4, 1] = [12, 4, 1] + // Contiguous stride should be [3*4, 4, 1] = [12, 4, 1] - // 设置新的常量步长 + // Set new constant stride tensor->setStride(Stride{6, 2, 1}); auto newStride = tensor->getStride(); EXPECT_EQ(newStride->getConstantValue(), Stride({6, 2, 1})); - // 设置为步长表达式 + // Set to stride expression auto strideExpr = StrideExpr( new StrideExprObj({ExprObj::variable("stride0"), ExprObj::constant(2), ExprObj::constant(1)})); @@ -123,37 +123,38 @@ TEST_F(TensorBasicTest, StrideGetSet) { EXPECT_EQ(tensor->getStride()->toString(), "[stride0, 2, 1]"); } -// 测试元素总数计算 +// Test total element count calculation TEST_F(TensorBasicTest, ElementCount) { - // 常量形状 + // Constant shape auto tensor1 = make_ref(Shape{2, 3, 4}, DataType(INFINI_DTYPE_F32)); EXPECT_EQ(tensor1->getElement(), 2 * 3 * 4); // 24 - // 符号形状 + // Symbolic shape auto symShape = ShapeExpr( new ShapeExprObj({ExprObj::variable("batch"), ExprObj::constant(224), ExprObj::constant(224)})); auto tensor2 = make_ref(symShape, DataType(INFINI_DTYPE_F32)); - EXPECT_THROW(tensor2->getElement(), Exception); // 无法计算元素总数 + EXPECT_THROW(tensor2->getElement(), + Exception); // Cannot calculate element count tensor2->setShape(symShape->evaluate({{"batch", 3}}).value()); EXPECT_EQ(tensor2->getElement(), 3 * 224 * 224); // 3*224*224 } -// 测试存储大小和总字节数 +// Test storage size and total bytes TEST_F(TensorBasicTest, StorageSizeAndBytes) { auto tensor = make_ref(Shape{2, 3, 4}, DataType(INFINI_DTYPE_F32)); EXPECT_EQ(tensor->getElement(), 24); // 2*3*4 - EXPECT_EQ(tensor->getStorageSize(), 24); // 假设连续存储 + EXPECT_EQ(tensor->getStorageSize(), 24); // Assume contiguous storage - // 测试总字节数 + // Test total bytes auto totalBytes = tensor->getTotalBytes(); EXPECT_EQ(totalBytes, 24 * sizeof(float)); - // 测试不同数据类型 + // Test different data types auto tensorInt8 = make_ref(Shape{10, 20}, DataType(INFINI_DTYPE_I8)); EXPECT_EQ(tensorInt8->getTotalBytes(), 10 * 20 * sizeof(int8_t)); diff --git a/test/kernels/test_elementwise_kernel.cc b/test/kernels/test_elementwise_kernel.cc index af2b04c..133777e 100644 --- a/test/kernels/test_elementwise_kernel.cc +++ b/test/kernels/test_elementwise_kernel.cc @@ -1,14 +1,11 @@ #include "core/runtime.h" #include "operators/ElementWise.h" +#include "utils/test_utils.h" #include "gtest/gtest.h" -#include -#include -#include -#include namespace infini { -// 线程测试参数 - 模板类支持任意数据类型 +// Thread test parameters template struct ThreadTestParams { infiniDevice_t device = INFINI_DEVICE_CPU; int deviceId = 0; @@ -23,34 +20,35 @@ template struct ThreadTestParams { std::string deviceName; }; -// 设备线程函数 - 模板函数 +// Device thread function template void deviceThreadFunc(ThreadTestParams ¶ms) { RuntimeObj::init(); Runtime &runtime = RuntimeObj::getInstance(); - // 初始化设备 Context + // Initialize device Context runtime->initThreadContext(params.device, params.deviceId); - // 创建 Graph + // Create Graph Graph g = make_ref(runtime); auto A = g->addTensor(params.shapeA, params.dataType); auto B = g->addTensor(params.shapeB, params.dataType); auto op = g->addOp(params.opType, A, B, nullptr); - // 先设置数据(设置CPU指针),再分配内存(触发H2D拷贝) + // Set data first (set CPU pointer), then allocate memory (triggers H2D + // copy) A->setData(params.inputAData.data()); B->setData(params.inputBData.data()); - runtime->dataMalloc(g); // 会检测到data是CPU,执行H2D拷贝 + runtime->dataMalloc(g); - // 运行计算 + // Run computation runtime->run(g); - // 获取输出并复制到 host + // Get output and copy to host auto output = op->getOutput(0); size_t numElements = output->getElement(); params.outputData.resize(numElements); - // 检查 output 的数据是否存在 + // Check if output data exists auto dataBlob = output->getData(); if (!dataBlob) { throw std::runtime_error("Output data blob is null!"); @@ -61,77 +59,49 @@ template void deviceThreadFunc(ThreadTestParams ¶ms) { "Output device pointer is null on GPU device!"); } - // 复制结果数据 + // Copy result data void *hostPtr = runtime->allocHost(output->getTotalBytes()); runtime->memcpy(hostPtr, devicePtr, output->getTotalBytes(), INFINIRT_MEMCPY_D2H); - // 根据数据类型复制 - if constexpr (std::is_same_v) { - if (params.dataType.getType() == INFINI_DTYPE_F32) { - std::memcpy(params.outputData.data(), hostPtr, - numElements * sizeof(float)); - } else if (params.dataType.getType() == INFINI_DTYPE_F16) { - // FP16 转换为 FP32 - uint16_t *fp16Data = static_cast(hostPtr); - for (size_t i = 0; i < numElements; ++i) { - params.outputData[i] = fp16_to_fp32(fp16Data[i]); - } - } - } else if constexpr (std::is_same_v) { - if (params.dataType.getType() == INFINI_DTYPE_F16) { - std::memcpy(params.outputData.data(), hostPtr, - numElements * sizeof(uint16_t)); - } else if (params.dataType.getType() == INFINI_DTYPE_F32) { - // FP32 转换为 FP16 - float *fp32Data = static_cast(hostPtr); - uint16_t *fp16Data = params.outputData.data(); - for (size_t i = 0; i < numElements; ++i) { - // 这里需要 FP32 转 FP16 的函数,暂时用简单的方式 - fp16Data[i] = static_cast(fp32Data[i]); - } - } - } + // Use generic function for data copy and conversion + copyAndConvertData(params.outputData, hostPtr, numElements, + params.dataType); runtime->deallocHost(hostPtr); params.completed = true; } -// 运行多线程测试 - 模板函数 +// Data generator function type +template +using DataGeneratorFunc = std::function(size_t, T, T)>; + +// Run multi-thread test template -void runMultiThreadTest(OpType opType, const Shape &shapeA, const Shape &shapeB, - const DataType &dataType, bool print = false) { +void runMultiThreadTest( + OpType opType, const Shape &shapeA, const Shape &shapeB, + const DataType &dataType, + DataGeneratorFunc dataGenerator = generateRandomData, + bool print = false) { - // 准备输入数据 + // Prepare input data - use utility function to simplify size_t elementA = 1, elementB = 1; for (auto dim : shapeA) elementA *= dim; for (auto dim : shapeB) elementB *= dim; - std::vector inputAData(elementA); - std::vector inputBData(elementB); - - // 使用简单的递增序列和递减序列,便于计算和验证 - for (size_t i = 0; i < elementA; ++i) { - if constexpr (std::is_same_v) { - inputAData[i] = static_cast(i + 1); // 1, 2, 3, ... - } else if constexpr (std::is_same_v) { - inputAData[i] = static_cast(i + 1); - } - } - for (size_t i = 0; i < elementB; ++i) { - if constexpr (std::is_same_v) { - inputBData[i] = static_cast(elementB - i); // n, n-1, n-2, ... - } else if constexpr (std::is_same_v) { - inputBData[i] = static_cast(elementB - i); - } - } + // Use the passed data generator function (default uses + // random data) + auto inputAData = + dataGenerator(elementA, static_cast(-10), static_cast(10)); + auto inputBData = + dataGenerator(elementB, static_cast(-10), static_cast(10)); - // 创建线程参数 + // Create thread parameters ThreadTestParams cpuParams, gpuParams; - // CPU 线程参数 + // CPU thread parameters cpuParams.device = INFINI_DEVICE_CPU; cpuParams.deviceId = 0; cpuParams.opType = opType; @@ -142,7 +112,7 @@ void runMultiThreadTest(OpType opType, const Shape &shapeA, const Shape &shapeB, cpuParams.inputBData = inputBData; cpuParams.deviceName = "CPU"; - // GPU 线程参数 + // GPU thread parameters gpuParams.device = INFINI_DEVICE_NVIDIA; gpuParams.deviceId = 0; gpuParams.opType = opType; @@ -167,22 +137,22 @@ void runMultiThreadTest(OpType opType, const Shape &shapeA, const Shape &shapeB, std::cout << "========================================" << std::endl; } - // 启动两个线程并行执行 + // Launch two threads for parallel execution std::thread cpuThread(deviceThreadFunc, std::ref(cpuParams)); std::thread gpuThread(deviceThreadFunc, std::ref(gpuParams)); - // 等待两个线程完成 + // Wait for both threads to complete cpuThread.join(); gpuThread.join(); - // 验证结果 + // Verify results ASSERT_TRUE(cpuParams.completed) << "CPU thread failed"; ASSERT_TRUE(gpuParams.completed) << "NVIDIA thread failed"; ASSERT_EQ(cpuParams.outputData.size(), gpuParams.outputData.size()) << "Output size mismatch"; - // 对比结果 + // Compare results size_t numErrors = 0; float maxError = 0.0f; const float epsilon = 1e-3f; @@ -190,12 +160,12 @@ void runMultiThreadTest(OpType opType, const Shape &shapeA, const Shape &shapeB, for (size_t i = 0; i < cpuParams.outputData.size(); ++i) { float cpuVal, gpuVal; - // 转换为 float 进行比较 + // Convert to float for comparison if constexpr (std::is_same_v) { cpuVal = cpuParams.outputData[i]; gpuVal = gpuParams.outputData[i]; } else if constexpr (std::is_same_v) { - // FP16 转 FP32 比较 + // FP16 to FP32 comparison cpuVal = fp16_to_fp32(cpuParams.outputData[i]); gpuVal = fp16_to_fp32(gpuParams.outputData[i]); } @@ -205,7 +175,7 @@ void runMultiThreadTest(OpType opType, const Shape &shapeA, const Shape &shapeB, if (error > epsilon) { numErrors++; - if (numErrors <= 5) { // 只打印前5个错误 + if (numErrors <= 5) { std::cout << "Mismatch at index " << i << ": CPU=" << cpuVal << ", NVIDIA=" << gpuVal << ", error=" << error << std::endl; @@ -228,82 +198,85 @@ void runMultiThreadTest(OpType opType, const Shape &shapeA, const Shape &shapeB, std::cout << "========================================" << std::endl; } - EXPECT_EQ(numErrors, 0) - << "Results mismatch between CPU and NVIDIA (max error: " << maxError - << ")"; + EXPECT_EQ(numErrors, 0) << "Results mismatch between " + "CPU and NVIDIA (max error: " + << maxError << ")"; } -// 基本Add操作测试 - F32 +// Basic Add operation test - F32 TEST(ElementWise, Add_MultiThread_F32) { Shape shapeA = {3, 1}; Shape shapeB = {2, 3, 4}; #ifdef USE_CUDA runMultiThreadTest(OpType::Add, shapeA, shapeB, - DataType(INFINI_DTYPE_F32), true); + DataType(INFINI_DTYPE_F32)); #else std::cout << "CUDA not enabled, skipping multi-thread test" << std::endl; #endif } -// 基本Add操作测试 - F16 +// Basic Add operation test - F16 TEST(ElementWise, Add_MultiThread_F16) { Shape shapeA = {3, 1}; Shape shapeB = {2, 3, 4}; #ifdef USE_CUDA runMultiThreadTest(OpType::Add, shapeA, shapeB, - DataType(INFINI_DTYPE_F16), true); + DataType(INFINI_DTYPE_F16), + generateSequentialData, true); #else std::cout << "CUDA not enabled, skipping multi-thread test" << std::endl; #endif } -// 基本Mul操作测试 - F32 +// Basic Mul operation test - F32 TEST(ElementWise, Mul_MultiThread_F32) { Shape shapeA = {3, 4}; Shape shapeB = {3, 4}; #ifdef USE_CUDA runMultiThreadTest(OpType::Mul, shapeA, shapeB, - DataType(INFINI_DTYPE_F32), false); + DataType(INFINI_DTYPE_F32)); #endif } -// 基本Mul操作测试 - F16 +// Basic Mul operation test - F16 TEST(ElementWise, Mul_MultiThread_F16) { Shape shapeA = {3, 4}; Shape shapeB = {3, 4}; #ifdef USE_CUDA runMultiThreadTest(OpType::Mul, shapeA, shapeB, - DataType(INFINI_DTYPE_F16), false); + DataType(INFINI_DTYPE_F16), + generateSequentialData, false); #endif } -// 基本Sub操作测试 - F32 +// Basic Sub operation test - F32 TEST(ElementWise, Sub_MultiThread_F32) { Shape shapeA = {1, 5, 6}; Shape shapeB = {1, 5, 6}; #ifdef USE_CUDA runMultiThreadTest(OpType::Sub, shapeA, shapeB, - DataType(INFINI_DTYPE_F32), false); + DataType(INFINI_DTYPE_F32)); #endif } -// 基本Sub操作测试 - F16 +// Basic Sub operation test - F16 TEST(ElementWise, Sub_MultiThread_F16) { Shape shapeA = {1, 5, 6}; Shape shapeB = {1, 5, 6}; #ifdef USE_CUDA runMultiThreadTest(OpType::Sub, shapeA, shapeB, - DataType(INFINI_DTYPE_F16), false); + DataType(INFINI_DTYPE_F16), + generateSequentialData, false); #endif } -// 单设备测试(用于调试)- CPU +// Single device test - CPU TEST(ElementWise, Add_SingleDevice_CPU) { RuntimeObj::init(); Runtime &runtime = RuntimeObj::getInstance(); @@ -319,7 +292,7 @@ TEST(ElementWise, Add_SingleDevice_CPU) { runtime->dataMalloc(g); - // 设置输入数据 + // Set input data std::vector inputAData(A->getElement()); std::vector inputBData(B->getElement()); @@ -333,17 +306,17 @@ TEST(ElementWise, Add_SingleDevice_CPU) { A->setData(inputAData.data()); B->setData(inputBData.data()); - // 执行计算 + // Execute computation runtime->run(g); - // 获取输出并打印 + // Get output and print auto output = op->getOutput(0); std::cout << "CPU Output Data: " << std::endl; output->printData(runtime); } #ifdef USE_CUDA -// 单设备测试(用于调试)- NVIDIA F32 +// Single device test - NVIDIA F32 TEST(ElementWise, Add_SingleDevice_NVIDIA_F32) { RuntimeObj::init(); Runtime &runtime = RuntimeObj::getInstance(); @@ -357,7 +330,7 @@ TEST(ElementWise, Add_SingleDevice_NVIDIA_F32) { auto B = g->addTensor(shapeB, DataType(INFINI_DTYPE_F32)); auto op = g->addOp(OpType::Add, A, B, nullptr); - // 设置输入数据 + // Set input data std::vector inputAData(A->getElement()); std::vector inputBData(B->getElement()); @@ -372,16 +345,16 @@ TEST(ElementWise, Add_SingleDevice_NVIDIA_F32) { B->setData(inputBData.data()); runtime->dataMalloc(g); - // 执行计算 + // Execute computation runtime->run(g); - // 获取输出并打印 + // Get output and print auto output = op->getOutput(0); std::cout << "NVIDIA F32 Output Data: " << std::endl; output->printData(runtime); } -// 单设备测试(用于调试)- NVIDIA F16 +// Single device test - NVIDIA F16 TEST(ElementWise, Add_SingleDevice_NVIDIA_F16) { RuntimeObj::init(); Runtime &runtime = RuntimeObj::getInstance(); @@ -395,7 +368,7 @@ TEST(ElementWise, Add_SingleDevice_NVIDIA_F16) { auto B = g->addTensor(shapeB, DataType(INFINI_DTYPE_F16)); auto op = g->addOp(OpType::Add, A, B, nullptr); - // 设置输入数据 + // Set input data std::vector inputAData(A->getElement()); std::vector inputBData(B->getElement()); @@ -410,10 +383,10 @@ TEST(ElementWise, Add_SingleDevice_NVIDIA_F16) { B->setData(inputBData.data()); runtime->dataMalloc(g); - // 执行计算 + // Execute computation runtime->run(g); - // 获取输出并打印 + // Get output and print auto output = op->getOutput(0); std::cout << "NVIDIA F16 Output Data: " << std::endl; output->printData(runtime); diff --git a/test/kernels/test_gemm_kenel.cc b/test/kernels/test_gemm_kenel.cc deleted file mode 100644 index c40dd68..0000000 --- a/test/kernels/test_gemm_kenel.cc +++ /dev/null @@ -1,42 +0,0 @@ -#include "core/runtime.h" -#include "operators/Gemm.h" -#include "gtest/gtest.h" - -namespace infini { -void runGemmTest(const std::string &deviceName, infiniDevice_t DeviceT, - const Shape &shapeA, const Shape &shapeB, float alpha, - float beta, bool transA, bool transB, const DataType &dataType, - bool print = false) { - Runtime &runtime = RuntimeObj::getInstance(); - RuntimeObj::init(); - runtime->initThreadContext(DeviceT, 0); - Graph g = make_ref(runtime); - auto A = g->addTensor(shapeA, dataType); - auto B = g->addTensor(shapeB, dataType); - auto op = - g->addOp(A, B, nullptr, nullptr, alpha, beta, transA, transB); - runtime->dataMalloc(g); - auto res = g->toString(); - // Only when the data is contiguous, will this assignment be successful. - std::vector inputAData(A->getElement()); - std::iota(inputAData.begin(), inputAData.end(), 1); - std::vector inputBData(B->getElement()); - std::iota(inputBData.begin(), inputBData.end(), 1); - A->setData(inputAData.data()); - B->setData(inputBData.data()); - std::cout << res << std::endl; - runtime->run(g); - auto output = op->getOutput(0); - output->printData(runtime); -} - -TEST(Gemm, Kernel) { - runGemmTest("CPU", INFINI_DEVICE_CPU, Shape{3, 5}, Shape{5, 2}, 1.0, 0.0, - false, false, DataType(INFINI_DTYPE_F32)); - -#ifdef USE_CUDA - runGemmTest("NVIDIA", INFINI_DEVICE_NVIDIA, Shape{3, 5}, Shape{5, 2}, 1.0, - 0.0, false, false, DataType(INFINI_DTYPE_F32)); -#endif -} -} // namespace infini diff --git a/test/kernels/test_gemm_kernel.cc b/test/kernels/test_gemm_kernel.cc new file mode 100644 index 0000000..ab2c983 --- /dev/null +++ b/test/kernels/test_gemm_kernel.cc @@ -0,0 +1,373 @@ +#include "core/runtime.h" +#include "operators/Gemm.h" +#include "utils/test_utils.h" +#include "gtest/gtest.h" + +namespace infini { + +// Thread test parameters +template struct GemmThreadTestParams { + infiniDevice_t device = INFINI_DEVICE_CPU; + int deviceId = 0; + Shape shapeA; + Shape shapeB; + DataType dataType = DataType(INFINI_DTYPE_F32); + float alpha = 1.0f; + float beta = 0.0f; + bool transA = false; + bool transB = false; + std::vector inputAData; + std::vector inputBData; + std::vector outputData; + bool completed = false; + std::string deviceName; +}; + +// Device thread function +template +void gemmDeviceThreadFunc(GemmThreadTestParams ¶ms) { + RuntimeObj::init(); + Runtime &runtime = RuntimeObj::getInstance(); + + // Initialize device Context + runtime->initThreadContext(params.device, params.deviceId); + + // Create Graph + Graph g = make_ref(runtime); + auto A = g->addTensor(params.shapeA, params.dataType); + auto B = g->addTensor(params.shapeB, params.dataType); + auto op = g->addOp(A, B, nullptr, nullptr, params.alpha, + params.beta, params.transA, params.transB); + + A->setData(params.inputAData.data()); + B->setData(params.inputBData.data()); + runtime->dataMalloc(g); + + // Run computation + runtime->run(g); + + // Get output and copy to host + auto output = op->getOutput(0); + size_t numElements = output->getElement(); + params.outputData.resize(numElements); + + // Check if output data exists + auto dataBlob = output->getData(); + if (!dataBlob) { + throw std::runtime_error("Output data blob is null!"); + } + void *devicePtr = dataBlob->getRawDataPtr(); + if (!devicePtr && !runtime->isCpu()) { + throw std::runtime_error( + "Output device pointer is null on GPU device!"); + } + + // Copy result data + void *hostPtr = runtime->allocHost(output->getTotalBytes()); + runtime->memcpy(hostPtr, devicePtr, output->getTotalBytes(), + INFINIRT_MEMCPY_D2H); + + // Use generic function for data copy and conversion + copyAndConvertData(params.outputData, hostPtr, numElements, + params.dataType); + + runtime->deallocHost(hostPtr); + params.completed = true; +} + +// Data generator function type +template +using GemmDataGeneratorFunc = std::function(size_t, T, T)>; + +// Run multi-thread test +template +void runGemmMultiThreadTest( + const Shape &shapeA, const Shape &shapeB, float alpha, float beta, + bool transA, bool transB, const DataType &dataType, + GemmDataGeneratorFunc dataGenerator = generateSequentialData, + bool print = false) { + + // Prepare input data + size_t elementA = 1, elementB = 1; + for (auto dim : shapeA) + elementA *= dim; + for (auto dim : shapeB) + elementB *= dim; + + // Use the passed data generator function (default uses sequential data) + auto inputAData = + dataGenerator(elementA, static_cast(1), static_cast(1)); + auto inputBData = + dataGenerator(elementB, static_cast(1), static_cast(1)); + + // Create thread parameters + GemmThreadTestParams cpuParams, gpuParams; + + // CPU thread parameters + cpuParams.device = INFINI_DEVICE_CPU; + cpuParams.deviceId = 0; + cpuParams.shapeA = shapeA; + cpuParams.shapeB = shapeB; + cpuParams.dataType = dataType; + cpuParams.alpha = alpha; + cpuParams.beta = beta; + cpuParams.transA = transA; + cpuParams.transB = transB; + cpuParams.inputAData = inputAData; + cpuParams.inputBData = inputBData; + cpuParams.deviceName = "CPU"; + + // GPU thread parameters + gpuParams.device = INFINI_DEVICE_NVIDIA; + gpuParams.deviceId = 0; + gpuParams.shapeA = shapeA; + gpuParams.shapeB = shapeB; + gpuParams.dataType = dataType; + gpuParams.alpha = alpha; + gpuParams.beta = beta; + gpuParams.transA = transA; + gpuParams.transB = transB; + gpuParams.inputAData = inputAData; + gpuParams.inputBData = inputBData; + gpuParams.deviceName = "NVIDIA"; + + if (print) { + std::cout << "========================================" << std::endl; + std::cout << "Running Multi-Thread Gemm Test" << std::endl; + std::cout << "DataType: " << dataType.toString() << std::endl; + std::cout << "Shape A: " << vecToString(shapeA) << std::endl; + std::cout << "Shape B: " << vecToString(shapeB) << std::endl; + std::cout << "Alpha: " << alpha << ", Beta: " << beta << std::endl; + std::cout << "TransA: " << (transA ? "Yes" : "No") + << ", TransB: " << (transB ? "Yes" : "No") << std::endl; + std::cout << "Thread 1: CPU (" << dataType.toString() << ")" + << std::endl; + std::cout << "Thread 2: NVIDIA (" << dataType.toString() << ")" + << std::endl; + std::cout << "========================================" << std::endl; + } + + // Launch two threads for parallel execution + std::thread cpuThread(gemmDeviceThreadFunc, std::ref(cpuParams)); + std::thread gpuThread(gemmDeviceThreadFunc, std::ref(gpuParams)); + + // Wait for both threads to complete + cpuThread.join(); + gpuThread.join(); + + // Verify results + ASSERT_TRUE(cpuParams.completed) << "CPU thread failed"; + ASSERT_TRUE(gpuParams.completed) << "NVIDIA thread failed"; + + ASSERT_EQ(cpuParams.outputData.size(), gpuParams.outputData.size()) + << "Output size mismatch"; + + // Compare results + size_t numErrors = 0; + float maxError = 0.0f; + const float epsilon = 1e-2f; + + for (size_t i = 0; i < cpuParams.outputData.size(); ++i) { + float cpuVal, gpuVal; + + // Convert to float for comparison + if constexpr (std::is_same_v) { + cpuVal = cpuParams.outputData[i]; + gpuVal = gpuParams.outputData[i]; + } else if constexpr (std::is_same_v) { + // FP16 to FP32 comparison + cpuVal = fp16_to_fp32(cpuParams.outputData[i]); + gpuVal = fp16_to_fp32(gpuParams.outputData[i]); + } + + float error = std::abs(cpuVal - gpuVal); + maxError = std::max(maxError, error); + + if (error > epsilon) { + numErrors++; + if (numErrors <= 5) { // Only print first 5 errors + std::cout << "Mismatch at index " << i << ": CPU=" << cpuVal + << ", NVIDIA=" << gpuVal << ", error=" << error + << std::endl; + } + } + } + + if (print) { + std::cout << "Result Comparison:" << std::endl; + std::cout << " Total elements: " << cpuParams.outputData.size() + << std::endl; + std::cout << " Errors: " << numErrors << std::endl; + std::cout << " Max error: " << maxError << std::endl; + + if (numErrors == 0) { + std::cout << " ✓ Test PASSED" << std::endl; + } else { + std::cout << " ✗ Test FAILED" << std::endl; + } + std::cout << "========================================" << std::endl; + } + + EXPECT_EQ(numErrors, 0) + << "Results mismatch between CPU and NVIDIA (max error: " << maxError + << ")"; +} + +// Basic Gemm operation test - F32 +TEST(Gemm, Basic_MultiThread_F32) { + Shape shapeA = {3, 5}; + Shape shapeB = {5, 2}; + +#ifdef USE_CUDA + runGemmMultiThreadTest(shapeA, shapeB, 1.0f, 0.0f, false, false, + DataType(INFINI_DTYPE_F32)); +#else + std::cout << "CUDA not enabled, skipping multi-thread test" << std::endl; +#endif +} + +// Basic Gemm operation test - F16 +TEST(Gemm, Basic_MultiThread_F16) { + Shape shapeA = {3, 5}; + Shape shapeB = {5, 2}; + +#ifdef USE_CUDA + runGemmMultiThreadTest(shapeA, shapeB, 1.0f, 0.0f, false, false, + DataType(INFINI_DTYPE_F16), + generateSequentialData, true); +#else + std::cout << "CUDA not enabled, skipping multi-thread test" << std::endl; +#endif +} + +// Test with alpha and beta - F32 +TEST(Gemm, AlphaBeta_MultiThread_F32) { + Shape shapeA = {4, 6}; + Shape shapeB = {6, 3}; + +#ifdef USE_CUDA + runGemmMultiThreadTest(shapeA, shapeB, 2.0f, 0.0f, false, false, + DataType(INFINI_DTYPE_F32)); +#endif +} + +// Large matrix test - F32 +TEST(Gemm, LargeMatrix_MultiThread_F32) { + Shape shapeA = {128, 256}; + Shape shapeB = {256, 128}; + +#ifdef USE_CUDA + runGemmMultiThreadTest(shapeA, shapeB, 1.0f, 0.0f, false, false, + DataType(INFINI_DTYPE_F32), + generateRandomData); +#endif +} + +// Single device test - CPU +TEST(Gemm, SingleDevice_CPU) { + RuntimeObj::init(); + Runtime &runtime = RuntimeObj::getInstance(); + runtime->initThreadContext(INFINI_DEVICE_CPU, 0); + + Shape shapeA = {3, 5}; + Shape shapeB = {5, 2}; + + Graph g = make_ref(runtime); + auto A = g->addTensor(shapeA, DataType(INFINI_DTYPE_F32)); + auto B = g->addTensor(shapeB, DataType(INFINI_DTYPE_F32)); + auto op = + g->addOp(A, B, nullptr, nullptr, 1.0f, 0.0f, false, false); + + // Set input data + std::vector inputAData(A->getElement()); + std::vector inputBData(B->getElement()); + + std::iota(inputAData.begin(), inputAData.end(), 1); + std::iota(inputBData.begin(), inputBData.end(), 1); + + A->setData(inputAData.data()); + B->setData(inputBData.data()); + runtime->dataMalloc(g); + + // Execute computation + runtime->run(g); + + // Get output and print + auto output = op->getOutput(0); + std::cout << "CPU Output Data: " << std::endl; + output->printData(runtime); +} + +#ifdef USE_CUDA +// Single device test - NVIDIA F32 +TEST(Gemm, SingleDevice_NVIDIA_F32) { + RuntimeObj::init(); + Runtime &runtime = RuntimeObj::getInstance(); + runtime->initThreadContext(INFINI_DEVICE_NVIDIA, 0); + + Shape shapeA = {3, 5}; + Shape shapeB = {5, 2}; + + Graph g = make_ref(runtime); + auto A = g->addTensor(shapeA, DataType(INFINI_DTYPE_F32)); + auto B = g->addTensor(shapeB, DataType(INFINI_DTYPE_F32)); + auto op = + g->addOp(A, B, nullptr, nullptr, 1.0f, 0.0f, false, false); + + // Set input data + std::vector inputAData(A->getElement()); + std::vector inputBData(B->getElement()); + + std::iota(inputAData.begin(), inputAData.end(), 1); + std::iota(inputBData.begin(), inputBData.end(), 1); + + A->setData(inputAData.data()); + B->setData(inputBData.data()); + runtime->dataMalloc(g); + + // Execute computation + runtime->run(g); + + // Get output and print + auto output = op->getOutput(0); + std::cout << "NVIDIA F32 Output Data: " << std::endl; + output->printData(runtime); +} + +// Single device test - NVIDIA F16 +TEST(Gemm, SingleDevice_NVIDIA_F16) { + RuntimeObj::init(); + Runtime &runtime = RuntimeObj::getInstance(); + runtime->initThreadContext(INFINI_DEVICE_NVIDIA, 0); + + Shape shapeA = {3, 5}; + Shape shapeB = {5, 2}; + + Graph g = make_ref(runtime); + auto A = g->addTensor(shapeA, DataType(INFINI_DTYPE_F16)); + auto B = g->addTensor(shapeB, DataType(INFINI_DTYPE_F16)); + auto op = + g->addOp(A, B, nullptr, nullptr, 1.0f, 0.0f, false, false); + + // Set input data + std::vector inputAData(A->getElement()); + std::vector inputBData(B->getElement()); + + std::iota(inputAData.begin(), inputAData.end(), 1); + std::iota(inputBData.begin(), inputBData.end(), 1); + + A->setData(inputAData.data()); + B->setData(inputBData.data()); + runtime->dataMalloc(g); + + // Execute computation + runtime->run(g); + + // Get output and print + auto output = op->getOutput(0); + std::cout << "NVIDIA F16 Output Data: " << std::endl; + output->printData(runtime); +} +#endif + +} // namespace infini diff --git a/test/operators/test_elementwise_op.cc b/test/operators/test_elementwise_op.cc index 1e6fe0c..f06b58b 100644 --- a/test/operators/test_elementwise_op.cc +++ b/test/operators/test_elementwise_op.cc @@ -15,7 +15,7 @@ class ElementWiseBasicTest : public testing::Test { } }; -// 测试ElementWise的基本构造 +// Test basic construction of ElementWise TEST_F(ElementWiseBasicTest, BasicConstruction) { auto A = graph->addTensor({2, 3, 4}, DataType(INFINI_DTYPE_F32)); auto B = graph->addTensor({2, 3, 4}, DataType(INFINI_DTYPE_F32)); @@ -26,7 +26,7 @@ TEST_F(ElementWiseBasicTest, BasicConstruction) { EXPECT_EQ(elementwise->getElemenwiseOpType(), OpType::Add); } -// 测试ElementWise形状推导 - 相同形状 +// Test ElementWise shape inference - same shape TEST_F(ElementWiseBasicTest, ShapeInferenceSameShape) { auto A = graph->addTensor({2, 3, 4}, DataType(INFINI_DTYPE_F32)); auto B = graph->addTensor({2, 3, 4}, DataType(INFINI_DTYPE_F32)); @@ -47,9 +47,9 @@ TEST_F(ElementWiseBasicTest, ShapeInferenceSameShape) { EXPECT_EQ(shapeValues[2], 4); } -// 测试ElementWise形状推导 - 广播(标量广播到张量) +// Test ElementWise shape inference - broadcast (scalar to tensor) TEST_F(ElementWiseBasicTest, ShapeInferenceScalarBroadcast) { - // 标量广播到 [2, 3, 4] + // Scalar broadcasts to [2, 3, 4] auto scalar = graph->addTensor({1}, DataType(INFINI_DTYPE_F32)); auto tensor = graph->addTensor({2, 3, 4}, DataType(INFINI_DTYPE_F32)); @@ -62,15 +62,16 @@ TEST_F(ElementWiseBasicTest, ShapeInferenceScalarBroadcast) { auto outputShape = (*inferredShapes)[0]; auto shapeValues = outputShape->getConstantValue(); - EXPECT_EQ(shapeValues.size(), 3); // 标量应该广播到 tensor 的维度 + EXPECT_EQ(shapeValues.size(), + 3); // Scalar should broadcast to tensor's dimension EXPECT_EQ(shapeValues[0], 2); EXPECT_EQ(shapeValues[1], 3); EXPECT_EQ(shapeValues[2], 4); } -// 测试ElementWise形状推导 - 广播(两个操作数都需要广播) +// Test ElementWise shape inference - broadcast (both operands need broadcast) TEST_F(ElementWiseBasicTest, ShapeInferenceBothBroadcast) { - // [1, 3, 1] 和 [2, 1, 4] 广播到 [2, 3, 4] + // [1, 3, 1] and [2, 1, 4] broadcast to [2, 3, 4] auto A = graph->addTensor({1, 3, 1}, DataType(INFINI_DTYPE_F32)); auto B = graph->addTensor({2, 1, 4}, DataType(INFINI_DTYPE_F32)); @@ -88,7 +89,7 @@ TEST_F(ElementWiseBasicTest, ShapeInferenceBothBroadcast) { EXPECT_EQ(shapeValues[2], 4); } -// 测试ElementWise数据类型推断 +// Test ElementWise data type inference TEST_F(ElementWiseBasicTest, DataTypeInference) { auto A = graph->addTensor({2, 3}, DataType(INFINI_DTYPE_F32)); auto B = graph->addTensor({2, 3}, DataType(INFINI_DTYPE_F32)); @@ -100,7 +101,7 @@ TEST_F(ElementWiseBasicTest, DataTypeInference) { EXPECT_EQ(inferredTypes[0], DataType(INFINI_DTYPE_F32)); } -// 测试符号形状推导 +// Test symbolic shape inference TEST_F(ElementWiseBasicTest, SymbolicShapeInference) { auto batch = ExprObj::variable("batch"); auto height = ExprObj::variable("h"); @@ -121,7 +122,7 @@ TEST_F(ElementWiseBasicTest, SymbolicShapeInference) { EXPECT_FALSE(outputShape->isConcrete()); EXPECT_EQ(outputShape->size(), 3); - // 检查符号表达式 + // Check symbolic expression EXPECT_EQ(outputShape->toString(), "[batch, h, 256]"); } diff --git a/test/operators/test_gemm_op.cc b/test/operators/test_gemm_op.cc index f738d57..99097d1 100644 --- a/test/operators/test_gemm_op.cc +++ b/test/operators/test_gemm_op.cc @@ -14,7 +14,7 @@ class GemmBasicTest : public testing::Test { } }; -// 测试Gemm的基本构造 +// Test basic construction of Gemm TEST_F(GemmBasicTest, BasicConstruction) { auto A = graph->addTensor({2, 3}, DataType(INFINI_DTYPE_F32)); auto B = graph->addTensor({3, 4}, DataType(INFINI_DTYPE_F32)); @@ -31,7 +31,7 @@ TEST_F(GemmBasicTest, BasicConstruction) { EXPECT_FALSE(gemm->getTransB()); } -// 测试Gemm形状推导(不转置) +// Test Gemm shape inference (no transpose) TEST_F(GemmBasicTest, ShapeInferenceNoTranspose) { auto A = graph->addTensor({2, 3}, DataType(INFINI_DTYPE_F32)); auto B = graph->addTensor({3, 4}, DataType(INFINI_DTYPE_F32)); @@ -52,7 +52,7 @@ TEST_F(GemmBasicTest, ShapeInferenceNoTranspose) { EXPECT_EQ(shapeValues[2], 4); // N } -// 测试Gemm形状推导(双转置) +// Test Gemm shape inference (both transpose) TEST_F(GemmBasicTest, ShapeInferenceBothTranspose) { auto A = graph->addTensor({3, 2}, DataType(INFINI_DTYPE_F32)); // A^T will be 2x3 @@ -72,7 +72,7 @@ TEST_F(GemmBasicTest, ShapeInferenceBothTranspose) { EXPECT_EQ(shapeValues[2], 4); // N from B^T } -// 测试batch维度的广播 +// Test batch dimension broadcasting TEST_F(GemmBasicTest, ShapeInferenceBatchBroadcast) { // A: [1, M, K], B: [batch, K, N] -> [batch, M, N] auto A = graph->addTensor({1, 2, 3}, DataType(INFINI_DTYPE_F32)); @@ -88,17 +88,17 @@ TEST_F(GemmBasicTest, ShapeInferenceBatchBroadcast) { EXPECT_EQ(shapeValues[0], 5); // broadcast batch } -// 测试K维度匹配检查 +// Test K dimension matching check TEST_F(GemmBasicTest, KDimensionMismatch) { auto A = graph->addTensor({2, 3}, DataType(INFINI_DTYPE_F32)); - auto B = - graph->addTensor({5, 4}, - DataType(INFINI_DTYPE_F32)); // K维度不匹配:3 != 5 + auto B = graph->addTensor( + {5, 4}, + DataType(INFINI_DTYPE_F32)); // K dimension mismatch: 3 != 5 EXPECT_THROW(graph->addOp(A, B, nullptr, nullptr), Exception); } -// 测试数据类型推断 +// Test data type inference TEST_F(GemmBasicTest, DataTypeInference) { auto A = graph->addTensor({2, 3}, DataType(INFINI_DTYPE_F32)); auto B = graph->addTensor({3, 4}, DataType(INFINI_DTYPE_F32));