Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] adapter to lifelines, most distributional survival regressors interfaced #247

Merged
merged 34 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
eecccbd
lifelines
fkiraly Apr 16, 2024
429ffc4
lifelines adapter
fkiraly Apr 16, 2024
2923254
docs, export
fkiraly Apr 16, 2024
54a875e
pyproject
fkiraly Apr 16, 2024
ddb8bbe
documentation
fkiraly Apr 16, 2024
5b26c03
Update __init__.py
fkiraly Apr 16, 2024
ed97ac6
Update __init__.py
fkiraly Apr 16, 2024
11fc89c
Update _aalen_lifelines.py
fkiraly Apr 16, 2024
ba78ad6
linting
fkiraly Apr 17, 2024
5335999
fix improper surv functions via clipping
fkiraly Apr 17, 2024
7fcbb3a
Update lifelines.py
fkiraly Apr 17, 2024
4727a67
Update lifelines.py
fkiraly Apr 17, 2024
b9b3647
deal with all zeroes
fkiraly Apr 17, 2024
ef787c3
move utils to common module
fkiraly Apr 17, 2024
d939689
Update _common.py
fkiraly Apr 17, 2024
54f3006
Update _common.py
fkiraly Apr 17, 2024
cfbecdf
coxph
fkiraly Apr 17, 2024
266485b
Update _coxph_lifelines.py
fkiraly Apr 17, 2024
a38f099
safe get param
fkiraly Apr 17, 2024
b3d1e07
comments
fkiraly Apr 17, 2024
1fc4408
fix comment
fkiraly Apr 17, 2024
091acfc
Update _coxph_lifelines.py
fkiraly Apr 17, 2024
2afaa7f
Update _coxph_lifelines.py
fkiraly Apr 17, 2024
e9f6252
weibull partial work
fkiraly Apr 17, 2024
d254591
weibull without weibull
fkiraly Apr 17, 2024
780c451
Update _aft_lifelines_weibull.py
fkiraly Apr 17, 2024
4e8905f
Merge branch 'main' into lifelines-adapt
fkiraly Apr 17, 2024
42d6253
docs
fkiraly Apr 17, 2024
b0fd1c0
Update _aft_lifelines_weibull.py
fkiraly Apr 17, 2024
e5fb0e3
Update _aft_lifelines_weibull.py
fkiraly Apr 17, 2024
2670aac
Update _aft_lifelines_weibull.py
fkiraly Apr 17, 2024
b9ad46c
Update _aft_lifelines_weibull.py
fkiraly Apr 17, 2024
f169b5f
Update _aft_lifelines_weibull.py
fkiraly Apr 17, 2024
c7142c6
fix broadcasting
fkiraly Apr 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/api_reference/survival.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,32 @@ Proportional hazards models
:template: class.rst

CoxPH
CoxPHlifelines
CoxPHSkSurv
CoxNet

Accelerated failure time models
-------------------------------

.. currentmodule:: skpro.survival.aft

.. autosummary::
:toctree: auto_generated/
:template: class.rst

AFTWeibull

Generalized additive survival models
------------------------------------

.. currentmodule:: skpro.survival.additive

.. autosummary::
:toctree: auto_generated/
:template: class.rst

AalenAdditiveLifelines

Tree models
-----------

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies = [
all_extras = [
"attrs",
"distfit",
"lifelines<0.29.0",
"mapie",
"matplotlib>=3.3.2",
"ngboost",
Expand Down
120 changes: 120 additions & 0 deletions skpro/survival/adapters/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Common utilities for adapters."""

import numpy as np


def _clip_surv(surv_arr):
"""Clips improper survival function values to proper range.

Enforces: values are in [0, 1] and are monotonically decreasing.

First clips to [0, 1], then enforces monotonicity, by replacing
any value with minimum of itself and any previous values.

Parameters
----------
surv_arr : 2D np.ndarray
Survival function values.
index 0 is instance index.
index 1 is time index, increasing.

Returns
-------
surv_arr_clipped : 2D np.ndarray
Clipped survival function values.
surv_arr_diff : 2D np.ndarray, same shape as surv_arr_clipped.
Difference of clipped survival function values.
Same as np.diff(surv_arr, axis=1, prepend=1, append=0),
then summing the last two columns to become one column.
Returned to avoid recomputation, if needed later in context.
clipped : boolean
Whether clipping was needed.
"""
too_large = surv_arr > 1
too_small = surv_arr < 0

surv_arr[too_large] = 1
surv_arr[too_small] = 0

surv_arr_diff = _surv_diff(surv_arr)

# avoid iterative minimization if no further clipping is needed
if not (surv_arr_diff > 0).any():
clipped = too_large.any() or too_small.any()
return surv_arr, surv_arr_diff, clipped

# enforce monotonicity
# iterating from left to right ensures values are replaced
# with minimum of itself and all values to the left
for i in range(1, surv_arr.shape[1]):
surv_arr[:, i] = np.minimum(surv_arr[:, i], surv_arr[:, i - 1])

surv_arr_diff = _surv_diff(surv_arr)

return surv_arr, surv_arr_diff, True


def _surv_diff(surv_arr):
"""Compute difference of survival function values.

Parameters
----------
surv_arr : 2D np.ndarray
Survival function values.
index 0 is instance index.
index 1 is time index, increasing.

Returns
-------
surv_arr_diff : 2D np.ndarray, same shape as surv_arr
Difference of survival function values.
Same as np.diff(surv_arr, axis=1, prepend=1, append=0),
then summing the last two columns to become one column
"""
surv_arr_diff = np.diff(surv_arr, axis=1, prepend=1, append=0)

surv_arr_diff[:, -2] = surv_arr_diff[:, -2] + surv_arr_diff[:, -1]
surv_arr_diff = surv_arr_diff[:, :-1]

return surv_arr_diff


def _get_fitted_params_default_safe(obj=None):
"""Obtain fitted params of object, per sklearn convention.

Same as _get_fitted_params_default, but with exception handling.

This is since in sksurv, feature_importances_ is a property
and may raise an exception if the estimator does not have it.

Parameters
----------
obj : any object

Returns
-------
fitted_params : dict with str keys
fitted parameters, keyed by names of fitted parameter
"""
# default retrieves all self attributes ending in "_"
# and returns them with keys that have the "_" removed
#
# get all attributes ending in "_", exclude any that start with "_" (private)
fitted_params = [
attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_")
]

def hasattr_safe(obj, attr):
try:
if hasattr(obj, attr):
getattr(obj, attr)
return True
except Exception:
return False

# remove the "_" at the end
fitted_param_dict = {
p[:-1]: getattr(obj, p) for p in fitted_params if hasattr_safe(obj, p)
}

return fitted_param_dict
196 changes: 196 additions & 0 deletions skpro/survival/adapters/lifelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""Implements adapter for lifelines models."""

__all__ = ["_LifelinesAdapter"]
__author__ = ["fkiraly"]

from warnings import warn

import numpy as np
import pandas as pd

from skpro.distributions.empirical import Empirical
from skpro.survival.adapters._common import _clip_surv, _get_fitted_params_default_safe
from skpro.utils.sklearn import prep_skl_df


class _LifelinesAdapter:
"""Mixin adapter class for lifelines models."""

_tags = {
# packaging info
# --------------
"authors": ["fkiraly"],
"python_dependencies": ["lifelines"],
"license_type": "permissive",
# capability tags
# ---------------
"X_inner_mtype": "pd_DataFrame_Table",
"y_inner_mtype": "pd_DataFrame_Table",
"C_inner_mtype": "pd_DataFrame_Table",
"capability:multioutput": False,
}

# defines the name of the attribute containing the lifelines estimator
_estimator_attr = "_estimator"

def _get_lifelines_class(self):
"""Abstract method to get lifelines class.

should import and return lifelines class
"""
# from lifelines import LifelinesClass
#
# return LifelinesClass
raise NotImplementedError("abstract method")

def _get_lifelines_object(self):
"""Abstract method to initialize lifelines object.

The default initializes result of _get_lifelines_class
with self.get_params.
"""
cls = self._get_lifelines_class()
return cls(**self.get_params())

def _init_lifelines_object(self):
"""Abstract method to initialize lifelines object and set to _estimator_attr.

The default writes the return of _get_lifelines_object to
the attribute of self with name _estimator_attr
"""
cls = self._get_lifelines_object()
setattr(self, self._estimator_attr, cls)
return getattr(self, self._estimator_attr)

def _get_extra_fit_args(self, X, y, C=None):
"""Get extra arguments for the fit method.

Parameters
----------
X : pd.DataFrame
Training features
y: pd.DataFrame
Training labels
C: pd.DataFrame, optional (default=None)
Censoring information for survival analysis.

Returns
-------
dict
Extra arguments for the fit method.
"""
return {}

def _fit(self, X, y, C=None):
"""Fit estimator training data.

Parameters
----------
X : pd.DataFrame
Training features
y: pd.DataFrame
Training labels
C: pd.DataFrame, optional (default=None)
Censoring information for survival analysis.

Returns
-------
self: reference to self
Fitted estimator.
"""
lifelines_est = self._init_lifelines_object()

# input conversion
X = X.astype("float") # lifelines insists on float dtype
X = prep_skl_df(X)

if hasattr(self, "X_col_subset"):
X = X[self.X_col_subset]

to_concat = [X, y]

if C is not None:
C_col = 1 - C.copy() # lifelines uses 1 for uncensored, 0 for censored
C_col.columns = ["__C"]
to_concat.append(C_col)

df = pd.concat(to_concat, axis=1)

self._y_cols = y.columns # remember column names for later
y_name = y.columns[0]

fit_args = {
"df": df,
"duration_col": y_name,
}
if C is not None:
fit_args["event_col"] = "__C"

fit_args.update(self._get_extra_fit_args(X, y, C))

# fit lifelines estimator
lifelines_est.fit(**fit_args)

# write fitted params to self
# some fitted parameters are properties and may raise exceptions
# for example, AIC_ and AIC_partial_ of CoxPHFitter
# to avoid this, we use a safe getter
lifelines_fitted_params = _get_fitted_params_default_safe(lifelines_est)
for k, v in lifelines_fitted_params.items():
setattr(self, f"{k}_", v)

return self

def _predict_proba(self, X):
"""Predict_proba method adapter.

Parameters
----------
X : pd.DataFrame
Features to predict on.

Returns
-------
skpro Empirical distribution
"""
lifelines_est = getattr(self, self._estimator_attr)

# input conversion
X = X.astype("float") # lifelines insists on float dtype
X = prep_skl_df(X)

# predict on X
lifelines_survf = lifelines_est.predict_survival_function(X)

times = lifelines_survf.index

nt = len(times)
mi = pd.MultiIndex.from_product([X.index, range(nt)]).swaplevel()

times_val = np.repeat(times, repeats=len(X))
times_df = pd.DataFrame(times_val, index=mi, columns=self._y_cols)

lifelines_survf_t = np.transpose(lifelines_survf.values)
_, lifelines_survf_t_diff, clipped = _clip_surv(lifelines_survf_t)

if clipped:
warn(
f"Warning from {self.__class__.__name__}: "
f"Interfaced lifelines class {lifelines_est.__class__.__name__} "
"produced improper survival function predictions, i.e., "
"not monotonically decreasing or not in [0, 1]. "
"skpro has clipped the predictions to enforce proper range and "
"valid predictive distributions. "
"However, predictions may still be unreliable.",
stacklevel=2,
)

weights = -lifelines_survf_t_diff.flatten()
weights_df = pd.Series(weights, index=mi)

dist = Empirical(
spl=times_df, weights=weights_df, index=X.index, columns=self._y_cols
)

return dist
Loading
Loading