Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score-based iid sampling #1381

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5350e58
npse MAP
gmoss13 Jan 18, 2025
58fd7b8
set default enable_Transform to True
gmoss13 Jan 18, 2025
74cfad4
ruff formatting version change
gmoss13 Jan 18, 2025
857399d
sampling via diffusion twice
gmoss13 Jan 28, 2025
bcea468
batched sampling for score-based posteriors
gmoss13 Jan 29, 2025
843ce7d
iid api integration
manuelgloeckler Jan 30, 2025
384f36f
new ruff
manuelgloeckler Jan 30, 2025
016c5a7
adding corrector back
manuelgloeckler Jan 30, 2025
9152d28
adding untesed GAUSS
manuelgloeckler Jan 30, 2025
df1f30d
All other methods
manuelgloeckler Jan 30, 2025
85bf355
reformat
manuelgloeckler Jan 30, 2025
5195417
messy version of simple gauss, API to sample method
manuelgloeckler Feb 14, 2025
5bcf427
jac method (still needs feasible Lambda projection to work)
manuelgloeckler Feb 14, 2025
09f0113
Only jac left
manuelgloeckler Feb 17, 2025
ad240b7
Formating and so on
manuelgloeckler Feb 17, 2025
44e08f2
Adding correct tests
manuelgloeckler Feb 18, 2025
d29ebf8
Add empirical support - but this doesnt work that well
manuelgloeckler Feb 18, 2025
b0b8b41
A bunch of auto marginalize and denois methods
manuelgloeckler Feb 19, 2025
0d0991b
general prior with GMM approx
manuelgloeckler Feb 19, 2025
f22467d
Bunch of reffactorings and customizability
manuelgloeckler Feb 20, 2025
6cbe5ae
Update API docstirngs
manuelgloeckler Feb 20, 2025
97a8dda
New tests, passes all now
manuelgloeckler Feb 20, 2025
5349747
Formating linting, type error form other PR
manuelgloeckler Feb 20, 2025
0539b62
Merge branch 'main' into 1226-score-based-iid
manuelgloeckler Feb 21, 2025
bfe6df2
Remove assert that IID data is not supported
manuelgloeckler Feb 21, 2025
fd0f964
Ruff linting
manuelgloeckler Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def sample(
corrector_params: Optional[Dict] = None,
steps: int = 500,
ts: Optional[Tensor] = None,
iid_method: str = "auto_gauss",
iid_params: Optional[Dict] = None,
max_sampling_batch_size: int = 10_000,
sample_with: Optional[str] = None,
show_progress_bars: bool = True,
Expand All @@ -123,6 +125,9 @@ def sample(
steps: Number of steps to take for the Euler-Maruyama method.
ts: Time points at which to evaluate the diffusion process. If None, a
linear grid between t_max and t_min is used.
iid_method: Which method to use for computing the score in the iid setting.
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great! can you please add a bit of details about these methods, e.g,. the full names? We should then also cover these options in the extended NPSE tutorial. Can you please add a not under #1392 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These kinda are the names of the methods, but I can add a bit of detail on what they do and what to care about (from an applied perspective).

iid_params: Additional parameters passed to the iid method.
max_sampling_batch_size: Maximum batch size for sampling.
sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to
`.sample()`.
Expand All @@ -138,7 +143,10 @@ def sample(

x = self._x_else_default_x(x)
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
self.potential_fn.set_x(x, x_is_iid=True)
is_iid = x.ndim > 1 and x.shape[0] > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

above, x is reshaped to "batch_event", so it will be always ndim>1, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch I think so, too.

self.potential_fn.set_x(
x, x_is_iid=is_iid, iid_method=iid_method, iid_params=iid_params
)

num_samples = torch.Size(sample_shape).numel()

Expand Down Expand Up @@ -176,6 +184,7 @@ def sample(
def _sample_via_diffusion(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need x here? I think it is just taken internally from potential_fn.x_o.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add that this was part of the #1226 changes. But I can have a look.

predictor: Union[str, Predictor] = "euler_maruyama",
corrector: Optional[Union[str, Corrector]] = None,
predictor_params: Optional[Dict] = None,
Expand Down Expand Up @@ -244,6 +253,7 @@ def _sample_via_diffusion(
def sample_via_ode(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need x here? I think it is just taken internally from potential_fn.x_o.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add that this was part of the #1226 changes. But I can have a look.

) -> Tensor:
r"""Return samples from posterior distribution with probability flow ODE.

Expand Down
36 changes: 24 additions & 12 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from zuko.transforms import FreeFormJacobianTransform

from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.score_fn_iid import get_iid_method
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
Expand Down Expand Up @@ -57,7 +58,8 @@ def __init__(
score_estimator: ConditionalScoreEstimator,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
iid_method: str = "iid_bridge",
iid_method: str = "auto_gauss",
iid_params: Optional[dict] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict vs Dict. Also, can we specify the types of the dict, e.g., Dict[str, float], or will it have different types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would do Dict[str, Any] can be from floats to tensors to strings.

device: str = "cpu",
):
r"""Returns the score function for score-based methods.
Expand All @@ -66,30 +68,37 @@ def __init__(
score_estimator: The neural network modelling the score.
prior: The prior distribution.
x_o: The observed data at which to evaluate the posterior.
iid_method: Which method to use for computing the score. Currently, only
`iid_bridge` as proposed in Geffner et al. is implemented.
iid_method: Which method to use for computing the score in the iid setting.
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss".
iid_params: Parameters for the iid method, for arguments see ScoreFnIID.
device: The device on which to evaluate the potential.
"""
self.score_estimator = score_estimator
self.score_estimator.eval()
self.iid_method = iid_method
self.iid_params = iid_params
super().__init__(prior, x_o, device=device)

def set_x(
self,
x_o: Optional[Tensor],
x_is_iid: Optional[bool] = False,
iid_method: str = "auto_gauss",
iid_params: Optional[dict] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict vs Dict, see comment above.

rebuild_flow: Optional[bool] = True,
):
"""
Set the observed data and whether it is IID.
Args:
x_o: The observed data.
x_is_iid: Whether the observed data is IID (if batch_dim>1).
rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, useful if
the flow needs to be evaluated many times (e.g. for MAP calculation).
x_o: The observed data.
x_is_iid: Whether the observed data is IID (if batch_dim>1).
rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model,
useful if the flow needs to be evaluated many times
(e.g. for MAP calculation).
Comment on lines +95 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beware: we had to do another PR for fixing some issues with the flow rebuilding and this signature here changed. You will have to rebase or merge again once #1404 is merged. Sorry about that!

"""
super().set_x(x_o, x_is_iid)
self.iid_method = iid_method
self.iid_params = iid_params
if rebuild_flow and self._x_o is not None:
# By default, we want a high-tolerance flow.
# This flow will be used mainly for MAP calculations, hence we want to save
Expand Down Expand Up @@ -172,10 +181,16 @@ def gradient(
input=theta, condition=self.x_o, time=time
)
else:
raise NotImplementedError(
"Score accumulation for IID data is not yet implemented."
assert self.prior is not None, "Prior is required for iid methods."

method_iid = get_iid_method(self.iid_method)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest renaming this to iid_method for name consistency.

# Always creating a new object every call is not efficient...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment. How are we preventing this here then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will delete it. I now cache (using functools) the most expensive operations, which only need to be done once for x_o. So, I think its fine now and I will just remove the comment (the overhead is now the standard object creation overhead for python).

score_fn_iid = method_iid(
self.score_estimator, self.prior, **(self.iid_params or {})
)

score = score_fn_iid(theta, self.x_o, time) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the type ignore needed here?


return score

def get_continuous_normalizing_flow(
Expand Down Expand Up @@ -217,9 +232,6 @@ def rebuild_flow(
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert x_density_estimator.shape[0] == 1, (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be assert x_density_estimator.shape[0] == 1 or self.x_is_iid? Or can we just have shape 1 or larger and move the responsibility for checking this upstream to the posterior? Not sure..
I believe this is already checked in the posterior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just removed it, as it was blocking any iid methodology.

I also think that it is checked in the posterior (and if not we should do it there). The score_fn_iid should anyway reduce this batch dimension to 1 and (atleast) the gauss methods should be applicable to ode sampling (althought not for flow matching).

"PosteriorScoreBasedPotential supports only x batchsize of 1`."
)

flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact
Expand Down
Loading
Loading