Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,13 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);

PrimExpr value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype)
if (src->dtype != dst->dtype) {
// If dst is fp8 and src is bf16, first cast dst to fp32.
if (src->dtype.is_bfloat16() && dst->dtype.is_float8_e4m3()) {
value = Cast(DataType::Float(32), value);
}
value = Cast(dst->dtype, value);
}
if (src_predicate.defined())
value = if_then_else(src_predicate, value, make_zero(dst->dtype));

Expand Down
28 changes: 28 additions & 0 deletions testing/python/issue/test_tilelang_issue_1046.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import tilelang
import tilelang.language as T

tilelang.disable_cache()

FP8 = "float8_e4m3"
BF16 = "bfloat16"


@tilelang.jit
def test_kernel(N, in_dtype=BF16, out_dtype=FP8):
M = T.dynamic("M")
blk_m = 128
group_size = 128

@T.prim_func
def test_kernel_(X: T.Tensor[(M, N), in_dtype], Y: T.Tensor[(M, N), out_dtype]):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (pid_m, pid_n):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, Y[pid_m * blk_m, pid_n * group_size])

return test_kernel_


kernel = test_kernel(128)

print(kernel.get_kernel_source())
Loading