From cdd6ee0d47b46ee05c2071f230cf29bd4c36bcf8 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Tue, 14 May 2024 05:43:40 -0700 Subject: [PATCH 01/13] COMMENTING OUT MODEL RATE STDIZATION --- netam/framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netam/framework.py b/netam/framework.py index ec13b9d0..a3c330cc 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -582,7 +582,7 @@ def standardize_model_rates(self): pass def standardize_and_optimize_branch_lengths(self, **optimization_kwargs): - self.standardize_model_rates() + #self.standardize_model_rates() if "learning_rate" not in optimization_kwargs: optimization_kwargs["learning_rate"] = 0.01 if "optimization_tol" not in optimization_kwargs: From 9a49ae8eb21ab1d690605e58eedddaf99fba445b Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Tue, 14 May 2024 08:58:05 -0700 Subject: [PATCH 02/13] mini_rsburrito --- tests/test_netam.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/test_netam.py b/tests/test_netam.py index 3e773526..a7981ec1 100644 --- a/tests/test_netam.py +++ b/tests/test_netam.py @@ -99,12 +99,26 @@ def tiny_rsburrito(tiny_dataset, tiny_val_dataset, tiny_rsshmoofmodel): return burrito -def test_standardize_model_rates(tiny_rsburrito): - tiny_rsburrito.standardize_model_rates() - vrc01_rate_1 = tiny_rsburrito.vrc01_site_1_model_rate() - assert np.isclose(vrc01_rate_1, 1.0) - - def test_write_output(tiny_rsburrito): os.makedirs("_ignore", exist_ok=True) tiny_rsburrito.save_crepe("_ignore/tiny_rscrepe") + + +@pytest.fixture +def mini_dataset(): + df = pd.read_csv("data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz") + return SHMoofDataset(df, site_count=500, kmer_length=3) + + +@pytest.fixture +def mini_rsburrito(mini_dataset, tiny_rsscnnmodel): + burrito = RSSHMBurrito(mini_dataset, mini_dataset, tiny_rsscnnmodel) + burrito.joint_train(epochs=5, training_method="yun") + return burrito + + +def test_standardize_model_rates(mini_rsburrito): + mini_rsburrito.standardize_model_rates() + vrc01_rate_1 = mini_rsburrito.vrc01_site_1_model_rate() + assert np.isclose(vrc01_rate_1, 1.0) + From 6d23e6530051ca8b47b41994cf5aad6d79749a58 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Tue, 14 May 2024 09:04:43 -0700 Subject: [PATCH 03/13] normalizing using site 14 (further from boundary) --- netam/framework.py | 16 ++++++++-------- tests/test_netam.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/netam/framework.py b/netam/framework.py index a3c330cc..83f31404 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -718,12 +718,12 @@ def loss_of_batch(self, batch): loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked) return loss - def vrc01_site_1_model_rate(self): + def vrc01_site_14_model_rate(self): """ - Calculate rate on site 1 (zero-indexed) of VRC01_NT_SEQ. + Calculate rate on site 14 (zero-indexed) of VRC01_NT_SEQ. """ encoder = self.val_loader.dataset.encoder - assert encoder.site_count >= 2 + assert encoder.site_count >= 15, "Encoder site count too small vrc01_site_14_model_rate" encoded_parent, wt_base_modifier = encoder.encode_sequence(VRC01_NT_SEQ) mask = nt_mask_tensor_of(VRC01_NT_SEQ, encoder.site_count) vrc01_rates, _ = self.model( @@ -731,16 +731,16 @@ def vrc01_site_1_model_rate(self): mask.unsqueeze(0), wt_base_modifier.unsqueeze(0), ) - vrc01_rate_1 = vrc01_rates.squeeze()[1].item() - return vrc01_rate_1 + vrc01_rate_14 = vrc01_rates.squeeze()[14].item() + return vrc01_rate_14 def standardize_model_rates(self): """ - Normalize the rates output by the model so that it predicts rate 1 on site 1 + Normalize the rates output by the model so that it predicts rate 1 on site 14 (zero-indexed) of VRC01_NT_SEQ. """ - vrc01_rate_1 = self.vrc01_site_1_model_rate() - self.model.adjust_rate_bias_by(-np.log(vrc01_rate_1)) + vrc01_rate_14 = self.vrc01_site_14_model_rate() + self.model.adjust_rate_bias_by(-np.log(vrc01_rate_14)) def to_crepe(self): training_hyperparameters = { diff --git a/tests/test_netam.py b/tests/test_netam.py index a7981ec1..aa8c1212 100644 --- a/tests/test_netam.py +++ b/tests/test_netam.py @@ -119,6 +119,6 @@ def mini_rsburrito(mini_dataset, tiny_rsscnnmodel): def test_standardize_model_rates(mini_rsburrito): mini_rsburrito.standardize_model_rates() - vrc01_rate_1 = mini_rsburrito.vrc01_site_1_model_rate() - assert np.isclose(vrc01_rate_1, 1.0) + vrc01_rate_14 = mini_rsburrito.vrc01_site_14_model_rate() + assert np.isclose(vrc01_rate_14, 1.0) From 12db8026efd3e33170df1b92c1cbbb8406d82ad8 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Tue, 14 May 2024 09:17:24 -0700 Subject: [PATCH 04/13] everything gets standardized during training --- netam/framework.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netam/framework.py b/netam/framework.py index 83f31404..9a24b428 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -582,7 +582,7 @@ def standardize_model_rates(self): pass def standardize_and_optimize_branch_lengths(self, **optimization_kwargs): - #self.standardize_model_rates() + self.standardize_model_rates() if "learning_rate" not in optimization_kwargs: optimization_kwargs["learning_rate"] = 0.01 if "optimization_tol" not in optimization_kwargs: @@ -652,7 +652,7 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"): elif training_method == "yun": optimize_branch_lengths = self.standardize_and_use_yun_approx_branch_lengths elif training_method == "fixed": - optimize_branch_lengths = lambda: None + optimize_branch_lengths = self.standardize_model_rates else: raise ValueError(f"Unknown training method {training_method}") loss_history_l = [] From 03cc5ad6fce495e2a860011967955ec19c3a6f1e Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Fri, 17 May 2024 07:55:59 -0700 Subject: [PATCH 05/13] respecting device for VRC01 --- netam/framework.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netam/framework.py b/netam/framework.py index 9a24b428..5e4a3d92 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -726,6 +726,9 @@ def vrc01_site_14_model_rate(self): assert encoder.site_count >= 15, "Encoder site count too small vrc01_site_14_model_rate" encoded_parent, wt_base_modifier = encoder.encode_sequence(VRC01_NT_SEQ) mask = nt_mask_tensor_of(VRC01_NT_SEQ, encoder.site_count) + encoded_parent = encoded_parent.to(self.device) + mask = mask.to(self.device) + wt_base_modifier = wt_base_modifier.to(self.device) vrc01_rates, _ = self.model( encoded_parent.unsqueeze(0), mask.unsqueeze(0), From 6972daa6bfeb32a5445e2955f0617adf42b17b93 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Sat, 18 May 2024 05:00:23 -0700 Subject: [PATCH 06/13] no standardization without branch length optimization --- netam/framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netam/framework.py b/netam/framework.py index 5e4a3d92..35e7caf9 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -652,7 +652,7 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"): elif training_method == "yun": optimize_branch_lengths = self.standardize_and_use_yun_approx_branch_lengths elif training_method == "fixed": - optimize_branch_lengths = self.standardize_model_rates + optimize_branch_lengths = lambda: None else: raise ValueError(f"Unknown training method {training_method}") loss_history_l = [] From d8bbcc6da5be370efd7b442cb106678338b31b38 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 20 May 2024 02:10:11 -0700 Subject: [PATCH 07/13] less radical resetting of LR after cycles --- netam/framework.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/netam/framework.py b/netam/framework.py index 35e7caf9..90d1e1f9 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -396,11 +396,13 @@ def __init__( def device(self): return next(self.model.parameters()).device - def reset_optimization(self): + def reset_optimization(self, learning_rate=None): """Reset the optimizer and scheduler.""" + if learning_rate is None: + learning_rate = self.learning_rate self.optimizer = torch.optim.Adam( self.model.parameters(), - lr=self.learning_rate, + lr=learning_rate, weight_decay=self.l2_regularization_coeff, ) self.scheduler = ReduceLROnPlateau( @@ -646,6 +648,9 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"): If training_method is "full", then we optimize the branch lengths using full ML optimization. If training_method is "yun", then we use Yun's approximation to the branch lengths. If training_method is "fixed", then we fix the branch lengths and only optimize the model. + + We reset the optimization after each cycle, and we use a learning rate schedule that + uses a weighted geometric mean of the current learning rate and the initial learning rate that progressively moves towards keeping the current learning rate as the cycles progress. """ if training_method == "full": optimize_branch_lengths = self.standardize_and_optimize_branch_lengths @@ -660,7 +665,11 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"): self.mark_branch_lengths_optimized(0) for cycle in range(cycle_count): self.mark_branch_lengths_optimized(cycle + 1) - self.reset_optimization() + current_lr = self.optimizer.param_groups[0]["lr"] + # set new_lr to be the geometric mean of current_lr and the learning rate + weight = 0.5 + cycle / (2 * cycle_count) + new_lr = np.exp(weight * np.log(current_lr) + (1 - weight) * np.log(self.learning_rate)) + self.reset_optimization(new_lr) loss_history_l.append(self.train(epochs)) if cycle < cycle_count - 1: optimize_branch_lengths() From 7dc0cda04dda149f70369b6cc02888c07a444b02 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 20 May 2024 02:18:25 -0700 Subject: [PATCH 08/13] test fix --- tests/test_netam.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/test_netam.py b/tests/test_netam.py index aa8c1212..ffb525ae 100644 --- a/tests/test_netam.py +++ b/tests/test_netam.py @@ -92,18 +92,6 @@ def tiny_rsscnnmodel(): ) -@pytest.fixture -def tiny_rsburrito(tiny_dataset, tiny_val_dataset, tiny_rsshmoofmodel): - burrito = RSSHMBurrito(tiny_dataset, tiny_val_dataset, tiny_rsshmoofmodel) - burrito.joint_train(epochs=5, training_method="yun") - return burrito - - -def test_write_output(tiny_rsburrito): - os.makedirs("_ignore", exist_ok=True) - tiny_rsburrito.save_crepe("_ignore/tiny_rscrepe") - - @pytest.fixture def mini_dataset(): df = pd.read_csv("data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz") @@ -117,8 +105,12 @@ def mini_rsburrito(mini_dataset, tiny_rsscnnmodel): return burrito +def test_write_output(mini_rsburrito): + os.makedirs("_ignore", exist_ok=True) + mini_rsburrito.save_crepe("_ignore/mini_rscrepe") + + def test_standardize_model_rates(mini_rsburrito): mini_rsburrito.standardize_model_rates() vrc01_rate_14 = mini_rsburrito.vrc01_site_14_model_rate() assert np.isclose(vrc01_rate_14, 1.0) - From 0db8ea0b5c165cf0db8474fff6faba524531524f Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 20 May 2024 02:18:40 -0700 Subject: [PATCH 09/13] make forma --- netam/framework.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/netam/framework.py b/netam/framework.py index 90d1e1f9..28efd967 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -668,7 +668,9 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"): current_lr = self.optimizer.param_groups[0]["lr"] # set new_lr to be the geometric mean of current_lr and the learning rate weight = 0.5 + cycle / (2 * cycle_count) - new_lr = np.exp(weight * np.log(current_lr) + (1 - weight) * np.log(self.learning_rate)) + new_lr = np.exp( + weight * np.log(current_lr) + (1 - weight) * np.log(self.learning_rate) + ) self.reset_optimization(new_lr) loss_history_l.append(self.train(epochs)) if cycle < cycle_count - 1: @@ -732,7 +734,9 @@ def vrc01_site_14_model_rate(self): Calculate rate on site 14 (zero-indexed) of VRC01_NT_SEQ. """ encoder = self.val_loader.dataset.encoder - assert encoder.site_count >= 15, "Encoder site count too small vrc01_site_14_model_rate" + assert ( + encoder.site_count >= 15 + ), "Encoder site count too small vrc01_site_14_model_rate" encoded_parent, wt_base_modifier = encoder.encode_sequence(VRC01_NT_SEQ) mask = nt_mask_tensor_of(VRC01_NT_SEQ, encoder.site_count) encoded_parent = encoded_parent.to(self.device) From 9d3b814b37138d5c6d3608506797440518337f95 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 20 May 2024 02:20:41 -0700 Subject: [PATCH 10/13] format --- netam/framework.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/netam/framework.py b/netam/framework.py index 28efd967..e4ceb0dd 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -649,8 +649,11 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"): If training_method is "yun", then we use Yun's approximation to the branch lengths. If training_method is "fixed", then we fix the branch lengths and only optimize the model. - We reset the optimization after each cycle, and we use a learning rate schedule that - uses a weighted geometric mean of the current learning rate and the initial learning rate that progressively moves towards keeping the current learning rate as the cycles progress. + We reset the optimization after each cycle, and we use a learning rate + schedule that uses a weighted geometric mean of the current learning + rate and the initial learning rate that progressively moves towards + keeping the current learning rate as the cycles progress. + """ if training_method == "full": optimize_branch_lengths = self.standardize_and_optimize_branch_lengths From 294975b397a2730c1979ee65f4c178c1a1f62ad3 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Tue, 21 May 2024 13:33:13 -0700 Subject: [PATCH 11/13] adding black check --- .github/workflows/build-and-test.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 33657d9b..9a58a733 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -53,6 +53,10 @@ jobs: cd ../main pip install . + - name: Run Black + run: | + black --check . + - name: Test shell: bash -l {0} run: | From c0ee4f12f461aed5da17bed79c385450d6ac2399 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 21 May 2024 13:51:41 -0700 Subject: [PATCH 12/13] add formatting and linting before tests --- .github/workflows/build-and-test.yml | 13 +++++++++---- Makefile | 3 +++ netam/toy_simulation.py | 1 + 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 9a58a733..8da54c9d 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -40,10 +40,19 @@ jobs: create-args: >- python=${{ matrix.python-version }} datrie + black + flake8 init-shell: bash cache-environment: false post-cleanup: 'none' + - name: Lint and check format + shell: bash -l {0} + run: | + cd main + make lint + make checkformat + - name: Install shell: bash -l {0} run: | @@ -53,10 +62,6 @@ jobs: cd ../main pip install . - - name: Run Black - run: | - black --check . - - name: Test shell: bash -l {0} run: | diff --git a/Makefile b/Makefile index ccb03b06..cce8f43f 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,9 @@ test: format: black netam tests +checkformat: + black --check netam tests + lint: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=_ignore diff --git a/netam/toy_simulation.py b/netam/toy_simulation.py index ad0a5a10..13d9a45d 100644 --- a/netam/toy_simulation.py +++ b/netam/toy_simulation.py @@ -5,6 +5,7 @@ It corresponds to the inference happning in toy_dnsm.py. """ + import random import pandas as pd From e3014d25b9f6e35e1476b563ea434418b2ed3727 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 21 May 2024 13:54:38 -0700 Subject: [PATCH 13/13] remove linting --- .github/workflows/build-and-test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 8da54c9d..b561159c 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -46,11 +46,10 @@ jobs: cache-environment: false post-cleanup: 'none' - - name: Lint and check format + - name: Check format shell: bash -l {0} run: | cd main - make lint make checkformat - name: Install