Skip to content

Commit

Permalink
better mask handling
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Nov 20, 2024
1 parent 2e62463 commit d434a81
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 36 deletions.
46 changes: 46 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch import nn, Tensor
import multiprocessing as mp

from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence
BIG = 1e9
SMALL_PROB = 1e-6
BASES = ["A", "C", "G", "T"]
Expand Down Expand Up @@ -83,6 +84,51 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
return mask


def codon_mask_tensor_of(nt_parent, nt_child, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
mask = [
("N" not in parent_codon and "N" not in child_codon)
for parent_codon, child_codon in zip(iter_codons(nt_parent), iter_codons(nt_child))
]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)

def check_pcp_valid(parent, child, aa_mask=None):
"""Check that the parent-child pairs are valid.
* The parent and child sequences must be the same length
* There must be unmasked codons
* The parent and child sequences must not match after masking codons containing
ambiguities.
Args:
parent: The parent sequence.
child: The child sequence.
aa_mask: The mask tensor for the amino acid sequence. If None, it will be
computed from the parent and child sequences.
"""
if aa_mask is None:
aa_mask = codon_mask_tensor_of(parent, child)
if len(parent) != len(child):
raise ValueError("Parent and child sequences are not the same length.")
if not aa_mask.any():
raise ValueError("Parent-child pair is masked in all codons.")
if (
apply_aa_mask_to_nt_sequence(parent, aa_mask)
== apply_aa_mask_to_nt_sequence(child, aa_mask)
):
raise ValueError("Parent-child pair matches after masking codons containing ambiguities")


def nt_mask_tensor_of(*args, **kwargs):
return generic_mask_tensor_of("N", *args, **kwargs)

Expand Down
39 changes: 14 additions & 25 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,18 @@
aa_idx_tensor_of_str_ambig,
aa_mask_tensor_of,
stack_heterogeneous,
codon_mask_tensor_of,
check_pcp_valid,
)
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
from netam.sequences import (
aa_subs_indicator_tensor_of,
translate_sequences,
apply_aa_mask_to_nt_sequence,
)

def cautious_mask_tensor_of(nt_str, aa_length):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked.
"""
if aa_length is None:
aa_length = len(nt_str) // 3
mask = ["N" not in nt_str[i * 3:(i + 1) * 3] for i in range(len(nt_str) // 3)]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)

class DXSMDataset(Dataset, ABC):
prefix = "dxsm"
Expand All @@ -69,12 +58,6 @@ def __init__(
assert len(self.nt_parents) == len(self.nt_children)
pcp_count = len(self.nt_parents)

for parent, child in zip(self.nt_parents, self.nt_children):
if parent == child:
raise ValueError(
f"Found an identical parent and child sequence: {parent}"
)

aa_parents = translate_sequences(self.nt_parents)
aa_children = translate_sequences(self.nt_children)
self.max_aa_seq_len = max(len(seq) for seq in aa_parents)
Expand All @@ -90,11 +73,10 @@ def __init__(
self.masks = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool)

for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)):
# self.masks[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len)
# TODO Figure out how we're really going to handle masking
self.masks[i, :] = cautious_mask_tensor_of(nt_parents[i], self.max_aa_seq_len)

self.masks[i, :] = codon_mask_tensor_of(nt_parents[i], nt_children[i], self.max_aa_seq_len)
aa_seq_len = len(aa_parent)
check_pcp_valid(nt_parents[i], nt_children[i], aa_mask=self.masks[i][: aa_seq_len])

self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(
aa_parent
)
Expand Down Expand Up @@ -281,8 +263,14 @@ def _find_optimal_branch_length(
# TODO this doesn't use any mask, couldn't we use already-computed
# aa_parent?
sel_matrix = self.build_selection_matrix_from_parent(parent)
trimmed_aa_mask = aa_mask[: len(sel_matrix)]
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix, parent, child, nt_rates, nt_csps, aa_mask[:len(sel_matrix)], multihit_model
sel_matrix[trimmed_aa_mask],
sequences.apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask[: len(parent)]),
sequences.apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask[: len(child)]),
nt_rates[trimmed_aa_mask.repeat_interleave(3)],
nt_csps[trimmed_aa_mask.repeat_interleave(3)],
multihit_model,
)
if isinstance(starting_branch_length, torch.Tensor):
starting_branch_length = starting_branch_length.detach().item()
Expand Down Expand Up @@ -337,6 +325,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# TODO disable later...
burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)

Expand Down
15 changes: 6 additions & 9 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def neutral_aa_mut_probs(


def mutsel_log_pcp_probability_of(
sel_matrix, parent, child, nt_rates, nt_csps, aa_mask, multihit_model=None
sel_matrix, parent, child, nt_rates, nt_csps, multihit_model=None
):
"""Constructs the log_pcp_probability function specific to given nt_rates and
nt_csps.
Expand All @@ -446,9 +446,6 @@ def mutsel_log_pcp_probability_of(
assert len(parent) % 3 == 0
assert sel_matrix.shape == (len(parent) // 3, 20)

# This is masked out later
parent = parent.replace("N", "A")
child = child.replace("N", "A")
parent_idxs = sequences.nt_idx_tensor_of_str(parent)
child_idxs = sequences.nt_idx_tensor_of_str(child)

Expand All @@ -457,18 +454,18 @@ def log_pcp_probability(log_branch_length: torch.Tensor):
nt_mut_probs = 1.0 - torch.exp(-branch_length * nt_rates)

codon_mutsel, sums_too_big = build_codon_mutsel(
parent_idxs.reshape(-1, 3)[aa_mask],
nt_mut_probs.reshape(-1, 3)[aa_mask],
nt_csps.reshape(-1, 3, 4)[aa_mask],
sel_matrix[aa_mask],
parent_idxs.reshape(-1, 3),
nt_mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
sel_matrix,
multihit_model=multihit_model,
)

# This is a diagnostic generating data for netam issue #7.
# if sums_too_big is not None:
# self.csv_file.write(f"{parent},{child},{branch_length},{sums_too_big}\n")

reshaped_child_idxs = child_idxs.reshape(-1, 3)[aa_mask]
reshaped_child_idxs = child_idxs.reshape(-1, 3)
child_prob_vector = codon_mutsel[
torch.arange(len(reshaped_child_idxs)),
reshaped_child_idxs[:, 0],
Expand Down
11 changes: 11 additions & 0 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,14 @@ def assert_full_sequences(parent, child):

if "N" in parent or "N" in child:
raise ValueError("Found ambiguous bases in the parent or child sequence.")

def apply_aa_mask_to_nt_sequence(nt_seq, aa_mask):
"""Apply an amino acid mask to a nucleotide sequence."""
return "".join(
nt for nt, mask_val in zip(nt_seq, aa_mask.repeat_interleave(3)) if mask_val
)

def iter_codons(nt_seq):
"""Iterate over the codons in a nucleotide sequence."""
for i in range(0, len(nt_seq), 3):
yield nt_seq[i : i + 3]
2 changes: 0 additions & 2 deletions tests/test_ambiguous.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def test_dnsm_burrito(ambig_pcp_df, dnsm_model):
min_learning_rate=0.0001,
)
burrito.joint_train(epochs=1, cycle_count=2, training_method="full")
return burrito


@pytest.fixture
Expand Down Expand Up @@ -179,4 +178,3 @@ def test_dasm_burrito(ambig_pcp_df, dasm_model):
burrito.joint_train(
epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False
)
return burrito

0 comments on commit d434a81

Please sign in to comment.