Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Dec 11, 2024
1 parent b405f14 commit 0735d86
Showing 1 changed file with 39 additions and 102 deletions.
141 changes: 39 additions & 102 deletions onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ Batch size: 1
*/
Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
std::cout << "FirefoxMatMulInteger8::Compute started" << std::endl;
const auto* a = ctx->Input<Tensor>(IN_A);
const auto* b = packed_b_ ? nullptr : ctx->Input<Tensor>(IN_B);

Expand All @@ -72,28 +71,35 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
a_offset = *(static_cast<const uint8_t*>(a_zero_point->DataRaw()));
}

const auto* b_zero_point = ctx->Input<Tensor>(IN_B_ZERO_POINT);

#ifndef __EMSCRIPTEN__
bool b_is_signed;
const uint8_t* b_offset_ptr = &b_default_offset;

Check failure on line 78 in onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc

View workflow job for this annotation

GitHub Actions / Vcpkg

use of undeclared identifier 'b_default_offset'
bool is_b_zp_per_column = false;
uint8_t b_default_offset = 0;
const uint8_t* b_offset_ptr = &b_default_offset;
const auto* b_zero_point = ctx->Input<Tensor>(IN_B_ZERO_POINT);
if (b_zero_point != nullptr) {
ORT_ENFORCE(IsBQuantParamSupported(b_zero_point->Shape(), b ? b->Shape() : b_shape_),
"MatmulInteger : B zero point is not valid");
is_b_zp_per_column = !IsScalarOr1ElementVector(b_zero_point);
b_offset_ptr = static_cast<const uint8_t*>(b_zero_point->DataRaw());
}
#endif

MatMulComputeHelper helper;
const uint8_t* b_data;
bool b_is_signed;
if (nullptr != b) {
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape(), nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr));
b_data = static_cast<const uint8_t*>(b->DataRaw());
#ifndef __EMSCRIPTEN__
b_is_signed = b->IsDataType<int8_t>();
#endif
} else {
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_, nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr));
b_data = static_cast<const uint8_t*>(packed_b_.get());
#ifndef __EMSCRIPTEN__
b_is_signed = b_is_signed_;
#endif
}

Tensor* y = ctx->Output(OUT_Y, helper.OutputShape());
Expand All @@ -103,34 +109,7 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
const uint8_t* a_data = static_cast<const uint8_t*>(a->DataRaw());
auto* y_data = y->MutableData<int32_t>();

MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(helper.M());
gemm_shape.N = static_cast<size_t>(helper.N());
gemm_shape.K = static_cast<size_t>(helper.K());
gemm_shape.AIsSigned = a->IsDataType<int8_t>();
gemm_shape.BIsSigned = b_is_signed;

const size_t batch_size = helper.OutputOffsets().size();

std::vector<MLAS_GEMM_QUANT_DATA_PARAMS> gemm_data_vec(batch_size);

for (size_t batch = 0; batch < batch_size; batch++) {
auto& gemm_params = gemm_data_vec[batch];
gemm_params.lda = gemm_shape.K;
gemm_params.ZeroPointA = a_offset;
gemm_params.ldb = gemm_shape.N;
gemm_params.ZeroPointB = b_offset_ptr + helper.RightZeroPointOffsets()[batch];
gemm_params.PerColumnZeroPoints = is_b_zp_per_column;
gemm_params.ldc = gemm_shape.N;
gemm_params.BIsPacked = bool(packed_b_);
gemm_params.A = a_data + helper.LeftOffsets()[batch];
gemm_params.B = b_data + helper.RightOffsets()[batch];
gemm_params.C = y_data + helper.OutputOffsets()[batch];
}

#ifdef __EMSCRIPTEN__
//MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());

// Prepare output buffer
std::vector<float> float_output(helper.M() * helper.N(), 0.0f);

Expand All @@ -152,90 +131,48 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
rows_a, // rows A
width, // width
cols_b, // col B
float_output.data());

// temporarily convert to int32
size_t num_elements = rows_a * cols_b;

for (size_t i = 0; i < num_elements; ++i) {
// Convert and assign: round and cast the float to int32_t
y_data[i] = static_cast<int32_t>(std::round(float_output[i]));

// Optional: Clamp to int32 range (unlikely needed if input floats are reasonable)
y_data[i] = std::clamp(
y_data[i],
std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max()
);
}
reinterpret_cast<float*>(y_data));

// Print the output
#if 0
std::cout << "Output matrix:\n";
for (Index i = 0; i < rows_a; ++i) {
for (Index j = 0; j < cols_b; ++j) {
std::cout << y_data[i * cols_b + j] << " ";
}
std::cout << "\n";
}
#endif

#else

std::cout << "Calling J'aime l'euneology" << std::endl;
std::cout << "A shape: " << a->Shape() << std::endl;
std::cout << "B shape: " << b->Shape() << std::endl;
size_t a_1 = static_cast<size_t>(a->Shape()[0]);
size_t a_2 = static_cast<size_t>(a->Shape()[1]); // <----
size_t b_1 = static_cast<size_t>(b->Shape()[1]);

const int8_t* casted_b_data = reinterpret_cast<const int8_t*>(b_data);
const int8_t* casted_a_data = static_cast<const int8_t*>(a->DataRaw());

// Print A data
std::cout << "Input Tensor A (casted_a_data):" << std::endl;
for (size_t i = 0; i < a_1; ++i) {
for (size_t j = 0; j < a_2; ++j) {
std::cout << static_cast<int>(casted_a_data[i * a_2 + j]) << " ";
}
std::cout << std::endl; // Move to the next row
}

// Print casted B data
std::cout << "Input Tensor B (casted_b_data):" << std::endl;
for (size_t i = 0; i < a_2; ++i) { // Rows of B
for (size_t j = 0; j < b_1; ++j) {
std::cout << static_cast<int>(casted_b_data[i * b_1 + j]) << " ";
// XXX original call
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(helper.M());
gemm_shape.N = static_cast<size_t>(helper.N());
gemm_shape.K = static_cast<size_t>(helper.K());
gemm_shape.AIsSigned = a->IsDataType<int8_t>();
gemm_shape.BIsSigned = b_is_signed;

const size_t batch_size = helper.OutputOffsets().size();

std::vector<MLAS_GEMM_QUANT_DATA_PARAMS> gemm_data_vec(batch_size);

for (size_t batch = 0; batch < batch_size; batch++) {
auto& gemm_params = gemm_data_vec[batch];
gemm_params.lda = gemm_shape.K;
gemm_params.ZeroPointA = a_offset;
gemm_params.ldb = gemm_shape.N;
gemm_params.ZeroPointB = b_offset_ptr + helper.RightZeroPointOffsets()[batch];
gemm_params.PerColumnZeroPoints = is_b_zp_per_column;
gemm_params.ldc = gemm_shape.N;
gemm_params.BIsPacked = bool(packed_b_);
gemm_params.A = a_data + helper.LeftOffsets()[batch];
gemm_params.B = b_data + helper.RightOffsets()[batch];
gemm_params.C = y_data + helper.OutputOffsets()[batch];
}
std::cout << std::endl; // Move to the next row
}

// the result is dequantized to float...
gemmology::Shift::Multiply(
reinterpret_cast<const uint8_t*>(casted_a_data),
casted_b_data,
a_1,
a_2,
b_1,
gemmology::callbacks::Write(reinterpret_cast<float*>(y_data))
);

// and we want int32..
//
// Get the shape of the tensor
std::cout << "y data result:" << std::endl;

size_t M = helper.M();
size_t N = helper.N();
for (size_t i = 0; i < M; ++i) {
for (size_t j = 0; j < N; ++j) {
// Access the element at row i and column j
std::cout << y_data[i * N + j] << " ";
}
std::cout << std::endl; // Move to the next row
}

#endif
std::cout << "Exiting FirefoxMatMulInteger8::Compute" << std::endl;
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());
#endif
return Status::OK();
}

Expand Down

0 comments on commit 0735d86

Please sign in to comment.