Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Sep 25, 2024
1 parent 1a67b48 commit cc0922d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
12 changes: 9 additions & 3 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,21 +280,27 @@ def to(self, device):
self.all_subs_probs = self.all_subs_probs.to(device)


def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0, multihit_model=None):
def train_val_datasets_of_pcp_df(
pcp_df, branch_length_multiplier=5.0, multihit_model=None
):
"""Perform a train-val split based on a "in_train" column.
Stays here so it can be used in tests.
"""
train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True)
val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True)
val_dataset = DNSMDataset.of_pcp_df(
val_df, branch_length_multiplier=branch_length_multiplier, multihit_model=multihit_model,
val_df,
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)
if len(train_df) == 0:
return None, val_dataset
# else:
train_dataset = DNSMDataset.of_pcp_df(
train_df, branch_length_multiplier=branch_length_multiplier, multihit_model=multihit_model,
train_df,
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)
return train_dataset, val_dataset

Expand Down
4 changes: 3 additions & 1 deletion netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def neutral_aa_mut_probs(
return 1.0 - p_staying_same


def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs, multihit_model=None):
def mutsel_log_pcp_probability_of(
sel_matrix, parent, child, rates, sub_probs, multihit_model=None
):
"""Constructs the log_pcp_probability function specific to given rates and
sub_probs.
Expand Down

0 comments on commit cc0922d

Please sign in to comment.