From c78363f173010de9d32d0935353bd56072372434 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 24 Sep 2025 00:37:54 +0000 Subject: [PATCH 1/3] Refactor find_matches --- src/include/migraphx/matcher.hpp | 124 ++++++++++++++++++++++++++++++- 1 file changed, 122 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index f5cd5682b39..c1e4d92b5bd 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -406,6 +406,110 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_MATCHERS) +template +auto make_match_runner_with_trace(source_location location, Finder& f) +{ + auto m = f.matcher(); + const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); + const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); + const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); + const auto& finder_name = get_type_name(f); + const bool trace_enabled = trace > 0 and (trace_filter.empty() or + contains(std::string{location.file_name()}, trace_filter) or + contains(std::string{location.function_name()}, trace_filter) or + contains(finder_name, trace_filter)); + return [=, &f](auto& mod, instruction_ref ins) -> bool { + using microseconds = std::chrono::duration; + if(trace > 1 and trace_enabled) + std::cout << "Running matcher: " << finder_name << std::endl; + + + match::matcher_result r; + double match_time = 0.0; + if(trace_enabled) + { + match_time = time([&] { + r = match::match_instruction(get_module(mod), ins, m); + }); + } + else + { + r = match::match_instruction(get_module(mod), ins, m); + } + + if(trace > 1 and trace_enabled) + { + std::cout << "Matcher time for " << finder_name << ": " << match_time << "us" + << std::endl; + } + + // did not match any instruction + if(r.result == get_module(mod).end()) + return false; + + if(trace > 0 or trace_enabled) + { + std::cout << "Matched by: " << finder_name << std::endl; + get_module(mod).debug_print(ins); + } + // If its already invalid dont validate it again + bool invalidated = validate and get_module(mod).validate() != get_module(mod).end(); + if(trace_enabled) + { + if(trace > 1) + std::cout << "Applying matcher: " << finder_name << std::endl; + auto apply_time = time([&] { f.apply(mod, r); }); + std::cout << "Apply time for " << finder_name << ": " << apply_time << "us" + << std::endl; + } + else + { + f.apply(mod, r); + } + + if(validate and not invalidated) + { + auto invalid = get_module(mod).validate(); + if(invalid != get_module(mod).end()) + { + std::cout << "Invalid program from match: " << finder_name << std::endl; + std::cout << "Invalid instructions: " << std::endl; + get_module(mod).debug_print(invalid->inputs()); + get_module(mod).debug_print(invalid); + } + } + return true; + }; +} + +template +auto make_match_runner(Finder& f) +{ + auto m = f.matcher(); + return [=, &f](auto& mod, instruction_ref ins) -> bool { + match::matcher_result r = match::match_instruction(get_module(mod), ins, m); + if(r.result == get_module(mod).end()) + return false; + f.apply(mod, r); + return true; + }; +} + +template +void find_matches_for(Mod& mod, instruction_ref ins, RunnerPack rp) +{ + rp([&](auto&&... rs) { + bool matched = false; + each_args( + [&](auto&& r) { + if(matched) + return; + matched = r(mod, ins); + }, + rs...); + }); +} + /// Find matches for an instruction in the module for per section of matchers template void find_matches_for(source_location location, Mod& mod, instruction_ref ins, Ms&&... ms) @@ -484,9 +588,25 @@ struct find_matches { find_matches(Mod& mod, Ms&&... ms, source_location location = source_location::current()) { - for(auto ins : iterator_for(get_module(mod))) + const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); + const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); + const bool need_trace = trace > 0 or validate; + + if(need_trace) { - find_matches_for(location, mod, ins, ms...); + auto runners = pack(make_match_runner_with_trace(location, ms)...); + for(auto ins : iterator_for(get_module(mod))) + { + find_matches_for(mod, ins, runners); + } + } + else + { + auto runners = pack(make_match_runner(ms)...); + for(auto ins : iterator_for(get_module(mod))) + { + find_matches_for(mod, ins, runners); + } } } }; From 40eea7b4780bca264ea70e70f8d6fb004bc9619e Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 24 Sep 2025 00:38:16 +0000 Subject: [PATCH 2/3] Format --- src/include/migraphx/matcher.hpp | 35 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index c1e4d92b5bd..a7298cb6054 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -406,31 +406,30 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_MATCHERS) -template +template auto make_match_runner_with_trace(source_location location, Finder& f) { - auto m = f.matcher(); - const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); - const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); - const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); + auto m = f.matcher(); + const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); + const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); + const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); const auto& finder_name = get_type_name(f); - const bool trace_enabled = trace > 0 and (trace_filter.empty() or - contains(std::string{location.file_name()}, trace_filter) or - contains(std::string{location.function_name()}, trace_filter) or - contains(finder_name, trace_filter)); + const bool trace_enabled = + trace > 0 and + (trace_filter.empty() or contains(std::string{location.file_name()}, trace_filter) or + contains(std::string{location.function_name()}, trace_filter) or + contains(finder_name, trace_filter)); return [=, &f](auto& mod, instruction_ref ins) -> bool { using microseconds = std::chrono::duration; if(trace > 1 and trace_enabled) std::cout << "Running matcher: " << finder_name << std::endl; - match::matcher_result r; double match_time = 0.0; if(trace_enabled) { - match_time = time([&] { - r = match::match_instruction(get_module(mod), ins, m); - }); + match_time = + time([&] { r = match::match_instruction(get_module(mod), ins, m); }); } else { @@ -460,7 +459,7 @@ auto make_match_runner_with_trace(source_location location, Finder& f) std::cout << "Applying matcher: " << finder_name << std::endl; auto apply_time = time([&] { f.apply(mod, r); }); std::cout << "Apply time for " << finder_name << ": " << apply_time << "us" - << std::endl; + << std::endl; } else { @@ -482,7 +481,7 @@ auto make_match_runner_with_trace(source_location location, Finder& f) }; } -template +template auto make_match_runner(Finder& f) { auto m = f.matcher(); @@ -498,7 +497,7 @@ auto make_match_runner(Finder& f) template void find_matches_for(Mod& mod, instruction_ref ins, RunnerPack rp) { - rp([&](auto&&... rs) { + rp([&](auto&&... rs) { bool matched = false; each_args( [&](auto&& r) { @@ -588,8 +587,8 @@ struct find_matches { find_matches(Mod& mod, Ms&&... ms, source_location location = source_location::current()) { - const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); - const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); + const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); + const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); const bool need_trace = trace > 0 or validate; if(need_trace) From 6000a7991b921b9f9099cb50834d635de43a8594 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Wed, 24 Sep 2025 17:23:11 -0500 Subject: [PATCH 3/3] Update src/include/migraphx/matcher.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/include/migraphx/matcher.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index a7298cb6054..3961bfdda12 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -446,7 +446,7 @@ auto make_match_runner_with_trace(source_location location, Finder& f) if(r.result == get_module(mod).end()) return false; - if(trace > 0 or trace_enabled) + if(trace_enabled) { std::cout << "Matched by: " << finder_name << std::endl; get_module(mod).debug_print(ins);