Skip to content

Commit

Permalink
format and fix dasm test
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Oct 23, 2024
1 parent cc65c0f commit f1684fa
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 9 deletions.
4 changes: 3 additions & 1 deletion netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None):
)

@classmethod
def train_val_datasets_of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None):
def train_val_datasets_of_pcp_df(
cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None
):
"""Perform a train-val split based on the 'in_train' column.
This is a class method so it works for subclasses.
Expand Down
1 change: 0 additions & 1 deletion netam/hit_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor
]



def apply_multihit_correction(
parent_codon_idxs: torch.Tensor,
codon_probs: torch.Tensor,
Expand Down
6 changes: 5 additions & 1 deletion netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from netam.sequences import CODON_AA_INDICATOR_MATRIX

import netam.sequences as sequences

# torch.autograd.set_detect_anomaly(True)


Expand Down Expand Up @@ -521,7 +522,10 @@ def optimize_branch_length(
torch.nn.utils.clip_grad_norm_([log_branch_length], max_norm=5.0)
optimizer.step()
if torch.isnan(log_branch_length):
print("branch length optimization resulted in NAN, previous log branch length:", prev_log_branch_length)
print(
"branch length optimization resulted in NAN, previous log branch length:",
prev_log_branch_length,
)
if np.isclose(prev_log_branch_length.detach().numpy(), 0):
log_branch_length = prev_log_branch_length
nan_issue = True
Expand Down
14 changes: 8 additions & 6 deletions netam/multihit.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,7 @@ def child_codon_probs_corrected(
torch.Tensor: A (codon_count,) shaped tensor containing the corrected probabilities of each child codon.
"""

corrected_per_parent_probs = model(
parent_codon_idxs, uncorrected_per_parent_probs
)
corrected_per_parent_probs = model(parent_codon_idxs, uncorrected_per_parent_probs)
return child_codon_probs_from_per_parent_probs(
corrected_per_parent_probs, child_codon_idxs
)
Expand Down Expand Up @@ -374,9 +372,13 @@ def log_pcp_probability(log_branch_length):

child_codon_idxs = reshape_for_codons(child_idxs)[codon_mask]
parent_codon_idxs = reshape_for_codons(parent_idxs)[codon_mask]
return child_codon_probs_corrected(
codon_probs, parent_codon_idxs, child_codon_idxs, self.model
).log().sum()
return (
child_codon_probs_corrected(
codon_probs, parent_codon_idxs, child_codon_idxs, self.model
)
.log()
.sum()
)

return optimize_branch_length(
log_pcp_probability,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,21 @@
DASMDataset,
zap_predictions_along_diagonal,
)
import multiprocessing as mp


def force_spawn():
"""Force the spawn start method for multiprocessing.
This is necessary to avoid conflicts with the internal OpenMP-based thread pool in
PyTorch.
"""
mp.set_start_method("spawn", force=True)


@pytest.fixture(scope="module")
def dasm_burrito(pcp_df):
force_spawn()
"""Fixture that returns the DNSM Burrito object."""
pcp_df["in_train"] = True
pcp_df.loc[pcp_df.index[-15:], "in_train"] = False
Expand Down

0 comments on commit f1684fa

Please sign in to comment.