Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <cutlass/numeric_types.h>
#include <math_constants.h>

#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>

Comment on lines +13 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Explicitly include cuda_bf16.h to guarantee __nv_bfloat16 availability.

Avoid relying on transitive includes; add the CUDA header so host/RTC builds consistently see __nv_bfloat16.

Apply this diff near the existing cuda_runtime include:

 #ifndef __CUDACC_RTC__
 #include <cuda_runtime.h>
+#include <cuda_bf16.h>
 #endif

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 13 to 15, the file relies on
transitive includes for the CUDA bfloat16 type (__nv_bfloat16); explicitly add
the CUDA header cuda_bf16.h (near the existing cuda_runtime include) to
guarantee __nv_bfloat16 is available for host and RTC builds, avoiding
transitive-include fragility.

using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;
Expand Down Expand Up @@ -318,6 +321,27 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
descriptor.reg32_[0] += (offset >> 4);
}

// and add the desired implicit conversion from bfloat16_t.
struct float_e4m3_t : public cutlass::float_e4m3_t {
using cutlass::float_e4m3_t::float_e4m3_t;
CUTLASS_HOST_DEVICE
float_e4m3_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(__nv_bfloat16 x)
: float_e4m3_t(static_cast<float>(x)) {}
};

struct float_e5m2_t : public cutlass::float_e5m2_t {
using cutlass::float_e5m2_t::float_e5m2_t;
CUTLASS_HOST_DEVICE
float_e5m2_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e5m2_t(__nv_bfloat16 x)
: float_e5m2_t(static_cast<float>(x)) {}
};

} // namespace tl

namespace cutlass {
Expand Down
5 changes: 3 additions & 2 deletions src/tl_templates/cuda/cuda_fp8.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#pragma once

#include "common.h"
#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp>

using fp8_e4_t = cute::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t;
using fp8_e4_t = tl::float_e4m3_t;
using fp8_e5_t = tl::float_e5m2_t;

struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x;
Expand Down
Loading