-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Score-based iid sampling #1381
base: main
Are you sure you want to change the base?
Score-based iid sampling #1381
Changes from all commits
5350e58
58fd7b8
74cfad4
857399d
bcea468
843ce7d
384f36f
016c5a7
9152d28
df1f30d
85bf355
5195417
5bcf427
09f0113
ad240b7
44e08f2
d29ebf8
b0b8b41
0d0991b
f22467d
6cbe5ae
97a8dda
5349747
0539b62
bfe6df2
fd0f964
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,6 +105,8 @@ def sample( | |
corrector_params: Optional[Dict] = None, | ||
steps: int = 500, | ||
ts: Optional[Tensor] = None, | ||
iid_method: str = "auto_gauss", | ||
iid_params: Optional[Dict] = None, | ||
max_sampling_batch_size: int = 10_000, | ||
sample_with: Optional[str] = None, | ||
show_progress_bars: bool = True, | ||
|
@@ -123,6 +125,9 @@ def sample( | |
steps: Number of steps to take for the Euler-Maruyama method. | ||
ts: Time points at which to evaluate the diffusion process. If None, a | ||
linear grid between t_max and t_min is used. | ||
iid_method: Which method to use for computing the score in the iid setting. | ||
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss". | ||
iid_params: Additional parameters passed to the iid method. | ||
max_sampling_batch_size: Maximum batch size for sampling. | ||
sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to | ||
`.sample()`. | ||
|
@@ -138,7 +143,10 @@ def sample( | |
|
||
x = self._x_else_default_x(x) | ||
x = reshape_to_batch_event(x, self.score_estimator.condition_shape) | ||
self.potential_fn.set_x(x, x_is_iid=True) | ||
is_iid = x.ndim > 1 and x.shape[0] > 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. above, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch I think so, too. |
||
self.potential_fn.set_x( | ||
x, x_is_iid=is_iid, iid_method=iid_method, iid_params=iid_params | ||
) | ||
|
||
num_samples = torch.Size(sample_shape).numel() | ||
|
||
|
@@ -176,6 +184,7 @@ def sample( | |
def _sample_via_diffusion( | ||
self, | ||
sample_shape: Shape = torch.Size(), | ||
x: Optional[Tensor] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did not add that this was part of the #1226 changes. But I can have a look. |
||
predictor: Union[str, Predictor] = "euler_maruyama", | ||
corrector: Optional[Union[str, Corrector]] = None, | ||
predictor_params: Optional[Dict] = None, | ||
|
@@ -244,6 +253,7 @@ def _sample_via_diffusion( | |
def sample_via_ode( | ||
self, | ||
sample_shape: Shape = torch.Size(), | ||
x: Optional[Tensor] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did not add that this was part of the #1226 changes. But I can have a look. |
||
) -> Tensor: | ||
r"""Return samples from posterior distribution with probability flow ODE. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
from zuko.transforms import FreeFormJacobianTransform | ||
|
||
from sbi.inference.potentials.base_potential import BasePotential | ||
from sbi.inference.potentials.score_fn_iid import get_iid_method | ||
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator | ||
from sbi.neural_nets.estimators.shape_handling import ( | ||
reshape_to_batch_event, | ||
|
@@ -57,7 +58,8 @@ def __init__( | |
score_estimator: ConditionalScoreEstimator, | ||
prior: Optional[Distribution], | ||
x_o: Optional[Tensor] = None, | ||
iid_method: str = "iid_bridge", | ||
iid_method: str = "auto_gauss", | ||
iid_params: Optional[dict] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would do Dict[str, Any] can be from floats to tensors to strings. |
||
device: str = "cpu", | ||
): | ||
r"""Returns the score function for score-based methods. | ||
|
@@ -66,30 +68,37 @@ def __init__( | |
score_estimator: The neural network modelling the score. | ||
prior: The prior distribution. | ||
x_o: The observed data at which to evaluate the posterior. | ||
iid_method: Which method to use for computing the score. Currently, only | ||
`iid_bridge` as proposed in Geffner et al. is implemented. | ||
iid_method: Which method to use for computing the score in the iid setting. | ||
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss". | ||
iid_params: Parameters for the iid method, for arguments see ScoreFnIID. | ||
device: The device on which to evaluate the potential. | ||
""" | ||
self.score_estimator = score_estimator | ||
self.score_estimator.eval() | ||
self.iid_method = iid_method | ||
self.iid_params = iid_params | ||
super().__init__(prior, x_o, device=device) | ||
|
||
def set_x( | ||
self, | ||
x_o: Optional[Tensor], | ||
x_is_iid: Optional[bool] = False, | ||
iid_method: str = "auto_gauss", | ||
iid_params: Optional[dict] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
rebuild_flow: Optional[bool] = True, | ||
): | ||
""" | ||
Set the observed data and whether it is IID. | ||
Args: | ||
x_o: The observed data. | ||
x_is_iid: Whether the observed data is IID (if batch_dim>1). | ||
rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, useful if | ||
the flow needs to be evaluated many times (e.g. for MAP calculation). | ||
x_o: The observed data. | ||
x_is_iid: Whether the observed data is IID (if batch_dim>1). | ||
rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, | ||
useful if the flow needs to be evaluated many times | ||
(e.g. for MAP calculation). | ||
Comment on lines
+95
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. beware: we had to do another PR for fixing some issues with the flow rebuilding and this signature here changed. You will have to rebase or merge again once #1404 is merged. Sorry about that! |
||
""" | ||
super().set_x(x_o, x_is_iid) | ||
self.iid_method = iid_method | ||
self.iid_params = iid_params | ||
if rebuild_flow and self._x_o is not None: | ||
# By default, we want a high-tolerance flow. | ||
# This flow will be used mainly for MAP calculations, hence we want to save | ||
|
@@ -172,10 +181,16 @@ def gradient( | |
input=theta, condition=self.x_o, time=time | ||
) | ||
else: | ||
raise NotImplementedError( | ||
"Score accumulation for IID data is not yet implemented." | ||
assert self.prior is not None, "Prior is required for iid methods." | ||
|
||
method_iid = get_iid_method(self.iid_method) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest renaming this to |
||
# Always creating a new object every call is not efficient... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this comment. How are we preventing this here then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will delete it. I now cache (using functools) the most expensive operations, which only need to be done once for x_o. So, I think its fine now and I will just remove the comment (the overhead is now the standard object creation overhead for python). |
||
score_fn_iid = method_iid( | ||
self.score_estimator, self.prior, **(self.iid_params or {}) | ||
) | ||
|
||
score = score_fn_iid(theta, self.x_o, time) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the type ignore needed here? |
||
|
||
return score | ||
|
||
def get_continuous_normalizing_flow( | ||
|
@@ -217,9 +232,6 @@ def rebuild_flow( | |
x_density_estimator = reshape_to_batch_event( | ||
self.x_o, event_shape=self.score_estimator.condition_shape | ||
) | ||
assert x_density_estimator.shape[0] == 1, ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just removed it, as it was blocking any iid methodology. I also think that it is checked in the posterior (and if not we should do it there). The score_fn_iid should anyway reduce this batch dimension to 1 and (atleast) the gauss methods should be applicable to ode sampling (althought not for flow matching). |
||
"PosteriorScoreBasedPotential supports only x batchsize of 1`." | ||
) | ||
|
||
flow = self.get_continuous_normalizing_flow( | ||
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is great! can you please add a bit of details about these methods, e.g,. the full names? We should then also cover these options in the extended NPSE tutorial. Can you please add a not under #1392 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These kinda are the names of the methods, but I can add a bit of detail on what they do and what to care about (from an applied perspective).