Skip to content

Commit 9bd2825

Browse files
authored
[SYCLomatic] Support migration dereference for std::vector subscript. (#650)
Signed-off-by: Chen, Sheng S <[email protected]>
1 parent 73cb8ff commit 9bd2825

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

clang/lib/DPCT/CallExprRewriterCommon.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,9 @@ inline std::function<std::string(const CallExpr *C)> getDerefedType(size_t Idx)
696696
} else if (BaseType->isPointerType()) {
697697
DerefQT = BaseType->getPointeeType();
698698
}
699+
} else if (auto COCE = dyn_cast<CXXOperatorCallExpr>(TE)) {
700+
// Handle cases like A[3] where A is a vector with sepecfying type.
701+
DerefQT = COCE->getType().getCanonicalType();
699702
}
700703

701704
// All other cases

clang/lib/DPCT/Utility.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2032,6 +2032,13 @@ bool needExtraParens(const Expr *E) {
20322032
else
20332033
return true;
20342034
}
2035+
case Stmt::CXXOperatorCallExprClass: {
2036+
if (auto COCE = dyn_cast<CXXOperatorCallExpr>(E)) {
2037+
if (COCE->getOperator() == clang::OO_Subscript)
2038+
return false;
2039+
}
2040+
return true;
2041+
}
20352042
default:
20362043
return true;
20372044
}

clang/test/dpct/USM-restricted.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <cuda.h>
1010
#include <stdio.h>
1111
#include <memory>
12+
#include <vector>
1213

1314
#define MY_SAFE_CALL(CALL) do { \
1415
int Error = CALL; \
@@ -1070,3 +1071,15 @@ void foo14() {
10701071
cudaMemcpy((void *)&h_selected_num, (void *)d_selected_num, sizeof(int), cudaMemcpyDeviceToHost);
10711072
cudaMemcpy((void *)h_out, (void *)d_out, h_selected_num * sizeof(int), cudaMemcpyDeviceToHost);
10721073
}
1074+
1075+
struct TEST_STR {
1076+
int a[10];
1077+
};
1078+
1079+
void foo15() {
1080+
std::vector<volatile TEST_STR *> buf;
1081+
for (int i = 0; i < 32; i++) {
1082+
//CHECK: buf[i] = (volatile TEST_STR *)sycl::malloc_host(sizeof(TEST_STR), dpct::get_default_queue());
1083+
cudaMallocHost(&buf[i], sizeof(TEST_STR));
1084+
}
1085+
}

0 commit comments

Comments
 (0)