From 24a19bdae86f2e6a3c61f70feeb351f5c4dc0006 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 21 Nov 2024 09:35:11 -0800 Subject: [PATCH] DXSM ambiguous sequences (#88) This PR addresses https://github.com/matsengrp/dnsm-experiments-1/issues/36. * Masks all codons containing "N" in parent or child sequences, and asserts that unmasked sequences aren't identical * Applies mask in branch length optimization computations * replaces X's with A's in some calls to functions that can't take ambiguous AA sequences. In all cases, a note is made in the doctstring that the function returns nonsense data on ambiguous sites, and those sites are later masked out of function outputs. --- netam/attention_map.py | 6 +- netam/common.py | 51 +++++++++++ netam/dasm.py | 14 ++- netam/dnsm.py | 10 ++- netam/dxsm.py | 38 ++++---- netam/sequences.py | 13 +++ tests/test_ambiguous.py | 190 ++++++++++++++++++++++++++++++++++++++++ tests/test_common.py | 14 ++- tests/test_molevol.py | 1 - tests/test_netam.py | 4 +- 10 files changed, 317 insertions(+), 24 deletions(-) create mode 100644 tests/test_ambiguous.py diff --git a/netam/attention_map.py b/netam/attention_map.py index 55426c25..7e05580b 100644 --- a/netam/attention_map.py +++ b/netam/attention_map.py @@ -1,4 +1,4 @@ -""" +r""" We are going to get the attention weights using the [MultiHeadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) module in PyTorch. These weights are $$ @@ -9,10 +9,10 @@ In our terminology, an attention map is the attention map for a single head. An "attention maps" object is a collection of attention maps: a tensor where the -first dimension is the number of heads. This assumes all layers have the same +first dimension is the number of heads. This assumes all layers have the same number of heads. An "attention mapss" is a list of attention maps objects, one for each sequence in the batch. An "attention profile" is some 1-D summary of an -attention map, such as the maximum attention score for each key position. +attention map, such as the maximum attention score for each key position. # Adapted from https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91 """ diff --git a/netam/common.py b/netam/common.py index ca6fc501..992eb2e0 100644 --- a/netam/common.py +++ b/netam/common.py @@ -10,6 +10,8 @@ from torch import nn, Tensor import multiprocessing as mp +from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence + BIG = 1e9 SMALL_PROB = 1e-6 BASES = ["A", "C", "G", "T"] @@ -83,6 +85,55 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None): return mask +def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None): + """Return a mask tensor indicating codons which contain at least one N. + + Codons beyond the length of the sequence are masked. If other_nt_seqs are provided, + the "and" mask will be computed for all sequences + """ + if aa_length is None: + aa_length = len(nt_parent) // 3 + sequences = (nt_parent,) + other_nt_seqs + mask = [ + all("N" not in codon for codon in codons) + for codons in zip(*(iter_codons(sequence) for sequence in sequences)) + ] + if len(mask) < aa_length: + mask += [False] * (aa_length - len(mask)) + else: + mask = mask[:aa_length] + assert len(mask) == aa_length + return torch.tensor(mask, dtype=torch.bool) + + +def assert_pcp_valid(parent, child, aa_mask=None): + """Check that the parent-child pairs are valid. + + * The parent and child sequences must be the same length + * There must be unmasked codons + * The parent and child sequences must not match after masking codons containing + ambiguities. + + Args: + parent: The parent sequence. + child: The child sequence. + aa_mask: The mask tensor for the amino acid sequence. If None, it will be + computed from the parent and child sequences. + """ + if aa_mask is None: + aa_mask = codon_mask_tensor_of(parent, child) + if len(parent) != len(child): + raise ValueError("Parent and child sequences are not the same length.") + if not aa_mask.any(): + raise ValueError("Parent-child pair is masked in all codons.") + if apply_aa_mask_to_nt_sequence(parent, aa_mask) == apply_aa_mask_to_nt_sequence( + child, aa_mask + ): + raise ValueError( + "Parent-child pair matches after masking codons containing ambiguities" + ) + + def nt_mask_tensor_of(*args, **kwargs): return generic_mask_tensor_of("N", *args, **kwargs) diff --git a/netam/dasm.py b/netam/dasm.py index 55dbcfff..8a2b16e8 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -29,6 +29,10 @@ def update_neutral_probs(self): mask = mask.to("cpu") nt_rates = nt_rates.to("cpu") nt_csps = nt_csps.to("cpu") + if self.multihit_model is not None: + multihit_model = copy.deepcopy(self.multihit_model).to("cpu") + else: + multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) @@ -36,12 +40,14 @@ def update_neutral_probs(self): mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] - molevol.check_csps(parent_idxs, nt_csps) + nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] + molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask]) neutral_aa_probs = molevol.neutral_aa_probs( parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), nt_csps.reshape(-1, 3, 4), + multihit_model=multihit_model, ) if not torch.isfinite(neutral_aa_probs).all(): @@ -196,11 +202,17 @@ def loss_of_batch(self, batch): return torch.stack([subs_pos_loss, csp_loss]) def build_selection_matrix_from_parent(self, parent: str): + """Build a selection matrix from a parent amino acid sequence. + + Values at ambiguous sites are meaningless. + """ # This is simpler than the equivalent in dnsm.py because we get the selection # matrix directly. Note that selection_factors_of_aa_str does the exponentiation # so this indeed gives us the selection factors, not the log selection factors. parent = sequences.translate_sequence(parent) per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent) + + parent = parent.replace("X", "A") parent_idxs = sequences.aa_idx_array_of_str(parent) per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 diff --git a/netam/dnsm.py b/netam/dnsm.py index cfa943b7..a63e7000 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -47,7 +47,10 @@ def update_neutral_probs(self): # with masking out these positions later. We do this below. parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) parent_len = len(nt_parent) - molevol.check_csps(parent_idxs, nt_csps) + # Cannot assume that nt_csps and mask are same length, because when + # datasets are split, masks are recomputed. + nt_mask = mask.repeat_interleave(3)[:parent_len] + molevol.check_csps(parent_idxs[nt_mask], nt_csps[:parent_len][nt_mask]) mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] @@ -154,12 +157,17 @@ def loss_of_batch(self, batch): return self.bce_loss(predictions, aa_subs_indicator) def build_selection_matrix_from_parent(self, parent: str): + """Build a selection matrix from a parent amino acid sequence. + + Values at ambiguous sites are meaningless. + """ parent = sequences.translate_sequence(parent) selection_factors = self.model.selection_factors_of_aa_str(parent) selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float) # Every "off-diagonal" entry of the selection matrix is set to the selection # factor, where "diagonal" means keeping the same amino acid. selection_matrix[:, :] = selection_factors[:, None] + parent = parent.replace("X", "A") # Set "diagonal" elements to one. parent_idxs = sequences.aa_idx_array_of_str(parent) selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 diff --git a/netam/dxsm.py b/netam/dxsm.py index 007e0240..c86733bf 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -18,15 +18,17 @@ from netam.common import ( MAX_AMBIG_AA_IDX, aa_idx_tensor_of_str_ambig, - aa_mask_tensor_of, stack_heterogeneous, + codon_mask_tensor_of, + assert_pcp_valid, ) import netam.framework as framework import netam.molevol as molevol -import netam.sequences as sequences from netam.sequences import ( aa_subs_indicator_tensor_of, translate_sequences, + apply_aa_mask_to_nt_sequence, + nt_mutation_frequency, ) @@ -55,12 +57,6 @@ def __init__( assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) - for parent, child in zip(self.nt_parents, self.nt_children): - if parent == child: - raise ValueError( - f"Found an identical parent and child sequence: {parent}" - ) - aa_parents = translate_sequences(self.nt_parents) aa_children = translate_sequences(self.nt_children) self.max_aa_seq_len = max(len(seq) for seq in aa_parents) @@ -76,8 +72,14 @@ def __init__( self.masks = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool) for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)): - self.masks[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len) + self.masks[i, :] = codon_mask_tensor_of( + nt_parents[i], nt_children[i], aa_length=self.max_aa_seq_len + ) aa_seq_len = len(aa_parent) + assert_pcp_valid( + nt_parents[i], nt_children[i], aa_mask=self.masks[i][:aa_seq_len] + ) + self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig( aa_parent ) @@ -112,8 +114,7 @@ def of_seriess( """ initial_branch_lengths = np.array( [ - sequences.nt_mutation_frequency(parent, child) - * branch_length_multiplier + nt_mutation_frequency(parent, child) * branch_length_multiplier for parent, child in zip(nt_parents, nt_children) ] ) @@ -248,15 +249,20 @@ def _find_optimal_branch_length( child, nt_rates, nt_csps, + aa_mask, starting_branch_length, multihit_model, **optimization_kwargs, ): - if parent == child: - return 0.0 sel_matrix = self.build_selection_matrix_from_parent(parent) + trimmed_aa_mask = aa_mask[: len(sel_matrix)] log_pcp_probability = molevol.mutsel_log_pcp_probability_of( - sel_matrix, parent, child, nt_rates, nt_csps, multihit_model + sel_matrix[trimmed_aa_mask], + apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask), + apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask), + nt_rates[trimmed_aa_mask.repeat_interleave(3)], + nt_csps[trimmed_aa_mask.repeat_interleave(3)], + multihit_model, ) if isinstance(starting_branch_length, torch.Tensor): starting_branch_length = starting_branch_length.detach().item() @@ -268,12 +274,13 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): optimal_lengths = [] failed_count = 0 - for parent, child, nt_rates, nt_csps, starting_length in tqdm( + for parent, child, nt_rates, nt_csps, aa_mask, starting_length in tqdm( zip( dataset.nt_parents, dataset.nt_children, dataset.nt_ratess, dataset.nt_cspss, + dataset.masks, dataset.branch_lengths, ), total=len(dataset.nt_parents), @@ -284,6 +291,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): child, nt_rates[: len(parent)], nt_csps[: len(parent), :], + aa_mask, starting_length, dataset.multihit_model, **optimization_kwargs, diff --git a/netam/sequences.py b/netam/sequences.py index 9eac0ad6..6a3c6916 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -183,3 +183,16 @@ def assert_full_sequences(parent, child): if "N" in parent or "N" in child: raise ValueError("Found ambiguous bases in the parent or child sequence.") + + +def apply_aa_mask_to_nt_sequence(nt_seq, aa_mask): + """Apply an amino acid mask to a nucleotide sequence.""" + return "".join( + nt for nt, mask_val in zip(nt_seq, aa_mask.repeat_interleave(3)) if mask_val + ) + + +def iter_codons(nt_seq): + """Iterate over the codons in a nucleotide sequence.""" + for i in range(0, (len(nt_seq) // 3) * 3, 3): + yield nt_seq[i : i + 3] diff --git a/tests/test_ambiguous.py b/tests/test_ambiguous.py new file mode 100644 index 00000000..86c55e23 --- /dev/null +++ b/tests/test_ambiguous.py @@ -0,0 +1,190 @@ +import pytest + +from netam.common import force_spawn +from netam.models import TransformerBinarySelectionModelWiggleAct +from netam.dasm import ( + DASMBurrito, + DASMDataset, +) +from netam.dnsm import DNSMBurrito, DNSMDataset +from netam.framework import ( + load_pcp_df, + add_shm_model_outputs_to_pcp_df, +) +from netam import pretrained +import random + + +# Function to randomly insert 'N' in sequences +def randomize_with_ns(parent_seq, child_seq, avoid_masked_equality=True): + old_parent = parent_seq + old_child = child_seq + seq_length = len(parent_seq) + try: + first_mut = next( + (idx, p, c) + for idx, (p, c) in enumerate(zip(parent_seq, child_seq)) + if p != c + ) + except: + return parent_seq, child_seq + + # Decide which type of modification to apply + modification_type = random.choice(["same_site", "different_site", "chunk", "none"]) + + if modification_type == "same_site": + # Select random positions to place 'N' in both parent and child at the same positions + num_ns = random.randint( + 1, seq_length // 3 + ) # Random number of Ns, max third the sequence length + positions = random.sample(range(seq_length), num_ns) + parent_seq = "".join( + ["N" if i in positions else base for i, base in enumerate(parent_seq)] + ) + child_seq = "".join( + ["N" if i in positions else base for i, base in enumerate(child_seq)] + ) + + elif modification_type == "different_site": + # Insert 'N's at random positions in parent and child, but not the same positions + num_ns_parent = random.randint(1, seq_length // 3) + num_ns_child = random.randint(1, seq_length // 3) + positions_parent = random.sample(range(seq_length), num_ns_parent) + positions_child = random.sample(range(seq_length), num_ns_child) + + parent_seq = "".join( + [ + "N" if i in positions_parent else base + for i, base in enumerate(parent_seq) + ] + ) + child_seq = "".join( + ["N" if i in positions_child else base for i, base in enumerate(child_seq)] + ) + + elif modification_type == "chunk": + # Replace a chunk of bases with 'N's in both parent and child + chunk_size = random.randint(1, seq_length // 3) + start_pos = random.randint(0, seq_length - chunk_size) + parent_seq = ( + parent_seq[:start_pos] + + "N" * chunk_size + + parent_seq[start_pos + chunk_size :] + ) + child_seq = ( + child_seq[:start_pos] + + "N" * chunk_size + + child_seq[start_pos + chunk_size :] + ) + + if parent_seq == child_seq: + # If sequences are the same, put one mutated site back in: + idx, p, c = first_mut + parent_seq = parent_seq[:idx] + p + parent_seq[idx + 1 :] + child_seq = child_seq[:idx] + c + child_seq[idx + 1 :] + if avoid_masked_equality: + codon_pairs = [ + (parent_seq[i * 3 : (i + 1) * 3], child_seq[i * 3 : (i + 1) * 3]) + for i in range(seq_length // 3) + ] + if all( + p == c + for p, c in filter( + lambda pair: "N" not in pair[0] and "N" not in pair[1], codon_pairs + ) + ): + # put original codon containing a mutation back in. + idx, p, c = first_mut + codon_start = (idx // 3) * 3 + codon_end = codon_start + 3 + parent_seq = ( + parent_seq[:codon_start] + + old_parent[codon_start:codon_end] + + parent_seq[codon_end:] + ) + child_seq = ( + child_seq[:codon_start] + + old_child[codon_start:codon_end] + + child_seq[codon_end:] + ) + + assert len(parent_seq) == len(child_seq) + assert len(parent_seq) == seq_length + + return parent_seq, child_seq + + +@pytest.fixture +def ambig_pcp_df(): + random.seed(1) + df = load_pcp_df( + "data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz", + ) + # Apply the random N adding function to each row + df[["parent", "child"]] = df.apply( + lambda row: randomize_with_ns(row["parent"], row["child"]), + axis=1, + result_type="expand", + ) + + df = add_shm_model_outputs_to_pcp_df( + df, + pretrained.load("ThriftyHumV0.2-45"), + ) + return df + + +@pytest.fixture +def dnsm_model(): + return TransformerBinarySelectionModelWiggleAct( + nhead=2, d_model_per_head=4, dim_feedforward=256, layer_count=2 + ) + + +def test_dnsm_burrito(ambig_pcp_df, dnsm_model): + """Fixture that returns the DNSM Burrito object.""" + force_spawn() + ambig_pcp_df["in_train"] = True + ambig_pcp_df.loc[ambig_pcp_df.index[-15:], "in_train"] = False + train_dataset, val_dataset = DNSMDataset.train_val_datasets_of_pcp_df(ambig_pcp_df) + + burrito = DNSMBurrito( + train_dataset, + val_dataset, + dnsm_model, + batch_size=32, + learning_rate=0.001, + min_learning_rate=0.0001, + ) + burrito.joint_train(epochs=1, cycle_count=2, training_method="full") + + +@pytest.fixture +def dasm_model(): + return TransformerBinarySelectionModelWiggleAct( + nhead=2, + d_model_per_head=4, + dim_feedforward=256, + layer_count=2, + output_dim=20, + ) + + +def test_dasm_burrito(ambig_pcp_df, dasm_model): + force_spawn() + """Fixture that returns the DNSM Burrito object.""" + ambig_pcp_df["in_train"] = True + ambig_pcp_df.loc[ambig_pcp_df.index[-15:], "in_train"] = False + train_dataset, val_dataset = DASMDataset.train_val_datasets_of_pcp_df(ambig_pcp_df) + + burrito = DASMBurrito( + train_dataset, + val_dataset, + dasm_model, + batch_size=32, + learning_rate=0.001, + min_learning_rate=0.0001, + ) + burrito.joint_train( + epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False + ) diff --git a/tests/test_common.py b/tests/test_common.py index f7c63165..e7f2f67b 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,6 +1,6 @@ import torch -from netam.common import nt_mask_tensor_of, aa_mask_tensor_of +from netam.common import nt_mask_tensor_of, aa_mask_tensor_of, codon_mask_tensor_of def test_mask_tensor_of(): @@ -13,3 +13,15 @@ def test_mask_tensor_of(): expected_output = torch.tensor([1, 1, 1, 1, 0], dtype=torch.bool) output = aa_mask_tensor_of(input_seq, length=5) assert torch.equal(output, expected_output) + + +def test_codon_mask_tensor_of(): + input_seq = "NAAAAAAAAAA" + # First test as nucleotides. + expected_output = torch.tensor([0, 1, 1, 0, 0], dtype=torch.bool) + output = codon_mask_tensor_of(input_seq, aa_length=5) + assert torch.equal(output, expected_output) + input_seq2 = "AAAANAAAAAA" + expected_output = torch.tensor([0, 0, 1, 0, 0], dtype=torch.bool) + output = codon_mask_tensor_of(input_seq, input_seq2, aa_length=5) + assert torch.equal(output, expected_output) diff --git a/tests/test_molevol.py b/tests/test_molevol.py index 26a2146b..0b313fa1 100644 --- a/tests/test_molevol.py +++ b/tests/test_molevol.py @@ -2,7 +2,6 @@ import pytest import netam.molevol as molevol -from netam import framework from netam import pretrained from netam.sequences import ( diff --git a/tests/test_netam.py b/tests/test_netam.py index 78d49804..c2fe9088 100644 --- a/tests/test_netam.py +++ b/tests/test_netam.py @@ -63,7 +63,7 @@ def test_make_dataset(tiny_dataset): assert branch_length == 1 / 5 -def test_write_output(tiny_burrito): +def test_write_tinyburrito_output(tiny_burrito): os.makedirs("_ignore", exist_ok=True) tiny_burrito.model.write_shmoof_output("_ignore") @@ -105,7 +105,7 @@ def mini_rsburrito(mini_dataset, tiny_rsscnnmodel): return burrito -def test_write_output(mini_rsburrito): +def test_write_mini_rsburrito_output(mini_rsburrito): os.makedirs("_ignore", exist_ok=True) mini_rsburrito.save_crepe("_ignore/mini_rscrepe")