Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 42 additions & 26 deletions src/include/migraphx/op/concat.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -35,6 +35,7 @@
#include <migraphx/value.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/common.hpp>
#include <cmath>
#include <utility>

Expand Down Expand Up @@ -109,38 +110,53 @@ struct concat
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
else if(std::all_of(
inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))

// Check if we have mixed static and dynamic shapes
bool has_static = std::any_of(
inputs.begin(), inputs.end(), [](const shape& s) { return not s.dynamic(); });
bool has_dynamic =
std::any_of(inputs.begin(), inputs.end(), [](const shape& s) { return s.dynamic(); });

// Convert all static shapes to dynamic shapes
if(has_static and has_dynamic)
{
// Dynamic input shapes
for(std::size_t index = 0; index < inputs[0].ndim(); index++)
for(auto& input : inputs)
{
if(index != axis)
{
if(not std::all_of(inputs.begin(), inputs.end(), [&](const shape& s) {
return s.dyn_dims()[index] == inputs[0].dyn_dims()[index];
}))
MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " +
std::to_string(index));
}
}
std::size_t new_min = 0;
std::size_t new_max = 0;
for(const auto& input : inputs)
{
auto ddim = input.dyn_dims()[axis];
new_min += ddim.min;
new_max += ddim.max;
if(input.dynamic())
continue;
input = input.to_dynamic();
}
}
Comment on lines +121 to +129
Copy link

Copilot AI Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modifying the inputs parameter directly can lead to unexpected side effects since it's passed by value but contains references to shapes. This modification affects the original shapes which may be used elsewhere. Consider creating a local copy of inputs or using a different approach to handle mixed shapes.

Copilot uses AI. Check for mistakes.

// Calculate the dynamic input shapes for all axes
auto common_dyn_dims = compute_common_dyn_dims(inputs);

auto new_dims = inputs[0].dyn_dims();
new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max};
return {inputs[0].type(), new_dims};
// Update the dynamic dimensions for the concat axis
std::size_t new_min = 0;
std::size_t new_max = 0;
for(const auto& input : inputs)
{
auto ddim = input.dyn_dims()[axis];
new_min += ddim.min;
new_max += ddim.max;
}
else

common_dyn_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max};

// Check if all dimensions can be made static
if(std::all_of(common_dyn_dims.begin(), common_dyn_dims.end(), [&](auto const& ddim) {
return ddim.is_fixed();
}))
{
MIGRAPHX_THROW("CONCAT: Cannot mix static and dynamic input shapes.");
// Return as static
std::vector<size_t> new_static_dims;
std::transform(common_dyn_dims.begin(),
common_dyn_dims.end(),
std::back_inserter(new_static_dims),
[&](auto const& ddim) { return ddim.max; });
return {inputs.at(0).type(), new_static_dims};
}
return {inputs[0].type(), common_dyn_dims};
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
Expand Down
30 changes: 26 additions & 4 deletions src/tf/parse_addn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,34 @@ struct parse_addn : op_parser<parse_addn>
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
instruction_ref sum = args[0];
for(auto i = 1; i < args.size(); i++)
if(args.size() == 1)
return args[0];

if(args.size() < 5) // using heuristic when args exceed over 5 elements
Copy link

Copilot AI Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is inconsistent with the condition. The condition checks for 'less than 5' but the comment says 'when args exceed over 5 elements'. The comment should read 'using chain addition when args are less than 5 elements' or similar.

Suggested change
if(args.size() < 5) // using heuristic when args exceed over 5 elements
if(args.size() < 5) // using chain addition when args are less than 5 elements

Copilot uses AI. Check for mistakes.
{
instruction_ref sum = args[0];
for(auto i = 1; i < args.size(); i++)
{
sum = info.add_common_op("add", sum, args[i]);
}
return sum;
}
else
{
sum = info.add_instruction(make_op("add"), sum, args[i]);
std::vector<instruction_ref> unsqueezed_args;
std::transform(args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&info](instruction_ref arg) {
return info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}),
arg);
});
auto concatenated =
info.add_instruction(make_op("concat", {{"axis", 0}}), unsqueezed_args);
auto reduced =
info.add_instruction(make_op("reduce_sum", {{"axes", {0}}}), concatenated);
return info.add_instruction(make_op("squeeze", {{"axes", {0}}}), reduced);
}
return sum;
}
};

Expand Down
31 changes: 30 additions & 1 deletion test/ref/add.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -31,6 +31,35 @@

#include <test.hpp>

TEST_CASE(addn_with_reducesum_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5};
migraphx::shape b_shape{migraphx::shape::float_type, {2, 3}};
std::vector<float> b_data{0, 1, 2, 3, 4, 5};
migraphx::shape c_shape{migraphx::shape::float_type, {2, 3}};
std::vector<float> c_data{0, 1, 2, 3, 4, 5};
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
auto l3 = mm->add_literal(migraphx::literal{c_shape, c_data});
auto unsqueezedl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1);
auto unsqueezedl2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l2);
auto unsqueezedl3 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l3);
auto concated = mm->add_instruction(
migraphx::make_op("concat", {{"axis", 0}}), unsqueezedl1, unsqueezedl2, unsqueezedl3);
auto reduced = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), concated);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), reduced);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
EXPECT(result.get_shape().packed());
std::vector<float> results_vector(6);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 12, 15};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(add_broadcast_test)
{
migraphx::program p;
Expand Down
8 changes: 8 additions & 0 deletions test/tf/gen_tf_pb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def addn_single_test(g1):

tf.math.add_n([g1_input], name='addn1')

@tf_test
def addn_with_many_elements_test(g1):
with g1.as_default():
input_list = []
for i in range(10):
input_list.append(tf.compat.v1.placeholder(tf.float32, shape=(1, 1648), name=str(i)))
tf.math.add_n(input_list, name='addn1')


@tf_test
def argmax_test(g1):
Expand Down
46 changes: 46 additions & 0 deletions test/tf/models/addn_with_many_elements_test.pb
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

3
0 Placeholder*
dtype0*
shape : ð
3
1 Placeholder*
dtype0*
shape : ð
3
2 Placeholder*
dtype0*
shape : ð
3
3 Placeholder*
dtype0*
shape : ð
3
4 Placeholder*
dtype0*
shape : ð
3
5 Placeholder*
dtype0*
shape : ð
3
6 Placeholder*
dtype0*
shape : ð
3
7 Placeholder*
dtype0*
shape : ð
3
8 Placeholder*
dtype0*
shape : ð
3
9 Placeholder*
dtype0*
shape : ð
=
addn1AddN0123456789*
N
*
T0"ð
44 changes: 44 additions & 0 deletions test/tf/tests/addn_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,47 @@ TEST_CASE(addn_single_test)

EXPECT(p == prog);
}

TEST_CASE(addn_with_10_elements_test)
{
migraphx::program p;

auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto l3 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto l4 = mm->add_parameter("4", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto l5 = mm->add_parameter("5", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto l6 = mm->add_parameter("6", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto l7 = mm->add_parameter("7", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto l8 = mm->add_parameter("8", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto l9 = mm->add_parameter("9", migraphx::shape{migraphx::shape::float_type, {1, 1648}});
auto us0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto us1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1);
auto us2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l2);
auto us3 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l3);
auto us4 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l4);
auto us5 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l5);
auto us6 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l6);
auto us7 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l7);
auto us8 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l8);
auto us9 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l9);
auto concatenated = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
us0,
us1,
us2,
us3,
us4,
us5,
us6,
us7,
us8,
us9);
auto reduced =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), concatenated);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), reduced);
auto prog = optimize_tf("addn_with_many_elements_test.pb", false);

EXPECT(p == prog);
}
Loading