Skip to content

Commit efb2d90

Browse files
committed
Fix faulty codegen for non-vectorized binary fp8 operations
Non-vectorized FP8 are store as __nv_fp8_[e5m2/e4m3] types, these types do not have support for binary operatios because internally FP8 are store in 16bit registers. This commits adds binary operator support by doing the operations in __half instead of fp8 (i.e cast up to 16-bit, then cast down to 8-bit).
1 parent 5a9f86a commit efb2d90

File tree

3 files changed

+161
-19
lines changed

3 files changed

+161
-19
lines changed

src/target/source/codegen_c.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp
214214
if (alloc_storage_scope_.count(buffer_var)) {
215215
scope = alloc_storage_scope_.at(buffer_var);
216216
}
217-
bool is_vol = IsVolatile(buffer_var);
217+
bool is_vol = IsVolatile(buffer_var) && !t.is_float8();
218218

219219
auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
220220
std::ostringstream ptr_os;
@@ -840,7 +840,8 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
840840
std::string value = this->PrintExpr(op->value);
841841
std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr);
842842
this->PrintIndent();
843-
stream << ref << " = " << value << ";\n";
843+
stream << ref << " = ";
844+
stream << value << ";\n";
844845
} else {
845846
arith::PVar<PrimExpr> base;
846847

@@ -876,7 +877,16 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
876877
stream << '[';
877878
PrintVecElemLoad(index, index_expr.dtype(), i, stream);
878879
stream << "] = ";
879-
PrintVecElemLoad(value, op->value.dtype(), i, stream);
880+
if (op->value.dtype().is_float8()) {
881+
ICHECK(value_dtype.lanes() == 2);
882+
std::string fp8_type = op->value.dtype().is_e5m2_float8() ? "e5m2" : "e4m3";
883+
static const char access[] = {'x', 'y'};
884+
stream << "__nv_fp8_" << fp8_type << "(__half2(";
885+
PrintVecElemLoad(value, op->value.dtype(), i, stream);
886+
stream << ")." << access[i % 2] << ")";
887+
} else {
888+
PrintVecElemLoad(value, op->value.dtype(), i, stream);
889+
}
880890
stream << ";\n";
881891
}
882892
EndScope(vec_scope);

src/target/source/codegen_cuda.cc

Lines changed: 132 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -536,26 +536,32 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
536536
std::string fp8_lanes = (t.lanes() == 4) ? "x4" : ((t.lanes() == 2) ? "x2" : "");
537537
ICHECK(t.is_e4m3_float8() || t.is_e5m2_float8());
538538
if (t.lanes() == 2) {
539-
value_temp << "__nv_cvt_halfraw2_to_fp8x2(";
539+
value_temp << "__nv_fp8x2_" << (t.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
540540
} else {
541-
value_temp << "__nv_fp8x4(";
541+
value_temp << "__nv_fp8x4_" << (t.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
542542
}
543543
for (int i = 0, lanes = t.lanes() / 2; i < lanes; ++i) {
544-
if (i == 0) {
545-
value_temp << "make_half2(";
544+
if (isalpha(op[0]) || op[0] == '_') {
545+
value_temp << op << "2"
546+
<< "(__half2(";
547+
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp);
548+
value_temp << "), __half2(";
549+
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp);
550+
value_temp << "))";
551+
} else {
552+
value_temp << "__half2(";
553+
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp);
554+
value_temp << ") " << op << " __half2(";
555+
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp);
556+
value_temp << ")";
557+
}
558+
559+
if (i != lanes - 1) {
560+
value_temp << ", ";
546561
}
547-
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp);
548-
value_temp << op;
549-
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp);
550-
value_temp << ",";
551-
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes + 1, value_temp);
552-
value_temp << op;
553-
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes + 1, value_temp);
554-
value_temp << ")";
555562
if (i == lanes - 1) {
556563
if (t.lanes() == 2) {
557-
value_temp << ", __NV_SATFINITE, " << (t.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3")
558-
<< ")";
564+
value_temp << ")";
559565
}
560566
PrintVecElemStore(sret, t, i, value_temp.str());
561567
}
@@ -612,7 +618,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
612618
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
613619
} else if (t.is_float8()) {
614620
os << "__nv_cvt_fp8x2_to_halfraw2(" << vec << ".__x,"
615-
<< (t.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3") << ")." << access[i % 2];
621+
<< (t.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3") << ")";
616622
} else if (t.lanes() > 4 && t.lanes() <= 8) {
617623
std::string type_name;
618624
if (t.bits() == 16) {
@@ -672,7 +678,7 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
672678
} else if (t.is_float8()) {
673679
// Since fp8 is a packed type (2 or 4 lanes), we only want call at end.
674680
ICHECK(i == (t.lanes() / 2) - 1);
675-
stream << vec << ".__x = " << value << ";\n";
681+
stream << vec << " = " << value << ";\n";
676682
} else if (t.lanes() > 4 && t.lanes() <= 8) {
677683
std::string type_name;
678684
if (t.bits() == 16) {
@@ -1740,5 +1746,115 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
17401746
return;
17411747
}
17421748

1749+
template <typename T>
1750+
inline void PrintBinaryExpr(const T* op, const char* opstr,
1751+
std::ostream& os, // NOLINT(*)
1752+
CodeGenCUDA* p) {
1753+
if (op->dtype.lanes() == 1) {
1754+
if (op->dtype.is_float8()) {
1755+
std::string fp8_type = (op->dtype.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3");
1756+
if (isalpha(opstr[0]) || opstr[0] == '_') {
1757+
os << "__nv_fp8_" << (op->dtype.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
1758+
os << opstr << "(";
1759+
os << "__half(__nv_cvt_fp8_to_halfraw(";
1760+
p->PrintExpr(op->a, os);
1761+
os << ".__x, " << fp8_type << ")), __half(__nv_cvt_fp8_to_halfraw(";
1762+
p->PrintExpr(op->b, os);
1763+
os << ".__x, " << fp8_type << ")))";
1764+
os << ")";
1765+
} else {
1766+
os << "__nv_fp8_" << (op->dtype.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
1767+
os << "__half(__nv_cvt_fp8_to_halfraw(";
1768+
p->PrintExpr(op->a, os);
1769+
os << ".__x, " << fp8_type << ")) " << opstr << " __half(__nv_cvt_fp8_to_halfraw(";
1770+
p->PrintExpr(op->b, os);
1771+
os << ".__x, " << fp8_type << ")))";
1772+
}
1773+
} else {
1774+
if (isalpha(opstr[0])) {
1775+
os << opstr << '(';
1776+
p->PrintExpr(op->a, os);
1777+
os << ", ";
1778+
p->PrintExpr(op->b, os);
1779+
os << ')';
1780+
} else {
1781+
os << '(';
1782+
p->PrintExpr(op->a, os);
1783+
os << ' ' << opstr << ' ';
1784+
p->PrintExpr(op->b, os);
1785+
os << ')';
1786+
}
1787+
}
1788+
1789+
} else {
1790+
p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
1791+
}
1792+
}
1793+
1794+
void CodeGenCUDA::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*)
1795+
PrintBinaryExpr(op, "+", os, this);
1796+
}
1797+
void CodeGenCUDA::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*)
1798+
PrintBinaryExpr(op, "-", os, this);
1799+
}
1800+
void CodeGenCUDA::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*)
1801+
PrintBinaryExpr(op, "*", os, this);
1802+
}
1803+
void CodeGenCUDA::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*)
1804+
PrintBinaryExpr(op, "/", os, this);
1805+
}
1806+
void CodeGenCUDA::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*)
1807+
if (op->dtype.is_int() || op->dtype.is_uint()) {
1808+
PrintBinaryExpr(op, "%", os, this);
1809+
} else {
1810+
ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got "
1811+
<< op->dtype;
1812+
if (op->dtype.bits() == 32) {
1813+
PrintBinaryExpr(op, "fmodf", os, this);
1814+
} else if (op->dtype.bits() == 64) {
1815+
PrintBinaryExpr(op, "fmod", os, this);
1816+
} else {
1817+
ICHECK(false)
1818+
<< "Non single or double precision floating point in Mod, expected 32 or 64 bits but got "
1819+
<< op->dtype.bits() << " bits.";
1820+
}
1821+
}
1822+
}
1823+
1824+
void CodeGenCUDA::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
1825+
PrintBinaryExpr(op, op->dtype.is_float8() ? "__hmin" : "min", os, this);
1826+
}
1827+
void CodeGenCUDA::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
1828+
PrintBinaryExpr(op, op->dtype.is_float8() ? "__hmax" : "max", os, this);
1829+
}
1830+
void CodeGenCUDA::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*)
1831+
PrintBinaryExpr(op, "==", os, this);
1832+
}
1833+
void CodeGenCUDA::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*)
1834+
PrintBinaryExpr(op, "!=", os, this);
1835+
}
1836+
void CodeGenCUDA::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*)
1837+
PrintBinaryExpr(op, "<", os, this);
1838+
}
1839+
void CodeGenCUDA::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*)
1840+
PrintBinaryExpr(op, "<=", os, this);
1841+
}
1842+
void CodeGenCUDA::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*)
1843+
PrintBinaryExpr(op, ">", os, this);
1844+
}
1845+
void CodeGenCUDA::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*)
1846+
PrintBinaryExpr(op, ">=", os, this);
1847+
}
1848+
void CodeGenCUDA::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*)
1849+
PrintBinaryExpr(op, "&&", os, this);
1850+
}
1851+
void CodeGenCUDA::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*)
1852+
PrintBinaryExpr(op, "||", os, this);
1853+
}
1854+
void CodeGenCUDA::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*)
1855+
os << '!';
1856+
PrintExpr(op->a, os);
1857+
}
1858+
17431859
} // namespace codegen
17441860
} // namespace tvm

src/target/source/codegen_cuda.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ class CodeGenCUDA final : public CodeGenC {
6868
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
6969
void VisitExpr_(const CallNode* op, std::ostream& os) final;
7070
void VisitExpr_(const CastNode* op, std::ostream& os) final;
71+
void VisitExpr_(const AddNode* op, std::ostream& os) final;
72+
void VisitExpr_(const SubNode* op, std::ostream& os) final;
73+
void VisitExpr_(const MulNode* op, std::ostream& os) final;
74+
void VisitExpr_(const DivNode* op, std::ostream& os) final;
75+
void VisitExpr_(const ModNode* op, std::ostream& os) final;
76+
void VisitExpr_(const MinNode* op, std::ostream& os) final;
77+
void VisitExpr_(const MaxNode* op, std::ostream& os) final;
78+
void VisitExpr_(const EQNode* op, std::ostream& os) final;
79+
void VisitExpr_(const NENode* op, std::ostream& os) final;
80+
void VisitExpr_(const LTNode* op, std::ostream& os) final;
81+
void VisitExpr_(const LENode* op, std::ostream& os) final;
82+
void VisitExpr_(const GTNode* op, std::ostream& os) final;
83+
void VisitExpr_(const GENode* op, std::ostream& os) final;
84+
void VisitExpr_(const AndNode* op, std::ostream& os) final;
85+
void VisitExpr_(const OrNode* op, std::ostream& os) final;
86+
void VisitExpr_(const NotNode* op, std::ostream& os) final;
7187
void VisitStmt_(const EvaluateNode* op) final;
7288
void VisitStmt_(const AllocateNode* op) final;
7389
void VisitStmt_(const AttrStmtNode* op) final;

0 commit comments

Comments
 (0)