diff --git a/netam/dnsm.py b/netam/dnsm.py index b527a25a..4b46775f 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -144,6 +144,16 @@ def clone_with_indices(self, indices): ) return new_dataset + def split(self, into_count: int): + """ + Split self into a list of into_count subsets. + """ + dataset_size = len(self) + indices = list(range(dataset_size)) + split_indices = np.array_split(indices, into_count) + subsets = [self.clone_with_indices(split_indices[i]) for i in range(into_count)] + return subsets + @property def branch_lengths(self): return self._branch_lengths @@ -376,7 +386,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) with mp.Pool(worker_count) as pool: - splits = split_dataset(dataset, worker_count) + splits = dataset.split(worker_count) results = pool.starmap( worker_optimize_branch_length, [(self.model, split, optimization_kwargs) for split in splits], @@ -402,17 +412,6 @@ def worker_optimize_branch_length(model, dataset, optimization_kwargs): return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) -def split_dataset(dataset, into_count): - """ - Split a Dataset into into_count subsets. - """ - dataset_size = len(dataset) - indices = list(range(dataset_size)) - split_indices = np.array_split(indices, into_count) - subsets = [dataset.clone_with_indices(split_indices[i]) for i in range(into_count)] - return subsets - - ### End functions used for parallel branch length optimization.