Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Nov 22, 2024
1 parent ebb30e0 commit b829c01
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
6 changes: 2 additions & 4 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,12 @@ def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr):
)
return lr


def encode_sequences(sequences, encoder):
encoded_parents, wt_base_modifiers = zip(
*[encoder.encode_sequence(sequence) for sequence in sequences]
)
masks = [
nt_mask_tensor_of(sequence, encoder.site_count)
for sequence in sequences
]
masks = [nt_mask_tensor_of(sequence, encoder.site_count) for sequence in sequences]
return (
torch.stack(encoded_parents),
torch.stack(masks),
Expand Down
4 changes: 3 additions & 1 deletion netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.training_hyperparameters = training_hyperparameters

def __call__(self, sequences):
return self.model.selection_factors_of_sequences(sequences, encoder=self.encoder)
return self.model.selection_factors_of_sequences(
sequences, encoder=self.encoder
)

@property
def device(self):
Expand Down
1 change: 0 additions & 1 deletion netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def selection_factors_of_sequences(self, sequences, encoder=None):
return tuple(t.detach().cpu() for t in outputs)



class KmerModel(ModelBase):
def __init__(self, kmer_length):
super().__init__()
Expand Down

0 comments on commit b829c01

Please sign in to comment.