From d434a81164c6dc98c663f97d1a511397ebb24d62 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 20 Nov 2024 11:22:03 -0800 Subject: [PATCH] better mask handling --- netam/common.py | 46 +++++++++++++++++++++++++++++++++++++++++ netam/dxsm.py | 39 +++++++++++++--------------------- netam/molevol.py | 15 ++++++-------- netam/sequences.py | 11 ++++++++++ tests/test_ambiguous.py | 2 -- 5 files changed, 77 insertions(+), 36 deletions(-) diff --git a/netam/common.py b/netam/common.py index ca6fc501..f7198dea 100644 --- a/netam/common.py +++ b/netam/common.py @@ -10,6 +10,7 @@ from torch import nn, Tensor import multiprocessing as mp +from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence BIG = 1e9 SMALL_PROB = 1e-6 BASES = ["A", "C", "G", "T"] @@ -83,6 +84,51 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None): return mask +def codon_mask_tensor_of(nt_parent, nt_child, aa_length=None): + """Return a mask tensor indicating codons which contain at least one N. + + Codons beyond the length of the sequence are masked. + """ + if aa_length is None: + aa_length = len(nt_parent) // 3 + mask = [ + ("N" not in parent_codon and "N" not in child_codon) + for parent_codon, child_codon in zip(iter_codons(nt_parent), iter_codons(nt_child)) + ] + if len(mask) < aa_length: + mask += [False] * (aa_length - len(mask)) + else: + mask = mask[:aa_length] + assert len(mask) == aa_length + return torch.tensor(mask, dtype=torch.bool) + +def check_pcp_valid(parent, child, aa_mask=None): + """Check that the parent-child pairs are valid. + + * The parent and child sequences must be the same length + * There must be unmasked codons + * The parent and child sequences must not match after masking codons containing + ambiguities. + + Args: + parent: The parent sequence. + child: The child sequence. + aa_mask: The mask tensor for the amino acid sequence. If None, it will be + computed from the parent and child sequences. + """ + if aa_mask is None: + aa_mask = codon_mask_tensor_of(parent, child) + if len(parent) != len(child): + raise ValueError("Parent and child sequences are not the same length.") + if not aa_mask.any(): + raise ValueError("Parent-child pair is masked in all codons.") + if ( + apply_aa_mask_to_nt_sequence(parent, aa_mask) + == apply_aa_mask_to_nt_sequence(child, aa_mask) + ): + raise ValueError("Parent-child pair matches after masking codons containing ambiguities") + + def nt_mask_tensor_of(*args, **kwargs): return generic_mask_tensor_of("N", *args, **kwargs) diff --git a/netam/dxsm.py b/netam/dxsm.py index 60642df3..9d899fb0 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -20,6 +20,8 @@ aa_idx_tensor_of_str_ambig, aa_mask_tensor_of, stack_heterogeneous, + codon_mask_tensor_of, + check_pcp_valid, ) import netam.framework as framework import netam.molevol as molevol @@ -27,22 +29,9 @@ from netam.sequences import ( aa_subs_indicator_tensor_of, translate_sequences, + apply_aa_mask_to_nt_sequence, ) -def cautious_mask_tensor_of(nt_str, aa_length): - """Return a mask tensor indicating codons which contain at least one N. - - Codons beyond the length of the sequence are masked. - """ - if aa_length is None: - aa_length = len(nt_str) // 3 - mask = ["N" not in nt_str[i * 3:(i + 1) * 3] for i in range(len(nt_str) // 3)] - if len(mask) < aa_length: - mask += [False] * (aa_length - len(mask)) - else: - mask = mask[:aa_length] - assert len(mask) == aa_length - return torch.tensor(mask, dtype=torch.bool) class DXSMDataset(Dataset, ABC): prefix = "dxsm" @@ -69,12 +58,6 @@ def __init__( assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) - for parent, child in zip(self.nt_parents, self.nt_children): - if parent == child: - raise ValueError( - f"Found an identical parent and child sequence: {parent}" - ) - aa_parents = translate_sequences(self.nt_parents) aa_children = translate_sequences(self.nt_children) self.max_aa_seq_len = max(len(seq) for seq in aa_parents) @@ -90,11 +73,10 @@ def __init__( self.masks = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool) for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)): - # self.masks[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len) - # TODO Figure out how we're really going to handle masking - self.masks[i, :] = cautious_mask_tensor_of(nt_parents[i], self.max_aa_seq_len) - + self.masks[i, :] = codon_mask_tensor_of(nt_parents[i], nt_children[i], self.max_aa_seq_len) aa_seq_len = len(aa_parent) + check_pcp_valid(nt_parents[i], nt_children[i], aa_mask=self.masks[i][: aa_seq_len]) + self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig( aa_parent ) @@ -281,8 +263,14 @@ def _find_optimal_branch_length( # TODO this doesn't use any mask, couldn't we use already-computed # aa_parent? sel_matrix = self.build_selection_matrix_from_parent(parent) + trimmed_aa_mask = aa_mask[: len(sel_matrix)] log_pcp_probability = molevol.mutsel_log_pcp_probability_of( - sel_matrix, parent, child, nt_rates, nt_csps, aa_mask[:len(sel_matrix)], multihit_model + sel_matrix[trimmed_aa_mask], + sequences.apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask[: len(parent)]), + sequences.apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask[: len(child)]), + nt_rates[trimmed_aa_mask.repeat_interleave(3)], + nt_csps[trimmed_aa_mask.repeat_interleave(3)], + multihit_model, ) if isinstance(starting_branch_length, torch.Tensor): starting_branch_length = starting_branch_length.detach().item() @@ -337,6 +325,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) # The following can be used when one wants a better traceback. + # TODO disable later... burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) diff --git a/netam/molevol.py b/netam/molevol.py index 0f0e0d32..2aef1c10 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -434,7 +434,7 @@ def neutral_aa_mut_probs( def mutsel_log_pcp_probability_of( - sel_matrix, parent, child, nt_rates, nt_csps, aa_mask, multihit_model=None + sel_matrix, parent, child, nt_rates, nt_csps, multihit_model=None ): """Constructs the log_pcp_probability function specific to given nt_rates and nt_csps. @@ -446,9 +446,6 @@ def mutsel_log_pcp_probability_of( assert len(parent) % 3 == 0 assert sel_matrix.shape == (len(parent) // 3, 20) - # This is masked out later - parent = parent.replace("N", "A") - child = child.replace("N", "A") parent_idxs = sequences.nt_idx_tensor_of_str(parent) child_idxs = sequences.nt_idx_tensor_of_str(child) @@ -457,10 +454,10 @@ def log_pcp_probability(log_branch_length: torch.Tensor): nt_mut_probs = 1.0 - torch.exp(-branch_length * nt_rates) codon_mutsel, sums_too_big = build_codon_mutsel( - parent_idxs.reshape(-1, 3)[aa_mask], - nt_mut_probs.reshape(-1, 3)[aa_mask], - nt_csps.reshape(-1, 3, 4)[aa_mask], - sel_matrix[aa_mask], + parent_idxs.reshape(-1, 3), + nt_mut_probs.reshape(-1, 3), + nt_csps.reshape(-1, 3, 4), + sel_matrix, multihit_model=multihit_model, ) @@ -468,7 +465,7 @@ def log_pcp_probability(log_branch_length: torch.Tensor): # if sums_too_big is not None: # self.csv_file.write(f"{parent},{child},{branch_length},{sums_too_big}\n") - reshaped_child_idxs = child_idxs.reshape(-1, 3)[aa_mask] + reshaped_child_idxs = child_idxs.reshape(-1, 3) child_prob_vector = codon_mutsel[ torch.arange(len(reshaped_child_idxs)), reshaped_child_idxs[:, 0], diff --git a/netam/sequences.py b/netam/sequences.py index 9eac0ad6..d7131ac5 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -183,3 +183,14 @@ def assert_full_sequences(parent, child): if "N" in parent or "N" in child: raise ValueError("Found ambiguous bases in the parent or child sequence.") + +def apply_aa_mask_to_nt_sequence(nt_seq, aa_mask): + """Apply an amino acid mask to a nucleotide sequence.""" + return "".join( + nt for nt, mask_val in zip(nt_seq, aa_mask.repeat_interleave(3)) if mask_val + ) + +def iter_codons(nt_seq): + """Iterate over the codons in a nucleotide sequence.""" + for i in range(0, len(nt_seq), 3): + yield nt_seq[i : i + 3] diff --git a/tests/test_ambiguous.py b/tests/test_ambiguous.py index 845a6c26..36d04d2a 100644 --- a/tests/test_ambiguous.py +++ b/tests/test_ambiguous.py @@ -147,7 +147,6 @@ def test_dnsm_burrito(ambig_pcp_df, dnsm_model): min_learning_rate=0.0001, ) burrito.joint_train(epochs=1, cycle_count=2, training_method="full") - return burrito @pytest.fixture @@ -179,4 +178,3 @@ def test_dasm_burrito(ambig_pcp_df, dasm_model): burrito.joint_train( epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False ) - return burrito