@@ -536,26 +536,32 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
536536 std::string fp8_lanes = (t.lanes () == 4 ) ? " x4" : ((t.lanes () == 2 ) ? " x2" : " " );
537537 ICHECK (t.is_e4m3_float8 () || t.is_e5m2_float8 ());
538538 if (t.lanes () == 2 ) {
539- value_temp << " __nv_cvt_halfraw2_to_fp8x2 (" ;
539+ value_temp << " __nv_fp8x2_ " << (t. is_e5m2_float8 () ? " e5m2 " : " e4m3 " ) << " (" ;
540540 } else {
541- value_temp << " __nv_fp8x4 (" ;
541+ value_temp << " __nv_fp8x4_ " << (t. is_e5m2_float8 () ? " e5m2 " : " e4m3 " ) << " (" ;
542542 }
543543 for (int i = 0 , lanes = t.lanes () / 2 ; i < lanes; ++i) {
544- if (i == 0 ) {
545- value_temp << " make_half2(" ;
544+ if (isalpha (op[0 ]) || op[0 ] == ' _' ) {
545+ value_temp << op << " 2"
546+ << " (__half2(" ;
547+ PrintVecElemLoad (vlhs, lhs.dtype (), i * lanes, value_temp);
548+ value_temp << " ), __half2(" ;
549+ PrintVecElemLoad (vrhs, rhs.dtype (), i * lanes, value_temp);
550+ value_temp << " ))" ;
551+ } else {
552+ value_temp << " __half2(" ;
553+ PrintVecElemLoad (vlhs, lhs.dtype (), i * lanes, value_temp);
554+ value_temp << " ) " << op << " __half2(" ;
555+ PrintVecElemLoad (vrhs, rhs.dtype (), i * lanes, value_temp);
556+ value_temp << " )" ;
557+ }
558+
559+ if (i != lanes - 1 ) {
560+ value_temp << " , " ;
546561 }
547- PrintVecElemLoad (vlhs, lhs.dtype (), i * lanes, value_temp);
548- value_temp << op;
549- PrintVecElemLoad (vrhs, rhs.dtype (), i * lanes, value_temp);
550- value_temp << " ," ;
551- PrintVecElemLoad (vlhs, lhs.dtype (), i * lanes + 1 , value_temp);
552- value_temp << op;
553- PrintVecElemLoad (vrhs, rhs.dtype (), i * lanes + 1 , value_temp);
554- value_temp << " )" ;
555562 if (i == lanes - 1 ) {
556563 if (t.lanes () == 2 ) {
557- value_temp << " , __NV_SATFINITE, " << (t.is_e5m2_float8 () ? " __NV_E5M2" : " __NV_E4M3" )
558- << " )" ;
564+ value_temp << " )" ;
559565 }
560566 PrintVecElemStore (sret, t, i, value_temp.str ());
561567 }
@@ -612,7 +618,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
612618 os << " ((nv_bfloat162*)(&(" << vec << " ." << access[i / 2 ] << " )))->" << access[i % 2 ];
613619 } else if (t.is_float8 ()) {
614620 os << " __nv_cvt_fp8x2_to_halfraw2(" << vec << " .__x,"
615- << (t.is_e5m2_float8 () ? " __NV_E5M2" : " __NV_E4M3" ) << " ). " << access[i % 2 ] ;
621+ << (t.is_e5m2_float8 () ? " __NV_E5M2" : " __NV_E4M3" ) << " )" ;
616622 } else if (t.lanes () > 4 && t.lanes () <= 8 ) {
617623 std::string type_name;
618624 if (t.bits () == 16 ) {
@@ -672,7 +678,7 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
672678 } else if (t.is_float8 ()) {
673679 // Since fp8 is a packed type (2 or 4 lanes), we only want call at end.
674680 ICHECK (i == (t.lanes () / 2 ) - 1 );
675- stream << vec << " .__x = " << value << " ;\n " ;
681+ stream << vec << " = " << value << " ;\n " ;
676682 } else if (t.lanes () > 4 && t.lanes () <= 8 ) {
677683 std::string type_name;
678684 if (t.bits () == 16 ) {
@@ -1740,5 +1746,115 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
17401746 return ;
17411747}
17421748
1749+ template <typename T>
1750+ inline void PrintBinaryExpr (const T* op, const char * opstr,
1751+ std::ostream& os, // NOLINT(*)
1752+ CodeGenCUDA* p) {
1753+ if (op->dtype .lanes () == 1 ) {
1754+ if (op->dtype .is_float8 ()) {
1755+ std::string fp8_type = (op->dtype .is_e5m2_float8 () ? " __NV_E5M2" : " __NV_E4M3" );
1756+ if (isalpha (opstr[0 ]) || opstr[0 ] == ' _' ) {
1757+ os << " __nv_fp8_" << (op->dtype .is_e5m2_float8 () ? " e5m2" : " e4m3" ) << " (" ;
1758+ os << opstr << " (" ;
1759+ os << " __half(__nv_cvt_fp8_to_halfraw(" ;
1760+ p->PrintExpr (op->a , os);
1761+ os << " .__x, " << fp8_type << " )), __half(__nv_cvt_fp8_to_halfraw(" ;
1762+ p->PrintExpr (op->b , os);
1763+ os << " .__x, " << fp8_type << " )))" ;
1764+ os << " )" ;
1765+ } else {
1766+ os << " __nv_fp8_" << (op->dtype .is_e5m2_float8 () ? " e5m2" : " e4m3" ) << " (" ;
1767+ os << " __half(__nv_cvt_fp8_to_halfraw(" ;
1768+ p->PrintExpr (op->a , os);
1769+ os << " .__x, " << fp8_type << " )) " << opstr << " __half(__nv_cvt_fp8_to_halfraw(" ;
1770+ p->PrintExpr (op->b , os);
1771+ os << " .__x, " << fp8_type << " )))" ;
1772+ }
1773+ } else {
1774+ if (isalpha (opstr[0 ])) {
1775+ os << opstr << ' (' ;
1776+ p->PrintExpr (op->a , os);
1777+ os << " , " ;
1778+ p->PrintExpr (op->b , os);
1779+ os << ' )' ;
1780+ } else {
1781+ os << ' (' ;
1782+ p->PrintExpr (op->a , os);
1783+ os << ' ' << opstr << ' ' ;
1784+ p->PrintExpr (op->b , os);
1785+ os << ' )' ;
1786+ }
1787+ }
1788+
1789+ } else {
1790+ p->PrintVecBinaryOp (opstr, op->dtype , op->a , op->b , os);
1791+ }
1792+ }
1793+
1794+ void CodeGenCUDA::VisitExpr_ (const AddNode* op, std::ostream& os) { // NOLINT(*)
1795+ PrintBinaryExpr (op, " +" , os, this );
1796+ }
1797+ void CodeGenCUDA::VisitExpr_ (const SubNode* op, std::ostream& os) { // NOLINT(*)
1798+ PrintBinaryExpr (op, " -" , os, this );
1799+ }
1800+ void CodeGenCUDA::VisitExpr_ (const MulNode* op, std::ostream& os) { // NOLINT(*)
1801+ PrintBinaryExpr (op, " *" , os, this );
1802+ }
1803+ void CodeGenCUDA::VisitExpr_ (const DivNode* op, std::ostream& os) { // NOLINT(*)
1804+ PrintBinaryExpr (op, " /" , os, this );
1805+ }
1806+ void CodeGenCUDA::VisitExpr_ (const ModNode* op, std::ostream& os) { // NOLINT(*)
1807+ if (op->dtype .is_int () || op->dtype .is_uint ()) {
1808+ PrintBinaryExpr (op, " %" , os, this );
1809+ } else {
1810+ ICHECK (op->dtype .is_float ()) << " Expected floating point or integer dtype in Mod, but got "
1811+ << op->dtype ;
1812+ if (op->dtype .bits () == 32 ) {
1813+ PrintBinaryExpr (op, " fmodf" , os, this );
1814+ } else if (op->dtype .bits () == 64 ) {
1815+ PrintBinaryExpr (op, " fmod" , os, this );
1816+ } else {
1817+ ICHECK (false )
1818+ << " Non single or double precision floating point in Mod, expected 32 or 64 bits but got "
1819+ << op->dtype .bits () << " bits." ;
1820+ }
1821+ }
1822+ }
1823+
1824+ void CodeGenCUDA::VisitExpr_ (const MinNode* op, std::ostream& os) { // NOLINT(*)
1825+ PrintBinaryExpr (op, op->dtype .is_float8 () ? " __hmin" : " min" , os, this );
1826+ }
1827+ void CodeGenCUDA::VisitExpr_ (const MaxNode* op, std::ostream& os) { // NOLINT(*)
1828+ PrintBinaryExpr (op, op->dtype .is_float8 () ? " __hmax" : " max" , os, this );
1829+ }
1830+ void CodeGenCUDA::VisitExpr_ (const EQNode* op, std::ostream& os) { // NOLINT(*)
1831+ PrintBinaryExpr (op, " ==" , os, this );
1832+ }
1833+ void CodeGenCUDA::VisitExpr_ (const NENode* op, std::ostream& os) { // NOLINT(*)
1834+ PrintBinaryExpr (op, " !=" , os, this );
1835+ }
1836+ void CodeGenCUDA::VisitExpr_ (const LTNode* op, std::ostream& os) { // NOLINT(*)
1837+ PrintBinaryExpr (op, " <" , os, this );
1838+ }
1839+ void CodeGenCUDA::VisitExpr_ (const LENode* op, std::ostream& os) { // NOLINT(*)
1840+ PrintBinaryExpr (op, " <=" , os, this );
1841+ }
1842+ void CodeGenCUDA::VisitExpr_ (const GTNode* op, std::ostream& os) { // NOLINT(*)
1843+ PrintBinaryExpr (op, " >" , os, this );
1844+ }
1845+ void CodeGenCUDA::VisitExpr_ (const GENode* op, std::ostream& os) { // NOLINT(*)
1846+ PrintBinaryExpr (op, " >=" , os, this );
1847+ }
1848+ void CodeGenCUDA::VisitExpr_ (const AndNode* op, std::ostream& os) { // NOLINT(*)
1849+ PrintBinaryExpr (op, " &&" , os, this );
1850+ }
1851+ void CodeGenCUDA::VisitExpr_ (const OrNode* op, std::ostream& os) { // NOLINT(*)
1852+ PrintBinaryExpr (op, " ||" , os, this );
1853+ }
1854+ void CodeGenCUDA::VisitExpr_ (const NotNode* op, std::ostream& os) { // NOLINT(*)
1855+ os << ' !' ;
1856+ PrintExpr (op->a , os);
1857+ }
1858+
17431859} // namespace codegen
17441860} // namespace tvm
0 commit comments