diff --git a/lifelines/fitters/cox_time_varying_fitter.py b/lifelines/fitters/cox_time_varying_fitter.py index d58cb4112..db0e1cb0f 100644 --- a/lifelines/fitters/cox_time_varying_fitter.py +++ b/lifelines/fitters/cox_time_varying_fitter.py @@ -295,6 +295,15 @@ def _compute_confidence_intervals(self): index=self.params_.index, ) + def calculate_survival_function(self, X: pd.DataFrame) -> pd.Series: + v = self.predict_partial_hazard(X) + col = utils._get_index(v) + times_ = X.start.values + c_0 = utils.interpolate_at_times(self.baseline_cumulative_hazard_, [times_]).T + cumulative_hazard_ = pd.DataFrame(c_0 * v.values, columns=col, index=times_) + surv_func = pd.Series(np.exp(-cumulative_hazard_).values.diagonal(), index=times_) + return surv_func + @property def summary(self): """Summary statistics describing the fit.