From b2601a412683f6bb1d3d14846882850926ed50fd Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 20 Nov 2024 15:25:08 -0800 Subject: [PATCH] cleanup and format --- netam/attention_map.py | 6 +++--- netam/common.py | 23 ++++++++++++++--------- netam/dasm.py | 7 +++++-- netam/dnsm.py | 12 ++++++------ netam/dxsm.py | 30 ++++++++++++++---------------- netam/models.py | 2 -- netam/sequences.py | 4 +++- tests/test_ambiguous.py | 38 ++++++++++++++++++++++++-------------- tests/test_common.py | 14 +++++++++++++- tests/test_molevol.py | 1 - tests/test_netam.py | 4 ++-- 11 files changed, 84 insertions(+), 57 deletions(-) diff --git a/netam/attention_map.py b/netam/attention_map.py index 55426c25..7e05580b 100644 --- a/netam/attention_map.py +++ b/netam/attention_map.py @@ -1,4 +1,4 @@ -""" +r""" We are going to get the attention weights using the [MultiHeadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) module in PyTorch. These weights are $$ @@ -9,10 +9,10 @@ In our terminology, an attention map is the attention map for a single head. An "attention maps" object is a collection of attention maps: a tensor where the -first dimension is the number of heads. This assumes all layers have the same +first dimension is the number of heads. This assumes all layers have the same number of heads. An "attention mapss" is a list of attention maps objects, one for each sequence in the batch. An "attention profile" is some 1-D summary of an -attention map, such as the maximum attention score for each key position. +attention map, such as the maximum attention score for each key position. # Adapted from https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91 """ diff --git a/netam/common.py b/netam/common.py index f7198dea..992eb2e0 100644 --- a/netam/common.py +++ b/netam/common.py @@ -11,6 +11,7 @@ import multiprocessing as mp from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence + BIG = 1e9 SMALL_PROB = 1e-6 BASES = ["A", "C", "G", "T"] @@ -84,16 +85,18 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None): return mask -def codon_mask_tensor_of(nt_parent, nt_child, aa_length=None): +def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None): """Return a mask tensor indicating codons which contain at least one N. - Codons beyond the length of the sequence are masked. + Codons beyond the length of the sequence are masked. If other_nt_seqs are provided, + the "and" mask will be computed for all sequences """ if aa_length is None: aa_length = len(nt_parent) // 3 + sequences = (nt_parent,) + other_nt_seqs mask = [ - ("N" not in parent_codon and "N" not in child_codon) - for parent_codon, child_codon in zip(iter_codons(nt_parent), iter_codons(nt_child)) + all("N" not in codon for codon in codons) + for codons in zip(*(iter_codons(sequence) for sequence in sequences)) ] if len(mask) < aa_length: mask += [False] * (aa_length - len(mask)) @@ -102,7 +105,8 @@ def codon_mask_tensor_of(nt_parent, nt_child, aa_length=None): assert len(mask) == aa_length return torch.tensor(mask, dtype=torch.bool) -def check_pcp_valid(parent, child, aa_mask=None): + +def assert_pcp_valid(parent, child, aa_mask=None): """Check that the parent-child pairs are valid. * The parent and child sequences must be the same length @@ -122,11 +126,12 @@ def check_pcp_valid(parent, child, aa_mask=None): raise ValueError("Parent and child sequences are not the same length.") if not aa_mask.any(): raise ValueError("Parent-child pair is masked in all codons.") - if ( - apply_aa_mask_to_nt_sequence(parent, aa_mask) - == apply_aa_mask_to_nt_sequence(child, aa_mask) + if apply_aa_mask_to_nt_sequence(parent, aa_mask) == apply_aa_mask_to_nt_sequence( + child, aa_mask ): - raise ValueError("Parent-child pair matches after masking codons containing ambiguities") + raise ValueError( + "Parent-child pair matches after masking codons containing ambiguities" + ) def nt_mask_tensor_of(*args, **kwargs): diff --git a/netam/dasm.py b/netam/dasm.py index 368dcab3..7691cc82 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -37,7 +37,7 @@ def update_neutral_probs(self): mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] - molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(nt_parent)][nt_mask]) + molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask]) # TODO don't we need to pass multihit model in here? neutral_aa_probs = molevol.neutral_aa_probs( @@ -198,13 +198,16 @@ def loss_of_batch(self, batch): return torch.stack([subs_pos_loss, csp_loss]) def build_selection_matrix_from_parent(self, parent: str): + """Build a selection matrix from a parent amino acid sequence. + + Values at ambiguous sites are meaningless. + """ # This is simpler than the equivalent in dnsm.py because we get the selection # matrix directly. Note that selection_factors_of_aa_str does the exponentiation # so this indeed gives us the selection factors, not the log selection factors. parent = sequences.translate_sequence(parent) per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent) - # TODO this nonsense output will need to get masked parent = parent.replace("X", "A") parent_idxs = sequences.aa_idx_array_of_str(parent) per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 diff --git a/netam/dnsm.py b/netam/dnsm.py index 00d2b253..d3a66284 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -50,11 +50,8 @@ def update_neutral_probs(self): parent_len = len(nt_parent) # Cannot assume that nt_csps and mask are same length, because when # datasets are split, masks are recomputed. - # TODO Figure out how we're really going to handle masking, because - # old method allowed some nt N's to be unmasked. - nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] - molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(nt_parent)][nt_mask]) - # molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(parent_idxs)][nt_mask]) + nt_mask = mask.repeat_interleave(3)[:parent_len] + molevol.check_csps(parent_idxs[nt_mask], nt_csps[:parent_len][nt_mask]) mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] @@ -161,13 +158,16 @@ def loss_of_batch(self, batch): return self.bce_loss(predictions, aa_subs_indicator) def build_selection_matrix_from_parent(self, parent: str): + """Build a selection matrix from a parent amino acid sequence. + + Values at ambiguous sites are meaningless. + """ parent = sequences.translate_sequence(parent) selection_factors = self.model.selection_factors_of_aa_str(parent) selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float) # Every "off-diagonal" entry of the selection matrix is set to the selection # factor, where "diagonal" means keeping the same amino acid. selection_matrix[:, :] = selection_factors[:, None] - # TODO this nonsense output will need to get masked parent = parent.replace("X", "A") # Set "diagonal" elements to one. parent_idxs = sequences.aa_idx_array_of_str(parent) diff --git a/netam/dxsm.py b/netam/dxsm.py index 40c927bd..55961d1b 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -18,18 +18,17 @@ from netam.common import ( MAX_AMBIG_AA_IDX, aa_idx_tensor_of_str_ambig, - aa_mask_tensor_of, stack_heterogeneous, codon_mask_tensor_of, - check_pcp_valid, + assert_pcp_valid, ) import netam.framework as framework import netam.molevol as molevol -import netam.sequences as sequences from netam.sequences import ( aa_subs_indicator_tensor_of, translate_sequences, apply_aa_mask_to_nt_sequence, + nt_mutation_frequency, ) @@ -73,9 +72,13 @@ def __init__( self.masks = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool) for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)): - self.masks[i, :] = codon_mask_tensor_of(nt_parents[i], nt_children[i], self.max_aa_seq_len) + self.masks[i, :] = codon_mask_tensor_of( + nt_parents[i], nt_children[i], self.max_aa_seq_len + ) aa_seq_len = len(aa_parent) - check_pcp_valid(nt_parents[i], nt_children[i], aa_mask=self.masks[i][: aa_seq_len]) + assert_pcp_valid( + nt_parents[i], nt_children[i], aa_mask=self.masks[i][:aa_seq_len] + ) self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig( aa_parent @@ -111,8 +114,7 @@ def of_seriess( """ initial_branch_lengths = np.array( [ - sequences.nt_mutation_frequency(parent, child) - * branch_length_multiplier + nt_mutation_frequency(parent, child) * branch_length_multiplier for parent, child in zip(nt_parents, nt_children) ] ) @@ -252,14 +254,12 @@ def _find_optimal_branch_length( multihit_model, **optimization_kwargs, ): - # TODO this doesn't use any mask, couldn't we use already-computed - # aa_parent and its mask? sel_matrix = self.build_selection_matrix_from_parent(parent) trimmed_aa_mask = aa_mask[: len(sel_matrix)] log_pcp_probability = molevol.mutsel_log_pcp_probability_of( sel_matrix[trimmed_aa_mask], - sequences.apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask[: len(parent)]), - sequences.apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask[: len(child)]), + apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask), + apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask), nt_rates[trimmed_aa_mask.repeat_interleave(3)], nt_csps[trimmed_aa_mask.repeat_interleave(3)], multihit_model, @@ -316,11 +316,9 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # The following can be used when one wants a better traceback. - # TODO disable later... - burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) - return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) - + # # The following can be used when one wants a better traceback. + # burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) + # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) our_optimize_branch_length = partial( worker_optimize_branch_length, self.__class__, diff --git a/netam/models.py b/netam/models.py index c135b404..0f5b2854 100644 --- a/netam/models.py +++ b/netam/models.py @@ -552,8 +552,6 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: aa_idxs = aa_idx_tensor_of_str_ambig(aa_str) aa_idxs = aa_idxs.to(model_device) - # TODO: Shouldn't we be using the new codon mask here, and allowing - # a pre-computed mask to be passed in? mask = aa_mask_tensor_of(aa_str) mask = mask.to(model_device) diff --git a/netam/sequences.py b/netam/sequences.py index d7131ac5..6a3c6916 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -184,13 +184,15 @@ def assert_full_sequences(parent, child): if "N" in parent or "N" in child: raise ValueError("Found ambiguous bases in the parent or child sequence.") + def apply_aa_mask_to_nt_sequence(nt_seq, aa_mask): """Apply an amino acid mask to a nucleotide sequence.""" return "".join( nt for nt, mask_val in zip(nt_seq, aa_mask.repeat_interleave(3)) if mask_val ) + def iter_codons(nt_seq): """Iterate over the codons in a nucleotide sequence.""" - for i in range(0, len(nt_seq), 3): + for i in range(0, (len(nt_seq) // 3) * 3, 3): yield nt_seq[i : i + 3] diff --git a/tests/test_ambiguous.py b/tests/test_ambiguous.py index 36d04d2a..86c55e23 100644 --- a/tests/test_ambiguous.py +++ b/tests/test_ambiguous.py @@ -1,19 +1,12 @@ import pytest from netam.common import force_spawn -from netam.framework import ( - crepe_exists, - load_crepe, -) from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dasm import ( DASMBurrito, DASMDataset, ) -from netam.framework import SHMoofDataset, SHMBurrito, RSSHMBurrito -from netam.models import SHMoofModel, RSSHMoofModel, IndepRSCNNModel from netam.dnsm import DNSMBurrito, DNSMDataset -import pytest from netam.framework import ( load_pcp_df, add_shm_model_outputs_to_pcp_df, @@ -28,11 +21,14 @@ def randomize_with_ns(parent_seq, child_seq, avoid_masked_equality=True): old_child = child_seq seq_length = len(parent_seq) try: - first_mut = next((idx, p, c) for idx, (p, c) in enumerate(zip(parent_seq, child_seq)) if p != c) + first_mut = next( + (idx, p, c) + for idx, (p, c) in enumerate(zip(parent_seq, child_seq)) + if p != c + ) except: return parent_seq, child_seq - # Decide which type of modification to apply modification_type = random.choice(["same_site", "different_site", "chunk", "none"]) @@ -88,16 +84,29 @@ def randomize_with_ns(parent_seq, child_seq, avoid_masked_equality=True): child_seq = child_seq[:idx] + c + child_seq[idx + 1 :] if avoid_masked_equality: codon_pairs = [ - (parent_seq[i*3: (i+1)*3], child_seq[i*3: (i+1)*3]) + (parent_seq[i * 3 : (i + 1) * 3], child_seq[i * 3 : (i + 1) * 3]) for i in range(seq_length // 3) ] - if all(p == c for p, c in filter(lambda pair: "N" not in pair[0] and "N" not in pair[1], codon_pairs)): + if all( + p == c + for p, c in filter( + lambda pair: "N" not in pair[0] and "N" not in pair[1], codon_pairs + ) + ): # put original codon containing a mutation back in. idx, p, c = first_mut codon_start = (idx // 3) * 3 codon_end = codon_start + 3 - parent_seq = parent_seq[:codon_start] + old_parent[codon_start:codon_end] + parent_seq[codon_end:] - child_seq = child_seq[:codon_start] + old_child[codon_start:codon_end] + child_seq[codon_end:] + parent_seq = ( + parent_seq[:codon_start] + + old_parent[codon_start:codon_end] + + parent_seq[codon_end:] + ) + child_seq = ( + child_seq[:codon_start] + + old_child[codon_start:codon_end] + + child_seq[codon_end:] + ) assert len(parent_seq) == len(child_seq) assert len(parent_seq) == seq_length @@ -124,15 +133,16 @@ def ambig_pcp_df(): ) return df + @pytest.fixture def dnsm_model(): return TransformerBinarySelectionModelWiggleAct( nhead=2, d_model_per_head=4, dim_feedforward=256, layer_count=2 ) + def test_dnsm_burrito(ambig_pcp_df, dnsm_model): """Fixture that returns the DNSM Burrito object.""" - # TODO fix and make also work with randomize_with_ns avoid_masked_equality=False force_spawn() ambig_pcp_df["in_train"] = True ambig_pcp_df.loc[ambig_pcp_df.index[-15:], "in_train"] = False diff --git a/tests/test_common.py b/tests/test_common.py index f7c63165..e7f2f67b 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,6 +1,6 @@ import torch -from netam.common import nt_mask_tensor_of, aa_mask_tensor_of +from netam.common import nt_mask_tensor_of, aa_mask_tensor_of, codon_mask_tensor_of def test_mask_tensor_of(): @@ -13,3 +13,15 @@ def test_mask_tensor_of(): expected_output = torch.tensor([1, 1, 1, 1, 0], dtype=torch.bool) output = aa_mask_tensor_of(input_seq, length=5) assert torch.equal(output, expected_output) + + +def test_codon_mask_tensor_of(): + input_seq = "NAAAAAAAAAA" + # First test as nucleotides. + expected_output = torch.tensor([0, 1, 1, 0, 0], dtype=torch.bool) + output = codon_mask_tensor_of(input_seq, aa_length=5) + assert torch.equal(output, expected_output) + input_seq2 = "AAAANAAAAAA" + expected_output = torch.tensor([0, 0, 1, 0, 0], dtype=torch.bool) + output = codon_mask_tensor_of(input_seq, input_seq2, aa_length=5) + assert torch.equal(output, expected_output) diff --git a/tests/test_molevol.py b/tests/test_molevol.py index 26a2146b..0b313fa1 100644 --- a/tests/test_molevol.py +++ b/tests/test_molevol.py @@ -2,7 +2,6 @@ import pytest import netam.molevol as molevol -from netam import framework from netam import pretrained from netam.sequences import ( diff --git a/tests/test_netam.py b/tests/test_netam.py index 78d49804..c2fe9088 100644 --- a/tests/test_netam.py +++ b/tests/test_netam.py @@ -63,7 +63,7 @@ def test_make_dataset(tiny_dataset): assert branch_length == 1 / 5 -def test_write_output(tiny_burrito): +def test_write_tinyburrito_output(tiny_burrito): os.makedirs("_ignore", exist_ok=True) tiny_burrito.model.write_shmoof_output("_ignore") @@ -105,7 +105,7 @@ def mini_rsburrito(mini_dataset, tiny_rsscnnmodel): return burrito -def test_write_output(mini_rsburrito): +def test_write_mini_rsburrito_output(mini_rsburrito): os.makedirs("_ignore", exist_ok=True) mini_rsburrito.save_crepe("_ignore/mini_rscrepe")