Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Generalize MHA pattern #2092

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

[DRAFT] Generalize MHA pattern #2092

wants to merge 5 commits into from

Conversation

gramalingam
Copy link
Collaborator

Generalize the MHA pattern (motivated by the Phi models). Specifically, we remove the initial MatMuls from the pattern (as being unnecessary). Phi uses packed MatMul (Q, K, and V are multiplied using a single MatMul and then sliced).

However, this is not sufficient yet, since Phi also uses partial rotary-embedding, which is not yet supported by the RotaryEmbedding pattern. I will separately work on the extension to the RotaryEmbedding pattern to handle partial embedding.

Comment on lines +161 to +162
# if no_match(mask, ["B", "H", "S", "St"]):
# return False

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

Copilot Autofix AI 1 day ago

To fix the problem, we need to either remove the commented-out code or reinstate it if it is necessary for the functionality. Given the presence of TODO comments, it is likely that the commented-out code was intended to be revisited and possibly reinstated. Therefore, the best approach is to reinstate the commented-out code and ensure that it is functional.

  • Reinstate the commented-out code on lines 161-162 and 173-178.
  • Ensure that the reinstated code is properly integrated and does not cause any issues.
Suggested changeset 1
onnxscript/rewriter/ort_fusions/mha.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py
--- a/onnxscript/rewriter/ort_fusions/mha.py
+++ b/onnxscript/rewriter/ort_fusions/mha.py
@@ -160,4 +160,4 @@
         # TODO: broadcast check
-        # if no_match(mask, ["B", "H", "S", "St"]):
-        #     return False
+        if no_match(mask, ["B", "H", "S", "St"]):
+            return False
         if no_match(past_key, ["B", "H", "Spast", "Dh"]):
@@ -172,8 +172,8 @@
             return False
-        # if not status:
-        #     return False
-        # if bindings["B"] * bindings["H"] != bindings["B*H"]:
-        #     return False
-        # if bindings["H"] * bindings["Dh"] != bindings["H*Dh"]:
-        #     return False
+        if not status:
+            return False
+        if bindings["B"] * bindings["H"] != bindings["B*H"]:
+            return False
+        if bindings["H"] * bindings["Dh"] != bindings["H*Dh"]:
+            return False
         return True
EOF
@@ -160,4 +160,4 @@
# TODO: broadcast check
# if no_match(mask, ["B", "H", "S", "St"]):
# return False
if no_match(mask, ["B", "H", "S", "St"]):
return False
if no_match(past_key, ["B", "H", "Spast", "Dh"]):
@@ -172,8 +172,8 @@
return False
# if not status:
# return False
# if bindings["B"] * bindings["H"] != bindings["B*H"]:
# return False
# if bindings["H"] * bindings["Dh"] != bindings["H*Dh"]:
# return False
if not status:
return False
if bindings["B"] * bindings["H"] != bindings["B*H"]:
return False
if bindings["H"] * bindings["Dh"] != bindings["H*Dh"]:
return False
return True
Copilot is powered by AI and may make mistakes. Always verify output.
Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
if ort_version >= packaging.version.Version("1.20"):
# Run model again
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'original_outputs' may be used before it is initialized.

Copilot Autofix AI 1 day ago

To fix the problem, we need to ensure that original_outputs is initialized before it is used. One way to achieve this is to initialize original_outputs to None before the conditional block and then check if it is None before using it. This ensures that original_outputs is always defined, and we can handle the case where it is not set due to the condition not being met.

Suggested changeset 1
onnxscript/rewriter/ort_fusions/mha_test.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py
--- a/onnxscript/rewriter/ort_fusions/mha_test.py
+++ b/onnxscript/rewriter/ort_fusions/mha_test.py
@@ -24,3 +24,3 @@
         xformers.fuse_cos_sin_cache(model)
-
+        original_outputs = None
         if ort_version >= packaging.version.Version("1.20"):
@@ -36,3 +36,3 @@
 
-        if ort_version >= packaging.version.Version("1.20"):
+        if ort_version >= packaging.version.Version("1.20") and original_outputs is not None:
             # Run model again
EOF
@@ -24,3 +24,3 @@
xformers.fuse_cos_sin_cache(model)

original_outputs = None
if ort_version >= packaging.version.Version("1.20"):
@@ -36,3 +36,3 @@

if ort_version >= packaging.version.Version("1.20"):
if ort_version >= packaging.version.Version("1.20") and original_outputs is not None:
# Run model again
Copilot is powered by AI and may make mistakes. Always verify output.
Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
Copy link

codecov bot commented Mar 7, 2025

❌ 83 Tests Failed:

Tests completed Failed Passed Skipped
9647 83 9564 1946
View the top 2 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0945_test_reshape_reordered_last_dims
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_reshape_reordered_last_dims'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_reshape_reordered_last_dims' (e=No module named 'tests.onnx_backend_test_code.test_reshape_reordered_last_dims') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_reshape_reordered_last_dims.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_reshape_reordered_last_dims.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_reshape_reordered_last_dims(data: FLOAT[2,3,4], shape: INT64[3]) -> (FLOAT[2,4,3]):
E       reshaped = opset21.Reshape(data, shape)
E       return reshaped
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1244_test_transpose_all_permutations_5
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_transpose_all_permutations_5'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_transpose_all_permutations_5' (e=No module named 'tests.onnx_backend_test_code.test_transpose_all_permutations_5') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_transpose_all_permutations_5.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_transpose_all_permutations_5.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_transpose_all_permutations_5(data: FLOAT[2,3,4]) -> (FLOAT[4,3,2]):
E       transposed = opset21.Transpose(data, perm=[2, 1, 0])
E       return transposed
View the full list of 1 ❄️ flaky tests
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1246_test_tril

Flake rate in main: 40.00% (Passed 3 times, Failed 2 times)

Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_tril'

The above exception was the direct cause of the following exception:
.nox\test\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_tril' (e=No module named 'tests.onnx_backend_test_code.test_tril') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_tril.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_tril.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import INT64
E   from onnxscript.onnx_opset import opset14
E   
E   @script()
E   def bck_test_tril(x: INT64[4,5]) -> (INT64[4,5]):
E       y = opset14.Trilu(x, upper=0)
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Comment on lines +136 to +137
if dim < 0:
dim += shape.rank()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be absorbed into return shape[dim]?

@@ -21,6 +22,9 @@ def _save(model, modelpath):
io.save(model, modelpath)


ort_version = packaging.version.Version(onnxruntime.__version__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used? If so

Suggested change
ort_version = packaging.version.Version(onnxruntime.__version__)
_ORT_VERSION = packaging.version.Version(onnxruntime.__version__)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

2 participants