-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: blackjax sampler gives Incorrect output dtype for return value #0: Expected: int64, Actual: int32 #7593
Comments
] |
Probably a blackjax or jax+windows problem. CC @junpenglao |
I get the same issue when using numpyro. But, using numpyro natively, i.e., not through the pymc interface, does work. So, my guess is that it somehow has to do with how pymc interacts with numpyro/blackjax. |
Might have to do with float64. PyMC uses it by default. You can try to run this code at the very top of your script/notebook to set it to float32: import pytensor
pytensor.config.floatX = "float32" |
This solves the problem for me! Do get I warning though, which I can just seem to ignore. Using
|
Describe the issue:
When I use the blackjax backend, I get datetype errors.
the code runs fine with the pymc and nutpie samplers.
Reproduceable code example:
Error message:
PyMC version information:
Platform windows 11 (winpython distribution), Python 3.12.6, PyMC v5.18.2, blackjax 1.2.4
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: