From 193822b0fedec7929a8d51ae46d059ca45ccdcd7 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 30 Sep 2024 06:01:10 -0700 Subject: [PATCH] make format --- netam/dasm.py | 19 +++++++++---------- tests/test_dasm.py | 14 ++++++++++---- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 96566ea7..72f61489 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -33,6 +33,7 @@ translate_sequences, ) + class DASMDataset(dnsm.DNSMDataset): # TODO should we rename this? @@ -74,9 +75,7 @@ def update_neutral_aa_mut_probs(self): print(f"rates: {rates}") print(f"subs_probs: {subs_probs}") print(f"branch_length: {branch_length}") - raise ValueError( - f"neutral_aa_probs is not finite: {neutral_aa_probs}" - ) + raise ValueError(f"neutral_aa_probs is not finite: {neutral_aa_probs}") # Ensure that all values are positive before taking the log later neutral_aa_probs = clamp_probability(neutral_aa_probs) @@ -134,13 +133,11 @@ def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0): return train_dataset, val_dataset - - class DASMBurrito(dnsm.DNSMBurrito): def prediction_pair_of_batch(self, batch): - """Get log neutral AA probabilities and log selection factors for a batch - of data.""" + """Get log neutral AA probabilities and log selection factors for a batch of + data.""" aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) mask = batch["mask"].to(self.device) log_neutral_aa_probs = batch["log_neutral_aa_probs"].to(self.device) @@ -182,15 +179,17 @@ def loss_of_batch(self, batch): [predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1 ) # Now we make predictions of mutation by taking everything off the diagonal. - #predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0 + # predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0 # Get batch size and sequence length batch_size, L, _ = predictions.shape # Create indices for each batch batch_indices = torch.arange(batch_size, device=self.device) # Zero out the diagonal by setting predictions[batch_idx, site_idx, aa_idx] to 0 - predictions[batch_indices[:, None], torch.arange(L, device=self.device), aa_parents_idxs] = 0.0 - + predictions[ + batch_indices[:, None], torch.arange(L, device=self.device), aa_parents_idxs + ] = 0.0 + predictions_of_mut = torch.sum(predictions, dim=-1) predictions_of_mut = predictions_of_mut.masked_select(mask) return self.bce_loss(predictions_of_mut, aa_subs_indicator) diff --git a/tests/test_dasm.py b/tests/test_dasm.py index 6b563e69..ece02a2f 100644 --- a/tests/test_dasm.py +++ b/tests/test_dasm.py @@ -11,10 +11,10 @@ ) from netam.common import aa_idx_tensor_of_str_ambig, MAX_AMBIG_AA_IDX from netam.models import TransformerBinarySelectionModelWiggleAct -from netam.dasm import DASMBurrito,train_val_datasets_of_pcp_df +from netam.dasm import DASMBurrito, train_val_datasets_of_pcp_df -#TODO code dup +# TODO code dup @pytest.fixture def pcp_df(): df = load_pcp_df( @@ -35,7 +35,11 @@ def dasm_burrito(pcp_df): train_dataset, val_dataset = train_val_datasets_of_pcp_df(pcp_df) model = TransformerBinarySelectionModelWiggleAct( - nhead=2, d_model_per_head=4, dim_feedforward=256, layer_count=2, output_dim=20, + nhead=2, + d_model_per_head=4, + dim_feedforward=256, + layer_count=2, + output_dim=20, ) burrito = DASMBurrito( @@ -46,7 +50,9 @@ def dasm_burrito(pcp_df): learning_rate=0.001, min_learning_rate=0.0001, ) - burrito.joint_train(epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False) + burrito.joint_train( + epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False + ) return burrito