Skip to content

Commit

Permalink
DASM code runs, who knows if it does something useful
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Sep 30, 2024
1 parent 39e52c2 commit e42529e
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 0 deletions.
226 changes: 226 additions & 0 deletions netam/dasm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""Here we define a mutation-selection model that is per-amino-acid."""

import copy
import multiprocessing as mp

import torch
import torch.nn.functional as F

# Amazingly, using one thread makes things 50x faster for branch length
# optimization on our server.
torch.set_num_threads(1)

import numpy as np
import pandas as pd

from tqdm import tqdm

from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
clamp_probability,
aa_mask_tensor_of,
stack_heterogeneous,
)
import netam.dnsm as dnsm
import netam.framework as framework
from netam.hyper_burrito import HyperBurrito
import netam.molevol as molevol
import netam.sequences as sequences
from netam.sequences import (
aa_subs_indicator_tensor_of,
translate_sequence,
translate_sequences,
)

class DASMDataset(dnsm.DNSMDataset):

# TODO should we rename this?
def update_neutral_aa_mut_probs(self):
neutral_aa_probs_l = []

for nt_parent, mask, rates, branch_length, subs_probs in zip(
self.nt_parents,
self.mask,
self.all_rates,
self._branch_lengths,
self.all_subs_probs,
):
mask = mask.to("cpu")
rates = rates.to("cpu")
subs_probs = subs_probs.to("cpu")
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)

mut_probs = 1.0 - torch.exp(-branch_length * rates[:parent_len])
normed_subs_probs = molevol.normalize_sub_probs(
parent_idxs, subs_probs[:parent_len, :]
)

# BIG TODO here: this is where we need to change to the new per-aa probabilities.
# Then we need to think carefully about masking and padding etc.
neutral_aa_probs = molevol.neutral_aa_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
normed_subs_probs.reshape(-1, 3, 4),
)

if not torch.isfinite(neutral_aa_probs).all():
print(f"Found a non-finite neutral_aa_probs")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"rates: {rates}")
print(f"subs_probs: {subs_probs}")
print(f"branch_length: {branch_length}")
raise ValueError(
f"neutral_aa_probs is not finite: {neutral_aa_probs}"
)

# Ensure that all values are positive before taking the log later
neutral_aa_probs = clamp_probability(neutral_aa_probs)

pad_len = self.max_aa_seq_len - neutral_aa_probs.shape[0]
if pad_len > 0:
neutral_aa_probs = F.pad(
neutral_aa_probs, (0, 0, 0, pad_len), value=1e-8
)
# Here we zero out masked positions.
neutral_aa_probs *= mask[:, None]

neutral_aa_probs_l.append(neutral_aa_probs)

# Note that our masked out positions will have a nan log probability,
# which will require us to handle them correctly downstream.
self.log_neutral_aa_probs = torch.log(torch.stack(neutral_aa_probs_l))

def __getitem__(self, idx):
return {
"aa_parents_idxs": self.aa_parents_idxs[idx],
"subs_indicator": self.aa_subs_indicator_tensor[idx],
"mask": self.mask[idx],
"log_neutral_aa_probs": self.log_neutral_aa_probs[idx],
"rates": self.all_rates[idx],
"subs_probs": self.all_subs_probs[idx],
}

def to(self, device):
self.aa_parents_idxs = self.aa_parents_idxs.to(device)
self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device)
self.mask = self.mask.to(device)
self.log_neutral_aa_probs = self.log_neutral_aa_probs.to(device)
self.all_rates = self.all_rates.to(device)
self.all_subs_probs = self.all_subs_probs.to(device)


# TODO code dup
def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
"""Perform a train-val split based on a "in_train" column.
Stays here so it can be used in tests.
"""
train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True)
val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True)
val_dataset = DASMDataset.of_pcp_df(
val_df, branch_length_multiplier=branch_length_multiplier
)
if len(train_df) == 0:
return None, val_dataset
# else:
train_dataset = DASMDataset.of_pcp_df(
train_df, branch_length_multiplier=branch_length_multiplier
)
return train_dataset, val_dataset




class DASMBurrito(dnsm.DNSMBurrito):

def prediction_pair_of_batch(self, batch):
"""Get log neutral AA probabilities and log selection factors for a batch
of data."""
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
mask = batch["mask"].to(self.device)
log_neutral_aa_probs = batch["log_neutral_aa_probs"].to(self.device)
if not torch.isfinite(log_neutral_aa_probs[mask]).all():
raise ValueError(
f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}"
)
log_selection_factors = self.model(aa_parents_idxs, mask)
return log_neutral_aa_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
# Take the product of the neutral mutation probabilities and the selection factors.
predictions = torch.exp(log_neutral_aa_probs + log_selection_factors)
assert torch.isfinite(predictions).all()
predictions = clamp_probability(predictions)
return predictions

def predictions_of_batch(self, batch):
"""Make predictions for a batch of data.
Note that we use the mask for prediction as part of the input for the
transformer, though we don't mask the predictions themselves.
"""
log_neutral_aa_probs, log_selection_factors = self.prediction_pair_of_batch(
batch
)
return self.predictions_of_pair(log_neutral_aa_probs, log_selection_factors)

def loss_of_batch(self, batch):
aa_subs_indicator = batch["subs_indicator"].to(self.device)
mask = batch["mask"].to(self.device)
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
aa_subs_indicator = aa_subs_indicator.masked_select(mask)
predictions = self.predictions_of_batch(batch)
# add one entry, zero, to the last dimension of the predictions tensor
# to handle the ambiguous amino acids
# TODO perhaps we can do better
predictions = torch.cat(
[predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1
)
# Now we make predictions of mutation by taking everything off the diagonal.
#predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0

# Get batch size and sequence length
batch_size, L, _ = predictions.shape
# Create indices for each batch
batch_indices = torch.arange(batch_size, device=self.device)
# Zero out the diagonal by setting predictions[batch_idx, site_idx, aa_idx] to 0
predictions[batch_indices[:, None], torch.arange(L, device=self.device), aa_parents_idxs] = 0.0

predictions_of_mut = torch.sum(predictions, dim=-1)
predictions_of_mut = predictions_of_mut.masked_select(mask)
return self.bce_loss(predictions_of_mut, aa_subs_indicator)

def build_selection_matrix_from_parent(self, parent: str):
# This is simpler than the equivalent in dnsm.py because we get the selection
# matrix directly.
parent = translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

return selection_factors

# We need to repeat this so that we use this worker_optimize_branch_length below.
def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
with mp.Pool(worker_count) as pool:
splits = dataset.split(worker_count)
results = pool.starmap(
worker_optimize_branch_length,
[(self.model, split, optimization_kwargs) for split in splits],
)
return torch.cat(results)


def worker_optimize_branch_length(model, dataset, optimization_kwargs):
"""The worker used for parallel branch length optimization."""
burrito = DASMBurrito(None, dataset, copy.deepcopy(model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
1 change: 1 addition & 0 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def build_codon_mutsel(
return codon_mutsel, sums_too_big


# TODO consider a nice name
def neutral_aa_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
Expand Down
73 changes: 73 additions & 0 deletions tests/test_dasm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os

import torch
import pytest

from netam.framework import (
crepe_exists,
load_crepe,
load_pcp_df,
add_shm_model_outputs_to_pcp_df,
)
from netam.common import aa_idx_tensor_of_str_ambig, MAX_AMBIG_AA_IDX
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dasm import DASMBurrito,train_val_datasets_of_pcp_df


#TODO code dup
@pytest.fixture
def pcp_df():
df = load_pcp_df(
"data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz",
)
df = add_shm_model_outputs_to_pcp_df(
df,
"data/cnn_joi_sml-shmoof_small",
)
return df


@pytest.fixture
def dasm_burrito(pcp_df):
"""Fixture that returns the DNSM Burrito object."""
pcp_df["in_train"] = True
pcp_df.loc[pcp_df.index[-15:], "in_train"] = False
train_dataset, val_dataset = train_val_datasets_of_pcp_df(pcp_df)

model = TransformerBinarySelectionModelWiggleAct(
nhead=2, d_model_per_head=4, dim_feedforward=256, layer_count=2, output_dim=20,
)

burrito = DASMBurrito(
train_dataset,
val_dataset,
model,
batch_size=32,
learning_rate=0.001,
min_learning_rate=0.0001,
)
burrito.joint_train(epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False)
return burrito


def test_parallel_branch_length_optimization(dasm_burrito):
dataset = dasm_burrito.val_dataset
parallel_branch_lengths = dasm_burrito.find_optimal_branch_lengths(dataset)
branch_lengths = dasm_burrito.serial_find_optimal_branch_lengths(dataset)
assert torch.allclose(branch_lengths, parallel_branch_lengths)


def test_crepe_roundtrip(dasm_burrito):
os.makedirs("_ignore", exist_ok=True)
crepe_path = "_ignore/dasm"
dasm_burrito.save_crepe(crepe_path)
assert crepe_exists(crepe_path)
crepe = load_crepe(crepe_path)
model = crepe.model
assert isinstance(model, TransformerBinarySelectionModelWiggleAct)
assert dasm_burrito.model.hyperparameters == model.hyperparameters
model.to(dasm_burrito.device)
for t1, t2 in zip(
dasm_burrito.model.state_dict().values(), model.state_dict().values()
):
assert torch.equal(t1, t2)

0 comments on commit e42529e

Please sign in to comment.