From f418d62dcf9b1c939e5b847f727bdd1dd0c1f18d Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 14 Nov 2024 12:30:24 -0900 Subject: [PATCH] Support for dnsm-experiments-1 PR#35 (#86) * add prefix class attribute to DXSM* inheritors * move branch length loading --- netam/dasm.py | 3 +++ netam/dnsm.py | 10 +++------- netam/dxsm.py | 11 +++++++++++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 7b30f733..55dbcfff 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -14,6 +14,7 @@ class DASMDataset(DXSMDataset): + prefix = "dasm" def update_neutral_probs(self): neutral_aa_probs_l = [] @@ -115,6 +116,8 @@ def zap_predictions_along_diagonal(predictions, aa_parents_idxs): class DASMBurrito(framework.TwoLossMixin, DXSMBurrito): + prefix = "dasm" + def __init__(self, *args, loss_weights: list = [1.0, 0.01], **kwargs): super().__init__(*args, **kwargs) self.xent_loss = torch.nn.CrossEntropyLoss() diff --git a/netam/dnsm.py b/netam/dnsm.py index f70c128a..cfa943b7 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -15,6 +15,7 @@ class DNSMDataset(DXSMDataset): + prefix = "dnsm" def update_neutral_probs(self): """Update the neutral mutation probabilities for the dataset. @@ -108,16 +109,11 @@ def to(self, device): class DNSMBurrito(DXSMBurrito): + prefix = "dnsm" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def load_branch_lengths(self, in_csv_prefix): - if self.train_dataset is not None: - self.train_dataset.load_branch_lengths( - in_csv_prefix + ".train_branch_lengths.csv" - ) - self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv") - def prediction_pair_of_batch(self, batch): """Get log neutral amino acid substitution probabilities and log selection factors for a batch of data.""" diff --git a/netam/dxsm.py b/netam/dxsm.py index e4876a2f..007e0240 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -31,6 +31,8 @@ class DXSMDataset(Dataset, ABC): + prefix = "dxsm" + def __init__( self, nt_parents: pd.Series, @@ -238,6 +240,8 @@ def update_neutral_probs(self): class DXSMBurrito(framework.Burrito, ABC): + prefix = "dxsm" + def _find_optimal_branch_length( self, parent, @@ -312,6 +316,13 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): ) return torch.cat(results) + def load_branch_lengths(self, in_csv_prefix): + if self.train_dataset is not None: + self.train_dataset.load_branch_lengths( + in_csv_prefix + ".train_branch_lengths.csv" + ) + self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv") + def to_crepe(self): training_hyperparameters = { key: self.__dict__[key]