Skip to content
5 changes: 5 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(wait_wgmma)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
18 changes: 13 additions & 5 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,11 @@ TVM_DLL const Op &ptx_wgmma_ss();
/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
* b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
* A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
* scale_out, bool scale_in_a, bool scale_in_b);
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
* bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out,
* bool scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_rs();

Expand Down Expand Up @@ -358,6 +358,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/
TVM_DLL const Op &warpgroup_wait();

/*!
* \brief Fence accumulator operand registers for upcoming WGMMA operations
*
* warpgroup_fence_operand(dtype, ptr, offset, num_regs)
*
*/
TVM_DLL const Op &warpgroup_fence_operand();

/*!
* \brief Wait the previous wgmma to finish
*
Expand Down
2 changes: 2 additions & 0 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,8 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {

if (A.scope() == "local.fragment") {
ICHECK(B.scope() != "local.fragment");
ICHECK(!trans_A)
<< "gemm_rs requires the A operand to be in non-transposed layout.";
op_name = "tl::gemm_rs";
} else if (B.scope() == "local.fragment") {
op_name = "tl::gemm_sr";
Expand Down
161 changes: 123 additions & 38 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ std::string CodeGenTileLangCUDA::Finish() {
if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}
if (need_mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/mma.h>\n";
}
if (need_wgmma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/wgmma.h>\n";
}
if (enable_fp8_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
}
Expand Down Expand Up @@ -1383,6 +1389,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma)
<< ">();\n";
} else if (op->op.same_as(tl::warpgroup_fence_operand())) {
ICHECK_EQ(op->args.size(), 4U);
std::string dtype = Downcast<StringImm>(op->args[0])->value;
std::string data_ptr = this->PrintExpr(op->args[1]);
std::string offset = this->PrintExpr(op->args[2]);
std::string num_regs = this->PrintExpr(op->args[3]);
auto dtype_enum = tl::codegen::ptx::DTypeFromString(dtype);
std::string cast_type = "uint32_t";
if (dtype_enum == tl::codegen::ptx::DataType::kFloat32 ||
dtype_enum == tl::codegen::ptx::DataType::kTensorFloat32) {
cast_type = "float";
}
this->PrintIndent();
this->stream << "tl::warpgroup_fence_operand(reinterpret_cast<" << cast_type
<< "*>(" << data_ptr << " + " << offset << "), " << num_regs
<< ");\n";
} else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
Expand Down Expand Up @@ -1494,14 +1516,41 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
bool saturate = Downcast<Bool>(op->args[12])->value;
std::string bit_op =
op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);

need_mma_instruction_h_ = true;
this->PrintIndent();
this->stream << asm_code;
std::string mma_call =
"tl::mma_sync<(AType), (BType), (CType), (M), (N), (K), (TransA), "
"(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true");
replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true");
replacer.register_rule("(ARegType)",
tl::codegen::GetMMARegisterType(dtype_a_enum));
replacer.register_rule("(BRegType)",
tl::codegen::GetMMARegisterType(dtype_b_enum));
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", a_bias);
replacer.register_rule("(B_ptr)", b_ref);
replacer.register_rule("(B_offset)", b_bias);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_bias);
this->stream << replacer.rewrite(mma_call);
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
Expand Down Expand Up @@ -1578,6 +1627,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
scale_in_b, a_is_shared, "", "", "", false);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
need_wgmma_instruction_h_ = true;
std::string wgmma_asm_code =
"tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
"(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), "
Expand Down Expand Up @@ -1606,41 +1656,76 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
this->stream << wgmma_asm_code;
} else if (op->op.same_as(tl::ptx_wgmma_rs())) {
// arg 0: dtype
// arg 1: shape
// arg 2: A_layout
// arg 3: B_layout
// arg 4: A_dtype
// arg 5: B_dtype
// arg 6: C_dtype
// arg 7: multiplicand_a
// arg 8: multiplicand_b
// arg 0: shape
// arg 1: B_layout
// arg 2: A_dtype
// arg 3: B_dtype
// arg 4: C_dtype
// arg 5: multiplicand_a
// arg 6: multiplicand_a offset
// arg 7: multiplicand_b descriptor
// arg 8: multiplicand_b offset
// arg 9: accumulator
// arg 10: saturate
ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_rs args is " << op->args;
// arg 10: accumulator offset
// arg 11: scale_out
// arg 12: scale_in_a
// arg 13: scale_in_b
ICHECK_EQ(op->args.size(), 14U) << "ptx_wgmma_rs args is " << op->args;
std::string shape = Downcast<StringImm>(op->args[0])->value;
bool A_layout = Downcast<Bool>(op->args[1])->value;
bool B_layout = Downcast<Bool>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string A_offset = this->PrintExpr(op->args[7]);
std::string b_desc = this->PrintExpr(op->args[8]);
std::string B_offset = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_offset = this->PrintExpr(op->args[11]);
bool scale_out = Downcast<Bool>(op->args[12])->value;
bool scale_in_a = Downcast<Bool>(op->args[13])->value;
bool scale_in_b = Downcast<Bool>(op->args[14])->value;
bool b_is_k_major = Downcast<Bool>(op->args[1])->value;
std::string A_dtype = Downcast<StringImm>(op->args[2])->value;
std::string B_dtype = Downcast<StringImm>(op->args[3])->value;
std::string C_dtype = Downcast<StringImm>(op->args[4])->value;
std::string a_ref = this->PrintExpr(op->args[5]);
std::string A_offset = this->PrintExpr(op->args[6]);
std::string b_desc = this->PrintExpr(op->args[7]);
std::string B_offset = this->PrintExpr(op->args[8]);
std::string c_ref = this->PrintExpr(op->args[9]);
std::string c_offset = this->PrintExpr(op->args[10]);
bool scale_out = Downcast<Bool>(op->args[11])->value;
bool scale_in_a = Downcast<Bool>(op->args[12])->value;
bool scale_in_b = Downcast<Bool>(op->args[13])->value;

auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);

const bool a_is_shared = false;
need_wgmma_instruction_h_ = true;
this->PrintIndent();
std::string asm_code = PrintWGMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, A_offset,
b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b,
a_is_shared, "", "", "", false);
this->stream << asm_code;
std::string wgmma_call =
"tl::wgmma_rs<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
"(tnspB), (scaleA), (scaleB)>(reinterpret_cast<const "
"uint32_t*>((A_ptr) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), "
"reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"(scale_out));\n";

tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(tnspA)", "false");
replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true");
replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1");
replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1");
replacer.register_rule("(CRegType)",
tl::codegen::GetMMARegisterType(dtype_c_enum));
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
wgmma_call = replacer.rewrite(wgmma_call);
this->stream << wgmma_call;
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
// arg 0: whether the matrix is loaded in column major format or not.
// arg 1: number of matrices to load.
Expand Down
4 changes: 4 additions & 0 deletions src/target/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class CodeGenTileLangCUDA final : public CodeGenC {
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
// whether need tl mma instruction header
bool need_mma_instruction_h_{false};
// whether need tl wgmma instruction header
bool need_wgmma_instruction_h_{false};
// whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false};
// whether need cooperative_groups.h
Expand Down
15 changes: 15 additions & 0 deletions src/target/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1529,5 +1529,20 @@ std::string PrintWaitBarrierAsm(const std::string &barrier) {
return predicated_asm_code;
}

std::string GetMMARegisterType(const ptx::DataType &dtype) {
switch (dtype) {
case ptx::DataType::kInt32:
return "unsigned";
case ptx::DataType::kUInt32:
return "unsigned";
case ptx::DataType::kFloat32:
return "float";
case ptx::DataType::kFloat64:
return "double";
default:
return "unsigned";
}
}
Comment on lines +1532 to +1545
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify limited type coverage and consider error handling.

The function handles only 4 data types (kInt32, kUInt32, kFloat32, kFloat64) but MMA/WGMMA operations support many more: kFloat16, kBFloat16, kInt8, kUInt8, kInt4, kUInt4, kFloat8_e4m3, kFloat8_e5m2, etc. (see valid_mma_configs/valid_wgmma_configs above).

Concerns:

  • kInt32 mapped to "unsigned": Signed int32 is mapped to "unsigned", which may be incorrect depending on usage context.
  • Silent fallback: The default case returns "unsigned" for all unsupported types instead of failing, which could mask errors if the function is called with types not yet supported by the new dispatcher.

Recommendations:

  • Verify whether the limited type coverage is intentional for the initial implementation.
  • Consider using LOG(FATAL) or CHECK(false) in the default case to catch unsupported types explicitly during development.
  • If the mapping of kInt32 to "unsigned" is intentional for PTX register constraints, add a clarifying comment.

Run the following script to verify if GetMMARegisterType is called with types beyond the four it handles:


🏁 Script executed:

#!/bin/bash
# Search for calls to GetMMARegisterType and analyze argument types
rg -n -C5 'GetMMARegisterType\(' --type=cpp

Length of output: 2440


🏁 Script executed:

#!/bin/bash
rg -nP "enum\s+DataType" -C5 --type=cpp

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for DataType enum definition in src/target
rg -nP "enum\s+(class\s+)?DataType" -C5 src/target
rg -n "DataType" -C2 src/target/ptx.h

Length of output: 1150


Improve type coverage and enforce unsupported types in GetMMARegisterType.

  • Extend the switch to cover all ptx::DataType enum values (e.g. kInt4, kUInt4, kInt8, kUInt8, kInt16, kUInt16, kFloat16, kBFloat16, kFloat8_e4m3, kFloat8_e5m2, etc.).
  • Replace the silent default with a runtime failure (LOG(FATAL) or CHECK(false)) to catch unsupported types.
  • If mapping signed kInt32 to "unsigned" is intentional, document the rationale with an explanatory comment.
🤖 Prompt for AI Agents
In src/target/ptx.cc around lines 1532 to 1545, the GetMMARegisterType switch
only covers a few ptx::DataType values, returns "unsigned" for signed kInt32,
and falls back to a silent default; update the switch to explicitly handle every
ptx::DataType enum value (kInt4, kUInt4, kInt8, kUInt8, kInt16, kUInt16, kInt32,
kUInt32, kFloat16, kBFloat16, kFloat32, kFloat64, kFloat8_e4m3, kFloat8_e5m2,
etc.) with the correct string mapping, replace the default branch with a runtime
failure (e.g., LOG(FATAL) or CHECK(false)) to surface unsupported types, and add
a short comment explaining why kInt32 maps to "unsigned" if that mapping is
intentional.


} // namespace codegen
} // namespace tvm::tl
5 changes: 5 additions & 0 deletions src/target/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier,
*/
std::string PrintWaitBarrierAsm(const std::string &barrier);

/*!
* \brief Return the register-level C++ type used by MMA fragments.
*/
std::string GetMMARegisterType(const ptx::DataType &dtype);

} // namespace codegen
} // namespace tvm::tl

Expand Down
Loading
Loading