Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score-based iid sampling #1381

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open

Score-based iid sampling #1381

wants to merge 26 commits into from

Conversation

manuelgloeckler
Copy link
Contributor

@manuelgloeckler manuelgloeckler commented Jan 30, 2025

Completes the missing features based on score estimation #1226.

  • IID interface
  • IID util functions
  • FNPE
  • GAUSS
  • JAC
  • test

@manuelgloeckler
Copy link
Contributor Author

manuelgloeckler commented Feb 17, 2025

Okey, everything should be implemented now. This acutally became quite a big PR now. A few more points:

  • Check if batch jacobian with torch.func.vmap actually works correctly
  • Check if the above or other cause some performance degradation in jac_gauss (although this can be sensitive to how the network is preconditioned)
  • Add an API to pass hyperparameters to the IID method (and make iid_methods more customizable i.e. auto_gauss)
  • Multivariate priors
  • General Empirical prior support for automatic denoising and marginalization (then auto_gauss should become default)

Copy link

codecov bot commented Feb 17, 2025

Codecov Report

Attention: Patch coverage is 92.88991% with 31 lines in your changes missing coverage. Please review.

Project coverage is 79.07%. Comparing base (18f92b1) to head (fd0f964).
Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
sbi/utils/score_utils.py 91.66% 16 Missing ⚠️
sbi/inference/potentials/score_fn_iid.py 92.53% 15 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1381       +/-   ##
===========================================
- Coverage   89.31%   79.07%   -10.24%     
===========================================
  Files         119      121        +2     
  Lines        8779     9311      +532     
===========================================
- Hits         7841     7363      -478     
- Misses        938     1948     +1010     
Flag Coverage Δ
unittests 79.07% <92.88%> (-10.24%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/posteriors/score_posterior.py 81.18% <100.00%> (-15.83%) ⬇️
sbi/inference/potentials/score_based_potential.py 82.85% <100.00%> (-14.12%) ⬇️
sbi/samplers/score/correctors.py 98.18% <100.00%> (+46.00%) ⬆️
sbi/samplers/score/diffuser.py 90.74% <100.00%> (+5.83%) ⬆️
sbi/inference/potentials/score_fn_iid.py 92.53% <92.53%> (ø)
sbi/utils/score_utils.py 91.66% <91.66%> (ø)

... and 36 files with indirect coverage changes

@manuelgloeckler
Copy link
Contributor Author

This is now basically done. With the review, one should probably wait until the other score branch and type fixes are merged.

But the major changes are:

  • ScoreFnIID classes which manage the score composition
  • ScoreUtil, which has a bunch of helpers for "automatic" marginalization and denoising of PyTorch distributions (i.e., what the user can pass as the prior). If there is no analytic solution (or the user does not pass a prior) it will fall back to a rather good MoG approximation.

@janfb janfb linked an issue Feb 20, 2025 that may be closed by this pull request
@janfb
Copy link
Contributor

janfb commented Feb 20, 2025

@manuelgloeckler #1370 has been merged into main. please merge main to resolve the conflicts showing up here.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, great effort! 👏 Thanks for adding all those methods!

Looks great overall, but I was a bit confused by the class structure in the score_fn_iid.py and added a couple of comments.
Also, the tests can be refactored a bit.

Please note that you might have to rebase or merge again with main once #1404 is merged.

@@ -123,6 +125,9 @@ def sample(
steps: Number of steps to take for the Euler-Maruyama method.
ts: Time points at which to evaluate the diffusion process. If None, a
linear grid between t_max and t_min is used.
iid_method: Which method to use for computing the score in the iid setting.
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great! can you please add a bit of details about these methods, e.g,. the full names? We should then also cover these options in the extended NPSE tutorial. Can you please add a not under #1392 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These kinda are the names of the methods, but I can add a bit of detail on what they do and what to care about (from an applied perspective).

@@ -138,7 +143,10 @@ def sample(

x = self._x_else_default_x(x)
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
self.potential_fn.set_x(x, x_is_iid=True)
is_iid = x.ndim > 1 and x.shape[0] > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

above, x is reshaped to "batch_event", so it will be always ndim>1, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch I think so, too.

@@ -176,6 +184,7 @@ def sample(
def _sample_via_diffusion(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need x here? I think it is just taken internally from potential_fn.x_o.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add that this was part of the #1226 changes. But I can have a look.

@@ -244,6 +253,7 @@ def _sample_via_diffusion(
def sample_via_ode(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need x here? I think it is just taken internally from potential_fn.x_o.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add that this was part of the #1226 changes. But I can have a look.

@@ -57,7 +58,8 @@ def __init__(
score_estimator: ConditionalScoreEstimator,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
iid_method: str = "iid_bridge",
iid_method: str = "auto_gauss",
iid_params: Optional[dict] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict vs Dict. Also, can we specify the types of the dict, e.g., Dict[str, float], or will it have different types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would do Dict[str, Any] can be from floats to tensors to strings.


theta = prior.sample((num_simulations,))
x = linear_gaussian(theta, likelihood_shift, likelihood_cov)

score_estimator = inference.append_simulations(theta, x).train(
training_batch_size=100,
training_batch_size=200, max_num_epochs=400
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need this increased num_epochs? was not this fixed with the recent update on the convergence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, will have a look into this. The tests did, however, fail on the new setting (although not by much, I can just change the tolerance of the checks a bit). In general, these will be slightly more sensitive to a small change in approximation performance on 1 trial; it will propagate.

check_c2st(
samples,
target_samples,
alg=f"npse-vp-gaussian-2D-{iid_method}-{num_trial}iid-trials",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add num_dim instead of 2D?

check_c2st(
samples, target_samples, alg=f"npse-vp-gaussian-1D-{num_trials}iid-trials"

@pytest.mark.slow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is essentially the same test as above but with prior=Uniform no? So I suggest merging the two tests and just using pytest.mark.parametrize("prior", ("gaussian", "uniform", None))

or am I missing something?

],
)
@pytest.mark.parametrize("d", [1, 2, 3])
def test_score_fn_iid_on_different_priors(sde_type, iid_method, d):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring with basic explanation.

"jac_gauss",
],
)
@pytest.mark.parametrize("d", [1, 2, 3])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_dim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

missing features and todos for score estimation
3 participants