-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move Predictive handlers to observational module #550
Conversation
@@ -205,7 +162,11 @@ def forward( | |||
|
|||
# move data plate dimension to the left | |||
for name in reversed(plate_name_to_dim.keys()): | |||
log_weights = bind_leftmost_dim(log_weights, name) | |||
log_weights = torch.transpose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only non-renaming change in this PR. It is equivalent to the previous version with the definition of bind_leftmost_dim
inlined for this special case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a slight preference for keeping it as bind_leftmost_dim
, as that seems a bit more transparent about what the code is doing.
If we do decide to remove it here is bind_leftmost_dim
used anywhere else anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I inlined bind_leftmost_dim
here because it is now in observational.internals
and shouldn't be used outside observational
. I could have moved BatchedNMCLogMarginalLikelihood
, but I would like to keep that as an internal implementation detail of chirho.robust
for now.
bind_leftmost_dim
is still used inside chirho.observational.handlers.predictive.PredictiveFunctional
, which was moved in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As noted I have a slight preference to revert the bind_leftmost_dim
change, but I'll happily defer to your preferences here.
I marked this as "request changes" just because the linter seems to be failing because of unused imports. Once that's resolved I'll approve.
@@ -205,7 +162,11 @@ def forward( | |||
|
|||
# move data plate dimension to the left | |||
for name in reversed(plate_name_to_dim.keys()): | |||
log_weights = bind_leftmost_dim(log_weights, name) | |||
log_weights = torch.transpose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a slight preference for keeping it as bind_leftmost_dim
, as that seems a bit more transparent about what the code is doing.
If we do decide to remove it here is bind_leftmost_dim
used anywhere else anymore?
Blocked by #549
This pure refactoring PR moves some of the utilities for making and batching predictive distributions out of
chirho.robust.handlers.predictive
andchirho.robust.internals.nmc
and into a newchirho.observational.handlers.predictive
. There are no changes to the code other than moving definitions to different files (with one small exception, noted below).