Skip to content

Commit f595ceb

Browse files
Replace observed collections (#266)
* add groups field to SemEnsemble, streamline replace_observed code, add replace_observed for ensemble models, add bootstrap se for ensemble models * fix mg tests
1 parent 56f0567 commit f595ceb

File tree

4 files changed

+133
-52
lines changed

4 files changed

+133
-52
lines changed

src/additional_functions/simulation.jl

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,23 @@
33
44
(2) replace_observed(model::AbstractSemSingle, observed; kwargs...)
55
6+
(3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...)
7+
68
Return a new model with swaped observed part.
79
810
# Arguments
911
- `model::AbstractSemSingle`: model to swap the observed part of.
1012
- `kwargs`: additional keyword arguments; typically includes `data` and `specification`
1113
- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`
1214
15+
# For SemEnsemble models:
16+
- `column`: if a DataFrame is passed as `data = ...`, which column signifies the group?
17+
- `weights`: how to weight the different sub-models,
18+
defaults to number of samples per group in the new data
19+
- `kwargs`: has to be a dict with keys equal to the group names.
20+
For `data` can also be a DataFrame with `column` containing the group information,
21+
and for `specification` can also be an `EnsembleParameterTable`.
22+
1323
# Examples
1424
See the online documentation on [Replace observed data](@ref).
1525
"""
@@ -37,51 +47,28 @@ function update_observed end
3747
replace_observed(model::AbstractSemSingle; kwargs...) =
3848
replace_observed(model, typeof(observed(model)).name.wrapper; kwargs...)
3949

40-
# construct a new observed type
41-
replace_observed(model::AbstractSemSingle, observed_type; kwargs...) =
42-
replace_observed(model, observed_type(; kwargs...); kwargs...)
43-
44-
replace_observed(model::AbstractSemSingle, new_observed::SemObserved; kwargs...) =
45-
replace_observed(
46-
model,
47-
observed(model),
48-
implied(model),
49-
loss(model),
50-
new_observed;
51-
kwargs...,
52-
)
53-
54-
function replace_observed(
55-
model::AbstractSemSingle,
56-
old_observed,
57-
implied,
58-
loss,
59-
new_observed::SemObserved;
60-
kwargs...,
61-
)
50+
function replace_observed(model::AbstractSemSingle, observed_type; kwargs...)
51+
new_observed = observed_type(;kwargs...)
6252
kwargs = Dict{Symbol, Any}(kwargs...)
6353

6454
# get field types
6555
kwargs[:observed_type] = typeof(new_observed)
66-
kwargs[:old_observed_type] = typeof(old_observed)
67-
kwargs[:implied_type] = typeof(implied)
68-
kwargs[:loss_types] = [typeof(lossfun) for lossfun in loss.functions]
56+
kwargs[:old_observed_type] = typeof(model.observed)
57+
kwargs[:implied_type] = typeof(model.implied)
58+
kwargs[:loss_types] = [typeof(lossfun) for lossfun in model.loss.functions]
6959

7060
# update implied
71-
implied = update_observed(implied, new_observed; kwargs...)
72-
kwargs[:implied] = implied
73-
kwargs[:nparams] = nparams(implied)
61+
new_implied = update_observed(model.implied, new_observed; kwargs...)
62+
kwargs[:implied] = new_implied
63+
kwargs[:nparams] = nparams(new_implied)
7464

7565
# update loss
76-
loss = update_observed(loss, new_observed; kwargs...)
77-
kwargs[:loss] = loss
78-
79-
#new_implied = update_observed(model.implied, new_observed; kwargs...)
66+
new_loss = update_observed(model.loss, new_observed; kwargs...)
8067

8168
return Sem(
8269
new_observed,
83-
update_observed(model.implied, new_observed; kwargs...),
84-
update_observed(model.loss, new_observed; kwargs...),
70+
new_implied,
71+
new_loss
8572
)
8673
end
8774

@@ -92,6 +79,39 @@ function update_observed(loss::SemLoss, new_observed; kwargs...)
9279
return SemLoss(new_functions, loss.weights)
9380
end
9481

82+
83+
function replace_observed(
84+
emodel::SemEnsemble;
85+
column = :group,
86+
weights = nothing,
87+
kwargs...,
88+
)
89+
kwargs = Dict{Symbol, Any}(kwargs...)
90+
# allow for EnsembleParameterTable to be passed as specification
91+
if haskey(kwargs, :specification) && isa(kwargs[:specification], EnsembleParameterTable)
92+
kwargs[:specification] = convert(Dict{Symbol, RAMMatrices}, kwargs[:specification])
93+
end
94+
# allow for DataFrame with group variable "column" to be passed as new data
95+
if haskey(kwargs, :data) && isa(kwargs[:data], DataFrame)
96+
kwargs[:data] = Dict(
97+
group => select(
98+
filter(
99+
r -> r[column] == group,
100+
kwargs[:data]),
101+
Not(column)) for group in emodel.groups)
102+
end
103+
# update each model for new data
104+
models = emodel.sems
105+
new_models = Tuple(
106+
replace_observed(m; group_kwargs(g, kwargs)...) for (m, g) in zip(models, emodel.groups)
107+
)
108+
return SemEnsemble(new_models...; weights = weights, groups = emodel.groups)
109+
end
110+
111+
function group_kwargs(g, kwargs)
112+
return Dict(k => kwargs[k][g] for k in keys(kwargs))
113+
end
114+
95115
############################################################################################
96116
# simulate data
97117
############################################################################################

src/frontend/fit/standard_errors/bootstrap.jl

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,19 @@
22
se_bootstrap(sem_fit::SemFit; n_boot = 3000, data = nothing, kwargs...)
33
44
Return boorstrap standard errors.
5-
Only works for single models.
65
76
# Arguments
87
- `n_boot`: number of boostrap samples
98
- `data`: data to sample from. Only needed if different than the data from `sem_fit`
109
- `kwargs...`: passed down to `replace_observed`
1110
"""
1211
function se_bootstrap(
13-
semfit::SemFit;
12+
semfit::SemFit{Mi, So, St, Mo, O};
1413
n_boot = 3000,
1514
data = nothing,
1615
specification = nothing,
1716
kwargs...,
18-
)
19-
if model(semfit) isa AbstractSemCollection
20-
throw(
21-
ArgumentError(
22-
"bootstrap standard errors for ensemble models are not available yet",
23-
),
24-
)
25-
end
17+
) where {Mi, So, St, Mo <: AbstractSemSingle, O}
2618

2719
if isnothing(data)
2820
data = samples(observed(model(semfit)))
@@ -69,6 +61,62 @@ function se_bootstrap(
6961
return sd
7062
end
7163

64+
function se_bootstrap(
65+
semfit::SemFit{Mi, So, St, Mo, O};
66+
n_boot = 3000,
67+
data = nothing,
68+
specification = nothing,
69+
kwargs...,
70+
) where {Mi, So, St, Mo <: SemEnsemble, O}
71+
72+
models = semfit.model.sems
73+
groups = semfit.model.groups
74+
75+
if isnothing(data)
76+
data = Dict(g => samples(observed(m)) for (g, m) in zip(groups, models))
77+
end
78+
79+
data = Dict(k => prepare_data_bootstrap(data[k]) for k in keys(data))
80+
81+
start = solution(semfit)
82+
83+
new_solution = zero(start)
84+
sum = zero(start)
85+
squared_sum = zero(start)
86+
87+
n_failed = 0.0
88+
89+
converged = true
90+
91+
for _ in 1:n_boot
92+
sample_data = Dict(k => bootstrap_sample(data[k]) for k in keys(data))
93+
new_model = replace_observed(
94+
semfit.model;
95+
data = sample_data,
96+
specification = specification,
97+
kwargs...,
98+
)
99+
100+
new_solution .= 0.0
101+
102+
try
103+
new_solution = solution(fit(new_model; start_val = start))
104+
catch
105+
n_failed += 1
106+
end
107+
108+
@. sum += new_solution
109+
@. squared_sum += new_solution^2
110+
111+
converged = true
112+
end
113+
114+
n_conv = n_boot - n_failed
115+
sd = sqrt.(squared_sum / n_conv - (sum / n_conv) .^ 2)
116+
print("Number of nonconverged models: ", n_failed, "\n")
117+
return sd
118+
end
119+
72120
function prepare_data_bootstrap(data)
73121
return Matrix(data)
74122
end

src/types.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ end
168168
# ensemble models
169169
############################################################################################
170170
"""
171-
(1) SemEnsemble(models...; weights = nothing, kwargs...)
171+
(1) SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...)
172172
173173
(2) SemEnsemble(;specification, data, groups, column = :group, kwargs...)
174174
@@ -192,24 +192,24 @@ Returns a SemEnsemble with fields
192192
193193
For instructions on multigroup models, see the online documentation.
194194
"""
195-
struct SemEnsemble{N, T <: Tuple, V <: AbstractVector, I} <: AbstractSemCollection
195+
struct SemEnsemble{N, T <: Tuple, V <: AbstractVector, I, G <: Vector{Symbol}} <: AbstractSemCollection
196196
n::N
197197
sems::T
198198
weights::V
199199
param_labels::I
200+
groups::G
200201
end
201202

202203
# constructor from multiple models
203-
function SemEnsemble(models...; weights = nothing, kwargs...)
204+
function SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...)
204205
n = length(models)
205-
206206
# default weights
207-
208207
if isnothing(weights)
209208
nsamples_total = sum(nsamples, models)
210209
weights = [nsamples(model) / nsamples_total for model in models]
211210
end
212-
211+
# default group labels
212+
groups = isnothing(groups) ? Symbol.(:g, 1:n) : groups
213213
# check parameters equality
214214
param_labels = SEM.param_labels(models[1])
215215
for model in models
@@ -220,7 +220,7 @@ function SemEnsemble(models...; weights = nothing, kwargs...)
220220
end
221221
end
222222

223-
return SemEnsemble(n, models, weights, param_labels)
223+
return SemEnsemble(n, models, weights, param_labels, groups)
224224
end
225225

226226
# constructor from EnsembleParameterTable and data set
@@ -238,7 +238,7 @@ function SemEnsemble(; specification, data, groups, column = :group, kwargs...)
238238
model = Sem(; specification = ram_matrices, data = data_group, kwargs...)
239239
push!(models, model)
240240
end
241-
return SemEnsemble(models...; weights = nothing, kwargs...)
241+
return SemEnsemble(models...; weights = nothing, groups = groups, kwargs...)
242242
end
243243

244244
param_labels(ensemble::SemEnsemble) = ensemble.param_labels

test/examples/multigroup/build_models.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ model_ml_multigroup2 = SemEnsemble(
2020
loss = SemML,
2121
)
2222

23+
model_ml_multigroup3 = replace_observed(
24+
model_ml_multigroup2,
25+
column = :school,
26+
specification = partable,
27+
data = dat,
28+
)
29+
2330
# gradients
2431
@testset "ml_gradients_multigroup" begin
2532
test_gradient(model_ml_multigroup, start_test; atol = 1e-9)
@@ -46,6 +53,12 @@ end
4653
)
4754
end
4855

56+
@testset "replace_observed_multigroup" begin
57+
sem_fit_1 = fit(semoptimizer, model_ml_multigroup)
58+
sem_fit_2 = fit(semoptimizer, model_ml_multigroup3)
59+
@test sem_fit_1.solution sem_fit_2.solution
60+
end
61+
4962
@testset "fitmeasures/se_ml" begin
5063
solution_ml = fit(model_ml_multigroup)
5164
test_fitmeasures(

0 commit comments

Comments
 (0)