Skip to content

Commit

Permalink
YES
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Dec 17, 2024
1 parent f138d24 commit 28b488f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 39 deletions.
107 changes: 71 additions & 36 deletions onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,39 +44,24 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
FirefoxMatMulInteger8);


#include <vector>
#include <cstdint>
#include <iostream>


Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
printf("FirefoxMatMulInteger8::Compute\n");

const auto* a = ctx->Input<Tensor>(IN_A);
const auto* b = packed_b_ ? nullptr : ctx->Input<Tensor>(IN_B);
uint8_t a_offset = 0;
const auto* a_zero_point = ctx->Input<Tensor>(IN_A_ZERO_POINT);
printf("FirefoxMatMulInteger8::Compute a\n");


if (a_zero_point != nullptr) {
ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point),
"MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1");
a_offset = *(static_cast<const uint8_t*>(a_zero_point->DataRaw()));
}
printf("FirefoxMatMulInteger8::Compute b\n");



uint8_t b_default_offset = 0;
const auto* b_zero_point = ctx->Input<Tensor>(IN_B_ZERO_POINT);
bool b_is_signed;
const uint8_t* b_offset_ptr = &b_default_offset;
bool is_b_zp_per_column = false;

printf("FirefoxMatMulInteger8::Compute c\n");


if (b_zero_point != nullptr) {
ORT_ENFORCE(IsBQuantParamSupported(b_zero_point->Shape(), b ? b->Shape() : b_shape_),
"MatmulInteger : B zero point is not valid");
Expand All @@ -85,8 +70,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
}

MatMulComputeHelper helper;
printf("FirefoxMatMulInteger8::Compute d\n");


const uint8_t* b_data;
if (nullptr != b) {
Expand All @@ -98,7 +81,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
b_data = static_cast<const uint8_t*>(packed_b_.get());
b_is_signed = b_is_signed_;
}
printf("FirefoxMatMulInteger8::Compute 2\n");


size_t M = static_cast<size_t>(helper.M());
Expand All @@ -113,9 +95,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
const uint8_t* a_data = static_cast<const uint8_t*>(a->DataRaw());
auto* y_data = y->MutableData<int32_t>();

printf("FirefoxMatMulInteger8::Compute 3\n");


MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = M;
gemm_shape.N = N;
Expand All @@ -124,7 +103,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
gemm_shape.BIsSigned = b_is_signed;

const size_t batch_size = helper.OutputOffsets().size();

std::vector<MLAS_GEMM_QUANT_DATA_PARAMS> gemm_data_vec(batch_size);

for (size_t batch = 0; batch < batch_size; batch++) {
Expand All @@ -141,34 +119,91 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
gemm_params.C = y_data + helper.OutputOffsets()[batch];
}

printf("FirefoxMatMulInteger8::Compute 4\n");
#if 0
std::cout << "Matrix A (sample):\n";
for (size_t i = 0; i < 5; ++i) {
for (size_t j = 0; j < 5; ++j) {
std::cout << static_cast<unsigned int>(a_data[i * helper.K() + j]) << " ";
}
std::cout << "\n";
}
std::cout << "\n";
std::cout << "Matrix B (sample):\n";
for (size_t i = 0; i < 5; ++i) {
for (size_t j = 0; j < 5; ++j) {
std::cout << static_cast<unsigned int>(b_data[i * helper.N() + j]) << " ";
}
std::cout << "\n";
}

std::cout << "b_zero_point content: \n";
if (b_zero_point != nullptr) {
size_t b_zero_point_size = static_cast<size_t>(b_zero_point->Shape()[0]);
const uint8_t* b_zp_data = static_cast<const uint8_t*>(b_zero_point->DataRaw());
for (size_t i = 0; i < b_zero_point_size; ++i) {
std::cout << static_cast<unsigned int>(b_zp_data[i]) << " ";
}
std::cout << "\n";
} else {
std::cout << "b_zero_point is null\n";
}

auto start_matmul = Clock::now();
#endif
//auto start_matmul = Clock::now();
int8Multiply(

Check failure on line 153 in onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc

View workflow job for this annotation

GitHub Actions / Vcpkg

use of undeclared identifier 'int8Multiply'
reinterpret_cast<const uint8_t*>(a_data),
reinterpret_cast<const uint8_t*>(a->DataRaw()),
a_offset,
reinterpret_cast<const int8_t*>(b_data),
reinterpret_cast<const uint8_t*>(b_offset_ptr),
reinterpret_cast<const int8_t*>(b->DataRaw()),
//reinterpret_cast<const uint8_t*>(b_zero_point->DataRaw()),
M,
N,
K,
reinterpret_cast<int32_t*>(y_data)
N,
reinterpret_cast<float*>(y_data)
);
//auto end_matmul = Clock::now();
// auto matmul_time = std::chrono::duration_cast<Microseconds>(end_matmul - start_matmul).count();

// rowsA = M
// width = K
// colsB = N
#if 0
for (size_t rowIndex = 0; rowIndex < rowsA; ++rowIndex) {
const uint8_t* aRow = inputMatrixAPtr + rowIndex * width; // Start of row in A
for (size_t colIndex = 0; colIndex < colsB; ++colIndex) {
int32_t tempResult = 0;

for (size_t k = 0; k < width; ++k) {
// Row-major access
uint8_t aValue = aRow[k];

// Column-major access for B
int8_t bValue = inputMatrixBPtr[k * colsB + colIndex];

// Adjust for zero-point offsets
int32_t adjustedA = static_cast<int32_t>(aValue) - static_cast<int32_t>(a_offset);
int32_t adjustedB = static_cast<int32_t>(bValue); // - static_cast<int32_t>(b_offset_ptr[colIndex]);

// Accumulate product
tempResult += adjustedA * adjustedB;
}

// Write result to the output array
outputPtr[rowIndex * colsB + colIndex] = tempResult;
}
}
#endif

auto end_matmul = Clock::now();
auto matmul_time = std::chrono::duration_cast<Microseconds>(end_matmul - start_matmul).count();

// Mlas (will fallback if we don't meet requirements)
/*
// Mlas (will fallback if we don't meet requirements)
auto start_mblas = Clock::now();
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());
auto end_mblas = Clock::now();
auto mblas_time = std::chrono::duration_cast<Microseconds>(end_mblas - start_mblas).count();
*/

// Output timing results
std::cout << "Timing (microseconds):\n";
std::cout << "MatMulFull: " << matmul_time << "\n";
//std::cout << "Timing (microseconds):\n";
//std::cout << "MatMulFull: " << matmul_time << "\n";
//std::cout << "MlasGemmBatch: " << mblas_time << "\n";

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ using Index = uint32_t;
extern "C" void
__attribute__((import_module("wasm_gemm"), import_name("int8_multiply")))
int8Multiply(const uint8_t* input_A,
const uint8_t zero_point_A,
float zero_point_A,
const int8_t* input_B,
const uint8_t* zero_point_B,
//const uint8_t* zero_point_B,
Index rows_A,
Index width,
Index cols_B,
int32_t* output);
float* output);


#endif
Expand Down

0 comments on commit 28b488f

Please sign in to comment.