Skip to content

Commit

Permalink
MFBT
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Dec 14, 2024
1 parent 7c31b4c commit a1f9e79
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 109 deletions.
221 changes: 180 additions & 41 deletions onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include "core/providers/common.h"
#include "core/util/math_cpuonly.h"
#include "core/util/qmath.h"
#include <chrono> // For time measurement

// Define aliases for convenience
using Clock = std::chrono::high_resolution_clock;
using Microseconds = std::chrono::microseconds;
using Index = std::size_t;

namespace onnxruntime {
Expand Down Expand Up @@ -105,75 +109,210 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
gemm_params.C = y_data + helper.OutputOffsets()[batch];
}

std::vector<int32_t> int32_output(helper.M() * helper.N(), 0);
std::vector<uint32_t> gemmology_output(helper.M() * helper.N(), 0);

#ifdef __EMSCRIPTEN__
uint8_t zero_point_b = *(b_offset_ptr + helper.RightZeroPointOffsets()[0]);

// Output all inputs before the call
std::cout << "Matrix A:\n";
for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.K()); ++j) {
std::cout << static_cast<int>(a_data[i * helper.K() + j]) << " ";
}
std::cout << "\n";
}

std::cout << "Matrix B:\n";
for (size_t i = 0; i < static_cast<size_t>(helper.K()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
std::cout << static_cast<int>(b_data[i * helper.N() + j]) << " ";
}
std::cout << "\n";
}

std::cout << "A Zero point: " << static_cast<int>(a_offset) << "\n";
std::cout << "B zero_point: " << static_cast<int>(zero_point_b) << "\n";
std::cout << "rows A: " << helper.M() << ", width: " << helper.K() << ", Cols B: " << helper.N() << "\n";
std::cout << "B is packed: " << (packed_b_ ? "true" : "false") << "\n";
std::cout << "B is signed: " << (b_is_signed ? "true" : "false") << "\n";

// Gemmology call
int8Multiply(reinterpret_cast<const int8_t*>(a_data),
a_offset,
reinterpret_cast<const int8_t*>(b_data),
zero_point_b,
static_cast<size_t>(helper.M()), // rows A
static_cast<size_t>(helper.K()), // width
static_cast<size_t>(helper.N()), // col B
reinterpret_cast<float*>(int32_output.data()));

std::cout << "Zero Points Debug:\n";
std::cout << "A Zero Point: " << static_cast<int>(a_offset) << "\n";
std::cout << "B Zero Points (all columns): ";
for (size_t i = 0; i < static_cast<size_t>(helper.N()); ++i) {
std::cout << static_cast<int>(b_offset_ptr[i]) << " ";
}
std::cout << "\n";

std::cout << "Matrix Dimensions:\n";
std::cout << "M (rows A): " << gemm_shape.M << ", K (width): " << gemm_shape.K
<< ", N (cols B): " << gemm_shape.N << "\n";

std::cout << "Signedness:\n";
std::cout << "AIsSigned: " << (gemm_shape.AIsSigned ? "true" : "false") << "\n";
std::cout << "BIsSigned: " << (gemm_shape.BIsSigned ? "true" : "false") << "\n";


std::cout << "Matrix A (sample):\n";
for (size_t i = 0; i < 5; ++i) {
for (size_t j = 0; j < 5; ++j) {
std::cout << static_cast<unsigned int>(a_data[i * helper.K() + j]) << " ";
}
std::cout << "\n";
}

std::cout << "Matrix B (sample):\n";
for (size_t i = 0; i < 5; ++i) {
for (size_t j = 0; j < 5; ++j) {
std::cout << static_cast<unsigned int>(b_data[i * helper.N() + j]) << " ";
}
std::cout << "\n";
}
std::cout << "Offsets Debug:\n";
std::cout << "Left Offsets (A): ";
for (size_t i = 0; i < batch_size; ++i) {
std::cout << helper.LeftOffsets()[i] << " ";
}
std::cout << "\n";

std::cout << "Right Offsets (B): ";
for (size_t i = 0; i < batch_size; ++i) {
std::cout << helper.RightOffsets()[i] << " ";
}
std::cout << "\n";

std::cout << "B is packed: " << (packed_b_ ? "true" : "false") << "\n";


// Manually compute the first value of the first row of the output
uint32_t manual_result = 0;

std::cout << "Dimensions: M = " << helper.M() << ", K = " << helper.K() << ", N = " << helper.N() << "\n";

std::cout << "Manually computing first value of the output matrix (Row 0, Col 0):\n";

int64_t temp_result = 0; // Use a signed type for accumulation to handle potential negatives
for (size_t k = 0; k < static_cast<size_t>(helper.K()); ++k) {
uint8_t a_value = static_cast<uint8_t>(a_data[k]); // First row of A (unsigned)
int8_t b_value = static_cast<int8_t>(b_data[k * helper.N()]); // First column of B (signed)

// Adjust for zero points
int32_t adjusted_a = static_cast<int32_t>(a_value) - static_cast<int32_t>(a_offset); // A is unsigned
int32_t adjusted_b = static_cast<int32_t>(b_value) - static_cast<int32_t>(b_offset_ptr[0]); // B is signed

// Accumulate the signed result
temp_result += static_cast<int64_t>(adjusted_a) * static_cast<int64_t>(adjusted_b);

// Debugging individual terms
std::cout << "k = " << k
<< ", A[k] = " << static_cast<int>(a_value)
<< ", B[k, 0] = " << static_cast<int>(b_value)
<< ", Adjusted A[k] = " << adjusted_a
<< ", Adjusted B[k, 0] = " << adjusted_b
<< ", Partial Sum (signed) = " << temp_result << "\n";
}

// Ensure the result fits in uint32_t, saturating if necessary
manual_result = static_cast<uint32_t>(std::max<int64_t>(0, temp_result)); // Clamp to 0 for unsigned range

std::cout << "Manual computation result (Row 0, Col 0): " << manual_result << "\n";


// Gemmology call
std::cout << "Calling gemmology from onnx:\n";
auto start_gemmology = Clock::now();

int8Multiply(
reinterpret_cast<const uint8_t*>(a_data),
0, // a_offset,
reinterpret_cast<const int8_t*>(b_data),
b_offset_ptr[0],
static_cast<size_t>(helper.M()), // rows A
static_cast<size_t>(helper.K()), // width
static_cast<size_t>(helper.N()), // col B
reinterpret_cast<float*>(gemmology_output.data()));

auto end_gemmology = Clock::now();
auto gemmology_time = std::chrono::duration_cast<Microseconds>(end_gemmology - start_gemmology).count();
std::cout << "gemmology call complete.\n";

std::cout << "Call done\n";

std::cout << "Manually Clamping\n";

for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
size_t index = i * static_cast<size_t>(helper.N()) + j;

// Interpret unsigned value as signed
uint32_t raw_value = gemmology_output[index];
//std::cout << "Index (" << i << ", " << j << "), Original Value (unsigned): " << raw_value << "\n";

int32_t signed_value = static_cast<int32_t>(raw_value);
//std::cout << "Index (" << i << ", " << j << "), Interpreted as Signed: " << signed_value << "\n";


// Clamp to non-negative
uint32_t clamped_value = static_cast<uint32_t>(std::max<int32_t>(0, signed_value));

// Write clamped value back to output
gemmology_output[index] = clamped_value;

// Log for debugging
if (i == 0 && j == 0) { // Only log the first value
std::cout << "Post-process Clamping for Index (0, 0):\n";
std::cout << "Raw Value (unsigned): " << raw_value << "\n";
std::cout << "Interpreted as Signed: " << signed_value << "\n";
std::cout << "Clamped Value: " << clamped_value << "\n";
}
}


}


#endif
std::cout << "Calling MlasGemmBatch\n";

auto start_mblas = Clock::now();

// Original MatmulInteger call
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());
auto end_mblas = Clock::now();
auto mblas_time = std::chrono::duration_cast<Microseconds>(end_mblas - start_mblas).count();


std::cout << "Calling MlasGemmBatch done\n";
// Compute percentage difference
double percentage_diff = (static_cast<double>(gemmology_time - mblas_time) / mblas_time) * 100.0;

Check failure on line 272 in onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc

View workflow job for this annotation

GitHub Actions / Vcpkg

use of undeclared identifier 'gemmology_time'

// Display the results
std::cout << "Execution Times (Microseconds): MBlas = " << mblas_time
<< ", Gemmology = " << gemmology_time

Check failure on line 276 in onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc

View workflow job for this annotation

GitHub Actions / Vcpkg

use of undeclared identifier 'gemmology_time'
<< ", Difference = " << percentage_diff << "%\n";




// Compare the outputs
std::cout << "Comparing Outputs:\n";
std::cout << "Gemmology:\n";
for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
std::cout << static_cast<int>(int32_output[i * helper.N() + j]) << " ";
}
std::cout << "\n";
}
std::cout << "MBLas:\n";
for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
std::cout << static_cast<int>(y_data[i * helper.N() + j]) << " ";
//for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t i = 0; i < 2; ++i) {
//for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
for (size_t j = 0; j < 2; ++j) {
std::cout << "Gemmology:";
std::cout << static_cast<uint32_t>(gemmology_output[i * helper.N() + j]) << "\n";
std::cout << "MBLas:";
std::cout << static_cast<uint32_t>(y_data[i * helper.N() + j]) << "\n";
}
std::cout << "\n";
}
std::cout << "Comparing\n";


for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
std::cout << "Mismatch lookup\n";


size_t index = i * helper.N() + j;
if (int32_output[index] != static_cast<float>(y_data[index])) {
ORT_ENFORCE(false, "Mismatch at Row ", i, ", Col ", j, ": int8Multiply = ", int32_output[index],
std::cout << "Lookup at Row " << i << ", Col " << j << ": " << index << "\n";

if (gemmology_output[index] != static_cast<float>(y_data[index])) {
std::cout << "Mismatch";

ORT_ENFORCE(false, "Mismatch at Row ", i, ", Col ", j, ": int8Multiply = ", gemmology_output[index],
", MlasGemmBatch = ", static_cast<float>(y_data[index]));
}
}
}


return Status::OK();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FirefoxMatMulInteger8 final : public MatMulIntegerBase {
using Index = uint32_t;
extern "C" void
__attribute__((import_module("wasm_gemm"), import_name("int8_multiply")))
int8Multiply(const int8_t* input_A,
int8Multiply(const uint8_t* input_A,
float zero_point_A,
const int8_t* input_B,
float zero_point_B,
Expand Down
75 changes: 8 additions & 67 deletions onnxruntime/wasm/pre.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,36 +52,6 @@ var SharedArrayBuffer = globalThis.SharedArrayBuffer ??
new WebAssembly.Memory({ 'initial': 0, 'maximum': 0, 'shared': true }).buffer.constructor;



function asmjsMangle(x) {
var unmangledSymbols = ["stackAlloc", "stackSave", "stackRestore"];
return x.indexOf("dynCall_") == 0 || unmangledSymbols.includes(x) ? x : "_" + x;
}

function exportAsmFunctions(asm) {
var global_object = this;
for (var __exportedFunc in asm) {
var jsname = asmjsMangle(__exportedFunc);
Module[jsname] = asm[__exportedFunc];
if (global_object) {
global_object[__exportedFunc] = asm[__exportedFunc];
}
}
}


function fallbackGemm(gemmToFallbackFunctionsMap) {
// The fallback gemm implementation
const FALLBACK_GEMM = "asm";

let fallbackGemmModuleExports = {};
for (let key in gemmToFallbackFunctionsMap) {
fallbackGemmModuleExports[key] = (...a) =>
Module[FALLBACK_GEMM][gemmToFallbackFunctionsMap[key]](...a);
}
return fallbackGemmModuleExports;
}

/**
* Custom call to instantiate WebAssembly module. so we can use custom imports
*/ Module["instantiateWasm"] = async (info, receiveInstance) => {
Expand All @@ -90,48 +60,19 @@ function fallbackGemm(gemmToFallbackFunctionsMap) {
const module = await WebAssembly.compile(bytes);
let imports = getWasmImports();

// XXX mozIntGemm can't be used from web pages - we use a fallback if we are not privileged
const OPTIMIZED_GEMM = "mozIntGemm";

const optimizedGemmModule = WebAssembly[OPTIMIZED_GEMM];
if (!optimizedGemmModule) {
const GEMM_TO_FALLBACK_FUNCTIONS_MAP = {
int8_prepare_a: "int8PrepareAFallback",
int8_prepare_b: "int8PrepareBFallback",
int8_prepare_b_from_transposed: "int8PrepareBFromTransposedFallback",
int8_prepare_b_from_quantized_transposed:
"int8PrepareBFromQuantizedTransposedFallback",
int8_prepare_bias: "int8PrepareBiasFallback",
int8_multiply_and_add_bias: "int8MultiplyAndAddBiasFallback",
int8_select_columns_of_b: "int8SelectColumnsOfBFallback",
};
imports.wasm_gemm = fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP);
}
const optimizedGemmModuleExports = new WebAssembly.Instance(optimizedGemmModule(), {
"": {
memory: wasmMemory
}
}).exports;

imports.wasm_gemm = optimizedGemmModuleExports;

else {
var gemmWasmMemory = new WebAssembly.Memory({
"initial": 32768,
"maximum": 32768,
"shared": true
});
const optimizedGemmModuleExports = new WebAssembly.Instance(optimizedGemmModule(), {
"": {
memory: gemmWasmMemory
}
}).exports;
imports.wasm_gemm = optimizedGemmModuleExports;
}
function mozReceiveInstance(instance) {
// XXX do we need a moz specific stuff here?
//var exports = instance.exports;
//Module.asm = exports;
// wasmTable = Module.asm.__indirect_function_table; ???
//exportAsmFunctions(exports);
return receiveInstance(instance);
}
try {
var instance = new WebAssembly.Instance(module, imports);
mozReceiveInstance(instance);
receiveInstance(instance);
} catch (error) {
console.error("Error creating WebAssembly instance:", error);
throw error;
Expand Down

0 comments on commit a1f9e79

Please sign in to comment.