From f204f6b392450533bb9e9bd0fffb571acbcc1152 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Wed, 16 Oct 2024 15:43:00 -0700 Subject: [PATCH] Adding a per-AA loss to the DASM (#66) * moving DASM prediction to be output in log space * added a CSP component of a weighted loss * refactored two loss functionality into a mixin class --- netam/common.py | 4 +++ netam/dasm.py | 72 ++++++++++++++++++++++++++++++---------------- netam/dnsm.py | 11 ++----- netam/framework.py | 40 +++++++++++++++----------- tests/test_dasm.py | 9 +++--- 5 files changed, 83 insertions(+), 53 deletions(-) diff --git a/netam/common.py b/netam/common.py index b6788b6f..8f2f0c88 100644 --- a/netam/common.py +++ b/netam/common.py @@ -89,6 +89,10 @@ def clamp_probability(x: Tensor) -> Tensor: return torch.clamp(x, min=SMALL_PROB, max=(1.0 - SMALL_PROB)) +def clamp_log_probability(x: Tensor) -> Tensor: + return torch.clamp(x, max=np.log(1.0 - SMALL_PROB)) + + def print_parameter_count(model): total = 0 for name, module in model.named_modules(): diff --git a/netam/dasm.py b/netam/dasm.py index 90603a09..bac48135 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -1,4 +1,4 @@ -"""Here we define a mutation-selection model that is per-amino-acid.""" +"""Here we define a model that outputs a vector of 20 amino acid preferences.""" import torch import torch.nn.functional as F @@ -11,9 +11,12 @@ import pandas as pd from netam.common import ( + clamp_log_probability, clamp_probability, + BIG, ) import netam.dnsm as dnsm +import netam.framework as framework import netam.molevol as molevol import netam.sequences as sequences from netam.sequences import ( @@ -81,6 +84,7 @@ def update_neutral_probs(self): def __getitem__(self, idx): return { "aa_parents_idxs": self.aa_parents_idxs[idx], + "aa_children_idxs": self.aa_children_idxs[idx], "subs_indicator": self.aa_subs_indicator_tensor[idx], "mask": self.mask[idx], "log_neutral_aa_probs": self.log_neutral_aa_probs[idx], @@ -90,6 +94,7 @@ def __getitem__(self, idx): def to(self, device): self.aa_parents_idxs = self.aa_parents_idxs.to(device) + self.aa_children_idxs = self.aa_children_idxs.to(device) self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device) self.mask = self.mask.to(device) self.log_neutral_aa_probs = self.log_neutral_aa_probs.to(device) @@ -97,14 +102,11 @@ def to(self, device): self.all_subs_probs = self.all_subs_probs.to(device) -def zero_predictions_along_diagonal(predictions, aa_parents_idxs): - """Zero out the diagonal of a batch of predictions. - - We do this so that we can sum then have the same type of predictions as for the - DNSM. - """ - # We would like to do - # predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0 +def zap_predictions_along_diagonal(predictions, aa_parents_idxs): + """Set the diagonal (i.e. no amino acid change) of the predictions tensor to + -BIG.""" + # This is effectively + # predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = -BIG # but we have a batch dimension. Thus the following. batch_size, L, _ = predictions.shape @@ -113,12 +115,16 @@ def zero_predictions_along_diagonal(predictions, aa_parents_idxs): batch_indices[:, None], torch.arange(L, device=predictions.device), aa_parents_idxs, - ] = 0.0 + ] = -BIG return predictions -class DASMBurrito(dnsm.DNSMBurrito): +class DASMBurrito(framework.TwoLossMixin, dnsm.DNSMBurrito): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.xent_loss = torch.nn.CrossEntropyLoss() + self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device) def prediction_pair_of_batch(self, batch): """Get log neutral AA probabilities and log selection factors for a batch of @@ -134,12 +140,13 @@ def prediction_pair_of_batch(self, batch): return log_neutral_aa_probs, log_selection_factors def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors): - # Take the product of the neutral mutation probabilities and the - # selection factors, namely p_{j, a} * f_{j, a}. - # In contrast to a DNSM, each of these now have last dimension of 20. - predictions = torch.exp(log_neutral_aa_probs + log_selection_factors) - assert torch.isfinite(predictions).all() - predictions = clamp_probability(predictions) + """Take the sum of the neutral mutation log probabilities and the selection + factors. + + In contrast to a DNSM, each of these now have last dimension of 20. + """ + predictions = log_neutral_aa_probs + log_selection_factors + assert torch.isnan(predictions).sum() == 0 return predictions def predictions_of_batch(self, batch): @@ -157,28 +164,43 @@ def loss_of_batch(self, batch): aa_subs_indicator = batch["subs_indicator"].to(self.device) mask = batch["mask"].to(self.device) aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) - aa_subs_indicator = aa_subs_indicator.masked_select(mask) + aa_children_idxs = batch["aa_children_idxs"].to(self.device) + masked_aa_subs_indicator = aa_subs_indicator.masked_select(mask) predictions = self.predictions_of_batch(batch) # Add one entry, zero, to the last dimension of the predictions tensor # to handle the ambiguous amino acids. This is the conservative choice. # It might be faster to reassign all the 20s to 0s if we are confident # in our masking. Perhaps we should always output a 21st dimension # for the ambiguous amino acids (see issue #16). + # Note that we're going to want to have a symbol for the junction + # between the heavy and light chains. # If we change something here we should also change the test code # in test_dasm.py::test_zero_diagonal. predictions = torch.cat( - [predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1 + [predictions, torch.full_like(predictions[:, :, :1], -BIG)], dim=-1 ) - predictions = zero_predictions_along_diagonal(predictions, aa_parents_idxs) + predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs) - # After zeroing out the diagonal, we are effectively summing over the + # After zapping out the diagonal, we can effectively sum over the # off-diagonal elements to get the probability of a nonsynonymous # mutation. - predictions_of_mut = torch.sum(predictions, dim=-1) - predictions_of_mut = predictions_of_mut.masked_select(mask) - predictions_of_mut = clamp_probability(predictions_of_mut) - return self.bce_loss(predictions_of_mut, aa_subs_indicator) + mut_pos_pred = torch.sum(torch.exp(predictions), dim=-1) + mut_pos_pred = mut_pos_pred.masked_select(mask) + mut_pos_pred = clamp_probability(mut_pos_pred) + mut_pos_loss = self.bce_loss(mut_pos_pred, masked_aa_subs_indicator) + + # We now need to calculate the conditional substitution probability + # (CSP) loss. We have already zapped out the diagonal, and we're in + # logit space, so we are set up for using the cross entropy loss. + # However we have to mask out the sites that are not substituted, i.e. + # the sites for which aa_subs_indicator is 0. + subs_mask = aa_subs_indicator == 1 + csp_pred = predictions[subs_mask] + csp_targets = aa_children_idxs[subs_mask] + csp_loss = self.xent_loss(csp_pred, csp_targets) + + return torch.stack([mut_pos_loss, csp_loss]) def build_selection_matrix_from_parent(self, parent: str): # This is simpler than the equivalent in dnsm.py because we get the selection diff --git a/netam/dnsm.py b/netam/dnsm.py index e71e33d4..3cd42be3 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -1,11 +1,4 @@ -"""Here we define a mutation-selection model that is just about mutation vs no mutation, -and is trainable. - -We'll use these conventions: - -* B is the batch size -* L is the max sequence length -""" +"""Defining the deep natural selection model (DNSM).""" import copy import multiprocessing as mp @@ -74,6 +67,7 @@ def __init__( self.aa_parents_idxs = torch.full( (pcp_count, self.max_aa_seq_len), MAX_AMBIG_AA_IDX ) + self.aa_children_idxs = self.aa_parents_idxs.clone() self.aa_subs_indicator_tensor = torch.zeros((pcp_count, self.max_aa_seq_len)) self.mask = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool) @@ -82,6 +76,7 @@ def __init__( self.mask[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len) aa_seq_len = len(aa_parent) self.aa_parents_idxs[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_parent) + self.aa_children_idxs[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_child) self.aa_subs_indicator_tensor[i, :aa_seq_len] = aa_subs_indicator_tensor_of( aa_parent, aa_child ) diff --git a/netam/framework.py b/netam/framework.py index d44ab1d0..37fa6785 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -858,11 +858,9 @@ def to_crepe(self): return Crepe(encoder, self.model, training_hyperparameters) -class RSSHMBurrito(SHMBurrito): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.xent_loss = nn.CrossEntropyLoss() - self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device) +class TwoLossMixin: + """A mixin for models that have two losses, one for mutation position and one for + conditional substitution probability (CSP).""" def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None): if loss_reduction is None: @@ -870,6 +868,25 @@ def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None return super().process_data_loader(data_loader, train_mode, loss_reduction) + def write_loss(self, loss_name, loss, step): + rate_loss, csp_loss = loss.unbind() + self.writer.add_scalar( + "Mut pos " + loss_name, + rate_loss.item(), + step, + walltime=self.execution_time(), + ) + self.writer.add_scalar( + "CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time() + ) + + +class RSSHMBurrito(TwoLossMixin, SHMBurrito): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.xent_loss = nn.CrossEntropyLoss() + self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device) + def evaluate(self): val_loader = self.build_val_loader() return super().process_data_loader( @@ -890,7 +907,7 @@ def loss_of_batch(self, batch): mut_prob = 1 - torch.exp(-rates * branch_lengths.unsqueeze(-1)) mut_prob_masked = mut_prob[masks] mutation_indicator_masked = mutation_indicators[masks].float() - rate_loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked) + mut_pos_loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked) # Conditional substitution probability (CSP) loss calculation # Mask the new_base_idxs to focus only on positions with mutations @@ -902,7 +919,7 @@ def loss_of_batch(self, batch): assert (new_base_idxs_masked >= 0).all() csp_loss = self.xent_loss(csp_logits_masked, new_base_idxs_masked) - return torch.stack([rate_loss, csp_loss]) + return torch.stack([mut_pos_loss, csp_loss]) def _find_optimal_branch_length( self, @@ -983,15 +1000,6 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): return torch.tensor(optimal_lengths) - def write_loss(self, loss_name, loss, step): - rate_loss, csp_loss = loss.unbind() - self.writer.add_scalar( - "Rate " + loss_name, rate_loss.item(), step, walltime=self.execution_time() - ) - self.writer.add_scalar( - "CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time() - ) - def burrito_class_of_model(model): if isinstance(model, models.RSCNNModel): diff --git a/tests/test_dasm.py b/tests/test_dasm.py index b0ea83e3..bdc827b8 100644 --- a/tests/test_dasm.py +++ b/tests/test_dasm.py @@ -3,6 +3,7 @@ import torch import pytest +from netam.common import BIG from netam.framework import ( crepe_exists, load_crepe, @@ -11,7 +12,7 @@ from netam.dasm import ( DASMBurrito, DASMDataset, - zero_predictions_along_diagonal, + zap_predictions_along_diagonal, ) @@ -67,7 +68,7 @@ def test_crepe_roundtrip(dasm_burrito): assert torch.equal(t1, t2) -def test_zero_diagonal(dasm_burrito): +def test_zap_diagonal(dasm_burrito): batch = dasm_burrito.val_dataset[0:2] predictions = dasm_burrito.predictions_of_batch(batch) predictions = torch.cat( @@ -75,7 +76,7 @@ def test_zero_diagonal(dasm_burrito): ) aa_parents_idxs = batch["aa_parents_idxs"].to(dasm_burrito.device) zeroed_predictions = predictions.clone() - zeroed_predictions = zero_predictions_along_diagonal( + zeroed_predictions = zap_predictions_along_diagonal( zeroed_predictions, aa_parents_idxs ) L = predictions.shape[1] @@ -83,7 +84,7 @@ def test_zero_diagonal(dasm_burrito): for i in range(L): for j in range(20): if j == aa_parents_idxs[batch_idx, i]: - assert zeroed_predictions[batch_idx, i, j] == 0.0 + assert zeroed_predictions[batch_idx, i, j] == -BIG else: assert ( zeroed_predictions[batch_idx, i, j]