-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix comment * scaling * copy-paste errors * missing bracket * gotta get that 100% * update docstrings * update tests * zscoring * avoid type piracy * patch bump * warning for R users * docstring fixes * zscore a term with kwargs * Revert "avoid type piracy" This reverts commit 2a753e6. * returns and type annotations * more docstring fixes * more docstring fixes * more more more * Apply suggestions from code review Co-authored-by: Dave Kleinschmidt <[email protected]> * remove explicit eltype checks for scale * round more * ditch empty test set Co-authored-by: Dave Kleinschmidt <[email protected]>
- Loading branch information
1 parent
d4b4ef2
commit 8c6bb43
Showing
10 changed files
with
575 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
""" | ||
scale(f=std, x, y=f(skipmissing(x))) | ||
Scale an array `x` by a scalar `y`. | ||
!!! warning | ||
This only scales and does not center the values, unlike `scale` in R. | ||
See `StatsBase.zscore` for that functionality. | ||
See also [`scale!`](@ref) | ||
""" | ||
scale(x) = scale(std, x) | ||
scale(f::Function, x) = scale(x, f(skipmissing(x))) | ||
scale(x, y) = x ./ y | ||
|
||
""" | ||
scale(f=std, x, y=f(skipmissing(x))) | ||
Scale an array `x` in place by a scalar `y`. | ||
!!! warning | ||
This only scales and does not center the values, unlike `scale` in R. | ||
See `StatsBase.zscore` for that functionality. | ||
See also [`scale`](@ref) | ||
""" | ||
scale!(x) = scale!(std, x) | ||
scale!(f::Function, x) = scale!(x, f(skipmissing(x))) | ||
|
||
function scale!(x, y) | ||
x ./= y | ||
return x | ||
end | ||
|
||
""" | ||
struct Scale | ||
Represents a scaling scheme, akin to `StatsModels.AbstractContrasts`. Pass as | ||
value in `Dict` as hints to `schema` (or as `contrasts` kwarg for `fit`). | ||
## Examples | ||
Can specify scale value to use: | ||
``` | ||
julia> schema((x=collect(1:10), ), Dict(:x => Scale(5))) | ||
StatsModels.Schema with 1 entry: | ||
x => x(scaled: 5)) | ||
``` | ||
Or scale will be automatically computed if left out: | ||
``` | ||
julia> schema((x=collect(1:10), ), Dict(:x => Scale())) | ||
StatsModels.Schema with 1 entry: | ||
x => x(scaled: 3.03) | ||
``` | ||
""" | ||
struct Scale | ||
scale::Any | ||
end | ||
|
||
Scale() = Scale(nothing) | ||
|
||
|
||
""" | ||
struct ScaledTerm{T,S} <: AbstractTerm | ||
A lazily scaled term. A wrapper around an `T<:AbstractTerm` which will | ||
produce scaled values with `modelcols` by dividing each element by `scale`. | ||
## Fields | ||
- `term::T`: The wrapped term. | ||
- `scale::S`: The scale value which the resulting `modelcols` are divided by. | ||
## Examples | ||
Directly construct with given scale: | ||
``` | ||
julia> d = (x=collect(1:10), ); | ||
julia> t = concrete_term(term(:x), d) | ||
x(continuous) | ||
julia> ts = ScaledTerm(t, 5) | ||
x(scaled: 5) | ||
julia> hcat(modelcols(t + ts, d)...) | ||
10×2 Matrix{Float64}: | ||
1.0 0.2 | ||
2.0 0.4 | ||
3.0 0.6 | ||
4.0 0.8 | ||
5.0 1.0 | ||
6.0 1.2 | ||
7.0 1.4 | ||
8.0 1.6 | ||
9.0 1.8 | ||
10.0 2.0 | ||
``` | ||
Construct with lazy scaling via [`Scale`](@ref) | ||
``` | ||
julia> ts = concrete_term(term(:x), d, Scale()) | ||
x(scaled: 3.03) | ||
julia> hcat(modelcols(t + ts, d)...) | ||
10×2 Matrix{Float64}: | ||
1.0 0.330289 | ||
2.0 0.660578 | ||
3.0 0.990867 | ||
4.0 1.32116 | ||
5.0 1.65145 | ||
6.0 1.98173 | ||
7.0 2.31202 | ||
8.0 2.64231 | ||
9.0 2.9726 | ||
10.0 3.30289 | ||
``` | ||
Or similarly via schema hints: | ||
``` | ||
julia> sch = schema(d, Dict(:x => Scale())) | ||
StatsModels.Schema with 1 entry: | ||
x => scale(x, 3.03) | ||
``` | ||
""" | ||
struct ScaledTerm{T,S} <: AbstractTerm | ||
term::T | ||
scale::S | ||
end | ||
|
||
StatsModels.concrete_term(t::Term, xs::AbstractArray, s::Scale) = | ||
scale(StatsModels.concrete_term(t, xs, nothing), s) | ||
|
||
# run-time constructors: | ||
scale(t::ContinuousTerm, s::Scale) = ScaledTerm(t, something(s.scale, sqrt(t.var))) | ||
scale(t::ContinuousTerm, s) = ScaledTerm(t, s) | ||
scale(t::ContinuousTerm) = ScaledTerm(t, sqrt(t.var)) | ||
scale(t::AbstractTerm) = throw(ArgumentError("can only compute scale for ContinuousTerm; must provide scale value via scale(t, s)")) | ||
|
||
function scale(t::AbstractTerm, s::Scale) | ||
s.scale !== nothing || throw(ArgumentError("can only compute scale for ContinuousTerm; must provide scale via scale(t, s)")) | ||
return ScaledTerm(t, s.scale) | ||
end | ||
|
||
StatsModels.modelcols(t::ScaledTerm, d::NamedTuple) = modelcols(t.term, d) ./ t.scale | ||
|
||
function StatsBase.coefnames(t::ScaledTerm) | ||
if StatsModels.width(t.term) == 1 | ||
return "$(coefnames(t.term))(scaled: $(_round(t.scale)))" | ||
elseif length(t.scale) > 1 | ||
return string.(vec(coefnames(t.term)), "(scaled: ", _round.(vec(t.scale)), ")") | ||
else | ||
return string.(coefnames(t.term), "(scaled: ", _round(t.scale), ")") | ||
end | ||
end | ||
# coef table: "x(scaled: 5.5)" | ||
Base.show(io::IO, t::ScaledTerm) = print(io, "$(t.term)(scaled: $(_round(t.scale)))") | ||
# regular show: "x(scaled: 5.5)", used in displaying schema dicts | ||
Base.show(io::IO, ::MIME"text/plain", t::ScaledTerm) = print(io, "$(t.term)(scaled: $(_round(t.scale)))") | ||
# long show: "x(scaled: 5.5)" | ||
|
||
# statsmodels glue code: | ||
StatsModels.width(t::ScaledTerm) = StatsModels.width(t.term) | ||
# don't generate schema entries for terms which are already scaled | ||
StatsModels.needs_schema(::ScaledTerm) = false | ||
StatsModels.termsyms(t::ScaledTerm) = StatsModels.termsyms(t.term) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
_round(v::AbstractArray) = _round.(v) | ||
_round(x::Integer) = x | ||
_round(x) = round(x; digits=2) |
Oops, something went wrong.
8c6bb43
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
8c6bb43
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/40143
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: