Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bringing over tests/test_molevol.py #40

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions tests/test_molevol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
import netam.molevol as molevol
from netam import framework

from netam.sequences import (
nt_idx_tensor_of_str,
translate_sequence,
AA_STR_SORTED,
CODONS,
NT_STR_SORTED,
)

# These happen to be the same as some examples in test_models.py but that's fine.
# If it was important that they were shared, we should put them in a conftest.py.
ex_mut_probs = torch.tensor([0.01, 0.02, 0.03])
ex_sub_probs = torch.tensor(
[[0.0, 0.3, 0.5, 0.2], [0.4, 0.0, 0.1, 0.5], [0.2, 0.3, 0.0, 0.5]]
)
ex_parent_codon_idxs = nt_idx_tensor_of_str("ACG")
parent_nt_seq = "CAGGTGCAGCTGGTGGAG" # QVQLVE
weights_path = "data/shmple_weights/my_shmoof"


def test_build_mutation_matrix():
correct_tensor = torch.tensor(
[
# probability of mutation to each nucleotide (first entry in the first row
# is probability of no mutation)
[0.99, 0.003, 0.005, 0.002],
[0.008, 0.98, 0.002, 0.01],
[0.006, 0.009, 0.97, 0.015],
]
)

computed_tensor = molevol.build_mutation_matrices(
ex_parent_codon_idxs.unsqueeze(0),
ex_mut_probs.unsqueeze(0),
ex_sub_probs.unsqueeze(0),
).squeeze()

assert torch.allclose(correct_tensor, computed_tensor)


def test_neutral_aa_mut_probs():
# This is the probability of a mutation to a codon that translates to the
# same. In this case, ACG is the codon, and it's fourfold degenerate. Thus
# we just multiply the probability of A and C staying the same from the
# correct_tensor just above.
correct_tensor = torch.tensor([1 - 0.99 * 0.98])

computed_tensor = molevol.neutral_aa_mut_probs(
ex_parent_codon_idxs.unsqueeze(0),
ex_mut_probs.unsqueeze(0),
ex_sub_probs.unsqueeze(0),
).squeeze()

assert torch.allclose(correct_tensor, computed_tensor)


def test_normalize_sub_probs():
parent_idxs = nt_idx_tensor_of_str("AC")
sub_probs = torch.tensor([[0.2, 0.3, 0.4, 0.1], [0.1, 0.2, 0.3, 0.4]])

expected_normalized = torch.tensor(
[[0.0, 0.375, 0.5, 0.125], [0.125, 0.0, 0.375, 0.5]]
)
normalized_sub_probs = molevol.normalize_sub_probs(parent_idxs, sub_probs)

assert normalized_sub_probs.shape == (2, 4), "Result has incorrect shape"
assert torch.allclose(
normalized_sub_probs, expected_normalized
), "Unexpected normalized values"


def iterative_aaprob_of_mut_and_sub(parent_codon, mut_probs, sub_probs):
"""
Original version of codon_to_aa_probabilities, used for testing.
"""
aa_probs = {}
for aa in AA_STR_SORTED:
aa_probs[aa] = 0.0

# iterate through all possible child codons
for child_codon in CODONS:
try:
aa = translate_sequence(child_codon)
except ValueError: # check for STOP codon
continue

# iterate through codon sites and compute total probability of potential child codon
child_prob = 1.0
for isite in range(3):
if parent_codon[isite] == child_codon[isite]:
child_prob *= 1.0 - mut_probs[isite]
else:
child_prob *= mut_probs[isite]
child_prob *= sub_probs[isite][NT_STR_SORTED.index(child_codon[isite])]

aa_probs[aa] += child_prob

# need renormalization factor so that amino acid probabilities sum to 1,
# since probabilities to STOP codon are dropped
psum = sum(aa_probs.values())

return torch.tensor([aa_probs[aa] / psum for aa in AA_STR_SORTED])


def test_aaprob_of_mut_and_sub():
crepe_path = "data/cnn_joi_sml-shmoof_small"
crepe = framework.load_crepe(crepe_path)
[rates], [subs] = crepe([parent_nt_seq])
mut_probs = 1.0 - torch.exp(-torch.tensor(rates.squeeze()))
parent_codon = parent_nt_seq[0:3]
parent_codon_idxs = nt_idx_tensor_of_str(parent_codon)
codon_mut_probs = mut_probs[0:3]
codon_subs = torch.tensor(subs[0:3])

iterative_result = iterative_aaprob_of_mut_and_sub(
parent_codon, codon_mut_probs, codon_subs
)

parent_codon_idxs = parent_codon_idxs.unsqueeze(0)
codon_mut_probs = codon_mut_probs.unsqueeze(0)
codon_subs = codon_subs.unsqueeze(0)

assert torch.allclose(
iterative_result,
molevol.aaprob_of_mut_and_sub(
parent_codon_idxs,
codon_mut_probs,
codon_subs,
).squeeze(),
)