From 284da9cd8d9373b60f64fab005c462e93d3ed56e Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Fri, 11 Mar 2022 22:22:55 +0100 Subject: [PATCH] Implement promotion rules for `CategoricalArray` These are rarely needed, but an exception is when storing `CategoricalArray` objects within arrays, as the `[a1, a2]` syntax uses promotion to choose the element type of the result. This previously failed as it hit a fallback `promote_result` method defined for `AbstractArray` in Base in range.jl, which tried to convert `CategoricalArray`s to `Array` by calling nonexistent `convert` methods for `CategoricalValue`. --- src/array.jl | 13 +++++++++++ test/13_arraycommon.jl | 50 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/src/array.jl b/src/array.jl index 6c3cd8b2..dec5270e 100644 --- a/src/array.jl +++ b/src/array.jl @@ -302,6 +302,19 @@ CategoricalMatrix(A::CategoricalArray{T, 2, R}; ordered::Bool=_isordered(A)) where {T, R} = CategoricalArray{T, 2, R}(A, levels=levels, ordered=ordered) +## Promotion methods + +Base.promote_rule(::Type{<:CategoricalArray{S}}, + ::Type{<:CategoricalArray{T}}) where {S, T} = + CategoricalArray{cat_promote_type(S, T)} +Base.promote_rule(::Type{<:CategoricalArray{S, N}}, + ::Type{<:CategoricalArray{T, N}}) where {S, T, N} = + CategoricalArray{cat_promote_type(S, T), N} +Base.promote_rule(::Type{<:CategoricalArray{S, N, R1}}, + ::Type{<:CategoricalArray{T, N, R2}}) where + {S, T, N, R1<:Integer, R2<:Integer} = + CategoricalArray{cat_promote_type(S, T), N, promote_type(R1, R2)} + ## Conversion methods # From AbstractArray diff --git a/test/13_arraycommon.jl b/test/13_arraycommon.jl index 2ac04c51..4b86c6d7 100644 --- a/test/13_arraycommon.jl +++ b/test/13_arraycommon.jl @@ -2164,4 +2164,54 @@ end end end +@testset "promotion" begin + @test promote_type(CategoricalVector{Int}, + CategoricalVector{String}) == + CategoricalVector{Union{Int, String}} + @test promote_type(CategoricalVector{Int, UInt32}, + CategoricalVector{String, UInt32}) == + CategoricalVector{Union{Int, String}, UInt32} + @test promote_type(CategoricalArray{Int, UInt32}, + CategoricalArray{String, UInt32}) == + CategoricalArray{Union{Int, String}, UInt32} + @test promote_type(CategoricalVector{Int, UInt32}, + CategoricalMatrix{String, UInt32}) == + CategoricalArray{Union{Int, String}} + @test promote_type(CategoricalVector{Int, UInt8}, + CategoricalVector{String, UInt16}) == + CategoricalVector{Union{Int, String}, UInt16} + + @test promote_type(CategoricalVector{Int8}, + CategoricalVector{Float64}) == + CategoricalVector{Float64} + @test promote_type(CategoricalVector{Int8, UInt32}, + CategoricalVector{Float64, UInt32}) == + CategoricalVector{Float64, UInt32} + @test promote_type(CategoricalArray{Int8, UInt32}, + CategoricalArray{Float64, UInt32}) == + CategoricalArray{Float64, UInt32} + @test promote_type(CategoricalVector{Int8, UInt32}, + CategoricalMatrix{Float64, UInt32}) == + CategoricalArray{Float64} + @test promote_type(CategoricalVector{Int8, UInt8}, + CategoricalVector{Float64, UInt16}) == + CategoricalVector{Float64, UInt16} + + @test [CategoricalVector([1, 2]), + CategoricalVector(["a", "b"])] isa + Vector{CategoricalVector{Union{Int, String}, UInt32}} + @test [CategoricalVector([1, missing]), + CategoricalVector(["a", "b"])] isa + Vector{CategoricalVector{Union{Int, String, Missing}, UInt32}} + @test [CategoricalVector([1, missing]), + CategoricalVector(["a", missing])] isa + Vector{CategoricalVector{Union{Int, String, Missing}, UInt32}} + @test [CategoricalVector([Int8(1), missing]), + CategoricalVector([Int16(2)])] isa + Vector{CategoricalVector{Union{Int16, Missing}, UInt32}} + @test [CategoricalVector([1, 2]), + CategoricalMatrix(["a" "b"])] isa + Vector{CategoricalArray{Union{Int, String}}} +end + end