From eecccbd467a71dace7eee8958230b0356f55ddac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 00:25:15 +0100 Subject: [PATCH 01/33] lifelines --- skpro/survival/adapters/lifelines.py | 153 +++++++++++++++++++++++++++ skpro/survival/adapters/sksurv.py | 3 +- 2 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 skpro/survival/adapters/lifelines.py diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py new file mode 100644 index 00000000..751cea18 --- /dev/null +++ b/skpro/survival/adapters/lifelines.py @@ -0,0 +1,153 @@ +# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +"""Implements adapter for lifelines models.""" + +__all__ = ["_LifelinesAdapter"] +__author__ = ["fkiraly"] + +import numpy as np +import pandas as pd + +from skpro.distributions.empirical import Empirical +from skpro.utils.sklearn import prep_skl_df + + +class _LifelinesAdapter: + """Mixin adapter class for lifelines models.""" + + _tags = { + # packaging info + # -------------- + "authors": ["fkiraly"], + "python_dependencies": ["lifelines"], + "license_type": "permissive", + # capability tags + # --------------- + "X_inner_mtype": "pd_DataFrame_Table", + "y_inner_mtype": "pd_DataFrame_Table", + "C_inner_mtype": "pd_DataFrame_Table", + "capability:multioutput": False, + } + + # defines the name of the attribute containing the lifelines estimator + _estimator_attr = "_estimator" + + def _get_lifelines_class(self): + """Abstract method to get lifelines class. + + should import and return lifelines class + """ + # from lifelines import lifelinesClass + # + # return lifelines + raise NotImplementedError("abstract method") + + def _get_lifelines_object(self): + """Abstract method to initialize lifelines object. + + The default initializes result of _get_lifelines_class + with self.get_params. + """ + cls = self._get_lifelines_class() + return cls(**self.get_params()) + + def _init_lifelines_object(self): + """Abstract method to initialize lifelines object and set to _estimator_attr. + + The default writes the return of _get_lifelines_object to + the attribute of self with name _estimator_attr + """ + cls = self._get_lifelines_object() + setattr(self, self._estimator_attr, cls) + return getattr(self, self._estimator_attr) + + def _fit(self, X, y, C=None): + """Fit estimator training data. + + Parameters + ---------- + X : pd.DataFrame + Training features + y: pd.Series + Training labels + C: pd.Series, optional (default=None) + Censoring information for survival analysis. + + Returns + ------- + self: reference to self + Fitted estimator. + """ + lifelines_est = self._init_lifelines_object() + + # input conversion + X = X.astype("float") # lifelines insists on float dtype + X = prep_skl_df(X) + + to_concat = [X, y] + + if C is not None: + C_col = 1 - C.copy() # lifelines uses 1 for uncensored, 0 for censored + C_col.columns = ["__C"] + to_concat.append(C_col) + + df = pd.concat(to_concat, axis=1) + + self._y_cols = y.columns # remember column names for later + y_name = y.columns[0] + + + fit_args = { + "df": df, + "duration_col": y_name, + } + if C is not None: + fit_args["event_col"] = "__C" + + # fit lifelines estimator + lifelines_est.fit(**fit_args) + + # write fitted params to self + lifelines_fitted_params = self._get_fitted_params_default(lifelines_est) + for k, v in lifelines_fitted_params.items(): + setattr(self, f"{k}_", v) + + return self + + def _predict_proba(self, X): + """Predict_proba method adapter. + + Parameters + ---------- + X : pd.DataFrame + Features to predict on. + + Returns + ------- + skpro Empirical distribution + """ + lifelines_est = getattr(self, self._estimator_attr) + + # input conversion + X = X.astype("float") # lifelines insists on float dtype + X = prep_skl_df(X) + + # predict on X + lifelines_survf = lifelines_est.predict_survival_function(X) + + times = lifelines_survf.index[:-1] + + nt = len(times) + mi = pd.MultiIndex.from_product([X.index, range(nt)]).swaplevel() + + times_val = np.repeat(times, repeats=len(X)) + times_df = pd.DataFrame(times_val, index=mi, columns=self._y_cols) + + lifelines_survf_t = np.transpose(lifelines_survf.values) + weights = -np.diff(lifelines_survf, axis=1).flatten() + weights_df = pd.Series(weights, index=mi) + + dist = Empirical( + spl=times_df, weights=weights_df, index=X.index, columns=self._y_cols + ) + + return dist diff --git a/skpro/survival/adapters/sksurv.py b/skpro/survival/adapters/sksurv.py index 2c26aec5..7f39ae17 100644 --- a/skpro/survival/adapters/sksurv.py +++ b/skpro/survival/adapters/sksurv.py @@ -160,8 +160,7 @@ def _predict_proba(self, X): Returns ------- - np.ndarray (1d array of shape (n_instances,)) - Index of the cluster each time series in X belongs to. + skpro Empirical distribution """ sksurv_est = getattr(self, self._estimator_attr) From 429ffc4187ec161d49ae9f93744f21364414b28e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 00:40:43 +0100 Subject: [PATCH 02/33] lifelines adapter --- skpro/survival/adapters/lifelines.py | 2 +- skpro/survival/additive/__init__.py | 3 + skpro/survival/additive/_aalen_lifelines.py | 103 ++++++++++++++++++++ 3 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 skpro/survival/additive/__init__.py create mode 100644 skpro/survival/additive/_aalen_lifelines.py diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 751cea18..06dedb6f 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -143,7 +143,7 @@ def _predict_proba(self, X): times_df = pd.DataFrame(times_val, index=mi, columns=self._y_cols) lifelines_survf_t = np.transpose(lifelines_survf.values) - weights = -np.diff(lifelines_survf, axis=1).flatten() + weights = -np.diff(lifelines_survf_t, axis=1).flatten() weights_df = pd.Series(weights, index=mi) dist = Empirical( diff --git a/skpro/survival/additive/__init__.py b/skpro/survival/additive/__init__.py new file mode 100644 index 00000000..99ef76e7 --- /dev/null +++ b/skpro/survival/additive/__init__.py @@ -0,0 +1,3 @@ +"""Generalized additive survival models.""" + +from skpro.additive._aalen_lifelines import AalenAdditiveLifelines diff --git a/skpro/survival/additive/_aalen_lifelines.py b/skpro/survival/additive/_aalen_lifelines.py new file mode 100644 index 00000000..8c96965f --- /dev/null +++ b/skpro/survival/additive/_aalen_lifelines.py @@ -0,0 +1,103 @@ +"""Interface adapter to lifelines Aalen additive surival model.""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) + +__author__ = ["fkiraly"] + +from skpro.survival.adapters.lifelines import _LifelinesAdapter +from skpro.survival.base import BaseSurvReg + + +class AalenAdditive(_LifelinesAdapter, BaseSurvReg): + r"""Aalen additive hazards model, from lifelines. + + Direct interface to ``lifelines.fitters.AalenAdditiveFitter``, + by ``CamDavidsonPilon``. + + This class fits the regression model: + + .. math:: h(t|x) = b_0(t) + b_1(t) x_1 + ... + b_N(t) x_N + + that is, the hazard rate is a linear function of the covariates + with time-varying coefficients. + This implementation assumes non-time-varying covariates. + + Parameters + ----------- + fit_intercept: bool, optional (default: True) + If False, do not attach an intercept (column of ones) to the covariate matrix. + The intercept, :math:`b_0(t)` acts as a baseline hazard. + alpha: float, optional (default=0.05) + the level in the confidence intervals around the estimated survival function, + for computation of ``confidence_intervals_`` fitted parameter. + coef_penalizer: float, optional (default: 0) + Attach a L2 penalizer to the size of the coefficients during regression. + This improves + stability of the estimates and controls for high correlation between covariates. + For example, this shrinks the magnitude of :math:`c_{i,t}`. + smoothing_penalizer: float, optional (default: 0) + Attach a L2 penalizer to difference between adjacent (over time) coefficients. + For example, this shrinks the magnitude of :math:`c_{i,t} - c_{i,t+1}`. + + Attributes + ---------- + cumulative_hazards_ : DataFrame + The estimated cumulative hazard + hazards_ : DataFrame + The estimated hazards + confidence_intervals_ : DataFrame + The lower and upper confidence intervals for the cumulative hazard + durations: array + The durations provided + """ + + _tags = {"authors": ["CamDavidsonPilon", "rocreguant", "fkiraly"]} + # CamDavidsonPilon credit for interfaced estimator + + def __init__( + self, + fit_intercept=True, + alpha=0.05, + coef_penalizer=0.0, + smoothing_penalizer=0.0, + ): + self.fit_intercept = fit_intercept + self.alpha = alpha + self.coef_penalizer = coef_penalizer + self.smoothing_penalizer = smoothing_penalizer + + super().__init__() + + def _get_lifelines_class(self): + """Getter of the lifelines class to be used for the adapter.""" + from lifelines.fitters.aalen_additive_fitter import AalenAdditiveFitter + + return AalenAdditiveFitter + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + params1 = {} + + params2 = { + "fit_intercept": False, + "alpha": 0.1, + "coef_penalizer": 0.1, + "smoothing_penalizer": 0.1, + } + + return [params1, params2] From 2923254c0f1021a99d44ae08fa894bc1c3366935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 00:42:12 +0100 Subject: [PATCH 03/33] docs, export --- docs/source/api_reference/survival.rst | 11 +++++++++++ skpro/survival/additive/__init__.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/docs/source/api_reference/survival.rst b/docs/source/api_reference/survival.rst index d85298cd..afc42869 100644 --- a/docs/source/api_reference/survival.rst +++ b/docs/source/api_reference/survival.rst @@ -89,6 +89,17 @@ Proportional hazards models CoxPHSkSurv CoxNet +Generalized additive survival models +------------------------------------ + +.. currentmodule:: skpro.survival.additive + +.. autosummary:: + :toctree: auto_generated/ + :template: class.rst + + AalenAdditiveLifelines + Tree models ----------- diff --git a/skpro/survival/additive/__init__.py b/skpro/survival/additive/__init__.py index 99ef76e7..988ff8c2 100644 --- a/skpro/survival/additive/__init__.py +++ b/skpro/survival/additive/__init__.py @@ -1,3 +1,5 @@ """Generalized additive survival models.""" +__all__ = ["AalenAdditiveLifelines"] + from skpro.additive._aalen_lifelines import AalenAdditiveLifelines From 54a875e24ed5a0ba7ef811a6f0fcf54f2e8b20ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 00:46:49 +0100 Subject: [PATCH 04/33] pyproject --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c0dab3e4..bd384ca3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ all_extras = [ "attrs", "distfit", + "lifelines<0.29.0", "mapie", "matplotlib>=3.3.2", "ngboost", From ddb8bbe1cf1b9fccf1e6a1e87b4c7410785cacae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 00:47:52 +0100 Subject: [PATCH 05/33] documentation --- skpro/survival/adapters/lifelines.py | 4 ++-- skpro/survival/adapters/sksurv.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 06dedb6f..49ff9318 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -36,9 +36,9 @@ def _get_lifelines_class(self): should import and return lifelines class """ - # from lifelines import lifelinesClass + # from lifelines import LifelinesClass # - # return lifelines + # return LifelinesClass raise NotImplementedError("abstract method") def _get_lifelines_object(self): diff --git a/skpro/survival/adapters/sksurv.py b/skpro/survival/adapters/sksurv.py index 7f39ae17..a9b73f90 100644 --- a/skpro/survival/adapters/sksurv.py +++ b/skpro/survival/adapters/sksurv.py @@ -37,9 +37,9 @@ def _get_sksurv_class(self): should import and return sksurv class """ - # from sksurv import sksurvClass + # from sksurv import SksurvClass # - # return sksurv + # return SksurvClass raise NotImplementedError("abstract method") def _get_sksurv_object(self): From 5b26c03db3c55a369fb6a1ed50efd1d16c63c154 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 00:48:21 +0100 Subject: [PATCH 06/33] Update __init__.py --- skpro/survival/additive/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/survival/additive/__init__.py b/skpro/survival/additive/__init__.py index 988ff8c2..54587fad 100644 --- a/skpro/survival/additive/__init__.py +++ b/skpro/survival/additive/__init__.py @@ -2,4 +2,4 @@ __all__ = ["AalenAdditiveLifelines"] -from skpro.additive._aalen_lifelines import AalenAdditiveLifelines +from skpro.survival.additive._aalen_lifelines import AalenAdditiveLifelines From ed97ac61bcff5850fc7bfe71c2f542eae67fb5bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 00:53:37 +0100 Subject: [PATCH 07/33] Update __init__.py --- skpro/survival/additive/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skpro/survival/additive/__init__.py b/skpro/survival/additive/__init__.py index 54587fad..bc5f7f52 100644 --- a/skpro/survival/additive/__init__.py +++ b/skpro/survival/additive/__init__.py @@ -1,5 +1,5 @@ """Generalized additive survival models.""" -__all__ = ["AalenAdditiveLifelines"] +__all__ = ["AalenAdditive"] -from skpro.survival.additive._aalen_lifelines import AalenAdditiveLifelines +from skpro.survival.additive._aalen_lifelines import AalenAdditive From 11fc89c9a76d37d68d65537f7544bb56f6f3cb0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 00:54:07 +0100 Subject: [PATCH 08/33] Update _aalen_lifelines.py --- skpro/survival/additive/_aalen_lifelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/survival/additive/_aalen_lifelines.py b/skpro/survival/additive/_aalen_lifelines.py index 8c96965f..6b9781fc 100644 --- a/skpro/survival/additive/_aalen_lifelines.py +++ b/skpro/survival/additive/_aalen_lifelines.py @@ -51,7 +51,7 @@ class AalenAdditive(_LifelinesAdapter, BaseSurvReg): """ _tags = {"authors": ["CamDavidsonPilon", "rocreguant", "fkiraly"]} - # CamDavidsonPilon credit for interfaced estimator + # CamDavidsonPilon, rocreguant credit for interfaced estimator def __init__( self, From ba78ad6d6ca942e099a35f35f72107f50ccb3002 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 08:24:47 +0100 Subject: [PATCH 09/33] linting --- skpro/survival/adapters/lifelines.py | 1 - skpro/survival/additive/_aalen_lifelines.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 49ff9318..744d9c70 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -95,7 +95,6 @@ def _fit(self, X, y, C=None): self._y_cols = y.columns # remember column names for later y_name = y.columns[0] - fit_args = { "df": df, "duration_col": y_name, diff --git a/skpro/survival/additive/_aalen_lifelines.py b/skpro/survival/additive/_aalen_lifelines.py index 6b9781fc..210d1712 100644 --- a/skpro/survival/additive/_aalen_lifelines.py +++ b/skpro/survival/additive/_aalen_lifelines.py @@ -12,7 +12,7 @@ class AalenAdditive(_LifelinesAdapter, BaseSurvReg): Direct interface to ``lifelines.fitters.AalenAdditiveFitter``, by ``CamDavidsonPilon``. - + This class fits the regression model: .. math:: h(t|x) = b_0(t) + b_1(t) x_1 + ... + b_N(t) x_N @@ -22,7 +22,7 @@ class AalenAdditive(_LifelinesAdapter, BaseSurvReg): This implementation assumes non-time-varying covariates. Parameters - ----------- + ---------- fit_intercept: bool, optional (default: True) If False, do not attach an intercept (column of ones) to the covariate matrix. The intercept, :math:`b_0(t)` acts as a baseline hazard. From 533599989c333a3eaa0d2fbf838d9dd7ba78099f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 12:12:57 +0100 Subject: [PATCH 10/33] fix improper surv functions via clipping --- skpro/survival/adapters/lifelines.py | 67 +++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 744d9c70..395d16b4 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -4,6 +4,8 @@ __all__ = ["_LifelinesAdapter"] __author__ = ["fkiraly"] +from warnings import warn + import numpy as np import pandas as pd @@ -142,7 +144,20 @@ def _predict_proba(self, X): times_df = pd.DataFrame(times_val, index=mi, columns=self._y_cols) lifelines_survf_t = np.transpose(lifelines_survf.values) - weights = -np.diff(lifelines_survf_t, axis=1).flatten() + _, lifelines_survf_t_diff, clipped = _clip_surv(lifelines_survf_t) + + if clipped: + warn( + f"Warning from {self.__class__.__name__}: " + f"Interfaced lifelines class {lifelines_est.__class__.__name__} " + "produced improper survival function predictions, i.e., " + "not monotonically decreasing or not in [0, 1]. " + "skpro has clipped the predictions to enforce proper range and " + "valid predictive distributions. " + "However, predictions may still be unreliable." + ) + + weights = -lifelines_survf_t_diff.flatten() weights_df = pd.Series(weights, index=mi) dist = Empirical( @@ -150,3 +165,53 @@ def _predict_proba(self, X): ) return dist + + +def _clip_surv(surv_arr): + """Clips improper survival function values to proper range. + + Enforces: values are in [0, 1] and are monotonically decreasing. + + First clips to [0, 1], then enforces monotonicity, by replacing + any value with minimum of itself and any previous values. + + Parameters + ---------- + surv_arr : 2D np.ndarray + Survival function values. + index 0 is instance index. + index 1 is time index, increasing. + + Returns + ------- + surv_arr_clipped : 2D np.ndarray + Clipped survival function values. + surv_arr_diff : 2D np.ndarray + Difference of clipped survival function values. + Same as np.diff(surv_arr_clipped, axis=1). + Returned to avoid recomputation, if needed later in context. + clipped : boolean + Whether clipping was needed. + """ + too_large = surv_arr > 1 + too_small = surv_arr < 0 + + surv_arr[too_large] = 1 + surv_arr[too_small] = 0 + + surv_arr_diff = np.diff(surv_arr, axis=1) + + # avoid iterative minimization if no further clipping is needed + if np.sum(surv_arr_diff == 0): + clipped = too_large.any() or too_small.any() + return surv_arr, surv_arr_diff, clipped + + # enforce monotonicity + # iterating from left to right ensures values are replaced + # with minimum of itself and all values to the left + for i in range(1, surv_arr.shape[1]): + surv_arr[:, i] = np.minimum(surv_arr[:, i], surv_arr[:, i - 1]) + + surv_arr_diff = np.diff(surv_arr, axis=1) + + return surv_arr, surv_arr_diff, True From 7fcbb3a826d7f01287ceef45e5d2a95b44ef0b69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 12:15:14 +0100 Subject: [PATCH 11/33] Update lifelines.py --- skpro/survival/adapters/lifelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 395d16b4..691bd5b5 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -202,7 +202,7 @@ def _clip_surv(surv_arr): surv_arr_diff = np.diff(surv_arr, axis=1) # avoid iterative minimization if no further clipping is needed - if np.sum(surv_arr_diff == 0): + if np.sum(surv_arr_diff > 0) == 0: clipped = too_large.any() or too_small.any() return surv_arr, surv_arr_diff, clipped From 4727a679814014fcd5abbbe410796c623592f65d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 12:15:28 +0100 Subject: [PATCH 12/33] Update lifelines.py --- skpro/survival/adapters/lifelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 691bd5b5..2ab4f466 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -202,7 +202,7 @@ def _clip_surv(surv_arr): surv_arr_diff = np.diff(surv_arr, axis=1) # avoid iterative minimization if no further clipping is needed - if np.sum(surv_arr_diff > 0) == 0: + if not (surv_arr_diff > 0).any(): clipped = too_large.any() or too_small.any() return surv_arr, surv_arr_diff, clipped From b9b364768b58bb3f27e6b3b1aca8adffe5f28810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 12:31:07 +0100 Subject: [PATCH 13/33] deal with all zeroes --- skpro/survival/adapters/lifelines.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 2ab4f466..7ee3a33c 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -135,7 +135,7 @@ def _predict_proba(self, X): # predict on X lifelines_survf = lifelines_est.predict_survival_function(X) - times = lifelines_survf.index[:-1] + times = lifelines_survf.index nt = len(times) mi = pd.MultiIndex.from_product([X.index, range(nt)]).swaplevel() @@ -188,7 +188,7 @@ def _clip_surv(surv_arr): Clipped survival function values. surv_arr_diff : 2D np.ndarray Difference of clipped survival function values. - Same as np.diff(surv_arr_clipped, axis=1). + Same as np.diff(surv_arr_clipped, axis=1, prepend=1). Returned to avoid recomputation, if needed later in context. clipped : boolean Whether clipping was needed. @@ -199,7 +199,7 @@ def _clip_surv(surv_arr): surv_arr[too_large] = 1 surv_arr[too_small] = 0 - surv_arr_diff = np.diff(surv_arr, axis=1) + surv_arr_diff = np.diff(surv_arr, axis=1, prepend=1) # avoid iterative minimization if no further clipping is needed if not (surv_arr_diff > 0).any(): @@ -212,6 +212,6 @@ def _clip_surv(surv_arr): for i in range(1, surv_arr.shape[1]): surv_arr[:, i] = np.minimum(surv_arr[:, i], surv_arr[:, i - 1]) - surv_arr_diff = np.diff(surv_arr, axis=1) + surv_arr_diff = np.diff(surv_arr, axis=1, prepend=1) return surv_arr, surv_arr_diff, True From ef787c35941014eff3b369d90f5adbc84577dad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 12:39:11 +0100 Subject: [PATCH 14/33] move utils to common module --- skpro/survival/adapters/_common.py | 78 ++++++++++++++++++++++++++++ skpro/survival/adapters/lifelines.py | 51 +----------------- 2 files changed, 79 insertions(+), 50 deletions(-) create mode 100644 skpro/survival/adapters/_common.py diff --git a/skpro/survival/adapters/_common.py b/skpro/survival/adapters/_common.py new file mode 100644 index 00000000..ead4cd8b --- /dev/null +++ b/skpro/survival/adapters/_common.py @@ -0,0 +1,78 @@ +"""Common utilities for adapters.""" + +import numpy as np + + +def _clip_surv(surv_arr): + """Clips improper survival function values to proper range. + + Enforces: values are in [0, 1] and are monotonically decreasing. + + First clips to [0, 1], then enforces monotonicity, by replacing + any value with minimum of itself and any previous values. + + Parameters + ---------- + surv_arr : 2D np.ndarray + Survival function values. + index 0 is instance index. + index 1 is time index, increasing. + + Returns + ------- + surv_arr_clipped : 2D np.ndarray + Clipped survival function values. + surv_arr_diff : 2D np.ndarray + Difference of clipped survival function values. + Same as np.diff(surv_arr_clipped, axis=1, prepend=1). + Returned to avoid recomputation, if needed later in context. + clipped : boolean + Whether clipping was needed. + """ + too_large = surv_arr > 1 + too_small = surv_arr < 0 + + surv_arr[too_large] = 1 + surv_arr[too_small] = 0 + + surv_arr_diff = _surv_diff(surv_arr) + + # avoid iterative minimization if no further clipping is needed + if not (surv_arr_diff > 0).any(): + clipped = too_large.any() or too_small.any() + return surv_arr, surv_arr_diff, clipped + + # enforce monotonicity + # iterating from left to right ensures values are replaced + # with minimum of itself and all values to the left + for i in range(1, surv_arr.shape[1]): + surv_arr[:, i] = np.minimum(surv_arr[:, i], surv_arr[:, i - 1]) + + surv_arr_diff = _surv_diff(surv_arr) + + return surv_arr, surv_arr_diff, True + + +def _surv_diff(surv_arr): + """Compute difference of survival function values. + + Parameters + ---------- + surv_arr : 2D np.ndarray + Survival function values. + index 0 is instance index. + index 1 is time index, increasing. + + Returns + ------- + surv_arr_diff : 2D np.ndarray, same shape as surv_arr + Difference of survival function values. + Same as np.diff(surv_arr, axis=1, prepend=1, append=0), + then summing the last two columns to become one column + """ + surv_arr_diff = np.diff(surv_arr, axis=1, prepend=1, append=0) + + surv_arr_diff[:, -2] = surv_arr_diff[:, -2] + surv_arr_diff[:, -1] + surv_arr_diff = surv_arr_diff[:, :-1] + + return surv_arr_diff diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 7ee3a33c..f98f2bca 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -10,6 +10,7 @@ import pandas as pd from skpro.distributions.empirical import Empirical +from skpro.survival.adapters._common import _clip_surv from skpro.utils.sklearn import prep_skl_df @@ -165,53 +166,3 @@ def _predict_proba(self, X): ) return dist - - -def _clip_surv(surv_arr): - """Clips improper survival function values to proper range. - - Enforces: values are in [0, 1] and are monotonically decreasing. - - First clips to [0, 1], then enforces monotonicity, by replacing - any value with minimum of itself and any previous values. - - Parameters - ---------- - surv_arr : 2D np.ndarray - Survival function values. - index 0 is instance index. - index 1 is time index, increasing. - - Returns - ------- - surv_arr_clipped : 2D np.ndarray - Clipped survival function values. - surv_arr_diff : 2D np.ndarray - Difference of clipped survival function values. - Same as np.diff(surv_arr_clipped, axis=1, prepend=1). - Returned to avoid recomputation, if needed later in context. - clipped : boolean - Whether clipping was needed. - """ - too_large = surv_arr > 1 - too_small = surv_arr < 0 - - surv_arr[too_large] = 1 - surv_arr[too_small] = 0 - - surv_arr_diff = np.diff(surv_arr, axis=1, prepend=1) - - # avoid iterative minimization if no further clipping is needed - if not (surv_arr_diff > 0).any(): - clipped = too_large.any() or too_small.any() - return surv_arr, surv_arr_diff, clipped - - # enforce monotonicity - # iterating from left to right ensures values are replaced - # with minimum of itself and all values to the left - for i in range(1, surv_arr.shape[1]): - surv_arr[:, i] = np.minimum(surv_arr[:, i], surv_arr[:, i - 1]) - - surv_arr_diff = np.diff(surv_arr, axis=1, prepend=1) - - return surv_arr, surv_arr_diff, True From d93968959ea09f25d6e172dcf55b071ec42734f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 12:48:20 +0100 Subject: [PATCH 15/33] Update _common.py --- skpro/survival/adapters/_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skpro/survival/adapters/_common.py b/skpro/survival/adapters/_common.py index ead4cd8b..3508df84 100644 --- a/skpro/survival/adapters/_common.py +++ b/skpro/survival/adapters/_common.py @@ -24,7 +24,8 @@ def _clip_surv(surv_arr): Clipped survival function values. surv_arr_diff : 2D np.ndarray Difference of clipped survival function values. - Same as np.diff(surv_arr_clipped, axis=1, prepend=1). + Same as np.diff(surv_arr, axis=1, prepend=1, append=0), + then summing the last two columns to become one column. Returned to avoid recomputation, if needed later in context. clipped : boolean Whether clipping was needed. From 54f30060ad8784cd7e3153ac7ff958977bddb54e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 12:48:45 +0100 Subject: [PATCH 16/33] Update _common.py --- skpro/survival/adapters/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/survival/adapters/_common.py b/skpro/survival/adapters/_common.py index 3508df84..34b51a7b 100644 --- a/skpro/survival/adapters/_common.py +++ b/skpro/survival/adapters/_common.py @@ -22,7 +22,7 @@ def _clip_surv(surv_arr): ------- surv_arr_clipped : 2D np.ndarray Clipped survival function values. - surv_arr_diff : 2D np.ndarray + surv_arr_diff : 2D np.ndarray, same shape as surv_arr_clipped. Difference of clipped survival function values. Same as np.diff(surv_arr, axis=1, prepend=1, append=0), then summing the last two columns to become one column. From cfbecdf413cc08b4520690f5702435ca37aea218 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 14:29:22 +0100 Subject: [PATCH 17/33] coxph --- docs/source/api_reference/survival.rst | 1 + skpro/survival/adapters/lifelines.py | 3 +- skpro/survival/coxph/__init__.py | 3 +- skpro/survival/coxph/_coxph_lifelines.py | 208 +++++++++++++++++++++++ 4 files changed, 213 insertions(+), 2 deletions(-) create mode 100644 skpro/survival/coxph/_coxph_lifelines.py diff --git a/docs/source/api_reference/survival.rst b/docs/source/api_reference/survival.rst index afc42869..de4bb2b1 100644 --- a/docs/source/api_reference/survival.rst +++ b/docs/source/api_reference/survival.rst @@ -86,6 +86,7 @@ Proportional hazards models :template: class.rst CoxPH + CoxPHlifelines CoxPHSkSurv CoxNet diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index f98f2bca..b5a3b85e 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -155,7 +155,8 @@ def _predict_proba(self, X): "not monotonically decreasing or not in [0, 1]. " "skpro has clipped the predictions to enforce proper range and " "valid predictive distributions. " - "However, predictions may still be unreliable." + "However, predictions may still be unreliable.", + stacklevel=2, ) weights = -lifelines_survf_t_diff.flatten() diff --git a/skpro/survival/coxph/__init__.py b/skpro/survival/coxph/__init__.py index 45d88ae4..d0a6dfbd 100644 --- a/skpro/survival/coxph/__init__.py +++ b/skpro/survival/coxph/__init__.py @@ -1,7 +1,8 @@ """Cox proportional hazards models.""" from skpro.survival.coxph._coxnet_sksurv import CoxNet +from skpro.survival.coxph._coxph_lifelines import CoxPHlifelines from skpro.survival.coxph._coxph_sksurv import CoxPHSkSurv from skpro.survival.coxph._coxph_statsmodels import CoxPH -__all__ = ["CoxNet", "CoxPH", "CoxPHSkSurv"] +__all__ = ["CoxNet", "CoxPH", "CoxPHlifelines", "CoxPHSkSurv"] diff --git a/skpro/survival/coxph/_coxph_lifelines.py b/skpro/survival/coxph/_coxph_lifelines.py new file mode 100644 index 00000000..8ae70cd2 --- /dev/null +++ b/skpro/survival/coxph/_coxph_lifelines.py @@ -0,0 +1,208 @@ +"""Interface adapter to lifelines Cox PH surival model.""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) + +__author__ = ["fkiraly"] + +from skpro.survival.adapters.lifelines import _LifelinesAdapter +from skpro.survival.base import BaseSurvReg + + +class CoxPHlifelines(_LifelinesAdapter, BaseSurvReg): + r"""Cox proportional hazards models, from lifelines. + + Direct interface to ``lifelines.fitters.CoxPHFitter``, + by ``CamDavidsonPilon``. + + This class implements Cox proportional hazard model, + + .. math:: h(t|x) = h_0(t) \exp((x - \overline{x})' \beta) + + with different options to fit the baseline hazard, :math:`h_0(t)`. + + The class offers multiple options via the ``baseline_estimation_method`` parameter: + + ``"breslow"`` (default): non-parametric estimate via Breslow's method. + In this case, the entire model is the traditional semi-parametric Cox model. + Ties are handled using Efron's method. + + ``"spline"``: parametric spline fit of baseline hazard, + via Royston-Parmar's method [1]_. The parametric form is + + .. math:: H_0(t) = \exp{\left( \phi_0 + \phi_1\log{t} + \sum_{j=2}^N \phi_j v_j(\log{t})\right)} # noqa E501 + + where :math:`v_j` are our cubic basis functions at predetermined knots, + and :math:`H_0` is the cumulative baseline hazard. See [1]_ for exact definition. + + ``"piecewise"``: non-parametric, piecewise constant empirical baseline hazard. + The explicit form of the baseline hazard is + + .. math:: h_0(t) = \begin{cases} + exp{\beta \cdot \text{center}(x)} & \text{if $t \le \tau_0$} \\ + exp{\beta \cdot \text{center}(x)} \cdot lambda_1 & \text{if $\tau_0 < t \le \tau_1$} \\ # noqa E501 + exp{\beta \cdot \text{center}(x)} \cdot lambda_2 & \text{if $\tau_1 < t \le \tau_2$} \\ # noqa E501 + ... + \end{cases} + + Parameters + ---------- + alpha: float, optional (default=0.05) + the level in the confidence intervals around the estimated survival function, + for computation of ``confidence_intervals_`` fitted parameter. + + baseline_estimation_method: string, default="breslow", + one of: ``"breslow"``, ``"spline"``, or ``"piecewise"``. + Specifies algorithm for estimato of baseline hazard, see above. + + penalizer: float or array, optional (default=0.0) + Penalty to the size of the coefficients during regression. + This improves stability of the estimates and controls for high correlation + between covariates. + For example, this shrinks the magnitude value of :math:`\beta_i`. + See ``l1_ratio`` below. + The penalty term is :math:`\text{penalizer} \left( \frac{1-\text{l1_ratio}}{2} ||\beta||_2^2 + \text{l1_ratio}||\beta||_1\right)`. # noqa E501 + + If an array, must be equal in size to the number of parameters, + with penalty coefficients for specific variables. For + example, ``penalizer=0.01 * np.ones(p)`` is the same as ``penalizer=0.01``. + + l1_ratio: float, optional (default=0.0) + Specify what ratio to assign to a L1 vs L2 penalty. + Same as in scikit-learn. See ``penalizer`` above. + + strata: list, optional + specify a list of columns to use in stratification. This is useful if a + categorical covariate does not obey the proportional hazard assumption. This + is used similar to the ``strata`` expression in R. + See http://courses.washington.edu/b515/l17.pdf. + + n_baseline_knots: int, optional, default=4 + Used only when ``baseline_estimation_method="spline"``. + Set the number of knots (interior & exterior) in the baseline hazard, + which will be placed evenly along the time axis. + Should be at least 2. + Royston et. al, the authors of this model, suggest 4 to start, + but any values between 2 and 8 are reasonable. + If you need to customize the timestamps used to calculate the curve, + use the ``knots`` parameter instead. + + knots: list, optional + Used only when ``baseline_estimation_method="spline"``. + Specifies custom points in the time axis for the baseline hazard curve. + To use evenly-spaced points in time, the ``n_baseline_knots`` + parameter can be employed instead. + + breakpoints: list, optional + Used only when ``baseline_estimation_method="piecewise"``. + Set the positions of the baseline hazard breakpoints. + + Attributes + ---------- + params_ : Series + The estimated coefficients. + hazard_ratios_ : Series + The exp(coefficients) + confidence_intervals_ : DataFrame + The lower and upper confidence intervals for the hazard coefficients + durations: Series + The durations provided + event_observed: Series + The event_observed variable provided + weights: Series + The event_observed variable provided + variance_matrix_ : DataFrame + The variance matrix of the coefficients + strata: list + the strata provided + standard_errors_: Series + the standard errors of the estimates + log_likelihood_: float + the log-likelihood at the fitted coefficients + AIC_: float + the AIC at the fitted coefficients (if using splines for baseline hazard) + partial_AIC_: float + the AIC at the fitted coefficients + (if using non-parametric inference for baseline hazard) + baseline_hazard_: DataFrame + the baseline hazard evaluated at the observed times. + Estimated using Breslow's method. + baseline_cumulative_hazard_: DataFrame + the baseline cumulative hazard evaluated at the observed times. + Estimated using Breslow's method. + baseline_survival_: DataFrame + the baseline survival evaluated at the observed times. + Estimated using Breslow's method. + summary: Dataframe + a Dataframe of the coefficients, p-values, CIs, etc. + + References + ---------- + .. [1] Royston, P., Parmar, M. K. B. (2002). + Flexible parametric proportional-hazards and proportional-odds + models for censored survival data, with application to prognostic + modelling and estimation of treatment effects. + Statistics in Medicine, 21(15), 2175–2197. doi:10.1002/sim.1203 + """ + + _tags = {"authors": ["CamDavidsonPilon", "rocreguant", "fkiraly"]} + # CamDavidsonPilon, rocreguant credit for interfaced estimator + + def __init__( + self, + baseline_estimation_method: str = "breslow", + penalizer=0.0, + strata=None, + l1_ratio=0.0, + n_baseline_knots=None, + knots=None, + breakpoints=None, + ): + self.baseline_estimation_method = baseline_estimation_method + self.penalizer = penalizer + self.strata = strata + self.l1_ratio = l1_ratio + self.n_baseline_knots = n_baseline_knots + self.knots = knots + self.breakpoints = breakpoints + + super().__init__() + + def _get_lifelines_class(self): + """Getter of the lifelines class to be used for the adapter.""" + from lifelines.fitters.coxph_fitter import CoxPHFitter + + return CoxPHFitter + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + params1 = {} + + params2 = { + "baseline_estimation_method": "spline", + "penalizer": 0.1, + "l1_ratio": 0.1, + "n_baseline_knots": 3, + } + + params3 = { + "baseline_estimation_method": "piecewise", + "penalizer": 0.15, + "l1_ratio": 0.05, + } + + return [params1, params2, params3] From 266485b1dd3c531e0aa6a5d4a0dd5e73b7e324ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 14:30:22 +0100 Subject: [PATCH 18/33] Update _coxph_lifelines.py --- skpro/survival/coxph/_coxph_lifelines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skpro/survival/coxph/_coxph_lifelines.py b/skpro/survival/coxph/_coxph_lifelines.py index 8ae70cd2..7af1506f 100644 --- a/skpro/survival/coxph/_coxph_lifelines.py +++ b/skpro/survival/coxph/_coxph_lifelines.py @@ -143,8 +143,8 @@ class CoxPHlifelines(_LifelinesAdapter, BaseSurvReg): Statistics in Medicine, 21(15), 2175–2197. doi:10.1002/sim.1203 """ - _tags = {"authors": ["CamDavidsonPilon", "rocreguant", "fkiraly"]} - # CamDavidsonPilon, rocreguant credit for interfaced estimator + _tags = {"authors": ["CamDavidsonPilon", "JoseLlanes", "mathurinm", "fkiraly"]} + # CamDavidsonPilon, JoseLlanes, mathurinm credit for interfaced estimator def __init__( self, From a38f099e59adb7ebd934e9a338745369c97eebe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 16:01:22 +0100 Subject: [PATCH 19/33] safe get param --- skpro/survival/adapters/_common.py | 41 +++++++++++++++++++++++++ skpro/survival/adapters/lifelines.py | 4 +-- skpro/survival/adapters/sksurv.py | 45 ++-------------------------- 3 files changed, 45 insertions(+), 45 deletions(-) diff --git a/skpro/survival/adapters/_common.py b/skpro/survival/adapters/_common.py index 34b51a7b..468eecda 100644 --- a/skpro/survival/adapters/_common.py +++ b/skpro/survival/adapters/_common.py @@ -77,3 +77,44 @@ def _surv_diff(surv_arr): surv_arr_diff = surv_arr_diff[:, :-1] return surv_arr_diff + + +def _get_fitted_params_default_safe(obj=None): + """Obtain fitted params of object, per sklearn convention. + + Same as _get_fitted_params_default, but with exception handling. + + This is since in sksurv, feature_importances_ is a property + and may raise an exception if the estimator does not have it. + + Parameters + ---------- + obj : any object + + Returns + ------- + fitted_params : dict with str keys + fitted parameters, keyed by names of fitted parameter + """ + # default retrieves all self attributes ending in "_" + # and returns them with keys that have the "_" removed + # + # get all attributes ending in "_", exclude any that start with "_" (private) + fitted_params = [ + attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_") + ] + + def hasattr_safe(obj, attr): + try: + if hasattr(obj, attr): + getattr(obj, attr) + return True + except Exception: + return False + + # remove the "_" at the end + fitted_param_dict = { + p[:-1]: getattr(obj, p) for p in fitted_params if hasattr_safe(obj, p) + } + + return fitted_param_dict diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index b5a3b85e..35aa1cae 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -10,7 +10,7 @@ import pandas as pd from skpro.distributions.empirical import Empirical -from skpro.survival.adapters._common import _clip_surv +from skpro.survival.adapters._common import _clip_surv, _get_fitted_params_default_safe from skpro.utils.sklearn import prep_skl_df @@ -109,7 +109,7 @@ def _fit(self, X, y, C=None): lifelines_est.fit(**fit_args) # write fitted params to self - lifelines_fitted_params = self._get_fitted_params_default(lifelines_est) + lifelines_fitted_params = _get_fitted_params_default_safe(lifelines_est) for k, v in lifelines_fitted_params.items(): setattr(self, f"{k}_", v) diff --git a/skpro/survival/adapters/sksurv.py b/skpro/survival/adapters/sksurv.py index a9b73f90..3aa48cce 100644 --- a/skpro/survival/adapters/sksurv.py +++ b/skpro/survival/adapters/sksurv.py @@ -8,6 +8,7 @@ import pandas as pd from skpro.distributions.empirical import Empirical +from skpro.survival.adapters._common import _get_fitted_params_default_safe from skpro.utils.sklearn import prep_skl_df @@ -101,55 +102,13 @@ def _fit(self, X, y, C=None): # write fitted params to self EXCEPTED_FITTED_PARAMS = ["n_features_in", "feature_names_in"] - sksurv_fitted_params = self._get_fitted_params_default_safe(sksurv_est) + sksurv_fitted_params = _get_fitted_params_default_safe(sksurv_est) for k, v in sksurv_fitted_params.items(): if k not in EXCEPTED_FITTED_PARAMS: setattr(self, f"{k}_", v) return self - def _get_fitted_params_default_safe(self, obj=None): - """Obtain fitted params of object, per sklearn convention. - - Same as _get_fitted_params_default, but with exception handling. - - This is since in sksurv, feature_importances_ is a property - and may raise an exception if the estimator does not have it. - - Parameters - ---------- - obj : any object, optional, default=self - - Returns - ------- - fitted_params : dict with str keys - fitted parameters, keyed by names of fitted parameter - """ - obj = obj if obj else self - - # default retrieves all self attributes ending in "_" - # and returns them with keys that have the "_" removed - # - # get all attributes ending in "_", exclude any that start with "_" (private) - fitted_params = [ - attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_") - ] - - def hasattr_safe(obj, attr): - try: - if hasattr(obj, attr): - getattr(obj, attr) - return True - except Exception: - return False - - # remove the "_" at the end - fitted_param_dict = { - p[:-1]: getattr(obj, p) for p in fitted_params if hasattr_safe(obj, p) - } - - return fitted_param_dict - def _predict_proba(self, X): """Predict_proba method adapter. From b3d1e076e48dae7625a7b8354050af55678bb9e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 16:02:46 +0100 Subject: [PATCH 20/33] comments --- skpro/survival/adapters/lifelines.py | 3 +++ skpro/survival/adapters/sksurv.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 35aa1cae..3e0e4755 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -109,6 +109,9 @@ def _fit(self, X, y, C=None): lifelines_est.fit(**fit_args) # write fitted params to self + # some fitted parameters are properties and may raise exceptions + # for example, AIC_ of CoxPHFitter + # to avoid this, we use a safe getter lifelines_fitted_params = _get_fitted_params_default_safe(lifelines_est) for k, v in lifelines_fitted_params.items(): setattr(self, f"{k}_", v) diff --git a/skpro/survival/adapters/sksurv.py b/skpro/survival/adapters/sksurv.py index 3aa48cce..ab6e69c1 100644 --- a/skpro/survival/adapters/sksurv.py +++ b/skpro/survival/adapters/sksurv.py @@ -101,6 +101,9 @@ def _fit(self, X, y, C=None): sksurv_est.fit(X, y_sksurv) # write fitted params to self + # some fitted parameters are properties and may raise exceptions + # for example, AIC_ of CoxPHFitter + # to avoid this, we use a safe getter EXCEPTED_FITTED_PARAMS = ["n_features_in", "feature_names_in"] sksurv_fitted_params = _get_fitted_params_default_safe(sksurv_est) for k, v in sksurv_fitted_params.items(): From 1fc4408dd6fef9be1804b606662fdf10164d5354 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 16:03:21 +0100 Subject: [PATCH 21/33] fix comment --- skpro/survival/adapters/lifelines.py | 2 +- skpro/survival/adapters/sksurv.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index 3e0e4755..f8648e00 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -110,7 +110,7 @@ def _fit(self, X, y, C=None): # write fitted params to self # some fitted parameters are properties and may raise exceptions - # for example, AIC_ of CoxPHFitter + # for example, AIC_ and AIC_partial_ of CoxPHFitter # to avoid this, we use a safe getter lifelines_fitted_params = _get_fitted_params_default_safe(lifelines_est) for k, v in lifelines_fitted_params.items(): diff --git a/skpro/survival/adapters/sksurv.py b/skpro/survival/adapters/sksurv.py index ab6e69c1..2716270c 100644 --- a/skpro/survival/adapters/sksurv.py +++ b/skpro/survival/adapters/sksurv.py @@ -102,7 +102,7 @@ def _fit(self, X, y, C=None): # write fitted params to self # some fitted parameters are properties and may raise exceptions - # for example, AIC_ of CoxPHFitter + # for example, AIC_ and AIC_partial_ of CoxPHFitter # to avoid this, we use a safe getter EXCEPTED_FITTED_PARAMS = ["n_features_in", "feature_names_in"] sksurv_fitted_params = _get_fitted_params_default_safe(sksurv_est) From 091acfca10961b0562d737d8624c3a60c0cd738d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 19:11:47 +0100 Subject: [PATCH 22/33] Update _coxph_lifelines.py --- skpro/survival/coxph/_coxph_lifelines.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/skpro/survival/coxph/_coxph_lifelines.py b/skpro/survival/coxph/_coxph_lifelines.py index 7af1506f..baa94da8 100644 --- a/skpro/survival/coxph/_coxph_lifelines.py +++ b/skpro/survival/coxph/_coxph_lifelines.py @@ -51,7 +51,8 @@ class CoxPHlifelines(_LifelinesAdapter, BaseSurvReg): baseline_estimation_method: string, default="breslow", one of: ``"breslow"``, ``"spline"``, or ``"piecewise"``. - Specifies algorithm for estimato of baseline hazard, see above. + Specifies algorithm for estimation of baseline hazard, see above. + If ``"piecewise"``, the ``breakpoints`` parameter must be set. penalizer: float or array, optional (default=0.0) Penalty to the size of the coefficients during regression. @@ -92,7 +93,8 @@ class CoxPHlifelines(_LifelinesAdapter, BaseSurvReg): parameter can be employed instead. breakpoints: list, optional - Used only when ``baseline_estimation_method="piecewise"``. + Used only when ``baseline_estimation_method="piecewise"``, + must be passed in this case. Set the positions of the baseline hazard breakpoints. Attributes @@ -203,6 +205,7 @@ def get_test_params(cls, parameter_set="default"): "baseline_estimation_method": "piecewise", "penalizer": 0.15, "l1_ratio": 0.05, + "breakpoints": [10, 20, 30, 100], } return [params1, params2, params3] From 2afaa7f8a9998120df0e6782b4e10e3553ee0533 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 19:27:52 +0100 Subject: [PATCH 23/33] Update _coxph_lifelines.py --- skpro/survival/coxph/_coxph_lifelines.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/skpro/survival/coxph/_coxph_lifelines.py b/skpro/survival/coxph/_coxph_lifelines.py index baa94da8..3f4ed158 100644 --- a/skpro/survival/coxph/_coxph_lifelines.py +++ b/skpro/survival/coxph/_coxph_lifelines.py @@ -201,11 +201,14 @@ def get_test_params(cls, parameter_set="default"): "n_baseline_knots": 3, } - params3 = { - "baseline_estimation_method": "piecewise", - "penalizer": 0.15, - "l1_ratio": 0.05, - "breakpoints": [10, 20, 30, 100], - } - - return [params1, params2, params3] + # breakpoints are specific to data ranges, + # but tests loop over various data sets, so this would break + # + # params3 = { + # "baseline_estimation_method": "piecewise", + # "penalizer": 0.15, + # "l1_ratio": 0.05, + # "breakpoints": [10, 20, 30, 100], + # } + + return [params1, params2] From e9f62528003dfc91f4e517f3a97817a1d159d032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 21:27:07 +0100 Subject: [PATCH 24/33] weibull partial work --- skpro/survival/aft/__init__.py | 6 + skpro/survival/aft/_aft_lifelines_weibull.py | 126 +++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 skpro/survival/aft/__init__.py create mode 100644 skpro/survival/aft/_aft_lifelines_weibull.py diff --git a/skpro/survival/aft/__init__.py b/skpro/survival/aft/__init__.py new file mode 100644 index 00000000..e57b2790 --- /dev/null +++ b/skpro/survival/aft/__init__.py @@ -0,0 +1,6 @@ +"""Module containing accelerated failure time models.""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) + +__all__ = ["_SksurvAdapter"] + +from skpro.survival.adapters.sksurv import _SksurvAdapter diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py new file mode 100644 index 00000000..3a53a0af --- /dev/null +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -0,0 +1,126 @@ +"""Interface adapter to lifelines Weibull AFT model.""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) + +__author__ = ["fkiraly"] + +from skpro.survival.adapters.lifelines import _LifelinesAdapter +from skpro.survival.base import BaseSurvReg + + +class AFTWeibullLifelines(_LifelinesAdapter, BaseSurvReg): + r"""Weibull AFT model, from lifelines. + + Direct interface to ``lifelines.fitters.WeibullAFTFitter``, + by ``CamDavidsonPilon``. + + This class implements a Weibull AFT model. The model has parametric form, with + :math:`\lambda(x) = \exp\left(\beta_0 + \beta_1x_1 + ... + \beta_n x_n \right)`, + and optionally, + :math:`\rho(y) = \exp\left(\alpha_0 + \alpha_1 y_1 + ... + \alpha_m y_m \right)`, + + with predictive distribution being Weibull, with + scale parameter :math:`\lambda(x)` and shape parameter (exponent) :math:`\rho(y)`. + + The :math:`\lambda` (scale) parameter is a decay or half-like like parameter, + more specifically, the time by which the survival probability is 37%. + The :math:`\rho` (shape) parameter controls curvature of the the cumulative hazard, + e.g., whether it is convex or concave, representing accelerating or decelerating + hazards. + + The cumulative hazard rate is + + .. math:: H(t; x, y) = \left(\frac{t}{\lambda(x)} \right)^{\rho(y)}, + + Parameters + ---------- + alpha: float, optional (default=0.05) + the level in the confidence intervals around the estimated survival function, + for computation of ``confidence_intervals_`` fitted parameter. + + fit_intercept: boolean, optional (default=True) + Whether to fit an intercept term in the model. + + penalizer: float or array, optional (default=0.0) + the penalizer coefficient to the size of the coefficients. + See ``l1_ratio``. Must be equal to or greater than 0. + Alternatively, penalizer is an array equal in size to the number of parameters, + with penalty coefficients for specific variables. For + example, ``penalizer=0.01 * np.ones(p)`` is the same as ``penalizer=0.01`` + + l1_ratio: float, optional (default=0.0) + how much of the penalizer should be attributed to an l1 penalty + (otherwise an l2 penalty). The penalty function looks like + ``penalizer * l1_ratio * ||w||_1 + 0.5 * penalizer * (1 - l1_ratio) * ||w||^2_2`` # noqa E501 + + Attributes + ---------- + params_ : DataFrame + The estimated coefficients + confidence_intervals_ : DataFrame + The lower and upper confidence intervals for the coefficients + durations: Series + The event_observed variable provided + event_observed: Series + The event_observed variable provided + weights: Series + The event_observed variable provided + variance_matrix_ : DataFrame + The variance matrix of the coefficients + standard_errors_: Series + the standard errors of the estimates + score_: float + the concordance index of the model. + """ + + _tags = {"authors": ["CamDavidsonPilon", "JoseLlanes", "mathurinm", "fkiraly"]} + # CamDavidsonPilon, JoseLlanes, mathurinm credit for interfaced estimator + + def __init__( + self, + alpha: float = 0.05, + penalizer: float = 0.0, + l1_ratio: float = 0.0, + fit_intercept: bool = True, + ): + self.alpha = alpha + self.penalizer = penalizer + self.l1_ratio = l1_ratio + self.fit_intercept = fit_intercept + + super().__init__() + + def _get_lifelines_class(self): + """Getter of the lifelines class to be used for the adapter.""" + from lifelines.fitters.weibull_aft_fitter import WeibullAFTFitter + + return WeibullAFTFitter + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + params1 = {} + + params2 = { + "baseline_estimation_method": "spline", + "penalizer": 0.1, + "l1_ratio": 0.1, + "n_baseline_knots": 3, + } + + + return [params1, params2] From d2545915f9f63b505480a607ebcfb9b7105bfa2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 21:51:02 +0100 Subject: [PATCH 25/33] weibull without weibull --- skpro/survival/adapters/lifelines.py | 28 +++++- skpro/survival/aft/_aft_lifelines_weibull.py | 94 ++++++++++++++++++-- 2 files changed, 112 insertions(+), 10 deletions(-) diff --git a/skpro/survival/adapters/lifelines.py b/skpro/survival/adapters/lifelines.py index f8648e00..f3c0b49f 100644 --- a/skpro/survival/adapters/lifelines.py +++ b/skpro/survival/adapters/lifelines.py @@ -63,6 +63,25 @@ def _init_lifelines_object(self): setattr(self, self._estimator_attr, cls) return getattr(self, self._estimator_attr) + def _get_extra_fit_args(self, X, y, C=None): + """Get extra arguments for the fit method. + + Parameters + ---------- + X : pd.DataFrame + Training features + y: pd.DataFrame + Training labels + C: pd.DataFrame, optional (default=None) + Censoring information for survival analysis. + + Returns + ------- + dict + Extra arguments for the fit method. + """ + return {} + def _fit(self, X, y, C=None): """Fit estimator training data. @@ -70,9 +89,9 @@ def _fit(self, X, y, C=None): ---------- X : pd.DataFrame Training features - y: pd.Series + y: pd.DataFrame Training labels - C: pd.Series, optional (default=None) + C: pd.DataFrame, optional (default=None) Censoring information for survival analysis. Returns @@ -86,6 +105,9 @@ def _fit(self, X, y, C=None): X = X.astype("float") # lifelines insists on float dtype X = prep_skl_df(X) + if hasattr(self, "X_col_subset"): + X = X[self.X_col_subset] + to_concat = [X, y] if C is not None: @@ -105,6 +127,8 @@ def _fit(self, X, y, C=None): if C is not None: fit_args["event_col"] = "__C" + fit_args.update(self._get_extra_fit_args(X, y, C)) + # fit lifelines estimator lifelines_est.fit(**fit_args) diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index 3a53a0af..cf5443cb 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -3,8 +3,10 @@ __author__ = ["fkiraly"] +from skpro.distributions.weibull import Weibull from skpro.survival.adapters.lifelines import _LifelinesAdapter from skpro.survival.base import BaseSurvReg +from skpro.utils.sklearn import prep_skl_df class AFTWeibullLifelines(_LifelinesAdapter, BaseSurvReg): @@ -33,12 +35,23 @@ class AFTWeibullLifelines(_LifelinesAdapter, BaseSurvReg): Parameters ---------- - alpha: float, optional (default=0.05) - the level in the confidence intervals around the estimated survival function, - for computation of ``confidence_intervals_`` fitted parameter. + scale_cols: pd.Index or coercible, optional, default=None + Columns of the input data frame to be used as covariates for + the scale parameter :math:`\lambda`. + If None, all columns are used. + + shape_cols: string "all", pd.Index or coercible, optional, default=None + Columns of the input data frame to be used as covariates for + the shape parameter :math:`\rho`. + If None, no covariates are used, the shape parameter is estimated as a constant. + If "all", all columns are used. fit_intercept: boolean, optional (default=True) Whether to fit an intercept term in the model. + + alpha: float, optional (default=0.05) + the level in the confidence intervals around the estimated survival function, + for computation of ``confidence_intervals_`` fitted parameter. penalizer: float or array, optional (default=0.0) the penalizer coefficient to the size of the coefficients. @@ -77,11 +90,15 @@ class AFTWeibullLifelines(_LifelinesAdapter, BaseSurvReg): def __init__( self, + scale_cols=None, + shape_cols=None, + fit_intercept: bool = True, alpha: float = 0.05, penalizer: float = 0.0, l1_ratio: float = 0.0, - fit_intercept: bool = True, ): + self.scale_cols = scale_cols + self.shape_cols = shape_cols self.alpha = alpha self.penalizer = penalizer self.l1_ratio = l1_ratio @@ -89,12 +106,74 @@ def __init__( super().__init__() + if scale_cols is not None: + self.X_col_subset = scale_cols + def _get_lifelines_class(self): """Getter of the lifelines class to be used for the adapter.""" from lifelines.fitters.weibull_aft_fitter import WeibullAFTFitter return WeibullAFTFitter + def _add_extra_fit_args(self, X, y, C=None): + """Get extra arguments for the fit method. + + Parameters + ---------- + X : pd.DataFrame + Training features + y: pd.DataFrame + Training labels + C: pd.DataFrame, optional (default=None) + Censoring information for survival analysis. + fit_args: dict, optional (default=None) + Existing arguments for the fit method, from the adapter. + + Returns + ------- + dict + Extra arguments for the fit method. + """ + if self.scale_cols is not None: + if self.scale_cols == "all": + return {"ancillary": True} + else: + return {"ancillary": X[self.scale_cols]} + else: + return {} + + def _predict_proba(self, X): + """Predict_proba method adapter. + + Parameters + ---------- + X : pd.DataFrame + Features to predict on. + + Returns + ------- + skpro Empirical distribution + """ + if self.shape_cols == "all": + ancillary = X + elif self.shape_cols is not None: + ancillary = X[self.shape_cols] + else: + ancillary = None + + if self.scale_cols is not None: + df = X[self.scale_cols] + else: + df = X + + lifelines_est = getattr(self, self._estimator_attr) + ll_pred_proba = lifelines_est._prep_inputs_for_prediction_and_return_scores + + scale, shape = ll_pred_proba(df, ancillary) + + dist = Weibull(scale=scale, shape=shape, index=X.index, columns=self._y_cols) + return dist + @classmethod def get_test_params(cls, parameter_set="default"): """Return testing parameter settings for the estimator. @@ -116,11 +195,10 @@ def get_test_params(cls, parameter_set="default"): params1 = {} params2 = { - "baseline_estimation_method": "spline", + "shape_cols": "all", + "fit_intercept": False, + "alpha": 0.1, "penalizer": 0.1, "l1_ratio": 0.1, - "n_baseline_knots": 3, } - - return [params1, params2] From 780c4518bb40c0b3431e68cebf6f5322d145a4ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 21:54:55 +0100 Subject: [PATCH 26/33] Update _aft_lifelines_weibull.py --- skpro/survival/aft/_aft_lifelines_weibull.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index cf5443cb..e9ef8d09 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -85,8 +85,8 @@ class AFTWeibullLifelines(_LifelinesAdapter, BaseSurvReg): the concordance index of the model. """ - _tags = {"authors": ["CamDavidsonPilon", "JoseLlanes", "mathurinm", "fkiraly"]} - # CamDavidsonPilon, JoseLlanes, mathurinm credit for interfaced estimator + _tags = {"authors": ["CamDavidsonPilon", "fkiraly"]} + # CamDavidsonPilon, credit for interfaced estimator def __init__( self, From 42d625360354770e81ffbe8922bf593a91cdd787 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 22:58:56 +0100 Subject: [PATCH 27/33] docs --- docs/source/api_reference/survival.rst | 11 +++++++++++ skpro/survival/aft/__init__.py | 4 ++-- skpro/survival/aft/_aft_lifelines_weibull.py | 2 +- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/docs/source/api_reference/survival.rst b/docs/source/api_reference/survival.rst index de4bb2b1..a03ee9c4 100644 --- a/docs/source/api_reference/survival.rst +++ b/docs/source/api_reference/survival.rst @@ -90,6 +90,17 @@ Proportional hazards models CoxPHSkSurv CoxNet +Accelerated failure time models +------------------------------- + +.. currentmodule:: skpro.survival.aft + +.. autosummary:: + :toctree: auto_generated/ + :template: class.rst + + AFTWeibull + Generalized additive survival models ------------------------------------ diff --git a/skpro/survival/aft/__init__.py b/skpro/survival/aft/__init__.py index e57b2790..3b591ab9 100644 --- a/skpro/survival/aft/__init__.py +++ b/skpro/survival/aft/__init__.py @@ -1,6 +1,6 @@ """Module containing accelerated failure time models.""" # copyright: skpro developers, BSD-3-Clause License (see LICENSE file) -__all__ = ["_SksurvAdapter"] +__all__ = ["AFTWeibull"] -from skpro.survival.adapters.sksurv import _SksurvAdapter +from skpro.survival.aft._aft_lifelines_weibull import AFTWeibull diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index e9ef8d09..0d03c6bc 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -9,7 +9,7 @@ from skpro.utils.sklearn import prep_skl_df -class AFTWeibullLifelines(_LifelinesAdapter, BaseSurvReg): +class AFTWeibull(_LifelinesAdapter, BaseSurvReg): r"""Weibull AFT model, from lifelines. Direct interface to ``lifelines.fitters.WeibullAFTFitter``, From b0fd1c04156f98a0ef5276d57cba52c5045ff4b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 23:01:06 +0100 Subject: [PATCH 28/33] Update _aft_lifelines_weibull.py --- skpro/survival/aft/_aft_lifelines_weibull.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index 0d03c6bc..35e76cfd 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -48,7 +48,7 @@ class AFTWeibull(_LifelinesAdapter, BaseSurvReg): fit_intercept: boolean, optional (default=True) Whether to fit an intercept term in the model. - + alpha: float, optional (default=0.05) the level in the confidence intervals around the estimated survival function, for computation of ``confidence_intervals_`` fitted parameter. From e5fb0e323e5821f5600de62f3a8f3c7183bcc32d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 23:02:06 +0100 Subject: [PATCH 29/33] Update _aft_lifelines_weibull.py --- skpro/survival/aft/_aft_lifelines_weibull.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index 35e76cfd..dbe31370 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -6,7 +6,6 @@ from skpro.distributions.weibull import Weibull from skpro.survival.adapters.lifelines import _LifelinesAdapter from skpro.survival.base import BaseSurvReg -from skpro.utils.sklearn import prep_skl_df class AFTWeibull(_LifelinesAdapter, BaseSurvReg): From 2670aac7086956efaf0d5c89d4d1ed8b9a295abc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 23:13:27 +0100 Subject: [PATCH 30/33] Update _aft_lifelines_weibull.py --- skpro/survival/aft/_aft_lifelines_weibull.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index dbe31370..4006a336 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -114,6 +114,18 @@ def _get_lifelines_class(self): return WeibullAFTFitter + def _get_lifelines_object(self): + """Abstract method to initialize lifelines object. + + The default initializes result of _get_lifelines_class + with self.get_params. + """ + cls = self._get_lifelines_class() + params = self.get_params() + params.pop("scale_cols", None) + params.pop("shape_cols", None) + return cls(params) + def _add_extra_fit_args(self, X, y, C=None): """Get extra arguments for the fit method. From b9ad46cc938dadb29c0386291ddf46ebf42687a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 23:24:25 +0100 Subject: [PATCH 31/33] Update _aft_lifelines_weibull.py --- skpro/survival/aft/_aft_lifelines_weibull.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index 4006a336..c3c7e075 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -124,7 +124,7 @@ def _get_lifelines_object(self): params = self.get_params() params.pop("scale_cols", None) params.pop("shape_cols", None) - return cls(params) + return cls(**params) def _add_extra_fit_args(self, X, y, C=None): """Get extra arguments for the fit method. From f169b5f7a54c267096360ef0df606a345ee697c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 23:36:58 +0100 Subject: [PATCH 32/33] Update _aft_lifelines_weibull.py --- skpro/survival/aft/_aft_lifelines_weibull.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index c3c7e075..ce85713b 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -180,9 +180,9 @@ def _predict_proba(self, X): lifelines_est = getattr(self, self._estimator_attr) ll_pred_proba = lifelines_est._prep_inputs_for_prediction_and_return_scores - scale, shape = ll_pred_proba(df, ancillary) + scale, k = ll_pred_proba(df, ancillary) - dist = Weibull(scale=scale, shape=shape, index=X.index, columns=self._y_cols) + dist = Weibull(scale=scale, k=k, index=X.index, columns=self._y_cols) return dist @classmethod From c7142c67e7842e5532eea3d199f0152a06ec8ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 17 Apr 2024 23:50:51 +0100 Subject: [PATCH 33/33] fix broadcasting --- skpro/survival/aft/_aft_lifelines_weibull.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/skpro/survival/aft/_aft_lifelines_weibull.py b/skpro/survival/aft/_aft_lifelines_weibull.py index ce85713b..1b9b5f97 100644 --- a/skpro/survival/aft/_aft_lifelines_weibull.py +++ b/skpro/survival/aft/_aft_lifelines_weibull.py @@ -3,6 +3,8 @@ __author__ = ["fkiraly"] +import numpy as np + from skpro.distributions.weibull import Weibull from skpro.survival.adapters.lifelines import _LifelinesAdapter from skpro.survival.base import BaseSurvReg @@ -181,6 +183,8 @@ def _predict_proba(self, X): ll_pred_proba = lifelines_est._prep_inputs_for_prediction_and_return_scores scale, k = ll_pred_proba(df, ancillary) + scale = np.expand_dims(scale, axis=1) + k = np.expand_dims(k, axis=1) dist = Weibull(scale=scale, k=k, index=X.index, columns=self._y_cols) return dist