diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 33657d9b..b561159c 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -40,10 +40,18 @@ jobs: create-args: >- python=${{ matrix.python-version }} datrie + black + flake8 init-shell: bash cache-environment: false post-cleanup: 'none' + - name: Check format + shell: bash -l {0} + run: | + cd main + make checkformat + - name: Install shell: bash -l {0} run: | diff --git a/Makefile b/Makefile index ccb03b06..cce8f43f 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,9 @@ test: format: black netam tests +checkformat: + black --check netam tests + lint: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=_ignore diff --git a/netam/framework.py b/netam/framework.py index ec13b9d0..e4ceb0dd 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -396,11 +396,13 @@ def __init__( def device(self): return next(self.model.parameters()).device - def reset_optimization(self): + def reset_optimization(self, learning_rate=None): """Reset the optimizer and scheduler.""" + if learning_rate is None: + learning_rate = self.learning_rate self.optimizer = torch.optim.Adam( self.model.parameters(), - lr=self.learning_rate, + lr=learning_rate, weight_decay=self.l2_regularization_coeff, ) self.scheduler = ReduceLROnPlateau( @@ -646,6 +648,12 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"): If training_method is "full", then we optimize the branch lengths using full ML optimization. If training_method is "yun", then we use Yun's approximation to the branch lengths. If training_method is "fixed", then we fix the branch lengths and only optimize the model. + + We reset the optimization after each cycle, and we use a learning rate + schedule that uses a weighted geometric mean of the current learning + rate and the initial learning rate that progressively moves towards + keeping the current learning rate as the cycles progress. + """ if training_method == "full": optimize_branch_lengths = self.standardize_and_optimize_branch_lengths @@ -660,7 +668,13 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"): self.mark_branch_lengths_optimized(0) for cycle in range(cycle_count): self.mark_branch_lengths_optimized(cycle + 1) - self.reset_optimization() + current_lr = self.optimizer.param_groups[0]["lr"] + # set new_lr to be the geometric mean of current_lr and the learning rate + weight = 0.5 + cycle / (2 * cycle_count) + new_lr = np.exp( + weight * np.log(current_lr) + (1 - weight) * np.log(self.learning_rate) + ) + self.reset_optimization(new_lr) loss_history_l.append(self.train(epochs)) if cycle < cycle_count - 1: optimize_branch_lengths() @@ -718,29 +732,34 @@ def loss_of_batch(self, batch): loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked) return loss - def vrc01_site_1_model_rate(self): + def vrc01_site_14_model_rate(self): """ - Calculate rate on site 1 (zero-indexed) of VRC01_NT_SEQ. + Calculate rate on site 14 (zero-indexed) of VRC01_NT_SEQ. """ encoder = self.val_loader.dataset.encoder - assert encoder.site_count >= 2 + assert ( + encoder.site_count >= 15 + ), "Encoder site count too small vrc01_site_14_model_rate" encoded_parent, wt_base_modifier = encoder.encode_sequence(VRC01_NT_SEQ) mask = nt_mask_tensor_of(VRC01_NT_SEQ, encoder.site_count) + encoded_parent = encoded_parent.to(self.device) + mask = mask.to(self.device) + wt_base_modifier = wt_base_modifier.to(self.device) vrc01_rates, _ = self.model( encoded_parent.unsqueeze(0), mask.unsqueeze(0), wt_base_modifier.unsqueeze(0), ) - vrc01_rate_1 = vrc01_rates.squeeze()[1].item() - return vrc01_rate_1 + vrc01_rate_14 = vrc01_rates.squeeze()[14].item() + return vrc01_rate_14 def standardize_model_rates(self): """ - Normalize the rates output by the model so that it predicts rate 1 on site 1 + Normalize the rates output by the model so that it predicts rate 1 on site 14 (zero-indexed) of VRC01_NT_SEQ. """ - vrc01_rate_1 = self.vrc01_site_1_model_rate() - self.model.adjust_rate_bias_by(-np.log(vrc01_rate_1)) + vrc01_rate_14 = self.vrc01_site_14_model_rate() + self.model.adjust_rate_bias_by(-np.log(vrc01_rate_14)) def to_crepe(self): training_hyperparameters = { diff --git a/netam/toy_simulation.py b/netam/toy_simulation.py index ad0a5a10..13d9a45d 100644 --- a/netam/toy_simulation.py +++ b/netam/toy_simulation.py @@ -5,6 +5,7 @@ It corresponds to the inference happning in toy_dnsm.py. """ + import random import pandas as pd diff --git a/tests/test_netam.py b/tests/test_netam.py index 3e773526..ffb525ae 100644 --- a/tests/test_netam.py +++ b/tests/test_netam.py @@ -93,18 +93,24 @@ def tiny_rsscnnmodel(): @pytest.fixture -def tiny_rsburrito(tiny_dataset, tiny_val_dataset, tiny_rsshmoofmodel): - burrito = RSSHMBurrito(tiny_dataset, tiny_val_dataset, tiny_rsshmoofmodel) +def mini_dataset(): + df = pd.read_csv("data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz") + return SHMoofDataset(df, site_count=500, kmer_length=3) + + +@pytest.fixture +def mini_rsburrito(mini_dataset, tiny_rsscnnmodel): + burrito = RSSHMBurrito(mini_dataset, mini_dataset, tiny_rsscnnmodel) burrito.joint_train(epochs=5, training_method="yun") return burrito -def test_standardize_model_rates(tiny_rsburrito): - tiny_rsburrito.standardize_model_rates() - vrc01_rate_1 = tiny_rsburrito.vrc01_site_1_model_rate() - assert np.isclose(vrc01_rate_1, 1.0) +def test_write_output(mini_rsburrito): + os.makedirs("_ignore", exist_ok=True) + mini_rsburrito.save_crepe("_ignore/mini_rscrepe") -def test_write_output(tiny_rsburrito): - os.makedirs("_ignore", exist_ok=True) - tiny_rsburrito.save_crepe("_ignore/tiny_rscrepe") +def test_standardize_model_rates(mini_rsburrito): + mini_rsburrito.standardize_model_rates() + vrc01_rate_14 = mini_rsburrito.vrc01_site_14_model_rate() + assert np.isclose(vrc01_rate_14, 1.0)