Skip to content

Commit 6a00b46

Browse files
add map.py and use in basic_demo.ipynb
1 parent fa0e2f7 commit 6a00b46

File tree

3 files changed

+232
-16
lines changed

3 files changed

+232
-16
lines changed

bvas/map.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from torch import triangular_solve as trisolve
3+
4+
from bvas.util import safe_cholesky
5+
6+
7+
def map_inference(Y, Gamma, taus=[2 ** exponent for exponent in range(4, 16)]):
8+
r"""
9+
Use Maximum A Posteriori (MAP) inference and a diffusion-based likelihood to infer
10+
selection effects from genomic surveillance data. See reference [1] for details.
11+
12+
References:
13+
[1] "Inferring effects of mutations on SARS-CoV-2 transmission from genomic surveillance data,"
14+
Brian Lee, Muhammad Saqib Sohail, Elizabeth Finney, Syed Faraz Ahmed, Ahmed Abdul Quadeer,
15+
Matthew R. McKay, John P. Barton.
16+
17+
:param torch.Tensor Y: A torch.Tensor of shape (A,) that encodes integrated alelle frequency
18+
increments for each allele and where A is the number of alleles.
19+
:param torch.Tensor Gamma: A torch.Tensor of shape (A, A) that encodes information about
20+
second moments of allele frequencies.
21+
:param list taus: A list of floats encoding regularizers `tau_reg` to use in MAP inference, i.e. we run
22+
MAP once for each value of `tau_reg`. Note that this quantity is called `gamma` in reference [1].
23+
24+
:returns dict: Returns a dictionary of inferred selection coefficients beta, one for each value
25+
in `taus`.
26+
"""
27+
results = {}
28+
29+
for tau_reg in taus:
30+
L_tau = safe_cholesky(Gamma + tau_reg * torch.eye(Gamma.size(-1)).type_as(Gamma))
31+
Yt = trisolve(Y.unsqueeze(-1), L_tau, upper=False)[0]
32+
beta = trisolve(Yt, L_tau.t(), upper=True)[0].squeeze(-1)
33+
results['map_{}'.format(tau_reg)] = {'beta': beta.data.cpu().numpy(),
34+
'tau_reg': tau_reg}
35+
36+
return results

data/covid_preprocessing.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pandas as pd
66
import torch
7+
78
from bvas.util import get_longest_ones_index
89

910

notebooks/basic_demo.ipynb

+195-16
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,28 @@
22
"cells": [
33
{
44
"cell_type": "markdown",
5-
"id": "3ceb4719",
5+
"id": "acc138b0",
66
"metadata": {},
77
"source": [
88
"# Basic BVAS demo using simulated data"
99
]
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 1,
14-
"id": "aceebab8",
13+
"execution_count": 25,
14+
"id": "105ca8b5",
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
18-
"from bvas import simulate_data, BVASSelector"
18+
"from bvas import simulate_data, BVASSelector\n",
19+
"from bvas.map import map_inference\n",
20+
"import pandas as pd\n",
21+
"import numpy as np"
1922
]
2023
},
2124
{
2225
"cell_type": "markdown",
23-
"id": "fef691b9",
26+
"id": "d4ae7daa",
2427
"metadata": {},
2528
"source": [
2629
"### Simulate data"
@@ -29,7 +32,7 @@
2932
{
3033
"cell_type": "code",
3134
"execution_count": 2,
32-
"id": "8f170790",
35+
"id": "67c8b3f2",
3336
"metadata": {},
3437
"outputs": [],
3538
"source": [
@@ -47,7 +50,7 @@
4750
{
4851
"cell_type": "code",
4952
"execution_count": 3,
50-
"id": "f74e7b1b",
53+
"id": "a8592b05",
5154
"metadata": {},
5255
"outputs": [
5356
{
@@ -73,7 +76,7 @@
7376
},
7477
{
7578
"cell_type": "markdown",
76-
"id": "24a7ce7f",
79+
"id": "929783f8",
7780
"metadata": {},
7881
"source": [
7982
"### Instantiate BVASSelector object"
@@ -82,7 +85,7 @@
8285
{
8386
"cell_type": "code",
8487
"execution_count": 4,
85-
"id": "617cb379",
88+
"id": "4a7a3d81",
8689
"metadata": {},
8790
"outputs": [],
8891
"source": [
@@ -99,7 +102,7 @@
99102
},
100103
{
101104
"cell_type": "markdown",
102-
"id": "884fedf0",
105+
"id": "56b0072c",
103106
"metadata": {},
104107
"source": [
105108
"### Run BVAS MCMC-based inference"
@@ -108,13 +111,13 @@
108111
{
109112
"cell_type": "code",
110113
"execution_count": 5,
111-
"id": "77bcb9bd",
114+
"id": "9c285218",
112115
"metadata": {},
113116
"outputs": [
114117
{
115118
"data": {
116119
"application/vnd.jupyter.widget-view+json": {
117-
"model_id": "cc591be1cd164f68be7f385dc2537701",
120+
"model_id": "3ddab09989224a008f51f31c69706a59",
118121
"version_major": 2,
119122
"version_minor": 0
120123
},
@@ -132,7 +135,7 @@
132135
},
133136
{
134137
"cell_type": "markdown",
135-
"id": "86691da7",
138+
"id": "0b3a07cc",
136139
"metadata": {},
137140
"source": [
138141
"### Inspect results\n",
@@ -149,7 +152,7 @@
149152
{
150153
"cell_type": "code",
151154
"execution_count": 6,
152-
"id": "f1ae71fd",
155+
"id": "a13e39fc",
153156
"metadata": {},
154157
"outputs": [
155158
{
@@ -182,7 +185,7 @@
182185
{
183186
"cell_type": "code",
184187
"execution_count": 9,
185-
"id": "6d2a48cd",
188+
"id": "81e46a30",
186189
"metadata": {},
187190
"outputs": [
188191
{
@@ -211,13 +214,189 @@
211214
{
212215
"cell_type": "code",
213216
"execution_count": 10,
214-
"id": "e5b9424c",
217+
"id": "7b64636a",
215218
"metadata": {},
216219
"outputs": [],
217220
"source": [
218221
"# the remaining coefficients are all zero\n",
219222
"assert data['true_betas'][10:].min().item() == data['true_betas'][10:].max().item() == 0.0"
220223
]
224+
},
225+
{
226+
"cell_type": "markdown",
227+
"id": "7149875f",
228+
"metadata": {},
229+
"source": [
230+
"# Compare to MAP inference\n",
231+
"\n",
232+
"Let's compare to Maximum A posteriorir (i.e. MAP) inference as in [Inferring effects of mutations on SARS-CoV-2 transmission from genomic surveillance data](https://www.medrxiv.org/content/10.1101/2021.12.31.21268591v2)."
233+
]
234+
},
235+
{
236+
"cell_type": "code",
237+
"execution_count": 64,
238+
"id": "975411eb",
239+
"metadata": {},
240+
"outputs": [],
241+
"source": [
242+
"map_results = map_inference(data['Y'], data['Gamma'], taus=[2048.0])\n",
243+
"inferred_beta = map_results['map_2048.0']['beta']"
244+
]
245+
},
246+
{
247+
"cell_type": "code",
248+
"execution_count": 65,
249+
"id": "60cffa06",
250+
"metadata": {},
251+
"outputs": [],
252+
"source": [
253+
"# package results as Pandas DataFrame\n",
254+
"inferred_beta = pd.DataFrame(inferred_beta, index=mutations, columns=['Beta'])\n",
255+
"inferred_beta['BetaAbs'] = np.fabs(inferred_beta)\n",
256+
"inferred_beta = inferred_beta.sort_values(by='BetaAbs', ascending=False)\n",
257+
"inferred_beta['Rank'] = 1 + np.arange(inferred_beta.shape[0])\n",
258+
"inferred_beta = inferred_beta[['Beta', 'Rank']]"
259+
]
260+
},
261+
{
262+
"cell_type": "code",
263+
"execution_count": 67,
264+
"id": "45a53f68",
265+
"metadata": {},
266+
"outputs": [
267+
{
268+
"data": {
269+
"text/html": [
270+
"<div>\n",
271+
"<style scoped>\n",
272+
" .dataframe tbody tr th:only-of-type {\n",
273+
" vertical-align: middle;\n",
274+
" }\n",
275+
"\n",
276+
" .dataframe tbody tr th {\n",
277+
" vertical-align: top;\n",
278+
" }\n",
279+
"\n",
280+
" .dataframe thead th {\n",
281+
" text-align: right;\n",
282+
" }\n",
283+
"</style>\n",
284+
"<table border=\"1\" class=\"dataframe\">\n",
285+
" <thead>\n",
286+
" <tr style=\"text-align: right;\">\n",
287+
" <th></th>\n",
288+
" <th>Beta</th>\n",
289+
" <th>Rank</th>\n",
290+
" </tr>\n",
291+
" </thead>\n",
292+
" <tbody>\n",
293+
" <tr>\n",
294+
" <th>Causal9</th>\n",
295+
" <td>-0.053871</td>\n",
296+
" <td>1</td>\n",
297+
" </tr>\n",
298+
" <tr>\n",
299+
" <th>Causal5</th>\n",
300+
" <td>0.049838</td>\n",
301+
" <td>2</td>\n",
302+
" </tr>\n",
303+
" <tr>\n",
304+
" <th>Causal10</th>\n",
305+
" <td>-0.048263</td>\n",
306+
" <td>3</td>\n",
307+
" </tr>\n",
308+
" <tr>\n",
309+
" <th>Causal4</th>\n",
310+
" <td>0.045866</td>\n",
311+
" <td>4</td>\n",
312+
" </tr>\n",
313+
" <tr>\n",
314+
" <th>Causal3</th>\n",
315+
" <td>0.027333</td>\n",
316+
" <td>5</td>\n",
317+
" </tr>\n",
318+
" <tr>\n",
319+
" <th>Causal8</th>\n",
320+
" <td>-0.021542</td>\n",
321+
" <td>6</td>\n",
322+
" </tr>\n",
323+
" <tr>\n",
324+
" <th>Spurious80</th>\n",
325+
" <td>0.020984</td>\n",
326+
" <td>7</td>\n",
327+
" </tr>\n",
328+
" <tr>\n",
329+
" <th>Spurious44</th>\n",
330+
" <td>-0.017381</td>\n",
331+
" <td>8</td>\n",
332+
" </tr>\n",
333+
" <tr>\n",
334+
" <th>Spurious68</th>\n",
335+
" <td>-0.015019</td>\n",
336+
" <td>9</td>\n",
337+
" </tr>\n",
338+
" <tr>\n",
339+
" <th>Spurious61</th>\n",
340+
" <td>0.014249</td>\n",
341+
" <td>10</td>\n",
342+
" </tr>\n",
343+
" <tr>\n",
344+
" <th>Spurious38</th>\n",
345+
" <td>0.014112</td>\n",
346+
" <td>11</td>\n",
347+
" </tr>\n",
348+
" <tr>\n",
349+
" <th>Spurious85</th>\n",
350+
" <td>0.012077</td>\n",
351+
" <td>12</td>\n",
352+
" </tr>\n",
353+
" <tr>\n",
354+
" <th>Spurious90</th>\n",
355+
" <td>0.012060</td>\n",
356+
" <td>13</td>\n",
357+
" </tr>\n",
358+
" <tr>\n",
359+
" <th>Spurious66</th>\n",
360+
" <td>0.011890</td>\n",
361+
" <td>14</td>\n",
362+
" </tr>\n",
363+
" <tr>\n",
364+
" <th>Spurious70</th>\n",
365+
" <td>0.011479</td>\n",
366+
" <td>15</td>\n",
367+
" </tr>\n",
368+
" </tbody>\n",
369+
"</table>\n",
370+
"</div>"
371+
],
372+
"text/plain": [
373+
" Beta Rank\n",
374+
"Causal9 -0.053871 1\n",
375+
"Causal5 0.049838 2\n",
376+
"Causal10 -0.048263 3\n",
377+
"Causal4 0.045866 4\n",
378+
"Causal3 0.027333 5\n",
379+
"Causal8 -0.021542 6\n",
380+
"Spurious80 0.020984 7\n",
381+
"Spurious44 -0.017381 8\n",
382+
"Spurious68 -0.015019 9\n",
383+
"Spurious61 0.014249 10\n",
384+
"Spurious38 0.014112 11\n",
385+
"Spurious85 0.012077 12\n",
386+
"Spurious90 0.012060 13\n",
387+
"Spurious66 0.011890 14\n",
388+
"Spurious70 0.011479 15"
389+
]
390+
},
391+
"execution_count": 67,
392+
"metadata": {},
393+
"output_type": "execute_result"
394+
}
395+
],
396+
"source": [
397+
"# MAP places 6/10 of the causal alleles at the top\n",
398+
"inferred_beta.iloc[:15]"
399+
]
221400
}
222401
],
223402
"metadata": {

0 commit comments

Comments
 (0)