Skip to content

Commit

Permalink
Merge pull request #978 from pints-team/i975-multivariate-gaussian-se…
Browse files Browse the repository at this point in the history
…nsitivities

Multivariate Gaussian sensitivities
  • Loading branch information
MichaelClerx authored Oct 1, 2019
2 parents 84d67ae + 94c429c commit d2c6b48
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
5 changes: 5 additions & 0 deletions pints/_log_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,12 +660,17 @@ def __init__(self, mean, cov):
self._mean = mean
self._cov = cov
self._n_parameters = mean.shape[0]
self._cov_inverse = np.linalg.inv(self._cov)

def __call__(self, x):
return np.log(
scipy.stats.multivariate_normal.pdf(
x, mean=self._mean, cov=self._cov))

def evaluateS1(self, x):
""" See :meth:`LogPDF.evaluateS1()`. """
return self(x), -np.matmul(self._cov_inverse, x - self._mean)

def mean(self):
""" See :meth:`LogPrior.mean()`. """
return self._mean
Expand Down
43 changes: 34 additions & 9 deletions pints/tests/test_log_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,20 +590,15 @@ def test_log_normal_prior_sampling(self):
self.assertTrue(0.9 * analyt_mean < sample_mean < 1.1 * analyt_mean)

def test_multivariate_normal_prior(self):
# 1d test
mean = 0
covariance = 1

# Input must be a matrix
self.assertRaises(
ValueError, pints.MultivariateGaussianLogPrior, mean, covariance)
covariance = [1]
ValueError, pints.MultivariateGaussianLogPrior, 0, 1)
self.assertRaises(
ValueError, pints.MultivariateGaussianLogPrior, mean, covariance)
ValueError, pints.MultivariateGaussianLogPrior, 0, [1])

# Basic test
covariance = [[1]]
p = pints.MultivariateGaussianLogPrior(mean, covariance)
# 1d test
p = pints.MultivariateGaussianLogPrior(0, [[1]])
self.assertEqual(p([0]), -0.5 * np.log(2 * np.pi))

# 5d tests
Expand All @@ -623,6 +618,36 @@ def test_multivariate_normal_prior(self):
ValueError, pints.MultivariateGaussianLogPrior, [1, 2],
[[1, 0, 0], [0, 1, 0], [0, 0, 1]])

# Test sensitivities
mean = [1, 3]
covariance = [[2, 0.5], [0.5, 2]]
p = pints.MultivariateGaussianLogPrior(mean, covariance)
y, dy = p.evaluateS1([4, 5])
self.assertEqual(len(dy), 2)
self.assertAlmostEqual(y, -5.165421653067172, places=6)
dy_test = [-float(4 / 3), -float(2 / 3)]
self.assertAlmostEqual(dy[0], dy_test[0], places=6)
self.assertAlmostEqual(dy[1], dy_test[1], places=6)

mean = [-5.5, 6.7, 3.2]
covariance = [[3.4, -0.5, -0.7], [-0.5, 2.7, 1.4], [-0.7, 1.4, 5]]
p = pints.MultivariateGaussianLogPrior(mean, covariance)
y, dy = p.evaluateS1([4.4, 3.5, -3])
self.assertEqual(len(dy), 3)
self.assertAlmostEqual(y, -20.855279298674258, places=6)
dy_test = [-2.709773397444412, 0.27739553170576203, 0.7829609754801692]
self.assertAlmostEqual(dy[0], dy_test[0], places=6)
self.assertAlmostEqual(dy[1], dy_test[1], places=6)
self.assertAlmostEqual(dy[2], dy_test[2], places=6)

# 1d sensitivity test
p = pints.MultivariateGaussianLogPrior(0, [[1]])
x = [0]
y, dy = p.evaluateS1(x)
self.assertEqual(y, p(x))
self.assertTrue(len(dy), 1)
self.assertEqual(dy[0], 0)

def test_multivariate_normal_sampling(self):
d = 1
mean = 2
Expand Down

0 comments on commit d2c6b48

Please sign in to comment.