From 3411ee0e3205e4a68914a7282c96cf7f9852e3d4 Mon Sep 17 00:00:00 2001 From: Rashid Al-Abri <9377845+rashidalabri@users.noreply.github.com> Date: Wed, 7 Jun 2023 05:18:11 -0700 Subject: [PATCH] Add scripts and update README --- .gitignore | 12 ++ README.md | 149 +++++++++++++- scripts/generate_validation_sets.py | 270 +++++++++++++++++++++++++ scripts/metrics.py | 160 +++++++++++++++ scripts/optimize.py | 134 +++++++++++++ scripts/prioritize.py | 292 ++++++++++++++++++++++++++++ 6 files changed, 1016 insertions(+), 1 deletion(-) create mode 100644 scripts/generate_validation_sets.py create mode 100644 scripts/metrics.py create mode 100644 scripts/optimize.py create mode 100644 scripts/prioritize.py diff --git a/.gitignore b/.gitignore index de88b67..a2d100d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,15 @@ /target .cargo-ok .DS_Store + +# Project-specific files +datasets/ +output/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Jupyter Notebook +.ipynb_checkpoints diff --git a/README.md b/README.md index 25b3457..4ef993f 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,164 @@ # strif [![Crates.io](https://img.shields.io/crates/v/strif.svg)](https://crates.io/crates/strif) -[![Docs.rs](https://docs.rs/strif/badge.svg)](https://docs.rs/strif) [![CI](https://github.com/rashidalabri/strif/workflows/CI/badge.svg)](https://github.com/rashidalabri/strif/actions) ## Installation +### Download binaries + +Binaries for the tool can be found under the "[Releases](https://github.com/rashidalabri/strif/releases)" tab. + ### Cargo * Install the rust toolchain in order to have cargo installed by following [this](https://www.rust-lang.org/tools/install) guide. * run `cargo install strif` + + +## Usage + +### Sequence-graph alignment + +To generate a sequence-graph alignment of your sample to STR loci, use [ExpansionHunter](https://github.com/Illumina/ExpansionHunter). The tool will produce a `.realigned.bam` file for each sample. Instructions for running ExpansionHunter can be found [here](https://github.com/Illumina/ExpansionHunter/blob/master/docs/03_Usage.md). + +### Extracting repeat sequences + +To extract repeat sequences from an [ExpansionHunter](https://github.com/Illumina/ExpansionHunter) BAMlet (`.realigned.bam`), run the following command. If the output is not specified, the output will be saved in the same directory as the BAMlet with a `.repeat_seqs.tsv` suffix. + +``` +strif extract [OUTPUT] +``` + +### Profiling STR interruptions + +To profile STR interruptions from extracted repeat sequences (the output of `strif extract`), run the following command. The STR catalog needs to be in the same format as [these catalogs](https://github.com/Illumina/RepeatCatalogs). If the output path is not specified, the output will be saved in the same directory as the repeat sequences file with a `.strif_profile.tsv` suffix. +``` +strif profile [OPTIONS] [OUTPUT] [OUTPUT_ALIGNMENTS] +``` +#### Options +``` + -z Output visual alignments. Default is false + -f, --filter Filter locus IDs using a regular expression. Defaults to None. This is useful for filtering out loci that are not of interest + -A [default: 1] + -B [default: 8] + -O [default: 10] + -E [default: 1] +``` + +### Merging STR interruption profiles + +To merge STR interruption profiles from multiple samples, run the following command. If the output path is not specified, the output will be saved in the same directory as the manifest file with a `.merged_profiles.tsv` suffix. +``` +strif merge [OPTIONS] [OUTPUT] +``` + +- Manifest + - Tab-separated file with the following columns: + - Sample ID, sample status (case or control), path to STRIF profile + - Do not include a header + - Example + - ``` + DO45195_case case output/DO45195_case.strif_profile.tsv + DO45195_control control output/DO45195_control.strif_profile.tsv + DO45231_case case output/DO45231_case.strif_profile.tsv + DO45231_control control output/DO45231_control.strif_profile.tsv + ``` +- Read depths + - Tab-separated file with the following columns: + - Sample ID, read depth + - Do not include a header + - Example + - ``` + DO219580_case 73.15 + DO219580_control 34.47 + DO22836_case 69.76 + DO22836_control 35.62 + ``` +#### Options +``` + -f, --filter + Filter locus IDs using a regular expression. Defaults to None. This is useful for filtering out loci that are not of interest + -m, --min-read-count + Minimum read count to include in the merged profile. Defaults to 1. This is useful for filtering out loci with low coverage [default: 1] + -l, --read-length + The sequencing read length. Used for normalizing the interruption counts [default: 150] + -h, --help +``` + +### Prioritizing interruptions + +To find interruptions that display a significant difference between case and control samples, you can use `prioritize.py` in the `scripts` directory. + +The prioritization script expects Sample IDs to be formatted as follows: `_`. If a paired test is run using the `-t` option, then it is expected that each individual has exactly one case and one control file. + +``` +python prioritize.py +``` + +- Output file + - File containing information about all tested interruption, including p-values and effect sizes + - Does not include interruption counts +- Sig(nificant) output file + - File containining information about all interruptions with a p-value below the cut-off + - Includes interruption counts (helpful for plotting data) + +> Note: Currently, the script does not perform multiple hypothesis test correction. It is strongly recommended to independently perform this step. + +#### Options +``` + -n MIN_SAMPLES, --min-samples MIN_SAMPLES + Minimum number of samples per group (case or control) + -p P_VALUE_CUTOFF, --p-value-cutoff P_VALUE_CUTOFF + P-value cutoff + -t, --paired-test Enable paired test + -c CHUNK_SIZE, --chunk-size CHUNK_SIZE + Chunk size for reading merged profile + --no-progress Disable progress bars +``` + +### Generating validation datasets +You can generate simulate repeat sequences to validate and test STRIF using `generate_validation_sets.py` in the `scripts` directory. The only argument is a path to a directory, such as `datasets/` where the generated datasets will be created. + +``` +python generate_validation_sets.py +``` + +- Generated datasets + - `simple` + - Small dataset helpful for debugging + - `no_interruption` + - Repeat sequences containing no interruptions + - `basic_<1-6>` + - Small dataset useful for development + - `comprehensive_` + - Comprehensive dataset useful for optimizing parameters, validating and testing + - `disjoint_<1-6>` + - Dataset of disjoint interruptions where the interruption sequence does not include any bases from the repeat sequence + - `intersect_<1-6>` + - Dataset of intersecting interruptions where the interruption sequence includes at least one base from the repeat sequence + - `insert_<1-6>` + - Dataset of interruptions that have been inserted into the repeat sequence + - `substitute_<1-6>` + - Dataset of interruptions that have substituted one or more repeat sequence bases + +### Calculating performance metrics +You can calculate metrics on the generated datasets using `metrics.py` in the `scripts` directory. The only argument is a path to a directory, such as `datasets/` where the generated datasets was created. + +``` +python metrics.py +``` + +The script will output a file `overall_stats.tsv` in the dataset directory containing a summary of metrics on each dataset. + +### Optimizing alignment parameters +You can find optimal aligning parameters for `strif profile` by running `optimize.py` in the `scripts` directory. The only argument is a path to a dataset. This will be any directory within the datasets directory. It is recommended to run this on `datasets/comprehensive_train`. + +``` +python optimize.py / +``` + ## License Licensed under either of diff --git a/scripts/generate_validation_sets.py b/scripts/generate_validation_sets.py new file mode 100644 index 0000000..49805f6 --- /dev/null +++ b/scripts/generate_validation_sets.py @@ -0,0 +1,270 @@ +import os +from pathlib import Path +import sys +import pandas as pd +import random +import json + + +ALPHABET = ["A", "C", "G", "T"] +SEED = 42 + +DEFAULT_RANGES = { + "motif": (2, 6), + "seq": (10, 100), + "intrpt": (0, 6), +} + +N_SMALL = 10 +N_LARGE = 1000 +N_XLARGE = 10000 + + +def random_seq(length, alphabet, allow_homopolymer=True): + """ + Generate a random sequence of length n with given alphabet. + Ensures that resulting sequence is not a homopolymer. + """ + if length >= 2 and not allow_homopolymer: + motif = random.sample(alphabet, 2) + random.choices(alphabet, k=length - 2) + else: + motif = random.choices(alphabet, k=length) + return "".join(motif) + + +def rotate_seq(seq, n): + """ + Rotate sequence by `n` positions. + """ + return seq[n:] + seq[:n] + + +def repeat_seq(motif, n, rotate=True): + """ + Generate a repeat sequence of length `n`. + """ + n_repeat = n // len(motif) + 1 + seq = motif * n_repeat + if rotate: + rotate_n = random.randint(0, len(motif) - 1) + seq = rotate_seq(seq, rotate_n) + return seq[:n] + + +def simulate_repeat_seq( + motif_len_range=DEFAULT_RANGES["motif"], + seq_len_range=DEFAULT_RANGES["seq"], + intrpt_len_range=DEFAULT_RANGES["intrpt"], + insert=None, + intersect_alpha=None, + rotate=True, +): + motif_len = random.randint(*motif_len_range) + seq_len = random.randint(*seq_len_range) + intrpt_len = random.randint(*intrpt_len_range) + + # if true, + if intersect_alpha is None or intersect_alpha: + motif_alpha = ALPHABET + intrpt_alpha = ALPHABET + else: + # the motif and interruption sequences are generated from disjoint alphabets + motif_alpha_len = random.randint(2, len(ALPHABET) - 1) + motif_alpha = random.sample(ALPHABET, motif_alpha_len) + intrpt_alpha = [x for x in ALPHABET if x not in motif_alpha] + + # generate repeat sequence + motif = random_seq(motif_len, motif_alpha, allow_homopolymer=False) + seq = repeat_seq(motif, seq_len, rotate=rotate) + + if intersect_alpha is not None and intersect_alpha: + intrpt_alpha = list(set(motif)) + + # generate interruption sequence and position + intrpt = random_seq(intrpt_len, intrpt_alpha) + intrpt_pos = random.randint(1, len(seq) - len(intrpt) - 1) + + # if insert is not specified, randomly choose whether to insert or substitute + if insert is None: + insert = bool(random.getrandbits(1)) + + if insert: + # insert the interruption sequence into the repeat sequence + intrpt_seq = seq[:intrpt_pos] + intrpt + seq[intrpt_pos:] + intrpt_seq = intrpt_seq[:seq_len] + else: + # substitute the interruption sequence for a portion of the repeat sequence + intrpt_seq = seq[:intrpt_pos] + intrpt + seq[len(intrpt) + intrpt_pos :] + + if not insert and intrpt == seq[intrpt_pos : intrpt_pos + len(intrpt)]: + # if the interruption sequence is the same as sequence + # it is substituting, set the interruption sequence to be empty + # since there will not be an interruption technically + intrpt = "" + + return motif, intrpt, intrpt_seq + + +def generate_files(repeat_seqs, dir_path, prefix): + motifs = [x[0] for x in repeat_seqs] + intrpts = [x[1] for x in repeat_seqs] + seqs = [x[2] for x in repeat_seqs] + + # create output directory + os.makedirs(dir_path / prefix, exist_ok=True) + + # create file paths + truth_path = dir_path / prefix / f"{prefix}.truth.tsv" + repeat_seqs_path = dir_path / prefix / f"{prefix}.repeat_seqs.tsv" + str_catalog_path = dir_path / prefix / f"{prefix}.str_catalog.json" + + # create truth dataframe + df = pd.DataFrame( + { + "locus_id": range(len(repeat_seqs)), + "motif": motifs, + "interruption": intrpts, + "seq": seqs, + } + ) + + # write truth file + df.to_csv(truth_path, index=False, sep="\t") + + # write repeat_seqs file for tool input + df[["locus_id", "seq"]].to_csv(repeat_seqs_path, index=False, sep="\t", header=None) + + # create str_catalog file for tool input + str_catalog = [] + for _, row in df.iterrows(): + motif = row["motif"] + str_catalog.append( + {"LocusId": str(row["locus_id"]), "LocusStructure": f"({motif})*"} + ) + + # write str_catalog file + with open(str_catalog_path, "w") as f: + json.dump(str_catalog, f, indent=4) + + +def generate_dataset(n, prefix, prefixes, dir_path, **kwargs): + repeat_seqs = [simulate_repeat_seq(**kwargs) for _ in range(n)] + generate_files(repeat_seqs, dir_path, prefix) + prefixes.append(prefix) + print(f"Generated {prefix} dataset") + + +def main(): + random.seed(SEED) + dir_path = Path(sys.argv[1]) + + prefixes = [] + + # simple dataset + generate_dataset( + N_SMALL, + "simple", + prefixes, + dir_path, + motif_len_range=(2, 2), + seq_len_range=(10, 10), + intrpt_len_range=(1, 1), + intersect_alpha=False, + insert=True, + rotate=False, + ) + + # no interruption dataset + generate_dataset( + N_SMALL, + "no_interruption", + prefixes, + dir_path, + intrpt_len_range=(0, 0), + ) + + # disjoint alphabet validation set (with variable interruption length) + for i in range(1, 7): + generate_dataset( + N_LARGE, + f"disjoint_{i}", + prefixes, + dir_path, + intrpt_len_range=(i, i), + intersect_alpha=False, + ) + + # intersecting alphabet validation set (with variable interruption length) + for i in range(1, 7): + generate_dataset( + N_LARGE, + f"intersect_{i}", + prefixes, + dir_path, + intrpt_len_range=(i, i), + intersect_alpha=True, + ) + + # insertion validation set + for i in range(1, 7): + generate_dataset( + N_LARGE, + f"insert_{i}", + prefixes, + dir_path, + intrpt_len_range=(i, i), + insert=True, + ) + + # substitution validation set + for i in range(1, 7): + generate_dataset( + N_LARGE, + f"substitute_{i}", + prefixes, + dir_path, + intrpt_len_range=(i, i), + insert=False, + ) + + # basic dataset + for i in range(1, 7): + generate_dataset( + N_LARGE, + f"basic_{i}", + prefixes, + dir_path, + intrpt_len_range=(i, i), + ) + + # comprehensive set for training + generate_dataset( + N_LARGE, + "comprehensive_train", + prefixes, + dir_path, + ) + + # comprehensive set for validation + generate_dataset( + N_LARGE, + "comprehensive_valid", + prefixes, + dir_path, + ) + + # comprehensive set for testing + generate_dataset( + N_LARGE, + "comprehensive_test", + prefixes, + dir_path, + ) + + # write prefixes file + with open(dir_path / "prefixes.txt", "w") as f: + f.write("\n".join(prefixes)) + + +if __name__ == "__main__": + main() diff --git a/scripts/metrics.py b/scripts/metrics.py new file mode 100644 index 0000000..76d2a13 --- /dev/null +++ b/scripts/metrics.py @@ -0,0 +1,160 @@ +from pathlib import Path +import subprocess +import pandas as pd +import sys + + +def run_command( + repeat_seqs, + str_catalog, + output, + match_score=None, + mismatch_score=None, + gap_open_score=None, + gap_extend_score=None, + visualize=False, +): + command = f"cargo run --release -- profile {repeat_seqs} {str_catalog} {output}" + if match_score is not None: + command += f" -A={match_score}" + if mismatch_score is not None: + command += f" -B={mismatch_score}" + if gap_open_score is not None: + command += f" -O={gap_open_score}" + if gap_extend_score is not None: + command += f" -E={gap_extend_score}" + if visualize: + command += " -z" + subprocess.run(command, shell=True, stderr=subprocess.DEVNULL) + return command + + +def create_stat_df(truth_path, profile_path): + truth = pd.read_csv(truth_path, sep="\t") + profile = pd.read_csv(profile_path, sep="\t") + + truth = truth.rename(columns={"interruption": "true_interruption"}) + profile = profile.rename(columns={"interruption": "pred_interruption"}) + + merged = pd.merge(truth, profile, on=["locus_id", "motif"], how="outer") + merged = merged.fillna("") + merged = merged.drop(columns=["seq", "count", "total_count"]) + + merged["exact_match"] = merged["true_interruption"] == merged["pred_interruption"] + merged["exact_match"] = merged["exact_match"].astype(int) + + # inexact match is true if true interruption is a substring of predicted interruption + merged["inexact_match"] = merged.apply( + lambda row: row["true_interruption"] in row["pred_interruption"], axis=1 + ) + merged["inexact_match"] = merged["inexact_match"].astype(int) + + return merged + + +def compute_mean_precision(stat_df, match_col): + mean_precision = ( + stat_df.groupby("locus_id").agg({match_col: "mean"})[match_col].mean() + ) + return mean_precision + + +def compute_mean_recall(stat_df, match_col): + mean_precision = ( + stat_df.groupby("locus_id").agg({match_col: "sum"})[match_col].mean() + ) + return mean_precision + + +def compute_metrics(stats_df): + exact_mean_precision = compute_mean_precision(stats_df, "exact_match") + exact_mean_recall = compute_mean_recall(stats_df, "exact_match") + + inexact_mean_precision = compute_mean_precision(stats_df, "inexact_match") + inexact_mean_recall = compute_mean_recall(stats_df, "inexact_match") + + return ( + exact_mean_precision, + exact_mean_recall, + inexact_mean_precision, + inexact_mean_recall, + ) + + +def main(): + dir_path = Path(sys.argv[1]) + + run_test = False + if len(sys.argv) > 2: + run_test = sys.argv[2] == "test" + + # load prefixes file + prefixes = [] + with open(dir_path / "prefixes.txt", "r") as f: + for line in f: + prefixes.append(line.strip()) + + overall_stats = [] + + # run tool on each prefix + for prefix in prefixes: + # don't run test sets if not specified + if "test" in prefix and not run_test: + print(f"Skipping {prefix}") + continue + + run_command( + dir_path / prefix / f"{prefix}.repeat_seqs.tsv", + dir_path / prefix / f"{prefix}.str_catalog.json", + dir_path / prefix / f"{prefix}.strif_profile.tsv", + visualize=True, + ) + + stat_df = create_stat_df( + dir_path / prefix / f"{prefix}.truth.tsv", + dir_path / prefix / f"{prefix}.strif_profile.tsv", + ) + stat_df.to_csv( + dir_path / prefix / f"{prefix}.compare.tsv", sep="\t", index=False + ) + + ( + exact_mean_prec, + exact_mean_rec, + inexact_mean_prec, + inexact_mean_rec, + ) = compute_metrics(stat_df) + + overall_stats.append( + ( + prefix, + exact_mean_prec, + exact_mean_rec, + inexact_mean_prec, + inexact_mean_rec, + ) + ) + + print(f"Completed {prefix}:") + print( + f" - Exact mean precision, recall: {exact_mean_prec:.2f}, {exact_mean_rec:.2f}" + ) + print( + f" - Inexact mean precision, recall: {inexact_mean_prec:.2f}, {inexact_mean_rec:.2f}" + ) + + overall_stats_df = pd.DataFrame( + overall_stats, + columns=[ + "prefix", + "exact_precision", + "exact_recall", + "inexact_precision", + "inexact_recall", + ], + ) + overall_stats_df.to_csv(dir_path / f"overall_stats.tsv", sep="\t", index=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/optimize.py b/scripts/optimize.py new file mode 100644 index 0000000..66c41d0 --- /dev/null +++ b/scripts/optimize.py @@ -0,0 +1,134 @@ +import errno +import os +import sys +from itertools import product +from pathlib import Path +import time + +import pandas as pd +from metrics import compute_metrics, create_stat_df, run_command + +CORES = 10 + +MATCH_SCORE_RANGE = list(range(1, 2)) +MISMATCH_PENALTY_RANGE = list(range(1, 11)) +GAP_OPEN_PENALTY_RANGE = list(range(1, 11)) +GAP_EXTEND_PENALTY_RANGE = list(range(1, 11)) + + +def start_process(id, params, valid_dir_path, prefix, tmp_dir_path, results_file): + pid = os.fork() + if pid == 0: + print(f"[Process {id}] Starting...", flush=True) + + truth_path = valid_dir_path / prefix / f"{prefix}.truth.tsv" + repeat_seqs_path = valid_dir_path / prefix / f"{prefix}.repeat_seqs.tsv" + str_catalog_path = valid_dir_path / prefix / f"{prefix}.str_catalog.json" + + for i, (m, x, g, e) in enumerate(params): + profile_path = tmp_dir_path / prefix / f"{prefix}_proc{id}_{i}.tmp" + + if i % 100 == 0: + percent_complete = int(i / len(params) * 100) + print(f"[Process {id}] {percent_complete}%", flush=True) + + cmd = run_command( + repeat_seqs_path, str_catalog_path, profile_path, m, x, g, e + ) + time.sleep(0.01) + + try: + stat_df = create_stat_df(truth_path, profile_path) + metrics = "\t".join([str(m) for m in compute_metrics(stat_df)]) + results_file.write(f"{m}\t{x}\t{g}\t{e}\t{metrics}\n") + results_file.flush() + except OSError as e: + if e.errno != errno.ENOENT: + raise + print(f"[Process {id}] Failed to run command: {cmd}", flush=True) + + # delete profile file + try: + os.remove(profile_path) + except OSError as e: + if e.errno != errno.ENOENT: + raise + + # when finished, exit the child process + print(f"[Process {id}] Finished", flush=True) + results_file.close() + exit(0) + else: + return pid + + +def sort_results(unsorted_file_path, sorted_file_path): + # sort the results using pandas + print("Sorting results...", flush=True) + results = pd.read_csv(unsorted_file_path, sep="\t") + results["mean_recall"] = (results["exact_recall"] + results["inexact_recall"]) / 2 + results.sort_values("mean_recall", ascending=False, inplace=True) + results.to_csv(sorted_file_path, sep="\t", index=False) + + results = results.head(10) + + # print average of each column + print("----------------------------------------", flush=True) + print("Avg. of top 10 parameters:", flush=True) + print(results.mean(axis=0), flush=True) + print("----------------------------------------", flush=True) + + # delete unsorted file + os.remove(unsorted_file_path) + + +def perform_param_grid_search(params, valid_dir_path, prefix, tmp_dir_path, cores): + print(f"Testing {len(params)} combinations using {cores} cores...") + + with open( + valid_dir_path / prefix / f"{prefix}.param_search.unsorted.tsv", "w" + ) as f: + f.write( + "match_score\tmismatch_penalty\tgap_open_penalty\tgap_extend_penalty\texact_precision\texact_recall\tinexact_precision\tinexact_recall\n" + ) + f.flush() + + batch_size = len(params) // cores + 1 + + for i in range(cores): + start_idx = i * batch_size + batch = params[start_idx : start_idx + batch_size] + start_process(i + 1, batch, valid_dir_path, prefix, tmp_dir_path, f) + + for _ in range(cores): + os.wait() + + sort_results( + valid_dir_path / prefix / f"{prefix}.param_search.unsorted.tsv", + valid_dir_path / prefix / f"{prefix}.param_search.tsv", + ) + + print("Done", flush=True) + + +def main(): + params = list( + product( + MATCH_SCORE_RANGE, + MISMATCH_PENALTY_RANGE, + GAP_OPEN_PENALTY_RANGE, + GAP_EXTEND_PENALTY_RANGE, + ) + ) + + # remove params where gap_open_penalty < gap_extend_penalty + params = [(m, x, g, e) for m, x, g, e in params if x > m and g > e] + + valid_dir_path = Path(sys.argv[1]) + tmp_dir_path = valid_dir_path + prefix = sys.argv[2] + perform_param_grid_search(params, valid_dir_path, prefix, tmp_dir_path, CORES) + + +if __name__ == "__main__": + main() diff --git a/scripts/prioritize.py b/scripts/prioritize.py new file mode 100644 index 0000000..10a1281 --- /dev/null +++ b/scripts/prioritize.py @@ -0,0 +1,292 @@ +import argparse +from collections import defaultdict + +import pandas as pd +import scipy.stats +import numpy as np +from numpy import mean, sqrt, std +from tqdm.auto import tqdm + + +def cohen_d(cases, controls): + """ + Source: https://stackoverflow.com/questions/21532471/how-to-calculate-cohens-d-in-python + """ + nx = len(cases) + ny = len(controls) + dof = nx + ny - 2 + return (mean(cases) - mean(controls)) / sqrt( + ((nx - 1) * std(cases, ddof=1) ** 2 + (ny - 1) * std(controls, ddof=1) ** 2) + / dof + ) + + +def file_len(fname): + with open(fname) as f: + for i, _ in enumerate(f): + pass + return i + 1 + + +def main(): + # Create the argument parser + parser = argparse.ArgumentParser(description="Prioritize interruptions") + parser.add_argument("merged_profile", type=str, help="Path to merged profile") + parser.add_argument( + "output_file", type=str, help="Path to output file containing all interruptions" + ) + parser.add_argument( + "sig_output_file", + type=str, + help="Path to output file containing interruptions that pass the p-value cutoff with corresponding normalized counts", + ) + parser.add_argument( + "-n", + "--min-samples", + type=int, + default=2, + help="Minimum number of samples per group (case or control)", + ) + parser.add_argument( + "-p", "--p-value-cutoff", type=float, default=0.05, help="P-value cutoff" + ) + parser.add_argument( + "-t", "--paired-test", action="store_true", help="Enable paired test" + ) + parser.add_argument( + "-c", + "--chunk-size", + type=int, + default=5000, + help="Chunk size for reading merged profile", + ) + parser.add_argument( + "--no-progress", + action="store_true", + help="Disable progress bars", + ) + + # Parse the arguments + args = parser.parse_args() + + # Access the parsed arguments + merged_profile_path = args.merged_profile + output_file_path = args.output_file + sig_output_file_path = args.sig_output_file + min_samples = args.min_samples + p_value_cutoff = args.p_value_cutoff + paired_test = args.paired_test + chunk_size = args.chunk_size + progress_bar = args.no_progress + + # Load merged profile + merged_profile = pd.read_csv(merged_profile_path, sep="\t", chunksize=chunk_size) + num_chunks = int((file_len(merged_profile_path) - 1) / chunk_size + 1) + + if paired_test: + print("Paired test enabled") + else: + print("Paired test disabled") + + n_skipped_interruptions = 0 + n_skipped_loci = 0 + + output = [] + + for chunk in tqdm( + merged_profile, + total=num_chunks, + desc="Chunks", + position=0, + disable=progress_bar, + ): + # Loop over the loci in the merged profile chunk + for _, row in tqdm( + chunk.iterrows(), + total=len(chunk), + desc="Loci", + position=1, + leave=False, + disable=progress_bar, + ): + # We exclude loci that do not have enough samples included + # We do an initial check here to avoid unnecessary computation + read_counts = row["read_counts"].split(",") + if len(read_counts) < min_samples * 2: + n_skipped_loci += 1 + continue + + # We find the donors that are included in the analysis of this locus by + # looking at which samples have read counts + case_donors = [] + control_donors = [] + + for entry in read_counts: + sample_id, count = entry.split(":") + donor_id, status = sample_id.split("_") + + if status == "case": + if donor_id in case_donors: + raise Exception(f"Duplicate case sample: {sample_id}") + case_donors.append(donor_id) + elif status == "control": + if donor_id in control_donors: + raise Exception(f"Duplicate control sample: {sample_id}") + control_donors.append(donor_id) + else: + raise Exception(f"Unknown sample status: {sample_id}") + + # If we're doing a paired test, exclude donors that do not have a paired case/control sample + if paired_test: + paired_donors = set(case_donors).intersection(set(control_donors)) + paired_donors = list(paired_donors) + case_donors = paired_donors + control_donors = paired_donors + + case_donors.sort() + control_donors.sort() + + # We skip loci that do not have enough donors included + if len(case_donors) < min_samples or len(control_donors) < min_samples: + n_skipped_loci += 1 + continue + + # Create a dict of interruption unit -> status -> donor_id -> count + interruptions = defaultdict( + lambda: { + "case": {d: 0 for d in case_donors}, + "control": {d: 0 for d in control_donors}, + } + ) + + intrp_units_to_skip = set() + + # Fill the interruptions dict from the interruption_counts column + # which contains info for all interruptions in the locus + for entry in row["interruption_counts"].split(","): + sample_id, intrpt_unit, count = entry.split(":") + donor_id, status = sample_id.split("_") + if status not in ["case", "control"]: + raise Exception("Unknown sample status: {}".format(sample_id)) + elif status == "case" and donor_id not in case_donors: + continue + elif status == "control" and donor_id not in control_donors: + continue + + count = float(count) + if ( + pd.isna(count) + or count == float("inf") + or count == float("-inf") + or count < 0 + ): + intrp_units_to_skip.add(intrpt_unit) + + interruptions[intrpt_unit][status][donor_id] = count + + # For each interruption unit, we calculate the p-value and cohen's d + for intrpt_unit, counts in tqdm( + interruptions.items(), + total=len(interruptions), + desc="Interruptions", + position=2, + leave=False, + disable=progress_bar, + ): + # Skip interruptions that have invalid counts + if intrpt_unit in intrp_units_to_skip: + n_skipped_interruptions += 1 + continue + + case_counts = [counts["case"][d] for d in case_donors] + control_counts = [counts["control"][d] for d in control_donors] + + # calculate p-value + if paired_test: + _, p_value = scipy.stats.wilcoxon( + case_counts, control_counts, alternative="two-sided" + ) + else: + _, p_value = scipy.stats.mannwhitneyu( + case_counts, control_counts, alternative="two-sided" + ) + + # Calculate cohen's d + cohen_d_value = cohen_d(case_counts, control_counts) + + if pd.isna(p_value): + print( + f"Warning: NaN p-value for {row['locus_id']} and '{intrpt_unit}' interruption. Skipping..." + ) + n_skipped_interruptions += 1 + continue + elif pd.isna(cohen_d_value): + print( + f"Warning: NaN Cohen's d for {row['locus_id']} and '{intrpt_unit}' interruption. Skipping..." + ) + n_skipped_interruptions += 1 + continue + + # We only output the counts if the p-value is below the cutoff + if p_value < p_value_cutoff: + read_counts_str = row["read_counts"] + interruption_counts_str = [] + for d in case_donors: + interruption_counts_str.append(f"{d}_case:{counts['case'][d]}") + for d in control_donors: + interruption_counts_str.append( + f"{d}_control:{counts['control'][d]}" + ) + interruption_counts_str = ",".join(interruption_counts_str) + else: + read_counts_str = "" + interruption_counts_str = "" + + output.append( + ( + row["locus_id"], + row["reference_region"], + row["motif"], + intrpt_unit, + len(case_donors), + len(control_donors), + p_value, + cohen_d_value, + read_counts_str, + interruption_counts_str, + ) + ) + + # create a dataframe from the output + output_cols = [ + "locus_id", + "reference_region", + "motif", + "interruption", + "n_case", + "n_control", + "p_value", + "cohen_d", + "read_counts", + "interruption_counts", + ] + output_df = pd.DataFrame(output, columns=output_cols) + + # sort by p-value + output_df.sort_values(by=["p_value"], inplace=True, ignore_index=True) + + # write to file + output_df[output_cols[:-2]].to_csv(output_file_path, sep="\t", index=False) + + # only keep interruptions with p-value < cutoff + output_df = output_df[output_df["p_value"] < p_value_cutoff] + + # write to file + output_df.to_csv(sig_output_file_path, sep="\t", index=False) + + print(f"Skipped {n_skipped_loci} loci") + print(f"Skipped {n_skipped_interruptions} interruptions") + print("Done!") + + +main()