From d44a60831c0bb1f52e61d4ae6a7bc16a24597740 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 2 Oct 2024 13:12:03 -0700 Subject: [PATCH] multihit works with threading --- netam/dnsm.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index 4d4af44e..bd81cc88 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -55,7 +55,10 @@ def __init__( self.nt_children = nt_children self.all_rates = all_rates self.all_subs_probs = all_subs_probs - self.multihit_model = multihit_model + self.multihit_model = copy.deepcopy(multihit_model) + if multihit_model is not None: + # We want these parameters to act like fixed data + self.multihit_model.values.requires_grad_(False) assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) @@ -147,7 +150,7 @@ def clone(self): self.all_rates.copy(), self.all_subs_probs.copy(), self._branch_lengths.copy(), - multihit_model=self.multihit_model, + multihit_model=copy.deepcopy(self.multihit_model), ) return new_dataset @@ -164,7 +167,7 @@ def subset_via_indices(self, indices): self.all_rates[indices], self.all_subs_probs[indices], self._branch_lengths[indices], - multihit_model=self.multihit_model, + multihit_model=copy.deepcopy(self.multihit_model), ) return new_dataset @@ -426,16 +429,16 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # The following can be used when one wants a better traceback. - burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) - return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) - # with mp.Pool(worker_count) as pool: - # splits = dataset.split(worker_count) - # results = pool.starmap( - # worker_optimize_branch_length, - # [(self.model, split, optimization_kwargs) for split in splits], - # ) - # return torch.cat(results) + # # The following can be used when one wants a better traceback. + # burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) + # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + with mp.Pool(worker_count) as pool: + splits = dataset.split(worker_count) + results = pool.starmap( + worker_optimize_branch_length, + [(self.model, split, optimization_kwargs) for split in splits], + ) + return torch.cat(results) def to_crepe(self): training_hyperparameters = {