Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Oct 3, 2024
1 parent bf44d5e commit cac0ba8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def prediction_pair_of_batch(self, batch):
def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
# Take the product of the neutral mutation probabilities and the selection factors.
# NOTE each of these now have last dimension of 20
# this is p_{j, a} * f_{j, a}
# this is p_{j, a} * f_{j, a}
predictions = torch.exp(log_neutral_aa_probs + log_selection_factors)
assert torch.isfinite(predictions).all()
predictions = clamp_probability(predictions)
Expand Down Expand Up @@ -206,4 +206,4 @@ def build_selection_matrix_from_parent(self, parent: str):
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

return selection_factors
return selection_factors
3 changes: 2 additions & 1 deletion netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
our_optimize_branch_length = partial(
worker_optimize_branch_length,
self.__class__,)
self.__class__,
)
with mp.Pool(worker_count) as pool:
splits = dataset.split(worker_count)
results = pool.starmap(
Expand Down

0 comments on commit cac0ba8

Please sign in to comment.