Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add constant and observed data to nutpie idata #6943

Merged
merged 2 commits into from
Oct 12, 2023

Conversation

Y0dler
Copy link
Contributor

@Y0dler Y0dler commented Oct 6, 2023

Temporary work-around for pymc-devs/nutpie#74.
With gracious assistance from Michael Osthege.

What is this PR about?
When sampling with the external NUTS sampler "nutpie", idata was not updated with neither attributes nor constant or observed data. Using pre-existing methods, this was fixed.

Checklist

Bugfixes

  • Observed and constant data are now included in the inference data object resulting from nutpie sampling.

📚 Documentation preview 📚: https://pymc--6943.org.readthedocs.build/en/6943/

@welcome
Copy link

welcome bot commented Oct 6, 2023

Thank You Banner
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

@ricardoV94
Copy link
Member

@Y0dler this is probably better done on the nutpie side

@michaelosthege
Copy link
Member

michaelosthege commented Oct 6, 2023

@Y0dler this is probably better done on the nutpie side

I disagree, because we'd introduce several dependencies onto internal PyMC functions to the nutpie codebase.
Additionally, we'd have to change the signature of the nutpie.sample function to take a pm.Model, where right now it's totally PyMC-independent.
For the other external samplers this is also done in the PyMC codebase.

@ricardoV94
Copy link
Member

nutpie wants to return a full InferenceData regardless of pymc so this would be rendered useless pretty soon

@michaelosthege
Copy link
Member

nutpie wants to return a full InferenceData regardless of pymc so this would be rendered useless pretty soon

At that point nutpie will probably require constant_data and observed_data as additional kwargs, right?
Then most of the code from this PR is still needed at the PyMC side - only the idata.add_groups and attrs would move to nutpie.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 6, 2023

Calling nutpie.sample wants to return an InferenceData with constant_data, that's all. Regardless of whether it was given a PyMC or Stan model I guess, although I'm not too worried about the Stan side. CC @aseyboldt

@codecov
Copy link

codecov bot commented Oct 6, 2023

Codecov Report

Merging #6943 (776ff03) into main (6f4a040) will decrease coverage by 0.05%.
The diff coverage is 25.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6943      +/-   ##
==========================================
- Coverage   92.18%   92.14%   -0.05%     
==========================================
  Files         100      100              
  Lines       16847    16858      +11     
==========================================
+ Hits        15531    15533       +2     
- Misses       1316     1325       +9     
Files Coverage Δ
pymc/sampling/mcmc.py 87.69% <25.00%> (-1.73%) ⬇️

@aseyboldt
Copy link
Member

Thank you @Y0dler, I think this looks good.
I also think the proper place would be nutpie however. It has to access someone internal pymc structures anyway, because it needs to build the cfunc to call from rust, and needs to handle shared data etc.

The nutpie API itself also directly takes a model as input as it is:

import nutpie
compiled = nutpie.compiled_pymc_model(model)
nutpie.sample(compiled)

And if used like this, it also should store the constant data (even if it was changed using the with_data method on compiled).

@michaelosthege
Copy link
Member

And if used like this, it also should store the constant data (even if it was changed using the with_data method on compiled).

I don't understand what you're suggesting. I checked compile_pymc.py but neither CompiledPyMCModel nor compile_pymc_model make any attempt at taking care of constant_data or observed_data.
One could pass compile_pymc_model(pmodel, constant_data=..., observed_data=...) for the nupie.sample code to pass them to arviz.from_dict(...), but this would break if someone uses the .with_data() method.

The solution proposed in this PR fixes the bug for the vast majority of PyMC users, without introducing complexity of coupling between the packages. And it can easily be adapted/removed once the nutpie code has been refactored to gather constant_data and observed_data from arbitrarily updated compiled models.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 8, 2023

One could pass compile_pymc_model(pmodel, constant_data=..., observed_data=...)

Why would one have to do that? All the information about constant/observed data is contained in the pymc model passed to compile_pymc_model. By that token one would have to call compile_pymc_model(pmodel, rvs=..., value_vars=..., transforms=...) just for sampling, which is obviously not true.

Also, nutpie already handles deterministics just fine, why couldn't it also handle observed and constant data in the returned InferenceData?

@michaelosthege
Copy link
Member

Oh nutpie could totally do it! But the current control flow, and its internal design makes is very non-obvious how. At least I don't see how it could/should be done.

From the PyMC perspective this PR is a totally fine fix until it's taken care of on the nupie side. And to be honest, the time @Y0dler and me spent on this PR cost our employer about 400 bucks, so there won't be any more heavy lifting on this issue from our side.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 8, 2023

And to be honest, the time @Y0dler and me spent on this PR cost our employer about 400 bucks, so there won't be any more heavy lifting on this issue from our side.

That doesn't sound like a fair argument to motivate design choices. The issue was moved to the nutpie repo and the PR was opened without any public discussion.

Anyway I am fine with the PR as a quick patch, with a comment to revert once the original issue is addressed.

@@ -293,8 +299,24 @@ def _sample_external_nuts(
"`idata_kwargs` are currently ignored by the nutpie sampler",
UserWarning,
)

# gather observed and constant data as nutpie.sample() has no access to the PyMC model
coords, dims = coords_and_dims_for_inferencedata(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done after sampling (if sampling fails, there's no point in doing this work in advance).

Please add a comment along the lines of # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a very cheap function call:

pymc/pymc/backends/arviz.py

Lines 103 to 112 in 6f4a040

def coords_and_dims_for_inferencedata(model: Model) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}
dims = {dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()}
return coords, dims

And sometimes this step catches issues with coordinate value types, that's why I recommended to put it before sampling (just like gathering constant and observed data)

@michaelosthege michaelosthege merged commit cb64480 into pymc-devs:main Oct 12, 2023
20 of 21 checks passed
@welcome
Copy link

welcome bot commented Oct 12, 2023

Congratulations Banner
Congrats on merging your first pull request! 🎉 We here at PyMC are proud of you! 💖 Thank you so much for your contribution 🎁

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants