Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parallel branch length optimization #28

Merged
merged 18 commits into from
Jun 5, 2024
124 changes: 102 additions & 22 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

"""

import copy
import multiprocessing as mp

import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
Expand Down Expand Up @@ -36,7 +39,6 @@
clamp_probability,
aa_mask_tensor_of,
stack_heterogeneous,
pick_device,
)
import netam.framework as framework
from netam.hyper_burrito import HyperBurrito
Expand All @@ -45,16 +47,16 @@
class DNSMDataset(Dataset):
def __init__(
self,
nt_parents,
nt_children,
all_rates,
all_subs_probs,
branch_length_multiplier=5.0,
nt_parents: pd.Series,
nt_children: pd.Series,
all_rates: torch.Tensor,
all_subs_probs: torch.Tensor,
branch_lengths: torch.Tensor,
):
self.nt_parents = nt_parents
self.nt_children = nt_children
self.all_rates = stack_heterogeneous(all_rates.reset_index(drop=True))
self.all_subs_probs = stack_heterogeneous(all_subs_probs.reset_index(drop=True))
self.all_rates = all_rates
self.all_subs_probs = all_subs_probs

assert len(self.nt_parents) == len(self.nt_children)
pcp_count = len(self.nt_parents)
Expand Down Expand Up @@ -89,15 +91,77 @@ def __init__(
assert torch.all(self.mask.sum(dim=1) > 0)
assert torch.max(self.aa_parents_idxs) <= MAX_AMBIG_AA_IDX

# Make initial branch lengths (will get optimized later).
self._branch_lengths = np.array(
self._branch_lengths = branch_lengths
self.update_neutral_aa_mut_probs()

@classmethod
def from_data(
cls,
nt_parents: pd.Series,
nt_children: pd.Series,
all_rates_series: pd.Series,
all_subs_probs_series: pd.Series,
branch_length_multiplier=5.0,
):
"""
Alternative constructor that takes the raw data and calculates the
initial branch lengths.

The `_series` arguments are series of Tensors which get stacked to
create the full object.
"""
initial_branch_lengths = np.array(
[
sequences.nt_mutation_frequency(parent, child)
* branch_length_multiplier
for parent, child in zip(self.nt_parents, self.nt_children)
for parent, child in zip(nt_parents, nt_children)
]
)
self.update_neutral_aa_mut_probs()
return cls(
nt_parents.reset_index(drop=True),
nt_children.reset_index(drop=True),
stack_heterogeneous(all_rates_series.reset_index(drop=True)),
stack_heterogeneous(all_subs_probs_series.reset_index(drop=True)),
initial_branch_lengths,
)

def clone(self):
"""Make a deep copy of the dataset."""
new_dataset = DNSMDataset(
self.nt_parents,
self.nt_children,
self.all_rates.copy(),
self.all_subs_probs.copy(),
self._branch_lengths.copy(),
)
return new_dataset

def subset_via_indices(self, indices):
"""
Create a new dataset with a subset of the data, as per `indices`.

Whether the new dataset is a deep copy or a shallow copy using slices
depends on `indices`: if `indices` is an iterable of integers, then we
make a deep copy, otherwise we use slices to make a shallow copy.
"""
new_dataset = DNSMDataset(
self.nt_parents[indices].reset_index(drop=True),
self.nt_children[indices].reset_index(drop=True),
self.all_rates[indices],
self.all_subs_probs[indices],
self._branch_lengths[indices],
)
return new_dataset

def split(self, into_count: int):
"""
Split self into a list of into_count subsets.
"""
dataset_size = len(self)
indices = list(range(dataset_size))
split_indices = np.array_split(indices, into_count)
subsets = [self.subset_via_indices(split_indices[i]) for i in range(into_count)]
return subsets

@property
def branch_lengths(self):
Expand Down Expand Up @@ -216,7 +280,7 @@ def train_test_datasets_of_pcp_df(pcp_df, train_frac=0.8, branch_length_multipli
subs_probs[:train_len],
subs_probs[train_len:],
)
val_dataset = DNSMDataset(
val_dataset = DNSMDataset.from_data(
val_parents,
val_children,
val_rates,
Expand All @@ -226,7 +290,7 @@ def train_test_datasets_of_pcp_df(pcp_df, train_frac=0.8, branch_length_multipli
if train_frac == 0.0:
return None, val_dataset
# else:
train_dataset = DNSMDataset(
train_dataset = DNSMDataset.from_data(
train_parents,
train_children,
train_rates,
Expand All @@ -242,13 +306,11 @@ def __init__(self, *args, **kwargs):
self.wrapped_model = WrappedBinaryMutSel(self.model, weights_directory=None)

def load_branch_lengths(self, in_csv_prefix):
if self.train_loader is not None:
self.train_loader.dataset.load_branch_lengths(
if self.train_dataset is not None:
self.train_dataset.load_branch_lengths(
in_csv_prefix + ".train_branch_lengths.csv"
)
self.val_loader.dataset.load_branch_lengths(
in_csv_prefix + ".val_branch_lengths.csv"
)
self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv")

def predictions_of_batch(self, batch):
"""
Expand Down Expand Up @@ -292,11 +354,13 @@ def _find_optimal_branch_length(
log_pcp_probability = self.wrapped_model._build_log_pcp_probability(
parent, child, rates, subs_probs
)
if type(starting_branch_length) == torch.Tensor:
starting_branch_length = starting_branch_length.detach().item()
return optimize_branch_length(
log_pcp_probability, starting_branch_length, **optimization_kwargs
)

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []
failed_count = 0

Expand Down Expand Up @@ -330,6 +394,16 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

return torch.tensor(optimal_lengths)

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
with mp.Pool(worker_count) as pool:
splits = dataset.split(worker_count)
results = pool.starmap(
worker_optimize_branch_length,
[(self.model, split, optimization_kwargs) for split in splits],
)
return torch.cat(results)

def to_crepe(self):
training_hyperparameters = {
key: self.__dict__[key]
Expand All @@ -341,6 +415,14 @@ def to_crepe(self):
return framework.Crepe(encoder, self.model, training_hyperparameters)


def worker_optimize_branch_length(model, dataset, optimization_kwargs):
"""
The worker used for parallel branch length optimization.
"""
burrito = DNSMBurrito(None, dataset, copy.deepcopy(model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)


class DNSMHyperBurrito(HyperBurrito):
# Note that we have to write the args out explicitly because we use some magic to filter kwargs in the optuna_objective method.
def burrito_of_model(
Expand All @@ -351,7 +433,6 @@ def burrito_of_model(
learning_rate=0.1,
min_learning_rate=1e-4,
l2_regularization_coeff=1e-6,
verbose=False,
):
model.to(device)
burrito = DNSMBurrito(
Expand All @@ -362,6 +443,5 @@ def burrito_of_model(
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
l2_regularization_coeff=l2_regularization_coeff,
verbose=verbose,
)
return burrito
Loading