Skip to content

Commit

Permalink
moving split dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 3, 2024
1 parent 2af11e1 commit 96c2595
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def clone_with_indices(self, indices):
)
return new_dataset

def split(self, into_count: int):
"""
Split self into a list of into_count subsets.
"""
dataset_size = len(self)
indices = list(range(dataset_size))
split_indices = np.array_split(indices, into_count)
subsets = [self.clone_with_indices(split_indices[i]) for i in range(into_count)]
return subsets

@property
def branch_lengths(self):
return self._branch_lengths
Expand Down Expand Up @@ -376,7 +386,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
with mp.Pool(worker_count) as pool:
splits = split_dataset(dataset, worker_count)
splits = dataset.split(worker_count)
results = pool.starmap(
worker_optimize_branch_length,
[(self.model, split, optimization_kwargs) for split in splits],
Expand All @@ -402,17 +412,6 @@ def worker_optimize_branch_length(model, dataset, optimization_kwargs):
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)


def split_dataset(dataset, into_count):
"""
Split a Dataset into into_count subsets.
"""
dataset_size = len(dataset)
indices = list(range(dataset_size))
split_indices = np.array_split(indices, into_count)
subsets = [dataset.clone_with_indices(split_indices[i]) for i in range(into_count)]
return subsets


### End functions used for parallel branch length optimization.


Expand Down

0 comments on commit 96c2595

Please sign in to comment.