Skip to content

Commit

Permalink
Precompile fn in ZarrChain and borrow inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Dec 21, 2024
1 parent 3c80cd4 commit 8019cf5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
13 changes: 10 additions & 3 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping, MutableMapping, Sequence
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any

import arviz as az
Expand Down Expand Up @@ -91,10 +91,11 @@ def __init__(
vars: Sequence[TensorVariable] | None = None,
test_point: dict[str, np.ndarray] | None = None,
draws_per_chunk: int = 1,
fn: Callable | None = None,
):
if not _zarr_available:
raise RuntimeError("You must install zarr to be able to create ZarrChain instances")
super().__init__(name="zarr", model=model, vars=vars, test_point=test_point)
super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn)
self._step_method: BlockedStep | CompoundStep | None = None
self.unconstrained_variables = {
var.name for var in self.vars if is_transformed_name(var.name)
Expand Down Expand Up @@ -452,7 +453,12 @@ def init_trace(
)
self.vars = [var for var in vars if var.name in self.varnames]

self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore")
self.fn = model.compile_fn(
self.vars,
inputs=model.value_vars,
on_unused_input="ignore",
borrow_vars=True,
)

# Get variable shapes. Most backends will need this
# information.
Expand Down Expand Up @@ -528,6 +534,7 @@ def init_trace(
test_point=test_point,
stats_bijection=StatsBijection(step.stats_dtypes),
draws_per_chunk=self.draws_per_chunk,
fn=self.fn,
)
for _ in range(chains)
]
Expand Down
7 changes: 7 additions & 0 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,7 @@ def compile_fn(
inputs: Sequence[Variable] | None = None,
mode=None,
point_fn: bool = True,
borrow_vars: bool = False,
**kwargs,
) -> PointFunc | Function:
"""Compiles a PyTensor function.
Expand All @@ -1636,6 +1637,10 @@ def compile_fn(
if inputs is None:
inputs = inputvars(outs)

if borrow_vars:
inputs = [pytensor.In(v, borrow=True) for v in inputs]
outs = [pytensor.Out(v, borrow=True) for v in outs]

with self:
fn = compile(
inputs,
Expand All @@ -1645,6 +1650,8 @@ def compile_fn(
mode=mode,
**kwargs,
)
if borrow_vars:
fn.trust_input = True

if point_fn:
return PointFunc(fn)
Expand Down
6 changes: 4 additions & 2 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def __init__(
zarr_chains: list[ZarrChain] | bytes | None = None,
zarr_chains_is_pickled: bool = False,
):
# For some strange reason, spawn multiprocessing doesn't copy the rng
# seed sequence, so we have to rebuild it from scratch
# Because of https://github.com/numpy/numpy/issues/27727, we can't send
# the rng instance to the child process because pickling (copying) looses
# the seed sequence state information. For this reason, we send a
# RandomGeneratorState instead.
rng = random_generator_from_state(rng_state)
self._msg_pipe = msg_pipe
self._step_method = step_method
Expand Down

0 comments on commit 8019cf5

Please sign in to comment.