Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parallel branch length optimization #27

Closed
matsen opened this issue May 31, 2024 · 5 comments · Fixed by #28
Closed

parallel branch length optimization #27

matsen opened this issue May 31, 2024 · 5 comments · Fixed by #28

Comments

@matsen
Copy link
Contributor

matsen commented May 31, 2024

No description provided.

@matsen
Copy link
Contributor Author

matsen commented Jun 2, 2024

This failed because of not being able to pickle:

    def build_input_for_parallel_branch_length_optimization(self, dataset, **optimization_kwargs):
        input = []

        for parent, child, rates, subs_probs, starting_length in tqdm(
            zip(
                dataset.nt_parents,
                dataset.nt_children,
                dataset.all_rates,
                dataset.all_subs_probs,
                dataset.branch_lengths,
            ),
            total=len(dataset.nt_parents),
            desc="Finding optimal branch lengths",
        ):
            log_pcp_probability = self.wrapped_model._build_log_pcp_probability(
                parent, child, rates[: len(parent)], subs_probs[: len(parent), :],
            )
            input.append((log_pcp_probability, starting_length))

        return input
def test_parallel_branch_length_optimization(dnsm_burrito):
    dataset = dnsm_burrito.val_dataset
    input = dnsm_burrito.build_input_for_parallel_branch_length_optimization(dataset)
    pool = Pool(1)
    branch_lengths = pool.starmap(optimize_branch_length, input)

@matsen
Copy link
Contributor Author

matsen commented Jun 2, 2024

Meeting a similar fate:

    def clone(self):
        return self.__class__(
            self.train_dataset.clone(),
            self.val_dataset.clone(),
            copy.deepcopy(self.model),
            batch_size=self.batch_size,
            learning_rate=self.learning_rate,
            min_learning_rate=self.min_learning_rate,
            l2_regularization_coeff=self.l2_regularization_coeff,
            verbose=self.verbose,
            name=self.name,
        )

@matsen
Copy link
Contributor Author

matsen commented Jun 2, 2024

And all this failed because my burrito object is far from being pickleable:

def worker_find_optimal_branch_lengths(burrito, dataset, optimization_kwargs):
    #return burrito.find_optimal_branch_lengths(dataset, **optimization_kwargs)
    return torch.tensor([1.])


def split_dataset(dataset, into_count):
    """
    Split a Dataset into into_count subsets.
    """
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split_indices = np.array_split(indices, into_count)
    subsets = [dataset.clone_with_indices(split_indices[i]) for i in range(into_count)]
    return subsets


def parallel_find_optimal_branch_lengths(burrito, dataset, optimization_kwargs, job_count):
    split_datasets = split_dataset(dataset, job_count)
    
    burrito_copies = [burrito.clone() for _ in range(job_count)]
    print(find_queues(burrito_copies[0]))
    print(find_queues(burrito_copies[1]))
    
    # TODO we could only make a writer upon request
    # we could check to see if two independent writers can be used with pool.map
    
        # tasks = [(burrito_copies[i], split_datasets[i], optimization_kwargs) for i in range(job_count)]
        # results = [worker_find_optimal_branch_lengths(*task) for task in tasks]
    # with mp.Pool(processes=job_count) as pool:
    #     #tasks = [(burrito_copies[i], split_datasets[i], {}) for i in range(job_count)]
    #     tasks = [(burrito_copies[i], None, {}) for i in range(job_count)]
    #     # print(tasks)
    #     #results = pool.starmap(worker_find_optimal_branch_lengths, tasks)
    #     results = []

    # merged_result = torch.cat(results)
    # return merged_result

@matsen
Copy link
Contributor Author

matsen commented Jun 2, 2024

Etc...

    def clone(self):
        new_dataset = DNSMDataset(
            self.nt_parents,
            self.nt_children,
            self.all_rates,
            self.all_subs_probs,
            self._branch_lengths,
        )
        return new_dataset


    def clone_with_indices(self, indices):
        new_dataset = DNSMDataset(
            self.nt_parents[indices],
            self.nt_children[indices],
            self.all_rates[indices],
            self.all_subs_probs[indices],
            self._branch_lengths[indices],
        )
        return new_dataset

@matsen
Copy link
Contributor Author

matsen commented Jun 2, 2024

Possible plan:

  • make a function that takes in text arguments, builds a burrito, and returns branch lengths
  • pool.starmap
  • concatenate those

@matsen matsen linked a pull request Jun 2, 2024 that will close this issue
@matsen matsen closed this as completed in #28 Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant