Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First approximation to a DASM: 20 output dimensions but same loss #64

Merged
merged 18 commits into from
Oct 14, 2024
Merged
188 changes: 188 additions & 0 deletions netam/dasm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Here we define a mutation-selection model that is per-amino-acid."""

import torch
import torch.nn.functional as F

# Amazingly, using one thread makes things 50x faster for branch length
# optimization on our server.
torch.set_num_threads(1)

import numpy as np
import pandas as pd

from netam.common import (
clamp_probability,
)
import netam.dnsm as dnsm
import netam.molevol as molevol
import netam.sequences as sequences
from netam.sequences import (
translate_sequence,
)


class DASMDataset(dnsm.DNSMDataset):

def update_neutral_probs(self):
neutral_aa_probs_l = []

for nt_parent, mask, rates, branch_length, subs_probs in zip(
self.nt_parents,
self.mask,
self.all_rates,
self._branch_lengths,
self.all_subs_probs,
):
mask = mask.to("cpu")
rates = rates.to("cpu")
subs_probs = subs_probs.to("cpu")
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)

mut_probs = 1.0 - torch.exp(-branch_length * rates[:parent_len])
normed_subs_probs = molevol.normalize_sub_probs(
parent_idxs, subs_probs[:parent_len, :]
)

neutral_aa_probs = molevol.neutral_aa_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
normed_subs_probs.reshape(-1, 3, 4),
)

if not torch.isfinite(neutral_aa_probs).all():
print(f"Found a non-finite neutral_aa_probs")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"rates: {rates}")
print(f"subs_probs: {subs_probs}")
print(f"branch_length: {branch_length}")
raise ValueError(f"neutral_aa_probs is not finite: {neutral_aa_probs}")

# Ensure that all values are positive before taking the log later
neutral_aa_probs = clamp_probability(neutral_aa_probs)

pad_len = self.max_aa_seq_len - neutral_aa_probs.shape[0]
if pad_len > 0:
neutral_aa_probs = F.pad(
neutral_aa_probs, (0, 0, 0, pad_len), value=1e-8
)
# Here we zero out masked positions.
neutral_aa_probs *= mask[:, None]

neutral_aa_probs_l.append(neutral_aa_probs)

# Note that our masked out positions will have a nan log probability,
# which will require us to handle them correctly downstream.
self.log_neutral_aa_probs = torch.log(torch.stack(neutral_aa_probs_l))

def __getitem__(self, idx):
return {
"aa_parents_idxs": self.aa_parents_idxs[idx],
"subs_indicator": self.aa_subs_indicator_tensor[idx],
"mask": self.mask[idx],
"log_neutral_aa_probs": self.log_neutral_aa_probs[idx],
"rates": self.all_rates[idx],
"subs_probs": self.all_subs_probs[idx],
}

def to(self, device):
self.aa_parents_idxs = self.aa_parents_idxs.to(device)
self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device)
self.mask = self.mask.to(device)
self.log_neutral_aa_probs = self.log_neutral_aa_probs.to(device)
self.all_rates = self.all_rates.to(device)
self.all_subs_probs = self.all_subs_probs.to(device)


def zero_predictions_along_diagonal(predictions, aa_parents_idxs):
"""Zero out the diagonal of a batch of predictions.

We do this so that we can sum then have the same type of predictions as for the
DNSM.
"""
# We would like to do
# predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0
# but we have a batch dimension. Thus the following.

batch_size, L, _ = predictions.shape
batch_indices = torch.arange(batch_size, device=predictions.device)
predictions[
batch_indices[:, None],
torch.arange(L, device=predictions.device),
aa_parents_idxs,
] = 0.0

return predictions


class DASMBurrito(dnsm.DNSMBurrito):

def prediction_pair_of_batch(self, batch):
"""Get log neutral AA probabilities and log selection factors for a batch of
data."""
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
mask = batch["mask"].to(self.device)
log_neutral_aa_probs = batch["log_neutral_aa_probs"].to(self.device)
if not torch.isfinite(log_neutral_aa_probs[mask]).all():
raise ValueError(
f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}"
)
log_selection_factors = self.model(aa_parents_idxs, mask)
return log_neutral_aa_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
# Take the product of the neutral mutation probabilities and the selection factors.
# NOTE each of these now have last dimension of 20
# this is p_{j, a} * f_{j, a}
predictions = torch.exp(log_neutral_aa_probs + log_selection_factors)
assert torch.isfinite(predictions).all()
predictions = clamp_probability(predictions)
return predictions

def predictions_of_batch(self, batch):
"""Make predictions for a batch of data.

Note that we use the mask for prediction as part of the input for the
transformer, though we don't mask the predictions themselves.
"""
log_neutral_aa_probs, log_selection_factors = self.prediction_pair_of_batch(
batch
)
return self.predictions_of_pair(log_neutral_aa_probs, log_selection_factors)

def loss_of_batch(self, batch):
aa_subs_indicator = batch["subs_indicator"].to(self.device)
mask = batch["mask"].to(self.device)
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
aa_subs_indicator = aa_subs_indicator.masked_select(mask)
predictions = self.predictions_of_batch(batch)
# Add one entry, zero, to the last dimension of the predictions tensor
# to handle the ambiguous amino acids. This is the conservative choice.
# It might be faster to reassign all the 20s to 0s if we are confident
# in our masking. Perhaps we should always output a 21st dimension
# for the ambiguous amino acids (see issue #16).
# If we change something here we should also change the test code
# in test_dasm.py::test_zero_diagonal.
predictions = torch.cat(
[predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1
)

predictions = zero_predictions_along_diagonal(predictions, aa_parents_idxs)

predictions_of_mut = torch.sum(predictions, dim=-1)
predictions_of_mut = predictions_of_mut.masked_select(mask)
predictions_of_mut = clamp_probability(predictions_of_mut)
return self.bce_loss(predictions_of_mut, aa_subs_indicator)

def build_selection_matrix_from_parent(self, parent: str):
# This is simpler than the equivalent in dnsm.py because we get the selection
# matrix directly.
parent = translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

return selection_factors
69 changes: 42 additions & 27 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import copy
import multiprocessing as mp
from functools import partial

import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(
assert torch.max(self.aa_parents_idxs) <= MAX_AMBIG_AA_IDX

self._branch_lengths = branch_lengths
self.update_neutral_aa_mut_probs()
self.update_neutral_probs()

@classmethod
def of_seriess(
Expand Down Expand Up @@ -134,9 +135,31 @@ def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
branch_length_multiplier=branch_length_multiplier,
)

@classmethod
def train_val_datasets_of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
"""Perform a train-val split based on the 'in_train' column.

This is a class method so it works for subclasses.
"""
train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True)
val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True)

val_dataset = cls.of_pcp_df(
val_df, branch_length_multiplier=branch_length_multiplier
)

if len(train_df) == 0:
return None, val_dataset
# else:
train_dataset = cls.of_pcp_df(
train_df, branch_length_multiplier=branch_length_multiplier
)

return train_dataset, val_dataset

def clone(self):
"""Make a deep copy of the dataset."""
new_dataset = DNSMDataset(
new_dataset = self.__class__(
self.nt_parents,
self.nt_children,
self.all_rates.copy(),
Expand All @@ -152,7 +175,7 @@ def subset_via_indices(self, indices):
depends on `indices`: if `indices` is an iterable of integers, then we
make a deep copy, otherwise we use slices to make a shallow copy.
"""
new_dataset = DNSMDataset(
new_dataset = self.__class__(
self.nt_parents[indices].reset_index(drop=True),
self.nt_children[indices].reset_index(drop=True),
self.all_rates[indices],
Expand Down Expand Up @@ -181,7 +204,7 @@ def branch_lengths(self, new_branch_lengths):
)
assert torch.all(torch.isfinite(new_branch_lengths) & (new_branch_lengths > 0))
self._branch_lengths = new_branch_lengths
self.update_neutral_aa_mut_probs()
self.update_neutral_probs()

def export_branch_lengths(self, out_csv_path):
pd.DataFrame({"branch_length": self.branch_lengths}).to_csv(
Expand All @@ -193,7 +216,14 @@ def load_branch_lengths(self, in_csv_path):
pd.read_csv(in_csv_path)["branch_length"].values
)

def update_neutral_aa_mut_probs(self):
def update_neutral_probs(self):
"""Update the neutral mutation probabilities for the dataset.

This is a somewhat vague name, but that's because it includes both the cases of
the DNSM (in which case it's neutral probabilities of any nonsynonymous
mutation) and the DASM (in which case it's the neutral probabilities of mutation
to the various amino acids).
"""
neutral_aa_mut_prob_l = []

for nt_parent, mask, rates, branch_length, subs_probs in zip(
Expand Down Expand Up @@ -272,25 +302,6 @@ def to(self, device):
self.all_subs_probs = self.all_subs_probs.to(device)


def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
"""Perform a train-val split based on a "in_train" column.

Stays here so it can be used in tests.
"""
train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True)
val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True)
val_dataset = DNSMDataset.of_pcp_df(
val_df, branch_length_multiplier=branch_length_multiplier
)
if len(train_df) == 0:
return None, val_dataset
# else:
train_dataset = DNSMDataset.of_pcp_df(
train_df, branch_length_multiplier=branch_length_multiplier
)
return train_dataset, val_dataset


class DNSMBurrito(framework.Burrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -413,10 +424,14 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
# The following can be used when one wants a better traceback.
# burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
our_optimize_branch_length = partial(
worker_optimize_branch_length,
self.__class__,
)
with mp.Pool(worker_count) as pool:
splits = dataset.split(worker_count)
results = pool.starmap(
worker_optimize_branch_length,
our_optimize_branch_length,
[(self.model, split, optimization_kwargs) for split in splits],
)
return torch.cat(results)
Expand All @@ -436,9 +451,9 @@ def to_crepe(self):
return framework.Crepe(encoder, self.model, training_hyperparameters)


def worker_optimize_branch_length(model, dataset, optimization_kwargs):
def worker_optimize_branch_length(burrito_class, model, dataset, optimization_kwargs):
"""The worker used for parallel branch length optimization."""
burrito = DNSMBurrito(None, dataset, copy.deepcopy(model))
burrito = burrito_class(None, dataset, copy.deepcopy(model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)


Expand Down
4 changes: 3 additions & 1 deletion netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def __init__(
dim_feedforward: int,
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
):
super().__init__()
# Note that d_model has to be divisible by nhead, so we make that
Expand All @@ -586,7 +587,7 @@ def __init__(
batch_first=True,
)
self.encoder = nn.TransformerEncoder(self.encoder_layer, layer_count)
self.linear = nn.Linear(self.d_model, 1)
self.linear = nn.Linear(self.d_model, output_dim)
self.init_weights()

@property
Expand All @@ -597,6 +598,7 @@ def hyperparameters(self):
"dim_feedforward": self.dim_feedforward,
"layer_count": self.encoder.num_layers,
"dropout_prob": self.pos_encoder.dropout.p,
"output_dim": self.linear.out_features,
}

def init_weights(self) -> None:
Expand Down
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
from netam.framework import (
load_pcp_df,
add_shm_model_outputs_to_pcp_df,
)


@pytest.fixture(scope="module")
def pcp_df():
df = load_pcp_df(
"data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz",
)
df = add_shm_model_outputs_to_pcp_df(
df,
"data/cnn_joi_sml-shmoof_small",
)
return df
Loading