Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better DASM handling of ambiguous amino acids #68

Merged
merged 6 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,22 @@ def to(self, device):


def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to
-BIG."""
# This is effectively
# predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = -BIG
# but we have a batch dimension. Thus the following.
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG,
except where aa_parents_idxs >= 20, which indicates no update should be done."""

device = predictions.device
batch_size, L, _ = predictions.shape
batch_indices = torch.arange(batch_size, device=predictions.device)
batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L)
sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1)

# Create a mask for valid positions (where aa_parents_idxs is less than 20)
valid_mask = aa_parents_idxs < 20

# Only update the predictions for valid positions
predictions[
batch_indices[:, None],
torch.arange(L, device=predictions.device),
aa_parents_idxs,
batch_indices[valid_mask],
sequence_indices[valid_mask],
aa_parents_idxs[valid_mask],
] = -BIG

return predictions
Expand Down Expand Up @@ -162,33 +166,29 @@ def predictions_of_batch(self, batch):

def loss_of_batch(self, batch):
aa_subs_indicator = batch["subs_indicator"].to(self.device)
# Netam issue #16: child mask would be preferable here.
mask = batch["mask"].to(self.device)
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
aa_children_idxs = batch["aa_children_idxs"].to(self.device)
masked_aa_subs_indicator = aa_subs_indicator.masked_select(mask)
predictions = self.predictions_of_batch(batch)
# Add one entry, zero, to the last dimension of the predictions tensor
# to handle the ambiguous amino acids. This is the conservative choice.
# It might be faster to reassign all the 20s to 0s if we are confident
# in our masking. Perhaps we should always output a 21st dimension
# for the ambiguous amino acids (see issue #16).
# Note that we're going to want to have a symbol for the junction
# between the heavy and light chains.
# If we change something here we should also change the test code
# in test_dasm.py::test_zero_diagonal.
predictions = torch.cat(
[predictions, torch.full_like(predictions[:, :, :1], -BIG)], dim=-1
)

# "Zapping" out the diagonal means setting it to zero in log space by
# setting it to -BIG. This is a no-op for sites that have an X
# (ambiguous AA) in the parent. This could cause problems in principle,
# but in practice we mask out sites with Xs in the parent for the
# mut_pos_loss, and we mask out sites with no substitution for the CSP
# loss. The latter class of sites also eliminates sites that have Xs in
# the parent or child (see sequences.aa_subs_indicator_tensor_of).
predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs)

# After zapping out the diagonal, we can effectively sum over the
# off-diagonal elements to get the probability of a nonsynonymous
# mutation.
mut_pos_pred = torch.sum(torch.exp(predictions), dim=-1)
mut_pos_pred = mut_pos_pred.masked_select(mask)
mut_pos_pred = clamp_probability(mut_pos_pred)
mut_pos_loss = self.bce_loss(mut_pos_pred, masked_aa_subs_indicator)
# substitution.
subs_pos_pred = torch.sum(torch.exp(predictions), dim=-1)
subs_pos_pred = subs_pos_pred.masked_select(mask)
subs_pos_pred = clamp_probability(subs_pos_pred)
subs_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator)

# We now need to calculate the conditional substitution probability
# (CSP) loss. We have already zapped out the diagonal, and we're in
Expand All @@ -200,11 +200,12 @@ def loss_of_batch(self, batch):
csp_targets = aa_children_idxs[subs_mask]
csp_loss = self.xent_loss(csp_pred, csp_targets)

return torch.stack([mut_pos_loss, csp_loss])
return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
# This is simpler than the equivalent in dnsm.py because we get the selection
# matrix directly.
# matrix directly. Note that selection_factors_of_aa_str does the exponentiation
# so this indeed gives us the selection factors, not the log selection factors.
parent = translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
parent_idxs = sequences.aa_idx_array_of_str(parent)
Expand Down
3 changes: 2 additions & 1 deletion netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,8 @@ def __init__(self):
super().__init__()

def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
"""Do the forward method without gradients from an amino acid string.
"""Do the forward method then exponentiation without gradients from an amino
acid string.

Args:
aa_str: A string of amino acids.
Expand Down