Skip to content

Commit

Permalink
remove prints and add dasm test
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Nov 19, 2024
1 parent 152367c commit 776ab82
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
1 change: 0 additions & 1 deletion netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def update_neutral_probs(self):
"""
neutral_aa_mut_prob_l = []

print("starting update_neutral_probs loop")
for nt_parent, mask, nt_rates, nt_csps, branch_length in zip(
self.nt_parents,
self.masks,
Expand Down
31 changes: 31 additions & 0 deletions tests/test_ambiguous.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,34 @@ def test_dnsm_burrito(ambig_pcp_df, dnsm_model):
)
burrito.joint_train(epochs=1, cycle_count=2, training_method="full")
return burrito

@pytest.fixture
def dasm_model():
return TransformerBinarySelectionModelWiggleAct(
nhead=2,
d_model_per_head=4,
dim_feedforward=256,
layer_count=2,
output_dim=20,
)


def test_dasm_burrito(ambig_pcp_df, dasm_model):
force_spawn()
"""Fixture that returns the DNSM Burrito object."""
ambig_pcp_df["in_train"] = True
ambig_pcp_df.loc[ambig_pcp_df.index[-15:], "in_train"] = False
train_dataset, val_dataset = DASMDataset.train_val_datasets_of_pcp_df(ambig_pcp_df)

burrito = DASMBurrito(
train_dataset,
val_dataset,
dasm_model,
batch_size=32,
learning_rate=0.001,
min_learning_rate=0.0001,
)
burrito.joint_train(
epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False
)
return burrito

0 comments on commit 776ab82

Please sign in to comment.