diff --git a/netam/common.py b/netam/common.py index 31279e66..0250aa83 100644 --- a/netam/common.py +++ b/netam/common.py @@ -369,14 +369,12 @@ def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr): ) return lr + def encode_sequences(sequences, encoder): encoded_parents, wt_base_modifiers = zip( *[encoder.encode_sequence(sequence) for sequence in sequences] ) - masks = [ - nt_mask_tensor_of(sequence, encoder.site_count) - for sequence in sequences - ] + masks = [nt_mask_tensor_of(sequence, encoder.site_count) for sequence in sequences] return ( torch.stack(encoded_parents), torch.stack(masks), diff --git a/netam/framework.py b/netam/framework.py index 73ca5831..2a661760 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -252,7 +252,9 @@ def __init__(self, encoder, model, training_hyperparameters={}): self.training_hyperparameters = training_hyperparameters def __call__(self, sequences): - return self.model.selection_factors_of_sequences(sequences, encoder=self.encoder) + return self.model.selection_factors_of_sequences( + sequences, encoder=self.encoder + ) @property def device(self): diff --git a/netam/models.py b/netam/models.py index 72cee9f7..40e8e576 100644 --- a/netam/models.py +++ b/netam/models.py @@ -73,7 +73,6 @@ def selection_factors_of_sequences(self, sequences, encoder=None): return tuple(t.detach().cpu() for t in outputs) - class KmerModel(ModelBase): def __init__(self, kmer_length): super().__init__()