diff --git a/netam/dasm.py b/netam/dasm.py index 1e3f3f52..0fefb8d3 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -149,7 +149,7 @@ def prediction_pair_of_batch(self, batch): def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors): # Take the product of the neutral mutation probabilities and the selection factors. # NOTE each of these now have last dimension of 20 - # this is p_{j, a} * f_{j, a} + # this is p_{j, a} * f_{j, a} predictions = torch.exp(log_neutral_aa_probs + log_selection_factors) assert torch.isfinite(predictions).all() predictions = clamp_probability(predictions) @@ -206,4 +206,4 @@ def build_selection_matrix_from_parent(self, parent: str): parent_idxs = sequences.aa_idx_array_of_str(parent) selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 - return selection_factors \ No newline at end of file + return selection_factors diff --git a/netam/dnsm.py b/netam/dnsm.py index 23539bd9..c2cac491 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -417,7 +417,8 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) our_optimize_branch_length = partial( worker_optimize_branch_length, - self.__class__,) + self.__class__, + ) with mp.Pool(worker_count) as pool: splits = dataset.split(worker_count) results = pool.starmap(