Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <cutlass/numeric_types.h>
#include <math_constants.h>

#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>

Comment on lines +13 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Explicitly include cuda_bf16.h to guarantee __nv_bfloat16 availability.

Avoid relying on transitive includes; add the CUDA header so host/RTC builds consistently see __nv_bfloat16.

Apply this diff near the existing cuda_runtime include:

 #ifndef __CUDACC_RTC__
 #include <cuda_runtime.h>
+#include <cuda_bf16.h>
 #endif

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 13 to 15, the file relies on
transitive includes for the CUDA bfloat16 type (__nv_bfloat16); explicitly add
the CUDA header cuda_bf16.h (near the existing cuda_runtime include) to
guarantee __nv_bfloat16 is available for host and RTC builds, avoiding
transitive-include fragility.

using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;
Expand Down Expand Up @@ -339,6 +342,37 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
descriptor.reg32_[0] += (offset >> 4);
}

// and add the desired implicit conversion from bfloat16_t.
struct float_e4m3_t : public cute::float_e4m3_t {
using cute::float_e4m3_t::float_e4m3_t;
CUTLASS_HOST_DEVICE
float_e4m3_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(__nv_bfloat16 x)
: float_e4m3_t(static_cast<float>(x)) {}
};

struct float_e5m2_t : public cute::float_e5m2_t {
using cute::float_e5m2_t::float_e5m2_t;
CUTLASS_HOST_DEVICE
float_e5m2_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e5m2_t(__nv_bfloat16 x)
: float_e5m2_t(static_cast<float>(x)) {}
};

template <typename T> struct to_cute_type {
using type = T;
};
template <> struct to_cute_type<tl::float_e4m3_t> {
using type = cute::float_e4m3_t;
};
template <> struct to_cute_type<tl::float_e5m2_t> {
using type = cute::float_e5m2_t;
};

} // namespace tl

namespace cutlass {
Expand Down
5 changes: 3 additions & 2 deletions src/tl_templates/cuda/cuda_fp8.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#pragma once

#include "common.h"
#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp>

using fp8_e4_t = cute::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t;
using fp8_e4_t = tl::float_e4m3_t;
using fp8_e5_t = tl::float_e5m2_t;

struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x;
Expand Down
10 changes: 6 additions & 4 deletions src/tl_templates/cuda/gemm_mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,14 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw;

using Instruction =
Expand Down
10 changes: 6 additions & 4 deletions src/tl_templates/cuda/gemm_sm100.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,14 @@ template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>::type;
typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw;

static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32);
Expand Down
10 changes: 6 additions & 4 deletions src/tl_templates/cuda/gemm_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using C_type = C_type_raw;

static constexpr GMMA::Major GmmaMajorA =
Expand Down
10 changes: 6 additions & 4 deletions src/tl_templates/cuda/gemm_sp_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ class GemmTensorOp {
public:
static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");

using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>;
using C_type = C_type_raw;

static constexpr bool need_tfloat32_cast =
Expand Down
Loading