From 28b488ffb4741c7c2d4cffef4a2f0732fee595a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Tue, 17 Dec 2024 22:40:07 +0100 Subject: [PATCH] YES --- .../quantization/firefox_matmul_integer.cc | 107 ++++++++++++------ .../cpu/quantization/firefox_matmul_integer.h | 6 +- 2 files changed, 74 insertions(+), 39 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc index 6df41c895e723..3f6cc2369f353 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -44,29 +44,17 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( FirefoxMatMulInteger8); -#include -#include -#include - - Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { - printf("FirefoxMatMulInteger8::Compute\n"); - const auto* a = ctx->Input(IN_A); const auto* b = packed_b_ ? nullptr : ctx->Input(IN_B); uint8_t a_offset = 0; const auto* a_zero_point = ctx->Input(IN_A_ZERO_POINT); - printf("FirefoxMatMulInteger8::Compute a\n"); - if (a_zero_point != nullptr) { ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point), "MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1"); a_offset = *(static_cast(a_zero_point->DataRaw())); } - printf("FirefoxMatMulInteger8::Compute b\n"); - - uint8_t b_default_offset = 0; const auto* b_zero_point = ctx->Input(IN_B_ZERO_POINT); @@ -74,9 +62,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { const uint8_t* b_offset_ptr = &b_default_offset; bool is_b_zp_per_column = false; - printf("FirefoxMatMulInteger8::Compute c\n"); - - if (b_zero_point != nullptr) { ORT_ENFORCE(IsBQuantParamSupported(b_zero_point->Shape(), b ? b->Shape() : b_shape_), "MatmulInteger : B zero point is not valid"); @@ -85,8 +70,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { } MatMulComputeHelper helper; - printf("FirefoxMatMulInteger8::Compute d\n"); - const uint8_t* b_data; if (nullptr != b) { @@ -98,7 +81,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { b_data = static_cast(packed_b_.get()); b_is_signed = b_is_signed_; } - printf("FirefoxMatMulInteger8::Compute 2\n"); size_t M = static_cast(helper.M()); @@ -113,9 +95,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { const uint8_t* a_data = static_cast(a->DataRaw()); auto* y_data = y->MutableData(); - printf("FirefoxMatMulInteger8::Compute 3\n"); - - MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; gemm_shape.M = M; gemm_shape.N = N; @@ -124,7 +103,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { gemm_shape.BIsSigned = b_is_signed; const size_t batch_size = helper.OutputOffsets().size(); - std::vector gemm_data_vec(batch_size); for (size_t batch = 0; batch < batch_size; batch++) { @@ -141,34 +119,91 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { gemm_params.C = y_data + helper.OutputOffsets()[batch]; } - printf("FirefoxMatMulInteger8::Compute 4\n"); + #if 0 + std::cout << "Matrix A (sample):\n"; + for (size_t i = 0; i < 5; ++i) { + for (size_t j = 0; j < 5; ++j) { + std::cout << static_cast(a_data[i * helper.K() + j]) << " "; + } + std::cout << "\n"; + } + std::cout << "\n"; + std::cout << "Matrix B (sample):\n"; + for (size_t i = 0; i < 5; ++i) { + for (size_t j = 0; j < 5; ++j) { + std::cout << static_cast(b_data[i * helper.N() + j]) << " "; + } + std::cout << "\n"; + } + + std::cout << "b_zero_point content: \n"; + if (b_zero_point != nullptr) { + size_t b_zero_point_size = static_cast(b_zero_point->Shape()[0]); + const uint8_t* b_zp_data = static_cast(b_zero_point->DataRaw()); + for (size_t i = 0; i < b_zero_point_size; ++i) { + std::cout << static_cast(b_zp_data[i]) << " "; + } + std::cout << "\n"; + } else { + std::cout << "b_zero_point is null\n"; + } - auto start_matmul = Clock::now(); + #endif + //auto start_matmul = Clock::now(); int8Multiply( - reinterpret_cast(a_data), + reinterpret_cast(a->DataRaw()), a_offset, - reinterpret_cast(b_data), - reinterpret_cast(b_offset_ptr), + reinterpret_cast(b->DataRaw()), + //reinterpret_cast(b_zero_point->DataRaw()), M, - N, K, - reinterpret_cast(y_data) + N, + reinterpret_cast(y_data) ); + //auto end_matmul = Clock::now(); + // auto matmul_time = std::chrono::duration_cast(end_matmul - start_matmul).count(); + + // rowsA = M + // width = K + // colsB = N + #if 0 + for (size_t rowIndex = 0; rowIndex < rowsA; ++rowIndex) { + const uint8_t* aRow = inputMatrixAPtr + rowIndex * width; // Start of row in A + for (size_t colIndex = 0; colIndex < colsB; ++colIndex) { + int32_t tempResult = 0; + + for (size_t k = 0; k < width; ++k) { + // Row-major access + uint8_t aValue = aRow[k]; + + // Column-major access for B + int8_t bValue = inputMatrixBPtr[k * colsB + colIndex]; + + // Adjust for zero-point offsets + int32_t adjustedA = static_cast(aValue) - static_cast(a_offset); + int32_t adjustedB = static_cast(bValue); // - static_cast(b_offset_ptr[colIndex]); + + // Accumulate product + tempResult += adjustedA * adjustedB; + } + + // Write result to the output array + outputPtr[rowIndex * colsB + colIndex] = tempResult; + } + } +#endif - auto end_matmul = Clock::now(); - auto matmul_time = std::chrono::duration_cast(end_matmul - start_matmul).count(); - - // Mlas (will fallback if we don't meet requirements) /* + // Mlas (will fallback if we don't meet requirements) auto start_mblas = Clock::now(); MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); auto end_mblas = Clock::now(); auto mblas_time = std::chrono::duration_cast(end_mblas - start_mblas).count(); */ - // Output timing results - std::cout << "Timing (microseconds):\n"; - std::cout << "MatMulFull: " << matmul_time << "\n"; + //std::cout << "Timing (microseconds):\n"; + //std::cout << "MatMulFull: " << matmul_time << "\n"; + //std::cout << "MlasGemmBatch: " << mblas_time << "\n"; return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h index a845985db8a13..bd24ea6666106 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h @@ -43,13 +43,13 @@ using Index = uint32_t; extern "C" void __attribute__((import_module("wasm_gemm"), import_name("int8_multiply"))) int8Multiply(const uint8_t* input_A, - const uint8_t zero_point_A, + float zero_point_A, const int8_t* input_B, - const uint8_t* zero_point_B, + //const uint8_t* zero_point_B, Index rows_A, Index width, Index cols_B, - int32_t* output); + float* output); #endif