Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 6, 2024
1 parent 3f3710c commit db0607f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
5 changes: 4 additions & 1 deletion netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def find_least_used_cuda_gpu(default_value):
gpu_to_use = default_value
else:
gpu_to_use = utilization.index(min(utilization))
print(f"Picking GPU {gpu_to_use}; default {default_value}; utilization: {utilization}")
print(
f"Picking GPU {gpu_to_use}; default {default_value}; utilization: {utilization}"
)
return gpu_to_use


Expand All @@ -169,6 +171,7 @@ def pick_device(jobid):
Pick a device for PyTorch to use. If CUDA is available, use the least used
GPU, and if all are idle use a GPU based on the jobid.
"""

# check that CUDA is usable
def check_CUDA():
try:
Expand Down
45 changes: 21 additions & 24 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,35 +264,32 @@ def to(self, device):
self.all_subs_probs = self.all_subs_probs.to(device)


def train_test_datasets_of_pcp_df(pcp_df, train_frac=0.8, branch_length_multiplier=5.0):
nt_parents = pcp_df["parent"].reset_index(drop=True)
nt_children = pcp_df["child"].reset_index(drop=True)
rates = pcp_df["rates"].reset_index(drop=True)
subs_probs = pcp_df["subs_probs"].reset_index(drop=True)

train_len = int(train_frac * len(nt_parents))
train_parents, val_parents = nt_parents[:train_len], nt_parents[train_len:]
train_children, val_children = nt_children[:train_len], nt_children[train_len:]
train_rates, val_rates = rates[:train_len], rates[train_len:]
train_subs_probs, val_subs_probs = (
subs_probs[:train_len],
subs_probs[train_len:],
def dataset_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
return DNSMDataset.from_data(
pcp_df["parents"],
pcp_df["children"],
pcp_df["rates"],
pcp_df["subs_probs"],
branch_length_multiplier=branch_length_multiplier,
)
val_dataset = DNSMDataset.from_data(
val_parents,
val_children,
val_rates,
val_subs_probs,


def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
"""
Perform a train-val split based on a "in_train" column.
"""
train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True)
val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True)

val_dataset = dataset_of_pcp_df(
val_df,
branch_length_multiplier=branch_length_multiplier,
)
if train_frac == 0.0:
if len(train_df) == 0:
return None, val_dataset
# else:
train_dataset = DNSMDataset.from_data(
train_parents,
train_children,
train_rates,
train_subs_probs,
train_dataset = dataset_of_pcp_df(
train_df,
branch_length_multiplier=branch_length_multiplier,
)
return train_dataset, val_dataset
Expand Down

0 comments on commit db0607f

Please sign in to comment.