Skip to content

Commit b3c9354

Browse files
mb706berndbischl
mb706
authored andcommitted
Fix #201: focussearch now handles discrete vector parameters
also general focussearch bugfixes
1 parent 7de7273 commit b3c9354

File tree

2 files changed

+63
-19
lines changed

2 files changed

+63
-19
lines changed

R/infillOptFocus.R

+58-19
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ infillOptFocus = function(infill.crit, models, control, par.set, opt.path, desig
1313
for (restart.iter in seq_len(control$infill.opt.restarts)) {
1414
# copy parset so we can shrink it
1515
ps.local = par.set
16+
17+
# handle discrete vectors:
18+
# The problem is that for discrete vectors, we can't adjust the values dimension-wise.
19+
# Therefore, for discrete vectors, we always drop the last level and instead have a
20+
# mapping that maps, for each discrete vector param and for each dimension, from
21+
# the sampled value (levels 1 to n - #(dropped levels)) to levels with random dropouts.
22+
discreteVectorMapping = lapply(filterParams(par.set, type = c("discretevector", "logicalvector"))$pars,
23+
function(param) rep(list(setNames(names(param$values), names(param$values))), param$len))
1624

1725
# do iterations where we focus the region-of-interest around the current best point
1826
for (local.iter in seq_len(control$infill.opt.focussearch.maxit)) {
@@ -21,13 +29,22 @@ infillOptFocus = function(infill.crit, models, control, par.set, opt.path, desig
2129

2230
# convert to param encoding our model was trained on and can use
2331
newdesign = convertDataFrameCols(newdesign, ints.as.num = TRUE, logicals.as.factor = TRUE)
24-
y = infill.crit(newdesign, models, control, ps.local, design, iter, ...)
32+
33+
# handle discrete vectors
34+
for (dvparam in filterParams(par.set, type = c("discretevector", "logicalvector"))$pars) {
35+
for (dimnum in seq_len(dvparam$len)) {
36+
dfindex = paste0(dvparam$id, dimnum)
37+
mapping = discreteVectorMapping[[dvparam$id]][[dimnum]]
38+
levels(newdesign[[dfindex]]) = mapping[levels(newdesign[[dfindex]])]
39+
}
40+
}
41+
y = infill.crit(newdesign, models, control, par.set, design, iter, ...)
2542

2643
# get current best value
2744
local.index = getMinIndex(y, ties.method = "random")
2845
local.y = y[local.index]
2946
local.x.df = newdesign[local.index, , drop = FALSE]
30-
local.x.list = dfRowToList(recodeTypes(local.x.df, ps.local), ps.local, 1)
47+
local.x.list = dfRowToList(recodeTypes(local.x.df, par.set), par.set, 1)
3148

3249
# if we found a new best value, store it
3350
if (local.y < global.y) {
@@ -39,24 +56,46 @@ infillOptFocus = function(infill.crit, models, control, par.set, opt.path, desig
3956
ps.local$pars = lapply(ps.local$pars, function(par) {
4057
# only shrink when there is a value
4158
val = local.x.list[[par$id]]
42-
if (!isScalarNA(val)) {
43-
if (isNumeric(par)) {
44-
# shrink to range / 2, centered at val
45-
range = par$upper - par$lower
46-
par$lower = pmax(par$lower, val - (range / 4))
47-
par$upper = pmin(par$upper, val + (range / 4))
48-
if (isInteger(par)) {
49-
par$lower = floor(par$lower)
50-
par$upper = ceiling(par$upper)
59+
if (isScalarNA(val)) {
60+
return(par)
61+
}
62+
if (isNumeric(par)) {
63+
# shrink to range / 2, centered at val
64+
range = par$upper - par$lower
65+
par$lower = pmax(par$lower, val - (range / 4))
66+
par$upper = pmin(par$upper, val + (range / 4))
67+
if (isInteger(par)) {
68+
par$lower = floor(par$lower)
69+
par$upper = ceiling(par$upper)
70+
}
71+
} else if (isDiscrete(par)) {
72+
# randomly drop a level, which is not val
73+
if (length(par$values) <= 1L) {
74+
return(par)
75+
}
76+
# need to do some magic to handle discrete vectors
77+
if (par$type %nin% c("discretevector", "logicalvector")) {
78+
val.names = names(par$values)
79+
# remove current val from delete options, should work also for NA
80+
val.names = val.names[!sapply(par$values, identical, y=val)] # remember, 'val' may not even be a character
81+
to.del = sample(val.names, 1)
82+
par$values[to.del] = NULL
83+
} else {
84+
# we remove the last element of par$values and a random element for
85+
# each dimension in discreteVectorMapping.
86+
par$values = par$values[-length(par$values)]
87+
if (par$type != "logicalvector") {
88+
# for discretevectorparam val would be a list; convert to character vector
89+
val = names(val)
5190
}
52-
} else if (isDiscrete(par)) {
53-
# randomly drop a level, which is not val
54-
if (length(par$values) > 1L) {
55-
val.names = names(par$values)
56-
# remove current val from delete options, should work also for NA
57-
val.names = setdiff(val.names, val)
58-
to.del = sample(seq_along(val.names), 1)
59-
par$values = par$values[-to.del]
91+
for (dimnum in seq_len(par$len)) {
92+
val.names = discreteVectorMapping[[par$id]][[dimnum]]
93+
newmap = val.names
94+
val.names = val.names[val.names != val[dimnum]]
95+
to.del = sample(val.names, 1)
96+
newmap = newmap[newmap != to.del]
97+
names(newmap) = names(par$values)
98+
discreteVectorMapping[[par$id]][[dimnum]] <<- newmap
6099
}
61100
}
62101
}

tests/testthat/test_infill_opt_focus.R

+5
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,18 @@ test_that("complex param space, dependencies, focusing, restarts", {
6868
if(x$disc2 == 'a') tmp3 = log(x$realA) + x$intA^4 + ifelse(x$discA == 'm', 5, 0)
6969
if(x$disc2 == 'b') tmp3 = exp(x$realB) + ifelse(x$discB == 'R', sin(x$realBR), sin(x$realBNR))
7070
if(x$disc2 == "c") tmp3 = 500
71+
assert(is.list(x$discVec))
72+
assert(x$discVec[[1]] %in% c("a", "b", "c"))
73+
assert(x$discScal %in% c("x", "y", "z"))
7174
tmp1 + tmp2 + tmp3
7275
},
7376
par.set = makeParamSet(
7477
makeNumericParam("real1", lower = 0, upper = 1000),
7578
makeIntegerParam("int1", lower = -100, upper = 100),
7679
makeNumericVectorParam("realVec", len = 10, lower = -50, upper = 50),
7780
makeIntegerVectorParam("intVec", len = 3, lower = 0, upper = 100),
81+
makeDiscreteVectorParam("discVec", len = 3, c(x = "a", y = "b", z = "c")),
82+
makeDiscreteParam("discScal", c(a = "x", b = "y", c = "z")),
7883
makeNumericParam("real2", lower = -1, upper = 1),
7984
makeDiscreteParam("disc1", values = c("foo", "bar"), requires = quote(real2 < 0)),
8085
makeNumericParam("real3", lower = -100, upper = 100, requires = quote(real2 > 0)),

0 commit comments

Comments
 (0)