Skip to content
31 changes: 4 additions & 27 deletions csrc/device_lower/id_model_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ namespace nvfuser {
class IdModelOptions {
public:
IdModelOptions()
: build_id_model_(!isOptionDisabled(DisableOption::IdModel)),
consumer_index_(
: consumer_index_(
isIdModelOptionEnabled(IdModelEnableOption::ConsumerIndex)),
producer_index_(
isIdModelOptionEnabled(IdModelEnableOption::ProducerIndex)),
Expand All @@ -29,15 +28,6 @@ class IdModelOptions {
ensureConsistency();
}

bool buildIdModel() const {
return build_id_model_;
}

void setBuildIdModel(bool b) {
build_id_model_ = b;
ensureConsistency();
}

bool buildTensorIndexer() const {
return build_tensor_indexer_;
}
Expand Down Expand Up @@ -106,8 +96,7 @@ class IdModelOptions {
auto bool2str = [](bool b) { return b ? "true" : "false"; };

std::stringstream ss;
ss << "build_id_model=" << bool2str(build_id_model_)
<< ", build_tensor_indexer=" << bool2str(build_tensor_indexer_)
ss << "build_tensor_indexer=" << bool2str(build_tensor_indexer_)
<< ", consumer_index=" << bool2str(consumer_index_)
<< ", producer_index=" << bool2str(producer_index_)
<< ", inline_predicate=" << bool2str(inline_predicate_)
Expand All @@ -118,23 +107,11 @@ class IdModelOptions {

private:
void ensureConsistency() {
if (!build_id_model_) {
build_tensor_indexer_ = false;
consumer_index_ = false;
producer_index_ = false;
inline_predicate_ = false;
unswitch_predicate_ = false;
loop_ = false;
} else {
// TensorIndexer is required if these options are enabled
build_tensor_indexer_ = build_tensor_indexer_ || consumer_index_ ||
producer_index_ || inline_predicate_ || unswitch_predicate_ || loop_;
}
build_tensor_indexer_ = build_tensor_indexer_ || consumer_index_ ||
producer_index_ || inline_predicate_ || unswitch_predicate_ || loop_;
}

private:
// Build IdModel
bool build_id_model_ = true;
// Build TensorIndexer
bool build_tensor_indexer_ = false;
// Globally enables consumer indexing.
Expand Down
14 changes: 6 additions & 8 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,12 @@ void GpuLower::analysis(Fusion* fusion) {

// New IterDomains may be created, so it is expected that generated
// code may use diffrent variable names
if (idModelOptions().buildIdModel()) {
info().set(std::make_unique<IdModel>(
fusion_,
/*build_graphs=*/true,
/*allow_self_mapping=*/false,
/*validate=*/false));
info().idModel().validateAndPropagatePType();
}
info().set(std::make_unique<IdModel>(
fusion_,
/*build_graphs=*/true,
/*allow_self_mapping=*/false,
/*validate=*/false));
info().idModel().validateAndPropagatePType();

// Build what's refered to as the compute at map. This map contains the
// mappings of all iteration domains across the fusion. There are three types
Expand Down
43 changes: 8 additions & 35 deletions csrc/id_model/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,54 +28,27 @@ enum class IdModelEnableOption {
};

inline std::unordered_set<IdModelEnableOption> getIdModelEnabledOptions() {
if (!isOptionEnabled(EnableOption::IdModel)) {
return {};
}

std::unordered_set<IdModelEnableOption> opts;

if (hasEnableOptionArgument(EnableOption::IdModel, "consumer_index") ||
hasEnableOptionArgument(EnableOption::IdModel, "index") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
if (!hasEnableOptionArgument(EnableOption::IdModel, "predicate_only")) {
opts.insert(IdModelEnableOption::ConsumerIndex);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "producer_index") ||
hasEnableOptionArgument(EnableOption::IdModel, "index") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
opts.insert(IdModelEnableOption::ProducerIndex);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "inline_predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
if (!hasEnableOptionArgument(EnableOption::IdModel, "index_only")) {
opts.insert(IdModelEnableOption::InlinePredicate);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "unswitch_predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
opts.insert(IdModelEnableOption::UnswitchPredicate);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "loop") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
if (!hasEnableOptionArgument(EnableOption::IdModel, "predicate_only") &&
!hasEnableOptionArgument(EnableOption::IdModel, "index_only")) {
opts.insert(IdModelEnableOption::Loop);
}

// Loop requires ConsumerIndex, ProducerIndex, InlinePredicate and
// UnswitchPredicate
if (opts.find(IdModelEnableOption::Loop) != opts.end()) {
NVF_ERROR(
opts.find(IdModelEnableOption::ConsumerIndex) != opts.end(),
"ConsumerIndex required for Loop");
NVF_ERROR(
opts.find(IdModelEnableOption::ProducerIndex) != opts.end(),
"ProducerIndex required for Loop");
NVF_ERROR(
opts.find(IdModelEnableOption::InlinePredicate) != opts.end(),
"InlinePredicate required for Loop");
NVF_ERROR(
opts.find(IdModelEnableOption::UnswitchPredicate) != opts.end(),
"UnswitchPredicate required for Loop");
}

return opts;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new implementation changes the semantics of EnableOption::IdModel arguments. The old code supported explicit arguments like "all", "index", "predicate", "consumer_index", "producer_index", etc. The new code only recognizes "predicate_only" and "index_only" as restrictive flags.

This breaks existing usage like NVFUSER_ENABLE=id_model(all) found in tests/python/direct/test_with_id_model_indexer.py:183. Consider either:

  1. Updating the test to use NVFUSER_ENABLE=id_model (no arguments, which enables everything by default)
  2. Adding backward compatibility for the "all" argument


Expand Down
1 change: 0 additions & 1 deletion csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ const std::unordered_map<std::string, DisableOption>& getDisableOptions() {
{"greedy_scheduler", DisableOption::GreedyScheduler},
{"grouped_grid_welford_outer_opt",
DisableOption::GroupedGridWelfordOuterOpt},
{"id_model", DisableOption::IdModel},
{"index_hoist", DisableOption::IndexHoist},
{"magic_zero", DisableOption::MagicZero},
{"matmul_expr_eval", DisableOption::MatmulExprEval},
Expand Down
1 change: 0 additions & 1 deletion csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ enum class DisableOption {
GreedyScheduler, //! Disable the greedy scheduler
GroupedGridWelfordOuterOpt, //! Disable use of outer-optimized
//! grouped grid welford kernel
IdModel, //! Disable IdModel
IndexHoist, //! Disable index hoisting
MagicZero, //! Disable nvfuser_zero
MatmulExprEval, //! Disable ATen evaluation for the entire fusion containing
Expand Down
39 changes: 23 additions & 16 deletions csrc/predicate_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,21 @@ bool isOutputLocal(const Expr* expr) {
});
}

IterDomain* getConcreteMappedId(IterDomain* id) {
NVF_ERROR(GpuLower::hasCurrent());
return GpuLower::current()
->info()
.idModel()
.idGraph(IdMappingMode::EXACT)
.toGroup(id)
->front()
->as<IterDomain>();
}

} // namespace

bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) {
auto concrete_id = GpuLower::current()->info().caMap().getConcreteMappedID(
id, IdMappingMode::EXACT);
auto concrete_id = getConcreteMappedId(id);
if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) {
ids_.push_back(concrete_id);
return true;
Expand All @@ -59,10 +69,7 @@ Val* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const {

for (const auto& pred_id : ids()) {
// Just sanity check that pred_id is concrete
NVF_ERROR(
pred_id ==
GpuLower::current()->info().caMap().getConcreteMappedID(
pred_id, IdMappingMode::EXACT));
NVF_ERROR(pred_id == getConcreteMappedId(pred_id));
auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent());
pred = SimplifyingIrBuilder::logicalAndExpr(pred, new_pred);
}
Expand Down Expand Up @@ -290,8 +297,11 @@ ParallelizedDomainPredicate::getPredicateMap(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
[&](auto tv_id) {
return gpu_lower->info().caMap().areMapped(
loop_id, tv_id, IdMappingMode::EXACT);
return gpu_lower->info()
.idModel()
.idGraph(IdMappingMode::EXACT)
.disjointValSets()
.strictAreMapped(loop_id, tv_id);
});
if (it == tv->getLoopDomain().end()) {
continue;
Expand Down Expand Up @@ -450,12 +460,10 @@ UnswitchPredicateKey::UnswitchPredicateKey(
}

// Find the corresponding concrete id for each parallel type
for (auto consumer_loop : parallelized_consumer_loop_ids) {
auto pt = consumer_loop->getParallelType();
auto concrete_loop =
GpuLower::current()->info().caMap().getConcreteMappedID(
consumer_loop, IdMappingMode::EXACT);
parallel_concrete_ids_.at(pt) = concrete_loop;
for (auto consumer_loop_id : parallelized_consumer_loop_ids) {
auto pt = consumer_loop_id->getParallelType();
auto concrete_loop_id = getConcreteMappedId(consumer_loop_id);
parallel_concrete_ids_.at(pt) = concrete_loop_id;
}
}

Expand Down Expand Up @@ -1015,8 +1023,7 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) {
bool first_key_set = false;

for (auto root_id : root_ids) {
auto concrete_root_id = gpu_lower->info().caMap().getConcreteMappedID(
root_id, IdMappingMode::EXACT);
auto concrete_root_id = getConcreteMappedId(root_id);

if (root_id->isBroadcast()) {
continue;
Expand Down
7 changes: 6 additions & 1 deletion python/nvfuser_direct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,15 @@ def execute(
# A copy of fusion is created after construction FusionExecutorCache
# Delete the _fusion and reference the fusion inside FusionExecutorCache
del self._fusion

# Add "id_model" as a default enable option
default_enable_options = ["id_model"]
merged_enable_options = default_enable_options + _enable_options
Comment on lines +370 to +372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical Logic Issue: Adding "id_model" to enable options without arguments has no effect.

The IdModel system requires explicit sub-options to be passed as arguments (e.g., "all", "index", "consumer_index", "producer_index", etc.). These are checked via hasEnableOptionArgument(EnableOption::IdModel, "arg") in csrc/id_model/utils.h:30-80.

When you add "id_model" here without arguments, the Python bindings call EnableOptionsGuard::getCurOptions().set(opt.value()) with an empty vector (see python/python_direct/runtime.cpp:284). This means no IdModel features are actually enabled.

Evidence from codebase:

  1. C++ tests use: EnableOption::IdModel, {"all"} (tests/cpp/utils.cpp)
  2. Python tests that need IdModel use: set_env(NVFUSER_ENABLE="id_model(all)") (tests/python/direct/test_with_id_model_indexer.py:183)

Proposed solution:
The Python _enable_options API doesn't support passing arguments. You need to either:

  1. Modify the Python bindings to support syntax like ["id_model:all"] or similar
  2. Use environment variable NVFUSER_ENABLE="id_model(all)" in test setup
  3. Add a new Python API parameter specifically for IdModel options

Without this fix, TensorIndexer will NOT be enabled with python_direct tests as the PR title claims.

Comment on lines +371 to +372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] Potential duplicate options if user passes "id_model" in _enable_options. Consider checking if "id_model" is already in _enable_options before adding it, or use a set to avoid duplicates. While this doesn't cause functional issues (the second set() call just overwrites), it's cleaner to avoid duplicates.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +370 to +372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] No opt-out mechanism provided. Users cannot disable id_model when using execute() even if they want to. Consider either: (1) checking if "id_model" is in _disable_options and skipping the default, or (2) providing a parameter to control default options. This limits flexibility for users who may need to test without id_model or work around potential id_model bugs.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +370 to +372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding "id_model" as a default enable option means users cannot opt-out of this behavior through the Python API. Consider whether this is intentional or if there should be a mechanism to allow users to disable id_model for testing/debugging purposes.

Potential implications:

  • Users performing A/B testing between old and new indexing cannot easily do so
  • Debugging issues specific to id_model becomes harder without a way to disable it
  • This is a behavioral change that affects all python_direct users

If this is intentional (to force id_model adoption), consider documenting this breaking change clearly. If not, consider checking if "id_model" is already in _enable_options or providing an opt-out mechanism.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


return self.fec.execute(
inputs,
device=self._get_device_index(device),
_enable_options=_enable_options,
_enable_options=merged_enable_options,
_disable_options=_disable_options,
)

Expand Down
5 changes: 3 additions & 2 deletions tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,9 @@ TEST_F(IndexingTest, Reshape) {
// to provide the extent of the group. However, since everything
// should be deterministic, string match should also work.
return std::string(
"( ( ( ( ( i98 * 20 ) + ( ( i99 * 10 ) + i100 ) ) / 25 ) * 25 ) "
"+ ( ( ( i98 * 20 ) + ( ( i99 * 10 ) + i100 ) ) % 25 ) )");
"( ( ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) / 25 ) * 25 "
") "
"+ ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) % 25 ) )");
}
default:
return std::string();
Expand Down
10 changes: 4 additions & 6 deletions tests/python/direct/test_python_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,20 @@ def test_fusion_execution_cache():
i2 = i0 / 8;
nvfuser_index_t i3;
i3 = i0 % 8;
nvfuser_index_t i4;
i4 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
if ((i4 < 64)) {
if ((((nvfuser_index_t)threadIdx.x) < 64)) {
Array<float, 1, 1> T4;
T4[0] = 0;
T4[0]
= T1[((((T1.alloc_stride[0LL] * i1) + (T1.alloc_stride[1LL] * i2)) + (T1.alloc_stride[2LL] * i3)) + ((4 * T1.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
= T1[(((T1.alloc_stride[0LL] * i1) + (T1.alloc_stride[1LL] * i2)) + (T1.alloc_stride[2LL] * i3))];
Array<float, 1, 1> T3;
T3[0] = 0;
T3[0]
= T0[((((T0.alloc_stride[0LL] * i1) + (T0.alloc_stride[1LL] * i2)) + (T0.alloc_stride[2LL] * i3)) + ((4 * T0.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
= T0[(((T0.alloc_stride[0LL] * i1) + (T0.alloc_stride[1LL] * i2)) + (T0.alloc_stride[2LL] * i3))];
Array<float, 1, 1> T5;
T5[0]
= T3[0]
+ T4[0];
T2[i4]
T2[((nvfuser_index_t)threadIdx.x)]
= T5[0];
}
}\n"""
Expand Down