diff --git a/netam/hit_class.py b/netam/hit_class.py index e67c1e54..5e53bb1b 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -29,11 +29,11 @@ def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor: """Produce a tensor containing the hit classes of all possible child codons, for each passed parent codon. - Parameters: - parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing for each codon, the - indices of the parent codon's nucleotides. + Args: + parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing for each codon, the + indices of the parent codon's nucleotides. Returns: - torch.Tensor: A (codon_count, 4, 4, 4) shaped tensor containing the hit classes of each possible child codon for each parent codon. + torch.Tensor: A (codon_count, 4, 4, 4) shaped tensor containing the hit classes of each possible child codon for each parent codon. """ return hit_class_tensor[ parent_codon_idxs[:, 0], parent_codon_idxs[:, 1], parent_codon_idxs[:, 2] @@ -49,17 +49,17 @@ def apply_multihit_correction( Suppose there are N codons, then the parameters are as follows: - Parameters: - parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the - indices of the parent codon's nucleotides. - codon_logprobs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the log probabilities - of mutating to each possible target codon, for each of the N parent codons. - hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The - factor for hit class 0 is assumed to be 1 (that is, 0 in log-space). + Args: + parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the + indices of the parent codon's nucleotides. + codon_logprobs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the log probabilities + of mutating to each possible target codon, for each of the N parent codons. + hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The + factor for hit class 0 is assumed to be 1 (that is, 0 in log-space). Returns: - torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible - target codon, for each of the N parent codons, after applying the hit class factors. + torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible + target codon, for each of the N parent codons, after applying the hit class factors. """ per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]) @@ -77,17 +77,18 @@ def hit_class_probs_tensor( """ Calculate probabilities of hit classes between parent codons and all other codons for all the sites of a sequence. - Parameters: - parent_codon_idxs (torch.Tensor): The parent nucleotide sequence encoded as a tensor of shape (codon_count, 3), containing the nt indices of each codon. - codon_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities of various codons, for each codon in parent seq. + Args: + parent_codon_idxs (torch.Tensor): The parent nucleotide sequence encoded as a tensor of shape (codon_count, 3), + containing the nt indices of each codon. + codon_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities of various + codons, for each codon in parent seq. Returns: - probs (torch.Tensor): A tensor containing the probabilities of different - counts of hit classes between parent codons and - all other codons, with shape (codon_count, 4). + probs (torch.Tensor): A tensor containing the probabilities of different + counts of hit classes between parent codons and + all other codons, with shape (codon_count, 4). Notes: - Uses hit_class_tensor (torch.Tensor): A 4x4x4x4x4x4 tensor which when indexed with a parent codon produces the hit classes to all possible child codons. """ diff --git a/netam/molevol.py b/netam/molevol.py index e8e1b494..cee380a3 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -235,17 +235,15 @@ def codon_probs_of_parent_scaled_rates_and_sub_probs( the codon level rather than moving to syn/nonsyn changes. Args: - parent_idxs (torch.Tensor): The parent nucleotide sequence encoded as a - tensor of length Cx3, where C is the number of codons, containing the nt indices of each site. - scaled_rates (torch.Tensor): Poisson rates of mutation per site, scaled by branch length. - sub_probs (torch.Tensor): Substitution probabilities per site: a 2D tensor with shape (site_count, 4). + parent_idxs (torch.Tensor): The parent nucleotide sequence encoded as a + tensor of length Cx3, where C is the number of codons, containing the nt indices of each site. + scaled_rates (torch.Tensor): Poisson rates of mutation per site, scaled by branch length. + sub_probs (torch.Tensor): Substitution probabilities per site: a 2D tensor with shape (site_count, 4). Returns: - torch.Tensor: A 4D tensor with shape (codon_count, 4, 4, 4) where the cijk-th entry is the probability - of the c'th codon mutating to the codon ijk. + torch.Tensor: A 4D tensor with shape (codon_count, 4, 4, 4) where the cijk-th entry is the probability + of the c'th codon mutating to the codon ijk. """ - # The following four lines are duplicated from - # aaprobs_of_parent_scaled_rates_and_sub_probs mut_probs = 1.0 - torch.exp(-scaled_rates) parent_codon_idxs = reshape_for_codons(parent_idxs) codon_mut_probs = reshape_for_codons(mut_probs) diff --git a/netam/multihit.py b/netam/multihit.py index b886166c..2eb521d8 100644 --- a/netam/multihit.py +++ b/netam/multihit.py @@ -1,9 +1,14 @@ -"""Burrito and Dataset classes for training a model to predict simple hit class corrections to codon probabilities. +"""Burrito and Dataset classes for training a model to predict simple hit class +corrections to codon probabilities. -Each codon mutation is hit class 0, 1, 2, or 3, corresponding to 0, 1, 2, or 3 mutations in the codon. +Each codon mutation is hit class 0, 1, 2, or 3, corresponding to 0, 1, +2, or 3 mutations in the codon. -The hit class corrections are three scalar values, one for each nonzero hit class. -To apply the correction to existing codon probability predictions, we multiply the probability of each child codon by the correction factor for its hit class, then renormalize. The correction factor for hit class 0 is fixed at 1. +The hit class corrections are three scalar values, one for each nonzero +hit class. To apply the correction to existing codon probability +predictions, we multiply the probability of each child codon by the +correction factor for its hit class, then renormalize. The correction +factor for hit class 0 is fixed at 1. """ import torch @@ -25,19 +30,37 @@ from netam.models import HitClassModel -def _trim_seqs_to_codon_boundary_and_max_len( - seqs: list, site_count: int = None -) -> list: - """Assumes that all sequences have the same length, and trims to codon boundary. - If site_count is None, does not enforce a maximum length.""" - if site_count is None: +def _trim_to_codon_boundary_and_max_len( + seqs: list[Sequence], max_len: int = None +) -> list[Sequence]: + """Trims sequences to codon boundary and maximum length. + + No assumption is made about the data of a sequence, other than that it is + indexable (string or list of nucleotide indices both work). + + `max_len` is the maximum number of nucleotides to be preserved. + If `max_len` is None, does not enforce a maximum length. + """ + if max_len is None: return [seq[: len(seq) - len(seq) % 3] for seq in seqs] else: - max_len = site_count - site_count % 3 - return [seq[: min(len(seq) - len(seq) % 3, max_len)] for seq in seqs] + max_codon_len = max_len - max_len % 3 + return [seq[: min(len(seq) - len(seq) % 3, max_codon_len)] for seq in seqs] + +def _observed_hit_classes(parents: Sequence[str], children: Sequence[str]): + """Compute the observed hit classes between parent and child sequences. -def _observed_hit_classes(parents, children): + Args: + parents (Sequence[str]): A list of parent sequences. + children (Sequence[str]): A list of the corresponding child sequences. + + Returns: + list[torch.Tensor]: A list of tensors, each containing the observed + hit classes for each codon in the parent sequence. At any codon position + where the parent or child sequence contains an N, the corresponding tensor + element will be -100. + """ labels = [] for parent_seq, child_seq in zip(parents, children): @@ -62,7 +85,7 @@ def _observed_hit_classes(parents, children): padded_mutations = num_mutations[:codon_count] # Truncate if necessary padded_mutations += [-100] * ( codon_count - len(padded_mutations) - ) # Pad with -1s + ) # Pad with -100s # Update the labels tensor for this row labels.append(torch.tensor(padded_mutations, dtype=torch.int)) @@ -78,8 +101,8 @@ def __init__( all_subs_probs: Sequence[list[list[float]]], branch_length_multiplier: float = 1.0, ): - trimmed_parents = _trim_seqs_to_codon_boundary_and_max_len(nt_parents) - trimmed_children = _trim_seqs_to_codon_boundary_and_max_len(nt_children) + trimmed_parents = _trim_to_codon_boundary_and_max_len(nt_parents) + trimmed_children = _trim_to_codon_boundary_and_max_len(nt_children) self.nt_parents = stack_heterogeneous( pd.Series( sequences.nt_idx_tensor_of_str(parent.replace("N", "A")) @@ -93,15 +116,14 @@ def __init__( ) ) self.all_rates = stack_heterogeneous( - pd.Series( - rates[: len(rates) - len(rates) % 3] for rates in all_rates - ).reset_index(drop=True) + pd.Series(_trim_to_codon_boundary_and_max_len(all_rates)).reset_index( + drop=True + ) ) self.all_subs_probs = stack_heterogeneous( - pd.Series( - subs_probs[: len(subs_probs) - len(subs_probs) % 3] - for subs_probs in all_subs_probs - ).reset_index(drop=True) + pd.Series(_trim_to_codon_boundary_and_max_len(all_subs_probs)).reset_index( + drop=True + ) ) assert len(self.nt_parents) == len(self.nt_children) @@ -146,7 +168,8 @@ def branch_lengths(self, new_branch_lengths): self.update_hit_class_probs() def update_hit_class_probs(self): - """Compute hit class probabilities for all codons in each sequence based on current branch lengths""" + """Compute hit class probabilities for all codons in each sequence + based on current branch lengths.""" new_codon_probs = [] new_hc_probs = [] for ( @@ -215,10 +238,13 @@ def to(self, device): def flatten_and_mask_sequence_codons( input_tensor: torch.Tensor, codon_mask: torch.Tensor = None ): - """Flatten first dimension of input_tensor, applying codon_mask first if provided. + """Flatten first dimension of input_tensor, applying codon_mask first if + provided. - This is useful for input_tensors whose first dimension represents sequences, and whose second dimension represents - codons. The resulting tensor will then aggregate the codons of all sequences into the first dimension. + This is useful for input_tensors whose first dimension represents + sequences, and whose second dimension represents codons. The + resulting tensor will then aggregate the codons of all sequences + into the first dimension. """ flat_input = input_tensor.flatten(start_dim=0, end_dim=1) if codon_mask is not None: @@ -228,15 +254,16 @@ def flatten_and_mask_sequence_codons( def child_codon_probs_from_per_parent_probs(per_parent_probs, child_codon_idxs): - """Calculate the probability of each child codon given the parent codon probabilities. + """Calculate the probability of each child codon given the parent codon + probabilities. - Parameters: - per_parent_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities - of each possible target codon, for each parent codon. - child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon. + Args: + per_parent_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities + of each possible target codon, for each parent codon. + child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon. Returns: - torch.Tensor: A (codon_count,) shaped tensor containing the probabilities of each child codon. + torch.Tensor: A (codon_count,) shaped tensor containing the probabilities of each child codon. """ return per_parent_probs[ torch.arange(child_codon_idxs.size(0)), @@ -249,18 +276,18 @@ def child_codon_probs_from_per_parent_probs(per_parent_probs, child_codon_idxs): def child_codon_logprobs_corrected( uncorrected_per_parent_logprobs, parent_codon_idxs, child_codon_idxs, model ): - """Calculate the probability of each child codon given the parent codon probabilities, corrected by hit class factors. - - Parameters: + """Calculate the probability of each child codon given the parent codon + probabilities, corrected by hit class factors. - uncorrected_per_parent_logprobs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the log probabilities - of each possible target codon, for each parent codon. - parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each parent codon - child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon - model: A HitClassModel implementing the desired correction. + Args: + uncorrected_per_parent_logprobs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the log probabilities + of each possible target codon, for each parent codon. + parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each parent codon + child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon + model: A HitClassModel implementing the desired correction. Returns: - torch.Tensor: A (codon_count,) shaped tensor containing the corrected log probabilities of each child codon. + torch.Tensor: A (codon_count,) shaped tensor containing the corrected log probabilities of each child codon. """ corrected_per_parent_logprobs = model( @@ -427,7 +454,8 @@ def hit_class_dataset_from_pcp_df( def train_test_datasets_of_pcp_df( pcp_df: pd.DataFrame, train_frac: float = 0.8, branch_length_multiplier: float = 1.0 ) -> tuple[HitClassDataset, HitClassDataset]: - """Splits a pcp_df prepared by `prepare_pcp_df` into a training and testing HitClassDataset.""" + """Splits a pcp_df prepared by `prepare_pcp_df` into a training and testing + HitClassDataset.""" nt_parents = pcp_df["parent"].reset_index(drop=True) nt_children = pcp_df["child"].reset_index(drop=True) rates = pcp_df["rates"].reset_index(drop=True) @@ -464,18 +492,14 @@ def train_test_datasets_of_pcp_df( def prepare_pcp_df( pcp_df: pd.DataFrame, crepe: framework.Crepe, site_count: int ) -> pd.DataFrame: - """ - Trim parent and child sequences in pcp_df to codon boundaries - and add the rates and substitution probabilities. + """Trim parent and child sequences in pcp_df to codon boundaries and add + the rates and substitution probabilities. - Returns the modified dataframe, which is the input dataframe modified in-place. + Returns the modified dataframe, which is the input dataframe + modified in-place. """ - pcp_df["parent"] = _trim_seqs_to_codon_boundary_and_max_len( - pcp_df["parent"], site_count - ) - pcp_df["child"] = _trim_seqs_to_codon_boundary_and_max_len( - pcp_df["child"], site_count - ) + pcp_df["parent"] = _trim_to_codon_boundary_and_max_len(pcp_df["parent"], site_count) + pcp_df["child"] = _trim_to_codon_boundary_and_max_len(pcp_df["child"], site_count) pcp_df = pcp_df[pcp_df["parent"] != pcp_df["child"]].reset_index(drop=True) ratess, cspss = framework.trimmed_shm_model_outputs_of_crepe( crepe, pcp_df["parent"]