@@ -149,8 +149,7 @@ std::string CodeGenCUDA::Finish() {
149149 if (enable_fp16_) {
150150 decl_stream << " #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n " ;
151151 decl_stream << " #include <cuda_fp16.h>\n " ;
152- decl_stream << " __device__ half max"
153- << " (half a, half b)\n "
152+ decl_stream << " __device__ half max" << " (half a, half b)\n "
154153 << " {\n return __hgt(__half(a), __half(b)) ? a : b;\n }\n " ;
155154 decl_stream << " __device__ half min(half a, half b)\n "
156155 << " {\n return __hlt(__half(a), __half(b)) ? a : b;\n }\n " ;
@@ -165,8 +164,7 @@ std::string CodeGenCUDA::Finish() {
165164 if (enable_bf16_) {
166165 decl_stream << " #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n " ;
167166 decl_stream << " #include <cuda_bf16.h>\n " ;
168- decl_stream << " __device__ nv_bfloat16 max"
169- << " (nv_bfloat16 a, nv_bfloat16 b)\n "
167+ decl_stream << " __device__ nv_bfloat16 max" << " (nv_bfloat16 a, nv_bfloat16 b)\n "
170168 << " {\n return __hgt(a, b) ? a : b;\n }\n " ;
171169 decl_stream << " __device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n "
172170 << " {\n return __hlt(a, b) ? a : b;\n }\n " ;
@@ -542,8 +540,7 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
542540 }
543541 for (int i = 0 , lanes = t.lanes () / 2 ; i < lanes; ++i) {
544542 if (isalpha (op[0 ]) || op[0 ] == ' _' ) {
545- value_temp << op << " 2"
546- << " (__half2(" ;
543+ value_temp << op << " 2" << " (__half2(" ;
547544 PrintVecElemLoad (vlhs, lhs.dtype (), i * lanes, value_temp);
548545 value_temp << " ), __half2(" ;
549546 PrintVecElemLoad (vrhs, rhs.dtype (), i * lanes, value_temp);
@@ -653,8 +650,7 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
653650 ICHECK (i >= 0 && i < (t.bits () == 8 ? 16 : (t.bits () == 16 || t.bits () == 32 ) ? 8 : 4 ));
654651 if (t.bits () == 8 && (t.is_int () || t.is_uint ())) {
655652 if (t.lanes () == 2 || t.lanes () == 3 ) {
656- stream << vec << ' .' << access[i % t.lanes ()] << " ="
657- << " (" << value << " );\n " ;
653+ stream << vec << ' .' << access[i % t.lanes ()] << " =" << " (" << value << " );\n " ;
658654 } else {
659655 std::string ac = t.lanes () == 4 ? vec : (vec + " ." + access[i / 4 ]);
660656 stream << ac << " =" ;
@@ -861,7 +857,23 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Arr
861857 }
862858 os << sret;
863859 } else {
864- CodeGenC::PrintCallExtern (ret_type, global_symbol, args, skip_first_arg, os);
860+ if (ret_dtype.is_float8 ()) {
861+ std::string fp8_type = (ret_dtype.is_e5m2_float8 () ? " __NV_E5M2" : " __NV_E4M3" );
862+ os << " __nv_fp8_" << (ret_dtype.is_e5m2_float8 () ? " e5m2" : " e4m3" ) << " (" ;
863+
864+ LOG_INFO << global_symbol;
865+ os << global_symbol << " (__half(__nv_cvt_fp8_to_halfraw(" ;
866+ for (size_t i = static_cast <size_t >(skip_first_arg); i < args.size (); ++i) {
867+ this ->PrintExpr (args[i], os);
868+ os << " .__x, " << fp8_type << " ))" ;
869+ if (i < args.size () - 1 ) {
870+ os << " , " << " __half(__nv_cvt_fp8_to_halfraw(" ;
871+ }
872+ }
873+ os << " ))" ;
874+ } else {
875+ CodeGenC::PrintCallExtern (ret_type, global_symbol, args, skip_first_arg, os);
876+ }
865877 }
866878}
867879
@@ -1198,8 +1210,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
11981210 this ->stream << " \" @!p mov.b32 %0, 0;\\ n\"\n " ;
11991211 this ->stream << " \" @p ld.global.nc.f32 %0, [%1];}\\ n\"\n " ;
12001212 // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
1201- stream << " : \" =f\" (" << reg << " [" << local_addr << " ]"
1202- << " )\n " ;
1213+ stream << " : \" =f\" (" << reg << " [" << local_addr << " ]" << " )\n " ;
12031214 stream << " : \" l\" ((void*)(" << global_buffer << " +" << global_addr << " )), \" r\" ((int)"
12041215 << guard << " )\n " ;
12051216 stream << " );\n " ;
@@ -1385,8 +1396,7 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
13851396 PrintVecConstructor (op->dtype , os);
13861397 os << " (" ;
13871398 for (int i = 0 ; i < lanes; i++) {
1388- os << " (" << PrintExpr (op->base ) << " )"
1389- << " +(" << PrintExpr (op->stride ) << " *" << i << " )" ;
1399+ os << " (" << PrintExpr (op->base ) << " )" << " +(" << PrintExpr (op->stride ) << " *" << i << " )" ;
13901400 if (i != lanes - 1 ) os << " , " ;
13911401 }
13921402 os << " )" ;
0 commit comments