From 9c8627db9031c0b3e2de56c89f5d5cd72975e91e Mon Sep 17 00:00:00 2001 From: DavAug Date: Tue, 6 Dec 2022 15:45:35 +0100 Subject: [PATCH] fix sample init params when params are fixed --- chi/_log_pdfs.py | 2 ++ chi/_population_models.py | 3 ++- chi/tests/test_log_pdfs.py | 53 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/chi/_log_pdfs.py b/chi/_log_pdfs.py index 535fc8c9..24492dcb 100644 --- a/chi/_log_pdfs.py +++ b/chi/_log_pdfs.py @@ -759,6 +759,8 @@ def sample_initial_parameters(self, n_samples=1, seed=None): # (Pooled and heterogen. dimensions do not count as bottom parameters) dims = [] current_dim = 0 + if isinstance(population_model, chi.ReducedPopulationModel): + population_model = population_model.get_population_model() try: pop_models = population_model.get_population_models() except AttributeError: diff --git a/chi/_population_models.py b/chi/_population_models.py index 194394b1..f694fb6a 100644 --- a/chi/_population_models.py +++ b/chi/_population_models.py @@ -2966,7 +2966,8 @@ def n_hierarchical_parameters(self, n_ids): # If parameters have been fixed, updated number of population # parameters if self._fixed_params_mask is not None: - n_pop = int(np.sum(self._fixed_params_mask)) + n_fixed = int(np.sum(self._fixed_params_mask)) + n_pop = self._n_parameters - n_fixed return (n_indiv, n_pop) diff --git a/chi/tests/test_log_pdfs.py b/chi/tests/test_log_pdfs.py index d1d22338..d420c86f 100644 --- a/chi/tests/test_log_pdfs.py +++ b/chi/tests/test_log_pdfs.py @@ -1501,6 +1501,59 @@ def test_sample_initial_parameters(self): n_samples=n_samples) self.assertEqual(samples.shape, (10, 3)) + # Test fixed population parameters + # Fix population parameters + population_model = chi.ComposedPopulationModel([ + chi.CovariatePopulationModel( + chi.GaussianModel(), chi.LinearCovariateModel()), + chi.PooledModel(n_dim=2)]) + population_model = chi.ReducedPopulationModel(population_model) + population_model.fix_parameters({'Std. Dim. 1': 1}) + + # Create hierarchical log-likelihood + log_likelihood = chi.HierarchicalLogLikelihood( + log_likelihoods, population_model, covariates=covariates) + + # Define log-prior + log_prior = pints.ComposedLogPrior( + pints.LogNormalLogPrior(1, 1), + pints.LogNormalLogPrior(1, 1), + pints.LogNormalLogPrior(1, 1), + pints.LogNormalLogPrior(1, 1), + pints.LogNormalLogPrior(1, 1)) + + # Create log-posterior + log_posterior = chi.HierarchicalLogPosterior( + log_likelihood, log_prior) + + n_samples = 10 + samples = log_posterior.sample_initial_parameters( + n_samples=n_samples) + self.assertEqual(samples.shape, (10, 7)) + + # Fix all parameters but one + population_model.fix_parameters({ + 'Mean Dim. 1': 1, + 'Mean Dim. 1 Cov. 1': 1, + 'Std. Dim. 1 Cov. 1': 1, + 'Pooled Dim. 1': 1}) + + # Create hierarchical log-likelihood + log_likelihood = chi.HierarchicalLogLikelihood( + log_likelihoods, population_model, covariates=covariates) + + # Define log-prior + log_prior = pints.LogNormalLogPrior(1, 1) + + # Create log-posterior + log_posterior = chi.HierarchicalLogPosterior( + log_likelihood, log_prior) + + n_samples = 10 + samples = log_posterior.sample_initial_parameters( + n_samples=n_samples) + self.assertEqual(samples.shape, (10, 3)) + class TestLogLikelihood(unittest.TestCase): """