diff --git a/mmvec/q2/_method.py b/mmvec/q2/_method.py index d71ee31..515f4ca 100644 --- a/mmvec/q2/_method.py +++ b/mmvec/q2/_method.py @@ -75,7 +75,9 @@ def paired_omics(microbes: biom.Table, ranks = ranks - ranks.mean(axis=1).values.reshape(-1, 1) ranks = ranks - ranks.mean(axis=0) u, s, v = svds(ranks, k=latent_dim) - + s = s[::-1] + u = u[:, ::-1] + v = v[::-1, :] microbe_embed = u @ np.diag(s) metabolite_embed = v.T diff --git a/mmvec/q2/tests/test_method.py b/mmvec/q2/tests/test_method.py index 6a4ef9f..0ce833b 100644 --- a/mmvec/q2/tests/test_method.py +++ b/mmvec/q2/tests/test_method.py @@ -59,6 +59,12 @@ def test_fit(self): res_biplot.features.shape, np.array([self.metabolites.shape[0], latent_dim])) + # make sure that the biplot has the correct ordering + self.assertGreater(res_biplot.proportion_explained[0], + res_biplot.proportion_explained[1]) + self.assertGreater(res_biplot.eigvals[0], + res_biplot.eigvals[1]) + if __name__ == "__main__": unittest.main() diff --git a/scripts/mmvec b/scripts/mmvec index 80b744c..1164c6d 100644 --- a/scripts/mmvec +++ b/scripts/mmvec @@ -198,6 +198,9 @@ def paired_omics(microbe_file, metabolite_file, # Save to an ordination file ranks = ranks - ranks.mean(axis=0) u, s, v = svds(ranks, k=latent_dim) + s = s[::-1] + u = u[:, ::-1] + v = v[::-1, :] microbe_embed = u @ np.diag(s) metabolite_embed = v.T pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])]