Skip to content

Commit

Permalink
Refactor neutral_aa_mut_probs to return per-AA information (#62)
Browse files Browse the repository at this point in the history
* Refactor neutral_aa_mut_probs to return per-AA information
* making test_multihit output ignored
* Fixing some warnings
  • Loading branch information
matsen authored Sep 27, 2024
1 parent 3f98c6c commit 03a061d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 16 deletions.
62 changes: 50 additions & 12 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ def build_mutation_matrices(
probabilities for each parent codon along the sequence, this function
constructs a sequence of 3x4 matrices. Each matrix in the sequence
represents the mutation probabilities for each nucleotide position in a
parent codon. The ijkth entry of the resulting tensor corresponds to the probability of the jth nucleotide
in the ith parent codon mutating to the kth nucleotide (in indices).
parent codon. The ijkth entry of the resulting tensor corresponds to the
probability of the jth nucleotide in the ith parent codon mutating to the
kth nucleotide (in indices).
Args:
parent_codon_idxs (torch.Tensor): 2D tensor with each row containing indices representing
Expand Down Expand Up @@ -334,26 +335,22 @@ def build_codon_mutsel(
return codon_mutsel, sums_too_big


def neutral_aa_mut_probs(
def neutral_aa_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
) -> Tensor:
"""For every site, what is the probability that the amino acid will have a
substution or mutate to a stop under neutral evolution?
This code computes all the probabilities and then indexes into that tensor
to get the relevant probabilities. This isn't the most efficient way to do
this, but it's the cleanest. We could make it faster as needed.
"""For every site, what is the probability that the amino acid will mutate to every
amino acid?
Args:
parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3)
codon_mut_probs (torch.Tensor): The mutation probabilities for each site in each codon. Shape: (codon_count, 3)
codon_sub_probs (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4)
Returns:
torch.Tensor: The probability that each site will change amino acid.
Shape: (codon_count,)
torch.Tensor: The probability that each site will change to each amino acid.
Shape: (codon_count, 20)
"""

mut_matrices = build_mutation_matrices(
Expand All @@ -364,7 +361,21 @@ def neutral_aa_mut_probs(
# Get the probability of mutating to each amino acid.
aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX

# Next we build a table that will allow us to look up the amino acid index
return aa_probs


def mut_probs_of_aa_probs(
parent_codon_idxs: Tensor,
aa_probs: Tensor,
) -> Tensor:
"""For every site, what is the probability that the amino acid will have a
substution or mutate to a stop under neutral evolution?
Args:
parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3)
aa_probs (torch.Tensor): The probability that each site will change to each amino acid. Shape: (codon_count, 20)
"""
# We build a table that will allow us to look up the amino acid index
# from the codon indices. Argmax gets the aa index.
aa_idx_from_codon = CODON_AA_INDICATOR_MATRIX.argmax(dim=1).view(4, 4, 4)

Expand All @@ -381,6 +392,33 @@ def neutral_aa_mut_probs(
return 1.0 - p_staying_same


def neutral_aa_mut_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
) -> Tensor:
"""For every site, what is the probability that the amino acid will have a
substution or mutate to a stop under neutral evolution?
This code computes all the probabilities and then indexes into that tensor
to get the relevant probabilities. This isn't the most efficient way to do
this, but it's the cleanest. We could make it faster as needed.
Args:
parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3)
codon_mut_probs (torch.Tensor): The mutation probabilities for each site in each codon. Shape: (codon_count, 3)
codon_sub_probs (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4)
Returns:
torch.Tensor: The probability that each site will change to some other amino acid.
Shape: (codon_count,)
"""

aa_probs = neutral_aa_probs(parent_codon_idxs, codon_mut_probs, codon_sub_probs)
mut_probs = mut_probs_of_aa_probs(parent_codon_idxs, aa_probs)
return mut_probs


def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs):
"""Constructs the log_pcp_probability function specific to given rates and
sub_probs.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ def test_aaprob_of_mut_and_sub():
crepe_path = "data/cnn_joi_sml-shmoof_small"
crepe = framework.load_crepe(crepe_path)
[rates], [subs] = crepe([parent_nt_seq])
mut_probs = 1.0 - torch.exp(-torch.tensor(rates.squeeze()))
mut_probs = 1.0 - torch.exp(-rates.squeeze().clone().detach())
parent_codon = parent_nt_seq[0:3]
parent_codon_idxs = nt_idx_tensor_of_str(parent_codon)
codon_mut_probs = mut_probs[0:3]
codon_subs = torch.tensor(subs[0:3])
codon_subs = subs.clone().detach()[0:3]

iterative_result = iterative_aaprob_of_mut_and_sub(
parent_codon, codon_mut_probs, codon_subs
Expand Down
7 changes: 5 additions & 2 deletions tests/test_multihit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import netam.multihit as multihit
import netam.framework as framework
import netam.hit_class as hit_class
Expand Down Expand Up @@ -81,8 +83,9 @@ def test_train(hitclass_burrito):


def test_serialize(hitclass_burrito):
hitclass_burrito.save_crepe("test_multihit_crepe")
new_crepe = framework.load_crepe("test_multihit_crepe")
os.makedirs("_ignore", exist_ok=True)
hitclass_burrito.save_crepe("_ignore/test_multihit_crepe")
new_crepe = framework.load_crepe("_ignore/test_multihit_crepe")
assert torch.allclose(new_crepe.model.values, hitclass_burrito.model.values)


Expand Down

0 comments on commit 03a061d

Please sign in to comment.