diff --git a/src/include/migraphx/op/concat.hpp b/src/include/migraphx/op/concat.hpp index 527ef55bc1c..5337d7fda3a 100644 --- a/src/include/migraphx/op/concat.hpp +++ b/src/include/migraphx/op/concat.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -109,38 +110,53 @@ struct concat new_lens[axis] = new_dim_axis; return shape::from_permutation(type, new_lens, find_permutation(inputs)); } - else if(std::all_of( - inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); })) + + // Check if we have mixed static and dynamic shapes + bool has_static = std::any_of( + inputs.begin(), inputs.end(), [](const shape& s) { return not s.dynamic(); }); + bool has_dynamic = + std::any_of(inputs.begin(), inputs.end(), [](const shape& s) { return s.dynamic(); }); + + // Convert all static shapes to dynamic shapes + if(has_static and has_dynamic) { - // Dynamic input shapes - for(std::size_t index = 0; index < inputs[0].ndim(); index++) + for(auto& input : inputs) { - if(index != axis) - { - if(not std::all_of(inputs.begin(), inputs.end(), [&](const shape& s) { - return s.dyn_dims()[index] == inputs[0].dyn_dims()[index]; - })) - MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " + - std::to_string(index)); - } - } - std::size_t new_min = 0; - std::size_t new_max = 0; - for(const auto& input : inputs) - { - auto ddim = input.dyn_dims()[axis]; - new_min += ddim.min; - new_max += ddim.max; + if(input.dynamic()) + continue; + input = input.to_dynamic(); } + } + + // Calculate the dynamic input shapes for all axes + auto common_dyn_dims = compute_common_dyn_dims(inputs); - auto new_dims = inputs[0].dyn_dims(); - new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max}; - return {inputs[0].type(), new_dims}; + // Update the dynamic dimensions for the concat axis + std::size_t new_min = 0; + std::size_t new_max = 0; + for(const auto& input : inputs) + { + auto ddim = input.dyn_dims()[axis]; + new_min += ddim.min; + new_max += ddim.max; } - else + + common_dyn_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max}; + + // Check if all dimensions can be made static + if(std::all_of(common_dyn_dims.begin(), common_dyn_dims.end(), [&](auto const& ddim) { + return ddim.is_fixed(); + })) { - MIGRAPHX_THROW("CONCAT: Cannot mix static and dynamic input shapes."); + // Return as static + std::vector new_static_dims; + std::transform(common_dyn_dims.begin(), + common_dyn_dims.end(), + std::back_inserter(new_static_dims), + [&](auto const& ddim) { return ddim.max; }); + return {inputs.at(0).type(), new_static_dims}; } + return {inputs[0].type(), common_dyn_dims}; } argument compute(const dyn_output& dyn_out, std::vector args) const diff --git a/src/tf/parse_addn.cpp b/src/tf/parse_addn.cpp index 7967d25912e..ac81589daee 100644 --- a/src/tf/parse_addn.cpp +++ b/src/tf/parse_addn.cpp @@ -40,12 +40,34 @@ struct parse_addn : op_parser const tf_parser::node_info& info, std::vector args) const { - instruction_ref sum = args[0]; - for(auto i = 1; i < args.size(); i++) + if(args.size() == 1) + return args[0]; + + if(args.size() < 5) // using heuristic when args exceed over 5 elements + { + instruction_ref sum = args[0]; + for(auto i = 1; i < args.size(); i++) + { + sum = info.add_common_op("add", sum, args[i]); + } + return sum; + } + else { - sum = info.add_instruction(make_op("add"), sum, args[i]); + std::vector unsqueezed_args; + std::transform(args.begin(), + args.end(), + std::back_inserter(unsqueezed_args), + [&info](instruction_ref arg) { + return info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), + arg); + }); + auto concatenated = + info.add_instruction(make_op("concat", {{"axis", 0}}), unsqueezed_args); + auto reduced = + info.add_instruction(make_op("reduce_sum", {{"axes", {0}}}), concatenated); + return info.add_instruction(make_op("squeeze", {{"axes", {0}}}), reduced); } - return sum; } }; diff --git a/test/ref/add.cpp b/test/ref/add.cpp index 148e9b5b08a..8e107a88639 100644 --- a/test/ref/add.cpp +++ b/test/ref/add.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -31,6 +31,35 @@ #include +TEST_CASE(addn_with_reducesum_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3}}; + std::vector a_data{0, 1, 2, 3, 4, 5}; + migraphx::shape b_shape{migraphx::shape::float_type, {2, 3}}; + std::vector b_data{0, 1, 2, 3, 4, 5}; + migraphx::shape c_shape{migraphx::shape::float_type, {2, 3}}; + std::vector c_data{0, 1, 2, 3, 4, 5}; + auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); + auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); + auto l3 = mm->add_literal(migraphx::literal{c_shape, c_data}); + auto unsqueezedl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); + auto unsqueezedl2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l2); + auto unsqueezedl3 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l3); + auto concated = mm->add_instruction( + migraphx::make_op("concat", {{"axis", 0}}), unsqueezedl1, unsqueezedl2, unsqueezedl3); + auto reduced = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), concated); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), reduced); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + EXPECT(result.get_shape().packed()); + std::vector results_vector(6); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 3, 6, 9, 12, 15}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + TEST_CASE(add_broadcast_test) { migraphx::program p; diff --git a/test/tf/gen_tf_pb.py b/test/tf/gen_tf_pb.py index 11ce411f81e..46256fb8086 100644 --- a/test/tf/gen_tf_pb.py +++ b/test/tf/gen_tf_pb.py @@ -88,6 +88,14 @@ def addn_single_test(g1): tf.math.add_n([g1_input], name='addn1') +@tf_test +def addn_with_many_elements_test(g1): + with g1.as_default(): + input_list = [] + for i in range(10): + input_list.append(tf.compat.v1.placeholder(tf.float32, shape=(1, 1648), name=str(i))) + tf.math.add_n(input_list, name='addn1') + @tf_test def argmax_test(g1): diff --git a/test/tf/models/addn_with_many_elements_test.pb b/test/tf/models/addn_with_many_elements_test.pb new file mode 100644 index 00000000000..120b9b2853d --- /dev/null +++ b/test/tf/models/addn_with_many_elements_test.pb @@ -0,0 +1,46 @@ + +3 +0 Placeholder* +dtype0* +shape : ð +3 +1 Placeholder* +dtype0* +shape : ð +3 +2 Placeholder* +dtype0* +shape : ð +3 +3 Placeholder* +dtype0* +shape : ð +3 +4 Placeholder* +dtype0* +shape : ð +3 +5 Placeholder* +dtype0* +shape : ð +3 +6 Placeholder* +dtype0* +shape : ð +3 +7 Placeholder* +dtype0* +shape : ð +3 +8 Placeholder* +dtype0* +shape : ð +3 +9 Placeholder* +dtype0* +shape : ð += +addn1AddN0123456789* +N +* +T0"ð \ No newline at end of file diff --git a/test/tf/tests/addn_test.cpp b/test/tf/tests/addn_test.cpp index 9d7ddaabb1e..6eacc029e00 100644 --- a/test/tf/tests/addn_test.cpp +++ b/test/tf/tests/addn_test.cpp @@ -51,3 +51,47 @@ TEST_CASE(addn_single_test) EXPECT(p == prog); } + +TEST_CASE(addn_with_10_elements_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto l3 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto l4 = mm->add_parameter("4", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto l5 = mm->add_parameter("5", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto l6 = mm->add_parameter("6", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto l7 = mm->add_parameter("7", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto l8 = mm->add_parameter("8", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto l9 = mm->add_parameter("9", migraphx::shape{migraphx::shape::float_type, {1, 1648}}); + auto us0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); + auto us1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); + auto us2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l2); + auto us3 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l3); + auto us4 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l4); + auto us5 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l5); + auto us6 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l6); + auto us7 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l7); + auto us8 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l8); + auto us9 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l9); + auto concatenated = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + us0, + us1, + us2, + us3, + us4, + us5, + us6, + us7, + us8, + us9); + auto reduced = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), concatenated); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), reduced); + auto prog = optimize_tf("addn_with_many_elements_test.pb", false); + + EXPECT(p == prog); +}