@@ -531,22 +531,53 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
531531 std::string vlhs = SSAGetID (PrintExpr (lhs), lhs.dtype ());
532532 std::string vrhs = SSAGetID (PrintExpr (rhs), rhs.dtype ());
533533
534- for ( int i = 0 , lanes = t. lanes (); i < lanes; ++i ) {
534+ if (t. is_float8 () ) {
535535 std::ostringstream value_temp;
536- if (isalpha (op[0 ])) {
537- value_temp << op << " (" ;
538- PrintVecElemLoad (vlhs, lhs.dtype (), i, value_temp);
539- value_temp << " , " ;
540- PrintVecElemLoad (vrhs, rhs.dtype (), i, value_temp);
541- value_temp << " )" ;
536+ std::string fp8_lanes = (t.lanes () == 4 ) ? " x4" : ((t.lanes () == 2 ) ? " x2" : " " );
537+ ICHECK (t.is_e4m3_float8 () || t.is_e5m2_float8 ());
538+ if (t.lanes () == 2 ) {
539+ value_temp << " __nv_cvt_halfraw2_to_fp8x2(" ;
542540 } else {
543- value_temp << " (" ;
544- PrintVecElemLoad (vlhs, lhs.dtype (), i, value_temp);
541+ value_temp << " __nv_fp8x4(" ;
542+ }
543+ for (int i = 0 , lanes = t.lanes () / 2 ; i < lanes; ++i) {
544+ if (i == 0 ) {
545+ value_temp << " make_half2(" ;
546+ }
547+ PrintVecElemLoad (vlhs, lhs.dtype (), i * lanes, value_temp);
545548 value_temp << op;
546- PrintVecElemLoad (vrhs, rhs.dtype (), i, value_temp);
549+ PrintVecElemLoad (vrhs, rhs.dtype (), i * lanes, value_temp);
550+ value_temp << " ," ;
551+ PrintVecElemLoad (vlhs, lhs.dtype (), i * lanes + 1 , value_temp);
552+ value_temp << op;
553+ PrintVecElemLoad (vrhs, rhs.dtype (), i * lanes + 1 , value_temp);
547554 value_temp << " )" ;
555+ if (i == lanes - 1 ) {
556+ if (t.lanes () == 2 ) {
557+ value_temp << " , __NV_SATFINITE, " << (t.is_e5m2_float8 () ? " __NV_E5M2" : " __NV_E4M3" )
558+ << " )" ;
559+ }
560+ PrintVecElemStore (sret, t, i, value_temp.str ());
561+ }
562+ }
563+ } else {
564+ for (int i = 0 , lanes = t.lanes (); i < lanes; ++i) {
565+ std::ostringstream value_temp;
566+ if (isalpha (op[0 ])) {
567+ value_temp << op << " (" ;
568+ PrintVecElemLoad (vlhs, lhs.dtype (), i, value_temp);
569+ value_temp << " , " ;
570+ PrintVecElemLoad (vrhs, rhs.dtype (), i, value_temp);
571+ value_temp << " )" ;
572+ } else {
573+ value_temp << " (" ;
574+ PrintVecElemLoad (vlhs, lhs.dtype (), i, value_temp);
575+ value_temp << op;
576+ PrintVecElemLoad (vrhs, rhs.dtype (), i, value_temp);
577+ value_temp << " )" ;
578+ }
579+ PrintVecElemStore (sret, t, i, value_temp.str ());
548580 }
549- PrintVecElemStore (sret, t, i, value_temp.str ());
550581 }
551582 }
552583 EndScope (ssa_scope);
@@ -561,6 +592,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
561592 }
562593
563594 static const char access[] = {' x' , ' y' , ' z' , ' w' };
595+ std::string fp8_type = (t.is_float8 ()) ? (t.is_e4m3_float8 () ? " e4m3" : " e5m2" ) : " " ;
564596 ICHECK (i >= 0 && i < (t.bits () == 8 ? 16 : (t.bits () == 16 || t.bits () == 32 ) ? 8 : 4 ));
565597 if (t.bits () == 8 && (t.is_int () || t.is_uint ())) {
566598 std::string type_name = t.is_int () ? " char" : " unsigned char" ;
@@ -578,6 +610,9 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
578610 }
579611 } else if (t.is_bfloat16 ()) {
580612 os << " ((nv_bfloat162*)(&(" << vec << " ." << access[i / 2 ] << " )))->" << access[i % 2 ];
613+ } else if (t.is_float8 ()) {
614+ os << " __nv_cvt_fp8x2_to_halfraw2(" << vec << " .__x,"
615+ << (t.is_e5m2_float8 () ? " __NV_E5M2" : " __NV_E4M3" ) << " )." << access[i % 2 ];
581616 } else if (t.lanes () > 4 && t.lanes () <= 8 ) {
582617 std::string type_name;
583618 if (t.bits () == 16 ) {
@@ -634,6 +669,10 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
634669 } else if (t.is_bfloat16 ()) {
635670 stream << " ((nv_bfloat162*)(&(" << vec << " ." << access[i / 2 ] << " )))->" << access[i % 2 ]
636671 << " = " << value << " ;\n " ;
672+ } else if (t.is_float8 ()) {
673+ // Since fp8 is a packed type (2 or 4 lanes), we only want call at end.
674+ ICHECK (i == (t.lanes () / 2 ) - 1 );
675+ stream << vec << " .__x = " << value << " ;\n " ;
637676 } else if (t.lanes () > 4 && t.lanes () <= 8 ) {
638677 std::string type_name;
639678 if (t.bits () == 16 ) {
0 commit comments