Skip to content

Commit

Permalink
Add multiprocessing to check
Browse files Browse the repository at this point in the history
Allow the processing of multiple bfiles in parallel
  • Loading branch information
HippocampusGirl committed Feb 8, 2024
1 parent 18b8b9d commit 669f612
Showing 1 changed file with 145 additions and 54 deletions.
199 changes: 145 additions & 54 deletions bin/check
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,32 @@ from __future__ import annotations

import gzip
import logging
import multiprocessing as mp
import re
from abc import ABC, abstractmethod
from argparse import ArgumentParser, Namespace
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass, field
from enum import Enum, auto
from functools import partial
from glob import glob
from math import isclose

# from multiprocessing.dummy import Pool
from pathlib import Path
from subprocess import call
from subprocess import STDOUT, check_output
from tempfile import TemporaryDirectory
from typing import Literal
from typing import (
Callable,
ContextManager,
Iterable,
Iterator,
Literal,
Sized,
Type,
TypeVar,
)

from tqdm.auto import tqdm

logger = logging.getLogger("check")

Expand All @@ -28,7 +41,7 @@ chain_file_regex = re.compile(r"(?P<version>hg\d+)ToHg19.over.chain.gz")

def verbose_call(command: list[str]) -> None:
logger.info(f"Running command line {command}")
call(command)
check_output(command, stderr=STDOUT)


def parse_chromosome(chromosome_str: str) -> Chromosome | None:
Expand Down Expand Up @@ -145,8 +158,10 @@ class Variant:
return self.alleles == other.alleles

def match_alleles(self, other: Variant) -> MatchResult:
assert self.chromosome == other.chromosome
assert self.position == other.position
if self.chromosome != other.chromosome:
raise ValueError("Chromosomes do not match")
if self.position != other.position:
raise ValueError("Positions do not match")

if self.is_ambiguous:
return self.match_alleles_ambiguous(other)
Expand Down Expand Up @@ -229,8 +244,9 @@ class Variant:
ambiguous_matched = self
else:
ambiguous_matched = self.flip()
if ambiguous_matched is None:
raise ValueError("Could not flip ambiguous variant")

assert ambiguous_matched is not None
return MatchResult(variant=ambiguous_matched)

def match_alleles_incomplete(self, other: Variant) -> MatchResult:
Expand Down Expand Up @@ -291,10 +307,8 @@ def read_legend_file(
header = legend_file_handle.readline()
columns = header.split()

assert columns[0] == "id"
assert columns[1] == "position"
assert columns[2] == "a0"
assert columns[3] == "a1"
if columns[:4] != ["id", "position", "a0", "a1"]:
raise ValueError(f'Invalid header in legend file "{legend_file}"')

for line in legend_file_handle:
values = line.split()
Expand All @@ -307,7 +321,11 @@ def read_legend_file(
allele_a=allele_a,
allele_b=allele_b,
)
assert variant.position is not None
if variant.position is None:
raise ValueError(
f'Invalid position "{position_str}" for variant "{id}" '
f'from legend file "{legend_file}"'
)

legend_file_variants[variant.position].append(variant)
legend_file_variant_count += 1
Expand Down Expand Up @@ -404,11 +422,12 @@ class FlipCommand(CommandBase):

self.counters: dict[str, int] = defaultdict(int)

self.exclude: list[str] = list()
self.extract: list[str] = list()
self.flip: list[str] = list()
self.update_alleles: list[tuple[str, str, str, str, str]] = list()

def _on_match(self, sample_variant: Variant, matched_variant: Variant):
self.extract.append(sample_variant.id)
for action in matched_variant.actions:
if action == "turn":
pass # we leave this up to the imputation step
Expand Down Expand Up @@ -449,7 +468,6 @@ class FlipCommand(CommandBase):
position not in reference_variants
or len(reference_variants[position]) == 0
):
self.exclude.append(sample_variant.id)
logger.debug(
f"Excluding variant {sample_variant.id} "
"because it could not be found in the reference"
Expand All @@ -464,14 +482,14 @@ class FlipCommand(CommandBase):
if match_result.variant is not None:
break

assert match_result is not None
if match_result is None:
raise ValueError("Could not find matching variant in reference")
matched_variant = match_result.variant

if matched_variant is None or len(match_result.messages) > 0:
print_variant = matched_variant
if print_variant is None:
print_variant = sample_variant
self.exclude.append(sample_variant.id)
logger.debug(
f"Excluding variant \n{print_variant} because it "
"could not be matched to the reference variants \n"
Expand All @@ -494,9 +512,9 @@ class FlipCommand(CommandBase):
"\n".join(f"{key}\t{value:d}" for key, value in counters_lines) + "\n"
)

exclude_file = f"{prefix}.snp.exclude.txt"
with open(exclude_file, "wt") as exclude_file_handle:
exclude_file_handle.write("\n".join(self.exclude) + "\n")
extract_file = f"{prefix}.snp.txt"
with open(extract_file, "wt") as extract_file_handle:
extract_file_handle.write("\n".join(self.extract) + "\n")

flip_file = f"{prefix}.snp.flip.txt"
with open(flip_file, "wt") as flip_file_handle:
Expand All @@ -514,8 +532,8 @@ class FlipCommand(CommandBase):
update_alleles_file,
"--bfile",
bfile,
"--exclude",
exclude_file,
"--extract",
extract_file,
"--flip",
flip_file,
"--out",
Expand Down Expand Up @@ -641,15 +659,16 @@ class Hg19Command(CommandBase):
if match_result.variant is not None:
break

assert match_result is not None
if match_result is None:
raise ValueError("Could not find matching variant in reference")
matched_variant = match_result.variant

if matched_variant is None or len(match_result.messages) > 0:
continue

self.overlap_counts[version] += 1

def reduce(self, bfile: str):
def reduce(self, bfile: str) -> None:
overlap: dict[str, float] = {
version: float(self.overlap_counts[version]) / float(len(variants))
for version, variants in self.lifted_variants.items()
Expand All @@ -658,9 +677,11 @@ class Hg19Command(CommandBase):
version, max_overlap = max(overlap.items(), key=lambda t: t[1])

if isclose(max_overlap, 0):
raise ValueError(
"The data has no overlap with any of the available reference genomes"
logger.error(
"No overlap with any of the available reference genomes"
f'for file "{bfile}"'
)
return

logger.info(
f'Sample is in "{version}" considering the overlap with the '
Expand All @@ -683,11 +704,11 @@ class Hg19Command(CommandBase):
return

lifted_variant_ids = frozenset(v.id for v in self.lifted_variants[version])
exclude = [v.id for v in self.sample_variants if v.id not in lifted_variant_ids]
extract = [v.id for v in self.sample_variants if v.id in lifted_variant_ids]

exclude_file = f"{prefix}.snp.exclude.txt"
with open(exclude_file, "wt") as exclude_file_handle:
exclude_file_handle.write("\n".join(exclude) + "\n")
extract_file = f"{prefix}.snp.txt"
with open(extract_file, "wt") as extract_file_handle:
extract_file_handle.write("\n".join(extract) + "\n")

new_variants = self.lifted_variants[version]

Expand All @@ -700,8 +721,8 @@ class Hg19Command(CommandBase):
"plink",
"--bfile",
bfile,
"--exclude",
exclude_file,
"--extract",
extract_file,
"--chr",
",".join(chromosomes),
"--out",
Expand All @@ -710,11 +731,13 @@ class Hg19Command(CommandBase):
]
)

old_variants = read_bim_file(prefix)
extracted_variants = read_bim_file(prefix)

assert len(old_variants) == len(new_variants)
for old, new in zip(old_variants, new_variants):
assert old.id == new.id
if len(extracted_variants) > len(new_variants):
raise ValueError(
f"Number of variants after liftOver is larger "
f"the expected number: {len(extracted_variants)} != {len(new_variants)}"
)

bim_file = f"{prefix}.bim"
write_bim_file(bim_file, new_variants)
Expand All @@ -725,31 +748,90 @@ Command = Hg19Command | FlipCommand

def parse_arguments(argv: list[str]) -> Namespace:
argument_parser = ArgumentParser()
subparsers = argument_parser.add_subparsers()

subparsers = argument_parser.add_subparsers(required=True)

flip_subparser = subparsers.add_parser("flip")
flip_subparser.set_defaults(command=FlipCommand)
flip_subparser.add_argument("--bfile", type=str, required=True)
flip_subparser.add_argument(
"--bfile", action="extend", nargs="+", type=str, required=True
)
flip_subparser.add_argument(
"--legend-files", action="extend", nargs="+", type=str, required=False
)

hg19_subparser = subparsers.add_parser("hg19")
hg19_subparser.set_defaults(command=Hg19Command)
hg19_subparser.add_argument("--bfile", type=str, required=True)
hg19_subparser.add_argument(
"--bfile", action="extend", nargs="+", type=str, required=True
)
hg19_subparser.add_argument(
"--legend-files", action="extend", nargs="+", type=str, required=False
)
hg19_subparser.add_argument(
"--chain-files", action="extend", nargs="+", type=str, required=False
)

argument_parser.add_argument("--num-threads", type=int, default=1)
argument_parser.add_argument("--log-level", type=str, default="INFO")
argument_parser.add_argument("--debug", action="store_true", default=False)

return argument_parser.parse_args(argv)


def make_command(
constructor: Type[Command],
chain_files: list[str],
bfile: str,
) -> Command:
# Read sample
sample_variants = read_bim_file(bfile)
# Make command object
command = constructor(sample_variants, chain_files)
return command


class IterationOrder(Enum):
ORDERED = auto()
UNORDERED = auto()


T = TypeVar("T")
S = TypeVar("S")


def make_pool_or_null_context(
iterable: Iterable[T],
callable: Callable[[T], S],
num_threads: int = 1,
chunksize: int | None = 1,
iteration_order: IterationOrder = IterationOrder.UNORDERED,
) -> tuple[ContextManager, Iterator[S]]:
if num_threads < 2:
return nullcontext(), map(callable, iterable)

if isinstance(iterable, Sized):
num_threads = min(len(iterable), num_threads)
# Apply logic from pool.map (multiprocessing/pool.py#L481) here as well
if chunksize is None:
chunksize, extra = divmod(len(iterable), num_threads * 4)
if extra:
chunksize += 1
if chunksize is None:
chunksize = 1

pool = mp.Pool(processes=num_threads)
if iteration_order is IterationOrder.ORDERED:
map_function = pool.imap
elif iteration_order is IterationOrder.UNORDERED:
map_function = pool.imap_unordered
else:
raise ValueError(f"Unknown iteration order {iteration_order}")
output_iterator: Iterator = map_function(callable, iterable, chunksize)
cm: ContextManager = pool
return cm, output_iterator


def run(arguments: Namespace) -> None:
# Normalize args
legend_files = arguments.legend_files
Expand Down Expand Up @@ -777,32 +859,41 @@ def run(arguments: Namespace) -> None:
'"--chain-files" argument'
)

# Read sample
sample_variants = read_bim_file(arguments.bfile)
sample_chromosomes = frozenset(variant.chromosome for variant in sample_variants)
bfiles: list[str] = arguments.bfile

# Instantiate commands
command = arguments.command(sample_variants, chain_files)
pool, commands_iterator = make_pool_or_null_context(
bfiles,
partial(make_command, arguments.command, chain_files),
num_threads=arguments.num_threads,
)
with pool:
commands: list[Command] = []
for command in tqdm(commands_iterator, total=len(bfiles), leave=False):
commands.append(command)

# Read reference
reference_chromosomes: set[Chromosome] = set()
# with Pool() as pool:pool.imap_unordered
for result in map(read_legend_file, legend_files):
for legend_file in tqdm(legend_files, leave=False):
result = read_legend_file(legend_file)
if result is None:
continue

chromosome, reference_variants = result
reference_chromosomes.add(chromosome)
for command in commands:
command.map(chromosome, reference_variants)

command.map(chromosome, reference_variants)

if not sample_chromosomes.issubset(reference_chromosomes):
missing_chromosomes = sample_chromosomes - reference_chromosomes
logger.warning(
"Missing reference data for sample chromosomes " f"{missing_chromosomes}. "
for bfile, command in zip(bfiles, commands):
sample_chromosomes = frozenset(
variant.chromosome for variant in command.sample_variants
)
if not sample_chromosomes.issubset(reference_chromosomes):
missing_chromosomes = sample_chromosomes - reference_chromosomes
logger.warning(
"Missing reference data for chromosomes "
f'{missing_chromosomes} for input file "{bfile}". '
)

command.reduce(arguments.bfile)
command.reduce(bfile)


def main(argv: list[str]) -> None:
Expand Down

0 comments on commit 669f612

Please sign in to comment.