From 5b32e6ef110184b6854df9ae5e4c7b800106da61 Mon Sep 17 00:00:00 2001 From: mb706 Date: Sat, 5 Mar 2016 23:10:04 +0100 Subject: [PATCH] Fix #201: focussearch now handles discrete vector parameters also general focussearch bugfixes --- R/infillOptFocus.R | 77 +++++++++++++++++++------- tests/testthat/test_infill_opt_focus.R | 5 ++ 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/R/infillOptFocus.R b/R/infillOptFocus.R index e66485f18..abaf0f6e9 100644 --- a/R/infillOptFocus.R +++ b/R/infillOptFocus.R @@ -13,6 +13,14 @@ infillOptFocus = function(infill.crit, models, control, par.set, opt.path, desig for (restart.iter in seq_len(control$infill.opt.restarts)) { # copy parset so we can shrink it ps.local = par.set + + # handle discrete vectors: + # The problem is that for discrete vectors, we can't adjust the values dimension-wise. + # Therefore, for discrete vectors, we always drop the last level and instead have a + # mapping that maps, for each discrete vector param and for each dimension, from + # the sampled value (levels 1 to n - #(dropped levels)) to levels with random dropouts. + discreteVectorMapping = lapply(filterParams(par.set, type = c("discretevector", "logicalvector"))$pars, + function(param) rep(list(setNames(names(param$values), names(param$values))), param$len)) # do iterations where we focus the region-of-interest around the current best point for (local.iter in seq_len(control$infill.opt.focussearch.maxit)) { @@ -21,13 +29,22 @@ infillOptFocus = function(infill.crit, models, control, par.set, opt.path, desig # convert to param encoding our model was trained on and can use newdesign = convertDataFrameCols(newdesign, ints.as.num = TRUE, logicals.as.factor = TRUE) - y = infill.crit(newdesign, models, control, ps.local, design, iter, ...) + + # handle discrete vectors + for (dvparam in filterParams(par.set, type = c("discretevector", "logicalvector"))$pars) { + for (dimnum in seq_len(dvparam$len)) { + dfindex = paste0(dvparam$id, dimnum) + mapping = discreteVectorMapping[[dvparam$id]][[dimnum]] + levels(newdesign[[dfindex]]) = mapping[levels(newdesign[[dfindex]])] + } + } + y = infill.crit(newdesign, models, control, par.set, design, iter, ...) # get current best value local.index = getMinIndex(y, ties.method = "random") local.y = y[local.index] local.x.df = newdesign[local.index, , drop = FALSE] - local.x.list = dfRowToList(recodeTypes(local.x.df, ps.local), ps.local, 1) + local.x.list = dfRowToList(recodeTypes(local.x.df, par.set), par.set, 1) # if we found a new best value, store it if (local.y < global.y) { @@ -39,24 +56,46 @@ infillOptFocus = function(infill.crit, models, control, par.set, opt.path, desig ps.local$pars = lapply(ps.local$pars, function(par) { # only shrink when there is a value val = local.x.list[[par$id]] - if (!isScalarNA(val)) { - if (isNumeric(par)) { - # shrink to range / 2, centered at val - range = par$upper - par$lower - par$lower = pmax(par$lower, val - (range / 4)) - par$upper = pmin(par$upper, val + (range / 4)) - if (isInteger(par)) { - par$lower = floor(par$lower) - par$upper = ceiling(par$upper) + if (isScalarNA(val)) { + return(par) + } + if (isNumeric(par)) { + # shrink to range / 2, centered at val + range = par$upper - par$lower + par$lower = pmax(par$lower, val - (range / 4)) + par$upper = pmin(par$upper, val + (range / 4)) + if (isInteger(par)) { + par$lower = floor(par$lower) + par$upper = ceiling(par$upper) + } + } else if (isDiscrete(par)) { + # randomly drop a level, which is not val + if (length(par$values) <= 1L) { + return(par) + } + # need to do some magic to handle discrete vectors + if (par$type %nin% c("discretevector", "logicalvector")) { + val.names = names(par$values) + # remove current val from delete options, should work also for NA + val.names = val.names[!sapply(par$values, identical, y=val)] # remember, 'val' may not even be a character + to.del = sample(val.names, 1) + par$values[to.del] = NULL + } else { + # we remove the last element of par$values and a random element for + # each dimension in discreteVectorMapping. + par$values = par$values[-length(par$values)] + if (par$type != "logicalvector") { + # for discretevectorparam val would be a list; convert to character vector + val = names(val) } - } else if (isDiscrete(par)) { - # randomly drop a level, which is not val - if (length(par$values) > 1L) { - val.names = names(par$values) - # remove current val from delete options, should work also for NA - val.names = setdiff(val.names, val) - to.del = sample(seq_along(val.names), 1) - par$values = par$values[-to.del] + for (dimnum in seq_len(par$len)) { + val.names = discreteVectorMapping[[par$id]][[dimnum]] + newmap = val.names + val.names = val.names[val.names != val[dimnum]] + to.del = sample(val.names, 1) + newmap = newmap[newmap != to.del] + names(newmap) = names(par$values) + discreteVectorMapping[[par$id]][[dimnum]] <<- newmap } } } diff --git a/tests/testthat/test_infill_opt_focus.R b/tests/testthat/test_infill_opt_focus.R index 3f9ea5483..b0d5d6990 100644 --- a/tests/testthat/test_infill_opt_focus.R +++ b/tests/testthat/test_infill_opt_focus.R @@ -68,6 +68,9 @@ test_that("complex param space, dependencies, focusing, restarts", { if(x$disc2 == 'a') tmp3 = log(x$realA) + x$intA^4 + ifelse(x$discA == 'm', 5, 0) if(x$disc2 == 'b') tmp3 = exp(x$realB) + ifelse(x$discB == 'R', sin(x$realBR), sin(x$realBNR)) if(x$disc2 == "c") tmp3 = 500 + assert(is.list(x$discVec)) + assert(x$discVec[[1]] %in% c("a", "b", "c")) + assert(x$discScal %in% c("x", "y", "z")) tmp1 + tmp2 + tmp3 }, par.set = makeParamSet( @@ -75,6 +78,8 @@ test_that("complex param space, dependencies, focusing, restarts", { makeIntegerParam("int1", lower = -100, upper = 100), makeNumericVectorParam("realVec", len = 10, lower = -50, upper = 50), makeIntegerVectorParam("intVec", len = 3, lower = 0, upper = 100), + makeDiscreteVectorParam("discVec", len = 3, c(x = "a", y = "b", z = "c")), + makeDiscreteParam("discScal", c(a = "x", b = "y", c = "z")), makeNumericParam("real2", lower = -1, upper = 1), makeDiscreteParam("disc1", values = c("foo", "bar"), requires = quote(real2 < 0)), makeNumericParam("real3", lower = -100, upper = 100, requires = quote(real2 > 0)),