@@ -168,7 +168,69 @@ TEST_F(ScatterTest, GridCounting) {
168168 testValidate (&fusion, outputs, {t0}, __LINE__, __FILE__);
169169}
170170
171- TEST_F (ScatterTest, BlockCountingWithShmem2D) {
171+ TEST_P (ScatterTest, BlockCountingWithShmem2DExact) {
172+ auto fusion_ptr = std::make_unique<Fusion>();
173+ Fusion& fusion = *fusion_ptr.get ();
174+ FusionGuard fg (&fusion);
175+
176+ const std::vector<int64_t > self_shape{4 , 100 };
177+ const std::vector<int64_t > index_shape{4 , 10 };
178+
179+ auto tv0 = makeContigConcreteTensor (index_shape, DataType::Int);
180+ fusion.addInput (tv0);
181+
182+ auto tv1 = set (tv0);
183+ auto tv2 = zeros (
184+ {IrBuilder::create<Val>(self_shape[0 ]),
185+ IrBuilder::create<Val>(self_shape[1 ])},
186+ DataType::Int);
187+ auto tv3 = ones (
188+ {IrBuilder::create<Val>(index_shape[0 ]),
189+ IrBuilder::create<Val>(index_shape[1 ])},
190+ DataType::Int);
191+ auto tv4 = scatter (tv2, 1 , tv1, tv3);
192+ auto tv5 = set (tv4);
193+ fusion.addOutput (tv5);
194+
195+ auto options = at::TensorOptions ().dtype (at::kLong ).device (at::kCUDA , 0 );
196+ auto t0 = at::randperm (self_shape[1 ], options)
197+ .slice (0 , 0 , index_shape[1 ])
198+ .repeat ({index_shape[0 ], 1 });
199+
200+ if (manual_scheduling) {
201+ for (auto tv : fusion.allTvs ()) {
202+ tv->axis (0 )->parallelize (ParallelType::BIDx);
203+ tv->axis (1 )->parallelize (ParallelType::TIDx);
204+ }
205+
206+ // Scatter input must use the same memory as the output
207+ tv2->setMemoryType (MemoryType::Shared);
208+ tv2->setAllocationDomain (tv2->getLogicalDomain (), true );
209+ tv4->setMemoryType (MemoryType::Shared);
210+ tv4->setAllocationDomain (tv4->getLogicalDomain (), true );
211+
212+ KernelExecutor ke;
213+ ke.compile (&fusion, {t0});
214+ auto outputs = ke.run ({t0});
215+
216+ testValidate (&fusion, outputs, {t0}, __LINE__, __FILE__);
217+ } else {
218+ FusionExecutorCache executor_cache (std::move (fusion_ptr));
219+ auto outputs = executor_cache.runFusionWithInputs ({t0});
220+ testValidate (executor_cache.fusion (), outputs, {t0}, __LINE__, __FILE__);
221+ FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime ();
222+ // All ops should be taken care the greedy scheduler, but there's
223+ // an additional segment due to a segmenter_set. Not sure why it
224+ // gets inserted.
225+ EXPECT_THAT (
226+ runtime->fusionSegments ()->groups (),
227+ testing::UnorderedElementsAre (
228+ HeuristicIs (SchedulerType::ExprEval),
229+ HeuristicIs (SchedulerType::Greedy)));
230+ }
231+ }
232+
233+ TEST_F (ScatterTest, BlockCountingWithShmem2DNonExact) {
172234 // Scatter allows the non-indexed domains of the index tensor to
173235 // have smaller extents, which causes indexing error as there's not
174236 // traversal path. It is not currently supported.
@@ -209,7 +271,9 @@ TEST_F(ScatterTest, BlockCountingWithShmem2D) {
209271 tv4->setAllocationDomain (tv4->getLogicalDomain (), true );
210272
211273 auto options = at::TensorOptions ().dtype (at::kLong ).device (at::kCUDA , 0 );
212- auto t0 = at::randperm (self_shape[1 ], options).slice (0 , 0 , index_shape[1 ]);
274+ auto t0 = at::randperm (self_shape[1 ], options)
275+ .slice (0 , 0 , index_shape[1 ])
276+ .repeat ({index_shape[0 ], 1 });
213277
214278 KernelExecutor ke;
215279 ke.compile (&fusion, {t0});
0 commit comments