Skip to content

Commit

Permalink
Standardization fixes; less radical LR reset (#20)
Browse files Browse the repository at this point in the history
* Making sure standardization happens on the right device
* Moving to a more central site of VRC01 for standardization
* Unit test improvements
* Less radical LR reset when doing joint_train (now using weighted geometric mean)
* Black check in CI
  • Loading branch information
matsen authored May 21, 2024
1 parent 5bfc067 commit 2b4ddac
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 20 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,18 @@ jobs:
create-args: >-
python=${{ matrix.python-version }}
datrie
black
flake8
init-shell: bash
cache-environment: false
post-cleanup: 'none'

- name: Check format
shell: bash -l {0}
run: |
cd main
make checkformat
- name: Install
shell: bash -l {0}
run: |
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ test:
format:
black netam tests

checkformat:
black --check netam tests

lint:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=_ignore
Expand Down
41 changes: 30 additions & 11 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,13 @@ def __init__(
def device(self):
return next(self.model.parameters()).device

def reset_optimization(self):
def reset_optimization(self, learning_rate=None):
"""Reset the optimizer and scheduler."""
if learning_rate is None:
learning_rate = self.learning_rate
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.learning_rate,
lr=learning_rate,
weight_decay=self.l2_regularization_coeff,
)
self.scheduler = ReduceLROnPlateau(
Expand Down Expand Up @@ -646,6 +648,12 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"):
If training_method is "full", then we optimize the branch lengths using full ML optimization.
If training_method is "yun", then we use Yun's approximation to the branch lengths.
If training_method is "fixed", then we fix the branch lengths and only optimize the model.
We reset the optimization after each cycle, and we use a learning rate
schedule that uses a weighted geometric mean of the current learning
rate and the initial learning rate that progressively moves towards
keeping the current learning rate as the cycles progress.
"""
if training_method == "full":
optimize_branch_lengths = self.standardize_and_optimize_branch_lengths
Expand All @@ -660,7 +668,13 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"):
self.mark_branch_lengths_optimized(0)
for cycle in range(cycle_count):
self.mark_branch_lengths_optimized(cycle + 1)
self.reset_optimization()
current_lr = self.optimizer.param_groups[0]["lr"]
# set new_lr to be the geometric mean of current_lr and the learning rate
weight = 0.5 + cycle / (2 * cycle_count)
new_lr = np.exp(
weight * np.log(current_lr) + (1 - weight) * np.log(self.learning_rate)
)
self.reset_optimization(new_lr)
loss_history_l.append(self.train(epochs))
if cycle < cycle_count - 1:
optimize_branch_lengths()
Expand Down Expand Up @@ -718,29 +732,34 @@ def loss_of_batch(self, batch):
loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked)
return loss

def vrc01_site_1_model_rate(self):
def vrc01_site_14_model_rate(self):
"""
Calculate rate on site 1 (zero-indexed) of VRC01_NT_SEQ.
Calculate rate on site 14 (zero-indexed) of VRC01_NT_SEQ.
"""
encoder = self.val_loader.dataset.encoder
assert encoder.site_count >= 2
assert (
encoder.site_count >= 15
), "Encoder site count too small vrc01_site_14_model_rate"
encoded_parent, wt_base_modifier = encoder.encode_sequence(VRC01_NT_SEQ)
mask = nt_mask_tensor_of(VRC01_NT_SEQ, encoder.site_count)
encoded_parent = encoded_parent.to(self.device)
mask = mask.to(self.device)
wt_base_modifier = wt_base_modifier.to(self.device)
vrc01_rates, _ = self.model(
encoded_parent.unsqueeze(0),
mask.unsqueeze(0),
wt_base_modifier.unsqueeze(0),
)
vrc01_rate_1 = vrc01_rates.squeeze()[1].item()
return vrc01_rate_1
vrc01_rate_14 = vrc01_rates.squeeze()[14].item()
return vrc01_rate_14

def standardize_model_rates(self):
"""
Normalize the rates output by the model so that it predicts rate 1 on site 1
Normalize the rates output by the model so that it predicts rate 1 on site 14
(zero-indexed) of VRC01_NT_SEQ.
"""
vrc01_rate_1 = self.vrc01_site_1_model_rate()
self.model.adjust_rate_bias_by(-np.log(vrc01_rate_1))
vrc01_rate_14 = self.vrc01_site_14_model_rate()
self.model.adjust_rate_bias_by(-np.log(vrc01_rate_14))

def to_crepe(self):
training_hyperparameters = {
Expand Down
1 change: 1 addition & 0 deletions netam/toy_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
It corresponds to the inference happning in toy_dnsm.py.
"""

import random

import pandas as pd
Expand Down
24 changes: 15 additions & 9 deletions tests/test_netam.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,24 @@ def tiny_rsscnnmodel():


@pytest.fixture
def tiny_rsburrito(tiny_dataset, tiny_val_dataset, tiny_rsshmoofmodel):
burrito = RSSHMBurrito(tiny_dataset, tiny_val_dataset, tiny_rsshmoofmodel)
def mini_dataset():
df = pd.read_csv("data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz")
return SHMoofDataset(df, site_count=500, kmer_length=3)


@pytest.fixture
def mini_rsburrito(mini_dataset, tiny_rsscnnmodel):
burrito = RSSHMBurrito(mini_dataset, mini_dataset, tiny_rsscnnmodel)
burrito.joint_train(epochs=5, training_method="yun")
return burrito


def test_standardize_model_rates(tiny_rsburrito):
tiny_rsburrito.standardize_model_rates()
vrc01_rate_1 = tiny_rsburrito.vrc01_site_1_model_rate()
assert np.isclose(vrc01_rate_1, 1.0)
def test_write_output(mini_rsburrito):
os.makedirs("_ignore", exist_ok=True)
mini_rsburrito.save_crepe("_ignore/mini_rscrepe")


def test_write_output(tiny_rsburrito):
os.makedirs("_ignore", exist_ok=True)
tiny_rsburrito.save_crepe("_ignore/tiny_rscrepe")
def test_standardize_model_rates(mini_rsburrito):
mini_rsburrito.standardize_model_rates()
vrc01_rate_14 = mini_rsburrito.vrc01_site_14_model_rate()
assert np.isclose(vrc01_rate_14, 1.0)

0 comments on commit 2b4ddac

Please sign in to comment.