Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
add set_seed
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Sonabend committed Feb 9, 2021
1 parent 0db846e commit f8d29fa
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ src/survivalmodels.so
*.o

*.so
CRAN-RELEASE
2 changes: 0 additions & 2 deletions CRAN-RELEASE

This file was deleted.

2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: survivalmodels
Title: Models for Survival Analysis
Version: 0.1.5
Version: 0.1.6
Authors@R:
person(given = "Raphael",
family = "Sonabend",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# survivalmodels 0.1.6

* Add `set_seed` for easier setting of seeds within R and Python environments

# survivalmodels 0.1.5

* Fixed bug in `risk` return type when `distr6 = FALSE`
Expand Down
21 changes: 21 additions & 0 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,24 @@ setcollapse <- function(x) {
requireNamespaces <- function(x) {
all(vapply(x, requireNamespace, logical(1), quietly = TRUE))
}

#' @title Set seed in R numpy and torch
#' @description To ensure consistent results, a seed has to be set in R
#' using [set.seed] as usual but also in {numpy} and {torch} via {reticulate}.
#' Therefore this function simplifies the process into one funciton.
#' @param seed_R (`integer(1)`) `seed` passed to [set.seed].
#' @param seed_np (`integer(1)`) `seed` passed to `numpy$random$seed`. Default is same as `seed_R`.
#' @param seed_torch (`integer(1)`) `seed` passed to `numpy$random$seed`.
#' Default is same as `seed_R`.
set_seed <- function(seed_R, seed_np = seed_R, seed_torch = seed_R) {
set.seed(seed_R)
if (reticulate::py_module_available("numpy")) {
np <- reticulate::import("numpy")
np$random$seed(as.integer(seed_np))
}
if (reticulate::py_module_available("torch")) {
torch <- reticulate::import("torch")
torch$manual_seed(as.integer(seed_torch))
}
invisible(NULL)
}
21 changes: 21 additions & 0 deletions man/set_seed.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,14 @@ test_that("clean_test_data", {
expect_equal(clean_test_data(fit), fit$x)
expect_error(clean_test_data(fit, rats[, 1:2]), "Names in")
})

test_that("set_seed", {
skip_if_no_pycox()
set_seed(1)
first <- deepsurv(Surv(time, status) ~ ., data = rats[1:50, ], verbose = FALSE,
frac = 0.3)
set_seed(1)
second <- deepsurv(Surv(time, status) ~ ., data = rats[1:50, ], verbose = FALSE,
frac = 0.3)
expect_equal(first, second)
})

0 comments on commit f8d29fa

Please sign in to comment.