Skip to content

Commit

Permalink
simplifying device handling of crepes
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 6, 2024
1 parent 6455ab6 commit 647dde3
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,12 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.encoder = encoder
self.model = model
self.training_hyperparameters = training_hyperparameters
self.device = None

@property
def device(self):
return next(self.model.parameters()).device

def to(self, device):
self.device = device
self.model.to(device)

def encode_sequences(self, sequences):
Expand All @@ -262,11 +264,9 @@ def encode_sequences(self, sequences):

def __call__(self, sequences):
encoded_parents, masks, wt_base_modifiers = self.encode_sequences(sequences)
if self.device is not None:
print(f"using device {self.device}")
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
with torch.no_grad():
outputs = self.model(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)
Expand Down Expand Up @@ -361,10 +361,6 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):

def add_shm_model_outputs_to_pcp_df(pcp_df, crepe_prefix, device=None):
crepe = load_crepe(crepe_prefix, device=device)
print(f"supposed to be using {device}")
# get device of crepe.model
actual_device = next(crepe.model.parameters()).device
print(f"actual device is {actual_device}")
rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
pcp_df["rates"] = rates
pcp_df["subs_probs"] = csps
Expand Down

0 comments on commit 647dde3

Please sign in to comment.