diff --git a/netam/dnsm.py b/netam/dnsm.py index cc8a93ca..4969631b 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -24,8 +24,6 @@ from tqdm import tqdm -from epam.models import WrappedBinaryMutSel - from netam.common import ( MAX_AMBIG_AA_IDX, aa_idx_tensor_of_str_ambig, @@ -39,6 +37,7 @@ import netam.sequences as sequences from netam.sequences import ( aa_subs_indicator_tensor_of, + translate_sequence, translate_sequences, ) @@ -302,7 +301,6 @@ def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0): class DNSMBurrito(framework.Burrito): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.wrapped_model = WrappedBinaryMutSel(self.model, weights_directory=None) def load_branch_lengths(self, in_csv_prefix): if self.train_dataset is not None: @@ -339,6 +337,19 @@ def loss_of_batch(self, batch): predictions = self.predictions_of_batch(batch).masked_select(mask) return self.bce_loss(predictions, aa_subs_indicator) + def build_selection_matrix_from_parent(self, parent: str): + parent = translate_sequence(parent) + selection_factors = self.model.selection_factors_of_aa_str(parent) + selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float) + # Every "off-diagonal" entry of the selection matrix is set to the selection + # factor, where "diagonal" means keeping the same amino acid. + selection_matrix[:, :] = selection_factors[:, None] + # Set "diagonal" elements to one. + parent_idxs = sequences.aa_idx_array_of_str(parent) + selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 + + return selection_matrix + def _find_optimal_branch_length( self, parent, @@ -350,8 +361,9 @@ def _find_optimal_branch_length( ): if parent == child: return 0.0 - log_pcp_probability = self.wrapped_model._build_log_pcp_probability( - parent, child, rates, subs_probs + sel_matrix = self.build_selection_matrix_from_parent(parent) + log_pcp_probability = molevol.mutsel_log_pcp_probability_of( + sel_matrix, parent, child, rates, subs_probs ) if type(starting_branch_length) == torch.Tensor: starting_branch_length = starting_branch_length.detach().item() diff --git a/netam/molevol.py b/netam/molevol.py index c958f62b..ad1b44d7 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -14,7 +14,9 @@ import torch from torch import Tensor, optim -from epam.sequences import CODON_AA_INDICATOR_MATRIX +from netam.sequences import CODON_AA_INDICATOR_MATRIX + +import netam.sequences as sequences def normalize_sub_probs(parent_idxs: Tensor, sub_probs: Tensor) -> Tensor: @@ -354,6 +356,55 @@ def neutral_aa_mut_probs( return 1.0 - p_staying_same +def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs): + """ + Constructs the log_pcp_probability function specific to given rates and sub_probs. + + This function takes log_branch_length as input and returns the log + probability of the child sequence. It uses log of branch length to + ensure non-negativity. + """ + + assert len(parent) % 3 == 0 + assert sel_matrix.shape == (len(parent) // 3, 20) + + parent_idxs = sequences.nt_idx_tensor_of_str(parent) + child_idxs = sequences.nt_idx_tensor_of_str(child) + + def log_pcp_probability(log_branch_length: torch.Tensor): + branch_length = torch.exp(log_branch_length) + mut_probs = 1.0 - torch.exp(-branch_length * rates) + + codon_mutsel, sums_too_big = build_codon_mutsel( + parent_idxs.reshape(-1, 3), + mut_probs.reshape(-1, 3), + sub_probs.reshape(-1, 3, 4), + sel_matrix, + ) + + # This is a diagnostic generating data for netam issue #7. + # if sums_too_big is not None: + # self.csv_file.write(f"{parent},{child},{branch_length},{sums_too_big}\n") + + reshaped_child_idxs = child_idxs.reshape(-1, 3) + child_prob_vector = codon_mutsel[ + torch.arange(len(reshaped_child_idxs)), + reshaped_child_idxs[:, 0], + reshaped_child_idxs[:, 1], + reshaped_child_idxs[:, 2], + ] + + child_prob_vector = torch.clamp(child_prob_vector, min=1e-10) + + result = torch.sum(torch.log(child_prob_vector)) + + assert torch.isfinite(result) + + return result + + return log_pcp_probability + + def optimize_branch_length( log_prob_fn, starting_branch_length,