Skip to content

Commit 40e2a30

Browse files
authored
Merge pull request #2636 from lyndond/l1_l2_norms
Feature: L1 and L2 norms
2 parents 83dbd82 + 535f0cb commit 40e2a30

File tree

13 files changed

+428
-0
lines changed

13 files changed

+428
-0
lines changed

stan/math/fwd/fun.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@
8787
#include <stan/math/fwd/fun/multiply_log.hpp>
8888
#include <stan/math/fwd/fun/multiply_lower_tri_self_transpose.hpp>
8989
#include <stan/math/fwd/fun/norm.hpp>
90+
#include <stan/math/fwd/fun/norm1.hpp>
91+
#include <stan/math/fwd/fun/norm2.hpp>
9092
#include <stan/math/fwd/fun/owens_t.hpp>
9193
#include <stan/math/fwd/fun/Phi.hpp>
9294
#include <stan/math/fwd/fun/Phi_approx.hpp>

stan/math/fwd/fun/norm1.hpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#ifndef STAN_MATH_FWD_FUN_NORM1_HPP
2+
#define STAN_MATH_FWD_FUN_NORM1_HPP
3+
4+
#include <stan/math/fwd/meta.hpp>
5+
#include <stan/math/fwd/core.hpp>
6+
#include <stan/math/prim/meta.hpp>
7+
#include <stan/math/prim/fun/Eigen.hpp>
8+
#include <stan/math/prim/fun/constants.hpp>
9+
#include <stan/math/prim/fun/norm1.hpp>
10+
#include <stan/math/prim/fun/sign.hpp>
11+
#include <stan/math/prim/fun/to_ref.hpp>
12+
13+
namespace stan {
14+
namespace math {
15+
16+
/**
17+
* Compute the L1 norm of the specified vector of values.
18+
*
19+
* @tparam T Type of input vector.
20+
* @param[in] x Vector of specified values.
21+
* @return L1 norm of x.
22+
*/
23+
template <typename Container,
24+
require_container_st<is_fvar, Container>* = nullptr>
25+
inline auto norm1(const Container& x) {
26+
return apply_vector_unary<ref_type_t<Container>>::reduce(
27+
to_ref(x), [&](const auto& v) {
28+
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
29+
return fvar<T_fvar_inner>(norm1(v.val()),
30+
v.d().cwiseProduct(sign(v.val())).sum());
31+
});
32+
}
33+
34+
} // namespace math
35+
} // namespace stan
36+
#endif

stan/math/fwd/fun/norm2.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef STAN_MATH_FWD_FUN_NORM2_HPP
2+
#define STAN_MATH_FWD_FUN_NORM2_HPP
3+
4+
#include <stan/math/fwd/meta.hpp>
5+
#include <stan/math/fwd/core.hpp>
6+
#include <stan/math/prim/meta.hpp>
7+
#include <stan/math/prim/fun/Eigen.hpp>
8+
#include <stan/math/prim/fun/norm2.hpp>
9+
#include <stan/math/prim/fun/to_ref.hpp>
10+
11+
namespace stan {
12+
namespace math {
13+
14+
/**
15+
* Compute the L2 norm of the specified vector of values.
16+
*
17+
* @tparam T Type of input vector.
18+
* @param[in] x Vector of specified values.
19+
* @return L2 norm of x.
20+
*/
21+
template <typename Container,
22+
require_container_st<is_fvar, Container>* = nullptr>
23+
inline auto norm2(const Container& x) {
24+
return apply_vector_unary<ref_type_t<Container>>::reduce(
25+
to_ref(x), [&](const auto& v) {
26+
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
27+
T_fvar_inner res = norm2(v.val());
28+
return fvar<T_fvar_inner>(res,
29+
v.d().cwiseProduct((v.val() / res)).sum());
30+
});
31+
}
32+
33+
} // namespace math
34+
} // namespace stan
35+
#endif

stan/math/prim/fun.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@
227227
#include <stan/math/prim/fun/multiply_log.hpp>
228228
#include <stan/math/prim/fun/multiply_lower_tri_self_transpose.hpp>
229229
#include <stan/math/prim/fun/norm.hpp>
230+
#include <stan/math/prim/fun/norm1.hpp>
231+
#include <stan/math/prim/fun/norm2.hpp>
230232
#include <stan/math/prim/fun/num_elements.hpp>
231233
#include <stan/math/prim/fun/offset_multiplier_constrain.hpp>
232234
#include <stan/math/prim/fun/offset_multiplier_free.hpp>

stan/math/prim/fun/norm1.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef STAN_MATH_PRIM_FUN_NORM1_HPP
2+
#define STAN_MATH_PRIM_FUN_NORM1_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/Eigen.hpp>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Returns L1 norm of a vector. For vectors that equals the
13+
* sum of magnitudes of its individual elements.
14+
*
15+
* @tparam T type of the vector (must be derived from \c Eigen::MatrixBase)
16+
* @param v Vector.
17+
* @return L1 norm of v.
18+
*/
19+
template <typename Container, require_st_arithmetic<Container>* = nullptr,
20+
require_container_t<Container>* = nullptr>
21+
inline auto norm1(const Container& x) {
22+
return apply_vector_unary<ref_type_t<Container>>::reduce(
23+
to_ref(x), [](const auto& v) { return v.template lpNorm<1>(); });
24+
}
25+
26+
} // namespace math
27+
} // namespace stan
28+
29+
#endif

stan/math/prim/fun/norm2.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef STAN_MATH_PRIM_FUN_NORM2_HPP
2+
#define STAN_MATH_PRIM_FUN_NORM2_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/Eigen.hpp>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Returns L2 norm of a vector. For vectors that equals the square-root of the
13+
* sum of squares of the elements.
14+
*
15+
* @tparam T type of the vector (must be derived from \c Eigen::MatrixBase)
16+
* @param v Vector.
17+
* @return L2 norm of v.
18+
*/
19+
template <typename Container, require_st_arithmetic<Container>* = nullptr,
20+
require_container_t<Container>* = nullptr>
21+
inline auto norm2(const Container& x) {
22+
return apply_vector_unary<ref_type_t<Container>>::reduce(
23+
to_ref(x), [](const auto& v) { return v.template lpNorm<2>(); });
24+
}
25+
26+
} // namespace math
27+
} // namespace stan
28+
29+
#endif

stan/math/rev/fun.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@
130130
#include <stan/math/rev/fun/multiply_log.hpp>
131131
#include <stan/math/rev/fun/multiply_lower_tri_self_transpose.hpp>
132132
#include <stan/math/rev/fun/norm.hpp>
133+
#include <stan/math/rev/fun/norm1.hpp>
134+
#include <stan/math/rev/fun/norm2.hpp>
133135
#include <stan/math/rev/fun/ordered_constrain.hpp>
134136
#include <stan/math/rev/fun/owens_t.hpp>
135137
#include <stan/math/rev/fun/polar.hpp>

stan/math/rev/fun/norm1.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#ifndef STAN_MATH_REV_FUN_NORM1_HPP
2+
#define STAN_MATH_REV_FUN_NORM1_HPP
3+
4+
#include <stan/math/rev/meta.hpp>
5+
#include <stan/math/rev/core.hpp>
6+
#include <stan/math/rev/core/typedefs.hpp>
7+
#include <stan/math/prim/err.hpp>
8+
#include <stan/math/prim/fun/Eigen.hpp>
9+
#include <stan/math/prim/fun/sign.hpp>
10+
11+
namespace stan {
12+
namespace math {
13+
14+
/**
15+
* Returns the L1 norm of a vector of var.
16+
*
17+
* @tparam T type of the vector (must have one compile-time dimension equal to
18+
* 1)
19+
* @param[in] v Vector.
20+
* @return L1 norm of v.
21+
*/
22+
template <typename T, require_eigen_vector_vt<is_var, T>* = nullptr>
23+
inline var norm1(const T& v) {
24+
arena_t<T> arena_v = v;
25+
var res = norm1(arena_v.val());
26+
reverse_pass_callback([res, arena_v]() mutable {
27+
arena_v.adj().array() += res.adj() * sign(arena_v.val().array());
28+
});
29+
return res;
30+
}
31+
32+
/**
33+
* Returns the L1 norm of a `var_value<Vector>`.
34+
*
35+
* @tparam A `var_value<>` whose inner type has one compile-time row or column.
36+
* @param[in] v Vector.
37+
* @return L1 norm of v.
38+
*/
39+
//
40+
template <typename T, require_var_matrix_t<T>* = nullptr>
41+
inline var norm1(const T& v) {
42+
return make_callback_vari(norm1(v.val()), [v](const auto& res) mutable {
43+
v.adj().array() += res.adj() * sign(v.val().array());
44+
});
45+
}
46+
47+
} // namespace math
48+
} // namespace stan
49+
#endif

stan/math/rev/fun/norm2.hpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifndef STAN_MATH_REV_FUN_NORM2_HPP
2+
#define STAN_MATH_REV_FUN_NORM2_HPP
3+
4+
#include <stan/math/rev/meta.hpp>
5+
#include <stan/math/rev/core.hpp>
6+
#include <stan/math/rev/core/typedefs.hpp>
7+
#include <stan/math/prim/err.hpp>
8+
#include <stan/math/prim/fun/Eigen.hpp>
9+
10+
namespace stan {
11+
namespace math {
12+
13+
/**
14+
* Returns the L2 norm of a vector of var.
15+
*
16+
* @tparam T type of the vector (must have one compile-time dimension equal to
17+
* 1)
18+
* @param[in] v Vector.
19+
* @return L2 norm of v.
20+
*/
21+
template <typename T, require_eigen_vector_vt<is_var, T>* = nullptr>
22+
inline var norm2(const T& v) {
23+
arena_t<T> arena_v = v;
24+
var res = norm2(arena_v.val());
25+
reverse_pass_callback([res, arena_v]() mutable {
26+
arena_v.adj().array() += res.adj() * (arena_v.val().array() / res.val());
27+
});
28+
return res;
29+
}
30+
31+
/**
32+
* Returns the L2 norm of a `var_value<Vector>`.
33+
*
34+
* @tparam A `var_value<>` whose inner type has one compile-time row or column.
35+
* @param[in] v Vector.
36+
* @return L2 norm of v.
37+
*/
38+
template <typename T, require_var_matrix_t<T>* = nullptr>
39+
inline var norm2(const T& v) {
40+
return make_callback_vari(norm2(v.val()), [v](const auto& res) mutable {
41+
v.adj().array() += res.adj() * (v.val().array() / res.val());
42+
});
43+
}
44+
45+
} // namespace math
46+
} // namespace stan
47+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <vector>
3+
4+
TEST(MathMixMatFun, norm1) {
5+
auto f = [](const auto& y) { return stan::math::norm1(y); };
6+
7+
Eigen::VectorXd x0(0);
8+
9+
Eigen::VectorXd x1(1);
10+
x1 << 2;
11+
12+
Eigen::VectorXd x2(2);
13+
x2 << 2, 3;
14+
15+
Eigen::VectorXd x3(3);
16+
x3 << 2, 3, 4;
17+
18+
for (const auto& a : std::vector<Eigen::VectorXd>{x0, x1, x2, x3}) {
19+
stan::test::expect_ad(f, a);
20+
stan::test::expect_ad_matvar(f, a);
21+
}
22+
}

0 commit comments

Comments
 (0)