3
3
4
4
(2) replace_observed(model::AbstractSemSingle, observed; kwargs...)
5
5
6
+ (3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...)
7
+
6
8
Return a new model with swaped observed part.
7
9
8
10
# Arguments
9
11
- `model::AbstractSemSingle`: model to swap the observed part of.
10
12
- `kwargs`: additional keyword arguments; typically includes `data` and `specification`
11
13
- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`
12
14
15
+ # For SemEnsemble models:
16
+ - `column`: if a DataFrame is passed as `data = ...`, which column signifies the group?
17
+ - `weights`: how to weight the different sub-models,
18
+ defaults to number of samples per group in the new data
19
+ - `kwargs`: has to be a dict with keys equal to the group names.
20
+ For `data` can also be a DataFrame with `column` containing the group information,
21
+ and for `specification` can also be an `EnsembleParameterTable`.
22
+
13
23
# Examples
14
24
See the online documentation on [Replace observed data](@ref).
15
25
"""
@@ -37,51 +47,28 @@ function update_observed end
37
47
replace_observed (model:: AbstractSemSingle ; kwargs... ) =
38
48
replace_observed (model, typeof (observed (model)). name. wrapper; kwargs... )
39
49
40
- # construct a new observed type
41
- replace_observed (model:: AbstractSemSingle , observed_type; kwargs... ) =
42
- replace_observed (model, observed_type (; kwargs... ); kwargs... )
43
-
44
- replace_observed (model:: AbstractSemSingle , new_observed:: SemObserved ; kwargs... ) =
45
- replace_observed (
46
- model,
47
- observed (model),
48
- implied (model),
49
- loss (model),
50
- new_observed;
51
- kwargs... ,
52
- )
53
-
54
- function replace_observed (
55
- model:: AbstractSemSingle ,
56
- old_observed,
57
- implied,
58
- loss,
59
- new_observed:: SemObserved ;
60
- kwargs... ,
61
- )
50
+ function replace_observed (model:: AbstractSemSingle , observed_type; kwargs... )
51
+ new_observed = observed_type (;kwargs... )
62
52
kwargs = Dict {Symbol, Any} (kwargs... )
63
53
64
54
# get field types
65
55
kwargs[:observed_type ] = typeof (new_observed)
66
- kwargs[:old_observed_type ] = typeof (old_observed )
67
- kwargs[:implied_type ] = typeof (implied)
68
- kwargs[:loss_types ] = [typeof (lossfun) for lossfun in loss. functions]
56
+ kwargs[:old_observed_type ] = typeof (model . observed )
57
+ kwargs[:implied_type ] = typeof (model . implied)
58
+ kwargs[:loss_types ] = [typeof (lossfun) for lossfun in model . loss. functions]
69
59
70
60
# update implied
71
- implied = update_observed (implied, new_observed; kwargs... )
72
- kwargs[:implied ] = implied
73
- kwargs[:nparams ] = nparams (implied )
61
+ new_implied = update_observed (model . implied, new_observed; kwargs... )
62
+ kwargs[:implied ] = new_implied
63
+ kwargs[:nparams ] = nparams (new_implied )
74
64
75
65
# update loss
76
- loss = update_observed (loss, new_observed; kwargs... )
77
- kwargs[:loss ] = loss
78
-
79
- # new_implied = update_observed(model.implied, new_observed; kwargs...)
66
+ new_loss = update_observed (model. loss, new_observed; kwargs... )
80
67
81
68
return Sem (
82
69
new_observed,
83
- update_observed (model . implied, new_observed; kwargs ... ) ,
84
- update_observed (model . loss, new_observed; kwargs ... ),
70
+ new_implied ,
71
+ new_loss
85
72
)
86
73
end
87
74
@@ -92,6 +79,39 @@ function update_observed(loss::SemLoss, new_observed; kwargs...)
92
79
return SemLoss (new_functions, loss. weights)
93
80
end
94
81
82
+
83
+ function replace_observed (
84
+ emodel:: SemEnsemble ;
85
+ column = :group ,
86
+ weights = nothing ,
87
+ kwargs... ,
88
+ )
89
+ kwargs = Dict {Symbol, Any} (kwargs... )
90
+ # allow for EnsembleParameterTable to be passed as specification
91
+ if haskey (kwargs, :specification ) && isa (kwargs[:specification ], EnsembleParameterTable)
92
+ kwargs[:specification ] = convert (Dict{Symbol, RAMMatrices}, kwargs[:specification ])
93
+ end
94
+ # allow for DataFrame with group variable "column" to be passed as new data
95
+ if haskey (kwargs, :data ) && isa (kwargs[:data ], DataFrame)
96
+ kwargs[:data ] = Dict (
97
+ group => select (
98
+ filter (
99
+ r -> r[column] == group,
100
+ kwargs[:data ]),
101
+ Not (column)) for group in emodel. groups)
102
+ end
103
+ # update each model for new data
104
+ models = emodel. sems
105
+ new_models = Tuple (
106
+ replace_observed (m; group_kwargs (g, kwargs)... ) for (m, g) in zip (models, emodel. groups)
107
+ )
108
+ return SemEnsemble (new_models... ; weights = weights, groups = emodel. groups)
109
+ end
110
+
111
+ function group_kwargs (g, kwargs)
112
+ return Dict (k => kwargs[k][g] for k in keys (kwargs))
113
+ end
114
+
95
115
# ###########################################################################################
96
116
# simulate data
97
117
# ###########################################################################################
0 commit comments