Skip to content

Commit

Permalink
Merge pull request #25 from cmusso86/update_release
Browse files Browse the repository at this point in the history
remove ::
  • Loading branch information
cmusso86 committed Jun 19, 2024
2 parents d117b41 + 04f06d2 commit 94d02e4
Showing 1 changed file with 57 additions and 49 deletions.
106 changes: 57 additions & 49 deletions vignettes/simple_mlp.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ knitr::opts_chunk$set(

```{r setup}
library(recalibratiNN)
library(glue)
library(RANN)
library(dplyr)
library(purrr)
library(tidyr)
library(ggplot2)
```

```{r}
Expand Down Expand Up @@ -139,30 +147,30 @@ gg_PIT_local(pit_local,
```

```{r}
coverage_model <- tidyr::tibble(
coverage_model <- tibble(
x1cal = x_test[,1],
x2cal = x_test[,2],
y_real = y_test,
y_hat = y_hat_test) |>
dplyr::mutate(lwr = qnorm(0.05, y_hat, sqrt(MSE_cal)),
mutate(lwr = qnorm(0.05, y_hat, sqrt(MSE_cal)),
upr = qnorm(0.95, y_hat, sqrt(MSE_cal)),
CI = ifelse(y_real <= upr & y_real >= lwr,
"in", "out" ),
coverage = round(mean(CI == "in")*100,1)
)
coverage_model |>
ggplot2::ggplot() +
ggplot2::geom_point(ggplot2::aes(x1cal,
ggplot() +
geom_point(aes(x1cal,
x2cal,
color = CI),
alpha = 0.8)+
ggplot2::labs(x="x1" , y="x2",
title = glue::glue("Original coverage: {coverage_model$coverage[1]} %"))+
ggplot2::scale_color_manual("Confidence Interval",
labs(x="x1" , y="x2",
title = glue("Original coverage: {coverage_model$coverage[1]} %"))+
scale_color_manual("Confidence Interval",
values = c("in" = "aquamarine3",
"out" = "steelblue4"))+
ggplot2::theme_classic()
theme_classic()
```

```{r}
Expand All @@ -180,10 +188,10 @@ y_hat_rec <- recalibrated$y_samples_calibrated_wt
```

```{r}
coverage_rec <- purrr::map_dfr( 1:nrow(x_test), ~ {
coverage_rec <- map_dfr( 1:nrow(x_test), ~ {
quantile(y_hat_rec[.,]
,c(0.05, 0.95))}) |>
dplyr::mutate(
mutate(
x1 = x_test[,1],
x2 = x_test[,2],
ytest = y_test,
Expand All @@ -192,15 +200,15 @@ y_hat_rec <- recalibrated$y_samples_calibrated_wt
coverage = round(mean(CI == "in")*100,1))
coverage_rec |>
ggplot2::ggplot() +
ggplot2::geom_point(ggplot2::aes(x1, x2, color = CI),
ggplot() +
geom_point(aes(x1, x2, color = CI),
alpha = 0.7)+
ggplot2::labs(x="x1" , y="x2",
title = glue::glue("Recalibrated coverage: {coverage_rec$coverage[1]} %"))+
ggplot2::scale_color_manual("Confidence Interval",
labs(x="x1" , y="x2",
title = glue("Recalibrated coverage: {coverage_rec$coverage[1]} %"))+
scale_color_manual("Confidence Interval",
values = c("in" = "aquamarine3",
"out" = "steelblue4"))+
ggplot2::theme_classic()
theme_classic()
```
```{r}
Expand All @@ -215,15 +223,15 @@ cluster_means_cal <- cluster_means_cal[order(cluster_means_cal[,1]),]
# finding neighbours
knn_cal <- RANN::nn2(x_test,
knn_cal <- nn2(x_test,
cluster_means_cal,
k = n_neighbours)$nn.idx
# geting corresponding ys (real and estimated)
y_real_local <- purrr::map(1:nrow(knn_cal), ~y_test[knn_cal[.,]])
y_real_local <- map(1:nrow(knn_cal), ~y_test[knn_cal[.,]])
y_hat_local <- purrr::map(1:nrow(knn_cal), ~y_hat_rec[knn_cal[.,],])
y_hat_local <- map(1:nrow(knn_cal), ~y_hat_rec[knn_cal[.,],])
# calculate pit_local
Expand All @@ -232,39 +240,39 @@ pits <- matrix(NA,
ncol = n_neighbours)
for (i in 1:n_clusters) {
pits[i,] <- purrr::map_dbl(1:n_neighbours, ~{
pits[i,] <- map_dbl(1:n_neighbours, ~{
mean(y_hat_local[[i]][.,] <= y_hat_local[[i]][.])
})
}
as.data.frame(t(pits)) |>
tidyr::pivot_longer(everything()) |>
dplyr::group_by(name) |>
dplyr::mutate(p_value =ks.test(value,
pivot_longer(everything()) |>
group_by(name) |>
mutate(p_value =ks.test(value,
"punif")$p.value,
name = gsub("V", "part_", name)) |>
ggplot2::ggplot()+
ggplot2::geom_density(ggplot2::aes(value,
ggplot()+
geom_density(aes(value,
color = name,
fill = name),
alpha = 0.5,
bounds = c(0, 1))+
ggplot2::geom_hline(yintercept = 1,
geom_hline(yintercept = 1,
linetype="dashed")+
ggplot2::scale_color_brewer(palette = "Set2")+
ggplot2::scale_fill_brewer(palette = "Set2")+
ggplot2::theme_classic()+
ggplot2::geom_text(ggplot2::aes(x = 0.5,
scale_color_brewer(palette = "Set2")+
scale_fill_brewer(palette = "Set2")+
theme_classic()+
geom_text(aes(x = 0.5,
y = 0.5,
label = glue::glue("p-value: {round(p_value, 3)}")),
label = glue("p-value: {round(p_value, 3)}")),
color = "black",
size = 3)+
ggplot2::theme(legend.position = "none")+
ggplot2::labs(title = "After Local Calibration",
theme(legend.position = "none")+
labs(title = "After Local Calibration",
subtitle = "It looks so much better!!",
x = "PIT-values",
y = "Density")+
ggplot2::facet_wrap(~name, scales = "free_y")
facet_wrap(~name, scales = "free_y")
```

```{r}
Expand All @@ -273,26 +281,26 @@ data.frame(
desc = y_hat_test,
recal = recalibrated$y_hat_calibrated
) |>
tidyr::pivot_longer(-real) |>
ggplot2::ggplot()+
ggplot2::geom_point(ggplot2::aes( x = value,
pivot_longer(-real) |>
ggplot()+
geom_point(aes( x = value,
y = real,
color = name),
alpha = 0.8)+
ggplot2::scale_color_manual("", values = c( "#003366","#80b298"),
scale_color_manual("", values = c( "#003366","#80b298"),
labels = c("Predicted", "Recalibrated"))+
ggplot2::geom_abline(color="red", linetype="dashed")+
ggplot2::labs(x="Estimated Mean", y="True Mean")+
ggplot2::theme_bw(base_size = 14) +
ggplot2::theme(axis.title.y=ggplot2::element_text(colour="black"),
axis.title.x = ggplot2::element_text(colour="black"),
axis.text = ggplot2::element_text(colour = "black"),
geom_abline(color="red", linetype="dashed")+
labs(x="Estimated Mean", y="True Mean")+
theme_bw(base_size = 14) +
theme(axis.title.y=element_text(colour="black"),
axis.title.x = element_text(colour="black"),
axis.text = element_text(colour = "black"),
legend.position = c(0.8, 0.2),
panel.border = ggplot2::element_blank(),
panel.grid.major = ggplot2::element_blank(),
panel.grid.minor = ggplot2::element_blank(),
axis.line = ggplot2::element_line(colour = "black"),
plot.margin = ggplot2::margin(0, 0, 0, 0.2, "cm"))
panel.border = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
axis.line = element_line(colour = "black"),
plot.margin = margin(0, 0, 0, 0.2, "cm"))
```

0 comments on commit 94d02e4

Please sign in to comment.