Skip to content

Commit e50c236

Browse files
authored
[SYCLomatic] Implement migrations for StableSortPairs and StableSortPairsDescending (#751)
Signed-off-by: Michael Aziz <[email protected]>
1 parent a7d7bbe commit e50c236

File tree

3 files changed

+459
-7
lines changed

3 files changed

+459
-7
lines changed

clang/lib/DPCT/APINamesCUB.inc

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,3 +1761,124 @@ CALL_FACTORY_ENTRY("cub::IADD3",
17611761
CAST_IF_NEED(LITERAL("unsigned int"), ARG(0)),
17621762
CAST_IF_NEED(LITERAL("unsigned int"), ARG(1))),
17631763
CAST_IF_NEED(LITERAL("unsigned int"), ARG(2)))))
1764+
1765+
// cub::DeviceSegmentedSort::StableSortPairs
1766+
CONDITIONAL_FACTORY_ENTRY(
1767+
CheckCubRedundantFunctionCall(),
1768+
REMOVE_API_FACTORY_ENTRY("cub::DeviceSegmentedSort::StableSortPairs"),
1769+
REMOVE_CUB_TEMP_STORAGE_FACTORY(FEATURE_REQUEST_FACTORY(
1770+
HelperFeatureEnum::DplExtrasAlgorithm_segmented_sort_pairs,
1771+
HEADER_INSERT_FACTORY(
1772+
HeaderType::HT_DPCT_DPL_Utils,
1773+
REMOVE_CUB_TEMP_STORAGE_FACTORY(CASE_FACTORY_ENTRY(
1774+
CASE(CheckArgCount(11, std::greater_equal<>(),
1775+
/* IncludeDefaultArg */ false),
1776+
CALL_FACTORY_ENTRY(
1777+
"cub::DeviceSegmentedSort::StableSortPairs",
1778+
CALL(MapNames::getDpctNamespace() +
1779+
"segmented_sort_pairs",
1780+
CALL("oneapi::dpl::execution::device_policy",
1781+
STREAM(10)),
1782+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1783+
ARG(8), ARG(9)))),
1784+
CASE(makeCheckAnd(CheckArgCount(10, std::greater_equal<>(),
1785+
/* IncludeDefaultArg */ false),
1786+
CheckArgType(9, "_Bool")),
1787+
CALL_FACTORY_ENTRY(
1788+
"cub::DeviceSegmentedSort::StableSortPairs",
1789+
CALL(MapNames::getDpctNamespace() +
1790+
"segmented_sort_pairs",
1791+
CALL("oneapi::dpl::execution::device_"
1792+
"policy",
1793+
STREAM(8)),
1794+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1795+
LITERAL("false"), LITERAL("true")))),
1796+
CASE(CheckArgCount(10, std::greater_equal<>(),
1797+
/* IncludeDefaultArg */ false),
1798+
CALL_FACTORY_ENTRY(
1799+
"cub::DeviceSegmentedSort::StableSortPairs",
1800+
CALL(MapNames::getDpctNamespace() +
1801+
"segmented_sort_pairs",
1802+
CALL("oneapi::dpl::execution::device_policy",
1803+
QUEUESTR),
1804+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1805+
ARG(8), ARG(9)))),
1806+
CASE(CheckArgCount(9, std::greater_equal<>(),
1807+
/* IncludeDefaultArg */ false),
1808+
CALL_FACTORY_ENTRY(
1809+
"cub::DeviceSegmentedSort::StableSortPairs",
1810+
CALL(MapNames::getDpctNamespace() +
1811+
"segmented_sort_pairs",
1812+
CALL("oneapi::dpl::execution::device_"
1813+
"policy",
1814+
STREAM(8)),
1815+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1816+
LITERAL("false"), LITERAL("true")))),
1817+
OTHERWISE(CALL_FACTORY_ENTRY(
1818+
"cub::DeviceSegmentedSort::StableSortPairs",
1819+
CALL(MapNames::getDpctNamespace() + "segmented_sort_pairs",
1820+
CALL("oneapi::dpl::execution::device_policy",
1821+
QUEUESTR),
1822+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1823+
LITERAL("false"), LITERAL("true"))))))))))
1824+
1825+
// cub::DeviceSegmentedSort::StableSortPairsDescending
1826+
CONDITIONAL_FACTORY_ENTRY(
1827+
CheckCubRedundantFunctionCall(),
1828+
REMOVE_API_FACTORY_ENTRY(
1829+
"cub::DeviceSegmentedSort::StableSortPairsDescending"),
1830+
REMOVE_CUB_TEMP_STORAGE_FACTORY(FEATURE_REQUEST_FACTORY(
1831+
HelperFeatureEnum::DplExtrasAlgorithm_segmented_sort_pairs,
1832+
HEADER_INSERT_FACTORY(
1833+
HeaderType::HT_DPCT_DPL_Utils,
1834+
REMOVE_CUB_TEMP_STORAGE_FACTORY(CASE_FACTORY_ENTRY(
1835+
CASE(CheckArgCount(11, std::greater_equal<>(),
1836+
/* IncludeDefaultArg */ false),
1837+
CALL_FACTORY_ENTRY(
1838+
"cub::DeviceSegmentedSort::StableSortPairsDescending",
1839+
CALL(MapNames::getDpctNamespace() +
1840+
"segmented_sort_pairs",
1841+
CALL("oneapi::dpl::execution::device_policy",
1842+
STREAM(10)),
1843+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1844+
ARG(8), ARG(9), LITERAL("true")))),
1845+
CASE(makeCheckAnd(CheckArgCount(10, std::greater_equal<>(),
1846+
/* IncludeDefaultArg */ false),
1847+
CheckArgType(9, "_Bool")),
1848+
CALL_FACTORY_ENTRY(
1849+
"cub::DeviceSegmentedSort::StableSortPairsDescending",
1850+
CALL(MapNames::getDpctNamespace() +
1851+
"segmented_sort_pairs",
1852+
CALL("oneapi::dpl::execution::device_"
1853+
"policy",
1854+
STREAM(8)),
1855+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1856+
LITERAL("true"), LITERAL("true")))),
1857+
CASE(CheckArgCount(10, std::greater_equal<>(),
1858+
/* IncludeDefaultArg */ false),
1859+
CALL_FACTORY_ENTRY(
1860+
"cub::DeviceSegmentedSort::StableSortPairsDescending",
1861+
CALL(MapNames::getDpctNamespace() +
1862+
"segmented_sort_pairs",
1863+
CALL("oneapi::dpl::execution::device_policy",
1864+
QUEUESTR),
1865+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1866+
ARG(8), ARG(9), LITERAL("true")))),
1867+
CASE(CheckArgCount(9, std::greater_equal<>(),
1868+
/* IncludeDefaultArg */ false),
1869+
CALL_FACTORY_ENTRY(
1870+
"cub::DeviceSegmentedSort::StableSortPairsDescending",
1871+
CALL(MapNames::getDpctNamespace() +
1872+
"segmented_sort_pairs",
1873+
CALL("oneapi::dpl::execution::device_"
1874+
"policy",
1875+
STREAM(8)),
1876+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1877+
LITERAL("true"), LITERAL("true")))),
1878+
OTHERWISE(CALL_FACTORY_ENTRY(
1879+
"cub::DeviceSegmentedSort::StableSortPairsDescending",
1880+
CALL(MapNames::getDpctNamespace() + "segmented_sort_pairs",
1881+
CALL("oneapi::dpl::execution::device_policy",
1882+
QUEUESTR),
1883+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1884+
LITERAL("true"), LITERAL("true"))))))))))

clang/lib/DPCT/CUBAPIMigration.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ auto parentStmt = []() {
4646

4747
auto isDeviceFuncCallExpr = []() {
4848
auto hasDeviceFuncName = []() {
49-
return hasAnyName("Sum", "Min", "Max", "ArgMin", "ArgMax", "Reduce",
50-
"ReduceByKey", "ExclusiveSum", "InclusiveSum",
51-
"InclusiveScan", "ExclusiveScan", "InclusiveScanByKey",
52-
"InclusiveSumByKey", "ExclusiveScanByKey",
53-
"ExclusiveSumByKey", "Flagged", "Unique", "UniqueByKey",
54-
"Encode", "SortKeys", "SortKeysDescending", "SortPairs",
55-
"SortPairsDescending", "If");
49+
return hasAnyName(
50+
"Sum", "Min", "Max", "ArgMin", "ArgMax", "Reduce", "ReduceByKey",
51+
"ExclusiveSum", "InclusiveSum", "InclusiveScan", "ExclusiveScan",
52+
"InclusiveScanByKey", "InclusiveSumByKey", "ExclusiveScanByKey",
53+
"ExclusiveSumByKey", "Flagged", "Unique", "UniqueByKey", "Encode",
54+
"SortKeys", "SortKeysDescending", "SortPairs", "SortPairsDescending",
55+
"If", "StableSortPairs", "StableSortPairsDescending");
5656
};
5757
auto hasDeviceRecordName = []() {
5858
return hasAnyName("DeviceSegmentedReduce", "DeviceReduce", "DeviceScan",

0 commit comments

Comments
 (0)