Skip to content

Commit 28a0488

Browse files
authored
refactor (#580)
* refactor * remove logging * ref
1 parent 551dfcc commit 28a0488

20 files changed

+382
-420
lines changed

preliz/__init__.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
44
Tools to help you pick a prior
55
"""
6-
import logging
76
from os import path as os_path
87

98
from matplotlib import rcParams
@@ -18,13 +17,6 @@
1817

1918
__version__ = "0.11.0"
2019

21-
_log = logging.getLogger("preliz")
22-
23-
if not logging.root.handlers:
24-
_log.setLevel(logging.INFO)
25-
if len(_log.handlers) == 0:
26-
handler = logging.StreamHandler()
27-
_log.addHandler(handler)
2820

2921
# Allow legend outside plot in maxent to be included when saving a figure
3022
# We may want to make this more explicit by having preliz.rcParams
@@ -37,4 +29,4 @@
3729
style.core.reload_library()
3830

3931
# clean namespace
40-
del logging, os_path, rcParams, _preliz_style_path, _log
32+
del os_path, rcParams, _preliz_style_path

preliz/internal/distribution_helper.py

+3-26
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def from_precision(precision):
1212

1313

1414
def to_precision(sigma):
15-
precision = 1 / sigma**2
15+
precision = 1 / (eps + sigma**2)
1616
return precision
1717

1818

@@ -148,38 +148,15 @@ def num_kurtosis(dist):
148148
}
149149

150150

151-
def get_distributions(dist_names=None, exclude=None):
151+
def get_distributions(dist_names=None):
152152

153153
if dist_names is None:
154154
all_distributions = modules["preliz.distributions"].__all__
155155
else:
156156
all_distributions = dist_names
157157

158-
if exclude is None:
159-
exclude = []
160-
if exclude == "auto":
161-
exclude = [
162-
"Beta",
163-
"BetaScaled",
164-
"Triangular",
165-
"TruncatedNormal",
166-
"Uniform",
167-
"VonMises",
168-
"Categorical",
169-
"DiscreteUniform",
170-
"HyperGeometric",
171-
"zeroInflatedBinomial",
172-
"ZeroInflatedNegativeBinomial",
173-
"ZeroInflatedPoisson",
174-
"MvNormal",
175-
"Mixture",
176-
]
177-
178158
distributions = []
179159
for a_dist in all_distributions:
180160
dist = getattr(modules["preliz.distributions"], a_dist)()
181-
if dist.__class__.__name__ not in exclude:
182-
distributions.append(dist)
183-
if exclude:
184-
return exclude, distributions
161+
distributions.append(dist)
185162
return distributions

preliz/internal/logging.py

-13
This file was deleted.

preliz/internal/parser.py

-208
This file was deleted.

preliz/internal/plot_helper.py

+1-60
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
try:
66
from IPython import get_ipython
77
from ipywidgets import FloatSlider, IntSlider, FloatText, IntText, Checkbox, ToggleButton
8-
from pymc import sample_prior_predictive
98
except ImportError:
109
pass
1110

1211
import numpy as np
1312
import matplotlib.pyplot as plt
1413
from matplotlib import _pylab_helpers, get_backend
1514
from matplotlib.ticker import MaxNLocator
16-
from .logging import disable_pymc_sampling_logs
17-
from .narviz import hdi, kde
15+
from preliz.internal.narviz import hdi, kde
1816

1917

2018
def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False, ax=None):
@@ -425,63 +423,6 @@ def looper(*args, **kwargs):
425423
return looper
426424

427425

428-
def bambi_plot_decorator(func, iterations, kind_plot, references, plot_func):
429-
def looper(*args, **kwargs):
430-
kwargs.pop("__resample__")
431-
x_min = kwargs.pop("__x_min__")
432-
x_max = kwargs.pop("__x_max__")
433-
if not kwargs.pop("__set_xlim__"):
434-
x_min = None
435-
x_max = None
436-
auto = True
437-
else:
438-
auto = False
439-
440-
model = func(*args, **kwargs)
441-
model.build()
442-
with disable_pymc_sampling_logs():
443-
idata = model.prior_predictive(iterations)
444-
results = (
445-
idata["prior_predictive"].stack(sample=("chain", "draw"))[model.response_name].values.T
446-
)
447-
448-
_, ax = plt.subplots()
449-
ax.set_xlim(x_min, x_max, auto=auto)
450-
if plot_func is None:
451-
plot_repr(results, kind_plot, references, iterations, ax)
452-
else:
453-
plot_func(results, ax)
454-
455-
return looper
456-
457-
458-
def pymc_plot_decorator(func, iterations, kind_plot, references, plot_func):
459-
def looper(*args, **kwargs):
460-
kwargs.pop("__resample__")
461-
x_min = kwargs.pop("__x_min__")
462-
x_max = kwargs.pop("__x_max__")
463-
if not kwargs.pop("__set_xlim__"):
464-
x_min = None
465-
x_max = None
466-
auto = True
467-
else:
468-
auto = False
469-
with func(*args, **kwargs) as model:
470-
obs_name = model.observed_RVs[0].name
471-
with disable_pymc_sampling_logs():
472-
idata = sample_prior_predictive(samples=iterations)
473-
results = idata["prior_predictive"].stack(sample=("chain", "draw"))[obs_name].values.T
474-
475-
_, ax = plt.subplots()
476-
ax.set_xlim(x_min, x_max, auto=auto)
477-
if plot_func is None:
478-
plot_repr(results, kind_plot, references, iterations, ax)
479-
else:
480-
plot_func(results, ax)
481-
482-
return looper
483-
484-
485426
def plot_repr(results, kind_plot, references, iterations, ax):
486427
alpha = max(0.01, 1 - iterations * 0.009)
487428

0 commit comments

Comments
 (0)