5
5
#ifndef GKO_DPCPP_BASE_COMPLEX_HPP_
6
6
#define GKO_DPCPP_BASE_COMPLEX_HPP_
7
7
8
+ #include < complex>
9
+
8
10
#include < sycl/half_type.hpp>
9
11
10
12
#include < ginkgo/config.hpp>
11
13
12
- // this file is to workaround for the intel sycl complex different loading.
13
- // intel sycl provides complex and the corresponding searching path. When users
14
- // load complex with -fsycl, the compiler will load intel's <complex> header
15
- // first and then load usual <complex> header. However, it implicitly
16
- // instantiates and uses std::complex<sycl::half>, so we need to provide the
17
- // implementation before that. In ginkgo, we will definitely load <complex> in
18
- // the public interface, which is before sycl backend, so we have no normal way
19
- // to provide the std::complex<sycl::half> implementation in sycl.
20
- // We apply the same trick to load this file first and then load their header
21
- // later. We will also configure this file as <complex> and provide the search
22
- // path in sycl module.
23
- // They start to do this from LIBSYCL 7.1.0.
24
-
25
- namespace std {
14
+
15
+ namespace gko {
26
16
27
17
template <typename >
28
18
class complex ;
29
19
30
- // implement std::complex<sycl::half> before knowing std::complex<float>
20
+
31
21
template <>
32
22
class complex <sycl::half> {
33
23
public:
@@ -53,7 +43,7 @@ class complex<sycl::half> {
53
43
{}
54
44
55
45
template <typename T, typename = std::enable_if_t <std::is_scalar<T>::value>>
56
- complex (const complex<T>& other)
46
+ complex (const std:: complex<T>& other)
57
47
: real_(static_cast <value_type>(other.real())),
58
48
imag_(static_cast <value_type>(other.imag()))
59
49
{}
@@ -62,7 +52,18 @@ class complex<sycl::half> {
62
52
63
53
value_type imag () const noexcept { return imag_; }
64
54
65
- inline operator std::complex<float >() const noexcept ;
55
+ operator std::complex<float >() const noexcept
56
+ {
57
+ return std::complex<float >(static_cast <float >(real_),
58
+ static_cast <float >(imag_));
59
+ }
60
+
61
+ bool operator !=(const complex& r) const { return !this ->operator ==(r); }
62
+
63
+ bool operator ==(const complex& r) const
64
+ {
65
+ return real_ == r.real () && imag_ == r.imag ();
66
+ }
66
67
67
68
template <typename V>
68
69
complex& operator =(const V& val)
@@ -107,37 +108,83 @@ class complex<sycl::half> {
107
108
}
108
109
109
110
template <typename T>
110
- complex& operator +=(const complex<T>& val)
111
+ complex& operator +=(const std:: complex<T>& val)
111
112
{
112
113
real_ += val.real ();
113
114
imag_ += val.imag ();
114
115
return *this ;
115
116
}
116
117
117
118
template <typename T>
118
- complex& operator -=(const complex<T>& val)
119
+ complex& operator -=(const std:: complex<T>& val)
119
120
{
120
121
real_ -= val.real ();
121
122
imag_ -= val.imag ();
122
123
return *this ;
123
124
}
124
125
125
126
template <typename T>
126
- inline complex& operator *=(const complex<T>& val);
127
+ complex& operator *=(const std::complex<T>& val)
128
+ {
129
+ auto val_f = static_cast <std::complex<float >>(val);
130
+ auto result_f = static_cast <std::complex<float >>(*this );
131
+ result_f *= val_f;
132
+ real_ = result_f.real ();
133
+ imag_ = result_f.imag ();
134
+ return *this ;
135
+ }
127
136
128
137
template <typename T>
129
- inline complex& operator /=(const complex<T>& val);
138
+ complex& operator /=(const std::complex<T>& val)
139
+ {
140
+ auto val_f = static_cast <std::complex<float >>(val);
141
+ auto result_f = static_cast <std::complex<float >>(*this );
142
+ result_f /= val_f;
143
+ real_ = result_f.real ();
144
+ imag_ = result_f.imag ();
145
+ return *this ;
146
+ }
147
+
148
+ complex& operator +=(const complex& val)
149
+ {
150
+ real_ += val.real ();
151
+ imag_ += val.imag ();
152
+ return *this ;
153
+ }
154
+
155
+ complex& operator -=(const complex& val)
156
+ {
157
+ real_ -= val.real ();
158
+ imag_ -= val.imag ();
159
+ return *this ;
160
+ }
130
161
131
- // It's for MacOS.
132
- // TODO: check whether mac compiler always use complex version even when real
133
- // half
134
- #define COMPLEX_HALF_OPERATOR (_op, _opeq ) \
135
- friend complex<sycl::half> operator _op (const complex<sycl::half> lhf, \
136
- const complex<sycl::half> rhf) \
137
- { \
138
- auto a = lhf; \
139
- a _opeq rhf; \
140
- return a; \
162
+ complex& operator *=(const complex& val)
163
+ {
164
+ auto val_f = static_cast <std::complex<float >>(val);
165
+ auto result_f = static_cast <std::complex<float >>(*this );
166
+ result_f *= val_f;
167
+ real_ = result_f.real ();
168
+ imag_ = result_f.imag ();
169
+ return *this ;
170
+ }
171
+
172
+ complex& operator /=(const complex& val)
173
+ {
174
+ auto val_f = static_cast <std::complex<float >>(val);
175
+ auto result_f = static_cast <std::complex<float >>(*this );
176
+ result_f /= val_f;
177
+ real_ = result_f.real ();
178
+ imag_ = result_f.imag ();
179
+ return *this ;
180
+ }
181
+
182
+ #define COMPLEX_HALF_OPERATOR (_op, _opeq ) \
183
+ friend complex operator _op (const complex& lhf, const complex& rhf) \
184
+ { \
185
+ auto a = lhf; \
186
+ a _opeq rhf; \
187
+ return a; \
141
188
}
142
189
143
190
COMPLEX_HALF_OPERATOR (+, +=)
@@ -147,77 +194,15 @@ class complex<sycl::half> {
147
194
148
195
#undef COMPLEX_HALF_OPERATOR
149
196
197
+ complex operator -() const { return complex (-real_, -imag_); }
198
+
150
199
private:
151
200
value_type real_;
152
201
value_type imag_;
153
202
};
154
203
155
- } // namespace std
156
-
157
-
158
- // after providing std::complex<sycl::half>, we can load their <complex> to
159
- // complete the header chain.
160
-
161
- #if GINKGO_DPCPP_MAJOR_VERSION > 7 || \
162
- (GINKGO_DPCPP_MAJOR_VERSION == 7 && GINKGO_DPCPP_MINOR_VERSION >= 1 )
163
-
164
- #if defined(__has_include_next)
165
- // GCC/clang support go through this path.
166
- #include_next <complex>
167
- #else
168
- // MSVC doesn't support "#include_next", so we take the same workaround in
169
- // stl_wrappers/complex.
170
- #include < ../stl_wrappers/complex>
171
- #endif
172
-
173
- #else
174
-
175
-
176
- #include < complex>
177
-
178
-
179
- #endif
180
-
181
-
182
- // we know the complex<float> now, so we implement those functions requiring
183
- // complex<float>
184
- namespace std {
185
-
186
-
187
- inline complex<sycl::half>::operator complex<float >() const noexcept
188
- {
189
- return std::complex<float >(static_cast <float >(real_),
190
- static_cast <float >(imag_));
191
- }
192
-
193
-
194
- template <typename T>
195
- inline complex<sycl::half>& complex<sycl::half>::operator *=(
196
- const complex<T>& val)
197
- {
198
- auto val_f = static_cast <std::complex<float >>(val);
199
- auto result_f = static_cast <std::complex<float >>(*this );
200
- result_f *= val_f;
201
- real_ = result_f.real ();
202
- imag_ = result_f.imag ();
203
- return *this ;
204
- }
205
-
206
-
207
- template <typename T>
208
- inline complex<sycl::half>& complex<sycl::half>::operator /=(
209
- const complex<T>& val)
210
- {
211
- auto val_f = static_cast <std::complex<float >>(val);
212
- auto result_f = static_cast <std::complex<float >>(*this );
213
- result_f /= val_f;
214
- real_ = result_f.real ();
215
- imag_ = result_f.imag ();
216
- return *this ;
217
- }
218
-
219
204
220
- } // namespace std
205
+ } // namespace gko
221
206
222
207
223
208
#endif // GKO_DPCPP_BASE_COMPLEX_HPP_
0 commit comments