Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
PrintVecElemStore(sret, target_ty, i, val.str());
}


if (used_bf16_op) {
stream << "#endif\n";
}
Expand Down
24 changes: 24 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <cutlass/numeric_types.h>
#include <math_constants.h>

#include <cutlass/float8.h>
#include <cutlass/bfloat16.h>

using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;
Expand Down Expand Up @@ -318,6 +321,27 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
descriptor.reg32_[0] += (offset >> 4);
}

// and add the desired implicit conversion from bfloat16_t.
struct float_e4m3_t : public cutlass::float_e4m3_t {
using cutlass::float_e4m3_t::float_e4m3_t;
CUTLASS_HOST_DEVICE
float_e4m3_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(__nv_bfloat16 x) : float_e4m3_t(static_cast<float>(x)) {
}
};

struct float_e5m2_t : public cutlass::float_e5m2_t {
using cutlass::float_e5m2_t::float_e5m2_t;
CUTLASS_HOST_DEVICE
float_e5m2_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e5m2_t(__nv_bfloat16 x) : float_e5m2_t(static_cast<float>(x)) {
}
};

} // namespace tl

namespace cutlass {
Expand Down
5 changes: 3 additions & 2 deletions src/tl_templates/cuda/cuda_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp>
#include "common.h"

using fp8_e4_t = cute::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t;
using fp8_e4_t = tl::float_e4m3_t;
using fp8_e5_t = tl::float_e5m2_t;

struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x;
Expand Down
Loading