-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CategoricalMADE
#1269
base: main
Are you sure you want to change the base?
Add CategoricalMADE
#1269
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1269 +/- ##
==========================================
- Coverage 86.05% 78.02% -8.04%
==========================================
Files 118 118
Lines 8672 8791 +119
==========================================
- Hits 7463 6859 -604
- Misses 1209 1932 +723
Flags with carried forward coverage won't be shown. Click here to find out more.
|
Hey @janfb, Currently the PR adds the As far as I can tell all functionalities of The question now is: How should I verify this works? / Which tests should I add/modify? Do you have an idea for a good toy example with several discrete variables that I could use? I have cooked up a toy simulator, for which I am getting good posteriors using SNPE, but for some reason MNLE raises a This is the simulator def toy_simulator(theta: torch.Tensor, centers: list[torch.Tensor]) -> torch.Tensor:
batch_size, n_dimensions = theta.shape
assert len(centers) == n_dimensions, "Number of center sets must match theta dimensions"
# Calculate discrete classes by assiging to the closest center
x_disc = torch.stack([
torch.argmin(torch.abs(centers[i].unsqueeze(1) - theta[:, i].unsqueeze(0)), dim=0)
for i in range(n_dimensions)
], dim=1)
closest_centers = torch.stack([centers[i][x_disc[:, i]] for i in range(n_dimensions)], dim=1)
# Add Gaussian noise to assigned class centers
std = 0.4
x_cont = closest_centers + std * torch.randn_like(closest_centers)
return torch.cat([x_cont, x_disc], dim=1) The setup: torch.random.manual_seed(0)
centers = [
torch.tensor([-0.5, 0.5]),
# torch.tensor([-1.0, 0.0, 1.0]),
]
prior = BoxUniform(low=torch.tensor([-2.0]*len(centers)), high=torch.tensor([2.0]*len(centers)))
theta = prior.sample((20000,))
x = toy_simulator(theta, centers)
theta_o = prior.sample((1,))
x_o = toy_simulator(theta_o, centers) NPE: trainer = SNPE()
estimator = trainer.append_simulations(theta=theta, x=x).train(training_batch_size=1000)
snpe_posterior = trainer.build_posterior(prior=prior)
posterior_samples = snpe_posterior.sample((2000,), x=x_o)
pairplot(posterior_samples, limits=[[-2, 2], [-2, 2]], figsize=(5, 5), points=theta_o) and the equivalent MNLE: trainer = MNLE()
estimator = trainer.append_simulations(theta=theta, x=x).train(training_batch_size=1000)
mnle_posterior = trainer.build_posterior(prior=prior)
mnle_samples = mnle_posterior.sample((10000,), x=x_o)
pairplot(mnle_samples, limits=[[-2, 2], [-2, 2]], figsize=(5, 5), points=theta_o) Hoping this makes sense. Lemme know if you need clarifications anywhere. Thanks for your feedback. |
Hey @janfb, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot for tackling this @jnsbck! 👏
Please find below some comments and questions.
There might be some misunderstanding about variables
and categories
on my side. We can have a call if that's more efficient than commenting here.
sbi/neural_nets/net_builders/mnle.py
Outdated
elif categorical_model == "mlp": | ||
assert num_disc == 1, "MLP only supports 1D input." | ||
discrete_net = build_categoricalmassestimator( | ||
disc_x, | ||
batch_y, | ||
z_score_x="none", # discrete data should not be z-scored. | ||
z_score_y="none", # y-embedding net already z-scores. | ||
num_hidden=hidden_features, | ||
num_layers=hidden_layers, | ||
embedding_net=embedding_net, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more generally, isn't the MLP a special case of the MADE? can't we absorb them into one class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment: Check if testcase is identical, if yes -> rm MLP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c2st between true and MADE MNLE posterior: 0.538
c2st between true and MLP MNLE posterior: 0.5730000000000001
c2st between MADE MNLE and MLP MNLE posterior: 0.5734999999999999
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's great! 👍
sbi/made_mnle.ipynb
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should eventually be integrated with the MNLE tutorial in 12_iid_data_and_permutation_invariant_embeddings.ipynb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change "mlp" to "made" and comment that several variables with different num_categories are supported.
Cool, thanks for all the feedback! A quick call would be great, also to discuss suitable tests for this. Will reach out via email and tackle the straight forward things in the meantime. |
|
||
# outputs (batch_size, num_variables, num_categories) | ||
def log_prob(self, inputs, context=None): | ||
outputs = self.forward(inputs, context=context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these shapes correct?
After discussion with @janfb I will:
@janfb could you still check tho what is up with the simulator above? Do you have a hunch why the SNPE and MNLE posteriors different? EDIT:
|
…too. log_prob has shape issues tho
…ting mixed_density estimator log_probs and sample to work as well
…rg to categorical_model
971201b
to
8407911
Compare
8407911
to
2e5898b
Compare
I did a bit more work on this PR, current tests should be passing and I have swapped out all the legacy
This last thing has been haunting me in my sleep, as I cannot figure out what is wrong. Maybe you have an idea of what could be causing this. Help would be much appreciated. @janfb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the update!
Made another round of comments. Happy to have another call to sort them out.
|
||
Args: | ||
condition: batch of context parameters for the net. | ||
input: Original data, x0. (batch_size, *input_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better keep this general: "input variable", because it could be x
for MNLE, or theta
when doing mixed NPE (see Daniel's recent PR).
condition = self.activation(layer(condition)) | ||
def compute_probs(self, outputs): | ||
ps = F.softmax(outputs, dim=-1) * self.mask | ||
ps = ps / ps.sum(dim=-1, keepdim=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this numerically stable? Better use logsumexp
?
outputs = outputs.reshape(*inputs.shape, self.num_categories) | ||
ps = self.compute_probs(outputs) | ||
|
||
# categorical log prob | ||
log_prob = torch.log(ps.gather(-1, inputs.unsqueeze(-1).long())) | ||
log_prob = log_prob.squeeze(-1).sum(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very naive question here: the outputs are coming from the MADE, i.e., the conditional dependencies are already taken care of internally right?
I am just wondering because for the 1-D case, we used the network-predicted ps
to construct a Categorical
distribution and then evaluated the inputs
under that distribution. This is not needed here because the underlying MADE
takes both the inputs
and the context
and outputs unnormalized conditional probabilities already?
def _initialize(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I am missing something the _initialize()
is needed only in MixtureOfGaussiansMADE(MADE):
, not in MADE
, so it's not needed here?
for i in range(self.num_variables): | ||
outputs = self.forward(samples, context) | ||
outputs = outputs.reshape(*samples.shape, self.num_categories) | ||
ps = self.compute_probs(outputs) | ||
samples[:, :, i] = Categorical(probs=ps[:, :, i]).sample() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question as above: these samples are internally autoregressive, right? So each discrete variable is sampled given the upstream discrete variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just confused because I would have expected that we for each iteration we need pass the so far sampled discrete samples as context
; but this seems to be happening implicitly in the MADE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I see it: in line 148 you are updating samples
with the new samples of the current i
. It probably boils down to the same thing, but you could also update all sofar sampled samples
, i.e.,
amples[:, :, :(i+1)] = Categorical(probs=ps[:, :, :(i+1)]).sample()
?
categorical_model: type of categorical net to use for the discrete part of | ||
the data. Can be "made" or "mlp". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a note here that MLP works only for 1D / one column.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
had another look and made two suggestion which could be a reason for the missing first dim fit.
condition: batch of parameters for prediction. | ||
with torch.no_grad(): | ||
samples = torch.zeros(num_samples, batch_dim, self.num_variables) | ||
print(samples.shape, context.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debugging leftover?
for i in range(self.num_variables): | ||
outputs = self.forward(samples, context) | ||
outputs = outputs.reshape(*samples.shape, self.num_categories) | ||
ps = self.compute_probs(outputs) | ||
samples[:, :, i] = Categorical(probs=ps[:, :, i]).sample() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I see it: in line 148 you are updating samples
with the new samples of the current i
. It probably boils down to the same thing, but you could also update all sofar sampled samples
, i.e.,
amples[:, :, :(i+1)] = Categorical(probs=ps[:, :, :(i+1)]).sample()
?
sample_shape: number of samples to obtain. | ||
condition: batch of parameters for prediction. | ||
with torch.no_grad(): | ||
samples = torch.zeros(num_samples, batch_dim, self.num_variables) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be initialized with uniform torch.rand
?
What does this implement/fix? Explain your changes
This implements a
CategoricalMADE
to generelize MNLE to multiple discrete dimensions addressing #1112.Essentially adapts
nflows
's MixtureofGaussiansMADE to autoregressively model categorical distributions.Does this close any currently open issues?
Fixes #1112
Comments
I have already discussed this with @michaeldeistler.
Checklist
Put an
x
in the boxes that apply. You can also fill these out after creatingthe PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.
guidelines
with
pytest.mark.slow
.guidelines
main
(or there are no conflicts withmain
)