From 9685343a0f21d87ca1efff486afa307e929df1fd Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 6 Feb 2025 15:47:37 -0800 Subject: [PATCH] rename DCSM (#110) Rename dasm -> ddsm and dcsm -> dasm in file contents and filenames, and move important tests back to `test_dasm` to avoid dependency on ddsm. Companion to matsengrp/dnsm-experiments-1#94. --- netam/dasm.py | 264 ++++++++++-------- netam/dcsm.py | 251 ----------------- netam/ddsm.py | 213 ++++++++++++++ netam/dnsm.py | 4 +- ....pth => ddsm_13k-v1jaffe+v1tang-joint.pth} | Bin ....yml => ddsm_13k-v1jaffe+v1tang-joint.yml} | 0 tests/old_models/{dasm_output => ddsm_output} | Bin tests/test_backward_compat.py | 32 +-- tests/test_dasm.py | 48 +--- tests/test_dcsm.py | 49 ---- tests/test_ddsm.py | 73 +++++ 11 files changed, 468 insertions(+), 466 deletions(-) delete mode 100644 netam/dcsm.py create mode 100644 netam/ddsm.py rename tests/old_models/{dasm_13k-v1jaffe+v1tang-joint.pth => ddsm_13k-v1jaffe+v1tang-joint.pth} (100%) rename tests/old_models/{dasm_13k-v1jaffe+v1tang-joint.yml => ddsm_13k-v1jaffe+v1tang-joint.yml} (100%) rename tests/old_models/{dasm_output => ddsm_output} (100%) delete mode 100644 tests/test_dcsm.py create mode 100644 tests/test_ddsm.py diff --git a/netam/dasm.py b/netam/dasm.py index 811e30d6..85f13750 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -1,21 +1,67 @@ -"""Here we define a model that outputs a vector of 20 amino acid preferences.""" +"""Defining the deep natural selection model (DNSM).""" + +import copy import torch import torch.nn.functional as F -from netam.common import clamp_probability -from netam.dxsm import DXSMDataset, DXSMBurrito, zap_predictions_along_diagonal -import netam.framework as framework +from netam.common import ( + clamp_probability, + BIG, +) +from netam.dxsm import DXSMDataset, DXSMBurrito import netam.molevol as molevol -import netam.sequences as sequences -import copy -from typing import Tuple + +from netam.sequences import ( + build_stop_codon_indicator_tensor, + nt_idx_tensor_of_str, + codon_idx_tensor_of_str_ambig, + AMBIGUOUS_CODON_IDX, + CODON_AA_INDICATOR_MATRIX, +) class DASMDataset(DXSMDataset): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + assert len(self.nt_parents) == len(self.nt_children) + # We need to add codon index tensors to the dataset. + + self.max_codon_seq_len = self.max_aa_seq_len + self.codon_parents_idxss = torch.full_like( + self.aa_parents_idxss, AMBIGUOUS_CODON_IDX + ) + self.codon_children_idxss = self.codon_parents_idxss.clone() + + # We are using the modified nt_parents and nt_children here because we + # don't want any funky symbols in our codon indices. + for i, (nt_parent, nt_child) in enumerate( + zip(self.nt_parents, self.nt_children) + ): + assert len(nt_parent) % 3 == 0 + codon_seq_len = len(nt_parent) // 3 + self.codon_parents_idxss[i, :codon_seq_len] = codon_idx_tensor_of_str_ambig( + nt_parent + ) + self.codon_children_idxss[i, :codon_seq_len] = ( + codon_idx_tensor_of_str_ambig(nt_child) + ) + assert torch.max(self.codon_parents_idxss) <= AMBIGUOUS_CODON_IDX + def update_neutral_probs(self): - neutral_aa_probs_l = [] + """Update the neutral mutation probabilities for the dataset. + + This is a somewhat vague name, but that's because it includes all of the various + types of neutral mutation probabilities that we might want to compute. + + In this case it's the neutral codon probabilities. + """ + neutral_codon_probs_l = [] for nt_parent, mask, nt_rates, nt_csps, branch_length in zip( self.nt_parents, @@ -33,7 +79,7 @@ def update_neutral_probs(self): multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. - parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) + parent_idxs = nt_idx_tensor_of_str(nt_parent.replace("N", "A")) parent_len = len(nt_parent) mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) @@ -41,173 +87,165 @@ def update_neutral_probs(self): nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask]) - neutral_aa_probs = molevol.neutral_aa_probs( + neutral_codon_probs = molevol.neutral_codon_probs( parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), nt_csps.reshape(-1, 3, 4), multihit_model=multihit_model, ) - if not torch.isfinite(neutral_aa_probs).all(): - print(f"Found a non-finite neutral_aa_probs") + if not torch.isfinite(neutral_codon_probs).all(): + print(f"Found a non-finite neutral_codon_prob") print(f"nt_parent: {nt_parent}") print(f"mask: {mask}") print(f"nt_rates: {nt_rates}") print(f"nt_csps: {nt_csps}") print(f"branch_length: {branch_length}") - raise ValueError(f"neutral_aa_probs is not finite: {neutral_aa_probs}") + raise ValueError( + f"neutral_codon_probs is not finite: {neutral_codon_probs}" + ) # Ensure that all values are positive before taking the log later - neutral_aa_probs = clamp_probability(neutral_aa_probs) + neutral_codon_probs = clamp_probability(neutral_codon_probs) - pad_len = self.max_aa_seq_len - neutral_aa_probs.shape[0] + pad_len = self.max_aa_seq_len - neutral_codon_probs.shape[0] if pad_len > 0: - neutral_aa_probs = F.pad( - neutral_aa_probs, (0, 0, 0, pad_len), value=1e-8 + neutral_codon_probs = F.pad( + neutral_codon_probs, (0, 0, 0, pad_len), value=1e-8 ) # Here we zero out masked positions. - neutral_aa_probs *= mask[:, None] + neutral_codon_probs *= mask[:, None] - neutral_aa_probs_l.append(neutral_aa_probs) + neutral_codon_probs_l.append(neutral_codon_probs) # Note that our masked out positions will have a nan log probability, # which will require us to handle them correctly downstream. - self.log_neutral_aa_probss = torch.log(torch.stack(neutral_aa_probs_l)) + self.log_neutral_codon_probss = torch.log(torch.stack(neutral_codon_probs_l)) def __getitem__(self, idx): return { + "codon_parents_idxs": self.codon_parents_idxss[idx], + "codon_children_idxs": self.codon_children_idxss[idx], "aa_parents_idxs": self.aa_parents_idxss[idx], "aa_children_idxs": self.aa_children_idxss[idx], "subs_indicator": self.aa_subs_indicators[idx], "mask": self.masks[idx], - "log_neutral_aa_probs": self.log_neutral_aa_probss[idx], + "log_neutral_codon_probs": self.log_neutral_codon_probss[idx], "nt_rates": self.nt_ratess[idx], "nt_csps": self.nt_cspss[idx], } def to(self, device): + self.codon_parents_idxss = self.codon_parents_idxss.to(device) + self.codon_children_idxss = self.codon_children_idxss.to(device) self.aa_parents_idxss = self.aa_parents_idxss.to(device) self.aa_children_idxss = self.aa_children_idxss.to(device) self.aa_subs_indicators = self.aa_subs_indicators.to(device) self.masks = self.masks.to(device) - self.log_neutral_aa_probss = self.log_neutral_aa_probss.to(device) + self.log_neutral_codon_probss = self.log_neutral_codon_probss.to(device) self.nt_ratess = self.nt_ratess.to(device) self.nt_cspss = self.nt_cspss.to(device) if self.multihit_model is not None: self.multihit_model = self.multihit_model.to(device) -class DASMBurrito(framework.TwoLossMixin, DXSMBurrito): +class DASMBurrito(DXSMBurrito): + model_type = "dasm" - def __init__(self, *args, loss_weights: list = [1.0, 0.01], **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.xent_loss = torch.nn.CrossEntropyLoss() - self.loss_weights = torch.tensor(loss_weights).to(self.device) + self.stop_codon_zapper = (build_stop_codon_indicator_tensor() * -BIG).to( + self.device + ) + self.aa_codon_indicator_matrix = CODON_AA_INDICATOR_MATRIX.to(self.device).T def prediction_pair_of_batch(self, batch): - """Get log neutral AA probabilities and log selection factors for a batch of - data.""" + """Get log neutral codon substitution probabilities and log selection factors + for a batch of data. + + We don't mask on the output, which will thus contain junk in all of the masked + sites. + """ aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) mask = batch["mask"].to(self.device) - log_neutral_aa_probs = batch["log_neutral_aa_probs"].to(self.device) - if not torch.isfinite(log_neutral_aa_probs[mask]).all(): + log_neutral_codon_probs = batch["log_neutral_codon_probs"].to(self.device) + if not torch.isfinite(log_neutral_codon_probs[mask]).all(): raise ValueError( - f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}" + f"log_neutral_codon_probs has non-finite values at relevant positions: {log_neutral_codon_probs[mask]}" ) log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask) - return log_neutral_aa_probs, log_selection_factors + return log_neutral_codon_probs, log_selection_factors - def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors): - """Take the sum of the neutral mutation log probabilities and the selection - factors. + def predictions_of_batch(self, batch): + """Make log probability predictions for a batch of data. - In contrast to a DNSM, each of these now have last dimension of 20. - """ - predictions = log_neutral_aa_probs + log_selection_factors - assert torch.isnan(predictions).sum() == 0 - return predictions + In this case they are log probabilities of codons, which are made to be + probabilities by setting the parent codon to 1 - sum(children). - def predictions_of_batch(self, batch): - """Make predictions for a batch of data. + After all this, we clip the probabilities below to avoid log(0) issues. + So, in cases when the sum of the children is > 1, we don't give a + normalized probability distribution, but that won't crash the loss + calculation because that step uses softmax. - Note that we use the mask for prediction as part of the input for the - transformer, though we don't mask the predictions themselves. + Note that make all ambiguous codons nan in the output, ensuring that + they must get properly masked downstream. """ - log_neutral_aa_probs, log_selection_factors = self.prediction_pair_of_batch( + log_neutral_codon_probs, log_selection_factors = self.prediction_pair_of_batch( batch ) - return self.predictions_of_pair(log_neutral_aa_probs, log_selection_factors) + + # This code block, in other burritos, is done in a separate function, + # but we can't do that here because we need to normalize the + # probabilities in a way that is not possible without having the index + # of the parent codon. Namely, we need to set the parent codon to 1 - + # sum(children). + + # The aa_codon_indicator_matrix lifts things up from aa land to codon land. + log_preds = ( + log_neutral_codon_probs + + log_selection_factors @ self.aa_codon_indicator_matrix + + self.stop_codon_zapper + ) + assert torch.isnan(log_preds).sum() == 0 + + parent_indices = batch["codon_parents_idxs"].to(self.device) # Shape: [B, L] + valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L] + + # Convert to linear space so we can add probabilities. + preds = torch.exp(log_preds) + + # Zero out the parent indices in preds, while keeping the computation + # graph intact. + preds_zeroer = torch.ones_like(preds) + preds_zeroer[valid_mask, parent_indices[valid_mask]] = 0.0 + preds = preds * preds_zeroer + + # Calculate the non-parent sum after zeroing out the parent indices. + non_parent_sum = preds[valid_mask, :].sum(dim=-1) + + # Add these parent values back in, again keeping the computation graph intact. + preds_parent = torch.zeros_like(preds) + preds_parent[valid_mask, parent_indices[valid_mask]] = 1.0 - non_parent_sum + preds = preds + preds_parent + + # We have to clamp the predictions to avoid log(0) issues. + preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps) + + log_preds = torch.log(preds) + + # Set ambiguous codons to nan to make sure that we handle them correctly downstream. + log_preds[~valid_mask, :] = float("nan") + + return log_preds def loss_of_batch(self, batch): - aa_subs_indicator = batch["subs_indicator"].to(self.device) - # Netam issue #16: child mask would be preferable here. + codon_children_idxs = batch["codon_children_idxs"].to(self.device) mask = batch["mask"].to(self.device) - aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) - aa_children_idxs = batch["aa_children_idxs"].to(self.device) - masked_aa_subs_indicator = aa_subs_indicator.masked_select(mask) - predictions = self.predictions_of_batch(batch) - - # "Zapping" out the diagonal means setting it to zero in log space by - # setting it to -BIG. This is a no-op for sites that have an X - # (ambiguous AA) in the parent. This could cause problems in principle, - # but in practice we mask out sites with Xs in the parent for the - # mut_pos_loss, and we mask out sites with no substitution for the CSP - # loss. The latter class of sites also eliminates sites that have Xs in - # the parent or child (see sequences.aa_subs_indicator_tensor_of). - - predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs) - - # After zapping out the diagonal, we can effectively sum over the - # off-diagonal elements to get the probability of a nonsynonymous - # substitution. - subs_pos_pred = torch.sum(torch.exp(predictions), dim=-1) - subs_pos_pred = subs_pos_pred.masked_select(mask) - subs_pos_pred = clamp_probability(subs_pos_pred) - subs_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator) - - # We now need to calculate the conditional substitution probability - # (CSP) loss. We have already zapped out the diagonal, and we're in - # logit space, so we are set up for using the cross entropy loss. - # However we have to mask out the sites that are not substituted, i.e. - # the sites for which aa_subs_indicator is 0. - subs_mask = (aa_subs_indicator == 1) & mask - csp_pred = predictions[subs_mask] - csp_targets = aa_children_idxs[subs_mask] - csp_loss = self.xent_loss(csp_pred, csp_targets) - return torch.stack([subs_pos_loss, csp_loss]) - - # This is not used anywhere, except for in a few tests. Keeping it around - # for that reason. - def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]): - """Build a selection matrix from a parent nucleotide sequence, a heavy-chain, - light-chain pair. - - Values at ambiguous sites are meaningless. Returned value is a tuple of - selection matrix for heavy and light chain sequences. - """ - # This is simpler than the equivalent in dnsm.py because we get the selection - # matrix directly. Note that selection_factors_of_aa_str does the exponentiation - # so this indeed gives us the selection factors, not the log selection factors. - aa_parent_pair = tuple(map(sequences.translate_sequence, parent)) - per_aa_selection_factorss = self.model.selection_factors_of_aa_str( - aa_parent_pair - ) - result = [] - for per_aa_selection_factors, aa_parent in zip( - per_aa_selection_factorss, aa_parent_pair - ): - aa_parent_idxs = torch.tensor(sequences.aa_idx_array_of_str(aa_parent)) - if len(per_aa_selection_factors) > 0: - result.append( - zap_predictions_along_diagonal( - per_aa_selection_factors.unsqueeze(0), - aa_parent_idxs.unsqueeze(0), - fill=1.0, - ).squeeze(0) - ) - else: - result.append(per_aa_selection_factors) + predictions = self.predictions_of_batch(batch)[mask] + assert torch.isnan(predictions).sum() == 0 + codon_children_idxs = codon_children_idxs[mask] - return tuple(result) + return self.xent_loss(predictions, codon_children_idxs) diff --git a/netam/dcsm.py b/netam/dcsm.py deleted file mode 100644 index 430d8cac..00000000 --- a/netam/dcsm.py +++ /dev/null @@ -1,251 +0,0 @@ -"""Defining the deep natural selection model (DNSM).""" - -import copy - -import torch -import torch.nn.functional as F - -from netam.common import ( - clamp_probability, - BIG, -) -from netam.dxsm import DXSMDataset, DXSMBurrito -import netam.molevol as molevol - -from netam.sequences import ( - build_stop_codon_indicator_tensor, - nt_idx_tensor_of_str, - codon_idx_tensor_of_str_ambig, - AMBIGUOUS_CODON_IDX, - CODON_AA_INDICATOR_MATRIX, -) - - -class DCSMDataset(DXSMDataset): - - def __init__( - self, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - assert len(self.nt_parents) == len(self.nt_children) - # We need to add codon index tensors to the dataset. - - self.max_codon_seq_len = self.max_aa_seq_len - self.codon_parents_idxss = torch.full_like( - self.aa_parents_idxss, AMBIGUOUS_CODON_IDX - ) - self.codon_children_idxss = self.codon_parents_idxss.clone() - - # We are using the modified nt_parents and nt_children here because we - # don't want any funky symbols in our codon indices. - for i, (nt_parent, nt_child) in enumerate( - zip(self.nt_parents, self.nt_children) - ): - assert len(nt_parent) % 3 == 0 - codon_seq_len = len(nt_parent) // 3 - self.codon_parents_idxss[i, :codon_seq_len] = codon_idx_tensor_of_str_ambig( - nt_parent - ) - self.codon_children_idxss[i, :codon_seq_len] = ( - codon_idx_tensor_of_str_ambig(nt_child) - ) - assert torch.max(self.codon_parents_idxss) <= AMBIGUOUS_CODON_IDX - - def update_neutral_probs(self): - """Update the neutral mutation probabilities for the dataset. - - This is a somewhat vague name, but that's because it includes all of the various - types of neutral mutation probabilities that we might want to compute. - - In this case it's the neutral codon probabilities. - """ - neutral_codon_probs_l = [] - - for nt_parent, mask, nt_rates, nt_csps, branch_length in zip( - self.nt_parents, - self.masks, - self.nt_ratess, - self.nt_cspss, - self._branch_lengths, - ): - mask = mask.to("cpu") - nt_rates = nt_rates.to("cpu") - nt_csps = nt_csps.to("cpu") - if self.multihit_model is not None: - multihit_model = copy.deepcopy(self.multihit_model).to("cpu") - else: - multihit_model = None - # Note we are replacing all Ns with As, which means that we need to be careful - # with masking out these positions later. We do this below. - parent_idxs = nt_idx_tensor_of_str(nt_parent.replace("N", "A")) - parent_len = len(nt_parent) - - mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) - nt_csps = nt_csps[:parent_len, :] - nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] - molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask]) - - neutral_codon_probs = molevol.neutral_codon_probs( - parent_idxs.reshape(-1, 3), - mut_probs.reshape(-1, 3), - nt_csps.reshape(-1, 3, 4), - multihit_model=multihit_model, - ) - - if not torch.isfinite(neutral_codon_probs).all(): - print(f"Found a non-finite neutral_codon_prob") - print(f"nt_parent: {nt_parent}") - print(f"mask: {mask}") - print(f"nt_rates: {nt_rates}") - print(f"nt_csps: {nt_csps}") - print(f"branch_length: {branch_length}") - raise ValueError( - f"neutral_codon_probs is not finite: {neutral_codon_probs}" - ) - - # Ensure that all values are positive before taking the log later - neutral_codon_probs = clamp_probability(neutral_codon_probs) - - pad_len = self.max_aa_seq_len - neutral_codon_probs.shape[0] - if pad_len > 0: - neutral_codon_probs = F.pad( - neutral_codon_probs, (0, 0, 0, pad_len), value=1e-8 - ) - # Here we zero out masked positions. - neutral_codon_probs *= mask[:, None] - - neutral_codon_probs_l.append(neutral_codon_probs) - - # Note that our masked out positions will have a nan log probability, - # which will require us to handle them correctly downstream. - self.log_neutral_codon_probss = torch.log(torch.stack(neutral_codon_probs_l)) - - def __getitem__(self, idx): - return { - "codon_parents_idxs": self.codon_parents_idxss[idx], - "codon_children_idxs": self.codon_children_idxss[idx], - "aa_parents_idxs": self.aa_parents_idxss[idx], - "aa_children_idxs": self.aa_children_idxss[idx], - "subs_indicator": self.aa_subs_indicators[idx], - "mask": self.masks[idx], - "log_neutral_codon_probs": self.log_neutral_codon_probss[idx], - "nt_rates": self.nt_ratess[idx], - "nt_csps": self.nt_cspss[idx], - } - - def to(self, device): - self.codon_parents_idxss = self.codon_parents_idxss.to(device) - self.codon_children_idxss = self.codon_children_idxss.to(device) - self.aa_parents_idxss = self.aa_parents_idxss.to(device) - self.aa_children_idxss = self.aa_children_idxss.to(device) - self.aa_subs_indicators = self.aa_subs_indicators.to(device) - self.masks = self.masks.to(device) - self.log_neutral_codon_probss = self.log_neutral_codon_probss.to(device) - self.nt_ratess = self.nt_ratess.to(device) - self.nt_cspss = self.nt_cspss.to(device) - if self.multihit_model is not None: - self.multihit_model = self.multihit_model.to(device) - - -class DCSMBurrito(DXSMBurrito): - - model_type = "dcsm" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.xent_loss = torch.nn.CrossEntropyLoss() - self.stop_codon_zapper = (build_stop_codon_indicator_tensor() * -BIG).to( - self.device - ) - self.aa_codon_indicator_matrix = CODON_AA_INDICATOR_MATRIX.to(self.device).T - - def prediction_pair_of_batch(self, batch): - """Get log neutral codon substitution probabilities and log selection factors - for a batch of data. - - We don't mask on the output, which will thus contain junk in all of the masked - sites. - """ - aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) - mask = batch["mask"].to(self.device) - log_neutral_codon_probs = batch["log_neutral_codon_probs"].to(self.device) - if not torch.isfinite(log_neutral_codon_probs[mask]).all(): - raise ValueError( - f"log_neutral_codon_probs has non-finite values at relevant positions: {log_neutral_codon_probs[mask]}" - ) - log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask) - return log_neutral_codon_probs, log_selection_factors - - def predictions_of_batch(self, batch): - """Make log probability predictions for a batch of data. - - In this case they are log probabilities of codons, which are made to be - probabilities by setting the parent codon to 1 - sum(children). - - After all this, we clip the probabilities below to avoid log(0) issues. - So, in cases when the sum of the children is > 1, we don't give a - normalized probability distribution, but that won't crash the loss - calculation because that step uses softmax. - - Note that make all ambiguous codons nan in the output, ensuring that - they must get properly masked downstream. - """ - log_neutral_codon_probs, log_selection_factors = self.prediction_pair_of_batch( - batch - ) - - # This code block, in other burritos, is done in a separate function, - # but we can't do that here because we need to normalize the - # probabilities in a way that is not possible without having the index - # of the parent codon. Namely, we need to set the parent codon to 1 - - # sum(children). - - # The aa_codon_indicator_matrix lifts things up from aa land to codon land. - log_preds = ( - log_neutral_codon_probs - + log_selection_factors @ self.aa_codon_indicator_matrix - + self.stop_codon_zapper - ) - assert torch.isnan(log_preds).sum() == 0 - - parent_indices = batch["codon_parents_idxs"].to(self.device) # Shape: [B, L] - valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L] - - # Convert to linear space so we can add probabilities. - preds = torch.exp(log_preds) - - # Zero out the parent indices in preds, while keeping the computation - # graph intact. - preds_zeroer = torch.ones_like(preds) - preds_zeroer[valid_mask, parent_indices[valid_mask]] = 0.0 - preds = preds * preds_zeroer - - # Calculate the non-parent sum after zeroing out the parent indices. - non_parent_sum = preds[valid_mask, :].sum(dim=-1) - - # Add these parent values back in, again keeping the computation graph intact. - preds_parent = torch.zeros_like(preds) - preds_parent[valid_mask, parent_indices[valid_mask]] = 1.0 - non_parent_sum - preds = preds + preds_parent - - # We have to clamp the predictions to avoid log(0) issues. - preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps) - - log_preds = torch.log(preds) - - # Set ambiguous codons to nan to make sure that we handle them correctly downstream. - log_preds[~valid_mask, :] = float("nan") - - return log_preds - - def loss_of_batch(self, batch): - codon_children_idxs = batch["codon_children_idxs"].to(self.device) - mask = batch["mask"].to(self.device) - - predictions = self.predictions_of_batch(batch)[mask] - assert torch.isnan(predictions).sum() == 0 - codon_children_idxs = codon_children_idxs[mask] - - return self.xent_loss(predictions, codon_children_idxs) diff --git a/netam/ddsm.py b/netam/ddsm.py new file mode 100644 index 00000000..fce19470 --- /dev/null +++ b/netam/ddsm.py @@ -0,0 +1,213 @@ +"""Here we define a model that outputs a vector of 20 amino acid preferences.""" + +import torch +import torch.nn.functional as F + +from netam.common import clamp_probability +from netam.dxsm import DXSMDataset, DXSMBurrito, zap_predictions_along_diagonal +import netam.framework as framework +import netam.molevol as molevol +import netam.sequences as sequences +import copy +from typing import Tuple + + +class DDSMDataset(DXSMDataset): + + def update_neutral_probs(self): + neutral_aa_probs_l = [] + + for nt_parent, mask, nt_rates, nt_csps, branch_length in zip( + self.nt_parents, + self.masks, + self.nt_ratess, + self.nt_cspss, + self._branch_lengths, + ): + mask = mask.to("cpu") + nt_rates = nt_rates.to("cpu") + nt_csps = nt_csps.to("cpu") + if self.multihit_model is not None: + multihit_model = copy.deepcopy(self.multihit_model).to("cpu") + else: + multihit_model = None + # Note we are replacing all Ns with As, which means that we need to be careful + # with masking out these positions later. We do this below. + parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) + parent_len = len(nt_parent) + + mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) + nt_csps = nt_csps[:parent_len, :] + nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] + molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask]) + + neutral_aa_probs = molevol.neutral_aa_probs( + parent_idxs.reshape(-1, 3), + mut_probs.reshape(-1, 3), + nt_csps.reshape(-1, 3, 4), + multihit_model=multihit_model, + ) + + if not torch.isfinite(neutral_aa_probs).all(): + print(f"Found a non-finite neutral_aa_probs") + print(f"nt_parent: {nt_parent}") + print(f"mask: {mask}") + print(f"nt_rates: {nt_rates}") + print(f"nt_csps: {nt_csps}") + print(f"branch_length: {branch_length}") + raise ValueError(f"neutral_aa_probs is not finite: {neutral_aa_probs}") + + # Ensure that all values are positive before taking the log later + neutral_aa_probs = clamp_probability(neutral_aa_probs) + + pad_len = self.max_aa_seq_len - neutral_aa_probs.shape[0] + if pad_len > 0: + neutral_aa_probs = F.pad( + neutral_aa_probs, (0, 0, 0, pad_len), value=1e-8 + ) + # Here we zero out masked positions. + neutral_aa_probs *= mask[:, None] + + neutral_aa_probs_l.append(neutral_aa_probs) + + # Note that our masked out positions will have a nan log probability, + # which will require us to handle them correctly downstream. + self.log_neutral_aa_probss = torch.log(torch.stack(neutral_aa_probs_l)) + + def __getitem__(self, idx): + return { + "aa_parents_idxs": self.aa_parents_idxss[idx], + "aa_children_idxs": self.aa_children_idxss[idx], + "subs_indicator": self.aa_subs_indicators[idx], + "mask": self.masks[idx], + "log_neutral_aa_probs": self.log_neutral_aa_probss[idx], + "nt_rates": self.nt_ratess[idx], + "nt_csps": self.nt_cspss[idx], + } + + def to(self, device): + self.aa_parents_idxss = self.aa_parents_idxss.to(device) + self.aa_children_idxss = self.aa_children_idxss.to(device) + self.aa_subs_indicators = self.aa_subs_indicators.to(device) + self.masks = self.masks.to(device) + self.log_neutral_aa_probss = self.log_neutral_aa_probss.to(device) + self.nt_ratess = self.nt_ratess.to(device) + self.nt_cspss = self.nt_cspss.to(device) + if self.multihit_model is not None: + self.multihit_model = self.multihit_model.to(device) + + +class DDSMBurrito(framework.TwoLossMixin, DXSMBurrito): + model_type = "ddsm" + + def __init__(self, *args, loss_weights: list = [1.0, 0.01], **kwargs): + super().__init__(*args, **kwargs) + self.xent_loss = torch.nn.CrossEntropyLoss() + self.loss_weights = torch.tensor(loss_weights).to(self.device) + + def prediction_pair_of_batch(self, batch): + """Get log neutral AA probabilities and log selection factors for a batch of + data.""" + aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) + mask = batch["mask"].to(self.device) + log_neutral_aa_probs = batch["log_neutral_aa_probs"].to(self.device) + if not torch.isfinite(log_neutral_aa_probs[mask]).all(): + raise ValueError( + f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}" + ) + log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask) + return log_neutral_aa_probs, log_selection_factors + + def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors): + """Take the sum of the neutral mutation log probabilities and the selection + factors. + + In contrast to a DNSM, each of these now have last dimension of 20. + """ + predictions = log_neutral_aa_probs + log_selection_factors + assert torch.isnan(predictions).sum() == 0 + return predictions + + def predictions_of_batch(self, batch): + """Make predictions for a batch of data. + + Note that we use the mask for prediction as part of the input for the + transformer, though we don't mask the predictions themselves. + """ + log_neutral_aa_probs, log_selection_factors = self.prediction_pair_of_batch( + batch + ) + return self.predictions_of_pair(log_neutral_aa_probs, log_selection_factors) + + def loss_of_batch(self, batch): + aa_subs_indicator = batch["subs_indicator"].to(self.device) + # Netam issue #16: child mask would be preferable here. + mask = batch["mask"].to(self.device) + aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) + aa_children_idxs = batch["aa_children_idxs"].to(self.device) + masked_aa_subs_indicator = aa_subs_indicator.masked_select(mask) + predictions = self.predictions_of_batch(batch) + + # "Zapping" out the diagonal means setting it to zero in log space by + # setting it to -BIG. This is a no-op for sites that have an X + # (ambiguous AA) in the parent. This could cause problems in principle, + # but in practice we mask out sites with Xs in the parent for the + # mut_pos_loss, and we mask out sites with no substitution for the CSP + # loss. The latter class of sites also eliminates sites that have Xs in + # the parent or child (see sequences.aa_subs_indicator_tensor_of). + + predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs) + + # After zapping out the diagonal, we can effectively sum over the + # off-diagonal elements to get the probability of a nonsynonymous + # substitution. + subs_pos_pred = torch.sum(torch.exp(predictions), dim=-1) + subs_pos_pred = subs_pos_pred.masked_select(mask) + subs_pos_pred = clamp_probability(subs_pos_pred) + subs_pos_loss = self.bce_loss(subs_pos_pred, masked_aa_subs_indicator) + + # We now need to calculate the conditional substitution probability + # (CSP) loss. We have already zapped out the diagonal, and we're in + # logit space, so we are set up for using the cross entropy loss. + # However we have to mask out the sites that are not substituted, i.e. + # the sites for which aa_subs_indicator is 0. + subs_mask = (aa_subs_indicator == 1) & mask + csp_pred = predictions[subs_mask] + csp_targets = aa_children_idxs[subs_mask] + csp_loss = self.xent_loss(csp_pred, csp_targets) + return torch.stack([subs_pos_loss, csp_loss]) + + # This is not used anywhere, except for in a few tests. Keeping it around + # for that reason. + def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]): + """Build a selection matrix from a parent nucleotide sequence, a heavy-chain, + light-chain pair. + + Values at ambiguous sites are meaningless. Returned value is a tuple of + selection matrix for heavy and light chain sequences. + """ + # This is simpler than the equivalent in dnsm.py because we get the selection + # matrix directly. Note that selection_factors_of_aa_str does the exponentiation + # so this indeed gives us the selection factors, not the log selection factors. + aa_parent_pair = tuple(map(sequences.translate_sequence, parent)) + per_aa_selection_factorss = self.model.selection_factors_of_aa_str( + aa_parent_pair + ) + + result = [] + for per_aa_selection_factors, aa_parent in zip( + per_aa_selection_factorss, aa_parent_pair + ): + aa_parent_idxs = torch.tensor(sequences.aa_idx_array_of_str(aa_parent)) + if len(per_aa_selection_factors) > 0: + result.append( + zap_predictions_along_diagonal( + per_aa_selection_factors.unsqueeze(0), + aa_parent_idxs.unsqueeze(0), + fill=1.0, + ).squeeze(0) + ) + else: + result.append(per_aa_selection_factors) + + return tuple(result) diff --git a/netam/dnsm.py b/netam/dnsm.py index abce4255..568b3798 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -23,10 +23,10 @@ def update_neutral_probs(self): This is a somewhat vague name, but that's because it includes both the cases of the DNSM (in which case it's neutral probabilities of any nonsynonymous - mutation) and the DASM (in which case it's the neutral probabilities of mutation + mutation) and the DDSM (in which case it's the neutral probabilities of mutation to the various amino acids). - This is the case of the DNSM, but the DASM will override this method. + This is the case of the DNSM, but the DDSM will override this method. """ neutral_aa_mut_prob_l = [] diff --git a/tests/old_models/dasm_13k-v1jaffe+v1tang-joint.pth b/tests/old_models/ddsm_13k-v1jaffe+v1tang-joint.pth similarity index 100% rename from tests/old_models/dasm_13k-v1jaffe+v1tang-joint.pth rename to tests/old_models/ddsm_13k-v1jaffe+v1tang-joint.pth diff --git a/tests/old_models/dasm_13k-v1jaffe+v1tang-joint.yml b/tests/old_models/ddsm_13k-v1jaffe+v1tang-joint.yml similarity index 100% rename from tests/old_models/dasm_13k-v1jaffe+v1tang-joint.yml rename to tests/old_models/ddsm_13k-v1jaffe+v1tang-joint.yml diff --git a/tests/old_models/dasm_output b/tests/old_models/ddsm_output similarity index 100% rename from tests/old_models/dasm_output rename to tests/old_models/ddsm_output diff --git a/tests/test_backward_compat.py b/tests/test_backward_compat.py index 2ec32840..f3941b88 100644 --- a/tests/test_backward_compat.py +++ b/tests/test_backward_compat.py @@ -1,7 +1,7 @@ import torch import pandas as pd import pytest -from netam.dasm import zap_predictions_along_diagonal, DASMBurrito, DASMDataset +from netam.ddsm import zap_predictions_along_diagonal, DDSMBurrito, DDSMDataset from netam.common import force_spawn from tqdm import tqdm @@ -10,18 +10,18 @@ @pytest.fixture(scope="module") -def fixed_dasm_val_burrito(pcp_df): +def fixed_ddsm_val_burrito(pcp_df): force_spawn() """Fixture that returns the DNSM Burrito object.""" pcp_df["in_train"] = True pcp_df.loc[pcp_df.index[-15:], "in_train"] = False - dasm_crepe = load_crepe("tests/old_models/dasm_13k-v1jaffe+v1tang-joint") - model = dasm_crepe.model - train_dataset, val_dataset = DASMDataset.train_val_datasets_of_pcp_df( + ddsm_crepe = load_crepe("tests/old_models/ddsm_13k-v1jaffe+v1tang-joint") + model = ddsm_crepe.model + train_dataset, val_dataset = DDSMDataset.train_val_datasets_of_pcp_df( pcp_df, model.known_token_count ) - burrito = DASMBurrito( + burrito = DDSMBurrito( train_dataset, val_dataset, model, @@ -33,21 +33,21 @@ def fixed_dasm_val_burrito(pcp_df): return burrito -def test_predictions_of_batch(fixed_dasm_val_burrito): +def test_predictions_of_batch(fixed_ddsm_val_burrito): # These outputs were produced by the comparison code in this test, but # written to the files referenced here. The code state was netam 3c632fa. # (however, this test did not exist in the codebase at that time) branch_lengths = torch.tensor( pd.read_csv("tests/old_models/val_branch_lengths.csv")["branch_length"] ).double() - these_branch_lengths = fixed_dasm_val_burrito.val_dataset.branch_lengths.double() + these_branch_lengths = fixed_ddsm_val_burrito.val_dataset.branch_lengths.double() assert torch.allclose(branch_lengths, these_branch_lengths) - fixed_dasm_val_burrito.model.eval() - val_loader = fixed_dasm_val_burrito.build_val_loader() + fixed_ddsm_val_burrito.model.eval() + val_loader = fixed_ddsm_val_burrito.build_val_loader() predictions_list = [] for batch in tqdm(val_loader, desc="Calculating model predictions"): predictions = zap_predictions_along_diagonal( - fixed_dasm_val_burrito.predictions_of_batch(batch), batch["aa_parents_idxs"] + fixed_ddsm_val_burrito.predictions_of_batch(batch), batch["aa_parents_idxs"] ) predictions_list.append(predictions.detach().cpu()) these_predictions = torch.cat(predictions_list, axis=0).double() @@ -62,18 +62,18 @@ def test_predictions_of_batch(fixed_dasm_val_burrito): # https://github.com/matsengrp/netam/pull/92. def test_old_crepe_outputs(): example_seq = "QVQLVESGGGVVQPGRSLRLSCAASGFTFSSSGMHWVRQAPGKGLEWVAVIWYDGSNKYYADSVKGRFTISRDNSKNTVYLQMNSLRAEDTAVYYCAREGHSNYPYYYYYMDVWGKGTTVTVSS" - dasm_crepe = load_crepe("tests/old_models/dasm_13k-v1jaffe+v1tang-joint") + ddsm_crepe = load_crepe("tests/old_models/ddsm_13k-v1jaffe+v1tang-joint") dnsm_crepe = load_crepe("tests/old_models/dnsm_13k-v1jaffe+v1tang-joint") - dasm_vals = torch.nan_to_num( + ddsm_vals = torch.nan_to_num( set_wt_to_nan( - torch.load("tests/old_models/dasm_output", weights_only=True), example_seq + torch.load("tests/old_models/ddsm_output", weights_only=True), example_seq ), 0.0, ) dnsm_vals = torch.load("tests/old_models/dnsm_output", weights_only=True) - dasm_result = torch.nan_to_num(dasm_crepe([example_seq])[0], 0.0) + ddsm_result = torch.nan_to_num(ddsm_crepe([example_seq])[0], 0.0) dnsm_result = dnsm_crepe([example_seq])[0] - assert torch.allclose(dasm_result, dasm_vals) + assert torch.allclose(ddsm_result, ddsm_vals) assert torch.allclose(dnsm_result, dnsm_vals) diff --git a/tests/test_dasm.py b/tests/test_dasm.py index bc29b033..0ef0ae7c 100644 --- a/tests/test_dasm.py +++ b/tests/test_dasm.py @@ -3,30 +3,30 @@ import torch import pytest -from netam.common import BIG, force_spawn from netam.framework import ( crepe_exists, load_crepe, ) +from netam.common import BIG, force_spawn from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dasm import ( DASMBurrito, DASMDataset, - zap_predictions_along_diagonal, ) +from netam.dxsm import zap_predictions_along_diagonal from netam.sequences import ( MAX_KNOWN_TOKEN_COUNT, + AA_AMBIG_IDX, TOKEN_STR_SORTED, token_mask_of_aa_idxs, + CODON_AA_INDICATOR_MATRIX, ) -torch.set_printoptions(precision=10) - @pytest.fixture(scope="module", params=["pcp_df", "pcp_df_paired"]) def dasm_burrito(pcp_df): force_spawn() - """Fixture that returns the DNSM Burrito object.""" + """Fixture that returns the DASM Burrito object.""" pcp_df["in_train"] = True pcp_df.loc[pcp_df.index[-15:], "in_train"] = False train_dataset, val_dataset = DASMDataset.train_val_datasets_of_pcp_df( @@ -93,15 +93,19 @@ def test_crepe_roundtrip(dasm_burrito): def test_zap_diagonal(dasm_burrito): batch = dasm_burrito.val_dataset[0:2] - predictions = dasm_burrito.predictions_of_batch(batch) - predictions = torch.cat( - [predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1 - ) + codon_predictions = dasm_burrito.predictions_of_batch(batch) + predictions = torch.log((torch.exp(codon_predictions) @ CODON_AA_INDICATOR_MATRIX)) aa_parents_idxs = batch["aa_parents_idxs"].to(dasm_burrito.device) + # These sites are set to NaN, so we need to make them zero for comparison + invalid_mask = aa_parents_idxs >= AA_AMBIG_IDX # Shape: [B, L] + predictions[invalid_mask] = 0.0 zeroed_predictions = predictions.clone() zeroed_predictions = zap_predictions_along_diagonal( zeroed_predictions, aa_parents_idxs ) + print(predictions.shape, aa_parents_idxs.shape) + print(predictions) + print(aa_parents_idxs) L = predictions.shape[1] for batch_idx in range(2): for i in range(L): @@ -126,29 +130,3 @@ def test_selection_factors_of_aa_str(dasm_burrito): assert len(res[1]) == len(aa_parent_pair[1]) assert res[0].shape[1] == 20 assert res[1].shape[1] == 20 - - -def test_build_selection_matrix_from_parent(dasm_burrito): - parent = dasm_burrito.val_dataset.nt_parents[0] - parent_aa_idxs = dasm_burrito.val_dataset.aa_parents_idxss[0] - aa_mask = dasm_burrito.val_dataset.masks[0] - aa_parent = "".join(TOKEN_STR_SORTED[i] for i in parent_aa_idxs) - # This won't work if we start testing with ambiguous sequences - aa_parent = aa_parent.replace("X", "") - - separator_idx = aa_parent.index("^") * 3 - light_chain_seq = parent[:separator_idx] - heavy_chain_seq = parent[separator_idx + 3 :] - - direct_val = dasm_burrito.build_selection_matrix_from_parent_aa( - parent_aa_idxs, aa_mask - ) - - indirect_val = dasm_burrito._build_selection_matrix_from_parent( - (light_chain_seq, heavy_chain_seq) - ) - - assert torch.allclose(direct_val[: len(indirect_val[0])], indirect_val[0]) - assert torch.allclose( - direct_val[len(indirect_val[0]) + 1 :][: len(indirect_val[1])], indirect_val[1] - ) diff --git a/tests/test_dcsm.py b/tests/test_dcsm.py deleted file mode 100644 index 4a3abbec..00000000 --- a/tests/test_dcsm.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import pytest - -from netam.common import force_spawn -from netam.sequences import MAX_KNOWN_TOKEN_COUNT -from netam.models import TransformerBinarySelectionModelWiggleAct -from netam.dcsm import ( - DCSMBurrito, - DCSMDataset, -) - - -@pytest.fixture(scope="module") -def dcsm_burrito(pcp_df): - force_spawn() - """Fixture that returns the DNSM Burrito object.""" - pcp_df["in_train"] = True - pcp_df.loc[pcp_df.index[-15:], "in_train"] = False - train_dataset, val_dataset = DCSMDataset.train_val_datasets_of_pcp_df( - pcp_df, MAX_KNOWN_TOKEN_COUNT - ) - - model = TransformerBinarySelectionModelWiggleAct( - nhead=2, - d_model_per_head=4, - dim_feedforward=256, - layer_count=2, - output_dim=20, - ) - - burrito = DCSMBurrito( - train_dataset, - val_dataset, - model, - batch_size=32, - learning_rate=0.001, - min_learning_rate=0.0001, - ) - burrito.joint_train( - epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False - ) - return burrito - - -def test_parallel_branch_length_optimization(dcsm_burrito): - dataset = dcsm_burrito.val_dataset - parallel_branch_lengths = dcsm_burrito.find_optimal_branch_lengths(dataset) - branch_lengths = dcsm_burrito.serial_find_optimal_branch_lengths(dataset) - assert torch.allclose(branch_lengths, parallel_branch_lengths) diff --git a/tests/test_ddsm.py b/tests/test_ddsm.py new file mode 100644 index 00000000..d514d0ba --- /dev/null +++ b/tests/test_ddsm.py @@ -0,0 +1,73 @@ +import torch +import pytest + +from netam.common import force_spawn +from netam.models import TransformerBinarySelectionModelWiggleAct +from netam.ddsm import ( + DDSMBurrito, + DDSMDataset, +) +from netam.sequences import ( + MAX_KNOWN_TOKEN_COUNT, + TOKEN_STR_SORTED, +) + +torch.set_printoptions(precision=10) + + +@pytest.fixture(scope="module", params=["pcp_df", "pcp_df_paired"]) +def ddsm_burrito(pcp_df): + force_spawn() + """Fixture that returns the DNSM Burrito object.""" + pcp_df["in_train"] = True + pcp_df.loc[pcp_df.index[-15:], "in_train"] = False + train_dataset, val_dataset = DDSMDataset.train_val_datasets_of_pcp_df( + pcp_df, MAX_KNOWN_TOKEN_COUNT + ) + + model = TransformerBinarySelectionModelWiggleAct( + nhead=2, + d_model_per_head=4, + dim_feedforward=256, + layer_count=2, + output_dim=20, + ) + + burrito = DDSMBurrito( + train_dataset, + val_dataset, + model, + batch_size=32, + learning_rate=0.001, + min_learning_rate=0.0001, + ) + burrito.joint_train( + epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False + ) + return burrito + + +def test_build_selection_matrix_from_parent(ddsm_burrito): + parent = ddsm_burrito.val_dataset.nt_parents[0] + parent_aa_idxs = ddsm_burrito.val_dataset.aa_parents_idxss[0] + aa_mask = ddsm_burrito.val_dataset.masks[0] + aa_parent = "".join(TOKEN_STR_SORTED[i] for i in parent_aa_idxs) + # This won't work if we start testing with ambiguous sequences + aa_parent = aa_parent.replace("X", "") + + separator_idx = aa_parent.index("^") * 3 + light_chain_seq = parent[:separator_idx] + heavy_chain_seq = parent[separator_idx + 3 :] + + direct_val = ddsm_burrito.build_selection_matrix_from_parent_aa( + parent_aa_idxs, aa_mask + ) + + indirect_val = ddsm_burrito._build_selection_matrix_from_parent( + (light_chain_seq, heavy_chain_seq) + ) + + assert torch.allclose(direct_val[: len(indirect_val[0])], indirect_val[0]) + assert torch.allclose( + direct_val[len(indirect_val[0]) + 1 :][: len(indirect_val[1])], indirect_val[1] + )