-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
parallel branch length optimization #27
Comments
This failed because of not being able to pickle: def build_input_for_parallel_branch_length_optimization(self, dataset, **optimization_kwargs):
input = []
for parent, child, rates, subs_probs, starting_length in tqdm(
zip(
dataset.nt_parents,
dataset.nt_children,
dataset.all_rates,
dataset.all_subs_probs,
dataset.branch_lengths,
),
total=len(dataset.nt_parents),
desc="Finding optimal branch lengths",
):
log_pcp_probability = self.wrapped_model._build_log_pcp_probability(
parent, child, rates[: len(parent)], subs_probs[: len(parent), :],
)
input.append((log_pcp_probability, starting_length))
return input def test_parallel_branch_length_optimization(dnsm_burrito):
dataset = dnsm_burrito.val_dataset
input = dnsm_burrito.build_input_for_parallel_branch_length_optimization(dataset)
pool = Pool(1)
branch_lengths = pool.starmap(optimize_branch_length, input) |
Meeting a similar fate: def clone(self):
return self.__class__(
self.train_dataset.clone(),
self.val_dataset.clone(),
copy.deepcopy(self.model),
batch_size=self.batch_size,
learning_rate=self.learning_rate,
min_learning_rate=self.min_learning_rate,
l2_regularization_coeff=self.l2_regularization_coeff,
verbose=self.verbose,
name=self.name,
) |
And all this failed because my burrito object is far from being pickleable: def worker_find_optimal_branch_lengths(burrito, dataset, optimization_kwargs):
#return burrito.find_optimal_branch_lengths(dataset, **optimization_kwargs)
return torch.tensor([1.])
def split_dataset(dataset, into_count):
"""
Split a Dataset into into_count subsets.
"""
dataset_size = len(dataset)
indices = list(range(dataset_size))
split_indices = np.array_split(indices, into_count)
subsets = [dataset.clone_with_indices(split_indices[i]) for i in range(into_count)]
return subsets
def parallel_find_optimal_branch_lengths(burrito, dataset, optimization_kwargs, job_count):
split_datasets = split_dataset(dataset, job_count)
burrito_copies = [burrito.clone() for _ in range(job_count)]
print(find_queues(burrito_copies[0]))
print(find_queues(burrito_copies[1]))
# TODO we could only make a writer upon request
# we could check to see if two independent writers can be used with pool.map
# tasks = [(burrito_copies[i], split_datasets[i], optimization_kwargs) for i in range(job_count)]
# results = [worker_find_optimal_branch_lengths(*task) for task in tasks]
# with mp.Pool(processes=job_count) as pool:
# #tasks = [(burrito_copies[i], split_datasets[i], {}) for i in range(job_count)]
# tasks = [(burrito_copies[i], None, {}) for i in range(job_count)]
# # print(tasks)
# #results = pool.starmap(worker_find_optimal_branch_lengths, tasks)
# results = []
# merged_result = torch.cat(results)
# return merged_result |
Etc... def clone(self):
new_dataset = DNSMDataset(
self.nt_parents,
self.nt_children,
self.all_rates,
self.all_subs_probs,
self._branch_lengths,
)
return new_dataset
def clone_with_indices(self, indices):
new_dataset = DNSMDataset(
self.nt_parents[indices],
self.nt_children[indices],
self.all_rates[indices],
self.all_subs_probs[indices],
self._branch_lengths[indices],
)
return new_dataset |
Possible plan:
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
No description provided.
The text was updated successfully, but these errors were encountered: