From 0735d8692f2cb164d8a9d49aa5f109104c38d01c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Wed, 11 Dec 2024 15:00:41 +0100 Subject: [PATCH] cleanups --- .../quantization/firefox_matmul_integer.cc | 141 +++++------------- 1 file changed, 39 insertions(+), 102 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc index 9877191a37106..94e2eb374c9ed 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -59,7 +59,6 @@ Batch size: 1 */ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { - std::cout << "FirefoxMatMulInteger8::Compute started" << std::endl; const auto* a = ctx->Input(IN_A); const auto* b = packed_b_ ? nullptr : ctx->Input(IN_B); @@ -72,28 +71,35 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { a_offset = *(static_cast(a_zero_point->DataRaw())); } + const auto* b_zero_point = ctx->Input(IN_B_ZERO_POINT); + + #ifndef __EMSCRIPTEN__ + bool b_is_signed; + const uint8_t* b_offset_ptr = &b_default_offset; bool is_b_zp_per_column = false; uint8_t b_default_offset = 0; - const uint8_t* b_offset_ptr = &b_default_offset; - const auto* b_zero_point = ctx->Input(IN_B_ZERO_POINT); if (b_zero_point != nullptr) { ORT_ENFORCE(IsBQuantParamSupported(b_zero_point->Shape(), b ? b->Shape() : b_shape_), "MatmulInteger : B zero point is not valid"); is_b_zp_per_column = !IsScalarOr1ElementVector(b_zero_point); b_offset_ptr = static_cast(b_zero_point->DataRaw()); } + #endif MatMulComputeHelper helper; const uint8_t* b_data; - bool b_is_signed; if (nullptr != b) { ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape(), nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr)); b_data = static_cast(b->DataRaw()); + #ifndef __EMSCRIPTEN__ b_is_signed = b->IsDataType(); + #endif } else { ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_, nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr)); b_data = static_cast(packed_b_.get()); + #ifndef __EMSCRIPTEN__ b_is_signed = b_is_signed_; + #endif } Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); @@ -103,34 +109,7 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { const uint8_t* a_data = static_cast(a->DataRaw()); auto* y_data = y->MutableData(); - MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; - gemm_shape.M = static_cast(helper.M()); - gemm_shape.N = static_cast(helper.N()); - gemm_shape.K = static_cast(helper.K()); - gemm_shape.AIsSigned = a->IsDataType(); - gemm_shape.BIsSigned = b_is_signed; - - const size_t batch_size = helper.OutputOffsets().size(); - - std::vector gemm_data_vec(batch_size); - - for (size_t batch = 0; batch < batch_size; batch++) { - auto& gemm_params = gemm_data_vec[batch]; - gemm_params.lda = gemm_shape.K; - gemm_params.ZeroPointA = a_offset; - gemm_params.ldb = gemm_shape.N; - gemm_params.ZeroPointB = b_offset_ptr + helper.RightZeroPointOffsets()[batch]; - gemm_params.PerColumnZeroPoints = is_b_zp_per_column; - gemm_params.ldc = gemm_shape.N; - gemm_params.BIsPacked = bool(packed_b_); - gemm_params.A = a_data + helper.LeftOffsets()[batch]; - gemm_params.B = b_data + helper.RightOffsets()[batch]; - gemm_params.C = y_data + helper.OutputOffsets()[batch]; - } - #ifdef __EMSCRIPTEN__ - //MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); - // Prepare output buffer std::vector float_output(helper.M() * helper.N(), 0.0f); @@ -152,25 +131,10 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { rows_a, // rows A width, // width cols_b, // col B - float_output.data()); - - // temporarily convert to int32 - size_t num_elements = rows_a * cols_b; - - for (size_t i = 0; i < num_elements; ++i) { - // Convert and assign: round and cast the float to int32_t - y_data[i] = static_cast(std::round(float_output[i])); - - // Optional: Clamp to int32 range (unlikely needed if input floats are reasonable) - y_data[i] = std::clamp( - y_data[i], - std::numeric_limits::min(), - std::numeric_limits::max() - ); - } + reinterpret_cast(y_data)); // Print the output - + #if 0 std::cout << "Output matrix:\n"; for (Index i = 0; i < rows_a; ++i) { for (Index j = 0; j < cols_b; ++j) { @@ -178,64 +142,37 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { } std::cout << "\n"; } + #endif #else - - std::cout << "Calling J'aime l'euneology" << std::endl; - std::cout << "A shape: " << a->Shape() << std::endl; - std::cout << "B shape: " << b->Shape() << std::endl; - size_t a_1 = static_cast(a->Shape()[0]); - size_t a_2 = static_cast(a->Shape()[1]); // <---- - size_t b_1 = static_cast(b->Shape()[1]); - - const int8_t* casted_b_data = reinterpret_cast(b_data); - const int8_t* casted_a_data = static_cast(a->DataRaw()); - - // Print A data - std::cout << "Input Tensor A (casted_a_data):" << std::endl; - for (size_t i = 0; i < a_1; ++i) { - for (size_t j = 0; j < a_2; ++j) { - std::cout << static_cast(casted_a_data[i * a_2 + j]) << " "; - } - std::cout << std::endl; // Move to the next row - } - - // Print casted B data - std::cout << "Input Tensor B (casted_b_data):" << std::endl; - for (size_t i = 0; i < a_2; ++i) { // Rows of B - for (size_t j = 0; j < b_1; ++j) { - std::cout << static_cast(casted_b_data[i * b_1 + j]) << " "; + // XXX original call + MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; + gemm_shape.M = static_cast(helper.M()); + gemm_shape.N = static_cast(helper.N()); + gemm_shape.K = static_cast(helper.K()); + gemm_shape.AIsSigned = a->IsDataType(); + gemm_shape.BIsSigned = b_is_signed; + + const size_t batch_size = helper.OutputOffsets().size(); + + std::vector gemm_data_vec(batch_size); + + for (size_t batch = 0; batch < batch_size; batch++) { + auto& gemm_params = gemm_data_vec[batch]; + gemm_params.lda = gemm_shape.K; + gemm_params.ZeroPointA = a_offset; + gemm_params.ldb = gemm_shape.N; + gemm_params.ZeroPointB = b_offset_ptr + helper.RightZeroPointOffsets()[batch]; + gemm_params.PerColumnZeroPoints = is_b_zp_per_column; + gemm_params.ldc = gemm_shape.N; + gemm_params.BIsPacked = bool(packed_b_); + gemm_params.A = a_data + helper.LeftOffsets()[batch]; + gemm_params.B = b_data + helper.RightOffsets()[batch]; + gemm_params.C = y_data + helper.OutputOffsets()[batch]; } - std::cout << std::endl; // Move to the next row - } - - // the result is dequantized to float... - gemmology::Shift::Multiply( - reinterpret_cast(casted_a_data), - casted_b_data, - a_1, - a_2, - b_1, - gemmology::callbacks::Write(reinterpret_cast(y_data)) - ); - - // and we want int32.. - // - // Get the shape of the tensor - std::cout << "y data result:" << std::endl; - - size_t M = helper.M(); - size_t N = helper.N(); - for (size_t i = 0; i < M; ++i) { - for (size_t j = 0; j < N; ++j) { - // Access the element at row i and column j - std::cout << y_data[i * N + j] << " "; - } - std::cout << std::endl; // Move to the next row - } - #endif - std::cout << "Exiting FirefoxMatMulInteger8::Compute" << std::endl; + MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); + #endif return Status::OK(); }