diff --git a/chi/tests/test_predictive_models.py b/chi/tests/test_predictive_models.py index 61826c16..029a9c3c 100644 --- a/chi/tests/test_predictive_models.py +++ b/chi/tests/test_predictive_models.py @@ -1700,6 +1700,23 @@ def test_sample_bad_input(self): with self.assertRaisesRegex(ValueError, 'The length of parameters'): self.model.sample(parameters, times) + # Raises error when number of covariates and does not match model + seed = 100 + times = [1, 2, 3, 4, 5] + parameters = [1, 1, 1, 1, 1, 1, 1, 0.1, 0.1, 2, 3] + covariates = [1.3, 2.4, 1] + with self.assertRaisesRegex(ValueError, 'Provided covariates do not'): + self.model2.sample( + parameters, times, seed=seed, covariates=covariates) + + # Raises error when the covariates per sample do not match n_samples + n_samples = 3 + covariates = np.ones(shape=(5, 2)) + with self.assertRaisesRegex(ValueError, 'Provided covariates cannot'): + self.model2.sample( + parameters, times, seed=seed, covariates=covariates, + n_samples=n_samples) + class TestPriorPredictiveModel(unittest.TestCase): """