Skip to content

Commit a5ad85c

Browse files
committed
move to gko::complex<sycl::half> alone
1 parent 7921c98 commit a5ad85c

File tree

7 files changed

+139
-111
lines changed

7 files changed

+139
-111
lines changed

Diff for: accessor/sycl_helper.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ namespace gko {
3333
class half;
3434

3535

36+
template <typename V>
37+
class complex;
38+
39+
3640
namespace acc {
3741
namespace detail {
3842

@@ -81,6 +85,12 @@ struct sycl_type<std::complex<T>> {
8185
};
8286

8387

88+
template <>
89+
struct sycl_type<std::complex<gko::half>> {
90+
using type = gko::complex<typename sycl_type<gko::half>::type>;
91+
};
92+
93+
8494
} // namespace detail
8595

8696

Diff for: cmake/sycl.cmake

-4
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ function(gko_add_sycl_to_target)
2222
"${one_value_args}"
2323
"${multi_value_args}"
2424
${ARGN})
25-
# trick for complex header chain
26-
if("${GINKGO_DPCPP_MAJOR_VERSION}.${GINKGO_DPCPP_MINOR_VERSION}" VERSION_GREATER_EQUAL 7.1)
27-
target_include_directories(${SYCL_TARGET} PRIVATE "${PROJECT_BINARY_DIR}/dpcpp/base")
28-
endif()
2925
if(COMMAND add_sycl_to_target)
3026
add_sycl_to_target(${ARGN})
3127
return()

Diff for: dpcpp/CMakeLists.txt

-5
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@ find_package(oneDPL REQUIRED HINTS "$ENV{DPL_ROOT}" "$ENV{DPLROOT}")
44
set(GINKGO_MKL_ROOT "${MKL_DIR}" PARENT_SCOPE)
55
set(GINKGO_DPL_ROOT "${oneDPL_DIR}" PARENT_SCOPE)
66

7-
# trick for complex header chain
8-
if("${GINKGO_DPCPP_MAJOR_VERSION}.${GINKGO_DPCPP_MINOR_VERSION}" VERSION_GREATER_EQUAL 7.1)
9-
configure_file(base/complex.hpp ${CMAKE_CURRENT_BINARY_DIR}/base/complex)
10-
endif()
11-
127
include(${PROJECT_SOURCE_DIR}/cmake/template_instantiation.cmake)
138
add_instantiation_files(${PROJECT_SOURCE_DIR}/common/unified matrix/dense_kernels.instantiate.cpp DENSE_INSTANTIATE)
149
add_instantiation_files(. solver/batch_bicgstab_launch.instantiate.dp.cpp BATCH_BICGSTAB_INSTANTIATE)

Diff for: dpcpp/base/complex.hpp

+81-96
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,19 @@
55
#ifndef GKO_DPCPP_BASE_COMPLEX_HPP_
66
#define GKO_DPCPP_BASE_COMPLEX_HPP_
77

8+
#include <complex>
9+
810
#include <sycl/half_type.hpp>
911

1012
#include <ginkgo/config.hpp>
1113

12-
// this file is to workaround for the intel sycl complex different loading.
13-
// intel sycl provides complex and the corresponding searching path. When users
14-
// load complex with -fsycl, the compiler will load intel's <complex> header
15-
// first and then load usual <complex> header. However, it implicitly
16-
// instantiates and uses std::complex<sycl::half>, so we need to provide the
17-
// implementation before that. In ginkgo, we will definitely load <complex> in
18-
// the public interface, which is before sycl backend, so we have no normal way
19-
// to provide the std::complex<sycl::half> implementation in sycl.
20-
// We apply the same trick to load this file first and then load their header
21-
// later. We will also configure this file as <complex> and provide the search
22-
// path in sycl module.
23-
// They start to do this from LIBSYCL 7.1.0.
24-
25-
namespace std {
14+
15+
namespace gko {
2616

2717
template <typename>
2818
class complex;
2919

30-
// implement std::complex<sycl::half> before knowing std::complex<float>
20+
3121
template <>
3222
class complex<sycl::half> {
3323
public:
@@ -53,7 +43,7 @@ class complex<sycl::half> {
5343
{}
5444

5545
template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
56-
complex(const complex<T>& other)
46+
complex(const std::complex<T>& other)
5747
: real_(static_cast<value_type>(other.real())),
5848
imag_(static_cast<value_type>(other.imag()))
5949
{}
@@ -62,7 +52,18 @@ class complex<sycl::half> {
6252

6353
value_type imag() const noexcept { return imag_; }
6454

65-
inline operator std::complex<float>() const noexcept;
55+
operator std::complex<float>() const noexcept
56+
{
57+
return std::complex<float>(static_cast<float>(real_),
58+
static_cast<float>(imag_));
59+
}
60+
61+
bool operator!=(const complex& r) const { return !this->operator==(r); }
62+
63+
bool operator==(const complex& r) const
64+
{
65+
return real_ == r.real() && imag_ == r.imag();
66+
}
6667

6768
template <typename V>
6869
complex& operator=(const V& val)
@@ -107,37 +108,83 @@ class complex<sycl::half> {
107108
}
108109

109110
template <typename T>
110-
complex& operator+=(const complex<T>& val)
111+
complex& operator+=(const std::complex<T>& val)
111112
{
112113
real_ += val.real();
113114
imag_ += val.imag();
114115
return *this;
115116
}
116117

117118
template <typename T>
118-
complex& operator-=(const complex<T>& val)
119+
complex& operator-=(const std::complex<T>& val)
119120
{
120121
real_ -= val.real();
121122
imag_ -= val.imag();
122123
return *this;
123124
}
124125

125126
template <typename T>
126-
inline complex& operator*=(const complex<T>& val);
127+
complex& operator*=(const std::complex<T>& val)
128+
{
129+
auto val_f = static_cast<std::complex<float>>(val);
130+
auto result_f = static_cast<std::complex<float>>(*this);
131+
result_f *= val_f;
132+
real_ = result_f.real();
133+
imag_ = result_f.imag();
134+
return *this;
135+
}
127136

128137
template <typename T>
129-
inline complex& operator/=(const complex<T>& val);
138+
complex& operator/=(const std::complex<T>& val)
139+
{
140+
auto val_f = static_cast<std::complex<float>>(val);
141+
auto result_f = static_cast<std::complex<float>>(*this);
142+
result_f /= val_f;
143+
real_ = result_f.real();
144+
imag_ = result_f.imag();
145+
return *this;
146+
}
147+
148+
complex& operator+=(const complex& val)
149+
{
150+
real_ += val.real();
151+
imag_ += val.imag();
152+
return *this;
153+
}
154+
155+
complex& operator-=(const complex& val)
156+
{
157+
real_ -= val.real();
158+
imag_ -= val.imag();
159+
return *this;
160+
}
130161

131-
// It's for MacOS.
132-
// TODO: check whether mac compiler always use complex version even when real
133-
// half
134-
#define COMPLEX_HALF_OPERATOR(_op, _opeq) \
135-
friend complex<sycl::half> operator _op(const complex<sycl::half> lhf, \
136-
const complex<sycl::half> rhf) \
137-
{ \
138-
auto a = lhf; \
139-
a _opeq rhf; \
140-
return a; \
162+
complex& operator*=(const complex& val)
163+
{
164+
auto val_f = static_cast<std::complex<float>>(val);
165+
auto result_f = static_cast<std::complex<float>>(*this);
166+
result_f *= val_f;
167+
real_ = result_f.real();
168+
imag_ = result_f.imag();
169+
return *this;
170+
}
171+
172+
complex& operator/=(const complex& val)
173+
{
174+
auto val_f = static_cast<std::complex<float>>(val);
175+
auto result_f = static_cast<std::complex<float>>(*this);
176+
result_f /= val_f;
177+
real_ = result_f.real();
178+
imag_ = result_f.imag();
179+
return *this;
180+
}
181+
182+
#define COMPLEX_HALF_OPERATOR(_op, _opeq) \
183+
friend complex operator _op(const complex& lhf, const complex& rhf) \
184+
{ \
185+
auto a = lhf; \
186+
a _opeq rhf; \
187+
return a; \
141188
}
142189

143190
COMPLEX_HALF_OPERATOR(+, +=)
@@ -147,77 +194,15 @@ class complex<sycl::half> {
147194

148195
#undef COMPLEX_HALF_OPERATOR
149196

197+
complex operator-() const { return complex(-real_, -imag_); }
198+
150199
private:
151200
value_type real_;
152201
value_type imag_;
153202
};
154203

155-
} // namespace std
156-
157-
158-
// after providing std::complex<sycl::half>, we can load their <complex> to
159-
// complete the header chain.
160-
161-
#if GINKGO_DPCPP_MAJOR_VERSION > 7 || \
162-
(GINKGO_DPCPP_MAJOR_VERSION == 7 && GINKGO_DPCPP_MINOR_VERSION >= 1)
163-
164-
#if defined(__has_include_next)
165-
// GCC/clang support go through this path.
166-
#include_next <complex>
167-
#else
168-
// MSVC doesn't support "#include_next", so we take the same workaround in
169-
// stl_wrappers/complex.
170-
#include <../stl_wrappers/complex>
171-
#endif
172-
173-
#else
174-
175-
176-
#include <complex>
177-
178-
179-
#endif
180-
181-
182-
// we know the complex<float> now, so we implement those functions requiring
183-
// complex<float>
184-
namespace std {
185-
186-
187-
inline complex<sycl::half>::operator complex<float>() const noexcept
188-
{
189-
return std::complex<float>(static_cast<float>(real_),
190-
static_cast<float>(imag_));
191-
}
192-
193-
194-
template <typename T>
195-
inline complex<sycl::half>& complex<sycl::half>::operator*=(
196-
const complex<T>& val)
197-
{
198-
auto val_f = static_cast<std::complex<float>>(val);
199-
auto result_f = static_cast<std::complex<float>>(*this);
200-
result_f *= val_f;
201-
real_ = result_f.real();
202-
imag_ = result_f.imag();
203-
return *this;
204-
}
205-
206-
207-
template <typename T>
208-
inline complex<sycl::half>& complex<sycl::half>::operator/=(
209-
const complex<T>& val)
210-
{
211-
auto val_f = static_cast<std::complex<float>>(val);
212-
auto result_f = static_cast<std::complex<float>>(*this);
213-
result_f /= val_f;
214-
real_ = result_f.real();
215-
imag_ = result_f.imag();
216-
return *this;
217-
}
218-
219204

220-
} // namespace std
205+
} // namespace gko
221206

222207

223208
#endif // GKO_DPCPP_BASE_COMPLEX_HPP_

Diff for: dpcpp/base/math.hpp

+40-5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,41 @@ struct basic_float_traits<sycl::half> {
3232
template <>
3333
struct is_complex_or_scalar_impl<sycl::half> : public std::true_type {};
3434

35+
template <typename ValueType>
36+
struct complex_helper {
37+
using type = std::complex<ValueType>;
38+
};
39+
40+
template <>
41+
struct complex_helper<sycl::half> {
42+
using type = gko::complex<sycl::half>;
43+
};
44+
45+
46+
template <typename T>
47+
struct type_size_impl<gko::complex<T>> {
48+
static constexpr auto value = sizeof(T) * byte_size;
49+
};
50+
51+
52+
template <typename T>
53+
struct remove_complex_impl<gko::complex<T>> {
54+
using type = T;
55+
};
56+
57+
58+
template <typename T>
59+
struct truncate_type_impl<gko::complex<T>> {
60+
using type =
61+
typename complex_helper<typename truncate_type_impl<T>::type>::type;
62+
};
63+
64+
template <typename T>
65+
struct is_complex_impl<gko::complex<T>> : public std::true_type {};
66+
67+
template <typename T>
68+
struct is_complex_or_scalar_impl<gko::complex<T>>
69+
: public is_complex_or_scalar_impl<T> {};
3570

3671
} // namespace detail
3772

@@ -41,7 +76,7 @@ bool __dpct_inline__ is_nan(const sycl::half& val)
4176
return std::isnan(static_cast<float>(val));
4277
}
4378

44-
bool __dpct_inline__ is_nan(const std::complex<sycl::half>& val)
79+
bool __dpct_inline__ is_nan(const gko::complex<sycl::half>& val)
4580
{
4681
return is_nan(val.real()) || is_nan(val.imag());
4782
}
@@ -52,7 +87,7 @@ sycl::half __dpct_inline__ abs(const sycl::half& val)
5287
return abs(static_cast<float>(val));
5388
}
5489

55-
sycl::half __dpct_inline__ abs(const std::complex<sycl::half>& val)
90+
sycl::half __dpct_inline__ abs(const gko::complex<sycl::half>& val)
5691
{
5792
return abs(static_cast<std::complex<float>>(val));
5893
}
@@ -62,8 +97,8 @@ sycl::half __dpct_inline__ sqrt(const sycl::half& val)
6297
return sqrt(static_cast<float>(val));
6398
}
6499

65-
std::complex<sycl::half> __dpct_inline__
66-
sqrt(const std::complex<sycl::half>& val)
100+
gko::complex<sycl::half> __dpct_inline__
101+
sqrt(const gko::complex<sycl::half>& val)
67102
{
68103
return sqrt(static_cast<std::complex<float>>(val));
69104
}
@@ -74,7 +109,7 @@ bool __dpct_inline__ is_finite(const sycl::half& value)
74109
return abs(value) < std::numeric_limits<sycl::half>::infinity();
75110
}
76111

77-
bool __dpct_inline__ is_finite(const std::complex<sycl::half>& value)
112+
bool __dpct_inline__ is_finite(const gko::complex<sycl::half>& value)
78113
{
79114
return is_finite(value.real()) && is_finite(value.imag());
80115
}

Diff for: dpcpp/base/types.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <ginkgo/core/base/matrix_data.hpp>
1515
#include <ginkgo/core/base/types.hpp>
1616

17+
#include "dpcpp/base/complex.hpp"
18+
1719

1820
namespace gko {
1921
namespace kernels {
@@ -56,6 +58,11 @@ struct sycl_type_impl<std::complex<T>> {
5658
using type = std::complex<typename sycl_type_impl<T>::type>;
5759
};
5860

61+
template <>
62+
struct sycl_type_impl<std::complex<gko::half>> {
63+
using type = gko::complex<typename sycl_type_impl<gko::half>::type>;
64+
};
65+
5966
template <typename ValueType, typename IndexType>
6067
struct sycl_type_impl<matrix_data_entry<ValueType, IndexType>> {
6168
using type =

Diff for: dpcpp/preconditioner/batch_block_jacobi.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class BlockJacobi final {
131131

132132
// reduction (it does not support complex<half>)
133133
if constexpr (std::is_same_v<value_type,
134-
std::complex<sycl::half>>) {
134+
gko::complex<sycl::half>>) {
135135
for (int i = sg_size / 2; i > 0; i /= 2) {
136136
sum += sycl::shift_group_left(sg, sum, i);
137137
}

0 commit comments

Comments
 (0)