Skip to content
131 changes: 11 additions & 120 deletions csrc/device_lower/id_model_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,146 +7,37 @@
// clang-format on
#pragma once

#include <id_model/utils.h>

#include <sstream>

#include "id_model/utils.h"
#include "options.h"

namespace nvfuser {

class IdModelOptions {
public:
IdModelOptions()
: build_id_model_(!isOptionDisabled(DisableOption::IdModel)),
consumer_index_(
isIdModelOptionEnabled(IdModelEnableOption::ConsumerIndex)),
producer_index_(
isIdModelOptionEnabled(IdModelEnableOption::ProducerIndex)),
inline_predicate_(
isIdModelOptionEnabled(IdModelEnableOption::InlinePredicate)),
unswitch_predicate_(
isIdModelOptionEnabled(IdModelEnableOption::UnswitchPredicate)),
loop_(isIdModelOptionEnabled(IdModelEnableOption::Loop)) {
ensureConsistency();
}

bool buildIdModel() const {
return build_id_model_;
}

void setBuildIdModel(bool b) {
build_id_model_ = b;
ensureConsistency();
}

bool buildTensorIndexer() const {
return build_tensor_indexer_;
}

void setBuildTensorIndexer(bool b) {
build_tensor_indexer_ = b;
ensureConsistency();
}

bool consumerIndex() const {
return consumer_index_;
}

void setConsumerIndex(bool b) {
consumer_index_ = b;
ensureConsistency();
}

bool producerIndex() const {
return producer_index_;
}
: tensor_indexer_enabled_(isOptionEnabled(EnableOption::IdModel)) {}

void setProducerIndex(bool b) {
producer_index_ = b;
ensureConsistency();
void setTensorIndexer(bool b) {
tensor_indexer_enabled_ = b;
}

void setIndex(bool b) {
setConsumerIndex(b);
setProducerIndex(b);
}

bool inlinePredicate() const {
return inline_predicate_;
}

void setInlinePredicate(bool b) {
inline_predicate_ = b;
ensureConsistency();
}

bool unswitchPredicate() const {
return unswitch_predicate_;
}

void setUnswitchPredicate(bool b) {
unswitch_predicate_ = b;
ensureConsistency();
}

void setPredicate(bool b) {
setInlinePredicate(b);
setUnswitchPredicate(b);
}

bool loop() const {
return loop_;
}

void setLoop(bool b) {
loop_ = b;
ensureConsistency();
bool isTensorIndexerEnabled() const {
return tensor_indexer_enabled_;
}

std::string toString() const {
auto bool2str = [](bool b) { return b ? "true" : "false"; };

std::stringstream ss;
ss << "build_id_model=" << bool2str(build_id_model_)
<< ", build_tensor_indexer=" << bool2str(build_tensor_indexer_)
<< ", consumer_index=" << bool2str(consumer_index_)
<< ", producer_index=" << bool2str(producer_index_)
<< ", inline_predicate=" << bool2str(inline_predicate_)
<< ", unswitch_predicate=" << bool2str(unswitch_predicate_)
<< ", loop=" << bool2str(loop_);
ss << "enable_tensor_indexer=" << bool2str(tensor_indexer_enabled_);
return ss.str();
}

private:
void ensureConsistency() {
if (!build_id_model_) {
build_tensor_indexer_ = false;
consumer_index_ = false;
producer_index_ = false;
inline_predicate_ = false;
unswitch_predicate_ = false;
loop_ = false;
} else {
// TensorIndexer is required if these options are enabled
build_tensor_indexer_ = build_tensor_indexer_ || consumer_index_ ||
producer_index_ || inline_predicate_ || unswitch_predicate_ || loop_;
}
}

private:
// Build IdModel
bool build_id_model_ = true;
// Build TensorIndexer
bool build_tensor_indexer_ = false;
// Globally enables consumer indexing.
bool consumer_index_ = false;
// Globally enables producer indexing.
bool producer_index_ = false;
// Globally enables inline predicate
bool inline_predicate_ = false;
// Globally enables unswitch predicate
bool unswitch_predicate_ = false;
// Generate loops using IdModel
bool loop_ = false;
// Enable TensorIndexer
bool tensor_indexer_enabled_ = false;
};

} // namespace nvfuser
63 changes: 17 additions & 46 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,24 +298,13 @@ IdModelOptions getIdModelOptions(Fusion* fusion) {

for (auto expr : fusion->exprs()) {
if (auto ldst = dynamic_cast<LoadStoreOp*>(expr)) {
if (ldst->opType() == LoadStoreOpType::CpAsyncBulkTensorTile ||
ldst->opType() == LoadStoreOpType::CpAsyncBulk) {
options.setBuildTensorIndexer(true);
if (ldst->opType() == LoadStoreOpType::CpAsyncBulk) {
options.setInlinePredicate(true);
}
if (ldst->opType() == LoadStoreOpType::CpAsyncBulk) {
options.setTensorIndexer(true);
continue;
}
} else if (expr->isA<MmaOp>()) {
options.setBuildTensorIndexer(true);
continue;
} else if (
expr->isOneOf<ArgsortOp, PadOp, ScanOp, ScatterOp, SliceOp, TopKOp>()) {
options.setProducerIndex(true);
options.setConsumerIndex(true);
options.setInlinePredicate(true);
options.setUnswitchPredicate(true);
options.setLoop(true);
options.setTensorIndexer(true);
continue;
} else if (auto reshape = dynamic_cast<ReshapeOp*>(expr)) {
// The legacy indexer has an issue when an expand broadcast is
Expand Down Expand Up @@ -369,31 +358,16 @@ IdModelOptions getIdModelOptions(Fusion* fusion) {
return consumer_expanded_root_ids.count(input);
});
})) {
options.setProducerIndex(true);
options.setConsumerIndex(true);
options.setInlinePredicate(true);
options.setUnswitchPredicate(true);
options.setTensorIndexer(true);
}
}
}

// If a tensor does not have a nice root->logical/allocation->loop
// linear transformation history, use TensorIndexer
for (auto tv : fusion->allTvs()) {
if (tv->getMemoryType() == MemoryType::Tensor ||
!ir_utils::hasRootToLoopLinearTransformations(tv)) {
options.setBuildTensorIndexer(true);
}
}

// If not supported, disable use of TensorIndexer by default. It is
// still used if explicitly opted-in (see, for example,
// Index::getConsumerIndex)
if (!TensorIndexer::isSupported(fusion)) {
// Do not disable building of TensorIndexer as it may be still used
options.setIndex(false);
options.setPredicate(false);
options.setLoop(false);
options.setTensorIndexer(false);
}

return options;
Expand Down Expand Up @@ -461,14 +435,12 @@ void GpuLower::analysis(Fusion* fusion) {

// New IterDomains may be created, so it is expected that generated
// code may use diffrent variable names
if (idModelOptions().buildIdModel()) {
info().set(std::make_unique<IdModel>(
fusion_,
/*build_graphs=*/true,
/*allow_self_mapping=*/false,
/*validate=*/false));
info().idModel().validateAndPropagatePType();
}
info().set(std::make_unique<IdModel>(
fusion_,
/*build_graphs=*/true,
/*allow_self_mapping=*/false,
/*validate=*/false));
info().idModel().validateAndPropagatePType();

// Build what's refered to as the compute at map. This map contains the
// mappings of all iteration domains across the fusion. There are three types
Expand Down Expand Up @@ -578,16 +550,15 @@ void GpuLower::analysis(Fusion* fusion) {
info().caMap().allocateIndexVariables();
dumpExprsIfEnabled(fusion_->exprs(), "allocateIndexVariables");

if (idModelOptions().loop()) {
if (idModelOptions().isTensorIndexerEnabled()) {
// Depends on CircularBufferInfo and compute_at_map_->allocateIndexVariables
info().idModel().allocateLoopIndexVariables();
}

if (idModelOptions().buildTensorIndexer()) {
tensor_indexer_ = std::make_unique<TensorIndexer>(info().idModel());
non_divisible_predicate_info_ =
std::make_unique<NonDivisiblePredicateInfo>(fusion_);
}
tensor_indexer_ = std::make_unique<TensorIndexer>(info().idModel());

non_divisible_predicate_info_ =
std::make_unique<NonDivisiblePredicateInfo>(fusion_);

// Detects all exprssions that don't need predicates. Depends on
// nonDivisibleSplitInfo.
Expand Down Expand Up @@ -663,7 +634,7 @@ bool GpuLower::resolveComputeWith(Fusion* fusion) {
Val* GpuLower::getLoopIndexVariable(
IterDomain* id,
CircularBufferLoopStage stage) const {
if (idModelOptions().loop()) {
if (idModelOptions().isTensorIndexerEnabled()) {
return info().idModel().getLoopIndexVariable(id, stage);
} else {
return info().caMap().getIndexVariable(id, stage);
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2126,7 +2126,7 @@ IterDomain* getConcreteLoopID(IterDomain* id) {
// Currently, the concrete loop ID uses the IdModel loop
// promotion only when opted in.
if ((GpuLower::hasCurrent() &&
GpuLower::current()->idModelOptions().loop()) ||
GpuLower::current()->idModelOptions().isTensorIndexerEnabled()) ||
(!GpuLower::hasCurrent() && FusionInfoGuard::current()->hasIdModel() &&
FusionInfoGuard::current()->idModel().hasIdGraph(IdMappingMode::LOOP))) {
// If enabled, the concret ID should be basically just the
Expand Down
2 changes: 1 addition & 1 deletion csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ void IdModel::allocateLoopIndexVariables() {
// If enabled, allocate own indices. Otherwise, use the one
// generated for ComputeAtMap for compatibility with the legacy
// indexing
if (GpuLower::current()->idModelOptions().loop()) {
if (GpuLower::current()->idModelOptions().isTensorIndexerEnabled()) {
loop_index = IrBuilder::create<Val>(DataType::Index);
} else {
const auto& ca_map = FusionInfoGuard::current()->caMap();
Expand Down
68 changes: 0 additions & 68 deletions csrc/id_model/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,74 +16,6 @@

namespace nvfuser {

// Options to enable the IdModel-based tensor indexer selectively
enum class IdModelEnableOption {
ConsumerIndex,
ProducerIndex,
InlinePredicate,
UnswitchPredicate,
// Uses the loop promotion to generate loops. Indexing and
// predication need to be enabled as well.
Loop,
};

inline std::unordered_set<IdModelEnableOption> getIdModelEnabledOptions() {
std::unordered_set<IdModelEnableOption> opts;

if (hasEnableOptionArgument(EnableOption::IdModel, "consumer_index") ||
hasEnableOptionArgument(EnableOption::IdModel, "index") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
opts.insert(IdModelEnableOption::ConsumerIndex);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "producer_index") ||
hasEnableOptionArgument(EnableOption::IdModel, "index") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
opts.insert(IdModelEnableOption::ProducerIndex);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "inline_predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
opts.insert(IdModelEnableOption::InlinePredicate);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "unswitch_predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "predicate") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
opts.insert(IdModelEnableOption::UnswitchPredicate);
}

if (hasEnableOptionArgument(EnableOption::IdModel, "loop") ||
hasEnableOptionArgument(EnableOption::IdModel, "all")) {
opts.insert(IdModelEnableOption::Loop);
}

// Loop requires ConsumerIndex, ProducerIndex, InlinePredicate and
// UnswitchPredicate
if (opts.find(IdModelEnableOption::Loop) != opts.end()) {
NVF_ERROR(
opts.find(IdModelEnableOption::ConsumerIndex) != opts.end(),
"ConsumerIndex required for Loop");
NVF_ERROR(
opts.find(IdModelEnableOption::ProducerIndex) != opts.end(),
"ProducerIndex required for Loop");
NVF_ERROR(
opts.find(IdModelEnableOption::InlinePredicate) != opts.end(),
"InlinePredicate required for Loop");
NVF_ERROR(
opts.find(IdModelEnableOption::UnswitchPredicate) != opts.end(),
"UnswitchPredicate required for Loop");
}

return opts;
}

inline bool isIdModelOptionEnabled(IdModelEnableOption option) {
const auto opts = getIdModelEnabledOptions();
return opts.find(option) != opts.end();
}

// Get the promotion domain of a given loop domain.
inline IterDomain* getLoopPromotion(
IterDomain* loop_id,
Expand Down
Loading