Skip to content

Commit

Permalink
fix todo
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Oct 16, 2024
1 parent 3183d76 commit 6190196
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def loss_of_batch(self, batch):
# the sites for which aa_subs_indicator is 0.
subs_mask = aa_subs_indicator.bool()
csp_pred = predictions[subs_mask]
csp_targets = aa_children_idxs.masked_select(subs_mask)
csp_targets = aa_children_idxs[subs_mask]
csp_loss = self.xent_loss(csp_pred, csp_targets)

return torch.stack([mut_pos_loss, csp_loss])
Expand Down
5 changes: 2 additions & 3 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,8 +907,7 @@ def loss_of_batch(self, batch):
mut_prob = 1 - torch.exp(-rates * branch_lengths.unsqueeze(-1))
mut_prob_masked = mut_prob[masks]
mutation_indicator_masked = mutation_indicators[masks].float()
# TODO call this mut_pos_loss?
rate_loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked)
mut_pos_loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked)

# Conditional substitution probability (CSP) loss calculation
# Mask the new_base_idxs to focus only on positions with mutations
Expand All @@ -920,7 +919,7 @@ def loss_of_batch(self, batch):
assert (new_base_idxs_masked >= 0).all()
csp_loss = self.xent_loss(csp_logits_masked, new_base_idxs_masked)

return torch.stack([rate_loss, csp_loss])
return torch.stack([mut_pos_loss, csp_loss])

def _find_optimal_branch_length(
self,
Expand Down

0 comments on commit 6190196

Please sign in to comment.