Skip to content

Commit

Permalink
Support for dnsm-experiments-1 PR#35 (#86)
Browse files Browse the repository at this point in the history
* add prefix class attribute to DXSM* inheritors
* move branch length loading
  • Loading branch information
willdumm authored Nov 14, 2024
1 parent 34d26cf commit f418d62
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
3 changes: 3 additions & 0 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


class DASMDataset(DXSMDataset):
prefix = "dasm"

def update_neutral_probs(self):
neutral_aa_probs_l = []
Expand Down Expand Up @@ -115,6 +116,8 @@ def zap_predictions_along_diagonal(predictions, aa_parents_idxs):


class DASMBurrito(framework.TwoLossMixin, DXSMBurrito):
prefix = "dasm"

def __init__(self, *args, loss_weights: list = [1.0, 0.01], **kwargs):
super().__init__(*args, **kwargs)
self.xent_loss = torch.nn.CrossEntropyLoss()
Expand Down
10 changes: 3 additions & 7 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


class DNSMDataset(DXSMDataset):
prefix = "dnsm"

def update_neutral_probs(self):
"""Update the neutral mutation probabilities for the dataset.
Expand Down Expand Up @@ -108,16 +109,11 @@ def to(self, device):


class DNSMBurrito(DXSMBurrito):
prefix = "dnsm"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def load_branch_lengths(self, in_csv_prefix):
if self.train_dataset is not None:
self.train_dataset.load_branch_lengths(
in_csv_prefix + ".train_branch_lengths.csv"
)
self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv")

def prediction_pair_of_batch(self, batch):
"""Get log neutral amino acid substitution probabilities and log selection
factors for a batch of data."""
Expand Down
11 changes: 11 additions & 0 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@


class DXSMDataset(Dataset, ABC):
prefix = "dxsm"

def __init__(
self,
nt_parents: pd.Series,
Expand Down Expand Up @@ -238,6 +240,8 @@ def update_neutral_probs(self):


class DXSMBurrito(framework.Burrito, ABC):
prefix = "dxsm"

def _find_optimal_branch_length(
self,
parent,
Expand Down Expand Up @@ -312,6 +316,13 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
)
return torch.cat(results)

def load_branch_lengths(self, in_csv_prefix):
if self.train_dataset is not None:
self.train_dataset.load_branch_lengths(
in_csv_prefix + ".train_branch_lengths.csv"
)
self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv")

def to_crepe(self):
training_hyperparameters = {
key: self.__dict__[key]
Expand Down

0 comments on commit f418d62

Please sign in to comment.