Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardization fixes; less radical LR reset #20

Merged
merged 13 commits into from
May 21, 2024
41 changes: 30 additions & 11 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,13 @@ def __init__(
def device(self):
return next(self.model.parameters()).device

def reset_optimization(self):
def reset_optimization(self, learning_rate=None):
"""Reset the optimizer and scheduler."""
if learning_rate is None:
learning_rate = self.learning_rate
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.learning_rate,
lr=learning_rate,
weight_decay=self.l2_regularization_coeff,
)
self.scheduler = ReduceLROnPlateau(
Expand Down Expand Up @@ -646,6 +648,12 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"):
If training_method is "full", then we optimize the branch lengths using full ML optimization.
If training_method is "yun", then we use Yun's approximation to the branch lengths.
If training_method is "fixed", then we fix the branch lengths and only optimize the model.

We reset the optimization after each cycle, and we use a learning rate
schedule that uses a weighted geometric mean of the current learning
rate and the initial learning rate that progressively moves towards
keeping the current learning rate as the cycles progress.

"""
if training_method == "full":
optimize_branch_lengths = self.standardize_and_optimize_branch_lengths
Expand All @@ -660,7 +668,13 @@ def joint_train(self, epochs=100, cycle_count=2, training_method="full"):
self.mark_branch_lengths_optimized(0)
for cycle in range(cycle_count):
self.mark_branch_lengths_optimized(cycle + 1)
self.reset_optimization()
current_lr = self.optimizer.param_groups[0]["lr"]
# set new_lr to be the geometric mean of current_lr and the learning rate
weight = 0.5 + cycle / (2 * cycle_count)
new_lr = np.exp(
weight * np.log(current_lr) + (1 - weight) * np.log(self.learning_rate)
)
self.reset_optimization(new_lr)
loss_history_l.append(self.train(epochs))
if cycle < cycle_count - 1:
optimize_branch_lengths()
Expand Down Expand Up @@ -718,29 +732,34 @@ def loss_of_batch(self, batch):
loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked)
return loss

def vrc01_site_1_model_rate(self):
def vrc01_site_14_model_rate(self):
"""
Calculate rate on site 1 (zero-indexed) of VRC01_NT_SEQ.
Calculate rate on site 14 (zero-indexed) of VRC01_NT_SEQ.
"""
encoder = self.val_loader.dataset.encoder
assert encoder.site_count >= 2
assert (
encoder.site_count >= 15
), "Encoder site count too small vrc01_site_14_model_rate"
encoded_parent, wt_base_modifier = encoder.encode_sequence(VRC01_NT_SEQ)
mask = nt_mask_tensor_of(VRC01_NT_SEQ, encoder.site_count)
encoded_parent = encoded_parent.to(self.device)
mask = mask.to(self.device)
wt_base_modifier = wt_base_modifier.to(self.device)
vrc01_rates, _ = self.model(
encoded_parent.unsqueeze(0),
mask.unsqueeze(0),
wt_base_modifier.unsqueeze(0),
)
vrc01_rate_1 = vrc01_rates.squeeze()[1].item()
return vrc01_rate_1
vrc01_rate_14 = vrc01_rates.squeeze()[14].item()
return vrc01_rate_14

def standardize_model_rates(self):
"""
Normalize the rates output by the model so that it predicts rate 1 on site 1
Normalize the rates output by the model so that it predicts rate 1 on site 14
(zero-indexed) of VRC01_NT_SEQ.
"""
vrc01_rate_1 = self.vrc01_site_1_model_rate()
self.model.adjust_rate_bias_by(-np.log(vrc01_rate_1))
vrc01_rate_14 = self.vrc01_site_14_model_rate()
self.model.adjust_rate_bias_by(-np.log(vrc01_rate_14))

def to_crepe(self):
training_hyperparameters = {
Expand Down
24 changes: 15 additions & 9 deletions tests/test_netam.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,24 @@ def tiny_rsscnnmodel():


@pytest.fixture
def tiny_rsburrito(tiny_dataset, tiny_val_dataset, tiny_rsshmoofmodel):
burrito = RSSHMBurrito(tiny_dataset, tiny_val_dataset, tiny_rsshmoofmodel)
def mini_dataset():
df = pd.read_csv("data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz")
return SHMoofDataset(df, site_count=500, kmer_length=3)


@pytest.fixture
def mini_rsburrito(mini_dataset, tiny_rsscnnmodel):
burrito = RSSHMBurrito(mini_dataset, mini_dataset, tiny_rsscnnmodel)
burrito.joint_train(epochs=5, training_method="yun")
return burrito


def test_standardize_model_rates(tiny_rsburrito):
tiny_rsburrito.standardize_model_rates()
vrc01_rate_1 = tiny_rsburrito.vrc01_site_1_model_rate()
assert np.isclose(vrc01_rate_1, 1.0)
def test_write_output(mini_rsburrito):
os.makedirs("_ignore", exist_ok=True)
mini_rsburrito.save_crepe("_ignore/mini_rscrepe")


def test_write_output(tiny_rsburrito):
os.makedirs("_ignore", exist_ok=True)
tiny_rsburrito.save_crepe("_ignore/tiny_rscrepe")
def test_standardize_model_rates(mini_rsburrito):
mini_rsburrito.standardize_model_rates()
vrc01_rate_14 = mini_rsburrito.vrc01_site_14_model_rate()
assert np.isclose(vrc01_rate_14, 1.0)