Skip to content

Commit db9721d

Browse files
authored
Allow ScatterOp with multiple dimensions as long as extents are the same (#5175)
Analogous to the exact size attribute of `GatherOp`.
1 parent 9dff34b commit db9721d

File tree

8 files changed

+122
-28
lines changed

8 files changed

+122
-28
lines changed

csrc/device_lower/pass/index.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ void IndexLowering::handle(const ScatterOp* sop) {
372372
sop->dim(),
373373
lowered_index,
374374
lowered_src,
375+
sop->exactSizes(),
375376
sop->accumulate() ? std::optional(sop->accumulateOp()) : std::nullopt));
376377
GpuLower::current()->propagateExprInfo(sop, back());
377378
}

csrc/device_lower/validation.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,12 +1339,10 @@ void validateScatter(Fusion* fusion) {
13391339
auto in_tv = sop->in()->as<TensorView>();
13401340
auto out_tv = sop->out()->as<TensorView>();
13411341

1342-
// TensorIndexer currently only supports scatter with 1D tensors
1343-
// due to the non-exactness of non-indexed IDs.
1344-
NVF_ERROR_EQ(
1345-
out_tv->getLogicalDomain().size(),
1346-
1,
1347-
"Scatter with multi-dimensional tensors is not yet supported: ",
1342+
// TensorIndexer currently only supports exact scatter ops
1343+
NVF_ERROR(
1344+
sop->exactSizes(),
1345+
"Non-exact scatter is not yet supported: ",
13481346
sop->toString());
13491347

13501348
// Scatter is implemented as an in-place op. To lower it safely, it

csrc/ir/internal_nodes.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,17 @@ class GatherOp : public Expr {
252252
class ScatterOp : public Expr {
253253
public:
254254
using Expr::Expr;
255+
256+
// exact_sizes: true when non-scatter axes of all inputs are
257+
// guaranteed to have the same extents
255258
ScatterOp(
256259
IrBuilderPasskey,
257260
Val* out,
258261
Val* self,
259262
int64_t dim,
260263
Val* index,
261264
Val* src,
265+
bool exact_sizes,
262266
std::optional<BinaryOpType> accumulate_op = std::nullopt);
263267

264268
NVFUSER_DECLARE_CLONE_AND_CREATE
@@ -295,13 +299,17 @@ class ScatterOp : public Expr {
295299

296300
IterDomain* getIndexedID() const;
297301

298-
bool accumulate() const {
302+
bool exactSizes() const {
299303
return attribute<bool>(1);
300304
}
301305

306+
bool accumulate() const {
307+
return attribute<bool>(2);
308+
}
309+
302310
BinaryOpType accumulateOp() const {
303311
NVF_ERROR(accumulate());
304-
return attribute<BinaryOpType>(2);
312+
return attribute<BinaryOpType>(3);
305313
}
306314
};
307315

csrc/ir/nodes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,15 @@ ScatterOp::ScatterOp(
292292
int64_t dim,
293293
Val* index,
294294
Val* src,
295+
bool exact_sizes,
295296
std::optional<BinaryOpType> accumulate_op)
296297
: Expr(passkey) {
297298
addInput(self);
298299
addInput(index);
299300
addInput(src);
300301
addOutput(out);
301302
addDataAttribute(dim);
303+
addDataAttribute(exact_sizes);
302304
// is this accumulate?
303305
addDataAttribute(accumulate_op.has_value());
304306
if (accumulate_op.has_value()) {

csrc/logical_domain_map.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ std::pair<std::unordered_set<IterDomain*>, bool> getNonMappingDomainInfo(
141141
// we are not mapping anything, `has_consumer_id` doesn't matter.
142142
has_consumer_id = false;
143143
}
144+
} else if (auto sop = dynamic_cast<ScatterOp*>(consumer_tv->definition())) {
145+
if (producer_tv != sop->in()) {
146+
auto producer_logical =
147+
TensorDomain::noReductions(producer_tv->getLogicalDomain());
148+
for (const auto& [i, p_id] : enumerate(producer_logical)) {
149+
if ((int64_t)i == sop->dim() || !sop->exactSizes()) {
150+
non_mapping_ids.insert(p_id);
151+
}
152+
}
153+
has_consumer_id = true;
154+
}
144155
}
145156

146157
return std::make_pair(non_mapping_ids, has_consumer_id);
@@ -153,15 +164,6 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseLogicalDomainMap::map(
153164
const TensorDomain* consumer,
154165
const std::unordered_set<IterDomain*>& dims_to_map,
155166
bool producer_to_consumer) const {
156-
// In the case of scatter, nothing is guaranteed to map except for
157-
// the self producer. Note that in PyTorch even non-indexed
158-
// dimensions of index and src tensors are not guaranteed to have
159-
// the same extent as the self/out tensors.
160-
if (auto sop = dynamic_cast<ScatterOp*>(consumer_tv_->definition());
161-
sop != nullptr && producer_tv_ != sop->in()) {
162-
return {};
163-
}
164-
165167
std::vector<bool> broadcast_flags;
166168
if (auto* bop = dynamic_cast<BroadcastOp*>(consumer_tv_->definition())) {
167169
broadcast_flags = bop->getBroadcastDimFlags();

csrc/ops/indexing.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
*/
77
// clang-format on
88

9+
#include <expr_simplifier.h>
910
#include <ir/all_nodes.h>
1011
#include <ir/builder.h>
1112
#include <ir/iostream.h>
@@ -182,6 +183,22 @@ TensorView* scatter(
182183
"dimensions in scatter like ops.");
183184
dim = wrapDim(dim, (int64_t)self_dom.size());
184185

186+
bool is_exact = true;
187+
for (const auto i : arange(std::ssize(self_dom))) {
188+
if (i == dim) {
189+
continue;
190+
}
191+
Val* self_id_size = self_dom.at(i)->getMaybeExpandedExtent();
192+
Val* idx_id_size = idx_dom.at(i)->getMaybeExpandedExtent();
193+
auto same_size =
194+
simplifyExpr(SimplifyingIrBuilder::eqExpr(self_id_size, idx_id_size));
195+
if (same_size->isTrue()) {
196+
continue;
197+
}
198+
is_exact = false;
199+
break;
200+
}
201+
185202
// The shape of output tensor is same as self tensor.
186203
std::vector<IterDomain*> out_logical;
187204
for (const auto i : arange(self_dom.size())) {
@@ -195,13 +212,16 @@ TensorView* scatter(
195212
}
196213

197214
// Create the loop domain based on the logical domain of the index
198-
// tensor.
215+
// tensor. For non-scatter axes, reuse the logical IDs if exact.
199216
std::vector<IterDomain*> out_loop;
200217
out_loop.reserve(idx_dom.size());
201-
std::ranges::transform(
202-
idx_dom, std::back_inserter(out_loop), [](IterDomain* id) {
203-
return IterDomainBuilder(id).build();
204-
});
218+
for (const auto& [i, idx_id] : enumerate(idx_dom)) {
219+
if ((int64_t)i == dim || !is_exact) {
220+
out_loop.push_back(IterDomainBuilder(idx_id).build());
221+
} else {
222+
out_loop.push_back(out_logical.at(i));
223+
}
224+
}
205225

206226
// Create the output tensor. The validation of the loop domain needs
207227
// to be skipped as it is not guaranteed to be equivalent to the
@@ -226,7 +246,7 @@ TensorView* scatter(
226246
}
227247

228248
IrBuilder::create<ScatterOp>(
229-
out_tensor, self, dim, index, src, accumulate_op);
249+
out_tensor, self, dim, index, src, is_exact, accumulate_op);
230250

231251
return out_tensor->as<TensorView>();
232252
}

csrc/scheduler/greedy.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,9 @@ class CompileTimeChecker : private IterVisitor {
239239
auto inp = scatter->in()->as<TensorView>();
240240
auto out = scatter->out()->as<TensorView>();
241241

242-
if (out->getLogicalDomain().size() != 1) {
242+
if (!scatter->exactSizes()) {
243243
can_schedule_ = false;
244-
setRejectReason(
245-
"Scatter with multi-dimensional tensors is not yet supported");
244+
setRejectReason("Non-exact scatter is not yet supported");
246245
return;
247246
}
248247

tests/cpp/test_scatter.cpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,69 @@ TEST_F(ScatterTest, GridCounting) {
168168
testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__);
169169
}
170170

171-
TEST_F(ScatterTest, BlockCountingWithShmem2D) {
171+
TEST_P(ScatterTest, BlockCountingWithShmem2DExact) {
172+
auto fusion_ptr = std::make_unique<Fusion>();
173+
Fusion& fusion = *fusion_ptr.get();
174+
FusionGuard fg(&fusion);
175+
176+
const std::vector<int64_t> self_shape{4, 100};
177+
const std::vector<int64_t> index_shape{4, 10};
178+
179+
auto tv0 = makeContigConcreteTensor(index_shape, DataType::Int);
180+
fusion.addInput(tv0);
181+
182+
auto tv1 = set(tv0);
183+
auto tv2 = zeros(
184+
{IrBuilder::create<Val>(self_shape[0]),
185+
IrBuilder::create<Val>(self_shape[1])},
186+
DataType::Int);
187+
auto tv3 = ones(
188+
{IrBuilder::create<Val>(index_shape[0]),
189+
IrBuilder::create<Val>(index_shape[1])},
190+
DataType::Int);
191+
auto tv4 = scatter(tv2, 1, tv1, tv3);
192+
auto tv5 = set(tv4);
193+
fusion.addOutput(tv5);
194+
195+
auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
196+
auto t0 = at::randperm(self_shape[1], options)
197+
.slice(0, 0, index_shape[1])
198+
.repeat({index_shape[0], 1});
199+
200+
if (manual_scheduling) {
201+
for (auto tv : fusion.allTvs()) {
202+
tv->axis(0)->parallelize(ParallelType::BIDx);
203+
tv->axis(1)->parallelize(ParallelType::TIDx);
204+
}
205+
206+
// Scatter input must use the same memory as the output
207+
tv2->setMemoryType(MemoryType::Shared);
208+
tv2->setAllocationDomain(tv2->getLogicalDomain(), true);
209+
tv4->setMemoryType(MemoryType::Shared);
210+
tv4->setAllocationDomain(tv4->getLogicalDomain(), true);
211+
212+
KernelExecutor ke;
213+
ke.compile(&fusion, {t0});
214+
auto outputs = ke.run({t0});
215+
216+
testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__);
217+
} else {
218+
FusionExecutorCache executor_cache(std::move(fusion_ptr));
219+
auto outputs = executor_cache.runFusionWithInputs({t0});
220+
testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);
221+
FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
222+
// All ops should be taken care the greedy scheduler, but there's
223+
// an additional segment due to a segmenter_set. Not sure why it
224+
// gets inserted.
225+
EXPECT_THAT(
226+
runtime->fusionSegments()->groups(),
227+
testing::UnorderedElementsAre(
228+
HeuristicIs(SchedulerType::ExprEval),
229+
HeuristicIs(SchedulerType::Greedy)));
230+
}
231+
}
232+
233+
TEST_F(ScatterTest, BlockCountingWithShmem2DNonExact) {
172234
// Scatter allows the non-indexed domains of the index tensor to
173235
// have smaller extents, which causes indexing error as there's not
174236
// traversal path. It is not currently supported.
@@ -209,7 +271,9 @@ TEST_F(ScatterTest, BlockCountingWithShmem2D) {
209271
tv4->setAllocationDomain(tv4->getLogicalDomain(), true);
210272

211273
auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
212-
auto t0 = at::randperm(self_shape[1], options).slice(0, 0, index_shape[1]);
274+
auto t0 = at::randperm(self_shape[1], options)
275+
.slice(0, 0, index_shape[1])
276+
.repeat({index_shape[0], 1});
213277

214278
KernelExecutor ke;
215279
ke.compile(&fusion, {t0});

0 commit comments

Comments
 (0)