From 562d1f1e07341e54e777c7cd77bd1040f2af2ad2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Jan 2026 16:39:54 -0800 Subject: [PATCH 1/6] Enable TensorIndexer with all python_direct tests --- csrc/id_model/utils.h | 39 ++++--------------------------- python/nvfuser_direct/__init__.py | 7 +++++- 2 files changed, 10 insertions(+), 36 deletions(-) diff --git a/csrc/id_model/utils.h b/csrc/id_model/utils.h index 9750bbd592d..1d57372f4f3 100644 --- a/csrc/id_model/utils.h +++ b/csrc/id_model/utils.h @@ -30,52 +30,21 @@ enum class IdModelEnableOption { inline std::unordered_set getIdModelEnabledOptions() { std::unordered_set opts; - if (hasEnableOptionArgument(EnableOption::IdModel, "consumer_index") || - hasEnableOptionArgument(EnableOption::IdModel, "index") || - hasEnableOptionArgument(EnableOption::IdModel, "all")) { + if (!hasEnableOptionArgument(EnableOption::IdModel, "predicate_only")) { opts.insert(IdModelEnableOption::ConsumerIndex); - } - - if (hasEnableOptionArgument(EnableOption::IdModel, "producer_index") || - hasEnableOptionArgument(EnableOption::IdModel, "index") || - hasEnableOptionArgument(EnableOption::IdModel, "all")) { opts.insert(IdModelEnableOption::ProducerIndex); } - if (hasEnableOptionArgument(EnableOption::IdModel, "inline_predicate") || - hasEnableOptionArgument(EnableOption::IdModel, "predicate") || - hasEnableOptionArgument(EnableOption::IdModel, "all")) { + if (!hasEnableOptionArgument(EnableOption::IdModel, "index_only")) { opts.insert(IdModelEnableOption::InlinePredicate); - } - - if (hasEnableOptionArgument(EnableOption::IdModel, "unswitch_predicate") || - hasEnableOptionArgument(EnableOption::IdModel, "predicate") || - hasEnableOptionArgument(EnableOption::IdModel, "all")) { opts.insert(IdModelEnableOption::UnswitchPredicate); } - if (hasEnableOptionArgument(EnableOption::IdModel, "loop") || - hasEnableOptionArgument(EnableOption::IdModel, "all")) { + if (!hasEnableOptionArgument(EnableOption::IdModel, "predicate_only") && + !hasEnableOptionArgument(EnableOption::IdModel, "index_only")) { opts.insert(IdModelEnableOption::Loop); } - // Loop requires ConsumerIndex, ProducerIndex, InlinePredicate and - // UnswitchPredicate - if (opts.find(IdModelEnableOption::Loop) != opts.end()) { - NVF_ERROR( - opts.find(IdModelEnableOption::ConsumerIndex) != opts.end(), - "ConsumerIndex required for Loop"); - NVF_ERROR( - opts.find(IdModelEnableOption::ProducerIndex) != opts.end(), - "ProducerIndex required for Loop"); - NVF_ERROR( - opts.find(IdModelEnableOption::InlinePredicate) != opts.end(), - "InlinePredicate required for Loop"); - NVF_ERROR( - opts.find(IdModelEnableOption::UnswitchPredicate) != opts.end(), - "UnswitchPredicate required for Loop"); - } - return opts; } diff --git a/python/nvfuser_direct/__init__.py b/python/nvfuser_direct/__init__.py index 317a780225f..34780436eac 100644 --- a/python/nvfuser_direct/__init__.py +++ b/python/nvfuser_direct/__init__.py @@ -366,10 +366,15 @@ def execute( # A copy of fusion is created after construction FusionExecutorCache # Delete the _fusion and reference the fusion inside FusionExecutorCache del self._fusion + + # Add "id_model" as a default enable option + default_enable_options = ["id_model"] + merged_enable_options = default_enable_options + _enable_options + return self.fec.execute( inputs, device=self._get_device_index(device), - _enable_options=_enable_options, + _enable_options=merged_enable_options, _disable_options=_disable_options, ) From f0c8bbe79de9453b06f3b1a428f14f69c1849a33 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Jan 2026 18:49:32 -0800 Subject: [PATCH 2/6] fix test --- tests/cpp/test_indexing.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index 918e0e7059e..85cc3e12739 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -864,8 +864,9 @@ TEST_F(IndexingTest, Reshape) { // to provide the extent of the group. However, since everything // should be deterministic, string match should also work. return std::string( - "( ( ( ( ( i98 * 20 ) + ( ( i99 * 10 ) + i100 ) ) / 25 ) * 25 ) " - "+ ( ( ( i98 * 20 ) + ( ( i99 * 10 ) + i100 ) ) % 25 ) )"); + "( ( ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) / 25 ) * 25 " + ") " + "+ ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) % 25 ) )"); } default: return std::string(); From af1b4b79e20928400ec5717af7493d3f940c3353 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Jan 2026 22:11:15 -0800 Subject: [PATCH 3/6] fix --- csrc/device_lower/id_model_options.h | 31 ++++------------------------ csrc/device_lower/lower2device.cpp | 14 ++++++------- csrc/id_model/utils.h | 4 ++++ csrc/options.cpp | 1 - csrc/options.h | 1 - 5 files changed, 14 insertions(+), 37 deletions(-) diff --git a/csrc/device_lower/id_model_options.h b/csrc/device_lower/id_model_options.h index 180ae91ddb0..594e49cdcea 100644 --- a/csrc/device_lower/id_model_options.h +++ b/csrc/device_lower/id_model_options.h @@ -16,8 +16,7 @@ namespace nvfuser { class IdModelOptions { public: IdModelOptions() - : build_id_model_(!isOptionDisabled(DisableOption::IdModel)), - consumer_index_( + : consumer_index_( isIdModelOptionEnabled(IdModelEnableOption::ConsumerIndex)), producer_index_( isIdModelOptionEnabled(IdModelEnableOption::ProducerIndex)), @@ -29,15 +28,6 @@ class IdModelOptions { ensureConsistency(); } - bool buildIdModel() const { - return build_id_model_; - } - - void setBuildIdModel(bool b) { - build_id_model_ = b; - ensureConsistency(); - } - bool buildTensorIndexer() const { return build_tensor_indexer_; } @@ -106,8 +96,7 @@ class IdModelOptions { auto bool2str = [](bool b) { return b ? "true" : "false"; }; std::stringstream ss; - ss << "build_id_model=" << bool2str(build_id_model_) - << ", build_tensor_indexer=" << bool2str(build_tensor_indexer_) + ss << "build_tensor_indexer=" << bool2str(build_tensor_indexer_) << ", consumer_index=" << bool2str(consumer_index_) << ", producer_index=" << bool2str(producer_index_) << ", inline_predicate=" << bool2str(inline_predicate_) @@ -118,23 +107,11 @@ class IdModelOptions { private: void ensureConsistency() { - if (!build_id_model_) { - build_tensor_indexer_ = false; - consumer_index_ = false; - producer_index_ = false; - inline_predicate_ = false; - unswitch_predicate_ = false; - loop_ = false; - } else { - // TensorIndexer is required if these options are enabled - build_tensor_indexer_ = build_tensor_indexer_ || consumer_index_ || - producer_index_ || inline_predicate_ || unswitch_predicate_ || loop_; - } + build_tensor_indexer_ = build_tensor_indexer_ || consumer_index_ || + producer_index_ || inline_predicate_ || unswitch_predicate_ || loop_; } private: - // Build IdModel - bool build_id_model_ = true; // Build TensorIndexer bool build_tensor_indexer_ = false; // Globally enables consumer indexing. diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index fe244d2d7da..7a5412248da 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -461,14 +461,12 @@ void GpuLower::analysis(Fusion* fusion) { // New IterDomains may be created, so it is expected that generated // code may use diffrent variable names - if (idModelOptions().buildIdModel()) { - info().set(std::make_unique( - fusion_, - /*build_graphs=*/true, - /*allow_self_mapping=*/false, - /*validate=*/false)); - info().idModel().validateAndPropagatePType(); - } + info().set(std::make_unique( + fusion_, + /*build_graphs=*/true, + /*allow_self_mapping=*/false, + /*validate=*/false)); + info().idModel().validateAndPropagatePType(); // Build what's refered to as the compute at map. This map contains the // mappings of all iteration domains across the fusion. There are three types diff --git a/csrc/id_model/utils.h b/csrc/id_model/utils.h index 1d57372f4f3..f45c6f4c851 100644 --- a/csrc/id_model/utils.h +++ b/csrc/id_model/utils.h @@ -28,6 +28,10 @@ enum class IdModelEnableOption { }; inline std::unordered_set getIdModelEnabledOptions() { + if (!isOptionEnabled(EnableOption::IdModel)) { + return {}; + } + std::unordered_set opts; if (!hasEnableOptionArgument(EnableOption::IdModel, "predicate_only")) { diff --git a/csrc/options.cpp b/csrc/options.cpp index 81564414f4c..0cc602bb5c8 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -217,7 +217,6 @@ const std::unordered_map& getDisableOptions() { {"greedy_scheduler", DisableOption::GreedyScheduler}, {"grouped_grid_welford_outer_opt", DisableOption::GroupedGridWelfordOuterOpt}, - {"id_model", DisableOption::IdModel}, {"index_hoist", DisableOption::IndexHoist}, {"magic_zero", DisableOption::MagicZero}, {"matmul_expr_eval", DisableOption::MatmulExprEval}, diff --git a/csrc/options.h b/csrc/options.h index 6f722c1227c..3f21c3d9392 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -147,7 +147,6 @@ enum class DisableOption { GreedyScheduler, //! Disable the greedy scheduler GroupedGridWelfordOuterOpt, //! Disable use of outer-optimized //! grouped grid welford kernel - IdModel, //! Disable IdModel IndexHoist, //! Disable index hoisting MagicZero, //! Disable nvfuser_zero MatmulExprEval, //! Disable ATen evaluation for the entire fusion containing From e4b6f22063696225a96a131fdf88f53a40f82e6f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 9 Jan 2026 00:31:28 -0800 Subject: [PATCH 4/6] cleanup --- csrc/predicate_compute.cpp | 39 ++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index f352b634d0f..b81ee621c09 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -38,11 +38,21 @@ bool isOutputLocal(const Expr* expr) { }); } +IterDomain* getConcreteMappedId(IterDomain* id) { + NVF_ERROR(GpuLower::hasCurrent()); + return GpuLower::current() + ->info() + .idModel() + .idGraph(IdMappingMode::EXACT) + .toGroup(id) + ->front() + ->as(); +} + } // namespace bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) { - auto concrete_id = GpuLower::current()->info().caMap().getConcreteMappedID( - id, IdMappingMode::EXACT); + auto concrete_id = getConcreteMappedId(id); if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) { ids_.push_back(concrete_id); return true; @@ -59,10 +69,7 @@ Val* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { for (const auto& pred_id : ids()) { // Just sanity check that pred_id is concrete - NVF_ERROR( - pred_id == - GpuLower::current()->info().caMap().getConcreteMappedID( - pred_id, IdMappingMode::EXACT)); + NVF_ERROR(pred_id == getConcreteMappedId(pred_id)); auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent()); pred = SimplifyingIrBuilder::logicalAndExpr(pred, new_pred); } @@ -290,8 +297,11 @@ ParallelizedDomainPredicate::getPredicateMap( tv->getLoopDomain().begin(), tv->getLoopDomain().end(), [&](auto tv_id) { - return gpu_lower->info().caMap().areMapped( - loop_id, tv_id, IdMappingMode::EXACT); + return gpu_lower->info() + .idModel() + .idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped(loop_id, tv_id); }); if (it == tv->getLoopDomain().end()) { continue; @@ -450,12 +460,10 @@ UnswitchPredicateKey::UnswitchPredicateKey( } // Find the corresponding concrete id for each parallel type - for (auto consumer_loop : parallelized_consumer_loop_ids) { - auto pt = consumer_loop->getParallelType(); - auto concrete_loop = - GpuLower::current()->info().caMap().getConcreteMappedID( - consumer_loop, IdMappingMode::EXACT); - parallel_concrete_ids_.at(pt) = concrete_loop; + for (auto consumer_loop_id : parallelized_consumer_loop_ids) { + auto pt = consumer_loop_id->getParallelType(); + auto concrete_loop_id = getConcreteMappedId(consumer_loop_id); + parallel_concrete_ids_.at(pt) = concrete_loop_id; } } @@ -1015,8 +1023,7 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) { bool first_key_set = false; for (auto root_id : root_ids) { - auto concrete_root_id = gpu_lower->info().caMap().getConcreteMappedID( - root_id, IdMappingMode::EXACT); + auto concrete_root_id = getConcreteMappedId(root_id); if (root_id->isBroadcast()) { continue; From 6fb3cc662ca440d145db1c8089617fbef57d1f8a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 9 Jan 2026 00:38:03 -0800 Subject: [PATCH 5/6] test fix --- tests/python/direct/test_python_direct.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/python/direct/test_python_direct.py b/tests/python/direct/test_python_direct.py index 57b1c594778..de7ac506ab7 100644 --- a/tests/python/direct/test_python_direct.py +++ b/tests/python/direct/test_python_direct.py @@ -229,22 +229,20 @@ def test_fusion_execution_cache(): i2 = i0 / 8; nvfuser_index_t i3; i3 = i0 % 8; - nvfuser_index_t i4; - i4 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - if ((i4 < 64)) { + if ((((nvfuser_index_t)threadIdx.x) < 64)) { Array T4; T4[0] = 0; T4[0] - = T1[((((T1.alloc_stride[0LL] * i1) + (T1.alloc_stride[1LL] * i2)) + (T1.alloc_stride[2LL] * i3)) + ((4 * T1.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))]; + = T1[(((T1.alloc_stride[0LL] * i1) + (T1.alloc_stride[1LL] * i2)) + (T1.alloc_stride[2LL] * i3))]; Array T3; T3[0] = 0; T3[0] - = T0[((((T0.alloc_stride[0LL] * i1) + (T0.alloc_stride[1LL] * i2)) + (T0.alloc_stride[2LL] * i3)) + ((4 * T0.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))]; + = T0[(((T0.alloc_stride[0LL] * i1) + (T0.alloc_stride[1LL] * i2)) + (T0.alloc_stride[2LL] * i3))]; Array T5; T5[0] = T3[0] + T4[0]; - T2[i4] + T2[((nvfuser_index_t)threadIdx.x)] = T5[0]; } }\n""" From 32362d3df2906af3f58a1349d5be4c6c37c1fef7 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 12 Jan 2026 21:50:54 -0800 Subject: [PATCH 6/6] cleanup --- csrc/device_lower/id_model_options.h | 31 +++++++++++++++++--- csrc/device_lower/lower2device.cpp | 14 +++++---- csrc/id_model/utils.h | 43 ++++++++++++++++++++++------ csrc/options.cpp | 1 + csrc/options.h | 1 + csrc/predicate_compute.cpp | 39 +++++++++++-------------- 6 files changed, 88 insertions(+), 41 deletions(-) diff --git a/csrc/device_lower/id_model_options.h b/csrc/device_lower/id_model_options.h index 594e49cdcea..180ae91ddb0 100644 --- a/csrc/device_lower/id_model_options.h +++ b/csrc/device_lower/id_model_options.h @@ -16,7 +16,8 @@ namespace nvfuser { class IdModelOptions { public: IdModelOptions() - : consumer_index_( + : build_id_model_(!isOptionDisabled(DisableOption::IdModel)), + consumer_index_( isIdModelOptionEnabled(IdModelEnableOption::ConsumerIndex)), producer_index_( isIdModelOptionEnabled(IdModelEnableOption::ProducerIndex)), @@ -28,6 +29,15 @@ class IdModelOptions { ensureConsistency(); } + bool buildIdModel() const { + return build_id_model_; + } + + void setBuildIdModel(bool b) { + build_id_model_ = b; + ensureConsistency(); + } + bool buildTensorIndexer() const { return build_tensor_indexer_; } @@ -96,7 +106,8 @@ class IdModelOptions { auto bool2str = [](bool b) { return b ? "true" : "false"; }; std::stringstream ss; - ss << "build_tensor_indexer=" << bool2str(build_tensor_indexer_) + ss << "build_id_model=" << bool2str(build_id_model_) + << ", build_tensor_indexer=" << bool2str(build_tensor_indexer_) << ", consumer_index=" << bool2str(consumer_index_) << ", producer_index=" << bool2str(producer_index_) << ", inline_predicate=" << bool2str(inline_predicate_) @@ -107,11 +118,23 @@ class IdModelOptions { private: void ensureConsistency() { - build_tensor_indexer_ = build_tensor_indexer_ || consumer_index_ || - producer_index_ || inline_predicate_ || unswitch_predicate_ || loop_; + if (!build_id_model_) { + build_tensor_indexer_ = false; + consumer_index_ = false; + producer_index_ = false; + inline_predicate_ = false; + unswitch_predicate_ = false; + loop_ = false; + } else { + // TensorIndexer is required if these options are enabled + build_tensor_indexer_ = build_tensor_indexer_ || consumer_index_ || + producer_index_ || inline_predicate_ || unswitch_predicate_ || loop_; + } } private: + // Build IdModel + bool build_id_model_ = true; // Build TensorIndexer bool build_tensor_indexer_ = false; // Globally enables consumer indexing. diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 7a5412248da..fe244d2d7da 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -461,12 +461,14 @@ void GpuLower::analysis(Fusion* fusion) { // New IterDomains may be created, so it is expected that generated // code may use diffrent variable names - info().set(std::make_unique( - fusion_, - /*build_graphs=*/true, - /*allow_self_mapping=*/false, - /*validate=*/false)); - info().idModel().validateAndPropagatePType(); + if (idModelOptions().buildIdModel()) { + info().set(std::make_unique( + fusion_, + /*build_graphs=*/true, + /*allow_self_mapping=*/false, + /*validate=*/false)); + info().idModel().validateAndPropagatePType(); + } // Build what's refered to as the compute at map. This map contains the // mappings of all iteration domains across the fusion. There are three types diff --git a/csrc/id_model/utils.h b/csrc/id_model/utils.h index f45c6f4c851..9750bbd592d 100644 --- a/csrc/id_model/utils.h +++ b/csrc/id_model/utils.h @@ -28,27 +28,54 @@ enum class IdModelEnableOption { }; inline std::unordered_set getIdModelEnabledOptions() { - if (!isOptionEnabled(EnableOption::IdModel)) { - return {}; - } - std::unordered_set opts; - if (!hasEnableOptionArgument(EnableOption::IdModel, "predicate_only")) { + if (hasEnableOptionArgument(EnableOption::IdModel, "consumer_index") || + hasEnableOptionArgument(EnableOption::IdModel, "index") || + hasEnableOptionArgument(EnableOption::IdModel, "all")) { opts.insert(IdModelEnableOption::ConsumerIndex); + } + + if (hasEnableOptionArgument(EnableOption::IdModel, "producer_index") || + hasEnableOptionArgument(EnableOption::IdModel, "index") || + hasEnableOptionArgument(EnableOption::IdModel, "all")) { opts.insert(IdModelEnableOption::ProducerIndex); } - if (!hasEnableOptionArgument(EnableOption::IdModel, "index_only")) { + if (hasEnableOptionArgument(EnableOption::IdModel, "inline_predicate") || + hasEnableOptionArgument(EnableOption::IdModel, "predicate") || + hasEnableOptionArgument(EnableOption::IdModel, "all")) { opts.insert(IdModelEnableOption::InlinePredicate); + } + + if (hasEnableOptionArgument(EnableOption::IdModel, "unswitch_predicate") || + hasEnableOptionArgument(EnableOption::IdModel, "predicate") || + hasEnableOptionArgument(EnableOption::IdModel, "all")) { opts.insert(IdModelEnableOption::UnswitchPredicate); } - if (!hasEnableOptionArgument(EnableOption::IdModel, "predicate_only") && - !hasEnableOptionArgument(EnableOption::IdModel, "index_only")) { + if (hasEnableOptionArgument(EnableOption::IdModel, "loop") || + hasEnableOptionArgument(EnableOption::IdModel, "all")) { opts.insert(IdModelEnableOption::Loop); } + // Loop requires ConsumerIndex, ProducerIndex, InlinePredicate and + // UnswitchPredicate + if (opts.find(IdModelEnableOption::Loop) != opts.end()) { + NVF_ERROR( + opts.find(IdModelEnableOption::ConsumerIndex) != opts.end(), + "ConsumerIndex required for Loop"); + NVF_ERROR( + opts.find(IdModelEnableOption::ProducerIndex) != opts.end(), + "ProducerIndex required for Loop"); + NVF_ERROR( + opts.find(IdModelEnableOption::InlinePredicate) != opts.end(), + "InlinePredicate required for Loop"); + NVF_ERROR( + opts.find(IdModelEnableOption::UnswitchPredicate) != opts.end(), + "UnswitchPredicate required for Loop"); + } + return opts; } diff --git a/csrc/options.cpp b/csrc/options.cpp index 0cc602bb5c8..81564414f4c 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -217,6 +217,7 @@ const std::unordered_map& getDisableOptions() { {"greedy_scheduler", DisableOption::GreedyScheduler}, {"grouped_grid_welford_outer_opt", DisableOption::GroupedGridWelfordOuterOpt}, + {"id_model", DisableOption::IdModel}, {"index_hoist", DisableOption::IndexHoist}, {"magic_zero", DisableOption::MagicZero}, {"matmul_expr_eval", DisableOption::MatmulExprEval}, diff --git a/csrc/options.h b/csrc/options.h index 3f21c3d9392..6f722c1227c 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -147,6 +147,7 @@ enum class DisableOption { GreedyScheduler, //! Disable the greedy scheduler GroupedGridWelfordOuterOpt, //! Disable use of outer-optimized //! grouped grid welford kernel + IdModel, //! Disable IdModel IndexHoist, //! Disable index hoisting MagicZero, //! Disable nvfuser_zero MatmulExprEval, //! Disable ATen evaluation for the entire fusion containing diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index b81ee621c09..f352b634d0f 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -38,21 +38,11 @@ bool isOutputLocal(const Expr* expr) { }); } -IterDomain* getConcreteMappedId(IterDomain* id) { - NVF_ERROR(GpuLower::hasCurrent()); - return GpuLower::current() - ->info() - .idModel() - .idGraph(IdMappingMode::EXACT) - .toGroup(id) - ->front() - ->as(); -} - } // namespace bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) { - auto concrete_id = getConcreteMappedId(id); + auto concrete_id = GpuLower::current()->info().caMap().getConcreteMappedID( + id, IdMappingMode::EXACT); if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) { ids_.push_back(concrete_id); return true; @@ -69,7 +59,10 @@ Val* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { for (const auto& pred_id : ids()) { // Just sanity check that pred_id is concrete - NVF_ERROR(pred_id == getConcreteMappedId(pred_id)); + NVF_ERROR( + pred_id == + GpuLower::current()->info().caMap().getConcreteMappedID( + pred_id, IdMappingMode::EXACT)); auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent()); pred = SimplifyingIrBuilder::logicalAndExpr(pred, new_pred); } @@ -297,11 +290,8 @@ ParallelizedDomainPredicate::getPredicateMap( tv->getLoopDomain().begin(), tv->getLoopDomain().end(), [&](auto tv_id) { - return gpu_lower->info() - .idModel() - .idGraph(IdMappingMode::EXACT) - .disjointValSets() - .strictAreMapped(loop_id, tv_id); + return gpu_lower->info().caMap().areMapped( + loop_id, tv_id, IdMappingMode::EXACT); }); if (it == tv->getLoopDomain().end()) { continue; @@ -460,10 +450,12 @@ UnswitchPredicateKey::UnswitchPredicateKey( } // Find the corresponding concrete id for each parallel type - for (auto consumer_loop_id : parallelized_consumer_loop_ids) { - auto pt = consumer_loop_id->getParallelType(); - auto concrete_loop_id = getConcreteMappedId(consumer_loop_id); - parallel_concrete_ids_.at(pt) = concrete_loop_id; + for (auto consumer_loop : parallelized_consumer_loop_ids) { + auto pt = consumer_loop->getParallelType(); + auto concrete_loop = + GpuLower::current()->info().caMap().getConcreteMappedID( + consumer_loop, IdMappingMode::EXACT); + parallel_concrete_ids_.at(pt) = concrete_loop; } } @@ -1023,7 +1015,8 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) { bool first_key_set = false; for (auto root_id : root_ids) { - auto concrete_root_id = getConcreteMappedId(root_id); + auto concrete_root_id = gpu_lower->info().caMap().getConcreteMappedID( + root_id, IdMappingMode::EXACT); if (root_id->isBroadcast()) { continue;