Skip to content

Commit

Permalink
mc
Browse files Browse the repository at this point in the history
  • Loading branch information
jovoni committed Aug 22, 2024
1 parent e5e199d commit 3f55637
Showing 1 changed file with 45 additions and 28 deletions.
73 changes: 45 additions & 28 deletions R/fit_task0.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
max_breakpoints = min(max_breakpoints, as.integer(length(x) / avg_points_per_window))
available_breakpoints <- 1:max_breakpoints
message("Initial proposals")
n_breakpoints <- 4
proposed_breakpoints <- lapply(available_breakpoints, function(n_breakpoints) {
proposed_breakpoints <- parallel::mclapply(available_breakpoints, FUN = function(n_breakpoints) {
print(n_breakpoints)
convergence <<- FALSE
iter <<- 1
Expand All @@ -185,7 +184,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
)

dplyr::tibble(n_breakpoints = n_breakpoints, best_bp = list(bb), convergence = convergence)
}) %>% do.call("bind_rows", .) %>% dplyr::distinct()
}, mc.cores = parallel::detectCores()) %>% do.call("bind_rows", .) %>% dplyr::distinct()

for (j in 1:nrow(proposed_breakpoints)) {
if (!all((biPOD:::bp_to_groups(dplyr::tibble(time=x, count=y), unlist(proposed_breakpoints[j,]$best_bp)) %>% table()) >= avg_points_per_window)) {
Expand All @@ -205,9 +204,9 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
fits <- list()
plots <- list()
proposed_breakpoints$idx <- c(1:nrow(proposed_breakpoints)) + 1
j <- 2

loos <- lapply(0:nrow(proposed_breakpoints), function(j) {
print(j)
#print(j)
if (j == 0) {
bp = array(0, dim = c(0))
} else {
Expand Down Expand Up @@ -238,42 +237,60 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
means <- lapply(1:ncol(repetitions), function(j) {mean(repetitions[,j])}) %>% unlist()
sds <- lapply(1:ncol(repetitions), function(j) {sd(repetitions[,j])}) %>% unlist()

plots[[j+1]] <<- ggplot2::ggplot() +
ggplot2::geom_pointrange(dplyr::tibble(x=x, means=means, sds = sds), mapping = ggplot2::aes(x=.data$x, y=.data$means, ymin=.data$means-.data$sds, ymax=.data$means+.data$sds)) +
ggplot2::geom_point(dplyr::tibble(x=x, y=y), mapping=ggplot2::aes(x=.data$x, y=.data$y), col="red") +
ggplot2::ggtitle(max(f$lp()))
# plots[[j+1]] <<- ggplot2::ggplot() +
# ggplot2::geom_pointrange(dplyr::tibble(x=x, means=means, sds = sds), mapping = ggplot2::aes(x=.data$x, y=.data$means, ymin=.data$means-.data$sds, ymax=.data$means+.data$sds)) +
# ggplot2::geom_point(dplyr::tibble(x=x, y=y), mapping=ggplot2::aes(x=.data$x, y=.data$y), col="red") +
# ggplot2::geom_vline(xintercept = bp) +
# ggplot2::ggtitle(max(f$lp()))


k = 1 + (1 + length(bp)) #+ length(bp)
n = length(x)
bic = k * log(n) - 2 * max(f$lp())
#print(bic)

return(loo)
return(bic)

#return(loo)

# k = 1 + (1 + length(bp)) #+ length(bp)
# n = length(x)
# bic = k * log(n) - 2 * max(f$lp())
# #median(f$lp())
# bic
})
}) %>% unlist()

#proposed_breakpoints
#ggpubr::ggarrange(plotlist = plots)

if (length(loos) == 1) {
message("Zero models with breakpoints has been found")
return(list(best_bp=NULL, best_fit=NULL))
}

suppressWarnings(loo_comp <- loo::loo_compare(loos))
loo_comp[,1] <- round(loo_comp[,1],1)

loo_comp <- loo_comp %>% as_tibble() %>%
dplyr::mutate(model = rownames(loo_comp)) %>%
dplyr::mutate(j = as.numeric(stringr::str_replace(rownames(loo_comp), pattern = "model", replacement = "")) - 1) %>%
dplyr::mutate(convergence = TRUE)

loo_comp$convergence <- lapply(loo_comp$j, function(j) {
if (j == 0) return(TRUE)
all(loos[[j]]$diagnostics$pareto_k <= .7)
}) %>% unlist()

loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::filter(elpd_diff == max(elpd_diff))
best_j <- min(loo_comp$j)

best_fit <- fits[[best_j]]
bic_comp <- dplyr::tibble(bic = loos, n_breakpoints = 1:length(loos) - 1)
best_j <- bic_comp %>%
dplyr::filter(bic == min(bic)) %>%
dplyr::pull(n_breakpoints) %>%
min()

# suppressWarnings(loo_comp <- loo::loo_compare(loos))
# loo_comp[,1] <- round(loo_comp[,1],1)
#
# loo_comp <- loo_comp %>% as_tibble() %>%
# dplyr::mutate(model = rownames(loo_comp)) %>%
# dplyr::mutate(j = as.numeric(stringr::str_replace(rownames(loo_comp), pattern = "model", replacement = "")) - 1) %>%
# dplyr::mutate(convergence = TRUE)
#
# loo_comp$convergence <- lapply(loo_comp$j, function(j) {
# if (j == 0) return(TRUE)
# all(loos[[j]]$diagnostics$pareto_k <= 1)
# }) %>% unlist()
#
# loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::filter(elpd_diff == max(elpd_diff))
# best_j <- min(loo_comp$j)

best_fit <- fits[[best_j + 1]]
best_fit <- biPOD:::convert_mcmc_fit_to_biPOD(best_fit)

best_bp <- proposed_breakpoints %>%
Expand Down

0 comments on commit 3f55637

Please sign in to comment.