Skip to content

Commit

Permalink
reduce duplication
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Arslan <[email protected]>
  • Loading branch information
palday and ararslan committed Aug 21, 2023
1 parent 66459e7 commit b6cf8e2
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 37 deletions.
19 changes: 5 additions & 14 deletions ext/EffectsGLMExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,13 @@ module EffectsGLMExt

using Effects

using GLM: AbstractGLM, Link, mueta, linkinv
using StatsAPI: RegressionModel
using GLM: AbstractGLM, Link, Link01, inverselink
using StatsModels: TableRegressionModel

# TODO: upstream a Link(::TableRegressionModel{<:AbstractGLM})
_link(m::TableRegressionModel{<:AbstractGLM}) = Link(m.model)

function Effects._difference_method!(eff::Vector{T}, err::Vector{T},
model::Union{TableRegressionModel{<:AbstractGLM},
AbstractGLM},
::AutoInvLink) where {T<:AbstractFloat}
link = _link(model)
err .*= mueta.(link, eff)
eff .= linkinv.(link, eff)

return err
end
Effects._model_link(m::TableRegressionModel{<:AbstractGLM}, ::AutoInvLink) = Link(m.model)
Effects._model_link(m::AbstractGLM, ::AutoInvLink) = Link(m)
Effects._invlink_and_deriv(link::Link01, η) = inverselink(link, η)[1:2:3] # (µ, 1 - µ, dμdη)
Effects._invlink_and_deriv(link::Link, η) = inverselink(link, η)[1:2] # (µ, dμdη, NaN)

Check warning on line 12 in ext/EffectsGLMExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/EffectsGLMExt.jl#L12

Added line #L12 was not covered by tests

end # module
11 changes: 2 additions & 9 deletions ext/EffectsMixedModelsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,8 @@ module EffectsMixedModelsExt

using Effects
using MixedModels
using GLM: Link, mueta, linkinv
using GLM: Link

function Effects._difference_method!(eff::Vector{T}, err::Vector{T},
model::GeneralizedLinearMixedModel,
::AutoInvLink) where {T<:AbstractFloat}
link = Link(model)
err .*= mueta.(link, eff)
eff .= linkinv.(link, eff)
return err
end
Effects._model_link(m::GeneralizedLinearMixedModel, ::AutoInvLink) = Link(m)

end # module
34 changes: 20 additions & 14 deletions src/regressionmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ function effects!(reference_grid::DataFrame, model::RegressionModel;
X = modelcols(form_typical, reference_grid)
eff = X * coef(model)
err = sqrt.(diag(X * vcov(model) * X'))
if invlink !== identity
_difference_method!(eff, err, model, invlink)
end
_difference_method!(eff, err, model, invlink)
reference_grid[!, something(eff_col, _responsename(model))] = eff
reference_grid[!, err_col] = err
return reference_grid
Expand All @@ -135,24 +133,32 @@ end
# in addition to the difference method
# xref https://github.com/JuliaStats/GLM.jl/blob/c13577eaf3f418c58020534dd407532ee57f219b/src/glmfit.jl#L773-L783

function _difference_method!(eff::Vector{T}, err::Vector{T},
::RegressionModel,
invlink) where {T<:AbstractFloat}
err .*= ForwardDiff.derivative.(invlink, eff)
eff .= invlink.(eff)
return eff, err
end

function _difference_method!(::Vector{T}, ::Vector{T},
::RegressionModel,
::AutoInvLink) where {T<:AbstractFloat}
_invlink_and_deriv(invlink, η) = (invlink(η), ForwardDiff.derivative(invlink, η))
_invlink_and_deriv(::typeof(identity), η) = (η, 1)
# this isn't the best name because it sometimes returns the inverse link and sometimes the link (Link())
# for now, this is private API, but we should see how this goes and whether we can make it public API
# so local extensions (instead of Package-Extensions) are better supported
_model_link(::RegressionModel, invlink::Function) = invlink
function _model_link(::RegressionModel, ::AutoInvLink)
@static if VERSION < v"1.9"

Check warning on line 143 in src/regressionmodel.jl

View check run for this annotation

Codecov / codecov/patch

src/regressionmodel.jl#L143

Added line #L143 was not covered by tests
@error "AutoInvLink requires extensions and is thus not available on Julia < 1.9."
end
throw(ArgumentError("No appropriate extension is loaded for automatic " *
"determination of the inverse link for this model type"))
end

function _difference_method!(eff::Vector{T}, err::Vector{T},
m::RegressionModel,
invlink) where {T<:AbstractFloat}
link = _model_link(m, invlink)
@inbounds for i in eachindex(eff, err)
μ, dμdη = _invlink_and_deriv(link, eff[i])
err[i] *= dμdη
eff[i] = μ
end
return eff, err
end

"""
expand_grid(design)
Expand Down
1 change: 1 addition & 0 deletions test/delta_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ end
iv = Base.Fix1(GLM.linkinv, Link(model.model))
@static if VERSION >= v"1.9"
invlinks = [iv, AutoInvLink()]
@test Effects._model_link(model, AutoInvLink()) == Effects._model_link(model.model, AutoInvLink())
else
invlinks = [iv]
end
Expand Down

0 comments on commit b6cf8e2

Please sign in to comment.