diff --git a/csrc/scheduler/greedy.cpp b/csrc/scheduler/greedy.cpp index 898e98d6e1c..7a6c46723f5 100644 --- a/csrc/scheduler/greedy.cpp +++ b/csrc/scheduler/greedy.cpp @@ -1102,14 +1102,6 @@ class ConstrainedOpScheduler : public OptOutDispatch { const auto& constrained_loop_id_offsets = getDependentLoopIds(tv, constrained_logical_id_offsets); - // Don't inline constrained IDs. For example, like reduction IDs, - // argsort'ed IDs should never be inlined into its consumers. - for (const auto constrained_logical_id_offset : - constrained_logical_id_offsets) { - uninlinable_ids_.insert( - tv->getLogicalDomain().at(constrained_logical_id_offset)); - } - // Move the constrained_logical_ids innermost std::unordered_map old2new; for (const auto [i, offset] : enumerate(constrained_loop_id_offsets)) { @@ -1176,6 +1168,25 @@ class ConstrainedOpScheduler : public OptOutDispatch { tv->flatten( 0, std::ssize(tv->getLoopDomain()) - 1 - num_constrained_loop_ids); tv->axis(0)->parallelize(ParallelType::BIDx); + + // Don't inline constrained IDs. For example, like reduction IDs, + // argsort'ed IDs should never be inlined into its consumers. + std::unordered_set constrained_logical; + for (const auto constrained_logical_id_offset : + constrained_logical_id_offsets) { + constrained_logical.insert( + tv->getLogicalDomain().at(constrained_logical_id_offset)); + } + + auto all_constrained_ids = DependencyCheck::getAllValsBetween( + constrained_logical, + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + for (const auto loop_id : tv->getLoopDomain()) { + if (std::ranges::find(all_constrained_ids, loop_id) != + all_constrained_ids.end()) { + uninlinable_ids_.insert(loop_id); + } + } } private: diff --git a/tests/cpp/test_greedy.cpp b/tests/cpp/test_greedy.cpp index 173512fa1b5..cdcbedb582d 100644 --- a/tests/cpp/test_greedy.cpp +++ b/tests/cpp/test_greedy.cpp @@ -735,6 +735,37 @@ TEST_P(GreedySchedulerTestConstraintSize, ArgsortLargeConstrainedIDs) { testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); EXPECT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented()); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + Fusion* scheduled_fusion = runtime->executors() + .at(0) + ->as() + ->compiledKernel() + ->kernel(); + for (const auto tv : scheduled_fusion->allTvs()) { + if (tv->isDefinitionType()) { + // The loop domain should look like: [BIDx, TIDx, Group] if + // grouped, or [BIDx, TIDx] if not. The computeAt and producer + // positions should be 1 in both cases. + EXPECT_TRUE( + std::ssize(tv->getLoopDomain()) == 2 || + std::ssize(tv->getLoopDomain()) == 3); + EXPECT_EQ(tv->axis(0)->getParallelType(), ParallelType::BIDx); + EXPECT_EQ(tv->axis(1)->getParallelType(), ParallelType::TIDx); + if (std::ssize(tv->getLoopDomain()) == 3) { + EXPECT_EQ(tv->axis(-1)->getParallelType(), ParallelType::Group); + } + EXPECT_EQ(tv->getComputeAtPosition(), 1); + // The input should not be inlined into the argsort dim + EXPECT_EQ( + tv->definition() + ->as() + ->in() + ->as() + ->getComputeAtPosition(), + 1); + } + } } TEST_P(GreedySchedulerTestConstraintSize, ScanLargeConstrainedIDs) { @@ -758,6 +789,37 @@ TEST_P(GreedySchedulerTestConstraintSize, ScanLargeConstrainedIDs) { testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); EXPECT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented()); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + Fusion* scheduled_fusion = runtime->executors() + .at(0) + ->as() + ->compiledKernel() + ->kernel(); + for (const auto tv : scheduled_fusion->allTvs()) { + if (tv->isDefinitionType()) { + // The loop domain should look like: [BIDx, TIDx, Group] if + // grouped, or [BIDx, TIDx] if not. The computeAt and producer + // positions should be 1 in both cases. + EXPECT_TRUE( + std::ssize(tv->getLoopDomain()) == 2 || + std::ssize(tv->getLoopDomain()) == 3); + EXPECT_EQ(tv->axis(0)->getParallelType(), ParallelType::BIDx); + EXPECT_EQ(tv->axis(1)->getParallelType(), ParallelType::TIDx); + if (std::ssize(tv->getLoopDomain()) == 3) { + EXPECT_EQ(tv->axis(-1)->getParallelType(), ParallelType::Group); + } + EXPECT_EQ(tv->getComputeAtPosition(), 1); + // The input should not be inlined into the scan dim + EXPECT_EQ( + tv->definition() + ->as() + ->in() + ->as() + ->getComputeAtPosition(), + 1); + } + } } TEST_P(GreedySchedulerTestConstraintSize, ScatterLargeConstrainedIDs) { @@ -786,6 +848,35 @@ TEST_P(GreedySchedulerTestConstraintSize, ScatterLargeConstrainedIDs) { testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); EXPECT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented()); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + Fusion* scheduled_fusion = runtime->executors() + .at(0) + ->as() + ->compiledKernel() + ->kernel(); + for (const auto tv : scheduled_fusion->allTvs()) { + if (tv->isDefinitionType()) { + // The loop domain should look like: [TIDx, Serial] if + // grouped, or [TIDx] if not. + EXPECT_TRUE( + std::ssize(tv->getLoopDomain()) == 1 || + std::ssize(tv->getLoopDomain()) == 2); + EXPECT_EQ(tv->axis(0)->getParallelType(), ParallelType::TIDx); + if (std::ssize(tv->getLoopDomain()) == 2) { + EXPECT_EQ(tv->axis(-1)->getParallelType(), ParallelType::Serial); + } + EXPECT_EQ(tv->getComputeAtPosition(), 0); + // The input should not be inlined into the scatter dim + EXPECT_EQ( + tv->definition() + ->as() + ->in() + ->as() + ->getComputeAtPosition(), + 0); + } + } } // Pattern appearing in test_moe.py