Skip to content

Commit

Permalink
moving in molevol and sequences.py
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 10, 2024
1 parent 8a3ddc3 commit ea24a89
Show file tree
Hide file tree
Showing 4 changed files with 597 additions and 12 deletions.
16 changes: 7 additions & 9 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,7 @@

from tqdm import tqdm

from epam.torch_common import optimize_branch_length
from epam.models import WrappedBinaryMutSel
import epam.molevol as molevol
import epam.sequences as sequences
from epam.sequences import (
aa_subs_indicator_tensor_of,
translate_sequences,
)

from netam.common import (
MAX_AMBIG_AA_IDX,
Expand All @@ -42,7 +35,12 @@
)
import netam.framework as framework
from netam.hyper_burrito import HyperBurrito

import netam.molevol as molevol
import netam.sequences as sequences
from netam.sequences import (
aa_subs_indicator_tensor_of,
translate_sequences,
)

class DNSMDataset(Dataset):
def __init__(
Expand Down Expand Up @@ -357,7 +355,7 @@ def _find_optimal_branch_length(
)
if type(starting_branch_length) == torch.Tensor:
starting_branch_length = starting_branch_length.detach().item()
return optimize_branch_length(
return molevol.optimize_branch_length(
log_pcp_probability, starting_branch_length, **optimization_kwargs
)

Expand Down
5 changes: 2 additions & 3 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
VRC01_NT_SEQ,
)
from netam import models

from epam.torch_common import optimize_branch_length
from netam import molevol


def encode_mut_pos_and_base(parent, child, site_count=None):
Expand Down Expand Up @@ -905,7 +904,7 @@ def log_pcp_probability(log_branch_length):
rate_loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked)
return -rate_loss

return optimize_branch_length(
return molevol.optimize_branch_length(
log_pcp_probability,
starting_branch_length.double().item(),
**optimization_kwargs,
Expand Down
Loading

0 comments on commit ea24a89

Please sign in to comment.