Skip to content

Commit

Permalink
working test
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Nov 19, 2024
1 parent 3fed1d8 commit 152367c
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 17 deletions.
12 changes: 11 additions & 1 deletion netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def update_neutral_probs(self):
"""
neutral_aa_mut_prob_l = []

print("starting update_neutral_probs loop")
for nt_parent, mask, nt_rates, nt_csps, branch_length in zip(
self.nt_parents,
self.masks,
Expand All @@ -45,9 +46,16 @@ def update_neutral_probs(self):
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
# TODO Figure out how we're really going to handle masking, because
# old method allowed some nt N's to be unmasked.
nt_mask = mask.repeat_interleave(3)[: len(nt_parent)]
# nt_mask = torch.tensor([it != "N" for it in nt_parent], dtype=torch.bool)
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)
molevol.check_csps(parent_idxs, nt_csps)
# Cannot assume that nt_csps and mask are same length, because when
# datasets are split, masks are recomputed.
molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(nt_parent)][nt_mask])
# molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(parent_idxs)][nt_mask])

mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]
Expand Down Expand Up @@ -160,6 +168,8 @@ def build_selection_matrix_from_parent(self, parent: str):
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
# TODO this nonsense output will need to get masked
parent = parent.replace("X", "A")
# Set "diagonal" elements to one.
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0
Expand Down
56 changes: 47 additions & 9 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
translate_sequences,
)

def cautious_mask_tensor_of(nt_str, aa_length):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked.
"""
if aa_length is None:
aa_length = len(nt_str) // 3
mask = ["N" not in nt_str[i * 3:(i + 1) * 3] for i in range(len(nt_str) // 3)]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)

class DXSMDataset(Dataset, ABC):
prefix = "dxsm"
Expand All @@ -42,6 +56,7 @@ def __init__(
branch_lengths: torch.Tensor,
multihit_model=None,
):
print("starting DXSMDataset init")
self.nt_parents = nt_parents
self.nt_children = nt_children
self.nt_ratess = nt_ratess
Expand All @@ -64,6 +79,7 @@ def __init__(
aa_parents = translate_sequences(self.nt_parents)
aa_children = translate_sequences(self.nt_children)
self.max_aa_seq_len = max(len(seq) for seq in aa_parents)
print("max_aa_seq_len:", self.max_aa_seq_len)
# We have sequences of varying length, so we start with all tensors set
# to the ambiguous amino acid, and then will fill in the actual values
# below.
Expand All @@ -76,7 +92,10 @@ def __init__(
self.masks = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool)

for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)):
self.masks[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len)
# self.masks[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len)
# TODO Figure out how we're really going to handle masking
self.masks[i, :] = cautious_mask_tensor_of(nt_parents[i], self.max_aa_seq_len)

aa_seq_len = len(aa_parent)
self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(
aa_parent
Expand Down Expand Up @@ -248,32 +267,49 @@ def _find_optimal_branch_length(
child,
nt_rates,
nt_csps,
aa_mask,
starting_branch_length,
multihit_model,
**optimization_kwargs,
):
if parent == child:
return 0.0
# TODO: This doesn't seem quite right, because we'll mask whole codons
# if they contain just one ambiguity, even when we know they also
# contain a substitution.
if all(p_c == c_c for idx, (p_c, c_c) in enumerate(zip(parent, child)) if aa_mask[idx // 3]):
print("Parent and child are the same when codons containing N are masked")
assert False
# if parent == child:
# return 0.0
# TODO this doesn't use any mask, couldn't we use already-computed
# aa_parent?
sel_matrix = self.build_selection_matrix_from_parent(parent)
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix, parent, child, nt_rates, nt_csps, multihit_model
sel_matrix, parent, child, nt_rates, nt_csps, aa_mask[:len(sel_matrix)], multihit_model
)
if isinstance(starting_branch_length, torch.Tensor):
starting_branch_length = starting_branch_length.detach().item()
return molevol.optimize_branch_length(
res = molevol.optimize_branch_length(
log_pcp_probability, starting_branch_length, **optimization_kwargs
)
if np.isclose(res[0], 0.0):
print("Optimization converged to 0.0")
print("parent:", parent)
print("child:", child)
assert False
else:
return res

def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []
failed_count = 0

for parent, child, nt_rates, nt_csps, starting_length in tqdm(
for parent, child, nt_rates, nt_csps, aa_mask, starting_length in tqdm(
zip(
dataset.nt_parents,
dataset.nt_children,
dataset.nt_ratess,
dataset.nt_cspss,
dataset.masks,
dataset.branch_lengths,
),
total=len(dataset.nt_parents),
Expand All @@ -284,6 +320,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
child,
nt_rates[: len(parent)],
nt_csps[: len(parent), :],
aa_mask,
starting_length,
dataset.multihit_model,
**optimization_kwargs,
Expand All @@ -301,9 +338,10 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# # The following can be used when one wants a better traceback.
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
# The following can be used when one wants a better traceback.
burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)

our_optimize_branch_length = partial(
worker_optimize_branch_length,
self.__class__,
Expand Down
15 changes: 9 additions & 6 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def neutral_aa_mut_probs(


def mutsel_log_pcp_probability_of(
sel_matrix, parent, child, nt_rates, nt_csps, multihit_model=None
sel_matrix, parent, child, nt_rates, nt_csps, aa_mask, multihit_model=None
):
"""Constructs the log_pcp_probability function specific to given nt_rates and
nt_csps.
Expand All @@ -446,6 +446,9 @@ def mutsel_log_pcp_probability_of(
assert len(parent) % 3 == 0
assert sel_matrix.shape == (len(parent) // 3, 20)

# This is masked out later
parent = parent.replace("N", "A")
child = child.replace("N", "A")
parent_idxs = sequences.nt_idx_tensor_of_str(parent)
child_idxs = sequences.nt_idx_tensor_of_str(child)

Expand All @@ -454,18 +457,18 @@ def log_pcp_probability(log_branch_length: torch.Tensor):
nt_mut_probs = 1.0 - torch.exp(-branch_length * nt_rates)

codon_mutsel, sums_too_big = build_codon_mutsel(
parent_idxs.reshape(-1, 3),
nt_mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
sel_matrix,
parent_idxs.reshape(-1, 3)[aa_mask],
nt_mut_probs.reshape(-1, 3)[aa_mask],
nt_csps.reshape(-1, 3, 4)[aa_mask],
sel_matrix[aa_mask],
multihit_model=multihit_model,
)

# This is a diagnostic generating data for netam issue #7.
# if sums_too_big is not None:
# self.csv_file.write(f"{parent},{child},{branch_length},{sums_too_big}\n")

reshaped_child_idxs = child_idxs.reshape(-1, 3)
reshaped_child_idxs = child_idxs.reshape(-1, 3)[aa_mask]
child_prob_vector = codon_mutsel[
torch.arange(len(reshaped_child_idxs)),
reshaped_child_idxs[:, 0],
Expand Down
31 changes: 30 additions & 1 deletion tests/test_ambiguous.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,15 @@


# Function to randomly insert 'N' in sequences
def randomize_with_ns(parent_seq, child_seq):
def randomize_with_ns(parent_seq, child_seq, avoid_masked_equality=True):
old_parent = parent_seq
old_child = child_seq
seq_length = len(parent_seq)
try:
first_mut = next((idx, p, c) for idx, (p, c) in enumerate(zip(parent_seq, child_seq)) if p != c)
except:
return parent_seq, child_seq


# Decide which type of modification to apply
modification_type = random.choice(["same_site", "different_site", "chunk", "none"])
Expand Down Expand Up @@ -74,6 +81,27 @@ def randomize_with_ns(parent_seq, child_seq):
+ child_seq[start_pos + chunk_size :]
)

if parent_seq == child_seq:
# If sequences are the same, put one mutated site back in:
idx, p, c = first_mut
parent_seq = parent_seq[:idx] + p + parent_seq[idx + 1 :]
child_seq = child_seq[:idx] + c + child_seq[idx + 1 :]
if avoid_masked_equality:
codon_pairs = [
(parent_seq[i*3: (i+1)*3], child_seq[i*3: (i+1)*3])
for i in range(seq_length // 3)
]
if all(p == c for p, c in filter(lambda pair: "N" not in pair[0] and "N" not in pair[1], codon_pairs)):
# put original codon containing a mutation back in.
idx, p, c = first_mut
codon_start = (idx // 3) * 3
codon_end = codon_start + 3
parent_seq = parent_seq[:codon_start] + old_parent[codon_start:codon_end] + parent_seq[codon_end:]
child_seq = child_seq[:codon_start] + old_child[codon_start:codon_end] + child_seq[codon_end:]

assert len(parent_seq) == len(child_seq)
assert len(parent_seq) == seq_length

return parent_seq, child_seq


Expand Down Expand Up @@ -104,6 +132,7 @@ def dnsm_model():

def test_dnsm_burrito(ambig_pcp_df, dnsm_model):
"""Fixture that returns the DNSM Burrito object."""
# TODO fix and make also work with randomize_with_ns avoid_masked_equality=False
force_spawn()
ambig_pcp_df["in_train"] = True
ambig_pcp_df.loc[ambig_pcp_df.index[-15:], "in_train"] = False
Expand Down

0 comments on commit 152367c

Please sign in to comment.