Skip to content

Commit

Permalink
fix sample init params when params are fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
DavAug committed Dec 6, 2022
1 parent ce1983b commit 9c8627d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
2 changes: 2 additions & 0 deletions chi/_log_pdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,8 @@ def sample_initial_parameters(self, n_samples=1, seed=None):
# (Pooled and heterogen. dimensions do not count as bottom parameters)
dims = []
current_dim = 0
if isinstance(population_model, chi.ReducedPopulationModel):
population_model = population_model.get_population_model()
try:
pop_models = population_model.get_population_models()
except AttributeError:
Expand Down
3 changes: 2 additions & 1 deletion chi/_population_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2966,7 +2966,8 @@ def n_hierarchical_parameters(self, n_ids):
# If parameters have been fixed, updated number of population
# parameters
if self._fixed_params_mask is not None:
n_pop = int(np.sum(self._fixed_params_mask))
n_fixed = int(np.sum(self._fixed_params_mask))
n_pop = self._n_parameters - n_fixed

return (n_indiv, n_pop)

Expand Down
53 changes: 53 additions & 0 deletions chi/tests/test_log_pdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,59 @@ def test_sample_initial_parameters(self):
n_samples=n_samples)
self.assertEqual(samples.shape, (10, 3))

# Test fixed population parameters
# Fix population parameters
population_model = chi.ComposedPopulationModel([
chi.CovariatePopulationModel(
chi.GaussianModel(), chi.LinearCovariateModel()),
chi.PooledModel(n_dim=2)])
population_model = chi.ReducedPopulationModel(population_model)
population_model.fix_parameters({'Std. Dim. 1': 1})

# Create hierarchical log-likelihood
log_likelihood = chi.HierarchicalLogLikelihood(
log_likelihoods, population_model, covariates=covariates)

# Define log-prior
log_prior = pints.ComposedLogPrior(
pints.LogNormalLogPrior(1, 1),
pints.LogNormalLogPrior(1, 1),
pints.LogNormalLogPrior(1, 1),
pints.LogNormalLogPrior(1, 1),
pints.LogNormalLogPrior(1, 1))

# Create log-posterior
log_posterior = chi.HierarchicalLogPosterior(
log_likelihood, log_prior)

n_samples = 10
samples = log_posterior.sample_initial_parameters(
n_samples=n_samples)
self.assertEqual(samples.shape, (10, 7))

# Fix all parameters but one
population_model.fix_parameters({
'Mean Dim. 1': 1,
'Mean Dim. 1 Cov. 1': 1,
'Std. Dim. 1 Cov. 1': 1,
'Pooled Dim. 1': 1})

# Create hierarchical log-likelihood
log_likelihood = chi.HierarchicalLogLikelihood(
log_likelihoods, population_model, covariates=covariates)

# Define log-prior
log_prior = pints.LogNormalLogPrior(1, 1)

# Create log-posterior
log_posterior = chi.HierarchicalLogPosterior(
log_likelihood, log_prior)

n_samples = 10
samples = log_posterior.sample_initial_parameters(
n_samples=n_samples)
self.assertEqual(samples.shape, (10, 3))


class TestLogLikelihood(unittest.TestCase):
"""
Expand Down

0 comments on commit 9c8627d

Please sign in to comment.