Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 0f1eba to 1815c3
6 changes: 3 additions & 3 deletions docs/compiler_internals/inject_fence_proxy.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
### Timeline View

```
generic initialize_descriptor → generic shared-store → async wgmma
generic initialize_wgmma_descriptor → generic shared-store → async wgmma
│ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑
Expand Down Expand Up @@ -53,7 +53,7 @@ def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.ptx_wgmma_ss(
"float16",
Expand Down Expand Up @@ -83,7 +83,7 @@ def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.fence_proxy_async()
T.ptx_wgmma_ss(
Expand Down
6 changes: 6 additions & 0 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_tcgen05mma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutSm100(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_full_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeFullBankSwizzleLayout(stride, continuous, element_size);
Expand Down
27 changes: 26 additions & 1 deletion src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(14)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -219,6 +229,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_lane_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -286,11 +301,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(initialize_descriptor)
TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor)
.set_num_inputs(7)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand All @@ -311,5 +331,10 @@ TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm
44 changes: 38 additions & 6 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,24 @@ TVM_DLL const Op &ptx_wgmma_ss();
/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
* b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
* A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
* scale_out, bool scale_in_a, bool scale_in_b);
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
* bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out,
* bool scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_rs();

/*!
* \brief tvm intrinsic for tcgen05 mma shared-shared instructions.
*/
TVM_DLL const Op &ptx_tcgen05_mma_ss();

/*!
* \brief tvm intrinsic for tcgen05 mma tensor-shared instructions.
*/
TVM_DLL const Op &ptx_tcgen05_mma_ts();

/*!
* \brief tvm intrinsics for initializing tensor memory
*
Expand Down Expand Up @@ -361,6 +371,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/
TVM_DLL const Op &warpgroup_wait();

/*!
* \brief Fence accumulator operand registers for upcoming WGMMA operations
*
* warpgroup_fence_operand(dtype, ptr, offset, num_regs)
*
*/
TVM_DLL const Op &warpgroup_fence_operand();

/*!
* \brief Return the canonical lane index for the calling thread.
*
Expand Down Expand Up @@ -494,7 +512,21 @@ TVM_DLL const Op &tl_shuffle_elect();
* This op is used to represent a descriptor initialization operation in
* tilelang.
*/
TVM_DLL const Op &initialize_descriptor();
TVM_DLL const Op &initialize_wgmma_descriptor();

/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* tcgen05 mma.
*/
TVM_DLL const Op &initialize_tcgen05_descriptor();

/*!
* \brief tilelang intrinsic for committing UMMA (TCGEN05) barrier arrive.
*
* This op wraps the device-side arrive used to signal completion of MMA work
* to a shared-memory mbarrier. It mirrors CUTLASS's umma_arrive.
*/
TVM_DLL const Op &tcgen05_mma_arrive();

/*!
* \brief tilelang intrinsic for setting the start address of a descriptor
Expand Down
72 changes: 6 additions & 66 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,13 @@
#include <tvm/tir/transform.h>

#include "../target/utils.h"
#include "tcgen5_meta.h"

namespace tvm {
namespace tl {

using namespace tir;

struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
};

// Return {is_success, meta}
static inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 16 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
}
}
FAIL;
#undef FAIL
#undef SUCCESS
}

/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
Expand Down Expand Up @@ -186,6 +122,8 @@ bool GemmNode::AllowWGMMA(int block_size, Target target) const {
GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target);
LOG(INFO) << "allow_tcgen5mma: " << allow_tcgen5mma
<< ", allow_wgmma: " << allow_wgmma;
if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
Expand All @@ -195,7 +133,7 @@ GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
ICHECK(0) << "Unsupported target for gemm: " << target;
}
}

Expand Down Expand Up @@ -578,6 +516,8 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {

if (A.scope() == "local.fragment") {
ICHECK(B.scope() != "local.fragment");
ICHECK(!trans_A)
<< "gemm_rs requires the A operand to be in non-transposed layout.";
op_name = "tl::gemm_rs";
} else if (B.scope() == "local.fragment") {
op_name = "tl::gemm_sr";
Expand Down
73 changes: 67 additions & 6 deletions src/op/gemm_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include "../support/ffi_aliases.h"
#include "../target/utils.h"
#include "tcgen5_meta.h"
#include "tvm/ffi/string.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -49,7 +51,6 @@ using namespace tir;
*/
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();

node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
Expand All @@ -76,6 +77,19 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
if (args.size() > 16) {
node->mbarptr = args[16];
} else {
node->mbarptr = IntImm(DataType::UInt(32), 0);
}
if (args.size() > 18) {
node->C_coords = Array<PrimExpr>({args[17], args[18]});
} else if (args.size() > 17) {
node->C_coords = Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)});
} else {
node->C_coords = Array<PrimExpr>(
{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)});
}
data_ = std::move(node);
}

Expand All @@ -92,16 +106,37 @@ TileOperator GemmPyNode::Clone() const {
return GemmPy(op);
}

GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
bool GemmPyNode::AllowTCGEN5MMA(Target target) const {
return TargetIsSm100(target) &&
((A.scope() == "shared.dyn" || A.scope() == "shared" ||
A.scope() == "shared.tmem") &&
(B.scope() == "shared.dyn" || B.scope() == "shared") &&
C.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
}
Comment on lines +109 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

AllowTCGEN5MMA gate LGTM; verify TargetIsSm100 logic

Condition looks right. However, TargetIsSm100 in src/target/utils.cc appears to use bitwise '&' instead of '&&', which can mis-gate SM100 detection.

If it shows arch >= 100 & arch <= 110, change to arch >= 100 && arch <= 110 in src/target/utils.cc. This affects AllowTCGEN5MMA().


🏁 Script executed:

#!/bin/bash
# Verify TargetIsSm100 implementation uses logical &&, not bitwise &.
rg -nC2 'bool\s+TargetIsSm100' src | sed -n '1,120p'

Length of output: 489


🏁 Script executed:

#!/bin/bash
sed -n '56,90p' src/target/utils.cc

Length of output: 1002


Replace bitwise & with logical && in TargetIsSm100
In src/target/utils.cc:58 change arch >= 100 & arch <= 110 to use && for correct SM100 detection.

🤖 Prompt for AI Agents
In src/op/gemm_py.cc around lines 110 to 117, the review points out that
TargetIsSm100 uses a bitwise & instead of logical && when checking architecture
range; update TargetIsSm100 in src/target/utils.cc (line ~58) to use `&&` (arch
>= 100 && arch <= 110) so the SM100 detection is correct, then rebuild and run
tests that exercise GemmPyNode::AllowTCGEN5MMA to ensure the change fixes
selection logic.


bool GemmPyNode::AllowWGMMA(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
}

GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target);
if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
} else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
TargetIsTuring(target) || TargetIsHopper(target) ||
TargetIsSm100(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
Expand Down Expand Up @@ -290,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK() {
});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tl.get_tcgen5_mma_meta",
[](int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype);
Array<Integer> result;
if (success) {
result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k));
}
return result;
});
refl::GlobalDef().def(
"tl.get_tcgen5_instr_desc",
[](int atom_m, int atom_n, int atom_k, DataType ab_dtype,
DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a,
int scale_in_b) {
uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype,
c_dtype, a_is_k_major, b_is_k_major,
scale_in_a, scale_in_b);
return Integer(static_cast<int64_t>(desc));
});
}

} // namespace tl
} // namespace tvm
Loading
Loading