Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate the multihit model into the DNSM framework #71

Merged
merged 8 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.optim as optim
from torch import nn, Tensor
import multiprocessing as mp

BIG = 1e9
SMALL_PROB = 1e-6
Expand All @@ -31,6 +32,15 @@
)


def force_spawn():
"""Force the spawn start method for multiprocessing.

This is necessary to avoid conflicts with the internal OpenMP-based thread pool in
PyTorch.
"""
mp.set_start_method("spawn", force=True)


def generate_kmers(kmer_length):
# Our strategy for kmers is to have a single representation for any kmer that isn't in ACGT.
# This is the first one, which is simply "N", and so this placeholder value is 0.
Expand Down
4 changes: 0 additions & 4 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
# optimization on our server.
torch.set_num_threads(1)

import numpy as np
import pandas as pd

from netam.common import (
clamp_log_probability,
clamp_probability,
BIG,
)
Expand Down
40 changes: 33 additions & 7 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@ def __init__(
all_rates: torch.Tensor,
all_subs_probs: torch.Tensor,
branch_lengths: torch.Tensor,
multihit_model=None,
):
self.nt_parents = nt_parents
self.nt_children = nt_children
self.all_rates = all_rates
self.all_subs_probs = all_subs_probs
self.multihit_model = copy.deepcopy(multihit_model)
if multihit_model is not None:
# We want these parameters to act like fixed data. This is essential
# for multithreaded branch length optimization to work.
self.multihit_model.values.requires_grad_(False)

assert len(self.nt_parents) == len(self.nt_children)
pcp_count = len(self.nt_parents)
Expand Down Expand Up @@ -95,6 +101,7 @@ def of_seriess(
all_rates_series: pd.Series,
all_subs_probs_series: pd.Series,
branch_length_multiplier=5.0,
multihit_model=None,
):
"""Alternative constructor that takes the raw data and calculates the initial
branch lengths.
Expand All @@ -115,10 +122,11 @@ def of_seriess(
stack_heterogeneous(all_rates_series.reset_index(drop=True)),
stack_heterogeneous(all_subs_probs_series.reset_index(drop=True)),
initial_branch_lengths,
multihit_model=multihit_model,
)

@classmethod
def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None):
"""Alternative constructor that takes in a pcp_df and calculates the initial
branch lengths."""
assert "rates" in pcp_df.columns, "pcp_df must have a neutral rates column"
Expand All @@ -128,10 +136,13 @@ def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
pcp_df["rates"],
pcp_df["subs_probs"],
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)

@classmethod
def train_val_datasets_of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
def train_val_datasets_of_pcp_df(
cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None
):
"""Perform a train-val split based on the 'in_train' column.

This is a class method so it works for subclasses.
Expand All @@ -140,14 +151,18 @@ def train_val_datasets_of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True)

val_dataset = cls.of_pcp_df(
val_df, branch_length_multiplier=branch_length_multiplier
val_df,
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)

if len(train_df) == 0:
return None, val_dataset
# else:
train_dataset = cls.of_pcp_df(
train_df, branch_length_multiplier=branch_length_multiplier
train_df,
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)

return train_dataset, val_dataset
Expand All @@ -160,6 +175,7 @@ def clone(self):
self.all_rates.copy(),
self.all_subs_probs.copy(),
self._branch_lengths.copy(),
multihit_model=self.multihit_model,
)
return new_dataset

Expand All @@ -176,6 +192,7 @@ def subset_via_indices(self, indices):
self.all_rates[indices],
self.all_subs_probs[indices],
self._branch_lengths[indices],
multihit_model=self.multihit_model,
)
return new_dataset

Expand Down Expand Up @@ -231,6 +248,10 @@ def update_neutral_probs(self):
mask = mask.to("cpu")
rates = rates.to("cpu")
subs_probs = subs_probs.to("cpu")
if self.multihit_model is not None:
multihit_model = copy.deepcopy(self.multihit_model).to("cpu")
else:
multihit_model = None
matsen marked this conversation as resolved.
Show resolved Hide resolved
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
Expand All @@ -245,6 +266,7 @@ def update_neutral_probs(self):
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
normed_subs_probs.reshape(-1, 3, 4),
multihit_model=multihit_model,
)

if not torch.isfinite(neutral_aa_mut_prob).all():
Expand Down Expand Up @@ -295,6 +317,8 @@ def to(self, device):
self.log_neutral_aa_mut_probs = self.log_neutral_aa_mut_probs.to(device)
self.all_rates = self.all_rates.to(device)
self.all_subs_probs = self.all_subs_probs.to(device)
if self.multihit_model is not None:
self.multihit_model = self.multihit_model.to(device)


class DNSMBurrito(framework.Burrito):
Expand Down Expand Up @@ -366,15 +390,16 @@ def _find_optimal_branch_length(
rates,
subs_probs,
starting_branch_length,
multihit_model,
**optimization_kwargs,
):
if parent == child:
return 0.0
sel_matrix = self.build_selection_matrix_from_parent(parent)
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix, parent, child, rates, subs_probs
sel_matrix, parent, child, rates, subs_probs, multihit_model
)
if type(starting_branch_length) == torch.Tensor:
if isinstance(starting_branch_length, torch.Tensor):
starting_branch_length = starting_branch_length.detach().item()
return molevol.optimize_branch_length(
log_pcp_probability, starting_branch_length, **optimization_kwargs
Expand All @@ -401,6 +426,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
rates[: len(parent)],
subs_probs[: len(parent), :],
starting_length,
dataset.multihit_model,
**optimization_kwargs,
)

Expand All @@ -416,7 +442,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# # The following can be used when one wants a better traceback.
# burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
our_optimize_branch_length = partial(
Expand Down
26 changes: 14 additions & 12 deletions netam/hit_class.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
import numpy as np

from netam.common import BASES


# Define the number of bases (e.g., 4 for DNA/RNA)
_num_bases = 4
Expand Down Expand Up @@ -43,8 +41,8 @@ def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor

def apply_multihit_correction(
parent_codon_idxs: torch.Tensor,
codon_logprobs: torch.Tensor,
hit_class_factors: torch.Tensor,
codon_probs: torch.Tensor,
log_hit_class_factors: torch.Tensor,
) -> torch.Tensor:
"""Multiply codon probabilities by their hit class factors, and renormalize.

Expand All @@ -53,23 +51,27 @@ def apply_multihit_correction(
Args:
parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the
indices of the parent codon's nucleotides.
codon_logprobs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the log probabilities
codon_probs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the probabilities
of mutating to each possible target codon, for each of the N parent codons.
hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The
log_hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The
factor for hit class 0 is assumed to be 1 (that is, 0 in log-space).

Returns:
torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible
torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the probabilities of mutating to each possible
target codon, for each of the N parent codons, after applying the hit class factors.
"""
per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs)
corrections = torch.cat([torch.tensor([0.0]), hit_class_factors])
corrections = torch.cat([torch.tensor([0.0]), log_hit_class_factors]).exp()
reshaped_corrections = corrections[per_parent_hit_class]
unnormalized_corrected_logprobs = codon_logprobs + reshaped_corrections
normalizations = torch.logsumexp(
unnormalized_corrected_logprobs, dim=[1, 2, 3], keepdim=True
unnormalized_corrected_probs = codon_probs * reshaped_corrections
normalizations = torch.sum(
unnormalized_corrected_probs, dim=[1, 2, 3], keepdim=True
)
return unnormalized_corrected_logprobs - normalizations
result = unnormalized_corrected_probs / normalizations
if torch.any(torch.isnan(result)):
print("NAN found in multihit correction application")
assert False
return result


def hit_class_probs_tensor(
Expand Down
12 changes: 6 additions & 6 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
import math
import warnings

warnings.filterwarnings(
"ignore", category=UserWarning, module="torch.nn.modules.transformer"
)

import pandas as pd

import torch
Expand All @@ -22,6 +18,10 @@
aa_mask_tensor_of,
)

warnings.filterwarnings(
"ignore", category=UserWarning, module="torch.nn.modules.transformer"
)


class ModelBase(nn.Module):
def reinitialize_weights(self):
Expand Down Expand Up @@ -705,13 +705,13 @@ def hyperparameters(self):
return {}

def forward(
self, parent_codon_idxs: torch.Tensor, uncorrected_log_codon_probs: torch.Tensor
self, parent_codon_idxs: torch.Tensor, uncorrected_codon_probs: torch.Tensor
):
"""Forward function takes a tensor of target codon distributions, for each
observed parent codon, and adjusts the distributions according to the hit class
corrections."""
return apply_multihit_correction(
parent_codon_idxs, uncorrected_log_codon_probs, self.values
parent_codon_idxs, uncorrected_codon_probs, self.values
)

def reinitialize_weights(self):
Expand Down
28 changes: 23 additions & 5 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def build_codon_mutsel(
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
aa_sel_matrices: Tensor,
multihit_model=None,
) -> Tensor:
"""Build a sequence of codon mutation-selection matrices for codons along a
sequence.
Expand All @@ -301,6 +302,9 @@ def build_codon_mutsel(
)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

if multihit_model is not None:
codon_probs = multihit_model(parent_codon_idxs, codon_probs)

# Calculate the codon selection matrix for each sequence via Einstein
# summation, in which we sum over the repeated indices.
# So, for each site (s) and codon (c), sum over amino acids (a):
Expand Down Expand Up @@ -339,6 +343,7 @@ def neutral_aa_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
multihit_model=None,
) -> Tensor:
"""For every site, what is the probability that the amino acid will mutate to every
amino acid?
Expand All @@ -356,10 +361,13 @@ def neutral_aa_probs(
mut_matrices = build_mutation_matrices(
parent_codon_idxs, codon_mut_probs, codon_sub_probs
)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices).view(-1, 64)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

if multihit_model is not None:
codon_probs = multihit_model(parent_codon_idxs, codon_probs)

# Get the probability of mutating to each amino acid.
aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX
aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX

return aa_probs

Expand Down Expand Up @@ -396,6 +404,7 @@ def neutral_aa_mut_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
multihit_model=None,
) -> Tensor:
"""For every site, what is the probability that the amino acid will have a
substution or mutate to a stop under neutral evolution?
Expand All @@ -414,12 +423,19 @@ def neutral_aa_mut_probs(
Shape: (codon_count,)
"""

aa_probs = neutral_aa_probs(parent_codon_idxs, codon_mut_probs, codon_sub_probs)
aa_probs = neutral_aa_probs(
parent_codon_idxs,
codon_mut_probs,
codon_sub_probs,
multihit_model=multihit_model,
)
mut_probs = mut_probs_of_aa_probs(parent_codon_idxs, aa_probs)
return mut_probs


def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs):
def mutsel_log_pcp_probability_of(
sel_matrix, parent, child, rates, sub_probs, multihit_model=None
):
"""Constructs the log_pcp_probability function specific to given rates and
sub_probs.

Expand All @@ -442,6 +458,7 @@ def log_pcp_probability(log_branch_length: torch.Tensor):
mut_probs.reshape(-1, 3),
sub_probs.reshape(-1, 3, 4),
sel_matrix,
multihit_model=multihit_model,
)

# This is a diagnostic generating data for netam issue #7.
Expand Down Expand Up @@ -496,7 +513,8 @@ def optimize_branch_length(
loss.backward()
torch.nn.utils.clip_grad_norm_([log_branch_length], max_norm=5.0)
optimizer.step()
assert not torch.isnan(log_branch_length)
if torch.isnan(log_branch_length):
raise ValueError("branch length optimization resulted in NAN")

change_in_log_branch_length = torch.abs(
log_branch_length - prev_log_branch_length
Expand Down
Loading