@@ -45,52 +45,61 @@ struct unary_tensor_expression;
45
45
46
46
namespace boost ::numeric::ublas::detail {
47
47
48
- template <class T , class E >
48
+ template <typename T>
49
+ struct is_tensor_type
50
+ : std::false_type
51
+ {};
52
+
53
+ template <typename E>
54
+ struct is_tensor_type < tensor_core<E> >
55
+ : std::true_type
56
+ {};
57
+
58
+ template <class T >
59
+ static constexpr bool is_tensor_type_v = is_tensor_type< std::decay_t <T> >::value;
60
+
61
+ template <typename T>
49
62
struct has_tensor_types
50
- : std::integral_constant< bool , same_exp<T,E> >
63
+ : is_tensor_type<T >
51
64
{};
52
65
53
- template <class T , class E >
54
- static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t <T>, std:: decay_t <E> >::value;
66
+ template <class T >
67
+ static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t <T> >::value;
55
68
56
69
template <class T , class D >
57
- struct has_tensor_types <T, tensor_expression<T,D>>
58
- {
59
- static constexpr bool value =
60
- same_exp<T,D> ||
61
- has_tensor_types<T, std::decay_t <D> >::value;
62
- };
70
+ struct has_tensor_types < tensor_expression<T,D> >
71
+ : has_tensor_types< std::decay_t <D> >
72
+ {};
63
73
64
74
template <class T , class EL , class ER , class OP >
65
- struct has_tensor_types <T, binary_tensor_expression<T,EL,ER,OP>>
66
- {
67
- static constexpr bool value =
68
- same_exp<T,EL> ||
69
- same_exp<T,ER> ||
70
- has_tensor_types<T, std::decay_t <EL> >::value ||
71
- has_tensor_types<T, std::decay_t <ER> >::value;
72
- };
75
+ struct has_tensor_types < binary_tensor_expression<T,EL,ER,OP> >
76
+ : std::integral_constant< bool , has_tensor_types_v<EL> || has_tensor_types_v<ER> >
77
+ {};
73
78
74
79
template <class T , class E , class OP >
75
- struct has_tensor_types <T, unary_tensor_expression<T,E,OP>>
76
- {
77
- static constexpr bool value =
78
- same_exp<T,E> ||
79
- has_tensor_types<T, std::decay_t <E> >::value;
80
- };
80
+ struct has_tensor_types < unary_tensor_expression<T,E,OP> >
81
+ : has_tensor_types< std::decay_t <E> >
82
+ {};
81
83
82
84
} // namespace boost::numeric::ublas::detail
83
85
84
86
85
87
namespace boost ::numeric::ublas::detail
86
88
{
87
89
90
+
91
+ // TODO: remove this place holder for the old ublas expression after we remove the
92
+ // support for them.
93
+ template <class E >
94
+ [[nodiscard]]
95
+ constexpr auto & retrieve_extents ([[maybe_unused]] ublas_expression<E> const & /* unused*/ ) noexcept ;
96
+
88
97
/* * @brief Retrieves extents of the tensor_core
89
98
*
90
99
*/
91
100
template <class TensorEngine >
92
101
[[nodiscard]]
93
- constexpr auto & retrieve_extents (tensor_core<TensorEngine> const & t)
102
+ constexpr auto & retrieve_extents (tensor_core<TensorEngine> const & t) noexcept
94
103
{
95
104
return t.extents ();
96
105
}
@@ -103,17 +112,14 @@ constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t)
103
112
*/
104
113
template <class T , class D >
105
114
[[nodiscard]]
106
- constexpr auto & retrieve_extents (tensor_expression<T,D> const & expr)
115
+ constexpr auto & retrieve_extents (tensor_expression<T,D> const & expr) noexcept
107
116
{
108
- static_assert (has_tensor_types_v<T, tensor_expression<T,D>>,
117
+ static_assert (has_tensor_types_v<tensor_expression<T,D>>,
109
118
" Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
110
119
111
120
auto const & cast_expr = expr ();
112
-
113
- if constexpr ( same_exp<T,D> )
114
- return cast_expr.extents ();
115
- else
116
- return retrieve_extents (cast_expr);
121
+
122
+ return retrieve_extents (cast_expr);
117
123
}
118
124
119
125
// Disable warning for unreachable code for MSVC compiler
@@ -129,24 +135,24 @@ constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr)
129
135
*/
130
136
template <class T , class EL , class ER , class OP >
131
137
[[nodiscard]]
132
- constexpr auto & retrieve_extents (binary_tensor_expression<T,EL,ER,OP> const & expr)
138
+ constexpr auto & retrieve_extents (binary_tensor_expression<T,EL,ER,OP> const & expr) noexcept
133
139
{
134
- static_assert (has_tensor_types_v<T, binary_tensor_expression<T,EL,ER,OP>>,
140
+ static_assert (has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
135
141
" Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
136
142
137
143
auto const & lexpr = expr.left_expr ();
138
144
auto const & rexpr = expr.right_expr ();
139
145
140
- if constexpr ( same_exp<T, EL> )
146
+ if constexpr ( is_tensor_type_v< EL> )
141
147
return lexpr.extents ();
142
148
143
- else if constexpr ( same_exp<T, ER> )
149
+ else if constexpr ( is_tensor_type_v< ER> )
144
150
return rexpr.extents ();
145
151
146
- else if constexpr ( has_tensor_types_v<T, EL> )
152
+ else if constexpr ( has_tensor_types_v<EL>)
147
153
return retrieve_extents (lexpr);
148
-
149
- else if constexpr ( has_tensor_types_v<T, ER> )
154
+
155
+ else if constexpr ( has_tensor_types_v<ER>)
150
156
return retrieve_extents (rexpr);
151
157
}
152
158
@@ -162,19 +168,15 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
162
168
*/
163
169
template <class T , class E , class OP >
164
170
[[nodiscard]]
165
- constexpr auto & retrieve_extents (unary_tensor_expression<T,E,OP> const & expr)
171
+ constexpr auto & retrieve_extents (unary_tensor_expression<T,E,OP> const & expr) noexcept
166
172
{
167
173
168
- static_assert (has_tensor_types_v<T, unary_tensor_expression<T,E,OP>>,
174
+ static_assert (has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
169
175
" Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
170
176
171
177
auto const & uexpr = expr.expr ();
172
178
173
- if constexpr ( same_exp<T,E> )
174
- return uexpr.extents ();
175
-
176
- else if constexpr ( has_tensor_types_v<T,E> )
177
- return retrieve_extents (uexpr);
179
+ return retrieve_extents (uexpr);
178
180
}
179
181
180
182
} // namespace boost::numeric::ublas::detail
@@ -184,91 +186,60 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
184
186
185
187
namespace boost ::numeric::ublas::detail {
186
188
189
+ // TODO: remove this place holder for the old ublas expression after we remove the
190
+ // support for them.
191
+ template <class E , std::size_t ... es>
192
+ [[nodiscard]] inline
193
+ constexpr auto all_extents_equal ([[maybe_unused]] ublas_expression<E> const & /* unused*/ , [[maybe_unused]] extents<es...> const & /* unused*/ ) noexcept
194
+ {
195
+ return true ;
196
+ }
197
+
187
198
template <class EN , std::size_t ... es>
188
199
[[nodiscard]] inline
189
- constexpr auto all_extents_equal (tensor_core<EN> const & t, extents<es...> const & e)
200
+ constexpr auto all_extents_equal (tensor_core<EN> const & t, extents<es...> const & e) noexcept
190
201
{
191
202
return ::operator ==(e,t.extents ());
192
203
}
193
204
194
205
template <class T , class D , std::size_t ... es>
195
206
[[nodiscard]]
196
- constexpr auto all_extents_equal (tensor_expression<T,D> const & expr, extents<es...> const & e)
207
+ constexpr auto all_extents_equal (tensor_expression<T,D> const & expr, extents<es...> const & e) noexcept
197
208
{
198
209
199
- static_assert (has_tensor_types_v<T, tensor_expression<T,D>>,
210
+ static_assert (has_tensor_types_v<tensor_expression<T,D>>,
200
211
" Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors." );
201
212
202
213
auto const & cast_expr = expr ();
203
214
204
- using ::operator ==;
205
- using ::operator !=;
206
-
207
- if constexpr ( same_exp<T,D> )
208
- if ( e != cast_expr.extents () )
209
- return false ;
210
-
211
- if constexpr ( has_tensor_types_v<T,D> )
212
- if ( !all_extents_equal (cast_expr, e))
213
- return false ;
214
-
215
- return true ;
216
-
215
+ return all_extents_equal (cast_expr, e);
217
216
}
218
217
219
218
template <class T , class EL , class ER , class OP , std::size_t ... es>
220
219
[[nodiscard]]
221
- constexpr auto all_extents_equal (binary_tensor_expression<T,EL,ER,OP> const & expr, extents<es...> const & e)
220
+ constexpr auto all_extents_equal (binary_tensor_expression<T,EL,ER,OP> const & expr, extents<es...> const & e) noexcept
222
221
{
223
- static_assert (has_tensor_types_v<T, binary_tensor_expression<T,EL,ER,OP>>,
222
+ static_assert (has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
224
223
" Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors." );
225
224
226
- using ::operator ==;
227
- using ::operator !=;
228
-
229
225
auto const & lexpr = expr.left_expr ();
230
226
auto const & rexpr = expr.right_expr ();
231
227
232
- if constexpr ( same_exp<T,EL> )
233
- if (e != lexpr.extents ())
234
- return false ;
235
-
236
- if constexpr ( same_exp<T,ER> )
237
- if (e != rexpr.extents ())
238
- return false ;
239
-
240
- if constexpr ( has_tensor_types_v<T,EL> )
241
- if (!all_extents_equal (lexpr, e))
242
- return false ;
243
-
244
- if constexpr ( has_tensor_types_v<T,ER> )
245
- if (!all_extents_equal (rexpr, e))
246
- return false ;
247
-
248
- return true ;
228
+ return all_extents_equal (lexpr, e) &&
229
+ all_extents_equal (rexpr, e) ;
249
230
}
250
231
251
232
252
233
template <class T , class E , class OP , std::size_t ... es>
253
234
[[nodiscard]]
254
- constexpr auto all_extents_equal (unary_tensor_expression<T,E,OP> const & expr, extents<es...> const & e)
235
+ constexpr auto all_extents_equal (unary_tensor_expression<T,E,OP> const & expr, extents<es...> const & e) noexcept
255
236
{
256
- static_assert (has_tensor_types_v<T, unary_tensor_expression<T,E,OP>>,
237
+ static_assert (has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
257
238
" Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors." );
258
239
259
- using ::operator ==;
260
-
261
240
auto const & uexpr = expr.expr ();
262
241
263
- if constexpr ( same_exp<T,E> )
264
- if (e != uexpr.extents ())
265
- return false ;
266
-
267
- if constexpr ( has_tensor_types_v<T,E> )
268
- if (!all_extents_equal (uexpr, e))
269
- return false ;
270
-
271
- return true ;
242
+ return all_extents_equal (uexpr, e);
272
243
}
273
244
274
245
} // namespace boost::numeric::ublas::detail
0 commit comments