Skip to content

Commit

Permalink
New Dataset superclass (#89)
Browse files Browse the repository at this point in the history
* Addressing #55, I consolidating some branch lengths-related code into a new superclass BranchLengthDataset.
* addressed #22 by moving some methods and functions around. All pre-existing functions and methods should work as they did before. Now, a DXSM crepe can be called on sequences to get selection factors. Previously, this required calling the method `selection_factors_of_aa_str` on a DXSM model.
  • Loading branch information
willdumm authored Nov 26, 2024
1 parent 24a19bd commit 469420e
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 50 deletions.
12 changes: 12 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,15 @@ def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr):
epoch - warmup_epochs
)
return lr


def encode_sequences(sequences, encoder):
encoded_parents, wt_base_modifiers = zip(
*[encoder.encode_sequence(sequence) for sequence in sequences]
)
masks = [nt_mask_tensor_of(sequence, encoder.site_count) for sequence in sequences]
return (
torch.stack(encoded_parents),
torch.stack(masks),
torch.stack(wt_base_modifiers),
)
1 change: 1 addition & 0 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
import copy


class DASMDataset(DXSMDataset):
Expand Down
16 changes: 1 addition & 15 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial

import torch
from torch.utils.data import Dataset

# Amazingly, using one thread makes things 50x faster for branch length
# optimization on our server.
Expand Down Expand Up @@ -32,7 +31,7 @@
)


class DXSMDataset(Dataset, ABC):
class DXSMDataset(framework.BranchLengthDataset, ABC):
prefix = "dxsm"

def __init__(
Expand Down Expand Up @@ -222,19 +221,6 @@ def branch_lengths(self, new_branch_lengths):
self._branch_lengths = new_branch_lengths
self.update_neutral_probs()

def __len__(self):
return len(self.aa_parents_idxss)

def export_branch_lengths(self, out_csv_path):
pd.DataFrame({"branch_length": self.branch_lengths}).to_csv(
out_csv_path, index=False
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = torch.Tensor(
pd.read_csv(in_csv_path)["branch_length"].values
)

@abstractmethod
def update_neutral_probs(self):
pass
Expand Down
53 changes: 22 additions & 31 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BASES_AND_N_TO_INDEX,
BIG,
VRC01_NT_SEQ,
encode_sequences,
)
from netam import models
import netam.molevol as molevol
Expand Down Expand Up @@ -132,7 +133,23 @@ def parameters(self):
return {}


class SHMoofDataset(Dataset):
class BranchLengthDataset(Dataset):
def __len__(self):
return len(self.branch_lengths)

def export_branch_lengths(self, out_csv_path):
pd.DataFrame({"branch_length": self.branch_lengths}).to_csv(
out_csv_path, index=False
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.branch_lengths.device}"


class SHMoofDataset(BranchLengthDataset):
def __init__(self, dataframe, kmer_length, site_count):
super().__init__()
self.encoder = KmerSequenceEncoder(kmer_length, site_count)
Expand All @@ -146,9 +163,6 @@ def __init__(self, dataframe, kmer_length, site_count):
) = self.encode_pcps(dataframe)
assert self.encoded_parents.shape[0] == self.branch_lengths.shape[0]

def __len__(self):
return len(self.encoded_parents)

def __getitem__(self, idx):
return (
self.encoded_parents[idx],
Expand All @@ -159,9 +173,6 @@ def __getitem__(self, idx):
self.branch_lengths[idx],
)

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.encoded_parents.device}"

def to(self, device):
self.encoded_parents = self.encoded_parents.to(device)
self.masks = self.masks.to(device)
Expand Down Expand Up @@ -224,9 +235,6 @@ def export_branch_lengths(self, out_csv_path):
}
).to_csv(out_csv_path, index=False)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values


class Crepe:
"""A lightweight wrapper around a model that can be used for prediction but not
Expand All @@ -243,6 +251,9 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.model = model
self.training_hyperparameters = training_hyperparameters

def __call__(self, sequences):
return self.model.evaluate_sequences(sequences, encoder=self.encoder)

@property
def device(self):
return next(self.model.parameters()).device
Expand All @@ -251,27 +262,7 @@ def to(self, device):
self.model.to(device)

def encode_sequences(self, sequences):
encoded_parents, wt_base_modifiers = zip(
*[self.encoder.encode_sequence(sequence) for sequence in sequences]
)
masks = [
nt_mask_tensor_of(sequence, self.encoder.site_count)
for sequence in sequences
]
return (
torch.stack(encoded_parents),
torch.stack(masks),
torch.stack(wt_base_modifiers),
)

def __call__(self, sequences):
encoded_parents, masks, wt_base_modifiers = self.encode_sequences(sequences)
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
with torch.no_grad():
outputs = self.model(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)
return encode_sequences(sequences, self.encoder)

def save(self, prefix):
torch.save(self.model.state_dict(), f"{prefix}.pth")
Expand Down
29 changes: 25 additions & 4 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PositionalEncoding,
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
)

warnings.filterwarnings(
Expand Down Expand Up @@ -49,6 +50,10 @@ def reinitialize_weights(self):
else:
raise ValueError(f"Unrecognized layer type: {type(layer)}")

@property
def device(self):
return next(self.parameters()).device

def freeze(self):
"""Freeze all parameters in the model, disabling gradient computations."""
for param in self.parameters():
Expand All @@ -59,6 +64,17 @@ def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

def evaluate_sequences(self, sequences, encoder=None):
if encoder is None:
raise ValueError("An encoder must be provided.")
encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder)
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
with torch.no_grad():
outputs = self(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)


class KmerModel(ModelBase):
def __init__(self, kmer_length):
Expand Down Expand Up @@ -536,6 +552,13 @@ class AbstractBinarySelectionModel(ABC, nn.Module):
def __init__(self):
super().__init__()

@property
def device(self):
return next(self.parameters()).device

def evaluate_sequences(self, sequences: list[str], **kwargs) -> Tensor:
return tuple(self.selection_factors_of_aa_str(seq) for seq in sequences)

def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Expand All @@ -548,12 +571,10 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
the level of selection for each amino acid at each site.
"""

model_device = next(self.parameters()).device

aa_idxs = aa_idx_tensor_of_str_ambig(aa_str)
aa_idxs = aa_idxs.to(model_device)
aa_idxs = aa_idxs.to(self.device)
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(model_device)
mask = mask.to(self.device)

with torch.no_grad():
model_out = self(aa_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0)
Expand Down

0 comments on commit 469420e

Please sign in to comment.