Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

#ifdef __SYCL_DEVICE_ONLY__

#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp> // for IdToMaskPosition
#include <sycl/__spirv/spirv_types.hpp>
#include <sycl/access/access.hpp>
#include <sycl/id.hpp>
#include <sycl/multi_ptr.hpp>
#include <sycl/detail/generic_type_traits.hpp>

#if defined(__NVPTX__)
#include <sycl/ext/oneapi/experimental/cuda/masked_shuffles.hpp>
Expand Down
26 changes: 8 additions & 18 deletions sycl/include/sycl/ext/oneapi/sub_group_mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

#include <sycl/detail/helpers.hpp> // for Builder
#include <sycl/detail/memcpy.hpp> // detail::memcpy
#include <sycl/exception.hpp> // for errc, exception
#include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
#include <sycl/id.hpp> // for id
#include <sycl/marray.hpp> // for marray
#include <sycl/vector.hpp> // for vec
#include <sycl/detail/spirv.hpp>
#include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
#include <sycl/id.hpp> // for id
#include <sycl/marray.hpp> // for marray
#include <sycl/vector.hpp> // for vec

#include <assert.h> // for assert
#include <climits> // for CHAR_BIT
Expand Down Expand Up @@ -342,8 +342,7 @@ template <typename Group>
std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
sub_group_mask>
group_ballot(Group g, bool predicate) {
(void)g;
group_ballot([[maybe_unused]] Group g, [[maybe_unused]] bool predicate) {
#ifdef __SYCL_DEVICE_ONLY__
auto res = __spirv_GroupNonUniformBallot(
sycl::detail::spirv::group_scope<Group>::value, predicate);
Expand All @@ -353,20 +352,11 @@ group_ballot(Group g, bool predicate) {
return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
val, g.get_max_local_range()[0]);
#else
(void)predicate;
throw exception{errc::feature_not_supported,
"Sub-group mask is not supported on host device"};
// Groups are not user-constructible, this call should not be reachable from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth adding llvm_unreachable here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two concerns here:

  1. llvm_unreachable comes from LLVM headers that we don't ship (and probably don't want to)
  2. Any "early exits" in form of exception/unreachable may affect host compilation - I'm afraid of unintended side effects like we saw in [SYCL] Fix SYCL_EXTERNAL device code when linking with a static lib #14256

// host and therefore we do nothing here.
#endif
}

} // namespace ext::oneapi
} // namespace _V1
} // namespace sycl

// We have a cyclic dependency with
// sub_group_mask.hpp
// detail/spirv.hpp
// non_uniform_groups.hpp
// "Break" it by including this at the end (instead of beginning). Ideally, we
// should refactor this somehow...
#include <sycl/detail/spirv.hpp>
Loading