diff --git a/Project.toml b/Project.toml index 94b95b5..56cfaa5 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,9 @@ version = "0.1.1" [deps] Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" [compat] diff --git a/src/RegressionFormulae.jl b/src/RegressionFormulae.jl index 03725b2..a3a20e8 100644 --- a/src/RegressionFormulae.jl +++ b/src/RegressionFormulae.jl @@ -1,5 +1,8 @@ module RegressionFormulae +using LinearAlgebra +using Statistics +using StatsAPI using StatsModels using Combinatorics using Base.Iterators @@ -13,4 +16,7 @@ include("fulldummy.jl") include("power.jl") include("nesting.jl") +include("poly.jl") +export poly + end # module diff --git a/src/poly.jl b/src/poly.jl new file mode 100644 index 0000000..3e11c81 --- /dev/null +++ b/src/poly.jl @@ -0,0 +1,78 @@ +function poly(x, deg; raw=true) + raw == true && return [el^d for el in x, d in 1:deg] + length(unique(x)) < deg && + error("Degree of orthogonal polynomial must be lower than the " * + "number of unique points") + x̄ = mean(x) + x = x .- mean(x) + mat = qr!([el^d for el in x, d in 0:deg]) + cols = vcat(Diagonal(mat.R), zeros(length(x) - size(mat.R,1), size(mat.R, 2))) + cols = mat.Q * cols + colnorms = sum.(abs2, eachcol(cols)) + + for (j, cn) in zip(axes(cols, 2), colnorms) + cols[:, j] ./= sqrt(cn) + end + + return cols[:, (begin+1):end] + +end + +# type of model where syntax applies: here this applies to any model type +const POLY_CONTEXT = RegressionModel + +# struct for behavior +struct PolyTerm{T,D} <: AbstractTerm + term::T + deg::D + raw::Bool +end + +Base.show(io::IO, p::PolyTerm) = print(io, "poly($(p.term), $(p.deg))") + +# for `poly` use at run-time (outside @formula), return a schema-less PolyTerm +poly(t::Symbol, d::Int; raw::Bool=true) = PolyTerm(term(t), term(d), raw) + +# for `poly` use inside @formula: create a schemaless PolyTerm and apply_schema +function StatsModels.apply_schema(t::FunctionTerm{typeof(poly)}, + sch::StatsModels.Schema, + Mod::Type{<:POLY_CONTEXT}) + apply_schema(PolyTerm(t.args_parsed...), sch, Mod) +end + +# apply_schema to internal Terms and check for proper types +function StatsModels.apply_schema(t::PolyTerm, + sch::StatsModels.Schema, + Mod::Type{<:POLY_CONTEXT}) + term = apply_schema(t.term, sch, Mod) + isa(term, ContinuousTerm) || + throw(ArgumentError("PolyTerm only works with continuous terms (got $term)")) + isa(t.deg, ConstantTerm) || + throw(ArgumentError("PolyTerm degree must be a number (got $t.deg)")) + PolyTerm(term, t.deg.n) +end + +function StatsModels.modelcols(p::PolyTerm, d::NamedTuple) + col = modelcols(p.term, d) + p.raw && return reduce(hcat, [col.^n for n in 1:p.deg]) + + return mapreduce(hcat, 1:p.deg) do n + + + # INSERT HERE + end +end + +# the basic terms contained within a PolyTerm (for schema extraction) +StatsModels.terms(p::PolyTerm) = terms(p.term) +# names variables from the data that a PolyTerm relies on +StatsModels.termvars(p::PolyTerm) = StatsModels.termvars(p.term) +# number of columns in the matrix this term produces +StatsModels.width(p::PolyTerm) = p.deg + +# highlight the difference between raw and orthogonal in the output +function StatsAPI.coefnames(p::PolyTerm) + p.raw && return coefnames(p.term) .* "^" .* string.(1:p.deg) + + return "poly(" .* coefnames(p.term) .* ", " .* string.(1:p.deg) .* ")" +end diff --git a/test/poly.jl b/test/poly.jl new file mode 100644 index 0000000..afe8282 --- /dev/null +++ b/test/poly.jl @@ -0,0 +1,22 @@ +# > poly(1:10, 3) +# 1 2 3 +# [1,] -0.49543369 0.52223297 -0.4534252 +# [2,] -0.38533732 0.17407766 0.1511417 +# [3,] -0.27524094 -0.08703883 0.3778543 +# [4,] -0.16514456 -0.26111648 0.3346710 +# [5,] -0.05504819 -0.34815531 0.1295501 +# [6,] 0.05504819 -0.34815531 -0.1295501 +# [7,] 0.16514456 -0.26111648 -0.3346710 +# [8,] 0.27524094 -0.08703883 -0.3778543 +# [9,] 0.38533732 0.17407766 -0.1511417 +# [10,] 0.49543369 0.52223297 0.4534252 + +# > poly(-3:3, 5) +# 1 2 3 4 5 +# [1,] -5.669467e-01 5.455447e-01 -4.082483e-01 0.2417469 -1.091089e-01 +# [2,] -3.779645e-01 9.690821e-17 4.082483e-01 -0.5640761 4.364358e-01 +# [3,] -1.889822e-01 -3.273268e-01 4.082483e-01 0.0805823 -5.455447e-01 +# [4,] 2.098124e-17 -4.364358e-01 4.532467e-17 0.4834938 5.342065e-16 +# [5,] 1.889822e-01 -3.273268e-01 -4.082483e-01 0.0805823 5.455447e-01 +# [6,] 3.779645e-01 0.000000e+00 -4.082483e-01 -0.5640761 -4.364358e-01 +# [7,] 5.669467e-01 5.455447e-01 4.082483e-01 0.2417469 1.091089e-01