Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add codon_prob.py with a model to adjust codon probs by hit class #50

Merged
merged 10 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VRC01_NT_SEQ,
)
from netam import models
from netam import molevol
import netam.molevol as molevol


def encode_mut_pos_and_base(parent, child, site_count=None):
Expand Down Expand Up @@ -551,7 +551,7 @@ def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None
first_value_of_batch = (
list(batch.values())[0] if isinstance(batch, dict) else batch[0]
)
batch_size = first_value_of_batch.shape[0]
batch_size = len(first_value_of_batch)
# If we multiply the loss by the batch size, then the loss will be the sum of the
# losses for each example in the batch. Then, when we divide by the number of
# examples in the dataset below, we will get the average loss per example.
Expand Down
104 changes: 104 additions & 0 deletions netam/hit_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch
import numpy as np

from netam.common import BASES


# Define the number of bases (e.g., 4 for DNA/RNA)
_num_bases = 4

# Generate all possible codons using broadcasting
_i, _j, _k = np.indices((_num_bases, _num_bases, _num_bases)) # Create index grids
_codon1 = np.stack((_i, _j, _k), axis=-1) # Shape: (4, 4, 4, 3)

# Expand dimensions to compare all codon pairs using broadcasting
_codon1_expanded = _codon1[
:, :, :, np.newaxis, np.newaxis, np.newaxis, :
] # Shape: (4, 4, 4, 1, 1, 1, 3)
_codon2_expanded = _codon1[
np.newaxis, np.newaxis, np.newaxis, :, :, :, :
] # Shape: (1, 1, 1, 4, 4, 4, 3)

# Count the number of differing positions between each pair of codons
"""hit_class_tensor is a tensor of shape (4, 4, 4, 4, 4, 4) recording the hit class (number of nucleotide differences) between all possible parent and child codons. The first three dimensions represent the parent codon, and the last three represent the child codon. Codons are identified by triples of nucleotide indices from `netam.common.BASES`."""
hit_class_tensor = torch.tensor(
np.sum(_codon1_expanded != _codon2_expanded, axis=-1)
).int()


def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor:
"""Produce a tensor containing the hit classes of all possible child codons, for each passed parent codon.

Args:
parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing for each codon, the
indices of the parent codon's nucleotides.
Returns:
torch.Tensor: A (codon_count, 4, 4, 4) shaped tensor containing the hit classes of each possible child codon for each parent codon.
"""
return hit_class_tensor[
parent_codon_idxs[:, 0], parent_codon_idxs[:, 1], parent_codon_idxs[:, 2]
]


def apply_multihit_correction(
parent_codon_idxs: torch.Tensor,
codon_logprobs: torch.Tensor,
hit_class_factors: torch.Tensor,
) -> torch.Tensor:
"""Multiply codon probabilities by their hit class factors, and renormalize.

Suppose there are N codons, then the parameters are as follows:

Args:
parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the
indices of the parent codon's nucleotides.
codon_logprobs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the log probabilities
of mutating to each possible target codon, for each of the N parent codons.
hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The
factor for hit class 0 is assumed to be 1 (that is, 0 in log-space).

Returns:
torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible
target codon, for each of the N parent codons, after applying the hit class factors.
"""
per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs)
corrections = torch.cat([torch.tensor([0.0]), hit_class_factors])
reshaped_corrections = corrections[per_parent_hit_class]
unnormalized_corrected_logprobs = codon_logprobs + reshaped_corrections
normalizations = torch.logsumexp(
unnormalized_corrected_logprobs, dim=[1, 2, 3], keepdim=True
)
return unnormalized_corrected_logprobs - normalizations


def hit_class_probs_tensor(
parent_codon_idxs: torch.Tensor, codon_probs: torch.Tensor
) -> torch.Tensor:
"""
Calculate probabilities of hit classes between parent codons and all other codons for all the sites of a sequence.

Args:
parent_codon_idxs (torch.Tensor): The parent nucleotide sequence encoded as a tensor of shape (codon_count, 3),
containing the nt indices of each codon.
codon_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities of various
codons, for each codon in parent seq.

Returns:
probs (torch.Tensor): A tensor containing the probabilities of different
counts of hit classes between parent codons and
all other codons, with shape (codon_count, 4).

Notes:
Uses hit_class_tensor (torch.Tensor): A 4x4x4x4x4x4 tensor which when indexed with a parent codon produces
the hit classes to all possible child codons.
"""

# Get a codon_countx4x4x4 tensor describing for each parent codon the hit classes of all child codons
per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs)
codon_count = per_parent_hit_class.size(0)
hc_prob_tensor = torch.zeros(codon_count, 4)
for k in range(4):
mask = per_parent_hit_class == k
hc_prob_tensor[:, k] = (codon_probs * mask).sum(dim=(1, 2, 3))

return hc_prob_tensor
23 changes: 23 additions & 0 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.nn.functional as F
from torch import Tensor

from netam.hit_class import apply_multihit_correction
from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
Expand Down Expand Up @@ -708,3 +709,25 @@ def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Build a binary log selection matrix from a one-hot encoded parent sequence."""
replicated_value = self.single_value.expand_as(amino_acid_indices)
return replicated_value


class HitClassModel(nn.Module):
def __init__(self):
super().__init__()
self.reinitialize_weights()

@property
def hyperparameters(self):
return {}

def forward(
self, parent_codon_idxs: torch.Tensor, uncorrected_log_codon_probs: torch.Tensor
):
"""Forward function takes a tensor of target codon distributions, for each observed parent codon,
and adjusts the distributions according to the hit class corrections."""
return apply_multihit_correction(
parent_codon_idxs, uncorrected_log_codon_probs, self.values
)

def reinitialize_weights(self):
self.values = nn.Parameter(torch.tensor([0.0, 0.0, 0.0]))
47 changes: 37 additions & 10 deletions netam/molevol.py
matsen marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,38 @@ def reshape_for_codons(array: Tensor) -> Tensor:
return array.reshape(codon_count, 3, *array.shape[1:])


def codon_probs_of_parent_scaled_rates_and_sub_probs(
parent_idxs: torch.Tensor, scaled_rates: torch.Tensor, sub_probs: torch.Tensor
):
"""
Compute the probabilities of mutating to various codons for a parent sequence.

This uses the same machinery as we use for fitting the DNSM, but we stay on
the codon level rather than moving to syn/nonsyn changes.

Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs reformatting with indentation like so:

"""
Compute the probabilities of mutating to various codons for a parent sequence.

This uses the same machinery as we use for fitting the DNSM, but we stay on
the codon level rather than moving to syn/nonsyn changes.

Args:
    parent_idxs (torch.Tensor): The parent nucleotide sequence encoded as a
        tensor of length Cx3, where C is the number of codons, containing the nt indices of each site.
    scaled_rates (torch.Tensor): Poisson rates of mutation per site, scaled by branch length.
    sub_probs (torch.Tensor): Substitution probabilities per site: a 2D tensor with shape (site_count, 4).

Returns:
    torch.Tensor: A 4D tensor with shape (codon_count, 4, 4, 4) where the cijk-th entry is the probability
        of the c'th codon mutating to the codon ijk.
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the format from the docstrings in molevol.py. Now I see that the last half use the format you want, and the first half use the format I imitated.

I've fixed the docstrings for my functions and opened a new issue to use docformatter, and to establish a consistent docstring format (#56 )

parent_idxs (torch.Tensor): The parent nucleotide sequence encoded as a
tensor of length Cx3, where C is the number of codons, containing the nt indices of each site.
scaled_rates (torch.Tensor): Poisson rates of mutation per site, scaled by branch length.
sub_probs (torch.Tensor): Substitution probabilities per site: a 2D tensor with shape (site_count, 4).

Returns:
torch.Tensor: A 4D tensor with shape (codon_count, 4, 4, 4) where the cijk-th entry is the probability
of the c'th codon mutating to the codon ijk.
"""
mut_probs = 1.0 - torch.exp(-scaled_rates)
parent_codon_idxs = reshape_for_codons(parent_idxs)
codon_mut_probs = reshape_for_codons(mut_probs)
codon_sub_probs = reshape_for_codons(sub_probs)

mut_matrices = build_mutation_matrices(
parent_codon_idxs, codon_mut_probs, codon_sub_probs
)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

return codon_probs


def aaprobs_of_parent_scaled_rates_and_sub_probs(
parent_idxs: Tensor, scaled_rates: Tensor, sub_probs: Tensor
) -> Tensor:
Expand All @@ -243,16 +275,11 @@ def aaprobs_of_parent_scaled_rates_and_sub_probs(
torch.Tensor: A 2D tensor with rows corresponding to sites and columns
corresponding to amino acids.
"""
# Calculate the probability of at least one mutation at each site.
mut_probs = 1.0 - torch.exp(-scaled_rates)

# Reshape the inputs to include a codon dimension.
parent_codon_idxs = reshape_for_codons(parent_idxs)
codon_mut_probs = reshape_for_codons(mut_probs)
codon_sub_probs = reshape_for_codons(sub_probs)

# Vectorized calculation of amino acid probabilities.
return aaprob_of_mut_and_sub(parent_codon_idxs, codon_mut_probs, codon_sub_probs)
return aaprobs_of_codon_probs(
codon_probs_of_parent_scaled_rates_and_sub_probs(
parent_idxs, scaled_rates, sub_probs
)
)


def build_codon_mutsel(
Expand Down
Loading