Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <cutlass/numeric_types.h>
#include <math_constants.h>

#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>

Comment on lines +13 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Explicitly include cuda_bf16.h to guarantee __nv_bfloat16 availability.

Avoid relying on transitive includes; add the CUDA header so host/RTC builds consistently see __nv_bfloat16.

Apply this diff near the existing cuda_runtime include:

 #ifndef __CUDACC_RTC__
 #include <cuda_runtime.h>
+#include <cuda_bf16.h>
 #endif

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 13 to 15, the file relies on
transitive includes for the CUDA bfloat16 type (__nv_bfloat16); explicitly add
the CUDA header cuda_bf16.h (near the existing cuda_runtime include) to
guarantee __nv_bfloat16 is available for host and RTC builds, avoiding
transitive-include fragility.

using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;
Expand Down Expand Up @@ -318,6 +321,27 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
descriptor.reg32_[0] += (offset >> 4);
}

// and add the desired implicit conversion from bfloat16_t.
struct float_e4m3_t : public cute::float_e4m3_t {
using cute::float_e4m3_t::float_e4m3_t;
CUTLASS_HOST_DEVICE
float_e4m3_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(__nv_bfloat16 x)
: float_e4m3_t(static_cast<float>(x)) {}
};

struct float_e5m2_t : public cute::float_e5m2_t {
using cute::float_e5m2_t::float_e5m2_t;
CUTLASS_HOST_DEVICE
float_e5m2_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e5m2_t(__nv_bfloat16 x)
: float_e5m2_t(static_cast<float>(x)) {}
};

} // namespace tl

namespace cutlass {
Expand Down
5 changes: 3 additions & 2 deletions src/tl_templates/cuda/cuda_fp8.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#pragma once

#include "common.h"
#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp>

using fp8_e4_t = cute::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t;
using fp8_e4_t = tl::float_e4m3_t;
using fp8_e5_t = tl::float_e5m2_t;

struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x;
Expand Down
16 changes: 14 additions & 2 deletions src/tl_templates/cuda/gemm_mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,30 @@ struct OperandTraits<64, N, K, false, num_warp_n, leading_dim,
using Copy = DefaultCopy;
};

template <typename T> struct to_cute_type {
using type = T;
};
template <> struct to_cute_type<tl::float_e4m3_t> {
using type = cute::float_e4m3_t;
};
template <> struct to_cute_type<tl::float_e5m2_t> {
using type = cute::float_e5m2_t;
};

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type_cute = typename to_cute_type<A_type_raw>::type;
using B_type_cute = typename to_cute_type<B_type_raw>::type;
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_raw>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw;

using Instruction =
Expand Down
19 changes: 15 additions & 4 deletions src/tl_templates/cuda/gemm_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,27 @@ using namespace SM90;
namespace tl_wgmma {

using namespace cutlass::gemm::collective::detail; // ss_smem_selector
template <typename T> struct to_cute_type {
using type = T;
};
template <> struct to_cute_type<tl::float_e4m3_t> {
using type = cute::float_e4m3_t;
};
template <> struct to_cute_type<tl::float_e5m2_t> {
using type = cute::float_e5m2_t;
};

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using A_type_cute = typename to_cute_type<A_type_raw>::type;
using B_type_cute = typename to_cute_type<B_type_raw>::type;
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using C_type = C_type_raw;

static constexpr GMMA::Major GmmaMajorA =
Expand Down
3 changes: 2 additions & 1 deletion tilelang/language/allocate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
"""Memory allocation utilities for Tile-AI programs.

This module provides a set of functions for allocating different types of memory buffers
Expand Down Expand Up @@ -67,7 +68,7 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope)


def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None):
def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None): # noqa: UP007
"""Allocate a single-element variable buffer.

Args:
Expand Down
Loading