diff --git a/netam/framework.py b/netam/framework.py index 547c42d3..c8a8afe9 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -775,8 +775,8 @@ def joint_train( ) self.reset_optimization(new_lr) loss_history_l.append(self.train(epochs, out_prefix=out_prefix)) - if cycle < cycle_count - 1: - optimize_branch_lengths() + # We standardize and optimize the branch lengths after each cycle, even the last one. + optimize_branch_lengths() self.mark_branch_lengths_optimized(cycle + 1) return pd.concat(loss_history_l, ignore_index=True) @@ -932,7 +932,7 @@ def _find_optimal_branch_length( **optimization_kwargs, ): if torch.sum(mutation_indicator) == 0: - return 0.0 + return 0.0, False rates, _ = self.model( encoded_parent.unsqueeze(0),