You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After having implemented non-atomic SNPE-C, I am writing this issue to keep track of things that would have been desirable to exist in mdn.
Get mixture components
The only non-protected methods of mdn are log_prob() and sample(). It would be great to have a non-protected get_mixture_components. Unlike the already existing, protected _get_mixture_components(), it should also call the embedding_net.
Evaluating the log_prob
The following code should be put in a separate static method called evaluate_mixture_log_prob():
batch_size, n_mixtures, output_dim = means.size()
inputs = inputs.view(-1, 1, output_dim)
# Split up evaluation into parts.
a = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
b = -(output_dim / 2.0) * np.log(2 * np.pi)
c = sumlogdiag
d1 = (inputs.expand_as(means) - means).view(
batch_size, n_mixtures, output_dim, 1
)
d2 = torch.matmul(precisions, d1)
d = -0.5 * torch.matmul(torch.transpose(d1, 2, 3), d2).view(
batch_size, n_mixtures
)
This would allow to evaluate the log_prob of a MoG without instantiating a mdn. Since snpe_c has to do this for every training data point at every iteration, it would be computationally cheaper.
Along with the above refactoring of get_mixture_components, this fully separates the two main steps of calling log_prob() in an mdn.
Log-prob based on cov
Right now, we use sumlogdiag for log_prob. If one does not yet have the cholesky trafo, it would be better to use log(det(cov))
Different init strategy for the means
Means are initialized close to 0. Maybe initializing at more random locations would be better.
Variable number of layers
Requires to write a forward function that loops over a torch.ModuleList.
The text was updated successfully, but these errors were encountered:
After having implemented non-atomic SNPE-C, I am writing this issue to keep track of things that would have been desirable to exist in
mdn
.Get mixture components
The only non-protected methods of
mdn
arelog_prob()
andsample()
. It would be great to have a non-protectedget_mixture_components
. Unlike the already existing, protected_get_mixture_components()
, it should also call theembedding_net
.Evaluating the log_prob
The following code should be put in a separate static method called
evaluate_mixture_log_prob()
:This would allow to evaluate the
log_prob
of a MoG without instantiating amdn
. Since snpe_c has to do this for every training data point at every iteration, it would be computationally cheaper.Along with the above refactoring of
get_mixture_components
, this fully separates the two main steps of callinglog_prob()
in an mdn.Log-prob based on cov
Right now, we use
sumlogdiag
for log_prob. If one does not yet have the cholesky trafo, it would be better to use log(det(cov))Different init strategy for the means
Means are initialized close to 0. Maybe initializing at more random locations would be better.
Variable number of layers
Requires to write a forward function that loops over a
torch.ModuleList
.The text was updated successfully, but these errors were encountered: