Skip to content

Commit 2bb1544

Browse files
authored
[SYCLomatic] Support migration of __nv_bfloat162 type and 2 related API. (#800)
Signed-off-by: Tang, Jiajun [email protected]
1 parent 5156990 commit 2bb1544

File tree

9 files changed

+79
-4
lines changed

9 files changed

+79
-4
lines changed

clang/lib/DPCT/APINames.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,7 +1767,7 @@ ENTRY(__assertfail, __assertfail, false, NO_FLAG, P4, "comment")
17671767
ENTRY(__ldcs, __ldcs, false, NO_FLAG, P4, "comment")
17681768
ENTRY(__ldcg, __ldcg, false, NO_FLAG, P4, "comment")
17691769
ENTRY(__ldca, __ldca, false, NO_FLAG, P4, "comment")
1770-
ENTRY(__bfloat1622float2, __bfloat1622float2, false, NO_FLAG, P4, "comment")
1770+
ENTRY(__bfloat1622float2, __bfloat1622float2, true, NO_FLAG, P4, "Successful")
17711771
ENTRY(__bfloat162bfloat162, __bfloat162bfloat162, false, NO_FLAG, P4, "comment")
17721772
ENTRY(__bfloat162int_rd, __bfloat162int_rd, false, NO_FLAG, P4, "comment")
17731773
ENTRY(__bfloat162int_rn, __bfloat162int_rn, false, NO_FLAG, P4, "comment")
@@ -1797,7 +1797,7 @@ ENTRY(__bfloat16_as_short, __bfloat16_as_short, false, NO_FLAG, P4, "comment")
17971797
ENTRY(__bfloat16_as_ushort, __bfloat16_as_ushort, false, NO_FLAG, P4, "comment")
17981798
ENTRY(__double2bfloat16, __double2bfloat16, false, NO_FLAG, P4, "comment")
17991799
ENTRY(__double2half, __double2half, false, NO_FLAG, P4, "comment")
1800-
ENTRY(__float22bfloat162_rn, __float22bfloat162_rn, false, NO_FLAG, P4, "comment")
1800+
ENTRY(__float22bfloat162_rn, __float22bfloat162_rn, true, NO_FLAG, P4, "Successful")
18011801
ENTRY(__float2bfloat162_rn, __float2bfloat162_rn, false, NO_FLAG, P4, "comment")
18021802
ENTRY(__float2bfloat16_rd, __float2bfloat16_rd, false, NO_FLAG, P4, "comment")
18031803
ENTRY(__float2bfloat16_rn, __float2bfloat16_rn, false, NO_FLAG, P4, "comment")

clang/lib/DPCT/APINamesMath.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,9 @@ ENTRY_TYPECAST("__ushort2half_ru")
428428
ENTRY_TYPECAST("__ushort2half_rz")
429429

430430
// Bfloat16 Precision Conversion and Data Movement
431+
ENTRY_REWRITE("__bfloat1622float2")
431432
ENTRY_REWRITE("__bfloat162float")
433+
ENTRY_REWRITE("__float22bfloat162_rn")
432434
ENTRY_REWRITE("__float2bfloat16")
433435

434436
// Type Casting Intrinsics

clang/lib/DPCT/APINamesMathRewrite.inc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,8 +943,20 @@ MATH_API_REWRITER_DEVICE(
943943
ARG(0), ARG(1), LITERAL("std::not_equal_to<>()"))))))
944944

945945
// Bfloat16 Precision Conversion and Data Movement
946+
CALL_FACTORY_ENTRY("__bfloat1622float2",
947+
CALL(MapNames::getClNamespace() + "float2",
948+
ARRAY_SUBSCRIPT(ARG(0), LITERAL("0")),
949+
ARRAY_SUBSCRIPT(ARG(0), LITERAL("1"))))
950+
946951
CALL_FACTORY_ENTRY("__bfloat162float", CALL("static_cast<float>", ARG(0)))
947952

953+
CALL_FACTORY_ENTRY("__float22bfloat162_rn",
954+
CALL(MapNames::getClNamespace() + "marray<" +
955+
MapNames::getClNamespace() +
956+
"ext::oneapi::bfloat16, 2>",
957+
ARRAY_SUBSCRIPT(ARG(0), LITERAL("0")),
958+
ARRAY_SUBSCRIPT(ARG(0), LITERAL("1"))))
959+
948960
CALL_FACTORY_ENTRY("__float2bfloat16",
949961
CALL(MapNames::getClNamespace() + "ext::oneapi::bfloat16",
950962
ARG(0)))

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2017,7 +2017,7 @@ void TypeInDeclRule::registerMatcher(MatchFinder &MF) {
20172017
"cufftType_t", "cufftType", "thrust::pair", "CUdeviceptr",
20182018
"cudaDeviceAttr", "CUmodule", "CUjit_option", "CUfunction",
20192019
"cudaMemcpyKind", "cudaComputeMode", "__nv_bfloat16",
2020-
"cooperative_groups::__v1::thread_block_tile",
2020+
"__nv_bfloat162", "cooperative_groups::__v1::thread_block_tile",
20212021
"cooperative_groups::__v1::thread_block", "libraryPropertyType_t",
20222022
"libraryPropertyType", "cudaDataType_t", "cudaDataType",
20232023
"cublasComputeType_t", "cublasAtomicsMode_t", "CUmem_advise_enum",
@@ -3006,6 +3006,17 @@ void VectorTypeMemberAccessRule::renameMemberField(const MemberExpr *ME) {
30063006
return emplaceTransformation(new ReplaceText(Begin, Length, ""));
30073007
}
30083008
std::string MemberName = ME->getMemberNameInfo().getAsString();
3009+
if (MapNames::VectorTypes2MArray.count(BaseTy) &&
3010+
MapNames::MArrayMemberNamesMap.count(MemberName)) {
3011+
auto Begin = ME->getOperatorLoc();
3012+
auto End =
3013+
Lexer::getLocForEndOfToken(SM.getSpellingLoc(ME->getMemberLoc()), 0, SM,
3014+
DpctGlobalInfo::getContext().getLangOpts());
3015+
auto Length = SM.getFileOffset(End) - SM.getFileOffset(Begin);
3016+
auto MArrayIdx = MapNames::MArrayMemberNamesMap.find(MemberName)->second;
3017+
return emplaceTransformation(
3018+
new ReplaceText(Begin, Length, std::move(MArrayIdx)));
3019+
}
30093020
if (MapNames::replaceName(MapNames::MemberNamesMap, MemberName))
30103021
emplaceTransformation(
30113022
new RenameFieldInMemberExpr(ME, std::move(MemberName)));

clang/lib/DPCT/CallExprRewriterCommon.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,17 @@ makeCallExprCreator(std::string Callee,
499499
Args...);
500500
}
501501

502+
template < class BaseT, class ArgValueT>
503+
inline std::function<
504+
ArraySubscriptExprPrinter<BaseT, ArgValueT>(const CallExpr *)>
505+
makeArraySubscriptExprCreator(std::function<BaseT(const CallExpr *)> E,
506+
std::function<ArgValueT(const CallExpr *)> I) {
507+
return PrinterCreator<ArraySubscriptExprPrinter<BaseT, ArgValueT>,
508+
std::function<BaseT(const CallExpr *)>,
509+
std::function<ArgValueT(const CallExpr *)>>(std::move(E),
510+
std::move(I));
511+
}
512+
502513
inline std::function<std::string(const CallExpr *)>
503514
makeFuncNameFromDevAttrCreator(unsigned idx) {
504515
return [=](const CallExpr *CE) -> std::string {
@@ -1796,6 +1807,7 @@ class IsArgumentIntegerType {
17961807
#define STATIC_MEMBER_EXPR(...) makeStaticMemberExprCreator(__VA_ARGS__)
17971808
#define LAMBDA(...) makeLambdaCreator(__VA_ARGS__)
17981809
#define CALL(...) makeCallExprCreator(__VA_ARGS__)
1810+
#define ARRAY_SUBSCRIPT(e, i) makeArraySubscriptExprCreator(e, i)
17991811
#define CAST(T, S) makeCastExprCreator(T, S)
18001812
#define CAST_IF_NEED(T, S) makeCastIfNeedExprCreator(T, S)
18011813
#define DOUBLE_POINTER_CONST_CAST(BASE_VALUE_TYPE, EXPR, \

clang/lib/DPCT/ExprAnalysis.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,17 @@ void ExprAnalysis::analyzeExpr(const MemberExpr *ME) {
766766
addReplacement(ME->getOperatorLoc(), ME->getEndLoc(), "");
767767
} else {
768768
std::string MemberName = ME->getMemberNameInfo().getAsString();
769+
if (MapNames::VectorTypes2MArray.count(BaseType) &&
770+
MapNames::MArrayMemberNamesMap.count(MemberName)) {
771+
auto Begin = ME->getOperatorLoc();
772+
auto End = Lexer::getLocForEndOfToken(
773+
SM.getSpellingLoc(ME->getMemberLoc()), 0, SM,
774+
DpctGlobalInfo::getContext().getLangOpts());
775+
auto Length = SM.getFileOffset(End) - SM.getFileOffset(Begin);
776+
auto MArrayIdx =
777+
MapNames::MArrayMemberNamesMap.find(MemberName)->second;
778+
return addReplacement(Begin, Length, std::move(MArrayIdx));
779+
}
769780
if (MapNames::replaceName(MapNames::MemberNamesMap, MemberName)) {
770781
// Retrieve the correct location before addReplacement
771782
auto Loc =

clang/lib/DPCT/MapNames.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,9 @@ void MapNames::setExplicitNamespaceMap() {
419419
{"cudaDeviceAttr", std::make_shared<TypeNameRule>("int")},
420420
{"__nv_bfloat16", std::make_shared<TypeNameRule>(
421421
getClNamespace() + "ext::oneapi::bfloat16")},
422+
{"__nv_bfloat162", std::make_shared<TypeNameRule>(
423+
getClNamespace() + "marray<" + getClNamespace() +
424+
"ext::oneapi::bfloat16, 2>")},
422425
{"libraryPropertyType_t",
423426
std::make_shared<TypeNameRule>(
424427
getDpctNamespace() + "version_field",
@@ -1841,6 +1844,7 @@ void MapNames::setExplicitNamespaceMap() {
18411844

18421845
// Supported vector types
18431846
const MapNames::SetTy MapNames::SupportedVectorTypes{SUPPORTEDVECTORTYPENAMES};
1847+
const MapNames::SetTy MapNames::VectorTypes2MArray{VECTORTYPE2MARRAYNAMES};
18441848

18451849
const std::map<std::string, int> MapNames::VectorTypeMigratedTypeSizeMap{
18461850
{"char1", 1}, {"char2", 2}, {"char3", 4},
@@ -4286,6 +4290,10 @@ const MapNames::MapTy MapNames::MemberNamesMap{
42864290
{"x", "x()"}, {"y", "y()"}, {"z", "z()"}, {"w", "w()"},
42874291
// ...
42884292
};
4293+
const MapNames::MapTy MapNames::MArrayMemberNamesMap{
4294+
{"x", "[0]"},
4295+
{"y", "[1]"},
4296+
};
42894297

42904298
const MapNames::SetTy MapNames::HostAllocSet{
42914299
"cudaHostAllocDefault", "cudaHostAllocMapped",

clang/lib/DPCT/MapNames.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ const std::string StringLiteralUnsupported{"UNSUPPORTED"};
3434
"long4", "ulong4", "float1", "float2", "float3", "float4", "longlong1", \
3535
"ulonglong1", "longlong2", "ulonglong2", "longlong3", "ulonglong3", \
3636
"longlong4", "ulonglong4", "double1", "double2", "double3", "double4", \
37-
"__half2"
37+
"__half2", "__nv_bfloat162"
38+
#define VECTORTYPE2MARRAYNAMES "__nv_bfloat162"
3839

3940
/// Record mapping between names
4041
class MapNames {
@@ -311,6 +312,7 @@ class MapNames {
311312
using ThrustMapTy = std::map<std::string, ThrustFuncReplInfo>;
312313

313314
static const SetTy SupportedVectorTypes;
315+
static const SetTy VectorTypes2MArray;
314316
static const std::map<std::string, int> VectorTypeMigratedTypeSizeMap;
315317
static const std::map<clang::dpct::KernelArgType, int> KernelArgTypeSizeMap;
316318
static int getArrayTypeSize(const int Dim);
@@ -410,6 +412,7 @@ class MapNames {
410412
}
411413

412414
static const MapNames::MapTy MemberNamesMap;
415+
static const MapNames::MapTy MArrayMemberNamesMap;
413416
static const MapNames::MapTy FunctionAttrMap;
414417
static const MapNames::SetTy HostAllocSet;
415418
static MapNames::MapTy MathFuncNameMap;

clang/test/dpct/bfloat16.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,41 @@ void foo(__nv_bfloat16 *a) {
1515

1616
// CHECK: void test_conversions_device() {
1717
// CHECK-NEXT: float f, f_1, f_2;
18+
// CHECK-NEXT: sycl::float2 f2, f2_1, f2_2;
1819
// CHECK-NEXT: sycl::ext::oneapi::bfloat16 bf16, bf16_1, bf16_2;
20+
// CHECK-NEXT: sycl::marray<sycl::ext::oneapi::bfloat16, 2> bf162, bf162_1, bf162_2;
21+
// CHECK-NEXT: f2 = sycl::float2(bf162[0], bf162[1]);
1922
// CHECK-NEXT: f = static_cast<float>(bf16);
23+
// CHECK-NEXT: bf162 = sycl::marray<sycl::ext::oneapi::bfloat16, 2>(f2[0], f2[1]);
2024
// CHECK-NEXT: bf16 = sycl::ext::oneapi::bfloat16(f);
2125
__global__ void test_conversions_device() {
2226
float f, f_1, f_2;
27+
float2 f2, f2_1, f2_2;
2328
__nv_bfloat16 bf16, bf16_1, bf16_2;
29+
__nv_bfloat162 bf162, bf162_1, bf162_2;
30+
f2 = __bfloat1622float2(bf162);
2431
f = __bfloat162float(bf16);
32+
bf162 = __float22bfloat162_rn(f2);
2533
bf16 = __float2bfloat16(f);
2634
}
2735

2836
// CHECK: void test_conversions() {
2937
// CHECK-NEXT: float f, f_1, f_2;
38+
// CHECK-NEXT: sycl::float2 f2, f2_1, f2_2;
3039
// CHECK-NEXT: sycl::ext::oneapi::bfloat16 bf16, bf16_1, bf16_2;
40+
// CHECK-NEXT: sycl::marray<sycl::ext::oneapi::bfloat16, 2> bf162, bf162_1, bf162_2;
41+
// CHECK-NEXT: f2 = sycl::float2(bf162[0], bf162[1]);
3142
// CHECK-NEXT: f = static_cast<float>(bf16);
43+
// CHECK-NEXT: bf162 = sycl::marray<sycl::ext::oneapi::bfloat16, 2>(f2[0], f2[1]);
3244
// CHECK-NEXT: bf16 = sycl::ext::oneapi::bfloat16(f);
3345
void test_conversions() {
3446
float f, f_1, f_2;
47+
float2 f2, f2_1, f2_2;
3548
__nv_bfloat16 bf16, bf16_1, bf16_2;
49+
__nv_bfloat162 bf162, bf162_1, bf162_2;
50+
f2 = __bfloat1622float2(bf162);
3651
f = __bfloat162float(bf16);
52+
bf162 = __float22bfloat162_rn(f2);
3753
bf16 = __float2bfloat16(f);
3854
}
3955

0 commit comments

Comments
 (0)