From 6431195d7a2f73bbd56c6479ca749b6daffcbc7d Mon Sep 17 00:00:00 2001 From: Michael Geuenich Date: Mon, 23 Oct 2023 13:46:13 -0400 Subject: [PATCH] AR params & main heatmap --- pipeline/figures/AR-params.R | 37 ++++++++++----- pipeline/figures/figure3.R | 91 +++++++++++++++++++++++++++++++----- 2 files changed, 106 insertions(+), 22 deletions(-) diff --git a/pipeline/figures/AR-params.R b/pipeline/figures/AR-params.R index debe3d7..fb760c9 100644 --- a/pipeline/figures/AR-params.R +++ b/pipeline/figures/AR-params.R @@ -11,8 +11,15 @@ scrna_acc <- read_tsv(snakemake@input$scrna_acc) |> mutate(cohort = "scRNASeq") snrna_acc <- read_tsv(snakemake@input$snrna_acc) |> mutate(cohort = "snRNASeq") +scrna_lung_acc <- read_tsv(snakemake@input$scrna_lung) |> + mutate(cohort = "scRNALung") +liverAtlas_acc <- read_tsv(snakemake@input$liver) |> + mutate(cohort = "liverAtlas") +tabulaVasc_acc <- read_tsv(snakemake@input$vasc) |> + mutate(cohort = "tabulaVasc") -acc <- bind_rows(cytof_acc, scrna_acc, snrna_acc) +acc <- bind_rows(cytof_acc, scrna_acc, snrna_acc, + scrna_lung_acc, liverAtlas_acc, tabulaVasc_acc) @@ -43,19 +50,27 @@ plot_knn_res <- function(acc, fill, cohort){ } -pdf(snakemake@output$res, height = 14, width = 9) - (plot_knn_res(cytof_acc, "res", "CyTOF") / - plot_knn_res(scrna_acc, "res", "scRNASeq") / - plot_knn_res(snrna_acc, "res", "snRNASeq")) + - plot_layout(guides = "collect", heights = c(1, 2, 2)) +pdf(snakemake@output$res, + height = 23, width = 15) + (plot_knn_res(cytof_acc, "res", "CyTOF - Bone marrow") / + plot_knn_res(scrna_acc, "res", "scRNASeq - Breast cancer cell lines") / + plot_knn_res(snrna_acc, "res", "snRNASeq - Pancreas cancer")) | + (plot_knn_res(scrna_lung_acc, "res", "scRNASeq - Lung cancer cell lines") / + plot_knn_res(liverAtlas_acc, "res", "scRNASeq - Liver") / + plot_knn_res(tabulaVasc_acc, "res", "scRNASeq - Vasculature")) + + plot_layout(guides = "collect") dev.off() -pdf(snakemake@output$knn, height = 14, width = 9) -(plot_knn_res(cytof_acc, "knn", "CyTOF") / - plot_knn_res(scrna_acc, "knn", "scRNASeq") / - plot_knn_res(snrna_acc, "knn", "snRNASeq")) + - plot_layout(guides = "collect", heights = c(1, 2, 2)) +pdf(snakemake@output$knn, + height = 23, width = 15) + (plot_knn_res(cytof_acc, "knn", "CyTOF - Bone marrow") / + plot_knn_res(scrna_acc, "knn", "scRNASeq - Breast cancer cell lines") / + plot_knn_res(snrna_acc, "knn", "snRNASeq - Pancreas cancer")) | + (plot_knn_res(scrna_lung_acc, "knn", "scRNASeq - Lung cancer cell lines") / + plot_knn_res(liverAtlas_acc, "knn", "scRNASeq - Liver") / + plot_knn_res(tabulaVasc_acc, "knn", "scRNASeq - Vasculature")) + + plot_layout(guides = "collect") dev.off() diff --git a/pipeline/figures/figure3.R b/pipeline/figures/figure3.R index 63b5f2d..f435019 100644 --- a/pipeline/figures/figure3.R +++ b/pipeline/figures/figure3.R @@ -5,25 +5,34 @@ suppressPackageStartupMessages({ library(dplyr) library(scales) library(magick) + library(tidyr) + library(tibble) + library(ComplexHeatmap) + library(circlize) + library(colorspace) + library(ggplot2) }) -devtools::load_all("/ggplot2") source("pipeline/whatsthatcell-helpers.R") ### [ ACCURACIES ] #### -acc <- lapply(snakemake@input$accs, function(x){ +acc <- lapply(snakemake@input$accs, + function(x){ df <- read_tsv(x) |> mutate(cohort = case_when(grepl("CyTOF", basename(x)) ~ "CyTOF", grepl("snRNASeq", basename(x)) ~ "snRNASeq", - grepl("scRNASeq", basename(x)) ~ "scRNASeq")) + grepl("scRNASeq", basename(x)) ~ "scRNASeq", + grepl("scRNALung", basename(x)) ~ "scRNALung", + grepl("tabulaLiver", basename(x)) ~ "tabulaLiver", + grepl("tabulaVasc", basename(x)) ~ "tabulaVasc", + grepl("liverAtlas-", basename(x)) ~ "liverAtlas")) df }) |> bind_rows() |> - mutate(cohort = factor(cohort, levels = c("scRNASeq", "snRNASeq", "CyTOF")), - selection_procedure = case_when(selection_procedure == "random" ~ "Random", + mutate(selection_procedure = case_when(selection_procedure == "random" ~ "Random", selection_procedure == "NoMarkerSeurat_clustering" ~ "AR No Marker", selection_procedure == "MarkerSeurat_clustering" ~ "AR Marker", - selection_procedure == "highest-entropy-AL" ~ "AL Highest-entropy", - selection_procedure == "lowest-maxp-AL" ~ "AL Lowest-maxp", + selection_procedure == "highest-entropy-AL" ~ "AL Highest entropy", + selection_procedure == "lowest-maxp-AL" ~ "AL Lowest maxp", selection_procedure == "0.95-entropy-AL" ~ "AL 0.95-entropy", selection_procedure == "0.05-maxp-AL" ~ "AL 0.05-maxp", selection_procedure == "0.25-maxp-AL" ~ "AL 0.25-maxp", @@ -33,10 +42,70 @@ acc <- lapply(snakemake@input$accs, function(x){ sel_meth_cols <- sel_met_cols -eval <- full_acc_plot_wrapper(acc, "rf", "ranking", "") & - labs(fill = "Selection method") +col_fun <- colorRamp2(c(-1,0,1), c("skyblue","white", "brown1")) -pdf(snakemake@output$overall_fig, height = 8.4, width = 9) - eval +hm_mat <- acc |> + group_by(method, knn, res, cell_num, initial, selection_procedure, AL_alg, .metric, cohort) |> + summarize(mean.estimate = mean(na.omit(.estimate))) |> + ungroup() |> + pivot_wider(names_from = .metric, values_from = mean.estimate) |> + select(bal_accuracy:sensitivity) |> + dplyr::rename("Balanced accuracy" = "bal_accuracy", + "F1-score" = "f_meas", + "Kappa" = "kap", + "MCC" = "mcc", + "Sensitivity" = "sensitivity") |> + as.matrix() |> + cor(use = "pairwise.complete.obs") + +pdf(snakemake@output$hm, height = 6, width = 7) + Heatmap(hm_mat, + col = col_fun, + name = "Correlation\ncoefficient", + cell_fun = function(j, i, x, y, width, height, fill) { + grid.text(sprintf("%.1f", hm_mat[i, j]), x, y, gp = gpar(fontsize = 10)) + } + ) +dev.off() + + +random <- filter(acc, selection_procedure == "Random") |> + select(-c(knn, res, rand, corrupted, initial, selection_procedure, AL_alg, .estimator)) |> + dplyr::rename("rand_estimate" = ".estimate") + +rand_improvement <- acc |> + filter(initial == "ranking" | is.na(initial)) |> + filter(AL_alg == "rf" | is.na(AL_alg)) |> + select(-c(.estimator, rand, corrupted, initial)) |> + left_join(random, by = c("method", "cell_num", "seed", ".metric", "cohort")) |> + mutate(rand_improvement = (.estimate - rand_estimate) / rand_estimate) + + +hm <- rand_improvement |> + group_by(selection_procedure, .metric, cohort) |> + summarize(median_improvement = median(na.omit(rand_improvement))) |> + group_by(.metric, cohort) |> + arrange(median_improvement) |> + mutate(rank = length(unique(rand_improvement$selection_procedure)):1) |> + filter(.metric == "bal_accuracy" | .metric == "sensitivity") |> + mutate(.metric = case_when(.metric == "bal_accuracy" ~ "Balanced accuracy", + .metric == "sensitivity" ~ "Sensitivity"), + cohort = case_when(cohort == "tabulaVasc" ~ "scRNASeq - Vasculature", + cohort == "snRNASeq" ~ "snRNASeq - Pancreas cancer", + cohort == "scRNASeq" ~ "scRNASeq - Breast cancer cell lines", + cohort == "scRNALung" ~ "scRNASeq - Lung cancer cell lines", + cohort == "liverAtlas" ~ "scRNASeq - Liver", + cohort == "CyTOF" ~ "CyTOF - Bone marrow")) |> + ggplot(aes(x = selection_procedure, y = cohort, fill = as.factor(rank))) + + geom_tile() + + scale_fill_viridis_d(direction = -1) + + labs(x = "Selection procedure", y = "", fill = "Rank") + + facet_wrap(~.metric, nrow = 1) + + whatsthatcell_theme() + + theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) + + +pdf(snakemake@output$overall_fig, height = 4, width = 9) + hm dev.off()