Skip to content

Commit 51674d5

Browse files
refactor(expression): clean unnecessary code and fix the comparision between different extents types
1 parent 2abe215 commit 51674d5

File tree

3 files changed

+87
-164
lines changed

3 files changed

+87
-164
lines changed

include/boost/numeric/ublas/tensor/expression.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ protected :
145145
constexpr tensor_expression(tensor_expression&&) noexcept = default;
146146
explicit tensor_expression() = default;
147147

148-
/// @brief This the only way to access the protected move constructor of other expressions.
148+
/// @brief This is the only way to access the protected move constructor of other expressions.
149149
template<class, class> friend struct tensor_expression;
150150
};
151151

@@ -198,7 +198,7 @@ struct binary_tensor_expression
198198
*/
199199
constexpr binary_tensor_expression(binary_tensor_expression&& l) noexcept = default;
200200

201-
/// @brief This the only way to access the protected move constructor of other expressions.
201+
/// @brief This is the only way to access the protected move constructor of other expressions.
202202
template<class, class, class> friend struct unary_tensor_expression;
203203
template<class, class, class, class> friend struct binary_tensor_expression;
204204

@@ -271,7 +271,7 @@ struct unary_tensor_expression
271271
*/
272272
constexpr unary_tensor_expression(unary_tensor_expression&& l) noexcept = default;
273273

274-
/// @brief This the only way to access the protected move constructor of other expressions.
274+
/// @brief This is the only way to access the protected move constructor of other expressions.
275275
template<class, class, class> friend struct unary_tensor_expression;
276276
template<class, class, class, class> friend struct binary_tensor_expression;
277277

include/boost/numeric/ublas/tensor/expression_evaluation.hpp

+68-97
Original file line numberDiff line numberDiff line change
@@ -45,52 +45,61 @@ struct unary_tensor_expression;
4545

4646
namespace boost::numeric::ublas::detail {
4747

48-
template<class T, class E>
48+
template<typename T>
49+
struct is_tensor_type
50+
: std::false_type
51+
{};
52+
53+
template<typename E>
54+
struct is_tensor_type< tensor_core<E> >
55+
: std::true_type
56+
{};
57+
58+
template<class T>
59+
static constexpr bool is_tensor_type_v = is_tensor_type< std::decay_t<T> >::value;
60+
61+
template<typename T>
4962
struct has_tensor_types
50-
: std::integral_constant< bool, same_exp<T,E> >
63+
: is_tensor_type<T>
5164
{};
5265

53-
template<class T, class E>
54-
static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t<T>, std::decay_t<E> >::value;
66+
template<class T>
67+
static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t<T> >::value;
5568

5669
template<class T, class D>
57-
struct has_tensor_types<T, tensor_expression<T,D>>
58-
{
59-
static constexpr bool value =
60-
same_exp<T,D> ||
61-
has_tensor_types<T, std::decay_t<D> >::value;
62-
};
70+
struct has_tensor_types< tensor_expression<T,D> >
71+
: has_tensor_types< std::decay_t<D> >
72+
{};
6373

6474
template<class T, class EL, class ER, class OP>
65-
struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
66-
{
67-
static constexpr bool value =
68-
same_exp<T,EL> ||
69-
same_exp<T,ER> ||
70-
has_tensor_types<T, std::decay_t<EL> >::value ||
71-
has_tensor_types<T, std::decay_t<ER> >::value;
72-
};
75+
struct has_tensor_types< binary_tensor_expression<T,EL,ER,OP> >
76+
: std::integral_constant< bool, has_tensor_types_v<EL> || has_tensor_types_v<ER> >
77+
{};
7378

7479
template<class T, class E, class OP>
75-
struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
76-
{
77-
static constexpr bool value =
78-
same_exp<T,E> ||
79-
has_tensor_types<T, std::decay_t<E> >::value;
80-
};
80+
struct has_tensor_types< unary_tensor_expression<T,E,OP> >
81+
: has_tensor_types< std::decay_t<E> >
82+
{};
8183

8284
} // namespace boost::numeric::ublas::detail
8385

8486

8587
namespace boost::numeric::ublas::detail
8688
{
8789

90+
91+
// TODO: remove this place holder for the old ublas expression after we remove the
92+
// support for them.
93+
template<class E>
94+
[[nodiscard]]
95+
constexpr auto& retrieve_extents([[maybe_unused]] ublas_expression<E> const& /*unused*/) noexcept;
96+
8897
/** @brief Retrieves extents of the tensor_core
8998
*
9099
*/
91100
template<class TensorEngine>
92101
[[nodiscard]]
93-
constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t)
102+
constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t) noexcept
94103
{
95104
return t.extents();
96105
}
@@ -103,17 +112,14 @@ constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t)
103112
*/
104113
template<class T, class D>
105114
[[nodiscard]]
106-
constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr)
115+
constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr) noexcept
107116
{
108-
static_assert(has_tensor_types_v<T,tensor_expression<T,D>>,
117+
static_assert(has_tensor_types_v<tensor_expression<T,D>>,
109118
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
110119

111120
auto const& cast_expr = expr();
112-
113-
if constexpr ( same_exp<T,D> )
114-
return cast_expr.extents();
115-
else
116-
return retrieve_extents(cast_expr);
121+
122+
return retrieve_extents(cast_expr);
117123
}
118124

119125
// Disable warning for unreachable code for MSVC compiler
@@ -129,24 +135,24 @@ constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr)
129135
*/
130136
template<class T, class EL, class ER, class OP>
131137
[[nodiscard]]
132-
constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
138+
constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr) noexcept
133139
{
134-
static_assert(has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
140+
static_assert(has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
135141
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
136142

137143
auto const& lexpr = expr.left_expr();
138144
auto const& rexpr = expr.right_expr();
139145

140-
if constexpr ( same_exp<T,EL> )
146+
if constexpr ( is_tensor_type_v<EL> )
141147
return lexpr.extents();
142148

143-
else if constexpr ( same_exp<T,ER> )
149+
else if constexpr ( is_tensor_type_v<ER> )
144150
return rexpr.extents();
145151

146-
else if constexpr ( has_tensor_types_v<T,EL> )
152+
else if constexpr ( has_tensor_types_v<EL>)
147153
return retrieve_extents(lexpr);
148-
149-
else if constexpr ( has_tensor_types_v<T,ER> )
154+
155+
else if constexpr ( has_tensor_types_v<ER>)
150156
return retrieve_extents(rexpr);
151157
}
152158

@@ -162,19 +168,15 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
162168
*/
163169
template<class T, class E, class OP>
164170
[[nodiscard]]
165-
constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
171+
constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr) noexcept
166172
{
167173

168-
static_assert(has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
174+
static_assert(has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
169175
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
170176

171177
auto const& uexpr = expr.expr();
172178

173-
if constexpr ( same_exp<T,E> )
174-
return uexpr.extents();
175-
176-
else if constexpr ( has_tensor_types_v<T,E> )
177-
return retrieve_extents(uexpr);
179+
return retrieve_extents(uexpr);
178180
}
179181

180182
} // namespace boost::numeric::ublas::detail
@@ -184,91 +186,60 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
184186

185187
namespace boost::numeric::ublas::detail {
186188

189+
// TODO: remove this place holder for the old ublas expression after we remove the
190+
// support for them.
191+
template<class E, std::size_t ... es>
192+
[[nodiscard]] inline
193+
constexpr auto all_extents_equal([[maybe_unused]] ublas_expression<E> const& /*unused*/, [[maybe_unused]] extents<es...> const& /*unused*/) noexcept
194+
{
195+
return true;
196+
}
197+
187198
template<class EN, std::size_t ... es>
188199
[[nodiscard]] inline
189-
constexpr auto all_extents_equal(tensor_core<EN> const& t, extents<es...> const& e)
200+
constexpr auto all_extents_equal(tensor_core<EN> const& t, extents<es...> const& e) noexcept
190201
{
191202
return ::operator==(e,t.extents());
192203
}
193204

194205
template<class T, class D, std::size_t ... es>
195206
[[nodiscard]]
196-
constexpr auto all_extents_equal(tensor_expression<T,D> const& expr, extents<es...> const& e)
207+
constexpr auto all_extents_equal(tensor_expression<T,D> const& expr, extents<es...> const& e) noexcept
197208
{
198209

199-
static_assert(has_tensor_types_v<T,tensor_expression<T,D>>,
210+
static_assert(has_tensor_types_v<tensor_expression<T,D>>,
200211
"Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors.");
201212

202213
auto const& cast_expr = expr();
203214

204-
using ::operator==;
205-
using ::operator!=;
206-
207-
if constexpr ( same_exp<T,D> )
208-
if( e != cast_expr.extents() )
209-
return false;
210-
211-
if constexpr ( has_tensor_types_v<T,D> )
212-
if ( !all_extents_equal(cast_expr, e))
213-
return false;
214-
215-
return true;
216-
215+
return all_extents_equal(cast_expr, e);
217216
}
218217

219218
template<class T, class EL, class ER, class OP, std::size_t... es>
220219
[[nodiscard]]
221-
constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, extents<es...> const& e)
220+
constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, extents<es...> const& e) noexcept
222221
{
223-
static_assert(has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
222+
static_assert(has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
224223
"Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors.");
225224

226-
using ::operator==;
227-
using ::operator!=;
228-
229225
auto const& lexpr = expr.left_expr();
230226
auto const& rexpr = expr.right_expr();
231227

232-
if constexpr ( same_exp<T,EL> )
233-
if(e != lexpr.extents())
234-
return false;
235-
236-
if constexpr ( same_exp<T,ER> )
237-
if(e != rexpr.extents())
238-
return false;
239-
240-
if constexpr ( has_tensor_types_v<T,EL> )
241-
if(!all_extents_equal(lexpr, e))
242-
return false;
243-
244-
if constexpr ( has_tensor_types_v<T,ER> )
245-
if(!all_extents_equal(rexpr, e))
246-
return false;
247-
248-
return true;
228+
return all_extents_equal(lexpr, e) &&
229+
all_extents_equal(rexpr, e) ;
249230
}
250231

251232

252233
template<class T, class E, class OP, std::size_t... es>
253234
[[nodiscard]]
254-
constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, extents<es...> const& e)
235+
constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, extents<es...> const& e) noexcept
255236
{
256-
static_assert(has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
237+
static_assert(has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
257238
"Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors.");
258239

259-
using ::operator==;
260-
261240
auto const& uexpr = expr.expr();
262241

263-
if constexpr ( same_exp<T,E> )
264-
if(e != uexpr.extents())
265-
return false;
266-
267-
if constexpr ( has_tensor_types_v<T,E> )
268-
if(!all_extents_equal(uexpr, e))
269-
return false;
270-
271-
return true;
242+
return all_extents_equal(uexpr, e);
272243
}
273244

274245
} // namespace boost::numeric::ublas::detail

0 commit comments

Comments
 (0)