Skip to content

Commit

Permalink
Merge pull request #64 from JaGarRod/add_custom_family_parameter_to_r…
Browse files Browse the repository at this point in the history
…egr_glmboost

add custom.family parameter to learner_mboost_regr_glmboost.R
  • Loading branch information
Raphael Sonabend authored Mar 13, 2021
2 parents e8c4f02 + 92f5257 commit 05ac990
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3extralearners
Title: Extra Learners For mlr3
Version: 0.3.3
Version: 0.3.4
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3extralearners 0.3.4

* Add support for custom families to `regr.glmboost`

# mlr3extralearners 0.3.1

* `surv.svm` now supports all feature types
Expand Down
34 changes: 20 additions & 14 deletions R/learner_mboost_regr_glmboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ LearnerRegrGLMBoost = R6Class("LearnerRegrGLMBoost",
id = "family", default = c("Gaussian"),
levels = c(
"Gaussian", "Laplace", "Huber", "Poisson",
"GammaReg", "NBinomial", "Hurdle"), tags = "train"),
"GammaReg", "NBinomial", "Hurdle", "custom"),
tags = "train"),
ParamUty$new(id = "custom.family", tags = "train"),
ParamUty$new(id = "nuirange", default = c(0, 100), tags = "train"),
ParamDbl$new(
id = "d", default = NULL, special_vals = list(NULL),
Expand Down Expand Up @@ -79,10 +81,13 @@ LearnerRegrGLMBoost = R6Class("LearnerRegrGLMBoost",
methods::formalArgs(mboost::boost_control))]
pars_glmboost = pars[which(names(pars) %in%
methods::formalArgs(mboost::gamboost))]
pars_family = pars[which(names(pars) %in%
methods::formalArgs(utils::getFromNamespace(
pars_glmboost$family,
asNamespace("mboost"))))]

if (self$param_set$values$family != "custom") {
pars_family = pars[which(names(pars) %in%
methods::formalArgs(utils::getFromNamespace(
pars_glmboost$family,
asNamespace("mboost"))))]
}

f = task$formula()
data = task$data()
Expand All @@ -94,19 +99,20 @@ LearnerRegrGLMBoost = R6Class("LearnerRegrGLMBoost",
}

pars_glmboost$family = switch(pars$family,
Gaussian = mboost::Gaussian(),
Laplace = mboost::Laplace(),
Huber = invoke(mboost::Huber, .args = pars_family),
Poisson = mboost::Poisson(),
GammaReg = invoke(mboost::GammaReg, .args = pars_family),
NBinomial = invoke(mboost::NBinomial, .args = pars_family),
Hurdle = invoke(mboost::Hurdle, .args = pars_family)
Gaussian = mboost::Gaussian(),
Laplace = mboost::Laplace(),
Huber = invoke(mboost::Huber, .args = pars_family),
Poisson = mboost::Poisson(),
GammaReg = invoke(mboost::GammaReg, .args = pars_family),
NBinomial = invoke(mboost::NBinomial, .args = pars_family),
Hurdle = invoke(mboost::Hurdle, .args = pars_family),
custom = pars$custom.family
)

ctrl = invoke(mboost::boost_control, .args = pars_boost)
invoke(mboost::glmboost, f,
data = data, control = ctrl,
.args = pars_glmboost)
data = data, control = ctrl,
.args = pars_glmboost)
},

.predict = function(task) {
Expand Down

0 comments on commit 05ac990

Please sign in to comment.