Skip to content

Commit

Permalink
AR params & main heatmap
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-Geuenich committed Oct 23, 2023
1 parent a1ac35a commit 6431195
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 22 deletions.
37 changes: 26 additions & 11 deletions pipeline/figures/AR-params.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,15 @@ scrna_acc <- read_tsv(snakemake@input$scrna_acc) |>
mutate(cohort = "scRNASeq")
snrna_acc <- read_tsv(snakemake@input$snrna_acc) |>
mutate(cohort = "snRNASeq")
scrna_lung_acc <- read_tsv(snakemake@input$scrna_lung) |>
mutate(cohort = "scRNALung")
liverAtlas_acc <- read_tsv(snakemake@input$liver) |>
mutate(cohort = "liverAtlas")
tabulaVasc_acc <- read_tsv(snakemake@input$vasc) |>
mutate(cohort = "tabulaVasc")

acc <- bind_rows(cytof_acc, scrna_acc, snrna_acc)
acc <- bind_rows(cytof_acc, scrna_acc, snrna_acc,
scrna_lung_acc, liverAtlas_acc, tabulaVasc_acc)



Expand Down Expand Up @@ -43,19 +50,27 @@ plot_knn_res <- function(acc, fill, cohort){
}


pdf(snakemake@output$res, height = 14, width = 9)
(plot_knn_res(cytof_acc, "res", "CyTOF") /
plot_knn_res(scrna_acc, "res", "scRNASeq") /
plot_knn_res(snrna_acc, "res", "snRNASeq")) +
plot_layout(guides = "collect", heights = c(1, 2, 2))
pdf(snakemake@output$res,
height = 23, width = 15)
(plot_knn_res(cytof_acc, "res", "CyTOF - Bone marrow") /
plot_knn_res(scrna_acc, "res", "scRNASeq - Breast cancer cell lines") /
plot_knn_res(snrna_acc, "res", "snRNASeq - Pancreas cancer")) |
(plot_knn_res(scrna_lung_acc, "res", "scRNASeq - Lung cancer cell lines") /
plot_knn_res(liverAtlas_acc, "res", "scRNASeq - Liver") /
plot_knn_res(tabulaVasc_acc, "res", "scRNASeq - Vasculature")) +
plot_layout(guides = "collect")
dev.off()


pdf(snakemake@output$knn, height = 14, width = 9)
(plot_knn_res(cytof_acc, "knn", "CyTOF") /
plot_knn_res(scrna_acc, "knn", "scRNASeq") /
plot_knn_res(snrna_acc, "knn", "snRNASeq")) +
plot_layout(guides = "collect", heights = c(1, 2, 2))
pdf(snakemake@output$knn,
height = 23, width = 15)
(plot_knn_res(cytof_acc, "knn", "CyTOF - Bone marrow") /
plot_knn_res(scrna_acc, "knn", "scRNASeq - Breast cancer cell lines") /
plot_knn_res(snrna_acc, "knn", "snRNASeq - Pancreas cancer")) |
(plot_knn_res(scrna_lung_acc, "knn", "scRNASeq - Lung cancer cell lines") /
plot_knn_res(liverAtlas_acc, "knn", "scRNASeq - Liver") /
plot_knn_res(tabulaVasc_acc, "knn", "scRNASeq - Vasculature")) +
plot_layout(guides = "collect")
dev.off()


Expand Down
91 changes: 80 additions & 11 deletions pipeline/figures/figure3.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,34 @@ suppressPackageStartupMessages({
library(dplyr)
library(scales)
library(magick)
library(tidyr)
library(tibble)
library(ComplexHeatmap)
library(circlize)
library(colorspace)
library(ggplot2)
})
devtools::load_all("/ggplot2")
source("pipeline/whatsthatcell-helpers.R")

### [ ACCURACIES ] ####
acc <- lapply(snakemake@input$accs, function(x){
acc <- lapply(snakemake@input$accs,
function(x){
df <- read_tsv(x) |>
mutate(cohort = case_when(grepl("CyTOF", basename(x)) ~ "CyTOF",
grepl("snRNASeq", basename(x)) ~ "snRNASeq",
grepl("scRNASeq", basename(x)) ~ "scRNASeq"))
grepl("scRNASeq", basename(x)) ~ "scRNASeq",
grepl("scRNALung", basename(x)) ~ "scRNALung",
grepl("tabulaLiver", basename(x)) ~ "tabulaLiver",
grepl("tabulaVasc", basename(x)) ~ "tabulaVasc",
grepl("liverAtlas-", basename(x)) ~ "liverAtlas"))
df
}) |>
bind_rows() |>
mutate(cohort = factor(cohort, levels = c("scRNASeq", "snRNASeq", "CyTOF")),
selection_procedure = case_when(selection_procedure == "random" ~ "Random",
mutate(selection_procedure = case_when(selection_procedure == "random" ~ "Random",
selection_procedure == "NoMarkerSeurat_clustering" ~ "AR No Marker",
selection_procedure == "MarkerSeurat_clustering" ~ "AR Marker",
selection_procedure == "highest-entropy-AL" ~ "AL Highest-entropy",
selection_procedure == "lowest-maxp-AL" ~ "AL Lowest-maxp",
selection_procedure == "highest-entropy-AL" ~ "AL Highest entropy",
selection_procedure == "lowest-maxp-AL" ~ "AL Lowest maxp",
selection_procedure == "0.95-entropy-AL" ~ "AL 0.95-entropy",
selection_procedure == "0.05-maxp-AL" ~ "AL 0.05-maxp",
selection_procedure == "0.25-maxp-AL" ~ "AL 0.25-maxp",
Expand All @@ -33,10 +42,70 @@ acc <- lapply(snakemake@input$accs, function(x){

sel_meth_cols <- sel_met_cols

eval <- full_acc_plot_wrapper(acc, "rf", "ranking", "") &
labs(fill = "Selection method")
col_fun <- colorRamp2(c(-1,0,1), c("skyblue","white", "brown1"))

pdf(snakemake@output$overall_fig, height = 8.4, width = 9)
eval
hm_mat <- acc |>
group_by(method, knn, res, cell_num, initial, selection_procedure, AL_alg, .metric, cohort) |>
summarize(mean.estimate = mean(na.omit(.estimate))) |>
ungroup() |>
pivot_wider(names_from = .metric, values_from = mean.estimate) |>
select(bal_accuracy:sensitivity) |>
dplyr::rename("Balanced accuracy" = "bal_accuracy",
"F1-score" = "f_meas",
"Kappa" = "kap",
"MCC" = "mcc",
"Sensitivity" = "sensitivity") |>
as.matrix() |>
cor(use = "pairwise.complete.obs")

pdf(snakemake@output$hm, height = 6, width = 7)
Heatmap(hm_mat,
col = col_fun,
name = "Correlation\ncoefficient",
cell_fun = function(j, i, x, y, width, height, fill) {
grid.text(sprintf("%.1f", hm_mat[i, j]), x, y, gp = gpar(fontsize = 10))
}
)
dev.off()


random <- filter(acc, selection_procedure == "Random") |>
select(-c(knn, res, rand, corrupted, initial, selection_procedure, AL_alg, .estimator)) |>
dplyr::rename("rand_estimate" = ".estimate")

rand_improvement <- acc |>
filter(initial == "ranking" | is.na(initial)) |>
filter(AL_alg == "rf" | is.na(AL_alg)) |>
select(-c(.estimator, rand, corrupted, initial)) |>
left_join(random, by = c("method", "cell_num", "seed", ".metric", "cohort")) |>
mutate(rand_improvement = (.estimate - rand_estimate) / rand_estimate)


hm <- rand_improvement |>
group_by(selection_procedure, .metric, cohort) |>
summarize(median_improvement = median(na.omit(rand_improvement))) |>
group_by(.metric, cohort) |>
arrange(median_improvement) |>
mutate(rank = length(unique(rand_improvement$selection_procedure)):1) |>
filter(.metric == "bal_accuracy" | .metric == "sensitivity") |>
mutate(.metric = case_when(.metric == "bal_accuracy" ~ "Balanced accuracy",
.metric == "sensitivity" ~ "Sensitivity"),
cohort = case_when(cohort == "tabulaVasc" ~ "scRNASeq - Vasculature",
cohort == "snRNASeq" ~ "snRNASeq - Pancreas cancer",
cohort == "scRNASeq" ~ "scRNASeq - Breast cancer cell lines",
cohort == "scRNALung" ~ "scRNASeq - Lung cancer cell lines",
cohort == "liverAtlas" ~ "scRNASeq - Liver",
cohort == "CyTOF" ~ "CyTOF - Bone marrow")) |>
ggplot(aes(x = selection_procedure, y = cohort, fill = as.factor(rank))) +
geom_tile() +
scale_fill_viridis_d(direction = -1) +
labs(x = "Selection procedure", y = "", fill = "Rank") +
facet_wrap(~.metric, nrow = 1) +
whatsthatcell_theme() +
theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1))


pdf(snakemake@output$overall_fig, height = 4, width = 9)
hm
dev.off()

0 comments on commit 6431195

Please sign in to comment.