diff --git a/netam/codon_prob.py b/netam/codon_prob.py index 701f5ef7..d9719b8c 100644 --- a/netam/codon_prob.py +++ b/netam/codon_prob.py @@ -1,5 +1,5 @@ import torch -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset import torch.nn as nn from tqdm import tqdm import numpy as np @@ -15,7 +15,6 @@ from netam.common import BASES, stack_heterogeneous, clamp_probability import netam.framework as framework from netam.framework import Burrito -from netam.models import AbstractBinarySelectionModel def hit_class(codon1, codon2): return sum(c1 != c2 for c1, c2 in zip(codon1, codon2)) @@ -41,32 +40,32 @@ def hit_class(codon1, codon2): # make a dict mapping from codon to triple integer index codon_to_idxs = {base_1+base_2+base_3: (i, j, k) for i, base_1 in enumerate(BASES) for j, base_2 in enumerate(BASES) for k, base_3 in enumerate(BASES)} -def hit_class_probs(hit_class_tensor, codon_probs): - """ - Calculate total probabilities for each number of differences between codons. +# def hit_class_probs(hit_class_tensor, codon_probs): +# """ +# Calculate total probabilities for each number of differences between codons. - Args: - - hit_class_tensor (torch.Tensor): A 4x4x4 integer tensor containing the number of differences - between each codon and a reference codon. - - codon_probs (torch.Tensor): A 4x4x4 tensor containing the probabilities of various codons. +# Args: +# - hit_class_tensor (torch.Tensor): A 4x4x4 integer tensor containing the number of differences +# between each codon and a reference codon. +# - codon_probs (torch.Tensor): A 4x4x4 tensor containing the probabilities of various codons. - Returns: - - total_probs (torch.Tensor): A 1D tensor containing the total probabilities for each number - of differences (0 to 3). - """ - total_probs = [] +# Returns: +# - total_probs (torch.Tensor): A 1D tensor containing the total probabilities for each number +# of differences (0 to 3). +# """ +# total_probs = [] - for hit_class in range(4): - # Create a mask of codons with the desired number of differences - mask = hit_class_tensor == hit_class +# for hit_class in range(4): +# # Create a mask of codons with the desired number of differences +# mask = hit_class_tensor == hit_class - # Multiply componentwise with the codon_probs tensor and sum - total_prob = (codon_probs * mask.float()).sum() +# # Multiply componentwise with the codon_probs tensor and sum +# total_prob = (codon_probs * mask.float()).sum() - # Append the total probability to the list - total_probs.append(total_prob.item()) +# # Append the total probability to the list +# total_probs.append(total_prob.item()) - return torch.tensor(total_probs) +# return torch.tensor(total_probs) def hit_class_probs_tensor(parent_codon_idxs, codon_probs): """ @@ -172,15 +171,11 @@ def __init__( trimmed_children = [child[: len(child) - len(child) % 3] for child in nt_children] self.nt_parents = stack_heterogeneous(pd.Series(sequences.nt_idx_tensor_of_str(parent.replace("N", "A")) for parent in trimmed_parents)) self.nt_children = stack_heterogeneous(pd.Series(sequences.nt_idx_tensor_of_str(child.replace("N", "A")) for child in trimmed_children)) - max_len = len(self.nt_parents[0]) - self.nt_parents_strs = [parent + ("N" * (max_len - len(parent))) for parent in trimmed_parents] - self.nt_children_strs = [child + ("N" * (max_len - len(child))) for child in trimmed_children] self.all_rates = stack_heterogeneous(pd.Series(rates[: len(rates) - len(rates) % 3] for rates in all_rates).reset_index(drop=True)) self.all_subs_probs = stack_heterogeneous(pd.Series(subs_probs[: len(subs_probs) - len(subs_probs) % 3] for subs_probs in all_subs_probs).reset_index(drop=True)) assert len(self.nt_parents) == len(self.nt_children) - # TODO get hit classes and do checks directly from tensor encoding of sequences for parent, child in zip(trimmed_parents, trimmed_children): if parent == child: raise ValueError( @@ -230,10 +225,6 @@ def update_hit_class_probs(self): self.all_subs_probs, self.branch_lengths, ): - # This encodes bases as indices in a sorted nucleotide list. Codons containing - # N's should already be masked in self.codon_mask, so treating them as A's here shouldn't matter... - # TODO Check that assertion ^^ - scaled_rates = branch_length * rates codon_probs = codon_probs_of_parent_scaled_rates_and_sub_probs( @@ -242,9 +233,9 @@ def update_hit_class_probs(self): new_hc_probs.append(hit_class_probs_tensor(reshape_for_codons(encoded_parent), codon_probs)) # We must store probability of all hit classes for arguments to cce_loss in loss_of_batch. - self.hit_class_probs = stack_heterogeneous(new_hc_probs, pad_value=-100) + self.hit_class_probs = torch.stack(new_hc_probs) - # A couple of these methods could be moved to a super class, which itself subclasses Dataset + # A couple of these methods could maybe be moved to a super class, which itself subclasses Dataset def export_branch_lengths(self, out_csv_path): pd.DataFrame({"branch_length": self.branch_lengths}).to_csv( out_csv_path, index=False @@ -268,12 +259,14 @@ def __getitem__(self, idx): } def to(self, device): - # TODO update this (and might have to encode sequences as Tensors), if used! - raise NotImplementedError - self.codon_mask = self.mask.to(device) + self.nt_parents = self.nt_parents.to(device) + self.nt_children = self.nt_children.to(device) + self.observed_hcs = self.observed_hcs.to(device) self.all_rates = self.all_rates.to(device) self.all_subs_probs = self.all_subs_probs.to(device) self.hit_class_probs = self.hit_class_probs.to(device) + self.codon_mask = self.codon_mask.to(device) + self.branch_lengths = self.branch_lengths.to(device) def flatten_and_mask_sequence_codons(input_tensor, codon_mask=None): """Flatten first dimension, that is over sequences, to return tensor @@ -308,6 +301,8 @@ def hyperparameters(self): return {} def forward(self, parent_codon_idxs: torch.Tensor, uncorrected_log_codon_probs: torch.Tensor): + """Forward function takes a tensor of target codon distributions, for each observed parent codon, + and adjusts the distributions according to the hit class adjustments.""" hit_class_tensor_t = hit_class_tensor_full[parent_codon_idxs[:, 0], parent_codon_idxs[:, 1], parent_codon_idxs[:, 2]].int() @@ -337,8 +332,6 @@ def __init__( self.cce_loss = torch.nn.CrossEntropyLoss(reduction='mean') - # For loss want categorical cross-entropy, appears in framework.py for another model - # When computing overall log-likelihood will need to account for the different sequence lengths def load_branch_lengths(self, in_csv_prefix): if self.train_loader is not None: self.train_loader.dataset.load_branch_lengths( @@ -348,8 +341,6 @@ def load_branch_lengths(self, in_csv_prefix): in_csv_prefix + ".val_branch_lengths.csv" ) - # Once optimized branch lengths, store the baseline codon-level predictions somewhere. See DNSMBurrito::predictions_of_batch - # Rates stay same, and are used to re-compute branch lengths whenever codon probs are adjusted. def loss_of_batch(self, batch): # different sequence lengths, and codons containing N's, are marked in the mask. observed_hcs = batch["observed_hcs"] @@ -359,33 +350,19 @@ def loss_of_batch(self, batch): flat_masked_hit_class_probs = flatten_and_mask_sequence_codons(hit_class_probs, codon_mask=codon_mask) flat_masked_observed_hcs = flatten_and_mask_sequence_codons(observed_hcs, codon_mask=codon_mask).long() corrections = torch.cat([torch.tensor([0.0]), self.model.values]) - corrected_probs = flat_masked_hit_class_probs.log() + corrections - corrected_probs = (corrected_probs - torch.logsumexp(corrected_probs, dim=1, keepdim=True)).exp() + scaled_log_probs = flat_masked_hit_class_probs.log() + corrections + corrected_probs = (scaled_log_probs - torch.logsumexp(scaled_log_probs, dim=1, keepdim=True)).exp() assert torch.isfinite(corrected_probs).all() adjusted_probs = clamp_probability(corrected_probs) logits = torch.log(adjusted_probs / (1 - adjusted_probs)) - # Just need to adjust hit class probs by model coefficients, and re-normalize. - return self.cce_loss(logits, flat_masked_observed_hcs) - # nt_parents = batch["nt_parents"] - # nt_children = batch["nt_children"] - # brlens = batch["branch_lengths"] - # codon_mask = batch["codon_mask"] - # rates = batch["rates"] - # subs_probs = batch["subs_probs"] - # scaled_rates = rates * brlens - # codon_probs = torch.tensor([codon_probs_of_parent_scaled_rates_and_sub_probs(parent_idxs, scaled_rates_it, subs_probs_it) - # for parent_idxs, scaled_rates_it, subs_probs_it in zip(nt_parents, scaled_rates, subs_probs)]) - - # These are from DNSMBurrito, as a start def _find_optimal_branch_length( self, parent_idxs, child_idxs, - observed_hcs, rates, subs_probs, codon_mask, @@ -393,9 +370,6 @@ def _find_optimal_branch_length( **optimization_kwargs, ): - # # A stand-in for the adjustment model we're fitting: - # codon_adjustment = self.model.values - def log_pcp_probability(log_branch_length): # We want to first return the log-probability of the observed branch, using codon probs. # Then we'll want to adjust codon probs using our hit class probabilities @@ -411,21 +385,12 @@ def log_pcp_probability(log_branch_length): child_codon_idxs = reshape_for_codons(child_idxs)[codon_mask] parent_codon_idxs = reshape_for_codons(parent_idxs)[codon_mask] - corrected_codon_probs = self.model(parent_codon_idxs, codon_probs.log()) - child_codon_probs = corrected_codon_probs[torch.arange(child_codon_idxs.size(0)), child_codon_idxs[:, 0], child_codon_idxs[:, 1], child_codon_idxs[:, 2]] - return child_codon_probs.sum() - - # # hc_probs is a Cx4 tensor containing codon probs aggregated by hit class - # hc_probs = hit_class_probs_tensor(parent_codon_idxs, codon_probs) - - # # Add fixed 1 adjustment for hit class 0: - # _adjust = torch.cat([torch.tensor([1]), codon_adjustment]) - # # Get adjustments for each site's observed hit class - # observed_hc_adjustments = _adjust.gather(0, observed_hcs[codon_mask]) - # numerators = (child_codon_probs * observed_hc_adjustments).log() - # # This is a dot product of the distribution and the adjustments at each site - # denominators = (torch.matmul(hc_probs, _adjust)).log() - # return (numerators - denominators).sum() + corrected_codon_log_probs = self.model(parent_codon_idxs, codon_probs.log()) + child_codon_log_probs = corrected_codon_log_probs[torch.arange(child_codon_idxs.size(0)), + child_codon_idxs[:, 0], + child_codon_idxs[:, 1], + child_codon_idxs[:, 2]] + return child_codon_log_probs.sum() return optimize_branch_length( @@ -437,11 +402,10 @@ def log_pcp_probability(log_branch_length): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): optimal_lengths = [] - for parent_idxs, child_idxs, observed_hcs, rates, subs_probs, codon_mask, starting_length in tqdm( + for parent_idxs, child_idxs, rates, subs_probs, codon_mask, starting_length in tqdm( zip( dataset.nt_parents, dataset.nt_children, - dataset.observed_hcs, dataset.all_rates, dataset.all_subs_probs, dataset.codon_mask, @@ -454,7 +418,6 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): self._find_optimal_branch_length( parent_idxs, child_idxs, - observed_hcs, rates[: len(parent_idxs)], subs_probs[: len(parent_idxs), :], codon_mask,