Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions csrc/scheduler/greedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1102,14 +1102,6 @@ class ConstrainedOpScheduler : public OptOutDispatch {
const auto& constrained_loop_id_offsets =
getDependentLoopIds(tv, constrained_logical_id_offsets);

// Don't inline constrained IDs. For example, like reduction IDs,
// argsort'ed IDs should never be inlined into its consumers.
for (const auto constrained_logical_id_offset :
constrained_logical_id_offsets) {
uninlinable_ids_.insert(
tv->getLogicalDomain().at(constrained_logical_id_offset));
}

// Move the constrained_logical_ids innermost
std::unordered_map<int64_t, int64_t> old2new;
for (const auto [i, offset] : enumerate(constrained_loop_id_offsets)) {
Expand Down Expand Up @@ -1176,6 +1168,25 @@ class ConstrainedOpScheduler : public OptOutDispatch {
tv->flatten(
0, std::ssize(tv->getLoopDomain()) - 1 - num_constrained_loop_ids);
tv->axis(0)->parallelize(ParallelType::BIDx);

// Don't inline constrained IDs. For example, like reduction IDs,
// argsort'ed IDs should never be inlined into its consumers.
std::unordered_set<Val*> constrained_logical;
for (const auto constrained_logical_id_offset :
constrained_logical_id_offsets) {
constrained_logical.insert(
tv->getLogicalDomain().at(constrained_logical_id_offset));
}

auto all_constrained_ids = DependencyCheck::getAllValsBetween(
constrained_logical,
{tv->getLoopDomain().begin(), tv->getLoopDomain().end()});
for (const auto loop_id : tv->getLoopDomain()) {
if (std::ranges::find(all_constrained_ids, loop_id) !=
all_constrained_ids.end()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to exclude constrained_logical

For a manual scheduling, we could have logical domain and loop domain share IDs and this would artificially exclude that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify your question? Exclude from inlining? Or exclude from uninlinable_ids?

They should be included in all_constrained_ids, so any loop ID, no matter if it's also a logical ID, should be included in uninlinable_ids. So, no matter if it's logical or not, all constrained loop IDs are excluded from inlining.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you are right. For some reason I read getAllValsBetween and figured that the dependencies wouldn't be included.

268   // Grab all values that exist between and including provided
269   // vals. Returned values are topologicaly ordered, and unique.
270   NVF_API static std::vector<Val*> getAllValsBetween(                                                             
271       const std::unordered_set<Val*>& dependencies,
272       const std::vector<Val*>& of);

uninlinable_ids_.insert(loop_id);
}
}
}

private:
Expand Down
91 changes: 91 additions & 0 deletions tests/cpp/test_greedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,37 @@ TEST_P(GreedySchedulerTestConstraintSize, ArgsortLargeConstrainedIDs) {
testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);

EXPECT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented());

FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
Fusion* scheduled_fusion = runtime->executors()
.at(0)
->as<KernelExecutor>()
->compiledKernel()
->kernel();
for (const auto tv : scheduled_fusion->allTvs()) {
if (tv->isDefinitionType<ArgsortOp>()) {
// The loop domain should look like: [BIDx, TIDx, Group] if
// grouped, or [BIDx, TIDx] if not. The computeAt and producer
// positions should be 1 in both cases.
EXPECT_TRUE(
std::ssize(tv->getLoopDomain()) == 2 ||
std::ssize(tv->getLoopDomain()) == 3);
EXPECT_EQ(tv->axis(0)->getParallelType(), ParallelType::BIDx);
EXPECT_EQ(tv->axis(1)->getParallelType(), ParallelType::TIDx);
if (std::ssize(tv->getLoopDomain()) == 3) {
EXPECT_EQ(tv->axis(-1)->getParallelType(), ParallelType::Group);
}
EXPECT_EQ(tv->getComputeAtPosition(), 1);
// The input should not be inlined into the argsort dim
EXPECT_EQ(
tv->definition()
->as<ArgsortOp>()
->in()
->as<TensorView>()
->getComputeAtPosition(),
1);
}
}
}

TEST_P(GreedySchedulerTestConstraintSize, ScanLargeConstrainedIDs) {
Expand All @@ -758,6 +789,37 @@ TEST_P(GreedySchedulerTestConstraintSize, ScanLargeConstrainedIDs) {
testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);

EXPECT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented());

FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
Fusion* scheduled_fusion = runtime->executors()
.at(0)
->as<KernelExecutor>()
->compiledKernel()
->kernel();
for (const auto tv : scheduled_fusion->allTvs()) {
if (tv->isDefinitionType<ScanOp>()) {
// The loop domain should look like: [BIDx, TIDx, Group] if
// grouped, or [BIDx, TIDx] if not. The computeAt and producer
// positions should be 1 in both cases.
EXPECT_TRUE(
std::ssize(tv->getLoopDomain()) == 2 ||
std::ssize(tv->getLoopDomain()) == 3);
EXPECT_EQ(tv->axis(0)->getParallelType(), ParallelType::BIDx);
EXPECT_EQ(tv->axis(1)->getParallelType(), ParallelType::TIDx);
if (std::ssize(tv->getLoopDomain()) == 3) {
EXPECT_EQ(tv->axis(-1)->getParallelType(), ParallelType::Group);
}
EXPECT_EQ(tv->getComputeAtPosition(), 1);
// The input should not be inlined into the scan dim
EXPECT_EQ(
tv->definition()
->as<ScanOp>()
->in()
->as<TensorView>()
->getComputeAtPosition(),
1);
}
}
}

TEST_P(GreedySchedulerTestConstraintSize, ScatterLargeConstrainedIDs) {
Expand Down Expand Up @@ -786,6 +848,35 @@ TEST_P(GreedySchedulerTestConstraintSize, ScatterLargeConstrainedIDs) {
testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__);

EXPECT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented());

FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
Fusion* scheduled_fusion = runtime->executors()
.at(0)
->as<KernelExecutor>()
->compiledKernel()
->kernel();
for (const auto tv : scheduled_fusion->allTvs()) {
if (tv->isDefinitionType<ScatterOp>()) {
// The loop domain should look like: [TIDx, Serial] if
// grouped, or [TIDx] if not.
EXPECT_TRUE(
std::ssize(tv->getLoopDomain()) == 1 ||
std::ssize(tv->getLoopDomain()) == 2);
EXPECT_EQ(tv->axis(0)->getParallelType(), ParallelType::TIDx);
if (std::ssize(tv->getLoopDomain()) == 2) {
EXPECT_EQ(tv->axis(-1)->getParallelType(), ParallelType::Serial);
}
EXPECT_EQ(tv->getComputeAtPosition(), 0);
// The input should not be inlined into the scatter dim
EXPECT_EQ(
tv->definition()
->as<ScatterOp>()
->in()
->as<TensorView>()
->getComputeAtPosition(),
0);
}
}
}

// Pattern appearing in test_moe.py
Expand Down