Skip to content

Commit

Permalink
some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Aug 13, 2024
1 parent d200bb4 commit bc522bf
Showing 1 changed file with 39 additions and 76 deletions.
115 changes: 39 additions & 76 deletions netam/codon_prob.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
from tqdm import tqdm
import numpy as np
Expand All @@ -15,7 +15,6 @@
from netam.common import BASES, stack_heterogeneous, clamp_probability
import netam.framework as framework
from netam.framework import Burrito
from netam.models import AbstractBinarySelectionModel

def hit_class(codon1, codon2):
return sum(c1 != c2 for c1, c2 in zip(codon1, codon2))
Expand All @@ -41,32 +40,32 @@ def hit_class(codon1, codon2):
# make a dict mapping from codon to triple integer index
codon_to_idxs = {base_1+base_2+base_3: (i, j, k) for i, base_1 in enumerate(BASES) for j, base_2 in enumerate(BASES) for k, base_3 in enumerate(BASES)}

def hit_class_probs(hit_class_tensor, codon_probs):
"""
Calculate total probabilities for each number of differences between codons.
# def hit_class_probs(hit_class_tensor, codon_probs):
# """
# Calculate total probabilities for each number of differences between codons.

Args:
- hit_class_tensor (torch.Tensor): A 4x4x4 integer tensor containing the number of differences
between each codon and a reference codon.
- codon_probs (torch.Tensor): A 4x4x4 tensor containing the probabilities of various codons.
# Args:
# - hit_class_tensor (torch.Tensor): A 4x4x4 integer tensor containing the number of differences
# between each codon and a reference codon.
# - codon_probs (torch.Tensor): A 4x4x4 tensor containing the probabilities of various codons.

Returns:
- total_probs (torch.Tensor): A 1D tensor containing the total probabilities for each number
of differences (0 to 3).
"""
total_probs = []
# Returns:
# - total_probs (torch.Tensor): A 1D tensor containing the total probabilities for each number
# of differences (0 to 3).
# """
# total_probs = []

for hit_class in range(4):
# Create a mask of codons with the desired number of differences
mask = hit_class_tensor == hit_class
# for hit_class in range(4):
# # Create a mask of codons with the desired number of differences
# mask = hit_class_tensor == hit_class

# Multiply componentwise with the codon_probs tensor and sum
total_prob = (codon_probs * mask.float()).sum()
# # Multiply componentwise with the codon_probs tensor and sum
# total_prob = (codon_probs * mask.float()).sum()

# Append the total probability to the list
total_probs.append(total_prob.item())
# # Append the total probability to the list
# total_probs.append(total_prob.item())

return torch.tensor(total_probs)
# return torch.tensor(total_probs)

def hit_class_probs_tensor(parent_codon_idxs, codon_probs):
"""
Expand Down Expand Up @@ -172,15 +171,11 @@ def __init__(
trimmed_children = [child[: len(child) - len(child) % 3] for child in nt_children]
self.nt_parents = stack_heterogeneous(pd.Series(sequences.nt_idx_tensor_of_str(parent.replace("N", "A")) for parent in trimmed_parents))
self.nt_children = stack_heterogeneous(pd.Series(sequences.nt_idx_tensor_of_str(child.replace("N", "A")) for child in trimmed_children))
max_len = len(self.nt_parents[0])
self.nt_parents_strs = [parent + ("N" * (max_len - len(parent))) for parent in trimmed_parents]
self.nt_children_strs = [child + ("N" * (max_len - len(child))) for child in trimmed_children]
self.all_rates = stack_heterogeneous(pd.Series(rates[: len(rates) - len(rates) % 3] for rates in all_rates).reset_index(drop=True))
self.all_subs_probs = stack_heterogeneous(pd.Series(subs_probs[: len(subs_probs) - len(subs_probs) % 3] for subs_probs in all_subs_probs).reset_index(drop=True))

assert len(self.nt_parents) == len(self.nt_children)

# TODO get hit classes and do checks directly from tensor encoding of sequences
for parent, child in zip(trimmed_parents, trimmed_children):
if parent == child:
raise ValueError(
Expand Down Expand Up @@ -230,10 +225,6 @@ def update_hit_class_probs(self):
self.all_subs_probs,
self.branch_lengths,
):
# This encodes bases as indices in a sorted nucleotide list. Codons containing
# N's should already be masked in self.codon_mask, so treating them as A's here shouldn't matter...
# TODO Check that assertion ^^

scaled_rates = branch_length * rates

codon_probs = codon_probs_of_parent_scaled_rates_and_sub_probs(
Expand All @@ -242,9 +233,9 @@ def update_hit_class_probs(self):

new_hc_probs.append(hit_class_probs_tensor(reshape_for_codons(encoded_parent), codon_probs))
# We must store probability of all hit classes for arguments to cce_loss in loss_of_batch.
self.hit_class_probs = stack_heterogeneous(new_hc_probs, pad_value=-100)
self.hit_class_probs = torch.stack(new_hc_probs)

# A couple of these methods could be moved to a super class, which itself subclasses Dataset
# A couple of these methods could maybe be moved to a super class, which itself subclasses Dataset
def export_branch_lengths(self, out_csv_path):
pd.DataFrame({"branch_length": self.branch_lengths}).to_csv(
out_csv_path, index=False
Expand All @@ -268,12 +259,14 @@ def __getitem__(self, idx):
}

def to(self, device):
# TODO update this (and might have to encode sequences as Tensors), if used!
raise NotImplementedError
self.codon_mask = self.mask.to(device)
self.nt_parents = self.nt_parents.to(device)
self.nt_children = self.nt_children.to(device)
self.observed_hcs = self.observed_hcs.to(device)
self.all_rates = self.all_rates.to(device)
self.all_subs_probs = self.all_subs_probs.to(device)
self.hit_class_probs = self.hit_class_probs.to(device)
self.codon_mask = self.codon_mask.to(device)
self.branch_lengths = self.branch_lengths.to(device)

def flatten_and_mask_sequence_codons(input_tensor, codon_mask=None):
"""Flatten first dimension, that is over sequences, to return tensor
Expand Down Expand Up @@ -308,6 +301,8 @@ def hyperparameters(self):
return {}

def forward(self, parent_codon_idxs: torch.Tensor, uncorrected_log_codon_probs: torch.Tensor):
"""Forward function takes a tensor of target codon distributions, for each observed parent codon,
and adjusts the distributions according to the hit class adjustments."""
hit_class_tensor_t = hit_class_tensor_full[parent_codon_idxs[:, 0],
parent_codon_idxs[:, 1],
parent_codon_idxs[:, 2]].int()
Expand Down Expand Up @@ -337,8 +332,6 @@ def __init__(
self.cce_loss = torch.nn.CrossEntropyLoss(reduction='mean')


# For loss want categorical cross-entropy, appears in framework.py for another model
# When computing overall log-likelihood will need to account for the different sequence lengths
def load_branch_lengths(self, in_csv_prefix):
if self.train_loader is not None:
self.train_loader.dataset.load_branch_lengths(
Expand All @@ -348,8 +341,6 @@ def load_branch_lengths(self, in_csv_prefix):
in_csv_prefix + ".val_branch_lengths.csv"
)

# Once optimized branch lengths, store the baseline codon-level predictions somewhere. See DNSMBurrito::predictions_of_batch
# Rates stay same, and are used to re-compute branch lengths whenever codon probs are adjusted.
def loss_of_batch(self, batch):
# different sequence lengths, and codons containing N's, are marked in the mask.
observed_hcs = batch["observed_hcs"]
Expand All @@ -359,43 +350,26 @@ def loss_of_batch(self, batch):
flat_masked_hit_class_probs = flatten_and_mask_sequence_codons(hit_class_probs, codon_mask=codon_mask)
flat_masked_observed_hcs = flatten_and_mask_sequence_codons(observed_hcs, codon_mask=codon_mask).long()
corrections = torch.cat([torch.tensor([0.0]), self.model.values])
corrected_probs = flat_masked_hit_class_probs.log() + corrections
corrected_probs = (corrected_probs - torch.logsumexp(corrected_probs, dim=1, keepdim=True)).exp()
scaled_log_probs = flat_masked_hit_class_probs.log() + corrections
corrected_probs = (scaled_log_probs - torch.logsumexp(scaled_log_probs, dim=1, keepdim=True)).exp()
assert torch.isfinite(corrected_probs).all()
adjusted_probs = clamp_probability(corrected_probs)
logits = torch.log(adjusted_probs / (1 - adjusted_probs))

# Just need to adjust hit class probs by model coefficients, and re-normalize.

return self.cce_loss(logits, flat_masked_observed_hcs)
# nt_parents = batch["nt_parents"]
# nt_children = batch["nt_children"]
# brlens = batch["branch_lengths"]
# codon_mask = batch["codon_mask"]
# rates = batch["rates"]
# subs_probs = batch["subs_probs"]
# scaled_rates = rates * brlens
# codon_probs = torch.tensor([codon_probs_of_parent_scaled_rates_and_sub_probs(parent_idxs, scaled_rates_it, subs_probs_it)
# for parent_idxs, scaled_rates_it, subs_probs_it in zip(nt_parents, scaled_rates, subs_probs)])



# These are from DNSMBurrito, as a start
def _find_optimal_branch_length(
self,
parent_idxs,
child_idxs,
observed_hcs,
rates,
subs_probs,
codon_mask,
starting_branch_length,
**optimization_kwargs,
):

# # A stand-in for the adjustment model we're fitting:
# codon_adjustment = self.model.values

def log_pcp_probability(log_branch_length):
# We want to first return the log-probability of the observed branch, using codon probs.
# Then we'll want to adjust codon probs using our hit class probabilities
Expand All @@ -411,21 +385,12 @@ def log_pcp_probability(log_branch_length):

child_codon_idxs = reshape_for_codons(child_idxs)[codon_mask]
parent_codon_idxs = reshape_for_codons(parent_idxs)[codon_mask]
corrected_codon_probs = self.model(parent_codon_idxs, codon_probs.log())
child_codon_probs = corrected_codon_probs[torch.arange(child_codon_idxs.size(0)), child_codon_idxs[:, 0], child_codon_idxs[:, 1], child_codon_idxs[:, 2]]
return child_codon_probs.sum()

# # hc_probs is a Cx4 tensor containing codon probs aggregated by hit class
# hc_probs = hit_class_probs_tensor(parent_codon_idxs, codon_probs)

# # Add fixed 1 adjustment for hit class 0:
# _adjust = torch.cat([torch.tensor([1]), codon_adjustment])
# # Get adjustments for each site's observed hit class
# observed_hc_adjustments = _adjust.gather(0, observed_hcs[codon_mask])
# numerators = (child_codon_probs * observed_hc_adjustments).log()
# # This is a dot product of the distribution and the adjustments at each site
# denominators = (torch.matmul(hc_probs, _adjust)).log()
# return (numerators - denominators).sum()
corrected_codon_log_probs = self.model(parent_codon_idxs, codon_probs.log())
child_codon_log_probs = corrected_codon_log_probs[torch.arange(child_codon_idxs.size(0)),
child_codon_idxs[:, 0],
child_codon_idxs[:, 1],
child_codon_idxs[:, 2]]
return child_codon_log_probs.sum()


return optimize_branch_length(
Expand All @@ -437,11 +402,10 @@ def log_pcp_probability(log_branch_length):
def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []

for parent_idxs, child_idxs, observed_hcs, rates, subs_probs, codon_mask, starting_length in tqdm(
for parent_idxs, child_idxs, rates, subs_probs, codon_mask, starting_length in tqdm(
zip(
dataset.nt_parents,
dataset.nt_children,
dataset.observed_hcs,
dataset.all_rates,
dataset.all_subs_probs,
dataset.codon_mask,
Expand All @@ -454,7 +418,6 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
self._find_optimal_branch_length(
parent_idxs,
child_idxs,
observed_hcs,
rates[: len(parent_idxs)],
subs_probs[: len(parent_idxs), :],
codon_mask,
Expand Down

0 comments on commit bc522bf

Please sign in to comment.