Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Desirable MDN features for SNPE-C / SNPE-A #16

Open
michaeldeistler opened this issue Aug 18, 2020 · 1 comment
Open

Desirable MDN features for SNPE-C / SNPE-A #16

michaeldeistler opened this issue Aug 18, 2020 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Aug 18, 2020

After having implemented non-atomic SNPE-C, I am writing this issue to keep track of things that would have been desirable to exist in mdn.

Get mixture components

The only non-protected methods of mdn are log_prob() and sample(). It would be great to have a non-protected get_mixture_components. Unlike the already existing, protected _get_mixture_components(), it should also call the embedding_net.

Evaluating the log_prob

The following code should be put in a separate static method called evaluate_mixture_log_prob():

batch_size, n_mixtures, output_dim = means.size()
inputs = inputs.view(-1, 1, output_dim)

# Split up evaluation into parts.
a = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
b = -(output_dim / 2.0) * np.log(2 * np.pi)
c = sumlogdiag
d1 = (inputs.expand_as(means) - means).view(
     batch_size, n_mixtures, output_dim, 1
)
d2 = torch.matmul(precisions, d1)
d = -0.5 * torch.matmul(torch.transpose(d1, 2, 3), d2).view(
          batch_size, n_mixtures
)

This would allow to evaluate the log_prob of a MoG without instantiating a mdn. Since snpe_c has to do this for every training data point at every iteration, it would be computationally cheaper.

Along with the above refactoring of get_mixture_components, this fully separates the two main steps of calling log_prob() in an mdn.

Log-prob based on cov

Right now, we use sumlogdiag for log_prob. If one does not yet have the cholesky trafo, it would be better to use log(det(cov))

Different init strategy for the means

Means are initialized close to 0. Maybe initializing at more random locations would be better.

Variable number of layers

Requires to write a forward function that loops over a torch.ModuleList.

@michaeldeistler michaeldeistler self-assigned this Aug 18, 2020
@michaeldeistler michaeldeistler added the enhancement New feature or request label Aug 18, 2020
@alvorithm
Copy link

I am adding here

Update MDNs to run on GPUs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants