Skip to content

Commit

Permalink
fix dasm, preliminarily
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Nov 20, 2024
1 parent d434a81 commit 23bc490
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
7 changes: 6 additions & 1 deletion netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def update_neutral_probs(self):

mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]
molevol.check_csps(parent_idxs, nt_csps)
nt_mask = mask.repeat_interleave(3)[: len(nt_parent)]
molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(nt_parent)][nt_mask])

# TODO don't we need to pass multihit model in here?
neutral_aa_probs = molevol.neutral_aa_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
Expand Down Expand Up @@ -201,6 +203,9 @@ def build_selection_matrix_from_parent(self, parent: str):
# so this indeed gives us the selection factors, not the log selection factors.
parent = sequences.translate_sequence(parent)
per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent)

# TODO this nonsense output will need to get masked
parent = parent.replace("X", "A")
parent_idxs = sequences.aa_idx_array_of_str(parent)
per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

Expand Down
6 changes: 3 additions & 3 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def update_neutral_probs(self):
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
# TODO Figure out how we're really going to handle masking, because
# old method allowed some nt N's to be unmasked.
nt_mask = mask.repeat_interleave(3)[: len(nt_parent)]
# nt_mask = torch.tensor([it != "N" for it in nt_parent], dtype=torch.bool)
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)
# Cannot assume that nt_csps and mask are same length, because when
# datasets are split, masks are recomputed.
# TODO Figure out how we're really going to handle masking, because
# old method allowed some nt N's to be unmasked.
nt_mask = mask.repeat_interleave(3)[: len(nt_parent)]
molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(nt_parent)][nt_mask])
# molevol.check_csps(parent_idxs[nt_mask], nt_csps[:len(parent_idxs)][nt_mask])

Expand Down
10 changes: 1 addition & 9 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,8 @@ def _find_optimal_branch_length(
multihit_model,
**optimization_kwargs,
):
# TODO: This doesn't seem quite right, because we'll mask whole codons
# if they contain just one ambiguity, even when we know they also
# contain a substitution.
if all(p_c == c_c for idx, (p_c, c_c) in enumerate(zip(parent, child)) if aa_mask[idx // 3]):
print("Parent and child are the same when codons containing N are masked")
assert False
# if parent == child:
# return 0.0
# TODO this doesn't use any mask, couldn't we use already-computed
# aa_parent?
# aa_parent and its mask?
sel_matrix = self.build_selection_matrix_from_parent(parent)
trimmed_aa_mask = aa_mask[: len(sel_matrix)]
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
Expand Down
2 changes: 2 additions & 0 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,8 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:

aa_idxs = aa_idx_tensor_of_str_ambig(aa_str)
aa_idxs = aa_idxs.to(model_device)
# TODO: Shouldn't we be using the new codon mask here, and allowing
# a pre-computed mask to be passed in?
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(model_device)

Expand Down

0 comments on commit 23bc490

Please sign in to comment.