Skip to content

Commit

Permalink
More conversion in PDMat
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Dec 10, 2024
1 parent 592634b commit de8c335
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 17 deletions.
4 changes: 3 additions & 1 deletion src/chol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ if HAVE_CHOLMOD
chol_upper(cf::CholTypeSparse) = cf.UP
end

mattype(::Cholesky{T,S}) where {T,S} = S

# Interface for `Cholesky`

dim(A::Cholesky) = LinearAlgebra.checksquare(A)
Expand Down Expand Up @@ -73,7 +75,7 @@ function invquad(A::Cholesky, x::AbstractVector)
@check_argdims size(A, 1) == size(x, 1)
return sum(abs2, chol_lower(A) \ x)
end
function invquad(A::Cholesky, X::AbstractMatrix)
function invquad(A::Cholesky, X::AbstractMatrix)
@check_argdims size(A, 1) == size(X, 1)
Z = chol_lower(A) \ X
return vec(sum(abs2, Z; dims=1))
Expand Down
54 changes: 39 additions & 15 deletions src/pdmat.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
"""
Full positive definite matrix together with a Cholesky factorization object.
"""
struct PDMat{T<:Real,S<:AbstractMatrix} <: AbstractPDMat{T}
struct PDMat{T<:Real,S<:AbstractMatrix{T}} <: AbstractPDMat{T}
mat::S
chol::Cholesky{T,S}

PDMat{T,S}(m::AbstractMatrix{T},c::Cholesky{T,S}) where {T,S} = new{T,S}(m,c)
function PDMat{T,S}(m::AbstractMatrix, c::Cholesky) where {T,S}
d = LinearAlgebra.checksquare(m)
if size(c, 1) != d
throw(DimensionMismatch("Dimensions of mat and chol are inconsistent."))
end
# in principle we might want to check that `c` is a Cholesky factorization of `m`,
# but that's slow
return new{T,S}(m,c)
end
end
PDMat{T,S}(pdm::PDMat) where {T,S} = PDMat{T,S}(pdm.mat, pdm.chol)

function PDMat(mat::AbstractMatrix,chol::Cholesky{T,S}) where {T,S}
d = LinearAlgebra.checksquare(mat)
if size(chol, 1) != d
throw(DimensionMismatch("Dimensions of mat and chol are inconsistent."))
end
PDMat{T,S}(convert(S, mat), chol)
function PDMat{T}(m::AbstractMatrix, c::Cholesky) where T
c = Cholesky{T}(c)
return PDMat{T,mattype(c)}(m, c)
end
PDMat{T}(pdm::PDMat) where T = PDMat{T}(pdm.mat, pdm.chol)

PDMat(mat::AbstractMatrix,chol::Cholesky{T,S}) where {T,S} = PDMat{T,S}(mat, chol)

function PDMat{T,S}(mat::AbstractMatrix) where {T,S}
mat = convert(S, mat)
return PDMat{T,S}(mat, cholesky(mat))
end
function PDMat{T}(mat::AbstractMatrix) where T
mat = convert(AbstractMatrix{T}, mat)
return PDMat{T}(mat, cholesky(mat))
end
PDMat(mat::AbstractMatrix) = PDMat(mat, cholesky(mat))
PDMat(fac::Cholesky) = PDMat(AbstractMatrix(fac), fac)

function PDMat{T,S}(c::Cholesky) where {T,S}
c = Cholesky{T,S}(c)
return PDMat{T,S}(AbstractMatrix(c), c)
end
function PDMat{T}(c::Cholesky) where T
c = Cholesky{T}(c)
return PDMat{T}(AbstractMatrix(c), c)
end
PDMat(c::Cholesky) = PDMat(AbstractMatrix(c), c)

function Base.getproperty(a::PDMat, s::Symbol)
if s === :dim
Expand All @@ -30,13 +56,11 @@ Base.propertynames(::PDMat) = (:mat, :chol, :dim)
AbstractPDMat(A::Cholesky) = PDMat(A)

### Conversion
Base.convert(::Type{PDMat{T,S}}, a::PDMat{T,S}) where {T<:Real,S<:AbstractMatrix{T}} = a
Base.convert(::Type{PDMat{T,S}}, a::PDMat) where {T<:Real,S<:AbstractMatrix{T}} = PDMat{T,S}(a)
Base.convert(::Type{PDMat{T}}, a::PDMat{T}) where {T<:Real} = a
function Base.convert(::Type{PDMat{T}}, a::PDMat) where {T<:Real}
chol = convert(Cholesky{T}, a.chol)
S = typeof(chol.factors)
mat = convert(S, a.mat)
return PDMat{T,S}(mat, chol)
end
Base.convert(::Type{PDMat{T}}, a::PDMat) where {T<:Real} = PDMat{T}(a)

Base.convert(::Type{AbstractPDMat{T}}, a::PDMat) where {T<:Real} = convert(PDMat{T}, a)

### Basics
Expand Down
20 changes: 19 additions & 1 deletion test/pdmtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,31 @@ using Test
M = convert(Array{T,2}, [4. -2. -1.; -2. 5. -1.; -1. -1. 6.])
V = convert(Array{T,1}, [1.5, 2.5, 2.0])
X = convert(T,2.0)
f64M = Float64.(M)

@testset "PDMat from Matrix" begin
pdf64M = PDMat(f64M)
test_pdmat(PDMat(M), M, cmat_eq=true, verbose=1)
test_pdmat(PDMat{Float64}(M), f64M, cmat_eq=true, verbose=1)
test_pdmat(PDMat{Float64,Matrix{Float64}}(M), f64M, cmat_eq=true, verbose=1)
@test_throws TypeError PDMat{Float32,Matrix{Float64}}(M)
end
@testset "PDMat from PDMat" begin
pdM = PDMat(M)
pdf64M = PDMat(f64M)
test_pdmat(PDMat(pdM), M, cmat_eq=true, verbose=1)
test_pdmat(PDMat{Float64}(pdf64M), f64M, cmat_eq=true, verbose=1)
test_pdmat(PDMat{Float64,Matrix{Float64}}(pdf64M), f64M, cmat_eq=true, verbose=1)
@test_throws TypeError PDMat{Float32,Matrix{Float64}}(pdM)
end
@testset "PDMat from Cholesky" begin
cholL = Cholesky(Matrix(transpose(cholesky(M).factors)), 'L', 0)
cholLf64 = Cholesky(Matrix(transpose(cholesky(f64M).factors)), 'L', 0)
test_pdmat(PDMat(cholL), M, cmat_eq=true, verbose=1)
test_pdmat(PDMat{Float64}(cholLf64), f64M, cmat_eq=true, verbose=1)
if Base.VERSION >= v"1.12.0-DEV.1654" # julia #56562
test_pdmat(PDMat{Float64,Matrix{Float64}}(cholLf64), f64M, cmat_eq=true, verbose=1)
end
end
@testset "PDiagMat" begin
test_pdmat(PDiagMat(V), Matrix(Diagonal(V)), cmat_eq=true, verbose=1)
Expand Down Expand Up @@ -145,7 +163,7 @@ using Test
# right division not defined for CHOLMOD:
# `rdiv!(::Matrix{Float64}, ::SuiteSparse.CHOLMOD.Factor{Float64})` not defined
if !HAVE_CHOLMOD
z = x / PDSparseMat(sparse(first(A), 1, 1))
z = x / PDSparseMat(sparse(first(A), 1, 1))
@test typeof(z) === typeof(y)
@test size(z) == size(y)
@test z y
Expand Down
8 changes: 8 additions & 0 deletions test/specialarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ using StaticArrays
@test PDMat(S, C) === PDS
@test @allocated(PDMat(S)) == @allocated(PDMat(C)) == @allocated(PDMat(S, C))

if Base.VERSION >= v"1.12.0-DEV.1654" # julia #56562
A = PDMat(Matrix{Float64}(I, 2, 2))
B = PDMat(SMatrix{2,2,Float64}(I))
@test !isa(A.mat, typeof(B.mat))
S = convert(typeof(B), A)
@test isa(S.mat, typeof(B.mat))
end

# Diagonal matrix
D = PDiagMat(@SVector(rand(4)))
@test D isa PDiagMat{Float64, <:SVector{4, Float64}}
Expand Down

0 comments on commit de8c335

Please sign in to comment.