From 23bc49004632533b9a914dd76350f37bfe394e68 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 20 Nov 2024 12:12:59 -0800 Subject: [PATCH] fix dasm, preliminarily --- netam/dasm.py | 7 ++++++- netam/dnsm.py | 6 +++--- netam/dxsm.py | 10 +--------- netam/models.py | 2 ++ 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 55dbcfff..368dcab3 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -36,8 +36,10 @@ def update_neutral_probs(self): mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] - molevol.check_csps(parent_idxs, nt_csps) + nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] + molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(nt_parent)][nt_mask]) + # TODO don't we need to pass multihit model in here? neutral_aa_probs = molevol.neutral_aa_probs( parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), @@ -201,6 +203,9 @@ def build_selection_matrix_from_parent(self, parent: str): # so this indeed gives us the selection factors, not the log selection factors. parent = sequences.translate_sequence(parent) per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent) + + # TODO this nonsense output will need to get masked + parent = parent.replace("X", "A") parent_idxs = sequences.aa_idx_array_of_str(parent) per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 diff --git a/netam/dnsm.py b/netam/dnsm.py index 8dffe001..00d2b253 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -45,14 +45,14 @@ def update_neutral_probs(self): multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. - # TODO Figure out how we're really going to handle masking, because - # old method allowed some nt N's to be unmasked. - nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] # nt_mask = torch.tensor([it != "N" for it in nt_parent], dtype=torch.bool) parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) parent_len = len(nt_parent) # Cannot assume that nt_csps and mask are same length, because when # datasets are split, masks are recomputed. + # TODO Figure out how we're really going to handle masking, because + # old method allowed some nt N's to be unmasked. + nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(nt_parent)][nt_mask]) # molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(parent_idxs)][nt_mask]) diff --git a/netam/dxsm.py b/netam/dxsm.py index 9d899fb0..40c927bd 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -252,16 +252,8 @@ def _find_optimal_branch_length( multihit_model, **optimization_kwargs, ): - # TODO: This doesn't seem quite right, because we'll mask whole codons - # if they contain just one ambiguity, even when we know they also - # contain a substitution. - if all(p_c == c_c for idx, (p_c, c_c) in enumerate(zip(parent, child)) if aa_mask[idx // 3]): - print("Parent and child are the same when codons containing N are masked") - assert False - # if parent == child: - # return 0.0 # TODO this doesn't use any mask, couldn't we use already-computed - # aa_parent? + # aa_parent and its mask? sel_matrix = self.build_selection_matrix_from_parent(parent) trimmed_aa_mask = aa_mask[: len(sel_matrix)] log_pcp_probability = molevol.mutsel_log_pcp_probability_of( diff --git a/netam/models.py b/netam/models.py index 0f5b2854..c135b404 100644 --- a/netam/models.py +++ b/netam/models.py @@ -552,6 +552,8 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: aa_idxs = aa_idx_tensor_of_str_ambig(aa_str) aa_idxs = aa_idxs.to(model_device) + # TODO: Shouldn't we be using the new codon mask here, and allowing + # a pre-computed mask to be passed in? mask = aa_mask_tensor_of(aa_str) mask = mask.to(model_device)