Skip to content

Commit

Permalink
Scale and ZScore (#16)
Browse files Browse the repository at this point in the history
* fix comment

* scaling

* copy-paste errors

* missing bracket

* gotta get that 100%

* update docstrings

* update tests

* zscoring

* avoid type piracy

* patch bump

* warning for R users

* docstring fixes

* zscore a term with kwargs

* Revert "avoid type piracy"

This reverts commit 2a753e6.

* returns and type annotations

* more docstring fixes

* more docstring fixes

* more more more

* Apply suggestions from code review

Co-authored-by: Dave Kleinschmidt <[email protected]>

* remove explicit eltype checks for scale

* round more

* ditch empty test set

Co-authored-by: Dave Kleinschmidt <[email protected]>
  • Loading branch information
palday and kleinschmidt committed Jul 3, 2021
1 parent d4b4ef2 commit 8c6bb43
Show file tree
Hide file tree
Showing 10 changed files with 575 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StandardizedPredictors"
uuid = "5064a6a7-f8c2-40e2-8bdc-797ec6f1ae18"
authors = "Beacon Biosignals, inc."
version = "0.1.1"
version = "0.1.2"

[deps]
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
13 changes: 12 additions & 1 deletion src/StandardizedPredictors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@ export
center,
center!,
Center,
CenteredTerm
CenteredTerm,
scale,
scale!,
Scale,
ScaledTerm,
zscore, # from StatsBase
zscore!, # from StatsBase
ZScore,
ZScoredTerm

using StatsModels
using StatsBase
using Statistics

include("utils.jl")
include("centering.jl")
include("scaling.jl")
include("zscoring.jl")

end
20 changes: 10 additions & 10 deletions src/centering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ StatsModels.Schema with 1 entry:
```
"""
struct Center
center
center::Any
end

Center() = Center(nothing)
Expand Down Expand Up @@ -92,7 +92,7 @@ julia> t = concrete_term(term(:x), d)
x(continuous)
julia> tc = CenteredTerm(t, 5)
center(x, 5)
x(centered: 5)
julia> hcat(modelcols(t + tc, d)...)
10×2 Matrix{Int64}:
Expand Down Expand Up @@ -154,23 +154,23 @@ center(t::AbstractTerm) = throw(ArgumentError("can only compute center for Conti

function center(t::AbstractTerm, c::Center)
c.center !== nothing || throw(ArgumentError("can only compute center for ContinuousTerm; must provide center via center(t, c)"))
CenteredTerm(t, c.center)
return CenteredTerm(t, c.center)
end

StatsModels.modelcols(t::CenteredTerm, d::NamedTuple) = modelcols(t.term, d) .- t.center
function StatsBase.coefnames(t::CenteredTerm)
if StatsModels.width(t.term) == 1
return "$(coefnames(t.term))(centered: $(t.center))"
return "$(coefnames(t.term))(centered: $(_round(t.center)))"
elseif length(t.center) > 1
return string.(vec(coefnames(t.term)), "(centered: ", vec(t.center), ")")
return string.(vec(coefnames(t.term)), "(centered: ", _round.(vec(t.center)), ")")
else
return string.(coefnames(t.term), "(centered: ", t.center, ")")
return string.(coefnames(t.term), "(centered: ", _round.(t.center), ")")
end
end
# coef table: "x: centered at 5.5"
Base.show(io::IO, t::CenteredTerm) = show(io, t.term)
# regular show: "x"
Base.show(io::IO, ::MIME"text/plain", t::CenteredTerm) = print(io, "$(t.term)(centered: $(t.center))")
# coef table: "x(centered: 5.5)"
Base.show(io::IO, t::CenteredTerm) = print(io, "$(t.term)(centered: $(_round(t.center)))")
# regular show: "x(centered: 5.5)", used in displaying schema dicts
Base.show(io::IO, ::MIME"text/plain", t::CenteredTerm) = print(io, "$(t.term)(centered: $(_round(t.center)))")
# long show: "x(centered: 5.5)"

# statsmodels glue code:
Expand Down
175 changes: 175 additions & 0 deletions src/scaling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
scale(f=std, x, y=f(skipmissing(x)))
Scale an array `x` by a scalar `y`.
!!! warning
This only scales and does not center the values, unlike `scale` in R.
See `StatsBase.zscore` for that functionality.
See also [`scale!`](@ref)
"""
scale(x) = scale(std, x)
scale(f::Function, x) = scale(x, f(skipmissing(x)))
scale(x, y) = x ./ y

"""
scale(f=std, x, y=f(skipmissing(x)))
Scale an array `x` in place by a scalar `y`.
!!! warning
This only scales and does not center the values, unlike `scale` in R.
See `StatsBase.zscore` for that functionality.
See also [`scale`](@ref)
"""
scale!(x) = scale!(std, x)
scale!(f::Function, x) = scale!(x, f(skipmissing(x)))

function scale!(x, y)
x ./= y
return x
end

"""
struct Scale
Represents a scaling scheme, akin to `StatsModels.AbstractContrasts`. Pass as
value in `Dict` as hints to `schema` (or as `contrasts` kwarg for `fit`).
## Examples
Can specify scale value to use:
```
julia> schema((x=collect(1:10), ), Dict(:x => Scale(5)))
StatsModels.Schema with 1 entry:
x => x(scaled: 5))
```
Or scale will be automatically computed if left out:
```
julia> schema((x=collect(1:10), ), Dict(:x => Scale()))
StatsModels.Schema with 1 entry:
x => x(scaled: 3.03)
```
"""
struct Scale
scale::Any
end

Scale() = Scale(nothing)


"""
struct ScaledTerm{T,S} <: AbstractTerm
A lazily scaled term. A wrapper around an `T<:AbstractTerm` which will
produce scaled values with `modelcols` by dividing each element by `scale`.
## Fields
- `term::T`: The wrapped term.
- `scale::S`: The scale value which the resulting `modelcols` are divided by.
## Examples
Directly construct with given scale:
```
julia> d = (x=collect(1:10), );
julia> t = concrete_term(term(:x), d)
x(continuous)
julia> ts = ScaledTerm(t, 5)
x(scaled: 5)
julia> hcat(modelcols(t + ts, d)...)
10×2 Matrix{Float64}:
1.0 0.2
2.0 0.4
3.0 0.6
4.0 0.8
5.0 1.0
6.0 1.2
7.0 1.4
8.0 1.6
9.0 1.8
10.0 2.0
```
Construct with lazy scaling via [`Scale`](@ref)
```
julia> ts = concrete_term(term(:x), d, Scale())
x(scaled: 3.03)
julia> hcat(modelcols(t + ts, d)...)
10×2 Matrix{Float64}:
1.0 0.330289
2.0 0.660578
3.0 0.990867
4.0 1.32116
5.0 1.65145
6.0 1.98173
7.0 2.31202
8.0 2.64231
9.0 2.9726
10.0 3.30289
```
Or similarly via schema hints:
```
julia> sch = schema(d, Dict(:x => Scale()))
StatsModels.Schema with 1 entry:
x => scale(x, 3.03)
```
"""
struct ScaledTerm{T,S} <: AbstractTerm
term::T
scale::S
end

StatsModels.concrete_term(t::Term, xs::AbstractArray, s::Scale) =
scale(StatsModels.concrete_term(t, xs, nothing), s)

# run-time constructors:
scale(t::ContinuousTerm, s::Scale) = ScaledTerm(t, something(s.scale, sqrt(t.var)))
scale(t::ContinuousTerm, s) = ScaledTerm(t, s)
scale(t::ContinuousTerm) = ScaledTerm(t, sqrt(t.var))
scale(t::AbstractTerm) = throw(ArgumentError("can only compute scale for ContinuousTerm; must provide scale value via scale(t, s)"))

function scale(t::AbstractTerm, s::Scale)
s.scale !== nothing || throw(ArgumentError("can only compute scale for ContinuousTerm; must provide scale via scale(t, s)"))
return ScaledTerm(t, s.scale)
end

StatsModels.modelcols(t::ScaledTerm, d::NamedTuple) = modelcols(t.term, d) ./ t.scale

function StatsBase.coefnames(t::ScaledTerm)
if StatsModels.width(t.term) == 1
return "$(coefnames(t.term))(scaled: $(_round(t.scale)))"
elseif length(t.scale) > 1
return string.(vec(coefnames(t.term)), "(scaled: ", _round.(vec(t.scale)), ")")
else
return string.(coefnames(t.term), "(scaled: ", _round(t.scale), ")")
end
end
# coef table: "x(scaled: 5.5)"
Base.show(io::IO, t::ScaledTerm) = print(io, "$(t.term)(scaled: $(_round(t.scale)))")
# regular show: "x(scaled: 5.5)", used in displaying schema dicts
Base.show(io::IO, ::MIME"text/plain", t::ScaledTerm) = print(io, "$(t.term)(scaled: $(_round(t.scale)))")
# long show: "x(scaled: 5.5)"

# statsmodels glue code:
StatsModels.width(t::ScaledTerm) = StatsModels.width(t.term)
# don't generate schema entries for terms which are already scaled
StatsModels.needs_schema(::ScaledTerm) = false
StatsModels.termsyms(t::ScaledTerm) = StatsModels.termsyms(t.term)
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

_round(v::AbstractArray) = _round.(v)
_round(x::Integer) = x
_round(x) = round(x; digits=2)
Loading

2 comments on commit 8c6bb43

@palday
Copy link
Member Author

@palday palday commented on 8c6bb43 Jul 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/40143

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.2 -m "<description of version>" 8c6bb43e3b874ab06f9b574577c0761601b9c73e
git push origin v0.1.2

Please sign in to comment.