diff --git a/netam/common.py b/netam/common.py index 992eb2e0..0250aa83 100644 --- a/netam/common.py +++ b/netam/common.py @@ -368,3 +368,15 @@ def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr): epoch - warmup_epochs ) return lr + + +def encode_sequences(sequences, encoder): + encoded_parents, wt_base_modifiers = zip( + *[encoder.encode_sequence(sequence) for sequence in sequences] + ) + masks = [nt_mask_tensor_of(sequence, encoder.site_count) for sequence in sequences] + return ( + torch.stack(encoded_parents), + torch.stack(masks), + torch.stack(wt_base_modifiers), + ) diff --git a/netam/dasm.py b/netam/dasm.py index 8a2b16e8..e3712fe2 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -11,6 +11,7 @@ import netam.framework as framework import netam.molevol as molevol import netam.sequences as sequences +import copy class DASMDataset(DXSMDataset): diff --git a/netam/dxsm.py b/netam/dxsm.py index c86733bf..9c7157e3 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -4,7 +4,6 @@ from functools import partial import torch -from torch.utils.data import Dataset # Amazingly, using one thread makes things 50x faster for branch length # optimization on our server. @@ -32,7 +31,7 @@ ) -class DXSMDataset(Dataset, ABC): +class DXSMDataset(framework.BranchLengthDataset, ABC): prefix = "dxsm" def __init__( @@ -222,19 +221,6 @@ def branch_lengths(self, new_branch_lengths): self._branch_lengths = new_branch_lengths self.update_neutral_probs() - def __len__(self): - return len(self.aa_parents_idxss) - - def export_branch_lengths(self, out_csv_path): - pd.DataFrame({"branch_length": self.branch_lengths}).to_csv( - out_csv_path, index=False - ) - - def load_branch_lengths(self, in_csv_path): - self.branch_lengths = torch.Tensor( - pd.read_csv(in_csv_path)["branch_length"].values - ) - @abstractmethod def update_neutral_probs(self): pass diff --git a/netam/framework.py b/netam/framework.py index dd04231b..be7cf500 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -25,6 +25,7 @@ BASES_AND_N_TO_INDEX, BIG, VRC01_NT_SEQ, + encode_sequences, ) from netam import models import netam.molevol as molevol @@ -132,7 +133,23 @@ def parameters(self): return {} -class SHMoofDataset(Dataset): +class BranchLengthDataset(Dataset): + def __len__(self): + return len(self.branch_lengths) + + def export_branch_lengths(self, out_csv_path): + pd.DataFrame({"branch_length": self.branch_lengths}).to_csv( + out_csv_path, index=False + ) + + def load_branch_lengths(self, in_csv_path): + self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values + + def __repr__(self): + return f"{self.__class__.__name__}(Size: {len(self)}) on {self.branch_lengths.device}" + + +class SHMoofDataset(BranchLengthDataset): def __init__(self, dataframe, kmer_length, site_count): super().__init__() self.encoder = KmerSequenceEncoder(kmer_length, site_count) @@ -146,9 +163,6 @@ def __init__(self, dataframe, kmer_length, site_count): ) = self.encode_pcps(dataframe) assert self.encoded_parents.shape[0] == self.branch_lengths.shape[0] - def __len__(self): - return len(self.encoded_parents) - def __getitem__(self, idx): return ( self.encoded_parents[idx], @@ -159,9 +173,6 @@ def __getitem__(self, idx): self.branch_lengths[idx], ) - def __repr__(self): - return f"{self.__class__.__name__}(Size: {len(self)}) on {self.encoded_parents.device}" - def to(self, device): self.encoded_parents = self.encoded_parents.to(device) self.masks = self.masks.to(device) @@ -224,9 +235,6 @@ def export_branch_lengths(self, out_csv_path): } ).to_csv(out_csv_path, index=False) - def load_branch_lengths(self, in_csv_path): - self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values - class Crepe: """A lightweight wrapper around a model that can be used for prediction but not @@ -243,6 +251,9 @@ def __init__(self, encoder, model, training_hyperparameters={}): self.model = model self.training_hyperparameters = training_hyperparameters + def __call__(self, sequences): + return self.model.evaluate_sequences(sequences, encoder=self.encoder) + @property def device(self): return next(self.model.parameters()).device @@ -251,27 +262,7 @@ def to(self, device): self.model.to(device) def encode_sequences(self, sequences): - encoded_parents, wt_base_modifiers = zip( - *[self.encoder.encode_sequence(sequence) for sequence in sequences] - ) - masks = [ - nt_mask_tensor_of(sequence, self.encoder.site_count) - for sequence in sequences - ] - return ( - torch.stack(encoded_parents), - torch.stack(masks), - torch.stack(wt_base_modifiers), - ) - - def __call__(self, sequences): - encoded_parents, masks, wt_base_modifiers = self.encode_sequences(sequences) - encoded_parents = encoded_parents.to(self.device) - masks = masks.to(self.device) - wt_base_modifiers = wt_base_modifiers.to(self.device) - with torch.no_grad(): - outputs = self.model(encoded_parents, masks, wt_base_modifiers) - return tuple(t.detach().cpu() for t in outputs) + return encode_sequences(sequences, self.encoder) def save(self, prefix): torch.save(self.model.state_dict(), f"{prefix}.pth") diff --git a/netam/models.py b/netam/models.py index 0f5b2854..e5b13673 100644 --- a/netam/models.py +++ b/netam/models.py @@ -16,6 +16,7 @@ PositionalEncoding, generate_kmers, aa_mask_tensor_of, + encode_sequences, ) warnings.filterwarnings( @@ -49,6 +50,10 @@ def reinitialize_weights(self): else: raise ValueError(f"Unrecognized layer type: {type(layer)}") + @property + def device(self): + return next(self.parameters()).device + def freeze(self): """Freeze all parameters in the model, disabling gradient computations.""" for param in self.parameters(): @@ -59,6 +64,17 @@ def unfreeze(self): for param in self.parameters(): param.requires_grad = True + def evaluate_sequences(self, sequences, encoder=None): + if encoder is None: + raise ValueError("An encoder must be provided.") + encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder) + encoded_parents = encoded_parents.to(self.device) + masks = masks.to(self.device) + wt_base_modifiers = wt_base_modifiers.to(self.device) + with torch.no_grad(): + outputs = self(encoded_parents, masks, wt_base_modifiers) + return tuple(t.detach().cpu() for t in outputs) + class KmerModel(ModelBase): def __init__(self, kmer_length): @@ -536,6 +552,13 @@ class AbstractBinarySelectionModel(ABC, nn.Module): def __init__(self): super().__init__() + @property + def device(self): + return next(self.parameters()).device + + def evaluate_sequences(self, sequences: list[str], **kwargs) -> Tensor: + return tuple(self.selection_factors_of_aa_str(seq) for seq in sequences) + def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: """Do the forward method then exponentiation without gradients from an amino acid string. @@ -548,12 +571,10 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: the level of selection for each amino acid at each site. """ - model_device = next(self.parameters()).device - aa_idxs = aa_idx_tensor_of_str_ambig(aa_str) - aa_idxs = aa_idxs.to(model_device) + aa_idxs = aa_idxs.to(self.device) mask = aa_mask_tensor_of(aa_str) - mask = mask.to(model_device) + mask = mask.to(self.device) with torch.no_grad(): model_out = self(aa_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0)