Skip to content

Commit

Permalink
Fixed test for 1691 multi-chain. Again.
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelClerx committed Dec 18, 2024
1 parent c4c76a6 commit 0a077db
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions pints/tests/test_mcmc_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,12 +759,12 @@ def test_log_pdf_storage_in_memory_multi(self):
priors = [self.log_prior(x) for x in chain]
self.assertTrue(np.all(evals[i, :, 2] == priors))

# Test with a sensitivity-using method
# Test with a sensitivity-using multi-chain method
# We don't have any of these!
mcmc = pints.MCMCController(
self.log_posterior, n_chains, xs,
method=FakeMultiChainSamplerWithSensitivities)
mcmc.set_max_iterations(5)
mcmc.set_max_iterations(n_iterations)
mcmc.set_log_to_screen(False)
mcmc.set_log_pdf_storage(True)
chains = mcmc.run()
Expand All @@ -773,6 +773,13 @@ def test_log_pdf_storage_in_memory_multi(self):
self.assertEqual(evals.shape[0], n_chains)
self.assertEqual(evals.shape[1], n_iterations)
self.assertEqual(evals.shape[2], 3)
for i, chain in enumerate(chains):
posteriors = [self.log_posterior(x) for x in chain]
self.assertTrue(np.all(evals[i, :, 0] == posteriors))
likelihoods = [self.log_likelihood(x) for x in chain]
self.assertTrue(np.all(evals[i, :, 1] == likelihoods))
priors = [self.log_prior(x) for x in chain]
self.assertTrue(np.all(evals[i, :, 2] == priors))

# Test with a loglikelihood
mcmc = pints.MCMCController(
Expand Down Expand Up @@ -1707,6 +1714,9 @@ def tell(self, fxs):
self._fxs = fxs
return self._xs, self._fxs, [True] * self._n_chains

def needs_sensitivities(self):
return True


class TestMCMCControllerSingleChainStorage(unittest.TestCase):
"""
Expand Down

0 comments on commit 0a077db

Please sign in to comment.