Skip to content

Commit

Permalink
Update description + minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
adibender committed Jun 23, 2020
1 parent 17f0550 commit e2dfcdf
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 23 deletions.
10 changes: 4 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Package: pamm.xgb
Title: Advanced survival analysis using xgboost
Version: 0.1.0
Package: pem.xgb
Title: GBT (PEM)
Version: 0.1.1
Date: 2020-01-07
Authors@R: person("Andreas", "Bender", , "[email protected]", role = c("aut", "cre"), comment=c(ORCID = "0000-0001-5628-8611"))
Description: Estimate piece-wise exponential models using gradient based algorithms.
Description: Estimate PEMs using XGBoost. Prototype implementation for Paper accepted at ECML 2020.
Depends: R (>= 3.6.0)
Imports:
caret,
Expand All @@ -22,8 +22,6 @@ Suggests:
knitr,
rmarkdown,
testthat
Remotes:
adibender/pammtools
License: MIT + file LICENSE
Encoding: UTF-8
LazyData: true
Expand Down
10 changes: 5 additions & 5 deletions R/pec-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ get_ibs <- function(
q_eval = c(.25, .5, .75),
...) {

if (class(data) == "list") {
if (class(data)[1] == "list") {
status_values <- sort(unique(data[[1]]$status))
} else {
status_values <- sort(unique(data$status))
Expand All @@ -32,15 +32,15 @@ get_ibs <- function(
times <- times_list(data, q_last = q_last, q_eval = q_eval,
status_value = .x)
time_seq <- seq(.01, times$t_last, length.out = 500L)
if (class(data) == "list") {
if (class(data)[1] == "list") {
if(class(object)[1] == "list") {
pred <- predictSurvProb(object[[1]], data, times = time_seq)
} else {
pred <- predictSurvProb(object, data, times = time_seq)
}
}
pec_params$object <- if (class(data) == "list") {list("pam_xgb" = pred)} else {object}
pec_params$data <- if (class(data) == "list") { data[[1]]} else{ data }
pec_params$object <- if (class(data)[1] == "list") {list("pam_xgb" = pred)} else {object}
pec_params$data <- if (class(data)[1] == "list") { data[[1]]} else{ data }
pec_params$times <- time_seq
pec_params$cause <- .x

Expand Down Expand Up @@ -80,7 +80,7 @@ get_ibs <- function(
times_list <- function(data, q_last = .8, q_eval = c(.25, .5, .75),
time_var = "time", status_var = "status", status_value = 1) {

if (class(data) == "list") {
if (class(data)[1] == "list") {
data <- data[[1]]
}

Expand Down
15 changes: 3 additions & 12 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ predict.pam_xgb <- function(
type = c("hazard", "cumu_hazard", "surv_prob"),
...) {

## TODO: this function actually doesn't work if data is xgb.DMatrix
type <- match.arg(type)
brks <- attr(object, "attr_ped")$breaks
# attr(object, "attr_ped") <- c(attr(object, "attr_ped"), status_var = "status")

## TODO: catch case where newdata either xgb.DMatrix or data.frame in ped format
ped_newdata <- as_ped(object, newdata)
vars <- setdiff(
attr(ped_newdata, "names"),
Expand All @@ -37,7 +35,7 @@ predict.pam_xgb <- function(
if (type == "cumu_hazard") {
ped_newdata <- ped_newdata %>%
group_by(.data$id) %>%
mutate(pred = cumsum(.data$pred * exp(.data$offset)))#TODO: is it correct to use offset here?
mutate(pred = cumsum(.data$pred * exp(.data$offset)))
}
if (type == "surv_prob") {
ped_newdata <- ped_newdata %>%
Expand All @@ -48,19 +46,16 @@ predict.pam_xgb <- function(
ped_newdata %>%
group_by(.data[["id"]]) %>%
filter(row_number() == n()) %>%
pull(.data[["pred"]]) # TODO: is the hazard/surv prob in the last available interval a useful return?
pull(.data[["pred"]])

}

# check for time-dependent covariates.
#TODO seems like this should just be a flag in attr_ped?
has_tdc <- function(model) {
any(c("ccr", "func") %in% names(attributes(model)[["attr_ped"]]))
}

get_new_ped <- function(object, newdata, times, attr_ped) {
## TODO: create ped_info without creating a ped_newdata object first
# extract vars used in model fit
covars <- setdiff(attr_ped[["names"]], attr_ped[["intvars"]])
if ("tend" %in% object$feature_names) {
vars <- c("tend", covars)
Expand All @@ -77,7 +72,7 @@ get_new_ped <- function(object, newdata, times, attr_ped) {
ped_info[["offset"]] <- c(ped_info[["times"]][1], diff(ped_info[["times"]]))

# create data set with interval/time + covariate info
newdata[[id_var]] <- seq_len(nrow(newdata)) #TODO: potentially overwriting it here seems dangerous?
newdata[[id_var]] <- seq_len(nrow(newdata))
new_ped <- pammtools::combine_df(ped_info, newdata[, c(id_var, covars)])
new_ped$ped_status <- 1
new_ped
Expand Down Expand Up @@ -113,7 +108,6 @@ get_all_intervals <- function(

#' @importFrom tidyr fill
get_new_ped_tdc <- function(object, newdata, times, attr_ped) {
## TODO: create ped_info without creating a ped_newdata object first
# extract vars used in model fit
covars <- setdiff(attr_ped[["names"]], attr_ped[["intvars"]])
id_var <- attr_ped[["id_var"]]
Expand All @@ -126,9 +120,6 @@ get_new_ped_tdc <- function(object, newdata, times, attr_ped) {
# avoid assumption that newdata[[2]] is a single data.frame with TDCs on
# the same time grid. could also be a list of data.frames (?)
if (!is.data.frame(newdata[[2]])) {
# TODO: rename all ccr_tz_vars to "times",
# do a full join over all TDCs,
# then fill up NAs by LCVF
stop("multiple time scales for TDCs not implemented yet.")
}
#drop ids for which no time constant info is available:
Expand Down

0 comments on commit e2dfcdf

Please sign in to comment.