Skip to content

Commit

Permalink
renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Oct 21, 2024
1 parent bf749ed commit ce47fff
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def loss_of_batch(self, batch):
subs_pos_pred = torch.sum(torch.exp(predictions), dim=-1)
subs_pos_pred = subs_pos_pred.masked_select(mask)
subs_pos_pred = clamp_probability(subs_pos_pred)
mut_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator)
subs_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator)

# We now need to calculate the conditional substitution probability
# (CSP) loss. We have already zapped out the diagonal, and we're in
Expand All @@ -200,7 +200,7 @@ def loss_of_batch(self, batch):
csp_targets = aa_children_idxs[subs_mask]
csp_loss = self.xent_loss(csp_pred, csp_targets)

return torch.stack([mut_pos_loss, csp_loss])
return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
# This is simpler than the equivalent in dnsm.py because we get the selection
Expand Down

0 comments on commit ce47fff

Please sign in to comment.