Skip to content

Commit 5a9f86a

Browse files
committed
Fix faulty codegen for vectorized binary fp8 operations
Vectorized FP8 are stored as __nv_[fp8x2/fp8x4]_[e5m2/e4m3] (i.e. 16bit registers). These types do not have overloaded binary operators (such as *) to handle these types. This commit adds the ability to do this by exctracting the high and low bits, statically casting them to floats, performing the operation, then repacking them into dual lane type.
1 parent 0632a7c commit 5a9f86a

File tree

1 file changed

+50
-11
lines changed

1 file changed

+50
-11
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -531,22 +531,53 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
531531
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
532532
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
533533

534-
for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
534+
if (t.is_float8()) {
535535
std::ostringstream value_temp;
536-
if (isalpha(op[0])) {
537-
value_temp << op << "(";
538-
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
539-
value_temp << ", ";
540-
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
541-
value_temp << ")";
536+
std::string fp8_lanes = (t.lanes() == 4) ? "x4" : ((t.lanes() == 2) ? "x2" : "");
537+
ICHECK(t.is_e4m3_float8() || t.is_e5m2_float8());
538+
if (t.lanes() == 2) {
539+
value_temp << "__nv_cvt_halfraw2_to_fp8x2(";
542540
} else {
543-
value_temp << "(";
544-
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
541+
value_temp << "__nv_fp8x4(";
542+
}
543+
for (int i = 0, lanes = t.lanes() / 2; i < lanes; ++i) {
544+
if (i == 0) {
545+
value_temp << "make_half2(";
546+
}
547+
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp);
545548
value_temp << op;
546-
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
549+
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp);
550+
value_temp << ",";
551+
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes + 1, value_temp);
552+
value_temp << op;
553+
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes + 1, value_temp);
547554
value_temp << ")";
555+
if (i == lanes - 1) {
556+
if (t.lanes() == 2) {
557+
value_temp << ", __NV_SATFINITE, " << (t.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3")
558+
<< ")";
559+
}
560+
PrintVecElemStore(sret, t, i, value_temp.str());
561+
}
562+
}
563+
} else {
564+
for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
565+
std::ostringstream value_temp;
566+
if (isalpha(op[0])) {
567+
value_temp << op << "(";
568+
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
569+
value_temp << ", ";
570+
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
571+
value_temp << ")";
572+
} else {
573+
value_temp << "(";
574+
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
575+
value_temp << op;
576+
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
577+
value_temp << ")";
578+
}
579+
PrintVecElemStore(sret, t, i, value_temp.str());
548580
}
549-
PrintVecElemStore(sret, t, i, value_temp.str());
550581
}
551582
}
552583
EndScope(ssa_scope);
@@ -561,6 +592,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
561592
}
562593

563594
static const char access[] = {'x', 'y', 'z', 'w'};
595+
std::string fp8_type = (t.is_float8()) ? (t.is_e4m3_float8() ? "e4m3" : "e5m2") : "";
564596
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
565597
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
566598
std::string type_name = t.is_int() ? "char" : "unsigned char";
@@ -578,6 +610,9 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
578610
}
579611
} else if (t.is_bfloat16()) {
580612
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
613+
} else if (t.is_float8()) {
614+
os << "__nv_cvt_fp8x2_to_halfraw2(" << vec << ".__x,"
615+
<< (t.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3") << ")." << access[i % 2];
581616
} else if (t.lanes() > 4 && t.lanes() <= 8) {
582617
std::string type_name;
583618
if (t.bits() == 16) {
@@ -634,6 +669,10 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
634669
} else if (t.is_bfloat16()) {
635670
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
636671
<< " = " << value << ";\n";
672+
} else if (t.is_float8()) {
673+
// Since fp8 is a packed type (2 or 4 lanes), we only want call at end.
674+
ICHECK(i == (t.lanes() / 2) - 1);
675+
stream << vec << ".__x = " << value << ";\n";
637676
} else if (t.lanes() > 4 && t.lanes() <= 8) {
638677
std::string type_name;
639678
if (t.bits() == 16) {

0 commit comments

Comments
 (0)