Skip to content

Commit a2b50ac

Browse files
improve lapalce api
1 parent b4a25af commit a2b50ac

File tree

3 files changed

+90
-90
lines changed

3 files changed

+90
-90
lines changed

bvas/laplace.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
"""
2-
Unlike most of the code in this repository, This code requires pyro.
3-
See https://github.com/pyro-ppl/pyro#installing for installation instructions.
4-
"""
1+
import numpy as np
2+
import pandas as pd
53
import pyro
64
import pyro.distributions as dist
75
import torch
@@ -10,29 +8,33 @@
108
from bvas.util import safe_cholesky
119

1210

13-
def laplace_inference(Y, Gamma,
11+
def laplace_inference(Y, Gamma, mutations,
1412
coef_scale=1.0e-2, seed=0, num_steps=10 ** 4,
1513
log_every=500, init_lr=0.01):
1614
r"""
1715
Use Maximum A Posteriori (MAP) inference and a diffusion-based likelihood in conjunction
1816
with a sparsity-inducing Laplace prior on selection coefficients to infer
1917
selection effects from genomic surveillance data.
2018
19+
Unlike most of the code in this repository, `laplace_inference` depends on Pyro.
20+
2121
:param torch.Tensor Y: A torch.Tensor of shape (A,) that encodes integrated alelle frequency
2222
increments for each allele and where A is the number of alleles.
2323
:param torch.Tensor Gamma: A torch.Tensor of shape (A, A) that encodes information about
2424
second moments of allele frequencies.
25+
:param list mutations: A list of strings of length `A` that encodes the names of the `A` alleles in `Y`.
2526
:param float coef_scale: The regularization scale of the Laplace prior. Defaults to 0.01.
2627
:param int seed: Random number seed for reproducibility.
2728
:param int num_steps: The number of optimization steps to do. Defaults to ten thousand.
2829
:param int log_every: Controls logging frequency. Defaults to 500.
2930
:param float init_lr: The initial learning rate. Defaults to 0.01.
3031
31-
:returns dict: Returns a dictionary of containing the inferred selection coefficients beta.
32+
:returns pandas.DataFrame: Returns a `pd.DataFrame` containing results of inference.
3233
"""
3334
pyro.clear_param_store()
3435

3536
A = Gamma.size(-1)
37+
assert len(mutations) == A == Gamma.size(-2) == Y.size(0)
3638

3739
L = safe_cholesky(Gamma, num_tries=10)
3840
L_Y = trisolve(L, Y.unsqueeze(-1), upper=False).squeeze(-1)
@@ -45,7 +47,8 @@ def fit_svi():
4547
pyro.set_rng_seed(seed)
4648

4749
guide = pyro.infer.autoguide.AutoDelta(model)
48-
optim = pyro.optim.ClippedAdam({"lr": init_lr, "lrd": 0.01 ** (1 / num_steps), "betas": (0.5, 0.99)})
50+
optim = pyro.optim.ClippedAdam({"lr": init_lr, "lrd": 0.01 ** (1 / num_steps),
51+
"betas": (0.5, 0.99)})
4952
svi = pyro.infer.SVI(model, guide, optim, pyro.infer.Trace_ELBO())
5053

5154
for step in range(num_steps):
@@ -56,5 +59,9 @@ def fit_svi():
5659
return guide
5760

5861
beta = fit_svi().median()['beta'].data.cpu().numpy()
62+
beta = pd.DataFrame(beta, index=mutations, columns=['Beta'])
63+
beta['BetaAbs'] = np.fabs(beta.Beta.values)
64+
beta = beta.sort_values(by='BetaAbs', ascending=False)
65+
beta['Rank'] = 1 + np.arange(beta.shape[0])
5966

60-
return {'beta': beta}
67+
return beta[['Beta', 'Rank']]

docs/source/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,6 @@
6565
"python": ("https://docs.python.org/3/", None),
6666
"torch": ("https://pytorch.org/docs/master/", None),
6767
"pyro": ("http://docs.pyro.ai/en/stable/", None),
68+
"pandas": ("https://pandas.pydata.org/docs/", None),
6869
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
6970
}

0 commit comments

Comments
 (0)