From 3f1e4bcd1b6230b7177ba0b7b252b9a83cfa9b94 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 13 Mar 2022 11:00:42 +0100 Subject: [PATCH] Implement promotion rules for `CategoricalArray` (#384) These are rarely needed, but an exception is when storing `CategoricalArray` objects within arrays, as the `[a1, a2]` syntax uses promotion to choose the element type of the result. This previously failed as it hit a fallback `promote_result` method defined for `AbstractArray` in Base in range.jl, which tried to convert `CategoricalArray`s to `Array` by calling nonexistent `convert` methods for `CategoricalValue`. --- src/array.jl | 13 +++++++++++ test/13_arraycommon.jl | 50 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/src/array.jl b/src/array.jl index a31e4669..c95c97b8 100644 --- a/src/array.jl +++ b/src/array.jl @@ -302,6 +302,19 @@ CategoricalMatrix(A::CategoricalArray{T, 2, R}; ordered::Bool=_isordered(A)) where {T, R} = CategoricalArray{T, 2, R}(A, levels=levels, ordered=ordered) +## Promotion methods + +Base.promote_rule(::Type{<:CategoricalArray{S}}, + ::Type{<:CategoricalArray{T}}) where {S, T} = + CategoricalArray{cat_promote_type(S, T)} +Base.promote_rule(::Type{<:CategoricalArray{S, N}}, + ::Type{<:CategoricalArray{T, N}}) where {S, T, N} = + CategoricalArray{cat_promote_type(S, T), N} +Base.promote_rule(::Type{<:CategoricalArray{S, N, R1}}, + ::Type{<:CategoricalArray{T, N, R2}}) where + {S, T, N, R1<:Integer, R2<:Integer} = + CategoricalArray{cat_promote_type(S, T), N, promote_type(R1, R2)} + ## Conversion methods # From AbstractArray diff --git a/test/13_arraycommon.jl b/test/13_arraycommon.jl index bb6823c0..1a913d65 100644 --- a/test/13_arraycommon.jl +++ b/test/13_arraycommon.jl @@ -2224,4 +2224,54 @@ end end end +@testset "promotion" begin + @test promote_type(CategoricalVector{Int}, + CategoricalVector{String}) == + CategoricalVector{Union{Int, String}} + @test promote_type(CategoricalVector{Int, UInt32}, + CategoricalVector{String, UInt32}) == + CategoricalVector{Union{Int, String}, UInt32} + @test promote_type(CategoricalArray{Int, UInt32}, + CategoricalArray{String, UInt32}) == + CategoricalArray{Union{Int, String}, UInt32} + @test promote_type(CategoricalVector{Int, UInt32}, + CategoricalMatrix{String, UInt32}) == + CategoricalArray{Union{Int, String}} + @test promote_type(CategoricalVector{Int, UInt8}, + CategoricalVector{String, UInt16}) == + CategoricalVector{Union{Int, String}, UInt16} + + @test promote_type(CategoricalVector{Int8}, + CategoricalVector{Float64}) == + CategoricalVector{Float64} + @test promote_type(CategoricalVector{Int8, UInt32}, + CategoricalVector{Float64, UInt32}) == + CategoricalVector{Float64, UInt32} + @test promote_type(CategoricalArray{Int8, UInt32}, + CategoricalArray{Float64, UInt32}) == + CategoricalArray{Float64, UInt32} + @test promote_type(CategoricalVector{Int8, UInt32}, + CategoricalMatrix{Float64, UInt32}) == + CategoricalArray{Float64} + @test promote_type(CategoricalVector{Int8, UInt8}, + CategoricalVector{Float64, UInt16}) == + CategoricalVector{Float64, UInt16} + + @test [CategoricalVector([1, 2]), + CategoricalVector(["a", "b"])] isa + Vector{CategoricalVector{Union{Int, String}, UInt32}} + @test [CategoricalVector([1, missing]), + CategoricalVector(["a", "b"])] isa + Vector{CategoricalVector{Union{Int, String, Missing}, UInt32}} + @test [CategoricalVector([1, missing]), + CategoricalVector(["a", missing])] isa + Vector{CategoricalVector{Union{Int, String, Missing}, UInt32}} + @test [CategoricalVector([Int8(1), missing]), + CategoricalVector([Int16(2)])] isa + Vector{CategoricalVector{Union{Int16, Missing}, UInt32}} + @test [CategoricalVector([1, 2]), + CategoricalMatrix(["a" "b"])] isa + Vector{CategoricalArray{Union{Int, String}}} +end + end