Skip to content

Commit

Permalink
Support Center(::Function), Scale(::Function), `ZScore(;center::F…
Browse files Browse the repository at this point in the history
…unction, scale::Function)` (#21)

* Support Center(::Function)

* Scale{T}

* docstring consistency

* tests for Scale{T}

* now nonbreaking

* fixes

* even better way

* ZScore

* patch bump

* JuliaFormatter

* slight organizational change

* make tests wackier
  • Loading branch information
palday authored Jan 12, 2022
1 parent 183fdf7 commit 8f52f69
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 69 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StandardizedPredictors"
uuid = "5064a6a7-f8c2-40e2-8bdc-797ec6f1ae18"
authors = "Beacon Biosignals, inc."
version = "0.1.3"
version = "0.1.4"

[deps]
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
9 changes: 5 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
using StandardizedPredictors
using Documenter

DocMeta.setdocmeta!(StandardizedPredictors, :DocTestSetup, :(using StandardizedPredictors); recursive=true)
DocMeta.setdocmeta!(StandardizedPredictors, :DocTestSetup, :(using StandardizedPredictors);
recursive=true)

makedocs(modules=[StandardizedPredictors],
makedocs(; modules=[StandardizedPredictors],
authors="Beacon Biosignals, Inc.",
repo="https://github.com/beacon-biosignals/StandardizedPredictors.jl/blob/{commit}{path}#{line}",
sitename="StandardizedPredictors.jl",
format=Documenter.HTML(prettyurls=get(ENV, "CI", "false") == "true",
format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://beacon-biosignals.github.io/StandardizedPredictors.jl",
assets=String[]),
pages=["Home" => "index.md",
"API" => "api.md"])

deploydocs(repo="github.com/beacon-biosignals/StandardizedPredictors.jl",
deploydocs(; repo="github.com/beacon-biosignals/StandardizedPredictors.jl",
devbranch="main",
push_preview=true)
36 changes: 24 additions & 12 deletions src/StandardizedPredictors.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
module StandardizedPredictors

export
center,
center!,
Center,
CenteredTerm,
scale,
scale!,
Scale,
ScaledTerm,
zscore, # from StatsBase
zscore!, # from StatsBase
ZScore,
ZScoredTerm
center,
center!,
Center,
CenteredTerm,
scale,
scale!,
Scale,
ScaledTerm,
zscore, # from StatsBase
zscore!, # from StatsBase
ZScore,
ZScoredTerm

using StatsModels
using StatsBase
using Statistics

"""
_standard(xs::AbstractArray, val)
Translate an abstract standardization value to a concrete one based on `xs`.
`nothing` and already concrete `Number` `val`s are passed through.
Otherwise, `val(xs)` is returned.
"""
_standard(::AbstractArray, t::Number) = t
_standard(::AbstractArray, ::Nothing) = nothing
_standard(xs::AbstractArray, t) = t(xs)

include("utils.jl")
include("centering.jl")
include("scaling.jl")
Expand Down
27 changes: 19 additions & 8 deletions src/centering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ StatsModels.Schema with 1 entry:
x => center(x, 5)
```
Or center will be automatically computed if left out:
You can use a function to compute the center value:
julia> schema((x=collect(1:10), ), Dict(:x => Center(median)))
StatsModels.Schema with 1 entry:
x => x(centered: 5.5)
Or [`center`](@ref) will be automatically computed if omitted:
```
julia> schema((x=collect(1:10), ), Dict(:x => Center()))
Expand All @@ -67,7 +73,7 @@ struct Center
end

Center() = Center(nothing)

Center(xs::AbstractArray, c::Center) = Center(_standard(xs, c.center))

"""
struct CenteredTerm{T,C} <: AbstractTerm
Expand Down Expand Up @@ -136,24 +142,27 @@ StatsModels.Schema with 1 entry:
x => center(x, 5.5)
```
"""
struct CenteredTerm{T,C} <: AbstractTerm
term::T
center::C
end

StatsModels.concrete_term(t::Term, xs::AbstractArray, c::Center) =
center(StatsModels.concrete_term(t, xs, nothing), c)
function StatsModels.concrete_term(t::Term, xs::AbstractArray, c::Center)
return center(StatsModels.concrete_term(t, xs, nothing), Center(xs, c))
end

# run-time constructors:
center(t::ContinuousTerm, c::Center) = CenteredTerm(t, something(c.center, t.mean))
center(t::ContinuousTerm, c) = CenteredTerm(t, c)
center(t::ContinuousTerm) = CenteredTerm(t, t.mean)
center(t::AbstractTerm) = throw(ArgumentError("can only compute center for ContinuousTerm; must provide center value via center(t, c)"))
function center(t::AbstractTerm)
throw(ArgumentError("can only compute center for ContinuousTerm; must provide center value via center(t, c)"))
end

function center(t::AbstractTerm, c::Center)
c.center !== nothing || throw(ArgumentError("can only compute center for ContinuousTerm; must provide center via center(t, c)"))
c.center !== nothing ||
throw(ArgumentError("can only compute center for ContinuousTerm; must provide center via center(t, c)"))
return CenteredTerm(t, c.center)
end

Expand All @@ -170,7 +179,9 @@ end
# coef table: "x(centered: 5.5)"
Base.show(io::IO, t::CenteredTerm) = print(io, "$(t.term)(centered: $(_round(t.center)))")
# regular show: "x(centered: 5.5)", used in displaying schema dicts
Base.show(io::IO, ::MIME"text/plain", t::CenteredTerm) = print(io, "$(t.term)(centered: $(_round(t.center)))")
function Base.show(io::IO, ::MIME"text/plain", t::CenteredTerm)
return print(io, "$(t.term)(centered: $(_round(t.center)))")
end
# long show: "x(centered: 5.5)"

# statsmodels glue code:
Expand Down
27 changes: 19 additions & 8 deletions src/scaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ StatsModels.Schema with 1 entry:
x => x(scaled: 5))
```
Or scale will be automatically computed if left out:
You can use a function to compute the scale value:
julia> schema((x=collect(1:10), ), Dict(:x => Scale(mad)))
StatsModels.Schema with 1 entry:
x => x(scaled: 3.71)
Or [`scale`](@ref) will be automatically computed if left out:
```
julia> schema((x=collect(1:10), ), Dict(:x => Scale()))
Expand All @@ -62,7 +68,7 @@ struct Scale
end

Scale() = Scale(nothing)

Scale(xs::AbstractArray, s::Scale) = Scale(_standard(xs, s.scale))

"""
struct ScaledTerm{T,S} <: AbstractTerm
Expand Down Expand Up @@ -130,24 +136,27 @@ StatsModels.Schema with 1 entry:
x => scale(x, 3.03)
```
"""
struct ScaledTerm{T,S} <: AbstractTerm
term::T
scale::S
end

StatsModels.concrete_term(t::Term, xs::AbstractArray, s::Scale) =
scale(StatsModels.concrete_term(t, xs, nothing), s)
function StatsModels.concrete_term(t::Term, xs::AbstractArray, s::Scale)
return scale(StatsModels.concrete_term(t, xs, nothing), Scale(xs, s))
end

# run-time constructors:
scale(t::ContinuousTerm, s::Scale) = ScaledTerm(t, something(s.scale, sqrt(t.var)))
scale(t::ContinuousTerm, s) = ScaledTerm(t, s)
scale(t::ContinuousTerm) = ScaledTerm(t, sqrt(t.var))
scale(t::AbstractTerm) = throw(ArgumentError("can only compute scale for ContinuousTerm; must provide scale value via scale(t, s)"))
function scale(t::AbstractTerm)
throw(ArgumentError("can only compute scale for ContinuousTerm; must provide scale value via scale(t, s)"))
end

function scale(t::AbstractTerm, s::Scale)
s.scale !== nothing || throw(ArgumentError("can only compute scale for ContinuousTerm; must provide scale via scale(t, s)"))
s.scale !== nothing ||
throw(ArgumentError("can only compute scale for ContinuousTerm; must provide scale via scale(t, s)"))
return ScaledTerm(t, s.scale)
end

Expand All @@ -165,7 +174,9 @@ end
# coef table: "x(scaled: 5.5)"
Base.show(io::IO, t::ScaledTerm) = print(io, "$(t.term)(scaled: $(_round(t.scale)))")
# regular show: "x(scaled: 5.5)", used in displaying schema dicts
Base.show(io::IO, ::MIME"text/plain", t::ScaledTerm) = print(io, "$(t.term)(scaled: $(_round(t.scale)))")
function Base.show(io::IO, ::MIME"text/plain", t::ScaledTerm)
return print(io, "$(t.term)(scaled: $(_round(t.scale)))")
end
# long show: "x(scaled: 5.5)"

# statsmodels glue code:
Expand Down
44 changes: 33 additions & 11 deletions src/zscoring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ end

ZScore(; center=nothing, scale=nothing) = ZScore(center, scale)

function ZScore(xs::AbstractArray, zs::ZScore)
center = _standard(xs, zs.center)
scale = _standard(xs, zs.scale)
return ZScore(center, scale)
end

"""
struct ZScoredTerm{T,C,S} <: AbstractTerm
Expand Down Expand Up @@ -105,34 +111,50 @@ struct ZScoredTerm{T,C,S} <: AbstractTerm
scale::S
end

StatsModels.concrete_term(t::Term, xs::AbstractArray, z::ZScore) =
zscore(StatsModels.concrete_term(t, xs, nothing), z)
function StatsModels.concrete_term(t::Term, xs::AbstractArray, z::ZScore)
return zscore(StatsModels.concrete_term(t, xs, nothing), ZScore(xs, z))
end

# run-time constructors:
StatsBase.zscore(t::ContinuousTerm, z::ZScore) = ZScoredTerm(t, something(z.center, t.mean), something(z.scale, sqrt(t.var)))
StatsBase.zscore(t::ContinuousTerm; center=nothing, scale=nothing) = ZScoredTerm(t, center, scale)
StatsBase.zscore(t::AbstractTerm) = throw(ArgumentError("can only compute z-score for ContinuousTerm; must provide scale value via zscore(t; center, scale)"))
function StatsBase.zscore(t::ContinuousTerm, z::ZScore)
return ZScoredTerm(t, something(z.center, t.mean), something(z.scale, sqrt(t.var)))
end
function StatsBase.zscore(t::ContinuousTerm; center=nothing, scale=nothing)
return ZScoredTerm(t, center, scale)
end
function StatsBase.zscore(t::AbstractTerm)
throw(ArgumentError("can only compute z-score for ContinuousTerm; must provide scale value via zscore(t; center, scale)"))
end

function StatsBase.zscore(t::AbstractTerm, z::ZScore)
z.scale !== nothing && z.center !== nothing || throw(ArgumentError("can only compute z-score for ContinuousTerm; must provide scale via zscore(t; center, scale)"))
z.scale !== nothing && z.center !== nothing ||
throw(ArgumentError("can only compute z-score for ContinuousTerm; must provide scale via zscore(t; center, scale)"))
return ZScoredTerm(t, z.center, z.scale)
end

StatsModels.modelcols(t::ZScoredTerm, d::NamedTuple) = zscore(modelcols(t.term, d), t.center, t.scale)
function StatsModels.modelcols(t::ZScoredTerm, d::NamedTuple)
return zscore(modelcols(t.term, d), t.center, t.scale)
end

function StatsBase.coefnames(t::ZScoredTerm)
if StatsModels.width(t.term) == 1
return "$(coefnames(t.term))(centered: $(_round(t.center)) scaled: $(_round(t.scale)))"
elseif length(t.scale) > 1
return string.(vec(coefnames(t.term)), "(centered: ", _round.(vec(t.center)), " scaled: ", _round.(vec(t.scale)), ")")
return string.(vec(coefnames(t.term)), "(centered: ", _round.(vec(t.center)),
" scaled: ", _round.(vec(t.scale)), ")")
else
return string.(coefnames(t.term), "(centered: ", _round(t.center), " scaled: ", _round(t.scale), ")")
return string.(coefnames(t.term), "(centered: ", _round(t.center), " scaled: ",
_round(t.scale), ")")
end
end
# coef table: "x(scaled: 5.5)"
Base.show(io::IO, t::ZScoredTerm) = print(io, "$(t.term)(centered: $(_round(t.center)) scaled: $(_round(t.scale)))")
function Base.show(io::IO, t::ZScoredTerm)
return print(io, "$(t.term)(centered: $(_round(t.center)) scaled: $(_round(t.scale)))")
end
# regular show: "x(scaled: 5.5)", used in displaying schema dicts
Base.show(io::IO, ::MIME"text/plain", t::ZScoredTerm) = print(io, "$(t.term)(centered: $(_round(t.center)) scaled: $(_round(t.scale)))")
function Base.show(io::IO, ::MIME"text/plain", t::ZScoredTerm)
return print(io, "$(t.term)(centered: $(_round(t.center)) scaled: $(_round(t.scale)))")
end
# long show: "x(scaled: 5.5)"

# statsmodels glue code:
Expand Down
22 changes: 17 additions & 5 deletions test/centering.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
@testset "Centering" begin

data = (x=collect(1:10),
y=rand(10) .+ 3,
z=Symbol.(repeat('a':'e', 2)))
Expand All @@ -14,6 +13,18 @@
@test yc isa CenteredTerm
@test yc.center == mean(data.y)
@test modelcols(yc, data) == data.y .- mean(data.y) == data.y .- yc.center

@testset "alternative center function" begin
f = first
xc = concrete_term(term(:x), data, Center(f))
@test xc isa CenteredTerm
@test xc.center == f(data.x)
@test modelcols(xc, data) == data.x .- f(data.x) == data.x .- xc.center
# why test this? well this makes sure that our tests
# wouldn't pass if we were using the default center function
# in other words, this tests we're actually hitting a different codepath
@test !isapprox(mean(data.x), f(data.x))
end
end

@testset "Manual centering" begin
Expand Down Expand Up @@ -76,7 +87,7 @@
center!(mean, copy(x)) == center(mean, x)
@test_throws ArgumentError center!(mean, [1, 2])
@test_throws ArgumentError center!([1, 2])
@test_throws MethodError center!(v -> 1, ["a","b"])
@test_throws MethodError center!(v -> 1, ["a", "b"])
end
end

Expand All @@ -93,7 +104,8 @@
data.y .- 2,
(data.x .- mean(data.x)) .* (data.y .- 2))

@test coefnames(ff_c.rhs) == ["x(centered: 5.5)", "y(centered: 2)", "x(centered: 5.5) & y(centered: 2)"]
@test coefnames(ff_c.rhs) ==
["x(centered: 5.5)", "y(centered: 2)", "x(centered: 5.5) & y(centered: 2)"]

# round-trip schema is empty since needs_schema is false
sch_2 = schema(ff_c, data)
Expand Down Expand Up @@ -126,12 +138,12 @@

zc2 = center(z, Center([1 2 3 4]))
@test modelcols(zc2, data) == modelcols(z, data) .- [1 2 3 4]
@test coefnames(zc2) == coefnames(z) .* "(centered: " .* string.([1, 2, 3, 4]) .* ")"
@test coefnames(zc2) ==
coefnames(z) .* "(centered: " .* string.([1, 2, 3, 4]) .* ")"
end

# @testset "utilities" begin


# end

end
Loading

2 comments on commit 8f52f69

@kleinschmidt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/52268

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.4 -m "<description of version>" 8f52f692cb01856f4c2111f52ff9703b8e934030
git push origin v0.1.4

Please sign in to comment.