diff --git a/netam/common.py b/netam/common.py index c5dd2d96..720bc314 100644 --- a/netam/common.py +++ b/netam/common.py @@ -160,7 +160,9 @@ def find_least_used_cuda_gpu(default_value): gpu_to_use = default_value else: gpu_to_use = utilization.index(min(utilization)) - print(f"Picking GPU {gpu_to_use}; default {default_value}; utilization: {utilization}") + print( + f"Picking GPU {gpu_to_use}; default {default_value}; utilization: {utilization}" + ) return gpu_to_use @@ -169,6 +171,7 @@ def pick_device(jobid): Pick a device for PyTorch to use. If CUDA is available, use the least used GPU, and if all are idle use a GPU based on the jobid. """ + # check that CUDA is usable def check_CUDA(): try: diff --git a/netam/dnsm.py b/netam/dnsm.py index 6cff9e7a..6b9c5a4b 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -264,35 +264,32 @@ def to(self, device): self.all_subs_probs = self.all_subs_probs.to(device) -def train_test_datasets_of_pcp_df(pcp_df, train_frac=0.8, branch_length_multiplier=5.0): - nt_parents = pcp_df["parent"].reset_index(drop=True) - nt_children = pcp_df["child"].reset_index(drop=True) - rates = pcp_df["rates"].reset_index(drop=True) - subs_probs = pcp_df["subs_probs"].reset_index(drop=True) - - train_len = int(train_frac * len(nt_parents)) - train_parents, val_parents = nt_parents[:train_len], nt_parents[train_len:] - train_children, val_children = nt_children[:train_len], nt_children[train_len:] - train_rates, val_rates = rates[:train_len], rates[train_len:] - train_subs_probs, val_subs_probs = ( - subs_probs[:train_len], - subs_probs[train_len:], +def dataset_of_pcp_df(pcp_df, branch_length_multiplier=5.0): + return DNSMDataset.from_data( + pcp_df["parents"], + pcp_df["children"], + pcp_df["rates"], + pcp_df["subs_probs"], + branch_length_multiplier=branch_length_multiplier, ) - val_dataset = DNSMDataset.from_data( - val_parents, - val_children, - val_rates, - val_subs_probs, + + +def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0): + """ + Perform a train-val split based on a "in_train" column. + """ + train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True) + val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True) + + val_dataset = dataset_of_pcp_df( + val_df, branch_length_multiplier=branch_length_multiplier, ) - if train_frac == 0.0: + if len(train_df) == 0: return None, val_dataset # else: - train_dataset = DNSMDataset.from_data( - train_parents, - train_children, - train_rates, - train_subs_probs, + train_dataset = dataset_of_pcp_df( + train_df, branch_length_multiplier=branch_length_multiplier, ) return train_dataset, val_dataset