Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 2, 2024
1 parent 063c5ec commit c3bef03
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
10 changes: 5 additions & 5 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,7 @@ def load_branch_lengths(self, in_csv_prefix):
self.train_dataset.load_branch_lengths(
in_csv_prefix + ".train_branch_lengths.csv"
)
self.val_dataset.load_branch_lengths(
in_csv_prefix + ".val_branch_lengths.csv"
)
self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv")

def predictions_of_batch(self, batch):
"""
Expand Down Expand Up @@ -375,15 +373,14 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
return torch.tensor(optimal_lengths)

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count()//2, 10)
worker_count = min(mp.cpu_count() // 2, 10)
with mp.Pool(worker_count) as pool:
splits = split_dataset(dataset, worker_count)
results = pool.starmap(
worker_optimize_branch_length,
[(self.model, split, optimization_kwargs) for split in splits],
)
return torch.cat(results)


def to_crepe(self):
training_hyperparameters = {
Expand All @@ -398,6 +395,7 @@ def to_crepe(self):

## Begin functions used for parallel branch length optimization.


def worker_optimize_branch_length(model, dataset, optimization_kwargs):
burrito = DNSMBurrito(None, dataset, model)
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
Expand All @@ -413,8 +411,10 @@ def split_dataset(dataset, into_count):
subsets = [dataset.clone_with_indices(split_indices[i]) for i in range(into_count)]
return subsets


### End functions used for parallel branch length optimization.


class DNSMHyperBurrito(HyperBurrito):
# Note that we have to write the args out explicitly because we use some magic to filter kwargs in the optuna_objective method.
def burrito_of_model(
Expand Down
10 changes: 3 additions & 7 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def build_train_loader(self):
return DataLoader(
self.train_dataset, batch_size=self.batch_size, shuffle=True
)

def build_val_loader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

Expand Down Expand Up @@ -562,9 +562,7 @@ def record_losses(train_loss, val_loss):
val_losses.append(val_loss)

# Record the initial loss before training.
train_loss = self.process_data_loader(
train_loader, train_mode=False
).item()
train_loss = self.process_data_loader(train_loader, train_mode=False).item()
val_loss = self.process_data_loader(val_loader, train_mode=False).item()
record_losses(train_loss, val_loss)

Expand All @@ -581,9 +579,7 @@ def record_losses(train_loss, val_loss):
train_loss = self.process_data_loader(
train_loader, train_mode=True
).item()
val_loss = self.process_data_loader(
val_loader, train_mode=False
).item()
val_loss = self.process_data_loader(val_loader, train_mode=False).item()
self.scheduler.step(val_loss)
record_losses(train_loss, val_loss)
self.global_epoch += 1
Expand Down
1 change: 1 addition & 0 deletions tests/test_dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from multiprocessing import Pool


def test_aa_idx_tensor_of_str_ambig():
input_seq = "ACX"
expected_output = torch.tensor([0, 1, MAX_AMBIG_AA_IDX], dtype=torch.int)
Expand Down

0 comments on commit c3bef03

Please sign in to comment.