diff --git a/R/fit_task0.R b/R/fit_task0.R index d6c60eb..72f1b14 100644 --- a/R/fit_task0.R +++ b/R/fit_task0.R @@ -162,8 +162,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t max_breakpoints = min(max_breakpoints, as.integer(length(x) / avg_points_per_window)) available_breakpoints <- 1:max_breakpoints message("Initial proposals") - n_breakpoints <- 4 - proposed_breakpoints <- lapply(available_breakpoints, function(n_breakpoints) { + proposed_breakpoints <- parallel::mclapply(available_breakpoints, FUN = function(n_breakpoints) { print(n_breakpoints) convergence <<- FALSE iter <<- 1 @@ -185,7 +184,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t ) dplyr::tibble(n_breakpoints = n_breakpoints, best_bp = list(bb), convergence = convergence) - }) %>% do.call("bind_rows", .) %>% dplyr::distinct() + }, mc.cores = parallel::detectCores()) %>% do.call("bind_rows", .) %>% dplyr::distinct() for (j in 1:nrow(proposed_breakpoints)) { if (!all((biPOD:::bp_to_groups(dplyr::tibble(time=x, count=y), unlist(proposed_breakpoints[j,]$best_bp)) %>% table()) >= avg_points_per_window)) { @@ -205,9 +204,9 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t fits <- list() plots <- list() proposed_breakpoints$idx <- c(1:nrow(proposed_breakpoints)) + 1 - j <- 2 + loos <- lapply(0:nrow(proposed_breakpoints), function(j) { - print(j) + #print(j) if (j == 0) { bp = array(0, dim = c(0)) } else { @@ -238,42 +237,60 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t means <- lapply(1:ncol(repetitions), function(j) {mean(repetitions[,j])}) %>% unlist() sds <- lapply(1:ncol(repetitions), function(j) {sd(repetitions[,j])}) %>% unlist() - plots[[j+1]] <<- ggplot2::ggplot() + - ggplot2::geom_pointrange(dplyr::tibble(x=x, means=means, sds = sds), mapping = ggplot2::aes(x=.data$x, y=.data$means, ymin=.data$means-.data$sds, ymax=.data$means+.data$sds)) + - ggplot2::geom_point(dplyr::tibble(x=x, y=y), mapping=ggplot2::aes(x=.data$x, y=.data$y), col="red") + - ggplot2::ggtitle(max(f$lp())) + # plots[[j+1]] <<- ggplot2::ggplot() + + # ggplot2::geom_pointrange(dplyr::tibble(x=x, means=means, sds = sds), mapping = ggplot2::aes(x=.data$x, y=.data$means, ymin=.data$means-.data$sds, ymax=.data$means+.data$sds)) + + # ggplot2::geom_point(dplyr::tibble(x=x, y=y), mapping=ggplot2::aes(x=.data$x, y=.data$y), col="red") + + # ggplot2::geom_vline(xintercept = bp) + + # ggplot2::ggtitle(max(f$lp())) + + + k = 1 + (1 + length(bp)) #+ length(bp) + n = length(x) + bic = k * log(n) - 2 * max(f$lp()) + #print(bic) - return(loo) + return(bic) + + #return(loo) # k = 1 + (1 + length(bp)) #+ length(bp) # n = length(x) # bic = k * log(n) - 2 * max(f$lp()) # #median(f$lp()) # bic - }) + }) %>% unlist() + + #proposed_breakpoints + #ggpubr::ggarrange(plotlist = plots) if (length(loos) == 1) { message("Zero models with breakpoints has been found") return(list(best_bp=NULL, best_fit=NULL)) } - suppressWarnings(loo_comp <- loo::loo_compare(loos)) - loo_comp[,1] <- round(loo_comp[,1],1) - - loo_comp <- loo_comp %>% as_tibble() %>% - dplyr::mutate(model = rownames(loo_comp)) %>% - dplyr::mutate(j = as.numeric(stringr::str_replace(rownames(loo_comp), pattern = "model", replacement = "")) - 1) %>% - dplyr::mutate(convergence = TRUE) - - loo_comp$convergence <- lapply(loo_comp$j, function(j) { - if (j == 0) return(TRUE) - all(loos[[j]]$diagnostics$pareto_k <= .7) - }) %>% unlist() - - loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::filter(elpd_diff == max(elpd_diff)) - best_j <- min(loo_comp$j) - - best_fit <- fits[[best_j]] + bic_comp <- dplyr::tibble(bic = loos, n_breakpoints = 1:length(loos) - 1) + best_j <- bic_comp %>% + dplyr::filter(bic == min(bic)) %>% + dplyr::pull(n_breakpoints) %>% + min() + + # suppressWarnings(loo_comp <- loo::loo_compare(loos)) + # loo_comp[,1] <- round(loo_comp[,1],1) + # + # loo_comp <- loo_comp %>% as_tibble() %>% + # dplyr::mutate(model = rownames(loo_comp)) %>% + # dplyr::mutate(j = as.numeric(stringr::str_replace(rownames(loo_comp), pattern = "model", replacement = "")) - 1) %>% + # dplyr::mutate(convergence = TRUE) + # + # loo_comp$convergence <- lapply(loo_comp$j, function(j) { + # if (j == 0) return(TRUE) + # all(loos[[j]]$diagnostics$pareto_k <= 1) + # }) %>% unlist() + # + # loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::filter(elpd_diff == max(elpd_diff)) + # best_j <- min(loo_comp$j) + + best_fit <- fits[[best_j + 1]] best_fit <- biPOD:::convert_mcmc_fit_to_biPOD(best_fit) best_bp <- proposed_breakpoints %>%