Skip to content

Commit 0c00d7a

Browse files
authored
refactor: move to mlr3learners style autotest, pass params in lrn call directly (#62)
* refactor: move to mlr3learners style autotest, pass params in lrn call directyl * refactor: use explicit integers
1 parent 136ae95 commit 0c00d7a

29 files changed

+57
-107
lines changed

R/LearnerClust.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
#' ids
2929
#'
3030
#' # get a specific learner from mlr_learners:
31-
#' lrn = mlr_learners$get("clust.kmeans")
32-
#' print(lrn)
31+
#' learner = lrn("clust.kmeans")
32+
#' print(learner)
3333
LearnerClust = R6Class("LearnerClust",
3434
inherit = Learner,
3535
public = list(

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ Also, the package is integrated with **[mlr3viz](https://github.com/mlr-org/mlr3
7777

7878
## Example
7979

80-
```{r}
80+
```r
8181
library(mlr3)
8282
library(mlr3cluster)
8383

84-
task = mlr_tasks$get("usarrests")
85-
learner = mlr_learners$get("clust.kmeans")
84+
task = tsk("usarrests")
85+
learner = lrn("clust.kmeans")
8686
learner$train(task)
8787
preds = learner$predict(task = task)
8888
```

man/LearnerClust.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_LearnerClust.R

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
test_that("predict on newdata works / clust", {
22
task = tsk("usarrests")$filter(1:40)
3-
learner = lrn("clust.featureless")
4-
learner$param_set$values = list(num_clusters = 1L)
3+
learner = lrn("clust.featureless", num_clusters = 1L)
54
expect_error(learner$predict(task), "trained")
65
learner$train(task)
76
expect_task(learner$state$train_task)
@@ -15,26 +14,24 @@ test_that("predict on newdata works / clust", {
1514

1615
# rely on internally stored task representation
1716
p = learner$predict_newdata(newdata = newdata, task = NULL)
18-
expect_data_table(as.data.table(p), nrows = 10)
17+
expect_data_table(as.data.table(p), nrows = 10L)
1918
expect_set_equal(as.data.table(p)$row_ids, 1:10)
2019
expect_null(p$truth)
2120
})
2221

2322
test_that("reset()", {
2423
task = tsk("usarrests")
25-
lrn = lrn("clust.featureless")
26-
lrn$param_set$values = list(num_clusters = 2L)
24+
learner = lrn("clust.featureless", num_clusters = 2L)
2725

28-
lrn$train(task)
29-
expect_list(lrn$state, names = "unique")
30-
expect_learner(lrn$reset())
31-
expect_null(lrn$state)
26+
learner$train(task)
27+
expect_list(learner$state, names = "unique")
28+
expect_learner(learner$reset())
29+
expect_null(learner$state)
3230
})
3331

3432
test_that("empty predict set (#421)", {
3533
task = tsk("usarrests")
36-
learner = lrn("clust.featureless")
37-
learner$param_set$values = list(num_clusters = 1L)
34+
learner = lrn("clust.featureless", num_clusters = 1L)
3835
resampling = rsmp("holdout", ratio = 1)
3936
hout = resampling$instantiate(task)
4037
model = learner$train(task, hout$train_set(1))

tests/testthat/test_MeasureClust.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
test_that("Cluster measures", {
22
keys = mlr_measures$keys("clust")
33
task = tsk("usarrests")
4-
learner = mlr_learners$get("clust.kmeans")
5-
learner$param_set$values = list(centers = 2)
6-
learner$train(task)
7-
p = learner$predict(task)
4+
learner = lrn("clust.kmeans", centers = 2L)
5+
p = learner$train(task)$predict(task)
86

97
for (key in keys) {
108
m = mlr_measures$get(key)

tests/testthat/test_PredictionClust.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,16 @@ test_that("Construction", {
99

1010
test_that("Internally constructed Prediction", {
1111
task = tsk("usarrests")
12-
lrn = mlr_learners$get("clust.featureless")
13-
lrn$param_set$values = list(num_clusters = 1L)
14-
p = lrn$train(task)$predict(task)
12+
learner = lrn("clust.featureless", num_clusters = 1L)
13+
p = learner$train(task)$predict(task)
1514
expect_prediction(p)
1615
expect_prediction_clust(p)
1716
})
1817

1918
test_that("filter works", {
2019
task = tsk("usarrests")
21-
lrn = mlr_learners$get("clust.featureless")
22-
lrn$param_set$values = list(num_clusters = 1L)
23-
p = lrn$train(task)$predict(task)
20+
learner = lrn("clust.featureless", num_clusters = 1L)
21+
p = learner$train(task)$predict(task)
2422
pdata = p$data
2523

2624
pdata = filter_prediction_data(pdata, row_ids = 1:3)

tests/testthat/test_TaskClust.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ test_that("Basic ops on usarrests task", {
22
task = tsk("usarrests")
33
expect_task(task)
44
expect_task_clust(task)
5-
expect_identical(task$target_names, character(0))
5+
expect_identical(task$target_names, character(0L))
66
})
77

88
test_that("Basic ops on ruspini task", {
99
task = tsk("ruspini")
1010
expect_task(task)
1111
expect_task_clust(task)
12-
expect_identical(task$target_names, character(0))
12+
expect_identical(task$target_names, character(0L))
1313
})
1414

1515
test_that("0 feature task", {
@@ -22,8 +22,7 @@ test_that("0 feature task", {
2222
expect_task_clust(task)
2323
expect_data_table(task$data(), ncols = 1L)
2424

25-
lrn = lrn("clust.featureless")
26-
lrn$param_set$values = list(num_clusters = 3L)
27-
p = lrn$train(task)$predict(task)
25+
learner = lrn("clust.featureless", num_clusters = 3L)
26+
p = learner$train(task)$predict(task)
2827
expect_prediction(p)
2928
})

tests/testthat/test_mlr_learners_clust_agnes.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@ skip_if_not_installed("clue")
33
test_that("autotest", {
44
learner = mlr3::lrn("clust.agnes")
55
expect_learner(learner)
6-
76
result = run_autotest(learner)
87
expect_true(result, info = result$error)
98
})
109

11-
1210
test_that("Learner properties are respected", {
1311
task = tsk("usarrests")
14-
learner = mlr_learners$get("clust.agnes")
12+
learner = lrn("clust.agnes")
1513
expect_learner(learner, task)
1614

1715
# test on multiple paramsets

tests/testthat/test_mlr_learners_clust_ap.R

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
skip_if_not_installed("apcluster")
22

33
test_that("autotest", {
4-
learner = mlr3::lrn("clust.ap")
5-
learner$param_set$values = list(s = apcluster::negDistMat(r = 2L))
4+
learner = mlr3::lrn("clust.ap", s = apcluster::negDistMat(r = 2L))
65
expect_learner(learner)
7-
86
result = run_autotest(learner)
97
expect_true(result, info = result$error)
108
})
119

12-
1310
test_that("Learner properties are respected", {
1411
task = tsk("usarrests")
15-
learner = mlr_learners$get("clust.ap")
12+
learner = lrn("clust.ap")
1613
expect_learner(learner, task)
1714

1815
# test on multiple paramsets

tests/testthat/test_mlr_learners_clust_cmeans.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,17 @@ skip_if_not_installed("e1071")
33
test_that("autotest", {
44
learner = mlr3::lrn("clust.cmeans")
55
expect_learner(learner)
6-
76
result = run_autotest(learner)
87
expect_true(result, info = result$error)
98
})
109

11-
1210
test_that("Learner properties are respected", {
1311
task = tsk("usarrests")
14-
learner = mlr_learners$get("clust.cmeans")
12+
learner = lrn("clust.cmeans")
1513
expect_learner(learner, task)
1614

1715
# test on multiple paramsets
18-
centers = data.frame(matrix(ncol = length(colnames(task$data())), nrow = 4))
16+
centers = data.frame(matrix(ncol = length(colnames(task$data())), nrow = 4L))
1917
colnames(centers) = colnames(task$data())
2018
centers$Assault = c(100, 200, 150, 300)
2119
centers$Murder = c(11, 3, 10, 5)

0 commit comments

Comments
 (0)