Skip to content

Commit

Permalink
cleanup and format
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Nov 20, 2024
1 parent 23bc490 commit b2601a4
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 57 deletions.
6 changes: 3 additions & 3 deletions netam/attention_map.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
r"""
We are going to get the attention weights using the [MultiHeadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) module in PyTorch. These weights are
$$
Expand All @@ -9,10 +9,10 @@
In our terminology, an attention map is the attention map for a single head. An
"attention maps" object is a collection of attention maps: a tensor where the
first dimension is the number of heads. This assumes all layers have the same
first dimension is the number of heads. This assumes all layers have the same
number of heads. An "attention mapss" is a list of attention maps objects, one
for each sequence in the batch. An "attention profile" is some 1-D summary of an
attention map, such as the maximum attention score for each key position.
attention map, such as the maximum attention score for each key position.
# Adapted from https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91
"""
Expand Down
23 changes: 14 additions & 9 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import multiprocessing as mp

from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence

BIG = 1e9
SMALL_PROB = 1e-6
BASES = ["A", "C", "G", "T"]
Expand Down Expand Up @@ -84,16 +85,18 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
return mask


def codon_mask_tensor_of(nt_parent, nt_child, aa_length=None):
def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked.
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
("N" not in parent_codon and "N" not in child_codon)
for parent_codon, child_codon in zip(iter_codons(nt_parent), iter_codons(nt_child))
all("N" not in codon for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
Expand All @@ -102,7 +105,8 @@ def codon_mask_tensor_of(nt_parent, nt_child, aa_length=None):
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)

def check_pcp_valid(parent, child, aa_mask=None):

def assert_pcp_valid(parent, child, aa_mask=None):
"""Check that the parent-child pairs are valid.
* The parent and child sequences must be the same length
Expand All @@ -122,11 +126,12 @@ def check_pcp_valid(parent, child, aa_mask=None):
raise ValueError("Parent and child sequences are not the same length.")
if not aa_mask.any():
raise ValueError("Parent-child pair is masked in all codons.")
if (
apply_aa_mask_to_nt_sequence(parent, aa_mask)
== apply_aa_mask_to_nt_sequence(child, aa_mask)
if apply_aa_mask_to_nt_sequence(parent, aa_mask) == apply_aa_mask_to_nt_sequence(
child, aa_mask
):
raise ValueError("Parent-child pair matches after masking codons containing ambiguities")
raise ValueError(
"Parent-child pair matches after masking codons containing ambiguities"
)


def nt_mask_tensor_of(*args, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def update_neutral_probs(self):
mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]
nt_mask = mask.repeat_interleave(3)[: len(nt_parent)]
molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(nt_parent)][nt_mask])
molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask])

# TODO don't we need to pass multihit model in here?
neutral_aa_probs = molevol.neutral_aa_probs(
Expand Down Expand Up @@ -198,13 +198,16 @@ def loss_of_batch(self, batch):
return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent amino acid sequence.
Values at ambiguous sites are meaningless.
"""
# This is simpler than the equivalent in dnsm.py because we get the selection
# matrix directly. Note that selection_factors_of_aa_str does the exponentiation
# so this indeed gives us the selection factors, not the log selection factors.
parent = sequences.translate_sequence(parent)
per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent)

# TODO this nonsense output will need to get masked
parent = parent.replace("X", "A")
parent_idxs = sequences.aa_idx_array_of_str(parent)
per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0
Expand Down
12 changes: 6 additions & 6 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ def update_neutral_probs(self):
parent_len = len(nt_parent)
# Cannot assume that nt_csps and mask are same length, because when
# datasets are split, masks are recomputed.
# TODO Figure out how we're really going to handle masking, because
# old method allowed some nt N's to be unmasked.
nt_mask = mask.repeat_interleave(3)[: len(nt_parent)]
molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(nt_parent)][nt_mask])
# molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(parent_idxs)][nt_mask])
nt_mask = mask.repeat_interleave(3)[:parent_len]
molevol.check_csps(parent_idxs[nt_mask], nt_csps[:parent_len][nt_mask])

mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]
Expand Down Expand Up @@ -161,13 +158,16 @@ def loss_of_batch(self, batch):
return self.bce_loss(predictions, aa_subs_indicator)

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent amino acid sequence.
Values at ambiguous sites are meaningless.
"""
parent = sequences.translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
# TODO this nonsense output will need to get masked
parent = parent.replace("X", "A")
# Set "diagonal" elements to one.
parent_idxs = sequences.aa_idx_array_of_str(parent)
Expand Down
30 changes: 14 additions & 16 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@
from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
aa_mask_tensor_of,
stack_heterogeneous,
codon_mask_tensor_of,
check_pcp_valid,
assert_pcp_valid,
)
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
from netam.sequences import (
aa_subs_indicator_tensor_of,
translate_sequences,
apply_aa_mask_to_nt_sequence,
nt_mutation_frequency,
)


Expand Down Expand Up @@ -73,9 +72,13 @@ def __init__(
self.masks = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool)

for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)):
self.masks[i, :] = codon_mask_tensor_of(nt_parents[i], nt_children[i], self.max_aa_seq_len)
self.masks[i, :] = codon_mask_tensor_of(
nt_parents[i], nt_children[i], self.max_aa_seq_len
)
aa_seq_len = len(aa_parent)
check_pcp_valid(nt_parents[i], nt_children[i], aa_mask=self.masks[i][: aa_seq_len])
assert_pcp_valid(
nt_parents[i], nt_children[i], aa_mask=self.masks[i][:aa_seq_len]
)

self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(
aa_parent
Expand Down Expand Up @@ -111,8 +114,7 @@ def of_seriess(
"""
initial_branch_lengths = np.array(
[
sequences.nt_mutation_frequency(parent, child)
* branch_length_multiplier
nt_mutation_frequency(parent, child) * branch_length_multiplier
for parent, child in zip(nt_parents, nt_children)
]
)
Expand Down Expand Up @@ -252,14 +254,12 @@ def _find_optimal_branch_length(
multihit_model,
**optimization_kwargs,
):
# TODO this doesn't use any mask, couldn't we use already-computed
# aa_parent and its mask?
sel_matrix = self.build_selection_matrix_from_parent(parent)
trimmed_aa_mask = aa_mask[: len(sel_matrix)]
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix[trimmed_aa_mask],
sequences.apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask[: len(parent)]),
sequences.apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask[: len(child)]),
apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask),
apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask),
nt_rates[trimmed_aa_mask.repeat_interleave(3)],
nt_csps[trimmed_aa_mask.repeat_interleave(3)],
multihit_model,
Expand Down Expand Up @@ -316,11 +316,9 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# TODO disable later...
burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)

# # The following can be used when one wants a better traceback.
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
our_optimize_branch_length = partial(
worker_optimize_branch_length,
self.__class__,
Expand Down
2 changes: 0 additions & 2 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,6 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:

aa_idxs = aa_idx_tensor_of_str_ambig(aa_str)
aa_idxs = aa_idxs.to(model_device)
# TODO: Shouldn't we be using the new codon mask here, and allowing
# a pre-computed mask to be passed in?
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(model_device)

Expand Down
4 changes: 3 additions & 1 deletion netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,15 @@ def assert_full_sequences(parent, child):
if "N" in parent or "N" in child:
raise ValueError("Found ambiguous bases in the parent or child sequence.")


def apply_aa_mask_to_nt_sequence(nt_seq, aa_mask):
"""Apply an amino acid mask to a nucleotide sequence."""
return "".join(
nt for nt, mask_val in zip(nt_seq, aa_mask.repeat_interleave(3)) if mask_val
)


def iter_codons(nt_seq):
"""Iterate over the codons in a nucleotide sequence."""
for i in range(0, len(nt_seq), 3):
for i in range(0, (len(nt_seq) // 3) * 3, 3):
yield nt_seq[i : i + 3]
38 changes: 24 additions & 14 deletions tests/test_ambiguous.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import pytest

from netam.common import force_spawn
from netam.framework import (
crepe_exists,
load_crepe,
)
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dasm import (
DASMBurrito,
DASMDataset,
)
from netam.framework import SHMoofDataset, SHMBurrito, RSSHMBurrito
from netam.models import SHMoofModel, RSSHMoofModel, IndepRSCNNModel
from netam.dnsm import DNSMBurrito, DNSMDataset
import pytest
from netam.framework import (
load_pcp_df,
add_shm_model_outputs_to_pcp_df,
Expand All @@ -28,11 +21,14 @@ def randomize_with_ns(parent_seq, child_seq, avoid_masked_equality=True):
old_child = child_seq
seq_length = len(parent_seq)
try:
first_mut = next((idx, p, c) for idx, (p, c) in enumerate(zip(parent_seq, child_seq)) if p != c)
first_mut = next(
(idx, p, c)
for idx, (p, c) in enumerate(zip(parent_seq, child_seq))
if p != c
)
except:
return parent_seq, child_seq


# Decide which type of modification to apply
modification_type = random.choice(["same_site", "different_site", "chunk", "none"])

Expand Down Expand Up @@ -88,16 +84,29 @@ def randomize_with_ns(parent_seq, child_seq, avoid_masked_equality=True):
child_seq = child_seq[:idx] + c + child_seq[idx + 1 :]
if avoid_masked_equality:
codon_pairs = [
(parent_seq[i*3: (i+1)*3], child_seq[i*3: (i+1)*3])
(parent_seq[i * 3 : (i + 1) * 3], child_seq[i * 3 : (i + 1) * 3])
for i in range(seq_length // 3)
]
if all(p == c for p, c in filter(lambda pair: "N" not in pair[0] and "N" not in pair[1], codon_pairs)):
if all(
p == c
for p, c in filter(
lambda pair: "N" not in pair[0] and "N" not in pair[1], codon_pairs
)
):
# put original codon containing a mutation back in.
idx, p, c = first_mut
codon_start = (idx // 3) * 3
codon_end = codon_start + 3
parent_seq = parent_seq[:codon_start] + old_parent[codon_start:codon_end] + parent_seq[codon_end:]
child_seq = child_seq[:codon_start] + old_child[codon_start:codon_end] + child_seq[codon_end:]
parent_seq = (
parent_seq[:codon_start]
+ old_parent[codon_start:codon_end]
+ parent_seq[codon_end:]
)
child_seq = (
child_seq[:codon_start]
+ old_child[codon_start:codon_end]
+ child_seq[codon_end:]
)

assert len(parent_seq) == len(child_seq)
assert len(parent_seq) == seq_length
Expand All @@ -124,15 +133,16 @@ def ambig_pcp_df():
)
return df


@pytest.fixture
def dnsm_model():
return TransformerBinarySelectionModelWiggleAct(
nhead=2, d_model_per_head=4, dim_feedforward=256, layer_count=2
)


def test_dnsm_burrito(ambig_pcp_df, dnsm_model):
"""Fixture that returns the DNSM Burrito object."""
# TODO fix and make also work with randomize_with_ns avoid_masked_equality=False
force_spawn()
ambig_pcp_df["in_train"] = True
ambig_pcp_df.loc[ambig_pcp_df.index[-15:], "in_train"] = False
Expand Down
14 changes: 13 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from netam.common import nt_mask_tensor_of, aa_mask_tensor_of
from netam.common import nt_mask_tensor_of, aa_mask_tensor_of, codon_mask_tensor_of


def test_mask_tensor_of():
Expand All @@ -13,3 +13,15 @@ def test_mask_tensor_of():
expected_output = torch.tensor([1, 1, 1, 1, 0], dtype=torch.bool)
output = aa_mask_tensor_of(input_seq, length=5)
assert torch.equal(output, expected_output)


def test_codon_mask_tensor_of():
input_seq = "NAAAAAAAAAA"
# First test as nucleotides.
expected_output = torch.tensor([0, 1, 1, 0, 0], dtype=torch.bool)
output = codon_mask_tensor_of(input_seq, aa_length=5)
assert torch.equal(output, expected_output)
input_seq2 = "AAAANAAAAAA"
expected_output = torch.tensor([0, 0, 1, 0, 0], dtype=torch.bool)
output = codon_mask_tensor_of(input_seq, input_seq2, aa_length=5)
assert torch.equal(output, expected_output)
1 change: 0 additions & 1 deletion tests/test_molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest

import netam.molevol as molevol
from netam import framework
from netam import pretrained

from netam.sequences import (
Expand Down
4 changes: 2 additions & 2 deletions tests/test_netam.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_make_dataset(tiny_dataset):
assert branch_length == 1 / 5


def test_write_output(tiny_burrito):
def test_write_tinyburrito_output(tiny_burrito):
os.makedirs("_ignore", exist_ok=True)
tiny_burrito.model.write_shmoof_output("_ignore")

Expand Down Expand Up @@ -105,7 +105,7 @@ def mini_rsburrito(mini_dataset, tiny_rsscnnmodel):
return burrito


def test_write_output(mini_rsburrito):
def test_write_mini_rsburrito_output(mini_rsburrito):
os.makedirs("_ignore", exist_ok=True)
mini_rsburrito.save_crepe("_ignore/mini_rscrepe")

Expand Down

0 comments on commit b2601a4

Please sign in to comment.