Skip to content

Commit e00a9ed

Browse files
committed
skip mma
1 parent 2372487 commit e00a9ed

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

csrc/device_lower/analysis/tma.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,8 +1180,13 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {
11801180

11811181
// Validate broadcast usage: TMA auto-fills out-of-bounds with zeros,
11821182
// breaking broadcast semantics when broadcast dims participate in tile
1183-
// shape.
1184-
validateTMAConsumerBroadcasts(smem_tv);
1183+
// shape. Skip if smem_tv is a producer of mma op
1184+
const auto& consumers = ir_utils::consumerTvsOf(smem_tv);
1185+
if (std::none_of(consumers.begin(), consumers.end(), [](TensorView* tv) {
1186+
return tv->definition()->isA<MmaOp>();
1187+
})) {
1188+
validateTMAConsumerBroadcasts(smem_tv);
1189+
}
11851190

11861191
MmaInputSmemSwizzle swizzle = getSwizzle(smem_tv);
11871192

0 commit comments

Comments
 (0)