diff --git a/src/Unitful.jl b/src/Unitful.jl index 343f49b6..0b9a42ce 100644 --- a/src/Unitful.jl +++ b/src/Unitful.jl @@ -21,8 +21,8 @@ import Base: steprange_last, unsigned end import Dates -import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal -import LinearAlgebra: istril, istriu, norm +import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal, Adjoint, Transpose, AdjOrTransAbsMat +import LinearAlgebra: istril, istriu, norm, mul!, dot, /, \, inv, pinv import Random import ConstructionBase: constructorof @@ -69,5 +69,6 @@ include("logarithm.jl") include("complex.jl") include("pkgdefaults.jl") include("dates.jl") +include("linearalgebra.jl") end diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl new file mode 100644 index 00000000..77433605 --- /dev/null +++ b/src/linearalgebra.jl @@ -0,0 +1,84 @@ + +# Multiplication + +function mul!(C::StridedVecOrMat{<:AbstractQuantity{T}}, + A::StridedMatrix{<:AbstractQuantity{T}}, + B::StridedVecOrMat{<:AbstractQuantity{T}}, + alpha::Number, beta::Number) where {T<:Base.HWNumber} + _mul!(C, A, B, alpha, beta) +end + +function mul!(C::StridedVecOrMat{<:AbstractQuantity{T}}, + A::AdjOrTransAbsMat{<:AbstractQuantity{T}, <:StridedMatrix}, + B::StridedVecOrMat{<:AbstractQuantity{T}}, + alpha::Number, beta::Number) where {T<:Base.HWNumber} + _mul!(C, A, B, alpha, beta) +end + +function _mul!(C, A, B, alpha, beta) + if unit(beta) != NoUnits + throw(DimensionError("beta", 1.0)) + elseif unit(eltype(C)) != unit(eltype(A)) * unit(eltype(B)) * unit(alpha) + throw(DimensionError("A * B .* α", "C")) + end + C0 = ustrip(C) + A0 = ustrip(A) + B0 = ustrip(B) + mul!(C0, A0, B0) + _linearalgebra_count() + return C +end + +function dot(A::StridedArray{<:AbstractQuantity{T}}, + B::StridedArray{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} + A0 = ustrip(A) + B0 = ustrip(B) + C0 = dot(A0, B0) + _linearalgebra_count() + C = C0 * unit(eltype(A)) * unit(eltype(B)) + return C +end + +# Division + +function (\)(A::StridedMatrix{<:AbstractQuantity{T}}, + B::StridedVecOrMat{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} + A0 = ustrip(A) + B0 = ustrip(B) + C0 = A0 \ B0 + _linearalgebra_count() + u = unit(eltype(B)) / unit(eltype(A)) + Tu = typeof(one(eltype(C0)) * u) + return reinterpret(Tu, C0) +end + +function (/)(A::StridedVecOrMat{<:AbstractQuantity{T}}, + B::StridedVecOrMat{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} + A0 = ustrip(A) + B0 = ustrip(B) + C0 = A0 / B0 + _linearalgebra_count() + u = unit(eltype(A)) / unit(eltype(B)) + Tu = typeof(one(eltype(C0)) * u) + return reinterpret(Tu, C0) +end + +function inv(A::StridedMatrix{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} + C0 = inv(ustrip(A)) + _linearalgebra_count() + u = inv(unit(eltype(A))) + Tu = typeof(one(eltype(C0)) * u) + return reinterpret(Tu, C0) +end + +function pinv(A::StridedMatrix{<:AbstractQuantity{T}}; kw...) where {T<:Base.HWNumber} + C0 = pinv(ustrip(A); kw...) + _linearalgebra_count() + u = inv(unit(eltype(A))) + Tu = typeof(one(eltype(C0)) * u) + return reinterpret(Tu, C0) +end + +# This function is re-defined during testing, to check we hit the fast path: +_linearalgebra_count() = nothing + diff --git a/src/utils.jl b/src/utils.jl index 53f12b95..85b1bda7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,6 +28,27 @@ true @inline ustrip(u::Units, x) = ustrip(uconvert(u, x)) @inline ustrip(T::Type, u::Units, x) = convert(T, ustrip(u, x)) +""" + ustrip(u::Units, xs::AbstractArray{<:Quantity}) + +This broadcasts `ustrip.(u, xs)`, unless `xs isa StridedArray` whose units match `u`, +in which case it reinterprets, which saves making a copy. + +```jldoctest +julia> ustrip(u"m", [1, 2, 3]u"m") isa Base.ReinterpretArray{Int} # fast path +true + +julia> ustrip(u"m", [1, 2, 3]u"mm") == [1//1000, 2//1000, 3//1000] # mismatch requires slow path +true +``` +""" +ustrip(u::Units, xs::AbstractArray) = ustrip.(u, xs) +function ustrip(u::Units, xs::StridedArray{T}) where {T} + dimension(u) == dimension(T) || return ustrip.(u, xs) + isequal(promote(true * u, oneunit(T))...) || return ustrip.(u, xs) + return reinterpret(numtype(T), xs) +end + """ ustrip(x::Number) ustrip(x::Quantity) @@ -50,7 +71,8 @@ true @inline ustrip(x::Missing) = missing """ - ustrip(x::Array{Q}) where {Q <: Quantity} + ustrip(x::Array{Q}) where {Q <: Quantity{T}}} + Strip units from an `Array` by reinterpreting to type `T`. The resulting `Array` is a not a copy, but rather a unit-stripped view into array `x`. Because the units are removed, information may be lost and this should be used with some care. @@ -75,7 +97,7 @@ julia> a[1] = 3u"m"; b 2 ``` """ -@inline ustrip(A::Array{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A) +@inline ustrip(A::StridedArray{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A) @deprecate(ustrip(A::AbstractArray{T}) where {T<:Number}, ustrip.(A)) @@ -90,6 +112,15 @@ ustrip(A::Diagonal) = Diagonal(ustrip(A.diag)) ustrip(A::Bidiagonal) = Bidiagonal(ustrip(A.dv), ustrip(A.ev), ifelse(istriu(A), :U, :L)) ustrip(A::Tridiagonal) = Tridiagonal(ustrip(A.dl), ustrip(A.d), ustrip(A.du)) ustrip(A::SymTridiagonal) = SymTridiagonal(ustrip(A.dv), ustrip(A.ev)) +ustrip(A::Adjoint) = adjoint(ustrip(parent(A))) +ustrip(A::Transpose) = transpose(ustrip(parent(A))) + +ustrip(u::Units, A::Diagonal) = Diagonal(ustrip(u, A.diag)) +ustrip(u::Units, A::Bidiagonal) = Bidiagonal(ustrip(u, A.dv), ustrip(u, A.ev), ifelse(istriu(A), :U, :L)) +ustrip(u::Units, A::Tridiagonal) = Tridiagonal(ustrip(u, A.dl), ustrip(u, A.d), ustrip(u, A.du)) +ustrip(u::Units, A::SymTridiagonal) = SymTridiagonal(ustrip(u, A.dv), ustrip(u, A.ev)) +ustrip(u::Units, A::Adjoint) = adjoint(ustrip(u, parent(A))) +ustrip(u::Units, A::Transpose) = transpose(ustrip(u, parent(A))) """ unit(x::Quantity{T,D,U}) where {T,D,U} diff --git a/test/runtests.jl b/test/runtests.jl index aec5a66f..321e35c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -84,6 +84,66 @@ const colon = Base.:(:) @test ConstructionBase.constructorof(typeof(1.0m))(2) === 2m end +@testset "LinearAlgebra functions" begin + CNT = Ref(0) + Unitful._linearalgebra_count() = (CNT[] += 1; nothing) + @testset "> Matrix multiplication: *" begin + M = rand(3,3) .* u"m" + M_ = view(M,:,1:3) + v = rand(3) .* u"V" + v_ = view(v, 1:3) + + CNT[] = 0 + + @test unit(first(M * M)) == u"m*m" + @test M * M == M_ * M == M * M_ == M_ * M_ + + @test unit(first(M * v)) == u"m*V" + @test M * v == M_ * v == M * v_ == M_ * v_ + + VERSION >= v"1.3" && @test CNT[] == 10 + + @test unit(first(v' * M)) == u"m*V" + @test v' * M == v_' * M == v_' * M == v_' * M_ + + VERSION >= v"1.3" && @test CNT[] == 15 + + @test unit(v' * v) == u"V*V" + @test v' * v == v_' * v == v_' * v == v_' * v_ + + VERSION >= v"1.3" && @test CNT[] == 20 + + # Mixed with & without units + N = rand(3,3) + w = rand(3) + + CNT[] = 0 + + @test unit(first(M * N)) == u"m" + @test unit(first(N * M)) == u"m" + + @test unit(first(M * w)) == u"m" + @test unit(first(N * v)) == u"V" + + @show CNT[] # not specialised yet + + end + @testset "> Matrix multiplication: mul!" begin + A = rand(3,3) .* u"m" + B = rand(3,3) .* u"m" + C = fill(zero(eltype(A*B)), 3, 3) + CNT[] = 0 + + mul!(C, A, B) + if VERSION >= v"1.3" # the 5-arm mul! exists + mul!(C, A, B, true, true) + mul!(C, A, B, 3, 7) # not specialised yet + + @show CNT[] + end + end +end + @testset "Types" begin @test Base.complex(Quantity{Float64,NoDims,NoUnits}) == Quantity{Complex{Float64},NoDims,NoUnits} @@ -1251,6 +1311,16 @@ end @test_deprecated ustrip([1,2]) @test ustrip.([1,2]) == [1,2] @test typeof(ustrip([1u"m", 2u"m"])) <: Base.ReinterpretArray{Int,1} + + # With target type + @test @inferred(ustrip(u"m", [1, 2]u"m")) == [1,2] + @test @inferred(ustrip(u"km", [1, 2]u"m")) == [1//1000, 2//1000] + @test typeof(ustrip(u"m", [1, 2]u"m")) <: Base.ReinterpretArray{Int,1} + @test typeof(ustrip(u"m/ms", [1, 2]*(u"km/s"))) <: Base.ReinterpretArray{Int,1} + + # Structured matrices + @test typeof(ustrip(adjoint([1,2]u"m"))) <: Adjoint{Int} + @test typeof(ustrip(transpose([1 2; 3 4]u"m"))) <: Transpose{Int} @test typeof(ustrip(Diagonal([1,2]u"m"))) <: Diagonal{Int} @test typeof(ustrip(Bidiagonal([1,2,3]u"m", [1,2]u"m", :U))) <: Bidiagonal{Int} @@ -1258,6 +1328,9 @@ end Tridiagonal{Int} @test typeof(ustrip(SymTridiagonal([1,2,3]u"m", [4,5]u"m"))) <: SymTridiagonal{Int} + + @test typeof(ustrip(u"m", adjoint([1,2]u"m"))) <: Adjoint{Int} + @test typeof(ustrip(u"m", Diagonal([1,2]u"m"))) <: Diagonal{Int} end @testset ">> Linear algebra" begin @test istril([1 1; 0 1]u"m") == false