From 51b6ae1de7cfe0f4cc0cb71ec20c12559d0a164c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Mon, 2 Dec 2024 10:09:49 +0100 Subject: [PATCH] added a firefox matmul backend --- build.sh | 15 +++++++++------ .../contrib_ops/cpu/cpu_contrib_kernels.cc | 3 +-- .../cpu/quantization/firefox_matmul_integer.cc | 1 + .../core/graph/contrib_ops/contrib_defs.cc | 2 +- onnxruntime/core/graph/contrib_ops/ms_opset.h | 6 +++++- .../contrib_ops/firefox_matmul_integer_test.cc | 16 ++++++++-------- 6 files changed, 25 insertions(+), 18 deletions(-) diff --git a/build.sh b/build.sh index bf799ac8b7211..0b293effe6330 100755 --- a/build.sh +++ b/build.sh @@ -1,21 +1,24 @@ #!/bin/bash # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +set -ex # Get directory this script is in -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" OS=$(uname -s) if [ "$OS" = "Darwin" ]; then - DIR_OS="MacOS" + DIR_OS="MacOS" else - DIR_OS="Linux" + DIR_OS="Linux" fi if [[ "$*" == *"--ios"* ]]; then - DIR_OS="iOS" + DIR_OS="iOS" elif [[ "$*" == *"--android"* ]]; then - DIR_OS="Android" + DIR_OS="Android" fi -python3 $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@" +PYTHON="${PYTHON:-python3}" + +$PYTHON $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@" diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 63d23c1e2549c..144b104aac7df 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -62,8 +62,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGram class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FirefoxMatMulInteger); - +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FirefoxMatMulInteger8); // ******** Start: Quantization ******************* // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool); diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc index 871cc2aaa7db1..e4fd8714eb11e 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -42,6 +42,7 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* static_cast(helper.K())); } + printf("I was called\n"); return Status::OK(); } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 3e7ad56bfa2e8..36bb1fb251dc8 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1987,7 +1987,7 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy- -ONNX_MS_OPERATOR_SET_SCHEMA(FirefoxMatMulInteger, 1, +ONNX_MS_OPERATOR_SET_SCHEMA(FirefoxMatMulInteger8, 1, OpSchema() .SetDoc(FirefoxMatMulInteger_doc) .Input(0, "A", "N-dimensional matrix A", "T1") diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 94376075fa985..930b81b1d7354 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -79,7 +79,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LongformerAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger8); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4); #endif @@ -190,7 +190,11 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); +<<<<<<< HEAD fn(GetOpSchema()); +======= + fn(GetOpSchema()); +>>>>>>> 045a17021c (added a firefox matmul backend) #ifndef ORT_MINIMAL_BUILD fn(GetOpSchema()); #endif diff --git a/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc b/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc index 55118ed8f7038..3b8e079ec705a 100644 --- a/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc +++ b/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc @@ -11,32 +11,32 @@ namespace onnxruntime { namespace test { -TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_1) { - OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain); +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_1) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); test.AddInput("T1", {1, 1}, {15}); test.AddInput("T2", {1, 1}, {8}); test.AddOutput("T3", {1, 1}, {120}); // Result is 15 * 8 test.Run(); } -TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_2) { - OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain); +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_2) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); test.AddInput("T1", {1, 2}, {-7, 10}); test.AddInput("T2", {2, 1}, {-8, -11}); test.AddOutput("T3", {1, 1}, {8}); // Result is (-7 * -8) + (10 * -11) test.Run(); } -TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_Empty_input) { - OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain); +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_Empty_input) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); test.AddInput("T1", {0, 2}, {}); test.AddInput("T2", {2, 1}, {-8, -11}); test.AddOutput("T3", {0, 1}, {}); // Empty input produces an empty output test.Run(); } -TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_3) { - OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain); +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_3) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); test.AddInput("T1", {3, 2}, {-7, 10, 10, -113, 22, -36}); test.AddInput("T2", {2, 4}, {-8, -11, 13, 14, -9, 12, 3, -6}); test.AddOutput("T3", {3, 4},