diff --git a/bin/check b/bin/check index ba11459..991d8f4 100755 --- a/bin/check +++ b/bin/check @@ -4,19 +4,32 @@ from __future__ import annotations import gzip import logging +import multiprocessing as mp import re from abc import ABC, abstractmethod from argparse import ArgumentParser, Namespace from collections import defaultdict +from contextlib import nullcontext from dataclasses import dataclass, field +from enum import Enum, auto +from functools import partial from glob import glob from math import isclose - -# from multiprocessing.dummy import Pool from pathlib import Path -from subprocess import call +from subprocess import STDOUT, check_output from tempfile import TemporaryDirectory -from typing import Literal +from typing import ( + Callable, + ContextManager, + Iterable, + Iterator, + Literal, + Sized, + Type, + TypeVar, +) + +from tqdm.auto import tqdm logger = logging.getLogger("check") @@ -28,7 +41,7 @@ chain_file_regex = re.compile(r"(?Phg\d+)ToHg19.over.chain.gz") def verbose_call(command: list[str]) -> None: logger.info(f"Running command line {command}") - call(command) + check_output(command, stderr=STDOUT) def parse_chromosome(chromosome_str: str) -> Chromosome | None: @@ -145,8 +158,10 @@ class Variant: return self.alleles == other.alleles def match_alleles(self, other: Variant) -> MatchResult: - assert self.chromosome == other.chromosome - assert self.position == other.position + if self.chromosome != other.chromosome: + raise ValueError("Chromosomes do not match") + if self.position != other.position: + raise ValueError("Positions do not match") if self.is_ambiguous: return self.match_alleles_ambiguous(other) @@ -229,8 +244,9 @@ class Variant: ambiguous_matched = self else: ambiguous_matched = self.flip() + if ambiguous_matched is None: + raise ValueError("Could not flip ambiguous variant") - assert ambiguous_matched is not None return MatchResult(variant=ambiguous_matched) def match_alleles_incomplete(self, other: Variant) -> MatchResult: @@ -291,10 +307,8 @@ def read_legend_file( header = legend_file_handle.readline() columns = header.split() - assert columns[0] == "id" - assert columns[1] == "position" - assert columns[2] == "a0" - assert columns[3] == "a1" + if columns[:4] != ["id", "position", "a0", "a1"]: + raise ValueError(f'Invalid header in legend file "{legend_file}"') for line in legend_file_handle: values = line.split() @@ -307,7 +321,11 @@ def read_legend_file( allele_a=allele_a, allele_b=allele_b, ) - assert variant.position is not None + if variant.position is None: + raise ValueError( + f'Invalid position "{position_str}" for variant "{id}" ' + f'from legend file "{legend_file}"' + ) legend_file_variants[variant.position].append(variant) legend_file_variant_count += 1 @@ -404,11 +422,12 @@ class FlipCommand(CommandBase): self.counters: dict[str, int] = defaultdict(int) - self.exclude: list[str] = list() + self.extract: list[str] = list() self.flip: list[str] = list() self.update_alleles: list[tuple[str, str, str, str, str]] = list() def _on_match(self, sample_variant: Variant, matched_variant: Variant): + self.extract.append(sample_variant.id) for action in matched_variant.actions: if action == "turn": pass # we leave this up to the imputation step @@ -449,7 +468,6 @@ class FlipCommand(CommandBase): position not in reference_variants or len(reference_variants[position]) == 0 ): - self.exclude.append(sample_variant.id) logger.debug( f"Excluding variant {sample_variant.id} " "because it could not be found in the reference" @@ -464,14 +482,14 @@ class FlipCommand(CommandBase): if match_result.variant is not None: break - assert match_result is not None + if match_result is None: + raise ValueError("Could not find matching variant in reference") matched_variant = match_result.variant if matched_variant is None or len(match_result.messages) > 0: print_variant = matched_variant if print_variant is None: print_variant = sample_variant - self.exclude.append(sample_variant.id) logger.debug( f"Excluding variant \n{print_variant} because it " "could not be matched to the reference variants \n" @@ -494,9 +512,9 @@ class FlipCommand(CommandBase): "\n".join(f"{key}\t{value:d}" for key, value in counters_lines) + "\n" ) - exclude_file = f"{prefix}.snp.exclude.txt" - with open(exclude_file, "wt") as exclude_file_handle: - exclude_file_handle.write("\n".join(self.exclude) + "\n") + extract_file = f"{prefix}.snp.txt" + with open(extract_file, "wt") as extract_file_handle: + extract_file_handle.write("\n".join(self.extract) + "\n") flip_file = f"{prefix}.snp.flip.txt" with open(flip_file, "wt") as flip_file_handle: @@ -514,8 +532,8 @@ class FlipCommand(CommandBase): update_alleles_file, "--bfile", bfile, - "--exclude", - exclude_file, + "--extract", + extract_file, "--flip", flip_file, "--out", @@ -641,7 +659,8 @@ class Hg19Command(CommandBase): if match_result.variant is not None: break - assert match_result is not None + if match_result is None: + raise ValueError("Could not find matching variant in reference") matched_variant = match_result.variant if matched_variant is None or len(match_result.messages) > 0: @@ -649,7 +668,7 @@ class Hg19Command(CommandBase): self.overlap_counts[version] += 1 - def reduce(self, bfile: str): + def reduce(self, bfile: str) -> None: overlap: dict[str, float] = { version: float(self.overlap_counts[version]) / float(len(variants)) for version, variants in self.lifted_variants.items() @@ -658,9 +677,11 @@ class Hg19Command(CommandBase): version, max_overlap = max(overlap.items(), key=lambda t: t[1]) if isclose(max_overlap, 0): - raise ValueError( - "The data has no overlap with any of the available reference genomes" + logger.error( + "No overlap with any of the available reference genomes" + f'for file "{bfile}"' ) + return logger.info( f'Sample is in "{version}" considering the overlap with the ' @@ -683,11 +704,11 @@ class Hg19Command(CommandBase): return lifted_variant_ids = frozenset(v.id for v in self.lifted_variants[version]) - exclude = [v.id for v in self.sample_variants if v.id not in lifted_variant_ids] + extract = [v.id for v in self.sample_variants if v.id in lifted_variant_ids] - exclude_file = f"{prefix}.snp.exclude.txt" - with open(exclude_file, "wt") as exclude_file_handle: - exclude_file_handle.write("\n".join(exclude) + "\n") + extract_file = f"{prefix}.snp.txt" + with open(extract_file, "wt") as extract_file_handle: + extract_file_handle.write("\n".join(extract) + "\n") new_variants = self.lifted_variants[version] @@ -700,8 +721,8 @@ class Hg19Command(CommandBase): "plink", "--bfile", bfile, - "--exclude", - exclude_file, + "--extract", + extract_file, "--chr", ",".join(chromosomes), "--out", @@ -710,11 +731,13 @@ class Hg19Command(CommandBase): ] ) - old_variants = read_bim_file(prefix) + extracted_variants = read_bim_file(prefix) - assert len(old_variants) == len(new_variants) - for old, new in zip(old_variants, new_variants): - assert old.id == new.id + if len(extracted_variants) > len(new_variants): + raise ValueError( + f"Number of variants after liftOver is larger " + f"the expected number: {len(extracted_variants)} != {len(new_variants)}" + ) bim_file = f"{prefix}.bim" write_bim_file(bim_file, new_variants) @@ -725,18 +748,23 @@ Command = Hg19Command | FlipCommand def parse_arguments(argv: list[str]) -> Namespace: argument_parser = ArgumentParser() - subparsers = argument_parser.add_subparsers() + + subparsers = argument_parser.add_subparsers(required=True) flip_subparser = subparsers.add_parser("flip") flip_subparser.set_defaults(command=FlipCommand) - flip_subparser.add_argument("--bfile", type=str, required=True) + flip_subparser.add_argument( + "--bfile", action="extend", nargs="+", type=str, required=True + ) flip_subparser.add_argument( "--legend-files", action="extend", nargs="+", type=str, required=False ) hg19_subparser = subparsers.add_parser("hg19") hg19_subparser.set_defaults(command=Hg19Command) - hg19_subparser.add_argument("--bfile", type=str, required=True) + hg19_subparser.add_argument( + "--bfile", action="extend", nargs="+", type=str, required=True + ) hg19_subparser.add_argument( "--legend-files", action="extend", nargs="+", type=str, required=False ) @@ -744,12 +772,66 @@ def parse_arguments(argv: list[str]) -> Namespace: "--chain-files", action="extend", nargs="+", type=str, required=False ) + argument_parser.add_argument("--num-threads", type=int, default=1) argument_parser.add_argument("--log-level", type=str, default="INFO") argument_parser.add_argument("--debug", action="store_true", default=False) return argument_parser.parse_args(argv) +def make_command( + constructor: Type[Command], + chain_files: list[str], + bfile: str, +) -> Command: + # Read sample + sample_variants = read_bim_file(bfile) + # Make command object + command = constructor(sample_variants, chain_files) + return command + + +class IterationOrder(Enum): + ORDERED = auto() + UNORDERED = auto() + + +T = TypeVar("T") +S = TypeVar("S") + + +def make_pool_or_null_context( + iterable: Iterable[T], + callable: Callable[[T], S], + num_threads: int = 1, + chunksize: int | None = 1, + iteration_order: IterationOrder = IterationOrder.UNORDERED, +) -> tuple[ContextManager, Iterator[S]]: + if num_threads < 2: + return nullcontext(), map(callable, iterable) + + if isinstance(iterable, Sized): + num_threads = min(len(iterable), num_threads) + # Apply logic from pool.map (multiprocessing/pool.py#L481) here as well + if chunksize is None: + chunksize, extra = divmod(len(iterable), num_threads * 4) + if extra: + chunksize += 1 + if chunksize is None: + chunksize = 1 + + pool = mp.Pool(processes=num_threads) + if iteration_order is IterationOrder.ORDERED: + map_function = pool.imap + elif iteration_order is IterationOrder.UNORDERED: + map_function = pool.imap_unordered + else: + raise ValueError(f"Unknown iteration order {iteration_order}") + output_iterator: Iterator = map_function(callable, iterable, chunksize) + cm: ContextManager = pool + return cm, output_iterator + + def run(arguments: Namespace) -> None: # Normalize args legend_files = arguments.legend_files @@ -777,32 +859,41 @@ def run(arguments: Namespace) -> None: '"--chain-files" argument' ) - # Read sample - sample_variants = read_bim_file(arguments.bfile) - sample_chromosomes = frozenset(variant.chromosome for variant in sample_variants) + bfiles: list[str] = arguments.bfile - # Instantiate commands - command = arguments.command(sample_variants, chain_files) + pool, commands_iterator = make_pool_or_null_context( + bfiles, + partial(make_command, arguments.command, chain_files), + num_threads=arguments.num_threads, + ) + with pool: + commands: list[Command] = [] + for command in tqdm(commands_iterator, total=len(bfiles), leave=False): + commands.append(command) # Read reference reference_chromosomes: set[Chromosome] = set() - # with Pool() as pool:pool.imap_unordered - for result in map(read_legend_file, legend_files): + for legend_file in tqdm(legend_files, leave=False): + result = read_legend_file(legend_file) if result is None: continue - chromosome, reference_variants = result reference_chromosomes.add(chromosome) + for command in commands: + command.map(chromosome, reference_variants) - command.map(chromosome, reference_variants) - - if not sample_chromosomes.issubset(reference_chromosomes): - missing_chromosomes = sample_chromosomes - reference_chromosomes - logger.warning( - "Missing reference data for sample chromosomes " f"{missing_chromosomes}. " + for bfile, command in zip(bfiles, commands): + sample_chromosomes = frozenset( + variant.chromosome for variant in command.sample_variants ) + if not sample_chromosomes.issubset(reference_chromosomes): + missing_chromosomes = sample_chromosomes - reference_chromosomes + logger.warning( + "Missing reference data for chromosomes " + f'{missing_chromosomes} for input file "{bfile}". ' + ) - command.reduce(arguments.bfile) + command.reduce(bfile) def main(argv: list[str]) -> None: