diff --git a/netam/framework.py b/netam/framework.py index 63f89be1..97b6726f 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -360,6 +360,10 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None): def add_shm_model_outputs_to_pcp_df(pcp_df, crepe_prefix, device=None): crepe = load_crepe(crepe_prefix, device=device) + print(f"supposed to be using {device}") + # get device of crepe.model + actual_device = next(crepe.model.parameters()).device + print(f"actual device is {actual_device}") rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"]) pcp_df["rates"] = rates pcp_df["subs_probs"] = csps