From c7142c67e7842e5532eea3d199f0152a06ec8ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 23:50:51 +0100 Subject: [PATCH] fix broadcasting --- skpro/survival/aft/_aft_lifelines_weibull.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index ce85713b..1b9b5f97 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -3,6 +3,8 @@ __author__ = ["fkiraly"] +import numpy as np + from skpro.distributions.weibull import Weibull from skpro.survival.adapters.lifelines import _LifelinesAdapter from skpro.survival.base import BaseSurvReg @@ -181,6 +183,8 @@ def _predict_proba(self, X): ll_pred_proba = lifelines_est._prep_inputs_for_prediction_and_return_scores scale, k = ll_pred_proba(df, ancillary) + scale = np.expand_dims(scale, axis=1) + k = np.expand_dims(k, axis=1) dist = Weibull(scale=scale, k=k, index=X.index, columns=self._y_cols) return dist