Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMvec refactor #166

Open
mortonjt opened this issue Mar 22, 2022 · 6 comments
Open

MMvec refactor #166

mortonjt opened this issue Mar 22, 2022 · 6 comments

Comments

@mortonjt
Copy link
Collaborator

mortonjt commented Mar 22, 2022

We're going to go pytorch OR numpyro. The framework will have the following skeleton

model.py (mmvec.py)

import torch
import torch.nn
from torch.distributions import Multinomial

class MMvec(nn.Module):
    def __init__(self, num_microbes, num_metabolites, latent_dim):
        self.encoder = nn.Embedding(num_microbes, latent_dim)
        self.decoder = nn.Sequential([nn.Linear(latent_dim, num_metabolite), nn.Softmax()])
        # TODO : may want to have a better softmax

    def forward(X, Y):
        """ X is one-hot encodings (B x num_microbes).  Y is metabolite abundances (B x num_metabolites).  B is the batch size""" 
        z = self.encoder(X)
        pred_y = self.decoder(z)
        lp = Multinomial(pred_y).log_prob(Y).mean()
        return lp

train.py (could use Pytorch lightning)

The wishlist

  • Early stopping (see video for example)
  • Arviz for diagnostics diagnostics
  • Typing would be great. See torchtyping
  • Torchtests could be cool also. See torchtest
  • Being Bayesian would be nice. SWAG is the laziest approach
@Keegan-Evans
Copy link

Keegan-Evans commented Apr 12, 2022

@mortonjt
Copy link
Collaborator Author

mortonjt commented Apr 13, 2022

Hi @Keegan-Evans this is a great first pass. The basic architecture is there, and it looks like the gradient descent is working.

There are a couple of things that we'll want to try out

  1. Getting the unittests to pass at https://github.com/biocore/mmvec/blob/master/mmvec/tests/test_multimodal.py#L18
  2. Doubling checking the soils experiment at https://github.com/biocore/mmvec/blob/master/mmvec/tests/test_multimodal.py#L76

We may want to revisit the decoder architecture -- the softmax identifability may end up biting us when interpreting the decoder parameters (probably also why you see word2vec applications interpreting only the encoder parameters, rather than decoder parameters). In MMvec, we used the inverse ALR transform, which would look something like

class LinearALR(nn.Module):
    def __init__(self, input_dim, output_dim):
        W = nn.Parameter(torch.randn(output_dim - 1, input_dim))
        b = nn.Parameter(torch.randn(output_dim - 1))

    def forward(self, x):
        b = x.shape[0]
        z = torch.zeros((b, 1))
        x = torch.stack((z, x), axis=1)
        y = W @ x + b
        return F.softmax(y, axis=1)

ALR does have some issues (the factorization isn't going to be super accurate). We had to do some redundant computation in the original MMvec code to do SVD afterwards - but its most definitely going to lose some information (i.e. if we have k principal components, we'll recover less than k PCs due to the ALR). Using the ILR transform can help with this (see our preprint here). That code would look like something as follows

from gneiss.cluster import random_linkage
from gneiss.balances import sparse_balance_basis

class LinearILR(nn.Module):
    def __init__(self, input_dim, output_dim):
        tree = random_linkage(output_dim)  # pick random tree it doesn't really matter tbh
        basis = sparse_balance_basis(tree)
        Psi = torch.sparse_coo_tensor(
            indices.copy(), basis.data.astype(np.float32).copy(),
            requires_grad=False).coalesce()
        self.linear = nn.Linear(input_dim, output_dim - 1)
        self.register_buffer('Psi', Psi)

    def forward(self, x):
        y = self.linear(x)
        logy = (self.Psi.t() @ y.t()).t()
        return F.softmax(logy, axis=1)

We may want to have some small benchmarks with the unittests with all of these 3 approaches -- based on what I've seen (since the MMvec paper), there are going to be differences. And my hunch tells me that ILR is going to be more convenient.
Of course, we can talk offline about this.

@mortonjt
Copy link
Collaborator Author

Here is a statistical significance test for differential abundance

https://github.com/flatironinstitute/q2-matchmaker/blob/main/q2_matchmaker/_stats.py#L46

@Keegan-Evans
Copy link

@mortonjt,

In your original model, when constructing the ranks, you stack zeros onto the product of U and V here? Is it necessary or expedient for how you performed the ALR?

@mortonjt
Copy link
Collaborator Author

Hi @Keegan-Evans this is necessary if the V matrix is in ALR coordinates. If we're using ILR, then it would be U @ V @ Psi.T (if Psi is D x D - 1)

@Keegan-Evans
Copy link

Keegan-Evans commented Apr 22, 2022

@mortonjt, I was working on it this morning and realized that might be the case, thanks for the reply!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants