Skip to content

Commit 551dfcc

Browse files
authored
Refactor ppe (#579)
* refactor ppe * lint
1 parent 84d301f commit 551dfcc

File tree

8 files changed

+204
-113
lines changed

8 files changed

+204
-113
lines changed

preliz/internal/optimization.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -253,30 +253,41 @@ def interval_short(params):
253253

254254

255255
def optimize_pymc_model(
256-
fmodel, target, draws, prior, initial_guess, bounds, var_info, p_model, rng
256+
fmodel,
257+
target,
258+
num_draws,
259+
bounds,
260+
initial_guess,
261+
prior,
262+
preliz_model,
263+
transformed_var_info,
264+
rng,
257265
):
258-
for _ in range(400):
266+
for idx in range(401):
259267
# can we sample systematically from these and less random?
260268
# This should be more flexible and allow other targets than just
261-
# a preliz distribution
269+
# a PreliZ distribution
262270
if isinstance(target, list):
263-
obs = get_weighted_rvs(target, draws, rng)
271+
obs = get_weighted_rvs(target, num_draws, rng)
264272
else:
265-
obs = target.rvs(draws, random_state=rng)
273+
obs = target.rvs(num_draws, random_state=rng)
266274
result = minimize(
267275
fmodel,
268276
initial_guess,
269277
tol=0.001,
270278
method="SLSQP",
271-
args=(obs, var_info, p_model),
279+
args=(obs, transformed_var_info, preliz_model),
272280
bounds=bounds,
273281
)
274-
275282
optimal_params = result.x
283+
# To help minimize the effect of priors
284+
# We don't save the first result and insteas we use it as the initial guess
285+
# for the next optimization
286+
# Updating the initial guess also helps to provides more spread samples
276287
initial_guess = optimal_params
277-
278-
for key, param in zip(prior.keys(), optimal_params):
279-
prior[key].append(param)
288+
if idx:
289+
for key, param in zip(prior.keys(), optimal_params):
290+
prior[key].append(param)
280291

281292
# convert to numpy arrays
282293
for key, value in prior.items():

preliz/internal/predictive_helper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .distribution_helper import get_distributions
1010

1111

12-
def back_fitting(model, subset, new_families=True):
12+
def back_fitting_ppa(model, subset, new_families=True):
1313
"""
1414
Use MLE to fit a subset of the prior samples to the marginal prior distributions
1515
"""

preliz/ppls/agnostic.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from preliz.distributions import Gamma, Normal, HalfNormal
88
from preliz.unidimensional.mle import mle
99
from preliz.ppls.pymc_io import get_model_information, write_pymc_string
10-
from preliz.ppls.bambi_io import get_bmb_model_information, write_bambi_string
10+
from preliz.ppls.bambi_io import get_pymc_model, write_bambi_string
1111

1212
_log = logging.getLogger("preliz")
1313

@@ -41,10 +41,23 @@ def posterior_to_prior(model, idata, alternative=None, engine="auto"):
4141
"""
4242
_log.info(""""This is an experimental method under development, use with caution.""")
4343
engine = get_engine(model) if engine == "auto" else engine
44+
4445
if engine == "bambi":
45-
_, _, model_info, _, var_info2, *_ = get_bmb_model_information(model)
46+
model = get_pymc_model(model)
47+
48+
_, _, preliz_model, _, untransformed_var_info, *_ = get_model_information(model)
49+
50+
new_priors = back_fitting_idata(idata, preliz_model, alternative)
51+
52+
if engine == "bambi":
53+
new_model = write_bambi_string(new_priors, untransformed_var_info)
4654
elif engine == "pymc":
47-
_, _, model_info, _, var_info2, *_ = get_model_information(model)
55+
new_model = write_pymc_string(new_priors, untransformed_var_info)
56+
57+
return new_model
58+
59+
60+
def back_fitting_idata(idata, model_info, alternative):
4861
new_priors = {}
4962
posterior = idata.posterior.stack(sample=("chain", "draw"))
5063

@@ -66,10 +79,4 @@ def posterior_to_prior(model, idata, alternative=None, engine="auto"):
6679

6780
idx, _ = mle(dists, posterior[var].values, plot=False)
6881
new_priors[var] = dists[idx[0]]
69-
70-
if engine == "bambi":
71-
new_model = write_bambi_string(new_priors, var_info2)
72-
elif engine == "pymc":
73-
new_model = write_pymc_string(new_priors, var_info2)
74-
75-
return new_model
82+
return new_priors

preliz/ppls/bambi_io.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
1-
from preliz.ppls.pymc_io import get_model_information
1+
"""Functions to communicate with Bambi."""
22

33

4-
def get_bmb_model_information(model):
4+
def get_pymc_model(model):
55
if not model.built:
66
model.build()
77
pymc_model = model.backend.model
8-
return get_model_information(pymc_model)
8+
return pymc_model
99

1010

1111
def write_bambi_string(new_priors, var_info):
1212
"""
1313
Return a string with the new priors for the Bambi model.
1414
So the user can copy and paste, ideally with none to minimal changes.
1515
"""
16-
header = "{"
16+
header = "{\n"
1717
for key, value in new_priors.items():
1818
dist_name, dist_params = repr(value).split("(")
1919
dist_params = dist_params.rstrip(")")
2020
size = var_info[key][1]
2121
if size > 1:
22-
header += f'"{key}" : bmb.Prior("{dist_name}", {dist_params}, shape={size}), '
22+
header += f'"{key}" : bmb.Prior("{dist_name}", {dist_params}, shape={size}),\n'
2323
else:
24-
header += f'"{key}" : bmb.Prior("{dist_name}", {dist_params}), '
24+
header += f'"{key}" : bmb.Prior("{dist_name}", {dist_params}),\n'
2525

2626
header = header.rstrip(", ") + "}"
2727
return header

preliz/ppls/pymc_io.py

+45-22
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,29 @@
1717
from preliz.internal.distribution_helper import get_distributions
1818

1919

20-
def backfitting(prior, p_model, var_info2):
20+
def back_fitting_pymc(prior, preliz_model, untransformed_var_info):
2121
"""
2222
Fit the samples from prior into user provided model's prior.
2323
from the perspective of ppe "prior" is actually an approximated posterior
2424
but from the users perspective is its prior.
25-
We need to "backfitted" because we can not use arbitrary samples as priors.
25+
We need to "backfit" because we can not use arbitrary samples as priors.
2626
We need probability distributions.
2727
"""
2828
new_priors = {}
29-
for key, size_inf in var_info2.items():
29+
for key, size_inf in untransformed_var_info.items():
3030
if not size_inf[2]:
3131
size = size_inf[1]
3232
if size > 1:
3333
params = []
3434
for i in range(size):
3535
value = prior[f"{key}__{i}"]
36-
dist = p_model[key]
36+
dist = preliz_model[key]
3737
dist._fit_mle(value)
3838
params.append(dist.params)
3939
dist._parametrization(*[np.array(x) for x in zip(*params)])
4040
else:
4141
value = prior[key]
42-
dist = p_model[key]
42+
dist = preliz_model[key]
4343
dist._fit_mle(value)
4444

4545
new_priors[key] = dist
@@ -81,7 +81,7 @@ def get_pymc_to_preliz():
8181
return pymc_to_preliz
8282

8383

84-
def get_guess(model, free_rvs):
84+
def get_initial_guess(model, free_rvs):
8585
"""
8686
Get initial guess for optimization routine.
8787
"""
@@ -104,17 +104,32 @@ def get_guess(model, free_rvs):
104104

105105
def get_model_information(model): # pylint: disable=too-many-locals
106106
"""
107-
Get information from the PyMC model.
108-
109-
This needs some love. We even have a variable named var_info,
110-
and another one var_info2!
107+
Get information from a PyMC model.
108+
109+
Parameters
110+
----------
111+
model : a PyMC model
112+
113+
Returns
114+
-------
115+
bounds : a list of tuples with the support of each marginal distribution in the model
116+
prior : a dictionary with a key for each marginal distribution in the model and an empty
117+
list as value. This will be filled with the samples from a backfitting procedure.
118+
preliz_model : a dictionary with a key for each marginal distribution in the model and the
119+
corresponding PreliZ distribution as value
120+
transformed_var_info : a dictionary with a key for each transformed variable in the model
121+
and a tuple with the shape, size and the indexes of the non-constant parents as value
122+
untransformed_var_info : same as `transformed_var_info` but the keys are untransformed
123+
variable names
124+
num_draws : the number of observed samples
125+
free_rvs : a list with the free random variables in the model
111126
"""
112127

113128
bounds = []
114129
prior = {}
115-
p_model = {}
116-
var_info = {}
117-
var_info2 = {}
130+
preliz_model = {}
131+
transformed_var_info = {}
132+
untransformed_var_info = {}
118133
free_rvs = []
119134
pymc_to_preliz = get_pymc_to_preliz()
120135
rvs_to_values = model.rvs_to_values
@@ -128,13 +143,13 @@ def get_model_information(model): # pylint: disable=too-many-locals
128143
r_v.owner.op.name if r_v.owner.op.name else str(r_v.owner.op).split("RV", 1)[0].lower()
129144
)
130145
dist = copy(pymc_to_preliz[name])
131-
p_model[r_v.name] = dist
146+
preliz_model[r_v.name] = dist
132147
if nc_parents:
133148
idxs = [free_rvs.index(var_) for var_ in nc_parents]
134149
# the keys are the name of the (transformed) variable
135-
var_info[rvs_to_values[r_v].name] = (shape, size, idxs)
150+
transformed_var_info[rvs_to_values[r_v].name] = (shape, size, idxs)
136151
# the keys are the name of the (untransformed) variable
137-
var_info2[r_v.name] = (shape, size, idxs)
152+
untransformed_var_info[r_v.name] = (shape, size, idxs)
138153
else:
139154
free_rvs.append(r_v)
140155

@@ -147,13 +162,21 @@ def get_model_information(model): # pylint: disable=too-many-locals
147162
prior[r_v.name] = []
148163

149164
# the keys are the name of the (transformed) variable
150-
var_info[rvs_to_values[r_v].name] = (shape, size, nc_parents)
165+
transformed_var_info[rvs_to_values[r_v].name] = (shape, size, nc_parents)
151166
# the keys are the name of the (untransformed) variable
152-
var_info2[r_v.name] = (shape, size, nc_parents)
153-
154-
draws = model.observed_RVs[0].eval().size
155-
156-
return bounds, prior, p_model, var_info, var_info2, draws, free_rvs
167+
untransformed_var_info[r_v.name] = (shape, size, nc_parents)
168+
169+
num_draws = model.observed_RVs[0].eval().size
170+
171+
return (
172+
bounds,
173+
prior,
174+
preliz_model,
175+
transformed_var_info,
176+
untransformed_var_info,
177+
num_draws,
178+
free_rvs,
179+
)
157180

158181

159182
def write_pymc_string(new_priors, var_info):

preliz/predictive/ppa.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
plot_pp_mean,
2020
)
2121
from ..internal.parser import get_prior_pp_samples, from_preliz, from_bambi
22-
from ..internal.predictive_helper import back_fitting, select_prior_samples
22+
from ..internal.predictive_helper import back_fitting_ppa, select_prior_samples
2323
from ..distributions import Normal
2424
from ..distributions.distributions import Distribution
2525

@@ -386,7 +386,7 @@ def on_return_prior(self):
386386
if len(selected) > 4:
387387
subsample = select_prior_samples(selected, self.prior_samples, self.model)
388388

389-
string, _ = back_fitting(self.model, subsample, new_families=False)
389+
string, _ = back_fitting_ppa(self.model, subsample, new_families=False)
390390

391391
self.fig.clf()
392392
plt.text(0.05, 0.5, string, fontsize=14)

0 commit comments

Comments
 (0)