From cb17df767296cbec1dbec72883d84bf743be1955 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Tue, 22 Oct 2024 02:27:56 -0700 Subject: [PATCH] Better DASM handling of ambiguous amino acids (#68) Previously we effectively ignored ambiguous amino acids by adding an extra fake amino acid, but this caused inconsistency downstream. Here we drop that in favor of more precise zapping. --- netam/dasm.py | 57 +++++++++++++++++++++++++------------------------ netam/models.py | 3 ++- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index bac48135..7597aa61 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -103,18 +103,22 @@ def to(self, device): def zap_predictions_along_diagonal(predictions, aa_parents_idxs): - """Set the diagonal (i.e. no amino acid change) of the predictions tensor to - -BIG.""" - # This is effectively - # predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = -BIG - # but we have a batch dimension. Thus the following. + """Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG, + except where aa_parents_idxs >= 20, which indicates no update should be done.""" + device = predictions.device batch_size, L, _ = predictions.shape - batch_indices = torch.arange(batch_size, device=predictions.device) + batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L) + sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1) + + # Create a mask for valid positions (where aa_parents_idxs is less than 20) + valid_mask = aa_parents_idxs < 20 + + # Only update the predictions for valid positions predictions[ - batch_indices[:, None], - torch.arange(L, device=predictions.device), - aa_parents_idxs, + batch_indices[valid_mask], + sequence_indices[valid_mask], + aa_parents_idxs[valid_mask], ] = -BIG return predictions @@ -162,33 +166,29 @@ def predictions_of_batch(self, batch): def loss_of_batch(self, batch): aa_subs_indicator = batch["subs_indicator"].to(self.device) + # Netam issue #16: child mask would be preferable here. mask = batch["mask"].to(self.device) aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) aa_children_idxs = batch["aa_children_idxs"].to(self.device) masked_aa_subs_indicator = aa_subs_indicator.masked_select(mask) predictions = self.predictions_of_batch(batch) - # Add one entry, zero, to the last dimension of the predictions tensor - # to handle the ambiguous amino acids. This is the conservative choice. - # It might be faster to reassign all the 20s to 0s if we are confident - # in our masking. Perhaps we should always output a 21st dimension - # for the ambiguous amino acids (see issue #16). - # Note that we're going to want to have a symbol for the junction - # between the heavy and light chains. - # If we change something here we should also change the test code - # in test_dasm.py::test_zero_diagonal. - predictions = torch.cat( - [predictions, torch.full_like(predictions[:, :, :1], -BIG)], dim=-1 - ) + # "Zapping" out the diagonal means setting it to zero in log space by + # setting it to -BIG. This is a no-op for sites that have an X + # (ambiguous AA) in the parent. This could cause problems in principle, + # but in practice we mask out sites with Xs in the parent for the + # mut_pos_loss, and we mask out sites with no substitution for the CSP + # loss. The latter class of sites also eliminates sites that have Xs in + # the parent or child (see sequences.aa_subs_indicator_tensor_of). predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs) # After zapping out the diagonal, we can effectively sum over the # off-diagonal elements to get the probability of a nonsynonymous - # mutation. - mut_pos_pred = torch.sum(torch.exp(predictions), dim=-1) - mut_pos_pred = mut_pos_pred.masked_select(mask) - mut_pos_pred = clamp_probability(mut_pos_pred) - mut_pos_loss = self.bce_loss(mut_pos_pred, masked_aa_subs_indicator) + # substitution. + subs_pos_pred = torch.sum(torch.exp(predictions), dim=-1) + subs_pos_pred = subs_pos_pred.masked_select(mask) + subs_pos_pred = clamp_probability(subs_pos_pred) + subs_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator) # We now need to calculate the conditional substitution probability # (CSP) loss. We have already zapped out the diagonal, and we're in @@ -200,11 +200,12 @@ def loss_of_batch(self, batch): csp_targets = aa_children_idxs[subs_mask] csp_loss = self.xent_loss(csp_pred, csp_targets) - return torch.stack([mut_pos_loss, csp_loss]) + return torch.stack([subs_pos_loss, csp_loss]) def build_selection_matrix_from_parent(self, parent: str): # This is simpler than the equivalent in dnsm.py because we get the selection - # matrix directly. + # matrix directly. Note that selection_factors_of_aa_str does the exponentiation + # so this indeed gives us the selection factors, not the log selection factors. parent = translate_sequence(parent) selection_factors = self.model.selection_factors_of_aa_str(parent) parent_idxs = sequences.aa_idx_array_of_str(parent) diff --git a/netam/models.py b/netam/models.py index 17008795..be31b8a0 100644 --- a/netam/models.py +++ b/netam/models.py @@ -537,7 +537,8 @@ def __init__(self): super().__init__() def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: - """Do the forward method without gradients from an amino acid string. + """Do the forward method then exponentiation without gradients from an amino + acid string. Args: aa_str: A string of amino acids.