Skip to content

Commit

Permalink
address issue #22
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Nov 22, 2024
1 parent 8ca6e52 commit ebb30e0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 21 deletions.
14 changes: 14 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,17 @@ def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr):
epoch - warmup_epochs
)
return lr

def encode_sequences(sequences, encoder):
encoded_parents, wt_base_modifiers = zip(
*[encoder.encode_sequence(sequence) for sequence in sequences]
)
masks = [
nt_mask_tensor_of(sequence, encoder.site_count)
for sequence in sequences
]
return (
torch.stack(encoded_parents),
torch.stack(masks),
torch.stack(wt_base_modifiers),
)
26 changes: 5 additions & 21 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BASES_AND_N_TO_INDEX,
BIG,
VRC01_NT_SEQ,
encode_sequences,
)
from netam import models
import netam.molevol as molevol
Expand Down Expand Up @@ -250,6 +251,9 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.model = model
self.training_hyperparameters = training_hyperparameters

def __call__(self, sequences):
return self.model.selection_factors_of_sequences(sequences, encoder=self.encoder)

@property
def device(self):
return next(self.model.parameters()).device
Expand All @@ -258,27 +262,7 @@ def to(self, device):
self.model.to(device)

def encode_sequences(self, sequences):
encoded_parents, wt_base_modifiers = zip(
*[self.encoder.encode_sequence(sequence) for sequence in sequences]
)
masks = [
nt_mask_tensor_of(sequence, self.encoder.site_count)
for sequence in sequences
]
return (
torch.stack(encoded_parents),
torch.stack(masks),
torch.stack(wt_base_modifiers),
)

def __call__(self, sequences):
encoded_parents, masks, wt_base_modifiers = self.encode_sequences(sequences)
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
with torch.no_grad():
outputs = self.model(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)
return encode_sequences(sequences, self.encoder)

def save(self, prefix):
torch.save(self.model.state_dict(), f"{prefix}.pth")
Expand Down
17 changes: 17 additions & 0 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PositionalEncoding,
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
)

warnings.filterwarnings(
Expand Down Expand Up @@ -59,6 +60,19 @@ def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

def selection_factors_of_sequences(self, sequences, encoder=None):
if encoder is None:
raise ValueError("An encoder must be provided.")
device = next(self.parameters()).device
encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder)
encoded_parents = encoded_parents.to(device)
masks = masks.to(device)
wt_base_modifiers = wt_base_modifiers.to(device)
with torch.no_grad():
outputs = self(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)



class KmerModel(ModelBase):
def __init__(self, kmer_length):
Expand Down Expand Up @@ -536,6 +550,9 @@ class AbstractBinarySelectionModel(ABC, nn.Module):
def __init__(self):
super().__init__()

def selection_factors_of_sequences(self, sequences: list[str], **kwargs) -> Tensor:
return tuple(self.selection_factors_of_aa_str(seq) for seq in sequences)

def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Expand Down

0 comments on commit ebb30e0

Please sign in to comment.