Skip to content

Commit

Permalink
improve map api
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak committed Apr 13, 2022
1 parent a2b50ac commit e376890
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 105 deletions.
28 changes: 15 additions & 13 deletions bvas/map.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import pandas as pd
import torch
from torch.linalg import solve_triangular as trisolve

from bvas.util import safe_cholesky


def map_inference(Y, Gamma, taus=[2 ** exponent for exponent in range(4, 16)]):
def map_inference(Y, Gamma, mutations, tau_reg):
r"""
Use Maximum A Posteriori (MAP) inference and a diffusion-based likelihood to infer
selection effects from genomic surveillance data. See reference [1] for details.
Expand All @@ -19,19 +21,19 @@ def map_inference(Y, Gamma, taus=[2 ** exponent for exponent in range(4, 16)]):
increments for each allele and where A is the number of alleles.
:param torch.Tensor Gamma: A torch.Tensor of shape (A, A) that encodes information about
second moments of allele frequencies.
:param list taus: A list of floats encoding regularizers `tau_reg` to use in MAP inference, i.e. we run
MAP once for each value of `tau_reg`. Note that this quantity is called `gamma` in reference [1].
:param list mutations: A list of strings of length `A` that encodes the names of the `A` alleles in `Y`.
:param float tau_reg: A positive float `tau_reg` that serves as the regularizer in MAP inference
along the lines of ridge regression. Note that this quantity is called `gamma` in reference [1].
:returns dict: Returns a dictionary of inferred selection coefficients beta, one for each value
in `taus`.
:returns pandas.DataFrame: Returns a `pd.DataFrame` containing results of inference.
"""
results = {}
L_tau = safe_cholesky(Gamma + tau_reg * torch.eye(Gamma.size(-1)).type_as(Gamma))
Yt = trisolve(L_tau, Y.unsqueeze(-1), upper=False)
beta = trisolve(L_tau.t(), Yt, upper=True).squeeze(-1)

for tau_reg in taus:
L_tau = safe_cholesky(Gamma + tau_reg * torch.eye(Gamma.size(-1)).type_as(Gamma))
Yt = trisolve(L_tau, Y.unsqueeze(-1), upper=False)
beta = trisolve(L_tau.t(), Yt, upper=True).squeeze(-1)
results['map_{}'.format(tau_reg)] = {'beta': beta.data.cpu().numpy(),
'tau_reg': tau_reg}
beta = pd.DataFrame(beta, index=mutations, columns=['Beta'])
beta['BetaAbs'] = np.fabs(beta.Beta.values)
beta = beta.sort_values(by='BetaAbs', ascending=False)
beta['Rank'] = 1 + np.arange(beta.shape[0])

return results
return beta[['Beta', 'Rank']]
Loading

0 comments on commit e376890

Please sign in to comment.