Skip to content

Commit

Permalink
Better DASM handling of ambiguous amino acids (#68)
Browse files Browse the repository at this point in the history
Previously we effectively ignored ambiguous amino acids by adding an extra fake amino acid, but this caused inconsistency downstream. Here we drop that in favor of more precise zapping.
  • Loading branch information
matsen authored Oct 22, 2024
1 parent f204f6b commit cb17df7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 29 deletions.
57 changes: 29 additions & 28 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,22 @@ def to(self, device):


def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to
-BIG."""
# This is effectively
# predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = -BIG
# but we have a batch dimension. Thus the following.
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG,
except where aa_parents_idxs >= 20, which indicates no update should be done."""

device = predictions.device
batch_size, L, _ = predictions.shape
batch_indices = torch.arange(batch_size, device=predictions.device)
batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L)
sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1)

# Create a mask for valid positions (where aa_parents_idxs is less than 20)
valid_mask = aa_parents_idxs < 20

# Only update the predictions for valid positions
predictions[
batch_indices[:, None],
torch.arange(L, device=predictions.device),
aa_parents_idxs,
batch_indices[valid_mask],
sequence_indices[valid_mask],
aa_parents_idxs[valid_mask],
] = -BIG

return predictions
Expand Down Expand Up @@ -162,33 +166,29 @@ def predictions_of_batch(self, batch):

def loss_of_batch(self, batch):
aa_subs_indicator = batch["subs_indicator"].to(self.device)
# Netam issue #16: child mask would be preferable here.
mask = batch["mask"].to(self.device)
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
aa_children_idxs = batch["aa_children_idxs"].to(self.device)
masked_aa_subs_indicator = aa_subs_indicator.masked_select(mask)
predictions = self.predictions_of_batch(batch)
# Add one entry, zero, to the last dimension of the predictions tensor
# to handle the ambiguous amino acids. This is the conservative choice.
# It might be faster to reassign all the 20s to 0s if we are confident
# in our masking. Perhaps we should always output a 21st dimension
# for the ambiguous amino acids (see issue #16).
# Note that we're going to want to have a symbol for the junction
# between the heavy and light chains.
# If we change something here we should also change the test code
# in test_dasm.py::test_zero_diagonal.
predictions = torch.cat(
[predictions, torch.full_like(predictions[:, :, :1], -BIG)], dim=-1
)

# "Zapping" out the diagonal means setting it to zero in log space by
# setting it to -BIG. This is a no-op for sites that have an X
# (ambiguous AA) in the parent. This could cause problems in principle,
# but in practice we mask out sites with Xs in the parent for the
# mut_pos_loss, and we mask out sites with no substitution for the CSP
# loss. The latter class of sites also eliminates sites that have Xs in
# the parent or child (see sequences.aa_subs_indicator_tensor_of).
predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs)

# After zapping out the diagonal, we can effectively sum over the
# off-diagonal elements to get the probability of a nonsynonymous
# mutation.
mut_pos_pred = torch.sum(torch.exp(predictions), dim=-1)
mut_pos_pred = mut_pos_pred.masked_select(mask)
mut_pos_pred = clamp_probability(mut_pos_pred)
mut_pos_loss = self.bce_loss(mut_pos_pred, masked_aa_subs_indicator)
# substitution.
subs_pos_pred = torch.sum(torch.exp(predictions), dim=-1)
subs_pos_pred = subs_pos_pred.masked_select(mask)
subs_pos_pred = clamp_probability(subs_pos_pred)
subs_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator)

# We now need to calculate the conditional substitution probability
# (CSP) loss. We have already zapped out the diagonal, and we're in
Expand All @@ -200,11 +200,12 @@ def loss_of_batch(self, batch):
csp_targets = aa_children_idxs[subs_mask]
csp_loss = self.xent_loss(csp_pred, csp_targets)

return torch.stack([mut_pos_loss, csp_loss])
return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
# This is simpler than the equivalent in dnsm.py because we get the selection
# matrix directly.
# matrix directly. Note that selection_factors_of_aa_str does the exponentiation
# so this indeed gives us the selection factors, not the log selection factors.
parent = translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
parent_idxs = sequences.aa_idx_array_of_str(parent)
Expand Down
3 changes: 2 additions & 1 deletion netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,8 @@ def __init__(self):
super().__init__()

def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
"""Do the forward method without gradients from an amino acid string.
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Args:
aa_str: A string of amino acids.
Expand Down

0 comments on commit cb17df7

Please sign in to comment.