Skip to content

Commit

Permalink
.wip force load
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 5, 2023
1 parent 8c82b48 commit 334a211
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 5 additions & 0 deletions pymc_experimental/tests/utils/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ def test_cache_sampling(tmpdir):
assert len(os.listdir(tmpdir)) == 0

prior1, prior2 = (cache_prior(samples=5) for _ in range(2))
prior3 = cache_sampling(pm.sample_prior_predictive, dir=tmpdir, force_sample=True)(
samples=5
)
assert len(os.listdir(tmpdir)) == 1
assert prior1.prior["x"].mean() == prior2.prior["x"].mean()
assert prior2.prior["x"].mean() != prior3.prior["x"].mean()
assert prior2.prior_predictive["y"].mean() != prior3.prior_predictive["y"].mean()

post1, post2 = (cache_post(tune=5, draws=5, progressbar=False) for _ in range(2))
assert len(os.listdir(tmpdir)) == 2
Expand Down
6 changes: 5 additions & 1 deletion pymc_experimental/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def cache_sampling(
sampling_fn: Literal[sample, sample_prior_predictive, sample_posterior_predictive],
dir: str = "",
force_sample: bool = False,
force_load: bool = True,
) -> Callable:
"""Cache the result of PyMC sampling.
Expand All @@ -88,6 +89,7 @@ def cache_sampling(
The directory where the results should be saved or retrieved from. Defaults to working directory.
force_sample: bool, Optional
Whether to force sampling even if cache is found. Defaults to False.
force_load:
Returns
-------
Expand Down Expand Up @@ -163,7 +165,9 @@ def wrapped_sampling_fn(*args, model=None, random_seed=None, **kwargs):

if not force_sample and os.path.exists(file_path):
print("Cache hit! Returning stored result", file=sys.stdout)
idata_out = az.from_netcdf(file_path)
idata_out: az.InferenceData = az.from_netcdf(file_path)
if force_load:
idata_out.load()

else:
idata_out = sampling_fn(*args, **kwargs, model=model, random_seed=random_seed)
Expand Down

0 comments on commit 334a211

Please sign in to comment.