Skip to content

Commit

Permalink
[ENH] adapter to lifelines, most distributional survival regressors…
Browse files Browse the repository at this point in the history
… interfaced (#247)

Adds an adapter to `lifelines`, exposing array-like survival functions
as `Empirical` distributions in `predict_proba`.

Adds all models from `lifelines` which are capable of full
distributional predictions:

* `AalenAdditiveFitter`
* `CoxPHFitter`
* `WeibullAFTFitter`

Remaining AFT fitters require distributions not yet merged.
  • Loading branch information
fkiraly authored Apr 17, 2024
1 parent ea7ce6d commit e2862a5
Show file tree
Hide file tree
Showing 11 changed files with 897 additions and 48 deletions.
23 changes: 23 additions & 0 deletions docs/source/api_reference/survival.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,32 @@ Proportional hazards models
:template: class.rst

CoxPH
CoxPHlifelines
CoxPHSkSurv
CoxNet

Accelerated failure time models
-------------------------------

.. currentmodule:: skpro.survival.aft

.. autosummary::
:toctree: auto_generated/
:template: class.rst

AFTWeibull

Generalized additive survival models
------------------------------------

.. currentmodule:: skpro.survival.additive

.. autosummary::
:toctree: auto_generated/
:template: class.rst

AalenAdditiveLifelines

Tree models
-----------

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies = [
all_extras = [
"attrs",
"distfit",
"lifelines<0.29.0",
"mapie",
"matplotlib>=3.3.2",
"ngboost",
Expand Down
120 changes: 120 additions & 0 deletions skpro/survival/adapters/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Common utilities for adapters."""

import numpy as np


def _clip_surv(surv_arr):
"""Clips improper survival function values to proper range.
Enforces: values are in [0, 1] and are monotonically decreasing.
First clips to [0, 1], then enforces monotonicity, by replacing
any value with minimum of itself and any previous values.
Parameters
----------
surv_arr : 2D np.ndarray
Survival function values.
index 0 is instance index.
index 1 is time index, increasing.
Returns
-------
surv_arr_clipped : 2D np.ndarray
Clipped survival function values.
surv_arr_diff : 2D np.ndarray, same shape as surv_arr_clipped.
Difference of clipped survival function values.
Same as np.diff(surv_arr, axis=1, prepend=1, append=0),
then summing the last two columns to become one column.
Returned to avoid recomputation, if needed later in context.
clipped : boolean
Whether clipping was needed.
"""
too_large = surv_arr > 1
too_small = surv_arr < 0

surv_arr[too_large] = 1
surv_arr[too_small] = 0

surv_arr_diff = _surv_diff(surv_arr)

# avoid iterative minimization if no further clipping is needed
if not (surv_arr_diff > 0).any():
clipped = too_large.any() or too_small.any()
return surv_arr, surv_arr_diff, clipped

# enforce monotonicity
# iterating from left to right ensures values are replaced
# with minimum of itself and all values to the left
for i in range(1, surv_arr.shape[1]):
surv_arr[:, i] = np.minimum(surv_arr[:, i], surv_arr[:, i - 1])

surv_arr_diff = _surv_diff(surv_arr)

return surv_arr, surv_arr_diff, True


def _surv_diff(surv_arr):
"""Compute difference of survival function values.
Parameters
----------
surv_arr : 2D np.ndarray
Survival function values.
index 0 is instance index.
index 1 is time index, increasing.
Returns
-------
surv_arr_diff : 2D np.ndarray, same shape as surv_arr
Difference of survival function values.
Same as np.diff(surv_arr, axis=1, prepend=1, append=0),
then summing the last two columns to become one column
"""
surv_arr_diff = np.diff(surv_arr, axis=1, prepend=1, append=0)

surv_arr_diff[:, -2] = surv_arr_diff[:, -2] + surv_arr_diff[:, -1]
surv_arr_diff = surv_arr_diff[:, :-1]

return surv_arr_diff


def _get_fitted_params_default_safe(obj=None):
"""Obtain fitted params of object, per sklearn convention.
Same as _get_fitted_params_default, but with exception handling.
This is since in sksurv, feature_importances_ is a property
and may raise an exception if the estimator does not have it.
Parameters
----------
obj : any object
Returns
-------
fitted_params : dict with str keys
fitted parameters, keyed by names of fitted parameter
"""
# default retrieves all self attributes ending in "_"
# and returns them with keys that have the "_" removed
#
# get all attributes ending in "_", exclude any that start with "_" (private)
fitted_params = [
attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_")
]

def hasattr_safe(obj, attr):
try:
if hasattr(obj, attr):
getattr(obj, attr)
return True
except Exception:
return False

# remove the "_" at the end
fitted_param_dict = {
p[:-1]: getattr(obj, p) for p in fitted_params if hasattr_safe(obj, p)
}

return fitted_param_dict
196 changes: 196 additions & 0 deletions skpro/survival/adapters/lifelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""Implements adapter for lifelines models."""

__all__ = ["_LifelinesAdapter"]
__author__ = ["fkiraly"]

from warnings import warn

import numpy as np
import pandas as pd

from skpro.distributions.empirical import Empirical
from skpro.survival.adapters._common import _clip_surv, _get_fitted_params_default_safe
from skpro.utils.sklearn import prep_skl_df


class _LifelinesAdapter:
"""Mixin adapter class for lifelines models."""

_tags = {
# packaging info
# --------------
"authors": ["fkiraly"],
"python_dependencies": ["lifelines"],
"license_type": "permissive",
# capability tags
# ---------------
"X_inner_mtype": "pd_DataFrame_Table",
"y_inner_mtype": "pd_DataFrame_Table",
"C_inner_mtype": "pd_DataFrame_Table",
"capability:multioutput": False,
}

# defines the name of the attribute containing the lifelines estimator
_estimator_attr = "_estimator"

def _get_lifelines_class(self):
"""Abstract method to get lifelines class.
should import and return lifelines class
"""
# from lifelines import LifelinesClass
#
# return LifelinesClass
raise NotImplementedError("abstract method")

def _get_lifelines_object(self):
"""Abstract method to initialize lifelines object.
The default initializes result of _get_lifelines_class
with self.get_params.
"""
cls = self._get_lifelines_class()
return cls(**self.get_params())

def _init_lifelines_object(self):
"""Abstract method to initialize lifelines object and set to _estimator_attr.
The default writes the return of _get_lifelines_object to
the attribute of self with name _estimator_attr
"""
cls = self._get_lifelines_object()
setattr(self, self._estimator_attr, cls)
return getattr(self, self._estimator_attr)

def _get_extra_fit_args(self, X, y, C=None):
"""Get extra arguments for the fit method.
Parameters
----------
X : pd.DataFrame
Training features
y: pd.DataFrame
Training labels
C: pd.DataFrame, optional (default=None)
Censoring information for survival analysis.
Returns
-------
dict
Extra arguments for the fit method.
"""
return {}

def _fit(self, X, y, C=None):
"""Fit estimator training data.
Parameters
----------
X : pd.DataFrame
Training features
y: pd.DataFrame
Training labels
C: pd.DataFrame, optional (default=None)
Censoring information for survival analysis.
Returns
-------
self: reference to self
Fitted estimator.
"""
lifelines_est = self._init_lifelines_object()

# input conversion
X = X.astype("float") # lifelines insists on float dtype
X = prep_skl_df(X)

if hasattr(self, "X_col_subset"):
X = X[self.X_col_subset]

to_concat = [X, y]

if C is not None:
C_col = 1 - C.copy() # lifelines uses 1 for uncensored, 0 for censored
C_col.columns = ["__C"]
to_concat.append(C_col)

df = pd.concat(to_concat, axis=1)

self._y_cols = y.columns # remember column names for later
y_name = y.columns[0]

fit_args = {
"df": df,
"duration_col": y_name,
}
if C is not None:
fit_args["event_col"] = "__C"

fit_args.update(self._get_extra_fit_args(X, y, C))

# fit lifelines estimator
lifelines_est.fit(**fit_args)

# write fitted params to self
# some fitted parameters are properties and may raise exceptions
# for example, AIC_ and AIC_partial_ of CoxPHFitter
# to avoid this, we use a safe getter
lifelines_fitted_params = _get_fitted_params_default_safe(lifelines_est)
for k, v in lifelines_fitted_params.items():
setattr(self, f"{k}_", v)

return self

def _predict_proba(self, X):
"""Predict_proba method adapter.
Parameters
----------
X : pd.DataFrame
Features to predict on.
Returns
-------
skpro Empirical distribution
"""
lifelines_est = getattr(self, self._estimator_attr)

# input conversion
X = X.astype("float") # lifelines insists on float dtype
X = prep_skl_df(X)

# predict on X
lifelines_survf = lifelines_est.predict_survival_function(X)

times = lifelines_survf.index

nt = len(times)
mi = pd.MultiIndex.from_product([X.index, range(nt)]).swaplevel()

times_val = np.repeat(times, repeats=len(X))
times_df = pd.DataFrame(times_val, index=mi, columns=self._y_cols)

lifelines_survf_t = np.transpose(lifelines_survf.values)
_, lifelines_survf_t_diff, clipped = _clip_surv(lifelines_survf_t)

if clipped:
warn(
f"Warning from {self.__class__.__name__}: "
f"Interfaced lifelines class {lifelines_est.__class__.__name__} "
"produced improper survival function predictions, i.e., "
"not monotonically decreasing or not in [0, 1]. "
"skpro has clipped the predictions to enforce proper range and "
"valid predictive distributions. "
"However, predictions may still be unreliable.",
stacklevel=2,
)

weights = -lifelines_survf_t_diff.flatten()
weights_df = pd.Series(weights, index=mi)

dist = Empirical(
spl=times_df, weights=weights_df, index=X.index, columns=self._y_cols
)

return dist
Loading

0 comments on commit e2862a5

Please sign in to comment.