-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add function that caches sampling results
- Loading branch information
1 parent
150fb0f
commit 5711f96
Showing
3 changed files
with
210 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,6 +49,7 @@ Utils | |
|
||
spline.bspline_interpolation | ||
prior.prior_from_idata | ||
cache.cache_sampling | ||
|
||
|
||
Statespace Models | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
|
||
import pymc as pm | ||
|
||
from pymc_experimental.utils.cache import cache_sampling | ||
|
||
|
||
def test_cache_sampling(tmpdir): | ||
|
||
with pm.Model() as m: | ||
x = pm.Normal("x", 0, 1) | ||
y = pm.Normal("y", mu=x, observed=[0, 1, 2]) | ||
|
||
cache_prior = cache_sampling(pm.sample_prior_predictive, path=tmpdir) | ||
cache_post = cache_sampling(pm.sample, path=tmpdir) | ||
cache_pred = cache_sampling(pm.sample_posterior_predictive, path=tmpdir) | ||
assert len(os.listdir(tmpdir)) == 0 | ||
|
||
prior1, prior2 = (cache_prior(samples=5) for _ in range(2)) | ||
assert len(os.listdir(tmpdir)) == 1 | ||
assert prior1.prior["x"].mean() == prior2.prior["x"].mean() | ||
|
||
post1, post2 = (cache_post(tune=5, draws=5, progressbar=False) for _ in range(2)) | ||
assert len(os.listdir(tmpdir)) == 2 | ||
assert post1.posterior["x"].mean() == post2.posterior["x"].mean() | ||
|
||
# Change model | ||
with pm.Model() as m: | ||
x = pm.Normal("x", 0, 1) | ||
y = pm.Normal("y", mu=x, observed=[0, 1, 2, 3]) | ||
|
||
post3 = cache_post(tune=5, draws=5, progressbar=False) | ||
assert len(os.listdir(tmpdir)) == 3 | ||
assert post3.posterior["x"].mean() != post1.posterior["x"].mean() | ||
|
||
pred1, pred2 = (cache_pred(trace=post3, progressbar=False) for _ in range(2)) | ||
assert len(os.listdir(tmpdir)) == 4 | ||
assert pred1.posterior_predictive["y"].mean() == pred2.posterior_predictive["y"].mean() | ||
assert "x" not in pred1.posterior_predictive | ||
|
||
# Change kwargs | ||
pred3 = cache_pred(trace=post3, progressbar=False, var_names=["x"]) | ||
assert len(os.listdir(tmpdir)) == 5 | ||
assert "x" in pred3.posterior_predictive |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import hashlib | ||
import os | ||
import sys | ||
from typing import Literal | ||
|
||
import arviz as az | ||
import numpy as np | ||
from pymc import ( | ||
modelcontext, | ||
sample, | ||
sample_posterior_predictive, | ||
sample_prior_predictive, | ||
) | ||
from pymc.model.fgraph import fgraph_from_model | ||
from pytensor.compile import SharedVariable | ||
from pytensor.graph import Constant, FunctionGraph | ||
from pytensor.scalar import ScalarType | ||
from pytensor.tensor import TensorType | ||
from pytensor.tensor.random.type import RandomType | ||
from pytensor.tensor.type_other import NoneTypeT | ||
|
||
|
||
def hash_data(c): | ||
if isinstance(c.type, NoneTypeT): | ||
return "" | ||
if isinstance(c.type, (ScalarType, TensorType)): | ||
if isinstance(c, Constant): | ||
arr = c.data | ||
elif isinstance(c, SharedVariable): | ||
arr = c.get_value(borrow=True) | ||
arr_data = arr.view(np.uint8) if arr.size > 1 else arr.tobytes() | ||
return hashlib.sha1(arr_data).hexdigest() | ||
else: | ||
raise NotImplementedError(f"Hashing not implemented for type {c.type}") | ||
|
||
|
||
def get_name_and_props(obj): | ||
name = str(obj) | ||
props = str(getattr(obj, "_props", lambda: {})()) | ||
return name, props | ||
|
||
|
||
def hash_from_fg(fg: FunctionGraph) -> int: | ||
objects_to_hash = [] | ||
for node in fg.toposort(): | ||
objects_to_hash.append( | ||
( | ||
get_name_and_props(node.op), | ||
tuple(get_name_and_props(inp.type) for inp in node.inputs), | ||
tuple(get_name_and_props(out.type) for out in node.outputs), | ||
# Name is not a symbolic input in the fgraph representation, maybe it should? | ||
tuple(inp.name for inp in node.inputs if inp.name), | ||
tuple(out.name for out in node.outputs if out.name), | ||
) | ||
) | ||
objects_to_hash.append( | ||
tuple( | ||
hash_data(c) | ||
for c in node.inputs | ||
if ( | ||
isinstance(c, (Constant, SharedVariable)) | ||
# Ignore RNG values | ||
and not isinstance(c.type, RandomType) | ||
) | ||
) | ||
) | ||
str_hash = "\n".join(map(str, objects_to_hash)) | ||
return hashlib.sha1(str_hash.encode()).hexdigest() | ||
|
||
|
||
def cache_sampling( | ||
sampling_fn: Literal[sample, sample_prior_predictive, sample_posterior_predictive], | ||
path: str = "", | ||
force_sample: bool = False, | ||
): | ||
"""Cache the result of PyMC sampling. | ||
Parameter | ||
--------- | ||
sampling_fn: Callable | ||
Must be one of `pymc.sample`, `pymc.sample_prior_predictive` or `pymc.sample_posterior_predictive`. | ||
Positional arguments are disallowed. | ||
path: string, Optional | ||
The path where the results should be saved or retrieved from. Defaults to working directory. | ||
force_sample: bool, Optional | ||
Whether to force sampling even if cache is found. Defaults to False. | ||
Returns | ||
------- | ||
cached_sampling_fn: Callable | ||
Function that wraps the sampling_fn. When called, the wrapped function will look for a valid cached result. | ||
A valid cache requires the same: | ||
1. Model and data | ||
2. Sampling function | ||
3. Sampling kwargs, ignoring ``random_seed``, ``trace``, ``progressbar``, ``extend_inferencedata`` and ``compile_kwargs``. | ||
If o valid cache is found, sampling is bypassed altogether, unless ``force_sample=True``. | ||
Otherwise, sampling is performed and the result cached for future reuse. | ||
Caching is done on the basis of SHA-1 hashing, and there could be unlikely false positives. | ||
Examples | ||
-------- | ||
.. code-block:: python | ||
import pymc as pm | ||
from pymc_experimental.utils import cache_sampling | ||
with pm.Model() as m: | ||
x = pm.Normal("x", 0, 1) | ||
y = pm.Normal("y", mu=x, observed=[0, 1, 2]) | ||
idata = cache_sampling(pm.sample)() | ||
with m: | ||
idata = cache_sampling(pm.sample)() # Cache hit! Returning stored result | ||
""" | ||
allowed_fns = (sample, sample_prior_predictive, sample_posterior_predictive) | ||
if sampling_fn not in allowed_fns: | ||
raise ValueError(f"Cache sampling can only be used with {allowed_fns}") | ||
|
||
def wrapped_sampling_fn(*args, model=None, random_seed=None, **kwargs): | ||
if args: | ||
raise ValueError("Non-keyword arguments not allowed in cache_sampling") | ||
|
||
extend_inferencedata = kwargs.pop("extend_inferencedata", False) | ||
|
||
# Model hash | ||
model = modelcontext(model) | ||
fg, _ = fgraph_from_model(model) | ||
model_hash = hash_from_fg(fg) | ||
|
||
# Sampling hash | ||
sampling_hash_kwargs = kwargs.copy() | ||
sampling_hash_kwargs["sampling_fn"] = str(sampling_fn) | ||
sampling_hash_kwargs.pop("trace", None) | ||
sampling_hash_kwargs.pop("random_seed", None) | ||
sampling_hash_kwargs.pop("progressbar", None) | ||
sampling_hash_kwargs.pop("compile_kwargs", None) | ||
sampling_hash = str(sampling_hash_kwargs) | ||
|
||
file_name = hashlib.sha1((model_hash + sampling_hash).encode()).hexdigest() + ".nc" | ||
file_path = os.path.join(path, file_name) | ||
|
||
if not force_sample and os.path.exists(file_path): | ||
print("Cache hit! Returning stored result", file=sys.stdout) | ||
idata_out = az.from_netcdf(file_path) | ||
|
||
else: | ||
idata_out = sampling_fn(*args, **kwargs, model=model, random_seed=random_seed) | ||
|
||
if os.path.exists(file_path): | ||
os.remove(file_path) | ||
az.to_netcdf(idata_out, file_path) | ||
|
||
# We save inferencedata separately and extend if needed | ||
if extend_inferencedata: | ||
trace = kwargs["trace"] | ||
trace.extend(idata_out) | ||
idata_out = trace | ||
|
||
return idata_out | ||
|
||
return wrapped_sampling_fn |