diff --git a/netam/framework.py b/netam/framework.py index 8485d31a..c8a5fa43 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -240,10 +240,12 @@ def __init__(self, encoder, model, training_hyperparameters={}): self.encoder = encoder self.model = model self.training_hyperparameters = training_hyperparameters - self.device = None + + @property + def device(self): + return next(self.model.parameters()).device def to(self, device): - self.device = device self.model.to(device) def encode_sequences(self, sequences): @@ -262,11 +264,9 @@ def encode_sequences(self, sequences): def __call__(self, sequences): encoded_parents, masks, wt_base_modifiers = self.encode_sequences(sequences) - if self.device is not None: - print(f"using device {self.device}") - encoded_parents = encoded_parents.to(self.device) - masks = masks.to(self.device) - wt_base_modifiers = wt_base_modifiers.to(self.device) + encoded_parents = encoded_parents.to(self.device) + masks = masks.to(self.device) + wt_base_modifiers = wt_base_modifiers.to(self.device) with torch.no_grad(): outputs = self.model(encoded_parents, masks, wt_base_modifiers) return tuple(t.detach().cpu() for t in outputs) @@ -361,10 +361,6 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None): def add_shm_model_outputs_to_pcp_df(pcp_df, crepe_prefix, device=None): crepe = load_crepe(crepe_prefix, device=device) - print(f"supposed to be using {device}") - # get device of crepe.model - actual_device = next(crepe.model.parameters()).device - print(f"actual device is {actual_device}") rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"]) pcp_df["rates"] = rates pcp_df["subs_probs"] = csps