From c5cc2cd5cfae31e6971b04dfe11f3ded8d52b404 Mon Sep 17 00:00:00 2001
From: Dmitry Shemetov <dshemetov@ucdavis.edu>
Date: Fri, 23 Aug 2024 14:06:02 -0700
Subject: [PATCH] refactor: simplify ungrouped epix_slide

---
 R/methods-epi_archive.R            | 126 +++++++++++++++++++++++++----
 tests/testthat/test-deprecations.R |  10 ---
 tests/testthat/test-epix_slide.R   |  15 +---
 3 files changed, 114 insertions(+), 37 deletions(-)

diff --git a/R/methods-epi_archive.R b/R/methods-epi_archive.R
index a666e5f3..0f4dc3bd 100644
--- a/R/methods-epi_archive.R
+++ b/R/methods-epi_archive.R
@@ -833,22 +833,116 @@ epix_slide.epi_archive <- function(
     ref_time_values = NULL,
     new_col_name = NULL,
     all_versions = FALSE) {
-  # For an "ungrouped" slide, treat all rows as belonging to one big
-  # group (group by 0 vars), like `dplyr::summarize`, and let the
-  # resulting `grouped_epi_archive` handle the slide:
-  epix_slide(
-    group_by(x),
-    f,
-    ...,
-    before = before, ref_time_values = ref_time_values, new_col_name = new_col_name,
-    all_versions = all_versions
-  ) %>%
-    # We want a slide on ungrouped archives to output something
-    # ungrouped, rather than retaining the trivial (0-variable)
-    # grouping applied above. So we `ungroup()`. However, the current
-    # `dplyr` implementation automatically ignores/drops trivial
-    # groupings, so this is just a no-op for now.
-    ungroup()
+  ### START Copy pasta from grouped_epi_archive ###
+  # Deprecated argument handling
+  provided_args <- rlang::call_args_names(rlang::call_match())
+  if ("all_rows" %in% provided_args) {
+    cli_abort("
+          The `all_rows` argument has been removed from `epix_slide` (but
+          is still supported in `epi_slide`). Add rows for excluded
+          results with a manual join instead.
+        ", class = "epiprocess__epix_slide_all_rows_parameter_deprecated")
+  }
+  if ("as_list_col" %in% provided_args) {
+    cli::cli_abort(
+      "epix_slide: the argument `as_list_col` is deprecated. If FALSE, you can just remove it.
+      If TRUE, have your given computation wrap its result using `list(result)` instead."
+    )
+  }
+  if ("names_sep" %in% provided_args) {
+    cli::cli_abort(
+      "epix_slide: the argument `names_sep` is deprecated. If NULL, you can remove it, it is now default.
+      If a string, please manually prefix your column names instead."
+    )
+  }
+
+  if (is.null(ref_time_values)) {
+    ref_time_values <- epix_slide_ref_time_values_default(x)
+  } else {
+    assert_numeric(ref_time_values, min.len = 1L, null.ok = FALSE, any.missing = FALSE)
+    if (any(ref_time_values > x$versions_end)) {
+      cli_abort("Some `ref_time_values` are greater than the latest version in the archive.")
+    }
+    if (anyDuplicated(ref_time_values) != 0L) {
+      cli_abort("Some `ref_time_values` are duplicated.")
+    }
+    # Sort, for consistency with `epi_slide`, although the current
+    # implementation doesn't take advantage of it.
+    ref_time_values <- sort(ref_time_values)
+  }
+
+  validate_slide_window_arg(before, x$time_type)
+
+  checkmate::assert_string(new_col_name, null.ok = TRUE)
+  if (identical(new_col_name, "time_value")) {
+    cli_abort('`new_col_name` must not be `"time_value"`; `epix_slide()` uses that column name to attach the `ref_time_value` associated with each slide computation') # nolint: line_length_linter
+  }
+
+  # Validate rest of parameters:
+  assert_logical(all_versions, len = 1L)
+
+  # If `f` is missing, interpret ... as an expression for tidy evaluation
+  if (missing(f)) {
+    used_data_masking <- TRUE
+    quosures <- enquos(...)
+    if (length(quosures) == 0) {
+      cli_abort("If `f` is missing then a computation must be specified via `...`.")
+    }
+
+    f <- as_slide_computation(quosures)
+    # Magic value that passes zero args as dots in calls below. Equivalent to
+    # `... <- missing_arg()`, but use `assign` to avoid warning about
+    # improper use of dots.
+    assign("...", missing_arg())
+  } else {
+    used_data_masking <- FALSE
+    f <- as_slide_computation(f, ...)
+  }
+  ### END Copy pasta from grouped_epi_archive ###
+
+  out <- purrr::map(ref_time_values, function(ref_time_value) {
+    epi_df <- x %>%
+      epix_as_of(ref_time_value, min_time_value = ref_time_value - before, all_versions = all_versions)
+    comp_value <- f(epi_df, "fake_gk", ref_time_value, ...)
+    if (!used_data_masking && !(
+      # vctrs considers data.frames to be vectors, but we still check
+      # separately for them because certain base operations output data frames
+      # with rownames, which we will allow (but might drop)
+      is.data.frame(comp_value) ||
+        vctrs::obj_is_vector(comp_value) && is.null(vctrs::vec_names(comp_value))
+    )) {
+      cli_abort("
+      the slide computations must always return data frames or unnamed vectors
+      (as determined by the vctrs package) (and not a mix of these two
+      structures).
+    ", class = "epiprocess__invalid_slide_comp_value")
+    }
+    res <- list(time_value = vctrs::vec_rep(ref_time_value, vctrs::vec_size(comp_value)))
+
+    if (is.null(new_col_name)) {
+      if (inherits(comp_value, "data.frame")) {
+        # unpack into separate columns (without name prefix):
+        res <- c(res, comp_value)
+      } else {
+        # apply default name (to vector or packed data.frame-type column):
+        res[["slide_value"]] <- comp_value
+      }
+    } else {
+      # vector or packed data.frame-type column (note: new_col_name of
+      # "time_value" is disallowed):
+      res[[new_col_name]] <- comp_value
+    }
+
+    # Stop on naming conflicts (names() fine here, non-NULL). Not the
+    # friendliest error messages though.
+    vctrs::vec_as_names(names(res), repair = "check_unique")
+
+    # Fast conversion:
+    return(validate_tibble(new_tibble(res)))
+  })
+  out <- vctrs::vec_rbind(!!!out) %>% decay_epi_df()
+
+  return(out)
 }
 
 
diff --git a/tests/testthat/test-deprecations.R b/tests/testthat/test-deprecations.R
index 7d29149b..1fce6274 100644
--- a/tests/testthat/test-deprecations.R
+++ b/tests/testthat/test-deprecations.R
@@ -1,14 +1,4 @@
 test_that("epix_slide group_by= deprecation works", {
-  expect_error(
-    archive_cases_dv_subset %>%
-      epix_slide(function(...) {}, before = 2L, group_by = c()),
-    class = "epiprocess__epix_slide_group_by_parameter_deprecated"
-  )
-  expect_error(
-    archive_cases_dv_subset %>%
-      epix_slide(function(...) {}, before = 2L, group_by = c()),
-    class = "epiprocess__epix_slide_group_by_parameter_deprecated"
-  )
   expect_error(
     archive_cases_dv_subset %>%
       group_by(geo_value) %>%
diff --git a/tests/testthat/test-epix_slide.R b/tests/testthat/test-epix_slide.R
index 2151a82c..e470bb51 100644
--- a/tests/testthat/test-epix_slide.R
+++ b/tests/testthat/test-epix_slide.R
@@ -1,11 +1,6 @@
 suppressPackageStartupMessages(library(dplyr))
 
 test_date <- as.Date("2020-01-01")
-
-test_that("epix_slide only works on an epi_archive", {
-  expect_error(epix_slide(data.frame(x = 1)))
-})
-
 x <- tibble::tribble(
   ~version, ~time_value, ~binary,
   test_date + 4, test_date + c(1:3), 2^(1:3),
@@ -14,10 +9,13 @@ x <- tibble::tribble(
   test_date + 7, test_date + 2:6, 2^(11:15)
 ) %>%
   tidyr::unnest(c(time_value, binary))
-
 xx <- bind_cols(geo_value = rep("ak", 15), x) %>%
   as_epi_archive()
 
+test_that("epix_slide only works on an epi_archive", {
+  expect_error(epix_slide(data.frame(x = 1)))
+})
+
 test_that("epix_slide works as intended", {
   xx1 <- xx %>%
     group_by(.data$geo_value) %>%
@@ -204,7 +202,6 @@ test_that("quosure passing issue in epix_slide is resolved + other potential iss
       new_col_name = "case_rate_3d_av"
     )
   reference_by_neither <- ea %>%
-    group_by() %>%
     epix_slide(
       f = ~ mean(.x$case_rate_7d_av),
       before = 2,
@@ -340,7 +337,6 @@ test_that("epix_slide with all_versions option has access to all older versions"
   ea_orig_mirror <- ea %>% clone()
 
   result1 <- ea %>%
-    group_by() %>%
     epix_slide(
       f = slide_fn,
       before = 10^3,
@@ -362,7 +358,6 @@ test_that("epix_slide with all_versions option has access to all older versions"
   expect_identical(result1, result2) # *
 
   result3 <- ea %>%
-    group_by() %>%
     epix_slide(
       f = slide_fn,
       before = 10^3,
@@ -373,7 +368,6 @@ test_that("epix_slide with all_versions option has access to all older versions"
 
   # formula interface
   result4 <- ea %>%
-    group_by() %>%
     epix_slide(
       f = ~ slide_fn(.x, .y),
       before = 10^3,
@@ -384,7 +378,6 @@ test_that("epix_slide with all_versions option has access to all older versions"
 
   # tidyeval interface
   result5 <- ea %>%
-    group_by() %>%
     epix_slide(
       # unfortunately, we can't pass this directly as `f` and need an extra comma
       ,