Skip to content

Commit

Permalink
DXSM ambiguous sequences (#88)
Browse files Browse the repository at this point in the history
This PR addresses matsengrp/dnsm-experiments-1#36.
* Masks all codons containing "N" in parent or child sequences, and asserts that unmasked sequences aren't identical
* Applies mask in branch length optimization computations
* replaces X's with A's in some calls to functions that can't take ambiguous AA sequences. In all cases, a note is made in the doctstring that the function returns nonsense data on ambiguous sites, and those sites are later masked out of function outputs.
  • Loading branch information
willdumm authored Nov 21, 2024
1 parent a6f02a2 commit 24a19bd
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 24 deletions.
6 changes: 3 additions & 3 deletions netam/attention_map.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
r"""
We are going to get the attention weights using the [MultiHeadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) module in PyTorch. These weights are
$$
Expand All @@ -9,10 +9,10 @@
In our terminology, an attention map is the attention map for a single head. An
"attention maps" object is a collection of attention maps: a tensor where the
first dimension is the number of heads. This assumes all layers have the same
first dimension is the number of heads. This assumes all layers have the same
number of heads. An "attention mapss" is a list of attention maps objects, one
for each sequence in the batch. An "attention profile" is some 1-D summary of an
attention map, such as the maximum attention score for each key position.
attention map, such as the maximum attention score for each key position.
# Adapted from https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91
"""
Expand Down
51 changes: 51 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch import nn, Tensor
import multiprocessing as mp

from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence

BIG = 1e9
SMALL_PROB = 1e-6
BASES = ["A", "C", "G", "T"]
Expand Down Expand Up @@ -83,6 +85,55 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
return mask


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all("N" not in codon for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)


def assert_pcp_valid(parent, child, aa_mask=None):
"""Check that the parent-child pairs are valid.
* The parent and child sequences must be the same length
* There must be unmasked codons
* The parent and child sequences must not match after masking codons containing
ambiguities.
Args:
parent: The parent sequence.
child: The child sequence.
aa_mask: The mask tensor for the amino acid sequence. If None, it will be
computed from the parent and child sequences.
"""
if aa_mask is None:
aa_mask = codon_mask_tensor_of(parent, child)
if len(parent) != len(child):
raise ValueError("Parent and child sequences are not the same length.")
if not aa_mask.any():
raise ValueError("Parent-child pair is masked in all codons.")
if apply_aa_mask_to_nt_sequence(parent, aa_mask) == apply_aa_mask_to_nt_sequence(
child, aa_mask
):
raise ValueError(
"Parent-child pair matches after masking codons containing ambiguities"
)


def nt_mask_tensor_of(*args, **kwargs):
return generic_mask_tensor_of("N", *args, **kwargs)

Expand Down
14 changes: 13 additions & 1 deletion netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,25 @@ def update_neutral_probs(self):
mask = mask.to("cpu")
nt_rates = nt_rates.to("cpu")
nt_csps = nt_csps.to("cpu")
if self.multihit_model is not None:
multihit_model = copy.deepcopy(self.multihit_model).to("cpu")
else:
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)

mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]
molevol.check_csps(parent_idxs, nt_csps)
nt_mask = mask.repeat_interleave(3)[: len(nt_parent)]
molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask])

neutral_aa_probs = molevol.neutral_aa_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
)

if not torch.isfinite(neutral_aa_probs).all():
Expand Down Expand Up @@ -196,11 +202,17 @@ def loss_of_batch(self, batch):
return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent amino acid sequence.
Values at ambiguous sites are meaningless.
"""
# This is simpler than the equivalent in dnsm.py because we get the selection
# matrix directly. Note that selection_factors_of_aa_str does the exponentiation
# so this indeed gives us the selection factors, not the log selection factors.
parent = sequences.translate_sequence(parent)
per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent)

parent = parent.replace("X", "A")
parent_idxs = sequences.aa_idx_array_of_str(parent)
per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

Expand Down
10 changes: 9 additions & 1 deletion netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def update_neutral_probs(self):
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)
molevol.check_csps(parent_idxs, nt_csps)
# Cannot assume that nt_csps and mask are same length, because when
# datasets are split, masks are recomputed.
nt_mask = mask.repeat_interleave(3)[:parent_len]
molevol.check_csps(parent_idxs[nt_mask], nt_csps[:parent_len][nt_mask])

mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]
Expand Down Expand Up @@ -154,12 +157,17 @@ def loss_of_batch(self, batch):
return self.bce_loss(predictions, aa_subs_indicator)

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent amino acid sequence.
Values at ambiguous sites are meaningless.
"""
parent = sequences.translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
parent = parent.replace("X", "A")
# Set "diagonal" elements to one.
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0
Expand Down
38 changes: 23 additions & 15 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
aa_mask_tensor_of,
stack_heterogeneous,
codon_mask_tensor_of,
assert_pcp_valid,
)
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
from netam.sequences import (
aa_subs_indicator_tensor_of,
translate_sequences,
apply_aa_mask_to_nt_sequence,
nt_mutation_frequency,
)


Expand Down Expand Up @@ -55,12 +57,6 @@ def __init__(
assert len(self.nt_parents) == len(self.nt_children)
pcp_count = len(self.nt_parents)

for parent, child in zip(self.nt_parents, self.nt_children):
if parent == child:
raise ValueError(
f"Found an identical parent and child sequence: {parent}"
)

aa_parents = translate_sequences(self.nt_parents)
aa_children = translate_sequences(self.nt_children)
self.max_aa_seq_len = max(len(seq) for seq in aa_parents)
Expand All @@ -76,8 +72,14 @@ def __init__(
self.masks = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool)

for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)):
self.masks[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len)
self.masks[i, :] = codon_mask_tensor_of(
nt_parents[i], nt_children[i], aa_length=self.max_aa_seq_len
)
aa_seq_len = len(aa_parent)
assert_pcp_valid(
nt_parents[i], nt_children[i], aa_mask=self.masks[i][:aa_seq_len]
)

self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(
aa_parent
)
Expand Down Expand Up @@ -112,8 +114,7 @@ def of_seriess(
"""
initial_branch_lengths = np.array(
[
sequences.nt_mutation_frequency(parent, child)
* branch_length_multiplier
nt_mutation_frequency(parent, child) * branch_length_multiplier
for parent, child in zip(nt_parents, nt_children)
]
)
Expand Down Expand Up @@ -248,15 +249,20 @@ def _find_optimal_branch_length(
child,
nt_rates,
nt_csps,
aa_mask,
starting_branch_length,
multihit_model,
**optimization_kwargs,
):
if parent == child:
return 0.0
sel_matrix = self.build_selection_matrix_from_parent(parent)
trimmed_aa_mask = aa_mask[: len(sel_matrix)]
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix, parent, child, nt_rates, nt_csps, multihit_model
sel_matrix[trimmed_aa_mask],
apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask),
apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask),
nt_rates[trimmed_aa_mask.repeat_interleave(3)],
nt_csps[trimmed_aa_mask.repeat_interleave(3)],
multihit_model,
)
if isinstance(starting_branch_length, torch.Tensor):
starting_branch_length = starting_branch_length.detach().item()
Expand All @@ -268,12 +274,13 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []
failed_count = 0

for parent, child, nt_rates, nt_csps, starting_length in tqdm(
for parent, child, nt_rates, nt_csps, aa_mask, starting_length in tqdm(
zip(
dataset.nt_parents,
dataset.nt_children,
dataset.nt_ratess,
dataset.nt_cspss,
dataset.masks,
dataset.branch_lengths,
),
total=len(dataset.nt_parents),
Expand All @@ -284,6 +291,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
child,
nt_rates[: len(parent)],
nt_csps[: len(parent), :],
aa_mask,
starting_length,
dataset.multihit_model,
**optimization_kwargs,
Expand Down
13 changes: 13 additions & 0 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,16 @@ def assert_full_sequences(parent, child):

if "N" in parent or "N" in child:
raise ValueError("Found ambiguous bases in the parent or child sequence.")


def apply_aa_mask_to_nt_sequence(nt_seq, aa_mask):
"""Apply an amino acid mask to a nucleotide sequence."""
return "".join(
nt for nt, mask_val in zip(nt_seq, aa_mask.repeat_interleave(3)) if mask_val
)


def iter_codons(nt_seq):
"""Iterate over the codons in a nucleotide sequence."""
for i in range(0, (len(nt_seq) // 3) * 3, 3):
yield nt_seq[i : i + 3]
Loading

0 comments on commit 24a19bd

Please sign in to comment.