Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: blackjax sampler gives Incorrect output dtype for return value #0: Expected: int64, Actual: int32 #7593

Open
mvds314 opened this issue Nov 27, 2024 · 5 comments
Labels

Comments

@mvds314
Copy link

mvds314 commented Nov 27, 2024

Describe the issue:

When I use the blackjax backend, I get datetype errors.

the code runs fine with the pymc and nutpie samplers.

Reproduceable code example:

import numpy as np
import pymc as pm

from pymc import HalfCauchy, Model, Normal, sample

if __name__ == "__main__":
    print(f"Running on PyMC v{pm.__version__}")
    RANDOM_SEED = 8927
    rng = np.random.default_rng(RANDOM_SEED)
    y = 1 + rng.normal(scale=0.5, size=200)
    with Model() as model:
        sigma = HalfCauchy("sigma", beta=10)
        mu = Normal("mu", mu=0, sigma=1)
        _ = Normal("y", mu=mu, sigma=sigma, observed=y)
        idata = sample(3000, progressbar=True, nuts_sampler="blackjax")

Error message:

<details>
XlaRuntimeError: INTERNAL: Compute error: CpuCallback error: Traceback (most recent call last):
  File "C:\.....\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2781, in _wrapped_callback
RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32
<details>

PyMC version information:

Platform windows 11 (winpython distribution), Python 3.12.6, PyMC v5.18.2, blackjax 1.2.4

Context for the issue:

No response

@mvds314 mvds314 added the bug label Nov 27, 2024
Copy link

welcome bot commented Nov 27, 2024

Welcome Banner]
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 27, 2024

Probably a blackjax or jax+windows problem. CC @junpenglao

@mvds314
Copy link
Author

mvds314 commented Nov 27, 2024

I get the same issue when using numpyro. But, using numpyro natively, i.e., not through the pymc interface, does work. So, my guess is that it somehow has to do with how pymc interacts with numpyro/blackjax.

@ricardoV94
Copy link
Member

I get the same issue when using numpyro. But, using numpyro natively, i.e., not through the pymc interface, does work. So, my guess is that it somehow has to do with how pymc interacts with numpyro/blackjax.

Might have to do with float64. PyMC uses it by default. You can try to run this code at the very top of your script/notebook to set it to float32:

import pytensor
pytensor.config.floatX = "float32"

@mvds314
Copy link
Author

mvds314 commented Nov 28, 2024

I get the same issue when using numpyro. But, using numpyro natively, i.e., not through the pymc interface, does work. So, my guess is that it somehow has to do with how pymc interacts with numpyro/blackjax.

Might have to do with float64. PyMC uses it by default. You can try to run this code at the very top of your script/notebook to set it to float32:

import pytensor
pytensor.config.floatX = "float32"

This solves the problem for me!

Do get I warning though, which I can just seem to ignore. Using jax_enable_x64, as suggested in the warning in this way, doesn't get rid of it.

C:...\Lib\site-packages\jax_src\numpy\array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
return lax_numpy.astype(self, dtype, copy=copy, device=device)
Running window adaptation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants