diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c14e4091c22..5d55331a1c3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,7 +1,7 @@ # #################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -96,6 +96,10 @@ if(MIGRAPHX_ENABLE_PYTHON) add_subdirectory(py) endif() +# Op builder test +set(TEST_OP_BUILDER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/op) +add_subdirectory(op) + # multitarget test if(MIGRAPHX_ENABLE_GPU AND MIGRAPHX_ENABLE_CPU AND MIGRAPHX_ENABLE_FPGA) set(TEST_MULTI_TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR}/multi_target) diff --git a/test/op/CMakeLists.txt b/test/op/CMakeLists.txt new file mode 100644 index 00000000000..e850f58b94f --- /dev/null +++ b/test/op/CMakeLists.txt @@ -0,0 +1,30 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### + +file(GLOB OP_BUILDER_TESTS CONFIGURE_DEPENDS builder/*.cpp) + +rocm_add_test_executable(test_op_builder_test ${OP_BUILDER_TESTS}) +target_include_directories(test_op_builder_test PUBLIC ../include include) +target_link_libraries(test_op_builder_test migraphx migraphx_ref) +rocm_clang_tidy_check(test_op_builder_test) diff --git a/test/op/builder/batchnorm_test.cpp b/test/op/builder/batchnorm_test.cpp new file mode 100644 index 00000000000..06bbab96bac --- /dev/null +++ b/test/op/builder/batchnorm_test.cpp @@ -0,0 +1,103 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(batchnorm_rank_0_op_builder_test) +{ + migraphx::module mm; + + mm.add_parameter("x", {migraphx::shape::half_type, {}}); + mm.add_parameter("scale", {migraphx::shape::float_type, {3}}); + mm.add_parameter("bias", {migraphx::shape::float_type, {3}}); + mm.add_parameter("mean", {migraphx::shape::float_type, {3}}); + mm.add_parameter("variance", {migraphx::shape::float_type, {3}}); + + EXPECT(test::throws( + [&] { make_op_module("batchnorm", {}, mm.get_parameters()); }, + "rank 0 input tensor, unhandled data format")); +} + +TEST_CASE(batchnorm_rank_1_op_builder_test) +{ + migraphx::module mm; + + const float epsilon = 1e-6f; + + auto x = mm.add_parameter("x", {migraphx::shape::float_type, {10}}); + auto scale = mm.add_parameter("scale", {migraphx::shape::float_type, {1}}); + auto bias = mm.add_parameter("bias", {migraphx::shape::float_type, {1}}); + auto mean = mm.add_parameter("mean", {migraphx::shape::float_type, {1}}); + auto var = mm.add_parameter("variance", {migraphx::shape::float_type, {1}}); + + auto eps = mm.add_literal(migraphx::literal{migraphx::shape::float_type, {epsilon}}); + + auto x_sub_mean = add_common_op(mm, migraphx::make_op("sub"), {x, mean}); + auto var_eps = add_common_op(mm, migraphx::make_op("add"), {var, eps}); + auto rsqrt = mm.add_instruction(migraphx::make_op("rsqrt"), {var_eps}); + auto mul0 = add_common_op(mm, migraphx::make_op("mul"), {scale, rsqrt}); + auto r0 = add_common_op(mm, migraphx::make_op("mul"), {x_sub_mean, mul0}); + add_common_op(mm, migraphx::make_op("add"), {r0, bias}); + + EXPECT(mm == make_op_module("batchnorm", {{"epsilon", epsilon}}, mm.get_parameters())); +} + +TEST_CASE(batchnorm_rank_larger_than_2_op_builder_test) +{ + migraphx::module mm; + + auto x = mm.add_parameter("x", {migraphx::shape::half_type, {2, 3, 4}}); + auto scale = mm.add_parameter("scale", {migraphx::shape::float_type, {3}}); + auto bias = mm.add_parameter("bias", {migraphx::shape::float_type, {3}}); + auto mean = mm.add_parameter("mean", {migraphx::shape::float_type, {3}}); + auto var = mm.add_parameter("variance", {migraphx::shape::float_type, {3}}); + + auto eps = mm.add_literal(migraphx::literal{migraphx::shape::half_type, {1e-5f}}); + + auto usq_scale = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), scale); + auto usq_bias = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bias); + auto usq_mean = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), mean); + auto usq_var = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), var); + + auto x_sub_mean = add_common_op(mm, migraphx::make_op("sub"), {x, usq_mean}); + auto var_eps = add_common_op(mm, migraphx::make_op("add"), {usq_var, eps}); + auto rsqrt = mm.add_instruction(migraphx::make_op("rsqrt"), var_eps); + auto mul0 = add_common_op(mm, migraphx::make_op("mul"), {usq_scale, rsqrt}); + auto r0 = add_common_op(mm, migraphx::make_op("mul"), {x_sub_mean, mul0}); + add_common_op(mm, migraphx::make_op("add"), {r0, usq_bias}); + + EXPECT(mm == make_op_module("batchnorm", {}, mm.get_parameters())); +} + +TEST_CASE(batchnorm_invalid_arguments_op_builder_test) +{ + migraphx::module mm; + + mm.add_parameter("x", {migraphx::shape::half_type, {2}}); + mm.add_parameter("scale", {migraphx::shape::float_type, {3, 2}}); + + EXPECT(test::throws( + [&] { make_op_module("batchnorm", {}, mm.get_parameters()); }, + "argument scale, bias, mean, or var rank != 1")); +} diff --git a/test/op/builder/celu_test.cpp b/test/op/builder/celu_test.cpp new file mode 100644 index 00000000000..e2f80c895c5 --- /dev/null +++ b/test/op/builder/celu_test.cpp @@ -0,0 +1,80 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(celu_happy_path_op_builder_test) +{ + migraphx::module mm; + + const float alpha = 0.8; + const migraphx::shape s = {migraphx::shape::float_type, {3}}; + + auto x = mm.add_parameter("x", s); + + const auto& input_lens = s.lens(); + const auto& input_type = s.type(); + auto zero_lit = + mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}})); + auto one_lit = + mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}})); + auto alpha_lit = + mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}})); + auto linear_part = mm.add_instruction(migraphx::make_op("max"), zero_lit, x); + auto divi = mm.add_instruction(migraphx::make_op("div"), x, alpha_lit); + auto expo = mm.add_instruction(migraphx::make_op("exp"), divi); + auto sub = mm.add_instruction(migraphx::make_op("sub"), expo, one_lit); + auto mul = mm.add_instruction(migraphx::make_op("mul"), alpha_lit, sub); + auto exp_part = mm.add_instruction(migraphx::make_op("min"), zero_lit, mul); + mm.add_instruction(migraphx::make_op("add"), linear_part, exp_part); + + EXPECT(mm == make_op_module("celu", {{"alpha", alpha}}, mm.get_parameters())); +} + +TEST_CASE(celu_zero_alpha_op_builder_test) +{ + migraphx::module mm; + + const float alpha = 0.0f; + + EXPECT( + test::throws([&] { make_op_module("celu", {{"alpha", alpha}}, {}); }, + "alpha is zero (division by zero)")); +} + +TEST_CASE(celu_wrong_shape_type_op_builder_test) +{ + migraphx::module mm; + const float alpha = 0.8; + const migraphx::shape s = {migraphx::shape::int8_type, {3}}; + + mm.add_parameter("x", s); + + EXPECT(test::throws( + [&] { make_op_module("celu", {{"alpha", alpha}}, mm.get_parameters()); }, + "input tensor not float type")); +} diff --git a/test/op/builder/einsum_test.cpp b/test/op/builder/einsum_test.cpp new file mode 100644 index 00000000000..92bc96fa091 --- /dev/null +++ b/test/op/builder/einsum_test.cpp @@ -0,0 +1,1094 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +static bool test_invalid_input(const std::string& equation = "", + const std::string& expected_msg = "", + const std::vector& x1_dims = {3, 3}, + const std::vector& x2_dims = {3, 3}) +{ + migraphx::module mm; + + if(not x1_dims.empty()) + mm.add_parameter("x1", {migraphx::shape::float_type, x1_dims}); + + if(not x2_dims.empty()) + mm.add_parameter("x2", {migraphx::shape::float_type, x2_dims}); + + migraphx::value options{}; + if(not equation.empty()) + { + options.insert({"equation", equation}); + } + + return test::throws( + [&]() { make_op_module("einsum", options, mm.get_parameters()); }, expected_msg); +} + +TEST_CASE(einsum_multiple_arrows_negative_op_builder_test) +{ + EXPECT(test_invalid_input("ii,jj->->ij", + "einsum op_builder: Einsum equation has multiple '->' symbols")); +} + +TEST_CASE(einsum_empty_term_before_arrow_negative_op_builder_test) +{ + EXPECT( + test_invalid_input("ii,->ij", "einsum op_builder: No term specified before '->' symbol")); +} + +TEST_CASE(einsum_multiple_ellipses_negative_op_builder_test) +{ + EXPECT(test_invalid_input( + "......ii,...jj->...ij", + "einsum op_builder: Ellipsis can only appear once per einsum equation term")); +} + +TEST_CASE(einsum_comma_in_output_negative_op_builder_test) +{ + EXPECT(test_invalid_input( + "ii,jj->i,j", "einsum op_builder: Einsum equation can't have a ',' symbol in the output")); +} + +TEST_CASE(einsum_empty_term_before_comma_negative_op_builder_test) +{ + EXPECT( + test_invalid_input("ii,,jj->ij", "einsum op_builder: No term specified before ',' symbol")); +} + +TEST_CASE(einsum_last_input_missing_negative_op_builder_test) +{ + EXPECT(test_invalid_input("ii,jj,", "einsum op_builder: Last input term is missing")); +} + +TEST_CASE(einsum_term_input_mismatch_negative_op_builder_test) +{ + EXPECT(test_invalid_input( + "ii,jj,kk->ijk", + "Number of terms in the input equation - 3 does not match the number of inputs 2")); +} + +TEST_CASE(einsum_ellipsis_mismatch_negative_op_builder_test) +{ + EXPECT(test_invalid_input("...ii,...jj->...ij", + "einsum op_builder: Every occurrence of ellipsis in the equation " + "must represent the same number of dimensions", + {3, 3, 3}, + {3, 3, 3, 3})); +} + +TEST_CASE(einsum_rank_mismatch_negative_op_builder_test) +{ + EXPECT(test_invalid_input("iik,jj->ij", + "einsum op_builder: Number of labels in 1. input_term (iik) does " + "not match the rank (2) of corresponding input")); +} + +TEST_CASE(einsum_output_surplus_label_negative_op_builder_test) +{ + EXPECT(test_invalid_input("ii,jj->ijk", + "einsum op_builder: Output term contains label 107, which is not " + "present in any of the input terms")); +} + +TEST_CASE(einsum_output_missing_ellipsis_negative_op_builder_test) +{ + EXPECT(test_invalid_input("...ii,...jj->ij", + "einsum op_builder: Output term does not contain ellipsis (...) " + "even though an input term does", + {3, 3, 3}, + {3, 3, 3})); +} + +TEST_CASE(einsum_multiple_diagonals_negative_op_builder_test) +{ + EXPECT(test_invalid_input("iijj->ij", + "einsum op_builder: Parsing of equations with more than one " + "duplicated labels per input term is not implemented", + {3, 3, 3, 3}, + {})); +} + +TEST_CASE(einsum_diagonal_dim_mismatch_negative_op_builder_test) +{ + EXPECT( + test_invalid_input("ii->i", + "einsum op_builder: All duplicate labels have to be the same dimension", + {3, 4}, + {})); +} + +TEST_CASE(einsum_right_batch_diagonal_negative_op_builder_test) +{ + EXPECT(test_invalid_input("ii...->i...", + "einsum op_builder: Parsing of equations with duplicated labels and " + "batch axes that are not the outer-most axes, is not implemented", + {3, 3, 3}, + {})); +} + +TEST_CASE(einsum_permute_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij->ji"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_summation_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1}}}), op); + mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij->"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_column_sum_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), op); + op = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), op); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij->j"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_row_sum_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), op); + op = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), op); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij->i"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_matrix_vector_multiplication_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto v = mm.add_parameter("v", {migraphx::shape::float_type, {3}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), v); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), tr2); + auto tr3 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq1); + auto tr4 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1}}}), dot); + auto tr6 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij,j->i"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_matrix_matrix_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 3}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij,kj->ik"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_vector_dot_product_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {3}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr2); + auto tr3 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), unsq1); + auto tr4 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1}}}), dot); + auto tr6 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), resh3); + mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), tr6); + + EXPECT(mm == make_op_module("einsum", {{"equation", "i,i->"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_matrix_dot_product_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 3}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr2); + auto tr3 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq1); + auto tr4 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 6}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 6}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1}}}), dot); + auto tr6 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), resh3); + mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1}}}), tr6); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij,ij->"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_hadamard_product_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 3}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr2); + auto tr3 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq1); + auto tr4 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {6, -1, 1}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {6, -1, 1}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3}}}), dot); + auto tr6 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), resh3); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), tr6); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij,ij->ij"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_vector_outer_product_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {5}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), tr2); + auto tr3 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq1); + auto tr4 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 1}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 1}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 5}}}), dot); + auto tr6 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), resh3); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), tr6); + + EXPECT(mm == make_op_module("einsum", {{"equation", "i,j->ij"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_matrix_outer_product_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 5}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2, 3}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 1}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 1}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 2, 5}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), resh3); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), tr6); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij,kl->ijkl"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_batch_matrix_multiplication_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 2, 5}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {3, 5, 3}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 5}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 5}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 2, 3, 1}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ijk,ikl->ijl"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_tensor_contraction_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3, 5, 7}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {1, 3, 3, 7, 5}}); + auto tr1 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {4, 5, 6}}}), tr1); + auto tr2 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 4, 0, 1, 3}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 3}}}), tr2); + auto tr3 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 5, 6, 1, 2}}}), unsq1); + auto tr4 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 5, 6, 1, 2}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 15}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 15}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = + mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 7, 1, 3, 7, 1, 1}}}), dot); + auto tr6 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 5, 6, 1, 2, 3, 4}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3, 4}}}), sq); + + EXPECT(mm == + make_op_module("einsum", {{"equation", "pqrs,tuqvr->pstuv"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_matrix_diagonal_op_builder_test) +{ + migraphx::module mm; + auto indices_arg = mm.add_literal(migraphx::literal{ + migraphx::shape{migraphx::shape::int64_type, {3, 2}}, {0, 0, 1, 1, 2, 2}}); + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 3}}); + auto op = + mm.add_instruction(migraphx::make_op("gathernd", {{"batch_dims", 0}}), a, indices_arg); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), op); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ii->i"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_batch_matrix_diagonal_op_builder_test) +{ + migraphx::module mm; + auto indices_arg = + mm.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {3, 3, 2}}, + {0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2}}); + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 3, 3}}); + auto op = + mm.add_instruction(migraphx::make_op("gathernd", {{"batch_dims", 1}}), a, indices_arg); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), op); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "...ii->...i"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_3d_diagonal_op_builder_test) +{ + migraphx::module mm; + auto indices_arg = mm.add_literal(migraphx::literal{ + migraphx::shape{migraphx::shape::int64_type, {3, 3}}, {0, 0, 0, 1, 1, 1, 2, 2, 2}}); + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 3, 3}}); + auto op = + mm.add_instruction(migraphx::make_op("gathernd", {{"batch_dims", 0}}), a, indices_arg); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), op); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "iii->i"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_diag_vector_multiply_op_builder_test) +{ + migraphx::module mm; + + auto lit = mm.add_literal(migraphx::literal{ + migraphx::shape{migraphx::shape::int64_type, {3, 2}}, {0, 0, 1, 1, 2, 2}}); + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {3}}); + auto gath = mm.add_instruction(migraphx::make_op("gathernd", {{"batch_dims", 0}}), a, lit); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), gath); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr2); + auto tr3 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), unsq1); + auto tr4 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 1}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 1}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3}}}), dot); + auto tr6 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), resh3); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), tr6); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ii,i->i"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_matrix_trace_op_builder_test) +{ + migraphx::module mm; + auto indices_arg = mm.add_literal(migraphx::literal{ + migraphx::shape{migraphx::shape::int64_type, {3, 2}}, {0, 0, 1, 1, 2, 2}}); + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 3}}); + auto op = + mm.add_instruction(migraphx::make_op("gathernd", {{"batch_dims", 0}}), a, indices_arg); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), op); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), op); + mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ii->"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_matrix_trace_implicit_op_builder_test) +{ + migraphx::module mm; + auto indices_arg = mm.add_literal(migraphx::literal{ + migraphx::shape{migraphx::shape::int64_type, {3, 2}}, {0, 0, 1, 1, 2, 2}}); + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 3}}); + auto op = + mm.add_instruction(migraphx::make_op("gathernd", {{"batch_dims", 0}}), a, indices_arg); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), op); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), op); + mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ii"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_2d_3d_multiplication_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {3, 4, 5}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2, 3}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 5, 1}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij,jkl"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_element_wise_multiplication_and_row_sum_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {3, 4}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr2); + auto red_sum = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), unsq2); + auto tr3 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), red_sum); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 1}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 1}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1}}}), dot); + auto tr6 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "i,ij->i"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_broadcast_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 1}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 2}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsq2); + auto mbrc = + mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 2}}}), tr3); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 2}}}), mbrc); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 2, 1}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ij, jk -> ik"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_3d_broadcast_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {1, 3, 1}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 2, 4}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), unsq2); + auto mbrc = + mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1, 2}}}), tr3); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), mbrc); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "bik,bkj->bij"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_3d_opposite_broadcast_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {1, 3, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 1, 4}}); + + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), unsq2); + auto mbrc1 = + mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1, 2}}}), tr3); + auto mbrc2 = + mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 2}}}), tr4); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), mbrc1); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), mbrc2); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "bik,bkj->bij"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_3_inputs_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 2, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 2}}); + auto c = mm.add_parameter("c", {migraphx::shape::float_type, {2, 2, 2}}); + auto tr_1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), a); + auto unsq_1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3, 4, 5}}}), tr_1); + auto reds_1 = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), unsq_1); + auto tr_2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), b); + auto unsq_2 = + mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1, 4, 5}}}), tr_2); + auto tr_3 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 4, 5, 1, 3}}}), reds_1); + auto tr_4 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 4, 5, 1, 3}}}), unsq_2); + auto resh_1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 1}}}), tr_3); + auto resh_2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 1}}}), tr_4); + auto tr_5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh_2); + auto dot_1 = mm.add_instruction(migraphx::make_op("dot"), resh_1, tr_5); + auto resh_3 = + mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 1, 1, 2, 2}}}), dot_1); + auto tr_6 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 5, 2, 3}}}), resh_3); + auto tr_7 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), c); + auto unsq_3 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1, 2}}}), tr_7); + auto reds_2 = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {5}}}), unsq_3); + auto tr_8 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 5, 1, 2, 4, 3}}}), tr_6); + auto tr_9 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 5, 1, 2, 4, 3}}}), reds_2); + auto resh_4 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 2}}}), tr_8); + auto resh_5 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 2}}}), tr_9); + auto tr_10 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh_5); + auto dot_2 = mm.add_instruction(migraphx::make_op("dot"), resh_4, tr_10); + auto resh_6 = + mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 2, 2, 2, 1}}}), dot_2); + auto tr_11 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 5, 4, 1}}}), resh_6); + auto sq_1 = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3, 5}}}), tr_11); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), sq_1); + + EXPECT(mm == make_op_module("einsum", {{"equation", "bac,cd,def->ebc"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_bilinear_transformation_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {5, 3, 7}}); + auto c = mm.add_parameter("c", {migraphx::shape::float_type, {2, 7}}); + auto tr_1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + auto unsq_1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 3}}}), tr_1); + auto tr_2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), b); + auto unsq_2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), tr_2); + auto tr_3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unsq_1); + auto tr_4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unsq_2); + auto resh_1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr_3); + auto resh_2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 3}}}), tr_4); + auto tr_5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh_2); + auto dot_1 = mm.add_instruction(migraphx::make_op("dot"), resh_1, tr_5); + auto resh_3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 5, 7, 1}}}), dot_1); + auto tr_6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), resh_3); + auto tr_7 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), c); + auto unsq_3 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), tr_7); + auto tr_8 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), tr_6); + auto tr_9 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), unsq_3); + auto resh_4 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 7}}}), tr_8); + auto resh_5 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 7}}}), tr_9); + auto tr_10 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh_5); + auto dot_2 = mm.add_instruction(migraphx::make_op("dot"), resh_4, tr_10); + auto resh_6 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 5, 1}}}), dot_2); + auto tr_11 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), resh_6); + auto sq_1 = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), tr_11); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), sq_1); + + EXPECT(mm == make_op_module("einsum", {{"equation", "ik,jkl,il->ij"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_ellipsis_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 4, 2}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), sq); + + EXPECT(mm == + make_op_module("einsum", {{"equation", "...ik,kj...->ij..."}}, mm.get_parameters())); +} + +TEST_CASE(einsum_ellipsis_multidim_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 2, 3, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 4, 3, 2}}); + auto tr1 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), tr1); + auto tr2 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2, 3}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), tr2); + auto tr3 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {3, 4, 0, 1, 2}}}), unsq1); + auto tr4 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {3, 4, 0, 1, 2}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {6, -1, 2}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {6, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 2, 3, 4, 1}}}), dot); + auto tr6 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {2, 3, 4, 0, 1}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), sq); + + EXPECT(mm == + make_op_module("einsum", {{"equation", "...ik,kj...->ij..."}}, mm.get_parameters())); +} + +TEST_CASE(einsum_ellipsis_zero_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {4, 3, 2}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 3, 2, 0}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 3, 2, 0}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 2}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 2, 4, 1}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), sq); + + EXPECT(mm == + make_op_module("einsum", {{"equation", "...qhd,...khd->...hqk"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_ellipsis_implicit_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 2, 3, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {3, 4, 3, 2}}); + auto tr1 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 1, 0}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), tr1); + auto tr2 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 1, 0}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), tr2); + auto tr3 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {4, 3, 2, 0, 1}}}), unsq1); + auto tr4 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {4, 3, 2, 0, 1}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 6}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1, 6}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 2, 4, 1, 1}}}), dot); + auto tr6 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {3, 4, 2, 1, 0}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "...qhd,...khd"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_ellipsis_scalar_multiplication_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 3}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 3}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {}), tr2); + auto tr3 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq1); + auto tr4 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {6, -1, 1}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {6, -1, 1}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3}}}), dot); + auto tr6 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), resh3); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1}}}), tr6); + + EXPECT(mm == make_op_module("einsum", {{"equation", "..., ..."}}, mm.get_parameters())); +} + +TEST_CASE(einsum_common_1_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 2, 2, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 2, 2, 2}}); + auto tr1 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {4}}}), tr1); + auto tr2 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), tr2); + auto tr3 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}}), unsq1); + auto tr4 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1, 2}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 1}}}), dot); + auto tr6 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 2, 3}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "bsnh,btnh->bnts"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_common_2_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 2, 2, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 2, 2, 2}}); + auto tr1 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 5}}}), tr1); + auto red_sum1 = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), unsq1); + auto tr2 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 4}}}), tr2); + auto red_sum2 = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), unsq2); + auto tr3 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 4, 5, 2}}}), red_sum1); + auto tr4 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 4, 5, 2}}}), red_sum2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = + mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 2, 2, 2, 1}}}), dot); + auto tr6 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 5, 2, 3, 4}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1, 2}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "bsnh,ctnh->nts"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_common_3_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 2, 2, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 2, 2, 2}}); + auto tr1 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), tr1); + auto red_sum1 = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), unsq1); + auto tr2 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 3}}}), tr2); + auto red_sum2 = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), unsq2); + auto tr3 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 4, 3, 2, 5}}}), red_sum1); + auto tr4 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 4, 3, 2, 5}}}), red_sum2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = + mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 2, 2, 2, 1}}}), dot); + auto tr6 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 4, 3, 2, 5}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1, 5}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "bnst,chst->shn"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_common_4_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {2, 2, 3, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 2, 4, 2}}); + auto tr1 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {4}}}), tr1); + auto tr2 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), tr2); + auto tr3 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), unsq1); + auto tr4 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1, 2}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 3, 4, 1}}}), dot); + auto tr6 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 4, 2, 3}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), sq); + + EXPECT(mm == make_op_module("einsum", {{"equation", "bcxd,bcyd->bcxy"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_common_5_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 2, 3, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {3, 4, 3, 2}}); + auto tr1 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 1, 0}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), tr1); + auto tr2 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 1, 0}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), tr2); + auto tr3 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {1, 4, 3, 2, 0}}}), unsq1); + auto tr4 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {1, 4, 3, 2, 0}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {9, -1, 2}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {9, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 3, 2, 4, 1}}}), dot); + auto tr6 = mm.add_instruction( + migraphx::make_op("transpose", {{"permutation", {4, 0, 3, 2, 1}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1}}}), sq); + + EXPECT(mm == + make_op_module("einsum", {{"equation", "...qhd,...khd->...hqk"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_common_6_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {3, 2, 2}}); + auto b = mm.add_parameter("b", {migraphx::shape::float_type, {2, 2, 3}}); + auto tr1 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), a); + auto unsq1 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), tr1); + auto tr2 = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), b); + auto unsq2 = mm.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), tr2); + auto tr3 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), unsq1); + auto tr4 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), unsq2); + auto resh1 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), tr3); + auto resh2 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1, 2}}}), tr4); + auto tr5 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), resh2); + auto dot = mm.add_instruction(migraphx::make_op("dot"), resh1, tr5); + auto resh3 = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 3, 1}}}), dot); + auto tr6 = + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0}}}), resh3); + auto sq = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), tr6); + mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), sq); + + EXPECT(mm == + make_op_module("einsum", {{"equation", "i...k,k...j->i...j"}}, mm.get_parameters())); +} + +TEST_CASE(einsum_common_7_op_builder_test) +{ + migraphx::module mm; + auto a = mm.add_parameter("a", {migraphx::shape::float_type, {5, 5}}); + auto op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a); + op = mm.add_instruction(migraphx::make_op("unsqueeze", {}), op); + op = mm.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), op); + op = mm.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), op); + op = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {0}}}), op); + + EXPECT(mm == make_op_module("einsum", {{"equation", "...j->..."}}, mm.get_parameters())); +} diff --git a/test/op/builder/gelu_test.cpp b/test/op/builder/gelu_test.cpp new file mode 100644 index 00000000000..0cf51e71c74 --- /dev/null +++ b/test/op/builder/gelu_test.cpp @@ -0,0 +1,165 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +TEST_CASE(gelu_quick_happy_path_op_builder_test) +{ + migraphx::module mm; + const float alpha_val = 0.5f; + + auto x = mm.add_parameter("x", {migraphx::shape::float_type, {3, 3}}); + auto alpha = mm.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {alpha_val}}); + auto mul_alpha = add_common_op(mm, migraphx::make_op("mul"), {alpha, x}); + auto sigmoid = mm.add_instruction(migraphx::make_op("sigmoid"), {mul_alpha}); + add_common_op(mm, migraphx::make_op("mul"), {x, sigmoid}); + + EXPECT(mm == make_op_module("gelu_quick", {{"alpha", alpha_val}}, mm.get_parameters())); +} + +TEST_CASE(gelu_erf_happy_path_op_builder_test) +{ + migraphx::module mm; + + auto x = mm.add_parameter("x", {migraphx::shape::float_type, {3, 3}}); + auto half = mm.add_literal({migraphx::shape::float_type, {0.5f}}); + auto one = mm.add_literal({migraphx::shape::float_type, {1.0f}}); + auto sqrt2 = mm.add_literal({migraphx::shape::float_type, {static_cast(M_SQRT2)}}); + auto mul_half = add_common_op(mm, migraphx::make_op("mul"), {x, half}); + auto div = add_common_op(mm, migraphx::make_op("div"), {x, sqrt2}); + auto erf = mm.add_instruction(migraphx::make_op("erf"), div); + auto add_one = add_common_op(mm, migraphx::make_op("add"), {erf, one}); + add_common_op(mm, migraphx::make_op("mul"), {mul_half, add_one}); + + EXPECT(mm == make_op_module("gelu_erf", {}, mm.get_parameters())); +} + +TEST_CASE(gelu_tanh_fast_happy_path_op_builder_test) +{ + migraphx::module mm; + + auto x = mm.add_parameter("x", {migraphx::shape::float_type, {3, 3}}); + + const auto fit_const_val = 0.035677; + auto fit_const = mm.add_literal({migraphx::shape::float_type, {fit_const_val}}); + const auto sqrt_2_rpi_val = 0.797885; + auto sqrt_2_rpi = mm.add_literal({migraphx::shape::float_type, {sqrt_2_rpi_val}}); + auto one = mm.add_literal({migraphx::shape::float_type, {1.0f}}); + auto half = mm.add_literal({migraphx::shape::float_type, {0.5f}}); + auto three = mm.add_literal({migraphx::shape::float_type, {3.0f}}); + + // [0.044715|0.035677] * x^3 + auto pow0 = add_common_op(mm, migraphx::make_op("pow"), {x, three}); + auto mul0 = add_common_op(mm, migraphx::make_op("mul"), {pow0, fit_const}); + migraphx::instruction_ref tanh_in; + + // approx = 0.797885 * x + 0.035677 * x^3 + auto mul1 = add_common_op(mm, migraphx::make_op("mul"), {sqrt_2_rpi, x}); + tanh_in = add_common_op(mm, migraphx::make_op("add"), {mul0, mul1}); + + // 0.5 * x * (1 + Tanh(approx)) + auto tanh0 = add_common_op(mm, migraphx::make_op("tanh"), {tanh_in}); + auto add1 = add_common_op(mm, migraphx::make_op("add"), {tanh0, one}); + auto mul2 = add_common_op(mm, migraphx::make_op("mul"), {x, half}); + add_common_op(mm, migraphx::make_op("mul"), {add1, mul2}); + + EXPECT(mm == make_op_module("gelu_tanh", {{"fast", true}}, mm.get_parameters())); +} + +TEST_CASE(gelu_tanh_not_fast_happy_path_op_builder_test) +{ + migraphx::module mm; + + auto x = mm.add_parameter("x", {migraphx::shape::float_type, {3, 3}}); + + const auto fit_const_val = 0.044715; + auto fit_const = mm.add_literal({migraphx::shape::float_type, {fit_const_val}}); + const auto sqrt_2_rpi_val = sqrt(M_2_PI); + auto sqrt_2_rpi = mm.add_literal({migraphx::shape::float_type, {sqrt_2_rpi_val}}); + auto one = mm.add_literal({migraphx::shape::float_type, {1.0f}}); + auto half = mm.add_literal({migraphx::shape::float_type, {0.5f}}); + auto three = mm.add_literal({migraphx::shape::float_type, {3.0f}}); + + // [0.044715|0.035677] * x^3 + auto pow0 = add_common_op(mm, migraphx::make_op("pow"), {x, three}); + auto mul0 = add_common_op(mm, migraphx::make_op("mul"), {pow0, fit_const}); + migraphx::instruction_ref tanh_in; + + // approx = sqrt(2/pi) * (x + 0.044715 * x^3 + auto add0 = add_common_op(mm, migraphx::make_op("add"), {mul0, x}); + tanh_in = add_common_op(mm, migraphx::make_op("mul"), {add0, sqrt_2_rpi}); + + // 0.5 * x * (1 + Tanh(approx)) + auto tanh0 = add_common_op(mm, migraphx::make_op("tanh"), {tanh_in}); + auto add1 = add_common_op(mm, migraphx::make_op("add"), {tanh0, one}); + auto mul2 = add_common_op(mm, migraphx::make_op("mul"), {x, half}); + add_common_op(mm, migraphx::make_op("mul"), {add1, mul2}); + + EXPECT(mm == make_op_module("gelu_tanh", {{"fast", false}}, mm.get_parameters())); +} + +TEST_CASE(gelu_split_happy_path_op_builder_path) +{ + migraphx::module mm; + + auto x = mm.add_parameter("x", {migraphx::shape::float_type, {2, 4, 6}}); + const size_t last_dim_size = x->get_shape().lens().back(); + auto split_left = mm.add_instruction( + migraphx::make_op("slice", + {{"axes", {-1}}, {"starts", {0}}, {"ends", {last_dim_size / 2}}}), + x); + auto split_right = mm.add_instruction( + migraphx::make_op( + "slice", {{"axes", {-1}}, {"starts", {last_dim_size / 2}}, {"ends", {last_dim_size}}}), + x); + + // building up gelu_erf + migraphx::instruction_ref gelu_erf; + { + auto x2 = split_right; + auto half = mm.add_literal({migraphx::shape::float_type, {0.5f}}); + auto one = mm.add_literal({migraphx::shape::float_type, {1.0f}}); + auto sqrt2 = mm.add_literal({migraphx::shape::float_type, {static_cast(M_SQRT2)}}); + auto mul_half = add_common_op(mm, migraphx::make_op("mul"), {x2, half}); + auto div = add_common_op(mm, migraphx::make_op("div"), {x2, sqrt2}); + auto erf = mm.add_instruction(migraphx::make_op("erf"), div); + auto add_one = add_common_op(mm, migraphx::make_op("add"), {erf, one}); + gelu_erf = add_common_op(mm, migraphx::make_op("mul"), {mul_half, add_one}); + } + + add_common_op(mm, migraphx::make_op("mul"), {split_left, gelu_erf}); + + EXPECT(mm == make_op_module("gelu_split", {}, mm.get_parameters())); +} + +TEST_CASE(gelu_split_invalid_dimension_op_builder_path) +{ + migraphx::module mm; + mm.add_parameter("x", {migraphx::shape::float_type, {3, 3}}); + EXPECT(test::throws( + [&] { make_op_module("gelu_split", {}, mm.get_parameters()); }, + "gelu_split op_builder: BiasSplitGelu must have even last dimension which is >= 2")); +} diff --git a/test/op/builder/gemm_test.cpp b/test/op/builder/gemm_test.cpp new file mode 100644 index 00000000000..1ad24a073f8 --- /dev/null +++ b/test/op/builder/gemm_test.cpp @@ -0,0 +1,177 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +TEST_CASE(gemm_invalid_input_dim_op_builder_test) +{ + migraphx::module mm; + mm.add_parameter("a", {migraphx::shape::float_type, {3}}); + mm.add_parameter("b", {migraphx::shape::float_type, {3, 3, 3}}); + + EXPECT(test::throws( + [&] { make_op_module("gemm", {}, mm.get_parameters()); }, + "gemm op_builder: A and B should be rank 2, A is rank 1, B is rank 3")); +} + +TEST_CASE(gemm_normal_path_op_builder_test) +{ + migraphx::module mm; + auto a_arg = mm.add_parameter("a", {migraphx::shape::float_type, {3, 3}}); + auto b_arg = mm.add_parameter("b", {migraphx::shape::float_type, {3, 3}}); + + a_arg = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a_arg); + b_arg = mm.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b_arg); + mm.add_instruction(migraphx::make_op("dot"), a_arg, b_arg); + + EXPECT(mm == make_op_module("gemm", + {{"alpha", 1.0f}, {"transA", true}, {"transB", true}}, + mm.get_parameters())); +} + +TEST_CASE(gemm_alpha_not_one_op_builder_test) +{ + migraphx::module mm; + auto a_arg = mm.add_parameter("a", {migraphx::shape::float_type, {3, 3}}); + auto b_arg = mm.add_parameter("b", {migraphx::shape::float_type, {3, 3}}); + + const float alpha = 1.1f; + + auto alpha_literal = mm.add_literal(alpha); + a_arg = add_common_op(mm, migraphx::make_op("mul"), {alpha_literal, a_arg}); + + mm.add_instruction(migraphx::make_op("dot"), a_arg, b_arg); + + EXPECT(mm == make_op_module("gemm", + {{"alpha", alpha}, {"transA", false}, {"transB", false}}, + mm.get_parameters())); +} + +TEST_CASE(gemm_alpha_not_one_type_mismatch_op_builder_test) +{ + migraphx::module mm; + auto a_arg = mm.add_parameter("a", {migraphx::shape::fp8e4m3fnuz_type, {3, 3}}); + auto b_arg = mm.add_parameter("b", {migraphx::shape::fp8e4m3fnuz_type, {3, 3}}); + + const float alpha = 1.1f; + const auto dot_type = a_arg->get_shape().type(); + + auto alpha_literal = mm.add_literal(alpha); + a_arg = add_common_op(mm, migraphx::make_op("mul"), {alpha_literal, a_arg}); + a_arg = mm.add_instruction(migraphx::make_op("convert", {{"target_type", dot_type}}), a_arg); + + mm.add_instruction(migraphx::make_op("dot"), a_arg, b_arg); + + EXPECT(mm == make_op_module("gemm", + {{"alpha", alpha}, {"transA", false}, {"transB", false}}, + mm.get_parameters())); +} + +TEST_CASE(gemm_3_params_not_dynamic_op_builder_test) +{ + migraphx::module mm; + + auto a_arg = mm.add_parameter("a", {migraphx::shape::float_type, {1, 3}}); + auto b_arg = mm.add_parameter("b", {migraphx::shape::float_type, {3, 4}}); + auto c_arg = mm.add_parameter("c", {migraphx::shape::float_type, {1, 1}}); + + auto dot_ins = mm.add_instruction(migraphx::make_op("dot"), a_arg, b_arg); + + c_arg = mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), c_arg); + mm.add_instruction(migraphx::make_op("add"), dot_ins, c_arg); + + EXPECT(mm == + make_op_module("gemm", + {{"alpha", 1.0f}, {"transA", false}, {"transB", false}, {"beta", 1.0f}}, + mm.get_parameters())); +} + +TEST_CASE(gemm_3_params_dynamic_op_builder_test) +{ + migraphx::module mm; + + migraphx::shape::dynamic_dimension dd{1, 4}; + std::vector dyn_dims{dd, dd}; + + auto a_arg = mm.add_parameter("a", {migraphx::shape::float_type, dyn_dims}); + auto b_arg = mm.add_parameter("b", {migraphx::shape::float_type, dyn_dims}); + auto c_arg = mm.add_parameter("c", {migraphx::shape::float_type, {1, 1}}); + + auto dot_ins = mm.add_instruction(migraphx::make_op("dot"), a_arg, b_arg); + + c_arg = mm.add_instruction(migraphx::make_op("multibroadcast"), c_arg, dot_ins); + mm.add_instruction(migraphx::make_op("add"), dot_ins, c_arg); + + EXPECT(mm == + make_op_module("gemm", + {{"alpha", 1.0f}, {"transA", false}, {"transB", false}, {"beta", 1.0f}}, + mm.get_parameters())); +} + +TEST_CASE(gemm_3_params_beta_not_one_op_builder_test) +{ + migraphx::module mm; + + const float beta = 1.1f; + + auto a_arg = mm.add_parameter("a", {migraphx::shape::float_type, {1, 3}}); + auto b_arg = mm.add_parameter("b", {migraphx::shape::float_type, {3, 4}}); + auto c_arg = mm.add_parameter("c", {migraphx::shape::float_type, {1, 4}}); + + auto dot_ins = mm.add_instruction(migraphx::make_op("dot"), a_arg, b_arg); + + auto beta_literal = mm.add_literal(beta); + c_arg = add_common_op(mm, migraphx::make_op("mul"), {c_arg, beta_literal}); + mm.add_instruction(migraphx::make_op("add"), dot_ins, c_arg); + + EXPECT(mm == + make_op_module("gemm", + {{"alpha", 1.0f}, {"transA", false}, {"transB", false}, {"beta", beta}}, + mm.get_parameters())); +} + +TEST_CASE(gemm_3_params_beta_not_one_type_mismatch_op_builder_test) +{ + migraphx::module mm; + + const float beta = 0.8f; + + auto a_arg = mm.add_parameter("a", {migraphx::shape::bf16_type, {1, 3}}); + auto b_arg = mm.add_parameter("b", {migraphx::shape::bf16_type, {3, 4}}); + auto c_arg = mm.add_parameter("c", {migraphx::shape::bf16_type, {1, 4}}); + + auto dot_ins = mm.add_instruction(migraphx::make_op("dot"), a_arg, b_arg); + + auto beta_literal = mm.add_literal(beta); + c_arg = add_common_op(mm, migraphx::make_op("mul"), {c_arg, beta_literal}); + c_arg = mm.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bf16_type}}), c_arg); + mm.add_instruction(migraphx::make_op("add"), dot_ins, c_arg); + + EXPECT(mm == + make_op_module("gemm", + {{"alpha", 1.0f}, {"transA", false}, {"transB", false}, {"beta", beta}}, + mm.get_parameters())); +} diff --git a/test/op/builder/main.cpp b/test/op/builder/main.cpp new file mode 100644 index 00000000000..1c2129e3161 --- /dev/null +++ b/test/op/builder/main.cpp @@ -0,0 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include "test.hpp" + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/op/builder/mean_variance_normalization_test.cpp b/test/op/builder/mean_variance_normalization_test.cpp new file mode 100644 index 00000000000..89b9348d6de --- /dev/null +++ b/test/op/builder/mean_variance_normalization_test.cpp @@ -0,0 +1,58 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(mean_variance_normalization_invalid_input_dim_op_builder_test) +{ + migraphx::module mm; + mm.add_parameter("x", {migraphx::shape::float_type, {3}}); + + EXPECT(test::throws( + [&] { make_op_module("mean_variance_normalization", {}, mm.get_parameters()); }, + "mvn op_builder: Length of axes attribute needs to be equal to input tensor rank - 1")); +} + +TEST_CASE(mean_variance_normalization_happy_path_op_builder_test) +{ + migraphx::module mm; + + const auto axes = {2, 2, 2}; + auto x = mm.add_parameter("x", {migraphx::shape::float_type, {2, 2, 2, 2}}); + + auto x_mean = mm.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), x); + auto x_mean_squared = add_common_op(mm, migraphx::make_op("mul"), {x_mean, x_mean}); + auto x_squared = add_common_op(mm, migraphx::make_op("mul"), {x, x}); + auto x_squared_mean = + mm.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), x_squared); + auto mean_sub = add_common_op(mm, migraphx::make_op("sub"), {x_squared_mean, x_mean_squared}); + auto std = add_common_op(mm, migraphx::make_op("sqrt"), {mean_sub}); + auto dividend = add_common_op(mm, migraphx::make_op("sub"), {x, x_mean}); + auto epsilon = mm.add_literal(1e-9f); + auto divisor = add_common_op(mm, migraphx::make_op("add"), {std, epsilon}); + add_common_op(mm, migraphx::make_op("div"), {dividend, divisor}); + + EXPECT(mm == + make_op_module("mean_variance_normalization", {{"axes", axes}}, mm.get_parameters())); +} diff --git a/test/op/include/op_builder_test_utils.hpp b/test/op/include/op_builder_test_utils.hpp new file mode 100644 index 00000000000..879a8a1b7b4 --- /dev/null +++ b/test/op/include/op_builder_test_utils.hpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MIGRAPHX_GUARD_TEST_OPBUILDER_TEST_UTILS_HPP +#define MIGRAPHX_GUARD_TEST_OPBUILDER_TEST_UTILS_HPP + +#include +#include +#include + +inline migraphx::module make_op_module(const std::string& op_builder_name, + const migraphx::value& options, + const std::vector& params) +{ + migraphx::module mm_op_built; + + const std::vector& args{params.rbegin(), params.rend()}; + mm_op_built.add_instructions(args); + + const auto& params2 = mm_op_built.get_parameters(); + const std::vector& args2{params2.rbegin(), params2.rend()}; + + migraphx::op::builder::add(op_builder_name, mm_op_built, args2, options); + + return mm_op_built; +} + +#endif