Skip to content

Commit 26dd1ee

Browse files
committed
moved to FF
1 parent 63d7cf6 commit 26dd1ee

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc

+33-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,22 @@ extern "C" void
3131
Index cols_B,
3232
float* output);
3333

34+
extern "C" void
35+
__attribute__((import_module("wasm_gemm"), import_name("f32_multiply")))
36+
f32Multiply(
37+
const uint8_t* a_data,
38+
float zero_point_A,
39+
const int8_t* input_B,
40+
const uint8_t* zero_point_B,
41+
Index rows_A,
42+
Index width,
43+
Index cols_B,
44+
const float* b_scale_data,
45+
float is_b_scale_per_column,
46+
float* output
47+
);
48+
49+
3450
void ScaleOutput(const Tensor& scale, Tensor& output) {
3551
ProcessBroadcastSpanFuncs funcs{
3652
[](BroadcastHelper& per_iter_bh) {
@@ -262,7 +278,23 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
262278
std::cout << "b_scale_data[0]: " << b_scale_data[0] << std::endl;
263279
#endif
264280

265-
MatMulFull(a_data, b_data, y_data_2.data(), rowsA, width, colsB, a_zp, b_zp_ptr, b_scale_data, is_b_scale_per_column);
281+
//MatMulFull(a_data, b_data, y_data, rowsA, width, colsB, a_zp, b_zp_ptr, b_scale_data, is_b_scale_per_column);
282+
283+
std::cout << "Calling f32Multiply\n";
284+
285+
f32Multiply(a_data,
286+
a_zp,
287+
b_data,
288+
b_zp_ptr,
289+
rowsA,
290+
width,
291+
colsB,
292+
b_scale_data,
293+
is_b_scale_per_column,
294+
y_data);
295+
296+
std::cout << "Done calling f32Multiply\n";
297+
266298

267299
#if 0
268300
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());

onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc

+6-9
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
103103
gemm_shape.BIsSigned = b_is_signed;
104104

105105
const size_t batch_size = helper.OutputOffsets().size();
106+
106107
std::vector<MLAS_GEMM_QUANT_DATA_PARAMS> gemm_data_vec(batch_size);
107108

108109
for (size_t batch = 0; batch < batch_size; batch++) {
@@ -118,7 +119,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
118119
gemm_params.B = b_data + helper.RightOffsets()[batch];
119120
gemm_params.C = y_data + helper.OutputOffsets()[batch];
120121
}
121-
122122
#if 0
123123
std::cout << "Matrix A (sample):\n";
124124
for (size_t i = 0; i < 5; ++i) {
@@ -147,7 +147,6 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
147147
} else {
148148
std::cout << "b_zero_point is null\n";
149149
}
150-
151150
#endif
152151
//auto start_matmul = Clock::now();
153152
int8Multiply(
@@ -161,7 +160,7 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
161160
reinterpret_cast<float*>(y_data)
162161
);
163162
//auto end_matmul = Clock::now();
164-
// auto matmul_time = std::chrono::duration_cast<Microseconds>(end_matmul - start_matmul).count();
163+
//auto matmul_time = std::chrono::duration_cast<Microseconds>(end_matmul - start_matmul).count();
165164

166165
// rowsA = M
167166
// width = K
@@ -191,20 +190,18 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
191190
outputPtr[rowIndex * colsB + colIndex] = tempResult;
192191
}
193192
}
194-
#endif
195193

196-
/*
197194
// Mlas (will fallback if we don't meet requirements)
198195
auto start_mblas = Clock::now();
199196
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());
200197
auto end_mblas = Clock::now();
201198
auto mblas_time = std::chrono::duration_cast<Microseconds>(end_mblas - start_mblas).count();
202-
*/
203199
// Output timing results
204-
//std::cout << "Timing (microseconds):\n";
205-
//std::cout << "MatMulFull: " << matmul_time << "\n";
206-
//std::cout << "MlasGemmBatch: " << mblas_time << "\n";
200+
std::cout << "Timing (microseconds):\n";
201+
std::cout << "MatMulFull: " << matmul_time << "\n";
202+
std::cout << "MlasGemmBatch: " << mblas_time << "\n";
207203

204+
#endif
208205
return Status::OK();
209206
}
210207

0 commit comments

Comments
 (0)