-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
545aab8
commit 8950545
Showing
1 changed file
with
314 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,314 @@ | ||
suppressPackageStartupMessages({ | ||
library(tidyverse) | ||
library(ggalluvial) | ||
library(patchwork) | ||
}) | ||
source("pipeline/whatsthatcell-helpers.R") | ||
|
||
### PREDICTIVE LABELLING ACCURACY | ||
pred_lab_acc <- read_tsv("output/v8/new/pred-labeling-accuracy.tsv") | ||
|
||
# SUPPLEMENTAL FILE | ||
pdf("output/v8/paper-figures/Supp-pred-labelling-acc.pdf", height = 10, width = 12) | ||
pred_lab_acc |> | ||
filter(.metric == "f_meas") |> | ||
mutate(pred_sel = gsub("top", "", pred_sel), | ||
pred_sel = paste0(pred_sel, "%")) |> | ||
mutate(pred_sel = factor(pred_sel, c("10%", "50%", "100%"))) |> | ||
mutate(pred = case_when(pred == "multinom" ~ "LR", | ||
pred == "rf" ~ "RF"), | ||
selection_procedure = case_when(selection_procedure == "MarkerSeurat-clustering" ~ "AR Marker", | ||
selection_procedure == "NoMarkerSeurat-clustering" ~ "AR No Marker", | ||
selection_procedure == "Active-Learning_entropy" ~ "AL Highest entropy", | ||
selection_procedure == "Active-Learning_maxp" ~ "AL Lowest maxp", | ||
selection_procedure == "random" ~ "Random")) |> | ||
ggplot(aes(x = pred_sel, y = .estimate, fill = pred)) + | ||
geom_boxplot() + | ||
labs(x = "Percentage of most confidently labelled cells selected", | ||
fill = "Self-\ntraining\nalgorithm", y = "F-1 score") + | ||
scale_fill_manual(values = al_colours()) + | ||
facet_grid(selection_procedure~mod+cell_num) + | ||
whatsthatcell_theme() + | ||
theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) | ||
dev.off() | ||
|
||
# PANEL A, MAIN FIGURE | ||
pred_lab_acc_pub <- pred_lab_acc |> | ||
filter(.metric == "f_meas" & selection_procedure == "MarkerSeurat-clustering") |> | ||
mutate(pred_sel = gsub("top", "", pred_sel), | ||
pred_sel = paste0(pred_sel, "%")) |> | ||
mutate(pred_sel = factor(pred_sel, c("10%", "50%", "100%"))) |> | ||
mutate(pred = case_when(pred == "multinom" ~ "LR", | ||
pred == "rf" ~ "RF")) | ||
|
||
plot_pred_lab_acc <- function(acc, cohort, x_lab = FALSE){ | ||
if(x_lab){ | ||
x_lab <- "Percentage of most confidently\nlabelled cells selected" | ||
}else{ | ||
x_lab <- "" | ||
} | ||
|
||
filter(acc, mod == cohort) |> | ||
ggplot(aes(x = pred_sel, y = .estimate, fill = pred)) + | ||
geom_boxplot() + | ||
labs(x = x_lab, fill = "Self-\ntraining\nalgorithm", | ||
title = cohort, y = "F1-score") + | ||
scale_fill_manual(values = al_colours()) + | ||
facet_wrap(~cell_num, nrow = 1) + | ||
whatsthatcell_theme() + | ||
theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) | ||
} | ||
|
||
pred_lab_acc_plot <- plot_pred_lab_acc(pred_lab_acc_pub, "CyTOF") | | ||
plot_pred_lab_acc(pred_lab_acc_pub, "scRNASeq", TRUE) & theme(axis.title.y = element_blank()) | | ||
plot_pred_lab_acc(pred_lab_acc_pub, "snRNASeq") & theme(axis.title.y = element_blank()) | ||
|
||
|
||
|
||
## Benchmarking with predictive labelling data | ||
scrna <- read_tsv("output/v8/new/pred2/benchmark-predictive-labeling-scRNASeq.tsv") |> | ||
mutate(AL_alg = sub(".*-ALAlg-", "", selection_procedure), | ||
selection_procedure = sub("-ALAlg-.*", "", selection_procedure), | ||
cohort = "scRNASeq") | ||
snrna <- read_tsv("output/v8/new/pred2/benchmark-predictive-labeling-snRNASeq.tsv") |> | ||
mutate(AL_alg = sub(".*-ALAlg-", "", selection_procedure), | ||
selection_procedure = sub("-ALAlg-.*", "", selection_procedure), | ||
cohort = "snRNASeq") | ||
|
||
cytof <- read_tsv("output/v8/new/pred2/benchmark-predictive-labeling-CyTOF.tsv") |> | ||
mutate(AL_alg = sub(".*-ALAlg-", "", selection_procedure), | ||
selection_procedure = sub("-ALAlg-.*", "", selection_procedure), | ||
cohort = "CyTOF") | ||
|
||
acc <- bind_rows(scrna, snrna, cytof) | ||
|
||
cell_nums <- as.character(sort(unique(acc$cell_num))) | ||
acc$cell_num <- factor(acc$cell_num, levels = cell_nums) | ||
|
||
|
||
### Main figure - accuracies | ||
baseline_acc <- acc |> | ||
select(-rand, -corrupted, -.estimator, -pred_lab_alg) |> | ||
filter(cell_selection == "baseline") |> | ||
unite("method", c(method, knn, res, cell_num, initial, seed, | ||
selection_procedure, .metric, AL_alg, cohort), sep = ",") | ||
|
||
pred_lab_acc <- acc |> | ||
select(-rand, -corrupted, -.estimator) |> | ||
filter(cell_selection != "baseline") |> | ||
unite("method", c(method, knn, res, cell_num, initial, seed, selection_procedure, | ||
.metric, AL_alg, cohort), sep = ',') |> | ||
pivot_wider(names_from = pred_lab_alg, values_from = .estimate) | ||
|
||
acc_gap <- left_join(pred_lab_acc, | ||
select(baseline_acc, -cell_selection), | ||
by = "method") |> | ||
pivot_longer(c(multinom, rf), values_to = ".estimate_pred_lab", names_to = "pred_labeller") |> | ||
mutate(gap = ((.estimate_pred_lab / .estimate) * 100) - 100) |> | ||
separate(method, c("method", "knn", "res", "cell_num", "initial", "seed", | ||
"selection_procedure", ".metric", "AL_alg", "cohort"), sep = ",") | ||
|
||
|
||
# Which is better rf or LR? | ||
lr_vs_rf <- acc_gap |> | ||
mutate(pred_labeller = case_when(pred_labeller == "multinom" ~ "LR", | ||
pred_labeller == "rf" ~ "RF"), | ||
cell_selection = gsub("top", "", cell_selection), | ||
cell_selection = paste0(cell_selection, "%"), | ||
cell_selection = factor(cell_selection, levels = c("10%", "50%", "100%"))) |> | ||
filter(selection_procedure == "MarkerSeurat-clustering" & cell_num == 100, .metric == "f_meas") |> | ||
ggplot(aes(x = method, y = gap, fill = pred_labeller)) + | ||
geom_boxplot() + | ||
scale_fill_manual(values = c("#DA94D4", "#7EA3CC")) + | ||
labs(x = "Cell type assignment method", y = "% change in F1-score", | ||
fill = "Self-\ntraining\nalgorithm") + | ||
facet_wrap(~cohort + cell_selection, scales = "free", nrow = 1) + | ||
whatsthatcell_theme() + | ||
theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) | ||
|
||
# Compare improvement to original accuracy by selection procedure | ||
comp_dataset <- acc_gap |> | ||
filter(.metric == "f_meas", cell_num == 100) |> | ||
filter(initial == "random" | is.na(initial) | initial == "NA") |> | ||
filter(AL_alg == 'rf' | is.na(AL_alg) | AL_alg == "NA") |> | ||
filter(pred_labeller == "multinom" & cell_selection == "top50") |> | ||
mutate(selection_procedure = case_when(selection_procedure == "random" ~ "Random", | ||
selection_procedure == "MarkerSeurat-clustering" ~ "AR Marker", | ||
selection_procedure == "NoMarkerSeurat-clustering" ~ "AR No Marker", | ||
selection_procedure == "0.05-maxp-AL" ~ "AL 0.05 maxp", | ||
selection_procedure == "0.25-maxp-AL" ~ "AL 0.25 maxp", | ||
selection_procedure == "0.75-entropy-AL" ~ "AL 0.75 entropy", | ||
selection_procedure == "0.95-entropy-AL" ~ "AL 0.95 entropy", | ||
selection_procedure == "highest-entropy-AL" ~ "AL highest entropy", | ||
selection_procedure == "lowest-maxp-AL" ~ "AL lowest maxp")) | ||
|
||
plot_comp <- function(df, ncol, ylab = "", xlab = ""){ | ||
if(ylab != ""){ | ||
ylab <- "Baseline F1-score" | ||
} | ||
if(xlab != ""){ | ||
xlab <- "% Self-training improvement" | ||
} | ||
df |> | ||
ggplot(aes(x = gap, y = .estimate, colour = selection_procedure)) + | ||
geom_point() + | ||
labs(x = xlab, | ||
y = ylab, | ||
colour = "Selection procedure") + | ||
whatsthatcell_theme() + | ||
facet_wrap(~method, ncol = ncol) | ||
} | ||
|
||
cytof_gap <- comp_dataset |> | ||
filter(cohort == "CyTOF") |> | ||
plot_comp(1, ylab = "Label") + | ||
labs(title = "CyTOF") | ||
|
||
scrnaseq_gap <- comp_dataset |> | ||
filter(cohort == "scRNASeq") |> | ||
plot_comp(2, xlab = "label") + | ||
labs(title = "scRNASeq") | ||
|
||
snrnaseq_gap <- comp_dataset |> | ||
filter(cohort == "snRNASeq") |> | ||
plot_comp(2) + | ||
labs(title = "snRNASeq") | ||
|
||
gap_comb <- (cytof_gap | scrnaseq_gap | snrnaseq_gap) + | ||
plot_layout(guides = "collect", widths = c(0.65, 1, 1)) | ||
|
||
# Detecting mislabelled cells | ||
cytof <- list.files("output/v8/identify_mislabelled/CyTOF/", full.names = TRUE) | ||
scrna <- list.files("output/v8/identify_mislabelled/scRNASeq/", full.names = TRUE) | ||
snrna <- list.files("output/v8/identify_mislabelled/snRNASeq/", full.names = TRUE) | ||
|
||
mislabelled_pred <- lapply(c(cytof, scrna, snrna), function(x){ | ||
df <- read_tsv(x) | ||
probs <- select(df, -c(cell_id, pred_type, corr_cell_type, gt_cell_type, params)) | ||
|
||
df$entropy <- apply(probs, 1, calculate_entropy) | ||
df$entropy <- df$entropy / log(ncol(probs), 2) | ||
|
||
select(df, cell_id, entropy, pred_type, corr_cell_type, gt_cell_type, params) | ||
}) |> bind_rows() |> | ||
separate(params, c("rm_mod", "modality", "rm_pred", "predAlg", "rm_seed", "seed")) |> | ||
select(-starts_with("rm")) | ||
|
||
mislabelled_pred_plot <- mislabelled_pred |> | ||
mutate(cell_is_corrupt = corr_cell_type != gt_cell_type, | ||
predAlg = case_when(predAlg == "multinom" ~ "LR", | ||
predAlg == "rf" ~ "RF")) |> | ||
ggplot(aes(x = gt_cell_type, y = entropy, fill = cell_is_corrupt)) + | ||
geom_boxplot() + | ||
scale_fill_manual(values = c("#8B80F9", "#F03560")) + | ||
labs(x = "Ground truth cell type", y = "Scaled entropy", fill = "Cell corrupted\nduring training") + | ||
facet_grid(predAlg~modality, scales = "free") + | ||
whatsthatcell_theme() + | ||
theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) | ||
|
||
pdf("output/v8/paper-figures/pred-labelling.pdf", height = 14, width = 12) | ||
(wrap_elements(full = pred_lab_acc_plot + plot_layout(guides = "collect"))) / | ||
wrap_elements(full = lr_vs_rf & labs(title = "")) / | ||
wrap_elements(full = gap_comb) / | ||
mislabelled_pred_plot + | ||
plot_layout(heights = c(1, 1.1, 1.3, 1.4)) + | ||
plot_annotation(tag_levels = "A") | ||
dev.off() | ||
|
||
|
||
sup_acc_gap <- acc_gap |> | ||
mutate(selection_procedure = case_when(selection_procedure == "random" ~ "Random", | ||
selection_procedure == "MarkerSeurat-clustering" ~ "AR Marker", | ||
selection_procedure == "NoMarkerSeurat-clustering" ~ "AR No Marker", | ||
selection_procedure == "0.05-maxp-AL" ~ "AL 0.05 maxp", | ||
selection_procedure == "0.25-maxp-AL" ~ "AL 0.25 maxp", | ||
selection_procedure == "0.75-entropy-AL" ~ "AL 0.75 entropy", | ||
selection_procedure == "0.95-entropy-AL" ~ "AL 0.95 entropy", | ||
selection_procedure == "highest-entropy-AL" ~ "AL highest entropy", | ||
selection_procedure == "lowest-maxp-AL" ~ "AL lowest maxp")) |> | ||
mutate(pred_labeller = case_when(pred_labeller == "multinom" ~ "LR", | ||
pred_labeller == "rf" ~ "RF")) | ||
|
||
plot_sup_gap <- function(df, sel_cohort){ | ||
filter(df, cohort == sel_cohort & .metric == "f_meas") |> | ||
mutate(cell_selection = gsub("top", "", cell_selection), | ||
cell_selection = paste0(cell_selection, "%"), | ||
cell_selection = factor(cell_selection, levels = c("10%", "50%", "100%"))) |> | ||
ggplot(aes(x = method, y = gap, fill = pred_labeller)) + | ||
geom_boxplot() + | ||
scale_fill_manual(values = c("#DA94D4", "#7EA3CC")) + | ||
labs(x = "Cell type assignment method", y = "% change in F1-score", | ||
fill = "Self-\ntraining\nalgorithm") + | ||
facet_grid(cell_selection + cell_num~selection_procedure, scales = "free") + | ||
whatsthatcell_theme() + | ||
theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) | ||
} | ||
|
||
pdf("output/v8/paper-figures/Supp-CyTOF-f1-improvement.pdf", width = 12, height = 8) | ||
plot_sup_gap(sup_acc_gap, "CyTOF") | ||
dev.off() | ||
|
||
pdf("output/v8/paper-figures/Supp-scRNASeq-f1-improvement.pdf", width = 12, height = 8) | ||
plot_sup_gap(sup_acc_gap, "scRNASeq") | ||
dev.off() | ||
|
||
pdf("output/v8/paper-figures/Supp-snRNASeq-f1-improvement.pdf", width = 12, height = 8) | ||
plot_sup_gap(sup_acc_gap, "snRNASeq") | ||
dev.off() | ||
|
||
### As function of number of cells predictively labelled | ||
test <- bind_rows( | ||
select(pred_lab_acc, method, cell_selection, multinom) |> | ||
dplyr::rename(".estimate" = "multinom"), | ||
baseline_acc | ||
) |> | ||
filter(cell_selection != "top200") |> | ||
mutate(cell_selection = factor(cell_selection, c("baseline", "top10", "top50", "top100"))) |> | ||
#filter(method == "Random-Forest,10,0.4,100,NA,0,MarkerSeurat-clustering,kap,NA,scRNASeq") |> | ||
separate(method, c("alg", "knn", "res", "cell_num", "initial", "seed", | ||
"selection_procedure", ".metric", "AL_alg", "cohort"), sep = ",", | ||
remove = FALSE) |> | ||
filter(.metric == "sensitivity") |> | ||
filter(initial == "NA" | is.na(initial) | initial == "random") |> | ||
filter(AL_alg == "NA" | is.na(AL_alg) | AL_alg == "multinom") | ||
|
||
test |> | ||
filter(alg == "Random-Forest", selection_procedure == "random") |> | ||
ggplot(aes(x = cell_selection, y = .estimate, group = method, colour = cell_num)) + | ||
geom_point() + | ||
geom_line() + | ||
facet_grid(selection_procedure ~ cohort + alg) + | ||
whatsthatcell_theme() | ||
|
||
left_join(pred_lab_acc, | ||
select(baseline_acc, -cell_selection), | ||
by = "method") |> | ||
#pivot_wider(names_from = "cell_selection", values_from = "multinom") | ||
pivot_longer() | ||
|
||
med_gap <- acc_gap |> | ||
filter(.metric == "f_meas") |> | ||
filter(initial == "NA" | is.na(initial) | initial == "random") |> | ||
filter(AL_alg == "NA" | is.na(AL_alg) | AL_alg == "multinom") |> | ||
group_by(method, cell_num, selection_procedure, cohort, cell_selection, pred_labeller) |> | ||
summarize(median_gap = median(na.omit(gap))) | ||
|
||
med_gap |> | ||
filter(pred_labeller == "multinom" & method ) |> | ||
ggplot(aes(x = cell_selection, y = median_gap, colour = cell_num)) + | ||
geom_line() + | ||
facet_grid(selection_procedure ~ cohort + method) + | ||
whatsthatcell_theme() | ||
|
||
|
||
med_gap |> | ||
filter(cell_selection != "top200") |> | ||
mutate(cell_selection = factor(cell_selection, levels = c("top10", "top50", "top100"))) |> | ||
filter(pred_labeller == "multinom" & method == "CyTOF-LDA" & cell_num == 100, cohort == "CyTOF" & | ||
selection_procedure == "0.05-maxp-AL") |> | ||
ggplot(aes(x = cell_selection, y = median_gap)) + | ||
geom_point() | ||
#geom_line()# + | ||
# facet_grid(selection_procedure ~ cohort + method) + | ||
# whatsthatcell_theme() | ||
|