-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Score-based iid sampling #1381
base: main
Are you sure you want to change the base?
Score-based iid sampling #1381
Conversation
Okey, everything should be implemented now. This acutally became quite a big PR now. A few more points:
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1381 +/- ##
===========================================
- Coverage 89.31% 79.07% -10.24%
===========================================
Files 119 121 +2
Lines 8779 9311 +532
===========================================
- Hits 7841 7363 -478
- Misses 938 1948 +1010
Flags with carried forward coverage won't be shown. Click here to find out more.
|
This is now basically done. With the review, one should probably wait until the other score branch and type fixes are merged. But the major changes are:
|
@manuelgloeckler #1370 has been merged into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, great effort! 👏 Thanks for adding all those methods!
Looks great overall, but I was a bit confused by the class structure in the score_fn_iid.py
and added a couple of comments.
Also, the tests can be refactored a bit.
Please note that you might have to rebase or merge again with main
once #1404 is merged.
@@ -123,6 +125,9 @@ def sample( | |||
steps: Number of steps to take for the Euler-Maruyama method. | |||
ts: Time points at which to evaluate the diffusion process. If None, a | |||
linear grid between t_max and t_min is used. | |||
iid_method: Which method to use for computing the score in the iid setting. | |||
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is great! can you please add a bit of details about these methods, e.g,. the full names? We should then also cover these options in the extended NPSE tutorial. Can you please add a not under #1392 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These kinda are the names of the methods, but I can add a bit of detail on what they do and what to care about (from an applied perspective).
@@ -138,7 +143,10 @@ def sample( | |||
|
|||
x = self._x_else_default_x(x) | |||
x = reshape_to_batch_event(x, self.score_estimator.condition_shape) | |||
self.potential_fn.set_x(x, x_is_iid=True) | |||
is_iid = x.ndim > 1 and x.shape[0] > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
above, x
is reshaped to "batch_event", so it will be always ndim>1
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch I think so, too.
@@ -176,6 +184,7 @@ def sample( | |||
def _sample_via_diffusion( | |||
self, | |||
sample_shape: Shape = torch.Size(), | |||
x: Optional[Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need x
here? I think it is just taken internally from potential_fn.x_o
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not add that this was part of the #1226 changes. But I can have a look.
@@ -244,6 +253,7 @@ def _sample_via_diffusion( | |||
def sample_via_ode( | |||
self, | |||
sample_shape: Shape = torch.Size(), | |||
x: Optional[Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need x
here? I think it is just taken internally from potential_fn.x_o
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not add that this was part of the #1226 changes. But I can have a look.
@@ -57,7 +58,8 @@ def __init__( | |||
score_estimator: ConditionalScoreEstimator, | |||
prior: Optional[Distribution], | |||
x_o: Optional[Tensor] = None, | |||
iid_method: str = "iid_bridge", | |||
iid_method: str = "auto_gauss", | |||
iid_params: Optional[dict] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict
vs Dict
. Also, can we specify the types of the dict, e.g., Dict[str, float]
, or will it have different types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would do Dict[str, Any] can be from floats to tensors to strings.
|
||
theta = prior.sample((num_simulations,)) | ||
x = linear_gaussian(theta, likelihood_shift, likelihood_cov) | ||
|
||
score_estimator = inference.append_simulations(theta, x).train( | ||
training_batch_size=100, | ||
training_batch_size=200, max_num_epochs=400 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need this increased num_epochs
? was not this fixed with the recent update on the convergence?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, will have a look into this. The tests did, however, fail on the new setting (although not by much, I can just change the tolerance of the checks a bit). In general, these will be slightly more sensitive to a small change in approximation performance on 1 trial; it will propagate.
check_c2st( | ||
samples, | ||
target_samples, | ||
alg=f"npse-vp-gaussian-2D-{iid_method}-{num_trial}iid-trials", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add num_dim
instead of 2D?
check_c2st( | ||
samples, target_samples, alg=f"npse-vp-gaussian-1D-{num_trials}iid-trials" | ||
|
||
@pytest.mark.slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is essentially the same test as above but with prior=Uniform
no? So I suggest merging the two tests and just using pytest.mark.parametrize("prior", ("gaussian", "uniform", None))
or am I missing something?
], | ||
) | ||
@pytest.mark.parametrize("d", [1, 2, 3]) | ||
def test_score_fn_iid_on_different_priors(sde_type, iid_method, d): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add docstring with basic explanation.
"jac_gauss", | ||
], | ||
) | ||
@pytest.mark.parametrize("d", [1, 2, 3]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_dim
Completes the missing features based on score estimation #1226.