|
| 1 | +library(tidyverse) |
| 2 | +library(httr) |
| 3 | +library(lubridate) |
| 4 | +library(progress) |
| 5 | +library(targets) |
| 6 | +source(here::here("R", "load_all.R")) |
| 7 | + |
| 8 | +options(readr.show_progress = FALSE) |
| 9 | +options(readr.show_col_types = FALSE) |
| 10 | + |
| 11 | +insufficient_data_geos <- c("as", "mp", "vi", "gu") |
| 12 | + |
| 13 | +# Configuration |
| 14 | +config <- list( |
| 15 | + base_url = "https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/main/model-output", |
| 16 | + forecasters = c("CMU-TimeSeries", "FluSight-baseline", "FluSight-ensemble", "FluSight-base_seasonal", "UMass-flusion"), |
| 17 | + local_storage = "data/forecasts", |
| 18 | + tracking_file = "data/download_tracking.csv" |
| 19 | +) |
| 20 | + |
| 21 | +# Function to ensure directory structure exists |
| 22 | +setup_directories <- function(base_dir) { |
| 23 | + dir.create(file.path(base_dir), recursive = TRUE, showWarnings = FALSE) |
| 24 | + for (forecaster in config$forecasters) { |
| 25 | + dir.create(file.path(base_dir, forecaster), recursive = TRUE, showWarnings = FALSE) |
| 26 | + } |
| 27 | +} |
| 28 | + |
| 29 | +# Function to load tracking data |
| 30 | +load_tracking_data <- function() { |
| 31 | + if (file.exists(config$tracking_file)) { |
| 32 | + read_csv(config$tracking_file) |
| 33 | + } else { |
| 34 | + tibble( |
| 35 | + forecaster = character(), |
| 36 | + filename = character(), |
| 37 | + download_date = character(), |
| 38 | + status = character() |
| 39 | + ) |
| 40 | + } |
| 41 | +} |
| 42 | + |
| 43 | +# Function to generate possible filenames for a date range |
| 44 | +generate_filenames <- function(start_date, end_date, forecaster) { |
| 45 | + dates <- seq(as_date(start_date), as_date(end_date), by = "week") |
| 46 | + filenames <- paste0( |
| 47 | + format(dates, "%Y-%m-%d"), |
| 48 | + "-", |
| 49 | + forecaster, |
| 50 | + ".csv" |
| 51 | + ) |
| 52 | + return(filenames) |
| 53 | +} |
| 54 | + |
| 55 | +# Function to check if file exists on GitHub |
| 56 | +check_github_file <- function(forecaster, filename) { |
| 57 | + url <- paste0(config$base_url, "/", forecaster, "/", filename) |
| 58 | + response <- GET(url) |
| 59 | + return(status_code(response) == 200) |
| 60 | +} |
| 61 | + |
| 62 | +# Function to download a single file |
| 63 | +download_forecast_file <- function(forecaster, filename) { |
| 64 | + url <- paste0(config$base_url, "/", forecaster, "/", filename) |
| 65 | + local_path <- file.path(config$local_storage, forecaster, filename) |
| 66 | + |
| 67 | + tryCatch( |
| 68 | + { |
| 69 | + download.file(url, local_path, mode = "wb", quiet = TRUE) |
| 70 | + return("success") |
| 71 | + }, |
| 72 | + error = function(e) { |
| 73 | + return("failed") |
| 74 | + } |
| 75 | + ) |
| 76 | +} |
| 77 | + |
| 78 | +# Main function to update forecast files |
| 79 | +update_forecast_files <- function(days_back = 30) { |
| 80 | + # Setup |
| 81 | + setup_directories(config$local_storage) |
| 82 | + tracking_data <- load_tracking_data() |
| 83 | + |
| 84 | + # Generate date range |
| 85 | + end_date <- Sys.Date() |
| 86 | + start_date <- get_forecast_reference_date(end_date - days_back) |
| 87 | + |
| 88 | + # Process each forecaster |
| 89 | + new_tracking_records <- list() |
| 90 | + |
| 91 | + pb_forecasters <- progress_bar$new( |
| 92 | + format = "Downloading forecasts from :forecaster [:bar] :percent :eta", |
| 93 | + total = length(config$forecasters), |
| 94 | + clear = FALSE, |
| 95 | + width = 60 |
| 96 | + ) |
| 97 | + |
| 98 | + for (forecaster in config$forecasters) { |
| 99 | + pb_forecasters$tick(tokens = list(forecaster = forecaster)) |
| 100 | + |
| 101 | + # Get potential filenames |
| 102 | + filenames <- generate_filenames(start_date, end_date, forecaster) |
| 103 | + |
| 104 | + # Filter out already downloaded files |
| 105 | + existing_files <- tracking_data %>% |
| 106 | + filter(forecaster == !!forecaster, status == "success") %>% |
| 107 | + pull(filename) |
| 108 | + |
| 109 | + new_files <- setdiff(filenames, existing_files) |
| 110 | + |
| 111 | + if (length(new_files) > 0) { |
| 112 | + # Create nested progress bar for files |
| 113 | + pb_files <- progress_bar$new( |
| 114 | + format = " Downloading files [:bar] :current/:total :filename", |
| 115 | + total = length(new_files) |
| 116 | + ) |
| 117 | + |
| 118 | + for (filename in new_files) { |
| 119 | + pb_files$tick(tokens = list(filename = filename)) |
| 120 | + |
| 121 | + if (check_github_file(forecaster, filename)) { |
| 122 | + status <- download_forecast_file(forecaster, filename) |
| 123 | + |
| 124 | + new_tracking_records[[length(new_tracking_records) + 1]] <- tibble( |
| 125 | + forecaster = forecaster, |
| 126 | + filename = filename, |
| 127 | + download_date = as.character(Sys.time()), |
| 128 | + status = status |
| 129 | + ) |
| 130 | + } |
| 131 | + } |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + # Update tracking data |
| 136 | + if (length(new_tracking_records) > 0) { |
| 137 | + new_tracking_data <- bind_rows(new_tracking_records) |
| 138 | + tracking_data <- bind_rows(tracking_data, new_tracking_data) |
| 139 | + write_csv(tracking_data, config$tracking_file) |
| 140 | + } |
| 141 | + |
| 142 | + return(tracking_data) |
| 143 | +} |
| 144 | + |
| 145 | +# Function to read all forecast data |
| 146 | +read_all_forecasts <- function() { |
| 147 | + tracking_data <- read_csv(config$tracking_file) |
| 148 | + |
| 149 | + successful_downloads <- tracking_data %>% |
| 150 | + filter(status == "success") |
| 151 | + |
| 152 | + forecast_data <- map(1:nrow(successful_downloads), function(i) { |
| 153 | + row <- successful_downloads[i, ] |
| 154 | + path <- file.path(config$local_storage, row$forecaster, row$filename) |
| 155 | + if (file.exists(path)) { |
| 156 | + read_csv(path, col_types = list( |
| 157 | + reference_date = col_date(format = "%Y-%m-%d"), |
| 158 | + target_end_date = col_date(format = "%Y-%m-%d"), |
| 159 | + target = col_character(), |
| 160 | + location = col_character(), |
| 161 | + horizon = col_integer(), |
| 162 | + output_type = col_character(), |
| 163 | + output_type_id = col_character(), |
| 164 | + value = col_double(), |
| 165 | + forecaster = col_character(), |
| 166 | + forecast_date = col_date(format = "%Y-%m-%d") |
| 167 | + )) %>% |
| 168 | + add_state_info(geo_value_col = "location", old_geo_code = "state_code", new_geo_code = "state_id") %>% |
| 169 | + mutate( |
| 170 | + forecaster = row$forecaster, |
| 171 | + forecast_date = str_extract(row$filename, "\\d{4}-\\d{2}-\\d{2}"), |
| 172 | + geo_value = state_id |
| 173 | + ) |
| 174 | + } |
| 175 | + }) |
| 176 | + |
| 177 | + return(bind_rows(forecast_data)) |
| 178 | +} |
| 179 | + |
| 180 | +get_latest_data <- function() { |
| 181 | + update_forecast_files(days_back = 120) |
| 182 | + read_all_forecasts() |
| 183 | +} |
| 184 | + |
| 185 | +rlang::list2( |
| 186 | + tar_target( |
| 187 | + nhsn_latest_data, |
| 188 | + command = { |
| 189 | + if (wday(Sys.Date()) < 6 & wday(Sys.Date()) > 3) { |
| 190 | + # download from the preliminary data source from Wednesday to Friday |
| 191 | + most_recent_result <- readr::read_csv("https://data.cdc.gov/resource/mpgq-jmmr.csv?$limit=20000&$select=weekendingdate,jurisdiction,totalconfc19newadm,totalconfflunewadm") |
| 192 | + } else { |
| 193 | + most_recent_result <- readr::read_csv("https://data.cdc.gov/resource/ua7e-t2fy.csv?$limit=20000&$select=weekendingdate,jurisdiction,totalconfc19newadm,totalconfflunewadm") |
| 194 | + } |
| 195 | + most_recent_result %>% |
| 196 | + process_nhsn_data() %>% |
| 197 | + filter(disease == "nhsn_flu") %>% |
| 198 | + select(-disease) %>% |
| 199 | + filter(geo_value %nin% insufficient_data_geos) %>% |
| 200 | + mutate( |
| 201 | + source = "nhsn", |
| 202 | + geo_value = ifelse(geo_value == "usa", "us", geo_value), |
| 203 | + time_value = time_value - 3 |
| 204 | + ) %>% |
| 205 | + filter(version == max(version)) %>% |
| 206 | + select(-version) %>% |
| 207 | + data_substitutions(disease = "flu") %>% |
| 208 | + as_epi_df(other_keys = "source", as_of = Sys.Date()) |
| 209 | + } |
| 210 | + ), |
| 211 | + tar_target( |
| 212 | + name = nhsn_archive_data, |
| 213 | + command = { |
| 214 | + create_nhsn_data_archive(disease = "nhsn_flu") |
| 215 | + } |
| 216 | + ), |
| 217 | + tar_target(download_forecasts, update_forecast_files(days_back = 120)), |
| 218 | + tar_target(all_forecasts, read_all_forecasts()) |
| 219 | +) |
0 commit comments