From a1f9e799fa415332a21ccaebff5f482b5b71ad14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Sat, 14 Dec 2024 22:36:35 +0100 Subject: [PATCH] MFBT --- .../quantization/firefox_matmul_integer.cc | 221 ++++++++++++++---- .../cpu/quantization/firefox_matmul_integer.h | 2 +- onnxruntime/wasm/pre.js | 75 +----- 3 files changed, 189 insertions(+), 109 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc index ffe92feb840e7..1a2e9b48e8ef5 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -7,7 +7,11 @@ #include "core/providers/common.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" +#include // For time measurement +// Define aliases for convenience +using Clock = std::chrono::high_resolution_clock; +using Microseconds = std::chrono::microseconds; using Index = std::size_t; namespace onnxruntime { @@ -105,75 +109,210 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { gemm_params.C = y_data + helper.OutputOffsets()[batch]; } - std::vector int32_output(helper.M() * helper.N(), 0); + std::vector gemmology_output(helper.M() * helper.N(), 0); #ifdef __EMSCRIPTEN__ uint8_t zero_point_b = *(b_offset_ptr + helper.RightZeroPointOffsets()[0]); - // Output all inputs before the call - std::cout << "Matrix A:\n"; - for (size_t i = 0; i < static_cast(helper.M()); ++i) { - for (size_t j = 0; j < static_cast(helper.K()); ++j) { - std::cout << static_cast(a_data[i * helper.K() + j]) << " "; - } - std::cout << "\n"; - } - - std::cout << "Matrix B:\n"; - for (size_t i = 0; i < static_cast(helper.K()); ++i) { - for (size_t j = 0; j < static_cast(helper.N()); ++j) { - std::cout << static_cast(b_data[i * helper.N() + j]) << " "; - } - std::cout << "\n"; - } - std::cout << "A Zero point: " << static_cast(a_offset) << "\n"; std::cout << "B zero_point: " << static_cast(zero_point_b) << "\n"; std::cout << "rows A: " << helper.M() << ", width: " << helper.K() << ", Cols B: " << helper.N() << "\n"; std::cout << "B is packed: " << (packed_b_ ? "true" : "false") << "\n"; std::cout << "B is signed: " << (b_is_signed ? "true" : "false") << "\n"; - // Gemmology call - int8Multiply(reinterpret_cast(a_data), - a_offset, - reinterpret_cast(b_data), - zero_point_b, - static_cast(helper.M()), // rows A - static_cast(helper.K()), // width - static_cast(helper.N()), // col B - reinterpret_cast(int32_output.data())); + +std::cout << "Zero Points Debug:\n"; +std::cout << "A Zero Point: " << static_cast(a_offset) << "\n"; +std::cout << "B Zero Points (all columns): "; +for (size_t i = 0; i < static_cast(helper.N()); ++i) { + std::cout << static_cast(b_offset_ptr[i]) << " "; +} +std::cout << "\n"; + +std::cout << "Matrix Dimensions:\n"; +std::cout << "M (rows A): " << gemm_shape.M << ", K (width): " << gemm_shape.K + << ", N (cols B): " << gemm_shape.N << "\n"; + +std::cout << "Signedness:\n"; +std::cout << "AIsSigned: " << (gemm_shape.AIsSigned ? "true" : "false") << "\n"; +std::cout << "BIsSigned: " << (gemm_shape.BIsSigned ? "true" : "false") << "\n"; + + +std::cout << "Matrix A (sample):\n"; +for (size_t i = 0; i < 5; ++i) { + for (size_t j = 0; j < 5; ++j) { + std::cout << static_cast(a_data[i * helper.K() + j]) << " "; + } + std::cout << "\n"; +} + +std::cout << "Matrix B (sample):\n"; +for (size_t i = 0; i < 5; ++i) { + for (size_t j = 0; j < 5; ++j) { + std::cout << static_cast(b_data[i * helper.N() + j]) << " "; + } + std::cout << "\n"; +} +std::cout << "Offsets Debug:\n"; +std::cout << "Left Offsets (A): "; +for (size_t i = 0; i < batch_size; ++i) { + std::cout << helper.LeftOffsets()[i] << " "; +} +std::cout << "\n"; + +std::cout << "Right Offsets (B): "; +for (size_t i = 0; i < batch_size; ++i) { + std::cout << helper.RightOffsets()[i] << " "; +} +std::cout << "\n"; + +std::cout << "B is packed: " << (packed_b_ ? "true" : "false") << "\n"; + + +// Manually compute the first value of the first row of the output +uint32_t manual_result = 0; + +std::cout << "Dimensions: M = " << helper.M() << ", K = " << helper.K() << ", N = " << helper.N() << "\n"; + + std::cout << "Manually computing first value of the output matrix (Row 0, Col 0):\n"; + + int64_t temp_result = 0; // Use a signed type for accumulation to handle potential negatives + for (size_t k = 0; k < static_cast(helper.K()); ++k) { + uint8_t a_value = static_cast(a_data[k]); // First row of A (unsigned) + int8_t b_value = static_cast(b_data[k * helper.N()]); // First column of B (signed) + + // Adjust for zero points + int32_t adjusted_a = static_cast(a_value) - static_cast(a_offset); // A is unsigned + int32_t adjusted_b = static_cast(b_value) - static_cast(b_offset_ptr[0]); // B is signed + + // Accumulate the signed result + temp_result += static_cast(adjusted_a) * static_cast(adjusted_b); + + // Debugging individual terms + std::cout << "k = " << k + << ", A[k] = " << static_cast(a_value) + << ", B[k, 0] = " << static_cast(b_value) + << ", Adjusted A[k] = " << adjusted_a + << ", Adjusted B[k, 0] = " << adjusted_b + << ", Partial Sum (signed) = " << temp_result << "\n"; + } + + // Ensure the result fits in uint32_t, saturating if necessary + manual_result = static_cast(std::max(0, temp_result)); // Clamp to 0 for unsigned range + + std::cout << "Manual computation result (Row 0, Col 0): " << manual_result << "\n"; + + + // Gemmology call + std::cout << "Calling gemmology from onnx:\n"; + auto start_gemmology = Clock::now(); + + int8Multiply( + reinterpret_cast(a_data), + 0, // a_offset, + reinterpret_cast(b_data), + b_offset_ptr[0], + static_cast(helper.M()), // rows A + static_cast(helper.K()), // width + static_cast(helper.N()), // col B + reinterpret_cast(gemmology_output.data())); + + auto end_gemmology = Clock::now(); + auto gemmology_time = std::chrono::duration_cast(end_gemmology - start_gemmology).count(); + std::cout << "gemmology call complete.\n"; + + std::cout << "Call done\n"; + + std::cout << "Manually Clamping\n"; + + for (size_t i = 0; i < static_cast(helper.M()); ++i) { + for (size_t j = 0; j < static_cast(helper.N()); ++j) { + size_t index = i * static_cast(helper.N()) + j; + + // Interpret unsigned value as signed + uint32_t raw_value = gemmology_output[index]; + //std::cout << "Index (" << i << ", " << j << "), Original Value (unsigned): " << raw_value << "\n"; + + int32_t signed_value = static_cast(raw_value); + //std::cout << "Index (" << i << ", " << j << "), Interpreted as Signed: " << signed_value << "\n"; + + + // Clamp to non-negative + uint32_t clamped_value = static_cast(std::max(0, signed_value)); + + // Write clamped value back to output + gemmology_output[index] = clamped_value; + + // Log for debugging + if (i == 0 && j == 0) { // Only log the first value + std::cout << "Post-process Clamping for Index (0, 0):\n"; + std::cout << "Raw Value (unsigned): " << raw_value << "\n"; + std::cout << "Interpreted as Signed: " << signed_value << "\n"; + std::cout << "Clamped Value: " << clamped_value << "\n"; + } + } + + +} + + #endif +std::cout << "Calling MlasGemmBatch\n"; + +auto start_mblas = Clock::now(); // Original MatmulInteger call - MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); +MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); +auto end_mblas = Clock::now(); +auto mblas_time = std::chrono::duration_cast(end_mblas - start_mblas).count(); + + +std::cout << "Calling MlasGemmBatch done\n"; +// Compute percentage difference +double percentage_diff = (static_cast(gemmology_time - mblas_time) / mblas_time) * 100.0; + +// Display the results +std::cout << "Execution Times (Microseconds): MBlas = " << mblas_time + << ", Gemmology = " << gemmology_time + << ", Difference = " << percentage_diff << "%\n"; + + + // Compare the outputs std::cout << "Comparing Outputs:\n"; - std::cout << "Gemmology:\n"; - for (size_t i = 0; i < static_cast(helper.M()); ++i) { - for (size_t j = 0; j < static_cast(helper.N()); ++j) { - std::cout << static_cast(int32_output[i * helper.N() + j]) << " "; - } - std::cout << "\n"; - } - std::cout << "MBLas:\n"; - for (size_t i = 0; i < static_cast(helper.M()); ++i) { - for (size_t j = 0; j < static_cast(helper.N()); ++j) { - std::cout << static_cast(y_data[i * helper.N() + j]) << " "; + //for (size_t i = 0; i < static_cast(helper.M()); ++i) { + for (size_t i = 0; i < 2; ++i) { + //for (size_t j = 0; j < static_cast(helper.N()); ++j) { + for (size_t j = 0; j < 2; ++j) { + std::cout << "Gemmology:"; + std::cout << static_cast(gemmology_output[i * helper.N() + j]) << "\n"; + std::cout << "MBLas:"; + std::cout << static_cast(y_data[i * helper.N() + j]) << "\n"; } std::cout << "\n"; } +std::cout << "Comparing\n"; + for (size_t i = 0; i < static_cast(helper.M()); ++i) { for (size_t j = 0; j < static_cast(helper.N()); ++j) { + std::cout << "Mismatch lookup\n"; + + size_t index = i * helper.N() + j; - if (int32_output[index] != static_cast(y_data[index])) { - ORT_ENFORCE(false, "Mismatch at Row ", i, ", Col ", j, ": int8Multiply = ", int32_output[index], + std::cout << "Lookup at Row " << i << ", Col " << j << ": " << index << "\n"; + + if (gemmology_output[index] != static_cast(y_data[index])) { + std::cout << "Mismatch"; + + ORT_ENFORCE(false, "Mismatch at Row ", i, ", Col ", j, ": int8Multiply = ", gemmology_output[index], ", MlasGemmBatch = ", static_cast(y_data[index])); } } } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h index 3b8066a405f7a..bdc543fd7df83 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h @@ -42,7 +42,7 @@ class FirefoxMatMulInteger8 final : public MatMulIntegerBase { using Index = uint32_t; extern "C" void __attribute__((import_module("wasm_gemm"), import_name("int8_multiply"))) - int8Multiply(const int8_t* input_A, + int8Multiply(const uint8_t* input_A, float zero_point_A, const int8_t* input_B, float zero_point_B, diff --git a/onnxruntime/wasm/pre.js b/onnxruntime/wasm/pre.js index 0261bf8a40a8e..47854471783ac 100644 --- a/onnxruntime/wasm/pre.js +++ b/onnxruntime/wasm/pre.js @@ -52,36 +52,6 @@ var SharedArrayBuffer = globalThis.SharedArrayBuffer ?? new WebAssembly.Memory({ 'initial': 0, 'maximum': 0, 'shared': true }).buffer.constructor; - -function asmjsMangle(x) { - var unmangledSymbols = ["stackAlloc", "stackSave", "stackRestore"]; - return x.indexOf("dynCall_") == 0 || unmangledSymbols.includes(x) ? x : "_" + x; -} - -function exportAsmFunctions(asm) { - var global_object = this; - for (var __exportedFunc in asm) { - var jsname = asmjsMangle(__exportedFunc); - Module[jsname] = asm[__exportedFunc]; - if (global_object) { - global_object[__exportedFunc] = asm[__exportedFunc]; - } - } -} - - -function fallbackGemm(gemmToFallbackFunctionsMap) { - // The fallback gemm implementation - const FALLBACK_GEMM = "asm"; - - let fallbackGemmModuleExports = {}; - for (let key in gemmToFallbackFunctionsMap) { - fallbackGemmModuleExports[key] = (...a) => - Module[FALLBACK_GEMM][gemmToFallbackFunctionsMap[key]](...a); - } - return fallbackGemmModuleExports; -} - /** * Custom call to instantiate WebAssembly module. so we can use custom imports */ Module["instantiateWasm"] = async (info, receiveInstance) => { @@ -90,48 +60,19 @@ function fallbackGemm(gemmToFallbackFunctionsMap) { const module = await WebAssembly.compile(bytes); let imports = getWasmImports(); - // XXX mozIntGemm can't be used from web pages - we use a fallback if we are not privileged const OPTIMIZED_GEMM = "mozIntGemm"; - const optimizedGemmModule = WebAssembly[OPTIMIZED_GEMM]; - if (!optimizedGemmModule) { - const GEMM_TO_FALLBACK_FUNCTIONS_MAP = { - int8_prepare_a: "int8PrepareAFallback", - int8_prepare_b: "int8PrepareBFallback", - int8_prepare_b_from_transposed: "int8PrepareBFromTransposedFallback", - int8_prepare_b_from_quantized_transposed: - "int8PrepareBFromQuantizedTransposedFallback", - int8_prepare_bias: "int8PrepareBiasFallback", - int8_multiply_and_add_bias: "int8MultiplyAndAddBiasFallback", - int8_select_columns_of_b: "int8SelectColumnsOfBFallback", - }; - imports.wasm_gemm = fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP); - } + const optimizedGemmModuleExports = new WebAssembly.Instance(optimizedGemmModule(), { + "": { + memory: wasmMemory + } + }).exports; + + imports.wasm_gemm = optimizedGemmModuleExports; - else { - var gemmWasmMemory = new WebAssembly.Memory({ - "initial": 32768, - "maximum": 32768, - "shared": true - }); - const optimizedGemmModuleExports = new WebAssembly.Instance(optimizedGemmModule(), { - "": { - memory: gemmWasmMemory - } - }).exports; - imports.wasm_gemm = optimizedGemmModuleExports; - } - function mozReceiveInstance(instance) { - // XXX do we need a moz specific stuff here? - //var exports = instance.exports; - //Module.asm = exports; - // wasmTable = Module.asm.__indirect_function_table; ??? - //exportAsmFunctions(exports); - return receiveInstance(instance); - } try { var instance = new WebAssembly.Instance(module, imports); - mozReceiveInstance(instance); + receiveInstance(instance); } catch (error) { console.error("Error creating WebAssembly instance:", error); throw error;