-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MMvec refactor #166
Comments
Hi @Keegan-Evans this is a great first pass. The basic architecture is there, and it looks like the gradient descent is working. There are a couple of things that we'll want to try out
We may want to revisit the decoder architecture -- the softmax identifability may end up biting us when interpreting the decoder parameters (probably also why you see word2vec applications interpreting only the encoder parameters, rather than decoder parameters). In MMvec, we used the inverse ALR transform, which would look something like class LinearALR(nn.Module):
def __init__(self, input_dim, output_dim):
W = nn.Parameter(torch.randn(output_dim - 1, input_dim))
b = nn.Parameter(torch.randn(output_dim - 1))
def forward(self, x):
b = x.shape[0]
z = torch.zeros((b, 1))
x = torch.stack((z, x), axis=1)
y = W @ x + b
return F.softmax(y, axis=1) ALR does have some issues (the factorization isn't going to be super accurate). We had to do some redundant computation in the original MMvec code to do SVD afterwards - but its most definitely going to lose some information (i.e. if we have k principal components, we'll recover less than k PCs due to the ALR). Using the ILR transform can help with this (see our preprint here). That code would look like something as follows from gneiss.cluster import random_linkage
from gneiss.balances import sparse_balance_basis
class LinearILR(nn.Module):
def __init__(self, input_dim, output_dim):
tree = random_linkage(output_dim) # pick random tree it doesn't really matter tbh
basis = sparse_balance_basis(tree)
Psi = torch.sparse_coo_tensor(
indices.copy(), basis.data.astype(np.float32).copy(),
requires_grad=False).coalesce()
self.linear = nn.Linear(input_dim, output_dim - 1)
self.register_buffer('Psi', Psi)
def forward(self, x):
y = self.linear(x)
logy = (self.Psi.t() @ y.t()).t()
return F.softmax(logy, axis=1) We may want to have some small benchmarks with the unittests with all of these 3 approaches -- based on what I've seen (since the MMvec paper), there are going to be differences. And my hunch tells me that ILR is going to be more convenient. |
Here is a statistical significance test for differential abundance https://github.com/flatironinstitute/q2-matchmaker/blob/main/q2_matchmaker/_stats.py#L46 |
Hi @Keegan-Evans this is necessary if the V matrix is in ALR coordinates. If we're using ILR, then it would be |
@mortonjt, I was working on it this morning and realized that might be the case, thanks for the reply! |
We're going to go pytorch OR numpyro. The framework will have the following skeleton
model.py (mmvec.py)
train.py
(could use Pytorch lightning)The wishlist
The text was updated successfully, but these errors were encountered: