Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions src/matrixlu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,19 @@ function rrLU{T}(A::AbstractMatrix{T}; leftorthogonal::Bool=true) where {T}
end

function swaprow!(lu::rrLU{T}, A::AbstractMatrix{T}, a, b) where {T}
lu.rowpermutation[[a, b]] = lu.rowpermutation[[b, a]]
A[[a, b], :] = A[[b, a], :]
lurp = lu.rowpermutation
lurp[a], lurp[b] = lurp[b], lurp[a]
@inbounds for j in axes(A, 2)
A[a, j], A[b, j] = A[b, j], A[a, j]
end
end

function swapcol!(lu::rrLU{T}, A::AbstractMatrix{T}, a, b) where {T}
lu.colpermutation[[a, b]] = lu.colpermutation[[b, a]]
A[:, [a, b]] = A[:, [b, a]]
lucp = lu.colpermutation
lucp[a], lucp[b] = lucp[b], lucp[a]
@inbounds for i in axes(A, 1)
A[i, a], A[i, b] = A[i, b], A[i, a]
end
end

function addpivot!(lu::rrLU{T}, A::AbstractMatrix{T}, newpivot) where {T}
Expand All @@ -111,9 +117,9 @@ function addpivot!(lu::rrLU{T}, A::AbstractMatrix{T}, newpivot) where {T}
swapcol!(lu, A, k, newpivot[2])

if lu.leftorthogonal
A[k+1:end, k] /= A[k, k]
A[k+1:end, k] ./= A[k, k]
else
A[k, k+1:end] /= A[k, k]
A[k, k+1:end] ./= A[k, k]
end

# perform BLAS subroutine manually: A <- -x * transpose(y) + A
Expand All @@ -139,12 +145,12 @@ function _optimizerrlu!(
reltol::Number=1e-14,
abstol::Number=0.0
) where {T}
maxrank = min(maxrank, size(A)...)
maxrank = min(maxrank, size(A, 1), size(A, 2))
maxerror = 0.0
while lu.npivot < maxrank
k = lu.npivot + 1
newpivot = submatrixargmax(abs2, A, k)
lu.error = abs(A[newpivot...])
lu.error = abs(A[newpivot[1], newpivot[2]])
# Add at least 1 pivot to get a well-defined L * U
if (abs(lu.error) < reltol * maxerror || abs(lu.error) < abstol) && lu.npivot > 0
break
Expand All @@ -153,8 +159,8 @@ function _optimizerrlu!(
addpivot!(lu, A, newpivot)
end

lu.L = tril(A[:, 1:lu.npivot])
lu.U = triu(A[1:lu.npivot, :])
lu.L = tril(@view A[:, 1:lu.npivot])
lu.U = triu(@view A[1:lu.npivot, :])
if any(isnan.(lu.L))
error("lu.L contains NaNs")
end
Expand Down Expand Up @@ -271,16 +277,16 @@ function arrlu(
I2 = setdiff(1:matrixsize[1], I0)
lu.rowpermutation = vcat(I0, I2)
L2 = _batchf(I2, J0)
cols2Lmatrix!(L2, lu.U[1:lu.npivot, 1:lu.npivot], leftorthogonal)
lu.L = vcat(lu.L[1:lu.npivot, 1:lu.npivot], L2)
cols2Lmatrix!(L2, (@view lu.U[1:lu.npivot, 1:lu.npivot]), leftorthogonal)
lu.L = vcat((@view lu.L[1:lu.npivot, 1:lu.npivot]), L2)
end

if size(lu.U, 2) < matrixsize[2]
J2 = setdiff(1:matrixsize[2], J0)
lu.colpermutation = vcat(J0, J2)
U2 = _batchf(I0, J2)
rows2Umatrix!(U2, lu.L[1:lu.npivot, 1:lu.npivot], leftorthogonal)
lu.U = hcat(lu.U[1:lu.npivot, 1:lu.npivot], U2)
rows2Umatrix!(U2, (@view lu.L[1:lu.npivot, 1:lu.npivot]), leftorthogonal)
lu.U = hcat((@view lu.U[1:lu.npivot, 1:lu.npivot]), U2)
end

return lu
Expand Down Expand Up @@ -313,8 +319,17 @@ function cols2Lmatrix!(C::AbstractMatrix, P::AbstractMatrix, leftorthogonal::Boo
end

for k in axes(P, 1)
C[:, k] /= P[k, k]
C[:, k+1:end] -= C[:, k] * transpose(P[k, k+1:end])
C[:, k] ./= P[k, k]
# C[:, k+1:end] .-= C[:, k] * transpose(P[k, k+1:end])
x = @view C[:, k]
y = @view P[k, k+1:end]
C̃ = @view C[:, k+1:end]
@inbounds for j in eachindex(axes(C̃, 2), y)
for i in eachindex(axes(C̃, 1), x)
# update `C[:, k+1:end]`
C̃[i, j] -= x[i] * y[j]
end
end
end
return C
end
Expand All @@ -327,8 +342,17 @@ function rows2Umatrix!(R::AbstractMatrix, P::AbstractMatrix, leftorthogonal::Boo
end

for k in axes(P, 1)
R[k, :] /= P[k, k]
R[k+1:end, :] -= P[k+1:end, k] * transpose(R[k, :])
R[k, :] ./= P[k, k]
# R[k+1:end, :] -= P[k+1:end, k] * transpose(R[k, :])
x = @view P[k+1:end, k]
y = @view R[k, :]
R̃ = @view R[k+1:end, :]
@inbounds for j in eachindex(axes(R̃, 2), y)
for i in eachindex(axes(R̃, 1), x)
# update R[k+1:end, :]
R̃[i, j] -= x[i] * y[j]
end
end
end
return R
end
Expand Down
2 changes: 1 addition & 1 deletion src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ end

function pushrandomsubset!(subset, set, n::Int)
topush = randomsubset(setdiff(set, subset), n)
push!(subset, topush...)
append!(subset, topush)
nothing
end

Expand Down