Skip to content

Commit

Permalink
respond to Erick's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Sep 12, 2024
1 parent a7a6ee0 commit 49b3594
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 81 deletions.
41 changes: 21 additions & 20 deletions netam/hit_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor:
"""Produce a tensor containing the hit classes of all possible child codons, for each passed parent codon.
Parameters:
parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing for each codon, the
indices of the parent codon's nucleotides.
Args:
parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing for each codon, the
indices of the parent codon's nucleotides.
Returns:
torch.Tensor: A (codon_count, 4, 4, 4) shaped tensor containing the hit classes of each possible child codon for each parent codon.
torch.Tensor: A (codon_count, 4, 4, 4) shaped tensor containing the hit classes of each possible child codon for each parent codon.
"""
return hit_class_tensor[
parent_codon_idxs[:, 0], parent_codon_idxs[:, 1], parent_codon_idxs[:, 2]
Expand All @@ -49,17 +49,17 @@ def apply_multihit_correction(
Suppose there are N codons, then the parameters are as follows:
Parameters:
parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the
indices of the parent codon's nucleotides.
codon_logprobs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the log probabilities
of mutating to each possible target codon, for each of the N parent codons.
hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The
factor for hit class 0 is assumed to be 1 (that is, 0 in log-space).
Args:
parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the
indices of the parent codon's nucleotides.
codon_logprobs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the log probabilities
of mutating to each possible target codon, for each of the N parent codons.
hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The
factor for hit class 0 is assumed to be 1 (that is, 0 in log-space).
Returns:
torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible
target codon, for each of the N parent codons, after applying the hit class factors.
torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible
target codon, for each of the N parent codons, after applying the hit class factors.
"""
per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs)
corrections = torch.cat([torch.tensor([0.0]), hit_class_factors])
Expand All @@ -77,17 +77,18 @@ def hit_class_probs_tensor(
"""
Calculate probabilities of hit classes between parent codons and all other codons for all the sites of a sequence.
Parameters:
parent_codon_idxs (torch.Tensor): The parent nucleotide sequence encoded as a tensor of shape (codon_count, 3), containing the nt indices of each codon.
codon_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities of various codons, for each codon in parent seq.
Args:
parent_codon_idxs (torch.Tensor): The parent nucleotide sequence encoded as a tensor of shape (codon_count, 3),
containing the nt indices of each codon.
codon_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities of various
codons, for each codon in parent seq.
Returns:
probs (torch.Tensor): A tensor containing the probabilities of different
counts of hit classes between parent codons and
all other codons, with shape (codon_count, 4).
probs (torch.Tensor): A tensor containing the probabilities of different
counts of hit classes between parent codons and
all other codons, with shape (codon_count, 4).
Notes:
Uses hit_class_tensor (torch.Tensor): A 4x4x4x4x4x4 tensor which when indexed with a parent codon produces
the hit classes to all possible child codons.
"""
Expand Down
14 changes: 6 additions & 8 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,15 @@ def codon_probs_of_parent_scaled_rates_and_sub_probs(
the codon level rather than moving to syn/nonsyn changes.
Args:
parent_idxs (torch.Tensor): The parent nucleotide sequence encoded as a
tensor of length Cx3, where C is the number of codons, containing the nt indices of each site.
scaled_rates (torch.Tensor): Poisson rates of mutation per site, scaled by branch length.
sub_probs (torch.Tensor): Substitution probabilities per site: a 2D tensor with shape (site_count, 4).
parent_idxs (torch.Tensor): The parent nucleotide sequence encoded as a
tensor of length Cx3, where C is the number of codons, containing the nt indices of each site.
scaled_rates (torch.Tensor): Poisson rates of mutation per site, scaled by branch length.
sub_probs (torch.Tensor): Substitution probabilities per site: a 2D tensor with shape (site_count, 4).
Returns:
torch.Tensor: A 4D tensor with shape (codon_count, 4, 4, 4) where the cijk-th entry is the probability
of the c'th codon mutating to the codon ijk.
torch.Tensor: A 4D tensor with shape (codon_count, 4, 4, 4) where the cijk-th entry is the probability
of the c'th codon mutating to the codon ijk.
"""
# The following four lines are duplicated from
# aaprobs_of_parent_scaled_rates_and_sub_probs
mut_probs = 1.0 - torch.exp(-scaled_rates)
parent_codon_idxs = reshape_for_codons(parent_idxs)
codon_mut_probs = reshape_for_codons(mut_probs)
Expand Down
130 changes: 77 additions & 53 deletions netam/multihit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Burrito and Dataset classes for training a model to predict simple hit class corrections to codon probabilities.
"""Burrito and Dataset classes for training a model to predict simple hit class
corrections to codon probabilities.
Each codon mutation is hit class 0, 1, 2, or 3, corresponding to 0, 1, 2, or 3 mutations in the codon.
Each codon mutation is hit class 0, 1, 2, or 3, corresponding to 0, 1,
2, or 3 mutations in the codon.
The hit class corrections are three scalar values, one for each nonzero hit class.
To apply the correction to existing codon probability predictions, we multiply the probability of each child codon by the correction factor for its hit class, then renormalize. The correction factor for hit class 0 is fixed at 1.
The hit class corrections are three scalar values, one for each nonzero
hit class. To apply the correction to existing codon probability
predictions, we multiply the probability of each child codon by the
correction factor for its hit class, then renormalize. The correction
factor for hit class 0 is fixed at 1.
"""

import torch
Expand All @@ -25,19 +30,37 @@
from netam.models import HitClassModel


def _trim_seqs_to_codon_boundary_and_max_len(
seqs: list, site_count: int = None
) -> list:
"""Assumes that all sequences have the same length, and trims to codon boundary.
If site_count is None, does not enforce a maximum length."""
if site_count is None:
def _trim_to_codon_boundary_and_max_len(
seqs: list[Sequence], max_len: int = None
) -> list[Sequence]:
"""Trims sequences to codon boundary and maximum length.
No assumption is made about the data of a sequence, other than that it is
indexable (string or list of nucleotide indices both work).
`max_len` is the maximum number of nucleotides to be preserved.
If `max_len` is None, does not enforce a maximum length.
"""
if max_len is None:
return [seq[: len(seq) - len(seq) % 3] for seq in seqs]
else:
max_len = site_count - site_count % 3
return [seq[: min(len(seq) - len(seq) % 3, max_len)] for seq in seqs]
max_codon_len = max_len - max_len % 3
return [seq[: min(len(seq) - len(seq) % 3, max_codon_len)] for seq in seqs]


def _observed_hit_classes(parents: Sequence[str], children: Sequence[str]):
"""Compute the observed hit classes between parent and child sequences.
def _observed_hit_classes(parents, children):
Args:
parents (Sequence[str]): A list of parent sequences.
children (Sequence[str]): A list of the corresponding child sequences.
Returns:
list[torch.Tensor]: A list of tensors, each containing the observed
hit classes for each codon in the parent sequence. At any codon position
where the parent or child sequence contains an N, the corresponding tensor
element will be -100.
"""
labels = []
for parent_seq, child_seq in zip(parents, children):

Expand All @@ -62,7 +85,7 @@ def _observed_hit_classes(parents, children):
padded_mutations = num_mutations[:codon_count] # Truncate if necessary
padded_mutations += [-100] * (
codon_count - len(padded_mutations)
) # Pad with -1s
) # Pad with -100s

# Update the labels tensor for this row
labels.append(torch.tensor(padded_mutations, dtype=torch.int))
Expand All @@ -78,8 +101,8 @@ def __init__(
all_subs_probs: Sequence[list[list[float]]],
branch_length_multiplier: float = 1.0,
):
trimmed_parents = _trim_seqs_to_codon_boundary_and_max_len(nt_parents)
trimmed_children = _trim_seqs_to_codon_boundary_and_max_len(nt_children)
trimmed_parents = _trim_to_codon_boundary_and_max_len(nt_parents)
trimmed_children = _trim_to_codon_boundary_and_max_len(nt_children)
self.nt_parents = stack_heterogeneous(
pd.Series(
sequences.nt_idx_tensor_of_str(parent.replace("N", "A"))
Expand All @@ -93,15 +116,14 @@ def __init__(
)
)
self.all_rates = stack_heterogeneous(
pd.Series(
rates[: len(rates) - len(rates) % 3] for rates in all_rates
).reset_index(drop=True)
pd.Series(_trim_to_codon_boundary_and_max_len(all_rates)).reset_index(
drop=True
)
)
self.all_subs_probs = stack_heterogeneous(
pd.Series(
subs_probs[: len(subs_probs) - len(subs_probs) % 3]
for subs_probs in all_subs_probs
).reset_index(drop=True)
pd.Series(_trim_to_codon_boundary_and_max_len(all_subs_probs)).reset_index(
drop=True
)
)

assert len(self.nt_parents) == len(self.nt_children)
Expand Down Expand Up @@ -146,7 +168,8 @@ def branch_lengths(self, new_branch_lengths):
self.update_hit_class_probs()

def update_hit_class_probs(self):
"""Compute hit class probabilities for all codons in each sequence based on current branch lengths"""
"""Compute hit class probabilities for all codons in each sequence
based on current branch lengths."""
new_codon_probs = []
new_hc_probs = []
for (
Expand Down Expand Up @@ -215,10 +238,13 @@ def to(self, device):
def flatten_and_mask_sequence_codons(
input_tensor: torch.Tensor, codon_mask: torch.Tensor = None
):
"""Flatten first dimension of input_tensor, applying codon_mask first if provided.
"""Flatten first dimension of input_tensor, applying codon_mask first if
provided.
This is useful for input_tensors whose first dimension represents sequences, and whose second dimension represents
codons. The resulting tensor will then aggregate the codons of all sequences into the first dimension.
This is useful for input_tensors whose first dimension represents
sequences, and whose second dimension represents codons. The
resulting tensor will then aggregate the codons of all sequences
into the first dimension.
"""
flat_input = input_tensor.flatten(start_dim=0, end_dim=1)
if codon_mask is not None:
Expand All @@ -228,15 +254,16 @@ def flatten_and_mask_sequence_codons(


def child_codon_probs_from_per_parent_probs(per_parent_probs, child_codon_idxs):
"""Calculate the probability of each child codon given the parent codon probabilities.
"""Calculate the probability of each child codon given the parent codon
probabilities.
Parameters:
per_parent_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities
of each possible target codon, for each parent codon.
child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon.
Args:
per_parent_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities
of each possible target codon, for each parent codon.
child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon.
Returns:
torch.Tensor: A (codon_count,) shaped tensor containing the probabilities of each child codon.
torch.Tensor: A (codon_count,) shaped tensor containing the probabilities of each child codon.
"""
return per_parent_probs[
torch.arange(child_codon_idxs.size(0)),
Expand All @@ -249,18 +276,18 @@ def child_codon_probs_from_per_parent_probs(per_parent_probs, child_codon_idxs):
def child_codon_logprobs_corrected(
uncorrected_per_parent_logprobs, parent_codon_idxs, child_codon_idxs, model
):
"""Calculate the probability of each child codon given the parent codon probabilities, corrected by hit class factors.
Parameters:
"""Calculate the probability of each child codon given the parent codon
probabilities, corrected by hit class factors.
uncorrected_per_parent_logprobs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the log probabilities
of each possible target codon, for each parent codon.
parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each parent codon
child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon
model: A HitClassModel implementing the desired correction.
Args:
uncorrected_per_parent_logprobs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the log probabilities
of each possible target codon, for each parent codon.
parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each parent codon
child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon
model: A HitClassModel implementing the desired correction.
Returns:
torch.Tensor: A (codon_count,) shaped tensor containing the corrected log probabilities of each child codon.
torch.Tensor: A (codon_count,) shaped tensor containing the corrected log probabilities of each child codon.
"""

corrected_per_parent_logprobs = model(
Expand Down Expand Up @@ -427,7 +454,8 @@ def hit_class_dataset_from_pcp_df(
def train_test_datasets_of_pcp_df(
pcp_df: pd.DataFrame, train_frac: float = 0.8, branch_length_multiplier: float = 1.0
) -> tuple[HitClassDataset, HitClassDataset]:
"""Splits a pcp_df prepared by `prepare_pcp_df` into a training and testing HitClassDataset."""
"""Splits a pcp_df prepared by `prepare_pcp_df` into a training and testing
HitClassDataset."""
nt_parents = pcp_df["parent"].reset_index(drop=True)
nt_children = pcp_df["child"].reset_index(drop=True)
rates = pcp_df["rates"].reset_index(drop=True)
Expand Down Expand Up @@ -464,18 +492,14 @@ def train_test_datasets_of_pcp_df(
def prepare_pcp_df(
pcp_df: pd.DataFrame, crepe: framework.Crepe, site_count: int
) -> pd.DataFrame:
"""
Trim parent and child sequences in pcp_df to codon boundaries
and add the rates and substitution probabilities.
"""Trim parent and child sequences in pcp_df to codon boundaries and add
the rates and substitution probabilities.
Returns the modified dataframe, which is the input dataframe modified in-place.
Returns the modified dataframe, which is the input dataframe
modified in-place.
"""
pcp_df["parent"] = _trim_seqs_to_codon_boundary_and_max_len(
pcp_df["parent"], site_count
)
pcp_df["child"] = _trim_seqs_to_codon_boundary_and_max_len(
pcp_df["child"], site_count
)
pcp_df["parent"] = _trim_to_codon_boundary_and_max_len(pcp_df["parent"], site_count)
pcp_df["child"] = _trim_to_codon_boundary_and_max_len(pcp_df["child"], site_count)
pcp_df = pcp_df[pcp_df["parent"] != pcp_df["child"]].reset_index(drop=True)
ratess, cspss = framework.trimmed_shm_model_outputs_of_crepe(
crepe, pcp_df["parent"]
Expand Down

0 comments on commit 49b3594

Please sign in to comment.