Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CategoricalMADE #1269

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open

Conversation

jnsbck
Copy link
Contributor

@jnsbck jnsbck commented Sep 5, 2024

What does this implement/fix? Explain your changes

This implements a CategoricalMADE to generelize MNLE to multiple discrete dimensions addressing #1112.
Essentially adapts nflows's MixtureofGaussiansMADE to autoregressively model categorical distributions.

Does this close any currently open issues?

Fixes #1112

Comments

I have already discussed this with @michaeldeistler.

Checklist

Put an x in the boxes that apply. You can also fill these out after creating
the PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • I have commented my code, particularly in hard-to-understand areas
  • I have added tests that prove my fix is effective or that my feature works
  • I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)
  • For reviewer: The continuous deployment (CD) workflow are passing.

Copy link

codecov bot commented Sep 5, 2024

Codecov Report

Attention: Patch coverage is 37.89474% with 59 lines in your changes missing coverage. Please review.

Project coverage is 78.02%. Comparing base (8afd985) to head (0188cea).
Report is 32 commits behind head on main.

Files with missing lines Patch % Lines
sbi/neural_nets/estimators/categorical_net.py 22.22% 42 Missing ⚠️
sbi/neural_nets/net_builders/categorial.py 33.33% 14 Missing ⚠️
sbi/neural_nets/net_builders/mnle.py 76.92% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1269      +/-   ##
==========================================
- Coverage   86.05%   78.02%   -8.04%     
==========================================
  Files         118      118              
  Lines        8672     8791     +119     
==========================================
- Hits         7463     6859     -604     
- Misses       1209     1932     +723     
Flag Coverage Δ
unittests 78.02% <37.89%> (-8.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/neural_nets/estimators/__init__.py 100.00% <100.00%> (ø)
.../neural_nets/estimators/mixed_density_estimator.py 98.24% <100.00%> (+0.13%) ⬆️
sbi/neural_nets/net_builders/mnle.py 91.17% <76.92%> (-8.83%) ⬇️
sbi/neural_nets/net_builders/categorial.py 58.33% <33.33%> (-36.41%) ⬇️
sbi/neural_nets/estimators/categorical_net.py 56.56% <22.22%> (-41.27%) ⬇️

... and 36 files with indirect coverage changes

@jnsbck
Copy link
Contributor Author

jnsbck commented Sep 16, 2024

Hey @janfb,
would very much appreciate your input at this stage:

Currently the PR adds the CategoricalMADE and builder build_autoregressive_categoricalestimator + some minor modifications to build_mnle and MixedDensityEstimator. This enables multiple discrete variables with different numbers of classes via trainer = MNLE(density_estimator=lambda x,y: build_mnle(y,x,categorical_model="made")) Note that for some reason x and y have to be flipped for mnle.

As far as I can tell all functionalities of CategoricalMADE work for both 1D and ND inputs and running the Example_01_DecisionMakingModel.ipynb with the CatMADE matches the ground truth
image

The question now is: How should I verify this works? / Which tests should I add/modify? Do you have an idea for a good toy example with several discrete variables that I could use?

I have cooked up a toy simulator, for which I am getting good posteriors using SNPE, but for some reason MNLE raises a RuntimeError: probability tensor contains either 'inf', 'nan' or element < 0 Even for the unmodified MNLE. Any ideas why this could be?

This is the simulator

def toy_simulator(theta: torch.Tensor, centers: list[torch.Tensor]) -> torch.Tensor:
    batch_size, n_dimensions = theta.shape
    assert len(centers) == n_dimensions, "Number of center sets must match theta dimensions"
    
    # Calculate discrete classes by assiging to the closest center
    x_disc = torch.stack([
        torch.argmin(torch.abs(centers[i].unsqueeze(1) - theta[:, i].unsqueeze(0)), dim=0)
        for i in range(n_dimensions)
    ], dim=1)

    closest_centers = torch.stack([centers[i][x_disc[:, i]] for i in range(n_dimensions)], dim=1)
    # Add Gaussian noise to assigned class centers
    std = 0.4
    x_cont = closest_centers + std * torch.randn_like(closest_centers)
       
    return torch.cat([x_cont, x_disc], dim=1)

The setup:

torch.random.manual_seed(0)
centers = [
    torch.tensor([-0.5, 0.5]),
    # torch.tensor([-1.0, 0.0, 1.0]),
]

prior = BoxUniform(low=torch.tensor([-2.0]*len(centers)), high=torch.tensor([2.0]*len(centers)))
theta = prior.sample((20000,))
x = toy_simulator(theta, centers)

theta_o = prior.sample((1,))
x_o = toy_simulator(theta_o, centers)

NPE:

trainer = SNPE()
estimator = trainer.append_simulations(theta=theta, x=x).train(training_batch_size=1000)

snpe_posterior = trainer.build_posterior(prior=prior)
posterior_samples = snpe_posterior.sample((2000,), x=x_o)
pairplot(posterior_samples, limits=[[-2, 2], [-2, 2]], figsize=(5, 5), points=theta_o)

and the equivalent MNLE:

trainer = MNLE()
estimator = trainer.append_simulations(theta=theta, x=x).train(training_batch_size=1000)

mnle_posterior = trainer.build_posterior(prior=prior)
mnle_samples = mnle_posterior.sample((10000,), x=x_o)
pairplot(mnle_samples, limits=[[-2, 2], [-2, 2]], figsize=(5, 5), points=theta_o)

Hoping this makes sense. Lemme know if you need clarifications anywhere. Thanks for your feedback.

@jnsbck
Copy link
Contributor Author

jnsbck commented Oct 22, 2024

Hey @janfb,
you might have missed this, but I would be happy about feedback :)

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for tackling this @jnsbck! 👏

Please find below some comments and questions.
There might be some misunderstanding about variables and categories on my side. We can have a call if that's more efficient than commenting here.

sbi/neural_nets/estimators/categorical_net.py Show resolved Hide resolved
sbi/neural_nets/estimators/categorical_net.py Outdated Show resolved Hide resolved
sbi/neural_nets/estimators/categorical_net.py Outdated Show resolved Hide resolved
sbi/neural_nets/estimators/categorical_net.py Show resolved Hide resolved
sbi/neural_nets/estimators/categorical_net.py Outdated Show resolved Hide resolved
sbi/neural_nets/estimators/categorical_net.py Outdated Show resolved Hide resolved
sbi/neural_nets/net_builders/mnle.py Outdated Show resolved Hide resolved
sbi/neural_nets/net_builders/mnle.py Outdated Show resolved Hide resolved
Comment on lines 164 to 174
elif categorical_model == "mlp":
assert num_disc == 1, "MLP only supports 1D input."
discrete_net = build_categoricalmassestimator(
disc_x,
batch_y,
z_score_x="none", # discrete data should not be z-scored.
z_score_y="none", # y-embedding net already z-scores.
num_hidden=hidden_features,
num_layers=hidden_layers,
embedding_net=embedding_net,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more generally, isn't the MLP a special case of the MADE? can't we absorb them into one class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: Check if testcase is identical, if yes -> rm MLP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be the case for the examples in the tutorial notebook
image

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c2st between true and MADE MNLE posterior: 0.538
c2st between true and MLP MNLE posterior: 0.5730000000000001
c2st between MADE MNLE and MLP MNLE posterior: 0.5734999999999999

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's great! 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should eventually be integrated with the MNLE tutorial in 12_iid_data_and_permutation_invariant_embeddings.ipynb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change "mlp" to "made" and comment that several variables with different num_categories are supported.

@jnsbck
Copy link
Contributor Author

jnsbck commented Nov 4, 2024

Cool, thanks for all the feedback! A quick call would be great, also to discuss suitable tests for this. Will reach out via email and tackle the straight forward things in the meantime.


# outputs (batch_size, num_variables, num_categories)
def log_prob(self, inputs, context=None):
outputs = self.forward(inputs, context=context)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these shapes correct?

@jnsbck
Copy link
Contributor Author

jnsbck commented Nov 14, 2024

After discussion with @janfb I will:

  1. adapt the simulator of Example_01_DecisionMakingModel.ipynb to multiple discrete variables.
  2. Get this to run for 1D and ND
  3. fix remaining comments/issues
  4. possibly refactor (fold 1D into ND code).

@janfb could you still check tho what is up with the simulator above? Do you have a hunch why the SNPE and MNLE posteriors different?

EDIT:

  1. wip
  2. add new tests / update old ones with mutli dim example.

@jnsbck jnsbck force-pushed the jnsbck-categorical_made branch from 971201b to 8407911 Compare January 9, 2025 16:30
@jnsbck jnsbck force-pushed the jnsbck-categorical_made branch from 8407911 to 2e5898b Compare January 9, 2025 16:39
@jnsbck
Copy link
Contributor Author

jnsbck commented Jan 9, 2025

I did a bit more work on this PR, current tests should be passing and I have swapped out all the legacy CategoricalNet code for the CategoricalMADE. See changes and comments above (please close if no longer relevant).
A few things remain:

  • Add a bit more docs and comments
  • add test cases
  • adapt the tutorial to 2d? (Can just add another beta distribution to the prior)
  • make sure it runs for ND.

This last thing has been haunting me in my sleep, as I cannot figure out what is wrong. Maybe you have an idea of what could be causing this.
For 1D it works, but for ND it always gets the first discrete dim wrong, i.e. yields the prior (all other dims are correct, see image). I am not sure if the conditioning for the first dimension is broken somehow, but I am not able to pin down where this would be happening in my code.

image

Help would be much appreciated. @janfb

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the update!
Made another round of comments. Happy to have another call to sort them out.


Args:
condition: batch of context parameters for the net.
input: Original data, x0. (batch_size, *input_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better keep this general: "input variable", because it could be x for MNLE, or theta when doing mixed NPE (see Daniel's recent PR).

condition = self.activation(layer(condition))
def compute_probs(self, outputs):
ps = F.softmax(outputs, dim=-1) * self.mask
ps = ps / ps.sum(dim=-1, keepdim=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this numerically stable? Better use logsumexp?

Comment on lines +111 to +116
outputs = outputs.reshape(*inputs.shape, self.num_categories)
ps = self.compute_probs(outputs)

# categorical log prob
log_prob = torch.log(ps.gather(-1, inputs.unsqueeze(-1).long()))
log_prob = log_prob.squeeze(-1).sum(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very naive question here: the outputs are coming from the MADE, i.e., the conditional dependencies are already taken care of internally right?

I am just wondering because for the 1-D case, we used the network-predicted ps to construct a Categorical distribution and then evaluated the inputs under that distribution. This is not needed here because the underlying MADE takes both the inputs and the context and outputs unnormalized conditional probabilities already?

Comment on lines +113 to +153
def _initialize(self):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I am missing something the _initialize() is needed only in MixtureOfGaussiansMADE(MADE):, not in MADE, so it's not needed here?

Comment on lines +144 to +148
for i in range(self.num_variables):
outputs = self.forward(samples, context)
outputs = outputs.reshape(*samples.shape, self.num_categories)
ps = self.compute_probs(outputs)
samples[:, :, i] = Categorical(probs=ps[:, :, i]).sample()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as above: these samples are internally autoregressive, right? So each discrete variable is sampled given the upstream discrete variables?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just confused because I would have expected that we for each iteration we need pass the so far sampled discrete samples as context; but this seems to be happening implicitly in the MADE?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see it: in line 148 you are updating samples with the new samples of the current i. It probably boils down to the same thing, but you could also update all sofar sampled samples, i.e.,

amples[:, :, :(i+1)] = Categorical(probs=ps[:, :, :(i+1)]).sample()

?

Comment on lines +108 to +109
categorical_model: type of categorical net to use for the discrete part of
the data. Can be "made" or "mlp".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a note here that MLP works only for 1D / one column.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had another look and made two suggestion which could be a reason for the missing first dim fit.

condition: batch of parameters for prediction.
with torch.no_grad():
samples = torch.zeros(num_samples, batch_dim, self.num_variables)
print(samples.shape, context.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debugging leftover?

Comment on lines +144 to +148
for i in range(self.num_variables):
outputs = self.forward(samples, context)
outputs = outputs.reshape(*samples.shape, self.num_categories)
ps = self.compute_probs(outputs)
samples[:, :, i] = Categorical(probs=ps[:, :, i]).sample()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see it: in line 148 you are updating samples with the new samples of the current i. It probably boils down to the same thing, but you could also update all sofar sampled samples, i.e.,

amples[:, :, :(i+1)] = Categorical(probs=ps[:, :, :(i+1)]).sample()

?

sample_shape: number of samples to obtain.
condition: batch of parameters for prediction.
with torch.no_grad():
samples = torch.zeros(num_samples, batch_dim, self.num_variables)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be initialized with uniform torch.rand?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Change MixedDensityEstimator to AutoregressiveMixedDensityEstimator
2 participants