From 596a27c04397f8eeb4a7f7500e72811ea1b7ec45 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Sat, 21 Dec 2024 11:19:28 +0100 Subject: [PATCH] Precompile fn in ZarrChain and borrow inputs --- pymc/backends/zarr.py | 13 ++++++++++--- pymc/model/core.py | 7 +++++++ pymc/sampling/parallel.py | 6 ++++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index 127ce7bdd5e..179c07bcc08 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Mapping, MutableMapping, Sequence +from collections.abc import Callable, Mapping, MutableMapping, Sequence from typing import Any import arviz as az @@ -85,8 +85,9 @@ def __init__( vars: Sequence[TensorVariable] | None = None, test_point: dict[str, np.ndarray] | None = None, draws_per_chunk: int = 1, + fn: Callable | None = None, ): - super().__init__(name="zarr", model=model, vars=vars, test_point=test_point) + super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn) self._step_method: BlockedStep | CompoundStep | None = None self.unconstrained_variables = { var.name for var in self.vars if is_transformed_name(var.name) @@ -442,7 +443,12 @@ def init_trace( ) self.vars = [var for var in vars if var.name in self.varnames] - self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore") + self.fn = model.compile_fn( + self.vars, + inputs=model.value_vars, + on_unused_input="ignore", + borrow_vars=True, + ) # Get variable shapes. Most backends will need this # information. @@ -518,6 +524,7 @@ def init_trace( test_point=test_point, stats_bijection=StatsBijection(step.stats_dtypes), draws_per_chunk=self.draws_per_chunk, + fn=self.fn, ) for _ in range(chains) ] diff --git a/pymc/model/core.py b/pymc/model/core.py index 99711e566ed..aababc89265 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1611,6 +1611,7 @@ def compile_fn( inputs: Sequence[Variable] | None = None, mode=None, point_fn: bool = True, + borrow_vars: bool = False, **kwargs, ) -> PointFunc | Function: """Compiles a PyTensor function. @@ -1636,6 +1637,10 @@ def compile_fn( if inputs is None: inputs = inputvars(outs) + if borrow_vars: + inputs = [pytensor.In(v, borrow=True) for v in inputs] + outs = [pytensor.Out(v, borrow=True) for v in outs] + with self: fn = compile( inputs, @@ -1645,6 +1650,8 @@ def compile_fn( mode=mode, **kwargs, ) + if borrow_vars: + fn.trust_input = True if point_fn: return PointFunc(fn) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 794763e6e12..28e74d5e8ae 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -110,8 +110,10 @@ def __init__( zarr_chains: list[ZarrChain] | bytes | None = None, zarr_chains_is_pickled: bool = False, ): - # For some strange reason, spawn multiprocessing doesn't copy the rng - # seed sequence, so we have to rebuild it from scratch + # Because of https://github.com/numpy/numpy/issues/27727, we can't send + # the rng instance to the child process because pickling (copying) looses + # the seed sequence state information. For this reason, we send a + # RandomGeneratorState instead. rng = random_generator_from_state(rng_state) self._msg_pipe = msg_pipe self._step_method = step_method