Skip to content

Commit

Permalink
Moving mutsel stuff over
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 10, 2024
1 parent ea24a89 commit 1ca1a6b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 6 deletions.
22 changes: 17 additions & 5 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

from tqdm import tqdm

from epam.models import WrappedBinaryMutSel

from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
Expand All @@ -39,6 +37,7 @@
import netam.sequences as sequences
from netam.sequences import (
aa_subs_indicator_tensor_of,
translate_sequence,
translate_sequences,
)

Expand Down Expand Up @@ -302,7 +301,6 @@ def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
class DNSMBurrito(framework.Burrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wrapped_model = WrappedBinaryMutSel(self.model, weights_directory=None)

def load_branch_lengths(self, in_csv_prefix):
if self.train_dataset is not None:
Expand Down Expand Up @@ -339,6 +337,19 @@ def loss_of_batch(self, batch):
predictions = self.predictions_of_batch(batch).masked_select(mask)
return self.bce_loss(predictions, aa_subs_indicator)

def build_selection_matrix_from_parent(self, parent: str):
parent = translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
# Set "diagonal" elements to one.
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

return selection_matrix

def _find_optimal_branch_length(
self,
parent,
Expand All @@ -350,8 +361,9 @@ def _find_optimal_branch_length(
):
if parent == child:
return 0.0
log_pcp_probability = self.wrapped_model._build_log_pcp_probability(
parent, child, rates, subs_probs
sel_matrix = self.build_selection_matrix_from_parent(parent)
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix, parent, child, rates, subs_probs
)
if type(starting_branch_length) == torch.Tensor:
starting_branch_length = starting_branch_length.detach().item()
Expand Down
53 changes: 52 additions & 1 deletion netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import torch
from torch import Tensor, optim

from epam.sequences import CODON_AA_INDICATOR_MATRIX
from netam.sequences import CODON_AA_INDICATOR_MATRIX

import netam.sequences as sequences


def normalize_sub_probs(parent_idxs: Tensor, sub_probs: Tensor) -> Tensor:
Expand Down Expand Up @@ -354,6 +356,55 @@ def neutral_aa_mut_probs(
return 1.0 - p_staying_same


def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs):
"""
Constructs the log_pcp_probability function specific to given rates and sub_probs.
This function takes log_branch_length as input and returns the log
probability of the child sequence. It uses log of branch length to
ensure non-negativity.
"""

assert len(parent) % 3 == 0
assert sel_matrix.shape == (len(parent) // 3, 20)

parent_idxs = sequences.nt_idx_tensor_of_str(parent)
child_idxs = sequences.nt_idx_tensor_of_str(child)

def log_pcp_probability(log_branch_length: torch.Tensor):
branch_length = torch.exp(log_branch_length)
mut_probs = 1.0 - torch.exp(-branch_length * rates)

codon_mutsel, sums_too_big = build_codon_mutsel(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
sub_probs.reshape(-1, 3, 4),
sel_matrix,
)

# This is a diagnostic generating data for netam issue #7.
# if sums_too_big is not None:
# self.csv_file.write(f"{parent},{child},{branch_length},{sums_too_big}\n")

reshaped_child_idxs = child_idxs.reshape(-1, 3)
child_prob_vector = codon_mutsel[
torch.arange(len(reshaped_child_idxs)),
reshaped_child_idxs[:, 0],
reshaped_child_idxs[:, 1],
reshaped_child_idxs[:, 2],
]

child_prob_vector = torch.clamp(child_prob_vector, min=1e-10)

result = torch.sum(torch.log(child_prob_vector))

assert torch.isfinite(result)

return result

return log_pcp_probability


def optimize_branch_length(
log_prob_fn,
starting_branch_length,
Expand Down

0 comments on commit 1ca1a6b

Please sign in to comment.