From e2dfcdf20da9b59f04b3dda01aef29c02dce86a9 Mon Sep 17 00:00:00 2001 From: adibender Date: Tue, 23 Jun 2020 19:44:13 +0200 Subject: [PATCH] Update description + minor updates --- DESCRIPTION | 10 ++++------ R/pec-helpers.R | 10 +++++----- R/predict.R | 15 +++------------ 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 1c10163..c4e457c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,9 +1,9 @@ -Package: pamm.xgb -Title: Advanced survival analysis using xgboost -Version: 0.1.0 +Package: pem.xgb +Title: GBT (PEM) +Version: 0.1.1 Date: 2020-01-07 Authors@R: person("Andreas", "Bender", , "andreas.bender@stat.uni-muenchen.de", role = c("aut", "cre"), comment=c(ORCID = "0000-0001-5628-8611")) -Description: Estimate piece-wise exponential models using gradient based algorithms. +Description: Estimate PEMs using XGBoost. Prototype implementation for Paper accepted at ECML 2020. Depends: R (>= 3.6.0) Imports: caret, @@ -22,8 +22,6 @@ Suggests: knitr, rmarkdown, testthat -Remotes: - adibender/pammtools License: MIT + file LICENSE Encoding: UTF-8 LazyData: true diff --git a/R/pec-helpers.R b/R/pec-helpers.R index ad2ecc7..f0b9158 100644 --- a/R/pec-helpers.R +++ b/R/pec-helpers.R @@ -18,7 +18,7 @@ get_ibs <- function( q_eval = c(.25, .5, .75), ...) { - if (class(data) == "list") { + if (class(data)[1] == "list") { status_values <- sort(unique(data[[1]]$status)) } else { status_values <- sort(unique(data$status)) @@ -32,15 +32,15 @@ get_ibs <- function( times <- times_list(data, q_last = q_last, q_eval = q_eval, status_value = .x) time_seq <- seq(.01, times$t_last, length.out = 500L) - if (class(data) == "list") { + if (class(data)[1] == "list") { if(class(object)[1] == "list") { pred <- predictSurvProb(object[[1]], data, times = time_seq) } else { pred <- predictSurvProb(object, data, times = time_seq) } } - pec_params$object <- if (class(data) == "list") {list("pam_xgb" = pred)} else {object} - pec_params$data <- if (class(data) == "list") { data[[1]]} else{ data } + pec_params$object <- if (class(data)[1] == "list") {list("pam_xgb" = pred)} else {object} + pec_params$data <- if (class(data)[1] == "list") { data[[1]]} else{ data } pec_params$times <- time_seq pec_params$cause <- .x @@ -80,7 +80,7 @@ get_ibs <- function( times_list <- function(data, q_last = .8, q_eval = c(.25, .5, .75), time_var = "time", status_var = "status", status_value = 1) { - if (class(data) == "list") { + if (class(data)[1] == "list") { data <- data[[1]] } diff --git a/R/predict.R b/R/predict.R index 8155249..714699e 100644 --- a/R/predict.R +++ b/R/predict.R @@ -20,12 +20,10 @@ predict.pam_xgb <- function( type = c("hazard", "cumu_hazard", "surv_prob"), ...) { - ## TODO: this function actually doesn't work if data is xgb.DMatrix type <- match.arg(type) brks <- attr(object, "attr_ped")$breaks # attr(object, "attr_ped") <- c(attr(object, "attr_ped"), status_var = "status") - ## TODO: catch case where newdata either xgb.DMatrix or data.frame in ped format ped_newdata <- as_ped(object, newdata) vars <- setdiff( attr(ped_newdata, "names"), @@ -37,7 +35,7 @@ predict.pam_xgb <- function( if (type == "cumu_hazard") { ped_newdata <- ped_newdata %>% group_by(.data$id) %>% - mutate(pred = cumsum(.data$pred * exp(.data$offset)))#TODO: is it correct to use offset here? + mutate(pred = cumsum(.data$pred * exp(.data$offset))) } if (type == "surv_prob") { ped_newdata <- ped_newdata %>% @@ -48,19 +46,16 @@ predict.pam_xgb <- function( ped_newdata %>% group_by(.data[["id"]]) %>% filter(row_number() == n()) %>% - pull(.data[["pred"]]) # TODO: is the hazard/surv prob in the last available interval a useful return? + pull(.data[["pred"]]) } # check for time-dependent covariates. -#TODO seems like this should just be a flag in attr_ped? has_tdc <- function(model) { any(c("ccr", "func") %in% names(attributes(model)[["attr_ped"]])) } get_new_ped <- function(object, newdata, times, attr_ped) { - ## TODO: create ped_info without creating a ped_newdata object first - # extract vars used in model fit covars <- setdiff(attr_ped[["names"]], attr_ped[["intvars"]]) if ("tend" %in% object$feature_names) { vars <- c("tend", covars) @@ -77,7 +72,7 @@ get_new_ped <- function(object, newdata, times, attr_ped) { ped_info[["offset"]] <- c(ped_info[["times"]][1], diff(ped_info[["times"]])) # create data set with interval/time + covariate info - newdata[[id_var]] <- seq_len(nrow(newdata)) #TODO: potentially overwriting it here seems dangerous? + newdata[[id_var]] <- seq_len(nrow(newdata)) new_ped <- pammtools::combine_df(ped_info, newdata[, c(id_var, covars)]) new_ped$ped_status <- 1 new_ped @@ -113,7 +108,6 @@ get_all_intervals <- function( #' @importFrom tidyr fill get_new_ped_tdc <- function(object, newdata, times, attr_ped) { - ## TODO: create ped_info without creating a ped_newdata object first # extract vars used in model fit covars <- setdiff(attr_ped[["names"]], attr_ped[["intvars"]]) id_var <- attr_ped[["id_var"]] @@ -126,9 +120,6 @@ get_new_ped_tdc <- function(object, newdata, times, attr_ped) { # avoid assumption that newdata[[2]] is a single data.frame with TDCs on # the same time grid. could also be a list of data.frames (?) if (!is.data.frame(newdata[[2]])) { - # TODO: rename all ccr_tz_vars to "times", - # do a full join over all TDCs, - # then fill up NAs by LCVF stop("multiple time scales for TDCs not implemented yet.") } #drop ids for which no time constant info is available: