Skip to content

Commit

Permalink
Update profiles section with shap visuals
Browse files Browse the repository at this point in the history
  • Loading branch information
langbart committed Dec 4, 2023
1 parent a2431ae commit 074a773
Show file tree
Hide file tree
Showing 27 changed files with 249 additions and 87 deletions.
69 changes: 42 additions & 27 deletions R/xgb-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,16 @@ plot_shap <- function(shap_object = NULL, model_type = NULL, alpha = NULL) {
mesh_shap = shap_object$S[, 3]
)
process_shap %>%
ggplot2::ggplot(aes(mesh_fact, mesh_shap, color = habitat_fact)) +
ggplot2::geom_jitter(width = 3, alpha = alpha, size = 1.5) +
ggplot2::ggplot(ggplot2::aes(mesh_fact, mesh_shap, color = habitat_fact)) +
ggplot2::geom_jitter(width = 2, alpha = alpha, size = 1.5, show.legend = FALSE) +
ggplot2::geom_point(size = 0.1) +
ggplot2::theme_minimal() +
ggplot2::scale_x_continuous(n.breaks = 10) +
ggplot2::geom_hline(yintercept = 0, linetype = 2, color = "grey50") +
ggplot2::scale_color_manual(values = c("#f28f3b", "grey50", "#ffd5c2", "#588b8b", "#c8553d", "#2d3047", "#93b7be"))+
ggplot2::coord_cartesian(expand = FALSE)+
ggplot2::labs(color = "Habitat")

ggplot2::scale_color_manual(values = c("#f28f3b", "#c27ba0", "#ffd5c2", "#588b8b", "#c8553d", "#2d3047", "#007ea7")) +
ggplot2::coord_cartesian(expand = FALSE) +
ggplot2::labs(color = "Habitat") +
ggplot2::guides(color = ggplot2::guide_legend(override.aes = list(size = 1.75)))
} else {
process_shap <-
dplyr::tibble(
Expand All @@ -269,14 +270,15 @@ plot_shap <- function(shap_object = NULL, model_type = NULL, alpha = NULL) {

process_shap %>%
ggplot2::ggplot(aes(reorder(habitat_gear_fact, habitat_gear_shap), habitat_gear_shap, color = vessel_fact)) +
ggplot2::geom_jitter(width = 0.5, alpha = alpha, size = 1.5) +
ggplot2::geom_jitter(width = 0.2, alpha = alpha, size = 1.5, show.legend = FALSE) +
ggplot2::geom_point(size = 0.1) +
ggplot2::theme_minimal() +
ggplot2::geom_hline(yintercept = 0, linetype = 2, color = "grey50") +
ggplot2::scale_color_manual(values = c("grey50", "#bc4749"))+
ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)) +
ggplot2::coord_cartesian(expand = FALSE)+
ggplot2::labs(color = "Transport")

ggplot2::scale_color_manual(values = c("grey50", "#bc4749")) +
ggplot2::coord_cartesian(expand = FALSE) +
ggplot2::labs(color = "Transport") +
ggplot2::guides(color = ggplot2::guide_legend(override.aes = list(size = 1.75))) +
ggplot2::coord_flip()
}
}

Expand All @@ -288,6 +290,7 @@ plot_shap <- function(shap_object = NULL, model_type = NULL, alpha = NULL) {
#' @param model_type A character string specifying the model type, passed to the `plot_shap` function.
#' @param alpha The alpha value for geom_jitter in ggplot2, controlling point transparency.
#' @param cols The number of columns in the plot layout.
#' @param drop_legend Wether to return legend. Default is TRUE.
#'
#' @details
#' The function uses the `shapviz` package for initial processing and then applies `plot_shap` to each model.
Expand All @@ -299,7 +302,7 @@ plot_shap <- function(shap_object = NULL, model_type = NULL, alpha = NULL) {
#' \dontrun{
#' plot_model_shaps(data_shaps = my_model_shaps, model_type = "gn", alpha = 0.2, cols = 2)
#' }
plot_model_shaps <- function(data_shaps = NULL, model_type = NULL, alpha = 0.2, cols = 1) {
plot_model_shaps <- function(data_shaps = NULL, model_type = NULL, alpha = 0.2, cols = 1, drop_legend = FALSE) {
sha <- shapviz::shapviz(data_shaps)
shapviz_object <- purrr::map(sha, plot_shap, model_type = model_type, alpha = alpha)

Expand All @@ -318,6 +321,7 @@ plot_model_shaps <- function(data_shaps = NULL, model_type = NULL, alpha = 0.2,
legend.key.size = ggplot2::unit(0.8, "cm"),
legend.title = ggplot2::element_text(size = 12)
))

combined_plots <- cowplot::plot_grid(
plotlist = plots,
ncol = cols,
Expand All @@ -331,23 +335,34 @@ plot_model_shaps <- function(data_shaps = NULL, model_type = NULL, alpha = 0.2,

if (model_type == "gn") {
x_label <- cowplot::draw_label("Mesh size (mm)", x = 0.5, y = 0.05)
y_label <- cowplot::draw_label("SHAP value", x = 0.015, y = 0.5, angle = 90)
} else {
x_label <- cowplot::draw_label("Habitat x Gear type ", x = 0.5, y = 0.05)
x_label <- cowplot::draw_label("Habitat x Gear type ", x = 0.015, y = 0.5, angle = 90)
y_label <- cowplot::draw_label("SHAP value", x = 0.5, y = 0.05)
}

y_label <- cowplot::draw_label("SHAP value (impact on model output)", x = 0.015, y = 0.5, angle = 90)

final_plot <-
cowplot::plot_grid(
combined_plots,
legend_plot,
ncol = 2,
rel_widths = c(1, 0.22),
scale = 0.9,
greedy = TRUE
) +
x_label +
y_label

if (drop_legend == TRUE) {
final_plot <-
cowplot::plot_grid(combined_plots,
scale = 0.9,
greedy = TRUE
) +
x_label +
y_label
} else {
final_plot <-
cowplot::plot_grid(
combined_plots,
legend_plot,
ncol = 2,
rel_widths = c(1, 0.22),
scale = 0.9,
greedy = TRUE
) +
x_label +
y_label
}

final_plot
}
2 changes: 1 addition & 1 deletion docs/404.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
<meta name="author" content="Lore" />


<meta name="date" content="2023-12-03" />
<meta name="date" content="2023-12-04" />

<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="apple-mobile-web-app-capable" content="yes" />
Expand Down
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/data.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
<meta name="author" content="Lore" />


<meta name="date" content="2023-12-03" />
<meta name="date" content="2023-12-04" />

<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="apple-mobile-web-app-capable" content="yes" />
Expand Down
2 changes: 1 addition & 1 deletion docs/distribution.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
<meta name="author" content="Lore" />


<meta name="date" content="2023-12-03" />
<meta name="date" content="2023-12-04" />

<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="apple-mobile-web-app-capable" content="yes" />
Expand Down
Loading

0 comments on commit 074a773

Please sign in to comment.