diff --git a/skpro/regression/_dummy.py b/skpro/regression/_dummy.py index 6e865613..ad87bd4b 100644 --- a/skpro/regression/_dummy.py +++ b/skpro/regression/_dummy.py @@ -105,13 +105,11 @@ def _predict_proba(self, X): """ X_ind = X.index X_n_rows = X.shape[0] - if self.strategy == "normal": # broadcast the mu and sigma from fit to the length of X - mu = np.ones(X_n_rows) * self._mu - sigma = np.ones(X_n_rows) * self._sigma + mu = np.reshape((np.ones(X_n_rows) * self._mu), (-1, 1)) + sigma = np.reshape((np.ones(X_n_rows) * self._sigma), (-1, 1)) pred_dist = Normal(mu=mu, sigma=sigma, index=X_ind, columns=self._y_columns) - return pred_dist if self.strategy == "empirical":