Skip to content

Commit

Permalink
Merge pull request #121 from r-causal/geom-dag
Browse files Browse the repository at this point in the history
Add `aes_dag()` and `geom_dag()`
  • Loading branch information
malcolmbarrett authored Jan 29, 2024
2 parents c4c7b40 + 55d7049 commit 7e253a9
Show file tree
Hide file tree
Showing 48 changed files with 2,151 additions and 1,209 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Imports:
ggraph (>= 2.0.0),
ggrepel,
igraph,
lifecycle (>= 0.2.0),
magrittr,
pillar,
purrr,
Expand All @@ -44,5 +45,5 @@ VignetteBuilder:
Encoding: UTF-8
Language: en-US
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Config/testthat/edition: 3
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export("update_dag<-")
export("update_dag_data<-")
export(activate_collider_paths)
export(adjust_for)
export(aes_dag)
export(as_tidy_dagitty)
export(butterfly_bias)
export(collider_triangle)
Expand All @@ -63,6 +64,7 @@ export(dag_paths)
export(dagify)
export(expand_plot)
export(filter)
export(geom_dag)
export(geom_dag_collider_edges)
export(geom_dag_edges)
export(geom_dag_edges_arc)
Expand Down Expand Up @@ -174,6 +176,10 @@ importFrom(ggplot2,"%+replace%")
importFrom(ggplot2,aes)
importFrom(ggplot2,fortify)
importFrom(ggplot2,ggplot)
importFrom(lifecycle,deprecate_soft)
importFrom(lifecycle,deprecate_warn)
importFrom(lifecycle,deprecated)
importFrom(lifecycle,is_present)
importFrom(magrittr,"%$%")
importFrom(magrittr,"%>%")
importFrom(purrr,"%||%")
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ggdag (development version)

* Introduced new functions `aes_dag()` and `geom_dag()` to simplify specification of ggplot code for most DAGs. Also refactored most quick plots to use these functions (#121)

# ggdag 0.2.11

* Internal update to address upcoming changes in ggplot2 (#125, thanks @teunbrand)
Expand Down
151 changes: 71 additions & 80 deletions R/adjustment_sets.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,7 @@
#' `NULL`, in which case it will be determined from the DAG.
#' @param ... additional arguments to `adjustmentSets`
#' @param shadow logical. Show paths blocked by adjustment?
#' @param node_size size of DAG node
#' @param text_size size of DAG text
#' @param label_size size of label text
#' @param text_col color of DAG text
#' @param label_col color of label text
#' @param node logical. Should nodes be included in the DAG?
#' @param stylized logical. Should DAG nodes be stylized? If so, use
#' `geom_dag_nodes` and if not use `geom_dag_point`
#' @param text logical. Should text be included in the DAG?
#' @param use_labels a string. Variable to use for `geom_dag_label_repel()`.
#' Default is `NULL`.
#' @inheritParams geom_dag
#' @inheritParams expand_plot
#'
#' @return a `tidy_dagitty` with an `adjusted` column and `set`
Expand Down Expand Up @@ -81,25 +71,29 @@ extract_sets <- function(sets) {
#' @rdname adjustment_sets
#' @export
ggdag_adjustment_set <- function(.tdy_dag, exposure = NULL, outcome = NULL, ..., shadow = FALSE,
node_size = 16, text_size = 3.88, label_size = text_size,
text_col = "white", label_col = text_col,
node = TRUE, stylized = FALSE, text = TRUE, use_labels = NULL,
size = 1, node_size = 16, text_size = 3.88,
label_size = text_size,
text_col = "white", label_col = "black",
edge_width = 0.6, edge_cap = 8, arrow_length = 5,
use_edges = TRUE, use_nodes = TRUE, use_stylized = FALSE,
use_text = TRUE, use_labels = FALSE, label = NULL,
text = NULL, node = deprecated(), stylized = deprecated(),
expand_x = expansion(c(0.25, 0.25)),
expand_y = expansion(c(0.2, 0.2))) {
.tdy_dag <- if_not_tidy_daggity(.tdy_dag) %>%
dag_adjustment_sets(exposure = exposure, outcome = outcome, ...)

p <- ggplot2::ggplot(.tdy_dag, ggplot2::aes(
x = x, y = y, xend = xend,
yend = yend, shape = adjusted,
col = adjusted
)) +
ggplot2::facet_wrap(~set) +
p <- ggplot2::ggplot(
.tdy_dag,
aes_dag(shape = adjusted, color = adjusted)
) +
ggplot2::facet_wrap(~ set) +
scale_adjusted() +
expand_plot(expand_x = expand_x, expand_y = expand_y)

if (shadow) {
p <- p + geom_dag_edges(ggplot2::aes(edge_alpha = adjusted),
p <- p + geom_dag_edges(
ggplot2::aes(edge_alpha = adjusted),
start_cap = ggraph::circle(10, "mm"),
end_cap = ggraph::circle(10, "mm")
)
Expand All @@ -116,27 +110,28 @@ ggdag_adjustment_set <- function(.tdy_dag, exposure = NULL, outcome = NULL, ...,
)
}

if (node) {
if (stylized) {
p <- p + geom_dag_node(size = node_size)
} else {
p <- p + geom_dag_point(size = node_size)
}
}

if (text) p <- p + geom_dag_text(col = text_col, size = text_size)
p <- p +
geom_dag(
size = size,
node_size = node_size,
text_size = text_size,
label_size = label_size,
text_col = text_col,
label_col = label_col,
edge_width = edge_width,
edge_cap = edge_cap,
arrow_length = arrow_length,
use_edges = FALSE,
use_nodes = use_nodes,
use_stylized = use_stylized,
use_text = use_text,
use_labels = use_labels,
text = !!rlang::enquo(text),
label = !!rlang::enquo(label),
node = node,
stylized = stylized
)

if (!is.null(use_labels)) {
p <- p +
geom_dag_label_repel(
ggplot2::aes(
label = !!rlang::sym(use_labels),
fill = adjusted
),
size = text_size,
col = label_col, show.legend = FALSE
)
}
p
}

Expand Down Expand Up @@ -176,17 +171,7 @@ is_confounder <- function(.tdy_dag, z, x, y, direct = FALSE) {
#' `dagitty`
#' @param var a character vector, the variable(s) to adjust for.
#' @param ... additional arguments passed to `tidy_dagitty()`
#' @param node_size size of DAG node
#' @param text_size size of DAG text
#' @param label_size size of label text
#' @param text_col color of DAG text
#' @param label_col color of label text
#' @param node logical. Should nodes be included in the DAG?
#' @param stylized logical. Should DAG nodes be stylized? If so, use
#' `geom_dag_nodes` and if not use `geom_dag_point`
#' @param text logical. Should text be included in the DAG?
#' @param use_labels a string. Variable to use for
#' `geom_dag_label_repel()`. Default is `NULL`.
#' @inheritParams geom_dag
#' @param collider_lines logical. Should the plot show paths activated by
#' adjusting for a collider?
#' @param as_factor logical. Should the `adjusted` column be a factor?
Expand Down Expand Up @@ -222,9 +207,14 @@ adjust_for <- control_for
#' @rdname control_for
#' @export
ggdag_adjust <- function(.tdy_dag, var = NULL, ...,
size = 1, edge_type = c("link_arc", "link", "arc", "diagonal"),
node_size = 16, text_size = 3.88, label_size = text_size,
text_col = "white", label_col = text_col,
node = TRUE, stylized = FALSE, text = TRUE, use_labels = NULL, collider_lines = TRUE) {
text_col = "white", label_col = "black",
edge_width = 0.6, edge_cap = 10, arrow_length = 5,
use_edges = TRUE,
use_nodes = TRUE, use_stylized = FALSE, use_text = TRUE,
use_labels = FALSE, text = NULL, label = NULL,
node = deprecated(), stylized = deprecated(), collider_lines = TRUE) {
.tdy_dag <- if_not_tidy_daggity(.tdy_dag, ...)
if (!is.null(var)) {
.tdy_dag <- .tdy_dag %>% control_for(var)
Expand All @@ -235,38 +225,39 @@ ggdag_adjust <- function(.tdy_dag, var = NULL, ...,
}

p <- .tdy_dag %>%
ggplot2::ggplot(ggplot2::aes(
x = x, y = y, xend = xend, yend = yend,
col = adjusted, shape = adjusted
)) +
geom_dag_edges(ggplot2::aes(edge_alpha = adjusted),
start_cap = ggraph::circle(10, "mm"),
end_cap = ggraph::circle(10, "mm")
ggplot2::ggplot(aes_dag(col = adjusted, shape = adjusted)) +
geom_dag_edges(
ggplot2::aes(edge_alpha = adjusted),
start_cap = ggraph::circle(edge_cap, "mm"),
end_cap = ggraph::circle(edge_cap, "mm")
) +
scale_adjusted() +
expand_plot(expand_y = expansion(c(0.2, 0.2)))

if (collider_lines) p <- p + geom_dag_collider_edges()
if (node) {
if (stylized) {
p <- p + geom_dag_node(size = node_size)
} else {
p <- p + geom_dag_point(size = node_size)
}
}

if (text) p <- p + geom_dag_text(col = text_col, size = text_size)
p <- p +
geom_dag(
size = size,
edge_type = edge_type,
node_size = node_size,
text_size = text_size,
label_size = label_size,
text_col = text_col,
label_col = label_col,
edge_width = edge_width,
edge_cap = edge_cap,
arrow_length = arrow_length,
use_edges = FALSE,
use_nodes = use_nodes,
use_stylized = use_stylized,
use_text = use_text,
use_labels = use_labels,
text = !!rlang::enquo(text),
label = !!rlang::enquo(label),
node = node,
stylized = stylized
)

if (!is.null(use_labels)) {
p <- p +
geom_dag_label_repel(
ggplot2::aes(
label = !!rlang::sym(use_labels),
fill = adjusted
),
size = text_size,
col = label_col, show.legend = FALSE
)
}
p
}
25 changes: 7 additions & 18 deletions R/canonical.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,7 @@
#' @param .dag,.tdy_dag input graph, an object of class `tidy_dagitty` or
#' `dagitty`
#' @param ... additional arguments passed to `tidy_dagitty()`
#' @param edge_type a character vector, the edge geom to use. One of:
#' "link_arc", which accounts for directed and bidirected edges, "link",
#' "arc", or "diagonal"
#' @param node_size size of DAG node
#' @param text_size size of DAG text
#' @param label_size size of label text
#' @param text_col color of DAG text
#' @param label_col color of label text
#' @param node logical. Should nodes be included in the DAG?
#' @param stylized logical. Should DAG nodes be stylized? If so, use
#' `geom_dag_nodes` and if not use `geom_dag_point`
#' @param text logical. Should text be included in the DAG?
#' @param use_labels a string. Variable to use for `geom_dag_label_repel()`.
#' Default is `NULL`.
#' @inheritParams geom_dag
#'
#' @return a `tidy_dagitty` that includes L or a `ggplot`
#' @export
Expand All @@ -45,14 +32,16 @@ node_canonical <- function(.dag, ...) {
#' @export
ggdag_canonical <- function(.tdy_dag, ..., edge_type = "link_arc", node_size = 16, text_size = 3.88,
label_size = text_size,
text_col = "white", label_col = text_col,
node = TRUE, stylized = FALSE, text = TRUE,
use_labels = NULL) {
text_col = "white", label_col = text_col, use_edges = TRUE,
use_nodes = TRUE, use_stylized = FALSE, use_text = TRUE,
use_labels = NULL, label = NULL, text = NULL, node = deprecated(),
stylized = deprecated()) {
if_not_tidy_daggity(.tdy_dag, ...) %>%
node_canonical() %>%
ggdag(
node_size = node_size, text_size = text_size, label_size,
edge_type = edge_type, text_col = text_col, label_col = label_col,
node = node, stylized = stylized, text = text, use_labels = use_labels
use_edges = use_edges, use_nodes = use_nodes, use_stylized = use_stylized,
use_text = use_text, use_labels = use_labels
)
}
Loading

0 comments on commit 7e253a9

Please sign in to comment.