Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
165869d
Adding lower_lrn_to_pooling function
aarushjain29 Sep 10, 2025
65d5b70
Adding lower_lrn_to_pooling function
aarushjain29 Sep 10, 2025
83651ec
Adding invert permutation header
aarushjain29 Sep 10, 2025
75d6370
changes in apply
aarushjain29 Sep 19, 2025
6d94d79
Update src/rewrite_pooling.cpp
aarushjain29 Sep 19, 2025
acafcc0
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Sep 19, 2025
0b717ab
test case added
aarushjain29 Sep 22, 2025
bec9dac
test case added
aarushjain29 Sep 23, 2025
ea39019
more test case
aarushjain29 Sep 23, 2025
8b36261
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Sep 23, 2025
69efb32
remove comment
aarushjain29 Sep 23, 2025
fb8b708
remove spaces
aarushjain29 Sep 23, 2025
6dae737
Update test/rewrite_pooling_test.cpp
aarushjain29 Sep 23, 2025
09ef691
Update test/rewrite_pooling_test.cpp
aarushjain29 Sep 23, 2025
6ac18ee
tidy errors
aarushjain29 Sep 24, 2025
4780e5e
cpp errors
aarushjain29 Sep 24, 2025
3858f86
tidy errors
aarushjain29 Sep 24, 2025
2c1b76f
tidy errors
aarushjain29 Sep 24, 2025
07a2730
test case for comparing two models
aarushjain29 Sep 24, 2025
c2dc92b
test case for comparing two models
aarushjain29 Sep 24, 2025
ea7159d
accepting both even and odd sizes
aarushjain29 Sep 24, 2025
cc383cd
remove tidy errors
aarushjain29 Sep 24, 2025
6d17986
remove tidy errors
aarushjain29 Sep 24, 2025
506e2e6
tidy error
aarushjain29 Sep 24, 2025
6e5597d
formatting
aarushjain29 Sep 24, 2025
393e613
formatting
aarushjain29 Sep 24, 2025
38a39c2
license
aarushjain29 Sep 24, 2025
2b6cc85
formatting
aarushjain29 Sep 24, 2025
d4d5452
reverting back to even size
aarushjain29 Sep 24, 2025
5a783c9
logic for both evenn and odd sizes
aarushjain29 Sep 26, 2025
02a7ce7
calculate padding
aarushjain29 Sep 30, 2025
84d2105
formatting
aarushjain29 Sep 30, 2025
f4df624
formatting
aarushjain29 Sep 30, 2025
89fc424
formatting
aarushjain29 Sep 30, 2025
e489de0
test case added
aarushjain29 Sep 30, 2025
b0335f2
verify test case
aarushjain29 Sep 30, 2025
77ca91d
formatting
aarushjain29 Sep 30, 2025
2043239
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Sep 30, 2025
e8c2547
Update test/rewrite_pooling_test.cpp
aarushjain29 Sep 30, 2025
32821aa
licensing
aarushjain29 Sep 30, 2025
50b429f
combine line 89 and 90
aarushjain29 Oct 2, 2025
5e51bc1
compiler warning unused param
aarushjain29 Oct 2, 2025
e0355c5
remove transposed lens
aarushjain29 Oct 2, 2025
b8fc4ee
Adding the check for size and combining all the checks in if
aarushjain29 Oct 2, 2025
1b073b9
changing the test to simplify_algebra like test
aarushjain29 Oct 2, 2025
72934da
tidy error
aarushjain29 Oct 3, 2025
ea31d5e
tidy error
aarushjain29 Oct 3, 2025
3d7b480
tidy error
aarushjain29 Oct 3, 2025
080ac81
tidy error
aarushjain29 Oct 3, 2025
85c15cd
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Oct 3, 2025
91eb3c7
remove try catch and add all conditions
aarushjain29 Oct 5, 2025
a351f31
formatting
aarushjain29 Oct 5, 2025
310c8e7
new tests added
aarushjain29 Oct 5, 2025
8745a74
removing test case from test_relu_lrn
aarushjain29 Oct 5, 2025
808a56e
MIGRAPHX_REWRITE_LRN flag
aarushjain29 Oct 5, 2025
355bec1
license
aarushjain29 Oct 5, 2025
6d8929f
formatting
aarushjain29 Oct 5, 2025
c69157a
formatting
aarushjain29 Oct 5, 2025
1dadd05
formatting
aarushjain29 Oct 5, 2025
f5318f2
license
aarushjain29 Oct 5, 2025
a546fd4
enable flag in test case
aarushjain29 Oct 5, 2025
afabc81
test case accepting flag
aarushjain29 Oct 5, 2025
8b17b1d
formatting
aarushjain29 Oct 6, 2025
2e244c5
formatting
aarushjain29 Oct 6, 2025
878b135
simplify code
aarushjain29 Oct 6, 2025
04a7705
simplify code
aarushjain29 Oct 6, 2025
7ef1891
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Oct 6, 2025
526e25e
updated the doc
aarushjain29 Oct 6, 2025
c7cf3ee
remove flag in test and formatting
aarushjain29 Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/reference/MIGraphX-dev-env-vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@ Model performance tunable variables change the compilation behavior of a model.

| Default: No tuning is done for composable kernels.

* - | ``MIGRAPHX_REWRITE_LRN``
| Turns on LRN-to-pooling lowering in the rewrite_pooling pass.

- | ``1``: Turns on LRN-to-pooling lowering.
| ``0``: Returns to default behavior.

| Default: LRN-to-pooling lowering is turned off.

Matching
**********
Expand Down
3 changes: 2 additions & 1 deletion src/include/migraphx/rewrite_pooling.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -38,6 +38,7 @@ struct module;
*/
struct MIGRAPHX_EXPORT rewrite_pooling
{
bool rewrite_lrn = false;
std::string name() const { return "rewrite_pooling"; }
void apply(module& m) const;
};
Expand Down
89 changes: 86 additions & 3 deletions src/rewrite_pooling.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -30,6 +30,8 @@
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/make_op.hpp>

#include <migraphx/op/lrn.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/program.hpp>

namespace migraphx {
Expand All @@ -55,6 +57,81 @@ static void replace_with_reduce(module& m, instruction_ref ins)
}
}

static void lower_lrn_to_pooling(module& m, instruction_ref ins)
{
auto v = ins->get_operator().to_value();

float alpha = v.at("alpha").to<float>();
float beta = v.at("beta").to<float>();
float k = v.at("bias").to<float>();
int size = v.at("size").to<int>();

auto x = ins->inputs().at(0);
const auto& xshape = x->get_shape();
const auto& lens = xshape.lens();

if(lens.size() != 4 or size <= 0 or size > lens[1])
{
return;
}

auto x2 = m.insert_instruction(ins, make_op("mul"), x, x);

std::vector<int64_t> perm = {0, 2, 3, 1};
auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2);

auto transposed_shape = transpose1->get_shape();
const auto& transposed_lens = transposed_shape.lens();

int64_t channel_dim = lens[1];
std::vector<int64_t> calculated_pads(2);
calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true);

auto avg = m.insert_instruction(
ins,
make_op("pooling",
{{"mode", op::pooling_mode::average},
{"lengths", std::vector<int64_t>{1, size}},
{"stride", std::vector<int64_t>{1, 1}},
{"padding", std::vector<int64_t>{0, calculated_pads[0], 0, calculated_pads[1]}},
{"dilations", std::vector<int64_t>{1, 1}},
{"count_include_pad", true}}),
transpose1);

auto avg_shape = avg->get_shape();
const auto& avg_lens = avg_shape.lens();

if(avg_lens.size() != 4 or avg_lens[3] != transposed_lens[3])
{
return;
}

std::vector<int64_t> inv_perm = {0, 3, 1, 2};
auto transpose2 =
m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg);

auto final_shape = transpose2->get_shape();
const auto& final_lens = final_shape.lens();

if(final_lens != lens)
return;

auto k_lit = m.add_literal(k);
auto a_lit = m.add_literal(alpha);
auto b_lit = m.add_literal(beta);

auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit);
auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit);
auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit);

auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2);
auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg);
auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb);
auto y = m.insert_instruction(ins, make_op("div"), x, denpow);

m.replace_instruction(ins, y);
}

static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins)
{
// TODO remove this when MIOpen supports dilated pooling
Expand Down Expand Up @@ -143,10 +220,16 @@ void rewrite_pooling::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
if(ins->inputs().empty())
continue;
if(rewrite_lrn and ins->name() == "lrn")
{
lower_lrn_to_pooling(m, ins);
continue;
}
if(ins->name() != "pooling")
continue;

auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
bool same_kernel_as_shape = std::equal(
Expand Down
3 changes: 2 additions & 1 deletion src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_REWRITE_DOT)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN)
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif
Expand Down Expand Up @@ -203,7 +204,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
insert_pad{{"convolution"}},
dead_code_elimination{},
inline_module{},
rewrite_pooling{},
rewrite_pooling{.rewrite_lrn = enabled(MIGRAPHX_REWRITE_LRN{})},
dead_code_elimination{},
rewrite_gelu{options.fast_math},
optimize_module{},
Expand Down
74 changes: 74 additions & 0 deletions test/rewrite_pooling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@

#include <migraphx/verify.hpp>

#include <migraphx/iterator.hpp>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN);
static void opt_pooling(migraphx::module& m)
{
migraphx::rewrite_pooling rp;
Expand Down Expand Up @@ -309,6 +312,77 @@ TEST_CASE(rewrite_pooling_dialtions_test5)
test_rewrite(migraphx::op::pooling_mode::max);
}

TEST_CASE(test_lower_lrn_to_pooling)
{
migraphx::module m1;
{
migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}};
auto input1 = m1.add_parameter("x", input_shape);
auto lrn1 = m1.add_instruction(
migraphx::make_op("lrn",
{{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}),
input1);
m1.add_return({lrn1});
}
// Apply the pass directly when the flag enabled
migraphx::rewrite_pooling rp{.rewrite_lrn = true};
migraphx::dead_code_elimination dce;
rp.apply(m1);
dce.apply(m1);

migraphx::module m2;
{
migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}};
auto input2 = m2.add_parameter("x", input_shape);

auto x2 = m2.add_instruction(migraphx::make_op("mul"), input2, input2);

auto transpose1 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", std::vector<int64_t>{0, 2, 3, 1}}}),
x2);

int64_t lrn_size = 4;
int64_t pad_left = (lrn_size - 1) / 2;
int64_t pad_right = lrn_size - 1 - pad_left;
std::vector<int64_t> expected_pads = {pad_left, pad_right};

auto avg = m2.add_instruction(
migraphx::make_op(
"pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", std::vector<int64_t>{1, lrn_size}},
{"stride", std::vector<int64_t>{1, 1}},
{"padding", std::vector<int64_t>{0, expected_pads[0], 0, expected_pads[1]}},
{"dilations", std::vector<int64_t>{1, 1}},
{"count_include_pad", true}}),
transpose1);

auto transpose2 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", std::vector<int64_t>{0, 3, 1, 2}}}),
avg);

auto k_lit = m2.add_literal(1.0f);
auto a_lit = m2.add_literal(0.0001f);
auto b_lit = m2.add_literal(0.75f);

auto k_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), k_lit);
auto a_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), a_lit);
auto b_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), b_lit);

auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, transpose2);
auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg);
auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb);
auto y = m2.add_instruction(migraphx::make_op("div"), input2, denpow);

m2.add_return({y});
}

EXPECT(m1 == m2);
}

TEST_CASE(rewrite_avgpool_rank3_dil_test)
{
// 1D case 1, input is 3D
Expand Down
50 changes: 50 additions & 0 deletions test/verify/test_lrn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

template <int ChannelSize, int LrnSize>
struct test_lrn : verify_program<test_lrn<ChannelSize, LrnSize>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter(
"x", migraphx::shape{migraphx::shape::float_type, {1, ChannelSize, 28, 28}});
mm->add_instruction(
migraphx::make_op(
"lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", LrnSize}}),
x);
return p;
}
};

template struct test_lrn<32, 6>;
template struct test_lrn<32, 5>;
template struct test_lrn<31, 8>;
template struct test_lrn<31, 5>;
2 changes: 1 addition & 1 deletion test/verify/test_relu_lrn.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down
Loading