Skip to content

Commit 522a76d

Browse files
committed
reworked SVD class, pinv, lstsq. added svdvals()
1 parent 83c0761 commit 522a76d

File tree

11 files changed

+378
-721
lines changed

11 files changed

+378
-721
lines changed

docs/markdown/ReleaseNotes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
## Version 2.15.1
44

55
* fixed **Issue #144**, `outer` function now operates on arrays of different sizes
6+
* fixed **Issue #215**, `pinv` function has been corrected
67
* added `eig()` for **Issue #143** <https://numpy.org/doc/stable/reference/generated/numpy.linalg.eig.html#numpy.linalg.eig>
78
* unlike NumPy, the NumCpp implementation is only suitable for real symmetric matrices
89
* added `eigvals()` for **Issue #143** <https://numpy.org/doc/stable/reference/generated/numpy.linalg.eigvals.html#numpy.linalg.eigvals>
910
* unlike NumPy, the NumCpp implementation is only suitable for real symmetric matrices
11+
* added `svdvals` <https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html>
1012

1113
## Version 2.15.0
1214

include/NumCpp/Linalg.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@
4242
#include "NumCpp/Linalg/pivotLU_decomposition.hpp"
4343
#include "NumCpp/Linalg/solve.hpp"
4444
#include "NumCpp/Linalg/svd.hpp"
45+
#include "NumCpp/Linalg/svdvals.hpp"

include/NumCpp/Linalg/lstsq.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#pragma once
2929

3030
#include "NumCpp/Core/Internal/StaticAsserts.hpp"
31-
#include "NumCpp/Linalg/svd/SVDClass.hpp"
31+
#include "NumCpp/Linalg/svd/SVD.hpp"
3232
#include "NumCpp/NdArray.hpp"
3333

3434
namespace nc::linalg
@@ -50,12 +50,11 @@ namespace nc::linalg
5050
/// @param inA: coefficient matrix
5151
/// @param inB: Ordinate or "dependent variable" values. If b is two-dimensional, the least-squares solution is
5252
/// calculated for each of the K columns of b.
53-
/// @param inTolerance (default 1e-12)
5453
///
5554
/// @return NdArray
5655
///
5756
template<typename dtype>
58-
NdArray<double> lstsq(const NdArray<dtype>& inA, const NdArray<dtype>& inB, double inTolerance = 1e-12)
57+
NdArray<double> lstsq(const NdArray<dtype>& inA, const NdArray<dtype>& inB)
5958
{
6059
STATIC_ASSERT_ARITHMETIC(dtype);
6160

@@ -72,12 +71,11 @@ namespace nc::linalg
7271
THROW_INVALID_ARGUMENT_ERROR("Invalid matrix dimensions");
7372
}
7473

75-
SVD svdSolver(inA.template astype<double>());
76-
const double threshold = inTolerance * svdSolver.s().front();
74+
SVD svd(inA.template astype<double>());
7775

7876
if (bIsFlat)
7977
{
80-
return svdSolver.solve(inB.template astype<double>(), threshold);
78+
return svd.lstsq(inB.template astype<double>());
8179
}
8280

8381
const auto bCast = inB.template astype<double>();
@@ -88,7 +86,7 @@ namespace nc::linalg
8886

8987
for (uint32 col = 0; col < bShape.cols; ++col)
9088
{
91-
result.put(resultRowSlice, col, svdSolver.solve(bCast(bRowSlice, col), threshold));
89+
result.put(resultRowSlice, col, svd.lstsq(bCast(bRowSlice, col)));
9290
}
9391

9492
return result;

include/NumCpp/Linalg/pinv.hpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131

3232
#include "NumCpp/Core/Internal/StaticAsserts.hpp"
3333
#include "NumCpp/Core/Types.hpp"
34-
#include "NumCpp/Functions/zeros.hpp"
35-
#include "NumCpp/Linalg/svd.hpp"
34+
#include "NumCpp/Linalg/svd/SVD.hpp"
3635
#include "NumCpp/NdArray.hpp"
3736

3837
namespace nc::linalg
@@ -51,19 +50,6 @@ namespace nc::linalg
5150
{
5251
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
5352

54-
NdArray<double> u;
55-
NdArray<double> d;
56-
NdArray<double> v;
57-
svd(inArray, u, d, v);
58-
59-
const auto inShape = inArray.shape();
60-
auto dPlus = nc::zeros<double>(inShape.cols, inShape.rows); // transpose
61-
62-
for (uint32 i = 0; i < d.shape().rows; ++i)
63-
{
64-
dPlus(i, i) = 1. / d(i, i);
65-
}
66-
67-
return v.transpose().dot(dPlus).dot(u.transpose());
53+
return SVD{ inArray }.pinv();
6854
}
69-
} // namespace nc::linalg
55+
} // namespace nc::linalg

include/NumCpp/Linalg/pivotLU_decomposition.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ namespace nc::linalg
5858
{
5959
STATIC_ASSERT_ARITHMETIC(dtype);
6060

61-
const auto shape = inMatrix.shape();
61+
const auto& shape = inMatrix.shape();
6262

6363
if (!shape.issquare())
6464
{

include/NumCpp/Linalg/svd.hpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
#include "NumCpp/Core/Internal/StaticAsserts.hpp"
3333
#include "NumCpp/Functions/diagflat.hpp"
34-
#include "NumCpp/Linalg/svd/SVDClass.hpp"
34+
#include "NumCpp/Linalg/svd/SVD.hpp"
3535
#include "NumCpp/NdArray.hpp"
3636

3737
namespace nc::linalg
@@ -45,20 +45,17 @@ namespace nc::linalg
4545
/// @param inArray: NdArray to be SVDed
4646
/// @param outU: NdArray output U
4747
/// @param outS: NdArray output S
48-
/// @param outVt: NdArray output V transpose
48+
/// @param outVT: NdArray output V transpose
4949
///
5050
template<typename dtype>
51-
void svd(const NdArray<dtype>& inArray, NdArray<double>& outU, NdArray<double>& outS, NdArray<double>& outVt)
51+
void svd(const NdArray<dtype>& inArray, NdArray<double>& outU, NdArray<double>& outS, NdArray<double>& outVT)
5252
{
5353
STATIC_ASSERT_ARITHMETIC(dtype);
5454

55-
SVD svdSolver(inArray.template astype<double>());
56-
outU = svdSolver.u();
55+
const auto svd = SVD{ inArray };
5756

58-
NdArray<double> vt = svdSolver.v().transpose();
59-
outVt = std::move(vt);
60-
61-
NdArray<double> s = diagflat(svdSolver.s(), 0);
62-
outS = std::move(s);
57+
outU = svd.u();
58+
outS = std::move(svd.s()[Slice(std::min(inArray.numRows(), inArray.numCols()))]);
59+
outVT = std::move(svd.v().transpose());
6360
}
6461
} // namespace nc::linalg

include/NumCpp/Linalg/svd/SVD.hpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
/// @file
2+
/// @author David Pilger <[email protected]>
3+
/// [GitHub Repository](https://github.com/dpilger26/NumCpp)
4+
///
5+
/// License
6+
/// Copyright 2020 David Pilger
7+
///
8+
/// Permission is hereby granted, free of charge, to any person obtaining a copy of this
9+
/// software and associated documentation files(the "Software"), to deal in the Software
10+
/// without restriction, including without limitation the rights to use, copy, modify,
11+
/// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
12+
/// permit persons to whom the Software is furnished to do so, subject to the following
13+
/// conditions :
14+
///
15+
/// The above copyright notice and this permission notice shall be included in all copies
16+
/// or substantial portions of the Software.
17+
///
18+
/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
19+
/// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
20+
/// PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
21+
/// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
22+
/// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23+
/// DEALINGS IN THE SOFTWARE.
24+
///
25+
/// Description
26+
/// Performs the singular value decomposition of a general matrix,
27+
/// taken and adapted from Numerical Recipes Third Edition svd.h
28+
///
29+
#pragma once
30+
31+
#include <cmath>
32+
#include <limits>
33+
#include <string>
34+
35+
#include "NumCpp/Core/Internal/Error.hpp"
36+
#include "NumCpp/Core/Types.hpp"
37+
#include "NumCpp/Functions/dot.hpp"
38+
#include "NumCpp/Functions/norm.hpp"
39+
#include "NumCpp/Functions/zeros.hpp"
40+
#include "NumCpp/Linalg/eig.hpp"
41+
#include "NumCpp/NdArray.hpp"
42+
43+
namespace nc::linalg
44+
{
45+
// =============================================================================
46+
// Class Description:
47+
/// Performs the singular value decomposition of a general matrix
48+
template<typename dtype>
49+
class SVD
50+
{
51+
public:
52+
STATIC_ASSERT_ARITHMETIC(dtype);
53+
54+
static constexpr auto TOLERANCE = 1e-12;
55+
56+
// =============================================================================
57+
// Description:
58+
/// Constructor
59+
///
60+
/// @param inMatrix: matrix to perform SVD on
61+
///
62+
explicit SVD(const NdArray<dtype>& inMatrix) :
63+
m_{ inMatrix.shape().rows },
64+
n_{ inMatrix.shape().cols },
65+
s_(1, m_)
66+
{
67+
compute(inMatrix.template astype<double>());
68+
}
69+
70+
// =============================================================================
71+
// Description:
72+
/// the resultant u matrix
73+
///
74+
/// @return u matrix
75+
///
76+
const NdArray<double>& u() const noexcept
77+
{
78+
return u_;
79+
}
80+
81+
// =============================================================================
82+
// Description:
83+
/// the resultant v transpose matrix
84+
///
85+
/// @return v matrix
86+
///
87+
const NdArray<double>& v() const noexcept
88+
{
89+
return v_;
90+
}
91+
92+
// =============================================================================
93+
// Description:
94+
/// the resultant w matrix
95+
///
96+
/// @return s matrix
97+
///
98+
const NdArray<double>& s() const noexcept
99+
{
100+
return s_;
101+
}
102+
103+
// =============================================================================
104+
// Description:
105+
/// Returns the pseudo-inverse of the input matrix
106+
///
107+
/// @return NdArray
108+
///
109+
NdArray<double> pinv()
110+
{
111+
// lazy evaluation
112+
if (pinv_.isempty())
113+
{
114+
auto sInverse = nc::zeros<double>(n_, m_); // transpose
115+
for (auto i = 0u; i < std::min(m_, n_); ++i)
116+
{
117+
if (s_[i] > TOLERANCE)
118+
{
119+
sInverse(i, i) = 1. / s_[i];
120+
}
121+
}
122+
123+
pinv_ = dot(v_, dot(sInverse, u_.transpose()));
124+
}
125+
126+
return pinv_;
127+
}
128+
129+
// =============================================================================
130+
// Description:
131+
/// solves the linear least squares problem
132+
///
133+
/// @param inInput
134+
///
135+
/// @return NdArray
136+
///
137+
NdArray<double> lstsq(const NdArray<double>& inInput)
138+
{
139+
if (inInput.size() != m_)
140+
{
141+
THROW_INVALID_ARGUMENT_ERROR("Invalid matrix dimensions");
142+
}
143+
144+
if (inInput.numCols() == 1)
145+
{
146+
return dot(pinv(), inInput);
147+
}
148+
else
149+
{
150+
const auto input = inInput.copy().reshape(inInput.size(), 1);
151+
return dot(pinv(), input);
152+
}
153+
}
154+
155+
private:
156+
// =============================================================================
157+
// Description:
158+
/// Computes the SVD of the input matrix
159+
///
160+
/// @param A: matrix to perform SVD on
161+
///
162+
void compute(const NdArray<double>& A)
163+
{
164+
const auto At = A.transpose();
165+
const auto AtA = dot(At, A);
166+
const auto AAt = dot(A, At);
167+
168+
const auto& [sigmaSquaredU, U] = eig(AAt);
169+
const auto& [sigmaSquaredV, V] = eig(AtA);
170+
171+
auto rank = 0u;
172+
for (auto i = 0u; i < std::min(m_, n_); ++i)
173+
{
174+
if (sigmaSquaredV[i] > TOLERANCE)
175+
{
176+
s_[i] = std::sqrt(sigmaSquaredV[i]);
177+
rank++;
178+
}
179+
}
180+
181+
// std::cout << U.front() << ' ' << U.back() << '\n';
182+
// std::cout << V.front() << ' ' << V.back() << '\n';
183+
// std::cout << "hello world\n";
184+
185+
u_ = std::move(U);
186+
v_ = std::move(V);
187+
188+
auto Av = NdArray<double>(m_, 1);
189+
for (auto i = 0u; i < rank; ++i)
190+
{
191+
for (auto j = 0u; j < m_; ++j)
192+
{
193+
auto sum = 0.;
194+
for (auto k = 0u; k < n_; ++k)
195+
{
196+
sum += A(j, k) * v_(k, i);
197+
}
198+
Av[j] = sum;
199+
}
200+
201+
const auto normalization = norm(Av).item();
202+
203+
if (normalization > TOLERANCE)
204+
{
205+
for (auto j = 0u; j < m_; ++j)
206+
{
207+
u_(j, i) = Av[j] / normalization;
208+
}
209+
}
210+
}
211+
}
212+
213+
private:
214+
// ===============================Attributes====================================
215+
const uint32 m_{};
216+
const uint32 n_{};
217+
NdArray<double> u_{};
218+
NdArray<double> v_{};
219+
NdArray<double> s_{};
220+
NdArray<double> pinv_{};
221+
};
222+
} // namespace nc::linalg

0 commit comments

Comments
 (0)