Skip to content

Commit 328bb8f

Browse files
Restrict variadic vec ctor to N > 1
1 parent ad3c303 commit 328bb8f

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

sycl/include/sycl/detail/spirv.hpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -871,11 +871,12 @@ EnableIfNativeShuffle<T> Shuffle(GroupT g, T x, id<1> local_id) {
871871
return result;
872872
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
873873
GroupT>) {
874-
return __spirv_GroupNonUniformShuffle(group_scope<GroupT>::value,
875-
convertToOpenCLType(x), LocalId);
874+
return convertFromOpenCLTypeFor<T>(__spirv_GroupNonUniformShuffle(
875+
group_scope<GroupT>::value, convertToOpenCLType(x), LocalId));
876876
} else {
877877
// Subgroup.
878-
return __spirv_SubgroupShuffleINTEL(convertToOpenCLType(x), LocalId);
878+
return convertFromOpenCLTypeFor<T>(
879+
__spirv_SubgroupShuffleINTEL(convertToOpenCLType(x), LocalId));
879880
}
880881
#else
881882
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
@@ -908,12 +909,12 @@ EnableIfNativeShuffle<T> ShuffleXor(GroupT g, T x, id<1> mask) {
908909
// general, and simple so we go with that.
909910
id<1> TargetLocalId = g.get_local_id() ^ mask;
910911
uint32_t TargetId = MapShuffleID(g, TargetLocalId);
911-
return __spirv_GroupNonUniformShuffle(group_scope<GroupT>::value,
912-
convertToOpenCLType(x), TargetId);
912+
return convertFromOpenCLTypeFor<T>(__spirv_GroupNonUniformShuffle(
913+
group_scope<GroupT>::value, convertToOpenCLType(x), TargetId));
913914
} else {
914915
// Subgroup.
915-
return __spirv_SubgroupShuffleXorINTEL(convertToOpenCLType(x),
916-
static_cast<uint32_t>(mask.get(0)));
916+
return convertFromOpenCLTypeFor<T>(__spirv_SubgroupShuffleXorINTEL(
917+
convertToOpenCLType(x), static_cast<uint32_t>(mask.get(0))));
917918
}
918919
#else
919920
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
@@ -956,12 +957,12 @@ EnableIfNativeShuffle<T> ShuffleDown(GroupT g, T x, uint32_t delta) {
956957
if (TargetLocalId[0] + delta < g.get_local_linear_range())
957958
TargetLocalId[0] += delta;
958959
uint32_t TargetId = MapShuffleID(g, TargetLocalId);
959-
return __spirv_GroupNonUniformShuffle(group_scope<GroupT>::value,
960-
convertToOpenCLType(x), TargetId);
960+
return convertFromOpenCLTypeFor<T>(__spirv_GroupNonUniformShuffle(
961+
group_scope<GroupT>::value, convertToOpenCLType(x), TargetId));
961962
} else {
962963
// Subgroup.
963-
return __spirv_SubgroupShuffleDownINTEL(convertToOpenCLType(x),
964-
convertToOpenCLType(x), delta);
964+
return convertFromOpenCLTypeFor<T>(__spirv_SubgroupShuffleDownINTEL(
965+
convertToOpenCLType(x), convertToOpenCLType(x), delta));
965966
}
966967
#else
967968
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
@@ -1000,12 +1001,12 @@ EnableIfNativeShuffle<T> ShuffleUp(GroupT g, T x, uint32_t delta) {
10001001
if (TargetLocalId[0] >= delta)
10011002
TargetLocalId[0] -= delta;
10021003
uint32_t TargetId = MapShuffleID(g, TargetLocalId);
1003-
return __spirv_GroupNonUniformShuffle(group_scope<GroupT>::value,
1004-
convertToOpenCLType(x), TargetId);
1004+
return convertFromOpenCLTypeFor<T>(__spirv_GroupNonUniformShuffle(
1005+
group_scope<GroupT>::value, convertToOpenCLType(x), TargetId));
10051006
} else {
10061007
// Subgroup.
1007-
return __spirv_SubgroupShuffleUpINTEL(convertToOpenCLType(x),
1008-
convertToOpenCLType(x), delta);
1008+
return convertFromOpenCLTypeFor<T>(__spirv_SubgroupShuffleUpINTEL(
1009+
convertToOpenCLType(x), convertToOpenCLType(x), delta));
10091010
}
10101011
#else
10111012
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<

sycl/include/sycl/stream.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,11 @@ typename std::enable_if_t<(VecLength == 2 || VecLength == 4 || VecLength == 8 ||
579579
unsigned>
580580
VecToStr(const vec<T, VecLength> &Vec, char *VecStr, unsigned Flags, int Width,
581581
int Precision) {
582-
unsigned Len =
583-
VecToStr<T, VecLength / 2>(Vec.lo(), VecStr, Flags, Width, Precision);
582+
unsigned Len = VecToStr<T, VecLength / 2>(vec<T, VecLength / 2>{Vec.lo()},
583+
VecStr, Flags, Width, Precision);
584584
Len += append(VecStr + Len, VEC_ELEMENT_DELIMITER);
585-
Len += VecToStr<T, VecLength / 2>(Vec.hi(), VecStr + Len, Flags, Width,
586-
Precision);
585+
Len += VecToStr<T, VecLength / 2>(vec<T, VecLength / 2>{Vec.hi()},
586+
VecStr + Len, Flags, Width, Precision);
587587
return Len;
588588
}
589589

@@ -593,7 +593,8 @@ VecToStr(const vec<T, VecLength> &Vec, char *VecStr, unsigned Flags, int Width,
593593
int Precision) {
594594
unsigned Len = VecToStr<T, 2>(Vec.lo(), VecStr, Flags, Width, Precision);
595595
Len += append(VecStr + Len, VEC_ELEMENT_DELIMITER);
596-
Len += VecToStr<T, 1>(Vec.z(), VecStr + Len, Flags, Width, Precision);
596+
Len +=
597+
VecToStr<T, 1>(vec<T, 1>(Vec.z()), VecStr + Len, Flags, Width, Precision);
597598
return Len;
598599
}
599600

sycl/include/sycl/vector.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,9 @@ class __SYCL_EBO Swizzle
835835

836836
#ifdef __SYCL_DEVICE_ONLY__
837837
operator vector_t() const {
838-
return static_cast<vector_t>(static_cast<ResultVec>(*this));
838+
// operator ResultVec() isn't available for single-element swizzle, create
839+
// sycl::vec explicitly here.
840+
return static_cast<vector_t>(ResultVec{this->Vec[Indexes]...});
839841
}
840842
#endif
841843

@@ -1037,10 +1039,11 @@ class __SYCL_EBO vec :
10371039

10381040
// Constructor from values of base type or vec of base type. Checks that
10391041
// base types are match and that the NumElements == sum of lengths of args.
1040-
template <typename... argTN,
1041-
typename = std::enable_if_t<
1042-
((AllowArgTypeInVariadicCtor<argTN> && ...)) &&
1043-
((num_elements<argTN>() + ...)) == NumElements>>
1042+
template <
1043+
typename... argTN,
1044+
typename = std::enable_if_t<
1045+
(NumElements > 1 && ((AllowArgTypeInVariadicCtor<argTN> && ...)) &&
1046+
((num_elements<argTN>() + ...)) == NumElements)>>
10441047
constexpr vec(const argTN &...args)
10451048
: vec{VecArgArrayCreator<DataT, argTN...>::Create(args...),
10461049
std::make_index_sequence<NumElements>()} {}

0 commit comments

Comments
 (0)