diff --git a/lifelines/fitters/cox_time_varying_fitter.py b/lifelines/fitters/cox_time_varying_fitter.py index d58cb4112..db0e1cb0f 100644 --- a/lifelines/fitters/cox_time_varying_fitter.py +++ b/lifelines/fitters/cox_time_varying_fitter.py @@ -295,6 +295,15 @@ def _compute_confidence_intervals(self): index=self.params_.index, ) + def calculate_survival_function(self, X: pd.DataFrame) -> pd.Series: + v = self.predict_partial_hazard(X) + col = utils._get_index(v) + times_ = X.start.values + c_0 = utils.interpolate_at_times(self.baseline_cumulative_hazard_, [times_]).T + cumulative_hazard_ = pd.DataFrame(c_0 * v.values, columns=col, index=times_) + surv_func = pd.Series(np.exp(-cumulative_hazard_).values.diagonal(), index=times_) + return surv_func + @property def summary(self): """Summary statistics describing the fit. diff --git a/lifelines/tests/test_estimation.py b/lifelines/tests/test_estimation.py index abf1d5f97..defbf3b50 100644 --- a/lifelines/tests/test_estimation.py +++ b/lifelines/tests/test_estimation.py @@ -5134,6 +5134,14 @@ def test_inference_against_known_R_output(self, ctv, dfcv): npt.assert_almost_equal(ctv.summary["se(coef)"].values, [1.229, 1.206], decimal=3) npt.assert_almost_equal(ctv.summary["p"].values, [0.14, 0.56], decimal=2) + def test_survival_function(self, ctv, dfcv): + sf_validation = [0.866, 0.866, 0.748, 0.866, 0.866, 0.277, 0.954, + 0.977, 0.902, 0.977, 0.977, 0.902, 0.977, 0.954] + ctv.fit(dfcv, id_col="id", start_col="start", stop_col="stop", event_col="event") + sf = ctv.calculate_survival_function(dfcv) + assert sf.index.values is dfcv.start.values + npt.assert_almost_equal(sf.values, sf_validation, decimal=2) + def test_that_id_col_is_optional(self, dfcv): ctv_with_id = CoxTimeVaryingFitter().fit(dfcv, id_col="id", start_col="start", stop_col="stop", event_col="event")