We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2372487 commit e00a9edCopy full SHA for e00a9ed
csrc/device_lower/analysis/tma.cpp
@@ -1180,8 +1180,13 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {
1180
1181
// Validate broadcast usage: TMA auto-fills out-of-bounds with zeros,
1182
// breaking broadcast semantics when broadcast dims participate in tile
1183
- // shape.
1184
- validateTMAConsumerBroadcasts(smem_tv);
+ // shape. Skip if smem_tv is a producer of mma op
+ const auto& consumers = ir_utils::consumerTvsOf(smem_tv);
1185
+ if (std::none_of(consumers.begin(), consumers.end(), [](TensorView* tv) {
1186
+ return tv->definition()->isA<MmaOp>();
1187
+ })) {
1188
+ validateTMAConsumerBroadcasts(smem_tv);
1189
+ }
1190
1191
MmaInputSmemSwizzle swizzle = getSwizzle(smem_tv);
1192
0 commit comments