-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add wrapper to handle 1-D pytorch distributions as priors. #1283
Comments
Hi there, thanks a lot for reporting this! The following will fix it: class WrappedExponential(Exponential):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def log_prob(*args, **kwargs):
log_probs = Exponential.log_prob(*args, **kwargs)
return log_probs.squeeze()
exp_prior = WrappedExponential(torch.tensor([2.0])) The reason that the issue was happening is that I will leave this issue open though because we should deal with this in I hope that the above fixes the issue! Let me know if you have any more questions! All the best |
More notes for future fixing on our side: This is only an issue for 1D pytorch distributions. The issue is that, e.g., from torch.distributions import Exponential, MultivariateNormal
prior = Exponential(torch.tensor(2.0))
samples = prior.sample((10,)) # (10,)
log_probs = prior.log_prob(samples) # (10,)
# `sbi` raises an error because one must have a sample dimension.
prior = Exponential(torch.tensor([2.0]))
samples = prior.sample((10,)) # (10, 1)
log_probs = prior.log_prob(samples) # (10, 1)
# `sbi` fails because the log_prob dimension contains the data dim.
prior = MultivariateNormal(torch.tensor([2.0]), torch.tensor([2.0]))
samples = prior.sample((10,)) # (10, 1)
log_probs = prior.log_prob(samples) # (10,)
# `sbi` works. IMO, the easiest fix would be to introduce the following: class OneDimDistributionWrapper(torch.Distribution):
def __init__(self, dist, *args, **kwargs):
super().__init__()
self.dist = dist
def sample(*args, **kwargs):
return self.dist.sample(*args, **kwargs)
def log_prob(*args, **kwargs):
return self.dist.log_prob(*args, **kwargs)[..., 0] # Remove the additional dimension.
@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
return self.dist.arg_constraints
@property
def support(self):
return self.dist.support
@property
def mean(self) -> Tensor:
return self.dist.mean
@property
def variance(self) -> Tensor:
return self.dist.variance We could then use this to wrap the 1D distributions which have a sample dimension. |
I don't understand what is being shown in the plot. |
Thank you so much for your suggestions! I did try them out. However, the estimates that I am getting from each of the 3 algorithms (SNPE_C, SNLE, SNRE_B) are somewhat unreliable. I get bad posterior estimates in between and the good estimates are not reproducible. Below are some examples of pathological posterior estimates:
It would be great if you could give some general suggestions, as my ultimate aim is to implement these algorithms for a 21 parameter ODE model. |
I am sorry for not attaching the code for this issue, so here goes:
I had another question: They all give pathological outputs as shown in the previous comment. I would appreciate any suggestions! |
Hi @paarth-dudani how many simulations are you using here? (I couldn't find it in the code snippet above). Given that you seem to have a fast simulator I would start with at least 1-10k simulations to debug this case. Above, it also seems that the My recommendations would be:
Does this help? |
Describe the bug
I am implementing SNRE and SNLE (from implemented algorithms) on a simple exponential simulator model with noise and a prior. The algorithms work just fine for a uniform prior but give the following error: 'number of categories cannot exceed 2^24', with the exponential prior.
To Reproduce
Versions
Python version: 3.9.13
SBI version: 0.23.1
Code for SNLE implementation but I get the same error for SNRE implementation as well (with Inference object: NRE)
Expected behavior
I expect the network to undergo multiple rounds of training (2 in this example) or give me a pairplot after one round of training (not shown above).
The text was updated successfully, but these errors were encountered: