-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add constant and observed data to nutpie idata #6943
Conversation
|
@Y0dler this is probably better done on the nutpie side |
I disagree, because we'd introduce several dependencies onto internal PyMC functions to the nutpie codebase. |
nutpie wants to return a full InferenceData regardless of pymc so this would be rendered useless pretty soon |
At that point nutpie will probably require |
Calling |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6943 +/- ##
==========================================
- Coverage 92.18% 92.14% -0.05%
==========================================
Files 100 100
Lines 16847 16858 +11
==========================================
+ Hits 15531 15533 +2
- Misses 1316 1325 +9
|
Thank you @Y0dler, I think this looks good. The nutpie API itself also directly takes a model as input as it is: import nutpie
compiled = nutpie.compiled_pymc_model(model)
nutpie.sample(compiled) And if used like this, it also should store the constant data (even if it was changed using the |
I don't understand what you're suggesting. I checked The solution proposed in this PR fixes the bug for the vast majority of PyMC users, without introducing complexity of coupling between the packages. And it can easily be adapted/removed once the |
Why would one have to do that? All the information about constant/observed data is contained in the pymc model passed to Also, nutpie already handles deterministics just fine, why couldn't it also handle observed and constant data in the returned InferenceData? |
Oh nutpie could totally do it! But the current control flow, and its internal design makes is very non-obvious how. At least I don't see how it could/should be done. From the PyMC perspective this PR is a totally fine fix until it's taken care of on the nupie side. And to be honest, the time @Y0dler and me spent on this PR cost our employer about 400 bucks, so there won't be any more heavy lifting on this issue from our side. |
That doesn't sound like a fair argument to motivate design choices. The issue was moved to the nutpie repo and the PR was opened without any public discussion. Anyway I am fine with the PR as a quick patch, with a comment to revert once the original issue is addressed. |
8d60b7d
to
7d10905
Compare
@@ -293,8 +299,24 @@ def _sample_external_nuts( | |||
"`idata_kwargs` are currently ignored by the nutpie sampler", | |||
UserWarning, | |||
) | |||
|
|||
# gather observed and constant data as nutpie.sample() has no access to the PyMC model | |||
coords, dims = coords_and_dims_for_inferencedata(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done after sampling (if sampling fails, there's no point in doing this work in advance).
Please add a comment along the lines of # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a very cheap function call:
Lines 103 to 112 in 6f4a040
def coords_and_dims_for_inferencedata(model: Model) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
"""Parse PyMC model coords and dims format to one accepted by InferenceData.""" | |
coords = { | |
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals | |
for cname, cvals in model.coords.items() | |
if cvals is not None | |
} | |
dims = {dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()} | |
return coords, dims |
And sometimes this step catches issues with coordinate value types, that's why I recommended to put it before sampling (just like gathering constant and observed data)
Temporary work-around for pymc-devs/nutpie#74.
With gracious assistance from Michael Osthege.
What is this PR about?
When sampling with the external NUTS sampler "nutpie", idata was not updated with neither attributes nor constant or observed data. Using pre-existing methods, this was fixed.
Checklist
Bugfixes
📚 Documentation preview 📚: https://pymc--6943.org.readthedocs.build/en/6943/