Skip to content

Commit

Permalink
Merge pull request #265 from mlr-org/keep_model
Browse files Browse the repository at this point in the history
Support hyperparameter 'model' in surv.rpart
  • Loading branch information
RaphaelS1 authored Mar 31, 2022
2 parents e4145cd + 0bd1d3f commit be96ee3
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.4.6
Version: 0.4.7
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3proba 0.4.7

* Fix bug in {rpart} where model was being discarded when set to be kept. Parameter `model` now called `keep_model`.

# mlr3proba 0.4.6

* Patch for upstream breakages
Expand Down
5 changes: 4 additions & 1 deletion R/LearnerSurvRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#'
#' @description
#' Parameter `xval` is set to 0 in order to save some computation time.
#' Parameter `model` has been renamed to `keep_model`.
#'
#' @references
#' `r format_bib("breiman_1984")`
Expand All @@ -28,7 +29,8 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart",
usesurrogate = p_int(default = 2L, lower = 0L, upper = 2L, tags = "train"),
surrogatestyle = p_int(default = 0L, lower = 0L, upper = 1L, tags = "train"),
xval = p_int(default = 10L, lower = 0L, tags = "train"),
cost = p_uty(tags = "train")
cost = p_uty(tags = "train"),
keep_model = p_lgl(default = FALSE, tags = "train")
)
ps$values = list(xval = 0L)

Expand Down Expand Up @@ -69,6 +71,7 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}
Expand Down
2 changes: 1 addition & 1 deletion man/mlr_learners_surv.coxph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_surv.rpart.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions tests/testthat/test_mlr_learners_surv_rpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,9 @@ test_that("importance/selected", {
expect_silent(learner$selected_features())
expect_silent(learner$importance())
})

test_that("keep_model", {
learner = lrn("surv.rpart", keep_model = TRUE)
learner$train(tsk("rats"))
expect_false(is.null(learner$model$model))
})

0 comments on commit be96ee3

Please sign in to comment.