Skip to content

Commit 9427233

Browse files
[SYCLomatic] Add equal_range implementation in helper function(#861)
Signed-off-by: Dan Hoeflinger <[email protected]>
1 parent 73bc4f8 commit 9427233

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

clang/runtime/dpct-rt/include/dpl_extras/algorithm.h.inc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,6 +2020,52 @@ inline void reduce_argmin(_ExecutionPolicy &&policy, Iter1 input, Iter2 output,
20202020
}
20212021
// DPCT_LABEL_END
20222022

2023+
// DPCT_LABEL_BEGIN|equal_range|dpct
2024+
// DPCT_DEPENDENCY_BEGIN
2025+
// DplExtrasIterators|make_constant_iterator
2026+
// DPCT_DEPENDENCY_END
2027+
// DPCT_CODE
2028+
template <typename _ExecutionPolicy, typename Iter1,
2029+
typename ValueLessComparable, typename StrictWeakOrdering>
2030+
inline ::std::pair<Iter1, Iter1>
2031+
equal_range(_ExecutionPolicy &&policy, Iter1 start, Iter1 end,
2032+
const ValueLessComparable &value, StrictWeakOrdering comp) {
2033+
2034+
auto size = ::std::distance(start, end);
2035+
auto zip_start =
2036+
oneapi::dpl::make_zip_iterator(start, dpct::make_counting_iterator(0));
2037+
auto constant_value_iter = dpct::make_constant_iterator(value);
2038+
auto reduction = [=](const auto &a, const auto &b) {
2039+
return oneapi::dpl::__internal::tuple<uint64_t, uint64_t>(
2040+
::std::min(::std::get<0>(a), ::std::get<0>(b)),
2041+
::std::max(::std::get<1>(::std::forward<decltype(a)>(a)),
2042+
::std::get<1>(::std::forward<decltype(b)>(b))));
2043+
};
2044+
auto trans = [=](const auto &a, const auto &b) {
2045+
return ::oneapi::dpl::__internal::tuple<uint64_t, uint64_t>(
2046+
comp(::std::get<0>(a), b) ? size : ::std::get<1>(a),
2047+
comp(b, ::std::get<0>(a)) ? 0 : ::std::get<1>(a) + 1);
2048+
};
2049+
2050+
auto result = oneapi::dpl::transform_reduce(
2051+
::std::forward<_ExecutionPolicy>(policy), zip_start, zip_start + size,
2052+
constant_value_iter,
2053+
oneapi::dpl::__internal::tuple<uint64_t, uint64_t>(size, 0), reduction,
2054+
trans);
2055+
return ::std::make_pair(start + ::std::get<0>(result),
2056+
start + ::std::get<1>(result));
2057+
}
2058+
2059+
template <typename _ExecutionPolicy, typename Iter1,
2060+
typename ValueLessComparable>
2061+
inline ::std::pair<Iter1, Iter1> equal_range(_ExecutionPolicy &&policy,
2062+
Iter1 start, Iter1 end,
2063+
const ValueLessComparable &value) {
2064+
return equal_range(::std::forward<_ExecutionPolicy>(policy), start, end,
2065+
value, internal::__less());
2066+
}
2067+
// DPCT_LABEL_END
2068+
20232069
} // end namespace dpct
20242070

20252071
#endif

clang/test/dpct/helper_files_ref/include/dpl_extras/algorithm.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,46 @@ inline void reduce_argmin(_ExecutionPolicy &&policy, Iter1 input, Iter2 output,
16751675
::std::copy(::std::forward<_ExecutionPolicy>(policy), ret, ret + 1, output);
16761676
}
16771677

1678+
template <typename _ExecutionPolicy, typename Iter1,
1679+
typename ValueLessComparable, typename StrictWeakOrdering>
1680+
inline ::std::pair<Iter1, Iter1>
1681+
equal_range(_ExecutionPolicy &&policy, Iter1 start, Iter1 end,
1682+
const ValueLessComparable &value, StrictWeakOrdering comp) {
1683+
1684+
auto size = ::std::distance(start, end);
1685+
auto zip_start =
1686+
oneapi::dpl::make_zip_iterator(start, dpct::make_counting_iterator(0));
1687+
auto constant_value_iter = dpct::make_constant_iterator(value);
1688+
auto reduction = [=](const auto &a, const auto &b) {
1689+
return oneapi::dpl::__internal::tuple<uint64_t, uint64_t>(
1690+
::std::min(::std::get<0>(a), ::std::get<0>(b)),
1691+
::std::max(::std::get<1>(::std::forward<decltype(a)>(a)),
1692+
::std::get<1>(::std::forward<decltype(b)>(b))));
1693+
};
1694+
auto trans = [=](const auto &a, const auto &b) {
1695+
return ::oneapi::dpl::__internal::tuple<uint64_t, uint64_t>(
1696+
comp(::std::get<0>(a), b) ? size : ::std::get<1>(a),
1697+
comp(b, ::std::get<0>(a)) ? 0 : ::std::get<1>(a) + 1);
1698+
};
1699+
1700+
auto result = oneapi::dpl::transform_reduce(
1701+
::std::forward<_ExecutionPolicy>(policy), zip_start, zip_start + size,
1702+
constant_value_iter,
1703+
oneapi::dpl::__internal::tuple<uint64_t, uint64_t>(size, 0), reduction,
1704+
trans);
1705+
return ::std::make_pair(start + ::std::get<0>(result),
1706+
start + ::std::get<1>(result));
1707+
}
1708+
1709+
template <typename _ExecutionPolicy, typename Iter1,
1710+
typename ValueLessComparable>
1711+
inline ::std::pair<Iter1, Iter1> equal_range(_ExecutionPolicy &&policy,
1712+
Iter1 start, Iter1 end,
1713+
const ValueLessComparable &value) {
1714+
return equal_range(::std::forward<_ExecutionPolicy>(policy), start, end,
1715+
value, internal::__less());
1716+
}
1717+
16781718
} // end namespace dpct
16791719

16801720
#endif

0 commit comments

Comments
 (0)