From fbe2505e303645721f50d1220b631ad34f065cd9 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Fri, 18 Oct 2024 02:55:04 -0700 Subject: [PATCH 1/6] better diagonal zapping --- netam/dasm.py | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index bac48135..038ffab3 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -103,18 +103,22 @@ def to(self, device): def zap_predictions_along_diagonal(predictions, aa_parents_idxs): - """Set the diagonal (i.e. no amino acid change) of the predictions tensor to - -BIG.""" - # This is effectively - # predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = -BIG - # but we have a batch dimension. Thus the following. - + """Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG, + except where aa_parents_idxs >= 20, which indicates no update should be done.""" + + device = predictions.device batch_size, L, _ = predictions.shape - batch_indices = torch.arange(batch_size, device=predictions.device) + batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L) + sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1) + + # Create a mask for valid positions (where aa_parents_idxs is less than 20) + valid_mask = aa_parents_idxs < 20 + + # Only update the predictions for valid positions predictions[ - batch_indices[:, None], - torch.arange(L, device=predictions.device), - aa_parents_idxs, + batch_indices[valid_mask], + sequence_indices[valid_mask], + aa_parents_idxs[valid_mask], ] = -BIG return predictions @@ -162,24 +166,15 @@ def predictions_of_batch(self, batch): def loss_of_batch(self, batch): aa_subs_indicator = batch["subs_indicator"].to(self.device) + # TODO this should be a child mask, right? mask = batch["mask"].to(self.device) aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) aa_children_idxs = batch["aa_children_idxs"].to(self.device) masked_aa_subs_indicator = aa_subs_indicator.masked_select(mask) predictions = self.predictions_of_batch(batch) - # Add one entry, zero, to the last dimension of the predictions tensor - # to handle the ambiguous amino acids. This is the conservative choice. - # It might be faster to reassign all the 20s to 0s if we are confident - # in our masking. Perhaps we should always output a 21st dimension - # for the ambiguous amino acids (see issue #16). - # Note that we're going to want to have a symbol for the junction - # between the heavy and light chains. - # If we change something here we should also change the test code - # in test_dasm.py::test_zero_diagonal. - predictions = torch.cat( - [predictions, torch.full_like(predictions[:, :, :1], -BIG)], dim=-1 - ) + # "Zapping" out the diagonal means setting it to zero in log space by + # setting it to -BIG. predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs) # After zapping out the diagonal, we can effectively sum over the From fbab54dd518c711fe9a3d19a6269b45a06fe58d8 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Fri, 18 Oct 2024 02:55:30 -0700 Subject: [PATCH 2/6] make format --- netam/dasm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 038ffab3..660e4cde 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -105,12 +105,12 @@ def to(self, device): def zap_predictions_along_diagonal(predictions, aa_parents_idxs): """Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG, except where aa_parents_idxs >= 20, which indicates no update should be done.""" - + device = predictions.device batch_size, L, _ = predictions.shape batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L) sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1) - + # Create a mask for valid positions (where aa_parents_idxs is less than 20) valid_mask = aa_parents_idxs < 20 From db2b79ad145ccdb4860a84f18eef50ab31d52de0 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Fri, 18 Oct 2024 03:31:28 -0700 Subject: [PATCH 3/6] comments and renaming --- netam/dasm.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 660e4cde..62b0c6b3 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -174,16 +174,21 @@ def loss_of_batch(self, batch): predictions = self.predictions_of_batch(batch) # "Zapping" out the diagonal means setting it to zero in log space by - # setting it to -BIG. + # setting it to -BIG. This is a no-op for sites that have an X + # (ambiguous AA) in the parent. This could cause problems in principle, + # but in practice we mask out sites with Xs in the parent for the + # mut_pos_loss, and we mask out sites with no substitution for the CSP + # loss. The latter class of sites also eliminates sites that have Xs in + # the parent or child (see sequences.aa_subs_indicator_tensor_of). predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs) # After zapping out the diagonal, we can effectively sum over the # off-diagonal elements to get the probability of a nonsynonymous - # mutation. - mut_pos_pred = torch.sum(torch.exp(predictions), dim=-1) - mut_pos_pred = mut_pos_pred.masked_select(mask) - mut_pos_pred = clamp_probability(mut_pos_pred) - mut_pos_loss = self.bce_loss(mut_pos_pred, masked_aa_subs_indicator) + # substitution. + subs_pos_pred = torch.sum(torch.exp(predictions), dim=-1) + subs_pos_pred = subs_pos_pred.masked_select(mask) + subs_pos_pred = clamp_probability(subs_pos_pred) + mut_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator) # We now need to calculate the conditional substitution probability # (CSP) loss. We have already zapped out the diagonal, and we're in From bf749edd560f5ee0ae6aabb1913bf5b4dd72f888 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Fri, 18 Oct 2024 11:27:53 -0700 Subject: [PATCH 4/6] comments --- netam/dasm.py | 3 ++- netam/models.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 62b0c6b3..8df3eef0 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -204,7 +204,8 @@ def loss_of_batch(self, batch): def build_selection_matrix_from_parent(self, parent: str): # This is simpler than the equivalent in dnsm.py because we get the selection - # matrix directly. + # matrix directly. Note that selection_factors_of_aa_str does the exponentiation + # so this indeed gives us the selection factors, not the log selection factors. parent = translate_sequence(parent) selection_factors = self.model.selection_factors_of_aa_str(parent) parent_idxs = sequences.aa_idx_array_of_str(parent) diff --git a/netam/models.py b/netam/models.py index 17008795..be31b8a0 100644 --- a/netam/models.py +++ b/netam/models.py @@ -537,7 +537,8 @@ def __init__(self): super().__init__() def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: - """Do the forward method without gradients from an amino acid string. + """Do the forward method then exponentiation without gradients from an amino + acid string. Args: aa_str: A string of amino acids. From ce47fff669df4936c6dd990e468c25064d745b76 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 21 Oct 2024 03:25:39 -0700 Subject: [PATCH 5/6] renaming --- netam/dasm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 8df3eef0..23c40a44 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -188,7 +188,7 @@ def loss_of_batch(self, batch): subs_pos_pred = torch.sum(torch.exp(predictions), dim=-1) subs_pos_pred = subs_pos_pred.masked_select(mask) subs_pos_pred = clamp_probability(subs_pos_pred) - mut_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator) + subs_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator) # We now need to calculate the conditional substitution probability # (CSP) loss. We have already zapped out the diagonal, and we're in @@ -200,7 +200,7 @@ def loss_of_batch(self, batch): csp_targets = aa_children_idxs[subs_mask] csp_loss = self.xent_loss(csp_pred, csp_targets) - return torch.stack([mut_pos_loss, csp_loss]) + return torch.stack([subs_pos_loss, csp_loss]) def build_selection_matrix_from_parent(self, parent: str): # This is simpler than the equivalent in dnsm.py because we get the selection From 5e1439ebb97dd2a6cddfd40b31fffcf45379ba95 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 21 Oct 2024 09:07:10 -0700 Subject: [PATCH 6/6] a note --- netam/dasm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netam/dasm.py b/netam/dasm.py index 23c40a44..7597aa61 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -166,7 +166,7 @@ def predictions_of_batch(self, batch): def loss_of_batch(self, batch): aa_subs_indicator = batch["subs_indicator"].to(self.device) - # TODO this should be a child mask, right? + # Netam issue #16: child mask would be preferable here. mask = batch["mask"].to(self.device) aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) aa_children_idxs = batch["aa_children_idxs"].to(self.device)