Skip to content

Commit

Permalink
fix: add lock for pymc init point func
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Dec 16, 2024
1 parent 04af51c commit 9a1da91
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from importlib.util import find_spec
from math import prod
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
import threading

import numpy as np
import pandas as pd
Expand All @@ -30,7 +31,7 @@ def intrinsic(f):
from pytensor.tensor import TensorVariable, Variable


def rv_dict_to_flat_array_wrapper(
def _rv_dict_to_flat_array_wrapper(
fn: Callable[[SeedType], dict[str, np.ndarray]],
names: list[str],
shapes: list[tuple[int]],
Expand Down Expand Up @@ -69,7 +70,7 @@ def seeded_array_fn(seed: SeedType = None):
for name, shape in zip(names, shapes, strict=True):
initial_value = initial_value_dict[name]
n = int(np.prod(initial_value.shape))
if initial_value.shape != shape:
if initial_value.shape != tuple(shape):
raise ValueError(
f"Size of initial value for {name} is {initial_value.shape}, "
f"expected {shape}"
Expand Down Expand Up @@ -498,6 +499,8 @@ def compile_pymc_model(
return_transformed=True,
)

initial_point_fn = _wrap_with_lock(initial_point_fn)

if backend.lower() == "numba":
if gradient_backend == "jax":
raise ValueError("Gradient backend cannot be jax when using numba backend")
Expand All @@ -515,6 +518,17 @@ def compile_pymc_model(
raise ValueError(f"Backend must be one of numba and jax. Got {backend}")


def _wrap_with_lock(func: Callable) -> Callable:
lock = threading.Lock()

@wraps(func)
def wrapper(*args, **kwargs):
with lock:
return func(*args, **kwargs)

return wrapper


def _compute_shapes(model) -> dict[str, tuple[int, ...]]:
import pytensor
from pymc.initial_point import make_initial_point_fn
Expand Down Expand Up @@ -645,7 +659,7 @@ def _make_functions(

num_free_vars = count

initial_point_fn = rv_dict_to_flat_array_wrapper(
initial_point_fn = _rv_dict_to_flat_array_wrapper(
pymc_initial_point_fn, names=joined_names, shapes=joined_shapes
)

Expand Down

0 comments on commit 9a1da91

Please sign in to comment.