Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 4 additions & 27 deletions csrc/device_lower/id_model_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ namespace nvfuser {
class IdModelOptions {
public:
IdModelOptions()
: build_id_model_(!isOptionDisabled(DisableOption::IdModel)),
consumer_index_(
: consumer_index_(
isIdModelOptionEnabled(IdModelEnableOption::ConsumerIndex)),
producer_index_(
isIdModelOptionEnabled(IdModelEnableOption::ProducerIndex)),
Expand All @@ -29,15 +28,6 @@ class IdModelOptions {
ensureConsistency();
}

bool buildIdModel() const {
return build_id_model_;
}

void setBuildIdModel(bool b) {
build_id_model_ = b;
ensureConsistency();
}

bool buildTensorIndexer() const {
return build_tensor_indexer_;
}
Expand Down Expand Up @@ -106,8 +96,7 @@ class IdModelOptions {
auto bool2str = [](bool b) { return b ? "true" : "false"; };

std::stringstream ss;
ss << "build_id_model=" << bool2str(build_id_model_)
<< ", build_tensor_indexer=" << bool2str(build_tensor_indexer_)
ss << "build_tensor_indexer=" << bool2str(build_tensor_indexer_)
<< ", consumer_index=" << bool2str(consumer_index_)
<< ", producer_index=" << bool2str(producer_index_)
<< ", inline_predicate=" << bool2str(inline_predicate_)
Expand All @@ -118,23 +107,11 @@ class IdModelOptions {

private:
void ensureConsistency() {
if (!build_id_model_) {
build_tensor_indexer_ = false;
consumer_index_ = false;
producer_index_ = false;
inline_predicate_ = false;
unswitch_predicate_ = false;
loop_ = false;
} else {
// TensorIndexer is required if these options are enabled
build_tensor_indexer_ = build_tensor_indexer_ || consumer_index_ ||
producer_index_ || inline_predicate_ || unswitch_predicate_ || loop_;
}
build_tensor_indexer_ = build_tensor_indexer_ || consumer_index_ ||
producer_index_ || inline_predicate_ || unswitch_predicate_ || loop_;
}

private:
// Build IdModel
bool build_id_model_ = true;
// Build TensorIndexer
bool build_tensor_indexer_ = false;
// Globally enables consumer indexing.
Expand Down
14 changes: 6 additions & 8 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,12 @@ void GpuLower::analysis(Fusion* fusion) {

// New IterDomains may be created, so it is expected that generated
// code may use diffrent variable names
if (idModelOptions().buildIdModel()) {
info().set(std::make_unique<IdModel>(
fusion_,
/*build_graphs=*/true,
/*allow_self_mapping=*/false,
/*validate=*/false));
info().idModel().validateAndPropagatePType();
}
info().set(std::make_unique<IdModel>(
fusion_,
/*build_graphs=*/true,
/*allow_self_mapping=*/false,
/*validate=*/false));
info().idModel().validateAndPropagatePType();

// Build what's refered to as the compute at map. This map contains the
// mappings of all iteration domains across the fusion. There are three types
Expand Down
43 changes: 8 additions & 35 deletions csrc/id_model/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,54 +28,27 @@ enum class IdModelEnableOption {
};

inline std::unordered_set<IdModelEnableOption> getIdModelEnabledOptions() {
if (!isOptionEnabled(EnableOption::IdModel)) {
return {};
}

std::unordered_set<IdModelEnableOption> opts;

if (hasEnableOptionArgument(EnableOption::IdModel, "consumer_index") ||
hasEnableOptionArgument(EnableOption::IdModel, "index") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
if (!hasEnableOptionArgument(EnableOption::IdModel, "predicate_only")) {
opts.insert(IdModelEnableOption::ConsumerIndex);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "producer_index") ||
hasEnableOptionArgument(EnableOption::IdModel, "index") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
opts.insert(IdModelEnableOption::ProducerIndex);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "inline_predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
if (!hasEnableOptionArgument(EnableOption::IdModel, "index_only")) {
opts.insert(IdModelEnableOption::InlinePredicate);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "unswitch_predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
opts.insert(IdModelEnableOption::UnswitchPredicate);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "loop") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
if (!hasEnableOptionArgument(EnableOption::IdModel, "predicate_only") &&
!hasEnableOptionArgument(EnableOption::IdModel, "index_only")) {
opts.insert(IdModelEnableOption::Loop);
}

// Loop requires ConsumerIndex, ProducerIndex, InlinePredicate and
// UnswitchPredicate
if (opts.find(IdModelEnableOption::Loop) != opts.end()) {
NVF_ERROR(
opts.find(IdModelEnableOption::ConsumerIndex) != opts.end(),
"ConsumerIndex required for Loop");
NVF_ERROR(
opts.find(IdModelEnableOption::ProducerIndex) != opts.end(),
"ProducerIndex required for Loop");
NVF_ERROR(
opts.find(IdModelEnableOption::InlinePredicate) != opts.end(),
"InlinePredicate required for Loop");
NVF_ERROR(
opts.find(IdModelEnableOption::UnswitchPredicate) != opts.end(),
"UnswitchPredicate required for Loop");
}

return opts;
}
Comment on lines 30 to 53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new implementation changes the semantics of EnableOption::IdModel arguments. The old code supported explicit arguments like "all", "index", "predicate", "consumer_index", "producer_index", etc. The new code only recognizes "predicate_only" and "index_only" as restrictive flags.

This breaks existing usage like NVFUSER_ENABLE=id_model(all) found in tests/python/direct/test_with_id_model_indexer.py:183. Consider either:

  1. Updating the test to use NVFUSER_ENABLE=id_model (no arguments, which enables everything by default)
  2. Adding backward compatibility for the "all" argument


Expand Down
1 change: 0 additions & 1 deletion csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ const std::unordered_map<std::string, DisableOption>& getDisableOptions() {
{"greedy_scheduler", DisableOption::GreedyScheduler},
{"grouped_grid_welford_outer_opt",
DisableOption::GroupedGridWelfordOuterOpt},
{"id_model", DisableOption::IdModel},
{"index_hoist", DisableOption::IndexHoist},
{"magic_zero", DisableOption::MagicZero},
{"matmul_expr_eval", DisableOption::MatmulExprEval},
Expand Down
1 change: 0 additions & 1 deletion csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ enum class DisableOption {
GreedyScheduler, //! Disable the greedy scheduler
GroupedGridWelfordOuterOpt, //! Disable use of outer-optimized
//! grouped grid welford kernel
IdModel, //! Disable IdModel
IndexHoist, //! Disable index hoisting
MagicZero, //! Disable nvfuser_zero
MatmulExprEval, //! Disable ATen evaluation for the entire fusion containing
Expand Down
39 changes: 23 additions & 16 deletions csrc/predicate_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,21 @@ bool isOutputLocal(const Expr* expr) {
});
}

IterDomain* getConcreteMappedId(IterDomain* id) {
NVF_ERROR(GpuLower::hasCurrent());
return GpuLower::current()
->info()
.idModel()
.idGraph(IdMappingMode::EXACT)
.toGroup(id)
->front()
->as<IterDomain>();
}

} // namespace

bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) {
auto concrete_id = GpuLower::current()->info().caMap().getConcreteMappedID(
id, IdMappingMode::EXACT);
auto concrete_id = getConcreteMappedId(id);
if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) {
ids_.push_back(concrete_id);
return true;
Expand All @@ -59,10 +69,7 @@ Val* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const {

for (const auto& pred_id : ids()) {
// Just sanity check that pred_id is concrete
NVF_ERROR(
pred_id ==
GpuLower::current()->info().caMap().getConcreteMappedID(
pred_id, IdMappingMode::EXACT));
NVF_ERROR(pred_id == getConcreteMappedId(pred_id));
auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent());
pred = SimplifyingIrBuilder::logicalAndExpr(pred, new_pred);
}
Expand Down Expand Up @@ -290,8 +297,11 @@ ParallelizedDomainPredicate::getPredicateMap(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
[&](auto tv_id) {
return gpu_lower->info().caMap().areMapped(
loop_id, tv_id, IdMappingMode::EXACT);
return gpu_lower->info()
.idModel()
.idGraph(IdMappingMode::EXACT)
.disjointValSets()
.strictAreMapped(loop_id, tv_id);
});
if (it == tv->getLoopDomain().end()) {
continue;
Expand Down Expand Up @@ -450,12 +460,10 @@ UnswitchPredicateKey::UnswitchPredicateKey(
}

// Find the corresponding concrete id for each parallel type
for (auto consumer_loop : parallelized_consumer_loop_ids) {
auto pt = consumer_loop->getParallelType();
auto concrete_loop =
GpuLower::current()->info().caMap().getConcreteMappedID(
consumer_loop, IdMappingMode::EXACT);
parallel_concrete_ids_.at(pt) = concrete_loop;
for (auto consumer_loop_id : parallelized_consumer_loop_ids) {
auto pt = consumer_loop_id->getParallelType();
auto concrete_loop_id = getConcreteMappedId(consumer_loop_id);
parallel_concrete_ids_.at(pt) = concrete_loop_id;
}
}

Expand Down Expand Up @@ -1015,8 +1023,7 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) {
bool first_key_set = false;

for (auto root_id : root_ids) {
auto concrete_root_id = gpu_lower->info().caMap().getConcreteMappedID(
root_id, IdMappingMode::EXACT);
auto concrete_root_id = getConcreteMappedId(root_id);

if (root_id->isBroadcast()) {
continue;
Expand Down
7 changes: 6 additions & 1 deletion python/nvfuser_direct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,15 @@ def execute(
# A copy of fusion is created after construction FusionExecutorCache
# Delete the _fusion and reference the fusion inside FusionExecutorCache
del self._fusion

# Add "id_model" as a default enable option
default_enable_options = ["id_model"]
merged_enable_options = default_enable_options + _enable_options

return self.fec.execute(
inputs,
device=self._get_device_index(device),
_enable_options=_enable_options,
_enable_options=merged_enable_options,
_disable_options=_disable_options,
)

Expand Down
5 changes: 3 additions & 2 deletions tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,9 @@ TEST_F(IndexingTest, Reshape) {
// to provide the extent of the group. However, since everything
// should be deterministic, string match should also work.
return std::string(
"( ( ( ( ( i98 * 20 ) + ( ( i99 * 10 ) + i100 ) ) / 25 ) * 25 ) "
"+ ( ( ( i98 * 20 ) + ( ( i99 * 10 ) + i100 ) ) % 25 ) )");
"( ( ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) / 25 ) * 25 "
") "
"+ ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) % 25 ) )");
}
default:
return std::string();
Expand Down
10 changes: 4 additions & 6 deletions tests/python/direct/test_python_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,20 @@ def test_fusion_execution_cache():
i2 = i0 / 8;
nvfuser_index_t i3;
i3 = i0 % 8;
nvfuser_index_t i4;
i4 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
if ((i4 < 64)) {
if ((((nvfuser_index_t)threadIdx.x) < 64)) {
Array<float, 1, 1> T4;
T4[0] = 0;
T4[0]
= T1[((((T1.alloc_stride[0LL] * i1) + (T1.alloc_stride[1LL] * i2)) + (T1.alloc_stride[2LL] * i3)) + ((4 * T1.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
= T1[(((T1.alloc_stride[0LL] * i1) + (T1.alloc_stride[1LL] * i2)) + (T1.alloc_stride[2LL] * i3))];
Array<float, 1, 1> T3;
T3[0] = 0;
T3[0]
= T0[((((T0.alloc_stride[0LL] * i1) + (T0.alloc_stride[1LL] * i2)) + (T0.alloc_stride[2LL] * i3)) + ((4 * T0.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
= T0[(((T0.alloc_stride[0LL] * i1) + (T0.alloc_stride[1LL] * i2)) + (T0.alloc_stride[2LL] * i3))];
Array<float, 1, 1> T5;
T5[0]
= T3[0]
+ T4[0];
T2[i4]
T2[((nvfuser_index_t)threadIdx.x)]
= T5[0];
}
}\n"""
Expand Down